diff --git a/acl/src/main/java/org/springframework/security/acls/AclEntryVoter.java b/acl/src/main/java/org/springframework/security/acls/AclEntryVoter.java index 0d5ebcc4e4..4e694b0247 100644 --- a/acl/src/main/java/org/springframework/security/acls/AclEntryVoter.java +++ b/acl/src/main/java/org/springframework/security/acls/AclEntryVoter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; import java.lang.reflect.InvocationTargetException; @@ -24,6 +25,7 @@ import java.util.List; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.access.AuthorizationServiceException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.vote.AbstractAclVoter; @@ -39,6 +41,7 @@ import org.springframework.security.acls.model.Sid; import org.springframework.security.acls.model.SidRetrievalStrategy; import org.springframework.security.core.Authentication; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; /** @@ -92,59 +95,45 @@ import org.springframework.util.StringUtils; *

* All comparisons and prefixes are case sensitive. * - * * @author Ben Alex */ public class AclEntryVoter extends AbstractAclVoter { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(AclEntryVoter.class); - // ~ Instance fields - // ================================================================================================ + private final AclService aclService; + + private final String processConfigAttribute; + + private final List requirePermission; - private AclService aclService; private ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy = new ObjectIdentityRetrievalStrategyImpl(); + private SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); + private String internalMethod; - private String processConfigAttribute; - private List requirePermission; - // ~ Constructors - // =================================================================================================== - - public AclEntryVoter(AclService aclService, String processConfigAttribute, - Permission[] requirePermission) { + public AclEntryVoter(AclService aclService, String processConfigAttribute, Permission[] requirePermission) { Assert.notNull(processConfigAttribute, "A processConfigAttribute is mandatory"); Assert.notNull(aclService, "An AclService is mandatory"); - - if ((requirePermission == null) || (requirePermission.length == 0)) { - throw new IllegalArgumentException( - "One or more requirePermission entries is mandatory"); - } - + Assert.isTrue(!ObjectUtils.isEmpty(requirePermission), "One or more requirePermission entries is mandatory"); this.aclService = aclService; this.processConfigAttribute = processConfigAttribute; this.requirePermission = Arrays.asList(requirePermission); } - // ~ Methods - // ======================================================================================================== - /** * Optionally specifies a method of the domain object that will be used to obtain a * contained domain object. That contained domain object will be used for the ACL * evaluation. This is useful if a domain object contains a parent that an ACL * evaluation should be targeted for, instead of the child domain object (which * perhaps is being created and as such does not yet have any ACL permissions) - * * @return null to use the domain object, or the name of a method (that * requires no arguments) that should be invoked to obtain an Object * which will be the domain object used for ACL evaluation */ protected String getInternalMethod() { - return internalMethod; + return this.internalMethod; } public void setInternalMethod(String internalMethod) { @@ -152,13 +141,11 @@ public class AclEntryVoter extends AbstractAclVoter { } protected String getProcessConfigAttribute() { - return processConfigAttribute; + return this.processConfigAttribute; } - public void setObjectIdentityRetrievalStrategy( - ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { - Assert.notNull(objectIdentityRetrievalStrategy, - "ObjectIdentityRetrievalStrategy required"); + public void setObjectIdentityRetrievalStrategy(ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { + Assert.notNull(objectIdentityRetrievalStrategy, "ObjectIdentityRetrievalStrategy required"); this.objectIdentityRetrievalStrategy = objectIdentityRetrievalStrategy; } @@ -167,103 +154,60 @@ public class AclEntryVoter extends AbstractAclVoter { this.sidRetrievalStrategy = sidRetrievalStrategy; } + @Override public boolean supports(ConfigAttribute attribute) { - return (attribute.getAttribute() != null) - && attribute.getAttribute().equals(getProcessConfigAttribute()); + return (attribute.getAttribute() != null) && attribute.getAttribute().equals(getProcessConfigAttribute()); } - public int vote(Authentication authentication, MethodInvocation object, - Collection attributes) { - + @Override + public int vote(Authentication authentication, MethodInvocation object, Collection attributes) { for (ConfigAttribute attr : attributes) { - - if (!this.supports(attr)) { + if (!supports(attr)) { continue; } + // Need to make an access decision on this invocation // Attempt to locate the domain object instance to process Object domainObject = getDomainObjectInstance(object); // If domain object is null, vote to abstain if (domainObject == null) { - if (logger.isDebugEnabled()) { - logger.debug("Voting to abstain - domainObject is null"); - } - + logger.debug("Voting to abstain - domainObject is null"); return ACCESS_ABSTAIN; } // Evaluate if we are required to use an inner domain object - if (StringUtils.hasText(internalMethod)) { - try { - Class clazz = domainObject.getClass(); - Method method = clazz.getMethod(internalMethod, new Class[0]); - domainObject = method.invoke(domainObject); - } - catch (NoSuchMethodException nsme) { - throw new AuthorizationServiceException("Object of class '" - + domainObject.getClass() - + "' does not provide the requested internalMethod: " - + internalMethod); - } - catch (IllegalAccessException iae) { - logger.debug("IllegalAccessException", iae); - - throw new AuthorizationServiceException( - "Problem invoking internalMethod: " + internalMethod - + " for object: " + domainObject); - } - catch (InvocationTargetException ite) { - logger.debug("InvocationTargetException", ite); - - throw new AuthorizationServiceException( - "Problem invoking internalMethod: " + internalMethod - + " for object: " + domainObject); - } + if (StringUtils.hasText(this.internalMethod)) { + domainObject = invokeInternalMethod(domainObject); } // Obtain the OID applicable to the domain object - ObjectIdentity objectIdentity = objectIdentityRetrievalStrategy - .getObjectIdentity(domainObject); + ObjectIdentity objectIdentity = this.objectIdentityRetrievalStrategy.getObjectIdentity(domainObject); // Obtain the SIDs applicable to the principal - List sids = sidRetrievalStrategy.getSids(authentication); + List sids = this.sidRetrievalStrategy.getSids(authentication); Acl acl; try { // Lookup only ACLs for SIDs we're interested in - acl = aclService.readAclById(objectIdentity, sids); + acl = this.aclService.readAclById(objectIdentity, sids); } - catch (NotFoundException nfe) { - if (logger.isDebugEnabled()) { - logger.debug("Voting to deny access - no ACLs apply for this principal"); - } - + catch (NotFoundException ex) { + logger.debug("Voting to deny access - no ACLs apply for this principal"); return ACCESS_DENIED; } try { - if (acl.isGranted(requirePermission, sids, false)) { - if (logger.isDebugEnabled()) { - logger.debug("Voting to grant access"); - } - + if (acl.isGranted(this.requirePermission, sids, false)) { + logger.debug("Voting to grant access"); return ACCESS_GRANTED; } - else { - if (logger.isDebugEnabled()) { - logger.debug("Voting to deny access - ACLs returned, but insufficient permissions for this principal"); - } - - return ACCESS_DENIED; - } + logger.debug("Voting to deny access - ACLs returned, but insufficient permissions for this principal"); + return ACCESS_DENIED; } - catch (NotFoundException nfe) { - if (logger.isDebugEnabled()) { - logger.debug("Voting to deny access - no ACLs apply for this principal"); - } - + catch (NotFoundException ex) { + logger.debug("Voting to deny access - no ACLs apply for this principal"); return ACCESS_DENIED; } } @@ -271,4 +215,27 @@ public class AclEntryVoter extends AbstractAclVoter { // No configuration attribute matched, so abstain return ACCESS_ABSTAIN; } + + private Object invokeInternalMethod(Object domainObject) { + try { + Class domainObjectType = domainObject.getClass(); + Method method = domainObjectType.getMethod(this.internalMethod, new Class[0]); + return method.invoke(domainObject); + } + catch (NoSuchMethodException ex) { + throw new AuthorizationServiceException("Object of class '" + domainObject.getClass() + + "' does not provide the requested internalMethod: " + this.internalMethod); + } + catch (IllegalAccessException ex) { + logger.debug("IllegalAccessException", ex); + throw new AuthorizationServiceException( + "Problem invoking internalMethod: " + this.internalMethod + " for object: " + domainObject); + } + catch (InvocationTargetException ex) { + logger.debug("InvocationTargetException", ex); + throw new AuthorizationServiceException( + "Problem invoking internalMethod: " + this.internalMethod + " for object: " + domainObject); + } + } + } diff --git a/acl/src/main/java/org/springframework/security/acls/AclPermissionCacheOptimizer.java b/acl/src/main/java/org/springframework/security/acls/AclPermissionCacheOptimizer.java index de440a5115..663dcbf174 100644 --- a/acl/src/main/java/org/springframework/security/acls/AclPermissionCacheOptimizer.java +++ b/acl/src/main/java/org/springframework/security/acls/AclPermissionCacheOptimizer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; import java.util.ArrayList; @@ -21,6 +22,8 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.access.PermissionCacheOptimizer; import org.springframework.security.acls.domain.ObjectIdentityRetrievalStrategyImpl; import org.springframework.security.acls.domain.SidRetrievalStrategyImpl; @@ -38,45 +41,42 @@ import org.springframework.security.core.Authentication; * @since 3.1 */ public class AclPermissionCacheOptimizer implements PermissionCacheOptimizer { + private final Log logger = LogFactory.getLog(getClass()); + private final AclService aclService; + private SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); + private ObjectIdentityRetrievalStrategy oidRetrievalStrategy = new ObjectIdentityRetrievalStrategyImpl(); public AclPermissionCacheOptimizer(AclService aclService) { this.aclService = aclService; } + @Override public void cachePermissionsFor(Authentication authentication, Collection objects) { if (objects.isEmpty()) { return; } - List oidsToCache = new ArrayList<>(objects.size()); - for (Object domainObject : objects) { - if (domainObject == null) { - continue; + if (domainObject != null) { + ObjectIdentity oid = this.oidRetrievalStrategy.getObjectIdentity(domainObject); + oidsToCache.add(oid); } - ObjectIdentity oid = oidRetrievalStrategy.getObjectIdentity(domainObject); - oidsToCache.add(oid); } - - List sids = sidRetrievalStrategy.getSids(authentication); - - if (logger.isDebugEnabled()) { - logger.debug("Eagerly loading Acls for " + oidsToCache.size() + " objects"); - } - - aclService.readAclsById(oidsToCache, sids); + List sids = this.sidRetrievalStrategy.getSids(authentication); + this.logger.debug(LogMessage.of(() -> "Eagerly loading Acls for " + oidsToCache.size() + " objects")); + this.aclService.readAclsById(oidsToCache, sids); } - public void setObjectIdentityRetrievalStrategy( - ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { + public void setObjectIdentityRetrievalStrategy(ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { this.oidRetrievalStrategy = objectIdentityRetrievalStrategy; } public void setSidRetrievalStrategy(SidRetrievalStrategy sidRetrievalStrategy) { this.sidRetrievalStrategy = sidRetrievalStrategy; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/AclPermissionEvaluator.java b/acl/src/main/java/org/springframework/security/acls/AclPermissionEvaluator.java index b6d362a305..15907a22ca 100644 --- a/acl/src/main/java/org/springframework/security/acls/AclPermissionEvaluator.java +++ b/acl/src/main/java/org/springframework/security/acls/AclPermissionEvaluator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; import java.io.Serializable; @@ -22,6 +23,8 @@ import java.util.Locale; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.acls.domain.DefaultPermissionFactory; import org.springframework.security.acls.domain.ObjectIdentityRetrievalStrategyImpl; @@ -51,9 +54,13 @@ public class AclPermissionEvaluator implements PermissionEvaluator { private final Log logger = LogFactory.getLog(getClass()); private final AclService aclService; + private ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy = new ObjectIdentityRetrievalStrategyImpl(); + private ObjectIdentityGenerator objectIdentityGenerator = new ObjectIdentityRetrievalStrategyImpl(); + private SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); + private PermissionFactory permissionFactory = new DefaultPermissionFactory(); public AclPermissionEvaluator(AclService aclService) { @@ -65,100 +72,72 @@ public class AclPermissionEvaluator implements PermissionEvaluator { * the ACL configuration. If the domain object is null, returns false (this can always * be overridden using a null check in the expression itself). */ - public boolean hasPermission(Authentication authentication, Object domainObject, - Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Object domainObject, Object permission) { if (domainObject == null) { return false; } - - ObjectIdentity objectIdentity = objectIdentityRetrievalStrategy - .getObjectIdentity(domainObject); - + ObjectIdentity objectIdentity = this.objectIdentityRetrievalStrategy.getObjectIdentity(domainObject); return checkPermission(authentication, objectIdentity, permission); } - public boolean hasPermission(Authentication authentication, Serializable targetId, - String targetType, Object permission) { - ObjectIdentity objectIdentity = objectIdentityGenerator.createObjectIdentity( - targetId, targetType); - - return checkPermission(authentication, objectIdentity, permission); - } - - private boolean checkPermission(Authentication authentication, ObjectIdentity oid, + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, Object permission) { + ObjectIdentity objectIdentity = this.objectIdentityGenerator.createObjectIdentity(targetId, targetType); + return checkPermission(authentication, objectIdentity, permission); + } + + private boolean checkPermission(Authentication authentication, ObjectIdentity oid, Object permission) { // Obtain the SIDs applicable to the principal - List sids = sidRetrievalStrategy.getSids(authentication); + List sids = this.sidRetrievalStrategy.getSids(authentication); List requiredPermission = resolvePermission(permission); - - final boolean debug = logger.isDebugEnabled(); - - if (debug) { - logger.debug("Checking permission '" + permission + "' for object '" + oid - + "'"); - } - + this.logger.debug(LogMessage.of(() -> "Checking permission '" + permission + "' for object '" + oid + "'")); try { // Lookup only ACLs for SIDs we're interested in - Acl acl = aclService.readAclById(oid, sids); - + Acl acl = this.aclService.readAclById(oid, sids); if (acl.isGranted(requiredPermission, sids, false)) { - if (debug) { - logger.debug("Access is granted"); - } - + this.logger.debug("Access is granted"); return true; } - - if (debug) { - logger.debug("Returning false - ACLs returned, but insufficient permissions for this principal"); - } - + this.logger.debug("Returning false - ACLs returned, but insufficient permissions for this principal"); } catch (NotFoundException nfe) { - if (debug) { - logger.debug("Returning false - no ACLs apply for this principal"); - } + this.logger.debug("Returning false - no ACLs apply for this principal"); } - return false; - } List resolvePermission(Object permission) { if (permission instanceof Integer) { - return Arrays.asList(permissionFactory.buildFromMask((Integer) permission)); + return Arrays.asList(this.permissionFactory.buildFromMask((Integer) permission)); } - if (permission instanceof Permission) { return Arrays.asList((Permission) permission); } - if (permission instanceof Permission[]) { return Arrays.asList((Permission[]) permission); } - if (permission instanceof String) { String permString = (String) permission; - Permission p; - - try { - p = permissionFactory.buildFromName(permString); - } - catch (IllegalArgumentException notfound) { - p = permissionFactory.buildFromName(permString.toUpperCase(Locale.ENGLISH)); - } - + Permission p = buildPermission(permString); if (p != null) { return Arrays.asList(p); } - } throw new IllegalArgumentException("Unsupported permission: " + permission); } - public void setObjectIdentityRetrievalStrategy( - ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { + private Permission buildPermission(String permString) { + try { + return this.permissionFactory.buildFromName(permString); + } + catch (IllegalArgumentException notfound) { + return this.permissionFactory.buildFromName(permString.toUpperCase(Locale.ENGLISH)); + } + } + + public void setObjectIdentityRetrievalStrategy(ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { this.objectIdentityRetrievalStrategy = objectIdentityRetrievalStrategy; } @@ -173,4 +152,5 @@ public class AclPermissionEvaluator implements PermissionEvaluator { public void setPermissionFactory(PermissionFactory permissionFactory) { this.permissionFactory = permissionFactory; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/AbstractAclProvider.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/AbstractAclProvider.java index 4e7db49438..14fb8730d7 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/AbstractAclProvider.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/AbstractAclProvider.java @@ -32,6 +32,7 @@ import org.springframework.security.acls.model.Sid; import org.springframework.security.acls.model.SidRetrievalStrategy; import org.springframework.security.core.Authentication; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * Abstract {@link AfterInvocationProvider} which provides commonly-used ACL-related @@ -40,64 +41,52 @@ import org.springframework.util.Assert; * @author Ben Alex */ public abstract class AbstractAclProvider implements AfterInvocationProvider { - // ~ Instance fields - // ================================================================================================ protected final AclService aclService; - protected Class processDomainObjectClass = Object.class; - protected ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy = new ObjectIdentityRetrievalStrategyImpl(); - protected SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); - protected String processConfigAttribute; - protected final List requirePermission; - // ~ Constructors - // =================================================================================================== + protected String processConfigAttribute; + + protected Class processDomainObjectClass = Object.class; + + protected ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy = new ObjectIdentityRetrievalStrategyImpl(); + + protected SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); + + protected final List requirePermission; public AbstractAclProvider(AclService aclService, String processConfigAttribute, List requirePermission) { Assert.hasText(processConfigAttribute, "A processConfigAttribute is mandatory"); Assert.notNull(aclService, "An AclService is mandatory"); - - if (requirePermission == null || requirePermission.isEmpty()) { - throw new IllegalArgumentException( - "One or more requirePermission entries is mandatory"); - } - + Assert.isTrue(!ObjectUtils.isEmpty(requirePermission), "One or more requirePermission entries is mandatory"); this.aclService = aclService; this.processConfigAttribute = processConfigAttribute; this.requirePermission = requirePermission; } - // ~ Methods - // ======================================================================================================== - protected Class getProcessDomainObjectClass() { - return processDomainObjectClass; + return this.processDomainObjectClass; } protected boolean hasPermission(Authentication authentication, Object domainObject) { // Obtain the OID applicable to the domain object - ObjectIdentity objectIdentity = objectIdentityRetrievalStrategy - .getObjectIdentity(domainObject); + ObjectIdentity objectIdentity = this.objectIdentityRetrievalStrategy.getObjectIdentity(domainObject); // Obtain the SIDs applicable to the principal - List sids = sidRetrievalStrategy.getSids(authentication); + List sids = this.sidRetrievalStrategy.getSids(authentication); try { // Lookup only ACLs for SIDs we're interested in - Acl acl = aclService.readAclById(objectIdentity, sids); - - return acl.isGranted(requirePermission, sids, false); + Acl acl = this.aclService.readAclById(objectIdentity, sids); + return acl.isGranted(this.requirePermission, sids, false); } - catch (NotFoundException ignore) { + catch (NotFoundException ex) { return false; } } - public void setObjectIdentityRetrievalStrategy( - ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { - Assert.notNull(objectIdentityRetrievalStrategy, - "ObjectIdentityRetrievalStrategy required"); + public void setObjectIdentityRetrievalStrategy(ObjectIdentityRetrievalStrategy objectIdentityRetrievalStrategy) { + Assert.notNull(objectIdentityRetrievalStrategy, "ObjectIdentityRetrievalStrategy required"); this.objectIdentityRetrievalStrategy = objectIdentityRetrievalStrategy; } @@ -107,8 +96,7 @@ public abstract class AbstractAclProvider implements AfterInvocationProvider { } public void setProcessDomainObjectClass(Class processDomainObjectClass) { - Assert.notNull(processDomainObjectClass, - "processDomainObjectClass cannot be set to null"); + Assert.notNull(processDomainObjectClass, "processDomainObjectClass cannot be set to null"); this.processDomainObjectClass = processDomainObjectClass; } @@ -117,19 +105,20 @@ public abstract class AbstractAclProvider implements AfterInvocationProvider { this.sidRetrievalStrategy = sidRetrievalStrategy; } + @Override public boolean supports(ConfigAttribute attribute) { - return processConfigAttribute.equals(attribute.getAttribute()); + return this.processConfigAttribute.equals(attribute.getAttribute()); } /** * This implementation supports any type of class, because it does not query the * presented secure object. - * * @param clazz the secure object - * * @return always true */ + @Override public boolean supports(Class clazz) { return true; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProvider.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProvider.java index 2e7e2a35ed..fb788322dc 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProvider.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.afterinvocation; import java.util.Collection; @@ -20,6 +21,8 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AuthorizationServiceException; import org.springframework.security.access.ConfigAttribute; @@ -60,33 +63,21 @@ import org.springframework.security.core.Authentication; * @author Ben Alex * @author Paulo Neves */ -public class AclEntryAfterInvocationCollectionFilteringProvider extends - AbstractAclProvider { - // ~ Static fields/initializers - // ===================================================================================== +public class AclEntryAfterInvocationCollectionFilteringProvider extends AbstractAclProvider { - protected static final Log logger = LogFactory - .getLog(AclEntryAfterInvocationCollectionFilteringProvider.class); - - // ~ Constructors - // =================================================================================================== + protected static final Log logger = LogFactory.getLog(AclEntryAfterInvocationCollectionFilteringProvider.class); public AclEntryAfterInvocationCollectionFilteringProvider(AclService aclService, List requirePermission) { super(aclService, "AFTER_ACL_COLLECTION_READ", requirePermission); } - // ~ Methods - // ======================================================================================================== - + @Override @SuppressWarnings("unchecked") - public Object decide(Authentication authentication, Object object, - Collection config, Object returnedObject) - throws AccessDeniedException { - + public Object decide(Authentication authentication, Object object, Collection config, + Object returnedObject) throws AccessDeniedException { if (returnedObject == null) { logger.debug("Return object is null, skipping"); - return null; } @@ -96,44 +87,34 @@ public class AclEntryAfterInvocationCollectionFilteringProvider extends } // Need to process the Collection for this invocation - Filterer filterer; - - if (returnedObject instanceof Collection) { - filterer = new CollectionFilterer((Collection) returnedObject); - } - else if (returnedObject.getClass().isArray()) { - filterer = new ArrayFilterer((Object[]) returnedObject); - } - else { - throw new AuthorizationServiceException( - "A Collection or an array (or null) was required as the " - + "returnedObject, but the returnedObject was: " - + returnedObject); - } + Filterer filterer = getFilterer(returnedObject); // Locate unauthorised Collection elements for (Object domainObject : filterer) { // Ignore nulls or entries which aren't instances of the configured domain // object class - if (domainObject == null - || !getProcessDomainObjectClass().isAssignableFrom( - domainObject.getClass())) { + if (domainObject == null || !getProcessDomainObjectClass().isAssignableFrom(domainObject.getClass())) { continue; } - if (!hasPermission(authentication, domainObject)) { filterer.remove(domainObject); - - if (logger.isDebugEnabled()) { - logger.debug("Principal is NOT authorised for element: " - + domainObject); - } + logger.debug(LogMessage.of(() -> "Principal is NOT authorised for element: " + domainObject)); } } - return filterer.getFilteredObject(); } - return returnedObject; } + + private Filterer getFilterer(Object returnedObject) { + if (returnedObject instanceof Collection) { + return new CollectionFilterer((Collection) returnedObject); + } + if (returnedObject.getClass().isArray()) { + return new ArrayFilterer((Object[]) returnedObject); + } + throw new AuthorizationServiceException("A Collection or an array (or null) was required as the " + + "returnedObject, but the returnedObject was: " + returnedObject); + } + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProvider.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProvider.java index 9a27f21572..7659cee298 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProvider.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.afterinvocation; import java.util.Collection; @@ -20,6 +21,7 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; @@ -58,50 +60,34 @@ import org.springframework.security.core.SpringSecurityMessageSource; *

* All comparisons and prefixes are case sensitive. */ -public class AclEntryAfterInvocationProvider extends AbstractAclProvider implements - MessageSourceAware { - // ~ Static fields/initializers - // ===================================================================================== +public class AclEntryAfterInvocationProvider extends AbstractAclProvider implements MessageSourceAware { - protected static final Log logger = LogFactory - .getLog(AclEntryAfterInvocationProvider.class); - - // ~ Instance fields - // ================================================================================================ + protected static final Log logger = LogFactory.getLog(AclEntryAfterInvocationProvider.class); protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); - // ~ Constructors - // =================================================================================================== - - public AclEntryAfterInvocationProvider(AclService aclService, - List requirePermission) { + public AclEntryAfterInvocationProvider(AclService aclService, List requirePermission) { this(aclService, "AFTER_ACL_READ", requirePermission); } - public AclEntryAfterInvocationProvider(AclService aclService, - String processConfigAttribute, List requirePermission) { + public AclEntryAfterInvocationProvider(AclService aclService, String processConfigAttribute, + List requirePermission) { super(aclService, processConfigAttribute, requirePermission); } - // ~ Methods - // ======================================================================================================== - - public Object decide(Authentication authentication, Object object, - Collection config, Object returnedObject) - throws AccessDeniedException { + @Override + public Object decide(Authentication authentication, Object object, Collection config, + Object returnedObject) throws AccessDeniedException { if (returnedObject == null) { // AclManager interface contract prohibits nulls // As they have permission to null/nothing, grant access logger.debug("Return object is null, skipping"); - return null; } if (!getProcessDomainObjectClass().isAssignableFrom(returnedObject.getClass())) { logger.debug("Return object is not applicable for this provider, skipping"); - return returnedObject; } @@ -109,24 +95,24 @@ public class AclEntryAfterInvocationProvider extends AbstractAclProvider impleme if (!this.supports(attr)) { continue; } - // Need to make an access decision on this invocation + // Need to make an access decision on this invocation if (hasPermission(authentication, returnedObject)) { return returnedObject; } logger.debug("Denying access"); - - throw new AccessDeniedException(messages.getMessage( - "AclEntryAfterInvocationProvider.noPermission", new Object[] { - authentication.getName(), returnedObject }, + throw new AccessDeniedException(this.messages.getMessage("AclEntryAfterInvocationProvider.noPermission", + new Object[] { authentication.getName(), returnedObject }, "Authentication {0} has NO permissions to the domain object {1}")); } return returnedObject; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/ArrayFilterer.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/ArrayFilterer.java index 1ffc3c7049..c9c9d01fd5 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/ArrayFilterer.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/ArrayFilterer.java @@ -25,6 +25,8 @@ import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; + /** * A filter used to filter arrays. * @@ -32,91 +34,70 @@ import org.apache.commons.logging.LogFactory; * @author Paulo Neves */ class ArrayFilterer implements Filterer { - // ~ Static fields/initializers - // ===================================================================================== protected static final Log logger = LogFactory.getLog(ArrayFilterer.class); - // ~ Instance fields - // ================================================================================================ - private final Set removeList; - private final T[] list; - // ~ Constructors - // =================================================================================================== + private final T[] list; ArrayFilterer(T[] list) { this.list = list; - // Collect the removed objects to a HashSet so that // it is fast to lookup them when a filtered array // is constructed. - removeList = new HashSet<>(); + this.removeList = new HashSet<>(); } - // ~ Methods - // ======================================================================================================== - - /** - * - * @see org.springframework.security.acls.afterinvocation.Filterer#getFilteredObject() - */ + @Override @SuppressWarnings("unchecked") public T[] getFilteredObject() { // Recreate an array of same type and filter the removed objects. - int originalSize = list.length; - int sizeOfResultingList = originalSize - removeList.size(); - T[] filtered = (T[]) Array.newInstance(list.getClass().getComponentType(), - sizeOfResultingList); - - for (int i = 0, j = 0; i < list.length; i++) { - T object = list[i]; - - if (!removeList.contains(object)) { + int originalSize = this.list.length; + int sizeOfResultingList = originalSize - this.removeList.size(); + T[] filtered = (T[]) Array.newInstance(this.list.getClass().getComponentType(), sizeOfResultingList); + for (int i = 0, j = 0; i < this.list.length; i++) { + T object = this.list[i]; + if (!this.removeList.contains(object)) { filtered[j] = object; j++; } } - - if (logger.isDebugEnabled()) { - logger.debug("Original array contained " + originalSize - + " elements; now contains " + sizeOfResultingList + " elements"); - } - + logger.debug(LogMessage.of(() -> "Original array contained " + originalSize + " elements; now contains " + + sizeOfResultingList + " elements")); return filtered; } - /** - * - * @see org.springframework.security.acls.afterinvocation.Filterer#iterator() - */ + @Override public Iterator iterator() { - return new Iterator() { - private int index = 0; + return new ArrayFiltererIterator(); + } - public boolean hasNext() { - return index < list.length; - } - - public T next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - return list[index++]; - } - - public void remove() { - throw new UnsupportedOperationException(); - } - }; + @Override + public void remove(T object) { + this.removeList.add(object); } /** - * - * @see org.springframework.security.acls.afterinvocation.Filterer#remove(java.lang.Object) + * Iterator for {@link ArrayFilterer} elements. */ - public void remove(T object) { - removeList.add(object); + private class ArrayFiltererIterator implements Iterator { + + private int index = 0; + + @Override + public boolean hasNext() { + return this.index < ArrayFilterer.this.list.length; + } + + @Override + public T next() { + if (hasNext()) { + return ArrayFilterer.this.list[this.index++]; + } + throw new NoSuchElementException(); + } + } + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/CollectionFilterer.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/CollectionFilterer.java index b937dd32d0..8322a9c1aa 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/CollectionFilterer.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/CollectionFilterer.java @@ -16,14 +16,16 @@ package org.springframework.security.acls.afterinvocation; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - import java.util.Collection; import java.util.HashSet; import java.util.Iterator; import java.util.Set; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; + /** * A filter used to filter Collections. * @@ -31,24 +33,15 @@ import java.util.Set; * @author Paulo Neves */ class CollectionFilterer implements Filterer { - // ~ Static fields/initializers - // ===================================================================================== protected static final Log logger = LogFactory.getLog(CollectionFilterer.class); - // ~ Instance fields - // ================================================================================================ - private final Collection collection; private final Set removeList; - // ~ Constructors - // =================================================================================================== - CollectionFilterer(Collection collection) { this.collection = collection; - // We create a Set of objects to be removed from the Collection, // as ConcurrentModificationException prevents removal during // iteration, and making a new Collection to be returned is @@ -56,47 +49,30 @@ class CollectionFilterer implements Filterer { // to the method may not necessarily be re-constructable (as // the Collection(collection) constructor is not guaranteed and // manually adding may lose sort order or other capabilities) - removeList = new HashSet<>(); + this.removeList = new HashSet<>(); } - // ~ Methods - // ======================================================================================================== - - /** - * - * @see org.springframework.security.acls.afterinvocation.Filterer#getFilteredObject() - */ + @Override public Object getFilteredObject() { // Now the Iterator has ended, remove Objects from Collection - Iterator removeIter = removeList.iterator(); - - int originalSize = collection.size(); - + Iterator removeIter = this.removeList.iterator(); + int originalSize = this.collection.size(); while (removeIter.hasNext()) { - collection.remove(removeIter.next()); + this.collection.remove(removeIter.next()); } - - if (logger.isDebugEnabled()) { - logger.debug("Original collection contained " + originalSize - + " elements; now contains " + collection.size() + " elements"); - } - - return collection; + logger.debug(LogMessage.of(() -> "Original collection contained " + originalSize + " elements; now contains " + + this.collection.size() + " elements")); + return this.collection; } - /** - * - * @see org.springframework.security.acls.afterinvocation.Filterer#iterator() - */ + @Override public Iterator iterator() { - return collection.iterator(); + return this.collection.iterator(); } - /** - * - * @see org.springframework.security.acls.afterinvocation.Filterer#remove(java.lang.Object) - */ + @Override public void remove(T object) { - removeList.add(object); + this.removeList.add(object); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/Filterer.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/Filterer.java index 2f9b35b575..f41bfa0bc7 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/Filterer.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/Filterer.java @@ -25,27 +25,24 @@ import java.util.Iterator; * @author Paulo Neves */ interface Filterer extends Iterable { - // ~ Methods - // ======================================================================================================== /** * Gets the filtered collection or array. - * * @return the filtered collection or array */ Object getFilteredObject(); /** * Returns an iterator over the filtered collection or array. - * * @return an Iterator */ + @Override Iterator iterator(); /** * Removes the given object from the resulting list. - * * @param object the object to be removed */ void remove(T object); + } diff --git a/acl/src/main/java/org/springframework/security/acls/afterinvocation/package-info.java b/acl/src/main/java/org/springframework/security/acls/afterinvocation/package-info.java index ab6c3e4251..4f4d7e647a 100644 --- a/acl/src/main/java/org/springframework/security/acls/afterinvocation/package-info.java +++ b/acl/src/main/java/org/springframework/security/acls/afterinvocation/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * After-invocation providers for collection and array filtering. Consider using a {@code PostFilter} annotation in - * preference. + * After-invocation providers for collection and array filtering. Consider using a + * {@code PostFilter} annotation in preference. */ package org.springframework.security.acls.afterinvocation; - diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AbstractPermission.java b/acl/src/main/java/org/springframework/security/acls/domain/AbstractPermission.java index ce679f4b49..147f643966 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AbstractPermission.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AbstractPermission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Permission; @@ -25,18 +26,13 @@ import org.springframework.security.acls.model.Permission; */ public abstract class AbstractPermission implements Permission { - // ~ Instance fields - // ================================================================================================ - protected final char code; + protected int mask; - // ~ Constructors - // =================================================================================================== /** * Sets the permission mask and uses the '*' character to represent active bits when * represented as a bit pattern string. - * * @param mask the integer bit mask for the permission */ protected AbstractPermission(int mask) { @@ -46,7 +42,6 @@ public abstract class AbstractPermission implements Permission { /** * Sets the permission mask and uses the specified character for active bits. - * * @param mask the integer bit mask for the permission * @param code the character to print for each active bit in the mask (see * {@link Permission#getPattern()}) @@ -56,36 +51,36 @@ public abstract class AbstractPermission implements Permission { this.code = code; } - // ~ Methods - // ======================================================================================================== - - public final boolean equals(Object arg0) { - if (arg0 == null) { + @Override + public final boolean equals(Object obj) { + if (obj == null) { return false; } - - if (!(arg0 instanceof Permission)) { + if (!(obj instanceof Permission)) { return false; } - - Permission rhs = (Permission) arg0; - - return (this.mask == rhs.getMask()); - } - - public final int getMask() { - return mask; - } - - public String getPattern() { - return AclFormattingUtils.printBinary(mask, code); - } - - public final String toString() { - return this.getClass().getSimpleName() + "[" + getPattern() + "=" + mask + "]"; + Permission other = (Permission) obj; + return (this.mask == other.getMask()); } + @Override public final int hashCode() { return this.mask; } + + @Override + public final String toString() { + return this.getClass().getSimpleName() + "[" + getPattern() + "=" + this.mask + "]"; + } + + @Override + public final int getMask() { + return this.mask; + } + + @Override + public String getPattern() { + return AclFormattingUtils.printBinary(this.mask, this.code); + } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AccessControlEntryImpl.java b/acl/src/main/java/org/springframework/security/acls/domain/AccessControlEntryImpl.java index fb41f101d6..dfc87a5f1f 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AccessControlEntryImpl.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AccessControlEntryImpl.java @@ -13,42 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; +import java.io.Serializable; + import org.springframework.security.acls.model.AccessControlEntry; import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.AuditableAccessControlEntry; import org.springframework.security.acls.model.Permission; import org.springframework.security.acls.model.Sid; - import org.springframework.util.Assert; -import java.io.Serializable; - /** * An immutable default implementation of AccessControlEntry. * * @author Ben Alex */ -public class AccessControlEntryImpl implements AccessControlEntry, - AuditableAccessControlEntry { - // ~ Instance fields - // ================================================================================================ +public class AccessControlEntryImpl implements AccessControlEntry, AuditableAccessControlEntry { private final Acl acl; + private Permission permission; + private final Serializable id; + private final Sid sid; + private boolean auditFailure = false; + private boolean auditSuccess = false; + private final boolean granting; - // ~ Constructors - // =================================================================================================== - - public AccessControlEntryImpl(Serializable id, Acl acl, Sid sid, - Permission permission, boolean granting, boolean auditSuccess, - boolean auditFailure) { + public AccessControlEntryImpl(Serializable id, Acl acl, Sid sid, Permission permission, boolean granting, + boolean auditSuccess, boolean auditFailure) { Assert.notNull(acl, "Acl required"); Assert.notNull(sid, "Sid required"); Assert.notNull(permission, "Permission required"); @@ -61,78 +60,66 @@ public class AccessControlEntryImpl implements AccessControlEntry, this.auditFailure = auditFailure; } - // ~ Methods - // ======================================================================================================== - @Override public boolean equals(Object arg0) { if (!(arg0 instanceof AccessControlEntryImpl)) { return false; } - - AccessControlEntryImpl rhs = (AccessControlEntryImpl) arg0; - + AccessControlEntryImpl other = (AccessControlEntryImpl) arg0; if (this.acl == null) { - if (rhs.getAcl() != null) { + if (other.getAcl() != null) { return false; } // Both this.acl and rhs.acl are null and thus equal } else { // this.acl is non-null - if (rhs.getAcl() == null) { + if (other.getAcl() == null) { return false; } // Both this.acl and rhs.acl are non-null, so do a comparison if (this.acl.getObjectIdentity() == null) { - if (rhs.acl.getObjectIdentity() != null) { + if (other.acl.getObjectIdentity() != null) { return false; } // Both this.acl and rhs.acl are null and thus equal } else { // Both this.acl.objectIdentity and rhs.acl.objectIdentity are non-null - if (!this.acl.getObjectIdentity() - .equals(rhs.getAcl().getObjectIdentity())) { + if (!this.acl.getObjectIdentity().equals(other.getAcl().getObjectIdentity())) { return false; } } } - if (this.id == null) { - if (rhs.id != null) { + if (other.id != null) { return false; } // Both this.id and rhs.id are null and thus equal } else { // this.id is non-null - if (rhs.id == null) { + if (other.id == null) { return false; } - // Both this.id and rhs.id are non-null - if (!this.id.equals(rhs.id)) { + if (!this.id.equals(other.id)) { return false; } } - - if ((this.auditFailure != rhs.isAuditFailure()) - || (this.auditSuccess != rhs.isAuditSuccess()) - || (this.granting != rhs.isGranting()) - || !this.permission.equals(rhs.getPermission()) - || !this.sid.equals(rhs.getSid())) { + if ((this.auditFailure != other.isAuditFailure()) || (this.auditSuccess != other.isAuditSuccess()) + || (this.granting != other.isGranting()) || !this.permission.equals(other.getPermission()) + || !this.sid.equals(other.getSid())) { return false; } - return true; } @Override public int hashCode() { int result = this.permission.hashCode(); - result = 31 * result + (this.id != null ? this.id.hashCode() : 0); + result = 31 * result + ((this.id != null) ? this.id.hashCode() : 0); result = 31 * result + (this.sid.hashCode()); result = 31 * result + (this.auditFailure ? 1 : 0); result = 31 * result + (this.auditSuccess ? 1 : 0); @@ -142,37 +129,37 @@ public class AccessControlEntryImpl implements AccessControlEntry, @Override public Acl getAcl() { - return acl; + return this.acl; } @Override public Serializable getId() { - return id; + return this.id; } @Override public Permission getPermission() { - return permission; + return this.permission; } @Override public Sid getSid() { - return sid; + return this.sid; } @Override public boolean isAuditFailure() { - return auditFailure; + return this.auditFailure; } @Override public boolean isAuditSuccess() { - return auditSuccess; + return this.auditSuccess; } @Override public boolean isGranting() { - return granting; + return this.granting; } void setAuditFailure(boolean auditFailure) { @@ -199,7 +186,7 @@ public class AccessControlEntryImpl implements AccessControlEntry, sb.append("auditSuccess: ").append(this.auditSuccess).append("; "); sb.append("auditFailure: ").append(this.auditFailure); sb.append("]"); - return sb.toString(); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategy.java b/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategy.java index 0ac6d86fd5..fa243b0fcf 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategy.java @@ -25,15 +25,13 @@ import org.springframework.security.acls.model.Acl; * @author Ben Alex */ public interface AclAuthorizationStrategy { - // ~ Static fields/initializers - // ===================================================================================== int CHANGE_OWNERSHIP = 0; + int CHANGE_AUDITING = 1; + int CHANGE_GENERAL = 2; - // ~ Methods - // ======================================================================================================== - void securityCheck(Acl acl, int changeType); + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImpl.java b/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImpl.java index 47693fc023..34e62babb5 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImpl.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImpl.java @@ -16,6 +16,10 @@ package org.springframework.security.acls.domain; +import java.util.Arrays; +import java.util.List; +import java.util.Set; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.Sid; @@ -26,10 +30,6 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; -import java.util.Arrays; -import java.util.List; -import java.util.Set; - /** * Default implementation of {@link AclAuthorizationStrategy}. *

@@ -45,21 +45,18 @@ import java.util.Set; * @author Ben Alex */ public class AclAuthorizationStrategyImpl implements AclAuthorizationStrategy { - // ~ Instance fields - // ================================================================================================ private final GrantedAuthority gaGeneralChanges; - private final GrantedAuthority gaModifyAuditing; - private final GrantedAuthority gaTakeOwnership; - private SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); - // ~ Constructors - // =================================================================================================== + private final GrantedAuthority gaModifyAuditing; + + private final GrantedAuthority gaTakeOwnership; + + private SidRetrievalStrategy sidRetrievalStrategy = new SidRetrievalStrategyImpl(); /** * Constructor. The only mandatory parameter relates to the system-wide * {@link GrantedAuthority} instances that can be held to always permit ACL changes. - * * @param auths the GrantedAuthoritys that have special permissions * (index 0 is the authority needed to change ownership, index 1 is the authority * needed to modify auditing details, index 2 is the authority needed to change other @@ -71,53 +68,33 @@ public class AclAuthorizationStrategyImpl implements AclAuthorizationStrategy { Assert.isTrue(auths != null && (auths.length == 3 || auths.length == 1), "One or three GrantedAuthority instances required"); if (auths.length == 3) { - gaTakeOwnership = auths[0]; - gaModifyAuditing = auths[1]; - gaGeneralChanges = auths[2]; + this.gaTakeOwnership = auths[0]; + this.gaModifyAuditing = auths[1]; + this.gaGeneralChanges = auths[2]; } else { - gaTakeOwnership = gaModifyAuditing = gaGeneralChanges = auths[0]; + this.gaTakeOwnership = auths[0]; + this.gaModifyAuditing = auths[0]; + this.gaGeneralChanges = auths[0]; } } - // ~ Methods - // ======================================================================================================== - + @Override public void securityCheck(Acl acl, int changeType) { if ((SecurityContextHolder.getContext() == null) || (SecurityContextHolder.getContext().getAuthentication() == null) - || !SecurityContextHolder.getContext().getAuthentication() - .isAuthenticated()) { - throw new AccessDeniedException( - "Authenticated principal required to operate with ACLs"); + || !SecurityContextHolder.getContext().getAuthentication().isAuthenticated()) { + throw new AccessDeniedException("Authenticated principal required to operate with ACLs"); } - - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); - + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); // Check if authorized by virtue of ACL ownership Sid currentUser = createCurrentUser(authentication); - if (currentUser.equals(acl.getOwner()) && ((changeType == CHANGE_GENERAL) || (changeType == CHANGE_OWNERSHIP))) { return; } - // Not authorized by ACL ownership; try via adminstrative permissions - GrantedAuthority requiredAuthority; - - if (changeType == CHANGE_AUDITING) { - requiredAuthority = this.gaModifyAuditing; - } - else if (changeType == CHANGE_GENERAL) { - requiredAuthority = this.gaGeneralChanges; - } - else if (changeType == CHANGE_OWNERSHIP) { - requiredAuthority = this.gaTakeOwnership; - } - else { - throw new IllegalArgumentException("Unknown change type"); - } + GrantedAuthority requiredAuthority = getRequiredAuthority(changeType); // Iterate this principal's authorities to determine right Set authorities = AuthorityUtils.authorityListToSet(authentication.getAuthorities()); @@ -126,8 +103,7 @@ public class AclAuthorizationStrategyImpl implements AclAuthorizationStrategy { } // Try to get permission via ACEs within the ACL - List sids = sidRetrievalStrategy.getSids(authentication); - + List sids = this.sidRetrievalStrategy.getSids(authentication); if (acl.isGranted(Arrays.asList(BasePermission.ADMINISTRATION), sids, false)) { return; } @@ -136,9 +112,21 @@ public class AclAuthorizationStrategyImpl implements AclAuthorizationStrategy { "Principal does not have required ACL permissions to perform requested operation"); } + private GrantedAuthority getRequiredAuthority(int changeType) { + if (changeType == CHANGE_AUDITING) { + return this.gaModifyAuditing; + } + if (changeType == CHANGE_GENERAL) { + return this.gaGeneralChanges; + } + if (changeType == CHANGE_OWNERSHIP) { + return this.gaTakeOwnership; + } + throw new IllegalArgumentException("Unknown change type"); + } + /** * Creates a principal-like sid from the authentication information. - * * @param authentication the authentication information that can provide principal and * thus the sid's id will be dependant on the value inside * @return a sid with the ID taken from the authentication information @@ -151,4 +139,5 @@ public class AclAuthorizationStrategyImpl implements AclAuthorizationStrategy { Assert.notNull(sidRetrievalStrategy, "SidRetrievalStrategy required"); this.sidRetrievalStrategy = sidRetrievalStrategy; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AclFormattingUtils.java b/acl/src/main/java/org/springframework/security/acls/domain/AclFormattingUtils.java index c9809204ae..36bdb28a2d 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AclFormattingUtils.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AclFormattingUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Permission; @@ -30,9 +31,7 @@ public abstract class AclFormattingUtils { Assert.notNull(removeBits, "Bits To Remove string required"); Assert.isTrue(original.length() == removeBits.length(), "Original and Bits To Remove strings must be identical length"); - char[] replacement = new char[original.length()]; - for (int i = 0; i < original.length(); i++) { if (removeBits.charAt(i) == Permission.RESERVED_OFF) { replacement[i] = original.charAt(i); @@ -41,7 +40,6 @@ public abstract class AclFormattingUtils { replacement[i] = Permission.RESERVED_OFF; } } - return new String(replacement); } @@ -50,9 +48,7 @@ public abstract class AclFormattingUtils { Assert.notNull(extraBits, "Extra Bits string required"); Assert.isTrue(original.length() == extraBits.length(), "Original and Extra Bits strings must be identical length"); - char[] replacement = new char[extraBits.length()]; - for (int i = 0; i < extraBits.length(); i++) { if (extraBits.charAt(i) == Permission.RESERVED_OFF) { replacement[i] = original.charAt(i); @@ -61,7 +57,6 @@ public abstract class AclFormattingUtils { replacement[i] = extraBits.charAt(i); } } - return new String(replacement); } @@ -70,9 +65,7 @@ public abstract class AclFormattingUtils { * bit being denoted by character '*'. *

* Inactive bits will be denoted by character {@link Permission#RESERVED_OFF}. - * * @param i the integer bit mask to print the active bits for - * * @return a 32-character representation of the bit mask */ public static String printBinary(int i) { @@ -84,29 +77,23 @@ public abstract class AclFormattingUtils { * bit being denoted by the passed character. *

* Inactive bits will be denoted by character {@link Permission#RESERVED_OFF}. - * * @param mask the integer bit mask to print the active bits for * @param code the character to print when an active bit is detected - * * @return a 32-character representation of the bit mask */ public static String printBinary(int mask, char code) { - Assert.doesNotContain(Character.toString(code), - Character.toString(Permission.RESERVED_ON), + Assert.doesNotContain(Character.toString(code), Character.toString(Permission.RESERVED_ON), () -> Permission.RESERVED_ON + " is a reserved character code"); - Assert.doesNotContain(Character.toString(code), - Character.toString(Permission.RESERVED_OFF), + Assert.doesNotContain(Character.toString(code), Character.toString(Permission.RESERVED_OFF), () -> Permission.RESERVED_OFF + " is a reserved character code"); - - return printBinary(mask, Permission.RESERVED_ON, Permission.RESERVED_OFF) - .replace(Permission.RESERVED_ON, code); + return printBinary(mask, Permission.RESERVED_ON, Permission.RESERVED_OFF).replace(Permission.RESERVED_ON, code); } private static String printBinary(int i, char on, char off) { String s = Integer.toBinaryString(i); String pattern = Permission.THIRTY_TWO_RESERVED_OFF; String temp2 = pattern.substring(0, pattern.length() - s.length()) + s; - return temp2.replace('0', off).replace('1', on); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AclImpl.java b/acl/src/main/java/org/springframework/security/acls/domain/AclImpl.java index 472b33eaf4..5f7a0532a3 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AclImpl.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AclImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import java.io.Serializable; @@ -31,6 +32,7 @@ import org.springframework.security.acls.model.PermissionGrantingStrategy; import org.springframework.security.acls.model.Sid; import org.springframework.security.acls.model.UnloadedSidException; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * Base implementation of Acl. @@ -38,35 +40,38 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { - // ~ Instance fields - // ================================================================================================ private Acl parentAcl; - private transient AclAuthorizationStrategy aclAuthorizationStrategy; - private transient PermissionGrantingStrategy permissionGrantingStrategy; - private final List aces = new ArrayList<>(); - private ObjectIdentity objectIdentity; - private Serializable id; - private Sid owner; // OwnershipAcl - private List loadedSids = null; // includes all SIDs the WHERE clause covered, - // even if there was no ACE for a SID - private boolean entriesInheriting = true; - // ~ Constructors - // =================================================================================================== + private transient AclAuthorizationStrategy aclAuthorizationStrategy; + + private transient PermissionGrantingStrategy permissionGrantingStrategy; + + private final List aces = new ArrayList<>(); + + private ObjectIdentity objectIdentity; + + private Serializable id; + + // OwnershipAcl + private Sid owner; + + // includes all SIDs the WHERE clause covered, even if there was no ACE for a SID + private List loadedSids = null; + + private boolean entriesInheriting = true; /** * Minimal constructor, which should be used * {@link org.springframework.security.acls.model.MutableAclService#createAcl(ObjectIdentity)} * . - * * @param objectIdentity the object identity this ACL relates to (required) * @param id the primary key assigned to this ACL (required) * @param aclAuthorizationStrategy authorization strategy (required) * @param auditLogger audit logger (required) */ - public AclImpl(ObjectIdentity objectIdentity, Serializable id, - AclAuthorizationStrategy aclAuthorizationStrategy, AuditLogger auditLogger) { + public AclImpl(ObjectIdentity objectIdentity, Serializable id, AclAuthorizationStrategy aclAuthorizationStrategy, + AuditLogger auditLogger) { Assert.notNull(objectIdentity, "Object Identity required"); Assert.notNull(id, "Id required"); Assert.notNull(aclAuthorizationStrategy, "AclAuthorizationStrategy required"); @@ -74,14 +79,12 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { this.objectIdentity = objectIdentity; this.id = id; this.aclAuthorizationStrategy = aclAuthorizationStrategy; - this.permissionGrantingStrategy = new DefaultPermissionGrantingStrategy( - auditLogger); + this.permissionGrantingStrategy = new DefaultPermissionGrantingStrategy(auditLogger); } /** * Full constructor, which should be used by persistence tools that do not provide * field-level access features. - * * @param objectIdentity the object identity this ACL relates to * @param id the primary key assigned to this ACL * @param aclAuthorizationStrategy authorization strategy @@ -93,15 +96,13 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { * @param entriesInheriting if ACEs from the parent should inherit into this ACL * @param owner the owner (required) */ - public AclImpl(ObjectIdentity objectIdentity, Serializable id, - AclAuthorizationStrategy aclAuthorizationStrategy, - PermissionGrantingStrategy grantingStrategy, Acl parentAcl, - List loadedSids, boolean entriesInheriting, Sid owner) { + public AclImpl(ObjectIdentity objectIdentity, Serializable id, AclAuthorizationStrategy aclAuthorizationStrategy, + PermissionGrantingStrategy grantingStrategy, Acl parentAcl, List loadedSids, boolean entriesInheriting, + Sid owner) { Assert.notNull(objectIdentity, "Object Identity required"); Assert.notNull(id, "Id required"); Assert.notNull(aclAuthorizationStrategy, "AclAuthorizationStrategy required"); Assert.notNull(owner, "Owner required"); - this.objectIdentity = objectIdentity; this.id = id; this.aclAuthorizationStrategy = aclAuthorizationStrategy; @@ -120,16 +121,11 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { private AclImpl() { } - // ~ Methods - // ======================================================================================================== - @Override public void deleteAce(int aceIndex) throws NotFoundException { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_GENERAL); + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_GENERAL); verifyAceIndexExists(aceIndex); - - synchronized (aces) { + synchronized (this.aces) { this.aces.remove(aceIndex); } } @@ -139,32 +135,26 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { throw new NotFoundException("aceIndex must be greater than or equal to zero"); } if (aceIndex >= this.aces.size()) { - throw new NotFoundException( - "aceIndex must refer to an index of the AccessControlEntry list. " - + "List size is " + aces.size() + ", index was " + aceIndex); + throw new NotFoundException("aceIndex must refer to an index of the AccessControlEntry list. " + + "List size is " + this.aces.size() + ", index was " + aceIndex); } } @Override - public void insertAce(int atIndexLocation, Permission permission, Sid sid, - boolean granting) throws NotFoundException { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_GENERAL); + public void insertAce(int atIndexLocation, Permission permission, Sid sid, boolean granting) + throws NotFoundException { + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_GENERAL); Assert.notNull(permission, "Permission required"); Assert.notNull(sid, "Sid required"); if (atIndexLocation < 0) { - throw new NotFoundException( - "atIndexLocation must be greater than or equal to zero"); + throw new NotFoundException("atIndexLocation must be greater than or equal to zero"); } if (atIndexLocation > this.aces.size()) { throw new NotFoundException( "atIndexLocation must be less than or equal to the size of the AccessControlEntry collection"); } - - AccessControlEntryImpl ace = new AccessControlEntryImpl(null, this, sid, - permission, granting, false, false); - - synchronized (aces) { + AccessControlEntryImpl ace = new AccessControlEntryImpl(null, this, sid, permission, granting, false, false); + synchronized (this.aces) { this.aces.add(atIndexLocation, ace); } } @@ -173,7 +163,7 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { public List getEntries() { // Can safely return AccessControlEntry directly, as they're immutable outside the // ACL package - return new ArrayList<>(aces); + return new ArrayList<>(this.aces); } @Override @@ -183,33 +173,29 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { @Override public ObjectIdentity getObjectIdentity() { - return objectIdentity; + return this.objectIdentity; } @Override public boolean isEntriesInheriting() { - return entriesInheriting; + return this.entriesInheriting; } /** * Delegates to the {@link PermissionGrantingStrategy}. - * * @throws UnloadedSidException if the passed SIDs are unknown to this ACL because the * ACL was only loaded for a subset of SIDs * @see DefaultPermissionGrantingStrategy */ @Override - public boolean isGranted(List permission, List sids, - boolean administrativeMode) throws NotFoundException, UnloadedSidException { + public boolean isGranted(List permission, List sids, boolean administrativeMode) + throws NotFoundException, UnloadedSidException { Assert.notEmpty(permission, "Permissions required"); Assert.notEmpty(sids, "SIDs required"); - if (!this.isSidLoaded(sids)) { throw new UnloadedSidException("ACL was not loaded for one or more SID"); } - - return permissionGrantingStrategy.isGranted(this, permission, sids, - administrativeMode); + return this.permissionGrantingStrategy.isGranted(this, permission, sids, administrativeMode); } @Override @@ -223,16 +209,13 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { // This ACL applies to a SID subset only. Iterate to check it applies. for (Sid sid : sids) { boolean found = false; - - for (Sid loadedSid : loadedSids) { + for (Sid loadedSid : this.loadedSids) { if (sid.equals(loadedSid)) { // this SID is OK found = true; - break; // out of loadedSids for loop } } - if (!found) { return false; } @@ -243,15 +226,13 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { @Override public void setEntriesInheriting(boolean entriesInheriting) { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_GENERAL); + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_GENERAL); this.entriesInheriting = entriesInheriting; } @Override public void setOwner(Sid newOwner) { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_OWNERSHIP); Assert.notNull(newOwner, "Owner required"); this.owner = newOwner; } @@ -263,38 +244,32 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { @Override public void setParent(Acl newParent) { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_GENERAL); - Assert.isTrue(newParent == null || !newParent.equals(this), - "Cannot be the parent of yourself"); + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_GENERAL); + Assert.isTrue(newParent == null || !newParent.equals(this), "Cannot be the parent of yourself"); this.parentAcl = newParent; } @Override public Acl getParentAcl() { - return parentAcl; + return this.parentAcl; } @Override public void updateAce(int aceIndex, Permission permission) throws NotFoundException { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_GENERAL); + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_GENERAL); verifyAceIndexExists(aceIndex); - - synchronized (aces) { - AccessControlEntryImpl ace = (AccessControlEntryImpl) aces.get(aceIndex); + synchronized (this.aces) { + AccessControlEntryImpl ace = (AccessControlEntryImpl) this.aces.get(aceIndex); ace.setPermission(permission); } } @Override public void updateAuditing(int aceIndex, boolean auditSuccess, boolean auditFailure) { - aclAuthorizationStrategy.securityCheck(this, - AclAuthorizationStrategy.CHANGE_AUDITING); + this.aclAuthorizationStrategy.securityCheck(this, AclAuthorizationStrategy.CHANGE_AUDITING); verifyAceIndexExists(aceIndex); - - synchronized (aces) { - AccessControlEntryImpl ace = (AccessControlEntryImpl) aces.get(aceIndex); + synchronized (this.aces) { + AccessControlEntryImpl ace = (AccessControlEntryImpl) this.aces.get(aceIndex); ace.setAuditSuccess(auditSuccess); ace.setAuditFailure(auditFailure); } @@ -302,57 +277,35 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { @Override public boolean equals(Object obj) { - if (obj instanceof AclImpl) { - AclImpl rhs = (AclImpl) obj; - if (this.aces.equals(rhs.aces)) { - if ((this.parentAcl == null && rhs.parentAcl == null) - || (this.parentAcl != null && this.parentAcl - .equals(rhs.parentAcl))) { - if ((this.objectIdentity == null && rhs.objectIdentity == null) - || (this.objectIdentity != null && this.objectIdentity - .equals(rhs.objectIdentity))) { - if ((this.id == null && rhs.id == null) - || (this.id != null && this.id.equals(rhs.id))) { - if ((this.owner == null && rhs.owner == null) - || (this.owner != null && this.owner - .equals(rhs.owner))) { - if (this.entriesInheriting == rhs.entriesInheriting) { - if ((this.loadedSids == null && rhs.loadedSids == null)) { - return true; - } - if (this.loadedSids != null - && (this.loadedSids.size() == rhs.loadedSids - .size())) { - for (int i = 0; i < this.loadedSids.size(); i++) { - if (!this.loadedSids.get(i).equals( - rhs.loadedSids.get(i))) { - return false; - } - } - return true; - } - } - } - } - } - } - } + if (obj == this) { + return true; } - return false; + if (obj == null || !(obj instanceof AclImpl)) { + return false; + } + AclImpl other = (AclImpl) obj; + boolean result = true; + result = result && this.aces.equals(other.aces); + result = result && ObjectUtils.nullSafeEquals(this.parentAcl, other.parentAcl); + result = result && ObjectUtils.nullSafeEquals(this.objectIdentity, other.objectIdentity); + result = result && ObjectUtils.nullSafeEquals(this.id, other.id); + result = result && ObjectUtils.nullSafeEquals(this.owner, other.owner); + result = result && this.entriesInheriting == other.entriesInheriting; + result = result && ObjectUtils.nullSafeEquals(this.loadedSids, other.loadedSids); + return result; } @Override public int hashCode() { - int result = this.parentAcl != null ? this.parentAcl.hashCode() : 0; + int result = (this.parentAcl != null) ? this.parentAcl.hashCode() : 0; result = 31 * result + this.aclAuthorizationStrategy.hashCode(); - result = 31 * result + (this.permissionGrantingStrategy != null ? - this.permissionGrantingStrategy.hashCode() : - 0); - result = 31 * result + (this.aces != null ? this.aces.hashCode() : 0); + result = 31 * result + + ((this.permissionGrantingStrategy != null) ? this.permissionGrantingStrategy.hashCode() : 0); + result = 31 * result + ((this.aces != null) ? this.aces.hashCode() : 0); result = 31 * result + this.objectIdentity.hashCode(); result = 31 * result + this.id.hashCode(); - result = 31 * result + (this.owner != null ? this.owner.hashCode() : 0); - result = 31 * result + (this.loadedSids != null ? this.loadedSids.hashCode() : 0); + result = 31 * result + ((this.owner != null) ? this.owner.hashCode() : 0); + result = 31 * result + ((this.loadedSids != null) ? this.loadedSids.hashCode() : 0); result = 31 * result + (this.entriesInheriting ? 1 : 0); return result; } @@ -364,33 +317,23 @@ public class AclImpl implements Acl, MutableAcl, AuditableAcl, OwnershipAcl { sb.append("id: ").append(this.id).append("; "); sb.append("objectIdentity: ").append(this.objectIdentity).append("; "); sb.append("owner: ").append(this.owner).append("; "); - int count = 0; - - for (AccessControlEntry ace : aces) { + for (AccessControlEntry ace : this.aces) { count++; - if (count == 1) { sb.append("\n"); } - sb.append(ace).append("\n"); } - if (count == 0) { sb.append("no ACEs; "); } - sb.append("inheriting: ").append(this.entriesInheriting).append("; "); - sb.append("parent: ").append( - (this.parentAcl == null) ? "Null" : this.parentAcl.getObjectIdentity() - .toString()); + sb.append("parent: ").append((this.parentAcl == null) ? "Null" : this.parentAcl.getObjectIdentity().toString()); sb.append("; "); - sb.append("aclAuthorizationStrategy: ").append(this.aclAuthorizationStrategy) - .append("; "); + sb.append("aclAuthorizationStrategy: ").append(this.aclAuthorizationStrategy).append("; "); sb.append("permissionGrantingStrategy: ").append(this.permissionGrantingStrategy); sb.append("]"); - return sb.toString(); } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/AuditLogger.java b/acl/src/main/java/org/springframework/security/acls/domain/AuditLogger.java index ff5bb69733..20edb6aa23 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/AuditLogger.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/AuditLogger.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.AccessControlEntry; @@ -21,11 +22,9 @@ import org.springframework.security.acls.model.AccessControlEntry; * Used by AclImpl to log audit events. * * @author Ben Alex - * */ public interface AuditLogger { - // ~ Methods - // ======================================================================================================== void logIfNeeded(boolean granted, AccessControlEntry ace); + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/BasePermission.java b/acl/src/main/java/org/springframework/security/acls/domain/BasePermission.java index 02981215f8..5da94f08d4 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/BasePermission.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/BasePermission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Permission; @@ -28,10 +29,15 @@ import org.springframework.security.acls.model.Permission; * @author Ben Alex */ public class BasePermission extends AbstractPermission { + public static final Permission READ = new BasePermission(1 << 0, 'R'); // 1 + public static final Permission WRITE = new BasePermission(1 << 1, 'W'); // 2 + public static final Permission CREATE = new BasePermission(1 << 2, 'C'); // 4 + public static final Permission DELETE = new BasePermission(1 << 3, 'D'); // 8 + public static final Permission ADMINISTRATION = new BasePermission(1 << 4, 'A'); // 16 protected BasePermission(int mask) { @@ -41,4 +47,5 @@ public class BasePermission extends AbstractPermission { protected BasePermission(int mask, char code) { super(mask, code); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/ConsoleAuditLogger.java b/acl/src/main/java/org/springframework/security/acls/domain/ConsoleAuditLogger.java index 1ceac22e15..744ec34148 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/ConsoleAuditLogger.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/ConsoleAuditLogger.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.AccessControlEntry; import org.springframework.security.acls.model.AuditableAccessControlEntry; - import org.springframework.util.Assert; /** @@ -26,15 +26,12 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class ConsoleAuditLogger implements AuditLogger { - // ~ Methods - // ======================================================================================================== + @Override public void logIfNeeded(boolean granted, AccessControlEntry ace) { Assert.notNull(ace, "AccessControlEntry required"); - if (ace instanceof AuditableAccessControlEntry) { AuditableAccessControlEntry auditableAce = (AuditableAccessControlEntry) ace; - if (granted && auditableAce.isAuditSuccess()) { System.out.println("GRANTED due to ACE: " + ace); } @@ -43,4 +40,5 @@ public class ConsoleAuditLogger implements AuditLogger { } } } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/CumulativePermission.java b/acl/src/main/java/org/springframework/security/acls/domain/CumulativePermission.java index 49e2c55af4..819ce4e5f3 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/CumulativePermission.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/CumulativePermission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Permission; @@ -37,27 +38,23 @@ public class CumulativePermission extends AbstractPermission { public CumulativePermission clear(Permission permission) { this.mask &= ~permission.getMask(); - this.pattern = AclFormattingUtils.demergePatterns(this.pattern, - permission.getPattern()); - + this.pattern = AclFormattingUtils.demergePatterns(this.pattern, permission.getPattern()); return this; } public CumulativePermission clear() { this.mask = 0; this.pattern = THIRTY_TWO_RESERVED_OFF; - return this; } public CumulativePermission set(Permission permission) { this.mask |= permission.getMask(); - this.pattern = AclFormattingUtils.mergePatterns(this.pattern, - permission.getPattern()); - + this.pattern = AclFormattingUtils.mergePatterns(this.pattern, permission.getPattern()); return this; } + @Override public String getPattern() { return this.pattern; } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionFactory.java b/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionFactory.java index 8061a916b5..8a68843d89 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionFactory.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import java.lang.reflect.Field; @@ -38,7 +39,9 @@ import org.springframework.util.Assert; * @since 2.0.3 */ public class DefaultPermissionFactory implements PermissionFactory { + private final Map registeredPermissionsByInteger = new HashMap<>(); + private final Map registeredPermissionsByName = new HashMap<>(); /** @@ -57,7 +60,6 @@ public class DefaultPermissionFactory implements PermissionFactory { /** * Registers a map of named Permission instances. - * * @param namedPermissions the map of Permissions, keyed by name. */ public DefaultPermissionFactory(Map namedPermissions) { @@ -71,27 +73,22 @@ public class DefaultPermissionFactory implements PermissionFactory { *

* These permissions will be registered under the name of the field. See * {@link BasePermission} for an example. - * * @param clazz a {@link Permission} class with public static fields to register */ protected void registerPublicPermissions(Class clazz) { Assert.notNull(clazz, "Class required"); - Field[] fields = clazz.getFields(); - for (Field field : fields) { try { Object fieldValue = field.get(null); - if (Permission.class.isAssignableFrom(fieldValue.getClass())) { // Found a Permission static field Permission perm = (Permission) fieldValue; String permissionName = field.getName(); - registerPermission(perm, permissionName); } } - catch (Exception ignore) { + catch (Exception ex) { } } } @@ -99,68 +96,57 @@ public class DefaultPermissionFactory implements PermissionFactory { protected void registerPermission(Permission perm, String permissionName) { Assert.notNull(perm, "Permission required"); Assert.hasText(permissionName, "Permission name required"); - Integer mask = perm.getMask(); // Ensure no existing Permission uses this integer or code - Assert.isTrue(!registeredPermissionsByInteger.containsKey(mask), + Assert.isTrue(!this.registeredPermissionsByInteger.containsKey(mask), () -> "An existing Permission already provides mask " + mask); - Assert.isTrue(!registeredPermissionsByName.containsKey(permissionName), + Assert.isTrue(!this.registeredPermissionsByName.containsKey(permissionName), () -> "An existing Permission already provides name '" + permissionName + "'"); // Register the new Permission - registeredPermissionsByInteger.put(mask, perm); - registeredPermissionsByName.put(permissionName, perm); + this.registeredPermissionsByInteger.put(mask, perm); + this.registeredPermissionsByName.put(permissionName, perm); } + @Override public Permission buildFromMask(int mask) { - if (registeredPermissionsByInteger.containsKey(mask)) { + if (this.registeredPermissionsByInteger.containsKey(mask)) { // The requested mask has an exact match against a statically-defined // Permission, so return it - return registeredPermissionsByInteger.get(mask); + return this.registeredPermissionsByInteger.get(mask); } // To get this far, we have to use a CumulativePermission CumulativePermission permission = new CumulativePermission(); - for (int i = 0; i < 32; i++) { int permissionToCheck = 1 << i; - if ((mask & permissionToCheck) == permissionToCheck) { - Permission p = registeredPermissionsByInteger.get(permissionToCheck); - - if (p == null) { - throw new IllegalStateException("Mask '" + permissionToCheck - + "' does not have a corresponding static Permission"); - } + Permission p = this.registeredPermissionsByInteger.get(permissionToCheck); + Assert.state(p != null, + () -> "Mask '" + permissionToCheck + "' does not have a corresponding static Permission"); permission.set(p); } } - return permission; } + @Override public Permission buildFromName(String name) { - Permission p = registeredPermissionsByName.get(name); - - if (p == null) { - throw new IllegalArgumentException("Unknown permission '" + name + "'"); - } - + Permission p = this.registeredPermissionsByName.get(name); + Assert.notNull(p, "Unknown permission '" + name + "'"); return p; } + @Override public List buildFromNames(List names) { if ((names == null) || (names.size() == 0)) { return Collections.emptyList(); } - List permissions = new ArrayList<>(names.size()); - for (String name : names) { permissions.add(buildFromName(name)); } - return permissions; } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionGrantingStrategy.java b/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionGrantingStrategy.java index 3f54678873..2bfc050be2 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionGrantingStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/DefaultPermissionGrantingStrategy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import java.util.List; @@ -61,42 +62,33 @@ public class DefaultPermissionGrantingStrategy implements PermissionGrantingStra * decide how to handle the permission check. Similarly, if any of the SID arguments * presented to the method were not loaded by the ACL, * UnloadedSidException will be thrown. - * * @param permission the exact permissions to scan for (order is important) * @param sids the exact SIDs to scan for (order is important) * @param administrativeMode if true denotes the query is for * administrative purposes and no auditing will be undertaken - * * @return true if one of the permissions has been granted, * false if one of the permissions has been specifically revoked - * * @throws NotFoundException if an exact ACE for one of the permission bit masks and * SID combination could not be found */ - public boolean isGranted(Acl acl, List permission, List sids, - boolean administrativeMode) throws NotFoundException { - - final List aces = acl.getEntries(); - + @Override + public boolean isGranted(Acl acl, List permission, List sids, boolean administrativeMode) + throws NotFoundException { + List aces = acl.getEntries(); AccessControlEntry firstRejection = null; - for (Permission p : permission) { for (Sid sid : sids) { // Attempt to find exact match for this permission mask and SID boolean scanNextSid = true; - for (AccessControlEntry ace : aces) { - - if (isGranted(ace, p) - && ace.getSid().equals(sid)) { + if (isGranted(ace, p) && ace.getSid().equals(sid)) { // Found a matching ACE, so its authorization decision will // prevail if (ace.isGranting()) { // Success if (!administrativeMode) { - auditLogger.logIfNeeded(true, ace); + this.auditLogger.logIfNeeded(true, ace); } - return true; } @@ -107,13 +99,11 @@ public class DefaultPermissionGrantingStrategy implements PermissionGrantingStra // Store first rejection for auditing reasons firstRejection = ace; } - scanNextSid = false; // helps break the loop break; // exit aces loop } } - if (!scanNextSid) { break; // exit SID for loop (now try next permission) } @@ -124,9 +114,8 @@ public class DefaultPermissionGrantingStrategy implements PermissionGrantingStra // We found an ACE to reject the request at this point, as no // other ACEs were found that granted a different permission if (!administrativeMode) { - auditLogger.logIfNeeded(false, firstRejection); + this.auditLogger.logIfNeeded(false, firstRejection); } - return false; } @@ -135,26 +124,22 @@ public class DefaultPermissionGrantingStrategy implements PermissionGrantingStra // We have a parent, so let them try to find a matching ACE return acl.getParentAcl().isGranted(permission, sids, false); } - else { - // We either have no parent, or we're the uppermost parent - throw new NotFoundException( - "Unable to locate a matching ACE for passed permissions and SIDs"); - } + + // We either have no parent, or we're the uppermost parent + throw new NotFoundException("Unable to locate a matching ACE for passed permissions and SIDs"); } /** - * Compares an ACE Permission to the given Permission. - * By default, we compare the Permission masks for exact match. - * Subclasses of this strategy can override this behavior and implement - * more sophisticated comparisons, e.g. a bitwise comparison for ACEs that grant access. - *

{@code
+	 * Compares an ACE Permission to the given Permission. By default, we compare the
+	 * Permission masks for exact match. Subclasses of this strategy can override this
+	 * behavior and implement more sophisticated comparisons, e.g. a bitwise comparison
+	 * for ACEs that grant access. 
{@code
 	 * if (ace.isGranting() && p.getMask() != 0) {
 	 *    return (ace.getPermission().getMask() & p.getMask()) != 0;
 	 * } else {
 	 *    return ace.getPermission().getMask() == p.getMask();
 	 * }
 	 * }
- * * @param ace the ACE from the Acl holding the mask. * @param p the Permission we are checking against. * @return true, if the respective masks are considered to be equal. diff --git a/acl/src/main/java/org/springframework/security/acls/domain/EhCacheBasedAclCache.java b/acl/src/main/java/org/springframework/security/acls/domain/EhCacheBasedAclCache.java index 6cf4658e35..124ec9671d 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/EhCacheBasedAclCache.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/EhCacheBasedAclCache.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import java.io.Serializable; @@ -38,18 +39,14 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class EhCacheBasedAclCache implements AclCache { - // ~ Instance fields - // ================================================================================================ private final Ehcache cache; + private PermissionGrantingStrategy permissionGrantingStrategy; + private AclAuthorizationStrategy aclAuthorizationStrategy; - // ~ Constructors - // =================================================================================================== - - public EhCacheBasedAclCache(Ehcache cache, - PermissionGrantingStrategy permissionGrantingStrategy, + public EhCacheBasedAclCache(Ehcache cache, PermissionGrantingStrategy permissionGrantingStrategy, AclAuthorizationStrategy aclAuthorizationStrategy) { Assert.notNull(cache, "Cache required"); Assert.notNull(permissionGrantingStrategy, "PermissionGrantingStrategy required"); @@ -59,72 +56,55 @@ public class EhCacheBasedAclCache implements AclCache { this.aclAuthorizationStrategy = aclAuthorizationStrategy; } - // ~ Methods - // ======================================================================================================== - + @Override public void evictFromCache(Serializable pk) { Assert.notNull(pk, "Primary key (identifier) required"); - MutableAcl acl = getFromCache(pk); - if (acl != null) { - cache.remove(acl.getId()); - cache.remove(acl.getObjectIdentity()); + this.cache.remove(acl.getId()); + this.cache.remove(acl.getObjectIdentity()); } } + @Override public void evictFromCache(ObjectIdentity objectIdentity) { Assert.notNull(objectIdentity, "ObjectIdentity required"); - MutableAcl acl = getFromCache(objectIdentity); - if (acl != null) { - cache.remove(acl.getId()); - cache.remove(acl.getObjectIdentity()); + this.cache.remove(acl.getId()); + this.cache.remove(acl.getObjectIdentity()); } } + @Override public MutableAcl getFromCache(ObjectIdentity objectIdentity) { Assert.notNull(objectIdentity, "ObjectIdentity required"); - - Element element = null; - try { - element = cache.get(objectIdentity); + Element element = this.cache.get(objectIdentity); + return (element != null) ? initializeTransientFields((MutableAcl) element.getValue()) : null; } - catch (CacheException ignored) { - } - - if (element == null) { + catch (CacheException ex) { return null; } - - return initializeTransientFields((MutableAcl) element.getValue()); } + @Override public MutableAcl getFromCache(Serializable pk) { Assert.notNull(pk, "Primary key (identifier) required"); - - Element element = null; - try { - element = cache.get(pk); + Element element = this.cache.get(pk); + return (element != null) ? initializeTransientFields((MutableAcl) element.getValue()) : null; } - catch (CacheException ignored) { - } - - if (element == null) { + catch (CacheException ex) { return null; } - - return initializeTransientFields((MutableAcl) element.getValue()); } + @Override public void putInCache(MutableAcl acl) { Assert.notNull(acl, "Acl required"); Assert.notNull(acl.getObjectIdentity(), "ObjectIdentity required"); Assert.notNull(acl.getId(), "ID required"); - if (this.aclAuthorizationStrategy == null) { if (acl instanceof AclImpl) { this.aclAuthorizationStrategy = (AclAuthorizationStrategy) FieldUtils @@ -133,30 +113,27 @@ public class EhCacheBasedAclCache implements AclCache { .getProtectedFieldValue("permissionGrantingStrategy", acl); } } - if ((acl.getParentAcl() != null) && (acl.getParentAcl() instanceof MutableAcl)) { putInCache((MutableAcl) acl.getParentAcl()); } - - cache.put(new Element(acl.getObjectIdentity(), acl)); - cache.put(new Element(acl.getId(), acl)); + this.cache.put(new Element(acl.getObjectIdentity(), acl)); + this.cache.put(new Element(acl.getId(), acl)); } private MutableAcl initializeTransientFields(MutableAcl value) { if (value instanceof AclImpl) { - FieldUtils.setProtectedFieldValue("aclAuthorizationStrategy", value, - this.aclAuthorizationStrategy); - FieldUtils.setProtectedFieldValue("permissionGrantingStrategy", value, - this.permissionGrantingStrategy); + FieldUtils.setProtectedFieldValue("aclAuthorizationStrategy", value, this.aclAuthorizationStrategy); + FieldUtils.setProtectedFieldValue("permissionGrantingStrategy", value, this.permissionGrantingStrategy); } - if (value.getParentAcl() != null) { initializeTransientFields((MutableAcl) value.getParentAcl()); } return value; } + @Override public void clearCache() { - cache.removeAll(); + this.cache.removeAll(); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/GrantedAuthoritySid.java b/acl/src/main/java/org/springframework/security/acls/domain/GrantedAuthoritySid.java index c8bf8ad20a..73c1dc0366 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/GrantedAuthoritySid.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/GrantedAuthoritySid.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Sid; import org.springframework.security.core.GrantedAuthority; - import org.springframework.util.Assert; /** @@ -31,14 +31,9 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class GrantedAuthoritySid implements Sid { - // ~ Instance fields - // ================================================================================================ private final String grantedAuthority; - // ~ Constructors - // =================================================================================================== - public GrantedAuthoritySid(String grantedAuthority) { Assert.hasText(grantedAuthority, "GrantedAuthority required"); this.grantedAuthority = grantedAuthority; @@ -46,25 +41,19 @@ public class GrantedAuthoritySid implements Sid { public GrantedAuthoritySid(GrantedAuthority grantedAuthority) { Assert.notNull(grantedAuthority, "GrantedAuthority required"); - Assert.notNull( - grantedAuthority.getAuthority(), + Assert.notNull(grantedAuthority.getAuthority(), "This Sid is only compatible with GrantedAuthoritys that provide a non-null getAuthority()"); this.grantedAuthority = grantedAuthority.getAuthority(); } - // ~ Methods - // ======================================================================================================== - @Override public boolean equals(Object object) { if ((object == null) || !(object instanceof GrantedAuthoritySid)) { return false; } - // Delegate to getGrantedAuthority() to perform actual comparison (both should be // identical) - return ((GrantedAuthoritySid) object).getGrantedAuthority().equals( - this.getGrantedAuthority()); + return ((GrantedAuthoritySid) object).getGrantedAuthority().equals(this.getGrantedAuthority()); } @Override @@ -73,11 +62,12 @@ public class GrantedAuthoritySid implements Sid { } public String getGrantedAuthority() { - return grantedAuthority; + return this.grantedAuthority; } @Override public String toString() { return "GrantedAuthoritySid[" + this.grantedAuthority + "]"; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/IdentityUnavailableException.java b/acl/src/main/java/org/springframework/security/acls/domain/IdentityUnavailableException.java index 5157646e1e..c18fc33606 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/IdentityUnavailableException.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/IdentityUnavailableException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; /** @@ -21,12 +22,9 @@ package org.springframework.security.acls.domain; * @author Ben Alex */ public class IdentityUnavailableException extends RuntimeException { - // ~ Constructors - // =================================================================================================== /** * Constructs an IdentityUnavailableException with the specified message. - * * @param msg the detail message */ public IdentityUnavailableException(String msg) { @@ -36,11 +34,11 @@ public class IdentityUnavailableException extends RuntimeException { /** * Constructs an IdentityUnavailableException with the specified message * and root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public IdentityUnavailableException(String msg, Throwable t) { - super(msg, t); + public IdentityUnavailableException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityImpl.java b/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityImpl.java index 484b43a3a6..aafa3fff3f 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityImpl.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import java.io.Serializable; @@ -31,19 +32,14 @@ import org.springframework.util.ClassUtils; * @author Ben Alex */ public class ObjectIdentityImpl implements ObjectIdentity { - // ~ Instance fields - // ================================================================================================ private final String type; - private Serializable identifier; - // ~ Constructors - // =================================================================================================== + private Serializable identifier; public ObjectIdentityImpl(String type, Serializable identifier) { Assert.hasText(type, "Type required"); Assert.notNull(identifier, "identifier required"); - this.identifier = identifier; this.type = type; } @@ -66,36 +62,28 @@ public class ObjectIdentityImpl implements ObjectIdentity { *

* The class name of the object passed will be considered the {@link #type}, so if * more control is required, a different constructor should be used. - * * @param object the domain object instance to create an identity for. - * * @throws IdentityUnavailableException if identity could not be extracted */ public ObjectIdentityImpl(Object object) throws IdentityUnavailableException { Assert.notNull(object, "object cannot be null"); - Class typeClass = ClassUtils.getUserClass(object.getClass()); - type = typeClass.getName(); - - Object result; - - try { - Method method = typeClass.getMethod("getId", new Class[] {}); - result = method.invoke(object); - } - catch (Exception e) { - throw new IdentityUnavailableException( - "Could not extract identity from object " + object, e); - } - + this.type = typeClass.getName(); + Object result = invokeGetIdMethod(object, typeClass); Assert.notNull(result, "getId() is required to return a non-null value"); - Assert.isInstanceOf(Serializable.class, result, - "Getter must provide a return value of type Serializable"); + Assert.isInstanceOf(Serializable.class, result, "Getter must provide a return value of type Serializable"); this.identifier = (Serializable) result; } - // ~ Methods - // ======================================================================================================== + private Object invokeGetIdMethod(Object object, Class typeClass) { + try { + Method method = typeClass.getMethod("getId", new Class[] {}); + return method.invoke(object); + } + catch (Exception ex) { + throw new IdentityUnavailableException("Could not extract identity from object " + object, ex); + } + } /** * Important so caching operates properly. @@ -105,49 +93,42 @@ public class ObjectIdentityImpl implements ObjectIdentity { *

* Numeric identities (Integer and Long values) are considered equal if they are * numerically equal. Other serializable types are evaluated using a simple equality. - * - * @param arg0 object to compare - * + * @param obj object to compare * @return true if the presented object matches this object */ @Override - public boolean equals(Object arg0) { - if (arg0 == null || !(arg0 instanceof ObjectIdentityImpl)) { + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof ObjectIdentityImpl)) { return false; } - - ObjectIdentityImpl other = (ObjectIdentityImpl) arg0; - - if (identifier instanceof Number && other.identifier instanceof Number) { + ObjectIdentityImpl other = (ObjectIdentityImpl) obj; + if (this.identifier instanceof Number && other.identifier instanceof Number) { // Integers and Longs with same value should be considered equal - if (((Number) identifier).longValue() != ((Number) other.identifier) - .longValue()) { + if (((Number) this.identifier).longValue() != ((Number) other.identifier).longValue()) { return false; } } else { // Use plain equality for other serializable types - if (!identifier.equals(other.identifier)) { + if (!this.identifier.equals(other.identifier)) { return false; } } - - return type.equals(other.type); + return this.type.equals(other.type); } @Override public Serializable getIdentifier() { - return identifier; + return this.identifier; } @Override public String getType() { - return type; + return this.type; } /** * Important so caching operates properly. - * * @return the hash */ @Override @@ -163,7 +144,7 @@ public class ObjectIdentityImpl implements ObjectIdentity { sb.append(this.getClass().getName()).append("["); sb.append("Type: ").append(this.type); sb.append("; Identifier: ").append(this.identifier).append("]"); - return sb.toString(); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImpl.java b/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImpl.java index 2fe0538c80..d08ba91d8c 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImpl.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImpl.java @@ -29,16 +29,16 @@ import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; * * @author Ben Alex */ -public class ObjectIdentityRetrievalStrategyImpl implements - ObjectIdentityRetrievalStrategy, ObjectIdentityGenerator { - // ~ Methods - // ======================================================================================================== +public class ObjectIdentityRetrievalStrategyImpl implements ObjectIdentityRetrievalStrategy, ObjectIdentityGenerator { + @Override public ObjectIdentity getObjectIdentity(Object domainObject) { return new ObjectIdentityImpl(domainObject); } + @Override public ObjectIdentity createObjectIdentity(Serializable id, String type) { return new ObjectIdentityImpl(type, id); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/PermissionFactory.java b/acl/src/main/java/org/springframework/security/acls/domain/PermissionFactory.java index 5d99deb201..41b613274e 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/PermissionFactory.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/PermissionFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import java.util.List; @@ -25,16 +26,13 @@ import org.springframework.security.acls.model.Permission; * * @author Ben Alex * @since 2.0.3 - * */ public interface PermissionFactory { /** * Dynamically creates a CumulativePermission or * BasePermission representing the active bits in the passed mask. - * * @param mask to build - * * @return a Permission representing the requested object */ Permission buildFromMask(int mask); @@ -42,4 +40,5 @@ public interface PermissionFactory { Permission buildFromName(String name); List buildFromNames(List names); + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/PrincipalSid.java b/acl/src/main/java/org/springframework/security/acls/domain/PrincipalSid.java index 2680b669c0..373d85a5e9 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/PrincipalSid.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/PrincipalSid.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Sid; import org.springframework.security.core.Authentication; - import org.springframework.util.Assert; /** @@ -31,14 +31,9 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class PrincipalSid implements Sid { - // ~ Instance fields - // ================================================================================================ private final String principal; - // ~ Constructors - // =================================================================================================== - public PrincipalSid(String principal) { Assert.hasText(principal, "Principal required"); this.principal = principal; @@ -47,19 +42,14 @@ public class PrincipalSid implements Sid { public PrincipalSid(Authentication authentication) { Assert.notNull(authentication, "Authentication required"); Assert.notNull(authentication.getPrincipal(), "Principal required"); - this.principal = authentication.getName(); } - // ~ Methods - // ======================================================================================================== - @Override public boolean equals(Object object) { if ((object == null) || !(object instanceof PrincipalSid)) { return false; } - // Delegate to getPrincipal() to perform actual comparison (both should be // identical) return ((PrincipalSid) object).getPrincipal().equals(this.getPrincipal()); @@ -71,11 +61,12 @@ public class PrincipalSid implements Sid { } public String getPrincipal() { - return principal; + return this.principal; } @Override public String toString() { return "PrincipalSid[" + this.principal + "]"; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/SidRetrievalStrategyImpl.java b/acl/src/main/java/org/springframework/security/acls/domain/SidRetrievalStrategyImpl.java index 2cf19291e6..92e0e6a224 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/SidRetrievalStrategyImpl.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/SidRetrievalStrategyImpl.java @@ -51,20 +51,16 @@ public class SidRetrievalStrategyImpl implements SidRetrievalStrategy { this.roleHierarchy = roleHierarchy; } - // ~ Methods - // ======================================================================================================== - + @Override public List getSids(Authentication authentication) { - Collection authorities = roleHierarchy + Collection authorities = this.roleHierarchy .getReachableGrantedAuthorities(authentication.getAuthorities()); List sids = new ArrayList<>(authorities.size() + 1); - sids.add(new PrincipalSid(authentication)); - for (GrantedAuthority authority : authorities) { sids.add(new GrantedAuthoritySid(authority)); } - return sids; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/SpringCacheBasedAclCache.java b/acl/src/main/java/org/springframework/security/acls/domain/SpringCacheBasedAclCache.java index 67410e9643..0ad9813b5e 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/SpringCacheBasedAclCache.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/SpringCacheBasedAclCache.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; +import java.io.Serializable; + import org.springframework.cache.Cache; import org.springframework.security.acls.model.AclCache; import org.springframework.security.acls.model.MutableAcl; @@ -23,8 +26,6 @@ import org.springframework.security.acls.model.PermissionGrantingStrategy; import org.springframework.security.util.FieldUtils; import org.springframework.util.Assert; -import java.io.Serializable; - /** * Simple implementation of {@link org.springframework.security.acls.model.AclCache} that * delegates to {@link Cache} implementation. @@ -39,18 +40,14 @@ import java.io.Serializable; * @since 3.2 */ public class SpringCacheBasedAclCache implements AclCache { - // ~ Instance fields - // ================================================================================================ private final Cache cache; + private PermissionGrantingStrategy permissionGrantingStrategy; + private AclAuthorizationStrategy aclAuthorizationStrategy; - // ~ Constructors - // =================================================================================================== - - public SpringCacheBasedAclCache(Cache cache, - PermissionGrantingStrategy permissionGrantingStrategy, + public SpringCacheBasedAclCache(Cache cache, PermissionGrantingStrategy permissionGrantingStrategy, AclAuthorizationStrategy aclAuthorizationStrategy) { Assert.notNull(cache, "Cache required"); Assert.notNull(permissionGrantingStrategy, "PermissionGrantingStrategy required"); @@ -60,79 +57,72 @@ public class SpringCacheBasedAclCache implements AclCache { this.aclAuthorizationStrategy = aclAuthorizationStrategy; } - // ~ Methods - // ======================================================================================================== - + @Override public void evictFromCache(Serializable pk) { Assert.notNull(pk, "Primary key (identifier) required"); - MutableAcl acl = getFromCache(pk); - if (acl != null) { - cache.evict(acl.getId()); - cache.evict(acl.getObjectIdentity()); + this.cache.evict(acl.getId()); + this.cache.evict(acl.getObjectIdentity()); } } + @Override public void evictFromCache(ObjectIdentity objectIdentity) { Assert.notNull(objectIdentity, "ObjectIdentity required"); - MutableAcl acl = getFromCache(objectIdentity); - if (acl != null) { - cache.evict(acl.getId()); - cache.evict(acl.getObjectIdentity()); + this.cache.evict(acl.getId()); + this.cache.evict(acl.getObjectIdentity()); } } + @Override public MutableAcl getFromCache(ObjectIdentity objectIdentity) { Assert.notNull(objectIdentity, "ObjectIdentity required"); return getFromCache((Object) objectIdentity); } + @Override public MutableAcl getFromCache(Serializable pk) { Assert.notNull(pk, "Primary key (identifier) required"); return getFromCache((Object) pk); } + @Override public void putInCache(MutableAcl acl) { Assert.notNull(acl, "Acl required"); Assert.notNull(acl.getObjectIdentity(), "ObjectIdentity required"); Assert.notNull(acl.getId(), "ID required"); - if ((acl.getParentAcl() != null) && (acl.getParentAcl() instanceof MutableAcl)) { putInCache((MutableAcl) acl.getParentAcl()); } - - cache.put(acl.getObjectIdentity(), acl); - cache.put(acl.getId(), acl); + this.cache.put(acl.getObjectIdentity(), acl); + this.cache.put(acl.getId(), acl); } private MutableAcl getFromCache(Object key) { - Cache.ValueWrapper element = cache.get(key); - + Cache.ValueWrapper element = this.cache.get(key); if (element == null) { return null; } - return initializeTransientFields((MutableAcl) element.get()); } private MutableAcl initializeTransientFields(MutableAcl value) { if (value instanceof AclImpl) { - FieldUtils.setProtectedFieldValue("aclAuthorizationStrategy", value, - this.aclAuthorizationStrategy); - FieldUtils.setProtectedFieldValue("permissionGrantingStrategy", value, - this.permissionGrantingStrategy); + FieldUtils.setProtectedFieldValue("aclAuthorizationStrategy", value, this.aclAuthorizationStrategy); + FieldUtils.setProtectedFieldValue("permissionGrantingStrategy", value, this.permissionGrantingStrategy); } - if (value.getParentAcl() != null) { initializeTransientFields((MutableAcl) value.getParentAcl()); } return value; } + @Override public void clearCache() { - cache.clear(); + this.cache.clear(); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/domain/package-info.java b/acl/src/main/java/org/springframework/security/acls/domain/package-info.java index ff0f6f62b7..8ef86cecdb 100644 --- a/acl/src/main/java/org/springframework/security/acls/domain/package-info.java +++ b/acl/src/main/java/org/springframework/security/acls/domain/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Basic implementation of access control lists (ACLs) interfaces. */ package org.springframework.security.acls.domain; - diff --git a/acl/src/main/java/org/springframework/security/acls/jdbc/AclClassIdUtils.java b/acl/src/main/java/org/springframework/security/acls/jdbc/AclClassIdUtils.java index 65b5d80f54..3255b3a967 100644 --- a/acl/src/main/java/org/springframework/security/acls/jdbc/AclClassIdUtils.java +++ b/acl/src/main/java/org/springframework/security/acls/jdbc/AclClassIdUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import java.io.Serializable; @@ -22,6 +23,7 @@ import java.util.UUID; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.TypeDescriptor; @@ -31,12 +33,16 @@ import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.util.Assert; /** - * Utility class for helping convert database representations of {@link ObjectIdentity#getIdentifier()} into - * the correct Java type as specified by acl_class.class_id_type. + * Utility class for helping convert database representations of + * {@link ObjectIdentity#getIdentifier()} into the correct Java type as specified by + * acl_class.class_id_type. + * * @author paulwheeler */ class AclClassIdUtils { + private static final String DEFAULT_CLASS_ID_TYPE_COLUMN_NAME = "class_id_type"; + private static final Log log = LogFactory.getLog(AclClassIdUtils.class); private ConversionService conversionService; @@ -54,88 +60,85 @@ class AclClassIdUtils { } /** - * Converts the raw type from the database into the right Java type. For most applications the 'raw type' will be Long, for some applications - * it could be String. + * Converts the raw type from the database into the right Java type. For most + * applications the 'raw type' will be Long, for some applications it could be String. * @param identifier The identifier from the database - * @param resultSet Result set of the query + * @param resultSet Result set of the query * @return The identifier in the appropriate target Java type. Typically Long or UUID. * @throws SQLException */ Serializable identifierFrom(Serializable identifier, ResultSet resultSet) throws SQLException { if (isString(identifier) && hasValidClassIdType(resultSet) - && canConvertFromStringTo(classIdTypeFrom(resultSet))) { - - identifier = convertFromStringTo((String) identifier, classIdTypeFrom(resultSet)); - } else { - // Assume it should be a Long type - identifier = convertToLong(identifier); + && canConvertFromStringTo(classIdTypeFrom(resultSet))) { + return convertFromStringTo((String) identifier, classIdTypeFrom(resultSet)); } - - return identifier; + // Assume it should be a Long type + return convertToLong(identifier); } private boolean hasValidClassIdType(ResultSet resultSet) { - boolean hasClassIdType = false; try { - hasClassIdType = classIdTypeFrom(resultSet) != null; - } catch (SQLException e) { - log.debug("Unable to obtain the class id type", e); + return classIdTypeFrom(resultSet) != null; + } + catch (SQLException ex) { + log.debug("Unable to obtain the class id type", ex); + return false; } - return hasClassIdType; } - private Class classIdTypeFrom(ResultSet resultSet) throws SQLException { + private Class classIdTypeFrom(ResultSet resultSet) throws SQLException { return classIdTypeFrom(resultSet.getString(DEFAULT_CLASS_ID_TYPE_COLUMN_NAME)); } private Class classIdTypeFrom(String className) { - Class targetType = null; - if (className != null) { - try { - targetType = Class.forName(className); - } catch (ClassNotFoundException e) { - log.debug("Unable to find class id type on classpath", e); - } + if (className == null) { + return null; + } + try { + return (Class) Class.forName(className); + } + catch (ClassNotFoundException ex) { + log.debug("Unable to find class id type on classpath", ex); + return null; } - return targetType; } private boolean canConvertFromStringTo(Class targetType) { - return conversionService.canConvert(String.class, targetType); + return this.conversionService.canConvert(String.class, targetType); } private T convertFromStringTo(String identifier, Class targetType) { - return conversionService.convert(identifier, targetType); + return this.conversionService.convert(identifier, targetType); } /** - * Converts to a {@link Long}, attempting to use the {@link ConversionService} if available. - * @param identifier The identifier + * Converts to a {@link Long}, attempting to use the {@link ConversionService} if + * available. + * @param identifier The identifier * @return Long version of the identifier * @throws NumberFormatException if the string cannot be parsed to a long. - * @throws org.springframework.core.convert.ConversionException if a conversion exception occurred + * @throws org.springframework.core.convert.ConversionException if a conversion + * exception occurred * @throws IllegalArgumentException if targetType is null */ private Long convertToLong(Serializable identifier) { - Long idAsLong; - if (conversionService.canConvert(identifier.getClass(), Long.class)) { - idAsLong = conversionService.convert(identifier, Long.class); - } else { - idAsLong = Long.valueOf(identifier.toString()); + if (this.conversionService.canConvert(identifier.getClass(), Long.class)) { + return this.conversionService.convert(identifier, Long.class); } - return idAsLong; + return Long.valueOf(identifier.toString()); } private boolean isString(Serializable object) { return object.getClass().isAssignableFrom(String.class); } - public void setConversionService(ConversionService conversionService) { + void setConversionService(ConversionService conversionService) { Assert.notNull(conversionService, "conversionService must not be null"); this.conversionService = conversionService; } private static class StringToLongConverter implements Converter { + @Override public Long convert(String identifierAsString) { if (identifierAsString == null) { @@ -145,9 +148,11 @@ class AclClassIdUtils { } return Long.parseLong(identifierAsString); } + } private static class StringToUUIDConverter implements Converter { + @Override public UUID convert(String identifierAsString) { if (identifierAsString == null) { @@ -157,5 +162,7 @@ class AclClassIdUtils { } return UUID.fromString(identifierAsString); } + } + } diff --git a/acl/src/main/java/org/springframework/security/acls/jdbc/BasicLookupStrategy.java b/acl/src/main/java/org/springframework/security/acls/jdbc/BasicLookupStrategy.java index ebd32632a7..9d4d099b25 100644 --- a/acl/src/main/java/org/springframework/security/acls/jdbc/BasicLookupStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/jdbc/BasicLookupStrategy.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import java.io.Serializable; import java.lang.reflect.Field; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; @@ -68,9 +70,9 @@ import org.springframework.util.Assert; * as it is likely to change in future releases and therefore subclassing is unsupported. *

* There are two SQL queries executed, one in the lookupPrimaryKeys method and - * one in lookupObjectIdentities. These are built from the same select and - * "order by" clause, using a different where clause in each case. In order to use custom - * schema or column names, each of these SQL clauses can be customized, but they must be + * one in lookupObjectIdentities. These are built from the same select and "order + * by" clause, using a different where clause in each case. In order to use custom schema + * or column names, each of these SQL clauses can be customized, but they must be * consistent with each other and with the expected result set generated by the the * default values. * @@ -78,183 +80,151 @@ import org.springframework.util.Assert; */ public class BasicLookupStrategy implements LookupStrategy { - private final static String DEFAULT_SELECT_CLAUSE_COLUMNS = "select acl_object_identity.object_id_identity, " - + "acl_entry.ace_order, " - + "acl_object_identity.id as acl_id, " - + "acl_object_identity.parent_object, " - + "acl_object_identity.entries_inheriting, " - + "acl_entry.id as ace_id, " - + "acl_entry.mask, " - + "acl_entry.granting, " - + "acl_entry.audit_success, " - + "acl_entry.audit_failure, " - + "acl_sid.principal as ace_principal, " - + "acl_sid.sid as ace_sid, " - + "acli_sid.principal as acl_principal, " - + "acli_sid.sid as acl_sid, " - + "acl_class.class "; - private final static String DEFAULT_SELECT_CLAUSE_ACL_CLASS_ID_TYPE_COLUMN = ", acl_class.class_id_type "; - private final static String DEFAULT_SELECT_CLAUSE_FROM = "from acl_object_identity " + private static final String DEFAULT_SELECT_CLAUSE_COLUMNS = "select acl_object_identity.object_id_identity, " + + "acl_entry.ace_order, " + "acl_object_identity.id as acl_id, " + "acl_object_identity.parent_object, " + + "acl_object_identity.entries_inheriting, " + "acl_entry.id as ace_id, " + "acl_entry.mask, " + + "acl_entry.granting, " + "acl_entry.audit_success, " + "acl_entry.audit_failure, " + + "acl_sid.principal as ace_principal, " + "acl_sid.sid as ace_sid, " + + "acli_sid.principal as acl_principal, " + "acli_sid.sid as acl_sid, " + "acl_class.class "; + + private static final String DEFAULT_SELECT_CLAUSE_ACL_CLASS_ID_TYPE_COLUMN = ", acl_class.class_id_type "; + + private static final String DEFAULT_SELECT_CLAUSE_FROM = "from acl_object_identity " + "left join acl_sid acli_sid on acli_sid.id = acl_object_identity.owner_sid " + "left join acl_class on acl_class.id = acl_object_identity.object_id_class " + "left join acl_entry on acl_object_identity.id = acl_entry.acl_object_identity " + "left join acl_sid on acl_entry.sid = acl_sid.id " + "where ( "; - public final static String DEFAULT_SELECT_CLAUSE = DEFAULT_SELECT_CLAUSE_COLUMNS + DEFAULT_SELECT_CLAUSE_FROM; + public static final String DEFAULT_SELECT_CLAUSE = DEFAULT_SELECT_CLAUSE_COLUMNS + DEFAULT_SELECT_CLAUSE_FROM; - public final static String DEFAULT_ACL_CLASS_ID_SELECT_CLAUSE = DEFAULT_SELECT_CLAUSE_COLUMNS + - DEFAULT_SELECT_CLAUSE_ACL_CLASS_ID_TYPE_COLUMN + DEFAULT_SELECT_CLAUSE_FROM; + public static final String DEFAULT_ACL_CLASS_ID_SELECT_CLAUSE = DEFAULT_SELECT_CLAUSE_COLUMNS + + DEFAULT_SELECT_CLAUSE_ACL_CLASS_ID_TYPE_COLUMN + DEFAULT_SELECT_CLAUSE_FROM; - private final static String DEFAULT_LOOKUP_KEYS_WHERE_CLAUSE = "(acl_object_identity.id = ?)"; + private static final String DEFAULT_LOOKUP_KEYS_WHERE_CLAUSE = "(acl_object_identity.id = ?)"; - private final static String DEFAULT_LOOKUP_IDENTITIES_WHERE_CLAUSE = "(acl_object_identity.object_id_identity = ? and acl_class.class = ?)"; + private static final String DEFAULT_LOOKUP_IDENTITIES_WHERE_CLAUSE = "(acl_object_identity.object_id_identity = ? and acl_class.class = ?)"; - public final static String DEFAULT_ORDER_BY_CLAUSE = ") order by acl_object_identity.object_id_identity" + public static final String DEFAULT_ORDER_BY_CLAUSE = ") order by acl_object_identity.object_id_identity" + " asc, acl_entry.ace_order asc"; - // ~ Instance fields - // ================================================================================================ - private final AclAuthorizationStrategy aclAuthorizationStrategy; + private PermissionFactory permissionFactory = new DefaultPermissionFactory(); + private final AclCache aclCache; + private final PermissionGrantingStrategy grantingStrategy; + private final JdbcTemplate jdbcTemplate; + private int batchSize = 50; private final Field fieldAces = FieldUtils.getField(AclImpl.class, "aces"); - private final Field fieldAcl = FieldUtils.getField(AccessControlEntryImpl.class, - "acl"); + + private final Field fieldAcl = FieldUtils.getField(AccessControlEntryImpl.class, "acl"); // SQL Customization fields private String selectClause = DEFAULT_SELECT_CLAUSE; + private String lookupPrimaryKeysWhereClause = DEFAULT_LOOKUP_KEYS_WHERE_CLAUSE; + private String lookupObjectIdentitiesWhereClause = DEFAULT_LOOKUP_IDENTITIES_WHERE_CLAUSE; + private String orderByClause = DEFAULT_ORDER_BY_CLAUSE; private AclClassIdUtils aclClassIdUtils; - // ~ Constructors - // =================================================================================================== - /** * Constructor accepting mandatory arguments - * * @param dataSource to access the database * @param aclCache the cache where fully-loaded elements can be stored * @param aclAuthorizationStrategy authorization strategy (required) */ public BasicLookupStrategy(DataSource dataSource, AclCache aclCache, AclAuthorizationStrategy aclAuthorizationStrategy, AuditLogger auditLogger) { - this(dataSource, aclCache, aclAuthorizationStrategy, - new DefaultPermissionGrantingStrategy(auditLogger)); + this(dataSource, aclCache, aclAuthorizationStrategy, new DefaultPermissionGrantingStrategy(auditLogger)); } /** * Creates a new instance - * * @param dataSource to access the database * @param aclCache the cache where fully-loaded elements can be stored * @param aclAuthorizationStrategy authorization strategy (required) * @param grantingStrategy the PermissionGrantingStrategy */ public BasicLookupStrategy(DataSource dataSource, AclCache aclCache, - AclAuthorizationStrategy aclAuthorizationStrategy, - PermissionGrantingStrategy grantingStrategy) { + AclAuthorizationStrategy aclAuthorizationStrategy, PermissionGrantingStrategy grantingStrategy) { Assert.notNull(dataSource, "DataSource required"); Assert.notNull(aclCache, "AclCache required"); Assert.notNull(aclAuthorizationStrategy, "AclAuthorizationStrategy required"); Assert.notNull(grantingStrategy, "grantingStrategy required"); - jdbcTemplate = new JdbcTemplate(dataSource); + this.jdbcTemplate = new JdbcTemplate(dataSource); this.aclCache = aclCache; this.aclAuthorizationStrategy = aclAuthorizationStrategy; this.grantingStrategy = grantingStrategy; this.aclClassIdUtils = new AclClassIdUtils(); - fieldAces.setAccessible(true); - fieldAcl.setAccessible(true); + this.fieldAces.setAccessible(true); + this.fieldAcl.setAccessible(true); } - // ~ Methods - // ======================================================================================================== - private String computeRepeatingSql(String repeatingSql, int requiredRepetitions) { - assert requiredRepetitions > 0 : "requiredRepetitions must be > 0"; - - final String startSql = selectClause; - - final String endSql = orderByClause; - - StringBuilder sqlStringBldr = new StringBuilder(startSql.length() - + endSql.length() + requiredRepetitions * (repeatingSql.length() + 4)); + Assert.isTrue(requiredRepetitions > 0, "requiredRepetitions must be > 0"); + String startSql = this.selectClause; + String endSql = this.orderByClause; + StringBuilder sqlStringBldr = new StringBuilder( + startSql.length() + endSql.length() + requiredRepetitions * (repeatingSql.length() + 4)); sqlStringBldr.append(startSql); - for (int i = 1; i <= requiredRepetitions; i++) { sqlStringBldr.append(repeatingSql); - if (i != requiredRepetitions) { sqlStringBldr.append(" or "); } } - sqlStringBldr.append(endSql); - return sqlStringBldr.toString(); } @SuppressWarnings("unchecked") private List readAces(AclImpl acl) { try { - return (List) fieldAces.get(acl); + return (List) this.fieldAces.get(acl); } - catch (IllegalAccessException e) { - throw new IllegalStateException("Could not obtain AclImpl.aces field", e); + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not obtain AclImpl.aces field", ex); } } private void setAclOnAce(AccessControlEntryImpl ace, AclImpl acl) { try { - fieldAcl.set(ace, acl); + this.fieldAcl.set(ace, acl); } - catch (IllegalAccessException e) { - throw new IllegalStateException( - "Could not or set AclImpl on AccessControlEntryImpl fields", e); + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not or set AclImpl on AccessControlEntryImpl fields", ex); } } private void setAces(AclImpl acl, List aces) { try { - fieldAces.set(acl, aces); + this.fieldAces.set(acl, aces); } - catch (IllegalAccessException e) { - throw new IllegalStateException("Could not set AclImpl entries", e); + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not set AclImpl entries", ex); } } /** * Locates the primary key IDs specified in "findNow", adding AclImpl instances with * StubAclParents to the "acls" Map. - * * @param acls the AclImpls (with StubAclParents) * @param findNow Long-based primary keys to retrieve * @param sids */ - private void lookupPrimaryKeys(final Map acls, - final Set findNow, final List sids) { + private void lookupPrimaryKeys(final Map acls, final Set findNow, final List sids) { Assert.notNull(acls, "ACLs are required"); Assert.notEmpty(findNow, "Items to find now required"); - - String sql = computeRepeatingSql(lookupPrimaryKeysWhereClause, findNow.size()); - - Set parentsToLookup = jdbcTemplate.query(sql, - ps -> { - int i = 0; - - for (Long toFind : findNow) { - i++; - ps.setLong(i, toFind); - } - }, new ProcessResultSet(acls, sids)); - + String sql = computeRepeatingSql(this.lookupPrimaryKeysWhereClause, findNow.size()); + Set parentsToLookup = this.jdbcTemplate.query(sql, (ps) -> setKeys(ps, findNow), + new ProcessResultSet(acls, sids)); // Lookup the parents, now that our JdbcTemplate has released the database // connection (SEC-547) if (parentsToLookup.size() > 0) { @@ -262,6 +232,14 @@ public class BasicLookupStrategy implements LookupStrategy { } } + private void setKeys(PreparedStatement ps, Set findNow) throws SQLException { + int i = 0; + for (Long toFind : findNow) { + i++; + ps.setLong(i, toFind); + } + } + /** * The main method. *

@@ -271,85 +249,61 @@ public class BasicLookupStrategy implements LookupStrategy { * develop a custom {@link LookupStrategy} implementation instead. *

* The implementation works in batch sizes specified by {@link #batchSize}. - * * @param objects the identities to lookup (required) * @param sids the SIDs for which identities are required (ignored by this * implementation) - * * @return a Map where keys represent the {@link ObjectIdentity} of the * located {@link Acl} and values are the located {@link Acl} (never null * although some entries may be missing; this method should not throw * {@link NotFoundException}, as a chain of {@link LookupStrategy}s may be used to * automatically create entries if required) */ - public final Map readAclsById(List objects, - List sids) { - Assert.isTrue(batchSize >= 1, "BatchSize must be >= 1"); + @Override + public final Map readAclsById(List objects, List sids) { + Assert.isTrue(this.batchSize >= 1, "BatchSize must be >= 1"); Assert.notEmpty(objects, "Objects to lookup required"); - // Map - Map result = new HashMap<>(); // contains - // FULLY - // loaded - // Acl - // objects - + // contains FULLY loaded Acl objects + Map result = new HashMap<>(); Set currentBatchToLoad = new HashSet<>(); - for (int i = 0; i < objects.size(); i++) { final ObjectIdentity oid = objects.get(i); boolean aclFound = false; - // Check we don't already have this ACL in the results if (result.containsKey(oid)) { aclFound = true; } - // Check cache for the present ACL entry if (!aclFound) { - Acl acl = aclCache.getFromCache(oid); - + Acl acl = this.aclCache.getFromCache(oid); // Ensure any cached element supports all the requested SIDs // (they should always, as our base impl doesn't filter on SID) if (acl != null) { - if (acl.isSidLoaded(sids)) { - result.put(acl.getObjectIdentity(), acl); - aclFound = true; - } - else { - throw new IllegalStateException( - "Error: SID-filtered element detected when implementation does not perform SID filtering " - + "- have you added something to the cache manually?"); - } + Assert.state(acl.isSidLoaded(sids), + "Error: SID-filtered element detected when implementation does not perform SID filtering " + + "- have you added something to the cache manually?"); + result.put(acl.getObjectIdentity(), acl); + aclFound = true; } } - // Load the ACL from the database if (!aclFound) { currentBatchToLoad.add(oid); } - // Is it time to load from JDBC the currentBatchToLoad? - if ((currentBatchToLoad.size() == this.batchSize) - || ((i + 1) == objects.size())) { + if ((currentBatchToLoad.size() == this.batchSize) || ((i + 1) == objects.size())) { if (currentBatchToLoad.size() > 0) { - Map loadedBatch = lookupObjectIdentities( - currentBatchToLoad, sids); - + Map loadedBatch = lookupObjectIdentities(currentBatchToLoad, sids); // Add loaded batch (all elements 100% initialized) to results result.putAll(loadedBatch); - // Add the loaded batch to the cache - for (Acl loadedAcl : loadedBatch.values()) { - aclCache.putInCache((AclImpl) loadedAcl); + this.aclCache.putInCache((AclImpl) loadedAcl); } - currentBatchToLoad.clear(); } } } - return result; } @@ -362,39 +316,20 @@ public class BasicLookupStrategy implements LookupStrategy { *

* This subclass is required to return fully valid Acls, including * properly-configured parent ACLs. - * */ - private Map lookupObjectIdentities( - final Collection objectIdentities, List sids) { + private Map lookupObjectIdentities(final Collection objectIdentities, + List sids) { Assert.notEmpty(objectIdentities, "Must provide identities to lookup"); - final Map acls = new HashMap<>(); // contains - // Acls - // with - // StubAclParents + // contains Acls with StubAclParents + Map acls = new HashMap<>(); // Make the "acls" map contain all requested objectIdentities // (including markers to each parent in the hierarchy) - String sql = computeRepeatingSql(lookupObjectIdentitiesWhereClause, - objectIdentities.size()); + String sql = computeRepeatingSql(this.lookupObjectIdentitiesWhereClause, objectIdentities.size()); - Set parentsToLookup = jdbcTemplate.query(sql, - ps -> { - int i = 0; - for (ObjectIdentity oid : objectIdentities) { - // Determine prepared statement values for this iteration - String type = oid.getType(); - - // No need to check for nulls, as guaranteed non-null by - // ObjectIdentity.getIdentifier() interface contract - String identifier = oid.getIdentifier().toString(); - - // Inject values - ps.setString((2 * i) + 1, identifier); - ps.setString((2 * i) + 2, type); - i++; - } - }, new ProcessResultSet(acls, sids)); + Set parentsToLookup = this.jdbcTemplate.query(sql, + (ps) -> setupLookupObjectIdentitiesStatement(ps, objectIdentities), new ProcessResultSet(acls, sids)); // Lookup the parents, now that our JdbcTemplate has released the database // connection (SEC-547) @@ -404,13 +339,9 @@ public class BasicLookupStrategy implements LookupStrategy { // Finally, convert our "acls" containing StubAclParents into true Acls Map resultMap = new HashMap<>(); - for (Acl inputAcl : acls.values()) { - Assert.isInstanceOf(AclImpl.class, inputAcl, - "Map should have contained an AclImpl"); - Assert.isInstanceOf(Long.class, ((AclImpl) inputAcl).getId(), - "Acl.getId() must be Long"); - + Assert.isInstanceOf(AclImpl.class, inputAcl, "Map should have contained an AclImpl"); + Assert.isInstanceOf(Long.class, ((AclImpl) inputAcl).getId(), "Acl.getId() must be Long"); Acl result = convert(acls, (Long) ((AclImpl) inputAcl).getId()); resultMap.put(result.getObjectIdentity(), result); } @@ -418,15 +349,31 @@ public class BasicLookupStrategy implements LookupStrategy { return resultMap; } + private void setupLookupObjectIdentitiesStatement(PreparedStatement ps, Collection objectIdentities) + throws SQLException { + int i = 0; + for (ObjectIdentity oid : objectIdentities) { + // Determine prepared statement values for this iteration + String type = oid.getType(); + + // No need to check for nulls, as guaranteed non-null by + // ObjectIdentity.getIdentifier() interface contract + String identifier = oid.getIdentifier().toString(); + + // Inject values + ps.setString((2 * i) + 1, identifier); + ps.setString((2 * i) + 2, type); + i++; + } + } + /** * The final phase of converting the Map of AclImpl * instances which contain StubAclParents into proper, valid * AclImpls with correct ACL parents. - * * @param inputMap the unconverted AclImpls * @param currentIdentity the currentAcl that we wish to convert (this * may be - * */ private AclImpl convert(Map inputMap, Long currentIdentity) { Assert.notEmpty(inputMap, "InputMap required"); @@ -434,8 +381,7 @@ public class BasicLookupStrategy implements LookupStrategy { // Retrieve this Acl from the InputMap Acl uncastAcl = inputMap.get(currentIdentity); - Assert.isInstanceOf(AclImpl.class, uncastAcl, - "The inputMap contained a non-AclImpl"); + Assert.isInstanceOf(AclImpl.class, uncastAcl, "The inputMap contained a non-AclImpl"); AclImpl inputAcl = (AclImpl) uncastAcl; @@ -448,9 +394,8 @@ public class BasicLookupStrategy implements LookupStrategy { } // Now we have the parent (if there is one), create the true AclImpl - AclImpl result = new AclImpl(inputAcl.getObjectIdentity(), - inputAcl.getId(), aclAuthorizationStrategy, grantingStrategy, - parent, null, inputAcl.isEntriesInheriting(), inputAcl.getOwner()); + AclImpl result = new AclImpl(inputAcl.getObjectIdentity(), inputAcl.getId(), this.aclAuthorizationStrategy, + this.grantingStrategy, parent, null, inputAcl.isEntriesInheriting(), inputAcl.getOwner()); // Copy the "aces" from the input to the destination @@ -477,7 +422,6 @@ public class BasicLookupStrategy implements LookupStrategy { /** * Creates a particular implementation of {@link Sid} depending on the arguments. - * * @param sid the name of the sid representing its unique identifier. In typical ACL * database schema it's located in table {@code acl_sid} table, {@code sid} column. * @param isPrincipal whether it's a user or granted authority like role @@ -487,16 +431,13 @@ public class BasicLookupStrategy implements LookupStrategy { if (isPrincipal) { return new PrincipalSid(sid); } - else { - return new GrantedAuthoritySid(sid); - } + return new GrantedAuthoritySid(sid); } /** * Sets the {@code PermissionFactory} instance which will be used to convert loaded * permission data values to {@code Permission}s. A {@code DefaultPermissionFactory} * will be used by default. - * * @param permissionFactory */ public final void setPermissionFactory(PermissionFactory permissionFactory) { @@ -510,7 +451,6 @@ public class BasicLookupStrategy implements LookupStrategy { /** * The SQL for the select clause. If customizing in order to modify column names, * schema etc, the other SQL customization fields must also be set to match. - * * @param selectClause the select clause, which defaults to * {@link #DEFAULT_SELECT_CLAUSE}. */ @@ -528,8 +468,7 @@ public class BasicLookupStrategy implements LookupStrategy { /** * The SQL for the where clause used in the lookupObjectIdentities method. */ - public final void setLookupObjectIdentitiesWhereClause( - String lookupObjectIdentitiesWhereClause) { + public final void setLookupObjectIdentitiesWhereClause(String lookupObjectIdentitiesWhereClause) { this.lookupObjectIdentitiesWhereClause = lookupObjectIdentitiesWhereClause; } @@ -542,8 +481,9 @@ public class BasicLookupStrategy implements LookupStrategy { public final void setAclClassIdSupported(boolean aclClassIdSupported) { if (aclClassIdSupported) { - Assert.isTrue(this.selectClause.equals(DEFAULT_SELECT_CLAUSE), "Cannot set aclClassIdSupported and override the select clause; " - + "just override the select clause"); + Assert.isTrue(this.selectClause.equals(DEFAULT_SELECT_CLAUSE), + "Cannot set aclClassIdSupported and override the select clause; " + + "just override the select clause"); this.selectClause = DEFAULT_ACL_CLASS_ID_SELECT_CLAUSE; } } @@ -552,11 +492,10 @@ public class BasicLookupStrategy implements LookupStrategy { this.aclClassIdUtils = new AclClassIdUtils(conversionService); } - // ~ Inner Classes - // ================================================================================================== - private class ProcessResultSet implements ResultSetExtractor> { + private final Map acls; + private final List sids; ProcessResultSet(Map acls, List sids) { @@ -575,32 +514,32 @@ public class BasicLookupStrategy implements LookupStrategy { * null) * @throws SQLException */ + @Override public Set extractData(ResultSet rs) throws SQLException { Set parentIdsToLookup = new HashSet<>(); // Set of parent_id Longs while (rs.next()) { // Convert current row into an Acl (albeit with a StubAclParent) - convertCurrentResultIntoObject(acls, rs); + convertCurrentResultIntoObject(this.acls, rs); // Figure out if this row means we need to lookup another parent long parentId = rs.getLong("parent_object"); if (parentId != 0) { // See if it's already in the "acls" - if (acls.containsKey(parentId)) { + if (this.acls.containsKey(parentId)) { continue; // skip this while iteration } // Now try to find it in the cache - MutableAcl cached = aclCache.getFromCache(parentId); - - if ((cached == null) || !cached.isSidLoaded(sids)) { + MutableAcl cached = BasicLookupStrategy.this.aclCache.getFromCache(parentId); + if ((cached == null) || !cached.isSidLoaded(this.sids)) { parentIdsToLookup.add(parentId); } else { // Pop into the acls map, so our convert method doesn't // need to deal with an unsynchronized AclCache - acls.put(cached.getId(), cached); + this.acls.put(cached.getId(), cached); } } } @@ -612,15 +551,12 @@ public class BasicLookupStrategy implements LookupStrategy { /** * Accepts the current ResultSet row, and converts it into an * AclImpl that contains a StubAclParent - * * @param acls the Map we should add the converted Acl to * @param rs the ResultSet focused on a current row - * * @throws SQLException if something goes wrong converting values * @throws ConversionException if can't convert to the desired Java type */ - private void convertCurrentResultIntoObject(Map acls, - ResultSet rs) throws SQLException { + private void convertCurrentResultIntoObject(Map acls, ResultSet rs) throws SQLException { Long id = rs.getLong("acl_id"); // If we already have an ACL for this ID, just create the ACE @@ -629,11 +565,11 @@ public class BasicLookupStrategy implements LookupStrategy { if (acl == null) { // Make an AclImpl and pop it into the Map - // If the Java type is a String, check to see if we can convert it to the target id type, e.g. UUID. + // If the Java type is a String, check to see if we can convert it to the + // target id type, e.g. UUID. Serializable identifier = (Serializable) rs.getObject("object_id_identity"); - identifier = aclClassIdUtils.identifierFrom(identifier, rs); - ObjectIdentity objectIdentity = new ObjectIdentityImpl( - rs.getString("class"), identifier); + identifier = BasicLookupStrategy.this.aclClassIdUtils.identifierFrom(identifier, rs); + ObjectIdentity objectIdentity = new ObjectIdentityImpl(rs.getString("class"), identifier); Acl parentAcl = null; long parentAclId = rs.getLong("parent_object"); @@ -643,11 +579,10 @@ public class BasicLookupStrategy implements LookupStrategy { } boolean entriesInheriting = rs.getBoolean("entries_inheriting"); - Sid owner = createSid(rs.getBoolean("acl_principal"), - rs.getString("acl_sid")); + Sid owner = createSid(rs.getBoolean("acl_principal"), rs.getString("acl_sid")); - acl = new AclImpl(objectIdentity, id, aclAuthorizationStrategy, - grantingStrategy, parentAcl, null, entriesInheriting, owner); + acl = new AclImpl(objectIdentity, id, BasicLookupStrategy.this.aclAuthorizationStrategy, + BasicLookupStrategy.this.grantingStrategy, parentAcl, null, entriesInheriting, owner); acls.put(id, acl); } @@ -657,17 +592,16 @@ public class BasicLookupStrategy implements LookupStrategy { // ACE_SID) if (rs.getString("ace_sid") != null) { Long aceId = rs.getLong("ace_id"); - Sid recipient = createSid(rs.getBoolean("ace_principal"), - rs.getString("ace_sid")); + Sid recipient = createSid(rs.getBoolean("ace_principal"), rs.getString("ace_sid")); int mask = rs.getInt("mask"); - Permission permission = permissionFactory.buildFromMask(mask); + Permission permission = BasicLookupStrategy.this.permissionFactory.buildFromMask(mask); boolean granting = rs.getBoolean("granting"); boolean auditSuccess = rs.getBoolean("audit_success"); boolean auditFailure = rs.getBoolean("audit_failure"); - AccessControlEntryImpl ace = new AccessControlEntryImpl(aceId, acl, - recipient, permission, granting, auditSuccess, auditFailure); + AccessControlEntryImpl ace = new AccessControlEntryImpl(aceId, acl, recipient, permission, granting, + auditSuccess, auditFailure); // Field acesField = FieldUtils.getField(AclImpl.class, "aces"); List aces = readAces((AclImpl) acl); @@ -678,47 +612,57 @@ public class BasicLookupStrategy implements LookupStrategy { } } } + } private static class StubAclParent implements Acl { + private final Long id; StubAclParent(Long id) { this.id = id; } + Long getId() { + return this.id; + } + + @Override public List getEntries() { throw new UnsupportedOperationException("Stub only"); } - public Long getId() { - return id; - } - + @Override public ObjectIdentity getObjectIdentity() { throw new UnsupportedOperationException("Stub only"); } + @Override public Sid getOwner() { throw new UnsupportedOperationException("Stub only"); } + @Override public Acl getParentAcl() { throw new UnsupportedOperationException("Stub only"); } + @Override public boolean isEntriesInheriting() { throw new UnsupportedOperationException("Stub only"); } - public boolean isGranted(List permission, List sids, - boolean administrativeMode) throws NotFoundException, - UnloadedSidException { + @Override + public boolean isGranted(List permission, List sids, boolean administrativeMode) + throws NotFoundException, UnloadedSidException { throw new UnsupportedOperationException("Stub only"); } + @Override public boolean isSidLoaded(List sids) { throw new UnsupportedOperationException("Stub only"); } + } + } diff --git a/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcAclService.java b/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcAclService.java index f2cb89f909..935466f5d1 100644 --- a/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcAclService.java +++ b/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcAclService.java @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import java.io.Serializable; +import java.sql.ResultSet; +import java.sql.SQLException; import java.util.Collections; import java.util.List; import java.util.Map; @@ -24,6 +27,7 @@ import javax.sql.DataSource; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.core.convert.ConversionService; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; @@ -45,34 +49,37 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class JdbcAclService implements AclService { - // ~ Static fields/initializers - // ===================================================================================== protected static final Log log = LogFactory.getLog(JdbcAclService.class); + private static final String DEFAULT_SELECT_ACL_CLASS_COLUMNS = "class.class as class"; - private static final String DEFAULT_SELECT_ACL_CLASS_COLUMNS_WITH_ID_TYPE = DEFAULT_SELECT_ACL_CLASS_COLUMNS + ", class.class_id_type as class_id_type"; - private static final String DEFAULT_SELECT_ACL_WITH_PARENT_SQL = "select obj.object_id_identity as obj_id, " + DEFAULT_SELECT_ACL_CLASS_COLUMNS - + " from acl_object_identity obj, acl_object_identity parent, acl_class class " - + "where obj.parent_object = parent.id and obj.object_id_class = class.id " - + "and parent.object_id_identity = ? and parent.object_id_class = (" - + "select id FROM acl_class where acl_class.class = ?)"; - private static final String DEFAULT_SELECT_ACL_WITH_PARENT_SQL_WITH_CLASS_ID_TYPE = "select obj.object_id_identity as obj_id, " + DEFAULT_SELECT_ACL_CLASS_COLUMNS_WITH_ID_TYPE + + private static final String DEFAULT_SELECT_ACL_CLASS_COLUMNS_WITH_ID_TYPE = DEFAULT_SELECT_ACL_CLASS_COLUMNS + + ", class.class_id_type as class_id_type"; + + private static final String DEFAULT_SELECT_ACL_WITH_PARENT_SQL = "select obj.object_id_identity as obj_id, " + + DEFAULT_SELECT_ACL_CLASS_COLUMNS + " from acl_object_identity obj, acl_object_identity parent, acl_class class " + "where obj.parent_object = parent.id and obj.object_id_class = class.id " + "and parent.object_id_identity = ? and parent.object_id_class = (" + "select id FROM acl_class where acl_class.class = ?)"; - // ~ Instance fields - // ================================================================================================ + private static final String DEFAULT_SELECT_ACL_WITH_PARENT_SQL_WITH_CLASS_ID_TYPE = "select obj.object_id_identity as obj_id, " + + DEFAULT_SELECT_ACL_CLASS_COLUMNS_WITH_ID_TYPE + + " from acl_object_identity obj, acl_object_identity parent, acl_class class " + + "where obj.parent_object = parent.id and obj.object_id_class = class.id " + + "and parent.object_id_identity = ? and parent.object_id_class = (" + + "select id FROM acl_class where acl_class.class = ?)"; protected final JdbcOperations jdbcOperations; - private final LookupStrategy lookupStrategy; - private boolean aclClassIdSupported; - private String findChildrenSql = DEFAULT_SELECT_ACL_WITH_PARENT_SQL; - private AclClassIdUtils aclClassIdUtils; - // ~ Constructors - // =================================================================================================== + private final LookupStrategy lookupStrategy; + + private boolean aclClassIdSupported; + + private String findChildrenSql = DEFAULT_SELECT_ACL_WITH_PARENT_SQL; + + private AclClassIdUtils aclClassIdUtils; public JdbcAclService(DataSource dataSource, LookupStrategy lookupStrategy) { this(new JdbcTemplate(dataSource), lookupStrategy); @@ -86,64 +93,55 @@ public class JdbcAclService implements AclService { this.aclClassIdUtils = new AclClassIdUtils(); } - // ~ Methods - // ======================================================================================================== - + @Override public List findChildren(ObjectIdentity parentIdentity) { Object[] args = { parentIdentity.getIdentifier().toString(), parentIdentity.getType() }; - List objects = jdbcOperations.query(findChildrenSql, args, - (rs, rowNum) -> { - String javaType = rs.getString("class"); - Serializable identifier = (Serializable) rs.getObject("obj_id"); - identifier = aclClassIdUtils.identifierFrom(identifier, rs); - return new ObjectIdentityImpl(javaType, identifier); - }); - - if (objects.isEmpty()) { - return null; - } - - return objects; + List objects = this.jdbcOperations.query(this.findChildrenSql, args, + (rs, rowNum) -> mapObjectIdentityRow(rs)); + return (!objects.isEmpty()) ? objects : null; } - public Acl readAclById(ObjectIdentity object, List sids) - throws NotFoundException { + private ObjectIdentity mapObjectIdentityRow(ResultSet rs) throws SQLException { + String javaType = rs.getString("class"); + Serializable identifier = (Serializable) rs.getObject("obj_id"); + identifier = this.aclClassIdUtils.identifierFrom(identifier, rs); + return new ObjectIdentityImpl(javaType, identifier); + } + + @Override + public Acl readAclById(ObjectIdentity object, List sids) throws NotFoundException { Map map = readAclsById(Collections.singletonList(object), sids); Assert.isTrue(map.containsKey(object), () -> "There should have been an Acl entry for ObjectIdentity " + object); - return map.get(object); } + @Override public Acl readAclById(ObjectIdentity object) throws NotFoundException { return readAclById(object, null); } - public Map readAclsById(List objects) - throws NotFoundException { + @Override + public Map readAclsById(List objects) throws NotFoundException { return readAclsById(objects, null); } - public Map readAclsById(List objects, - List sids) throws NotFoundException { - Map result = lookupStrategy.readAclsById(objects, sids); - + @Override + public Map readAclsById(List objects, List sids) + throws NotFoundException { + Map result = this.lookupStrategy.readAclsById(objects, sids); // Check every requested object identity was found (throw NotFoundException if // needed) for (ObjectIdentity oid : objects) { if (!result.containsKey(oid)) { - throw new NotFoundException( - "Unable to find ACL information for object identity '" + oid - + "'"); + throw new NotFoundException("Unable to find ACL information for object identity '" + oid + "'"); } } - return result; } /** * Allows customization of the SQL query used to find child object identities. - * * @param findChildrenSql */ public void setFindChildrenQuery(String findChildrenSql) { @@ -156,7 +154,8 @@ public class JdbcAclService implements AclService { // Change the default children select if it hasn't been overridden if (this.findChildrenSql.equals(DEFAULT_SELECT_ACL_WITH_PARENT_SQL)) { this.findChildrenSql = DEFAULT_SELECT_ACL_WITH_PARENT_SQL_WITH_CLASS_ID_TYPE; - } else { + } + else { log.debug("Find children statement has already been overridden, so not overridding the default"); } } @@ -167,6 +166,7 @@ public class JdbcAclService implements AclService { } protected boolean isAclClassIdSupported() { - return aclClassIdSupported; + return this.aclClassIdSupported; } + } diff --git a/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcMutableAclService.java b/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcMutableAclService.java index 6625abdc5d..f1c7a7a174 100644 --- a/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcMutableAclService.java +++ b/acl/src/main/java/org/springframework/security/acls/jdbc/JdbcMutableAclService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import java.sql.PreparedStatement; @@ -58,55 +59,58 @@ import org.springframework.util.Assert; * @author Johannes Zlattinger */ public class JdbcMutableAclService extends JdbcAclService implements MutableAclService { + private static final String DEFAULT_INSERT_INTO_ACL_CLASS = "insert into acl_class (class) values (?)"; + private static final String DEFAULT_INSERT_INTO_ACL_CLASS_WITH_ID = "insert into acl_class (class, class_id_type) values (?, ?)"; - // ~ Instance fields - // ================================================================================================ private boolean foreignKeysInDatabase = true; + private final AclCache aclCache; + private String deleteEntryByObjectIdentityForeignKey = "delete from acl_entry where acl_object_identity=?"; + private String deleteObjectIdentityByPrimaryKey = "delete from acl_object_identity where id=?"; + private String classIdentityQuery = "call identity()"; + private String sidIdentityQuery = "call identity()"; + private String insertClass = DEFAULT_INSERT_INTO_ACL_CLASS; + private String insertEntry = "insert into acl_entry " + "(acl_object_identity, ace_order, sid, mask, granting, audit_success, audit_failure)" + "values (?, ?, ?, ?, ?, ?, ?)"; + private String insertObjectIdentity = "insert into acl_object_identity " - + "(object_id_class, object_id_identity, owner_sid, entries_inheriting) " - + "values (?, ?, ?, ?)"; + + "(object_id_class, object_id_identity, owner_sid, entries_inheriting) " + "values (?, ?, ?, ?)"; + private String insertSid = "insert into acl_sid (principal, sid) values (?, ?)"; + private String selectClassPrimaryKey = "select id from acl_class where class=?"; + private String selectObjectIdentityPrimaryKey = "select acl_object_identity.id from acl_object_identity, acl_class " + "where acl_object_identity.object_id_class = acl_class.id and acl_class.class=? " + "and acl_object_identity.object_id_identity = ?"; + private String selectSidPrimaryKey = "select id from acl_sid where principal=? and sid=?"; + private String updateObjectIdentity = "update acl_object_identity set " - + "parent_object = ?, owner_sid = ?, entries_inheriting = ?" - + " where id = ?"; + + "parent_object = ?, owner_sid = ?, entries_inheriting = ?" + " where id = ?"; - // ~ Constructors - // =================================================================================================== - - public JdbcMutableAclService(DataSource dataSource, LookupStrategy lookupStrategy, - AclCache aclCache) { + public JdbcMutableAclService(DataSource dataSource, LookupStrategy lookupStrategy, AclCache aclCache) { super(dataSource, lookupStrategy); Assert.notNull(aclCache, "AclCache required"); this.aclCache = aclCache; } - // ~ Methods - // ======================================================================================================== - - public MutableAcl createAcl(ObjectIdentity objectIdentity) - throws AlreadyExistsException { + @Override + public MutableAcl createAcl(ObjectIdentity objectIdentity) throws AlreadyExistsException { Assert.notNull(objectIdentity, "Object Identity required"); // Check this object identity hasn't already been persisted if (retrieveObjectIdentityPrimaryKey(objectIdentity) != null) { - throw new AlreadyExistsException("Object identity '" + objectIdentity - + "' already exists"); + throw new AlreadyExistsException("Object identity '" + objectIdentity + "' already exists"); } // Need to retrieve the current principal, in order to know who "owns" this ACL @@ -128,22 +132,23 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS /** * Creates a new row in acl_entry for every ACE defined in the passed MutableAcl * object. - * * @param acl containing the ACEs to insert */ protected void createEntries(final MutableAcl acl) { if (acl.getEntries().isEmpty()) { return; } - jdbcOperations.batchUpdate(insertEntry, new BatchPreparedStatementSetter() { + this.jdbcOperations.batchUpdate(this.insertEntry, new BatchPreparedStatementSetter() { + + @Override public int getBatchSize() { return acl.getEntries().size(); } + @Override public void setValues(PreparedStatement stmt, int i) throws SQLException { AccessControlEntry entry_ = acl.getEntries().get(i); - Assert.isTrue(entry_ instanceof AccessControlEntryImpl, - "Unknown ACE class"); + Assert.isTrue(entry_ instanceof AccessControlEntryImpl, "Unknown ACE class"); AccessControlEntryImpl entry = (AccessControlEntryImpl) entry_; stmt.setLong(1, (Long) acl.getId()); @@ -154,6 +159,7 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS stmt.setBoolean(6, entry.isAuditSuccess()); stmt.setBoolean(7, entry.isAuditFailure()); } + }); } @@ -161,7 +167,6 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS * Creates an entry in the acl_object_identity table for the passed ObjectIdentity. * The Sid is also necessary, as acl_object_identity has defined the sid column as * non-null. - * * @param object to represent an acl_object_identity for * @param owner for the SID column (will be created if there is no acl_sid entry for * this particular Sid already) @@ -169,22 +174,20 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS protected void createObjectIdentity(ObjectIdentity object, Sid owner) { Long sidId = createOrRetrieveSidPrimaryKey(owner, true); Long classId = createOrRetrieveClassPrimaryKey(object.getType(), true, object.getIdentifier().getClass()); - jdbcOperations.update(insertObjectIdentity, classId, object.getIdentifier().toString(), sidId, + this.jdbcOperations.update(this.insertObjectIdentity, classId, object.getIdentifier().toString(), sidId, Boolean.TRUE); } /** * Retrieves the primary key from {@code acl_class}, creating a new row if needed and * the {@code allowCreate} property is {@code true}. - * * @param type to find or create an entry for (often the fully-qualified class name) * @param allowCreate true if creation is permitted if not found - * * @return the primary key or null if not found */ protected Long createOrRetrieveClassPrimaryKey(String type, boolean allowCreate, Class idType) { - List classIds = jdbcOperations.queryForList(selectClassPrimaryKey, - new Object[] { type }, Long.class); + List classIds = this.jdbcOperations.queryForList(this.selectClassPrimaryKey, new Object[] { type }, + Long.class); if (!classIds.isEmpty()) { return classIds.get(0); @@ -192,13 +195,13 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS if (allowCreate) { if (!isAclClassIdSupported()) { - jdbcOperations.update(insertClass, type); - } else { - jdbcOperations.update(insertClass, type, idType.getCanonicalName()); + this.jdbcOperations.update(this.insertClass, type); } - Assert.isTrue(TransactionSynchronizationManager.isSynchronizationActive(), - "Transaction must be running"); - return jdbcOperations.queryForObject(classIdentityQuery, Long.class); + else { + this.jdbcOperations.update(this.insertClass, type, idType.getCanonicalName()); + } + Assert.isTrue(TransactionSynchronizationManager.isSynchronizationActive(), "Transaction must be running"); + return this.jdbcOperations.queryForObject(this.classIdentityQuery, Long.class); } return null; @@ -207,33 +210,23 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS /** * Retrieves the primary key from acl_sid, creating a new row if needed and the * allowCreate property is true. - * * @param sid to find or create * @param allowCreate true if creation is permitted if not found - * * @return the primary key or null if not found - * * @throws IllegalArgumentException if the Sid is not a recognized * implementation. */ protected Long createOrRetrieveSidPrimaryKey(Sid sid, boolean allowCreate) { Assert.notNull(sid, "Sid required"); - - String sidName; - boolean sidIsPrincipal = true; - if (sid instanceof PrincipalSid) { - sidName = ((PrincipalSid) sid).getPrincipal(); + String sidName = ((PrincipalSid) sid).getPrincipal(); + return createOrRetrieveSidPrimaryKey(sidName, true, allowCreate); } - else if (sid instanceof GrantedAuthoritySid) { - sidName = ((GrantedAuthoritySid) sid).getGrantedAuthority(); - sidIsPrincipal = false; + if (sid instanceof GrantedAuthoritySid) { + String sidName = ((GrantedAuthoritySid) sid).getGrantedAuthority(); + return createOrRetrieveSidPrimaryKey(sidName, false, allowCreate); } - else { - throw new IllegalArgumentException("Unsupported implementation of Sid"); - } - - return createOrRetrieveSidPrimaryKey(sidName, sidIsPrincipal, allowCreate); + throw new IllegalArgumentException("Unsupported implementation of Sid"); } /** @@ -244,32 +237,24 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS * @param allowCreate true if creation is permitted if not found * @return the primary key or null if not found */ - protected Long createOrRetrieveSidPrimaryKey(String sidName, boolean sidIsPrincipal, - boolean allowCreate) { - - List sidIds = jdbcOperations.queryForList(selectSidPrimaryKey, new Object[] { - sidIsPrincipal, sidName }, Long.class); - + protected Long createOrRetrieveSidPrimaryKey(String sidName, boolean sidIsPrincipal, boolean allowCreate) { + List sidIds = this.jdbcOperations.queryForList(this.selectSidPrimaryKey, + new Object[] { sidIsPrincipal, sidName }, Long.class); if (!sidIds.isEmpty()) { return sidIds.get(0); } - if (allowCreate) { - jdbcOperations.update(insertSid, sidIsPrincipal, sidName); - Assert.isTrue(TransactionSynchronizationManager.isSynchronizationActive(), - "Transaction must be running"); - return jdbcOperations.queryForObject(sidIdentityQuery, Long.class); + this.jdbcOperations.update(this.insertSid, sidIsPrincipal, sidName); + Assert.isTrue(TransactionSynchronizationManager.isSynchronizationActive(), "Transaction must be running"); + return this.jdbcOperations.queryForObject(this.sidIdentityQuery, Long.class); } - return null; } - public void deleteAcl(ObjectIdentity objectIdentity, boolean deleteChildren) - throws ChildrenExistException { + @Override + public void deleteAcl(ObjectIdentity objectIdentity, boolean deleteChildren) throws ChildrenExistException { Assert.notNull(objectIdentity, "Object Identity required"); - Assert.notNull(objectIdentity.getIdentifier(), - "Object Identity doesn't provide an identifier"); - + Assert.notNull(objectIdentity.getIdentifier(), "Object Identity doesn't provide an identifier"); if (deleteChildren) { List children = findChildren(objectIdentity); if (children != null) { @@ -279,14 +264,13 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS } } else { - if (!foreignKeysInDatabase) { + if (!this.foreignKeysInDatabase) { // We need to perform a manual verification for what a FK would normally - // do - // We generally don't do this, in the interests of deadlock management + // do. We generally don't do this, in the interests of deadlock management List children = findChildren(objectIdentity); if (children != null) { - throw new ChildrenExistException("Cannot delete '" + objectIdentity - + "' (has " + children.size() + " children)"); + throw new ChildrenExistException( + "Cannot delete '" + objectIdentity + "' (has " + children.size() + " children)"); } } } @@ -300,17 +284,16 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS deleteObjectIdentity(oidPrimaryKey); // Clear the cache - aclCache.evictFromCache(objectIdentity); + this.aclCache.evictFromCache(objectIdentity); } /** * Deletes all ACEs defined in the acl_entry table belonging to the presented * ObjectIdentity primary key. - * * @param oidPrimaryKey the rows in acl_entry to delete */ protected void deleteEntries(Long oidPrimaryKey) { - jdbcOperations.update(deleteEntryByObjectIdentityForeignKey, oidPrimaryKey); + this.jdbcOperations.update(this.deleteEntryByObjectIdentityForeignKey, oidPrimaryKey); } /** @@ -319,27 +302,24 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS *

* We do not delete any entries from acl_class, even if no classes are using that * class any longer. This is a deadlock avoidance approach. - * * @param oidPrimaryKey to delete the acl_object_identity */ protected void deleteObjectIdentity(Long oidPrimaryKey) { // Delete the acl_object_identity row - jdbcOperations.update(deleteObjectIdentityByPrimaryKey, oidPrimaryKey); + this.jdbcOperations.update(this.deleteObjectIdentityByPrimaryKey, oidPrimaryKey); } /** * Retrieves the primary key from the acl_object_identity table for the passed * ObjectIdentity. Unlike some other methods in this implementation, this method will * NOT create a row (use {@link #createObjectIdentity(ObjectIdentity, Sid)} instead). - * * @param oid to find - * * @return the object identity or null if not found */ protected Long retrieveObjectIdentityPrimaryKey(ObjectIdentity oid) { try { - return jdbcOperations.queryForObject(selectObjectIdentityPrimaryKey, Long.class, - oid.getType(), oid.getIdentifier().toString()); + return this.jdbcOperations.queryForObject(this.selectObjectIdentityPrimaryKey, Long.class, oid.getType(), + oid.getIdentifier().toString()); } catch (DataAccessException notFound) { return null; @@ -352,6 +332,7 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS * dirty state checking, or more likely use ORM capabilities for create, update and * delete operations of {@link MutableAcl}. */ + @Override public MutableAcl updateAcl(MutableAcl acl) throws NotFoundException { Assert.notNull(acl.getId(), "Object Identity doesn't provide an identifier"); @@ -380,37 +361,28 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS clearCacheIncludingChildren(child); } } - aclCache.evictFromCache(objectIdentity); + this.aclCache.evictFromCache(objectIdentity); } /** * Updates an existing acl_object_identity row, with new information presented in the * passed MutableAcl object. Also will create an acl_sid entry if needed for the Sid * that owns the MutableAcl. - * * @param acl to modify (a row must already exist in acl_object_identity) - * * @throws NotFoundException if the ACL could not be found to update. */ protected void updateObjectIdentity(MutableAcl acl) { Long parentId = null; - if (acl.getParentAcl() != null) { - Assert.isInstanceOf(ObjectIdentityImpl.class, acl.getParentAcl() - .getObjectIdentity(), + Assert.isInstanceOf(ObjectIdentityImpl.class, acl.getParentAcl().getObjectIdentity(), "Implementation only supports ObjectIdentityImpl"); - - ObjectIdentityImpl oii = (ObjectIdentityImpl) acl.getParentAcl() - .getObjectIdentity(); + ObjectIdentityImpl oii = (ObjectIdentityImpl) acl.getParentAcl().getObjectIdentity(); parentId = retrieveObjectIdentityPrimaryKey(oii); } - Assert.notNull(acl.getOwner(), "Owner is required in this implementation"); - Long ownerSid = createOrRetrieveSidPrimaryKey(acl.getOwner(), true); - int count = jdbcOperations.update(updateObjectIdentity, parentId, ownerSid, - acl.isEntriesInheriting(), acl.getId()); - + int count = this.jdbcOperations.update(this.updateObjectIdentity, parentId, ownerSid, acl.isEntriesInheriting(), + acl.getId()); if (count != 1) { throw new NotFoundException("Unable to locate ACL to update"); } @@ -419,7 +391,6 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS /** * Sets the query that will be used to retrieve the identity of a newly created row in * the acl_class table. - * * @param classIdentityQuery the query, which should return the identifier. Defaults * to call identity() */ @@ -431,7 +402,6 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS /** * Sets the query that will be used to retrieve the identity of a newly created row in * the acl_sid table. - * * @param sidIdentityQuery the query, which should return the identifier. Defaults to * call identity() */ @@ -440,13 +410,11 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS this.sidIdentityQuery = sidIdentityQuery; } - public void setDeleteEntryByObjectIdentityForeignKeySql( - String deleteEntryByObjectIdentityForeignKey) { + public void setDeleteEntryByObjectIdentityForeignKeySql(String deleteEntryByObjectIdentityForeignKey) { this.deleteEntryByObjectIdentityForeignKey = deleteEntryByObjectIdentityForeignKey; } - public void setDeleteObjectIdentityByPrimaryKeySql( - String deleteObjectIdentityByPrimaryKey) { + public void setDeleteObjectIdentityByPrimaryKeySql(String deleteObjectIdentityByPrimaryKey) { this.deleteObjectIdentityByPrimaryKey = deleteObjectIdentityByPrimaryKey; } @@ -498,9 +466,11 @@ public class JdbcMutableAclService extends JdbcAclService implements MutableAclS // Change the default insert if it hasn't been overridden if (this.insertClass.equals(DEFAULT_INSERT_INTO_ACL_CLASS)) { this.insertClass = DEFAULT_INSERT_INTO_ACL_CLASS_WITH_ID; - } else { + } + else { log.debug("Insert class statement has already been overridden, so not overridding the default"); } } } + } diff --git a/acl/src/main/java/org/springframework/security/acls/jdbc/LookupStrategy.java b/acl/src/main/java/org/springframework/security/acls/jdbc/LookupStrategy.java index dc8a11d419..adc6c6aef1 100644 --- a/acl/src/main/java/org/springframework/security/acls/jdbc/LookupStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/jdbc/LookupStrategy.java @@ -13,32 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; +import java.util.List; +import java.util.Map; + import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.NotFoundException; import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.security.acls.model.Sid; -import java.util.List; -import java.util.Map; - /** * Performs lookups for {@link org.springframework.security.acls.model.AclService}. * * @author Ben Alex */ public interface LookupStrategy { - // ~ Methods - // ======================================================================================================== /** * Perform database-specific optimized lookup. - * * @param objects the identities to lookup (required) * @param sids the SIDs for which identities are required (may be null - * implementations may elect not to provide SID optimisations) - * * @return a Map where keys represent the {@link ObjectIdentity} of the * located {@link Acl} and values are the located {@link Acl} (never null * although some entries may be missing; this method should not throw @@ -46,4 +43,5 @@ public interface LookupStrategy { * automatically create entries if required) */ Map readAclsById(List objects, List sids); + } diff --git a/acl/src/main/java/org/springframework/security/acls/jdbc/package-info.java b/acl/src/main/java/org/springframework/security/acls/jdbc/package-info.java index 154bf41bfc..2c34420915 100644 --- a/acl/src/main/java/org/springframework/security/acls/jdbc/package-info.java +++ b/acl/src/main/java/org/springframework/security/acls/jdbc/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * JDBC-based persistence of ACL information */ package org.springframework.security.acls.jdbc; - diff --git a/acl/src/main/java/org/springframework/security/acls/model/AccessControlEntry.java b/acl/src/main/java/org/springframework/security/acls/model/AccessControlEntry.java index a1bd103c9b..d8c8e286f6 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AccessControlEntry.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AccessControlEntry.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -26,17 +27,13 @@ import java.io.Serializable; *

* * @author Ben Alex - * */ public interface AccessControlEntry extends Serializable { - // ~ Methods - // ======================================================================================================== Acl getAcl(); /** * Obtains an identifier that represents this ACE. - * * @return the identifier, or null if unsaved */ Serializable getId(); @@ -46,10 +43,10 @@ public interface AccessControlEntry extends Serializable { Sid getSid(); /** - * Indicates the permission is being granted to the relevant Sid. If false, - * indicates the permission is being revoked/blocked. - * + * Indicates the permission is being granted to the relevant Sid. If false, indicates + * the permission is being revoked/blocked. * @return true if being granted, false otherwise */ boolean isGranting(); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/Acl.java b/acl/src/main/java/org/springframework/security/acls/model/Acl.java index de48e3a929..8128c9e0c1 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/Acl.java +++ b/acl/src/main/java/org/springframework/security/acls/model/Acl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -63,7 +64,6 @@ public interface Acl extends Serializable { * subset of Sids. The caller is responsible for correctly handling the * result if only a subset of Sids is represented. *

- * * @return the list of entries represented by the Acl, or null if * there are no entries presently associated with this Acl. */ @@ -72,7 +72,6 @@ public interface Acl extends Serializable { /** * Obtains the domain object this Acl provides entries for. This is immutable * once an Acl is created. - * * @return the object identity (never null) */ ObjectIdentity getObjectIdentity(); @@ -80,7 +79,6 @@ public interface Acl extends Serializable { /** * Determines the owner of the Acl. The meaning of ownership varies by * implementation and is unspecified. - * * @return the owner (may be null if the implementation does not use * ownership concepts) */ @@ -102,7 +100,6 @@ public interface Acl extends Serializable { * subset of Sids. The caller is responsible for correctly handling the * result if only a subset of Sids is represented. *

- * * @return the parent Acl (may be null if this Acl does not * have a parent) */ @@ -118,7 +115,6 @@ public interface Acl extends Serializable { * parent for navigation purposes. Thus, this method denotes whether or not the * navigation relationship also extends to the actual inheritance of entries. *

- * * @return true if parent ACL entries inherit into the current Acl */ boolean isEntriesInheriting(); @@ -158,7 +154,6 @@ public interface Acl extends Serializable { * authorization decision for a {@link Sid} that was never loaded in this Acl * . *

- * * @param permission the permission or permissions required (at least one entry * required) * @param sids the security identities held by the principal (at least one entry @@ -166,17 +161,15 @@ public interface Acl extends Serializable { * @param administrativeMode if true denotes the query is for administrative * purposes and no logging or auditing (if supported by the implementation) should be * undertaken - * * @return true if authorization is granted - * * @throws NotFoundException MUST be thrown if an implementation cannot make an * authoritative authorization decision, usually because there is no ACL information * for this particular permission and/or SID * @throws UnloadedSidException thrown if the Acl does not have details for * one or more of the Sids passed as arguments */ - boolean isGranted(List permission, List sids, - boolean administrativeMode) throws NotFoundException, UnloadedSidException; + boolean isGranted(List permission, List sids, boolean administrativeMode) + throws NotFoundException, UnloadedSidException; /** * For efficiency reasons an Acl may be loaded and not contain @@ -191,12 +184,11 @@ public interface Acl extends Serializable { * all Sids. This method denotes whether or not the specified Sids * have been loaded or not. *

- * * @param sids one or more security identities the caller is interest in knowing * whether this Sid supports - * * @return true if every passed Sid is represented by this * Acl instance */ boolean isSidLoaded(List sids); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/AclCache.java b/acl/src/main/java/org/springframework/security/acls/model/AclCache.java index 7c23e5f999..945b1b1a2a 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AclCache.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AclCache.java @@ -13,21 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; -import org.springframework.security.acls.jdbc.JdbcAclService; - import java.io.Serializable; +import org.springframework.security.acls.jdbc.JdbcAclService; + /** * A caching layer for {@link JdbcAclService}. * * @author Ben Alex - * */ public interface AclCache { - // ~ Methods - // ======================================================================================================== void evictFromCache(Serializable pk); @@ -40,4 +38,5 @@ public interface AclCache { void putInCache(MutableAcl acl); void clearCache(); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/AclDataAccessException.java b/acl/src/main/java/org/springframework/security/acls/model/AclDataAccessException.java index 090cecc49f..84d65590b7 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AclDataAccessException.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AclDataAccessException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -26,7 +27,6 @@ public abstract class AclDataAccessException extends RuntimeException { /** * Constructs an AclDataAccessException with the specified message and * root cause. - * * @param msg the detail message * @param cause the root cause */ @@ -37,10 +37,10 @@ public abstract class AclDataAccessException extends RuntimeException { /** * Constructs an AclDataAccessException with the specified message and no * root cause. - * * @param msg the detail message */ public AclDataAccessException(String msg) { super(msg); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/AclService.java b/acl/src/main/java/org/springframework/security/acls/model/AclService.java index 462d33d7f2..2866ec83bd 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AclService.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AclService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.util.List; @@ -24,15 +25,11 @@ import java.util.Map; * @author Ben Alex */ public interface AclService { - // ~ Methods - // ======================================================================================================== /** * Locates all object identities that use the specified parent. This is useful for * administration tools. - * * @param parentIdentity to locate children of - * * @return the children (or null if none were found) */ List findChildren(ObjectIdentity parentIdentity); @@ -44,12 +41,9 @@ public interface AclService { * implementation's potential ability to filter Acl entries based on a * {@link Sid} parameter. *

- * * @param object to locate an {@link Acl} for - * * @return the {@link Acl} for the requested {@link ObjectIdentity} (never * null) - * * @throws NotFoundException if an {@link Acl} was not found for the requested * {@link ObjectIdentity} */ @@ -57,14 +51,11 @@ public interface AclService { /** * Same as {@link #readAclsById(List, List)} except it returns only a single Acl. - * * @param object to locate an {@link Acl} for * @param sids the security identities for which {@link Acl} information is required * (may be null to denote all entries) - * * @return the {@link Acl} for the requested {@link ObjectIdentity} (never * null) - * * @throws NotFoundException if an {@link Acl} was not found for the requested * {@link ObjectIdentity} */ @@ -76,17 +67,13 @@ public interface AclService { * The returned map is keyed on the passed objects, with the values being the * Acl instances. Any unknown objects will not have a map key. *

- * * @param objects the objects to find {@link Acl} information for - * * @return a map with exactly one element for each {@link ObjectIdentity} passed as an * argument (never null) - * * @throws NotFoundException if an {@link Acl} was not found for each requested * {@link ObjectIdentity} */ - Map readAclsById(List objects) - throws NotFoundException; + Map readAclsById(List objects) throws NotFoundException; /** * Obtains all the Acls that apply for the passed Objects, but only @@ -103,17 +90,14 @@ public interface AclService { * Acl instances. Any unknown objects (or objects for which the interested * Sids do not have entries) will not have a map key. *

- * * @param objects the objects to find {@link Acl} information for * @param sids the security identities for which {@link Acl} information is required * (may be null to denote all entries) - * * @return a map with exactly one element for each {@link ObjectIdentity} passed as an * argument (never null) - * * @throws NotFoundException if an {@link Acl} was not found for each requested * {@link ObjectIdentity} */ - Map readAclsById(List objects, List sids) - throws NotFoundException; + Map readAclsById(List objects, List sids) throws NotFoundException; + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/AlreadyExistsException.java b/acl/src/main/java/org/springframework/security/acls/model/AlreadyExistsException.java index d611ab11ee..0f82ee6b34 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AlreadyExistsException.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AlreadyExistsException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -21,12 +22,9 @@ package org.springframework.security.acls.model; * @author Ben Alex */ public class AlreadyExistsException extends AclDataAccessException { - // ~ Constructors - // =================================================================================================== /** * Constructs an AlreadyExistsException with the specified message. - * * @param msg the detail message */ public AlreadyExistsException(String msg) { @@ -36,11 +34,11 @@ public class AlreadyExistsException extends AclDataAccessException { /** * Constructs an AlreadyExistsException with the specified message and * root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public AlreadyExistsException(String msg, Throwable t) { - super(msg, t); + public AlreadyExistsException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/AuditableAccessControlEntry.java b/acl/src/main/java/org/springframework/security/acls/model/AuditableAccessControlEntry.java index 30ea235690..e14a5d1596 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AuditableAccessControlEntry.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AuditableAccessControlEntry.java @@ -13,19 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** * Represents an ACE that provides auditing information. * * @author Ben Alex - * */ public interface AuditableAccessControlEntry extends AccessControlEntry { - // ~ Methods - // ======================================================================================================== boolean isAuditFailure(); boolean isAuditSuccess(); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/AuditableAcl.java b/acl/src/main/java/org/springframework/security/acls/model/AuditableAcl.java index ddf31a954d..3760d78dc3 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/AuditableAcl.java +++ b/acl/src/main/java/org/springframework/security/acls/model/AuditableAcl.java @@ -13,17 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** * A mutable ACL that provides audit capabilities. * * @author Ben Alex - * */ public interface AuditableAcl extends MutableAcl { - // ~ Methods - // ======================================================================================================== void updateAuditing(int aceIndex, boolean auditSuccess, boolean auditFailure); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/ChildrenExistException.java b/acl/src/main/java/org/springframework/security/acls/model/ChildrenExistException.java index 514258ec07..d65f1903bb 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/ChildrenExistException.java +++ b/acl/src/main/java/org/springframework/security/acls/model/ChildrenExistException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -21,12 +22,9 @@ package org.springframework.security.acls.model; * @author Ben Alex */ public class ChildrenExistException extends AclDataAccessException { - // ~ Constructors - // =================================================================================================== /** * Constructs an ChildrenExistException with the specified message. - * * @param msg the detail message */ public ChildrenExistException(String msg) { @@ -36,11 +34,11 @@ public class ChildrenExistException extends AclDataAccessException { /** * Constructs an ChildrenExistException with the specified message and * root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public ChildrenExistException(String msg, Throwable t) { - super(msg, t); + public ChildrenExistException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/MutableAcl.java b/acl/src/main/java/org/springframework/security/acls/model/MutableAcl.java index 0e91c66e25..9231a0cc30 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/MutableAcl.java +++ b/acl/src/main/java/org/springframework/security/acls/model/MutableAcl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -26,41 +27,35 @@ import java.io.Serializable; * @author Ben Alex */ public interface MutableAcl extends Acl { - // ~ Methods - // ======================================================================================================== void deleteAce(int aceIndex) throws NotFoundException; /** * Obtains an identifier that represents this MutableAcl. - * * @return the identifier, or null if unsaved */ Serializable getId(); - void insertAce(int atIndexLocation, Permission permission, Sid sid, boolean granting) - throws NotFoundException; + void insertAce(int atIndexLocation, Permission permission, Sid sid, boolean granting) throws NotFoundException; /** * Changes the present owner to a different owner. - * * @param newOwner the new owner (mandatory; cannot be null) */ void setOwner(Sid newOwner); /** * Change the value returned by {@link Acl#isEntriesInheriting()}. - * * @param entriesInheriting the new value */ void setEntriesInheriting(boolean entriesInheriting); /** * Changes the parent of this ACL. - * * @param newParent the new parent */ void setParent(Acl newParent); void updateAce(int aceIndex, Permission permission) throws NotFoundException; + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/MutableAclService.java b/acl/src/main/java/org/springframework/security/acls/model/MutableAclService.java index 7b14f99e40..cdc229120e 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/MutableAclService.java +++ b/acl/src/main/java/org/springframework/security/acls/model/MutableAclService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -21,41 +22,32 @@ package org.springframework.security.acls.model; * @author Ben Alex */ public interface MutableAclService extends AclService { - // ~ Methods - // ======================================================================================================== /** * Creates an empty Acl object in the database. It will have no entries. * The returned object will then be used to add entries. - * * @param objectIdentity the object identity to create - * * @return an ACL object with its ID set - * * @throws AlreadyExistsException if the passed object identity already has a record */ MutableAcl createAcl(ObjectIdentity objectIdentity) throws AlreadyExistsException; /** * Removes the specified entry from the database. - * * @param objectIdentity the object identity to remove * @param deleteChildren whether to cascade the delete to children - * * @throws ChildrenExistException if the deleteChildren argument was * false but children exist */ - void deleteAcl(ObjectIdentity objectIdentity, boolean deleteChildren) - throws ChildrenExistException; + void deleteAcl(ObjectIdentity objectIdentity, boolean deleteChildren) throws ChildrenExistException; /** * Changes an existing Acl in the database. - * * @param acl to modify - * * @throws NotFoundException if the relevant record could not be found (did you * remember to use {@link #createAcl(ObjectIdentity)} to create the object, rather * than creating it with the new keyword?) */ MutableAcl updateAcl(MutableAcl acl) throws NotFoundException; + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/NotFoundException.java b/acl/src/main/java/org/springframework/security/acls/model/NotFoundException.java index e470d61292..1b0ab639bd 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/NotFoundException.java +++ b/acl/src/main/java/org/springframework/security/acls/model/NotFoundException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -21,12 +22,9 @@ package org.springframework.security.acls.model; * @author Ben Alex */ public class NotFoundException extends AclDataAccessException { - // ~ Constructors - // =================================================================================================== /** * Constructs an NotFoundException with the specified message. - * * @param msg the detail message */ public NotFoundException(String msg) { @@ -36,11 +34,11 @@ public class NotFoundException extends AclDataAccessException { /** * Constructs an NotFoundException with the specified message and root * cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public NotFoundException(String msg, Throwable t) { - super(msg, t); + public NotFoundException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentity.java b/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentity.java index 2edf7cd2e1..995e514056 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentity.java +++ b/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentity.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -32,15 +33,13 @@ import java.io.Serializable; * @author Ben Alex */ public interface ObjectIdentity extends Serializable { - // ~ Methods - // ======================================================================================================== /** * @param obj to be compared - * * @return true if the objects are equal, false otherwise * @see Object#equals(Object) */ + @Override boolean equals(Object obj); /** @@ -53,7 +52,6 @@ public interface ObjectIdentity extends Serializable { * identifier with business meaning, as that business meaning may change in the future * such change will cascade to the ACL subsystem data. *

- * * @return the identifier (unique within this type; never null) */ Serializable getIdentifier(); @@ -62,7 +60,6 @@ public interface ObjectIdentity extends Serializable { * Obtains the "type" metadata for the domain object. This will often be a Java type * name (an interface or a class) – traditionally it is the name of the domain * object implementation class. - * * @return the "type" of the domain object (never null). */ String getType(); @@ -71,5 +68,7 @@ public interface ObjectIdentity extends Serializable { * @return a hash code representation of the ObjectIdentity * @see Object#hashCode() */ + @Override int hashCode(); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityGenerator.java b/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityGenerator.java index 1814295874..e46f3dc09f 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityGenerator.java +++ b/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityGenerator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -30,7 +31,6 @@ import java.io.Serializable; public interface ObjectIdentityGenerator { /** - * * @param id the identifier of the domain object, not null * @param type the type of the object (often a class name), not null * @return the identity constructed using the supplied identifier and type diff --git a/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityRetrievalStrategy.java b/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityRetrievalStrategy.java index cd1b53031e..3838443a5d 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityRetrievalStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/model/ObjectIdentityRetrievalStrategy.java @@ -21,11 +21,9 @@ package org.springframework.security.acls.model; * will be returned for a particular domain object * * @author Ben Alex - * */ public interface ObjectIdentityRetrievalStrategy { - // ~ Methods - // ======================================================================================================== ObjectIdentity getObjectIdentity(Object domainObject); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/OwnershipAcl.java b/acl/src/main/java/org/springframework/security/acls/model/OwnershipAcl.java index edda240b19..de1e0bbace 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/OwnershipAcl.java +++ b/acl/src/main/java/org/springframework/security/acls/model/OwnershipAcl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -25,8 +26,8 @@ package org.springframework.security.acls.model; * @author Ben Alex */ public interface OwnershipAcl extends MutableAcl { - // ~ Methods - // ======================================================================================================== + @Override void setOwner(Sid newOwner); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/Permission.java b/acl/src/main/java/org/springframework/security/acls/model/Permission.java index 68beec8568..99a0d36a76 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/Permission.java +++ b/acl/src/main/java/org/springframework/security/acls/model/Permission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -23,19 +24,15 @@ import java.io.Serializable; * @author Ben Alex */ public interface Permission extends Serializable { - // ~ Static fields/initializers - // ===================================================================================== char RESERVED_ON = '~'; - char RESERVED_OFF = '.'; - String THIRTY_TWO_RESERVED_OFF = "................................"; - // ~ Methods - // ======================================================================================================== + char RESERVED_OFF = '.'; + + String THIRTY_TWO_RESERVED_OFF = "................................"; /** * Returns the bits that represents the permission. - * * @return the bits that represent the permission */ int getMask(); @@ -56,8 +53,8 @@ public interface Permission extends Serializable { * This method is only used for user interface and logging purposes. It is not used in * any permission calculations. Therefore, duplication of characters within the output * is permitted. - * * @return a 32-character bit pattern */ String getPattern(); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/PermissionGrantingStrategy.java b/acl/src/main/java/org/springframework/security/acls/model/PermissionGrantingStrategy.java index 14c8185c18..13ccb68062 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/PermissionGrantingStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/model/PermissionGrantingStrategy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.util.List; @@ -27,10 +28,9 @@ import java.util.List; public interface PermissionGrantingStrategy { /** - * Returns true if the supplied strategy decides that the supplied {@code Acl} - * grants access based on the supplied list of permissions and sids. + * Returns true if the supplied strategy decides that the supplied {@code Acl} grants + * access based on the supplied list of permissions and sids. */ - boolean isGranted(Acl acl, List permission, List sids, - boolean administrativeMode); + boolean isGranted(Acl acl, List permission, List sids, boolean administrativeMode); } diff --git a/acl/src/main/java/org/springframework/security/acls/model/Sid.java b/acl/src/main/java/org/springframework/security/acls/model/Sid.java index 134fc0ed7a..1e16e8ea94 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/Sid.java +++ b/acl/src/main/java/org/springframework/security/acls/model/Sid.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; import java.io.Serializable; @@ -31,24 +32,22 @@ import java.io.Serializable; * @author Ben Alex */ public interface Sid extends Serializable { - // ~ Methods - // ======================================================================================================== /** * Refer to the java.lang.Object documentation for the interface * contract. - * * @param obj to be compared - * * @return true if the objects are equal, false otherwise */ + @Override boolean equals(Object obj); /** * Refer to the java.lang.Object documentation for the interface * contract. - * * @return a hash code representation of this object */ + @Override int hashCode(); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/SidRetrievalStrategy.java b/acl/src/main/java/org/springframework/security/acls/model/SidRetrievalStrategy.java index 3f605440c0..42694840d3 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/SidRetrievalStrategy.java +++ b/acl/src/main/java/org/springframework/security/acls/model/SidRetrievalStrategy.java @@ -27,8 +27,7 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface SidRetrievalStrategy { - // ~ Methods - // ======================================================================================================== List getSids(Authentication authentication); + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/UnloadedSidException.java b/acl/src/main/java/org/springframework/security/acls/model/UnloadedSidException.java index b692ad79ca..fe208d693b 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/UnloadedSidException.java +++ b/acl/src/main/java/org/springframework/security/acls/model/UnloadedSidException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.model; /** @@ -23,12 +24,9 @@ package org.springframework.security.acls.model; * @author Ben Alex */ public class UnloadedSidException extends AclDataAccessException { - // ~ Constructors - // =================================================================================================== /** * Constructs an NotFoundException with the specified message. - * * @param msg the detail message */ public UnloadedSidException(String msg) { @@ -38,11 +36,11 @@ public class UnloadedSidException extends AclDataAccessException { /** * Constructs an NotFoundException with the specified message and root * cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public UnloadedSidException(String msg, Throwable t) { - super(msg, t); + public UnloadedSidException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/acl/src/main/java/org/springframework/security/acls/model/package-info.java b/acl/src/main/java/org/springframework/security/acls/model/package-info.java index c6f8389ae9..98dba043d4 100644 --- a/acl/src/main/java/org/springframework/security/acls/model/package-info.java +++ b/acl/src/main/java/org/springframework/security/acls/model/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Interfaces and shared classes to manage access control lists (ACLs) for domain object instances. + * Interfaces and shared classes to manage access control lists (ACLs) for domain object + * instances. */ package org.springframework.security.acls.model; - diff --git a/acl/src/main/java/org/springframework/security/acls/package-info.java b/acl/src/main/java/org/springframework/security/acls/package-info.java index 75d4de804f..dbd8f749de 100644 --- a/acl/src/main/java/org/springframework/security/acls/package-info.java +++ b/acl/src/main/java/org/springframework/security/acls/package-info.java @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * The Spring Security ACL package which implements instance-based security for domain objects. + * The Spring Security ACL package which implements instance-based security for domain + * objects. *

- * Consider using the annotation based approach ({@code @PreAuthorize}, {@code @PostFilter} annotations) combined - * with a {@link org.springframework.security.acls.AclPermissionEvaluator} in preference to the older and more verbose - * attribute/voter/after-invocation approach from versions before Spring Security 3.0. + * Consider using the annotation based approach ({@code @PreAuthorize}, + * {@code @PostFilter} annotations) combined with a + * {@link org.springframework.security.acls.AclPermissionEvaluator} in preference to the + * older and more verbose attribute/voter/after-invocation approach from versions before + * Spring Security 3.0. */ package org.springframework.security.acls; - diff --git a/acl/src/test/java/org/springframework/security/acls/AclFormattingUtilsTests.java b/acl/src/test/java/org/springframework/security/acls/AclFormattingUtilsTests.java index acf951ef06..6f09d71890 100644 --- a/acl/src/test/java/org/springframework/security/acls/AclFormattingUtilsTests.java +++ b/acl/src/test/java/org/springframework/security/acls/AclFormattingUtilsTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; +import org.junit.Test; + +import org.springframework.security.acls.domain.AclFormattingUtils; +import org.springframework.security.acls.model.Permission; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -import org.junit.Test; -import org.springframework.security.acls.domain.AclFormattingUtils; -import org.springframework.security.acls.model.Permission; - /** * Tests for {@link AclFormattingUtils}. * @@ -29,8 +31,6 @@ import org.springframework.security.acls.model.Permission; */ public class AclFormattingUtilsTests { - // ~ Methods - // ======================================================================================================== @Test public final void testDemergePatternsParametersConstraints() { try { @@ -39,21 +39,18 @@ public class AclFormattingUtilsTests { } catch (IllegalArgumentException expected) { } - try { AclFormattingUtils.demergePatterns("SOME STRING", null); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { AclFormattingUtils.demergePatterns("SOME STRING", "LONGER SOME STRING"); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { AclFormattingUtils.demergePatterns("SOME STRING", "SAME LENGTH"); } @@ -66,13 +63,10 @@ public class AclFormattingUtilsTests { public final void testDemergePatterns() { String original = "...........................A...R"; String removeBits = "...............................R"; - assertThat(AclFormattingUtils.demergePatterns(original, removeBits)).isEqualTo( - "...........................A...."); - - assertThat(AclFormattingUtils.demergePatterns("ABCDEF", "......")).isEqualTo( - "ABCDEF"); - assertThat(AclFormattingUtils.demergePatterns("ABCDEF", "GHIJKL")).isEqualTo( - "......"); + assertThat(AclFormattingUtils.demergePatterns(original, removeBits)) + .isEqualTo("...........................A...."); + assertThat(AclFormattingUtils.demergePatterns("ABCDEF", "......")).isEqualTo("ABCDEF"); + assertThat(AclFormattingUtils.demergePatterns("ABCDEF", "GHIJKL")).isEqualTo("......"); } @Test @@ -83,21 +77,18 @@ public class AclFormattingUtilsTests { } catch (IllegalArgumentException expected) { } - try { AclFormattingUtils.mergePatterns("SOME STRING", null); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { AclFormattingUtils.mergePatterns("SOME STRING", "LONGER SOME STRING"); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { AclFormattingUtils.mergePatterns("SOME STRING", "SAME LENGTH"); } @@ -109,47 +100,37 @@ public class AclFormattingUtilsTests { public final void testMergePatterns() { String original = "...............................R"; String extraBits = "...........................A...."; - assertThat(AclFormattingUtils.mergePatterns(original, extraBits)).isEqualTo( - "...........................A...R"); - - assertThat(AclFormattingUtils.mergePatterns("ABCDEF", "......")).isEqualTo( - "ABCDEF"); - assertThat(AclFormattingUtils.mergePatterns("ABCDEF", "GHIJKL")).isEqualTo( - "GHIJKL"); + assertThat(AclFormattingUtils.mergePatterns(original, extraBits)).isEqualTo("...........................A...R"); + assertThat(AclFormattingUtils.mergePatterns("ABCDEF", "......")).isEqualTo("ABCDEF"); + assertThat(AclFormattingUtils.mergePatterns("ABCDEF", "GHIJKL")).isEqualTo("GHIJKL"); } @Test public final void testBinaryPrints() { - assertThat(AclFormattingUtils.printBinary(15)).isEqualTo( - "............................****"); - + assertThat(AclFormattingUtils.printBinary(15)).isEqualTo("............................****"); try { AclFormattingUtils.printBinary(15, Permission.RESERVED_ON); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException notExpected) { } - try { AclFormattingUtils.printBinary(15, Permission.RESERVED_OFF); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException notExpected) { } - - assertThat(AclFormattingUtils.printBinary(15, 'x')).isEqualTo( - "............................xxxx"); + assertThat(AclFormattingUtils.printBinary(15, 'x')).isEqualTo("............................xxxx"); } @Test public void testPrintBinaryNegative() { - assertThat(AclFormattingUtils.printBinary(0x80000000)).isEqualTo( - "*..............................."); + assertThat(AclFormattingUtils.printBinary(0x80000000)).isEqualTo("*..............................."); } @Test public void testPrintBinaryMinusOne() { - assertThat(AclFormattingUtils.printBinary(0xffffffff)).isEqualTo( - "********************************"); + assertThat(AclFormattingUtils.printBinary(0xffffffff)).isEqualTo("********************************"); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/AclPermissionCacheOptimizerTests.java b/acl/src/test/java/org/springframework/security/acls/AclPermissionCacheOptimizerTests.java index d499a79ca7..844a2d4d86 100644 --- a/acl/src/test/java/org/springframework/security/acls/AclPermissionCacheOptimizerTests.java +++ b/acl/src/test/java/org/springframework/security/acls/AclPermissionCacheOptimizerTests.java @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; -import static org.mockito.Mockito.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.junit.Test; + import org.springframework.security.acls.domain.ObjectIdentityImpl; import org.springframework.security.acls.model.AclService; import org.springframework.security.acls.model.ObjectIdentity; @@ -25,9 +29,12 @@ import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; import org.springframework.security.acls.model.SidRetrievalStrategy; import org.springframework.security.core.Authentication; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; /** * @author Luke Taylor @@ -44,13 +51,10 @@ public class AclPermissionCacheOptimizerTests { pco.setObjectIdentityRetrievalStrategy(oidStrat); pco.setSidRetrievalStrategy(sidStrat); Object[] dos = { new Object(), null, new Object() }; - ObjectIdentity[] oids = { new ObjectIdentityImpl("A", "1"), - new ObjectIdentityImpl("A", "2") }; - when(oidStrat.getObjectIdentity(dos[0])).thenReturn(oids[0]); - when(oidStrat.getObjectIdentity(dos[2])).thenReturn(oids[1]); - + ObjectIdentity[] oids = { new ObjectIdentityImpl("A", "1"), new ObjectIdentityImpl("A", "2") }; + given(oidStrat.getObjectIdentity(dos[0])).willReturn(oids[0]); + given(oidStrat.getObjectIdentity(dos[2])).willReturn(oids[1]); pco.cachePermissionsFor(mock(Authentication.class), Arrays.asList(dos)); - // AclService should be invoked with the list of required Oids verify(service).readAclsById(eq(Arrays.asList(oids)), any(List.class)); } @@ -63,9 +67,7 @@ public class AclPermissionCacheOptimizerTests { SidRetrievalStrategy sids = mock(SidRetrievalStrategy.class); pco.setObjectIdentityRetrievalStrategy(oids); pco.setSidRetrievalStrategy(sids); - pco.cachePermissionsFor(mock(Authentication.class), Collections.emptyList()); - verifyZeroInteractions(service, sids, oids); } diff --git a/acl/src/test/java/org/springframework/security/acls/AclPermissionEvaluatorTests.java b/acl/src/test/java/org/springframework/security/acls/AclPermissionEvaluatorTests.java index 47df332908..a29c3ab3be 100644 --- a/acl/src/test/java/org/springframework/security/acls/AclPermissionEvaluatorTests.java +++ b/acl/src/test/java/org/springframework/security/acls/AclPermissionEvaluatorTests.java @@ -13,15 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; -import static org.assertj.core.api.Assertions.*; - -import static org.mockito.Mockito.*; - import java.util.Locale; import org.junit.Test; + import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.AclService; import org.springframework.security.acls.model.ObjectIdentity; @@ -29,8 +27,14 @@ import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; import org.springframework.security.acls.model.SidRetrievalStrategy; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** - * * @author Luke Taylor * @since 3.0 */ @@ -42,14 +46,12 @@ public class AclPermissionEvaluatorTests { AclPermissionEvaluator pe = new AclPermissionEvaluator(service); ObjectIdentity oid = mock(ObjectIdentity.class); ObjectIdentityRetrievalStrategy oidStrategy = mock(ObjectIdentityRetrievalStrategy.class); - when(oidStrategy.getObjectIdentity(any(Object.class))).thenReturn(oid); + given(oidStrategy.getObjectIdentity(any(Object.class))).willReturn(oid); pe.setObjectIdentityRetrievalStrategy(oidStrategy); pe.setSidRetrievalStrategy(mock(SidRetrievalStrategy.class)); Acl acl = mock(Acl.class); - - when(service.readAclById(any(ObjectIdentity.class), anyList())).thenReturn(acl); - when(acl.isGranted(anyList(), anyList(), eq(false))).thenReturn(true); - + given(service.readAclById(any(ObjectIdentity.class), anyList())).willReturn(acl); + given(acl.isGranted(anyList(), anyList(), eq(false))).willReturn(true); assertThat(pe.hasPermission(mock(Authentication.class), new Object(), "READ")).isTrue(); } @@ -57,21 +59,18 @@ public class AclPermissionEvaluatorTests { public void resolvePermissionNonEnglishLocale() { Locale systemLocale = Locale.getDefault(); Locale.setDefault(new Locale("tr")); - AclService service = mock(AclService.class); AclPermissionEvaluator pe = new AclPermissionEvaluator(service); ObjectIdentity oid = mock(ObjectIdentity.class); ObjectIdentityRetrievalStrategy oidStrategy = mock(ObjectIdentityRetrievalStrategy.class); - when(oidStrategy.getObjectIdentity(any(Object.class))).thenReturn(oid); + given(oidStrategy.getObjectIdentity(any(Object.class))).willReturn(oid); pe.setObjectIdentityRetrievalStrategy(oidStrategy); pe.setSidRetrievalStrategy(mock(SidRetrievalStrategy.class)); Acl acl = mock(Acl.class); - - when(service.readAclById(any(ObjectIdentity.class), anyList())).thenReturn(acl); - when(acl.isGranted(anyList(), anyList(), eq(false))).thenReturn(true); - + given(service.readAclById(any(ObjectIdentity.class), anyList())).willReturn(acl); + given(acl.isGranted(anyList(), anyList(), eq(false))).willReturn(true); assertThat(pe.hasPermission(mock(Authentication.class), new Object(), "write")).isTrue(); - Locale.setDefault(systemLocale); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/TargetObject.java b/acl/src/test/java/org/springframework/security/acls/TargetObject.java index c18f3e03ee..eeac7c8d70 100644 --- a/acl/src/test/java/org/springframework/security/acls/TargetObject.java +++ b/acl/src/test/java/org/springframework/security/acls/TargetObject.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; /** diff --git a/acl/src/test/java/org/springframework/security/acls/TargetObjectWithUUID.java b/acl/src/test/java/org/springframework/security/acls/TargetObjectWithUUID.java index 426956f7ea..ee0727fa69 100644 --- a/acl/src/test/java/org/springframework/security/acls/TargetObjectWithUUID.java +++ b/acl/src/test/java/org/springframework/security/acls/TargetObjectWithUUID.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls; import java.util.UUID; @@ -27,10 +28,11 @@ public final class TargetObjectWithUUID { private UUID id; public UUID getId() { - return id; + return this.id; } public void setId(UUID id) { this.id = id; } + } diff --git a/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProviderTests.java b/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProviderTests.java index b85b01d3dd..296e306d5f 100644 --- a/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProviderTests.java +++ b/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationCollectionFilteringProviderTests.java @@ -13,46 +13,54 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.afterinvocation; - -import static org.assertj.core.api.Assertions.*; - -import static org.mockito.Mockito.*; - -import org.junit.Test; -import org.springframework.security.access.ConfigAttribute; -import org.springframework.security.access.SecurityConfig; -import org.springframework.security.acls.model.*; -import org.springframework.security.core.Authentication; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.junit.Test; + +import org.springframework.security.access.ConfigAttribute; +import org.springframework.security.access.SecurityConfig; +import org.springframework.security.acls.model.Acl; +import org.springframework.security.acls.model.AclService; +import org.springframework.security.acls.model.ObjectIdentity; +import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; +import org.springframework.security.acls.model.Permission; +import org.springframework.security.acls.model.SidRetrievalStrategy; +import org.springframework.security.core.Authentication; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + /** * @author Luke Taylor */ @SuppressWarnings({ "unchecked" }) public class AclEntryAfterInvocationCollectionFilteringProviderTests { + @Test public void objectsAreRemovedIfPermissionDenied() { AclService service = mock(AclService.class); Acl acl = mock(Acl.class); - when(acl.isGranted(any(), any(), anyBoolean())).thenReturn( - false); - when(service.readAclById(any(), any())).thenReturn( - acl); + given(acl.isGranted(any(), any(), anyBoolean())).willReturn(false); + given(service.readAclById(any(), any())).willReturn(acl); AclEntryAfterInvocationCollectionFilteringProvider provider = new AclEntryAfterInvocationCollectionFilteringProvider( service, Arrays.asList(mock(Permission.class))); provider.setObjectIdentityRetrievalStrategy(mock(ObjectIdentityRetrievalStrategy.class)); provider.setProcessDomainObjectClass(Object.class); provider.setSidRetrievalStrategy(mock(SidRetrievalStrategy.class)); - Object returned = provider.decide(mock(Authentication.class), new Object(), - SecurityConfig.createList("AFTER_ACL_COLLECTION_READ"), new ArrayList( - Arrays.asList(new Object(), new Object()))); + SecurityConfig.createList("AFTER_ACL_COLLECTION_READ"), + new ArrayList(Arrays.asList(new Object(), new Object()))); assertThat(returned).isInstanceOf(List.class); assertThat(((List) returned)).isEmpty(); returned = provider.decide(mock(Authentication.class), new Object(), @@ -67,11 +75,8 @@ public class AclEntryAfterInvocationCollectionFilteringProviderTests { AclEntryAfterInvocationCollectionFilteringProvider provider = new AclEntryAfterInvocationCollectionFilteringProvider( mock(AclService.class), Arrays.asList(mock(Permission.class))); Object returned = new Object(); - - assertThat(returned) - .isSameAs( - provider.decide(mock(Authentication.class), new Object(), - Collections. emptyList(), returned)); + assertThat(returned).isSameAs(provider.decide(mock(Authentication.class), new Object(), + Collections.emptyList(), returned)); } @Test @@ -79,10 +84,8 @@ public class AclEntryAfterInvocationCollectionFilteringProviderTests { AclService service = mock(AclService.class); AclEntryAfterInvocationCollectionFilteringProvider provider = new AclEntryAfterInvocationCollectionFilteringProvider( service, Arrays.asList(mock(Permission.class))); - assertThat(provider.decide(mock(Authentication.class), new Object(), - SecurityConfig.createList("AFTER_ACL_COLLECTION_READ"), null)) - .isNull(); + SecurityConfig.createList("AFTER_ACL_COLLECTION_READ"), null)).isNull(); verify(service, never()).readAclById(any(ObjectIdentity.class), any(List.class)); } diff --git a/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProviderTests.java b/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProviderTests.java index 322bb3d11f..b044f89c3a 100644 --- a/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProviderTests.java +++ b/acl/src/test/java/org/springframework/security/acls/afterinvocation/AclEntryAfterInvocationProviderTests.java @@ -13,23 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.afterinvocation; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import org.junit.Test; -import org.springframework.security.access.AccessDeniedException; -import org.springframework.security.access.ConfigAttribute; -import org.springframework.security.access.SecurityConfig; -import org.springframework.security.acls.model.*; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.SpringSecurityMessageSource; - import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.junit.Test; + +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.access.ConfigAttribute; +import org.springframework.security.access.SecurityConfig; +import org.springframework.security.acls.model.Acl; +import org.springframework.security.acls.model.AclService; +import org.springframework.security.acls.model.NotFoundException; +import org.springframework.security.acls.model.ObjectIdentity; +import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; +import org.springframework.security.acls.model.Permission; +import org.springframework.security.acls.model.SidRetrievalStrategy; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.SpringSecurityMessageSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + /** * @author Luke Taylor */ @@ -44,74 +58,56 @@ public class AclEntryAfterInvocationProviderTests { } catch (IllegalArgumentException expected) { } - new AclEntryAfterInvocationProvider(mock(AclService.class), - Collections. emptyList()); + new AclEntryAfterInvocationProvider(mock(AclService.class), Collections.emptyList()); } @Test public void accessIsAllowedIfPermissionIsGranted() { AclService service = mock(AclService.class); Acl acl = mock(Acl.class); - when(acl.isGranted(any(List.class), any(List.class), anyBoolean())).thenReturn( - true); - when(service.readAclById(any(), any())).thenReturn( - acl); - AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider( - service, Arrays.asList(mock(Permission.class))); + given(acl.isGranted(any(List.class), any(List.class), anyBoolean())).willReturn(true); + given(service.readAclById(any(), any())).willReturn(acl); + AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider(service, + Arrays.asList(mock(Permission.class))); provider.setMessageSource(new SpringSecurityMessageSource()); provider.setObjectIdentityRetrievalStrategy(mock(ObjectIdentityRetrievalStrategy.class)); provider.setProcessDomainObjectClass(Object.class); provider.setSidRetrievalStrategy(mock(SidRetrievalStrategy.class)); Object returned = new Object(); - - assertThat( - returned) - .isSameAs( - provider.decide(mock(Authentication.class), new Object(), - SecurityConfig.createList("AFTER_ACL_READ"), returned)); + assertThat(returned).isSameAs(provider.decide(mock(Authentication.class), new Object(), + SecurityConfig.createList("AFTER_ACL_READ"), returned)); } @Test public void accessIsGrantedIfNoAttributesDefined() { - AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider( - mock(AclService.class), Arrays.asList(mock(Permission.class))); + AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider(mock(AclService.class), + Arrays.asList(mock(Permission.class))); Object returned = new Object(); - - assertThat( - returned) - .isSameAs( - provider.decide(mock(Authentication.class), new Object(), - Collections. emptyList(), returned)); + assertThat(returned).isSameAs(provider.decide(mock(Authentication.class), new Object(), + Collections.emptyList(), returned)); } @Test public void accessIsGrantedIfObjectTypeNotSupported() { - AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider( - mock(AclService.class), Arrays.asList(mock(Permission.class))); + AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider(mock(AclService.class), + Arrays.asList(mock(Permission.class))); provider.setProcessDomainObjectClass(String.class); // Not a String Object returned = new Object(); - - assertThat( - returned) - .isSameAs( - provider.decide(mock(Authentication.class), new Object(), - SecurityConfig.createList("AFTER_ACL_READ"), returned)); + assertThat(returned).isSameAs(provider.decide(mock(Authentication.class), new Object(), + SecurityConfig.createList("AFTER_ACL_READ"), returned)); } @Test(expected = AccessDeniedException.class) public void accessIsDeniedIfPermissionIsNotGranted() { AclService service = mock(AclService.class); Acl acl = mock(Acl.class); - when(acl.isGranted(any(List.class), any(List.class), anyBoolean())).thenReturn( - false); + given(acl.isGranted(any(List.class), any(List.class), anyBoolean())).willReturn(false); // Try a second time with no permissions found - when(acl.isGranted(any(), any(List.class), anyBoolean())).thenThrow( - new NotFoundException("")); - when(service.readAclById(any(), any())).thenReturn( - acl); - AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider( - service, Arrays.asList(mock(Permission.class))); + given(acl.isGranted(any(), any(List.class), anyBoolean())).willThrow(new NotFoundException("")); + given(service.readAclById(any(), any())).willReturn(acl); + AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider(service, + Arrays.asList(mock(Permission.class))); provider.setProcessConfigAttribute("MY_ATTRIBUTE"); provider.setMessageSource(new SpringSecurityMessageSource()); provider.setObjectIdentityRetrievalStrategy(mock(ObjectIdentityRetrievalStrategy.class)); @@ -119,8 +115,7 @@ public class AclEntryAfterInvocationProviderTests { provider.setSidRetrievalStrategy(mock(SidRetrievalStrategy.class)); try { provider.decide(mock(Authentication.class), new Object(), - SecurityConfig.createList("UNSUPPORTED", "MY_ATTRIBUTE"), - new Object()); + SecurityConfig.createList("UNSUPPORTED", "MY_ATTRIBUTE"), new Object()); fail("Expected Exception"); } catch (AccessDeniedException expected) { @@ -133,12 +128,11 @@ public class AclEntryAfterInvocationProviderTests { @Test public void nullReturnObjectIsIgnored() { AclService service = mock(AclService.class); - AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider( - service, Arrays.asList(mock(Permission.class))); - + AclEntryAfterInvocationProvider provider = new AclEntryAfterInvocationProvider(service, + Arrays.asList(mock(Permission.class))); assertThat(provider.decide(mock(Authentication.class), new Object(), - SecurityConfig.createList("AFTER_ACL_COLLECTION_READ"), null)) - .isNull(); + SecurityConfig.createList("AFTER_ACL_COLLECTION_READ"), null)).isNull(); verify(service, never()).readAclById(any(ObjectIdentity.class), any(List.class)); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/AccessControlImplEntryTests.java b/acl/src/test/java/org/springframework/security/acls/domain/AccessControlImplEntryTests.java index bbf33dfef4..743f8ee3b8 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/AccessControlImplEntryTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/AccessControlImplEntryTests.java @@ -13,18 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import org.junit.Test; + import org.springframework.security.acls.model.AccessControlEntry; import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.AuditableAccessControlEntry; import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.security.acls.model.Sid; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests for {@link AccessControlEntryImpl}. * @@ -32,33 +36,26 @@ import org.springframework.security.acls.model.Sid; */ public class AccessControlImplEntryTests { - // ~ Methods - // ======================================================================================================== - @Test public void testConstructorRequiredFields() { // Check Acl field is present try { - new AccessControlEntryImpl(null, null, new PrincipalSid("johndoe"), - BasePermission.ADMINISTRATION, true, true, true); + new AccessControlEntryImpl(null, null, new PrincipalSid("johndoe"), BasePermission.ADMINISTRATION, true, + true, true); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - // Check Sid field is present try { - new AccessControlEntryImpl(null, mock(Acl.class), null, - BasePermission.ADMINISTRATION, true, true, true); + new AccessControlEntryImpl(null, mock(Acl.class), null, BasePermission.ADMINISTRATION, true, true, true); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - // Check Permission field is present try { - new AccessControlEntryImpl(null, mock(Acl.class), - new PrincipalSid("johndoe"), null, true, true, true); + new AccessControlEntryImpl(null, mock(Acl.class), new PrincipalSid("johndoe"), null, true, true, true); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -69,11 +66,9 @@ public class AccessControlImplEntryTests { public void testAccessControlEntryImplGetters() { Acl mockAcl = mock(Acl.class); Sid sid = new PrincipalSid("johndoe"); - // Create a sample entry - AccessControlEntry ace = new AccessControlEntryImpl(1L, mockAcl, - sid, BasePermission.ADMINISTRATION, true, true, true); - + AccessControlEntry ace = new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.ADMINISTRATION, true, true, + true); // and check every get() method assertThat(ace.getId()).isEqualTo(1L); assertThat(ace.getAcl()).isEqualTo(mockAcl); @@ -88,30 +83,27 @@ public class AccessControlImplEntryTests { public void testEquals() { final Acl mockAcl = mock(Acl.class); final ObjectIdentity oid = mock(ObjectIdentity.class); - - when(mockAcl.getObjectIdentity()).thenReturn(oid); + given(mockAcl.getObjectIdentity()).willReturn(oid); Sid sid = new PrincipalSid("johndoe"); - - AccessControlEntry ace = new AccessControlEntryImpl(1L, mockAcl, - sid, BasePermission.ADMINISTRATION, true, true, true); - + AccessControlEntry ace = new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.ADMINISTRATION, true, true, + true); assertThat(ace).isNotNull(); assertThat(ace).isNotEqualTo(100L); assertThat(ace).isEqualTo(ace); - assertThat(ace).isEqualTo(new AccessControlEntryImpl(1L, mockAcl, sid, + assertThat(ace).isEqualTo( + new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.ADMINISTRATION, true, true, true)); + assertThat(ace).isNotEqualTo( + new AccessControlEntryImpl(2L, mockAcl, sid, BasePermission.ADMINISTRATION, true, true, true)); + assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, new PrincipalSid("scott"), BasePermission.ADMINISTRATION, true, true, true)); - assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(2L, mockAcl, sid, - BasePermission.ADMINISTRATION, true, true, true)); - assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, - new PrincipalSid("scott"), BasePermission.ADMINISTRATION, true, true, - true)); - assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, sid, - BasePermission.WRITE, true, true, true)); - assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, sid, - BasePermission.ADMINISTRATION, false, true, true)); - assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, sid, - BasePermission.ADMINISTRATION, true, false, true)); - assertThat(ace).isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, sid, - BasePermission.ADMINISTRATION, true, true, false)); + assertThat(ace) + .isNotEqualTo(new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.WRITE, true, true, true)); + assertThat(ace).isNotEqualTo( + new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.ADMINISTRATION, false, true, true)); + assertThat(ace).isNotEqualTo( + new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.ADMINISTRATION, true, false, true)); + assertThat(ace).isNotEqualTo( + new AccessControlEntryImpl(1L, mockAcl, sid, BasePermission.ADMINISTRATION, true, true, false)); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImplTests.java b/acl/src/test/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImplTests.java index 0e8d7c12ff..e1b06b7418 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImplTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/AclAuthorizationStrategyImplTests.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.domain; +package org.springframework.security.acls.domain; import java.util.Arrays; @@ -24,6 +24,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.security.acls.model.Acl; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.GrantedAuthority; @@ -31,21 +32,24 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; /** - * * @author Rob Winch * */ @RunWith(MockitoJUnitRunner.class) public class AclAuthorizationStrategyImplTests { + @Mock Acl acl; + GrantedAuthority authority; + AclAuthorizationStrategyImpl strategy; @Before public void setup() { - authority = new SimpleGrantedAuthority("ROLE_AUTH"); - TestingAuthenticationToken authentication = new TestingAuthenticationToken("foo", "bar", Arrays.asList(authority)); + this.authority = new SimpleGrantedAuthority("ROLE_AUTH"); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("foo", "bar", + Arrays.asList(this.authority)); authentication.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(authentication); } @@ -58,15 +62,18 @@ public class AclAuthorizationStrategyImplTests { // gh-4085 @Test public void securityCheckWhenCustomAuthorityThenNameIsUsed() { - strategy = new AclAuthorizationStrategyImpl(new CustomAuthority()); - strategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_GENERAL); + this.strategy = new AclAuthorizationStrategyImpl(new CustomAuthority()); + this.strategy.securityCheck(this.acl, AclAuthorizationStrategy.CHANGE_GENERAL); } @SuppressWarnings("serial") class CustomAuthority implements GrantedAuthority { + @Override public String getAuthority() { - return authority.getAuthority(); + return AclAuthorizationStrategyImplTests.this.authority.getAuthority(); } + } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/AclImplTests.java b/acl/src/test/java/org/springframework/security/acls/domain/AclImplTests.java index 7313048d20..c86776a9c7 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/AclImplTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/AclImplTests.java @@ -13,21 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; -import org.junit.*; -import org.springframework.security.acls.model.*; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.acls.model.AccessControlEntry; +import org.springframework.security.acls.model.Acl; +import org.springframework.security.acls.model.AlreadyExistsException; +import org.springframework.security.acls.model.AuditableAccessControlEntry; +import org.springframework.security.acls.model.AuditableAcl; +import org.springframework.security.acls.model.ChildrenExistException; +import org.springframework.security.acls.model.MutableAcl; +import org.springframework.security.acls.model.MutableAclService; +import org.springframework.security.acls.model.NotFoundException; +import org.springframework.security.acls.model.ObjectIdentity; +import org.springframework.security.acls.model.Permission; +import org.springframework.security.acls.model.PermissionGrantingStrategy; +import org.springframework.security.acls.model.Sid; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.util.FieldUtils; -import java.lang.reflect.Field; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.Mockito.mock; /** * Tests for {@link AclImpl}. @@ -35,33 +55,41 @@ import java.util.*; * @author Andrei Stefan */ public class AclImplTests { + private static final String TARGET_CLASS = "org.springframework.security.acls.TargetObject"; + private static final List READ = Arrays.asList(BasePermission.READ); + private static final List WRITE = Arrays.asList(BasePermission.WRITE); + private static final List CREATE = Arrays.asList(BasePermission.CREATE); + private static final List DELETE = Arrays.asList(BasePermission.DELETE); + private static final List SCOTT = Arrays.asList((Sid) new PrincipalSid("scott")); + private static final List BEN = Arrays.asList((Sid) new PrincipalSid("ben")); - Authentication auth = new TestingAuthenticationToken("joe", "ignored", - "ROLE_ADMINISTRATOR"); - AclAuthorizationStrategy authzStrategy; - PermissionGrantingStrategy pgs; - AuditLogger mockAuditLogger; - ObjectIdentity objectIdentity = new ObjectIdentityImpl(TARGET_CLASS, 100); - private DefaultPermissionFactory permissionFactory; + Authentication auth = new TestingAuthenticationToken("joe", "ignored", "ROLE_ADMINISTRATOR"); - // ~ Methods - // ======================================================================================================== + AclAuthorizationStrategy authzStrategy; + + PermissionGrantingStrategy pgs; + + AuditLogger mockAuditLogger; + + ObjectIdentity objectIdentity = new ObjectIdentityImpl(TARGET_CLASS, 100); + + private DefaultPermissionFactory permissionFactory; @Before public void setUp() { - SecurityContextHolder.getContext().setAuthentication(auth); - authzStrategy = mock(AclAuthorizationStrategy.class); - mockAuditLogger = mock(AuditLogger.class); - pgs = new DefaultPermissionGrantingStrategy(mockAuditLogger); - auth.setAuthenticated(true); - permissionFactory = new DefaultPermissionFactory(); + SecurityContextHolder.getContext().setAuthentication(this.auth); + this.authzStrategy = mock(AclAuthorizationStrategy.class); + this.mockAuditLogger = mock(AuditLogger.class); + this.pgs = new DefaultPermissionGrantingStrategy(this.mockAuditLogger); + this.auth.setAuthenticated(true); + this.permissionFactory = new DefaultPermissionFactory(); } @After @@ -72,44 +100,43 @@ public class AclImplTests { @Test(expected = IllegalArgumentException.class) public void constructorsRejectNullObjectIdentity() { try { - new AclImpl(null, 1, authzStrategy, pgs, null, null, true, new PrincipalSid( - "joe")); + new AclImpl(null, 1, this.authzStrategy, this.pgs, null, null, true, new PrincipalSid("joe")); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - new AclImpl(null, 1, authzStrategy, mockAuditLogger); + new AclImpl(null, 1, this.authzStrategy, this.mockAuditLogger); } @Test(expected = IllegalArgumentException.class) public void constructorsRejectNullId() { try { - new AclImpl(objectIdentity, null, authzStrategy, pgs, null, null, true, + new AclImpl(this.objectIdentity, null, this.authzStrategy, this.pgs, null, null, true, new PrincipalSid("joe")); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - new AclImpl(objectIdentity, null, authzStrategy, mockAuditLogger); + new AclImpl(this.objectIdentity, null, this.authzStrategy, this.mockAuditLogger); } @SuppressWarnings("deprecation") @Test(expected = IllegalArgumentException.class) public void constructorsRejectNullAclAuthzStrategy() { try { - new AclImpl(objectIdentity, 1, null, new DefaultPermissionGrantingStrategy( - mockAuditLogger), null, null, true, new PrincipalSid("joe")); + new AclImpl(this.objectIdentity, 1, null, new DefaultPermissionGrantingStrategy(this.mockAuditLogger), null, + null, true, new PrincipalSid("joe")); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - new AclImpl(objectIdentity, 1, null, mockAuditLogger); + new AclImpl(this.objectIdentity, 1, null, this.mockAuditLogger); } @Test public void insertAceRejectsNullParameters() { - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); try { acl.insertAce(0, null, new GrantedAuthoritySid("ROLE_IGNORED"), true); fail("It should have thrown IllegalArgumentException"); @@ -126,10 +153,9 @@ public class AclImplTests { @Test public void insertAceAddsElementAtCorrectIndex() { - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); MockAclService service = new MockAclService(); - // Insert one permission acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST1"), true); service.updateAcl(acl); @@ -137,9 +163,7 @@ public class AclImplTests { assertThat(acl.getEntries()).hasSize(1); assertThat(acl).isEqualTo(acl.getEntries().get(0).getAcl()); assertThat(BasePermission.READ).isEqualTo(acl.getEntries().get(0).getPermission()); - assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST1")); - + assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST1")); // Add a second permission acl.insertAce(1, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST2"), true); service.updateAcl(acl); @@ -147,71 +171,54 @@ public class AclImplTests { assertThat(acl.getEntries()).hasSize(2); assertThat(acl).isEqualTo(acl.getEntries().get(1).getAcl()); assertThat(BasePermission.READ).isEqualTo(acl.getEntries().get(1).getPermission()); - assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST2")); - + assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST2")); // Add a third permission, after the first one - acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_TEST3"), - false); + acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_TEST3"), false); service.updateAcl(acl); assertThat(acl.getEntries()).hasSize(3); // Check the third entry was added between the two existent ones assertThat(BasePermission.READ).isEqualTo(acl.getEntries().get(0).getPermission()); - assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST1")); + assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST1")); assertThat(BasePermission.WRITE).isEqualTo(acl.getEntries().get(1).getPermission()); - assertThat(acl.getEntries().get(1).getSid()).isEqualTo( new GrantedAuthoritySid( - "ROLE_TEST3")); + assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST3")); assertThat(BasePermission.READ).isEqualTo(acl.getEntries().get(2).getPermission()); - assertThat(acl.getEntries().get(2).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST2")); + assertThat(acl.getEntries().get(2).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST2")); } @Test(expected = NotFoundException.class) public void insertAceFailsForNonExistentElement() { - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); MockAclService service = new MockAclService(); - // Insert one permission acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST1"), true); service.updateAcl(acl); - - acl.insertAce(55, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST2"), - true); + acl.insertAce(55, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST2"), true); } @Test public void deleteAceKeepsInitialOrdering() { - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); MockAclService service = new MockAclService(); - // Add several permissions acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST1"), true); acl.insertAce(1, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST2"), true); acl.insertAce(2, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST3"), true); service.updateAcl(acl); - // Delete first permission and check the order of the remaining permissions is // kept acl.deleteAce(0); assertThat(acl.getEntries()).hasSize(2); - assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST2")); - assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST3")); - + assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST2")); + assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST3")); // Add one more permission and remove the permission in the middle acl.insertAce(2, BasePermission.READ, new GrantedAuthoritySid("ROLE_TEST4"), true); service.updateAcl(acl); acl.deleteAce(1); assertThat(acl.getEntries()).hasSize(2); - assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST2")); - assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid( - "ROLE_TEST4")); - + assertThat(acl.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST2")); + assertThat(acl.getEntries().get(1).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST4")); // Remove remaining permissions acl.deleteAce(1); acl.deleteAce(0); @@ -221,10 +228,10 @@ public class AclImplTests { @Test public void deleteAceFailsForNonExistentElement() { AclAuthorizationStrategyImpl strategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); - MutableAcl acl = new AclImpl(objectIdentity, (1), strategy, pgs, null, null, - true, new PrincipalSid("joe")); + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); + MutableAcl acl = new AclImpl(this.objectIdentity, (1), strategy, this.pgs, null, null, true, + new PrincipalSid("joe")); try { acl.deleteAce(99); fail("It should have thrown NotFoundException"); @@ -235,8 +242,8 @@ public class AclImplTests { @Test public void isGrantingRejectsEmptyParameters() { - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); Sid ben = new PrincipalSid("ben"); try { acl.isGranted(new ArrayList<>(0), Arrays.asList(ben), false); @@ -254,28 +261,21 @@ public class AclImplTests { @Test public void isGrantingGrantsAccessForAclWithNoParent() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_GENERAL", "ROLE_GUEST"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_GENERAL", "ROLE_GUEST"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); ObjectIdentity rootOid = new ObjectIdentityImpl(TARGET_CLASS, 100); - // Create an ACL which owner is not the authenticated principal - MutableAcl rootAcl = new AclImpl(rootOid, 1, authzStrategy, pgs, null, null, - false, new PrincipalSid("joe")); - + MutableAcl rootAcl = new AclImpl(rootOid, 1, this.authzStrategy, this.pgs, null, null, false, + new PrincipalSid("joe")); // Grant some permissions rootAcl.insertAce(0, BasePermission.READ, new PrincipalSid("ben"), false); rootAcl.insertAce(1, BasePermission.WRITE, new PrincipalSid("scott"), true); rootAcl.insertAce(2, BasePermission.WRITE, new PrincipalSid("rod"), false); - rootAcl.insertAce(3, BasePermission.WRITE, new GrantedAuthoritySid( - "WRITE_ACCESS_ROLE"), true); - + rootAcl.insertAce(3, BasePermission.WRITE, new GrantedAuthoritySid("WRITE_ACCESS_ROLE"), true); // Check permissions granting - List permissions = Arrays.asList(BasePermission.READ, - BasePermission.CREATE); - List sids = Arrays.asList(new PrincipalSid("ben"), new GrantedAuthoritySid( - "ROLE_GUEST")); + List permissions = Arrays.asList(BasePermission.READ, BasePermission.CREATE); + List sids = Arrays.asList(new PrincipalSid("ben"), new GrantedAuthoritySid("ROLE_GUEST")); assertThat(rootAcl.isGranted(permissions, sids, false)).isFalse(); try { rootAcl.isGranted(permissions, SCOTT, false); @@ -284,14 +284,14 @@ public class AclImplTests { catch (NotFoundException expected) { } assertThat(rootAcl.isGranted(WRITE, SCOTT, false)).isTrue(); - assertThat(rootAcl.isGranted(WRITE, Arrays.asList(new PrincipalSid("rod"), - new GrantedAuthoritySid("WRITE_ACCESS_ROLE")), false)).isFalse(); - assertThat(rootAcl.isGranted(WRITE, Arrays.asList(new GrantedAuthoritySid( - "WRITE_ACCESS_ROLE"), new PrincipalSid("rod")), false)).isTrue(); + assertThat(rootAcl.isGranted(WRITE, + Arrays.asList(new PrincipalSid("rod"), new GrantedAuthoritySid("WRITE_ACCESS_ROLE")), false)).isFalse(); + assertThat(rootAcl.isGranted(WRITE, + Arrays.asList(new GrantedAuthoritySid("WRITE_ACCESS_ROLE"), new PrincipalSid("rod")), false)).isTrue(); try { // Change the type of the Sid and check the granting process - rootAcl.isGranted(WRITE, Arrays.asList(new GrantedAuthoritySid("rod"), - new PrincipalSid("WRITE_ACCESS_ROLE")), false); + rootAcl.isGranted(WRITE, + Arrays.asList(new GrantedAuthoritySid("rod"), new PrincipalSid("WRITE_ACCESS_ROLE")), false); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { @@ -300,8 +300,7 @@ public class AclImplTests { @Test public void isGrantingGrantsAccessForInheritableAcls() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); ObjectIdentity grandParentOid = new ObjectIdentityImpl(TARGET_CLASS, 100); @@ -309,60 +308,44 @@ public class AclImplTests { ObjectIdentity parentOid2 = new ObjectIdentityImpl(TARGET_CLASS, 102); ObjectIdentity childOid1 = new ObjectIdentityImpl(TARGET_CLASS, 103); ObjectIdentity childOid2 = new ObjectIdentityImpl(TARGET_CLASS, 104); - // Create ACLs PrincipalSid joe = new PrincipalSid("joe"); - MutableAcl grandParentAcl = new AclImpl(grandParentOid, 1, authzStrategy, pgs, - null, null, false, joe); - MutableAcl parentAcl1 = new AclImpl(parentOid1, 2, authzStrategy, pgs, null, - null, true, joe); - MutableAcl parentAcl2 = new AclImpl(parentOid2, 3, authzStrategy, pgs, null, - null, true, joe); - MutableAcl childAcl1 = new AclImpl(childOid1, 4, authzStrategy, pgs, null, null, - true, joe); - MutableAcl childAcl2 = new AclImpl(childOid2, 4, authzStrategy, pgs, null, null, - false, joe); - + MutableAcl grandParentAcl = new AclImpl(grandParentOid, 1, this.authzStrategy, this.pgs, null, null, false, + joe); + MutableAcl parentAcl1 = new AclImpl(parentOid1, 2, this.authzStrategy, this.pgs, null, null, true, joe); + MutableAcl parentAcl2 = new AclImpl(parentOid2, 3, this.authzStrategy, this.pgs, null, null, true, joe); + MutableAcl childAcl1 = new AclImpl(childOid1, 4, this.authzStrategy, this.pgs, null, null, true, joe); + MutableAcl childAcl2 = new AclImpl(childOid2, 4, this.authzStrategy, this.pgs, null, null, false, joe); // Create hierarchies childAcl2.setParent(childAcl1); childAcl1.setParent(parentAcl1); parentAcl2.setParent(grandParentAcl); parentAcl1.setParent(grandParentAcl); - // Add some permissions - grandParentAcl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid( - "ROLE_USER_READ"), true); + grandParentAcl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), true); grandParentAcl.insertAce(1, BasePermission.WRITE, new PrincipalSid("ben"), true); - grandParentAcl - .insertAce(2, BasePermission.DELETE, new PrincipalSid("ben"), false); - grandParentAcl.insertAce(3, BasePermission.DELETE, new PrincipalSid("scott"), - true); + grandParentAcl.insertAce(2, BasePermission.DELETE, new PrincipalSid("ben"), false); + grandParentAcl.insertAce(3, BasePermission.DELETE, new PrincipalSid("scott"), true); parentAcl1.insertAce(0, BasePermission.READ, new PrincipalSid("scott"), true); parentAcl1.insertAce(1, BasePermission.DELETE, new PrincipalSid("scott"), false); parentAcl2.insertAce(0, BasePermission.CREATE, new PrincipalSid("ben"), true); childAcl1.insertAce(0, BasePermission.CREATE, new PrincipalSid("scott"), true); - // Check granting process for parent1 assertThat(parentAcl1.isGranted(READ, SCOTT, false)).isTrue(); - assertThat(parentAcl1.isGranted(READ, - Arrays.asList((Sid) new GrantedAuthoritySid("ROLE_USER_READ")), false)) + assertThat(parentAcl1.isGranted(READ, Arrays.asList((Sid) new GrantedAuthoritySid("ROLE_USER_READ")), false)) .isTrue(); assertThat(parentAcl1.isGranted(WRITE, BEN, false)).isTrue(); assertThat(parentAcl1.isGranted(DELETE, BEN, false)).isFalse(); assertThat(parentAcl1.isGranted(DELETE, SCOTT, false)).isFalse(); - // Check granting process for parent2 assertThat(parentAcl2.isGranted(CREATE, BEN, false)).isTrue(); assertThat(parentAcl2.isGranted(WRITE, BEN, false)).isTrue(); assertThat(parentAcl2.isGranted(DELETE, BEN, false)).isFalse(); - // Check granting process for child1 assertThat(childAcl1.isGranted(CREATE, SCOTT, false)).isTrue(); - assertThat(childAcl1.isGranted(READ, - Arrays.asList((Sid) new GrantedAuthoritySid("ROLE_USER_READ")), false)) + assertThat(childAcl1.isGranted(READ, Arrays.asList((Sid) new GrantedAuthoritySid("ROLE_USER_READ")), false)) .isTrue(); assertThat(childAcl1.isGranted(DELETE, BEN, false)).isFalse(); - // Check granting process for child2 (doesn't inherit the permissions from its // parent) try { @@ -372,8 +355,7 @@ public class AclImplTests { catch (NotFoundException expected) { } try { - childAcl2.isGranted(CREATE, - Arrays.asList((Sid) new PrincipalSid("joe")), false); + childAcl2.isGranted(CREATE, Arrays.asList((Sid) new PrincipalSid("joe")), false); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { @@ -382,30 +364,23 @@ public class AclImplTests { @Test public void updatedAceValuesAreCorrectlyReflectedInAcl() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - false, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, false, + new PrincipalSid("joe")); MockAclService service = new MockAclService(); - - acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), - true); - acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_USER_READ"), - true); + acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), true); + acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_USER_READ"), true); acl.insertAce(2, BasePermission.CREATE, new PrincipalSid("ben"), true); service.updateAcl(acl); - assertThat(BasePermission.READ).isEqualTo(acl.getEntries().get(0).getPermission()); assertThat(BasePermission.WRITE).isEqualTo(acl.getEntries().get(1).getPermission()); assertThat(BasePermission.CREATE).isEqualTo(acl.getEntries().get(2).getPermission()); - // Change each permission acl.updateAce(0, BasePermission.CREATE); acl.updateAce(1, BasePermission.DELETE); acl.updateAce(2, BasePermission.READ); - // Check the change was successfully made assertThat(BasePermission.CREATE).isEqualTo(acl.getEntries().get(0).getPermission()); assertThat(BasePermission.DELETE).isEqualTo(acl.getEntries().get(1).getPermission()); @@ -414,37 +389,22 @@ public class AclImplTests { @Test public void auditableEntryFlagsAreUpdatedCorrectly() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_AUDITING", "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_AUDITING", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - false, new PrincipalSid("joe")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, false, + new PrincipalSid("joe")); MockAclService service = new MockAclService(); - - acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), - true); - acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_USER_READ"), - true); + acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), true); + acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_USER_READ"), true); service.updateAcl(acl); - - assertThat(((AuditableAccessControlEntry) acl.getEntries().get(0)) - .isAuditFailure()) - .isFalse(); - assertThat(((AuditableAccessControlEntry) acl.getEntries().get(1)) - .isAuditFailure()) - .isFalse(); - assertThat(((AuditableAccessControlEntry) acl.getEntries().get(0)) - .isAuditSuccess()) - .isFalse(); - assertThat(((AuditableAccessControlEntry) acl.getEntries().get(1)) - .isAuditSuccess()) - .isFalse(); - + assertThat(((AuditableAccessControlEntry) acl.getEntries().get(0)).isAuditFailure()).isFalse(); + assertThat(((AuditableAccessControlEntry) acl.getEntries().get(1)).isAuditFailure()).isFalse(); + assertThat(((AuditableAccessControlEntry) acl.getEntries().get(0)).isAuditSuccess()).isFalse(); + assertThat(((AuditableAccessControlEntry) acl.getEntries().get(1)).isAuditSuccess()).isFalse(); // Change each permission ((AuditableAcl) acl).updateAuditing(0, true, true); ((AuditableAcl) acl).updateAuditing(1, true, true); - // Check the change was successfuly made assertThat(acl.getEntries()).extracting("auditSuccess").containsOnly(true, true); assertThat(acl.getEntries()).extracting("auditFailure").containsOnly(true, true); @@ -452,86 +412,74 @@ public class AclImplTests { @Test public void gettersAndSettersAreConsistent() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, (100)); ObjectIdentity identity2 = new ObjectIdentityImpl(TARGET_CLASS, (101)); - MutableAcl acl = new AclImpl(identity, 1, authzStrategy, pgs, null, null, true, + MutableAcl acl = new AclImpl(identity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); + MutableAcl parentAcl = new AclImpl(identity2, 2, this.authzStrategy, this.pgs, null, null, true, new PrincipalSid("joe")); - MutableAcl parentAcl = new AclImpl(identity2, 2, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); MockAclService service = new MockAclService(); - acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), - true); - acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_USER_READ"), - true); + acl.insertAce(0, BasePermission.READ, new GrantedAuthoritySid("ROLE_USER_READ"), true); + acl.insertAce(1, BasePermission.WRITE, new GrantedAuthoritySid("ROLE_USER_READ"), true); service.updateAcl(acl); - assertThat(1).isEqualTo(acl.getId()); assertThat(identity).isEqualTo(acl.getObjectIdentity()); assertThat(new PrincipalSid("joe")).isEqualTo(acl.getOwner()); assertThat(acl.getParentAcl()).isNull(); assertThat(acl.isEntriesInheriting()).isTrue(); assertThat(acl.getEntries()).hasSize(2); - acl.setParent(parentAcl); assertThat(parentAcl).isEqualTo(acl.getParentAcl()); - acl.setEntriesInheriting(false); assertThat(acl.isEntriesInheriting()).isFalse(); - acl.setOwner(new PrincipalSid("ben")); assertThat(new PrincipalSid("ben")).isEqualTo(acl.getOwner()); } @Test public void isSidLoadedBehavesAsExpected() { - List loadedSids = Arrays.asList(new PrincipalSid("ben"), - new GrantedAuthoritySid("ROLE_IGNORED")); - MutableAcl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, - loadedSids, true, new PrincipalSid("joe")); - + List loadedSids = Arrays.asList(new PrincipalSid("ben"), new GrantedAuthoritySid("ROLE_IGNORED")); + MutableAcl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, loadedSids, true, + new PrincipalSid("joe")); assertThat(acl.isSidLoaded(loadedSids)).isTrue(); - assertThat(acl.isSidLoaded(Arrays.asList(new GrantedAuthoritySid("ROLE_IGNORED"), - new PrincipalSid("ben")))) - .isTrue(); - assertThat(acl.isSidLoaded(Arrays.asList((Sid) new GrantedAuthoritySid( - "ROLE_IGNORED")))) - .isTrue(); + assertThat(acl.isSidLoaded(Arrays.asList(new GrantedAuthoritySid("ROLE_IGNORED"), new PrincipalSid("ben")))) + .isTrue(); + assertThat(acl.isSidLoaded(Arrays.asList((Sid) new GrantedAuthoritySid("ROLE_IGNORED")))).isTrue(); assertThat(acl.isSidLoaded(BEN)).isTrue(); assertThat(acl.isSidLoaded(null)).isTrue(); assertThat(acl.isSidLoaded(new ArrayList<>(0))).isTrue(); - assertThat(acl.isSidLoaded(Arrays.asList(new GrantedAuthoritySid( - "ROLE_IGNORED"), new GrantedAuthoritySid("ROLE_IGNORED")))) - .isTrue(); - assertThat(acl.isSidLoaded(Arrays.asList(new GrantedAuthoritySid( - "ROLE_GENERAL"), new GrantedAuthoritySid("ROLE_IGNORED")))) - .isFalse(); - assertThat(acl.isSidLoaded(Arrays.asList(new GrantedAuthoritySid( - "ROLE_IGNORED"), new GrantedAuthoritySid("ROLE_GENERAL")))) - .isFalse(); + assertThat(acl.isSidLoaded( + Arrays.asList(new GrantedAuthoritySid("ROLE_IGNORED"), new GrantedAuthoritySid("ROLE_IGNORED")))) + .isTrue(); + assertThat(acl.isSidLoaded( + Arrays.asList(new GrantedAuthoritySid("ROLE_GENERAL"), new GrantedAuthoritySid("ROLE_IGNORED")))) + .isFalse(); + assertThat(acl.isSidLoaded( + Arrays.asList(new GrantedAuthoritySid("ROLE_IGNORED"), new GrantedAuthoritySid("ROLE_GENERAL")))) + .isFalse(); } @Test(expected = NotFoundException.class) public void insertAceRaisesNotFoundExceptionForIndexLessThanZero() { - AclImpl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + AclImpl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); acl.insertAce(-1, mock(Permission.class), mock(Sid.class), true); } @Test(expected = NotFoundException.class) public void deleteAceRaisesNotFoundExceptionForIndexLessThanZero() { - AclImpl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + AclImpl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); acl.deleteAce(-1); } @Test(expected = NotFoundException.class) public void insertAceRaisesNotFoundExceptionForIndexGreaterThanSize() { - AclImpl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + AclImpl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); // Insert at zero, OK. acl.insertAce(0, mock(Permission.class), mock(Sid.class), true); // Size is now 1 @@ -541,8 +489,8 @@ public class AclImplTests { // SEC-1151 @Test(expected = NotFoundException.class) public void deleteAceRaisesNotFoundExceptionForIndexEqualToSize() { - AclImpl acl = new AclImpl(objectIdentity, 1, authzStrategy, pgs, null, null, - true, new PrincipalSid("joe")); + AclImpl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, this.pgs, null, null, true, + new PrincipalSid("joe")); acl.insertAce(0, mock(Permission.class), mock(Sid.class), true); // Size is now 1 acl.deleteAce(1); @@ -551,12 +499,9 @@ public class AclImplTests { // SEC-1795 @Test public void changingParentIsSuccessful() { - AclImpl parentAcl = new AclImpl(objectIdentity, 1L, authzStrategy, - mockAuditLogger); - AclImpl childAcl = new AclImpl(objectIdentity, 2L, authzStrategy, mockAuditLogger); - AclImpl changeParentAcl = new AclImpl(objectIdentity, 3L, authzStrategy, - mockAuditLogger); - + AclImpl parentAcl = new AclImpl(this.objectIdentity, 1L, this.authzStrategy, this.mockAuditLogger); + AclImpl childAcl = new AclImpl(this.objectIdentity, 2L, this.authzStrategy, this.mockAuditLogger); + AclImpl changeParentAcl = new AclImpl(this.objectIdentity, 3L, this.authzStrategy, this.mockAuditLogger); childAcl.setParent(parentAcl); childAcl.setParent(changeParentAcl); } @@ -564,11 +509,12 @@ public class AclImplTests { // SEC-2342 @Test public void maskPermissionGrantingStrategy() { - DefaultPermissionGrantingStrategy maskPgs = new MaskPermissionGrantingStrategy(mockAuditLogger); + DefaultPermissionGrantingStrategy maskPgs = new MaskPermissionGrantingStrategy(this.mockAuditLogger); MockAclService service = new MockAclService(); - AclImpl acl = new AclImpl(objectIdentity, 1, authzStrategy, maskPgs, null, null, - true, new PrincipalSid("joe")); - Permission permission = permissionFactory.buildFromMask(BasePermission.READ.getMask() | BasePermission.WRITE.getMask()); + AclImpl acl = new AclImpl(this.objectIdentity, 1, this.authzStrategy, maskPgs, null, null, true, + new PrincipalSid("joe")); + Permission permission = this.permissionFactory + .buildFromMask(BasePermission.READ.getMask() | BasePermission.WRITE.getMask()); Sid sid = new PrincipalSid("ben"); acl.insertAce(0, permission, sid, true); service.updateAcl(acl); @@ -579,27 +525,21 @@ public class AclImplTests { @Test public void hashCodeWithoutStackOverFlow() throws Exception { - //given Sid sid = new PrincipalSid("pSid"); ObjectIdentity oid = new ObjectIdentityImpl("type", 1); AclAuthorizationStrategy authStrategy = new AclAuthorizationStrategyImpl(new SimpleGrantedAuthority("role")); PermissionGrantingStrategy grantingStrategy = new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()); - - AclImpl acl = new AclImpl(oid, 1L, authStrategy, grantingStrategy, null, null, false, sid); + AclImpl acl = new AclImpl(oid, 1L, authStrategy, grantingStrategy, null, null, false, sid); AccessControlEntryImpl ace = new AccessControlEntryImpl(1L, acl, sid, BasePermission.READ, true, true, true); - - Field fieldAces = FieldUtils.getField(AclImpl.class, "aces"); + Field fieldAces = FieldUtils.getField(AclImpl.class, "aces"); fieldAces.setAccessible(true); List aces = (List) fieldAces.get(acl); aces.add(ace); - //when - then none StackOverFlowError been raised ace.hashCode(); } - // ~ Inner Classes - // ================================================================================================== - private static class MaskPermissionGrantingStrategy extends DefaultPermissionGrantingStrategy { + MaskPermissionGrantingStrategy(AuditLogger auditLogger) { super(auditLogger); } @@ -611,26 +551,28 @@ public class AclImplTests { } return super.isGranted(ace, p); } + } private class MockAclService implements MutableAclService { - public MutableAcl createAcl(ObjectIdentity objectIdentity) - throws AlreadyExistsException { + + @Override + public MutableAcl createAcl(ObjectIdentity objectIdentity) throws AlreadyExistsException { return null; } - public void deleteAcl(ObjectIdentity objectIdentity, boolean deleteChildren) - throws ChildrenExistException { + @Override + public void deleteAcl(ObjectIdentity objectIdentity, boolean deleteChildren) throws ChildrenExistException { } /* * Mock implementation that populates the aces list with fully initialized * AccessControlEntries * - * @see - * org.springframework.security.acls.MutableAclService#updateAcl(org.springframework - * .security.acls.MutableAcl) + * @see org.springframework.security.acls.MutableAclService#updateAcl(org. + * springframework .security.acls.MutableAcl) */ + @Override @SuppressWarnings("unchecked") public MutableAcl updateAcl(MutableAcl acl) throws NotFoundException { List oldAces = acl.getEntries(); @@ -640,45 +582,47 @@ public class AclImplTests { try { newAces = (List) acesField.get(acl); newAces.clear(); - for (int i = 0; i < oldAces.size(); i++) { AccessControlEntry ac = oldAces.get(i); // Just give an ID to all this acl's aces, rest of the fields are just // copied - newAces.add(new AccessControlEntryImpl((i + 1), ac.getAcl(), ac - .getSid(), ac.getPermission(), ac.isGranting(), - ((AuditableAccessControlEntry) ac).isAuditSuccess(), + newAces.add(new AccessControlEntryImpl((i + 1), ac.getAcl(), ac.getSid(), ac.getPermission(), + ac.isGranting(), ((AuditableAccessControlEntry) ac).isAuditSuccess(), ((AuditableAccessControlEntry) ac).isAuditFailure())); } } - catch (IllegalAccessException e) { - e.printStackTrace(); + catch (IllegalAccessException ex) { + ex.printStackTrace(); } - return acl; } + @Override public List findChildren(ObjectIdentity parentIdentity) { return null; } + @Override public Acl readAclById(ObjectIdentity object) throws NotFoundException { return null; } - public Acl readAclById(ObjectIdentity object, List sids) + @Override + public Acl readAclById(ObjectIdentity object, List sids) throws NotFoundException { + return null; + } + + @Override + public Map readAclsById(List objects) throws NotFoundException { + return null; + } + + @Override + public Map readAclsById(List objects, List sids) throws NotFoundException { return null; } - public Map readAclsById(List objects) - throws NotFoundException { - return null; - } - - public Map readAclsById(List objects, - List sids) throws NotFoundException { - return null; - } } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/AclImplementationSecurityCheckTests.java b/acl/src/test/java/org/springframework/security/acls/domain/AclImplementationSecurityCheckTests.java index 0eab9940ac..9a121f71ab 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/AclImplementationSecurityCheckTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/AclImplementationSecurityCheckTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.*; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; -import org.junit.*; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.MutableAcl; @@ -28,6 +30,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.fail; + /** * Test class for {@link AclAuthorizationStrategyImpl} and {@link AclImpl} security * checks. @@ -35,10 +39,8 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Andrei Stefan */ public class AclImplementationSecurityCheckTests { - private static final String TARGET_CLASS = "org.springframework.security.acls.TargetObject"; - // ~ Methods - // ======================================================================================================== + private static final String TARGET_CLASS = "org.springframework.security.acls.TargetObject"; @Before public void setUp() { @@ -52,50 +54,38 @@ public class AclImplementationSecurityCheckTests { @Test public void testSecurityCheckNoACEs() { - Authentication auth = new TestingAuthenticationToken("user", "password", - "ROLE_GENERAL", "ROLE_AUDITING", "ROLE_OWNERSHIP"); + Authentication auth = new TestingAuthenticationToken("user", "password", "ROLE_GENERAL", "ROLE_AUDITING", + "ROLE_OWNERSHIP"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 100L); AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); - - Acl acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, - new ConsoleAuditLogger()); - - aclAuthorizationStrategy.securityCheck(acl, - AclAuthorizationStrategy.CHANGE_GENERAL); - aclAuthorizationStrategy.securityCheck(acl, - AclAuthorizationStrategy.CHANGE_AUDITING); - aclAuthorizationStrategy.securityCheck(acl, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); - + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); + Acl acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, new ConsoleAuditLogger()); + aclAuthorizationStrategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_GENERAL); + aclAuthorizationStrategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_AUDITING); + aclAuthorizationStrategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_OWNERSHIP); // Create another authorization strategy AclAuthorizationStrategy aclAuthorizationStrategy2 = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_ONE"), new SimpleGrantedAuthority( - "ROLE_TWO"), new SimpleGrantedAuthority("ROLE_THREE")); - Acl acl2 = new AclImpl(identity, 1L, aclAuthorizationStrategy2, - new ConsoleAuditLogger()); + new SimpleGrantedAuthority("ROLE_ONE"), new SimpleGrantedAuthority("ROLE_TWO"), + new SimpleGrantedAuthority("ROLE_THREE")); + Acl acl2 = new AclImpl(identity, 1L, aclAuthorizationStrategy2, new ConsoleAuditLogger()); // Check access in case the principal has no authorization rights try { - aclAuthorizationStrategy2.securityCheck(acl2, - AclAuthorizationStrategy.CHANGE_GENERAL); + aclAuthorizationStrategy2.securityCheck(acl2, AclAuthorizationStrategy.CHANGE_GENERAL); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { } try { - aclAuthorizationStrategy2.securityCheck(acl2, - AclAuthorizationStrategy.CHANGE_AUDITING); + aclAuthorizationStrategy2.securityCheck(acl2, AclAuthorizationStrategy.CHANGE_AUDITING); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { } try { - aclAuthorizationStrategy2.securityCheck(acl2, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); + aclAuthorizationStrategy2.securityCheck(acl2, AclAuthorizationStrategy.CHANGE_OWNERSHIP); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { @@ -105,100 +95,71 @@ public class AclImplementationSecurityCheckTests { @Test public void testSecurityCheckWithMultipleACEs() { // Create a simple authentication with ROLE_GENERAL - Authentication auth = new TestingAuthenticationToken("user", "password", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("user", "password", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 100L); // Authorization strategy will require a different role for each access AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); - + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); // Let's give the principal the ADMINISTRATION permission, without // granting access - MutableAcl aclFirstDeny = new AclImpl(identity, 1L, - aclAuthorizationStrategy, new ConsoleAuditLogger()); - aclFirstDeny.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), - false); - + MutableAcl aclFirstDeny = new AclImpl(identity, 1L, aclAuthorizationStrategy, new ConsoleAuditLogger()); + aclFirstDeny.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), false); // The CHANGE_GENERAL test should pass as the principal has ROLE_GENERAL - aclAuthorizationStrategy.securityCheck(aclFirstDeny, - AclAuthorizationStrategy.CHANGE_GENERAL); - + aclAuthorizationStrategy.securityCheck(aclFirstDeny, AclAuthorizationStrategy.CHANGE_GENERAL); // The CHANGE_AUDITING and CHANGE_OWNERSHIP should fail since the // principal doesn't have these authorities, // nor granting access try { - aclAuthorizationStrategy.securityCheck(aclFirstDeny, - AclAuthorizationStrategy.CHANGE_AUDITING); + aclAuthorizationStrategy.securityCheck(aclFirstDeny, AclAuthorizationStrategy.CHANGE_AUDITING); fail("It should have thrown AccessDeniedException"); } catch (AccessDeniedException expected) { } try { - aclAuthorizationStrategy.securityCheck(aclFirstDeny, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); + aclAuthorizationStrategy.securityCheck(aclFirstDeny, AclAuthorizationStrategy.CHANGE_OWNERSHIP); fail("It should have thrown AccessDeniedException"); } catch (AccessDeniedException expected) { } - // Add granting access to this principal - aclFirstDeny.insertAce(1, BasePermission.ADMINISTRATION, new PrincipalSid(auth), - true); + aclFirstDeny.insertAce(1, BasePermission.ADMINISTRATION, new PrincipalSid(auth), true); // and try again for CHANGE_AUDITING - the first ACE's granting flag // (false) will deny this access try { - aclAuthorizationStrategy.securityCheck(aclFirstDeny, - AclAuthorizationStrategy.CHANGE_AUDITING); + aclAuthorizationStrategy.securityCheck(aclFirstDeny, AclAuthorizationStrategy.CHANGE_AUDITING); fail("It should have thrown AccessDeniedException"); } catch (AccessDeniedException expected) { } - // Create another ACL and give the principal the ADMINISTRATION // permission, with granting access - MutableAcl aclFirstAllow = new AclImpl(identity, 1L, - aclAuthorizationStrategy, new ConsoleAuditLogger()); - aclFirstAllow.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), - true); - + MutableAcl aclFirstAllow = new AclImpl(identity, 1L, aclAuthorizationStrategy, new ConsoleAuditLogger()); + aclFirstAllow.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), true); // The CHANGE_AUDITING test should pass as there is one ACE with // granting access - - aclAuthorizationStrategy.securityCheck(aclFirstAllow, - AclAuthorizationStrategy.CHANGE_AUDITING); - + aclAuthorizationStrategy.securityCheck(aclFirstAllow, AclAuthorizationStrategy.CHANGE_AUDITING); // Add a deny ACE and test again for CHANGE_AUDITING - aclFirstAllow.insertAce(1, BasePermission.ADMINISTRATION, new PrincipalSid(auth), - false); + aclFirstAllow.insertAce(1, BasePermission.ADMINISTRATION, new PrincipalSid(auth), false); try { - aclAuthorizationStrategy.securityCheck(aclFirstAllow, - AclAuthorizationStrategy.CHANGE_AUDITING); - + aclAuthorizationStrategy.securityCheck(aclFirstAllow, AclAuthorizationStrategy.CHANGE_AUDITING); } catch (AccessDeniedException notExpected) { fail("It shouldn't have thrown AccessDeniedException"); } - // Create an ACL with no ACE - MutableAcl aclNoACE = new AclImpl(identity, 1L, - aclAuthorizationStrategy, new ConsoleAuditLogger()); + MutableAcl aclNoACE = new AclImpl(identity, 1L, aclAuthorizationStrategy, new ConsoleAuditLogger()); try { - aclAuthorizationStrategy.securityCheck(aclNoACE, - AclAuthorizationStrategy.CHANGE_AUDITING); + aclAuthorizationStrategy.securityCheck(aclNoACE, AclAuthorizationStrategy.CHANGE_AUDITING); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { - } // and still grant access for CHANGE_GENERAL try { - aclAuthorizationStrategy.securityCheck(aclNoACE, - AclAuthorizationStrategy.CHANGE_GENERAL); - + aclAuthorizationStrategy.securityCheck(aclNoACE, AclAuthorizationStrategy.CHANGE_GENERAL); } catch (NotFoundException expected) { fail("It shouldn't have thrown NotFoundException"); @@ -208,64 +169,46 @@ public class AclImplementationSecurityCheckTests { @Test public void testSecurityCheckWithInheritableACEs() { // Create a simple authentication with ROLE_GENERAL - Authentication auth = new TestingAuthenticationToken("user", "password", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("user", "password", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 100); // Authorization strategy will require a different role for each access AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_ONE"), new SimpleGrantedAuthority( - "ROLE_TWO"), new SimpleGrantedAuthority("ROLE_GENERAL")); - + new SimpleGrantedAuthority("ROLE_ONE"), new SimpleGrantedAuthority("ROLE_TWO"), + new SimpleGrantedAuthority("ROLE_GENERAL")); // Let's give the principal an ADMINISTRATION permission, with granting // access - MutableAcl parentAcl = new AclImpl(identity, 1, aclAuthorizationStrategy, - new ConsoleAuditLogger()); - parentAcl.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), - true); - MutableAcl childAcl = new AclImpl(identity, 2, aclAuthorizationStrategy, - new ConsoleAuditLogger()); - + MutableAcl parentAcl = new AclImpl(identity, 1, aclAuthorizationStrategy, new ConsoleAuditLogger()); + parentAcl.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), true); + MutableAcl childAcl = new AclImpl(identity, 2, aclAuthorizationStrategy, new ConsoleAuditLogger()); // Check against the 'child' acl, which doesn't offer any authorization // rights on CHANGE_OWNERSHIP try { - aclAuthorizationStrategy.securityCheck(childAcl, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); + aclAuthorizationStrategy.securityCheck(childAcl, AclAuthorizationStrategy.CHANGE_OWNERSHIP); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { - } - // Link the child with its parent and test again against the // CHANGE_OWNERSHIP right childAcl.setParent(parentAcl); childAcl.setEntriesInheriting(true); try { - aclAuthorizationStrategy.securityCheck(childAcl, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); - + aclAuthorizationStrategy.securityCheck(childAcl, AclAuthorizationStrategy.CHANGE_OWNERSHIP); } catch (NotFoundException expected) { fail("It shouldn't have thrown NotFoundException"); } - // Create a root parent and link it to the middle parent - MutableAcl rootParentAcl = new AclImpl(identity, 1, aclAuthorizationStrategy, - new ConsoleAuditLogger()); - parentAcl = new AclImpl(identity, 1, aclAuthorizationStrategy, - new ConsoleAuditLogger()); - rootParentAcl.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), - true); + MutableAcl rootParentAcl = new AclImpl(identity, 1, aclAuthorizationStrategy, new ConsoleAuditLogger()); + parentAcl = new AclImpl(identity, 1, aclAuthorizationStrategy, new ConsoleAuditLogger()); + rootParentAcl.insertAce(0, BasePermission.ADMINISTRATION, new PrincipalSid(auth), true); parentAcl.setEntriesInheriting(true); parentAcl.setParent(rootParentAcl); childAcl.setParent(parentAcl); try { - aclAuthorizationStrategy.securityCheck(childAcl, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); - + aclAuthorizationStrategy.securityCheck(childAcl, AclAuthorizationStrategy.CHANGE_OWNERSHIP); } catch (NotFoundException expected) { fail("It shouldn't have thrown NotFoundException"); @@ -274,39 +217,34 @@ public class AclImplementationSecurityCheckTests { @Test public void testSecurityCheckPrincipalOwner() { - Authentication auth = new TestingAuthenticationToken("user", "password", - "ROLE_ONE"); + Authentication auth = new TestingAuthenticationToken("user", "password", "ROLE_ONE"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 100); AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); - + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); Acl acl = new AclImpl(identity, 1, aclAuthorizationStrategy, - new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()), null, - null, false, new PrincipalSid(auth)); + new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()), null, null, false, + new PrincipalSid(auth)); try { - aclAuthorizationStrategy.securityCheck(acl, - AclAuthorizationStrategy.CHANGE_GENERAL); + aclAuthorizationStrategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_GENERAL); } catch (AccessDeniedException notExpected) { fail("It shouldn't have thrown AccessDeniedException"); } try { - aclAuthorizationStrategy.securityCheck(acl, - AclAuthorizationStrategy.CHANGE_AUDITING); + aclAuthorizationStrategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_AUDITING); fail("It shouldn't have thrown AccessDeniedException"); } catch (NotFoundException expected) { } try { - aclAuthorizationStrategy.securityCheck(acl, - AclAuthorizationStrategy.CHANGE_OWNERSHIP); + aclAuthorizationStrategy.securityCheck(acl, AclAuthorizationStrategy.CHANGE_OWNERSHIP); } catch (AccessDeniedException notExpected) { fail("It shouldn't have thrown AccessDeniedException"); } } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/AuditLoggerTests.java b/acl/src/test/java/org/springframework/security/acls/domain/AuditLoggerTests.java index 001e0450e3..e2abb35c01 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/AuditLoggerTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/AuditLoggerTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.acls.domain; import java.io.ByteArrayOutputStream; import java.io.PrintStream; @@ -24,72 +22,76 @@ import java.io.PrintStream; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.security.acls.model.AccessControlEntry; import org.springframework.security.acls.model.AuditableAccessControlEntry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Test class for {@link ConsoleAuditLogger}. * * @author Andrei Stefan */ public class AuditLoggerTests { - // ~ Instance fields - // ================================================================================================ - private PrintStream console; - private ByteArrayOutputStream bytes = new ByteArrayOutputStream(); - private ConsoleAuditLogger logger; - private AuditableAccessControlEntry ace; - // ~ Methods - // ======================================================================================================== + private PrintStream console; + + private ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + private ConsoleAuditLogger logger; + + private AuditableAccessControlEntry ace; @Before public void setUp() { - logger = new ConsoleAuditLogger(); - ace = mock(AuditableAccessControlEntry.class); - console = System.out; - System.setOut(new PrintStream(bytes)); + this.logger = new ConsoleAuditLogger(); + this.ace = mock(AuditableAccessControlEntry.class); + this.console = System.out; + System.setOut(new PrintStream(this.bytes)); } @After public void tearDown() { - System.setOut(console); - bytes.reset(); + System.setOut(this.console); + this.bytes.reset(); } @Test public void nonAuditableAceIsIgnored() { AccessControlEntry ace = mock(AccessControlEntry.class); - logger.logIfNeeded(true, ace); - assertThat(bytes.size()).isZero(); + this.logger.logIfNeeded(true, ace); + assertThat(this.bytes.size()).isZero(); } @Test public void successIsNotLoggedIfAceDoesntRequireSuccessAudit() { - when(ace.isAuditSuccess()).thenReturn(false); - logger.logIfNeeded(true, ace); - assertThat(bytes.size()).isZero(); + given(this.ace.isAuditSuccess()).willReturn(false); + this.logger.logIfNeeded(true, this.ace); + assertThat(this.bytes.size()).isZero(); } @Test public void successIsLoggedIfAceRequiresSuccessAudit() { - when(ace.isAuditSuccess()).thenReturn(true); - - logger.logIfNeeded(true, ace); - assertThat(bytes.toString()).startsWith("GRANTED due to ACE"); + given(this.ace.isAuditSuccess()).willReturn(true); + this.logger.logIfNeeded(true, this.ace); + assertThat(this.bytes.toString()).startsWith("GRANTED due to ACE"); } @Test public void failureIsntLoggedIfAceDoesntRequireFailureAudit() { - when(ace.isAuditFailure()).thenReturn(false); - logger.logIfNeeded(false, ace); - assertThat(bytes.size()).isZero(); + given(this.ace.isAuditFailure()).willReturn(false); + this.logger.logIfNeeded(false, this.ace); + assertThat(this.bytes.size()).isZero(); } @Test public void failureIsLoggedIfAceRequiresFailureAudit() { - when(ace.isAuditFailure()).thenReturn(true); - logger.logIfNeeded(false, ace); - assertThat(bytes.toString()).startsWith("DENIED due to ACE"); + given(this.ace.isAuditFailure()).willReturn(true); + this.logger.logIfNeeded(false, this.ace); + assertThat(this.bytes.toString()).startsWith("DENIED due to ACE"); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityImplTests.java b/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityImplTests.java index 93adac3837..309dc8776f 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityImplTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityImplTests.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; -import org.springframework.security.acls.domain.IdentityUnavailableException; -import org.springframework.security.acls.domain.ObjectIdentityImpl; + import org.springframework.security.acls.model.ObjectIdentity; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests for {@link ObjectIdentityImpl}. * @@ -32,9 +33,6 @@ public class ObjectIdentityImplTests { private static final String DOMAIN_CLASS = "org.springframework.security.acls.domain.ObjectIdentityImplTests$MockIdDomainObject"; - // ~ Methods - // ======================================================================================================== - @Test public void constructorsRespectRequiredFields() { // Check one-argument constructor required field @@ -44,7 +42,6 @@ public class ObjectIdentityImplTests { } catch (IllegalArgumentException expected) { } - // Check String-Serializable constructor required field try { new ObjectIdentityImpl("", 1L); @@ -52,7 +49,6 @@ public class ObjectIdentityImplTests { } catch (IllegalArgumentException expected) { } - // Check Serializable parameter is not null try { new ObjectIdentityImpl(DOMAIN_CLASS, null); @@ -60,7 +56,6 @@ public class ObjectIdentityImplTests { } catch (IllegalArgumentException expected) { } - // The correct way of using String-Serializable constructor try { new ObjectIdentityImpl(DOMAIN_CLASS, 1L); @@ -68,7 +63,6 @@ public class ObjectIdentityImplTests { catch (IllegalArgumentException notExpected) { fail("It shouldn't have thrown IllegalArgumentException"); } - // Check the Class-Serializable constructor try { new ObjectIdentityImpl(MockIdDomainObject.class, null); @@ -93,9 +87,7 @@ public class ObjectIdentityImplTests { fail("It should have thrown IdentityUnavailableException"); } catch (IdentityUnavailableException expected) { - } - // getId() should return a non-null value MockIdDomainObject mockId = new MockIdDomainObject(); try { @@ -103,9 +95,7 @@ public class ObjectIdentityImplTests { fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - // getId() should return a Serializable object mockId.setId(new MockIdDomainObject()); try { @@ -114,7 +104,6 @@ public class ObjectIdentityImplTests { } catch (IllegalArgumentException expected) { } - // getId() should return a Serializable object mockId.setId(100L); try { @@ -134,15 +123,13 @@ public class ObjectIdentityImplTests { ObjectIdentity obj = new ObjectIdentityImpl(DOMAIN_CLASS, 1L); MockIdDomainObject mockObj = new MockIdDomainObject(); mockObj.setId(1L); - String string = "SOME_STRING"; assertThat(string).isNotSameAs(obj); assertThat(obj).isNotNull(); assertThat(obj).isNotEqualTo("DIFFERENT_OBJECT_TYPE"); assertThat(obj).isNotEqualTo(new ObjectIdentityImpl(DOMAIN_CLASS, 2L)); assertThat(obj).isNotEqualTo(new ObjectIdentityImpl( - "org.springframework.security.acls.domain.ObjectIdentityImplTests$MockOtherIdDomainObject", - 1L)); + "org.springframework.security.acls.domain.ObjectIdentityImplTests$MockOtherIdDomainObject", 1L)); assertThat(new ObjectIdentityImpl(DOMAIN_CLASS, 1L)).isEqualTo(obj); assertThat(new ObjectIdentityImpl(mockObj)).isEqualTo(obj); } @@ -158,7 +145,6 @@ public class ObjectIdentityImplTests { public void longAndIntegerIdsWithSameValueAreEqualAndHaveSameHashcode() { ObjectIdentity obj = new ObjectIdentityImpl(Object.class, 5L); ObjectIdentity obj2 = new ObjectIdentityImpl(Object.class, 5); - assertThat(obj2).isEqualTo(obj); assertThat(obj2.hashCode()).isEqualTo(obj.hashCode()); } @@ -178,30 +164,32 @@ public class ObjectIdentityImplTests { assertThat(obj).isNotEqualTo(obj2); } - // ~ Inner Classes - // ================================================================================================== - private class MockIdDomainObject { + private Object id; public Object getId() { - return id; + return this.id; } public void setId(Object id) { this.id = id; } + } private class MockOtherIdDomainObject { + private Object id; public Object getId() { - return id; + return this.id; } public void setId(Object id) { this.id = id; } + } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImplTests.java b/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImplTests.java index 67e068c947..b6787f893c 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImplTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/ObjectIdentityRetrievalStrategyImplTests.java @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests for {@link ObjectIdentityRetrievalStrategyImpl} * @@ -28,33 +30,29 @@ import org.springframework.security.acls.model.ObjectIdentityRetrievalStrategy; */ public class ObjectIdentityRetrievalStrategyImplTests { - // ~ Methods - // ======================================================================================================== @Test public void testObjectIdentityCreation() { MockIdDomainObject domain = new MockIdDomainObject(); domain.setId(1); - ObjectIdentityRetrievalStrategy retStrategy = new ObjectIdentityRetrievalStrategyImpl(); ObjectIdentity identity = retStrategy.getObjectIdentity(domain); - assertThat(identity).isNotNull(); assertThat(new ObjectIdentityImpl(domain)).isEqualTo(identity); } - // ~ Inner Classes - // ================================================================================================== @SuppressWarnings("unused") private class MockIdDomainObject { private Object id; public Object getId() { - return id; + return this.id; } public void setId(Object id) { this.id = id; } + } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/PermissionTests.java b/acl/src/test/java/org/springframework/security/acls/domain/PermissionTests.java index f71c50755d..5aef8e4c74 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/PermissionTests.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/PermissionTests.java @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.domain; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.acls.domain; import org.junit.Before; import org.junit.Test; + import org.springframework.security.acls.model.Permission; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests classes associated with Permission. * @@ -32,12 +34,12 @@ public class PermissionTests { @Before public void createPermissionfactory() { - permissionFactory = new DefaultPermissionFactory(); + this.permissionFactory = new DefaultPermissionFactory(); } @Test public void basePermissionTest() { - Permission p = permissionFactory.buildFromName("WRITE"); + Permission p = this.permissionFactory.buildFromName("WRITE"); assertThat(p).isNotNull(); } @@ -45,53 +47,37 @@ public class PermissionTests { public void expectedIntegerValues() { assertThat(BasePermission.READ.getMask()).isEqualTo(1); assertThat(BasePermission.ADMINISTRATION.getMask()).isEqualTo(16); - assertThat( - new CumulativePermission().set(BasePermission.READ) - .set(BasePermission.WRITE).set(BasePermission.CREATE).getMask()) - .isEqualTo(7); - assertThat( - new CumulativePermission().set(BasePermission.READ) - .set(BasePermission.ADMINISTRATION).getMask()) - .isEqualTo(17); + assertThat(new CumulativePermission().set(BasePermission.READ).set(BasePermission.WRITE) + .set(BasePermission.CREATE).getMask()).isEqualTo(7); + assertThat(new CumulativePermission().set(BasePermission.READ).set(BasePermission.ADMINISTRATION).getMask()) + .isEqualTo(17); } @Test public void fromInteger() { - Permission permission = permissionFactory.buildFromMask(7); - permission = permissionFactory.buildFromMask(4); + Permission permission = this.permissionFactory.buildFromMask(7); + permission = this.permissionFactory.buildFromMask(4); } @Test public void stringConversion() { - permissionFactory.registerPublicPermissions(SpecialPermission.class); - - assertThat(BasePermission.READ.toString()) - .isEqualTo("BasePermission[...............................R=1]"); - + this.permissionFactory.registerPublicPermissions(SpecialPermission.class); + assertThat(BasePermission.READ.toString()).isEqualTo("BasePermission[...............................R=1]"); + assertThat(BasePermission.ADMINISTRATION.toString()) + .isEqualTo("BasePermission[...........................A....=16]"); + assertThat(new CumulativePermission().set(BasePermission.READ).toString()) + .isEqualTo("CumulativePermission[...............................R=1]"); assertThat( - BasePermission.ADMINISTRATION.toString()) - .isEqualTo("BasePermission[...........................A....=16]"); - - assertThat( - new CumulativePermission().set(BasePermission.READ).toString()) - .isEqualTo("CumulativePermission[...............................R=1]"); - - assertThat(new CumulativePermission().set(SpecialPermission.ENTER) - .set(BasePermission.ADMINISTRATION).toString()) - .isEqualTo("CumulativePermission[..........................EA....=48]"); - - assertThat(new CumulativePermission().set(BasePermission.ADMINISTRATION) - .set(BasePermission.READ).toString()) - .isEqualTo("CumulativePermission[...........................A...R=17]"); - - assertThat(new CumulativePermission().set(BasePermission.ADMINISTRATION) - .set(BasePermission.READ).clear(BasePermission.ADMINISTRATION) - .toString()) - .isEqualTo("CumulativePermission[...............................R=1]"); - - assertThat(new CumulativePermission().set(BasePermission.ADMINISTRATION) - .set(BasePermission.READ).clear(BasePermission.ADMINISTRATION) - .clear(BasePermission.READ).toString()) - .isEqualTo("CumulativePermission[................................=0]"); + new CumulativePermission().set(SpecialPermission.ENTER).set(BasePermission.ADMINISTRATION).toString()) + .isEqualTo("CumulativePermission[..........................EA....=48]"); + assertThat(new CumulativePermission().set(BasePermission.ADMINISTRATION).set(BasePermission.READ).toString()) + .isEqualTo("CumulativePermission[...........................A...R=17]"); + assertThat(new CumulativePermission().set(BasePermission.ADMINISTRATION).set(BasePermission.READ) + .clear(BasePermission.ADMINISTRATION).toString()) + .isEqualTo("CumulativePermission[...............................R=1]"); + assertThat(new CumulativePermission().set(BasePermission.ADMINISTRATION).set(BasePermission.READ) + .clear(BasePermission.ADMINISTRATION).clear(BasePermission.READ).toString()) + .isEqualTo("CumulativePermission[................................=0]"); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/domain/SpecialPermission.java b/acl/src/test/java/org/springframework/security/acls/domain/SpecialPermission.java index 0aedb10afe..fb11be3008 100644 --- a/acl/src/test/java/org/springframework/security/acls/domain/SpecialPermission.java +++ b/acl/src/test/java/org/springframework/security/acls/domain/SpecialPermission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.domain; import org.springframework.security.acls.model.Permission; @@ -23,10 +24,13 @@ import org.springframework.security.acls.model.Permission; * @author Ben Alex */ public class SpecialPermission extends BasePermission { + public static final Permission ENTER = new SpecialPermission(1 << 5, 'E'); // 32 + public static final Permission LEAVE = new SpecialPermission(1 << 6, 'L'); protected SpecialPermission(int mask, char code) { super(mask, code); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/AbstractBasicLookupStrategyTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/AbstractBasicLookupStrategyTests.java index a2cefecab5..6000bab596 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/AbstractBasicLookupStrategyTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/AbstractBasicLookupStrategyTests.java @@ -13,19 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; -import static org.assertj.core.api.Assertions.*; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import javax.sql.DataSource; import net.sf.ehcache.Cache; import net.sf.ehcache.CacheManager; import net.sf.ehcache.Ehcache; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; -import org.junit.*; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.security.acls.TargetObject; import org.springframework.security.acls.TargetObjectWithUUID; -import org.springframework.security.acls.domain.*; +import org.springframework.security.acls.domain.AclAuthorizationStrategy; +import org.springframework.security.acls.domain.AclAuthorizationStrategyImpl; +import org.springframework.security.acls.domain.BasePermission; +import org.springframework.security.acls.domain.ConsoleAuditLogger; +import org.springframework.security.acls.domain.DefaultPermissionFactory; +import org.springframework.security.acls.domain.DefaultPermissionGrantingStrategy; +import org.springframework.security.acls.domain.EhCacheBasedAclCache; +import org.springframework.security.acls.domain.GrantedAuthoritySid; +import org.springframework.security.acls.domain.ObjectIdentityImpl; +import org.springframework.security.acls.domain.PrincipalSid; import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.AuditableAccessControlEntry; import org.springframework.security.acls.model.MutableAcl; @@ -35,9 +54,8 @@ import org.springframework.security.acls.model.Permission; import org.springframework.security.acls.model.Sid; import org.springframework.security.core.authority.SimpleGrantedAuthority; -import java.util.*; - -import javax.sql.DataSource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** * Tests {@link BasicLookupStrategy} @@ -47,19 +65,18 @@ import javax.sql.DataSource; public abstract class AbstractBasicLookupStrategyTests { protected static final Sid BEN_SID = new PrincipalSid("ben"); + protected static final String TARGET_CLASS = TargetObject.class.getName(); + protected static final String TARGET_CLASS_WITH_UUID = TargetObjectWithUUID.class.getName(); + protected static final UUID OBJECT_IDENTITY_UUID = UUID.randomUUID(); + protected static final Long OBJECT_IDENTITY_LONG_AS_UUID = 110L; - // ~ Instance fields - // ================================================================================================ - private BasicLookupStrategy strategy; - private static CacheManager cacheManager; - // ~ Methods - // ======================================================================================================== + private static CacheManager cacheManager; public abstract JdbcTemplate getJdbcTemplate(); @@ -80,44 +97,41 @@ public abstract class AbstractBasicLookupStrategyTests { @Before public void populateDatabase() { String query = "INSERT INTO acl_sid(ID,PRINCIPAL,SID) VALUES (1,1,'ben');" - + "INSERT INTO acl_class(ID,CLASS) VALUES (2,'" + TARGET_CLASS + "');" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (1,2,100,null,1,1);" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (2,2,101,1,1,1);" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (3,2,102,2,1,1);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (1,1,0,1,1,1,0,0);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (2,1,1,1,2,0,0,0);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (3,2,0,1,8,1,0,0);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (4,3,0,1,8,0,0,0);"; + + "INSERT INTO acl_class(ID,CLASS) VALUES (2,'" + TARGET_CLASS + "');" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (1,2,100,null,1,1);" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (2,2,101,1,1,1);" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (3,2,102,2,1,1);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (1,1,0,1,1,1,0,0);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (2,1,1,1,2,0,0,0);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (3,2,0,1,8,1,0,0);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (4,3,0,1,8,0,0,0);"; getJdbcTemplate().execute(query); } @Before public void initializeBeans() { - strategy = new BasicLookupStrategy(getDataSource(), aclCache(), aclAuthStrategy(), - new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger())); - strategy.setPermissionFactory(new DefaultPermissionFactory()); + this.strategy = new BasicLookupStrategy(getDataSource(), aclCache(), aclAuthStrategy(), + new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger())); + this.strategy.setPermissionFactory(new DefaultPermissionFactory()); } protected AclAuthorizationStrategy aclAuthStrategy() { - return new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_ADMINISTRATOR")); + return new AclAuthorizationStrategyImpl(new SimpleGrantedAuthority("ROLE_ADMINISTRATOR")); } protected EhCacheBasedAclCache aclCache() { - return new EhCacheBasedAclCache(getCache(), - new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()), - new AclAuthorizationStrategyImpl(new SimpleGrantedAuthority("ROLE_USER"))); + return new EhCacheBasedAclCache(getCache(), new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()), + new AclAuthorizationStrategyImpl(new SimpleGrantedAuthority("ROLE_USER"))); } - @After public void emptyDatabase() { String query = "DELETE FROM acl_entry;" + "DELETE FROM acl_object_identity WHERE ID = 9;" - + "DELETE FROM acl_object_identity WHERE ID = 8;" + "DELETE FROM acl_object_identity WHERE ID = 7;" - + "DELETE FROM acl_object_identity WHERE ID = 6;" + "DELETE FROM acl_object_identity WHERE ID = 5;" - + "DELETE FROM acl_object_identity WHERE ID = 4;" + "DELETE FROM acl_object_identity WHERE ID = 3;" - + "DELETE FROM acl_object_identity WHERE ID = 2;" + "DELETE FROM acl_object_identity WHERE ID = 1;" - + "DELETE FROM acl_class;" + "DELETE FROM acl_sid;"; + + "DELETE FROM acl_object_identity WHERE ID = 8;" + "DELETE FROM acl_object_identity WHERE ID = 7;" + + "DELETE FROM acl_object_identity WHERE ID = 6;" + "DELETE FROM acl_object_identity WHERE ID = 5;" + + "DELETE FROM acl_object_identity WHERE ID = 4;" + "DELETE FROM acl_object_identity WHERE ID = 3;" + + "DELETE FROM acl_object_identity WHERE ID = 2;" + "DELETE FROM acl_object_identity WHERE ID = 1;" + + "DELETE FROM acl_class;" + "DELETE FROM acl_sid;"; getJdbcTemplate().execute(query); } @@ -133,9 +147,8 @@ public abstract class AbstractBasicLookupStrategyTests { ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS, 101L); // Deliberately use an integer for the child, to reproduce bug report in SEC-819 ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 102); - Map map = this.strategy - .readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); + .readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); checkEntries(topParentOid, middleParentOid, childOid, map); } @@ -144,15 +157,12 @@ public abstract class AbstractBasicLookupStrategyTests { ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, 100); ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS, 101L); ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 102L); - // Objects were put in cache - strategy.readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); - + this.strategy.readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); // Let's empty the database to force acls retrieval from cache emptyDatabase(); Map map = this.strategy - .readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); - + .readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); checkEntries(topParentOid, middleParentOid, childOid, map); } @@ -161,43 +171,36 @@ public abstract class AbstractBasicLookupStrategyTests { ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, 100L); ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS, 101); ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 102L); - // Set a batch size to allow multiple database queries in order to retrieve all // acls this.strategy.setBatchSize(1); Map map = this.strategy - .readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); + .readAclsById(Arrays.asList(topParentOid, middleParentOid, childOid), null); checkEntries(topParentOid, middleParentOid, childOid, map); } private void checkEntries(ObjectIdentity topParentOid, ObjectIdentity middleParentOid, ObjectIdentity childOid, - Map map) { + Map map) { assertThat(map).hasSize(3); - MutableAcl topParent = (MutableAcl) map.get(topParentOid); MutableAcl middleParent = (MutableAcl) map.get(middleParentOid); MutableAcl child = (MutableAcl) map.get(childOid); - // Check the retrieved versions has IDs assertThat(topParent.getId()).isNotNull(); assertThat(middleParent.getId()).isNotNull(); assertThat(child.getId()).isNotNull(); - // Check their parents were correctly retrieved assertThat(topParent.getParentAcl()).isNull(); assertThat(middleParent.getParentAcl().getObjectIdentity()).isEqualTo(topParentOid); assertThat(child.getParentAcl().getObjectIdentity()).isEqualTo(middleParentOid); - // Check their ACEs were correctly retrieved assertThat(topParent.getEntries()).hasSize(2); assertThat(middleParent.getEntries()).hasSize(1); assertThat(child.getEntries()).hasSize(1); - // Check object identities were correctly retrieved assertThat(topParent.getObjectIdentity()).isEqualTo(topParentOid); assertThat(middleParent.getObjectIdentity()).isEqualTo(middleParentOid); assertThat(child.getObjectIdentity()).isEqualTo(childOid); - // Check each entry assertThat(topParent.isEntriesInheriting()).isTrue(); assertThat(Long.valueOf(1)).isEqualTo(topParent.getId()); @@ -208,14 +211,12 @@ public abstract class AbstractBasicLookupStrategyTests { assertThat(((AuditableAccessControlEntry) topParent.getEntries().get(0)).isAuditFailure()).isFalse(); assertThat(((AuditableAccessControlEntry) topParent.getEntries().get(0)).isAuditSuccess()).isFalse(); assertThat((topParent.getEntries().get(0)).isGranting()).isTrue(); - assertThat(Long.valueOf(2)).isEqualTo(topParent.getEntries().get(1).getId()); assertThat(topParent.getEntries().get(1).getPermission()).isEqualTo(BasePermission.WRITE); assertThat(topParent.getEntries().get(1).getSid()).isEqualTo(new PrincipalSid("ben")); assertThat(((AuditableAccessControlEntry) topParent.getEntries().get(1)).isAuditFailure()).isFalse(); assertThat(((AuditableAccessControlEntry) topParent.getEntries().get(1)).isAuditSuccess()).isFalse(); assertThat(topParent.getEntries().get(1).isGranting()).isFalse(); - assertThat(middleParent.isEntriesInheriting()).isTrue(); assertThat(Long.valueOf(2)).isEqualTo(middleParent.getId()); assertThat(new PrincipalSid("ben")).isEqualTo(middleParent.getOwner()); @@ -225,7 +226,6 @@ public abstract class AbstractBasicLookupStrategyTests { assertThat(((AuditableAccessControlEntry) middleParent.getEntries().get(0)).isAuditFailure()).isFalse(); assertThat(((AuditableAccessControlEntry) middleParent.getEntries().get(0)).isAuditSuccess()).isFalse(); assertThat(middleParent.getEntries().get(0).isGranting()).isTrue(); - assertThat(child.isEntriesInheriting()).isTrue(); assertThat(Long.valueOf(3)).isEqualTo(child.getId()); assertThat(new PrincipalSid("ben")).isEqualTo(child.getOwner()); @@ -241,15 +241,12 @@ public abstract class AbstractBasicLookupStrategyTests { public void testAllParentsAreRetrievedWhenChildIsLoaded() { String query = "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (6,2,103,1,1,1);"; getJdbcTemplate().execute(query); - ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, 100L); ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS, 101L); ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 102L); ObjectIdentity middleParent2Oid = new ObjectIdentityImpl(TARGET_CLASS, 103L); - // Retrieve the child Map map = this.strategy.readAclsById(Arrays.asList(childOid), null); - // Check that the child and all its parents were retrieved assertThat(map.get(childOid)).isNotNull(); assertThat(map.get(childOid).getObjectIdentity()).isEqualTo(childOid); @@ -257,7 +254,6 @@ public abstract class AbstractBasicLookupStrategyTests { assertThat(map.get(middleParentOid).getObjectIdentity()).isEqualTo(middleParentOid); assertThat(map.get(topParentOid)).isNotNull(); assertThat(map.get(topParentOid).getObjectIdentity()).isEqualTo(topParentOid); - // The second parent shouldn't have been retrieved assertThat(map.get(middleParent2Oid)).isNull(); } @@ -268,31 +264,26 @@ public abstract class AbstractBasicLookupStrategyTests { @Test public void testReadAllObjectIdentitiesWhenLastElementIsAlreadyCached() { String query = "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (6,2,105,null,1,1);" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (7,2,106,6,1,1);" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (8,2,107,6,1,1);" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (9,2,108,7,1,1);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (7,6,0,1,1,1,0,0)"; + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (7,2,106,6,1,1);" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (8,2,107,6,1,1);" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (9,2,108,7,1,1);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (7,6,0,1,1,1,0,0)"; getJdbcTemplate().execute(query); - ObjectIdentity grandParentOid = new ObjectIdentityImpl(TARGET_CLASS, 104L); ObjectIdentity parent1Oid = new ObjectIdentityImpl(TARGET_CLASS, 105L); ObjectIdentity parent2Oid = new ObjectIdentityImpl(TARGET_CLASS, 106); ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 107); - // First lookup only child, thus populating the cache with grandParent, // parent1 // and child List checkPermission = Arrays.asList(BasePermission.READ); List sids = Arrays.asList(BEN_SID); List childOids = Arrays.asList(childOid); - - strategy.setBatchSize(6); - Map foundAcls = strategy.readAclsById(childOids, sids); - + this.strategy.setBatchSize(6); + Map foundAcls = this.strategy.readAclsById(childOids, sids); Acl foundChildAcl = foundAcls.get(childOid); assertThat(foundChildAcl).isNotNull(); assertThat(foundChildAcl.isGranted(checkPermission, sids, false)).isTrue(); - // Search for object identities has to be done in the following order: // last // element have to be one which @@ -300,12 +291,11 @@ public abstract class AbstractBasicLookupStrategyTests { // cache List allOids = Arrays.asList(grandParentOid, parent1Oid, parent2Oid, childOid); try { - foundAcls = strategy.readAclsById(allOids, sids); - - } catch (NotFoundException notExpected) { + foundAcls = this.strategy.readAclsById(allOids, sids); + } + catch (NotFoundException notExpected) { fail("It shouldn't have thrown NotFoundException"); } - Acl foundParent2Acl = foundAcls.get(parent2Oid); assertThat(foundParent2Acl).isNotNull(); assertThat(foundParent2Acl.isGranted(checkPermission, sids, false)).isTrue(); @@ -314,26 +304,21 @@ public abstract class AbstractBasicLookupStrategyTests { @Test(expected = IllegalArgumentException.class) public void nullOwnerIsNotSupported() { String query = "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (6,2,104,null,null,1);"; - getJdbcTemplate().execute(query); - ObjectIdentity oid = new ObjectIdentityImpl(TARGET_CLASS, 104L); - - strategy.readAclsById(Arrays.asList(oid), Arrays.asList(BEN_SID)); + this.strategy.readAclsById(Arrays.asList(oid), Arrays.asList(BEN_SID)); } @Test public void testCreatePrincipalSid() { - Sid result = strategy.createSid(true, "sid"); - + Sid result = this.strategy.createSid(true, "sid"); assertThat(result.getClass()).isEqualTo(PrincipalSid.class); assertThat(((PrincipalSid) result).getPrincipal()).isEqualTo("sid"); } @Test public void testCreateGrantedAuthority() { - Sid result = strategy.createSid(false, "sid"); - + Sid result = this.strategy.createSid(false, "sid"); assertThat(result.getClass()).isEqualTo(GrantedAuthoritySid.class); assertThat(((GrantedAuthoritySid) result).getGrantedAuthority()).isEqualTo("sid"); } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/AclClassIdUtilsTest.java b/acl/src/test/java/org/springframework/security/acls/jdbc/AclClassIdUtilsTests.java similarity index 62% rename from acl/src/test/java/org/springframework/security/acls/jdbc/AclClassIdUtilsTest.java rename to acl/src/test/java/org/springframework/security/acls/jdbc/AclClassIdUtilsTests.java index 5dd2e36f31..99660737d7 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/AclClassIdUtilsTest.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/AclClassIdUtilsTests.java @@ -13,38 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.core.convert.ConversionService; - import java.io.Serializable; import java.math.BigInteger; import java.sql.ResultSet; import java.sql.SQLException; import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.core.convert.ConversionService; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; /** * Tests for {@link AclClassIdUtils}. + * * @author paulwheeler */ @RunWith(MockitoJUnitRunner.class) -public class AclClassIdUtilsTest { +public class AclClassIdUtilsTests { private static final Long DEFAULT_IDENTIFIER = 999L; + private static final BigInteger BIGINT_IDENTIFIER = new BigInteger("999"); + private static final String DEFAULT_IDENTIFIER_AS_STRING = DEFAULT_IDENTIFIER.toString(); @Mock private ResultSet resultSet; + @Mock private ConversionService conversionService; @@ -52,124 +57,82 @@ public class AclClassIdUtilsTest { @Before public void setUp() { - aclClassIdUtils = new AclClassIdUtils(); + this.aclClassIdUtils = new AclClassIdUtils(); } @Test public void shouldReturnLongIfIdentifierIsLong() throws SQLException { - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER, resultSet); - - // then + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnLongIfIdentifierIsBigInteger() throws SQLException { - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(BIGINT_IDENTIFIER, resultSet); - - // then + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(BIGINT_IDENTIFIER, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnLongIfClassIdTypeIsNull() throws SQLException { - // given - given(resultSet.getString("class_id_type")).willReturn(null); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willReturn(null); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnLongIfNoClassIdTypeColumn() throws SQLException { - // given - given(resultSet.getString("class_id_type")).willThrow(SQLException.class); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willThrow(SQLException.class); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnLongIfTypeClassNotFound() throws SQLException { - // given - given(resultSet.getString("class_id_type")).willReturn("com.example.UnknownType"); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willReturn("com.example.UnknownType"); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnLongEvenIfCustomConversionServiceDoesNotSupportLongConversion() throws SQLException { - // given - given(resultSet.getString("class_id_type")).willReturn("java.lang.Long"); - given(conversionService.canConvert(String.class, Long.class)).willReturn(false); - aclClassIdUtils.setConversionService(conversionService); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willReturn("java.lang.Long"); + given(this.conversionService.canConvert(String.class, Long.class)).willReturn(false); + this.aclClassIdUtils.setConversionService(this.conversionService); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnLongWhenLongClassIdType() throws SQLException { - // given - given(resultSet.getString("class_id_type")).willReturn("java.lang.Long"); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willReturn("java.lang.Long"); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(DEFAULT_IDENTIFIER_AS_STRING, this.resultSet); assertThat(newIdentifier).isEqualTo(DEFAULT_IDENTIFIER); } @Test public void shouldReturnUUIDWhenUUIDClassIdType() throws SQLException { - // given UUID identifier = UUID.randomUUID(); - given(resultSet.getString("class_id_type")).willReturn("java.util.UUID"); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(identifier.toString(), resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willReturn("java.util.UUID"); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(identifier.toString(), this.resultSet); assertThat(newIdentifier).isEqualTo(identifier); } @Test public void shouldReturnStringWhenStringClassIdType() throws SQLException { - // given String identifier = "MY_STRING_IDENTIFIER"; - given(resultSet.getString("class_id_type")).willReturn("java.lang.String"); - - // when - Serializable newIdentifier = aclClassIdUtils.identifierFrom(identifier, resultSet); - - // then + given(this.resultSet.getString("class_id_type")).willReturn("java.lang.String"); + Serializable newIdentifier = this.aclClassIdUtils.identifierFrom(identifier, this.resultSet); assertThat(newIdentifier).isEqualTo(identifier); } @Test(expected = IllegalArgumentException.class) public void shouldNotAcceptNullConversionServiceInConstruction() { - // when new AclClassIdUtils(null); } @Test(expected = IllegalArgumentException.class) public void shouldNotAcceptNullConversionServiceInSetter() { - // when - aclClassIdUtils.setConversionService(null); + this.aclClassIdUtils.setConversionService(null); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTests.java index 4f5f7c13a7..4e117ee5c0 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import javax.sql.DataSource; import org.junit.AfterClass; import org.junit.BeforeClass; + import org.springframework.jdbc.core.JdbcTemplate; /** @@ -28,8 +30,8 @@ import org.springframework.jdbc.core.JdbcTemplate; * @author Paul Wheeler */ public class BasicLookupStrategyTests extends AbstractBasicLookupStrategyTests { - private static final BasicLookupStrategyTestsDbHelper DATABASE_HELPER = new BasicLookupStrategyTestsDbHelper(); + private static final BasicLookupStrategyTestsDbHelper DATABASE_HELPER = new BasicLookupStrategyTestsDbHelper(); @BeforeClass public static void createDatabase() throws Exception { @@ -50,4 +52,5 @@ public class BasicLookupStrategyTests extends AbstractBasicLookupStrategyTests { public DataSource getDataSource() { return DATABASE_HELPER.getDataSource(); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTestsDbHelper.java b/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTestsDbHelper.java index 8c1f7042dc..32a2547351 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTestsDbHelper.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyTestsDbHelper.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import org.springframework.core.io.ClassPathResource; @@ -23,15 +24,20 @@ import org.springframework.util.FileCopyUtils; /** * Helper class to initialize the database for BasicLookupStrategyTests. + * * @author Andrei Stefan * @author Paul Wheeler */ public class BasicLookupStrategyTestsDbHelper { + private static final String ACL_SCHEMA_SQL_FILE = "createAclSchema.sql"; + private static final String ACL_SCHEMA_SQL_FILE_WITH_ACL_CLASS_ID = "createAclSchemaWithAclClassIdType.sql"; private SingleConnectionDataSource dataSource; + private JdbcTemplate jdbcTemplate; + private boolean withAclClassIdType; public BasicLookupStrategyTestsDbHelper() { @@ -45,28 +51,28 @@ public class BasicLookupStrategyTestsDbHelper { // Use a different connection url so the tests can run in parallel String connectionUrl; String sqlClassPathResource; - if (!withAclClassIdType) { + if (!this.withAclClassIdType) { connectionUrl = "jdbc:hsqldb:mem:lookupstrategytest"; sqlClassPathResource = ACL_SCHEMA_SQL_FILE; - } else { + } + else { connectionUrl = "jdbc:hsqldb:mem:lookupstrategytestWithAclClassIdType"; sqlClassPathResource = ACL_SCHEMA_SQL_FILE_WITH_ACL_CLASS_ID; - } - dataSource = new SingleConnectionDataSource(connectionUrl, "sa", "", true); - dataSource.setDriverClassName("org.hsqldb.jdbcDriver"); - jdbcTemplate = new JdbcTemplate(dataSource); - + this.dataSource = new SingleConnectionDataSource(connectionUrl, "sa", "", true); + this.dataSource.setDriverClassName("org.hsqldb.jdbcDriver"); + this.jdbcTemplate = new JdbcTemplate(this.dataSource); Resource resource = new ClassPathResource(sqlClassPathResource); String sql = new String(FileCopyUtils.copyToByteArray(resource.getInputStream())); - jdbcTemplate.execute(sql); + this.jdbcTemplate.execute(sql); } public JdbcTemplate getJdbcTemplate() { - return jdbcTemplate; + return this.jdbcTemplate; } public SingleConnectionDataSource getDataSource() { - return dataSource; + return this.dataSource; } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyWithAclClassTypeTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyWithAclClassTypeTests.java index 60ca2508f7..532304a57b 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyWithAclClassTypeTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/BasicLookupStrategyWithAclClassTypeTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; import java.util.Arrays; @@ -20,10 +21,12 @@ import java.util.Map; import javax.sql.DataSource; +import junit.framework.Assert; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; + import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.jdbc.core.JdbcTemplate; @@ -34,8 +37,6 @@ import org.springframework.security.acls.domain.ObjectIdentityImpl; import org.springframework.security.acls.model.Acl; import org.springframework.security.acls.model.ObjectIdentity; -import junit.framework.Assert; - /** * Tests {@link BasicLookupStrategy} with Acl Class type id set to UUID. * @@ -67,34 +68,35 @@ public class BasicLookupStrategyWithAclClassTypeTests extends AbstractBasicLooku DATABASE_HELPER.getDataSource().destroy(); } + @Override @Before public void initializeBeans() { super.initializeBeans(); - uuidEnabledStrategy = new BasicLookupStrategy(getDataSource(), aclCache(), aclAuthStrategy(), - new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger())); - uuidEnabledStrategy.setPermissionFactory(new DefaultPermissionFactory()); - uuidEnabledStrategy.setAclClassIdSupported(true); - uuidEnabledStrategy.setConversionService(new DefaultConversionService()); + this.uuidEnabledStrategy = new BasicLookupStrategy(getDataSource(), aclCache(), aclAuthStrategy(), + new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger())); + this.uuidEnabledStrategy.setPermissionFactory(new DefaultPermissionFactory()); + this.uuidEnabledStrategy.setAclClassIdSupported(true); + this.uuidEnabledStrategy.setConversionService(new DefaultConversionService()); } @Before public void populateDatabaseForAclClassTypeTests() { - String query = "INSERT INTO acl_class(ID,CLASS,CLASS_ID_TYPE) VALUES (3,'" - + TARGET_CLASS_WITH_UUID - + "', 'java.util.UUID');" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (4,3,'" - + OBJECT_IDENTITY_UUID.toString() + "',null,1,1);" - + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (5,3,'" - + OBJECT_IDENTITY_LONG_AS_UUID + "',null,1,1);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (5,4,0,1,8,0,0,0);" - + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (6,5,0,1,8,0,0,0);"; + String query = "INSERT INTO acl_class(ID,CLASS,CLASS_ID_TYPE) VALUES (3,'" + TARGET_CLASS_WITH_UUID + + "', 'java.util.UUID');" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (4,3,'" + + OBJECT_IDENTITY_UUID.toString() + "',null,1,1);" + + "INSERT INTO acl_object_identity(ID,OBJECT_ID_CLASS,OBJECT_ID_IDENTITY,PARENT_OBJECT,OWNER_SID,ENTRIES_INHERITING) VALUES (5,3,'" + + OBJECT_IDENTITY_LONG_AS_UUID + "',null,1,1);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (5,4,0,1,8,0,0,0);" + + "INSERT INTO acl_entry(ID,ACL_OBJECT_IDENTITY,ACE_ORDER,SID,MASK,GRANTING,AUDIT_SUCCESS,AUDIT_FAILURE) VALUES (6,5,0,1,8,0,0,0);"; DATABASE_HELPER.getJdbcTemplate().execute(query); } @Test public void testReadObjectIdentityUsingUuidType() { ObjectIdentity oid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, OBJECT_IDENTITY_UUID); - Map foundAcls = uuidEnabledStrategy.readAclsById(Arrays.asList(oid), Arrays.asList(BEN_SID)); + Map foundAcls = this.uuidEnabledStrategy.readAclsById(Arrays.asList(oid), + Arrays.asList(BEN_SID)); Assert.assertEquals(1, foundAcls.size()); Assert.assertNotNull(foundAcls.get(oid)); } @@ -102,7 +104,8 @@ public class BasicLookupStrategyWithAclClassTypeTests extends AbstractBasicLooku @Test public void testReadObjectIdentityUsingLongTypeWithConversionServiceEnabled() { ObjectIdentity oid = new ObjectIdentityImpl(TARGET_CLASS, 100L); - Map foundAcls = uuidEnabledStrategy.readAclsById(Arrays.asList(oid), Arrays.asList(BEN_SID)); + Map foundAcls = this.uuidEnabledStrategy.readAclsById(Arrays.asList(oid), + Arrays.asList(BEN_SID)); Assert.assertEquals(1, foundAcls.size()); Assert.assertNotNull(foundAcls.get(oid)); } @@ -110,6 +113,7 @@ public class BasicLookupStrategyWithAclClassTypeTests extends AbstractBasicLooku @Test(expected = ConversionFailedException.class) public void testReadObjectIdentityUsingNonUuidInDatabase() { ObjectIdentity oid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, OBJECT_IDENTITY_LONG_AS_UUID); - uuidEnabledStrategy.readAclsById(Arrays.asList(oid), Arrays.asList(BEN_SID)); + this.uuidEnabledStrategy.readAclsById(Arrays.asList(oid), Arrays.asList(BEN_SID)); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/DatabaseSeeder.java b/acl/src/test/java/org/springframework/security/acls/jdbc/DatabaseSeeder.java index d8170b3081..eca0b5d635 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/DatabaseSeeder.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/DatabaseSeeder.java @@ -13,34 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; -import org.springframework.core.io.Resource; - -import org.springframework.jdbc.core.JdbcTemplate; - -import org.springframework.util.Assert; -import org.springframework.util.FileCopyUtils; - import java.io.IOException; import javax.sql.DataSource; +import org.springframework.core.io.Resource; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.util.Assert; +import org.springframework.util.FileCopyUtils; + /** * Seeds the database for {@link JdbcMutableAclServiceTests}. * * @author Ben Alex */ public class DatabaseSeeder { - // ~ Constructors - // =================================================================================================== public DatabaseSeeder(DataSource dataSource, Resource resource) throws IOException { Assert.notNull(dataSource, "dataSource required"); Assert.notNull(resource, "resource required"); - JdbcTemplate template = new JdbcTemplate(dataSource); String sql = new String(FileCopyUtils.copyToByteArray(resource.getInputStream())); template.execute(sql); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/EhCacheBasedAclCacheTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/EhCacheBasedAclCacheTests.java index 1e90a4ddc0..d293b50084 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/EhCacheBasedAclCacheTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/EhCacheBasedAclCacheTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.jdbc; -import static org.mockito.Mockito.*; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.acls.jdbc; import java.io.File; import java.io.FileInputStream; @@ -28,7 +26,6 @@ import java.util.List; import net.sf.ehcache.Ehcache; import net.sf.ehcache.Element; - import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -37,7 +34,14 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.security.acls.domain.*; + +import org.springframework.security.acls.domain.AclAuthorizationStrategy; +import org.springframework.security.acls.domain.AclAuthorizationStrategyImpl; +import org.springframework.security.acls.domain.AclImpl; +import org.springframework.security.acls.domain.ConsoleAuditLogger; +import org.springframework.security.acls.domain.DefaultPermissionGrantingStrategy; +import org.springframework.security.acls.domain.EhCacheBasedAclCache; +import org.springframework.security.acls.domain.ObjectIdentityImpl; import org.springframework.security.acls.model.MutableAcl; import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -47,6 +51,12 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.util.FieldUtils; import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + /** * Tests {@link EhCacheBasedAclCache} * @@ -54,10 +64,12 @@ import org.springframework.test.util.ReflectionTestUtils; */ @RunWith(MockitoJUnitRunner.class) public class EhCacheBasedAclCacheTests { + private static final String TARGET_CLASS = "org.springframework.security.acls.TargetObject"; @Mock private Ehcache cache; + @Captor private ArgumentCaptor element; @@ -67,17 +79,14 @@ public class EhCacheBasedAclCacheTests { @Before public void setup() { - myCache = new EhCacheBasedAclCache(cache, new DefaultPermissionGrantingStrategy( - new ConsoleAuditLogger()), new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_USER"))); - + this.myCache = new EhCacheBasedAclCache(this.cache, + new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()), + new AclAuthorizationStrategyImpl(new SimpleGrantedAuthority("ROLE_USER"))); ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 100L); AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); - - acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, - new ConsoleAuditLogger()); + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); + this.acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, new ConsoleAuditLogger()); } @After @@ -87,48 +96,43 @@ public class EhCacheBasedAclCacheTests { @Test(expected = IllegalArgumentException.class) public void constructorRejectsNullParameters() { - new EhCacheBasedAclCache(null, new DefaultPermissionGrantingStrategy( - new ConsoleAuditLogger()), new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_USER"))); + new EhCacheBasedAclCache(null, new DefaultPermissionGrantingStrategy(new ConsoleAuditLogger()), + new AclAuthorizationStrategyImpl(new SimpleGrantedAuthority("ROLE_USER"))); } @Test public void methodsRejectNullParameters() { try { Serializable id = null; - myCache.evictFromCache(id); + this.myCache.evictFromCache(id); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { ObjectIdentity obj = null; - myCache.evictFromCache(obj); + this.myCache.evictFromCache(obj); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { Serializable id = null; - myCache.getFromCache(id); + this.myCache.getFromCache(id); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { ObjectIdentity obj = null; - myCache.getFromCache(obj); + this.myCache.getFromCache(obj); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { MutableAcl acl = null; - myCache.putInCache(acl); + this.myCache.putInCache(acl); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -142,146 +146,107 @@ public class EhCacheBasedAclCacheTests { File file = File.createTempFile("SEC_TEST", ".object"); FileOutputStream fos = new FileOutputStream(file); ObjectOutputStream oos = new ObjectOutputStream(fos); - oos.writeObject(acl); + oos.writeObject(this.acl); oos.close(); - FileInputStream fis = new FileInputStream(file); ObjectInputStream ois = new ObjectInputStream(fis); MutableAcl retrieved = (MutableAcl) ois.readObject(); ois.close(); - - assertThat(retrieved).isEqualTo(acl); - - Object retrieved1 = FieldUtils.getProtectedFieldValue("aclAuthorizationStrategy", - retrieved); + assertThat(retrieved).isEqualTo(this.acl); + Object retrieved1 = FieldUtils.getProtectedFieldValue("aclAuthorizationStrategy", retrieved); assertThat(retrieved1).isNull(); - - Object retrieved2 = FieldUtils.getProtectedFieldValue( - "permissionGrantingStrategy", retrieved); + Object retrieved2 = FieldUtils.getProtectedFieldValue("permissionGrantingStrategy", retrieved); assertThat(retrieved2).isNull(); } @Test public void clearCache() { - myCache.clearCache(); - - verify(cache).removeAll(); + this.myCache.clearCache(); + verify(this.cache).removeAll(); } @Test public void putInCache() { - myCache.putInCache(acl); - - verify(cache, times(2)).put(element.capture()); - assertThat(element.getValue().getKey()).isEqualTo(acl.getId()); - assertThat(element.getValue().getObjectValue()).isEqualTo(acl); - assertThat(element.getAllValues().get(0).getKey()).isEqualTo( - acl.getObjectIdentity()); - assertThat(element.getAllValues().get(0).getObjectValue()).isEqualTo(acl); + this.myCache.putInCache(this.acl); + verify(this.cache, times(2)).put(this.element.capture()); + assertThat(this.element.getValue().getKey()).isEqualTo(this.acl.getId()); + assertThat(this.element.getValue().getObjectValue()).isEqualTo(this.acl); + assertThat(this.element.getAllValues().get(0).getKey()).isEqualTo(this.acl.getObjectIdentity()); + assertThat(this.element.getAllValues().get(0).getObjectValue()).isEqualTo(this.acl); } @Test public void putInCacheAclWithParent() { - Authentication auth = new TestingAuthenticationToken("user", "password", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("user", "password", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - - ObjectIdentity identityParent = new ObjectIdentityImpl(TARGET_CLASS, - 2L); + ObjectIdentity identityParent = new ObjectIdentityImpl(TARGET_CLASS, 2L); AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); - MutableAcl parentAcl = new AclImpl(identityParent, 2L, - aclAuthorizationStrategy, new ConsoleAuditLogger()); - acl.setParent(parentAcl); - - myCache.putInCache(acl); - - verify(cache, times(4)).put(element.capture()); - - List allValues = element.getAllValues(); - + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); + MutableAcl parentAcl = new AclImpl(identityParent, 2L, aclAuthorizationStrategy, new ConsoleAuditLogger()); + this.acl.setParent(parentAcl); + this.myCache.putInCache(this.acl); + verify(this.cache, times(4)).put(this.element.capture()); + List allValues = this.element.getAllValues(); assertThat(allValues.get(0).getKey()).isEqualTo(parentAcl.getObjectIdentity()); assertThat(allValues.get(0).getObjectValue()).isEqualTo(parentAcl); - assertThat(allValues.get(1).getKey()).isEqualTo(parentAcl.getId()); assertThat(allValues.get(1).getObjectValue()).isEqualTo(parentAcl); - - assertThat(allValues.get(2).getKey()).isEqualTo(acl.getObjectIdentity()); - assertThat(allValues.get(2).getObjectValue()).isEqualTo(acl); - - assertThat(allValues.get(3).getKey()).isEqualTo(acl.getId()); - assertThat(allValues.get(3).getObjectValue()).isEqualTo(acl); + assertThat(allValues.get(2).getKey()).isEqualTo(this.acl.getObjectIdentity()); + assertThat(allValues.get(2).getObjectValue()).isEqualTo(this.acl); + assertThat(allValues.get(3).getKey()).isEqualTo(this.acl.getId()); + assertThat(allValues.get(3).getObjectValue()).isEqualTo(this.acl); } @Test public void getFromCacheSerializable() { - when(cache.get(acl.getId())).thenReturn(new Element(acl.getId(), acl)); - - assertThat(myCache.getFromCache(acl.getId())).isEqualTo(acl); + given(this.cache.get(this.acl.getId())).willReturn(new Element(this.acl.getId(), this.acl)); + assertThat(this.myCache.getFromCache(this.acl.getId())).isEqualTo(this.acl); } @Test public void getFromCacheSerializablePopulatesTransient() { - when(cache.get(acl.getId())).thenReturn(new Element(acl.getId(), acl)); - - myCache.putInCache(acl); - - ReflectionTestUtils.setField(acl, "permissionGrantingStrategy", null); - ReflectionTestUtils.setField(acl, "aclAuthorizationStrategy", null); - - MutableAcl fromCache = myCache.getFromCache(acl.getId()); - - assertThat(ReflectionTestUtils.getField(fromCache, "aclAuthorizationStrategy")) - .isNotNull(); - assertThat(ReflectionTestUtils.getField(fromCache, "permissionGrantingStrategy")) - .isNotNull(); + given(this.cache.get(this.acl.getId())).willReturn(new Element(this.acl.getId(), this.acl)); + this.myCache.putInCache(this.acl); + ReflectionTestUtils.setField(this.acl, "permissionGrantingStrategy", null); + ReflectionTestUtils.setField(this.acl, "aclAuthorizationStrategy", null); + MutableAcl fromCache = this.myCache.getFromCache(this.acl.getId()); + assertThat(ReflectionTestUtils.getField(fromCache, "aclAuthorizationStrategy")).isNotNull(); + assertThat(ReflectionTestUtils.getField(fromCache, "permissionGrantingStrategy")).isNotNull(); } @Test public void getFromCacheObjectIdentity() { - when(cache.get(acl.getId())).thenReturn(new Element(acl.getId(), acl)); - - assertThat(myCache.getFromCache(acl.getId())).isEqualTo(acl); + given(this.cache.get(this.acl.getId())).willReturn(new Element(this.acl.getId(), this.acl)); + assertThat(this.myCache.getFromCache(this.acl.getId())).isEqualTo(this.acl); } @Test public void getFromCacheObjectIdentityPopulatesTransient() { - when(cache.get(acl.getObjectIdentity())) - .thenReturn(new Element(acl.getId(), acl)); - - myCache.putInCache(acl); - - ReflectionTestUtils.setField(acl, "permissionGrantingStrategy", null); - ReflectionTestUtils.setField(acl, "aclAuthorizationStrategy", null); - - MutableAcl fromCache = myCache.getFromCache(acl.getObjectIdentity()); - - assertThat(ReflectionTestUtils.getField(fromCache, "aclAuthorizationStrategy")) - .isNotNull(); - assertThat(ReflectionTestUtils.getField(fromCache, "permissionGrantingStrategy")) - .isNotNull(); + given(this.cache.get(this.acl.getObjectIdentity())).willReturn(new Element(this.acl.getId(), this.acl)); + this.myCache.putInCache(this.acl); + ReflectionTestUtils.setField(this.acl, "permissionGrantingStrategy", null); + ReflectionTestUtils.setField(this.acl, "aclAuthorizationStrategy", null); + MutableAcl fromCache = this.myCache.getFromCache(this.acl.getObjectIdentity()); + assertThat(ReflectionTestUtils.getField(fromCache, "aclAuthorizationStrategy")).isNotNull(); + assertThat(ReflectionTestUtils.getField(fromCache, "permissionGrantingStrategy")).isNotNull(); } @Test public void evictCacheSerializable() { - when(cache.get(acl.getObjectIdentity())) - .thenReturn(new Element(acl.getId(), acl)); - - myCache.evictFromCache(acl.getObjectIdentity()); - - verify(cache).remove(acl.getId()); - verify(cache).remove(acl.getObjectIdentity()); + given(this.cache.get(this.acl.getObjectIdentity())).willReturn(new Element(this.acl.getId(), this.acl)); + this.myCache.evictFromCache(this.acl.getObjectIdentity()); + verify(this.cache).remove(this.acl.getId()); + verify(this.cache).remove(this.acl.getObjectIdentity()); } @Test public void evictCacheObjectIdentity() { - when(cache.get(acl.getId())).thenReturn(new Element(acl.getId(), acl)); - - myCache.evictFromCache(acl.getId()); - - verify(cache).remove(acl.getId()); - verify(cache).remove(acl.getObjectIdentity()); + given(this.cache.get(this.acl.getId())).willReturn(new Element(this.acl.getId(), this.acl)); + this.myCache.evictFromCache(this.acl.getId()); + verify(this.cache).remove(this.acl.getId()); + verify(this.cache).remove(this.acl.getObjectIdentity()); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcAclServiceTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcAclServiceTests.java index b039e4d656..49985c2d2e 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcAclServiceTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcAclServiceTests.java @@ -13,14 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import javax.sql.DataSource; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; @@ -32,17 +43,15 @@ import org.springframework.security.acls.model.NotFoundException; import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.security.acls.model.Sid; -import javax.sql.DataSource; -import java.util.*; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.AdditionalMatchers.aryEq; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; /** - * Unit and Integration tests the ACL JdbcAclService using an - * in-memory database. + * Unit and Integration tests the ACL JdbcAclService using an in-memory database. * * @author Nena Raab */ @@ -61,51 +70,48 @@ public class JdbcAclServiceTests { JdbcOperations jdbcOperations; private JdbcAclService aclServiceIntegration; + private JdbcAclService aclService; @Before public void setUp() { - aclService = new JdbcAclService(jdbcOperations, lookupStrategy); - aclServiceIntegration = new JdbcAclService(embeddedDatabase, lookupStrategy); + this.aclService = new JdbcAclService(this.jdbcOperations, this.lookupStrategy); + this.aclServiceIntegration = new JdbcAclService(this.embeddedDatabase, this.lookupStrategy); } @Before public void setUpEmbeddedDatabase() { - embeddedDatabase = new EmbeddedDatabaseBuilder()// - .addScript("createAclSchemaWithAclClassIdType.sql") - .addScript("db/sql/test_data_hierarchy.sql") - .build(); + // @formatter:off + this.embeddedDatabase = new EmbeddedDatabaseBuilder() + .addScript("createAclSchemaWithAclClassIdType.sql") + .addScript("db/sql/test_data_hierarchy.sql") + .build(); + // @formatter:on } @After public void tearDownEmbeddedDatabase() { - embeddedDatabase.shutdown(); + this.embeddedDatabase.shutdown(); } // SEC-1898 @Test(expected = NotFoundException.class) public void readAclByIdMissingAcl() { Map result = new HashMap<>(); - when( - lookupStrategy.readAclsById(anyList(), - anyList())).thenReturn(result); + given(this.lookupStrategy.readAclsById(anyList(), anyList())).willReturn(result); ObjectIdentity objectIdentity = new ObjectIdentityImpl(Object.class, 1); List sids = Arrays.asList(new PrincipalSid("user")); - - aclService.readAclById(objectIdentity, sids); + this.aclService.readAclById(objectIdentity, sids); } @Test public void findOneChildren() { List result = new ArrayList<>(); result.add(new ObjectIdentityImpl(Object.class, "5577")); - Object[] args = {"1", "org.springframework.security.acls.jdbc.JdbcAclServiceTests$MockLongIdDomainObject"}; - when( - jdbcOperations.query(anyString(), - aryEq(args), any(RowMapper.class))).thenReturn(result); + Object[] args = { "1", "org.springframework.security.acls.jdbc.JdbcAclServiceTests$MockLongIdDomainObject" }; + given(this.jdbcOperations.query(anyString(), eq(args), any(RowMapper.class))).willReturn(result); ObjectIdentity objectIdentity = new ObjectIdentityImpl(MockLongIdDomainObject.class, 1L); - - List objectIdentities = aclService.findChildren(objectIdentity); + List objectIdentities = this.aclService.findChildren(objectIdentity); assertThat(objectIdentities.size()).isEqualTo(1); assertThat(objectIdentities.get(0).getIdentifier()).isEqualTo("5577"); } @@ -113,19 +119,14 @@ public class JdbcAclServiceTests { @Test public void findNoChildren() { ObjectIdentity objectIdentity = new ObjectIdentityImpl(MockLongIdDomainObject.class, 1L); - - List objectIdentities = aclService.findChildren(objectIdentity); + List objectIdentities = this.aclService.findChildren(objectIdentity); assertThat(objectIdentities).isNull(); } - // ~ Some integration tests - // ======================================================================================================== - @Test public void findChildrenWithoutIdType() { ObjectIdentity objectIdentity = new ObjectIdentityImpl(MockLongIdDomainObject.class, 4711L); - - List objectIdentities = aclServiceIntegration.findChildren(objectIdentity); + List objectIdentities = this.aclServiceIntegration.findChildren(objectIdentity); assertThat(objectIdentities.size()).isEqualTo(1); assertThat(objectIdentities.get(0).getType()).isEqualTo(MockUntypedIdDomainObject.class.getName()); assertThat(objectIdentities.get(0).getIdentifier()).isEqualTo(5000L); @@ -134,16 +135,14 @@ public class JdbcAclServiceTests { @Test public void findChildrenForUnknownObject() { ObjectIdentity objectIdentity = new ObjectIdentityImpl(Object.class, 33); - - List objectIdentities = aclServiceIntegration.findChildren(objectIdentity); + List objectIdentities = this.aclServiceIntegration.findChildren(objectIdentity); assertThat(objectIdentities).isNull(); } @Test public void findChildrenOfIdTypeLong() { ObjectIdentity objectIdentity = new ObjectIdentityImpl("location", "US-PAL"); - - List objectIdentities = aclServiceIntegration.findChildren(objectIdentity); + List objectIdentities = this.aclServiceIntegration.findChildren(objectIdentity); assertThat(objectIdentities.size()).isEqualTo(2); assertThat(objectIdentities.get(0).getType()).isEqualTo(MockLongIdDomainObject.class.getName()); assertThat(objectIdentities.get(0).getIdentifier()).isEqualTo(4711L); @@ -154,9 +153,8 @@ public class JdbcAclServiceTests { @Test public void findChildrenOfIdTypeString() { ObjectIdentity objectIdentity = new ObjectIdentityImpl("location", "US"); - - aclServiceIntegration.setAclClassIdSupported(true); - List objectIdentities = aclServiceIntegration.findChildren(objectIdentity); + this.aclServiceIntegration.setAclClassIdSupported(true); + List objectIdentities = this.aclServiceIntegration.findChildren(objectIdentity); assertThat(objectIdentities.size()).isEqualTo(1); assertThat(objectIdentities.get(0).getType()).isEqualTo("location"); assertThat(objectIdentities.get(0).getIdentifier()).isEqualTo("US-PAL"); @@ -165,35 +163,40 @@ public class JdbcAclServiceTests { @Test public void findChildrenOfIdTypeUUID() { ObjectIdentity objectIdentity = new ObjectIdentityImpl(MockUntypedIdDomainObject.class, 5000L); - - aclServiceIntegration.setAclClassIdSupported(true); - List objectIdentities = aclServiceIntegration.findChildren(objectIdentity); + this.aclServiceIntegration.setAclClassIdSupported(true); + List objectIdentities = this.aclServiceIntegration.findChildren(objectIdentity); assertThat(objectIdentities.size()).isEqualTo(1); assertThat(objectIdentities.get(0).getType()).isEqualTo("costcenter"); - assertThat(objectIdentities.get(0).getIdentifier()).isEqualTo(UUID.fromString("25d93b3f-c3aa-4814-9d5e-c7c96ced7762")); + assertThat(objectIdentities.get(0).getIdentifier()) + .isEqualTo(UUID.fromString("25d93b3f-c3aa-4814-9d5e-c7c96ced7762")); } - private class MockLongIdDomainObject { + class MockLongIdDomainObject { + private Object id; - public Object getId() { - return id; + Object getId() { + return this.id; } - public void setId(Object id) { + void setId(Object id) { this.id = id; } + } - private class MockUntypedIdDomainObject { + class MockUntypedIdDomainObject { + private Object id; - public Object getId() { - return id; + Object getId() { + return this.id; } - public void setId(Object id) { + void setId(Object id) { this.id = id; } + } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTests.java index 0e54ed8f58..fe44732540 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.jdbc; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.acls.jdbc; import java.util.Arrays; import java.util.List; @@ -25,6 +23,7 @@ import java.util.Map; import javax.sql.DataSource; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.io.ClassPathResource; import org.springframework.jdbc.core.JdbcTemplate; @@ -55,6 +54,11 @@ import org.springframework.test.context.transaction.AfterTransaction; import org.springframework.test.context.transaction.BeforeTransaction; import org.springframework.transaction.annotation.Transactional; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.spy; + /** * Integration tests the ACL system using an in-memory database. * @@ -62,56 +66,49 @@ import org.springframework.transaction.annotation.Transactional; * @author Andrei Stefan */ @ContextConfiguration(locations = { "/jdbcMutableAclServiceTests-context.xml" }) -public class JdbcMutableAclServiceTests extends - AbstractTransactionalJUnit4SpringContextTests { - // ~ Constant fields - // ================================================================================================ +public class JdbcMutableAclServiceTests extends AbstractTransactionalJUnit4SpringContextTests { private static final String TARGET_CLASS = TargetObject.class.getName(); - private final Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_ADMINISTRATOR"); + private final Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_ADMINISTRATOR"); public static final String SELECT_ALL_CLASSES = "SELECT * FROM acl_class WHERE class = ?"; - // ~ Instance fields - // ================================================================================================ + private final ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, 100L); - private final ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, - 100L); - private final ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS, - 101L); - private final ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, - 102L); + private final ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS, 101L); + + private final ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 102L); @Autowired private JdbcMutableAclService jdbcMutableAclService; + @Autowired private AclCache aclCache; + @Autowired private LookupStrategy lookupStrategy; + @Autowired private DataSource dataSource; + @Autowired private JdbcTemplate jdbcTemplate; - // ~ Methods - // ======================================================================================================== - protected String getSqlClassPathResource() { return "createAclSchema.sql"; } protected ObjectIdentity getTopParentOid() { - return topParentOid; + return this.topParentOid; } protected ObjectIdentity getMiddleParentOid() { - return middleParentOid; + return this.middleParentOid; } protected ObjectIdentity getChildOid() { - return childOid; + return this.childOid; } protected String getTargetClass() { @@ -121,162 +118,134 @@ public class JdbcMutableAclServiceTests extends @BeforeTransaction public void createTables() throws Exception { try { - new DatabaseSeeder(dataSource, new ClassPathResource(getSqlClassPathResource())); + new DatabaseSeeder(this.dataSource, new ClassPathResource(getSqlClassPathResource())); // new DatabaseSeeder(dataSource, new // ClassPathResource("createAclSchemaPostgres.sql")); } - catch (Exception e) { - e.printStackTrace(); - throw e; + catch (Exception ex) { + ex.printStackTrace(); + throw ex; } } @AfterTransaction public void clearContextAndData() { SecurityContextHolder.clearContext(); - jdbcTemplate.execute("drop table acl_entry"); - jdbcTemplate.execute("drop table acl_object_identity"); - jdbcTemplate.execute("drop table acl_class"); - jdbcTemplate.execute("drop table acl_sid"); - aclCache.clearCache(); + this.jdbcTemplate.execute("drop table acl_entry"); + this.jdbcTemplate.execute("drop table acl_object_identity"); + this.jdbcTemplate.execute("drop table acl_class"); + this.jdbcTemplate.execute("drop table acl_sid"); + this.aclCache.clearCache(); } @Test @Transactional public void testLifecycle() { - SecurityContextHolder.getContext().setAuthentication(auth); - - MutableAcl topParent = jdbcMutableAclService.createAcl(getTopParentOid()); - MutableAcl middleParent = jdbcMutableAclService.createAcl(getMiddleParentOid()); - MutableAcl child = jdbcMutableAclService.createAcl(getChildOid()); - + SecurityContextHolder.getContext().setAuthentication(this.auth); + MutableAcl topParent = this.jdbcMutableAclService.createAcl(getTopParentOid()); + MutableAcl middleParent = this.jdbcMutableAclService.createAcl(getMiddleParentOid()); + MutableAcl child = this.jdbcMutableAclService.createAcl(getChildOid()); // Specify the inheritance hierarchy middleParent.setParent(topParent); child.setParent(middleParent); - // Now let's add a couple of permissions - topParent.insertAce(0, BasePermission.READ, new PrincipalSid(auth), true); - topParent.insertAce(1, BasePermission.WRITE, new PrincipalSid(auth), false); - middleParent.insertAce(0, BasePermission.DELETE, new PrincipalSid(auth), true); - child.insertAce(0, BasePermission.DELETE, new PrincipalSid(auth), false); - + topParent.insertAce(0, BasePermission.READ, new PrincipalSid(this.auth), true); + topParent.insertAce(1, BasePermission.WRITE, new PrincipalSid(this.auth), false); + middleParent.insertAce(0, BasePermission.DELETE, new PrincipalSid(this.auth), true); + child.insertAce(0, BasePermission.DELETE, new PrincipalSid(this.auth), false); // Explicitly save the changed ACL - jdbcMutableAclService.updateAcl(topParent); - jdbcMutableAclService.updateAcl(middleParent); - jdbcMutableAclService.updateAcl(child); - + this.jdbcMutableAclService.updateAcl(topParent); + this.jdbcMutableAclService.updateAcl(middleParent); + this.jdbcMutableAclService.updateAcl(child); // Let's check if we can read them back correctly - Map map = jdbcMutableAclService.readAclsById(Arrays.asList( - getTopParentOid(), getMiddleParentOid(), getChildOid())); + Map map = this.jdbcMutableAclService + .readAclsById(Arrays.asList(getTopParentOid(), getMiddleParentOid(), getChildOid())); assertThat(map).hasSize(3); - // Replace our current objects with their retrieved versions topParent = (MutableAcl) map.get(getTopParentOid()); middleParent = (MutableAcl) map.get(getMiddleParentOid()); child = (MutableAcl) map.get(getChildOid()); - // Check the retrieved versions has IDs assertThat(topParent.getId()).isNotNull(); assertThat(middleParent.getId()).isNotNull(); assertThat(child.getId()).isNotNull(); - // Check their parents were correctly persisted assertThat(topParent.getParentAcl()).isNull(); assertThat(middleParent.getParentAcl().getObjectIdentity()).isEqualTo(getTopParentOid()); assertThat(child.getParentAcl().getObjectIdentity()).isEqualTo(getMiddleParentOid()); - // Check their ACEs were correctly persisted assertThat(topParent.getEntries()).hasSize(2); assertThat(middleParent.getEntries()).hasSize(1); assertThat(child.getEntries()).hasSize(1); - // Check the retrieved rights are correct List read = Arrays.asList(BasePermission.READ); List write = Arrays.asList(BasePermission.WRITE); List delete = Arrays.asList(BasePermission.DELETE); - List pSid = Arrays.asList((Sid) new PrincipalSid(auth)); - + List pSid = Arrays.asList((Sid) new PrincipalSid(this.auth)); assertThat(topParent.isGranted(read, pSid, false)).isTrue(); assertThat(topParent.isGranted(write, pSid, false)).isFalse(); assertThat(middleParent.isGranted(delete, pSid, false)).isTrue(); assertThat(child.isGranted(delete, pSid, false)).isFalse(); - try { child.isGranted(Arrays.asList(BasePermission.ADMINISTRATION), pSid, false); fail("Should have thrown NotFoundException"); } catch (NotFoundException expected) { - } - // Now check the inherited rights (when not explicitly overridden) also look OK assertThat(child.isGranted(read, pSid, false)).isTrue(); assertThat(child.isGranted(write, pSid, false)).isFalse(); assertThat(child.isGranted(delete, pSid, false)).isFalse(); - // Next change the child so it doesn't inherit permissions from above child.setEntriesInheriting(false); - jdbcMutableAclService.updateAcl(child); - child = (MutableAcl) jdbcMutableAclService.readAclById(getChildOid()); + this.jdbcMutableAclService.updateAcl(child); + child = (MutableAcl) this.jdbcMutableAclService.readAclById(getChildOid()); assertThat(child.isEntriesInheriting()).isFalse(); - // Check the child permissions no longer inherit assertThat(child.isGranted(delete, pSid, true)).isFalse(); - try { child.isGranted(read, pSid, true); fail("Should have thrown NotFoundException"); } catch (NotFoundException expected) { - } - try { child.isGranted(write, pSid, true); fail("Should have thrown NotFoundException"); } catch (NotFoundException expected) { - } - // Let's add an identical permission to the child, but it'll appear AFTER the // current permission, so has no impact - child.insertAce(1, BasePermission.DELETE, new PrincipalSid(auth), true); - + child.insertAce(1, BasePermission.DELETE, new PrincipalSid(this.auth), true); // Let's also add another permission to the child - child.insertAce(2, BasePermission.CREATE, new PrincipalSid(auth), true); - + child.insertAce(2, BasePermission.CREATE, new PrincipalSid(this.auth), true); // Save the changed child - jdbcMutableAclService.updateAcl(child); - child = (MutableAcl) jdbcMutableAclService.readAclById(getChildOid()); + this.jdbcMutableAclService.updateAcl(child); + child = (MutableAcl) this.jdbcMutableAclService.readAclById(getChildOid()); assertThat(child.getEntries()).hasSize(3); - // Output permissions for (int i = 0; i < child.getEntries().size(); i++) { System.out.println(child.getEntries().get(i)); } - // Check the permissions are as they should be - assertThat(child.isGranted(delete, pSid, true)).isFalse(); // as earlier permission - // overrode + assertThat(child.isGranted(delete, pSid, true)).isFalse(); // as earlier + // permission + // overrode assertThat(child.isGranted(Arrays.asList(BasePermission.CREATE), pSid, true)).isTrue(); - // Now check the first ACE (index 0) really is DELETE for our Sid and is // non-granting AccessControlEntry entry = child.getEntries().get(0); assertThat(entry.getPermission().getMask()).isEqualTo(BasePermission.DELETE.getMask()); - assertThat(entry.getSid()).isEqualTo(new PrincipalSid(auth)); + assertThat(entry.getSid()).isEqualTo(new PrincipalSid(this.auth)); assertThat(entry.isGranting()).isFalse(); assertThat(entry.getId()).isNotNull(); - // Now delete that first ACE child.deleteAce(0); - // Save and check it worked - child = jdbcMutableAclService.updateAcl(child); + child = this.jdbcMutableAclService.updateAcl(child); assertThat(child.getEntries()).hasSize(2); assertThat(child.isGranted(delete, pSid, false)).isTrue(); - SecurityContextHolder.clearContext(); } @@ -286,38 +255,31 @@ public class JdbcMutableAclServiceTests extends @Test @Transactional public void deleteAclAlsoDeletesChildren() { - SecurityContextHolder.getContext().setAuthentication(auth); - - jdbcMutableAclService.createAcl(getTopParentOid()); - MutableAcl middleParent = jdbcMutableAclService.createAcl(getMiddleParentOid()); - MutableAcl child = jdbcMutableAclService.createAcl(getChildOid()); + SecurityContextHolder.getContext().setAuthentication(this.auth); + this.jdbcMutableAclService.createAcl(getTopParentOid()); + MutableAcl middleParent = this.jdbcMutableAclService.createAcl(getMiddleParentOid()); + MutableAcl child = this.jdbcMutableAclService.createAcl(getChildOid()); child.setParent(middleParent); - jdbcMutableAclService.updateAcl(middleParent); - jdbcMutableAclService.updateAcl(child); + this.jdbcMutableAclService.updateAcl(middleParent); + this.jdbcMutableAclService.updateAcl(child); // Check the childOid really is a child of middleParentOid - Acl childAcl = jdbcMutableAclService.readAclById(getChildOid()); - + Acl childAcl = this.jdbcMutableAclService.readAclById(getChildOid()); assertThat(childAcl.getParentAcl().getObjectIdentity()).isEqualTo(getMiddleParentOid()); - // Delete the mid-parent and test if the child was deleted, as well - jdbcMutableAclService.deleteAcl(getMiddleParentOid(), true); - + this.jdbcMutableAclService.deleteAcl(getMiddleParentOid(), true); try { - jdbcMutableAclService.readAclById(getMiddleParentOid()); + this.jdbcMutableAclService.readAclById(getMiddleParentOid()); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { - } try { - jdbcMutableAclService.readAclById(getChildOid()); + this.jdbcMutableAclService.readAclById(getChildOid()); fail("It should have thrown NotFoundException"); } catch (NotFoundException expected) { - } - - Acl acl = jdbcMutableAclService.readAclById(getTopParentOid()); + Acl acl = this.jdbcMutableAclService.readAclById(getTopParentOid()); assertThat(acl).isNotNull(); assertThat(getTopParentOid()).isEqualTo(acl.getObjectIdentity()); } @@ -325,21 +287,19 @@ public class JdbcMutableAclServiceTests extends @Test public void constructorRejectsNullParameters() { try { - new JdbcMutableAclService(null, lookupStrategy, aclCache); + new JdbcMutableAclService(null, this.lookupStrategy, this.aclCache); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new JdbcMutableAclService(dataSource, null, aclCache); + new JdbcMutableAclService(this.dataSource, null, this.aclCache); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new JdbcMutableAclService(dataSource, lookupStrategy, null); + new JdbcMutableAclService(this.dataSource, this.lookupStrategy, null); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -349,7 +309,7 @@ public class JdbcMutableAclServiceTests extends @Test public void createAclRejectsNullParameter() { try { - jdbcMutableAclService.createAcl(null); + this.jdbcMutableAclService.createAcl(null); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -359,13 +319,12 @@ public class JdbcMutableAclServiceTests extends @Test @Transactional public void createAclForADuplicateDomainObject() { - SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity duplicateOid = new ObjectIdentityImpl(TARGET_CLASS, - 100L); - jdbcMutableAclService.createAcl(duplicateOid); + SecurityContextHolder.getContext().setAuthentication(this.auth); + ObjectIdentity duplicateOid = new ObjectIdentityImpl(TARGET_CLASS, 100L); + this.jdbcMutableAclService.createAcl(duplicateOid); // Try to add the same object second time try { - jdbcMutableAclService.createAcl(duplicateOid); + this.jdbcMutableAclService.createAcl(duplicateOid); fail("It should have thrown AlreadyExistsException"); } catch (AlreadyExistsException expected) { @@ -376,7 +335,7 @@ public class JdbcMutableAclServiceTests extends @Transactional public void deleteAclRejectsNullParameters() { try { - jdbcMutableAclService.deleteAcl(null, true); + this.jdbcMutableAclService.deleteAcl(null, true); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -386,61 +345,52 @@ public class JdbcMutableAclServiceTests extends @Test @Transactional public void deleteAclWithChildrenThrowsException() { - SecurityContextHolder.getContext().setAuthentication(auth); - MutableAcl parent = jdbcMutableAclService.createAcl(getTopParentOid()); - MutableAcl child = jdbcMutableAclService.createAcl(getMiddleParentOid()); - + SecurityContextHolder.getContext().setAuthentication(this.auth); + MutableAcl parent = this.jdbcMutableAclService.createAcl(getTopParentOid()); + MutableAcl child = this.jdbcMutableAclService.createAcl(getMiddleParentOid()); // Specify the inheritance hierarchy child.setParent(parent); - jdbcMutableAclService.updateAcl(child); - + this.jdbcMutableAclService.updateAcl(child); try { - jdbcMutableAclService.setForeignKeysInDatabase(false); // switch on FK - // checking in the - // class, not database - jdbcMutableAclService.deleteAcl(getTopParentOid(), false); + this.jdbcMutableAclService.setForeignKeysInDatabase(false); // switch on FK + // checking in the + // class, not database + this.jdbcMutableAclService.deleteAcl(getTopParentOid(), false); fail("It should have thrown ChildrenExistException"); } catch (ChildrenExistException expected) { } finally { - jdbcMutableAclService.setForeignKeysInDatabase(true); // restore to the - // default + this.jdbcMutableAclService.setForeignKeysInDatabase(true); // restore to the + // default } } @Test @Transactional public void deleteAclRemovesRowsFromDatabase() { - SecurityContextHolder.getContext().setAuthentication(auth); - MutableAcl child = jdbcMutableAclService.createAcl(getChildOid()); - child.insertAce(0, BasePermission.DELETE, new PrincipalSid(auth), false); - jdbcMutableAclService.updateAcl(child); - + SecurityContextHolder.getContext().setAuthentication(this.auth); + MutableAcl child = this.jdbcMutableAclService.createAcl(getChildOid()); + child.insertAce(0, BasePermission.DELETE, new PrincipalSid(this.auth), false); + this.jdbcMutableAclService.updateAcl(child); // Remove the child and check all related database rows were removed accordingly - jdbcMutableAclService.deleteAcl(getChildOid(), false); - assertThat( - jdbcTemplate.queryForList(SELECT_ALL_CLASSES, - new Object[] { getTargetClass() })).hasSize(1); - assertThat(jdbcTemplate.queryForList("select * from acl_object_identity") - ).isEmpty(); - assertThat(jdbcTemplate.queryForList("select * from acl_entry")).isEmpty(); - + this.jdbcMutableAclService.deleteAcl(getChildOid(), false); + assertThat(this.jdbcTemplate.queryForList(SELECT_ALL_CLASSES, new Object[] { getTargetClass() })).hasSize(1); + assertThat(this.jdbcTemplate.queryForList("select * from acl_object_identity")).isEmpty(); + assertThat(this.jdbcTemplate.queryForList("select * from acl_entry")).isEmpty(); // Check the cache - assertThat(aclCache.getFromCache(getChildOid())).isNull(); - assertThat(aclCache.getFromCache(102L)).isNull(); + assertThat(this.aclCache.getFromCache(getChildOid())).isNull(); + assertThat(this.aclCache.getFromCache(102L)).isNull(); } /** SEC-1107 */ @Test @Transactional public void identityWithIntegerIdIsSupportedByCreateAcl() { - SecurityContextHolder.getContext().setAuthentication(auth); + SecurityContextHolder.getContext().setAuthentication(this.auth); ObjectIdentity oid = new ObjectIdentityImpl(TARGET_CLASS, 101); - jdbcMutableAclService.createAcl(oid); - - assertThat(jdbcMutableAclService.readAclById(new ObjectIdentityImpl( - TARGET_CLASS, 101L))).isNotNull(); + this.jdbcMutableAclService.createAcl(oid); + assertThat(this.jdbcMutableAclService.readAclById(new ObjectIdentityImpl(TARGET_CLASS, 101L))).isNotNull(); } /** @@ -449,32 +399,25 @@ public class JdbcMutableAclServiceTests extends @Test @Transactional public void childrenAreClearedFromCacheWhenParentIsUpdated() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_ADMINISTRATOR"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_ADMINISTRATOR"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity parentOid = new ObjectIdentityImpl(TARGET_CLASS, 104L); ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS, 105L); - - MutableAcl parent = jdbcMutableAclService.createAcl(parentOid); - MutableAcl child = jdbcMutableAclService.createAcl(childOid); - + MutableAcl parent = this.jdbcMutableAclService.createAcl(parentOid); + MutableAcl child = this.jdbcMutableAclService.createAcl(childOid); child.setParent(parent); - jdbcMutableAclService.updateAcl(child); - - parent = (AclImpl) jdbcMutableAclService.readAclById(parentOid); + this.jdbcMutableAclService.updateAcl(child); + parent = (AclImpl) this.jdbcMutableAclService.readAclById(parentOid); parent.insertAce(0, BasePermission.READ, new PrincipalSid("ben"), true); - jdbcMutableAclService.updateAcl(parent); - - parent = (AclImpl) jdbcMutableAclService.readAclById(parentOid); + this.jdbcMutableAclService.updateAcl(parent); + parent = (AclImpl) this.jdbcMutableAclService.readAclById(parentOid); parent.insertAce(1, BasePermission.READ, new PrincipalSid("scott"), true); - jdbcMutableAclService.updateAcl(parent); - - child = (MutableAcl) jdbcMutableAclService.readAclById(childOid); + this.jdbcMutableAclService.updateAcl(parent); + child = (MutableAcl) this.jdbcMutableAclService.readAclById(childOid); parent = (MutableAcl) child.getParentAcl(); - - assertThat(parent.getEntries()).hasSize(2).withFailMessage("Fails because child has a stale reference to its parent"); + assertThat(parent.getEntries()).hasSize(2) + .withFailMessage("Fails because child has a stale reference to its parent"); assertThat(parent.getEntries().get(0).getPermission().getMask()).isEqualTo(1); assertThat(parent.getEntries().get(0).getSid()).isEqualTo(new PrincipalSid("ben")); assertThat(parent.getEntries().get(1).getPermission().getMask()).isEqualTo(1); @@ -487,34 +430,22 @@ public class JdbcMutableAclServiceTests extends @Test @Transactional public void childrenAreClearedFromCacheWhenParentisUpdated2() { - Authentication auth = new TestingAuthenticationToken("system", "secret", - "ROLE_IGNORED"); + Authentication auth = new TestingAuthenticationToken("system", "secret", "ROLE_IGNORED"); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentityImpl rootObject = new ObjectIdentityImpl(TARGET_CLASS, - 1L); - - MutableAcl parent = jdbcMutableAclService.createAcl(rootObject); - MutableAcl child = jdbcMutableAclService.createAcl(new ObjectIdentityImpl( - TARGET_CLASS, 2L)); + ObjectIdentityImpl rootObject = new ObjectIdentityImpl(TARGET_CLASS, 1L); + MutableAcl parent = this.jdbcMutableAclService.createAcl(rootObject); + MutableAcl child = this.jdbcMutableAclService.createAcl(new ObjectIdentityImpl(TARGET_CLASS, 2L)); child.setParent(parent); - jdbcMutableAclService.updateAcl(child); - - parent.insertAce(0, BasePermission.ADMINISTRATION, new GrantedAuthoritySid( - "ROLE_ADMINISTRATOR"), true); - jdbcMutableAclService.updateAcl(parent); - + this.jdbcMutableAclService.updateAcl(child); + parent.insertAce(0, BasePermission.ADMINISTRATION, new GrantedAuthoritySid("ROLE_ADMINISTRATOR"), true); + this.jdbcMutableAclService.updateAcl(parent); parent.insertAce(1, BasePermission.DELETE, new PrincipalSid("terry"), true); - jdbcMutableAclService.updateAcl(parent); - - child = (MutableAcl) jdbcMutableAclService.readAclById(new ObjectIdentityImpl( - TARGET_CLASS, 2L)); - + this.jdbcMutableAclService.updateAcl(parent); + child = (MutableAcl) this.jdbcMutableAclService.readAclById(new ObjectIdentityImpl(TARGET_CLASS, 2L)); parent = (MutableAcl) child.getParentAcl(); - assertThat(parent.getEntries()).hasSize(2); assertThat(parent.getEntries().get(0).getPermission().getMask()).isEqualTo(16); - assertThat(parent.getEntries() - .get(0).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_ADMINISTRATOR")); + assertThat(parent.getEntries().get(0).getSid()).isEqualTo(new GrantedAuthoritySid("ROLE_ADMINISTRATOR")); assertThat(parent.getEntries().get(1).getPermission().getMask()).isEqualTo(8); assertThat(parent.getEntries().get(1).getSid()).isEqualTo(new PrincipalSid("terry")); } @@ -522,56 +453,50 @@ public class JdbcMutableAclServiceTests extends @Test @Transactional public void cumulativePermissions() { - Authentication auth = new TestingAuthenticationToken("ben", "ignored", - "ROLE_ADMINISTRATOR"); + Authentication auth = new TestingAuthenticationToken("ben", "ignored", "ROLE_ADMINISTRATOR"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - - ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, - 110L); - MutableAcl topParent = jdbcMutableAclService.createAcl(topParentOid); - + ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS, 110L); + MutableAcl topParent = this.jdbcMutableAclService.createAcl(topParentOid); // Add an ACE permission entry - Permission cm = new CumulativePermission().set(BasePermission.READ).set( - BasePermission.ADMINISTRATION); + Permission cm = new CumulativePermission().set(BasePermission.READ).set(BasePermission.ADMINISTRATION); assertThat(cm.getMask()).isEqualTo(17); Sid benSid = new PrincipalSid(auth); topParent.insertAce(0, cm, benSid, true); assertThat(topParent.getEntries()).hasSize(1); - // Explicitly save the changed ACL - topParent = jdbcMutableAclService.updateAcl(topParent); - + topParent = this.jdbcMutableAclService.updateAcl(topParent); // Check the mask was retrieved correctly assertThat(topParent.getEntries().get(0).getPermission().getMask()).isEqualTo(17); assertThat(topParent.isGranted(Arrays.asList(cm), Arrays.asList(benSid), true)).isTrue(); - SecurityContextHolder.clearContext(); } @Test public void testProcessingCustomSid() { - CustomJdbcMutableAclService customJdbcMutableAclService = spy(new CustomJdbcMutableAclService( - dataSource, lookupStrategy, aclCache)); + CustomJdbcMutableAclService customJdbcMutableAclService = spy( + new CustomJdbcMutableAclService(this.dataSource, this.lookupStrategy, this.aclCache)); CustomSid customSid = new CustomSid("Custom sid"); - when( - customJdbcMutableAclService.createOrRetrieveSidPrimaryKey("Custom sid", - false, false)).thenReturn(1L); - - Long result = customJdbcMutableAclService.createOrRetrieveSidPrimaryKey( - customSid, false); - + given(customJdbcMutableAclService.createOrRetrieveSidPrimaryKey("Custom sid", false, false)).willReturn(1L); + Long result = customJdbcMutableAclService.createOrRetrieveSidPrimaryKey(customSid, false); assertThat(new Long(1L)).isEqualTo(result); } + protected Authentication getAuth() { + return this.auth; + } + + protected JdbcMutableAclService getJdbcMutableAclService() { + return this.jdbcMutableAclService; + } + /** * This class needed to show how to extend {@link JdbcMutableAclService} for * processing custom {@link Sid} implementations */ private class CustomJdbcMutableAclService extends JdbcMutableAclService { - private CustomJdbcMutableAclService(DataSource dataSource, - LookupStrategy lookupStrategy, AclCache aclCache) { + CustomJdbcMutableAclService(DataSource dataSource, LookupStrategy lookupStrategy, AclCache aclCache) { super(dataSource, lookupStrategy, aclCache); } @@ -591,13 +516,7 @@ public class JdbcMutableAclServiceTests extends } return createOrRetrieveSidPrimaryKey(sidName, isPrincipal, allowCreate); } + } - protected Authentication getAuth() { - return auth; - } - - protected JdbcMutableAclService getJdbcMutableAclService() { - return jdbcMutableAclService; - } } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTestsWithAclClassId.java b/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTestsWithAclClassId.java index 1d45b033aa..ab69977a56 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTestsWithAclClassId.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/JdbcMutableAclServiceTestsWithAclClassId.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.jdbc; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.acls.jdbc; import java.util.UUID; import org.junit.Test; + import org.springframework.security.acls.TargetObjectWithUUID; import org.springframework.security.acls.domain.ObjectIdentityImpl; import org.springframework.security.acls.model.ObjectIdentity; @@ -27,21 +27,24 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.test.context.ContextConfiguration; import org.springframework.transaction.annotation.Transactional; +import static org.assertj.core.api.Assertions.assertThat; + /** - * Integration tests the ACL system using ACL class id type of UUID and using an in-memory database. + * Integration tests the ACL system using ACL class id type of UUID and using an in-memory + * database. + * * @author Paul Wheeler */ -@ContextConfiguration(locations = {"/jdbcMutableAclServiceTestsWithAclClass-context.xml"}) +@ContextConfiguration(locations = { "/jdbcMutableAclServiceTestsWithAclClass-context.xml" }) public class JdbcMutableAclServiceTestsWithAclClassId extends JdbcMutableAclServiceTests { private static final String TARGET_CLASS_WITH_UUID = TargetObjectWithUUID.class.getName(); - private final ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, - UUID.randomUUID()); - private final ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, - UUID.randomUUID()); - private final ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, - UUID.randomUUID()); + private final ObjectIdentity topParentOid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, UUID.randomUUID()); + + private final ObjectIdentity middleParentOid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, UUID.randomUUID()); + + private final ObjectIdentity childOid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, UUID.randomUUID()); @Override protected String getSqlClassPathResource() { @@ -50,17 +53,17 @@ public class JdbcMutableAclServiceTestsWithAclClassId extends JdbcMutableAclServ @Override protected ObjectIdentity getTopParentOid() { - return topParentOid; + return this.topParentOid; } @Override protected ObjectIdentity getMiddleParentOid() { - return middleParentOid; + return this.middleParentOid; } @Override protected ObjectIdentity getChildOid() { - return childOid; + return this.childOid; } @Override @@ -72,12 +75,11 @@ public class JdbcMutableAclServiceTestsWithAclClassId extends JdbcMutableAclServ @Transactional public void identityWithUuidIdIsSupportedByCreateAcl() { SecurityContextHolder.getContext().setAuthentication(getAuth()); - UUID id = UUID.randomUUID(); ObjectIdentity oid = new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, id); getJdbcMutableAclService().createAcl(oid); - - assertThat(getJdbcMutableAclService().readAclById(new ObjectIdentityImpl( - TARGET_CLASS_WITH_UUID, id))).isNotNull(); + assertThat(getJdbcMutableAclService().readAclById(new ObjectIdentityImpl(TARGET_CLASS_WITH_UUID, id))) + .isNotNull(); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/jdbc/SpringCacheBasedAclCacheTests.java b/acl/src/test/java/org/springframework/security/acls/jdbc/SpringCacheBasedAclCacheTests.java index 4d9de2f556..9a3bd62400 100644 --- a/acl/src/test/java/org/springframework/security/acls/jdbc/SpringCacheBasedAclCacheTests.java +++ b/acl/src/test/java/org/springframework/security/acls/jdbc/SpringCacheBasedAclCacheTests.java @@ -13,15 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.jdbc; +import java.util.Map; + import org.junit.After; import org.junit.BeforeClass; import org.junit.Test; + import org.springframework.cache.Cache; import org.springframework.cache.CacheManager; import org.springframework.cache.concurrent.ConcurrentMapCacheManager; -import org.springframework.security.acls.domain.*; +import org.springframework.security.acls.domain.AclAuthorizationStrategy; +import org.springframework.security.acls.domain.AclAuthorizationStrategyImpl; +import org.springframework.security.acls.domain.AclImpl; +import org.springframework.security.acls.domain.AuditLogger; +import org.springframework.security.acls.domain.ConsoleAuditLogger; +import org.springframework.security.acls.domain.DefaultPermissionGrantingStrategy; +import org.springframework.security.acls.domain.ObjectIdentityImpl; +import org.springframework.security.acls.domain.SpringCacheBasedAclCache; import org.springframework.security.acls.model.MutableAcl; import org.springframework.security.acls.model.ObjectIdentity; import org.springframework.security.acls.model.PermissionGrantingStrategy; @@ -31,9 +42,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.util.FieldUtils; -import java.util.Map; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * Tests {@link org.springframework.security.acls.domain.SpringCacheBasedAclCache} @@ -41,6 +50,7 @@ import static org.assertj.core.api.Assertions.*; * @author Marten Deinum */ public class SpringCacheBasedAclCacheTests { + private static final String TARGET_CLASS = "org.springframework.security.acls.TargetObject"; private static CacheManager cacheManager; @@ -76,43 +86,31 @@ public class SpringCacheBasedAclCacheTests { Map realCache = (Map) cache.getNativeCache(); ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 100L); AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); AuditLogger auditLogger = new ConsoleAuditLogger(); - - PermissionGrantingStrategy permissionGrantingStrategy = new DefaultPermissionGrantingStrategy( - auditLogger); - SpringCacheBasedAclCache myCache = new SpringCacheBasedAclCache(cache, - permissionGrantingStrategy, aclAuthorizationStrategy); - MutableAcl acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, - auditLogger); - + PermissionGrantingStrategy permissionGrantingStrategy = new DefaultPermissionGrantingStrategy(auditLogger); + SpringCacheBasedAclCache myCache = new SpringCacheBasedAclCache(cache, permissionGrantingStrategy, + aclAuthorizationStrategy); + MutableAcl acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, auditLogger); assertThat(realCache).isEmpty(); myCache.putInCache(acl); - // Check we can get from cache the same objects we put in assertThat(acl).isEqualTo(myCache.getFromCache(1L)); assertThat(acl).isEqualTo(myCache.getFromCache(identity)); - // Put another object in cache ObjectIdentity identity2 = new ObjectIdentityImpl(TARGET_CLASS, 101L); - MutableAcl acl2 = new AclImpl(identity2, 2L, - aclAuthorizationStrategy, new ConsoleAuditLogger()); - + MutableAcl acl2 = new AclImpl(identity2, 2L, aclAuthorizationStrategy, new ConsoleAuditLogger()); myCache.putInCache(acl2); - // Try to evict an entry that doesn't exist myCache.evictFromCache(3L); myCache.evictFromCache(new ObjectIdentityImpl(TARGET_CLASS, 102L)); assertThat(realCache).hasSize(4); - myCache.evictFromCache(1L); assertThat(realCache).hasSize(2); - // Check the second object inserted assertThat(acl2).isEqualTo(myCache.getFromCache(2L)); assertThat(acl2).isEqualTo(myCache.getFromCache(identity2)); - myCache.evictFromCache(identity2); assertThat(realCache).isEmpty(); } @@ -122,50 +120,36 @@ public class SpringCacheBasedAclCacheTests { public void cacheOperationsAclWithParent() throws Exception { Cache cache = getCache(); Map realCache = (Map) cache.getNativeCache(); - - Authentication auth = new TestingAuthenticationToken("user", "password", - "ROLE_GENERAL"); + Authentication auth = new TestingAuthenticationToken("user", "password", "ROLE_GENERAL"); auth.setAuthenticated(true); SecurityContextHolder.getContext().setAuthentication(auth); - ObjectIdentity identity = new ObjectIdentityImpl(TARGET_CLASS, 1L); - ObjectIdentity identityParent = new ObjectIdentityImpl(TARGET_CLASS, - 2L); + ObjectIdentity identityParent = new ObjectIdentityImpl(TARGET_CLASS, 2L); AclAuthorizationStrategy aclAuthorizationStrategy = new AclAuthorizationStrategyImpl( - new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority( - "ROLE_AUDITING"), new SimpleGrantedAuthority("ROLE_GENERAL")); + new SimpleGrantedAuthority("ROLE_OWNERSHIP"), new SimpleGrantedAuthority("ROLE_AUDITING"), + new SimpleGrantedAuthority("ROLE_GENERAL")); AuditLogger auditLogger = new ConsoleAuditLogger(); - - PermissionGrantingStrategy permissionGrantingStrategy = new DefaultPermissionGrantingStrategy( - auditLogger); - SpringCacheBasedAclCache myCache = new SpringCacheBasedAclCache(cache, - permissionGrantingStrategy, aclAuthorizationStrategy); - - MutableAcl acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, - auditLogger); - MutableAcl parentAcl = new AclImpl(identityParent, 2L, - aclAuthorizationStrategy, auditLogger); - + PermissionGrantingStrategy permissionGrantingStrategy = new DefaultPermissionGrantingStrategy(auditLogger); + SpringCacheBasedAclCache myCache = new SpringCacheBasedAclCache(cache, permissionGrantingStrategy, + aclAuthorizationStrategy); + MutableAcl acl = new AclImpl(identity, 1L, aclAuthorizationStrategy, auditLogger); + MutableAcl parentAcl = new AclImpl(identityParent, 2L, aclAuthorizationStrategy, auditLogger); acl.setParent(parentAcl); - assertThat(realCache).isEmpty(); myCache.putInCache(acl); assertThat(4).isEqualTo(realCache.size()); - // Check we can get from cache the same objects we put in AclImpl aclFromCache = (AclImpl) myCache.getFromCache(1L); assertThat(aclFromCache).isEqualTo(acl); // SEC-951 check transient fields are set on parent - assertThat(FieldUtils.getFieldValue(aclFromCache.getParentAcl(), - "aclAuthorizationStrategy")).isNotNull(); - assertThat(FieldUtils.getFieldValue(aclFromCache.getParentAcl(), - "permissionGrantingStrategy")).isNotNull(); + assertThat(FieldUtils.getFieldValue(aclFromCache.getParentAcl(), "aclAuthorizationStrategy")).isNotNull(); + assertThat(FieldUtils.getFieldValue(aclFromCache.getParentAcl(), "permissionGrantingStrategy")).isNotNull(); assertThat(myCache.getFromCache(identity)).isEqualTo(acl); assertThat(FieldUtils.getFieldValue(aclFromCache, "aclAuthorizationStrategy")).isNotNull(); AclImpl parentAclFromCache = (AclImpl) myCache.getFromCache(2L); assertThat(parentAclFromCache).isEqualTo(parentAcl); - assertThat(FieldUtils.getFieldValue(parentAclFromCache, - "aclAuthorizationStrategy")).isNotNull(); + assertThat(FieldUtils.getFieldValue(parentAclFromCache, "aclAuthorizationStrategy")).isNotNull(); assertThat(myCache.getFromCache(identityParent)).isEqualTo(parentAcl); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/sid/CustomSid.java b/acl/src/test/java/org/springframework/security/acls/sid/CustomSid.java index 21226f1888..92d0ceef48 100644 --- a/acl/src/test/java/org/springframework/security/acls/sid/CustomSid.java +++ b/acl/src/test/java/org/springframework/security/acls/sid/CustomSid.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.sid; import org.springframework.security.acls.model.Sid; /** * This class is example of custom {@link Sid} implementation + * * @author Mikhail Stryzhonok */ public class CustomSid implements Sid { @@ -30,10 +32,11 @@ public class CustomSid implements Sid { } public String getSid() { - return sid; + return this.sid; } public void setSid(String sid) { this.sid = sid; } + } diff --git a/acl/src/test/java/org/springframework/security/acls/sid/SidRetrievalStrategyTests.java b/acl/src/test/java/org/springframework/security/acls/sid/SidRetrievalStrategyTests.java index ba61b41d01..5922f8f670 100644 --- a/acl/src/test/java/org/springframework/security/acls/sid/SidRetrievalStrategyTests.java +++ b/acl/src/test/java/org/springframework/security/acls/sid/SidRetrievalStrategyTests.java @@ -13,14 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.acls.sid; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.acls.sid; import java.util.List; import org.junit.Test; + import org.springframework.security.access.hierarchicalroles.RoleHierarchy; import org.springframework.security.acls.domain.GrantedAuthoritySid; import org.springframework.security.acls.domain.PrincipalSid; @@ -31,6 +30,11 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests for {@link SidRetrievalStrategyImpl} * @@ -39,26 +43,20 @@ import org.springframework.security.core.authority.AuthorityUtils; */ @SuppressWarnings("unchecked") public class SidRetrievalStrategyTests { - Authentication authentication = new TestingAuthenticationToken("scott", "password", - "A", "B", "C"); - // ~ Methods - // ======================================================================================================== + Authentication authentication = new TestingAuthenticationToken("scott", "password", "A", "B", "C"); @Test public void correctSidsAreRetrieved() { SidRetrievalStrategy retrStrategy = new SidRetrievalStrategyImpl(); - List sids = retrStrategy.getSids(authentication); - + List sids = retrStrategy.getSids(this.authentication); assertThat(sids).isNotNull(); assertThat(sids).hasSize(4); assertThat(sids.get(0)).isNotNull(); assertThat(sids.get(0) instanceof PrincipalSid).isTrue(); - for (int i = 1; i < sids.size(); i++) { assertThat(sids.get(i) instanceof GrantedAuthoritySid).isTrue(); } - assertThat(((PrincipalSid) sids.get(0)).getPrincipal()).isEqualTo("scott"); assertThat(((GrantedAuthoritySid) sids.get(1)).getGrantedAuthority()).isEqualTo("A"); assertThat(((GrantedAuthoritySid) sids.get(2)).getGrantedAuthority()).isEqualTo("B"); @@ -69,14 +67,13 @@ public class SidRetrievalStrategyTests { public void roleHierarchyIsUsedWhenSet() { RoleHierarchy rh = mock(RoleHierarchy.class); List rhAuthorities = AuthorityUtils.createAuthorityList("D"); - when(rh.getReachableGrantedAuthorities(anyCollection())) - .thenReturn(rhAuthorities); + given(rh.getReachableGrantedAuthorities(anyCollection())).willReturn(rhAuthorities); SidRetrievalStrategy strat = new SidRetrievalStrategyImpl(rh); - - List sids = strat.getSids(authentication); + List sids = strat.getSids(this.authentication); assertThat(sids).hasSize(2); assertThat(sids.get(0)).isNotNull(); assertThat(sids.get(0) instanceof PrincipalSid).isTrue(); assertThat(((GrantedAuthoritySid) sids.get(1)).getGrantedAuthority()).isEqualTo("D"); } + } diff --git a/acl/src/test/java/org/springframework/security/acls/sid/SidTests.java b/acl/src/test/java/org/springframework/security/acls/sid/SidTests.java index b65c1cb906..3b566e8c92 100644 --- a/acl/src/test/java/org/springframework/security/acls/sid/SidTests.java +++ b/acl/src/test/java/org/springframework/security/acls/sid/SidTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.acls.sid; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +import java.util.Collection; +import java.util.Collections; import org.junit.Test; + import org.springframework.security.acls.domain.GrantedAuthoritySid; import org.springframework.security.acls.domain.PrincipalSid; import org.springframework.security.acls.model.Sid; @@ -29,13 +31,11 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.User; -import java.util.Collection; -import java.util.Collections; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; public class SidTests { - // ~ Methods - // ======================================================================================================== @Test public void testPrincipalSidConstructorsRequiredFields() { // Check one String-argument constructor @@ -46,17 +46,14 @@ public class SidTests { } catch (IllegalArgumentException expected) { } - try { new PrincipalSid(""); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - new PrincipalSid("johndoe"); // throws no exception - // Check one Authentication-argument constructor try { Authentication authentication = null; @@ -65,18 +62,14 @@ public class SidTests { } catch (IllegalArgumentException expected) { } - try { - Authentication authentication = new TestingAuthenticationToken(null, - "password"); + Authentication authentication = new TestingAuthenticationToken(null, "password"); new PrincipalSid(authentication); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - - Authentication authentication = new TestingAuthenticationToken("johndoe", - "password"); + Authentication authentication = new TestingAuthenticationToken("johndoe", "password"); new PrincipalSid(authentication); // throws no exception } @@ -90,25 +83,19 @@ public class SidTests { fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - try { new GrantedAuthoritySid(""); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - try { new GrantedAuthoritySid("ROLE_TEST"); - } catch (IllegalArgumentException notExpected) { fail("It shouldn't have thrown IllegalArgumentException"); } - // Check one GrantedAuthority-argument constructor try { GrantedAuthority ga = null; @@ -116,22 +103,17 @@ public class SidTests { fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - try { GrantedAuthority ga = new SimpleGrantedAuthority(null); new GrantedAuthoritySid(ga); fail("It should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - try { GrantedAuthority ga = new SimpleGrantedAuthority("ROLE_TEST"); new GrantedAuthoritySid(ga); - } catch (IllegalArgumentException notExpected) { fail("It shouldn't have thrown IllegalArgumentException"); @@ -140,18 +122,14 @@ public class SidTests { @Test public void testPrincipalSidEquals() { - Authentication authentication = new TestingAuthenticationToken("johndoe", - "password"); + Authentication authentication = new TestingAuthenticationToken("johndoe", "password"); Sid principalSid = new PrincipalSid(authentication); - assertThat(principalSid.equals(null)).isFalse(); assertThat(principalSid.equals("DIFFERENT_TYPE_OBJECT")).isFalse(); assertThat(principalSid.equals(principalSid)).isTrue(); assertThat(principalSid.equals(new PrincipalSid(authentication))).isTrue(); - assertThat(principalSid.equals(new PrincipalSid( - new TestingAuthenticationToken("johndoe", null)))).isTrue(); - assertThat(principalSid.equals(new PrincipalSid( - new TestingAuthenticationToken("scott", null)))).isFalse(); + assertThat(principalSid.equals(new PrincipalSid(new TestingAuthenticationToken("johndoe", null)))).isTrue(); + assertThat(principalSid.equals(new PrincipalSid(new TestingAuthenticationToken("scott", null)))).isFalse(); assertThat(principalSid.equals(new PrincipalSid("johndoe"))).isTrue(); assertThat(principalSid.equals(new PrincipalSid("scott"))).isFalse(); } @@ -160,59 +138,46 @@ public class SidTests { public void testGrantedAuthoritySidEquals() { GrantedAuthority ga = new SimpleGrantedAuthority("ROLE_TEST"); Sid gaSid = new GrantedAuthoritySid(ga); - assertThat(gaSid.equals(null)).isFalse(); assertThat(gaSid.equals("DIFFERENT_TYPE_OBJECT")).isFalse(); assertThat(gaSid.equals(gaSid)).isTrue(); assertThat(gaSid.equals(new GrantedAuthoritySid(ga))).isTrue(); - assertThat(gaSid.equals(new GrantedAuthoritySid( - new SimpleGrantedAuthority("ROLE_TEST")))).isTrue(); - assertThat(gaSid.equals(new GrantedAuthoritySid( - new SimpleGrantedAuthority("ROLE_NOT_EQUAL")))).isFalse(); + assertThat(gaSid.equals(new GrantedAuthoritySid(new SimpleGrantedAuthority("ROLE_TEST")))).isTrue(); + assertThat(gaSid.equals(new GrantedAuthoritySid(new SimpleGrantedAuthority("ROLE_NOT_EQUAL")))).isFalse(); assertThat(gaSid.equals(new GrantedAuthoritySid("ROLE_TEST"))).isTrue(); assertThat(gaSid.equals(new GrantedAuthoritySid("ROLE_NOT_EQUAL"))).isFalse(); } @Test public void testPrincipalSidHashCode() { - Authentication authentication = new TestingAuthenticationToken("johndoe", - "password"); + Authentication authentication = new TestingAuthenticationToken("johndoe", "password"); Sid principalSid = new PrincipalSid(authentication); - assertThat(principalSid.hashCode()).isEqualTo("johndoe".hashCode()); - assertThat(principalSid.hashCode()).isEqualTo( - new PrincipalSid("johndoe").hashCode()); - assertThat(principalSid.hashCode()).isNotEqualTo( - new PrincipalSid("scott").hashCode()); - assertThat(principalSid.hashCode()).isNotEqualTo(new PrincipalSid( - new TestingAuthenticationToken("scott", "password")).hashCode()); + assertThat(principalSid.hashCode()).isEqualTo(new PrincipalSid("johndoe").hashCode()); + assertThat(principalSid.hashCode()).isNotEqualTo(new PrincipalSid("scott").hashCode()); + assertThat(principalSid.hashCode()) + .isNotEqualTo(new PrincipalSid(new TestingAuthenticationToken("scott", "password")).hashCode()); } @Test public void testGrantedAuthoritySidHashCode() { GrantedAuthority ga = new SimpleGrantedAuthority("ROLE_TEST"); Sid gaSid = new GrantedAuthoritySid(ga); - assertThat(gaSid.hashCode()).isEqualTo("ROLE_TEST".hashCode()); - assertThat(gaSid.hashCode()).isEqualTo( - new GrantedAuthoritySid("ROLE_TEST").hashCode()); - assertThat(gaSid.hashCode()).isNotEqualTo( - new GrantedAuthoritySid("ROLE_TEST_2").hashCode()); - assertThat(gaSid.hashCode()).isNotEqualTo(new GrantedAuthoritySid( - new SimpleGrantedAuthority("ROLE_TEST_2")).hashCode()); + assertThat(gaSid.hashCode()).isEqualTo(new GrantedAuthoritySid("ROLE_TEST").hashCode()); + assertThat(gaSid.hashCode()).isNotEqualTo(new GrantedAuthoritySid("ROLE_TEST_2").hashCode()); + assertThat(gaSid.hashCode()) + .isNotEqualTo(new GrantedAuthoritySid(new SimpleGrantedAuthority("ROLE_TEST_2")).hashCode()); } @Test public void testGetters() { - Authentication authentication = new TestingAuthenticationToken("johndoe", - "password"); + Authentication authentication = new TestingAuthenticationToken("johndoe", "password"); PrincipalSid principalSid = new PrincipalSid(authentication); GrantedAuthority ga = new SimpleGrantedAuthority("ROLE_TEST"); GrantedAuthoritySid gaSid = new GrantedAuthoritySid(ga); - assertThat("johndoe".equals(principalSid.getPrincipal())).isTrue(); assertThat("scott".equals(principalSid.getPrincipal())).isFalse(); - assertThat("ROLE_TEST".equals(gaSid.getGrantedAuthority())).isTrue(); assertThat("ROLE_TEST2".equals(gaSid.getGrantedAuthority())).isFalse(); } @@ -222,7 +187,6 @@ public class SidTests { User user = new User("user", "password", Collections.singletonList(new SimpleGrantedAuthority("ROLE_TEST"))); Authentication authentication = new TestingAuthenticationToken(user, "password"); PrincipalSid principalSid = new PrincipalSid(authentication); - assertThat("user").isEqualTo(principalSid.getPrincipal()); } @@ -230,7 +194,6 @@ public class SidTests { public void getPrincipalWhenPrincipalNotInstanceOfUserDetailsThenReturnsPrincipalName() { Authentication authentication = new TestingAuthenticationToken("token", "password"); PrincipalSid principalSid = new PrincipalSid(authentication); - assertThat("token").isEqualTo(principalSid.getPrincipal()); } @@ -238,11 +201,11 @@ public class SidTests { public void getPrincipalWhenCustomAuthenticationPrincipalThenReturnsPrincipalName() { Authentication authentication = new CustomAuthenticationToken(new CustomToken("token"), null); PrincipalSid principalSid = new PrincipalSid(authentication); - assertThat("token").isEqualTo(principalSid.getPrincipal()); } static class CustomAuthenticationToken extends AbstractAuthenticationToken { + private CustomToken principal; CustomAuthenticationToken(CustomToken principal, Collection authorities) { @@ -262,11 +225,13 @@ public class SidTests { @Override public String getName() { - return principal.getName(); + return this.principal.getName(); } + } static class CustomToken { + private String name; CustomToken(String name) { @@ -274,7 +239,9 @@ public class SidTests { } String getName() { - return name; + return this.name; } + } + } diff --git a/aspects/src/test/java/org/springframework/security/access/intercept/aspectj/aspect/AnnotationSecurityAspectTests.java b/aspects/src/test/java/org/springframework/security/access/intercept/aspectj/aspect/AnnotationSecurityAspectTests.java index f0d7570722..4296c72cec 100644 --- a/aspects/src/test/java/org/springframework/security/access/intercept/aspectj/aspect/AnnotationSecurityAspectTests.java +++ b/aspects/src/test/java/org/springframework/security/access/intercept/aspectj/aspect/AnnotationSecurityAspectTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.intercept.aspectj.aspect; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.access.intercept.aspectj.aspect; import java.util.ArrayList; import java.util.Arrays; @@ -49,39 +48,44 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** - * * @author Luke Taylor * @since 3.0.3 */ public class AnnotationSecurityAspectTests { + private AffirmativeBased adm; - private @Mock AuthenticationManager authman; - private TestingAuthenticationToken anne = new TestingAuthenticationToken("anne", "", - "ROLE_A"); + + @Mock + private AuthenticationManager authman; + + private TestingAuthenticationToken anne = new TestingAuthenticationToken("anne", "", "ROLE_A"); + // private TestingAuthenticationToken bob = new TestingAuthenticationToken("bob", "", // "ROLE_B"); private AspectJMethodSecurityInterceptor interceptor; + private SecuredImpl secured = new SecuredImpl(); + private SecuredImplSubclass securedSub = new SecuredImplSubclass(); + private PrePostSecured prePostSecured = new PrePostSecured(); @Before public final void setUp() { MockitoAnnotations.initMocks(this); - interceptor = new AspectJMethodSecurityInterceptor(); - AccessDecisionVoter[] voters = new AccessDecisionVoter[] { - new RoleVoter(), - new PreInvocationAuthorizationAdviceVoter( - new ExpressionBasedPreInvocationAdvice()) }; - adm = new AffirmativeBased( - Arrays.> asList(voters)); - interceptor.setAccessDecisionManager(adm); - interceptor.setAuthenticationManager(authman); - interceptor - .setSecurityMetadataSource(new SecuredAnnotationSecurityMetadataSource()); + this.interceptor = new AspectJMethodSecurityInterceptor(); + AccessDecisionVoter[] voters = new AccessDecisionVoter[] { new RoleVoter(), + new PreInvocationAuthorizationAdviceVoter(new ExpressionBasedPreInvocationAdvice()) }; + this.adm = new AffirmativeBased(Arrays.>asList(voters)); + this.interceptor.setAccessDecisionManager(this.adm); + this.interceptor.setAuthenticationManager(this.authman); + this.interceptor.setSecurityMetadataSource(new SecuredAnnotationSecurityMetadataSource()); AnnotationSecurityAspect secAspect = AnnotationSecurityAspect.aspectOf(); - secAspect.setSecurityInterceptor(interceptor); + secAspect.setSecurityInterceptor(this.interceptor); } @After @@ -91,59 +95,57 @@ public class AnnotationSecurityAspectTests { @Test public void securedInterfaceMethodAllowsAllAccess() { - secured.securedMethod(); + this.secured.securedMethod(); } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void securedClassMethodDeniesUnauthenticatedAccess() { - secured.securedClassMethod(); + this.secured.securedClassMethod(); } @Test public void securedClassMethodAllowsAccessToRoleA() { - SecurityContextHolder.getContext().setAuthentication(anne); - secured.securedClassMethod(); + SecurityContextHolder.getContext().setAuthentication(this.anne); + this.secured.securedClassMethod(); } @Test(expected = AccessDeniedException.class) public void internalPrivateCallIsIntercepted() { - SecurityContextHolder.getContext().setAuthentication(anne); - + SecurityContextHolder.getContext().setAuthentication(this.anne); try { - secured.publicCallsPrivate(); + this.secured.publicCallsPrivate(); fail("Expected AccessDeniedException"); } catch (AccessDeniedException expected) { } - securedSub.publicCallsPrivate(); + this.securedSub.publicCallsPrivate(); } @Test(expected = AccessDeniedException.class) public void protectedMethodIsIntercepted() { - SecurityContextHolder.getContext().setAuthentication(anne); - - secured.protectedMethod(); + SecurityContextHolder.getContext().setAuthentication(this.anne); + this.secured.protectedMethod(); } @Test public void overriddenProtectedMethodIsNotIntercepted() { // AspectJ doesn't inherit annotations - securedSub.protectedMethod(); + this.securedSub.protectedMethod(); } // SEC-1262 @Test(expected = AccessDeniedException.class) public void denyAllPreAuthorizeDeniesAccess() { configureForElAnnotations(); - SecurityContextHolder.getContext().setAuthentication(anne); - prePostSecured.denyAllMethod(); + SecurityContextHolder.getContext().setAuthentication(this.anne); + this.prePostSecured.denyAllMethod(); } @Test public void postFilterIsApplied() { configureForElAnnotations(); - SecurityContextHolder.getContext().setAuthentication(anne); - List objects = prePostSecured.postFilterMethod(); + SecurityContextHolder.getContext().setAuthentication(this.anne); + List objects = this.prePostSecured.postFilterMethod(); assertThat(objects).hasSize(2); assertThat(objects.contains("apple")).isTrue(); assertThat(objects.contains("aubergine")).isTrue(); @@ -151,64 +153,73 @@ public class AnnotationSecurityAspectTests { private void configureForElAnnotations() { DefaultMethodSecurityExpressionHandler eh = new DefaultMethodSecurityExpressionHandler(); - interceptor - .setSecurityMetadataSource(new PrePostAnnotationSecurityMetadataSource( - new ExpressionBasedAnnotationAttributeFactory(eh))); - interceptor.setAccessDecisionManager(adm); + this.interceptor.setSecurityMetadataSource( + new PrePostAnnotationSecurityMetadataSource(new ExpressionBasedAnnotationAttributeFactory(eh))); + this.interceptor.setAccessDecisionManager(this.adm); AfterInvocationProviderManager aim = new AfterInvocationProviderManager(); - aim.setProviders(Arrays.asList(new PostInvocationAdviceProvider( - new ExpressionBasedPostInvocationAdvice(eh)))); - interceptor.setAfterInvocationManager(aim); - } -} - -interface SecuredInterface { - @Secured("ROLE_X") - void securedMethod(); -} - -class SecuredImpl implements SecuredInterface { - // Not really secured because AspectJ doesn't inherit annotations from interfaces - public void securedMethod() { - } - - @Secured("ROLE_A") - public void securedClassMethod() { - } - - @Secured("ROLE_X") - private void privateMethod() { - } - - @Secured("ROLE_X") - protected void protectedMethod() { - } - - @Secured("ROLE_X") - public void publicCallsPrivate() { - privateMethod(); - } -} - -class SecuredImplSubclass extends SecuredImpl { - protected void protectedMethod() { - } - - public void publicCallsPrivate() { - super.publicCallsPrivate(); - } -} - -class PrePostSecured { - @PreAuthorize("denyAll") - public void denyAllMethod() { - } - - @PostFilter("filterObject.startsWith('a')") - public List postFilterMethod() { - ArrayList objects = new ArrayList<>(); - objects.addAll(Arrays.asList(new String[] { "apple", "banana", "aubergine", - "orange" })); - return objects; + aim.setProviders(Arrays.asList(new PostInvocationAdviceProvider(new ExpressionBasedPostInvocationAdvice(eh)))); + this.interceptor.setAfterInvocationManager(aim); } + + interface SecuredInterface { + + @Secured("ROLE_X") + void securedMethod(); + + } + + static class SecuredImpl implements SecuredInterface { + + // Not really secured because AspectJ doesn't inherit annotations from interfaces + @Override + public void securedMethod() { + } + + @Secured("ROLE_A") + public void securedClassMethod() { + } + + @Secured("ROLE_X") + private void privateMethod() { + } + + @Secured("ROLE_X") + protected void protectedMethod() { + } + + @Secured("ROLE_X") + public void publicCallsPrivate() { + privateMethod(); + } + + } + + static class SecuredImplSubclass extends SecuredImpl { + + @Override + protected void protectedMethod() { + } + + @Override + public void publicCallsPrivate() { + super.publicCallsPrivate(); + } + + } + + static class PrePostSecured { + + @PreAuthorize("denyAll") + public void denyAllMethod() { + } + + @PostFilter("filterObject.startsWith('a')") + public List postFilterMethod() { + ArrayList objects = new ArrayList<>(); + objects.addAll(Arrays.asList(new String[] { "apple", "banana", "aubergine", "orange" })); + return objects; + } + + } + } diff --git a/build.gradle b/build.gradle index ace768def7..fb65d2623f 100644 --- a/build.gradle +++ b/build.gradle @@ -1,6 +1,7 @@ buildscript { dependencies { classpath 'io.spring.gradle:spring-build-conventions:0.0.33.RELEASE' + classpath "io.spring.javaformat:spring-javaformat-gradle-plugin:$springJavaformatVersion" classpath "org.springframework.boot:spring-boot-gradle-plugin:$springBootVersion" classpath 'io.spring.nohttp:nohttp-gradle:0.0.5.RELEASE' classpath "io.freefair.gradle:aspectj-plugin:5.0.1" @@ -34,12 +35,35 @@ subprojects { plugins.withType(JavaPlugin) { project.sourceCompatibility='1.8' } - tasks.withType(JavaCompile) { options.encoding = "UTF-8" } } +allprojects { + apply plugin: 'io.spring.javaformat' + apply plugin: 'checkstyle' + + pluginManager.withPlugin("io.spring.convention.checkstyle", { plugin -> + configure(plugin) { + dependencies { + checkstyle "io.spring.javaformat:spring-javaformat-checkstyle:$springJavaformatVersion" + } + checkstyle { + toolVersion = '8.34' + } + } + }) + + if (project.name.contains('sample')) { + tasks.whenTaskAdded { task -> + if (task.name.contains('format') || task.name.contains('checkFormat') || task.name.contains("checkstyle")) { + task.enabled = false + } + } + } +} + nohttp { allowlistFile = project.file("etc/nohttp/allowlist.lines") } diff --git a/buildSrc/src/main/java/lock/GlobalLockPlugin.java b/buildSrc/src/main/java/lock/GlobalLockPlugin.java index 955e73e681..ee458eb0d4 100644 --- a/buildSrc/src/main/java/lock/GlobalLockPlugin.java +++ b/buildSrc/src/main/java/lock/GlobalLockPlugin.java @@ -9,7 +9,7 @@ import org.gradle.api.Project; public class GlobalLockPlugin implements Plugin { @Override public void apply(Project project) { - project.getTasks().register("writeLocks", GlobalLockTask.class, writeAll -> { + project.getTasks().register("writeLocks", GlobalLockTask.class, (writeAll) -> { writeAll.setDescription("Writes the locks for all projects"); }); } diff --git a/buildSrc/src/main/java/trang/TrangPlugin.java b/buildSrc/src/main/java/trang/TrangPlugin.java index d447b64bb4..02ca3746ad 100644 --- a/buildSrc/src/main/java/trang/TrangPlugin.java +++ b/buildSrc/src/main/java/trang/TrangPlugin.java @@ -10,7 +10,7 @@ import org.gradle.api.Project; public class TrangPlugin implements Plugin { @Override public void apply(Project project) { - project.getTasks().register("rncToXsd", RncToXsd.class, rncToXsd -> { + project.getTasks().register("rncToXsd", RncToXsd.class, (rncToXsd) -> { rncToXsd.setDescription("Converts .rnc to .xsd"); rncToXsd.setGroup("Build"); }); diff --git a/cas/src/main/java/org/springframework/security/cas/SamlServiceProperties.java b/cas/src/main/java/org/springframework/security/cas/SamlServiceProperties.java index 6be415859e..8e300f8f8e 100644 --- a/cas/src/main/java/org/springframework/security/cas/SamlServiceProperties.java +++ b/cas/src/main/java/org/springframework/security/cas/SamlServiceProperties.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas; /** @@ -32,4 +33,5 @@ public final class SamlServiceProperties extends ServiceProperties { super.setArtifactParameter(DEFAULT_SAML_ARTIFACT_PARAMETER); super.setServiceParameter(DEFAULT_SAML_SERVICE_PARAMETER); } + } diff --git a/cas/src/main/java/org/springframework/security/cas/ServiceProperties.java b/cas/src/main/java/org/springframework/security/cas/ServiceProperties.java index e63742222c..caf03dd62a 100644 --- a/cas/src/main/java/org/springframework/security/cas/ServiceProperties.java +++ b/cas/src/main/java/org/springframework/security/cas/ServiceProperties.java @@ -34,9 +34,6 @@ public class ServiceProperties implements InitializingBean { public static final String DEFAULT_CAS_SERVICE_PARAMETER = "service"; - // ~ Instance fields - // ================================================================================================ - private String service; private boolean authenticateAllArtifacts; @@ -47,9 +44,7 @@ public class ServiceProperties implements InitializingBean { private String serviceParameter = DEFAULT_CAS_SERVICE_PARAMETER; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { Assert.hasLength(this.service, "service cannot be empty."); Assert.hasLength(this.artifactParameter, "artifactParameter cannot be empty."); @@ -65,7 +60,6 @@ public class ServiceProperties implements InitializingBean { *

 	 * https://www.mycompany.com/application/login/cas
 	 * 
- * * @return the URL of the service the user is authenticating to */ public final String getService() { @@ -81,7 +75,6 @@ public class ServiceProperties implements InitializingBean { * ticket was generated as a consequence of an explicit login. High security * applications would probably set this to true. Defaults to * false, providing automated single sign on. - * * @return whether to send the renew parameter to CAS */ public final boolean isSendRenew() { @@ -103,7 +96,6 @@ public class ServiceProperties implements InitializingBean { /** * Configures the Request Parameter to look for when attempting to see if a CAS ticket * was sent from the server. - * * @param artifactParameter the id to use. Default is "ticket". */ public final void setArtifactParameter(final String artifactParameter) { @@ -113,7 +105,6 @@ public class ServiceProperties implements InitializingBean { /** * Configures the Request parameter to look for when attempting to send a request to * CAS. - * * @return the service parameter to use. Default is "service". */ public final String getServiceParameter() { @@ -132,11 +123,10 @@ public class ServiceProperties implements InitializingBean { * If true, then any non-null artifact (ticket) should be authenticated. Additionally, * the service will be determined dynamically in order to ensure the service matches * the expected value for this artifact. - * * @param authenticateAllArtifacts */ - public final void setAuthenticateAllArtifacts( - final boolean authenticateAllArtifacts) { + public final void setAuthenticateAllArtifacts(final boolean authenticateAllArtifacts) { this.authenticateAllArtifacts = authenticateAllArtifacts; } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/CasAssertionAuthenticationToken.java b/cas/src/main/java/org/springframework/security/cas/authentication/CasAssertionAuthenticationToken.java index af82fa1183..d04d30d154 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/CasAssertionAuthenticationToken.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/CasAssertionAuthenticationToken.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.authentication; import java.util.ArrayList; import org.jasig.cas.client.validation.Assertion; + import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.SpringSecurityCoreVersion; @@ -37,15 +39,16 @@ public final class CasAssertionAuthenticationToken extends AbstractAuthenticatio public CasAssertionAuthenticationToken(final Assertion assertion, final String ticket) { super(new ArrayList<>()); - this.assertion = assertion; this.ticket = ticket; } + @Override public Object getPrincipal() { return this.assertion.getPrincipal().getName(); } + @Override public Object getCredentials() { return this.ticket; } @@ -53,4 +56,5 @@ public final class CasAssertionAuthenticationToken extends AbstractAuthenticatio public Assertion getAssertion() { return this.assertion; } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationProvider.java b/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationProvider.java index 226a786d6b..3a84c2109a 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationProvider.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationProvider.java @@ -21,10 +21,12 @@ import org.apache.commons.logging.LogFactory; import org.jasig.cas.client.validation.Assertion; import org.jasig.cas.client.validation.TicketValidationException; import org.jasig.cas.client.validation.TicketValidator; + import org.springframework.beans.factory.InitializingBean; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AccountStatusUserDetailsChecker; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.BadCredentialsException; @@ -37,7 +39,11 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.NullAuthoritiesMapper; -import org.springframework.security.core.userdetails.*; +import org.springframework.security.core.userdetails.AuthenticationUserDetailsService; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsByNameServiceWrapper; +import org.springframework.security.core.userdetails.UserDetailsChecker; +import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.util.Assert; /** @@ -54,118 +60,93 @@ import org.springframework.util.Assert; * @author Ben Alex * @author Scott Battaglia */ -public class CasAuthenticationProvider implements AuthenticationProvider, - InitializingBean, MessageSourceAware { - // ~ Static fields/initializers - // ===================================================================================== +public class CasAuthenticationProvider implements AuthenticationProvider, InitializingBean, MessageSourceAware { private static final Log logger = LogFactory.getLog(CasAuthenticationProvider.class); - // ~ Instance fields - // ================================================================================================ - private AuthenticationUserDetailsService authenticationUserDetailsService; private final UserDetailsChecker userDetailsChecker = new AccountStatusUserDetailsChecker(); + protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private StatelessTicketCache statelessTicketCache = new NullStatelessTicketCache(); + private String key; + private TicketValidator ticketValidator; + private ServiceProperties serviceProperties; + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(this.authenticationUserDetailsService, - "An authenticationUserDetailsService must be set"); + Assert.notNull(this.authenticationUserDetailsService, "An authenticationUserDetailsService must be set"); Assert.notNull(this.ticketValidator, "A ticketValidator must be set"); Assert.notNull(this.statelessTicketCache, "A statelessTicketCache must be set"); - Assert.hasText( - this.key, + Assert.hasText(this.key, "A Key is required so CasAuthenticationProvider can identify tokens it previously authenticated"); Assert.notNull(this.messages, "A message source must be set"); } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { if (!supports(authentication.getClass())) { return null; } - if (authentication instanceof UsernamePasswordAuthenticationToken - && (!CasAuthenticationFilter.CAS_STATEFUL_IDENTIFIER - .equals(authentication.getPrincipal().toString()) && !CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER - .equals(authentication.getPrincipal().toString()))) { + && (!CasAuthenticationFilter.CAS_STATEFUL_IDENTIFIER.equals(authentication.getPrincipal().toString()) + && !CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER + .equals(authentication.getPrincipal().toString()))) { // UsernamePasswordAuthenticationToken not CAS related return null; } - // If an existing CasAuthenticationToken, just check we created it if (authentication instanceof CasAuthenticationToken) { - if (this.key.hashCode() == ((CasAuthenticationToken) authentication) - .getKeyHash()) { - return authentication; - } - else { - throw new BadCredentialsException( - messages.getMessage("CasAuthenticationProvider.incorrectKey", - "The presented CasAuthenticationToken does not contain the expected key")); + if (this.key.hashCode() != ((CasAuthenticationToken) authentication).getKeyHash()) { + throw new BadCredentialsException(this.messages.getMessage("CasAuthenticationProvider.incorrectKey", + "The presented CasAuthenticationToken does not contain the expected key")); } + return authentication; } // Ensure credentials are presented - if ((authentication.getCredentials() == null) - || "".equals(authentication.getCredentials())) { - throw new BadCredentialsException(messages.getMessage( - "CasAuthenticationProvider.noServiceTicket", + if ((authentication.getCredentials() == null) || "".equals(authentication.getCredentials())) { + throw new BadCredentialsException(this.messages.getMessage("CasAuthenticationProvider.noServiceTicket", "Failed to provide a CAS service ticket to validate")); } - boolean stateless = false; - - if (authentication instanceof UsernamePasswordAuthenticationToken - && CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER.equals(authentication - .getPrincipal())) { - stateless = true; - } - + boolean stateless = (authentication instanceof UsernamePasswordAuthenticationToken + && CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER.equals(authentication.getPrincipal())); CasAuthenticationToken result = null; if (stateless) { // Try to obtain from cache - result = statelessTicketCache.getByTicketId(authentication.getCredentials() - .toString()); + result = this.statelessTicketCache.getByTicketId(authentication.getCredentials().toString()); } - if (result == null) { result = this.authenticateNow(authentication); result.setDetails(authentication.getDetails()); } - if (stateless) { // Add to cache - statelessTicketCache.putTicketInCache(result); + this.statelessTicketCache.putTicketInCache(result); } - return result; } - private CasAuthenticationToken authenticateNow(final Authentication authentication) - throws AuthenticationException { + private CasAuthenticationToken authenticateNow(final Authentication authentication) throws AuthenticationException { try { - final Assertion assertion = this.ticketValidator.validate(authentication - .getCredentials().toString(), getServiceUrl(authentication)); - final UserDetails userDetails = loadUserByAssertion(assertion); - userDetailsChecker.check(userDetails); - return new CasAuthenticationToken(this.key, userDetails, - authentication.getCredentials(), - authoritiesMapper.mapAuthorities(userDetails.getAuthorities()), - userDetails, assertion); + Assertion assertion = this.ticketValidator.validate(authentication.getCredentials().toString(), + getServiceUrl(authentication)); + UserDetails userDetails = loadUserByAssertion(assertion); + this.userDetailsChecker.check(userDetails); + return new CasAuthenticationToken(this.key, userDetails, authentication.getCredentials(), + this.authoritiesMapper.mapAuthorities(userDetails.getAuthorities()), userDetails, assertion); } - catch (final TicketValidationException e) { - throw new BadCredentialsException(e.getMessage(), e); + catch (TicketValidationException ex) { + throw new BadCredentialsException(ex.getMessage(), ex); } } @@ -174,30 +155,20 @@ public class CasAuthenticationProvider implements AuthenticationProvider, * {@link ServiceAuthenticationDetails}, then * {@link ServiceAuthenticationDetails#getServiceUrl()} is used. Otherwise, the * {@link ServiceProperties#getService()} is used. - * * @param authentication * @return */ private String getServiceUrl(Authentication authentication) { String serviceUrl; if (authentication.getDetails() instanceof ServiceAuthenticationDetails) { - serviceUrl = ((ServiceAuthenticationDetails) authentication.getDetails()) - .getServiceUrl(); - } - else if (serviceProperties == null) { - throw new IllegalStateException( - "serviceProperties cannot be null unless Authentication.getDetails() implements ServiceAuthenticationDetails."); - } - else if (serviceProperties.getService() == null) { - throw new IllegalStateException( - "serviceProperties.getService() cannot be null unless Authentication.getDetails() implements ServiceAuthenticationDetails."); - } - else { - serviceUrl = serviceProperties.getService(); - } - if (logger.isDebugEnabled()) { - logger.debug("serviceUrl = " + serviceUrl); + return ((ServiceAuthenticationDetails) authentication.getDetails()).getServiceUrl(); } + Assert.state(this.serviceProperties != null, + "serviceProperties cannot be null unless Authentication.getDetails() implements ServiceAuthenticationDetails."); + Assert.state(this.serviceProperties.getService() != null, + "serviceProperties.getService() cannot be null unless Authentication.getDetails() implements ServiceAuthenticationDetails."); + serviceUrl = this.serviceProperties.getService(); + logger.debug(LogMessage.format("serviceUrl = %s", serviceUrl)); return serviceUrl; } @@ -205,13 +176,11 @@ public class CasAuthenticationProvider implements AuthenticationProvider, * Template method for retrieving the UserDetails based on the assertion. Default is * to call configured userDetailsService and pass the username. Deployers can override * this method and retrieve the user based on any criteria they desire. - * * @param assertion The CAS Assertion. * @return the UserDetails. */ protected UserDetails loadUserByAssertion(final Assertion assertion) { - final CasAssertionAuthenticationToken token = new CasAssertionAuthenticationToken( - assertion, ""); + final CasAssertionAuthenticationToken token = new CasAssertionAuthenticationToken(assertion, ""); return this.authenticationUserDetailsService.loadUserDetails(token); } @@ -220,8 +189,7 @@ public class CasAuthenticationProvider implements AuthenticationProvider, * Sets the UserDetailsService to use. This is a convenience method to invoke */ public void setUserDetailsService(final UserDetailsService userDetailsService) { - this.authenticationUserDetailsService = new UserDetailsByNameServiceWrapper( - userDetailsService); + this.authenticationUserDetailsService = new UserDetailsByNameServiceWrapper(userDetailsService); } public void setAuthenticationUserDetailsService( @@ -234,7 +202,7 @@ public class CasAuthenticationProvider implements AuthenticationProvider, } protected String getKey() { - return key; + return this.key; } public void setKey(String key) { @@ -242,13 +210,14 @@ public class CasAuthenticationProvider implements AuthenticationProvider, } public StatelessTicketCache getStatelessTicketCache() { - return statelessTicketCache; + return this.statelessTicketCache; } protected TicketValidator getTicketValidator() { - return ticketValidator; + return this.ticketValidator; } + @Override public void setMessageSource(final MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } @@ -265,11 +234,11 @@ public class CasAuthenticationProvider implements AuthenticationProvider, this.authoritiesMapper = authoritiesMapper; } + @Override public boolean supports(final Class authentication) { - return (UsernamePasswordAuthenticationToken.class - .isAssignableFrom(authentication)) + return (UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication)) || (CasAuthenticationToken.class.isAssignableFrom(authentication)) - || (CasAssertionAuthenticationToken.class - .isAssignableFrom(authentication)); + || (CasAssertionAuthenticationToken.class.isAssignableFrom(authentication)); } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationToken.java b/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationToken.java index d3d0827133..8020da0400 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationToken.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/CasAuthenticationToken.java @@ -20,11 +20,13 @@ import java.io.Serializable; import java.util.Collection; import org.jasig.cas.client.validation.Assertion; + import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * Represents a successful CAS Authentication. @@ -32,80 +34,69 @@ import org.springframework.util.Assert; * @author Ben Alex * @author Scott Battaglia */ -public class CasAuthenticationToken extends AbstractAuthenticationToken implements - Serializable { +public class CasAuthenticationToken extends AbstractAuthenticationToken implements Serializable { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ private final Object credentials; - private final Object principal; - private final UserDetails userDetails; - private final int keyHash; - private final Assertion assertion; - // ~ Constructors - // =================================================================================================== + private final Object principal; + + private final UserDetails userDetails; + + private final int keyHash; + + private final Assertion assertion; /** * Constructor. - * - * @param key to identify if this object made by a given - * {@link CasAuthenticationProvider} - * @param principal typically the UserDetails object (cannot be null) + * @param key to identify if this object made by a given + * {@link CasAuthenticationProvider} + * @param principal typically the UserDetails object (cannot be null) * @param credentials the service/proxy ticket ID from CAS (cannot be - * null) + * null) * @param authorities the authorities granted to the user (from the - * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot - * be null) + * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot + * be null) * @param userDetails the user details (from the - * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot - * be null) - * @param assertion the assertion returned from the CAS servers. It contains the - * principal and how to obtain a proxy ticket for the user. + * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot + * be null) + * @param assertion the assertion returned from the CAS servers. It contains the + * principal and how to obtain a proxy ticket for the user. * @throws IllegalArgumentException if a null was passed */ - public CasAuthenticationToken(final String key, final Object principal, - final Object credentials, - final Collection authorities, - final UserDetails userDetails, final Assertion assertion) { + public CasAuthenticationToken(final String key, final Object principal, final Object credentials, + final Collection authorities, final UserDetails userDetails, + final Assertion assertion) { this(extractKeyHash(key), principal, credentials, authorities, userDetails, assertion); } /** * Private constructor for Jackson Deserialization support - * - * @param keyHash hashCode of provided key to identify if this object made by a given - * {@link CasAuthenticationProvider} - * @param principal typically the UserDetails object (cannot be null) + * @param keyHash hashCode of provided key to identify if this object made by a given + * {@link CasAuthenticationProvider} + * @param principal typically the UserDetails object (cannot be null) * @param credentials the service/proxy ticket ID from CAS (cannot be - * null) + * null) * @param authorities the authorities granted to the user (from the - * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot - * be null) + * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot + * be null) * @param userDetails the user details (from the - * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot - * be null) - * @param assertion the assertion returned from the CAS servers. It contains the - * principal and how to obtain a proxy ticket for the user. + * {@link org.springframework.security.core.userdetails.UserDetailsService}) (cannot + * be null) + * @param assertion the assertion returned from the CAS servers. It contains the + * principal and how to obtain a proxy ticket for the user. * @throws IllegalArgumentException if a null was passed * @since 4.2 */ - private CasAuthenticationToken(final Integer keyHash, final Object principal, - final Object credentials, - final Collection authorities, - final UserDetails userDetails, final Assertion assertion) { + private CasAuthenticationToken(final Integer keyHash, final Object principal, final Object credentials, + final Collection authorities, final UserDetails userDetails, + final Assertion assertion) { super(authorities); - - if ((principal == null) - || "".equals(principal) || (credentials == null) - || "".equals(credentials) || (authorities == null) - || (userDetails == null) || (assertion == null)) { - throw new IllegalArgumentException( - "Cannot pass null or empty values to constructor"); + if ((principal == null) || "".equals(principal) || (credentials == null) || "".equals(credentials) + || (authorities == null) || (userDetails == null) || (assertion == null)) { + throw new IllegalArgumentException("Cannot pass null or empty values to constructor"); } - this.keyHash = keyHash; this.principal = principal; this.credentials = credentials; @@ -114,9 +105,6 @@ public class CasAuthenticationToken extends AbstractAuthenticationToken implemen setAuthenticated(true); } - // ~ Methods - // ======================================================================================================== - private static Integer extractKeyHash(String key) { Assert.hasLength(key, "key cannot be null or empty"); return key.hashCode(); @@ -127,21 +115,16 @@ public class CasAuthenticationToken extends AbstractAuthenticationToken implemen if (!super.equals(obj)) { return false; } - if (obj instanceof CasAuthenticationToken) { CasAuthenticationToken test = (CasAuthenticationToken) obj; - if (!this.assertion.equals(test.getAssertion())) { return false; } - if (this.getKeyHash() != test.getKeyHash()) { return false; } - return true; } - return false; } @@ -152,7 +135,7 @@ public class CasAuthenticationToken extends AbstractAuthenticationToken implemen result = 31 * result + this.principal.hashCode(); result = 31 * result + this.userDetails.hashCode(); result = 31 * result + this.keyHash; - result = 31 * result + (this.assertion != null ? this.assertion.hashCode() : 0); + result = 31 * result + ObjectUtils.nullSafeHashCode(this.assertion); return result; } @@ -175,7 +158,7 @@ public class CasAuthenticationToken extends AbstractAuthenticationToken implemen } public UserDetails getUserDetails() { - return userDetails; + return this.userDetails; } @Override @@ -184,7 +167,7 @@ public class CasAuthenticationToken extends AbstractAuthenticationToken implemen sb.append(super.toString()); sb.append(" Assertion: ").append(this.assertion); sb.append(" Credentials (Service/Proxy Ticket): ").append(this.credentials); - return (sb.toString()); } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCache.java b/cas/src/main/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCache.java index e424d512de..037b2a3f0f 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCache.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCache.java @@ -18,74 +18,61 @@ package org.springframework.security.cas.authentication; import net.sf.ehcache.Ehcache; import net.sf.ehcache.Element; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; /** - * Caches tickets using a Spring IoC defined EHCACHE. + * Caches tickets using a Spring IoC defined + * EHCACHE. * * @author Ben Alex */ public class EhCacheBasedTicketCache implements StatelessTicketCache, InitializingBean { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(EhCacheBasedTicketCache.class); - // ~ Instance fields - // ================================================================================================ - private Ehcache cache; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(cache, "cache mandatory"); + Assert.notNull(this.cache, "cache mandatory"); } + @Override public CasAuthenticationToken getByTicketId(final String serviceTicket) { - final Element element = cache.get(serviceTicket); - - if (logger.isDebugEnabled()) { - logger.debug("Cache hit: " + (element != null) + "; service ticket: " - + serviceTicket); - } - - return element == null ? null : (CasAuthenticationToken) element.getValue(); + final Element element = this.cache.get(serviceTicket); + logger.debug(LogMessage.of(() -> "Cache hit: " + (element != null) + "; service ticket: " + serviceTicket)); + return (element != null) ? (CasAuthenticationToken) element.getValue() : null; } public Ehcache getCache() { - return cache; + return this.cache; } + @Override public void putTicketInCache(final CasAuthenticationToken token) { final Element element = new Element(token.getCredentials().toString(), token); - - if (logger.isDebugEnabled()) { - logger.debug("Cache put: " + element.getKey()); - } - - cache.put(element); + logger.debug(LogMessage.of(() -> "Cache put: " + element.getKey())); + this.cache.put(element); } + @Override public void removeTicketFromCache(final CasAuthenticationToken token) { - if (logger.isDebugEnabled()) { - logger.debug("Cache remove: " + token.getCredentials().toString()); - } - + logger.debug(LogMessage.of(() -> "Cache remove: " + token.getCredentials().toString())); this.removeTicketFromCache(token.getCredentials().toString()); } + @Override public void removeTicketFromCache(final String serviceTicket) { - cache.remove(serviceTicket); + this.cache.remove(serviceTicket); } public void setCache(final Ehcache cache) { this.cache = cache; } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/NullStatelessTicketCache.java b/cas/src/main/java/org/springframework/security/cas/authentication/NullStatelessTicketCache.java index b33518114b..4284161a39 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/NullStatelessTicketCache.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/NullStatelessTicketCache.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.authentication; /** @@ -24,7 +25,6 @@ package org.springframework.security.cas.authentication; * are not using the stateless session management. * * @author Scott Battaglia - * * @see CasAuthenticationProvider */ public final class NullStatelessTicketCache implements StatelessTicketCache { @@ -32,6 +32,7 @@ public final class NullStatelessTicketCache implements StatelessTicketCache { /** * @return null since we are not storing any tickets. */ + @Override public CasAuthenticationToken getByTicketId(final String serviceTicket) { return null; } @@ -39,6 +40,7 @@ public final class NullStatelessTicketCache implements StatelessTicketCache { /** * This is a no-op since we are not storing tickets. */ + @Override public void putTicketInCache(final CasAuthenticationToken token) { // nothing to do } @@ -46,6 +48,7 @@ public final class NullStatelessTicketCache implements StatelessTicketCache { /** * This is a no-op since we are not storing tickets. */ + @Override public void removeTicketFromCache(final CasAuthenticationToken token) { // nothing to do } @@ -53,7 +56,9 @@ public final class NullStatelessTicketCache implements StatelessTicketCache { /** * This is a no-op since we are not storing tickets. */ + @Override public void removeTicketFromCache(final String serviceTicket) { // nothing to do } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCache.java b/cas/src/main/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCache.java index 171792448d..b72e824c75 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCache.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCache.java @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.authentication; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.cache.Cache; +import org.springframework.core.log.LogMessage; import org.springframework.util.Assert; /** @@ -28,59 +31,39 @@ import org.springframework.util.Assert; * */ public class SpringCacheBasedTicketCache implements StatelessTicketCache { - // ~ Static fields/initializers - // ===================================================================================== - private static final Log logger = LogFactory - .getLog(SpringCacheBasedTicketCache.class); - - // ~ Instance fields - // ================================================================================================ + private static final Log logger = LogFactory.getLog(SpringCacheBasedTicketCache.class); private final Cache cache; - // ~ Constructors - // =================================================================================================== - public SpringCacheBasedTicketCache(Cache cache) { Assert.notNull(cache, "cache mandatory"); this.cache = cache; } - // ~ Methods - // ======================================================================================================== - + @Override public CasAuthenticationToken getByTicketId(final String serviceTicket) { - final Cache.ValueWrapper element = serviceTicket != null ? cache - .get(serviceTicket) : null; - - if (logger.isDebugEnabled()) { - logger.debug("Cache hit: " + (element != null) + "; service ticket: " - + serviceTicket); - } - - return element == null ? null : (CasAuthenticationToken) element.get(); + final Cache.ValueWrapper element = (serviceTicket != null) ? this.cache.get(serviceTicket) : null; + logger.debug(LogMessage.of(() -> "Cache hit: " + (element != null) + "; service ticket: " + serviceTicket)); + return (element != null) ? (CasAuthenticationToken) element.get() : null; } + @Override public void putTicketInCache(final CasAuthenticationToken token) { String key = token.getCredentials().toString(); - - if (logger.isDebugEnabled()) { - logger.debug("Cache put: " + key); - } - - cache.put(key, token); + logger.debug(LogMessage.of(() -> "Cache put: " + key)); + this.cache.put(key, token); } + @Override public void removeTicketFromCache(final CasAuthenticationToken token) { - if (logger.isDebugEnabled()) { - logger.debug("Cache remove: " + token.getCredentials().toString()); - } - + logger.debug(LogMessage.of(() -> "Cache remove: " + token.getCredentials().toString())); this.removeTicketFromCache(token.getCredentials().toString()); } + @Override public void removeTicketFromCache(final String serviceTicket) { - cache.evict(serviceTicket); + this.cache.evict(serviceTicket); } + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/StatelessTicketCache.java b/cas/src/main/java/org/springframework/security/cas/authentication/StatelessTicketCache.java index 7c848f19a4..74df6bb9df 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/StatelessTicketCache.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/StatelessTicketCache.java @@ -59,7 +59,6 @@ package org.springframework.security.cas.authentication; * @author Ben Alex */ public interface StatelessTicketCache { - // ~ Methods ================================================================ /** * Retrieves the CasAuthenticationToken associated with the specified @@ -68,7 +67,6 @@ public interface StatelessTicketCache { *

* If not found, returns a nullCasAuthenticationToken. *

- * * @return the fully populated authentication token */ CasAuthenticationToken getByTicketId(String serviceTicket); @@ -80,7 +78,6 @@ public interface StatelessTicketCache { * The {@link CasAuthenticationToken#getCredentials()} method is used to retrieve the * service ticket number. *

- * * @param token to be added to the cache */ void putTicketInCache(CasAuthenticationToken token); @@ -91,10 +88,9 @@ public interface StatelessTicketCache { * *

* Implementations should use {@link CasAuthenticationToken#getCredentials()} to - * obtain the ticket and then delegate to the - * {@link #removeTicketFromCache(String)} method. + * obtain the ticket and then delegate to the {@link #removeTicketFromCache(String)} + * method. *

- * * @param token to be removed */ void removeTicketFromCache(CasAuthenticationToken token); @@ -107,8 +103,8 @@ public interface StatelessTicketCache { * This is in case applications wish to provide a session termination capability for * their stateless clients. *

- * * @param serviceTicket to be removed */ void removeTicketFromCache(String serviceTicket); + } diff --git a/cas/src/main/java/org/springframework/security/cas/authentication/package-info.java b/cas/src/main/java/org/springframework/security/cas/authentication/package-info.java index 240f124bb9..8803500a0a 100644 --- a/cas/src/main/java/org/springframework/security/cas/authentication/package-info.java +++ b/cas/src/main/java/org/springframework/security/cas/authentication/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * An {@code AuthenticationProvider} that can process CAS service tickets and proxy tickets. + * An {@code AuthenticationProvider} that can process CAS service tickets and proxy + * tickets. */ package org.springframework.security.cas.authentication; - diff --git a/cas/src/main/java/org/springframework/security/cas/jackson2/AssertionImplMixin.java b/cas/src/main/java/org/springframework/security/cas/jackson2/AssertionImplMixin.java index 3085f92d95..5d922b44e6 100644 --- a/cas/src/main/java/org/springframework/security/cas/jackson2/AssertionImplMixin.java +++ b/cas/src/main/java/org/springframework/security/cas/jackson2/AssertionImplMixin.java @@ -16,38 +16,43 @@ package org.springframework.security.cas.jackson2; -import com.fasterxml.jackson.annotation.*; -import org.jasig.cas.client.authentication.AttributePrincipal; - import java.util.Date; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.jasig.cas.client.authentication.AttributePrincipal; + /** - * Helps in jackson deserialization of class {@link org.jasig.cas.client.validation.AssertionImpl}, which is - * used with {@link org.springframework.security.cas.authentication.CasAuthenticationToken}. - * To use this class we need to register with {@link com.fasterxml.jackson.databind.ObjectMapper}. Type information - * will be stored in @class property. + * Helps in jackson deserialization of class + * {@link org.jasig.cas.client.validation.AssertionImpl}, which is used with + * {@link org.springframework.security.cas.authentication.CasAuthenticationToken}. To use + * this class we need to register with + * {@link com.fasterxml.jackson.databind.ObjectMapper}. Type information will be stored + * in @class property. *

*

  *     ObjectMapper mapper = new ObjectMapper();
  *     mapper.registerModule(new CasJackson2Module());
  * 
* - * * @author Jitendra Singh * @see CasJackson2Module * @see org.springframework.security.jackson2.SecurityJackson2Modules * @since 4.2 */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) -@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, - getterVisibility = JsonAutoDetect.Visibility.NONE, isGetterVisibility = JsonAutoDetect.Visibility.NONE) +@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, + isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonIgnoreProperties(ignoreUnknown = true) class AssertionImplMixin { /** - * Mixin Constructor helps in deserialize {@link org.jasig.cas.client.validation.AssertionImpl} - * + * Mixin Constructor helps in deserialize + * {@link org.jasig.cas.client.validation.AssertionImpl} * @param principal the Principal to associate with the Assertion. * @param validFromDate when the assertion is valid from. * @param validUntilDate when the assertion is valid to. @@ -56,7 +61,9 @@ class AssertionImplMixin { */ @JsonCreator AssertionImplMixin(@JsonProperty("principal") AttributePrincipal principal, - @JsonProperty("validFromDate") Date validFromDate, @JsonProperty("validUntilDate") Date validUntilDate, - @JsonProperty("authenticationDate") Date authenticationDate, @JsonProperty("attributes") Map attributes){ + @JsonProperty("validFromDate") Date validFromDate, @JsonProperty("validUntilDate") Date validUntilDate, + @JsonProperty("authenticationDate") Date authenticationDate, + @JsonProperty("attributes") Map attributes) { } + } diff --git a/cas/src/main/java/org/springframework/security/cas/jackson2/AttributePrincipalImplMixin.java b/cas/src/main/java/org/springframework/security/cas/jackson2/AttributePrincipalImplMixin.java index ddc326704f..775850c3b4 100644 --- a/cas/src/main/java/org/springframework/security/cas/jackson2/AttributePrincipalImplMixin.java +++ b/cas/src/main/java/org/springframework/security/cas/jackson2/AttributePrincipalImplMixin.java @@ -16,15 +16,20 @@ package org.springframework.security.cas.jackson2; -import com.fasterxml.jackson.annotation.*; -import org.jasig.cas.client.proxy.ProxyRetriever; - import java.util.Map; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.jasig.cas.client.proxy.ProxyRetriever; + /** - * Helps in deserialize {@link org.jasig.cas.client.authentication.AttributePrincipalImpl} which is used with - * {@link org.springframework.security.cas.authentication.CasAuthenticationToken}. Type information will be stored - * in property named @class. + * Helps in deserialize {@link org.jasig.cas.client.authentication.AttributePrincipalImpl} + * which is used with + * {@link org.springframework.security.cas.authentication.CasAuthenticationToken}. Type + * information will be stored in property named @class. *

*

  *     ObjectMapper mapper = new ObjectMapper();
@@ -43,16 +48,19 @@ import java.util.Map;
 class AttributePrincipalImplMixin {
 
 	/**
-	 * Mixin Constructor helps in deserialize {@link org.jasig.cas.client.authentication.AttributePrincipalImpl}
-	 *
+	 * Mixin Constructor helps in deserialize
+	 * {@link org.jasig.cas.client.authentication.AttributePrincipalImpl}
 	 * @param name the unique identifier for the principal.
 	 * @param attributes the key/value pairs for this principal.
 	 * @param proxyGrantingTicket the ticket associated with this principal.
-	 * @param proxyRetriever the ProxyRetriever implementation to call back to the CAS server.
+	 * @param proxyRetriever the ProxyRetriever implementation to call back to the CAS
+	 * server.
 	 */
 	@JsonCreator
-	AttributePrincipalImplMixin(@JsonProperty("name") String name, @JsonProperty("attributes") Map attributes,
-										@JsonProperty("proxyGrantingTicket") String proxyGrantingTicket,
-										@JsonProperty("proxyRetriever") ProxyRetriever proxyRetriever) {
+	AttributePrincipalImplMixin(@JsonProperty("name") String name,
+			@JsonProperty("attributes") Map attributes,
+			@JsonProperty("proxyGrantingTicket") String proxyGrantingTicket,
+			@JsonProperty("proxyRetriever") ProxyRetriever proxyRetriever) {
 	}
+
 }
diff --git a/cas/src/main/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixin.java b/cas/src/main/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixin.java
index dba9e0521e..98c9151e72 100644
--- a/cas/src/main/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixin.java
+++ b/cas/src/main/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixin.java
@@ -16,21 +16,27 @@
 
 package org.springframework.security.cas.jackson2;
 
-import com.fasterxml.jackson.annotation.*;
+import java.util.Collection;
+
+import com.fasterxml.jackson.annotation.JsonAutoDetect;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
 import org.jasig.cas.client.validation.Assertion;
+
 import org.springframework.security.cas.authentication.CasAuthenticationProvider;
 import org.springframework.security.cas.authentication.CasAuthenticationToken;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.userdetails.UserDetails;
 
-import java.util.Collection;
-
 /**
- * Mixin class which helps in deserialize {@link org.springframework.security.cas.authentication.CasAuthenticationToken}
- * using jackson. Two more dependent classes needs to register along with this mixin class.
+ * Mixin class which helps in deserialize
+ * {@link org.springframework.security.cas.authentication.CasAuthenticationToken} using
+ * jackson. Two more dependent classes needs to register along with this mixin class.
  * 
    - *
  1. {@link org.springframework.security.cas.jackson2.AssertionImplMixin}
  2. - *
  3. {@link org.springframework.security.cas.jackson2.AttributePrincipalImplMixin}
  4. + *
  5. {@link org.springframework.security.cas.jackson2.AssertionImplMixin}
  6. + *
  7. {@link org.springframework.security.cas.jackson2.AttributePrincipalImplMixin}
  8. *
* *

@@ -53,7 +59,6 @@ class CasAuthenticationTokenMixin { /** * Mixin Constructor helps in deserialize {@link CasAuthenticationToken} - * * @param keyHash hashCode of provided key to identify if this object made by a given * {@link CasAuthenticationProvider} * @param principal typically the UserDetails object (cannot be null) @@ -70,8 +75,9 @@ class CasAuthenticationTokenMixin { */ @JsonCreator CasAuthenticationTokenMixin(@JsonProperty("keyHash") Integer keyHash, @JsonProperty("principal") Object principal, - @JsonProperty("credentials") Object credentials, - @JsonProperty("authorities") Collection authorities, - @JsonProperty("userDetails") UserDetails userDetails, @JsonProperty("assertion") Assertion assertion) { + @JsonProperty("credentials") Object credentials, + @JsonProperty("authorities") Collection authorities, + @JsonProperty("userDetails") UserDetails userDetails, @JsonProperty("assertion") Assertion assertion) { } + } diff --git a/cas/src/main/java/org/springframework/security/cas/jackson2/CasJackson2Module.java b/cas/src/main/java/org/springframework/security/cas/jackson2/CasJackson2Module.java index 5d2e99370d..34f19ca10a 100644 --- a/cas/src/main/java/org/springframework/security/cas/jackson2/CasJackson2Module.java +++ b/cas/src/main/java/org/springframework/security/cas/jackson2/CasJackson2Module.java @@ -20,24 +20,26 @@ import com.fasterxml.jackson.core.Version; import com.fasterxml.jackson.databind.module.SimpleModule; import org.jasig.cas.client.authentication.AttributePrincipalImpl; import org.jasig.cas.client.validation.AssertionImpl; + import org.springframework.security.cas.authentication.CasAuthenticationToken; import org.springframework.security.jackson2.SecurityJackson2Modules; /** - * Jackson module for spring-security-cas. This module register {@link AssertionImplMixin}, - * {@link AttributePrincipalImplMixin} and {@link CasAuthenticationTokenMixin}. If no default typing enabled by default then - * it'll enable it because typing info is needed to properly serialize/deserialize objects. In order to use this module just - * add this module into your ObjectMapper configuration. + * Jackson module for spring-security-cas. This module register + * {@link AssertionImplMixin}, {@link AttributePrincipalImplMixin} and + * {@link CasAuthenticationTokenMixin}. If no default typing enabled by default then it'll + * enable it because typing info is needed to properly serialize/deserialize objects. In + * order to use this module just add this module into your ObjectMapper configuration. * *

  *     ObjectMapper mapper = new ObjectMapper();
  *     mapper.registerModule(new CasJackson2Module());
- * 
- * Note: use {@link SecurityJackson2Modules#getModules(ClassLoader)} to get list of all security modules on the classpath. + *
Note: use {@link SecurityJackson2Modules#getModules(ClassLoader)} to get list + * of all security modules on the classpath. * * @author Jitendra Singh. - * @see org.springframework.security.jackson2.SecurityJackson2Modules * @since 4.2 + * @see org.springframework.security.jackson2.SecurityJackson2Modules */ public class CasJackson2Module extends SimpleModule { @@ -52,4 +54,5 @@ public class CasJackson2Module extends SimpleModule { context.setMixInAnnotations(AttributePrincipalImpl.class, AttributePrincipalImplMixin.class); context.setMixInAnnotations(CasAuthenticationToken.class, CasAuthenticationTokenMixin.class); } + } diff --git a/cas/src/main/java/org/springframework/security/cas/package-info.java b/cas/src/main/java/org/springframework/security/cas/package-info.java index 8ce8a88b5a..13fae9057d 100644 --- a/cas/src/main/java/org/springframework/security/cas/package-info.java +++ b/cas/src/main/java/org/springframework/security/cas/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Spring Security support for Jasig's Central Authentication Service (CAS). + * Spring Security support for Jasig's Central Authentication Service + * (CAS). */ package org.springframework.security.cas; - diff --git a/cas/src/main/java/org/springframework/security/cas/userdetails/AbstractCasAssertionUserDetailsService.java b/cas/src/main/java/org/springframework/security/cas/userdetails/AbstractCasAssertionUserDetailsService.java index bb4770eb05..3d8cd9e412 100644 --- a/cas/src/main/java/org/springframework/security/cas/userdetails/AbstractCasAssertionUserDetailsService.java +++ b/cas/src/main/java/org/springframework/security/cas/userdetails/AbstractCasAssertionUserDetailsService.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.userdetails; import org.jasig.cas.client.validation.Assertion; + import org.springframework.security.cas.authentication.CasAssertionAuthenticationToken; import org.springframework.security.core.userdetails.AuthenticationUserDetailsService; import org.springframework.security.core.userdetails.UserDetails; @@ -28,9 +30,10 @@ import org.springframework.security.core.userdetails.UserDetails; * @author Scott Battaglia * @since 3.0 */ -public abstract class AbstractCasAssertionUserDetailsService implements - AuthenticationUserDetailsService { +public abstract class AbstractCasAssertionUserDetailsService + implements AuthenticationUserDetailsService { + @Override public final UserDetails loadUserDetails(final CasAssertionAuthenticationToken token) { return loadUserDetails(token.getAssertion()); } @@ -39,10 +42,10 @@ public abstract class AbstractCasAssertionUserDetailsService implements * Protected template method for construct a * {@link org.springframework.security.core.userdetails.UserDetails} via the supplied * CAS assertion. - * * @param assertion the assertion to use to construct the new UserDetails. CANNOT be * NULL. * @return the newly constructed UserDetails. */ protected abstract UserDetails loadUserDetails(Assertion assertion); + } diff --git a/cas/src/main/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsService.java b/cas/src/main/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsService.java index 93f03eeabe..0e47d1c57f 100644 --- a/cas/src/main/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsService.java +++ b/cas/src/main/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsService.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.userdetails; -import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.security.core.userdetails.User; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; -import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.List; + import org.jasig.cas.client.validation.Assertion; -import java.util.List; -import java.util.ArrayList; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.util.Assert; /** * Populates the {@link org.springframework.security.core.GrantedAuthority}s for a user by @@ -34,8 +36,8 @@ import java.util.ArrayList; * @author Scott Battaglia * @since 3.0 */ -public final class GrantedAuthorityFromAssertionAttributesUserDetailsService extends - AbstractCasAssertionUserDetailsService { +public final class GrantedAuthorityFromAssertionAttributesUserDetailsService + extends AbstractCasAssertionUserDetailsService { private static final String NON_EXISTENT_PASSWORD_VALUE = "NO_PASSWORD"; @@ -43,54 +45,43 @@ public final class GrantedAuthorityFromAssertionAttributesUserDetailsService ext private boolean convertToUpperCase = true; - public GrantedAuthorityFromAssertionAttributesUserDetailsService( - final String[] attributes) { + public GrantedAuthorityFromAssertionAttributesUserDetailsService(final String[] attributes) { Assert.notNull(attributes, "attributes cannot be null."); - Assert.isTrue(attributes.length > 0, - "At least one attribute is required to retrieve roles from."); + Assert.isTrue(attributes.length > 0, "At least one attribute is required to retrieve roles from."); this.attributes = attributes; } @SuppressWarnings("unchecked") @Override protected UserDetails loadUserDetails(final Assertion assertion) { - final List grantedAuthorities = new ArrayList<>(); - - for (final String attribute : this.attributes) { - final Object value = assertion.getPrincipal().getAttributes().get(attribute); - - if (value == null) { - continue; - } - - if (value instanceof List) { - final List list = (List) value; - - for (final Object o : list) { - grantedAuthorities.add(new SimpleGrantedAuthority( - this.convertToUpperCase ? o.toString().toUpperCase() : o - .toString())); + List grantedAuthorities = new ArrayList<>(); + for (String attribute : this.attributes) { + Object value = assertion.getPrincipal().getAttributes().get(attribute); + if (value != null) { + if (value instanceof List) { + for (Object o : (List) value) { + grantedAuthorities.add(createSimpleGrantedAuthority(o)); + } + } + else { + grantedAuthorities.add(createSimpleGrantedAuthority(value)); } - } - else { - grantedAuthorities.add(new SimpleGrantedAuthority( - this.convertToUpperCase ? value.toString().toUpperCase() : value - .toString())); - } - } + return new User(assertion.getPrincipal().getName(), NON_EXISTENT_PASSWORD_VALUE, true, true, true, true, + grantedAuthorities); + } - return new User(assertion.getPrincipal().getName(), NON_EXISTENT_PASSWORD_VALUE, - true, true, true, true, grantedAuthorities); + private SimpleGrantedAuthority createSimpleGrantedAuthority(Object o) { + return new SimpleGrantedAuthority(this.convertToUpperCase ? o.toString().toUpperCase() : o.toString()); } /** * Converts the returned attribute values to uppercase values. - * * @param convertToUpperCase true if it should convert, false otherwise. */ public void setConvertToUpperCase(final boolean convertToUpperCase) { this.convertToUpperCase = convertToUpperCase; } + } diff --git a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java index 9742fde7d5..25221addf8 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java +++ b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationEntryPoint.java @@ -22,10 +22,11 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.jasig.cas.client.util.CommonUtils; + +import org.springframework.beans.factory.InitializingBean; import org.springframework.security.cas.ServiceProperties; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.AuthenticationEntryPoint; -import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; /** @@ -42,10 +43,8 @@ import org.springframework.util.Assert; * @author Ben Alex * @author Scott Battaglia */ -public class CasAuthenticationEntryPoint implements AuthenticationEntryPoint, - InitializingBean { - // ~ Instance fields - // ================================================================================================ +public class CasAuthenticationEntryPoint implements AuthenticationEntryPoint, InitializingBean { + private ServiceProperties serviceProperties; private String loginUrl; @@ -61,25 +60,19 @@ public class CasAuthenticationEntryPoint implements AuthenticationEntryPoint, */ private boolean encodeServiceUrlWithSessionId = true; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { Assert.hasLength(this.loginUrl, "loginUrl must be specified"); Assert.notNull(this.serviceProperties, "serviceProperties must be specified"); - Assert.notNull(this.serviceProperties.getService(), - "serviceProperties.getService() cannot be null."); + Assert.notNull(this.serviceProperties.getService(), "serviceProperties.getService() cannot be null."); } - public final void commence(final HttpServletRequest servletRequest, - final HttpServletResponse response, - final AuthenticationException authenticationException) throws IOException { - - final String urlEncodedService = createServiceUrl(servletRequest, response); - final String redirectUrl = createRedirectUrl(urlEncodedService); - + @Override + public final void commence(final HttpServletRequest servletRequest, HttpServletResponse response, + AuthenticationException authenticationException) throws IOException { + String urlEncodedService = createServiceUrl(servletRequest, response); + String redirectUrl = createRedirectUrl(urlEncodedService); preCommence(servletRequest, response); - response.sendRedirect(redirectUrl); } @@ -90,42 +83,34 @@ public class CasAuthenticationEntryPoint implements AuthenticationEntryPoint, * @param response the HttpServlet Response * @return the constructed service url. CANNOT be NULL. */ - protected String createServiceUrl(final HttpServletRequest request, - final HttpServletResponse response) { - return CommonUtils.constructServiceUrl(null, response, - this.serviceProperties.getService(), null, - this.serviceProperties.getArtifactParameter(), - this.encodeServiceUrlWithSessionId); + protected String createServiceUrl(HttpServletRequest request, HttpServletResponse response) { + return CommonUtils.constructServiceUrl(null, response, this.serviceProperties.getService(), null, + this.serviceProperties.getArtifactParameter(), this.encodeServiceUrlWithSessionId); } /** * Constructs the Url for Redirection to the CAS server. Default implementation relies * on the CAS client to do the bulk of the work. - * * @param serviceUrl the service url that should be included. * @return the redirect url. CANNOT be NULL. */ - protected String createRedirectUrl(final String serviceUrl) { - return CommonUtils.constructRedirectUrl(this.loginUrl, - this.serviceProperties.getServiceParameter(), serviceUrl, + protected String createRedirectUrl(String serviceUrl) { + return CommonUtils.constructRedirectUrl(this.loginUrl, this.serviceProperties.getServiceParameter(), serviceUrl, this.serviceProperties.isSendRenew(), false); } /** * Template method for you to do your own pre-processing before the redirect occurs. - * * @param request the HttpServletRequest * @param response the HttpServletResponse */ - protected void preCommence(final HttpServletRequest request, - final HttpServletResponse response) { + protected void preCommence(HttpServletRequest request, HttpServletResponse response) { } /** * The enterprise-wide CAS login URL. Usually something like * https://www.mycompany.com/cas/login. - * * @return the enterprise-wide CAS login URL */ public final String getLoginUrl() { @@ -136,22 +121,20 @@ public class CasAuthenticationEntryPoint implements AuthenticationEntryPoint, return this.serviceProperties; } - public final void setLoginUrl(final String loginUrl) { + public final void setLoginUrl(String loginUrl) { this.loginUrl = loginUrl; } - public final void setServiceProperties(final ServiceProperties serviceProperties) { + public final void setServiceProperties(ServiceProperties serviceProperties) { this.serviceProperties = serviceProperties; } /** * Sets whether to encode the service url with the session id or not. - * * @param encodeServiceUrlWithSessionId whether to encode the service url with the * session id or not. */ - public final void setEncodeServiceUrlWithSessionId( - final boolean encodeServiceUrlWithSessionId) { + public final void setEncodeServiceUrlWithSessionId(boolean encodeServiceUrlWithSessionId) { this.encodeServiceUrlWithSessionId = encodeServiceUrlWithSessionId; } @@ -163,4 +146,5 @@ public class CasAuthenticationEntryPoint implements AuthenticationEntryPoint, protected boolean getEncodeServiceUrlWithSessionId() { return this.encodeServiceUrlWithSessionId; } + } diff --git a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java index 7ff21e2480..42339c0c9d 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java +++ b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java @@ -26,6 +26,8 @@ import javax.servlet.http.HttpServletResponse; import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage; import org.jasig.cas.client.util.CommonUtils; import org.jasig.cas.client.validation.TicketValidator; + +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -45,7 +47,8 @@ import org.springframework.util.Assert; /** * Processes a CAS service ticket, obtains proxy granting tickets, and processes proxy - * tickets.

Service Tickets

+ * tickets. + *

Service Tickets

*

* A service ticket consists of an opaque ticket string. It arrives at this filter by the * user's browser successfully authenticating using CAS, and then receiving a HTTP @@ -171,10 +174,10 @@ import org.springframework.util.Assert; * @author Rob Winch */ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFilter { - // ~ Static fields/initializers - // ===================================================================================== - /** Used to identify a CAS request for a stateful user agent, such as a web browser. */ + /** + * Used to identify a CAS request for a stateful user agent, such as a web browser. + */ public static final String CAS_STATEFUL_IDENTIFIER = "_cas_stateful_"; /** @@ -201,72 +204,47 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil private AuthenticationFailureHandler proxyFailureHandler = new SimpleUrlAuthenticationFailureHandler(); - // ~ Constructors - // =================================================================================================== - public CasAuthenticationFilter() { super("/login/cas"); setAuthenticationFailureHandler(new SimpleUrlAuthenticationFailureHandler()); } - // ~ Methods - // ======================================================================================================== - @Override - protected final void successfulAuthentication(HttpServletRequest request, - HttpServletResponse response, FilterChain chain, Authentication authResult) - throws IOException, ServletException { - boolean continueFilterChain = proxyTicketRequest( - serviceTicketRequest(request, response), request); + protected final void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, + FilterChain chain, Authentication authResult) throws IOException, ServletException { + boolean continueFilterChain = proxyTicketRequest(serviceTicketRequest(request, response), request); if (!continueFilterChain) { super.successfulAuthentication(request, response, chain, authResult); return; } - - if (logger.isDebugEnabled()) { - logger.debug("Authentication success. Updating SecurityContextHolder to contain: " - + authResult); - } - + this.logger.debug( + LogMessage.format("Authentication success. Updating SecurityContextHolder to contain: %s", authResult)); SecurityContextHolder.getContext().setAuthentication(authResult); - - // Fire event if (this.eventPublisher != null) { - eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent( - authResult, this.getClass())); + this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); } - chain.doFilter(request, response); } @Override - public Authentication attemptAuthentication(final HttpServletRequest request, - final HttpServletResponse response) throws AuthenticationException, - IOException { + public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) + throws AuthenticationException, IOException { // if the request is a proxy request process it and return null to indicate the // request has been processed if (proxyReceptorRequest(request)) { - logger.debug("Responding to proxy receptor request"); - CommonUtils.readAndRespondToProxyReceptorRequest(request, response, - this.proxyGrantingTicketStorage); + this.logger.debug("Responding to proxy receptor request"); + CommonUtils.readAndRespondToProxyReceptorRequest(request, response, this.proxyGrantingTicketStorage); return null; } - - final boolean serviceTicketRequest = serviceTicketRequest(request, response); - final String username = serviceTicketRequest ? CAS_STATEFUL_IDENTIFIER - : CAS_STATELESS_IDENTIFIER; + boolean serviceTicketRequest = serviceTicketRequest(request, response); + String username = serviceTicketRequest ? CAS_STATEFUL_IDENTIFIER : CAS_STATELESS_IDENTIFIER; String password = obtainArtifact(request); - if (password == null) { - logger.debug("Failed to obtain an artifact (cas ticket)"); + this.logger.debug("Failed to obtain an artifact (cas ticket)"); password = ""; } - - final UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken( - username, password); - - authRequest.setDetails(authenticationDetailsSource.buildDetails(request)); - + UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken(username, password); + authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); return this.getAuthenticationManager().authenticate(authRequest); } @@ -276,19 +254,19 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * @return if present the artifact from the {@link HttpServletRequest}, else null */ protected String obtainArtifact(HttpServletRequest request) { - return request.getParameter(artifactParameter); + return request.getParameter(this.artifactParameter); } /** * Overridden to provide proxying capabilities. */ - protected boolean requiresAuthentication(final HttpServletRequest request, - final HttpServletResponse response) { + @Override + protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response) { final boolean serviceTicketRequest = serviceTicketRequest(request, response); final boolean result = serviceTicketRequest || proxyReceptorRequest(request) || (proxyTicketRequest(serviceTicketRequest, request)); - if (logger.isDebugEnabled()) { - logger.debug("requiresAuthentication = " + result); + if (this.logger.isDebugEnabled()) { + this.logger.debug("requiresAuthentication = " + result); } return result; } @@ -297,8 +275,7 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * Sets the {@link AuthenticationFailureHandler} for proxy requests. * @param proxyFailureHandler */ - public final void setProxyAuthenticationFailureHandler( - AuthenticationFailureHandler proxyFailureHandler) { + public final void setProxyAuthenticationFailureHandler(AuthenticationFailureHandler proxyFailureHandler) { Assert.notNull(proxyFailureHandler, "proxyFailureHandler cannot be null"); this.proxyFailureHandler = proxyFailureHandler; } @@ -308,18 +285,15 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * proxy ticket authentication failures and service ticket failures. */ @Override - public final void setAuthenticationFailureHandler( - AuthenticationFailureHandler failureHandler) { - super.setAuthenticationFailureHandler(new CasAuthenticationFailureHandler( - failureHandler)); + public final void setAuthenticationFailureHandler(AuthenticationFailureHandler failureHandler) { + super.setAuthenticationFailureHandler(new CasAuthenticationFailureHandler(failureHandler)); } public final void setProxyReceptorUrl(final String proxyReceptorUrl) { this.proxyReceptorMatcher = new AntPathRequestMatcher("/**" + proxyReceptorUrl); } - public final void setProxyGrantingTicketStorage( - final ProxyGrantingTicketStorage proxyGrantingTicketStorage) { + public final void setProxyGrantingTicketStorage(final ProxyGrantingTicketStorage proxyGrantingTicketStorage) { this.proxyGrantingTicketStorage = proxyGrantingTicketStorage; } @@ -335,12 +309,9 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * @param response * @return */ - private boolean serviceTicketRequest(final HttpServletRequest request, - final HttpServletResponse response) { + private boolean serviceTicketRequest(HttpServletRequest request, HttpServletResponse response) { boolean result = super.requiresAuthentication(request, response); - if (logger.isDebugEnabled()) { - logger.debug("serviceTicketRequest = " + result); - } + this.logger.debug(LogMessage.format("serviceTicketRequest = %s", result)); return result; } @@ -349,16 +320,12 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * @param request * @return */ - private boolean proxyTicketRequest(final boolean serviceTicketRequest, - final HttpServletRequest request) { + private boolean proxyTicketRequest(boolean serviceTicketRequest, HttpServletRequest request) { if (serviceTicketRequest) { return false; } - final boolean result = authenticateAllArtifacts - && obtainArtifact(request) != null && !authenticated(); - if (logger.isDebugEnabled()) { - logger.debug("proxyTicketRequest = " + result); - } + boolean result = this.authenticateAllArtifacts && obtainArtifact(request) != null && !authenticated(); + this.logger.debug(LogMessage.format("proxyTicketRequest = %s", result)); return result; } @@ -367,8 +334,7 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * @return */ private boolean authenticated() { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); return authentication != null && authentication.isAuthenticated() && !(authentication instanceof AnonymousAuthenticationToken); } @@ -378,41 +344,33 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil * @param request * @return */ - private boolean proxyReceptorRequest(final HttpServletRequest request) { - final boolean result = proxyReceptorConfigured() - && proxyReceptorMatcher.matches(request); - if (logger.isDebugEnabled()) { - logger.debug("proxyReceptorRequest = " + result); - } + private boolean proxyReceptorRequest(HttpServletRequest request) { + final boolean result = proxyReceptorConfigured() && this.proxyReceptorMatcher.matches(request); + this.logger.debug(LogMessage.format("proxyReceptorRequest = %s", result)); return result; } /** * Determines if the {@link CasAuthenticationFilter} is configured to handle the proxy * receptor requests. - * * @return */ private boolean proxyReceptorConfigured() { - final boolean result = this.proxyGrantingTicketStorage != null - && proxyReceptorMatcher != null; - if (logger.isDebugEnabled()) { - logger.debug("proxyReceptorConfigured = " + result); - } + final boolean result = this.proxyGrantingTicketStorage != null && this.proxyReceptorMatcher != null; + this.logger.debug(LogMessage.format("proxyReceptorConfigured = %s", result)); return result; } /** * A wrapper for the AuthenticationFailureHandler that will flex the * {@link AuthenticationFailureHandler} that is used. The value - * {@link CasAuthenticationFilter#setProxyAuthenticationFailureHandler(AuthenticationFailureHandler) + * {@link CasAuthenticationFilter#setProxyAuthenticationFailureHandler(AuthenticationFailureHandler)} * will be used for proxy requests that fail. The value * {@link CasAuthenticationFilter#setAuthenticationFailureHandler(AuthenticationFailureHandler)} * will be used for service tickets that fail. - * - * @author Rob Winch */ private class CasAuthenticationFailureHandler implements AuthenticationFailureHandler { + private final AuthenticationFailureHandler serviceTicketFailureHandler; CasAuthenticationFailureHandler(AuthenticationFailureHandler failureHandler) { @@ -420,16 +378,17 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil this.serviceTicketFailureHandler = failureHandler; } - public void onAuthenticationFailure(HttpServletRequest request, - HttpServletResponse response, AuthenticationException exception) - throws IOException, ServletException { + @Override + public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, + AuthenticationException exception) throws IOException, ServletException { if (serviceTicketRequest(request, response)) { - serviceTicketFailureHandler.onAuthenticationFailure(request, response, - exception); + this.serviceTicketFailureHandler.onAuthenticationFailure(request, response, exception); } else { - proxyFailureHandler.onAuthenticationFailure(request, response, exception); + CasAuthenticationFilter.this.proxyFailureHandler.onAuthenticationFailure(request, response, exception); } } + } + } diff --git a/cas/src/main/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetails.java b/cas/src/main/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetails.java index cac9cd1643..2171df6cfc 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetails.java +++ b/cas/src/main/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetails.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.web.authentication; import java.net.MalformedURLException; @@ -34,16 +35,11 @@ import org.springframework.util.Assert; */ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails implements ServiceAuthenticationDetails { + private static final long serialVersionUID = 6192409090610517700L; - // ~ Instance fields - // ================================================================================================ - private final String serviceUrl; - // ~ Constructors - // =================================================================================================== - /** * Creates a new instance * @param request the current {@link HttpServletRequest} to obtain the @@ -52,33 +48,23 @@ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails * string from containing the artifact name and value. This can be created using * {@link #createArtifactPattern(String)}. */ - DefaultServiceAuthenticationDetails(String casService, HttpServletRequest request, - Pattern artifactPattern) throws MalformedURLException { + DefaultServiceAuthenticationDetails(String casService, HttpServletRequest request, Pattern artifactPattern) + throws MalformedURLException { super(request); URL casServiceUrl = new URL(casService); int port = getServicePort(casServiceUrl); final String query = getQueryString(request, artifactPattern); - this.serviceUrl = UrlUtils.buildFullRequestUrl(casServiceUrl.getProtocol(), - casServiceUrl.getHost(), port, request.getRequestURI(), query); + this.serviceUrl = UrlUtils.buildFullRequestUrl(casServiceUrl.getProtocol(), casServiceUrl.getHost(), port, + request.getRequestURI(), query); } - // ~ Methods - // ======================================================================================================== - /** * Returns the current URL minus the artifact parameter and its value, if present. * @see org.springframework.security.cas.web.authentication.ServiceAuthenticationDetails#getServiceUrl() */ - public String getServiceUrl() { - return serviceUrl; - } - @Override - public int hashCode() { - final int prime = 31; - int result = super.hashCode(); - result = prime * result + serviceUrl.hashCode(); - return result; + public String getServiceUrl() { + return this.serviceUrl; } @Override @@ -90,7 +76,15 @@ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails return false; } ServiceAuthenticationDetails that = (ServiceAuthenticationDetails) obj; - return serviceUrl.equals(that.getServiceUrl()); + return this.serviceUrl.equals(that.getServiceUrl()); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + this.serviceUrl.hashCode(); + return result; } @Override @@ -98,7 +92,7 @@ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails StringBuilder result = new StringBuilder(); result.append(super.toString()); result.append("ServiceUrl: "); - result.append(serviceUrl); + result.append(this.serviceUrl); return result.toString(); } @@ -109,13 +103,12 @@ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails * @return the query String minus the artifactParameterName and the corresponding * value. */ - private String getQueryString(final HttpServletRequest request, - final Pattern artifactPattern) { + private String getQueryString(final HttpServletRequest request, final Pattern artifactPattern) { final String query = request.getQueryString(); if (query == null) { return null; } - final String result = artifactPattern.matcher(query).replaceFirst(""); + String result = artifactPattern.matcher(query).replaceFirst(""); if (result.length() == 0) { return null; } @@ -127,7 +120,6 @@ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails * Creates a {@link Pattern} that can be passed into the constructor. This allows the * {@link Pattern} to be reused for every instance of * {@link DefaultServiceAuthenticationDetails}. - * * @param artifactParameterName * @return */ @@ -150,4 +142,5 @@ final class DefaultServiceAuthenticationDetails extends WebAuthenticationDetails } return port; } -} \ No newline at end of file + +} diff --git a/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetails.java b/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetails.java index 80f9e23803..e14da3d70e 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetails.java +++ b/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetails.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.web.authentication; import java.io.Serializable; @@ -28,15 +29,14 @@ import org.springframework.security.core.Authentication; * {@link ServiceProperties#getService()}. * * @author Rob Winch - * * @see ServiceAuthenticationDetailsSource */ public interface ServiceAuthenticationDetails extends Serializable { /** * Gets the absolute service url (i.e. https://example.com/service/). - * * @return the service url. Cannot be null. */ String getServiceUrl(); -} \ No newline at end of file + +} diff --git a/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetailsSource.java b/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetailsSource.java index c323bb491b..375952373f 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetailsSource.java +++ b/cas/src/main/java/org/springframework/security/cas/web/authentication/ServiceAuthenticationDetailsSource.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.web.authentication; import java.net.MalformedURLException; @@ -34,22 +35,16 @@ import org.springframework.util.Assert; * * @author Rob Winch */ -public class ServiceAuthenticationDetailsSource implements - AuthenticationDetailsSource { - // ~ Instance fields - // ================================================================================================ +public class ServiceAuthenticationDetailsSource + implements AuthenticationDetailsSource { private final Pattern artifactPattern; private ServiceProperties serviceProperties; - // ~ Constructors - // =================================================================================================== - /** * Creates an implementation that uses the specified ServiceProperties and the default * CAS artifactParameterName. - * * @param serviceProperties The ServiceProperties to use to construct the serviceUrl. */ public ServiceAuthenticationDetailsSource(ServiceProperties serviceProperties) { @@ -58,35 +53,31 @@ public class ServiceAuthenticationDetailsSource implements /** * Creates an implementation that uses the specified artifactParameterName - * * @param serviceProperties The ServiceProperties to use to construct the serviceUrl. * @param artifactParameterName the artifactParameterName that is removed from the * current URL. The result becomes the service url. Cannot be null and cannot be an * empty String. */ - public ServiceAuthenticationDetailsSource(ServiceProperties serviceProperties, - String artifactParameterName) { + public ServiceAuthenticationDetailsSource(ServiceProperties serviceProperties, String artifactParameterName) { Assert.notNull(serviceProperties, "serviceProperties cannot be null"); this.serviceProperties = serviceProperties; - this.artifactPattern = DefaultServiceAuthenticationDetails - .createArtifactPattern(artifactParameterName); + this.artifactPattern = DefaultServiceAuthenticationDetails.createArtifactPattern(artifactParameterName); } - // ~ Methods - // ======================================================================================================== - /** * @param context the {@code HttpServletRequest} object. * @return the {@code ServiceAuthenticationDetails} containing information about the * current request */ + @Override public ServiceAuthenticationDetails buildDetails(HttpServletRequest context) { try { - return new DefaultServiceAuthenticationDetails( - serviceProperties.getService(), context, artifactPattern); + return new DefaultServiceAuthenticationDetails(this.serviceProperties.getService(), context, + this.artifactPattern); } - catch (MalformedURLException e) { - throw new RuntimeException(e); + catch (MalformedURLException ex) { + throw new RuntimeException(ex); } } -} \ No newline at end of file + +} diff --git a/cas/src/main/java/org/springframework/security/cas/web/authentication/package-info.java b/cas/src/main/java/org/springframework/security/cas/web/authentication/package-info.java index b7be180138..ecd447dbac 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/authentication/package-info.java +++ b/cas/src/main/java/org/springframework/security/cas/web/authentication/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Authentication processing mechanisms which respond to the submission of authentication * credentials using CAS. */ package org.springframework.security.cas.web.authentication; - diff --git a/cas/src/main/java/org/springframework/security/cas/web/package-info.java b/cas/src/main/java/org/springframework/security/cas/web/package-info.java index 21781af725..903fdb8d4c 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/package-info.java +++ b/cas/src/main/java/org/springframework/security/cas/web/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Authenticates standard web browser users via CAS. */ package org.springframework.security.cas.web; - diff --git a/cas/src/test/java/org/springframework/security/cas/authentication/AbstractStatelessTicketCacheTests.java b/cas/src/test/java/org/springframework/security/cas/authentication/AbstractStatelessTicketCacheTests.java index ec009058a7..7f1233b7d5 100644 --- a/cas/src/test/java/org/springframework/security/cas/authentication/AbstractStatelessTicketCacheTests.java +++ b/cas/src/test/java/org/springframework/security/cas/authentication/AbstractStatelessTicketCacheTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.authentication; import java.util.ArrayList; @@ -20,12 +21,11 @@ import java.util.List; import org.jasig.cas.client.validation.Assertion; import org.jasig.cas.client.validation.AssertionImpl; -import org.springframework.security.cas.authentication.CasAuthenticationToken; + import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.User; /** - * * @author Scott Battaglia * @since 2.0 * @@ -35,14 +35,11 @@ public abstract class AbstractStatelessTicketCacheTests { protected CasAuthenticationToken getToken() { List proxyList = new ArrayList<>(); proxyList.add("https://localhost/newPortal/login/cas"); - User user = new User("rod", "password", true, true, true, true, AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); final Assertion assertion = new AssertionImpl("rod"); - return new CasAuthenticationToken("key", user, "ST-0-ER94xMJmn6pha35CQRoZ", - AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), user, - assertion); + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), user, assertion); } } diff --git a/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationProviderTests.java b/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationProviderTests.java index c39fdf8cfd..d5bef694f1 100644 --- a/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationProviderTests.java +++ b/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationProviderTests.java @@ -16,13 +16,14 @@ package org.springframework.security.cas.authentication; -import static org.mockito.Mockito.*; -import static org.assertj.core.api.Assertions.*; +import java.util.HashMap; +import java.util.Map; import org.jasig.cas.client.validation.Assertion; import org.jasig.cas.client.validation.AssertionImpl; import org.jasig.cas.client.validation.TicketValidator; -import org.junit.*; +import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -39,7 +40,13 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.web.authentication.WebAuthenticationDetails; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Tests {@link CasAuthenticationProvider}. @@ -49,8 +56,6 @@ import java.util.*; */ @SuppressWarnings("unchecked") public class CasAuthenticationProviderTests { - // ~ Methods - // ======================================================================================================== private UserDetails makeUserDetails() { return new User("user", "password", true, true, true, true, @@ -66,7 +71,6 @@ public class CasAuthenticationProviderTests { final ServiceProperties serviceProperties = new ServiceProperties(); serviceProperties.setSendRenew(false); serviceProperties.setService("http://test.com"); - return serviceProperties; } @@ -75,41 +79,30 @@ public class CasAuthenticationProviderTests { CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); - StatelessTicketCache cache = new MockStatelessTicketCache(); cap.setStatelessTicketCache(cache); cap.setServiceProperties(makeServiceProperties()); - cap.setTicketValidator(new MockTicketValidator(true)); cap.afterPropertiesSet(); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( CasAuthenticationFilter.CAS_STATEFUL_IDENTIFIER, "ST-123"); token.setDetails("details"); - Authentication result = cap.authenticate(token); - // Confirm ST-123 was NOT added to the cache assertThat(cache.getByTicketId("ST-456") == null).isTrue(); - if (!(result instanceof CasAuthenticationToken)) { fail("Should have returned a CasAuthenticationToken"); } - CasAuthenticationToken casResult = (CasAuthenticationToken) result; assertThat(casResult.getPrincipal()).isEqualTo(makeUserDetailsFromAuthoritiesPopulator()); assertThat(casResult.getCredentials()).isEqualTo("ST-123"); - assertThat(casResult.getAuthorities()).contains( - new SimpleGrantedAuthority("ROLE_A")); - assertThat(casResult.getAuthorities()).contains( - new SimpleGrantedAuthority("ROLE_B")); + assertThat(casResult.getAuthorities()).contains(new SimpleGrantedAuthority("ROLE_A")); + assertThat(casResult.getAuthorities()).contains(new SimpleGrantedAuthority("ROLE_B")); assertThat(casResult.getKeyHash()).isEqualTo(cap.getKey().hashCode()); assertThat(casResult.getDetails()).isEqualTo("details"); - // Now confirm the CasAuthenticationToken is automatically re-accepted. // To ensure TicketValidator not called again, set it to deliver an exception... cap.setTicketValidator(new MockTicketValidator(false)); - Authentication laterResult = cap.authenticate(result); assertThat(laterResult).isEqualTo(result); } @@ -119,34 +112,26 @@ public class CasAuthenticationProviderTests { CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); - StatelessTicketCache cache = new MockStatelessTicketCache(); cap.setStatelessTicketCache(cache); cap.setTicketValidator(new MockTicketValidator(true)); cap.setServiceProperties(makeServiceProperties()); cap.afterPropertiesSet(); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER, "ST-456"); token.setDetails("details"); - Authentication result = cap.authenticate(token); - // Confirm ST-456 was added to the cache assertThat(cache.getByTicketId("ST-456") != null).isTrue(); - if (!(result instanceof CasAuthenticationToken)) { fail("Should have returned a CasAuthenticationToken"); } - assertThat(result.getPrincipal()).isEqualTo(makeUserDetailsFromAuthoritiesPopulator()); assertThat(result.getCredentials()).isEqualTo("ST-456"); assertThat(result.getDetails()).isEqualTo("details"); - // Now try to authenticate again. To ensure TicketValidator not // called again, set it to deliver an exception... cap.setTicketValidator(new MockTicketValidator(false)); - // Previously created UsernamePasswordAuthenticationToken is OK Authentication newResult = cap.authenticate(token); assertThat(newResult.getPrincipal()).isEqualTo(makeUserDetailsFromAuthoritiesPopulator()); @@ -157,26 +142,20 @@ public class CasAuthenticationProviderTests { public void authenticateAllNullService() throws Exception { String serviceUrl = "https://service/context"; ServiceAuthenticationDetails details = mock(ServiceAuthenticationDetails.class); - when(details.getServiceUrl()).thenReturn(serviceUrl); + given(details.getServiceUrl()).willReturn(serviceUrl); TicketValidator validator = mock(TicketValidator.class); - when(validator.validate(any(String.class), any(String.class))).thenReturn( - new AssertionImpl("rod")); - + given(validator.validate(any(String.class), any(String.class))).willReturn(new AssertionImpl("rod")); ServiceProperties serviceProperties = makeServiceProperties(); serviceProperties.setAuthenticateAllArtifacts(true); - CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); - cap.setTicketValidator(validator); cap.setServiceProperties(serviceProperties); cap.afterPropertiesSet(); - String ticket = "ST-456"; UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER, ticket); - Authentication result = cap.authenticate(token); } @@ -184,44 +163,34 @@ public class CasAuthenticationProviderTests { public void authenticateAllAuthenticationIsSuccessful() throws Exception { String serviceUrl = "https://service/context"; ServiceAuthenticationDetails details = mock(ServiceAuthenticationDetails.class); - when(details.getServiceUrl()).thenReturn(serviceUrl); + given(details.getServiceUrl()).willReturn(serviceUrl); TicketValidator validator = mock(TicketValidator.class); - when(validator.validate(any(String.class), any(String.class))).thenReturn( - new AssertionImpl("rod")); - + given(validator.validate(any(String.class), any(String.class))).willReturn(new AssertionImpl("rod")); ServiceProperties serviceProperties = makeServiceProperties(); serviceProperties.setAuthenticateAllArtifacts(true); - CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); - cap.setTicketValidator(validator); cap.setServiceProperties(serviceProperties); cap.afterPropertiesSet(); - String ticket = "ST-456"; UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( CasAuthenticationFilter.CAS_STATELESS_IDENTIFIER, ticket); - Authentication result = cap.authenticate(token); verify(validator).validate(ticket, serviceProperties.getService()); - serviceProperties.setAuthenticateAllArtifacts(true); result = cap.authenticate(token); verify(validator, times(2)).validate(ticket, serviceProperties.getService()); - token.setDetails(details); result = cap.authenticate(token); verify(validator).validate(ticket, serviceUrl); - serviceProperties.setAuthenticateAllArtifacts(false); serviceProperties.setService(null); cap.setServiceProperties(serviceProperties); cap.afterPropertiesSet(); result = cap.authenticate(token); verify(validator, times(2)).validate(ticket, serviceUrl); - token.setDetails(new WebAuthenticationDetails(new MockHttpServletRequest())); try { cap.authenticate(token); @@ -229,7 +198,6 @@ public class CasAuthenticationProviderTests { } catch (IllegalStateException success) { } - cap.setServiceProperties(null); cap.afterPropertiesSet(); try { @@ -245,16 +213,13 @@ public class CasAuthenticationProviderTests { CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); - StatelessTicketCache cache = new MockStatelessTicketCache(); cap.setStatelessTicketCache(cache); cap.setTicketValidator(new MockTicketValidator(true)); cap.setServiceProperties(makeServiceProperties()); cap.afterPropertiesSet(); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( CasAuthenticationFilter.CAS_STATEFUL_IDENTIFIER, ""); - cap.authenticate(token); } @@ -264,17 +229,13 @@ public class CasAuthenticationProviderTests { CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); - StatelessTicketCache cache = new MockStatelessTicketCache(); cap.setStatelessTicketCache(cache); cap.setTicketValidator(new MockTicketValidator(true)); cap.setServiceProperties(makeServiceProperties()); cap.afterPropertiesSet(); - - CasAuthenticationToken token = new CasAuthenticationToken("WRONG_KEY", - makeUserDetails(), "credentials", + CasAuthenticationToken token = new CasAuthenticationToken("WRONG_KEY", makeUserDetails(), "credentials", AuthorityUtils.createAuthorityList("XX"), makeUserDetails(), assertion); - cap.authenticate(token); } @@ -329,7 +290,6 @@ public class CasAuthenticationProviderTests { cap.setTicketValidator(new MockTicketValidator(true)); cap.setServiceProperties(makeServiceProperties()); cap.afterPropertiesSet(); - // TODO disabled because why do we need to expose this? // assertThat(cap.getUserDetailsService() != null).isTrue(); assertThat(cap.getKey()).isEqualTo("qwerty"); @@ -346,18 +306,14 @@ public class CasAuthenticationProviderTests { cap.setTicketValidator(new MockTicketValidator(true)); cap.setServiceProperties(makeServiceProperties()); cap.afterPropertiesSet(); - - TestingAuthenticationToken token = new TestingAuthenticationToken("user", - "password", "ROLE_A"); + TestingAuthenticationToken token = new TestingAuthenticationToken("user", "password", "ROLE_A"); assertThat(cap.supports(TestingAuthenticationToken.class)).isFalse(); - // Try it anyway assertThat(cap.authenticate(token)).isNull(); } @Test - public void ignoresUsernamePasswordAuthenticationTokensWithoutCasIdentifiersAsPrincipal() - throws Exception { + public void ignoresUsernamePasswordAuthenticationTokensWithoutCasIdentifiersAsPrincipal() throws Exception { CasAuthenticationProvider cap = new CasAuthenticationProvider(); cap.setAuthenticationUserDetailsService(new MockAuthoritiesPopulator()); cap.setKey("qwerty"); @@ -365,10 +321,8 @@ public class CasAuthenticationProviderTests { cap.setTicketValidator(new MockTicketValidator(true)); cap.setServiceProperties(makeServiceProperties()); cap.afterPropertiesSet(); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "some_normal_user", "password", - AuthorityUtils.createAuthorityList("ROLE_A")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("some_normal_user", + "password", AuthorityUtils.createAuthorityList("ROLE_A")); assertThat(cap.authenticate(token)).isNull(); } @@ -379,49 +333,57 @@ public class CasAuthenticationProviderTests { assertThat(cap.supports(CasAuthenticationToken.class)).isTrue(); } - // ~ Inner Classes - // ================================================================================================== - private class MockAuthoritiesPopulator implements AuthenticationUserDetailsService { - public UserDetails loadUserDetails(final Authentication token) - throws UsernameNotFoundException { + @Override + public UserDetails loadUserDetails(final Authentication token) throws UsernameNotFoundException { return makeUserDetailsFromAuthoritiesPopulator(); } + } private class MockStatelessTicketCache implements StatelessTicketCache { + private Map cache = new HashMap<>(); + @Override public CasAuthenticationToken getByTicketId(String serviceTicket) { - return cache.get(serviceTicket); + return this.cache.get(serviceTicket); } + @Override public void putTicketInCache(CasAuthenticationToken token) { - cache.put(token.getCredentials().toString(), token); + this.cache.put(token.getCredentials().toString(), token); } + @Override public void removeTicketFromCache(CasAuthenticationToken token) { throw new UnsupportedOperationException("mock method not implemented"); } + @Override public void removeTicketFromCache(String serviceTicket) { throw new UnsupportedOperationException("mock method not implemented"); } + } private class MockTicketValidator implements TicketValidator { + private boolean returnTicket; MockTicketValidator(boolean returnTicket) { this.returnTicket = returnTicket; } + @Override public Assertion validate(final String ticket, final String service) { - if (returnTicket) { + if (this.returnTicket) { return new AssertionImpl("rod"); } throw new BadCredentialsException("As requested from mock"); } + } + } diff --git a/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationTokenTests.java b/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationTokenTests.java index 920b8d9c18..21278296c5 100644 --- a/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationTokenTests.java +++ b/cas/src/test/java/org/springframework/security/cas/authentication/CasAuthenticationTokenTests.java @@ -16,15 +16,13 @@ package org.springframework.security.cas.authentication; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.Collections; import java.util.List; import org.jasig.cas.client.validation.Assertion; import org.jasig.cas.client.validation.AssertionImpl; import org.junit.Test; + import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; @@ -32,6 +30,9 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link CasAuthenticationToken}. * @@ -39,64 +40,52 @@ import org.springframework.security.core.userdetails.UserDetails; */ public class CasAuthenticationTokenTests { - private final List ROLES = AuthorityUtils.createAuthorityList( - "ROLE_ONE", "ROLE_TWO"); + private final List ROLES = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); private UserDetails makeUserDetails() { return makeUserDetails("user"); } private UserDetails makeUserDetails(final String name) { - return new User(name, "password", true, true, true, true, ROLES); + return new User(name, "password", true, true, true, true, this.ROLES); } @Test public void testConstructorRejectsNulls() { final Assertion assertion = new AssertionImpl("test"); try { - new CasAuthenticationToken(null, makeUserDetails(), "Password", ROLES, - makeUserDetails(), assertion); + new CasAuthenticationToken(null, makeUserDetails(), "Password", this.ROLES, makeUserDetails(), assertion); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new CasAuthenticationToken("key", null, "Password", ROLES, makeUserDetails(), - assertion); + new CasAuthenticationToken("key", null, "Password", this.ROLES, makeUserDetails(), assertion); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new CasAuthenticationToken("key", makeUserDetails(), null, ROLES, - makeUserDetails(), assertion); + new CasAuthenticationToken("key", makeUserDetails(), null, this.ROLES, makeUserDetails(), assertion); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new CasAuthenticationToken("key", makeUserDetails(), "Password", ROLES, - makeUserDetails(), null); + new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, makeUserDetails(), null); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new CasAuthenticationToken("key", makeUserDetails(), "Password", ROLES, null, - assertion); + new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, null, assertion); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { new CasAuthenticationToken("key", makeUserDetails(), "Password", - AuthorityUtils.createAuthorityList("ROLE_1", null), makeUserDetails(), - assertion); + AuthorityUtils.createAuthorityList("ROLE_1", null), makeUserDetails(), assertion); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -112,13 +101,10 @@ public class CasAuthenticationTokenTests { @Test public void testEqualsWhenEqual() { final Assertion assertion = new AssertionImpl("test"); - - CasAuthenticationToken token1 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - - CasAuthenticationToken token2 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - + CasAuthenticationToken token1 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); + CasAuthenticationToken token2 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); assertThat(token2).isEqualTo(token1); } @@ -126,18 +112,15 @@ public class CasAuthenticationTokenTests { public void testGetters() { // Build the proxy list returned in the ticket from CAS final Assertion assertion = new AssertionImpl("test"); - CasAuthenticationToken token = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); + CasAuthenticationToken token = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); assertThat(token.getKeyHash()).isEqualTo("key".hashCode()); assertThat(token.getPrincipal()).isEqualTo(makeUserDetails()); assertThat(token.getCredentials()).isEqualTo("Password"); - assertThat(token.getAuthorities()).contains( - new SimpleGrantedAuthority("ROLE_ONE")); - assertThat(token.getAuthorities()).contains( - new SimpleGrantedAuthority("ROLE_TWO")); + assertThat(token.getAuthorities()).contains(new SimpleGrantedAuthority("ROLE_ONE")); + assertThat(token.getAuthorities()).contains(new SimpleGrantedAuthority("ROLE_TWO")); assertThat(token.getAssertion()).isEqualTo(assertion); - assertThat(token.getUserDetails().getUsername()).isEqualTo( - makeUserDetails().getUsername()); + assertThat(token.getUserDetails().getUsername()).isEqualTo(makeUserDetails().getUsername()); } @Test @@ -147,46 +130,36 @@ public class CasAuthenticationTokenTests { fail("Should have thrown NoSuchMethodException"); } catch (NoSuchMethodException expected) { - } } @Test public void testNotEqualsDueToAbstractParentEqualsCheck() { final Assertion assertion = new AssertionImpl("test"); - - CasAuthenticationToken token1 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - - CasAuthenticationToken token2 = new CasAuthenticationToken("key", - makeUserDetails("OTHER_NAME"), "Password", ROLES, makeUserDetails(), - assertion); - + CasAuthenticationToken token1 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); + CasAuthenticationToken token2 = new CasAuthenticationToken("key", makeUserDetails("OTHER_NAME"), "Password", + this.ROLES, makeUserDetails(), assertion); assertThat(!token1.equals(token2)).isTrue(); } @Test public void testNotEqualsDueToDifferentAuthenticationClass() { final Assertion assertion = new AssertionImpl("test"); - - CasAuthenticationToken token1 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - - UsernamePasswordAuthenticationToken token2 = new UsernamePasswordAuthenticationToken( - "Test", "Password", ROLES); + CasAuthenticationToken token1 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); + UsernamePasswordAuthenticationToken token2 = new UsernamePasswordAuthenticationToken("Test", "Password", + this.ROLES); assertThat(!token1.equals(token2)).isTrue(); } @Test public void testNotEqualsDueToKey() { final Assertion assertion = new AssertionImpl("test"); - - CasAuthenticationToken token1 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - - CasAuthenticationToken token2 = new CasAuthenticationToken("DIFFERENT_KEY", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - + CasAuthenticationToken token1 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); + CasAuthenticationToken token2 = new CasAuthenticationToken("DIFFERENT_KEY", makeUserDetails(), "Password", + this.ROLES, makeUserDetails(), assertion); assertThat(!token1.equals(token2)).isTrue(); } @@ -194,21 +167,18 @@ public class CasAuthenticationTokenTests { public void testNotEqualsDueToAssertion() { final Assertion assertion = new AssertionImpl("test"); final Assertion assertion2 = new AssertionImpl("test"); - - CasAuthenticationToken token1 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); - - CasAuthenticationToken token2 = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion2); - + CasAuthenticationToken token1 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); + CasAuthenticationToken token2 = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion2); assertThat(!token1.equals(token2)).isTrue(); } @Test public void testSetAuthenticated() { final Assertion assertion = new AssertionImpl("test"); - CasAuthenticationToken token = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); + CasAuthenticationToken token = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); assertThat(token.isAuthenticated()).isTrue(); token.setAuthenticated(false); assertThat(!token.isAuthenticated()).isTrue(); @@ -217,10 +187,10 @@ public class CasAuthenticationTokenTests { @Test public void testToString() { final Assertion assertion = new AssertionImpl("test"); - CasAuthenticationToken token = new CasAuthenticationToken("key", - makeUserDetails(), "Password", ROLES, makeUserDetails(), assertion); + CasAuthenticationToken token = new CasAuthenticationToken("key", makeUserDetails(), "Password", this.ROLES, + makeUserDetails(), assertion); String result = token.toString(); - assertThat( - result.lastIndexOf("Credentials (Service/Proxy Ticket):") != -1).isTrue(); + assertThat(result.lastIndexOf("Credentials (Service/Proxy Ticket):") != -1).isTrue(); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCacheTests.java b/cas/src/test/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCacheTests.java index d50013caed..513158a479 100644 --- a/cas/src/test/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCacheTests.java +++ b/cas/src/test/java/org/springframework/security/cas/authentication/EhCacheBasedTicketCacheTests.java @@ -16,17 +16,15 @@ package org.springframework.security.cas.authentication; -import net.sf.ehcache.Ehcache; -import net.sf.ehcache.CacheManager; import net.sf.ehcache.Cache; - -import org.junit.Test; -import org.junit.BeforeClass; +import net.sf.ehcache.CacheManager; +import net.sf.ehcache.Ehcache; import org.junit.AfterClass; -import org.springframework.security.cas.authentication.CasAuthenticationToken; -import org.springframework.security.cas.authentication.EhCacheBasedTicketCache; +import org.junit.BeforeClass; +import org.junit.Test; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** * Tests {@link EhCacheBasedTicketCache}. @@ -34,10 +32,9 @@ import static org.assertj.core.api.Assertions.*; * @author Ben Alex */ public class EhCacheBasedTicketCacheTests extends AbstractStatelessTicketCacheTests { + private static CacheManager cacheManager; - // ~ Methods - // ======================================================================================================== @BeforeClass public static void initCacheManaer() { cacheManager = CacheManager.create(); @@ -55,17 +52,13 @@ public class EhCacheBasedTicketCacheTests extends AbstractStatelessTicketCacheTe EhCacheBasedTicketCache cache = new EhCacheBasedTicketCache(); cache.setCache(cacheManager.getCache("castickets")); cache.afterPropertiesSet(); - final CasAuthenticationToken token = getToken(); - // Check it gets stored in the cache cache.putTicketInCache(token); assertThat(cache.getByTicketId("ST-0-ER94xMJmn6pha35CQRoZ")).isEqualTo(token); - // Check it gets removed from the cache cache.removeTicketFromCache(getToken()); assertThat(cache.getByTicketId("ST-0-ER94xMJmn6pha35CQRoZ")).isNull(); - // Check it doesn't return values for null or unknown service tickets assertThat(cache.getByTicketId(null)).isNull(); assertThat(cache.getByTicketId("UNKNOWN_SERVICE_TICKET")).isNull(); @@ -74,17 +67,15 @@ public class EhCacheBasedTicketCacheTests extends AbstractStatelessTicketCacheTe @Test public void testStartupDetectsMissingCache() throws Exception { EhCacheBasedTicketCache cache = new EhCacheBasedTicketCache(); - try { cache.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - Ehcache myCache = cacheManager.getCache("castickets"); cache.setCache(myCache); assertThat(cache.getCache()).isEqualTo(myCache); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/authentication/NullStatelessTicketCacheTests.java b/cas/src/test/java/org/springframework/security/cas/authentication/NullStatelessTicketCacheTests.java index 1643bb0394..c48644644b 100644 --- a/cas/src/test/java/org/springframework/security/cas/authentication/NullStatelessTicketCacheTests.java +++ b/cas/src/test/java/org/springframework/security/cas/authentication/NullStatelessTicketCacheTests.java @@ -13,14 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.authentication; import org.junit.Test; -import org.springframework.security.cas.authentication.CasAuthenticationToken; -import org.springframework.security.cas.authentication.NullStatelessTicketCache; -import org.springframework.security.cas.authentication.StatelessTicketCache; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * Test cases for the @link {@link NullStatelessTicketCache} @@ -34,14 +32,15 @@ public class NullStatelessTicketCacheTests extends AbstractStatelessTicketCacheT @Test public void testGetter() { - assertThat(cache.getByTicketId(null)).isNull(); - assertThat(cache.getByTicketId("test")).isNull(); + assertThat(this.cache.getByTicketId(null)).isNull(); + assertThat(this.cache.getByTicketId("test")).isNull(); } @Test public void testInsertAndGet() { final CasAuthenticationToken token = getToken(); - cache.putTicketInCache(token); - assertThat(cache.getByTicketId((String) token.getCredentials())).isNull(); + this.cache.putTicketInCache(token); + assertThat(this.cache.getByTicketId((String) token.getCredentials())).isNull(); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCacheTests.java b/cas/src/test/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCacheTests.java index 57ccf6136c..607ed39260 100644 --- a/cas/src/test/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCacheTests.java +++ b/cas/src/test/java/org/springframework/security/cas/authentication/SpringCacheBasedTicketCacheTests.java @@ -18,10 +18,11 @@ package org.springframework.security.cas.authentication; import org.junit.BeforeClass; import org.junit.Test; + import org.springframework.cache.CacheManager; import org.springframework.cache.concurrent.ConcurrentMapCacheManager; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * Tests @@ -31,10 +32,8 @@ import static org.assertj.core.api.Assertions.*; * @since 3.2 */ public class SpringCacheBasedTicketCacheTests extends AbstractStatelessTicketCacheTests { - private static CacheManager cacheManager; - // ~ Methods - // ======================================================================================================== + private static CacheManager cacheManager; @BeforeClass public static void initCacheManaer() { @@ -44,19 +43,14 @@ public class SpringCacheBasedTicketCacheTests extends AbstractStatelessTicketCac @Test public void testCacheOperation() throws Exception { - SpringCacheBasedTicketCache cache = new SpringCacheBasedTicketCache( - cacheManager.getCache("castickets")); - + SpringCacheBasedTicketCache cache = new SpringCacheBasedTicketCache(cacheManager.getCache("castickets")); final CasAuthenticationToken token = getToken(); - // Check it gets stored in the cache cache.putTicketInCache(token); assertThat(cache.getByTicketId("ST-0-ER94xMJmn6pha35CQRoZ")).isEqualTo(token); - // Check it gets removed from the cache cache.removeTicketFromCache(getToken()); assertThat(cache.getByTicketId("ST-0-ER94xMJmn6pha35CQRoZ")).isNull(); - // Check it doesn't return values for null or unknown service tickets assertThat(cache.getByTicketId(null)).isNull(); assertThat(cache.getByTicketId("UNKNOWN_SERVICE_TICKET")).isNull(); @@ -66,4 +60,5 @@ public class SpringCacheBasedTicketCacheTests extends AbstractStatelessTicketCac public void testStartupDetectsMissingCache() throws Exception { new SpringCacheBasedTicketCache(null); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixinTests.java b/cas/src/test/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixinTests.java index 933be2cf14..4eed00cbe9 100644 --- a/cas/src/test/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixinTests.java +++ b/cas/src/test/java/org/springframework/security/cas/jackson2/CasAuthenticationTokenMixinTests.java @@ -47,15 +47,20 @@ import static org.assertj.core.api.Assertions.assertThat; public class CasAuthenticationTokenMixinTests { private static final String KEY = "casKey"; + private static final String PASSWORD = "\"1234\""; + private static final Date START_DATE = new Date(); + private static final Date END_DATE = new Date(); public static final String AUTHORITY_JSON = "{\"@class\": \"org.springframework.security.core.authority.SimpleGrantedAuthority\", \"authority\": \"ROLE_USER\"}"; - public static final String AUTHORITIES_SET_JSON = "[\"java.util.Collections$UnmodifiableSet\", [" + AUTHORITY_JSON + "]]"; + public static final String AUTHORITIES_SET_JSON = "[\"java.util.Collections$UnmodifiableSet\", [" + AUTHORITY_JSON + + "]]"; - public static final String AUTHORITIES_ARRAYLIST_JSON = "[\"java.util.Collections$UnmodifiableRandomAccessList\", [" + AUTHORITY_JSON + "]]"; + public static final String AUTHORITIES_ARRAYLIST_JSON = "[\"java.util.Collections$UnmodifiableRandomAccessList\", [" + + AUTHORITY_JSON + "]]"; // @formatter:off public static final String USER_JSON = "{" @@ -69,31 +74,19 @@ public class CasAuthenticationTokenMixinTests { + "\"authorities\": " + AUTHORITIES_SET_JSON + "}"; // @formatter:on - private static final String CAS_TOKEN_JSON = "{" - + "\"@class\": \"org.springframework.security.cas.authentication.CasAuthenticationToken\", " - + "\"keyHash\": " + KEY.hashCode() + "," - + "\"principal\": " + USER_JSON + ", " - + "\"credentials\": " + PASSWORD + ", " - + "\"authorities\": " + AUTHORITIES_ARRAYLIST_JSON + "," - + "\"userDetails\": " + USER_JSON +"," - + "\"authenticated\": true, " - + "\"details\": null," - + "\"assertion\": {" - + "\"@class\": \"org.jasig.cas.client.validation.AssertionImpl\", " - + "\"principal\": {" - + "\"@class\": \"org.jasig.cas.client.authentication.AttributePrincipalImpl\", " - + "\"name\": \"assertName\", " - + "\"attributes\": {\"@class\": \"java.util.Collections$EmptyMap\"}, " - + "\"proxyGrantingTicket\": null, " - + "\"proxyRetriever\": null" - + "}, " + + "\"@class\": \"org.springframework.security.cas.authentication.CasAuthenticationToken\", " + + "\"keyHash\": " + KEY.hashCode() + "," + "\"principal\": " + USER_JSON + ", " + "\"credentials\": " + + PASSWORD + ", " + "\"authorities\": " + AUTHORITIES_ARRAYLIST_JSON + "," + "\"userDetails\": " + USER_JSON + + "," + "\"authenticated\": true, " + "\"details\": null," + "\"assertion\": {" + + "\"@class\": \"org.jasig.cas.client.validation.AssertionImpl\", " + "\"principal\": {" + + "\"@class\": \"org.jasig.cas.client.authentication.AttributePrincipalImpl\", " + + "\"name\": \"assertName\", " + "\"attributes\": {\"@class\": \"java.util.Collections$EmptyMap\"}, " + + "\"proxyGrantingTicket\": null, " + "\"proxyRetriever\": null" + "}, " + "\"validFromDate\": [\"java.util.Date\", " + START_DATE.getTime() + "], " + "\"validUntilDate\": [\"java.util.Date\", " + END_DATE.getTime() + "]," + "\"authenticationDate\": [\"java.util.Date\", " + START_DATE.getTime() + "], " - + "\"attributes\": {\"@class\": \"java.util.Collections$EmptyMap\"}" + - "}" - + "}"; + + "\"attributes\": {\"@class\": \"java.util.Collections$EmptyMap\"}" + "}" + "}"; private static final String CAS_TOKEN_CLEARED_JSON = CAS_TOKEN_JSON.replaceFirst(PASSWORD, "null"); @@ -101,35 +94,36 @@ public class CasAuthenticationTokenMixinTests { @Before public void setup() { - mapper = new ObjectMapper(); + this.mapper = new ObjectMapper(); ClassLoader loader = getClass().getClassLoader(); - mapper.registerModules(SecurityJackson2Modules.getModules(loader)); + this.mapper.registerModules(SecurityJackson2Modules.getModules(loader)); } @Test public void serializeCasAuthenticationTest() throws JsonProcessingException, JSONException { CasAuthenticationToken token = createCasAuthenticationToken(); - String actualJson = mapper.writeValueAsString(token); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(CAS_TOKEN_JSON, actualJson, true); } @Test - public void serializeCasAuthenticationTestAfterEraseCredentialInvoked() throws JsonProcessingException, JSONException { + public void serializeCasAuthenticationTestAfterEraseCredentialInvoked() + throws JsonProcessingException, JSONException { CasAuthenticationToken token = createCasAuthenticationToken(); token.eraseCredentials(); - String actualJson = mapper.writeValueAsString(token); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(CAS_TOKEN_CLEARED_JSON, actualJson, true); } @Test public void deserializeCasAuthenticationTestAfterEraseCredentialInvoked() throws Exception { - CasAuthenticationToken token = mapper.readValue(CAS_TOKEN_CLEARED_JSON, CasAuthenticationToken.class); + CasAuthenticationToken token = this.mapper.readValue(CAS_TOKEN_CLEARED_JSON, CasAuthenticationToken.class); assertThat(((UserDetails) token.getPrincipal()).getPassword()).isNull(); } @Test public void deserializeCasAuthenticationTest() throws IOException { - CasAuthenticationToken token = mapper.readValue(CAS_TOKEN_JSON, CasAuthenticationToken.class); + CasAuthenticationToken token = this.mapper.readValue(CAS_TOKEN_JSON, CasAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getPrincipal()).isNotNull().isInstanceOf(User.class); assertThat(((User) token.getPrincipal()).getUsername()).isEqualTo("admin"); @@ -137,9 +131,8 @@ public class CasAuthenticationTokenMixinTests { assertThat(token.getUserDetails()).isNotNull().isInstanceOf(User.class); assertThat(token.getAssertion()).isNotNull().isInstanceOf(AssertionImpl.class); assertThat(token.getKeyHash()).isEqualTo(KEY.hashCode()); - assertThat(token.getUserDetails().getAuthorities()) - .extracting(GrantedAuthority::getAuthority) - .containsOnly("ROLE_USER"); + assertThat(token.getUserDetails().getAuthorities()).extracting(GrantedAuthority::getAuthority) + .containsOnly("ROLE_USER"); assertThat(token.getAssertion().getAuthenticationDate()).isEqualTo(START_DATE); assertThat(token.getAssertion().getValidFromDate()).isEqualTo(START_DATE); assertThat(token.getAssertion().getValidUntilDate()).isEqualTo(END_DATE); @@ -149,9 +142,12 @@ public class CasAuthenticationTokenMixinTests { private CasAuthenticationToken createCasAuthenticationToken() { User principal = new User("admin", "1234", Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"))); - Collection authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER")); - Assertion assertion = new AssertionImpl(new AttributePrincipalImpl("assertName"), START_DATE, END_DATE, START_DATE, Collections.emptyMap()); + Collection authorities = Collections + .singletonList(new SimpleGrantedAuthority("ROLE_USER")); + Assertion assertion = new AssertionImpl(new AttributePrincipalImpl("assertName"), START_DATE, END_DATE, + START_DATE, Collections.emptyMap()); return new CasAuthenticationToken(KEY, principal, principal.getPassword(), authorities, new User("admin", "1234", authorities), assertion); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsServiceTests.java b/cas/src/test/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsServiceTests.java index b4cbad40cf..e33e8a441e 100644 --- a/cas/src/test/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsServiceTests.java +++ b/cas/src/test/java/org/springframework/security/cas/userdetails/GrantedAuthorityFromAssertionAttributesUserDetailsServiceTests.java @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.userdetails; -import static org.assertj.core.api.Assertions.*; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import org.jasig.cas.client.authentication.AttributePrincipal; -import org.jasig.cas.client.validation.Assertion; -import org.junit.Test; -import org.springframework.security.cas.authentication.CasAssertionAuthenticationToken; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.userdetails.UserDetails; - import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Set; +import org.jasig.cas.client.authentication.AttributePrincipal; +import org.jasig.cas.client.validation.Assertion; +import org.junit.Test; + +import org.springframework.security.cas.authentication.CasAssertionAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.userdetails.UserDetails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * @author Luke Taylor */ @@ -50,14 +51,13 @@ public class GrantedAuthorityFromAssertionAttributesUserDetailsServiceTests { attributes.put("c", "role_c"); attributes.put("d", null); attributes.put("someother", "unused"); - when(assertion.getPrincipal()).thenReturn(principal); - when(principal.getAttributes()).thenReturn(attributes); - when(principal.getName()).thenReturn("somebody"); - CasAssertionAuthenticationToken token = new CasAssertionAuthenticationToken( - assertion, "ticket"); + given(assertion.getPrincipal()).willReturn(principal); + given(principal.getAttributes()).willReturn(attributes); + given(principal.getName()).willReturn("somebody"); + CasAssertionAuthenticationToken token = new CasAssertionAuthenticationToken(assertion, "ticket"); UserDetails user = uds.loadUserDetails(token); Set roles = AuthorityUtils.authorityListToSet(user.getAuthorities()); - assertThat(roles).containsExactlyInAnyOrder( - "role_a1", "role_a2", "role_b", "role_c"); + assertThat(roles).containsExactlyInAnyOrder("role_a1", "role_a2", "role_b", "role_c"); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationEntryPointTests.java b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationEntryPointTests.java index 35d3f7fdc6..825542cb79 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationEntryPointTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationEntryPointTests.java @@ -16,16 +16,17 @@ package org.springframework.security.cas.web; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.net.URLEncoder; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.cas.ServiceProperties; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link CasAuthenticationEntryPoint}. * @@ -33,13 +34,10 @@ import org.springframework.security.cas.ServiceProperties; */ public class CasAuthenticationEntryPointTests { - // ~ Methods - // ======================================================================================================== @Test public void testDetectsMissingLoginFormUrl() throws Exception { CasAuthenticationEntryPoint ep = new CasAuthenticationEntryPoint(); ep.setServiceProperties(new ServiceProperties()); - try { ep.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); @@ -53,14 +51,12 @@ public class CasAuthenticationEntryPointTests { public void testDetectsMissingServiceProperties() throws Exception { CasAuthenticationEntryPoint ep = new CasAuthenticationEntryPoint(); ep.setLoginUrl("https://cas/login"); - try { ep.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - assertThat(expected.getMessage()).isEqualTo( - "serviceProperties must be specified"); + assertThat(expected.getMessage()).isEqualTo("serviceProperties must be specified"); } } @@ -69,7 +65,6 @@ public class CasAuthenticationEntryPointTests { CasAuthenticationEntryPoint ep = new CasAuthenticationEntryPoint(); ep.setLoginUrl("https://cas/login"); assertThat(ep.getLoginUrl()).isEqualTo("https://cas/login"); - ep.setServiceProperties(new ServiceProperties()); assertThat(ep.getServiceProperties() != null).isTrue(); } @@ -79,22 +74,17 @@ public class CasAuthenticationEntryPointTests { ServiceProperties sp = new ServiceProperties(); sp.setSendRenew(false); sp.setService("https://mycompany.com/bigWebApp/login/cas"); - CasAuthenticationEntryPoint ep = new CasAuthenticationEntryPoint(); ep.setLoginUrl("https://cas/login"); ep.setServiceProperties(sp); - MockHttpServletRequest request = new MockHttpServletRequest(); request.setRequestURI("/some_path"); - MockHttpServletResponse response = new MockHttpServletResponse(); - ep.afterPropertiesSet(); ep.commence(request, response, null); - - assertThat("https://cas/login?service=" + URLEncoder.encode( - "https://mycompany.com/bigWebApp/login/cas", "UTF-8")).isEqualTo( - response.getRedirectedUrl()); + assertThat( + "https://cas/login?service=" + URLEncoder.encode("https://mycompany.com/bigWebApp/login/cas", "UTF-8")) + .isEqualTo(response.getRedirectedUrl()); } @Test @@ -102,20 +92,17 @@ public class CasAuthenticationEntryPointTests { ServiceProperties sp = new ServiceProperties(); sp.setSendRenew(true); sp.setService("https://mycompany.com/bigWebApp/login/cas"); - CasAuthenticationEntryPoint ep = new CasAuthenticationEntryPoint(); ep.setLoginUrl("https://cas/login"); ep.setServiceProperties(sp); - MockHttpServletRequest request = new MockHttpServletRequest(); request.setRequestURI("/some_path"); - MockHttpServletResponse response = new MockHttpServletResponse(); - ep.afterPropertiesSet(); ep.commence(request, response, null); assertThat("https://cas/login?service=" - + URLEncoder.encode("https://mycompany.com/bigWebApp/login/cas", "UTF-8") - + "&renew=true").isEqualTo(response.getRedirectedUrl()); + + URLEncoder.encode("https://mycompany.com/bigWebApp/login/cas", "UTF-8") + "&renew=true") + .isEqualTo(response.getRedirectedUrl()); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java index 700e24240f..e704ecd3ff 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java @@ -16,9 +16,12 @@ package org.springframework.security.cas.web; +import javax.servlet.FilterChain; + import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage; import org.junit.After; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -32,10 +35,13 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; -import javax.servlet.FilterChain; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests {@link CasAuthenticationFilter}. @@ -44,8 +50,6 @@ import static org.mockito.Mockito.*; * @author Rob Winch */ public class CasAuthenticationFilterTests { - // ~ Methods - // ======================================================================================================== @After public void tearDown() { @@ -65,26 +69,20 @@ public class CasAuthenticationFilterTests { MockHttpServletRequest request = new MockHttpServletRequest(); request.setServletPath("/login/cas"); request.addParameter("ticket", "ST-0-ER94xMJmn6pha35CQRoZ"); - CasAuthenticationFilter filter = new CasAuthenticationFilter(); - filter.setAuthenticationManager(a -> a); - + filter.setAuthenticationManager((a) -> a); assertThat(filter.requiresAuthentication(request, new MockHttpServletResponse())).isTrue(); - - Authentication result = filter.attemptAuthentication(request, - new MockHttpServletResponse()); + Authentication result = filter.attemptAuthentication(request, new MockHttpServletResponse()); assertThat(result != null).isTrue(); } @Test(expected = AuthenticationException.class) public void testNullServiceTicketHandledGracefully() throws Exception { CasAuthenticationFilter filter = new CasAuthenticationFilter(); - filter.setAuthenticationManager(a -> { + filter.setAuthenticationManager((a) -> { throw new BadCredentialsException("Rejected"); }); - - filter.attemptAuthentication(new MockHttpServletRequest(), - new MockHttpServletResponse()); + filter.attemptAuthentication(new MockHttpServletRequest(), new MockHttpServletResponse()); } @Test @@ -94,7 +92,6 @@ public class CasAuthenticationFilterTests { filter.setFilterProcessesUrl(url); MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); } @@ -104,7 +101,6 @@ public class CasAuthenticationFilterTests { CasAuthenticationFilter filter = new CasAuthenticationFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); filter.setProxyReceptorUrl(request.getServletPath()); @@ -119,30 +115,25 @@ public class CasAuthenticationFilterTests { public void testRequiresAuthenticationAuthAll() { ServiceProperties properties = new ServiceProperties(); properties.setAuthenticateAllArtifacts(true); - String url = "/login/cas"; CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setFilterProcessesUrl(url); filter.setServiceProperties(properties); MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath(url); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - request.setServletPath("/other"); assertThat(filter.requiresAuthentication(request, response)).isFalse(); request.setParameter(properties.getArtifactParameter(), "value"); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - SecurityContextHolder.getContext().setAuthentication( - new AnonymousAuthenticationToken("key", "principal", AuthorityUtils - .createAuthorityList("ROLE_ANONYMOUS"))); + SecurityContextHolder.getContext().setAuthentication(new AnonymousAuthenticationToken("key", "principal", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("un", "principal")); + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("un", "principal")); assertThat(filter.requiresAuthentication(request, response)).isTrue(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("un", "principal", "ROLE_ANONYMOUS")); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("un", "principal", "ROLE_ANONYMOUS")); assertThat(filter.requiresAuthentication(request, response)).isFalse(); } @@ -151,7 +142,6 @@ public class CasAuthenticationFilterTests { CasAuthenticationFilter filter = new CasAuthenticationFilter(); MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - request.setServletPath("/pgtCallback"); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyReceptorUrl(request.getServletPath()); @@ -162,9 +152,8 @@ public class CasAuthenticationFilterTests { public void testDoFilterAuthenticateAll() throws Exception { AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class); AuthenticationManager manager = mock(AuthenticationManager.class); - Authentication authentication = new TestingAuthenticationToken("un", "pwd", - "ROLE_USER"); - when(manager.authenticate(any(Authentication.class))).thenReturn(authentication); + Authentication authentication = new TestingAuthenticationToken("un", "pwd", "ROLE_USER"); + given(manager.authenticate(any(Authentication.class))).willReturn(authentication); ServiceProperties serviceProperties = new ServiceProperties(); serviceProperties.setAuthenticateAllArtifacts(true); MockHttpServletRequest request = new MockHttpServletRequest(); @@ -172,20 +161,17 @@ public class CasAuthenticationFilterTests { request.setServletPath("/authenticate"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - CasAuthenticationFilter filter = new CasAuthenticationFilter(); filter.setServiceProperties(serviceProperties); filter.setAuthenticationSuccessHandler(successHandler); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setAuthenticationManager(manager); filter.afterPropertiesSet(); - filter.doFilter(request, response, chain); - assertThat(SecurityContextHolder - .getContext().getAuthentication()).isNotNull().withFailMessage("Authentication should not be null"); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull() + .withFailMessage("Authentication should not be null"); verify(chain).doFilter(request, response); verifyZeroInteractions(successHandler); - // validate for when the filterProcessUrl matches filter.setFilterProcessesUrl(request.getServletPath()); SecurityContextHolder.clearContext(); @@ -201,12 +187,11 @@ public class CasAuthenticationFilterTests { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - request.setServletPath("/pgtCallback"); filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); filter.setProxyReceptorUrl(request.getServletPath()); - filter.doFilter(request, response, chain); verifyZeroInteractions(chain); } + } diff --git a/cas/src/test/java/org/springframework/security/cas/web/ServicePropertiesTests.java b/cas/src/test/java/org/springframework/security/cas/web/ServicePropertiesTests.java index b9ddcf1942..cc61ac93be 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/ServicePropertiesTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/ServicePropertiesTests.java @@ -16,20 +16,20 @@ package org.springframework.security.cas.web; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; + import org.springframework.security.cas.SamlServiceProperties; import org.springframework.security.cas.ServiceProperties; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link ServiceProperties}. * * @author Ben Alex */ public class ServicePropertiesTests { - // ~ Methods - // ======================================================================================================== @Test(expected = IllegalArgumentException.class) public void detectsMissingService() throws Exception { @@ -68,11 +68,10 @@ public class ServicePropertiesTests { assertThat(sp.getArtifactParameter()).isEqualTo("notticket"); sp.setServiceParameter("notservice"); assertThat(sp.getServiceParameter()).isEqualTo("notservice"); - sp.setService("https://mycompany.com/service"); assertThat(sp.getService()).isEqualTo("https://mycompany.com/service"); - sp.afterPropertiesSet(); } } + } diff --git a/cas/src/test/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetailsTests.java b/cas/src/test/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetailsTests.java index 3343a2bbb2..d7d95ea3a7 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetailsTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/authentication/DefaultServiceAuthenticationDetailsTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.cas.web.authentication; import java.util.regex.Pattern; @@ -30,13 +31,16 @@ import org.springframework.security.web.util.UrlUtils; import static org.assertj.core.api.Assertions.assertThat; /** - * * @author Rob Winch */ public class DefaultServiceAuthenticationDetailsTests { + private DefaultServiceAuthenticationDetails details; + private MockHttpServletRequest request; + private Pattern artifactPattern; + private String casServiceUrl; private ConfigurableApplicationContext context; @@ -51,7 +55,6 @@ public class DefaultServiceAuthenticationDetailsTests { this.request.setRequestURI("/cas-sample/secure/"); this.artifactPattern = DefaultServiceAuthenticationDetails .createArtifactPattern(ServiceProperties.DEFAULT_CAS_ARTIFACT_PARAMETER); - } @After @@ -63,17 +66,14 @@ public class DefaultServiceAuthenticationDetailsTests { @Test public void getServiceUrlNullQuery() throws Exception { - this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, - this.request, this.artifactPattern); - assertThat(this.details.getServiceUrl()) - .isEqualTo(UrlUtils.buildFullRequestUrl(this.request)); + this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, this.request, this.artifactPattern); + assertThat(this.details.getServiceUrl()).isEqualTo(UrlUtils.buildFullRequestUrl(this.request)); } @Test public void getServiceUrlTicketOnlyParam() throws Exception { this.request.setQueryString("ticket=123"); - this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, - this.request, this.artifactPattern); + this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, this.request, this.artifactPattern); String serviceUrl = this.details.getServiceUrl(); this.request.setQueryString(null); assertThat(serviceUrl).isEqualTo(UrlUtils.buildFullRequestUrl(this.request)); @@ -82,8 +82,7 @@ public class DefaultServiceAuthenticationDetailsTests { @Test public void getServiceUrlTicketFirstMultiParam() throws Exception { this.request.setQueryString("ticket=123&other=value"); - this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, - this.request, this.artifactPattern); + this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, this.request, this.artifactPattern); String serviceUrl = this.details.getServiceUrl(); this.request.setQueryString("other=value"); assertThat(serviceUrl).isEqualTo(UrlUtils.buildFullRequestUrl(this.request)); @@ -92,8 +91,7 @@ public class DefaultServiceAuthenticationDetailsTests { @Test public void getServiceUrlTicketLastMultiParam() throws Exception { this.request.setQueryString("other=value&ticket=123"); - this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, - this.request, this.artifactPattern); + this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, this.request, this.artifactPattern); String serviceUrl = this.details.getServiceUrl(); this.request.setQueryString("other=value"); assertThat(serviceUrl).isEqualTo(UrlUtils.buildFullRequestUrl(this.request)); @@ -102,8 +100,7 @@ public class DefaultServiceAuthenticationDetailsTests { @Test public void getServiceUrlTicketMiddleMultiParam() throws Exception { this.request.setQueryString("other=value&ticket=123&last=this"); - this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, - this.request, this.artifactPattern); + this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, this.request, this.artifactPattern); String serviceUrl = this.details.getServiceUrl(); this.request.setQueryString("other=value&last=this"); assertThat(serviceUrl).isEqualTo(UrlUtils.buildFullRequestUrl(this.request)); @@ -113,10 +110,8 @@ public class DefaultServiceAuthenticationDetailsTests { public void getServiceUrlDoesNotUseHostHeader() throws Exception { this.casServiceUrl = "https://example.com/j_spring_security_cas"; this.request.setServerName("evil.com"); - this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, - this.request, this.artifactPattern); - assertThat(this.details.getServiceUrl()) - .isEqualTo("https://example.com/cas-sample/secure/"); + this.details = new DefaultServiceAuthenticationDetails(this.casServiceUrl, this.request, this.artifactPattern); + assertThat(this.details.getServiceUrl()).isEqualTo("https://example.com/cas-sample/secure/"); } @Test @@ -125,15 +120,13 @@ public class DefaultServiceAuthenticationDetailsTests { this.request.setServerName("evil.com"); ServiceAuthenticationDetails details = loadServiceAuthenticationDetails( "defaultserviceauthenticationdetails-explicit.xml"); - assertThat(details.getServiceUrl()) - .isEqualTo("https://example.com/cas-sample/secure/"); + assertThat(details.getServiceUrl()).isEqualTo("https://example.com/cas-sample/secure/"); } - private ServiceAuthenticationDetails loadServiceAuthenticationDetails( - String resourceName) { + private ServiceAuthenticationDetails loadServiceAuthenticationDetails(String resourceName) { this.context = new GenericXmlApplicationContext(getClass(), resourceName); - ServiceAuthenticationDetailsSource source = this.context - .getBean(ServiceAuthenticationDetailsSource.class); + ServiceAuthenticationDetailsSource source = this.context.getBean(ServiceAuthenticationDetailsSource.class); return source.buildDetails(this.request); } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderBuilderSecurityBuilderTests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderBuilderSecurityBuilderTests.java index 9ff0675bc2..36d1d636ee 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderBuilderSecurityBuilderTests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderBuilderSecurityBuilderTests.java @@ -16,8 +16,16 @@ package org.springframework.security.config.annotation.authentication.ldap; +import java.io.IOException; +import java.net.ServerSocket; +import java.util.Collections; +import java.util.List; + +import javax.naming.directory.SearchControls; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -39,16 +47,14 @@ import org.springframework.security.ldap.userdetails.LdapAuthoritiesPopulator; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; -import java.io.IOException; -import java.net.ServerSocket; -import java.util.List; -import javax.naming.directory.SearchControls; -import static java.util.Collections.singleton; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; public class LdapAuthenticationProviderBuilderSecurityBuilderTests { + + static Integer port; + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -72,37 +78,13 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { assertThat(ReflectionTestUtils.getField(getAuthoritiesMapper(provider), "prefix")).isEqualTo("ROLE_"); } - @EnableWebSecurity - static class DefaultLdapConfig extends BaseLdapProviderConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .userDnPatterns("uid={0},ou=people"); - } - // @formatter:on - } - @Test public void groupRolesCustom() { this.spring.register(GroupRolesConfig.class).autowire(); LdapAuthenticationProvider provider = ldapProvider(); - assertThat(ReflectionTestUtils.getField(getAuthoritiesPopulator(provider), "groupRoleAttribute")).isEqualTo("group"); - } - - @EnableWebSecurity - static class GroupRolesConfig extends BaseLdapProviderConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .userDnPatterns("uid={0},ou=people") - .groupRoleAttribute("group"); - } - // @formatter:on + assertThat(ReflectionTestUtils.getField(getAuthoritiesPopulator(provider), "groupRoleAttribute")) + .isEqualTo("group"); } @Test @@ -110,20 +92,8 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { this.spring.register(GroupSearchConfig.class).autowire(); LdapAuthenticationProvider provider = ldapProvider(); - assertThat(ReflectionTestUtils.getField(getAuthoritiesPopulator(provider), "groupSearchFilter")).isEqualTo("ou=groupName"); - } - - @EnableWebSecurity - static class GroupSearchConfig extends BaseLdapProviderConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .userDnPatterns("uid={0},ou=people") - .groupSearchFilter("ou=groupName"); - } - // @formatter:on + assertThat(ReflectionTestUtils.getField(getAuthoritiesPopulator(provider), "groupSearchFilter")) + .isEqualTo("ou=groupName"); } @Test @@ -135,20 +105,6 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { .extracting("searchScope").isEqualTo(SearchControls.SUBTREE_SCOPE); } - @EnableWebSecurity - static class GroupSubtreeSearchConfig extends BaseLdapProviderConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .userDnPatterns("uid={0},ou=people") - .groupSearchFilter("ou=groupName") - .groupSearchSubtree(true); - } - // @formatter:on - } - @Test public void rolePrefixCustom() { this.spring.register(RolePrefixConfig.class).autowire(); @@ -157,39 +113,13 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { assertThat(ReflectionTestUtils.getField(getAuthoritiesMapper(provider), "prefix")).isEqualTo("role_"); } - @EnableWebSecurity - static class RolePrefixConfig extends BaseLdapProviderConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .userDnPatterns("uid={0},ou=people") - .rolePrefix("role_"); - } - // @formatter:on - } - @Test public void bindAuthentication() throws Exception { this.spring.register(BindAuthenticationConfig.class).autowire(); this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withUsername("bob").withAuthorities(singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); - } - - @EnableWebSecurity - static class BindAuthenticationConfig extends BaseLdapServerConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .groupSearchBase("ou=groups") - .groupSearchFilter("(member={0})") - .userDnPatterns("uid={0},ou=people"); - } - // @formatter:on + .andExpect(authenticated().withUsername("bob") + .withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); } // SEC-2472 @@ -198,26 +128,13 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { this.spring.register(PasswordEncoderConfig.class).autowire(); this.mockMvc.perform(formLogin().user("bcrypt").password("password")) - .andExpect(authenticated().withUsername("bcrypt").withAuthorities(singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); - } - - @EnableWebSecurity - static class PasswordEncoderConfig extends BaseLdapServerConfig { - // @formatter:off - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .ldapAuthentication() - .contextSource(contextSource()) - .passwordEncoder(new BCryptPasswordEncoder()) - .groupSearchBase("ou=groups") - .groupSearchFilter("(member={0})") - .userDnPatterns("uid={0},ou=people"); - } - // @formatter:on + .andExpect(authenticated().withUsername("bcrypt") + .withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); } private LdapAuthenticationProvider ldapProvider() { - return ((List) ReflectionTestUtils.getField(authenticationManager, "providers")).get(0); + return ((List) ReflectionTestUtils.getField(this.authenticationManager, + "providers")).get(0); } private LdapAuthoritiesPopulator getAuthoritiesPopulator(LdapAuthenticationProvider provider) { @@ -228,23 +145,150 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { return (GrantedAuthoritiesMapper) ReflectionTestUtils.getField(provider, "authoritiesMapper"); } + static int getPort() throws IOException { + if (port == null) { + ServerSocket socket = new ServerSocket(0); + port = socket.getLocalPort(); + socket.close(); + } + return port; + } + @EnableWebSecurity - static abstract class BaseLdapServerConfig extends BaseLdapProviderConfig { + static class DefaultLdapConfig extends BaseLdapProviderConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .userDnPatterns("uid={0},ou=people"); + // @formatter:on + } + + } + + @EnableWebSecurity + static class GroupRolesConfig extends BaseLdapProviderConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .userDnPatterns("uid={0},ou=people") + .groupRoleAttribute("group"); + // @formatter:on + } + + } + + @EnableWebSecurity + static class GroupSearchConfig extends BaseLdapProviderConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .userDnPatterns("uid={0},ou=people") + .groupSearchFilter("ou=groupName"); + // @formatter:on + } + + } + + @EnableWebSecurity + static class GroupSubtreeSearchConfig extends BaseLdapProviderConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .userDnPatterns("uid={0},ou=people") + .groupSearchFilter("ou=groupName") + .groupSearchSubtree(true); + // @formatter:on + } + + } + + @EnableWebSecurity + static class RolePrefixConfig extends BaseLdapProviderConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .userDnPatterns("uid={0},ou=people") + .rolePrefix("role_"); + // @formatter:on + } + + } + + @EnableWebSecurity + static class BindAuthenticationConfig extends BaseLdapServerConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .groupSearchBase("ou=groups") + .groupSearchFilter("(member={0})") + .userDnPatterns("uid={0},ou=people"); + // @formatter:on + } + + } + + @EnableWebSecurity + static class PasswordEncoderConfig extends BaseLdapServerConfig { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .ldapAuthentication() + .contextSource(contextSource()) + .passwordEncoder(new BCryptPasswordEncoder()) + .groupSearchBase("ou=groups") + .groupSearchFilter("(member={0})") + .userDnPatterns("uid={0},ou=people"); + // @formatter:on + } + + } + + @EnableWebSecurity + abstract static class BaseLdapServerConfig extends BaseLdapProviderConfig { + @Bean - public ApacheDSContainer ldapServer() throws Exception { - ApacheDSContainer apacheDSContainer = new ApacheDSContainer("dc=springframework,dc=org", "classpath:/test-server.ldif"); + ApacheDSContainer ldapServer() throws Exception { + ApacheDSContainer apacheDSContainer = new ApacheDSContainer("dc=springframework,dc=org", + "classpath:/test-server.ldif"); apacheDSContainer.setPort(getPort()); return apacheDSContainer; } + } @EnableWebSecurity @EnableGlobalAuthentication @Import(ObjectPostProcessorConfiguration.class) - static abstract class BaseLdapProviderConfig extends WebSecurityConfigurerAdapter { + abstract static class BaseLdapProviderConfig extends WebSecurityConfigurerAdapter { @Bean - public BaseLdapPathContextSource contextSource() throws Exception { + BaseLdapPathContextSource contextSource() throws Exception { DefaultSpringSecurityContextSource contextSource = new DefaultSpringSecurityContextSource( "ldap://127.0.0.1:" + getPort() + "/dc=springframework,dc=org"); contextSource.setUserDn("uid=admin,ou=system"); @@ -254,22 +298,14 @@ public class LdapAuthenticationProviderBuilderSecurityBuilderTests { } @Bean - public AuthenticationManager authenticationManager(AuthenticationManagerBuilder auth) throws Exception { + AuthenticationManager authenticationManager(AuthenticationManagerBuilder auth) throws Exception { configure(auth); return auth.build(); } - abstract protected void configure(AuthenticationManagerBuilder auth) throws Exception; + @Override + protected abstract void configure(AuthenticationManagerBuilder auth) throws Exception; + } - static Integer port; - - static int getPort() throws IOException { - if (port == null) { - ServerSocket socket = new ServerSocket(0); - port = socket.getLocalPort(); - socket.close(); - } - return port; - } } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderConfigurerTests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderConfigurerTests.java index 3cf29d6338..ea34da288e 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderConfigurerTests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/LdapAuthenticationProviderConfigurerTests.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.authentication.ldap; +import java.util.Collections; + import org.junit.Rule; import org.junit.Test; @@ -27,13 +29,15 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.test.web.servlet.MockMvc; -import static java.util.Collections.singleton; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; public class LdapAuthenticationProviderConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -41,7 +45,8 @@ public class LdapAuthenticationProviderConfigurerTests { private MockMvc mockMvc; @Test - public void authenticationManagerSupportMultipleDefaultLdapContextsWithPortsDynamicallyAllocated() throws Exception { + public void authenticationManagerSupportMultipleDefaultLdapContextsWithPortsDynamicallyAllocated() + throws Exception { this.spring.register(MultiLdapAuthenticationProvidersConfig.class).autowire(); this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) @@ -52,39 +57,68 @@ public class LdapAuthenticationProviderConfigurerTests { public void authenticationManagerSupportMultipleLdapContextWithDefaultRolePrefix() throws Exception { this.spring.register(MultiLdapAuthenticationProvidersConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withUsername("bob").withAuthorities(singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS")))); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bob") + .password("bobspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher expectedUser = authenticated() + .withUsername("bob") + .withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_DEVELOPERS"))); + // @formatter:on + this.mockMvc.perform(request).andExpect(expectedUser); } @Test public void authenticationManagerSupportMultipleLdapContextWithCustomRolePrefix() throws Exception { this.spring.register(MultiLdapWithCustomRolePrefixAuthenticationProvidersConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withUsername("bob").withAuthorities(singleton(new SimpleGrantedAuthority("ROL_DEVELOPERS")))); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bob") + .password("bobspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher expectedUser = authenticated() + .withUsername("bob") + .withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROL_DEVELOPERS"))); + // @formatter:on + this.mockMvc.perform(request).andExpect(expectedUser); } @Test public void authenticationManagerWhenPortZeroThenAuthenticates() throws Exception { this.spring.register(LdapWithRandomPortConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withUsername("bob")); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bob") + .password("bobspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher expectedUser = authenticated() + .withUsername("bob"); + // @formatter:on + this.mockMvc.perform(request).andExpect(expectedUser); } @Test public void authenticationManagerWhenSearchSubtreeThenNestedGroupFound() throws Exception { this.spring.register(GroupSubtreeSearchConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("ben").password("benspassword")) - .andExpect(authenticated().withUsername("ben").withAuthorities( - AuthorityUtils.createAuthorityList("ROLE_SUBMANAGERS", "ROLE_MANAGERS", "ROLE_DEVELOPERS"))); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("ben") + .password("benspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher expectedUser = authenticated() + .withUsername("ben") + .withAuthorities( + AuthorityUtils.createAuthorityList("ROLE_SUBMANAGERS", "ROLE_MANAGERS", "ROLE_DEVELOPERS")); + // @formatter:on + this.mockMvc.perform(request).andExpect(expectedUser); } @EnableWebSecurity static class MultiLdapAuthenticationProvidersConfig extends WebSecurityConfigurerAdapter { - // @formatter:off + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupSearchBase("ou=groups") @@ -95,14 +129,17 @@ public class LdapAuthenticationProviderConfigurerTests { .groupSearchBase("ou=groups") .groupSearchFilter("(member={0})") .userDnPatterns("uid={0},ou=people"); + // @formatter:on } - // @formatter:on + } @EnableWebSecurity static class MultiLdapWithCustomRolePrefixAuthenticationProvidersConfig extends WebSecurityConfigurerAdapter { - // @formatter:off + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupSearchBase("ou=groups") @@ -115,14 +152,17 @@ public class LdapAuthenticationProviderConfigurerTests { .groupSearchFilter("(member={0})") .userDnPatterns("uid={0},ou=people") .rolePrefix("RUOLO_"); + // @formatter:on } - // @formatter:on + } @EnableWebSecurity static class LdapWithRandomPortConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupSearchBase("ou=groups") @@ -130,20 +170,26 @@ public class LdapAuthenticationProviderConfigurerTests { .userDnPatterns("uid={0},ou=people") .contextSource() .port(0); + // @formatter:on } + } @EnableWebSecurity static class GroupSubtreeSearchConfig extends BaseLdapProviderConfig { - // @formatter:off + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupSearchBase("ou=groups") .groupSearchFilter("(member={0})") .groupSearchSubtree(true) .userDnPatterns("uid={0},ou=people"); + // @formatter:on } - // @formatter:on + } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTests.java index 527d0f0462..3369a4ed8a 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTests.java @@ -16,8 +16,13 @@ package org.springframework.security.config.annotation.authentication.ldap; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.ldap.core.DirContextOperations; import org.springframework.ldap.core.support.LdapContextSource; @@ -31,13 +36,11 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.ldap.DefaultSpringSecurityContextSource; import org.springframework.security.ldap.userdetails.DefaultLdapAuthoritiesPopulator; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.security.web.FilterChainProxy; import org.springframework.test.web.servlet.MockMvc; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; @@ -56,22 +59,35 @@ public class NamespaceLdapAuthenticationProviderTests { public void ldapAuthenticationProvider() throws Exception { this.spring.register(LdapAuthenticationProviderConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withUsername("bob")); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bob") + .password("bobspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher user = authenticated() + .withUsername("bob"); + // @formatter:on + this.mockMvc.perform(request).andExpect(user); } @Test public void ldapAuthenticationProviderCustom() throws Exception { this.spring.register(CustomLdapAuthenticationProviderConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withAuthorities(Collections.singleton(new SimpleGrantedAuthority("PREFIX_DEVELOPERS")))); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bob") + .password("bobspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher user = authenticated() + .withAuthorities(Collections.singleton(new SimpleGrantedAuthority("PREFIX_DEVELOPERS"))); + // @formatter:on + this.mockMvc.perform(request).andExpect(user); } // SEC-2490 @Test public void ldapAuthenticationProviderCustomLdapAuthoritiesPopulator() throws Exception { - LdapContextSource contextSource = new DefaultSpringSecurityContextSource("ldap://blah.example.com:789/dc=springframework,dc=org"); + LdapContextSource contextSource = new DefaultSpringSecurityContextSource( + "ldap://blah.example.com:789/dc=springframework,dc=org"); CustomAuthoritiesPopulatorConfig.LAP = new DefaultLdapAuthoritiesPopulator(contextSource, null) { @Override protected Set getAdditionalRoles(DirContextOperations user, String username) { @@ -81,15 +97,27 @@ public class NamespaceLdapAuthenticationProviderTests { this.spring.register(CustomAuthoritiesPopulatorConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bob").password("bobspassword")) - .andExpect(authenticated().withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_EXTRA")))); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bob") + .password("bobspassword"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher user = authenticated() + .withAuthorities(Collections.singleton(new SimpleGrantedAuthority("ROLE_EXTRA"))); + // @formatter:on + this.mockMvc.perform(request).andExpect(user); } @Test public void ldapAuthenticationProviderPasswordCompare() throws Exception { this.spring.register(PasswordCompareLdapConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("bcrypt").password("password")) - .andExpect(authenticated().withUsername("bcrypt")); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin() + .user("bcrypt") + .password("password"); + SecurityMockMvcResultMatchers.AuthenticatedMatcher user = authenticated().withUsername("bcrypt"); + // @formatter:on + this.mockMvc.perform(request).andExpect(user); } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTestsConfigs.java b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTestsConfigs.java index c349ce6a1b..535bfa5496 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTestsConfigs.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/authentication/ldap/NamespaceLdapAuthenticationProviderTestsConfigs.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.ldap; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; @@ -27,23 +28,28 @@ import org.springframework.security.ldap.userdetails.PersonContextMapper; * */ public class NamespaceLdapAuthenticationProviderTestsConfigs { + @EnableWebSecurity static class LdapAuthenticationProviderConfig extends WebSecurityConfigurerAdapter { - // @formatter:off + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupSearchBase("ou=groups") .userDnPatterns("uid={0},ou=people"); // ldap-server@user-dn-pattern + // @formatter:on } - // @formatter:on + } @EnableWebSecurity - static class CustomLdapAuthenticationProviderConfig extends - WebSecurityConfigurerAdapter { - // @formatter:off + static class CustomLdapAuthenticationProviderConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupRoleAttribute("cn") // ldap-authentication-provider@group-role-attribute @@ -60,31 +66,36 @@ public class NamespaceLdapAuthenticationProviderTestsConfigs { .managerDn("uid=admin,ou=system") // ldap-server@manager-dn .managerPassword("secret") // ldap-server@manager-password .port(33399) // ldap-server@port - .root("dc=springframework,dc=org") // ldap-server@root + .root("dc=springframework,dc=org"); // ldap-server@root // .url("ldap://localhost:33389/dc-springframework,dc=org") this overrides root and port and is used for external - ; + // @formatter:on } - // @formatter:on + } @EnableWebSecurity static class CustomAuthoritiesPopulatorConfig extends WebSecurityConfigurerAdapter { + static LdapAuthoritiesPopulator LAP; - // @formatter:off + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .userSearchFilter("(uid={0})") .ldapAuthoritiesPopulator(LAP); + // @formatter:on } - // @formatter:on + } @EnableWebSecurity static class PasswordCompareLdapConfig extends WebSecurityConfigurerAdapter { - // @formatter:off + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .ldapAuthentication() .groupSearchBase("ou=groups") @@ -92,7 +103,9 @@ public class NamespaceLdapAuthenticationProviderTestsConfigs { .passwordCompare() .passwordEncoder(new BCryptPasswordEncoder()) // ldap-authentication-provider/password-compare/password-encoder@ref .passwordAttribute("userPassword"); // ldap-authentication-provider/password-compare@password-attribute + // @formatter:on } - // @formatter:on + } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java index 377a3abd0b..73b562445c 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.rsocket; import io.rsocket.ConnectionSetupPayload; @@ -26,15 +27,15 @@ public class HelloHandler implements SocketAcceptor { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { - return Mono.just( - new RSocket() { - @Override - public Mono requestResponse(Payload payload) { - String data = payload.getDataUtf8(); - payload.release(); - System.out.println("Got " + data); - return Mono.just(ByteBufPayload.create("Hello " + data)); - } - }); + return Mono.just(new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + String data = payload.getDataUtf8(); + payload.release(); + System.out.println("Got " + data); + return Mono.just(ByteBufPayload.create("Hello " + data)); + } + }); } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketITests.java index 32acc219cf..cb9ba6c939 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketITests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketITests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; import io.rsocket.core.RSocketServer; +import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -46,7 +47,7 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rob Winch @@ -54,6 +55,7 @@ import static org.assertj.core.api.Assertions.assertThatCode; @ContextConfiguration @RunWith(SpringRunner.class) public class HelloRSocketITests { + @Autowired RSocketMessageHandler handler; @@ -69,14 +71,16 @@ public class HelloRSocketITests { @Before public void setup() { + // @formatter:off this.server = RSocketServer.create() .payloadDecoder(PayloadDecoder.ZERO_COPY) - .interceptors((registry) -> { - registry.forSocketAcceptor(this.interceptor); - }) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) .acceptor(this.handler.responder()) .bind(TcpServerTransport.create("localhost", 0)) .block(); + // @formatter:on } @After @@ -88,38 +92,45 @@ public class HelloRSocketITests { @Test public void retrieveMonoWhenSecureThenDenied() throws Exception { + // @formatter:off this.requester = RSocketRequester.builder() - .rsocketStrategies(this.handler.getRSocketStrategies()) - .connectTcp("localhost", this.server.address().getPort()) - .block(); - + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", this.server.address().getPort()) + .block(); + // @formatter:on String data = "rob"; - assertThatCode(() -> this.requester.route("secure.retrieve-mono") - .data(data) - .retrieveMono(String.class) - .block() - ) - .isNotNull(); + // @formatter:off + assertThatExceptionOfType(Exception.class).isThrownBy( + () -> this.requester.route("secure.retrieve-mono") + .data(data) + .retrieveMono(String.class) + .block() + ) + .matches((ex) -> ex instanceof RejectedSetupException + || ex.getClass().toString().contains("ReactiveException")); + // @formatter:on // FIXME: https://github.com/rsocket/rsocket-java/issues/686 - // .isInstanceOf(RejectedSetupException.class); assertThat(this.controller.payloads).isEmpty(); } @Test public void retrieveMonoWhenAuthorizedThenGranted() throws Exception { UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password"); + // @formatter:off this.requester = RSocketRequester.builder() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .rsocketStrategies(this.handler.getRSocketStrategies()) - .connectTcp("localhost", this.server.address().getPort()) - .block(); + .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", this.server.address().getPort()) + .block(); + // @formatter:on String data = "rob"; + // @formatter:off String hiRob = this.requester.route("secure.retrieve-mono") - .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .data(data) - .retrieveMono(String.class) - .block(); - + .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data(data) + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); assertThat(this.controller.payloads).containsOnly(data); } @@ -129,37 +140,39 @@ public class HelloRSocketITests { static class Config { @Bean - public ServerController controller() { + ServerController controller() { return new ServerController(); } @Bean - public RSocketMessageHandler messageHandler() { + RSocketMessageHandler messageHandler() { RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setRSocketStrategies(rsocketStrategies()); return handler; } @Bean - public RSocketStrategies rsocketStrategies() { - return RSocketStrategies.builder() - .encoder(new BasicAuthenticationEncoder()) - .build(); + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build(); } @Bean MapReactiveUserDetailsService uds() { + // @formatter:off UserDetails rob = User.withDefaultPasswordEncoder() .username("rob") .password("password") .roles("USER", "ADMIN") .build(); + // @formatter:on return new MapReactiveUserDetailsService(rob); } + } @Controller static class ServerController { + private List payloads = new ArrayList<>(); @MessageMapping("**") @@ -171,6 +184,7 @@ public class HelloRSocketITests { private void add(String p) { this.payloads.add(p); } + } } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java index db183363a1..16fb7da016 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/JwtITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.rsocket; import java.util.ArrayList; @@ -21,6 +22,7 @@ import java.util.List; import io.rsocket.core.RSocketServer; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import org.junit.After; @@ -51,11 +53,10 @@ import org.springframework.test.context.junit4.SpringRunner; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import static io.rsocket.metadata.WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Matchers.any; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -63,6 +64,7 @@ import static org.mockito.Mockito.when; @ContextConfiguration @RunWith(SpringRunner.class) public class JwtITests { + @Autowired RSocketMessageHandler handler; @@ -81,14 +83,16 @@ public class JwtITests { @Before public void setup() { + // @formatter:off this.server = RSocketServer.create() - .payloadDecoder(PayloadDecoder.ZERO_COPY) - .interceptors((registry) -> { - registry.forSocketAcceptor(this.interceptor); - }) - .acceptor(this.handler.responder()) - .bind(TcpServerTransport.create("localhost", 0)) - .block(); + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) + .acceptor(this.handler.responder()) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + // @formatter:on } @After @@ -100,88 +104,72 @@ public class JwtITests { @Test public void routeWhenBearerThenAuthorized() { - BearerTokenMetadata credentials = - new BearerTokenMetadata("token"); - when(this.decoder.decode(any())).thenReturn(Mono.just(jwt())); + BearerTokenMetadata credentials = new BearerTokenMetadata("token"); + given(this.decoder.decode(any())).willReturn(Mono.just(jwt())); + // @formatter:off this.requester = requester() - .setupMetadata(credentials.getToken(), BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE) - .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) - .block(); - + .setupMetadata(credentials.getToken(), BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); String hiRob = this.requester.route("secure.retrieve-mono") - .data("rob") - .retrieveMono(String.class) - .block(); - + .data("rob") + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); } @Test public void routeWhenAuthenticationBearerThenAuthorized() { - MimeType authenticationMimeType = MimeTypeUtils.parseMimeType(MESSAGE_RSOCKET_AUTHENTICATION.getString()); - - BearerTokenMetadata credentials = - new BearerTokenMetadata("token"); - when(this.decoder.decode(any())).thenReturn(Mono.just(jwt())); - this.requester = requester() - .setupMetadata(credentials, authenticationMimeType) + MimeType authenticationMimeType = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()); + BearerTokenMetadata credentials = new BearerTokenMetadata("token"); + given(this.decoder.decode(any())).willReturn(Mono.just(jwt())); + // @formatter:off + this.requester = requester().setupMetadata(credentials, authenticationMimeType) .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) .block(); - String hiRob = this.requester.route("secure.retrieve-mono") .data("rob") - .retrieveMono(String.class) - .block(); - + .retrieveMono(String.class).block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); } private Jwt jwt() { - return TestJwts.jwt() - .claim(IdTokenClaimNames.ISS, "https://issuer.example.com") - .claim(IdTokenClaimNames.SUB, "rob") - .claim(IdTokenClaimNames.AUD, Arrays.asList("client-id")) - .build(); + return TestJwts.jwt().claim(IdTokenClaimNames.ISS, "https://issuer.example.com") + .claim(IdTokenClaimNames.SUB, "rob").claim(IdTokenClaimNames.AUD, Arrays.asList("client-id")).build(); } private RSocketRequester.Builder requester() { - return RSocketRequester.builder() - .rsocketStrategies(this.handler.getRSocketStrategies()); + return RSocketRequester.builder().rsocketStrategies(this.handler.getRSocketStrategies()); } - @Configuration @EnableRSocketSecurity static class Config { @Bean - public ServerController controller() { + ServerController controller() { return new ServerController(); } @Bean - public RSocketMessageHandler messageHandler() { + RSocketMessageHandler messageHandler() { RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setRSocketStrategies(rsocketStrategies()); return handler; } @Bean - public RSocketStrategies rsocketStrategies() { - return RSocketStrategies.builder() - .encoder(new BearerTokenAuthenticationEncoder()) - .build(); + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new BearerTokenAuthenticationEncoder()).build(); } @Bean PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { - rsocket - .authorizePayload(authorize -> - authorize - .anyRequest().authenticated() - .anyExchange().permitAll() - ) - .jwt(Customizer.withDefaults()); + rsocket.authorizePayload((authorize) -> authorize.anyRequest().authenticated().anyExchange().permitAll()) + .jwt(Customizer.withDefaults()); return rsocket.build(); } @@ -189,16 +177,19 @@ public class JwtITests { ReactiveJwtDecoder jwtDecoder() { return mock(ReactiveJwtDecoder.class); } + } @Controller static class ServerController { + private List payloads = new ArrayList<>(); @MessageMapping("**") String connect(String payload) { return "Hi " + payload; } + } } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java index 1bf93f18a7..eb7e8a551b 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerConnectionITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.rsocket; import java.util.ArrayList; @@ -20,6 +21,7 @@ import java.util.List; import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -48,7 +50,7 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rob Winch @@ -60,6 +62,7 @@ import static org.assertj.core.api.Assertions.assertThatCode; @ContextConfiguration @RunWith(SpringRunner.class) public class RSocketMessageHandlerConnectionITests { + @Autowired RSocketMessageHandler handler; @@ -75,14 +78,16 @@ public class RSocketMessageHandlerConnectionITests { @Before public void setup() { + // @formatter:off this.server = RSocketServer.create() .payloadDecoder(PayloadDecoder.ZERO_COPY) - .interceptors((registry) -> { - registry.forSocketAcceptor(this.interceptor); - }) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) .acceptor(this.handler.responder()) .bind(TcpServerTransport.create("localhost", 0)) .block(); + // @formatter:on } @After @@ -94,182 +99,179 @@ public class RSocketMessageHandlerConnectionITests { @Test public void routeWhenAuthorized() { - UsernamePasswordMetadata credentials = - new UsernamePasswordMetadata("user", "password"); - this.requester = requester() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) - .block(); - + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); + // @formatter:off + this.requester = requester().setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); String hiRob = this.requester.route("secure.retrieve-mono") - .data("rob") - .retrieveMono(String.class) - .block(); - + .data("rob") + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); } @Test public void routeWhenNotAuthorized() { UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); - this.requester = requester() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) - .block(); - - assertThatCode(() -> this.requester.route("secure.admin.retrieve-mono") + // @formatter:off + this.requester = requester().setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); + assertThatExceptionOfType(ApplicationErrorException.class).isThrownBy(() -> this.requester + .route("secure.admin.retrieve-mono") .data("data") .retrieveMono(String.class) - .block()) - .isInstanceOf(ApplicationErrorException.class); + .block() + ); + // @formatter:on } @Test public void routeWhenStreamCredentialsAuthorized() { UsernamePasswordMetadata connectCredentials = new UsernamePasswordMetadata("user", "password"); - this.requester = requester() - .setupMetadata(connectCredentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) - .block(); - + // @formatter:off + this.requester = requester().setupMetadata(connectCredentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); String hiRob = this.requester.route("secure.admin.retrieve-mono") - .metadata(new UsernamePasswordMetadata("admin", "password"), UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .data("rob") - .retrieveMono(String.class) - .block(); - + .metadata(new UsernamePasswordMetadata("admin", "password"), + UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data("rob") + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); } @Test public void routeWhenStreamCredentialsHaveAuthority() { UsernamePasswordMetadata connectCredentials = new UsernamePasswordMetadata("user", "password"); - this.requester = requester() - .setupMetadata(connectCredentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) - .block(); - + // @formatter:off + this.requester = requester().setupMetadata(connectCredentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + .block(); String hiUser = this.requester.route("secure.authority.retrieve-mono") - .metadata(new UsernamePasswordMetadata("admin", "password"), UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .data("Felipe") - .retrieveMono(String.class) - .block(); - + .metadata(new UsernamePasswordMetadata("admin", "password"), + UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data("Felipe") + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiUser).isEqualTo("Hi Felipe"); } @Test public void connectWhenNotAuthenticated() { - this.requester = requester() - .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) + // @formatter:off + this.requester = requester().connectTcp(this.server.address().getHostName(), this.server.address().getPort()) .block(); - - assertThatCode(() -> this.requester.route("retrieve-mono") - .data("data") - .retrieveMono(String.class) - .block()) - .isNotNull(); + assertThatExceptionOfType(Exception.class) + .isThrownBy(() -> this.requester.route("retrieve-mono") + .data("data") + .retrieveMono(String.class) + .block() + ) + .matches((ex) -> ex instanceof RejectedSetupException + || ex.getClass().toString().contains("ReactiveException")); + // @formatter:on // FIXME: https://github.com/rsocket/rsocket-java/issues/686 - // .isInstanceOf(RejectedSetupException.class); } @Test public void connectWhenNotAuthorized() { UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("evil", "password"); - this.requester = requester() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + // @formatter:off + this.requester = requester().setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) .block(); - - assertThatCode(() -> this.requester.route("retrieve-mono") - .data("data") - .retrieveMono(String.class) - .block()) - .isNotNull(); -// FIXME: https://github.com/rsocket/rsocket-java/issues/686 -// .isInstanceOf(RejectedSetupException.class); + assertThatExceptionOfType(Exception.class) + .isThrownBy(() -> this.requester.route("retrieve-mono") + .data("data") + .retrieveMono(String.class) + .block() + ) + .matches((ex) -> ex instanceof RejectedSetupException + || ex.getClass().toString().contains("ReactiveException")); + // @formatter:on + // FIXME: https://github.com/rsocket/rsocket-java/issues/686 } @Test public void connectionDenied() { UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); - this.requester = requester() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + // @formatter:off + this.requester = requester().setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) .block(); - - assertThatCode(() -> this.requester.route("prohibit") - .data("data") - .retrieveMono(String.class) - .block()) - .isInstanceOf(ApplicationErrorException.class); + assertThatExceptionOfType(ApplicationErrorException.class) + .isThrownBy(() -> this.requester.route("prohibit") + .data("data") + .retrieveMono(String.class) + .block() + ); + // @formatter:on } @Test public void connectWithAnyRole() { - UsernamePasswordMetadata credentials = - new UsernamePasswordMetadata("user", "password"); - this.requester = requester() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); + // @formatter:off + this.requester = requester().setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) .block(); - String hiRob = this.requester.route("anyroute") .data("rob") .retrieveMono(String.class) .block(); - + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); } @Test public void connectWithAnyAuthority() { - UsernamePasswordMetadata credentials = - new UsernamePasswordMetadata("admin", "password"); - this.requester = requester() - .setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("admin", "password"); + // @formatter:off + this.requester = requester().setupMetadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) .connectTcp(this.server.address().getHostName(), this.server.address().getPort()) .block(); - String hiEbert = this.requester.route("management.users") .data("admin") .retrieveMono(String.class) .block(); - + // @formatter:on assertThat(hiEbert).isEqualTo("Hi admin"); } private RSocketRequester.Builder requester() { - return RSocketRequester.builder() - .rsocketStrategies(this.handler.getRSocketStrategies()); + return RSocketRequester.builder().rsocketStrategies(this.handler.getRSocketStrategies()); } - @Configuration @EnableRSocketSecurity static class Config { @Bean - public ServerController controller() { + ServerController controller() { return new ServerController(); } @Bean - public RSocketMessageHandler messageHandler() { + RSocketMessageHandler messageHandler() { RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setRSocketStrategies(rsocketStrategies()); return handler; } @Bean - public RSocketStrategies rsocketStrategies() { - return RSocketStrategies.builder() - .encoder(new BasicAuthenticationEncoder()) - .build(); + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build(); } @Bean MapReactiveUserDetailsService uds() { + // @formatter:off UserDetails admin = User.withDefaultPasswordEncoder() .username("admin") .password("password") @@ -280,41 +282,44 @@ public class RSocketMessageHandlerConnectionITests { .password("password") .roles("USER", "SETUP") .build(); - UserDetails evil = User.withDefaultPasswordEncoder() .username("evil") .password("password") .roles("EVIL") .build(); + // @formatter:on return new MapReactiveUserDetailsService(admin, user, evil); } @Bean PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { - rsocket - .authorizePayload(authorize -> - authorize - .setup().hasRole("SETUP") - .route("secure.admin.*").hasRole("ADMIN") - .route("secure.**").hasRole("USER") - .route("secure.authority.*").hasAuthority("ROLE_USER") - .route("management.*").hasAnyAuthority("ROLE_ADMIN") - .route("prohibit").denyAll() - .anyRequest().permitAll() - ) - .basicAuthentication(Customizer.withDefaults()); + // @formatter:off + rsocket.authorizePayload((authorize) -> authorize + .setup().hasRole("SETUP") + .route("secure.admin.*").hasRole("ADMIN") + .route("secure.**").hasRole("USER") + .route("secure.authority.*").hasAuthority("ROLE_USER") + .route("management.*").hasAnyAuthority("ROLE_ADMIN") + .route("prohibit").denyAll() + .anyRequest().permitAll() + ) + .basicAuthentication(Customizer.withDefaults()); + // @formatter:on return rsocket.build(); } + } @Controller static class ServerController { + private List payloads = new ArrayList<>(); @MessageMapping("**") String connect(String payload) { return "Hi " + payload; } + } } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java index 6f351e54e2..774d0b611a 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/RSocketMessageHandlerITests.java @@ -52,7 +52,7 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rob Winch @@ -60,6 +60,7 @@ import static org.assertj.core.api.Assertions.assertThatCode; @ContextConfiguration @RunWith(SpringRunner.class) public class RSocketMessageHandlerITests { + @Autowired RSocketMessageHandler handler; @@ -75,20 +76,22 @@ public class RSocketMessageHandlerITests { @Before public void setup() { + // @formatter:off this.server = RSocketServer.create() .payloadDecoder(PayloadDecoder.ZERO_COPY) - .interceptors((registry) -> { - registry.forSocketAcceptor(this.interceptor); - }) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) .acceptor(this.handler.responder()) .bind(TcpServerTransport.create("localhost", 0)) .block(); - this.requester = RSocketRequester.builder() - // .rsocketFactory(factory -> factory.addRequesterPlugin(payloadInterceptor)) + // .rsocketFactory((factory) -> + // factory.addRequesterPlugin(payloadInterceptor)) .rsocketStrategies(this.handler.getRSocketStrategies()) .connectTcp("localhost", this.server.address().getPort()) .block(); + // @formatter:on } @After @@ -101,13 +104,15 @@ public class RSocketMessageHandlerITests { @Test public void retrieveMonoWhenSecureThenDenied() throws Exception { String data = "rob"; - assertThatCode(() -> this.requester.route("secure.retrieve-mono") - .data(data) - .retrieveMono(String.class) - .block() - ).isInstanceOf(ApplicationErrorException.class) - .hasMessageContaining("Access Denied"); - + // @formatter:off + assertThatExceptionOfType(ApplicationErrorException.class).isThrownBy( + () -> this.requester.route("secure.retrieve-mono") + .data(data) + .retrieveMono(String.class) + .block() + ) + .withMessageContaining("Access Denied"); + // @formatter:on assertThat(this.controller.payloads).isEmpty(); } @@ -115,14 +120,15 @@ public class RSocketMessageHandlerITests { public void retrieveMonoWhenAuthenticationFailedThenException() throws Exception { String data = "rob"; UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("invalid", "password"); - assertThatCode(() -> this.requester.route("secure.retrieve-mono") - .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .data(data) + // @formatter:off + assertThatExceptionOfType(ApplicationErrorException.class) + .isThrownBy(() -> this.requester.route("secure.retrieve-mono") + .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE).data(data) .retrieveMono(String.class) .block() - ).isInstanceOf(ApplicationErrorException.class) - .hasMessageContaining("Invalid Credentials"); - + ) + .withMessageContaining("Invalid Credentials"); + // @formatter:on assertThat(this.controller.payloads).isEmpty(); } @@ -130,12 +136,13 @@ public class RSocketMessageHandlerITests { public void retrieveMonoWhenAuthorizedThenGranted() throws Exception { String data = "rob"; UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password"); + // @formatter:off String hiRob = this.requester.route("secure.retrieve-mono") - .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) - .data(data) - .retrieveMono(String.class) - .block(); - + .metadata(credentials, UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE) + .data(data) + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); assertThat(this.controller.payloads).containsOnly(data); } @@ -143,11 +150,12 @@ public class RSocketMessageHandlerITests { @Test public void retrieveMonoWhenPublicThenGranted() throws Exception { String data = "rob"; + // @formatter:off String hiRob = this.requester.route("retrieve-mono") - .data(data) - .retrieveMono(String.class) - .block(); - + .data(data) + .retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); assertThat(this.controller.payloads).containsOnly(data); } @@ -155,26 +163,29 @@ public class RSocketMessageHandlerITests { @Test public void retrieveFluxWhenDataFluxAndSecureThenDenied() throws Exception { Flux data = Flux.just("a", "b", "c"); - assertThatCode(() -> this.requester.route("secure.retrieve-flux") + // @formatter:off + assertThatExceptionOfType(ApplicationErrorException.class) + .isThrownBy(() -> this.requester.route("secure.retrieve-flux") .data(data, String.class) .retrieveFlux(String.class) .collectList() .block() - ).isInstanceOf(ApplicationErrorException.class) - .hasMessageContaining("Access Denied"); - + ) + .withMessageContaining("Access Denied"); + // @formatter:on assertThat(this.controller.payloads).isEmpty(); } @Test public void retrieveFluxWhenDataFluxAndPublicThenGranted() throws Exception { Flux data = Flux.just("a", "b", "c"); + // @formatter:off List hi = this.requester.route("retrieve-flux") - .data(data, String.class) - .retrieveFlux(String.class) - .collectList() - .block(); - + .data(data, String.class) + .retrieveFlux(String.class) + .collectList() + .block(); + // @formatter:on assertThat(hi).containsOnly("hello a", "hello b", "hello c"); assertThat(this.controller.payloads).containsOnlyElementsOf(data.collectList().block()); } @@ -182,35 +193,33 @@ public class RSocketMessageHandlerITests { @Test public void retrieveFluxWhenDataStringAndSecureThenDenied() throws Exception { String data = "a"; - assertThatCode(() -> this.requester.route("secure.hello") - .data(data) - .retrieveFlux(String.class) - .collectList() - .block() - ).isInstanceOf(ApplicationErrorException.class) - .hasMessageContaining("Access Denied"); - + assertThatExceptionOfType(ApplicationErrorException.class).isThrownBy( + () -> this.requester.route("secure.hello").data(data).retrieveFlux(String.class).collectList().block()) + .withMessageContaining("Access Denied"); assertThat(this.controller.payloads).isEmpty(); } @Test public void sendWhenSecureThenDenied() throws Exception { String data = "hi"; + // @formatter:off this.requester.route("secure.send") - .data(data) - .send() - .block(); - + .data(data) + .send() + .block(); + // @formatter:on assertThat(this.controller.payloads).isEmpty(); } @Test public void sendWhenPublicThenGranted() throws Exception { String data = "hi"; + // @formatter:off this.requester.route("send") - .data(data) - .send() - .block(); + .data(data) + .send() + .block(); + // @formatter:on assertThat(this.controller.awaitPayloads()).containsOnly("hi"); } @@ -219,26 +228,25 @@ public class RSocketMessageHandlerITests { static class Config { @Bean - public ServerController controller() { + ServerController controller() { return new ServerController(); } @Bean - public RSocketMessageHandler messageHandler() { + RSocketMessageHandler messageHandler() { RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setRSocketStrategies(rsocketStrategies()); return handler; } @Bean - public RSocketStrategies rsocketStrategies() { - return RSocketStrategies.builder() - .encoder(new BasicAuthenticationEncoder()) - .build(); + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build(); } @Bean MapReactiveUserDetailsService uds() { + // @formatter:off UserDetails rob = User.withDefaultPasswordEncoder() .username("rob") .password("password") @@ -249,45 +257,44 @@ public class RSocketMessageHandlerITests { .password("password") .roles("USER") .build(); + // @formatter:on return new MapReactiveUserDetailsService(rob, rossen); } @Bean PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { - rsocket - .authorizePayload(authorize -> { - authorize - .route("secure.*").authenticated() - .anyExchange().permitAll(); - }) - .basicAuthentication(Customizer.withDefaults()); + // @formatter:off + rsocket.authorizePayload( + (authorize) -> authorize + .route("secure.*").authenticated() + .anyExchange().permitAll() + ) + .basicAuthentication(Customizer.withDefaults()); + // @formatter:on return rsocket.build(); } + } @Controller static class ServerController { + private List payloads = new ArrayList<>(); - @MessageMapping({"secure.retrieve-mono", "retrieve-mono"}) + @MessageMapping({ "secure.retrieve-mono", "retrieve-mono" }) String retrieveMono(String payload) { add(payload); return "Hi " + payload; } - @MessageMapping({"secure.retrieve-flux", "retrieve-flux"}) + @MessageMapping({ "secure.retrieve-flux", "retrieve-flux" }) Flux retrieveFlux(Flux payload) { - return payload.doOnNext(this::add) - .map(p -> "hello " + p); + return payload.doOnNext(this::add).map((p) -> "hello " + p); } - @MessageMapping({"secure.send", "send"}) + @MessageMapping({ "secure.send", "send" }) Mono send(Mono payload) { - return payload - .doOnNext(this::add) - .then(Mono.fromRunnable(() -> { - doNotifyAll(); - })); + return payload.doOnNext(this::add).then(Mono.fromRunnable(() -> doNotifyAll())); } private synchronized void doNotifyAll() { @@ -302,6 +309,7 @@ public class RSocketMessageHandlerITests { private void add(String p) { this.payloads.add(p); } + } } diff --git a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/SimpleAuthenticationITests.java b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/SimpleAuthenticationITests.java index dec6fb2367..a2b195102f 100644 --- a/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/SimpleAuthenticationITests.java +++ b/config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/SimpleAuthenticationITests.java @@ -22,6 +22,7 @@ import java.util.List; import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import org.junit.After; @@ -50,9 +51,8 @@ import org.springframework.test.context.junit4.SpringRunner; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import static io.rsocket.metadata.WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rob Winch @@ -60,6 +60,7 @@ import static org.assertj.core.api.Assertions.assertThatCode; @ContextConfiguration @RunWith(SpringRunner.class) public class SimpleAuthenticationITests { + @Autowired RSocketMessageHandler handler; @@ -75,14 +76,16 @@ public class SimpleAuthenticationITests { @Before public void setup() { + // @formatter:off this.server = RSocketServer.create() .payloadDecoder(PayloadDecoder.ZERO_COPY) - .interceptors((registry) -> { - registry.forSocketAcceptor(this.interceptor); - }) + .interceptors((registry) -> + registry.forSocketAcceptor(this.interceptor) + ) .acceptor(this.handler.responder()) .bind(TcpServerTransport.create("localhost", 0)) .block(); + // @formatter:on } @After @@ -94,38 +97,42 @@ public class SimpleAuthenticationITests { @Test public void retrieveMonoWhenSecureThenDenied() throws Exception { + // @formatter:off this.requester = RSocketRequester.builder() - .rsocketStrategies(this.handler.getRSocketStrategies()) - .connectTcp("localhost", this.server.address().getPort()) - .block(); - + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", this.server.address().getPort()) + .block(); + // @formatter:on String data = "rob"; - assertThatCode(() -> this.requester.route("secure.retrieve-mono") - .data(data) - .retrieveMono(String.class) + // @formatter:off + assertThatExceptionOfType(ApplicationErrorException.class) + .isThrownBy(() -> this.requester.route("secure.retrieve-mono") + .data(data).retrieveMono(String.class) .block() - ) - .isInstanceOf(ApplicationErrorException.class); + ); + // @formatter:on assertThat(this.controller.payloads).isEmpty(); } @Test public void retrieveMonoWhenAuthorizedThenGranted() { - MimeType authenticationMimeType = MimeTypeUtils.parseMimeType(MESSAGE_RSOCKET_AUTHENTICATION.getString()); - + MimeType authenticationMimeType = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()); UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("rob", "password"); + // @formatter:off this.requester = RSocketRequester.builder() - .setupMetadata(credentials, authenticationMimeType) - .rsocketStrategies(this.handler.getRSocketStrategies()) - .connectTcp("localhost", this.server.address().getPort()) - .block(); + .setupMetadata(credentials, authenticationMimeType) + .rsocketStrategies(this.handler.getRSocketStrategies()) + .connectTcp("localhost", this.server.address().getPort()) + .block(); + // @formatter:on String data = "rob"; + // @formatter:off String hiRob = this.requester.route("secure.retrieve-mono") - .metadata(credentials, authenticationMimeType) - .data(data) - .retrieveMono(String.class) - .block(); - + .metadata(credentials, authenticationMimeType) + .data(data).retrieveMono(String.class) + .block(); + // @formatter:on assertThat(hiRob).isEqualTo("Hi rob"); assertThat(this.controller.payloads).containsOnly(data); } @@ -135,49 +142,46 @@ public class SimpleAuthenticationITests { static class Config { @Bean - public ServerController controller() { + ServerController controller() { return new ServerController(); } @Bean - public RSocketMessageHandler messageHandler() { + RSocketMessageHandler messageHandler() { RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setRSocketStrategies(rsocketStrategies()); return handler; } @Bean - public RSocketStrategies rsocketStrategies() { - return RSocketStrategies.builder() - .encoder(new SimpleAuthenticationEncoder()) - .build(); + RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder().encoder(new SimpleAuthenticationEncoder()).build(); } @Bean PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) { - rsocket - .authorizePayload(authorize -> - authorize - .anyRequest().authenticated() - .anyExchange().permitAll() - ) + rsocket.authorizePayload((authorize) -> authorize.anyRequest().authenticated().anyExchange().permitAll()) .simpleAuthentication(Customizer.withDefaults()); return rsocket.build(); } @Bean MapReactiveUserDetailsService uds() { + // @formatter:off UserDetails rob = User.withDefaultPasswordEncoder() .username("rob") .password("password") .roles("USER", "ADMIN") .build(); + // @formatter:on return new MapReactiveUserDetailsService(rob); } + } @Controller static class ServerController { + private List payloads = new ArrayList<>(); @MessageMapping("**") @@ -189,6 +193,7 @@ public class SimpleAuthenticationITests { private void add(String p) { this.payloads.add(p); } + } } diff --git a/config/src/integration-test/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParserTests.java b/config/src/integration-test/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParserTests.java index f2a954fed1..01edcd9f47 100644 --- a/config/src/integration-test/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParserTests.java +++ b/config/src/integration-test/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParserTests.java @@ -35,84 +35,74 @@ import org.springframework.security.ldap.userdetails.InetOrgPersonContextMapper; import static org.assertj.core.api.Assertions.assertThat; public class LdapProviderBeanDefinitionParserTests { + InMemoryXmlApplicationContext appCtx; @After public void closeAppContext() { - if (appCtx != null) { - appCtx.close(); - appCtx = null; + if (this.appCtx != null) { + this.appCtx.close(); + this.appCtx = null; } } @Test public void simpleProviderAuthenticatesCorrectly() { - appCtx = new InMemoryXmlApplicationContext("" - + "" - + " " - + "" - ); + this.appCtx = new InMemoryXmlApplicationContext("" + + "" + " " + + ""); - AuthenticationManager authenticationManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, AuthenticationManager.class); - Authentication auth = authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("ben", "benspassword")); + AuthenticationManager authenticationManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, + AuthenticationManager.class); + Authentication auth = authenticationManager + .authenticate(new UsernamePasswordAuthenticationToken("ben", "benspassword")); UserDetails ben = (UserDetails) auth.getPrincipal(); assertThat(ben.getAuthorities()).hasSize(3); } @Test public void multipleProvidersAreSupported() { - appCtx = new InMemoryXmlApplicationContext("" - + "" - + " " + this.appCtx = new InMemoryXmlApplicationContext("" + + "" + " " + " " - + "" - ); + + ""); - ProviderManager providerManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, ProviderManager.class); + ProviderManager providerManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, ProviderManager.class); assertThat(providerManager.getProviders()).hasSize(2); - assertThat(providerManager.getProviders()) - .extracting("authoritiesPopulator.groupSearchFilter") + assertThat(providerManager.getProviders()).extracting("authoritiesPopulator.groupSearchFilter") .containsExactly("member={0}", "uniqueMember={0}"); } @Test(expected = ApplicationContextException.class) public void missingServerEltCausesConfigException() { - new InMemoryXmlApplicationContext("" - + " " - + "" - ); + new InMemoryXmlApplicationContext( + "" + " " + ""); } @Test public void supportsPasswordComparisonAuthentication() { - appCtx = new InMemoryXmlApplicationContext("" - + "" - + " " - + " " - + " " - + "" - ); + this.appCtx = new InMemoryXmlApplicationContext("" + + "" + " " + + " " + " " + ""); - AuthenticationManager authenticationManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, AuthenticationManager.class); - Authentication auth = authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("ben", "benspassword")); + AuthenticationManager authenticationManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, + AuthenticationManager.class); + Authentication auth = authenticationManager + .authenticate(new UsernamePasswordAuthenticationToken("ben", "benspassword")); assertThat(auth).isNotNull(); } @Test public void supportsPasswordComparisonAuthenticationWithPasswordEncoder() { - appCtx = new InMemoryXmlApplicationContext("" - + "" - + " " - + " " - + " " - + " " - + " " - + "" - + "" - ); + this.appCtx = new InMemoryXmlApplicationContext("" + + "" + " " + + " " + " " + + " " + " " + "" + + ""); - AuthenticationManager authenticationManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, AuthenticationManager.class); + AuthenticationManager authenticationManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, + AuthenticationManager.class); Authentication auth = authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("ben", "ben")); assertThat(auth).isNotNull(); @@ -121,58 +111,52 @@ public class LdapProviderBeanDefinitionParserTests { // SEC-2472 @Test public void supportsCryptoPasswordEncoder() { - appCtx = new InMemoryXmlApplicationContext("" - + "" - + " " - + " " - + " " - + " " - + " " - + "" - + "" - ); + this.appCtx = new InMemoryXmlApplicationContext("" + + "" + " " + + " " + " " + " " + + " " + "" + + ""); - AuthenticationManager authenticationManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, AuthenticationManager.class); - Authentication auth = authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("bcrypt", "password")); + AuthenticationManager authenticationManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, + AuthenticationManager.class); + Authentication auth = authenticationManager + .authenticate(new UsernamePasswordAuthenticationToken("bcrypt", "password")); assertThat(auth).isNotNull(); } @Test public void inetOrgContextMapperIsSupported() { - appCtx = new InMemoryXmlApplicationContext("" - + "" - + " " - + "" - ); + this.appCtx = new InMemoryXmlApplicationContext( + "" + + "" + + " " + + ""); - ProviderManager providerManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, ProviderManager.class); + ProviderManager providerManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, ProviderManager.class); assertThat(providerManager.getProviders()).hasSize(1); - assertThat(providerManager.getProviders()) - .extracting("userDetailsContextMapper") - .allSatisfy(contextMapper -> assertThat(contextMapper).isInstanceOf(InetOrgPersonContextMapper.class)); + assertThat(providerManager.getProviders()).extracting("userDetailsContextMapper").allSatisfy( + (contextMapper) -> assertThat(contextMapper).isInstanceOf(InetOrgPersonContextMapper.class)); } @Test public void ldapAuthenticationProviderWorksWithPlaceholders() { System.setProperty("udp", "people"); System.setProperty("gsf", "member"); - appCtx = new InMemoryXmlApplicationContext("" - + "" + this.appCtx = new InMemoryXmlApplicationContext("" + "" + " " + "" - + "" - ); + + ""); - ProviderManager providerManager = appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, ProviderManager.class); + ProviderManager providerManager = this.appCtx.getBean(BeanIds.AUTHENTICATION_MANAGER, ProviderManager.class); assertThat(providerManager.getProviders()).hasSize(1); AuthenticationProvider authenticationProvider = providerManager.getProviders().get(0); - assertThat(authenticationProvider) - .extracting("authenticator.userDnFormat") - .satisfies(messageFormats -> assertThat(messageFormats).isEqualTo(new MessageFormat[]{new MessageFormat("uid={0},ou=people")})); - assertThat(authenticationProvider) - .extracting("authoritiesPopulator.groupSearchFilter") - .satisfies(searchFilter -> assertThat(searchFilter).isEqualTo("member={0}")); + assertThat(authenticationProvider).extracting("authenticator.userDnFormat") + .satisfies((messageFormats) -> assertThat(messageFormats) + .isEqualTo(new MessageFormat[] { new MessageFormat("uid={0},ou=people") })); + assertThat(authenticationProvider).extracting("authoritiesPopulator.groupSearchFilter") + .satisfies((searchFilter) -> assertThat(searchFilter).isEqualTo("member={0}")); } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParserTests.java b/config/src/integration-test/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParserTests.java index 564352400b..032a388401 100644 --- a/config/src/integration-test/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParserTests.java +++ b/config/src/integration-test/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParserTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.ldap; import java.io.IOException; @@ -35,22 +36,22 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Rob Winch */ public class LdapServerBeanDefinitionParserTests { + InMemoryXmlApplicationContext appCtx; @After public void closeAppContext() { - if (appCtx != null) { - appCtx.close(); - appCtx = null; + if (this.appCtx != null) { + this.appCtx.close(); + this.appCtx = null; } } @Test public void embeddedServerCreationContainsExpectedContextSourceAndData() { - appCtx = new InMemoryXmlApplicationContext( - ""); + this.appCtx = new InMemoryXmlApplicationContext(""); - DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) appCtx + DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) this.appCtx .getBean(BeanIds.CONTEXT_SOURCE); // Check data is loaded @@ -62,17 +63,14 @@ public class LdapServerBeanDefinitionParserTests { public void useOfUrlAttributeCreatesCorrectContextSource() throws Exception { int port = getDefaultPort(); // Create second "server" with a url pointing at embedded one - appCtx = new InMemoryXmlApplicationContext( - "" - + ""); + this.appCtx = new InMemoryXmlApplicationContext("" + ""); // Check the default context source is still there. - appCtx.getBean(BeanIds.CONTEXT_SOURCE); + this.appCtx.getBean(BeanIds.CONTEXT_SOURCE); - DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) appCtx + DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) this.appCtx .getBean("blah"); // Check data is loaded as before @@ -82,9 +80,9 @@ public class LdapServerBeanDefinitionParserTests { @Test public void loadingSpecificLdifFileIsSuccessful() { - appCtx = new InMemoryXmlApplicationContext( + this.appCtx = new InMemoryXmlApplicationContext( ""); - DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) appCtx + DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) this.appCtx .getBean(BeanIds.CONTEXT_SOURCE); LdapTemplate template = new LdapTemplate(contextSource); @@ -93,8 +91,8 @@ public class LdapServerBeanDefinitionParserTests { @Test public void defaultLdifFileIsSuccessful() { - appCtx = new InMemoryXmlApplicationContext(""); - ApacheDSContainer dsContainer = appCtx.getBean(ApacheDSContainer.class); + this.appCtx = new InMemoryXmlApplicationContext(""); + ApacheDSContainer dsContainer = this.appCtx.getBean(ApacheDSContainer.class); assertThat(ReflectionTestUtils.getField(dsContainer, "ldifResources")).isEqualTo("classpath*:*.ldif"); } @@ -104,4 +102,5 @@ public class LdapServerBeanDefinitionParserTests { return server.getLocalPort(); } } + } diff --git a/config/src/integration-test/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParserTests.java b/config/src/integration-test/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParserTests.java index caff4c8757..e9c8c32a17 100644 --- a/config/src/integration-test/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParserTests.java +++ b/config/src/integration-test/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParserTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.ldap; import java.util.Set; @@ -36,11 +37,6 @@ import org.springframework.security.ldap.userdetails.PersonContextMapper; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; -import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.INET_ORG_PERSON_MAPPER_CLASS; -import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.LDAP_AUTHORITIES_POPULATOR_CLASS; -import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.LDAP_SEARCH_CLASS; -import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.LDAP_USER_MAPPER_CLASS; -import static org.springframework.security.config.ldap.LdapUserServiceBeanDefinitionParser.PERSON_MAPPER_CLASS; /** * @author Luke Taylor @@ -48,36 +44,45 @@ import static org.springframework.security.config.ldap.LdapUserServiceBeanDefini * @author Eddú Meléndez */ public class LdapUserServiceBeanDefinitionParserTests { + private InMemoryXmlApplicationContext appCtx; @After public void closeAppContext() { - if (appCtx != null) { - appCtx.close(); - appCtx = null; + if (this.appCtx != null) { + this.appCtx.close(); + this.appCtx = null; } } @Test public void beanClassNamesAreCorrect() { - assertThat(FilterBasedLdapUserSearch.class.getName()).isEqualTo(LDAP_SEARCH_CLASS); - assertThat(PersonContextMapper.class.getName()).isEqualTo(PERSON_MAPPER_CLASS); - assertThat(InetOrgPersonContextMapper.class.getName()).isEqualTo(INET_ORG_PERSON_MAPPER_CLASS); - assertThat(LdapUserDetailsMapper.class.getName()).isEqualTo(LDAP_USER_MAPPER_CLASS); - assertThat(DefaultLdapAuthoritiesPopulator.class.getName()).isEqualTo(LDAP_AUTHORITIES_POPULATOR_CLASS); - assertThat(new LdapUserServiceBeanDefinitionParser().getBeanClassName(mock(Element.class))).isEqualTo(LdapUserDetailsService.class.getName()); + assertThat(FilterBasedLdapUserSearch.class.getName()) + .isEqualTo(LdapUserServiceBeanDefinitionParser.LDAP_SEARCH_CLASS); + assertThat(PersonContextMapper.class.getName()) + .isEqualTo(LdapUserServiceBeanDefinitionParser.PERSON_MAPPER_CLASS); + assertThat(InetOrgPersonContextMapper.class.getName()) + .isEqualTo(LdapUserServiceBeanDefinitionParser.INET_ORG_PERSON_MAPPER_CLASS); + assertThat(LdapUserDetailsMapper.class.getName()) + .isEqualTo(LdapUserServiceBeanDefinitionParser.LDAP_USER_MAPPER_CLASS); + assertThat(DefaultLdapAuthoritiesPopulator.class.getName()) + .isEqualTo(LdapUserServiceBeanDefinitionParser.LDAP_AUTHORITIES_POPULATOR_CLASS); + assertThat(new LdapUserServiceBeanDefinitionParser().getBeanClassName(mock(Element.class))) + .isEqualTo(LdapUserDetailsService.class.getName()); } @Test public void minimalConfigurationIsParsedOk() { - setContext(""); + setContext( + ""); } @Test public void userServiceReturnsExpectedData() { - setContext(""); + setContext( + ""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails ben = uds.loadUserByUsername("ben"); Set authorities = AuthorityUtils.authorityListToSet(ben.getAuthorities()); @@ -87,12 +92,11 @@ public class LdapUserServiceBeanDefinitionParserTests { @Test public void differentUserSearchBaseWorksAsExpected() { - setContext(""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails joe = uds.loadUserByUsername("Joe Smeth"); assertThat(joe.getUsername()).isEqualTo("Joe Smeth"); @@ -100,27 +104,26 @@ public class LdapUserServiceBeanDefinitionParserTests { @Test public void rolePrefixIsSupported() { - setContext("" - + ""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails ben = uds.loadUserByUsername("ben"); assertThat(AuthorityUtils.authorityListToSet(ben.getAuthorities())).contains("PREFIX_DEVELOPERS"); - uds = (UserDetailsService) appCtx.getBean("ldapUDSNoPrefix"); + uds = (UserDetailsService) this.appCtx.getBean("ldapUDSNoPrefix"); ben = uds.loadUserByUsername("ben"); assertThat(AuthorityUtils.authorityListToSet(ben.getAuthorities())).contains("DEVELOPERS"); } @Test public void differentGroupRoleAttributeWorksAsExpected() { - setContext(""); + setContext( + ""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails ben = uds.loadUserByUsername("ben"); Set authorities = AuthorityUtils.authorityListToSet(ben.getAuthorities()); @@ -131,18 +134,18 @@ public class LdapUserServiceBeanDefinitionParserTests { @Test public void isSupportedByAuthenticationProviderElement() { - setContext("" - + "" - + " " - + " " - + " " + ""); + setContext( + "" + + "" + " " + + " " + " " + + ""); } @Test public void personContextMapperIsSupported() { setContext("" + ""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails ben = uds.loadUserByUsername("ben"); assertThat(ben instanceof Person).isTrue(); } @@ -151,7 +154,7 @@ public class LdapUserServiceBeanDefinitionParserTests { public void inetOrgContextMapperIsSupported() { setContext("" + ""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails ben = uds.loadUserByUsername("ben"); assertThat(ben instanceof InetOrgPerson).isTrue(); } @@ -160,15 +163,15 @@ public class LdapUserServiceBeanDefinitionParserTests { public void externalContextMapperIsSupported() { setContext("" + "" - + ""); + + ""); - UserDetailsService uds = (UserDetailsService) appCtx.getBean("ldapUDS"); + UserDetailsService uds = (UserDetailsService) this.appCtx.getBean("ldapUDS"); UserDetails ben = uds.loadUserByUsername("ben"); assertThat(ben instanceof InetOrgPerson).isTrue(); } private void setContext(String context) { - appCtx = new InMemoryXmlApplicationContext(context); + this.appCtx = new InMemoryXmlApplicationContext(context); } + } diff --git a/config/src/main/java/org/springframework/security/config/BeanIds.java b/config/src/main/java/org/springframework/security/config/BeanIds.java index 85027d2c73..fcf2e5fc1d 100644 --- a/config/src/main/java/org/springframework/security/config/BeanIds.java +++ b/config/src/main/java/org/springframework/security/config/BeanIds.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; /** @@ -25,6 +26,7 @@ package org.springframework.security.config; * @author Luke Taylor */ public abstract class BeanIds { + private static final String PREFIX = "org.springframework.security."; /** @@ -33,29 +35,31 @@ public abstract class BeanIds { */ public static final String AUTHENTICATION_MANAGER = PREFIX + "authenticationManager"; - /** External alias for FilterChainProxy bean, for use in web.xml files */ + /** + * External alias for FilterChainProxy bean, for use in web.xml files + */ public static final String SPRING_SECURITY_FILTER_CHAIN = "springSecurityFilterChain"; - public static final String CONTEXT_SOURCE_SETTING_POST_PROCESSOR = PREFIX - + "contextSettingPostProcessor"; + public static final String CONTEXT_SOURCE_SETTING_POST_PROCESSOR = PREFIX + "contextSettingPostProcessor"; public static final String USER_DETAILS_SERVICE = PREFIX + "userDetailsService"; - public static final String USER_DETAILS_SERVICE_FACTORY = PREFIX - + "userDetailsServiceFactory"; - public static final String METHOD_ACCESS_MANAGER = PREFIX - + "defaultMethodAccessManager"; + public static final String USER_DETAILS_SERVICE_FACTORY = PREFIX + "userDetailsServiceFactory"; + + public static final String METHOD_ACCESS_MANAGER = PREFIX + "defaultMethodAccessManager"; public static final String FILTER_CHAIN_PROXY = PREFIX + "filterChainProxy"; + public static final String FILTER_CHAINS = PREFIX + "filterChains"; - public static final String METHOD_SECURITY_METADATA_SOURCE_ADVISOR = PREFIX - + "methodSecurityMetadataSourceAdvisor"; - public static final String EMBEDDED_APACHE_DS = PREFIX - + "apacheDirectoryServerContainer"; - public static final String EMBEDDED_UNBOUNDID = PREFIX - + "unboundidServerContainer"; + public static final String METHOD_SECURITY_METADATA_SOURCE_ADVISOR = PREFIX + "methodSecurityMetadataSourceAdvisor"; + + public static final String EMBEDDED_APACHE_DS = PREFIX + "apacheDirectoryServerContainer"; + + public static final String EMBEDDED_UNBOUNDID = PREFIX + "unboundidServerContainer"; + public static final String CONTEXT_SOURCE = PREFIX + "securityContextSource"; public static final String DEBUG_FILTER = PREFIX + "debugFilter"; + } diff --git a/config/src/main/java/org/springframework/security/config/Customizer.java b/config/src/main/java/org/springframework/security/config/Customizer.java index 048ec1f3e0..fc68b127e0 100644 --- a/config/src/main/java/org/springframework/security/config/Customizer.java +++ b/config/src/main/java/org/springframework/security/config/Customizer.java @@ -28,17 +28,17 @@ public interface Customizer { /** * Performs the customizations on the input argument. - * * @param t the input argument */ void customize(T t); /** * Returns a {@link Customizer} that does not alter the input argument. - * * @return a {@link Customizer} that does not alter the input argument. */ static Customizer withDefaults() { - return t -> {}; + return (t) -> { + }; } + } diff --git a/config/src/main/java/org/springframework/security/config/DebugBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/DebugBeanDefinitionParser.java index e7327e5ab0..b96b2ee872 100644 --- a/config/src/main/java/org/springframework/security/config/DebugBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/DebugBeanDefinitionParser.java @@ -13,24 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.config.debug.SecurityDebugBeanFactoryPostProcessor; -import org.w3c.dom.Element; /** * @author Luke Taylor */ public class DebugBeanDefinitionParser implements BeanDefinitionParser { - public BeanDefinition parse(Element element, ParserContext parserContext) { - RootBeanDefinition debugPP = new RootBeanDefinition( - SecurityDebugBeanFactoryPostProcessor.class); - parserContext.getReaderContext().registerWithGeneratedName(debugPP); + @Override + public BeanDefinition parse(Element element, ParserContext parserContext) { + RootBeanDefinition debugPP = new RootBeanDefinition(SecurityDebugBeanFactoryPostProcessor.class); + parserContext.getReaderContext().registerWithGeneratedName(debugPP); return null; } + } diff --git a/config/src/main/java/org/springframework/security/config/Elements.java b/config/src/main/java/org/springframework/security/config/Elements.java index 35a4d0fa17..20ad615769 100644 --- a/config/src/main/java/org/springframework/security/config/Elements.java +++ b/config/src/main/java/org/springframework/security/config/Elements.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; /** @@ -23,61 +24,113 @@ package org.springframework.security.config; public abstract class Elements { public static final String ACCESS_DENIED_HANDLER = "access-denied-handler"; + public static final String AUTHENTICATION_MANAGER = "authentication-manager"; + public static final String AFTER_INVOCATION_PROVIDER = "after-invocation-provider"; + public static final String USER_SERVICE = "user-service"; + public static final String JDBC_USER_SERVICE = "jdbc-user-service"; + public static final String FILTER_CHAIN_MAP = "filter-chain-map"; + public static final String INTERCEPT_METHODS = "intercept-methods"; + public static final String INTERCEPT_URL = "intercept-url"; + public static final String AUTHENTICATION_PROVIDER = "authentication-provider"; + public static final String HTTP = "http"; + public static final String LDAP_PROVIDER = "ldap-authentication-provider"; + public static final String LDAP_SERVER = "ldap-server"; + public static final String LDAP_USER_SERVICE = "ldap-user-service"; + public static final String PROTECT_POINTCUT = "protect-pointcut"; + public static final String EXPRESSION_HANDLER = "expression-handler"; + public static final String INVOCATION_HANDLING = "pre-post-annotation-handling"; + public static final String INVOCATION_ATTRIBUTE_FACTORY = "invocation-attribute-factory"; + public static final String PRE_INVOCATION_ADVICE = "pre-invocation-advice"; + public static final String POST_INVOCATION_ADVICE = "post-invocation-advice"; + public static final String PROTECT = "protect"; + public static final String SESSION_MANAGEMENT = "session-management"; + public static final String CONCURRENT_SESSIONS = "concurrency-control"; + public static final String LOGOUT = "logout"; + public static final String FORM_LOGIN = "form-login"; + public static final String OPENID_LOGIN = "openid-login"; + public static final String OPENID_ATTRIBUTE_EXCHANGE = "attribute-exchange"; + public static final String OPENID_ATTRIBUTE = "openid-attribute"; + public static final String BASIC_AUTH = "http-basic"; + public static final String REMEMBER_ME = "remember-me"; + public static final String ANONYMOUS = "anonymous"; + public static final String FILTER_CHAIN = "filter-chain"; + public static final String GLOBAL_METHOD_SECURITY = "global-method-security"; + public static final String PASSWORD_ENCODER = "password-encoder"; + public static final String PORT_MAPPINGS = "port-mappings"; + public static final String PORT_MAPPING = "port-mapping"; + public static final String CUSTOM_FILTER = "custom-filter"; + public static final String REQUEST_CACHE = "request-cache"; + public static final String X509 = "x509"; + public static final String JEE = "jee"; + public static final String FILTER_SECURITY_METADATA_SOURCE = "filter-security-metadata-source"; + public static final String METHOD_SECURITY_METADATA_SOURCE = "method-security-metadata-source"; + public static final String LDAP_PASSWORD_COMPARE = "password-compare"; + public static final String DEBUG = "debug"; + public static final String HTTP_FIREWALL = "http-firewall"; + public static final String HEADERS = "headers"; + public static final String CORS = "cors"; + public static final String CSRF = "csrf"; public static final String OAUTH2_RESOURCE_SERVER = "oauth2-resource-server"; + public static final String JWT = "jwt"; + public static final String OPAQUE_TOKEN = "opaque-token"; public static final String WEBSOCKET_MESSAGE_BROKER = "websocket-message-broker"; + public static final String INTERCEPT_MESSAGE = "intercept-message"; public static final String OAUTH2_LOGIN = "oauth2-login"; + public static final String OAUTH2_CLIENT = "oauth2-client"; + public static final String CLIENT_REGISTRATIONS = "client-registrations"; + } diff --git a/config/src/main/java/org/springframework/security/config/SecurityNamespaceHandler.java b/config/src/main/java/org/springframework/security/config/SecurityNamespaceHandler.java index fc044e24e2..9a11086870 100644 --- a/config/src/main/java/org/springframework/security/config/SecurityNamespaceHandler.java +++ b/config/src/main/java/org/springframework/security/config/SecurityNamespaceHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import java.util.HashMap; @@ -59,150 +60,140 @@ import org.springframework.util.ClassUtils; * @since 2.0 */ public final class SecurityNamespaceHandler implements NamespaceHandler { + private static final String FILTER_CHAIN_PROXY_CLASSNAME = "org.springframework.security.web.FilterChainProxy"; + private static final String MESSAGE_CLASSNAME = "org.springframework.messaging.Message"; + private final Log logger = LogFactory.getLog(getClass()); + private final Map parsers = new HashMap<>(); + private final BeanDefinitionDecorator interceptMethodsBDD = new InterceptMethodsBeanDefinitionDecorator(); + private BeanDefinitionDecorator filterChainMapBDD; public SecurityNamespaceHandler() { String coreVersion = SpringSecurityCoreVersion.getVersion(); - Package pkg = SpringSecurityCoreVersion.class.getPackage(); - if (pkg == null || coreVersion == null) { - logger.info("Couldn't determine package version information."); + this.logger.info("Couldn't determine package version information."); return; } - String version = pkg.getImplementationVersion(); - logger.info("Spring Security 'config' module version is " + version); - + this.logger.info("Spring Security 'config' module version is " + version); if (version.compareTo(coreVersion) != 0) { - logger.error("You are running with different versions of the Spring Security 'core' and 'config' modules"); + this.logger.error( + "You are running with different versions of the Spring Security 'core' and 'config' modules"); } } + @Override public BeanDefinition parse(Element element, ParserContext pc) { if (!namespaceMatchesVersion(element)) { - pc.getReaderContext() - .fatal("You cannot use a spring-security-2.0.xsd or spring-security-3.0.xsd or spring-security-3.1.xsd schema or spring-security-3.2.xsd schema or spring-security-4.0.xsd schema " - + "with Spring Security 5.4. Please update your schema declarations to the 5.4 schema.", - element); + pc.getReaderContext().fatal("You cannot use a spring-security-2.0.xsd or spring-security-3.0.xsd or " + + "spring-security-3.1.xsd schema or spring-security-3.2.xsd schema or spring-security-4.0.xsd schema " + + "with Spring Security 5.4. Please update your schema declarations to the 5.4 schema.", element); } String name = pc.getDelegate().getLocalName(element); - BeanDefinitionParser parser = parsers.get(name); - + BeanDefinitionParser parser = this.parsers.get(name); if (parser == null) { // SEC-1455. Load parsers when required, not just on init(). loadParsers(); } - - if (parser == null) { - if (Elements.HTTP.equals(name) - || Elements.FILTER_SECURITY_METADATA_SOURCE.equals(name) - || Elements.FILTER_CHAIN_MAP.equals(name) - || Elements.FILTER_CHAIN.equals(name)) { - reportMissingWebClasses(name, pc, element); - } - else { - reportUnsupportedNodeType(name, pc, element); - } - - return null; + if (parser != null) { + return parser.parse(element, pc); } - - return parser.parse(element, pc); + if (Elements.HTTP.equals(name) || Elements.FILTER_SECURITY_METADATA_SOURCE.equals(name) + || Elements.FILTER_CHAIN_MAP.equals(name) || Elements.FILTER_CHAIN.equals(name)) { + reportMissingWebClasses(name, pc, element); + } + else { + reportUnsupportedNodeType(name, pc, element); + } + return null; } - public BeanDefinitionHolder decorate(Node node, BeanDefinitionHolder definition, - ParserContext pc) { + @Override + public BeanDefinitionHolder decorate(Node node, BeanDefinitionHolder definition, ParserContext pc) { String name = pc.getDelegate().getLocalName(node); - - // We only handle elements if (node instanceof Element) { + // We only handle elements if (Elements.INTERCEPT_METHODS.equals(name)) { - return interceptMethodsBDD.decorate(node, definition, pc); + return this.interceptMethodsBDD.decorate(node, definition, pc); } - if (Elements.FILTER_CHAIN_MAP.equals(name)) { - if (filterChainMapBDD == null) { + if (this.filterChainMapBDD == null) { loadParsers(); } - if (filterChainMapBDD == null) { + if (this.filterChainMapBDD == null) { reportMissingWebClasses(name, pc, node); } - return filterChainMapBDD.decorate(node, definition, pc); + return this.filterChainMapBDD.decorate(node, definition, pc); } } - reportUnsupportedNodeType(name, pc, node); - return null; } private void reportUnsupportedNodeType(String name, ParserContext pc, Node node) { - pc.getReaderContext().fatal( - "Security namespace does not support decoration of " - + (node instanceof Element ? "element" : "attribute") + " [" - + name + "]", node); + pc.getReaderContext().fatal("Security namespace does not support decoration of " + + ((node instanceof Element) ? "element" : "attribute") + " [" + name + "]", node); } private void reportMissingWebClasses(String nodeName, ParserContext pc, Node node) { String errorMessage = "The classes from the spring-security-web jar " - + "(or one of its dependencies) are not available. You need these to use <" - + nodeName + ">"; + + "(or one of its dependencies) are not available. You need these to use <" + nodeName + ">"; try { ClassUtils.forName(FILTER_CHAIN_PROXY_CLASSNAME, getClass().getClassLoader()); // no details available pc.getReaderContext().fatal(errorMessage, node); } - catch (Throwable cause) { + catch (Throwable ex) { // provide details on why it could not be loaded - pc.getReaderContext().fatal(errorMessage, node, cause); + pc.getReaderContext().fatal(errorMessage, node, ex); } } + @Override public void init() { loadParsers(); } private void loadParsers() { // Parsers - parsers.put(Elements.LDAP_PROVIDER, new LdapProviderBeanDefinitionParser()); - parsers.put(Elements.LDAP_SERVER, new LdapServerBeanDefinitionParser()); - parsers.put(Elements.LDAP_USER_SERVICE, new LdapUserServiceBeanDefinitionParser()); - parsers.put(Elements.USER_SERVICE, new UserServiceBeanDefinitionParser()); - parsers.put(Elements.JDBC_USER_SERVICE, new JdbcUserServiceBeanDefinitionParser()); - parsers.put(Elements.AUTHENTICATION_PROVIDER, - new AuthenticationProviderBeanDefinitionParser()); - parsers.put(Elements.GLOBAL_METHOD_SECURITY, - new GlobalMethodSecurityBeanDefinitionParser()); - parsers.put(Elements.AUTHENTICATION_MANAGER, - new AuthenticationManagerBeanDefinitionParser()); - parsers.put(Elements.METHOD_SECURITY_METADATA_SOURCE, + this.parsers.put(Elements.LDAP_PROVIDER, new LdapProviderBeanDefinitionParser()); + this.parsers.put(Elements.LDAP_SERVER, new LdapServerBeanDefinitionParser()); + this.parsers.put(Elements.LDAP_USER_SERVICE, new LdapUserServiceBeanDefinitionParser()); + this.parsers.put(Elements.USER_SERVICE, new UserServiceBeanDefinitionParser()); + this.parsers.put(Elements.JDBC_USER_SERVICE, new JdbcUserServiceBeanDefinitionParser()); + this.parsers.put(Elements.AUTHENTICATION_PROVIDER, new AuthenticationProviderBeanDefinitionParser()); + this.parsers.put(Elements.GLOBAL_METHOD_SECURITY, new GlobalMethodSecurityBeanDefinitionParser()); + this.parsers.put(Elements.AUTHENTICATION_MANAGER, new AuthenticationManagerBeanDefinitionParser()); + this.parsers.put(Elements.METHOD_SECURITY_METADATA_SOURCE, new MethodSecurityMetadataSourceBeanDefinitionParser()); - - // Only load the web-namespace parsers if the web classes are available - if (ClassUtils.isPresent(FILTER_CHAIN_PROXY_CLASSNAME, getClass() - .getClassLoader())) { - parsers.put(Elements.DEBUG, new DebugBeanDefinitionParser()); - parsers.put(Elements.HTTP, new HttpSecurityBeanDefinitionParser()); - parsers.put(Elements.HTTP_FIREWALL, new HttpFirewallBeanDefinitionParser()); - parsers.put(Elements.FILTER_SECURITY_METADATA_SOURCE, - new FilterInvocationSecurityMetadataSourceParser()); - parsers.put(Elements.FILTER_CHAIN, new FilterChainBeanDefinitionParser()); - filterChainMapBDD = new FilterChainMapBeanDefinitionDecorator(); - parsers.put(Elements.CLIENT_REGISTRATIONS, new ClientRegistrationsBeanDefinitionParser()); + if (ClassUtils.isPresent(FILTER_CHAIN_PROXY_CLASSNAME, getClass().getClassLoader())) { + loadWebParsers(); } - if (ClassUtils.isPresent(MESSAGE_CLASSNAME, getClass().getClassLoader())) { - parsers.put(Elements.WEBSOCKET_MESSAGE_BROKER, - new WebSocketMessageBrokerSecurityBeanDefinitionParser()); + loadWebSocketParsers(); } } + private void loadWebParsers() { + this.parsers.put(Elements.DEBUG, new DebugBeanDefinitionParser()); + this.parsers.put(Elements.HTTP, new HttpSecurityBeanDefinitionParser()); + this.parsers.put(Elements.HTTP_FIREWALL, new HttpFirewallBeanDefinitionParser()); + this.parsers.put(Elements.FILTER_SECURITY_METADATA_SOURCE, new FilterInvocationSecurityMetadataSourceParser()); + this.parsers.put(Elements.FILTER_CHAIN, new FilterChainBeanDefinitionParser()); + this.filterChainMapBDD = new FilterChainMapBeanDefinitionDecorator(); + this.parsers.put(Elements.CLIENT_REGISTRATIONS, new ClientRegistrationsBeanDefinitionParser()); + } + + private void loadWebSocketParsers() { + this.parsers.put(Elements.WEBSOCKET_MESSAGE_BROKER, new WebSocketMessageBrokerSecurityBeanDefinitionParser()); + } + /** * Check that the schema location declared in the source file being parsed matches the * Spring Security version. The old 2.0 schema is not compatible with the 3.1 parser, @@ -212,7 +203,6 @@ public final class SecurityNamespaceHandler implements NamespaceHandler { * using 3.0 as an error too. It might be an error to declare spring-security.xsd as * an alias, but you are only going to find that out when one of the sub parsers * breaks. - * * @param element the element that is to be parsed next * @return true if we find a schema declaration that matches */ @@ -222,8 +212,7 @@ public final class SecurityNamespaceHandler implements NamespaceHandler { } private boolean matchesVersionInternal(Element element) { - String schemaLocation = element.getAttributeNS( - "http://www.w3.org/2001/XMLSchema-instance", "schemaLocation"); + String schemaLocation = element.getAttributeNS("http://www.w3.org/2001/XMLSchema-instance", "schemaLocation"); return schemaLocation.matches("(?m).*spring-security-5\\.4.*.xsd.*") || schemaLocation.matches("(?m).*spring-security.xsd.*") || !schemaLocation.matches("(?m).*spring-security.*"); diff --git a/config/src/main/java/org/springframework/security/config/annotation/AbstractConfiguredSecurityBuilder.java b/config/src/main/java/org/springframework/security/config/annotation/AbstractConfiguredSecurityBuilder.java index 34fe85cb2c..b9d6ebed1e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/AbstractConfiguredSecurityBuilder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/AbstractConfiguredSecurityBuilder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; import java.util.ArrayList; @@ -44,18 +45,18 @@ import org.springframework.web.filter.DelegatingFilterProxy; * filters necessary for session management, form based login, authorization, etc. *

* - * @see WebSecurity - * - * @author Rob Winch - * * @param The object that this builder returns * @param The type of this builder (that is returned by the base class) + * @author Rob Winch + * @see WebSecurity */ public abstract class AbstractConfiguredSecurityBuilder> extends AbstractSecurityBuilder { + private final Log logger = LogFactory.getLog(getClass()); private final LinkedHashMap>, List>> configurers = new LinkedHashMap<>(); + private final List> configurersAddedInInitializing = new ArrayList<>(); private final Map, Object> sharedObjects = new HashMap<>(); @@ -70,11 +71,9 @@ public abstract class AbstractConfiguredSecurityBuilder objectPostProcessor) { + protected AbstractConfiguredSecurityBuilder(ObjectPostProcessor objectPostProcessor) { this(objectPostProcessor, false); } @@ -82,13 +81,11 @@ public abstract class AbstractConfiguredSecurityBuilder objectPostProcessor, + protected AbstractConfiguredSecurityBuilder(ObjectPostProcessor objectPostProcessor, boolean allowConfigurersOfSameType) { Assert.notNull(objectPostProcessor, "objectPostProcessor cannot be null"); this.objectPostProcessor = objectPostProcessor; @@ -98,37 +95,32 @@ public abstract class AbstractConfiguredSecurityBuilder> C apply(C configurer) - throws Exception { - configurer.addObjectPostProcessor(objectPostProcessor); + public > C apply(C configurer) throws Exception { + configurer.addObjectPostProcessor(this.objectPostProcessor); configurer.setBuilder((B) this); add(configurer); return configurer; @@ -138,7 +130,6 @@ public abstract class AbstractConfiguredSecurityBuilder> void add(C configurer) { Assert.notNull(configurer, "configurer cannot be null"); - Class> clazz = (Class>) configurer .getClass(); - synchronized (configurers) { - if (buildState.isConfigured()) { - throw new IllegalStateException("Cannot apply " + configurer - + " to already built object"); + synchronized (this.configurers) { + if (this.buildState.isConfigured()) { + throw new IllegalStateException("Cannot apply " + configurer + " to already built object"); } - List> configs = allowConfigurersOfSameType ? this.configurers - .get(clazz) : null; - if (configs == null) { - configs = new ArrayList<>(1); + List> configs = null; + if (this.allowConfigurersOfSameType) { + configs = this.configurers.get(clazz); } + configs = (configs != null) ? configs : new ArrayList<>(1); configs.add(configurer); this.configurers.put(clazz, configs); - if (buildState.isInitializing()) { + if (this.buildState.isInitializing()) { this.configurersAddedInInitializing.add(configurer); } } @@ -211,7 +197,6 @@ public abstract class AbstractConfiguredSecurityBuildernull if not * found. Note that object hierarchies are not considered. - * * @param clazz * @return the {@link SecurityConfigurer} for further customizations */ @@ -253,17 +236,14 @@ public abstract class AbstractConfiguredSecurityBuilder "Only one configurer expected for type " + clazz + ", but got " + configs); return (C) configs.get(0); } /** * Removes and returns the {@link SecurityConfigurer} by its class name or * null if not found. Note that object hierarchies are not considered. - * * @param clazz * @return */ @@ -273,10 +253,8 @@ public abstract class AbstractConfiguredSecurityBuilder "Only one configurer expected for type " + clazz + ", but got " + configs); return (C) configs.get(0); } @@ -295,7 +273,6 @@ public abstract class AbstractConfiguredSecurityBuilder> configurers = getConfigurers(); - for (SecurityConfigurer configurer : configurers) { configurer.init((B) this); } - - for (SecurityConfigurer configurer : configurersAddedInInitializing) { + for (SecurityConfigurer configurer : this.configurersAddedInInitializing) { configurer.init((B) this); } } @@ -378,7 +345,6 @@ public abstract class AbstractConfiguredSecurityBuilder> configurers = getConfigurers(); - for (SecurityConfigurer configurer : configurers) { configurer.configure((B) this); } @@ -397,8 +363,8 @@ public abstract class AbstractConfiguredSecurityBuilder= CONFIGURING.order; + return this.order >= CONFIGURING.order; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/AbstractSecurityBuilder.java b/config/src/main/java/org/springframework/security/config/annotation/AbstractSecurityBuilder.java index 2f4c4022fe..925307b4c2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/AbstractSecurityBuilder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/AbstractSecurityBuilder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; import java.util.concurrent.atomic.AtomicBoolean; @@ -22,20 +23,16 @@ import java.util.concurrent.atomic.AtomicBoolean; * time. * * @param the type of Object that is being built - * * @author Rob Winch * */ public abstract class AbstractSecurityBuilder implements SecurityBuilder { + private AtomicBoolean building = new AtomicBoolean(); private O object; - /* - * (non-Javadoc) - * - * @see org.springframework.security.config.annotation.SecurityBuilder#build() - */ + @Override public final O build() throws Exception { if (this.building.compareAndSet(false, true)) { this.object = doBuild(); @@ -47,7 +44,6 @@ public abstract class AbstractSecurityBuilder implements SecurityBuilder { /** * Gets the object that was built. If it has not been built yet an Exception is * thrown. - * * @return the Object that was built */ public final O getObject() { @@ -59,10 +55,9 @@ public abstract class AbstractSecurityBuilder implements SecurityBuilder { /** * Subclasses should implement this to perform the build. - * * @return the object that should be returned by {@link #build()}. - * * @throws Exception if an error occurs */ protected abstract O doBuild() throws Exception; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/AlreadyBuiltException.java b/config/src/main/java/org/springframework/security/config/annotation/AlreadyBuiltException.java index b84bbff45e..06ad51cec7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/AlreadyBuiltException.java +++ b/config/src/main/java/org/springframework/security/config/annotation/AlreadyBuiltException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; /** @@ -28,4 +29,5 @@ public class AlreadyBuiltException extends IllegalStateException { } private static final long serialVersionUID = -5891004752785553015L; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/ObjectPostProcessor.java b/config/src/main/java/org/springframework/security/config/annotation/ObjectPostProcessor.java index ca07992749..53a43d1ce9 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/ObjectPostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/annotation/ObjectPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; import org.springframework.beans.factory.Aware; @@ -25,7 +26,6 @@ import org.springframework.beans.factory.InitializingBean; * {@link DisposableBean#destroy()} has been invoked. * * @param the bound of the types of Objects this {@link ObjectPostProcessor} supports. - * * @author Rob Winch * @since 3.2 */ @@ -34,9 +34,9 @@ public interface ObjectPostProcessor { /** * Initialize the object possibly returning a modified instance that should be used * instead. - * * @param object the object to initialize * @return the initialized version of the object */ O postProcess(O object); -} \ No newline at end of file + +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/SecurityBuilder.java b/config/src/main/java/org/springframework/security/config/annotation/SecurityBuilder.java index 1f097537c0..d141d8c3f7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/SecurityBuilder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/SecurityBuilder.java @@ -13,23 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; /** * Interface for building an Object * + * @param The type of the Object being built * @author Rob Winch * @since 3.2 - * - * @param The type of the Object being built */ public interface SecurityBuilder { /** * Builds the object and returns it or null. - * * @return the Object to be built or null if the implementation allows it. * @throws Exception if an error occurred when building the Object */ O build() throws Exception; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurer.java index 4ddba28274..59f51a8dd6 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; /** @@ -21,21 +22,19 @@ package org.springframework.security.config.annotation; * {@link #init(SecurityBuilder)} methods have been invoked, each * {@link #configure(SecurityBuilder)} method is invoked. * - * @see AbstractConfiguredSecurityBuilder - * - * @author Rob Winch - * * @param The object being built by the {@link SecurityBuilder} B * @param The {@link SecurityBuilder} that builds objects of type O. This is also the * {@link SecurityBuilder} that is being configured. + * @author Rob Winch + * @see AbstractConfiguredSecurityBuilder */ public interface SecurityConfigurer> { + /** * Initialize the {@link SecurityBuilder}. Here only shared state should be created * and modified, but not properties on the {@link SecurityBuilder} used for building * the object. This ensures that the {@link #configure(SecurityBuilder)} method uses * the correct shared objects when building. Configurers should be applied here. - * * @param builder * @throws Exception */ @@ -44,9 +43,9 @@ public interface SecurityConfigurer> { /** * Configure the {@link SecurityBuilder} by setting the necessary properties on the * {@link SecurityBuilder}. - * * @param builder * @throws Exception */ void configure(B builder) throws Exception; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurerAdapter.java b/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurerAdapter.java index d5ca809c64..fd25c16d12 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurerAdapter.java +++ b/config/src/main/java/org/springframework/security/config/annotation/SecurityConfigurerAdapter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; import java.util.ArrayList; @@ -20,6 +21,7 @@ import java.util.List; import org.springframework.core.GenericTypeResolver; import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.util.Assert; /** * A base class for {@link SecurityConfigurer} that allows subclasses to only implement @@ -27,29 +29,29 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator; * {@link SecurityConfigurer} and when done gaining access to the {@link SecurityBuilder} * that is being configured. * - * @author Rob Winch - * @author Wallace Wadge - * * @param The Object being built by B * @param The Builder that is building O and is configured by * {@link SecurityConfigurerAdapter} + * @author Rob Winch + * @author Wallace Wadge */ -public abstract class SecurityConfigurerAdapter> - implements SecurityConfigurer { +public abstract class SecurityConfigurerAdapter> implements SecurityConfigurer { + private B securityBuilder; private CompositeObjectPostProcessor objectPostProcessor = new CompositeObjectPostProcessor(); + @Override public void init(B builder) throws Exception { } + @Override public void configure(B builder) throws Exception { } /** * Return the {@link SecurityBuilder} when done using the {@link SecurityConfigurer}. * This is useful for method chaining. - * * @return the {@link SecurityBuilder} for further customizations */ public B and() { @@ -58,21 +60,17 @@ public abstract class SecurityConfigurerAdapter> /** * Gets the {@link SecurityBuilder}. Cannot be null. - * * @return the {@link SecurityBuilder} * @throws IllegalStateException if {@link SecurityBuilder} is null */ protected final B getBuilder() { - if (securityBuilder == null) { - throw new IllegalStateException("securityBuilder cannot be null"); - } - return securityBuilder; + Assert.state(this.securityBuilder != null, "securityBuilder cannot be null"); + return this.securityBuilder; } /** * Performs post processing of an object. The default is to delegate to the * {@link ObjectPostProcessor}. - * * @param object the Object to post process * @return the possibly modified Object to use */ @@ -85,7 +83,6 @@ public abstract class SecurityConfigurerAdapter> * Adds an {@link ObjectPostProcessor} to be used for this * {@link SecurityConfigurerAdapter}. The default implementation does nothing to the * object. - * * @param objectPostProcessor the {@link ObjectPostProcessor} to use */ public void addObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { @@ -95,7 +92,6 @@ public abstract class SecurityConfigurerAdapter> /** * Sets the {@link SecurityBuilder} to be used. This is automatically set when using * {@link AbstractConfiguredSecurityBuilder#apply(SecurityConfigurerAdapter)} - * * @param builder the {@link SecurityBuilder} to set */ public void setBuilder(B builder) { @@ -108,16 +104,16 @@ public abstract class SecurityConfigurerAdapter> * * @author Rob Winch */ - private static final class CompositeObjectPostProcessor implements - ObjectPostProcessor { + private static final class CompositeObjectPostProcessor implements ObjectPostProcessor { + private List> postProcessors = new ArrayList<>(); + @Override @SuppressWarnings({ "rawtypes", "unchecked" }) public Object postProcess(Object object) { - for (ObjectPostProcessor opp : postProcessors) { + for (ObjectPostProcessor opp : this.postProcessors) { Class oppClass = opp.getClass(); - Class oppType = GenericTypeResolver.resolveTypeArgument(oppClass, - ObjectPostProcessor.class); + Class oppType = GenericTypeResolver.resolveTypeArgument(oppClass, ObjectPostProcessor.class); if (oppType == null || oppType.isAssignableFrom(object.getClass())) { object = opp.postProcess(object); } @@ -130,11 +126,12 @@ public abstract class SecurityConfigurerAdapter> * @param objectPostProcessor the {@link ObjectPostProcessor} to add * @return true if the {@link ObjectPostProcessor} was added, else false */ - private boolean addObjectPostProcessor( - ObjectPostProcessor objectPostProcessor) { + private boolean addObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { boolean result = this.postProcessors.add(objectPostProcessor); - postProcessors.sort(AnnotationAwareOrderComparator.INSTANCE); + this.postProcessors.sort(AnnotationAwareOrderComparator.INSTANCE); return result; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/ProviderManagerBuilder.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/ProviderManagerBuilder.java index aab2e9eab6..35e6e0ac4a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/ProviderManagerBuilder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/ProviderManagerBuilder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication; import org.springframework.security.authentication.AuthenticationManager; @@ -23,12 +24,11 @@ import org.springframework.security.config.annotation.SecurityBuilder; /** * Interface for operating on a SecurityBuilder that creates a {@link ProviderManager} * - * @author Rob Winch - * * @param the type of the {@link SecurityBuilder} + * @author Rob Winch */ -public interface ProviderManagerBuilder> extends - SecurityBuilder { +public interface ProviderManagerBuilder> + extends SecurityBuilder { /** * Add authentication based upon the custom {@link AuthenticationProvider} that is @@ -36,10 +36,11 @@ public interface ProviderManagerBuilder> ext * customizations must be done externally and the {@link ProviderManagerBuilder} is * returned immediately. * - * Note that an Exception is thrown if an error occurs when adding the {@link AuthenticationProvider}. - * + * Note that an Exception is thrown if an error occurs when adding the + * {@link AuthenticationProvider}. * @return a {@link ProviderManagerBuilder} to allow further authentication to be * provided to the {@link ProviderManagerBuilder} */ B authenticationProvider(AuthenticationProvider authenticationProvider); + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/builders/AuthenticationManagerBuilder.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/builders/AuthenticationManagerBuilder.java index 36b4a97a28..77cef1a1e5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/builders/AuthenticationManagerBuilder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/builders/AuthenticationManagerBuilder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.builders; import java.util.ArrayList; @@ -20,6 +21,7 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.authentication.AuthenticationEventPublisher; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationProvider; @@ -48,15 +50,19 @@ import org.springframework.util.Assert; * @since 3.2 */ public class AuthenticationManagerBuilder - extends - AbstractConfiguredSecurityBuilder + extends AbstractConfiguredSecurityBuilder implements ProviderManagerBuilder { + private final Log logger = LogFactory.getLog(getClass()); private AuthenticationManager parentAuthenticationManager; + private List authenticationProviders = new ArrayList<>(); + private UserDetailsService defaultUserDetailsService; + private Boolean eraseCredentials; + private AuthenticationEventPublisher eventPublisher; /** @@ -71,18 +77,15 @@ public class AuthenticationManagerBuilder * Allows providing a parent {@link AuthenticationManager} that will be tried if this * {@link AuthenticationManager} was unable to attempt to authenticate the provided * {@link Authentication}. - * * @param authenticationManager the {@link AuthenticationManager} that should be used * if the current {@link AuthenticationManager} was unable to attempt to authenticate * the provided {@link Authentication}. * @return the {@link AuthenticationManagerBuilder} for further adding types of * authentication */ - public AuthenticationManagerBuilder parentAuthenticationManager( - AuthenticationManager authenticationManager) { + public AuthenticationManagerBuilder parentAuthenticationManager(AuthenticationManager authenticationManager) { if (authenticationManager instanceof ProviderManager) { - eraseCredentials(((ProviderManager) authenticationManager) - .isEraseCredentialsAfterAuthentication()); + eraseCredentials(((ProviderManager) authenticationManager).isEraseCredentialsAfterAuthentication()); } this.parentAuthenticationManager = authenticationManager; return this; @@ -90,20 +93,16 @@ public class AuthenticationManagerBuilder /** * Sets the {@link AuthenticationEventPublisher} - * * @param eventPublisher the {@link AuthenticationEventPublisher} to use * @return the {@link AuthenticationManagerBuilder} for further customizations */ - public AuthenticationManagerBuilder authenticationEventPublisher( - AuthenticationEventPublisher eventPublisher) { + public AuthenticationManagerBuilder authenticationEventPublisher(AuthenticationEventPublisher eventPublisher) { Assert.notNull(eventPublisher, "AuthenticationEventPublisher cannot be null"); this.eventPublisher = eventPublisher; return this; } /** - * - * * @param eraseCredentials true if {@link AuthenticationManager} should clear the * credentials from the {@link Authentication} object after authenticating * @return the {@link AuthenticationManagerBuilder} for further customizations @@ -124,7 +123,6 @@ public class AuthenticationManagerBuilder * {@link UserDetailsService}'s may override this {@link UserDetailsService} as the * default. *

- * * @return a {@link InMemoryUserDetailsManagerConfigurer} to allow customization of * the in memory authentication * @throws Exception if an error occurs when adding the in memory authentication @@ -141,8 +139,8 @@ public class AuthenticationManagerBuilder * *

* When using with a persistent data store, it is best to add users external of - * configuration using something like Flyway or Liquibase to create the schema and adding + * configuration using something like Flyway or + * Liquibase to create the schema and adding * users to ensure these steps are only done once and that the optimal SQL is used. *

* @@ -154,13 +152,11 @@ public class AuthenticationManagerBuilder * "https://docs.spring.io/spring-security/site/docs/current/reference/htmlsingle/#user-schema" * >User Schema section of the reference for the default schema. *

- * * @return a {@link JdbcUserDetailsManagerConfigurer} to allow customization of the * JDBC authentication * @throws Exception if an error occurs when adding the JDBC authentication */ - public JdbcUserDetailsManagerConfigurer jdbcAuthentication() - throws Exception { + public JdbcUserDetailsManagerConfigurer jdbcAuthentication() throws Exception { return apply(new JdbcUserDetailsManagerConfigurer<>()); } @@ -175,7 +171,6 @@ public class AuthenticationManagerBuilder * {@link UserDetailsService}'s may override this {@link UserDetailsService} as the * default. *

- * * @return a {@link DaoAuthenticationConfigurer} to allow customization of the DAO * authentication * @throws Exception if an error occurs when adding the {@link UserDetailsService} @@ -184,8 +179,7 @@ public class AuthenticationManagerBuilder public DaoAuthenticationConfigurer userDetailsService( T userDetailsService) throws Exception { this.defaultUserDetailsService = userDetailsService; - return apply(new DaoAuthenticationConfigurer<>( - userDetailsService)); + return apply(new DaoAuthenticationConfigurer<>(userDetailsService)); } /** @@ -196,13 +190,11 @@ public class AuthenticationManagerBuilder *

* This method does NOT ensure that a {@link UserDetailsService} is available * for the {@link #getDefaultUserDetailsService()} method. - * * @return a {@link LdapAuthenticationProviderConfigurer} to allow customization of * the LDAP authentication * @throws Exception if an error occurs when adding the LDAP authentication */ - public LdapAuthenticationProviderConfigurer ldapAuthentication() - throws Exception { + public LdapAuthenticationProviderConfigurer ldapAuthentication() throws Exception { return apply(new LdapAuthenticationProviderConfigurer<>()); } @@ -216,13 +208,13 @@ public class AuthenticationManagerBuilder * This method does NOT ensure that the {@link UserDetailsService} is available * for the {@link #getDefaultUserDetailsService()} method. * - * Note that an {@link Exception} might be thrown if an error occurs when adding the {@link AuthenticationProvider}. - * + * Note that an {@link Exception} might be thrown if an error occurs when adding the + * {@link AuthenticationProvider}. * @return a {@link AuthenticationManagerBuilder} to allow further authentication to * be provided to the {@link AuthenticationManagerBuilder} */ - public AuthenticationManagerBuilder authenticationProvider( - AuthenticationProvider authenticationProvider) { + @Override + public AuthenticationManagerBuilder authenticationProvider(AuthenticationProvider authenticationProvider) { this.authenticationProviders.add(authenticationProvider); return this; } @@ -230,16 +222,16 @@ public class AuthenticationManagerBuilder @Override protected ProviderManager performBuild() throws Exception { if (!isConfigured()) { - logger.debug("No authenticationProviders and no parentAuthenticationManager defined. Returning null."); + this.logger.debug("No authenticationProviders and no parentAuthenticationManager defined. Returning null."); return null; } - ProviderManager providerManager = new ProviderManager(authenticationProviders, - parentAuthenticationManager); - if (eraseCredentials != null) { - providerManager.setEraseCredentialsAfterAuthentication(eraseCredentials); + ProviderManager providerManager = new ProviderManager(this.authenticationProviders, + this.parentAuthenticationManager); + if (this.eraseCredentials != null) { + providerManager.setEraseCredentialsAfterAuthentication(this.eraseCredentials); } - if (eventPublisher != null) { - providerManager.setAuthenticationEventPublisher(eventPublisher); + if (this.eventPublisher != null) { + providerManager.setAuthenticationEventPublisher(this.eventPublisher); } providerManager = postProcess(providerManager); return providerManager; @@ -257,17 +249,16 @@ public class AuthenticationManagerBuilder * {@link SecurityConfigurer} that is last could check this method and provide a * default configuration in the {@link SecurityConfigurer#configure(SecurityBuilder)} * method. - * - * @return true, if {@link AuthenticationManagerBuilder} is configured, otherwise false + * @return true, if {@link AuthenticationManagerBuilder} is configured, otherwise + * false */ public boolean isConfigured() { - return !authenticationProviders.isEmpty() || parentAuthenticationManager != null; + return !this.authenticationProviders.isEmpty() || this.parentAuthenticationManager != null; } /** * Gets the default {@link UserDetailsService} for the * {@link AuthenticationManagerBuilder}. The result may be null in some circumstances. - * * @return the default {@link UserDetailsService} for the * {@link AuthenticationManagerBuilder} */ @@ -278,7 +269,6 @@ public class AuthenticationManagerBuilder /** * Captures the {@link UserDetailsService} from any {@link UserDetailsAwareConfigurer} * . - * * @param configurer the {@link UserDetailsAwareConfigurer} to capture the * {@link UserDetailsService} from. * @return the {@link UserDetailsAwareConfigurer} for further customizations @@ -289,4 +279,5 @@ public class AuthenticationManagerBuilder this.defaultUserDetailsService = configurer.getUserDetailsService(); return super.apply(configurer); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfiguration.java index a9a7b088af..821a88d3d2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfiguration.java @@ -13,10 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configuration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.aop.framework.ProxyFactoryBean; import org.springframework.aop.target.LazyInitTargetSource; import org.springframework.beans.factory.BeanFactoryUtils; @@ -28,6 +36,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationEventPublisher; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.annotation.ObjectPostProcessor; @@ -43,12 +52,6 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.util.Assert; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.ArrayList; -import java.util.concurrent.atomic.AtomicBoolean; - /** * Exports the authentication {@link Configuration} * @@ -68,18 +71,18 @@ public class AuthenticationConfiguration { private boolean authenticationManagerInitialized; - private List globalAuthConfigurers = Collections - .emptyList(); + private List globalAuthConfigurers = Collections.emptyList(); private ObjectPostProcessor objectPostProcessor; @Bean - public AuthenticationManagerBuilder authenticationManagerBuilder( - ObjectPostProcessor objectPostProcessor, ApplicationContext context) { + public AuthenticationManagerBuilder authenticationManagerBuilder(ObjectPostProcessor objectPostProcessor, + ApplicationContext context) { LazyPasswordEncoder defaultPasswordEncoder = new LazyPasswordEncoder(context); - AuthenticationEventPublisher authenticationEventPublisher = getBeanOrNull(context, AuthenticationEventPublisher.class); - - DefaultPasswordEncoderAuthenticationManagerBuilder result = new DefaultPasswordEncoderAuthenticationManagerBuilder(objectPostProcessor, defaultPasswordEncoder); + AuthenticationEventPublisher authenticationEventPublisher = getBeanOrNull(context, + AuthenticationEventPublisher.class); + DefaultPasswordEncoderAuthenticationManagerBuilder result = new DefaultPasswordEncoderAuthenticationManagerBuilder( + objectPostProcessor, defaultPasswordEncoder); if (authenticationEventPublisher != null) { result.authenticationEventPublisher(authenticationEventPublisher); } @@ -93,12 +96,14 @@ public class AuthenticationConfiguration { } @Bean - public static InitializeUserDetailsBeanManagerConfigurer initializeUserDetailsBeanManagerConfigurer(ApplicationContext context) { + public static InitializeUserDetailsBeanManagerConfigurer initializeUserDetailsBeanManagerConfigurer( + ApplicationContext context) { return new InitializeUserDetailsBeanManagerConfigurer(context); } @Bean - public static InitializeAuthenticationProviderBeanManagerConfigurer initializeAuthenticationProviderBeanManagerConfigurer(ApplicationContext context) { + public static InitializeAuthenticationProviderBeanManagerConfigurer initializeAuthenticationProviderBeanManagerConfigurer( + ApplicationContext context) { return new InitializeAuthenticationProviderBeanManagerConfigurer(context); } @@ -110,24 +115,19 @@ public class AuthenticationConfiguration { if (this.buildingAuthenticationManager.getAndSet(true)) { return new AuthenticationManagerDelegator(authBuilder); } - - for (GlobalAuthenticationConfigurerAdapter config : globalAuthConfigurers) { + for (GlobalAuthenticationConfigurerAdapter config : this.globalAuthConfigurers) { authBuilder.apply(config); } - - authenticationManager = authBuilder.build(); - - if (authenticationManager == null) { - authenticationManager = getAuthenticationManagerBean(); + this.authenticationManager = authBuilder.build(); + if (this.authenticationManager == null) { + this.authenticationManager = getAuthenticationManagerBean(); } - this.authenticationManagerInitialized = true; - return authenticationManager; + return this.authenticationManager; } @Autowired(required = false) - public void setGlobalAuthenticationConfigurers( - List configurers) { + public void setGlobalAuthenticationConfigurers(List configurers) { configurers.sort(AnnotationAwareOrderComparator.INSTANCE); this.globalAuthConfigurers = configurers; } @@ -145,40 +145,40 @@ public class AuthenticationConfiguration { @SuppressWarnings("unchecked") private T lazyBean(Class interfaceName) { LazyInitTargetSource lazyTargetSource = new LazyInitTargetSource(); - String[] beanNamesForType = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( - applicationContext, interfaceName); + String[] beanNamesForType = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.applicationContext, + interfaceName); if (beanNamesForType.length == 0) { return null; } - String beanName; - if (beanNamesForType.length > 1) { - List primaryBeanNames = getPrimaryBeanNames(beanNamesForType); - - Assert.isTrue(primaryBeanNames.size() != 0, () -> "Found " + beanNamesForType.length - + " beans for type " + interfaceName + ", but none marked as primary"); - Assert.isTrue(primaryBeanNames.size() == 1, () -> "Found " + primaryBeanNames.size() - + " beans for type " + interfaceName + " marked as primary"); - beanName = primaryBeanNames.get(0); - } else { - beanName = beanNamesForType[0]; - } - + String beanName = getBeanName(interfaceName, beanNamesForType); lazyTargetSource.setTargetBeanName(beanName); - lazyTargetSource.setBeanFactory(applicationContext); + lazyTargetSource.setBeanFactory(this.applicationContext); ProxyFactoryBean proxyFactory = new ProxyFactoryBean(); - proxyFactory = objectPostProcessor.postProcess(proxyFactory); + proxyFactory = this.objectPostProcessor.postProcess(proxyFactory); proxyFactory.setTargetSource(lazyTargetSource); return (T) proxyFactory.getObject(); } + private String getBeanName(Class interfaceName, String[] beanNamesForType) { + if (beanNamesForType.length == 1) { + return beanNamesForType[0]; + } + List primaryBeanNames = getPrimaryBeanNames(beanNamesForType); + Assert.isTrue(primaryBeanNames.size() != 0, () -> "Found " + beanNamesForType.length + " beans for type " + + interfaceName + ", but none marked as primary"); + Assert.isTrue(primaryBeanNames.size() == 1, + () -> "Found " + primaryBeanNames.size() + " beans for type " + interfaceName + " marked as primary"); + return primaryBeanNames.get(0); + } + private List getPrimaryBeanNames(String[] beanNamesForType) { List list = new ArrayList<>(); - if (!(applicationContext instanceof ConfigurableApplicationContext)) { + if (!(this.applicationContext instanceof ConfigurableApplicationContext)) { return Collections.emptyList(); } for (String beanName : beanNamesForType) { - if (((ConfigurableApplicationContext) applicationContext).getBeanFactory() - .getBeanDefinition(beanName).isPrimary()) { + if (((ConfigurableApplicationContext) this.applicationContext).getBeanFactory().getBeanDefinition(beanName) + .isPrimary()) { list.add(beanName); } } @@ -192,16 +192,17 @@ public class AuthenticationConfiguration { private static T getBeanOrNull(ApplicationContext applicationContext, Class type) { try { return applicationContext.getBean(type); - } catch(NoSuchBeanDefinitionException notFound) { + } + catch (NoSuchBeanDefinitionException notFound) { return null; } } - private static class EnableGlobalAuthenticationAutowiredConfigurer extends - GlobalAuthenticationConfigurerAdapter { + private static class EnableGlobalAuthenticationAutowiredConfigurer extends GlobalAuthenticationConfigurerAdapter { + private final ApplicationContext context; - private static final Log logger = LogFactory - .getLog(EnableGlobalAuthenticationAutowiredConfigurer.class); + + private static final Log logger = LogFactory.getLog(EnableGlobalAuthenticationAutowiredConfigurer.class); EnableGlobalAuthenticationAutowiredConfigurer(ApplicationContext context) { this.context = context; @@ -209,12 +210,11 @@ public class AuthenticationConfiguration { @Override public void init(AuthenticationManagerBuilder auth) { - Map beansWithAnnotation = context + Map beansWithAnnotation = this.context .getBeansWithAnnotation(EnableGlobalAuthentication.class); - if (logger.isDebugEnabled()) { - logger.debug("Eagerly initializing " + beansWithAnnotation); - } + logger.debug(LogMessage.format("Eagerly initializing %s", beansWithAnnotation)); } + } /** @@ -225,8 +225,11 @@ public class AuthenticationConfiguration { * @since 4.1.1 */ static final class AuthenticationManagerDelegator implements AuthenticationManager { + private AuthenticationManagerBuilder delegateBuilder; + private AuthenticationManager delegate; + private final Object delegateMonitor = new Object(); AuthenticationManagerDelegator(AuthenticationManagerBuilder delegateBuilder) { @@ -235,19 +238,16 @@ public class AuthenticationConfiguration { } @Override - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + public Authentication authenticate(Authentication authentication) throws AuthenticationException { if (this.delegate != null) { return this.delegate.authenticate(authentication); } - synchronized (this.delegateMonitor) { if (this.delegate == null) { this.delegate = this.delegateBuilder.getObject(); this.delegateBuilder = null; } } - return this.delegate.authenticate(authentication); } @@ -255,46 +255,46 @@ public class AuthenticationConfiguration { public String toString() { return "AuthenticationManagerDelegator [delegate=" + this.delegate + "]"; } + } static class DefaultPasswordEncoderAuthenticationManagerBuilder extends AuthenticationManagerBuilder { + private PasswordEncoder defaultPasswordEncoder; /** * Creates a new instance - * * @param objectPostProcessor the {@link ObjectPostProcessor} instance to use. */ - DefaultPasswordEncoderAuthenticationManagerBuilder( - ObjectPostProcessor objectPostProcessor, PasswordEncoder defaultPasswordEncoder) { + DefaultPasswordEncoderAuthenticationManagerBuilder(ObjectPostProcessor objectPostProcessor, + PasswordEncoder defaultPasswordEncoder) { super(objectPostProcessor); this.defaultPasswordEncoder = defaultPasswordEncoder; } @Override public InMemoryUserDetailsManagerConfigurer inMemoryAuthentication() - throws Exception { - return super.inMemoryAuthentication() - .passwordEncoder(this.defaultPasswordEncoder); + throws Exception { + return super.inMemoryAuthentication().passwordEncoder(this.defaultPasswordEncoder); } @Override - public JdbcUserDetailsManagerConfigurer jdbcAuthentication() - throws Exception { - return super.jdbcAuthentication() - .passwordEncoder(this.defaultPasswordEncoder); + public JdbcUserDetailsManagerConfigurer jdbcAuthentication() throws Exception { + return super.jdbcAuthentication().passwordEncoder(this.defaultPasswordEncoder); } @Override public DaoAuthenticationConfigurer userDetailsService( - T userDetailsService) throws Exception { - return super.userDetailsService(userDetailsService) - .passwordEncoder(this.defaultPasswordEncoder); + T userDetailsService) throws Exception { + return super.userDetailsService(userDetailsService).passwordEncoder(this.defaultPasswordEncoder); } + } static class LazyPasswordEncoder implements PasswordEncoder { + private ApplicationContext applicationContext; + private PasswordEncoder passwordEncoder; LazyPasswordEncoder(ApplicationContext applicationContext) { @@ -307,8 +307,7 @@ public class AuthenticationConfiguration { } @Override - public boolean matches(CharSequence rawPassword, - String encodedPassword) { + public boolean matches(CharSequence rawPassword, String encodedPassword) { return getPasswordEncoder().matches(rawPassword, encodedPassword); } @@ -333,5 +332,7 @@ public class AuthenticationConfiguration { public String toString() { return getPasswordEncoder().toString(); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthentication.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthentication.java index 4a41eb75bf..acc8fef818 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthentication.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthentication.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configuration; import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import org.springframework.context.annotation.Configuration; @@ -81,10 +84,11 @@ import org.springframework.security.config.annotation.web.servlet.configuration. * @author Rob Winch * */ -@Retention(value = java.lang.annotation.RetentionPolicy.RUNTIME) -@Target(value = { java.lang.annotation.ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) @Documented @Import(AuthenticationConfiguration.class) @Configuration public @interface EnableGlobalAuthentication { + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/GlobalAuthenticationConfigurerAdapter.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/GlobalAuthenticationConfigurerAdapter.java index c3d6e253d4..20e4addef7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/GlobalAuthenticationConfigurerAdapter.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/GlobalAuthenticationConfigurerAdapter.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configuration; import org.springframework.core.annotation.Order; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.annotation.SecurityConfigurer; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; -import org.springframework.security.config.annotation.authentication.configuration.AuthenticationConfiguration; /** * A {@link SecurityConfigurer} that can be exposed as a bean to configure the global @@ -31,12 +31,15 @@ import org.springframework.security.config.annotation.authentication.configurati * @author Rob Winch */ @Order(100) -public abstract class GlobalAuthenticationConfigurerAdapter implements - SecurityConfigurer { +public abstract class GlobalAuthenticationConfigurerAdapter + implements SecurityConfigurer { + @Override public void init(AuthenticationManagerBuilder auth) throws Exception { } + @Override public void configure(AuthenticationManagerBuilder auth) throws Exception { } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeAuthenticationProviderBeanManagerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeAuthenticationProviderBeanManagerConfigurer.java index 6aa64b8dfa..3958699f4b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeAuthenticationProviderBeanManagerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeAuthenticationProviderBeanManagerConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configuration; import org.springframework.context.ApplicationContext; @@ -21,26 +22,23 @@ import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; /** - * Lazily initializes the global authentication with an {@link AuthenticationProvider} if it is - * not yet configured and there is only a single Bean of that type. + * Lazily initializes the global authentication with an {@link AuthenticationProvider} if + * it is not yet configured and there is only a single Bean of that type. * * @author Rob Winch * @since 4.1 */ @Order(InitializeAuthenticationProviderBeanManagerConfigurer.DEFAULT_ORDER) -class InitializeAuthenticationProviderBeanManagerConfigurer - extends GlobalAuthenticationConfigurerAdapter { +class InitializeAuthenticationProviderBeanManagerConfigurer extends GlobalAuthenticationConfigurerAdapter { - static final int DEFAULT_ORDER = InitializeUserDetailsBeanManagerConfigurer.DEFAULT_ORDER - - 100; + static final int DEFAULT_ORDER = InitializeUserDetailsBeanManagerConfigurer.DEFAULT_ORDER - 100; private final ApplicationContext context; /** * @param context the ApplicationContext to look up beans. */ - InitializeAuthenticationProviderBeanManagerConfigurer( - ApplicationContext context) { + InitializeAuthenticationProviderBeanManagerConfigurer(ApplicationContext context) { this.context = context; } @@ -49,25 +47,23 @@ class InitializeAuthenticationProviderBeanManagerConfigurer auth.apply(new InitializeAuthenticationProviderManagerConfigurer()); } - class InitializeAuthenticationProviderManagerConfigurer - extends GlobalAuthenticationConfigurerAdapter { + class InitializeAuthenticationProviderManagerConfigurer extends GlobalAuthenticationConfigurerAdapter { + @Override public void configure(AuthenticationManagerBuilder auth) { if (auth.isConfigured()) { return; } - AuthenticationProvider authenticationProvider = getBeanOrNull( - AuthenticationProvider.class); + AuthenticationProvider authenticationProvider = getBeanOrNull(AuthenticationProvider.class); if (authenticationProvider == null) { return; } - - auth.authenticationProvider(authenticationProvider); } /** - * @return a bean of the requested class if there's just a single registered component, null otherwise. + * @return a bean of the requested class if there's just a single registered + * component, null otherwise. */ private T getBeanOrNull(Class type) { String[] beanNames = InitializeAuthenticationProviderBeanManagerConfigurer.this.context @@ -75,9 +71,9 @@ class InitializeAuthenticationProviderBeanManagerConfigurer if (beanNames.length != 1) { return null; } - - return InitializeAuthenticationProviderBeanManagerConfigurer.this.context - .getBean(beanNames[0], type); + return InitializeAuthenticationProviderBeanManagerConfigurer.this.context.getBean(beanNames[0], type); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeUserDetailsBeanManagerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeUserDetailsBeanManagerConfigurer.java index 14013dbf0d..07fff8886b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeUserDetailsBeanManagerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configuration/InitializeUserDetailsBeanManagerConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configuration; import org.springframework.context.ApplicationContext; @@ -20,9 +21,9 @@ import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.security.authentication.dao.DaoAuthenticationProvider; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; +import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.crypto.password.PasswordEncoder; -import org.springframework.security.core.userdetails.UserDetailsPasswordService; /** * Lazily initializes the global authentication with a {@link UserDetailsService} if it is @@ -33,8 +34,7 @@ import org.springframework.security.core.userdetails.UserDetailsPasswordService; * @since 4.1 */ @Order(InitializeUserDetailsBeanManagerConfigurer.DEFAULT_ORDER) -class InitializeUserDetailsBeanManagerConfigurer - extends GlobalAuthenticationConfigurerAdapter { +class InitializeUserDetailsBeanManagerConfigurer extends GlobalAuthenticationConfigurerAdapter { static final int DEFAULT_ORDER = Ordered.LOWEST_PRECEDENCE - 5000; @@ -52,22 +52,19 @@ class InitializeUserDetailsBeanManagerConfigurer auth.apply(new InitializeUserDetailsManagerConfigurer()); } - class InitializeUserDetailsManagerConfigurer - extends GlobalAuthenticationConfigurerAdapter { + class InitializeUserDetailsManagerConfigurer extends GlobalAuthenticationConfigurerAdapter { + @Override public void configure(AuthenticationManagerBuilder auth) throws Exception { if (auth.isConfigured()) { return; } - UserDetailsService userDetailsService = getBeanOrNull( - UserDetailsService.class); + UserDetailsService userDetailsService = getBeanOrNull(UserDetailsService.class); if (userDetailsService == null) { return; } - PasswordEncoder passwordEncoder = getBeanOrNull(PasswordEncoder.class); UserDetailsPasswordService passwordManager = getBeanOrNull(UserDetailsPasswordService.class); - DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setUserDetailsService(userDetailsService); if (passwordEncoder != null) { @@ -77,22 +74,21 @@ class InitializeUserDetailsBeanManagerConfigurer provider.setUserDetailsPasswordService(passwordManager); } provider.afterPropertiesSet(); - auth.authenticationProvider(provider); } /** - * @return a bean of the requested class if there's just a single registered component, null otherwise. + * @return a bean of the requested class if there's just a single registered + * component, null otherwise. */ private T getBeanOrNull(Class type) { - String[] beanNames = InitializeUserDetailsBeanManagerConfigurer.this.context - .getBeanNamesForType(type); + String[] beanNames = InitializeUserDetailsBeanManagerConfigurer.this.context.getBeanNamesForType(type); if (beanNames.length != 1) { return null; } - - return InitializeUserDetailsBeanManagerConfigurer.this.context - .getBean(beanNames[0], type); + return InitializeUserDetailsBeanManagerConfigurer.this.context.getBean(beanNames[0], type); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurer.java index 9c6e2bae85..ac956837b7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configurers.ldap; import java.io.IOException; @@ -52,48 +53,58 @@ import org.springframework.util.ClassUtils; * Configures LDAP {@link AuthenticationProvider} in the {@link ProviderManagerBuilder}. * * @param the {@link ProviderManagerBuilder} type that this is configuring. - * * @author Rob Winch * @author Eddú Meléndez * @since 3.2 */ public class LdapAuthenticationProviderConfigurer> extends SecurityConfigurerAdapter { + private String groupRoleAttribute = "cn"; + private String groupSearchBase = ""; + private boolean groupSearchSubtree = false; + private String groupSearchFilter = "(uniqueMember={0})"; + private String rolePrefix = "ROLE_"; + private String userSearchBase = ""; // only for search + private String userSearchFilter = null; // "uid={0}"; // only for search + private String[] userDnPatterns; + private BaseLdapPathContextSource contextSource; + private ContextSourceBuilder contextSourceBuilder = new ContextSourceBuilder(); + private UserDetailsContextMapper userDetailsContextMapper; + private PasswordEncoder passwordEncoder; + private String passwordAttribute; + private LdapAuthoritiesPopulator ldapAuthoritiesPopulator; + private GrantedAuthoritiesMapper authoritiesMapper; private LdapAuthenticationProvider build() throws Exception { BaseLdapPathContextSource contextSource = getContextSource(); LdapAuthenticator ldapAuthenticator = createLdapAuthenticator(contextSource); - LdapAuthoritiesPopulator authoritiesPopulator = getLdapAuthoritiesPopulator(); - - LdapAuthenticationProvider ldapAuthenticationProvider = new LdapAuthenticationProvider( - ldapAuthenticator, authoritiesPopulator); + LdapAuthenticationProvider ldapAuthenticationProvider = new LdapAuthenticationProvider(ldapAuthenticator, + authoritiesPopulator); ldapAuthenticationProvider.setAuthoritiesMapper(getAuthoritiesMapper()); - if (userDetailsContextMapper != null) { - ldapAuthenticationProvider - .setUserDetailsContextMapper(userDetailsContextMapper); + if (this.userDetailsContextMapper != null) { + ldapAuthenticationProvider.setUserDetailsContextMapper(this.userDetailsContextMapper); } return ldapAuthenticationProvider; } /** * Specifies the {@link LdapAuthoritiesPopulator}. - * * @param ldapAuthoritiesPopulator the {@link LdapAuthoritiesPopulator} the default is * {@link DefaultLdapAuthoritiesPopulator} * @return the {@link LdapAuthenticationProviderConfigurer} for further customizations @@ -106,12 +117,10 @@ public class LdapAuthenticationProviderConfigurer withObjectPostProcessor( - ObjectPostProcessor objectPostProcessor) { + public LdapAuthenticationProviderConfigurer withObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { addObjectPostProcessor(objectPostProcessor); return this; } @@ -119,51 +128,47 @@ public class LdapAuthenticationProviderConfigurer authoritiesMapper(GrantedAuthoritiesMapper grantedAuthoritiesMapper) { + public LdapAuthenticationProviderConfigurer authoritiesMapper( + GrantedAuthoritiesMapper grantedAuthoritiesMapper) { this.authoritiesMapper = grantedAuthoritiesMapper; return this; } /** - * Gets the {@link GrantedAuthoritiesMapper} and defaults to {@link SimpleAuthorityMapper}. - * + * Gets the {@link GrantedAuthoritiesMapper} and defaults to + * {@link SimpleAuthorityMapper}. * @return the {@link GrantedAuthoritiesMapper} * @throws Exception if errors in {@link SimpleAuthorityMapper#afterPropertiesSet()} */ protected GrantedAuthoritiesMapper getAuthoritiesMapper() throws Exception { - if (authoritiesMapper != null) { - return authoritiesMapper; + if (this.authoritiesMapper != null) { + return this.authoritiesMapper; } - SimpleAuthorityMapper simpleAuthorityMapper = new SimpleAuthorityMapper(); simpleAuthorityMapper.setPrefix(this.rolePrefix); simpleAuthorityMapper.afterPropertiesSet(); @@ -173,70 +178,61 @@ public class LdapAuthenticationProviderConfigurer 0) { - ldapAuthenticator.setUserDnPatterns(userDnPatterns); + if (this.userDnPatterns != null && this.userDnPatterns.length > 0) { + ldapAuthenticator.setUserDnPatterns(this.userDnPatterns); } return postProcess(ldapAuthenticator); } /** * Creates {@link PasswordComparisonAuthenticator} - * * @param contextSource the {@link BaseLdapPathContextSource} to use * @return */ private PasswordComparisonAuthenticator createPasswordCompareAuthenticator( BaseLdapPathContextSource contextSource) { - PasswordComparisonAuthenticator ldapAuthenticator = new PasswordComparisonAuthenticator( - contextSource); - if (passwordAttribute != null) { - ldapAuthenticator.setPasswordAttributeName(passwordAttribute); + PasswordComparisonAuthenticator ldapAuthenticator = new PasswordComparisonAuthenticator(contextSource); + if (this.passwordAttribute != null) { + ldapAuthenticator.setPasswordAttributeName(this.passwordAttribute); } - ldapAuthenticator.setPasswordEncoder(passwordEncoder); + ldapAuthenticator.setPasswordEncoder(this.passwordEncoder); return ldapAuthenticator; } /** * Creates a {@link BindAuthenticator} - * * @param contextSource the {@link BaseLdapPathContextSource} to use * @return the {@link BindAuthenticator} to use */ - private BindAuthenticator createBindAuthenticator( - BaseLdapPathContextSource contextSource) { + private BindAuthenticator createBindAuthenticator(BaseLdapPathContextSource contextSource) { return new BindAuthenticator(contextSource); } private LdapUserSearch createUserSearch() { - if (userSearchFilter == null) { + if (this.userSearchFilter == null) { return null; } - return new FilterBasedLdapUserSearch(userSearchBase, userSearchFilter, - contextSource); + return new FilterBasedLdapUserSearch(this.userSearchBase, this.userSearchFilter, this.contextSource); } /** * Specifies the {@link BaseLdapPathContextSource} to be used. If not specified, an * embedded LDAP server will be created using {@link #contextSource()}. - * * @param contextSource the {@link BaseLdapPathContextSource} to use * @return the {@link LdapAuthenticationProviderConfigurer} for further customizations * @see #contextSource() */ - public LdapAuthenticationProviderConfigurer contextSource( - BaseLdapPathContextSource contextSource) { + public LdapAuthenticationProviderConfigurer contextSource(BaseLdapPathContextSource contextSource) { this.contextSource = contextSource; return this; } @@ -244,17 +240,15 @@ public class LdapAuthenticationProviderConfigurer userDnPatterns( - String... userDnPatterns) { + public LdapAuthenticationProviderConfigurer userDnPatterns(String... userDnPatterns) { this.userDnPatterns = userDnPatterns; return this; } @@ -287,7 +279,6 @@ public class LdapAuthenticationProviderConfigurer groupRoleAttribute( - String groupRoleAttribute) { + public LdapAuthenticationProviderConfigurer groupRoleAttribute(String groupRoleAttribute) { this.groupRoleAttribute = groupRoleAttribute; return this; } @@ -323,11 +313,10 @@ public class LdapAuthenticationProviderConfigurergroupSearchBase. + * groupSearchBase. * @return the {@link LdapAuthenticationProviderConfigurer} for further customizations */ public LdapAuthenticationProviderConfigurer groupSearchSubtree(boolean groupSearchSubtree) { @@ -338,12 +327,10 @@ public class LdapAuthenticationProviderConfigurer groupSearchFilter( - String groupSearchFilter) { + public LdapAuthenticationProviderConfigurer groupSearchFilter(String groupSearchFilter) { this.groupSearchFilter = groupSearchFilter; return this; } @@ -351,7 +338,6 @@ public class LdapAuthenticationProviderConfigurer userSearchFilter( - String userSearchFilter) { + public LdapAuthenticationProviderConfigurer userSearchFilter(String userSearchFilter) { this.userSearchFilter = userSearchFilter; return this; } @@ -392,6 +375,21 @@ public class LdapAuthenticationProviderConfigurer and() { @@ -435,6 +431,7 @@ public class LdapAuthenticationProviderConfigurer the type of the {@link ProviderManagerBuilder} that is being configured - * * @author Rob Winch * @since 3.2 */ @@ -40,4 +40,5 @@ public class InMemoryUserDetailsManagerConfigurer())); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/JdbcUserDetailsManagerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/JdbcUserDetailsManagerConfigurer.java index 035bd7877f..1020cd6207 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/JdbcUserDetailsManagerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/JdbcUserDetailsManagerConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configurers.provisioning; import java.util.ArrayList; @@ -40,7 +41,6 @@ import org.springframework.security.provisioning.JdbcUserDetailsManager; * methods have reasonable defaults. * * @param the type of the {@link ProviderManagerBuilder} that is being configured - * * @author Rob Winch * @since 3.2 */ @@ -61,9 +61,9 @@ public class JdbcUserDetailsManagerConfigurer dataSource(DataSource dataSource) { this.dataSource = dataSource; @@ -94,7 +94,6 @@ public class JdbcUserDetailsManagerConfigurer * select username,authority from authorities where username = ? * - * * @param query The query to use for selecting the username, authority by username. * Must contain a single parameter for the username. * @return The {@link JdbcUserDetailsManagerConfigurer} used for additional @@ -116,7 +115,6 @@ public class JdbcUserDetailsManagerConfigurer - * * @param query The query to use for selecting the authorities by group. Must contain * a single parameter for the username. * @return The {@link JdbcUserDetailsManagerConfigurer} used for additional @@ -132,9 +130,9 @@ public class JdbcUserDetailsManagerConfigurer rolePrefix(String rolePrefix) { getUserDetailsService().setRolePrefix(rolePrefix); @@ -143,7 +141,6 @@ public class JdbcUserDetailsManagerConfigurer withDefaultSchema() { - this.initScripts.add(new ClassPathResource( - "org/springframework/security/core/userdetails/jdbc/users.ddl")); + this.initScripts.add(new ClassPathResource("org/springframework/security/core/userdetails/jdbc/users.ddl")); return this; } protected DatabasePopulator getDatabasePopulator() { ResourceDatabasePopulator dbp = new ResourceDatabasePopulator(); - dbp.setScripts(initScripts.toArray(new Resource[0])); + dbp.setScripts(this.initScripts.toArray(new Resource[0])); return dbp; } private DataSourceInitializer getDataSourceInit() { DataSourceInitializer dsi = new DataSourceInitializer(); dsi.setDatabasePopulator(getDatabasePopulator()); - dsi.setDataSource(dataSource); + dsi.setDataSource(this.dataSource); return dsi; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurer.java index 0ae53f7e63..5b9cf34b97 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configurers.provisioning; import java.util.ArrayList; @@ -34,7 +35,6 @@ import org.springframework.security.provisioning.UserDetailsManager; * * @param the type of the {@link SecurityBuilder} that is being configured * @param the type of {@link UserDetailsManagerConfigurer} - * * @author Rob Winch * @since 3.2 */ @@ -51,12 +51,11 @@ public class UserDetailsManagerConfigurer, C /** * Populates the users that have been added. - * * @throws Exception */ @Override protected void initUserDetailsService() throws Exception { - for (UserDetailsBuilder userBuilder : userBuilders) { + for (UserDetailsBuilder userBuilder : this.userBuilders) { getUserDetailsService().createUser(userBuilder.build()); } for (UserDetails userDetails : this.users) { @@ -67,7 +66,6 @@ public class UserDetailsManagerConfigurer, C /** * Allows adding a user to the {@link UserDetailsManager} that is being created. This * method can be invoked multiple times to add multiple users. - * * @param userDetails the user to add. Cannot be null. * @return the {@link UserDetailsBuilder} for further customizations */ @@ -80,7 +78,6 @@ public class UserDetailsManagerConfigurer, C /** * Allows adding a user to the {@link UserDetailsManager} that is being created. This * method can be invoked multiple times to add multiple users. - * * @param userBuilder the user to add. Cannot be null. * @return the {@link UserDetailsBuilder} for further customizations */ @@ -93,7 +90,6 @@ public class UserDetailsManagerConfigurer, C /** * Allows adding a user to the {@link UserDetailsManager} that is being created. This * method can be invoked multiple times to add multiple users. - * * @param username the username for the user being added. Cannot be null. * @return the {@link UserDetailsBuilder} for further customizations */ @@ -109,8 +105,10 @@ public class UserDetailsManagerConfigurer, C * Builds the user to be added. At minimum the username, password, and authorities * should provided. The remaining attributes have reasonable defaults. */ - public class UserDetailsBuilder { + public final class UserDetailsBuilder { + private UserBuilder user; + private final C builder; /** @@ -122,18 +120,16 @@ public class UserDetailsManagerConfigurer, C } /** - * Returns the {@link UserDetailsManagerConfigurer} for method chaining (i.e. to add - * another user) - * + * Returns the {@link UserDetailsManagerConfigurer} for method chaining (i.e. to + * add another user) * @return the {@link UserDetailsManagerConfigurer} for method chaining */ public C and() { - return builder; + return this.builder; } /** * Populates the username. This attribute is required. - * * @param username the username. Cannot be null. * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -145,7 +141,6 @@ public class UserDetailsManagerConfigurer, C /** * Populates the password. This attribute is required. - * * @param password the password. Cannot be null. * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -174,7 +169,6 @@ public class UserDetailsManagerConfigurer, C * This attribute is required, but can also be populated with * {@link #authorities(String...)}. *

- * * @param roles the roles for this user (i.e. USER, ADMIN, etc). Cannot be null, * contain null values or start with "ROLE_" * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate @@ -187,7 +181,6 @@ public class UserDetailsManagerConfigurer, C /** * Populates the authorities. This attribute is required. - * * @param authorities the authorities for this user. Cannot be null, or contain * null values * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate @@ -201,7 +194,6 @@ public class UserDetailsManagerConfigurer, C /** * Populates the authorities. This attribute is required. - * * @param authorities the authorities for this user. Cannot be null, or contain * null values * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate @@ -215,7 +207,6 @@ public class UserDetailsManagerConfigurer, C /** * Populates the authorities. This attribute is required. - * * @param authorities the authorities for this user (i.e. ROLE_USER, ROLE_ADMIN, * etc). Cannot be null, or contain null values * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate @@ -229,7 +220,6 @@ public class UserDetailsManagerConfigurer, C /** * Defines if the account is expired or not. Default is false. - * * @param accountExpired true if the account is expired, false otherwise * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -241,7 +231,6 @@ public class UserDetailsManagerConfigurer, C /** * Defines if the account is locked or not. Default is false. - * * @param accountLocked true if the account is locked, false otherwise * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -253,7 +242,6 @@ public class UserDetailsManagerConfigurer, C /** * Defines if the credentials are expired or not. Default is false. - * * @param credentialsExpired true if the credentials are expired, false otherwise * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -265,7 +253,6 @@ public class UserDetailsManagerConfigurer, C /** * Defines if the account is disabled or not. Default is false. - * * @param disabled true if the account is disabled, false otherwise * @return the {@link UserDetailsBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -278,5 +265,7 @@ public class UserDetailsManagerConfigurer, C UserDetails build() { return this.user.build(); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/AbstractDaoAuthenticationConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/AbstractDaoAuthenticationConfigurer.java index 0cf004160f..6acd120958 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/AbstractDaoAuthenticationConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/AbstractDaoAuthenticationConfigurer.java @@ -13,40 +13,40 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configurers.userdetails; import org.springframework.security.authentication.dao.DaoAuthenticationProvider; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.SecurityBuilder; import org.springframework.security.config.annotation.authentication.ProviderManagerBuilder; +import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.crypto.password.PasswordEncoder; -import org.springframework.security.core.userdetails.UserDetailsPasswordService; /** * Allows configuring a {@link DaoAuthenticationProvider} * - * @author Rob Winch - * @since 3.2 - * * @param the type of the {@link SecurityBuilder} * @param the type of {@link AbstractDaoAuthenticationConfigurer} this is * @param The type of {@link UserDetailsService} that is being used - * + * @author Rob Winch + * @since 3.2 */ -abstract class AbstractDaoAuthenticationConfigurer, C extends AbstractDaoAuthenticationConfigurer, U extends UserDetailsService> +public abstract class AbstractDaoAuthenticationConfigurer, C extends AbstractDaoAuthenticationConfigurer, U extends UserDetailsService> extends UserDetailsAwareConfigurer { + private DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); + private final U userDetailsService; /** * Creates a new instance - * * @param userDetailsService */ - protected AbstractDaoAuthenticationConfigurer(U userDetailsService) { + AbstractDaoAuthenticationConfigurer(U userDetailsService) { this.userDetailsService = userDetailsService; - provider.setUserDetailsService(userDetailsService); + this.provider.setUserDetailsService(userDetailsService); if (userDetailsService instanceof UserDetailsPasswordService) { this.provider.setUserDetailsPasswordService((UserDetailsPasswordService) userDetailsService); } @@ -54,7 +54,6 @@ abstract class AbstractDaoAuthenticationConfigurer The type of {@link ProviderManagerBuilder} this is * @param The type of {@link UserDetailsService} that is being used - * + * @author Rob Winch + * @since 3.2 */ public class DaoAuthenticationConfigurer, U extends UserDetailsService> - extends - AbstractDaoAuthenticationConfigurer, U> { + extends AbstractDaoAuthenticationConfigurer, U> { /** * Creates a new instance @@ -40,4 +38,5 @@ public class DaoAuthenticationConfigurer, U public DaoAuthenticationConfigurer(U userDetailsService) { super(userDetailsService); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/UserDetailsAwareConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/UserDetailsAwareConfigurer.java index 7ba27523ec..a09e64ace8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/UserDetailsAwareConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/authentication/configurers/userdetails/UserDetailsAwareConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configurers.userdetails; import org.springframework.security.authentication.AuthenticationManager; @@ -25,10 +26,9 @@ import org.springframework.security.core.userdetails.UserDetailsService; * Base class that allows access to the {@link UserDetailsService} for using as a default * value with {@link AuthenticationManagerBuilder}. * - * @author Rob Winch - * * @param the type of the {@link ProviderManagerBuilder} * @param the type of {@link UserDetailsService} + * @author Rob Winch */ public abstract class UserDetailsAwareConfigurer, U extends UserDetailsService> extends SecurityConfigurerAdapter { @@ -38,4 +38,5 @@ public abstract class UserDetailsAwareConfigurer the type of the {@link ProviderManagerBuilder} * @param the {@link UserDetailsServiceConfigurer} (or this) * @param the type of UserDetailsService being used to allow for returning the * concrete UserDetailsService. + * @author Rob Winch + * @since 3.2 */ public class UserDetailsServiceConfigurer, C extends UserDetailsServiceConfigurer, U extends UserDetailsService> extends AbstractDaoAuthenticationConfigurer { @@ -45,7 +45,6 @@ public class UserDetailsServiceConfigurer, C @Override public void configure(B builder) throws Exception { initUserDetailsService(); - super.configure(builder); } @@ -55,4 +54,5 @@ public class UserDetailsServiceConfigurer, C */ protected void initUserDetailsService() throws Exception { } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessor.java b/config/src/main/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessor.java index 813e05684c..7792ff44ce 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.configuration; import java.util.ArrayList; @@ -39,24 +40,21 @@ import org.springframework.util.Assert; */ final class AutowireBeanFactoryObjectPostProcessor implements ObjectPostProcessor, DisposableBean, SmartInitializingSingleton { + private final Log logger = LogFactory.getLog(getClass()); + private final AutowireCapableBeanFactory autowireBeanFactory; + private final List disposableBeans = new ArrayList<>(); + private final List smartSingletons = new ArrayList<>(); - AutowireBeanFactoryObjectPostProcessor( - AutowireCapableBeanFactory autowireBeanFactory) { + AutowireBeanFactoryObjectPostProcessor(AutowireCapableBeanFactory autowireBeanFactory) { Assert.notNull(autowireBeanFactory, "autowireBeanFactory cannot be null"); this.autowireBeanFactory = autowireBeanFactory; } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.config.annotation.web.Initializer#initialize(java. - * lang.Object) - */ + @Override @SuppressWarnings("unchecked") public T postProcess(T object) { if (object == null) { @@ -64,13 +62,11 @@ final class AutowireBeanFactoryObjectPostProcessor } T result = null; try { - result = (T) this.autowireBeanFactory.initializeBean(object, - object.toString()); + result = (T) this.autowireBeanFactory.initializeBean(object, object.toString()); } - catch (RuntimeException e) { + catch (RuntimeException ex) { Class type = object.getClass(); - throw new RuntimeException( - "Could not postProcess " + object + " of type " + type, e); + throw new RuntimeException("Could not postProcess " + object + " of type " + type, ex); } this.autowireBeanFactory.autowireBean(object); if (result instanceof DisposableBean) { @@ -82,28 +78,21 @@ final class AutowireBeanFactoryObjectPostProcessor return result; } - /* (non-Javadoc) - * @see org.springframework.beans.factory.SmartInitializingSingleton#afterSingletonsInstantiated() - */ @Override public void afterSingletonsInstantiated() { - for (SmartInitializingSingleton singleton : smartSingletons) { + for (SmartInitializingSingleton singleton : this.smartSingletons) { singleton.afterSingletonsInstantiated(); } } - /* - * (non-Javadoc) - * - * @see org.springframework.beans.factory.DisposableBean#destroy() - */ + @Override public void destroy() { for (DisposableBean disposable : this.disposableBeans) { try { disposable.destroy(); } - catch (Exception error) { - this.logger.error(error); + catch (Exception ex) { + this.logger.error(ex); } } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/configuration/ObjectPostProcessorConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/configuration/ObjectPostProcessorConfiguration.java index 3c61557c4c..cb1ffde00d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/configuration/ObjectPostProcessorConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/configuration/ObjectPostProcessorConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.configuration; import org.springframework.beans.factory.config.AutowireCapableBeanFactory; @@ -31,7 +32,6 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe * * @see EnableWebSecurity * @see EnableGlobalMethodSecurity - * * @author Rob Winch * @since 3.2 */ @@ -41,8 +41,8 @@ public class ObjectPostProcessorConfiguration { @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - public ObjectPostProcessor objectPostProcessor( - AutowireCapableBeanFactory beanFactory) { + public ObjectPostProcessor objectPostProcessor(AutowireCapableBeanFactory beanFactory) { return new AutowireBeanFactoryObjectPostProcessor(beanFactory); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableGlobalMethodSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableGlobalMethodSecurity.java index 7c1fbc49d0..34517d6beb 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableGlobalMethodSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableGlobalMethodSecurity.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import org.springframework.context.annotation.AdviceMode; @@ -28,8 +31,8 @@ import org.springframework.security.config.annotation.authentication.configurati /** *

- * Enables Spring Security global method security similar to the <global-method-security> - * xml support. + * Enables Spring Security global method security similar to the + * <global-method-security> xml support. * *

* More advanced configurations may wish to extend @@ -41,8 +44,8 @@ import org.springframework.security.config.annotation.authentication.configurati * @author Rob Winch * @since 3.2 */ -@Retention(value = java.lang.annotation.RetentionPolicy.RUNTIME) -@Target(value = { java.lang.annotation.ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) @Documented @Import({ GlobalMethodSecuritySelector.class }) @EnableGlobalAuthentication @@ -82,7 +85,6 @@ public @interface EnableGlobalMethodSecurity { * annotation will be upgraded to subclass proxying at the same time. This approach * has no negative impact in practice unless one is explicitly expecting one type of * proxy vs another, e.g. in tests. - * * @return true if CGILIB proxies should be created instead of interface based * proxies, else false */ @@ -92,7 +94,6 @@ public @interface EnableGlobalMethodSecurity { * Indicate how security advice should be applied. The default is * {@link AdviceMode#PROXY}. * @see AdviceMode - * * @return the {@link AdviceMode} to use */ AdviceMode mode() default AdviceMode.PROXY; @@ -101,8 +102,8 @@ public @interface EnableGlobalMethodSecurity { * Indicate the ordering of the execution of the security advisor when multiple * advices are applied at a specific joinpoint. The default is * {@link Ordered#LOWEST_PRECEDENCE}. - * * @return the order the security advisor should be applied */ int order() default Ordered.LOWEST_PRECEDENCE; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurity.java index 3d659bc4eb..8e129695c4 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurity.java @@ -16,36 +16,40 @@ package org.springframework.security.config.annotation.method.configuration; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + import org.springframework.context.annotation.AdviceMode; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.core.Ordered; -import java.lang.annotation.Documented; -import java.lang.annotation.Retention; -import java.lang.annotation.Target; - /** * * @author Rob Winch * @since 5.0 */ -@Retention(value = java.lang.annotation.RetentionPolicy.RUNTIME) -@Target(value = { java.lang.annotation.ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) @Documented -@Import({ ReactiveMethodSecuritySelector.class }) +@Import(ReactiveMethodSecuritySelector.class) @Configuration public @interface EnableReactiveMethodSecurity { + /** - * Indicate whether subclass-based (CGLIB) proxies are to be created as opposed - * to standard Java interface-based proxies. The default is {@code false}. + * Indicate whether subclass-based (CGLIB) proxies are to be created as opposed to + * standard Java interface-based proxies. The default is {@code false}. * Applicable only if {@link #mode()} is set to {@link AdviceMode#PROXY}. - *

Note that setting this attribute to {@code true} will affect all - * Spring-managed beans requiring proxying, not just those marked with {@code @Cacheable}. - * For example, other beans marked with Spring's {@code @Transactional} annotation will - * be upgraded to subclass proxying at the same time. This approach has no negative - * impact in practice unless one is explicitly expecting one type of proxy vs another, - * e.g. in tests. + *

+ * Note that setting this attribute to {@code true} will affect all + * Spring-managed beans requiring proxying, not just those marked with + * {@code @Cacheable}. For example, other beans marked with Spring's + * {@code @Transactional} annotation will be upgraded to subclass proxying at the same + * time. This approach has no negative impact in practice unless one is explicitly + * expecting one type of proxy vs another, e.g. in tests. */ boolean proxyTargetClass() default false; @@ -53,7 +57,6 @@ public @interface EnableReactiveMethodSecurity { * Indicate how security advice should be applied. The default is * {@link AdviceMode#PROXY}. * @see AdviceMode - * * @return the {@link AdviceMode} to use */ AdviceMode mode() default AdviceMode.PROXY; @@ -62,8 +65,8 @@ public @interface EnableReactiveMethodSecurity { * Indicate the ordering of the execution of the security advisor when multiple * advices are applied at a specific joinpoint. The default is * {@link Ordered#LOWEST_PRECEDENCE}. - * * @return the order the security advisor should be applied */ int order() default Ordered.LOWEST_PRECEDENCE; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityAspectJAutoProxyRegistrar.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityAspectJAutoProxyRegistrar.java index 392c58f3d8..b2c44c9280 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityAspectJAutoProxyRegistrar.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityAspectJAutoProxyRegistrar.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import org.springframework.beans.factory.config.BeanDefinition; @@ -36,28 +37,22 @@ import org.springframework.core.type.AnnotationMetadata; * @author Rob Winch * @since 3.2 */ -class GlobalMethodSecurityAspectJAutoProxyRegistrar implements - ImportBeanDefinitionRegistrar { +class GlobalMethodSecurityAspectJAutoProxyRegistrar implements ImportBeanDefinitionRegistrar { /** * Register, escalate, and configure the AspectJ auto proxy creator based on the value * of the @{@link EnableGlobalMethodSecurity#proxyTargetClass()} attribute on the * importing {@code @Configuration} class. */ - public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, - BeanDefinitionRegistry registry) { - - BeanDefinition interceptor = registry - .getBeanDefinition("methodSecurityInterceptor"); - - BeanDefinitionBuilder aspect = BeanDefinitionBuilder - .rootBeanDefinition("org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"); + @Override + public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { + BeanDefinition interceptor = registry.getBeanDefinition("methodSecurityInterceptor"); + BeanDefinitionBuilder aspect = BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"); aspect.setFactoryMethod("aspectOf"); aspect.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); aspect.addPropertyValue("securityInterceptor", interceptor); - - registry.registerBeanDefinition("annotationSecurityAspect$0", - aspect.getBeanDefinition()); + registry.registerBeanDefinition("annotationSecurityAspect$0", aspect.getBeanDefinition()); } -} \ No newline at end of file +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfiguration.java index eb3207f823..89c713062d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import java.util.ArrayList; @@ -30,7 +31,11 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.context.annotation.*; +import org.springframework.context.annotation.AdviceMode; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.ImportAware; +import org.springframework.context.annotation.Role; import org.springframework.core.annotation.AnnotationAttributes; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.type.AnnotationMetadata; @@ -82,24 +87,34 @@ import org.springframework.util.Assert; */ @Configuration(proxyBeanMethods = false) @Role(BeanDefinition.ROLE_INFRASTRUCTURE) -public class GlobalMethodSecurityConfiguration - implements ImportAware, SmartInitializingSingleton, BeanFactoryAware { - private static final Log logger = LogFactory - .getLog(GlobalMethodSecurityConfiguration.class); +public class GlobalMethodSecurityConfiguration implements ImportAware, SmartInitializingSingleton, BeanFactoryAware { + + private static final Log logger = LogFactory.getLog(GlobalMethodSecurityConfiguration.class); + private ObjectPostProcessor objectPostProcessor = new ObjectPostProcessor() { + + @Override public T postProcess(T object) { throw new IllegalStateException(ObjectPostProcessor.class.getName() - + " is a required bean. Ensure you have used @" - + EnableGlobalMethodSecurity.class.getName()); + + " is a required bean. Ensure you have used @" + EnableGlobalMethodSecurity.class.getName()); } + }; + private DefaultMethodSecurityExpressionHandler defaultMethodExpressionHandler = new DefaultMethodSecurityExpressionHandler(); + private AuthenticationManager authenticationManager; + private AuthenticationManagerBuilder auth; + private boolean disableAuthenticationRegistry; + private AnnotationAttributes enableMethodSecurity; + private BeanFactory context; + private MethodSecurityExpressionHandler expressionHandler; + private MethodSecurityInterceptor methodSecurityInterceptor; /** @@ -117,72 +132,56 @@ public class GlobalMethodSecurityConfiguration * Subclasses can override this method to provide a different * {@link MethodInterceptor}. *

- * @param methodSecurityMetadataSource the default {@link MethodSecurityMetadataSource}. - * + * @param methodSecurityMetadataSource the default + * {@link MethodSecurityMetadataSource}. * @return the {@link MethodInterceptor}. */ @Bean public MethodInterceptor methodSecurityInterceptor(MethodSecurityMetadataSource methodSecurityMetadataSource) { - this.methodSecurityInterceptor = isAspectJ() - ? new AspectJMethodSecurityInterceptor() + this.methodSecurityInterceptor = isAspectJ() ? new AspectJMethodSecurityInterceptor() : new MethodSecurityInterceptor(); - methodSecurityInterceptor.setAccessDecisionManager(accessDecisionManager()); - methodSecurityInterceptor.setAfterInvocationManager(afterInvocationManager()); - methodSecurityInterceptor - .setSecurityMetadataSource(methodSecurityMetadataSource); + this.methodSecurityInterceptor.setAccessDecisionManager(accessDecisionManager()); + this.methodSecurityInterceptor.setAfterInvocationManager(afterInvocationManager()); + this.methodSecurityInterceptor.setSecurityMetadataSource(methodSecurityMetadataSource); RunAsManager runAsManager = runAsManager(); if (runAsManager != null) { - methodSecurityInterceptor.setRunAsManager(runAsManager); + this.methodSecurityInterceptor.setRunAsManager(runAsManager); } - return this.methodSecurityInterceptor; } - /* - * (non-Javadoc) - * - * @see org.springframework.beans.factory.SmartInitializingSingleton# - * afterSingletonsInstantiated() - */ @Override public void afterSingletonsInstantiated() { try { initializeMethodSecurityInterceptor(); } - catch (Exception e) { - throw new RuntimeException(e); + catch (Exception ex) { + throw new RuntimeException(ex); } - - PermissionEvaluator permissionEvaluator = getSingleBeanOrNull( - PermissionEvaluator.class); + PermissionEvaluator permissionEvaluator = getSingleBeanOrNull(PermissionEvaluator.class); if (permissionEvaluator != null) { - this.defaultMethodExpressionHandler - .setPermissionEvaluator(permissionEvaluator); + this.defaultMethodExpressionHandler.setPermissionEvaluator(permissionEvaluator); } - RoleHierarchy roleHierarchy = getSingleBeanOrNull(RoleHierarchy.class); if (roleHierarchy != null) { this.defaultMethodExpressionHandler.setRoleHierarchy(roleHierarchy); } - - AuthenticationTrustResolver trustResolver = getSingleBeanOrNull( - AuthenticationTrustResolver.class); + AuthenticationTrustResolver trustResolver = getSingleBeanOrNull(AuthenticationTrustResolver.class); if (trustResolver != null) { this.defaultMethodExpressionHandler.setTrustResolver(trustResolver); } - - GrantedAuthorityDefaults grantedAuthorityDefaults = getSingleBeanOrNull( - GrantedAuthorityDefaults.class); + GrantedAuthorityDefaults grantedAuthorityDefaults = getSingleBeanOrNull(GrantedAuthorityDefaults.class); if (grantedAuthorityDefaults != null) { - this.defaultMethodExpressionHandler.setDefaultRolePrefix( - grantedAuthorityDefaults.getRolePrefix()); + this.defaultMethodExpressionHandler.setDefaultRolePrefix(grantedAuthorityDefaults.getRolePrefix()); } } private T getSingleBeanOrNull(Class type) { try { - return context.getBean(type); - } catch (NoSuchBeanDefinitionException e) {} + return this.context.getBean(type); + } + catch (NoSuchBeanDefinitionException ex) { + } return null; } @@ -195,14 +194,14 @@ public class GlobalMethodSecurityConfiguration /** * Provide a custom {@link AfterInvocationManager} for the default implementation of - * {@link #methodSecurityInterceptor(MethodSecurityMetadataSource)}. The default is null - * if pre post is not enabled. Otherwise, it returns a {@link AfterInvocationProviderManager}. + * {@link #methodSecurityInterceptor(MethodSecurityMetadataSource)}. The default is + * null if pre post is not enabled. Otherwise, it returns a + * {@link AfterInvocationProviderManager}. * *

* Subclasses should override this method to provide a custom * {@link AfterInvocationManager} *

- * * @return the {@link AfterInvocationManager} to use */ protected AfterInvocationManager afterInvocationManager() { @@ -210,8 +209,7 @@ public class GlobalMethodSecurityConfiguration AfterInvocationProviderManager invocationProviderManager = new AfterInvocationProviderManager(); ExpressionBasedPostInvocationAdvice postAdvice = new ExpressionBasedPostInvocationAdvice( getExpressionHandler()); - PostInvocationAdviceProvider postInvocationAdviceProvider = new PostInvocationAdviceProvider( - postAdvice); + PostInvocationAdviceProvider postInvocationAdviceProvider = new PostInvocationAdviceProvider(postAdvice); List afterInvocationProviders = new ArrayList<>(); afterInvocationProviders.add(postInvocationAdviceProvider); invocationProviderManager.setProviders(afterInvocationProviders); @@ -222,8 +220,8 @@ public class GlobalMethodSecurityConfiguration /** * Provide a custom {@link RunAsManager} for the default implementation of - * {@link #methodSecurityInterceptor(MethodSecurityMetadataSource)}. The default is null. - * + * {@link #methodSecurityInterceptor(MethodSecurityMetadataSource)}. The default is + * null. * @return the {@link RunAsManager} to use */ protected RunAsManager runAsManager() { @@ -239,24 +237,20 @@ public class GlobalMethodSecurityConfiguration *
  • {@link RoleVoter}
  • *
  • {@link AuthenticatedVoter}
  • * - * * @return the {@link AccessDecisionManager} to use */ protected AccessDecisionManager accessDecisionManager() { List> decisionVoters = new ArrayList<>(); if (prePostEnabled()) { - ExpressionBasedPreInvocationAdvice expressionAdvice = - new ExpressionBasedPreInvocationAdvice(); + ExpressionBasedPreInvocationAdvice expressionAdvice = new ExpressionBasedPreInvocationAdvice(); expressionAdvice.setExpressionHandler(getExpressionHandler()); - decisionVoters - .add(new PreInvocationAuthorizationAdviceVoter(expressionAdvice)); + decisionVoters.add(new PreInvocationAuthorizationAdviceVoter(expressionAdvice)); } if (jsr250Enabled()) { decisionVoters.add(new Jsr250Voter()); } RoleVoter roleVoter = new RoleVoter(); - GrantedAuthorityDefaults grantedAuthorityDefaults = - getSingleBeanOrNull(GrantedAuthorityDefaults.class); + GrantedAuthorityDefaults grantedAuthorityDefaults = getSingleBeanOrNull(GrantedAuthorityDefaults.class); if (grantedAuthorityDefaults != null) { roleVoter.setRolePrefix(grantedAuthorityDefaults.getRolePrefix()); } @@ -275,30 +269,27 @@ public class GlobalMethodSecurityConfiguration * Subclasses may override this method to provide a custom * {@link MethodSecurityExpressionHandler} *

    - * * @return the {@link MethodSecurityExpressionHandler} to use */ protected MethodSecurityExpressionHandler createExpressionHandler() { - return defaultMethodExpressionHandler; + return this.defaultMethodExpressionHandler; } /** * Gets the {@link MethodSecurityExpressionHandler} or creates it using * {@link #expressionHandler}. - * * @return a non {@code null} {@link MethodSecurityExpressionHandler} */ protected final MethodSecurityExpressionHandler getExpressionHandler() { - if (expressionHandler == null) { - expressionHandler = createExpressionHandler(); + if (this.expressionHandler == null) { + this.expressionHandler = createExpressionHandler(); } - return expressionHandler; + return this.expressionHandler; } /** * Provides a custom {@link MethodSecurityMetadataSource} that is registered with the * {@link #methodSecurityMetadataSource()}. Default is null. - * * @return a custom {@link MethodSecurityMetadataSource} that is registered with the * {@link #methodSecurityMetadataSource()} */ @@ -312,32 +303,25 @@ public class GlobalMethodSecurityConfiguration * {@link #configure(AuthenticationManagerBuilder)}. If * {@link #configure(AuthenticationManagerBuilder)} was not overridden, then an * {@link AuthenticationManager} is attempted to be autowired by type. - * * @return the {@link AuthenticationManager} to use */ protected AuthenticationManager authenticationManager() throws Exception { - if (authenticationManager == null) { - DefaultAuthenticationEventPublisher eventPublisher = objectPostProcessor + if (this.authenticationManager == null) { + DefaultAuthenticationEventPublisher eventPublisher = this.objectPostProcessor .postProcess(new DefaultAuthenticationEventPublisher()); - auth = new AuthenticationManagerBuilder(objectPostProcessor); - auth.authenticationEventPublisher(eventPublisher); - configure(auth); - if (disableAuthenticationRegistry) { - authenticationManager = getAuthenticationConfiguration() - .getAuthenticationManager(); - } - else { - authenticationManager = auth.build(); - } + this.auth = new AuthenticationManagerBuilder(this.objectPostProcessor); + this.auth.authenticationEventPublisher(eventPublisher); + configure(this.auth); + this.authenticationManager = (this.disableAuthenticationRegistry) + ? getAuthenticationConfiguration().getAuthenticationManager() : this.auth.build(); } - return authenticationManager; + return this.authenticationManager; } /** * Sub classes can override this method to register different types of authentication. * If not overridden, {@link #configure(AuthenticationManagerBuilder)} will attempt to * autowire by type. - * * @param auth the {@link AuthenticationManagerBuilder} used to register different * authentication mechanisms for the global method security. * @throws Exception @@ -351,7 +335,6 @@ public class GlobalMethodSecurityConfiguration * creates a {@link DelegatingMethodSecurityMetadataSource} based upon * {@link #customMethodSecurityMetadataSource()} and the attributes on * {@link EnableGlobalMethodSecurity}. - * * @return the {@link MethodSecurityMetadataSource} */ @Bean @@ -363,17 +346,13 @@ public class GlobalMethodSecurityConfiguration if (customMethodSecurityMetadataSource != null) { sources.add(customMethodSecurityMetadataSource); } - boolean hasCustom = customMethodSecurityMetadataSource != null; boolean isPrePostEnabled = prePostEnabled(); boolean isSecuredEnabled = securedEnabled(); boolean isJsr250Enabled = jsr250Enabled(); - - if (!isPrePostEnabled && !isSecuredEnabled && !isJsr250Enabled && !hasCustom) { - throw new IllegalStateException("In the composition of all global method configuration, " + - "no annotation support was actually activated"); - } - + Assert.state(isPrePostEnabled || isSecuredEnabled || isJsr250Enabled || hasCustom, + "In the composition of all global method configuration, " + + "no annotation support was actually activated"); if (isPrePostEnabled) { sources.add(new PrePostAnnotationSecurityMetadataSource(attributeFactory)); } @@ -381,12 +360,11 @@ public class GlobalMethodSecurityConfiguration sources.add(new SecuredAnnotationSecurityMetadataSource()); } if (isJsr250Enabled) { - GrantedAuthorityDefaults grantedAuthorityDefaults = - getSingleBeanOrNull(GrantedAuthorityDefaults.class); - Jsr250MethodSecurityMetadataSource jsr250MethodSecurityMetadataSource = this.context.getBean(Jsr250MethodSecurityMetadataSource.class); + GrantedAuthorityDefaults grantedAuthorityDefaults = getSingleBeanOrNull(GrantedAuthorityDefaults.class); + Jsr250MethodSecurityMetadataSource jsr250MethodSecurityMetadataSource = this.context + .getBean(Jsr250MethodSecurityMetadataSource.class); if (grantedAuthorityDefaults != null) { - jsr250MethodSecurityMetadataSource.setDefaultRolePrefix( - grantedAuthorityDefaults.getRolePrefix()); + jsr250MethodSecurityMetadataSource.setDefaultRolePrefix(grantedAuthorityDefaults.getRolePrefix()); } sources.add(jsr250MethodSecurityMetadataSource); } @@ -396,7 +374,6 @@ public class GlobalMethodSecurityConfiguration /** * Creates the {@link PreInvocationAuthorizationAdvice} to be used. The default is * {@link ExpressionBasedPreInvocationAdvice}. - * * @return the {@link PreInvocationAuthorizationAdvice} */ @Bean @@ -410,25 +387,23 @@ public class GlobalMethodSecurityConfiguration * Obtains the attributes from {@link EnableGlobalMethodSecurity} if this class was * imported using the {@link EnableGlobalMethodSecurity} annotation. */ + @Override public final void setImportMetadata(AnnotationMetadata importMetadata) { Map annotationAttributes = importMetadata .getAnnotationAttributes(EnableGlobalMethodSecurity.class.getName()); - enableMethodSecurity = AnnotationAttributes.fromMap(annotationAttributes); + this.enableMethodSecurity = AnnotationAttributes.fromMap(annotationAttributes); } @Autowired(required = false) public void setObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { this.objectPostProcessor = objectPostProcessor; - this.defaultMethodExpressionHandler = objectPostProcessor - .postProcess(defaultMethodExpressionHandler); + this.defaultMethodExpressionHandler = objectPostProcessor.postProcess(this.defaultMethodExpressionHandler); } @Autowired(required = false) - public void setMethodSecurityExpressionHandler( - List handlers) { + public void setMethodSecurityExpressionHandler(List handlers) { if (handlers.size() != 1) { - logger.debug("Not autowiring MethodSecurityExpressionHandler since size != 1. Got " - + handlers); + logger.debug("Not autowiring MethodSecurityExpressionHandler since size != 1. Got " + handlers); return; } this.expressionHandler = handlers.get(0); @@ -440,7 +415,7 @@ public class GlobalMethodSecurityConfiguration } private AuthenticationConfiguration getAuthenticationConfiguration() { - return context.getBean(AuthenticationConfiguration.class); + return this.context.getBean(AuthenticationConfiguration.class); } private boolean prePostEnabled() { @@ -455,25 +430,20 @@ public class GlobalMethodSecurityConfiguration return enableMethodSecurity().getBoolean("jsr250Enabled"); } - private int order() { - return (Integer) enableMethodSecurity().get("order"); - } - private boolean isAspectJ() { return enableMethodSecurity().getEnum("mode") == AdviceMode.ASPECTJ; } private AnnotationAttributes enableMethodSecurity() { - if (enableMethodSecurity == null) { + if (this.enableMethodSecurity == null) { // if it is null look at this instance (i.e. a subclass was used) - EnableGlobalMethodSecurity methodSecurityAnnotation = AnnotationUtils - .findAnnotation(getClass(), EnableGlobalMethodSecurity.class); - Assert.notNull(methodSecurityAnnotation, - () -> EnableGlobalMethodSecurity.class.getName() + " is required"); - Map methodSecurityAttrs = AnnotationUtils - .getAnnotationAttributes(methodSecurityAnnotation); + EnableGlobalMethodSecurity methodSecurityAnnotation = AnnotationUtils.findAnnotation(getClass(), + EnableGlobalMethodSecurity.class); + Assert.notNull(methodSecurityAnnotation, () -> EnableGlobalMethodSecurity.class.getName() + " is required"); + Map methodSecurityAttrs = AnnotationUtils.getAnnotationAttributes(methodSecurityAnnotation); this.enableMethodSecurity = AnnotationAttributes.fromMap(methodSecurityAttrs); } return this.enableMethodSecurity; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecuritySelector.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecuritySelector.java index 1ffa8256a2..516835b0a4 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecuritySelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecuritySelector.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import java.util.ArrayList; @@ -36,46 +37,36 @@ import org.springframework.util.ClassUtils; */ final class GlobalMethodSecuritySelector implements ImportSelector { + @Override public String[] selectImports(AnnotationMetadata importingClassMetadata) { Class annoType = EnableGlobalMethodSecurity.class; - Map annotationAttributes = importingClassMetadata - .getAnnotationAttributes(annoType.getName(), false); - AnnotationAttributes attributes = AnnotationAttributes - .fromMap(annotationAttributes); - Assert.notNull(attributes, () -> String.format( - "@%s is not present on importing class '%s' as expected", + Map annotationAttributes = importingClassMetadata.getAnnotationAttributes(annoType.getName(), + false); + AnnotationAttributes attributes = AnnotationAttributes.fromMap(annotationAttributes); + Assert.notNull(attributes, () -> String.format("@%s is not present on importing class '%s' as expected", annoType.getSimpleName(), importingClassMetadata.getClassName())); - // TODO would be nice if could use BeanClassLoaderAware (does not work) - Class importingClass = ClassUtils - .resolveClassName(importingClassMetadata.getClassName(), - ClassUtils.getDefaultClassLoader()); + Class importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), + ClassUtils.getDefaultClassLoader()); boolean skipMethodSecurityConfiguration = GlobalMethodSecurityConfiguration.class .isAssignableFrom(importingClass); - AdviceMode mode = attributes.getEnum("mode"); boolean isProxy = AdviceMode.PROXY == mode; - String autoProxyClassName = isProxy ? AutoProxyRegistrar.class - .getName() : GlobalMethodSecurityAspectJAutoProxyRegistrar.class - .getName(); - + String autoProxyClassName = isProxy ? AutoProxyRegistrar.class.getName() + : GlobalMethodSecurityAspectJAutoProxyRegistrar.class.getName(); boolean jsr250Enabled = attributes.getBoolean("jsr250Enabled"); - List classNames = new ArrayList<>(4); if (isProxy) { classNames.add(MethodSecurityMetadataSourceAdvisorRegistrar.class.getName()); } - classNames.add(autoProxyClassName); - if (!skipMethodSecurityConfiguration) { classNames.add(GlobalMethodSecurityConfiguration.class.getName()); } - if (jsr250Enabled) { classNames.add(Jsr250MetadataSourceConfiguration.class.getName()); } - return classNames.toArray(new String[0]); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/Jsr250MetadataSourceConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/Jsr250MetadataSourceConfiguration.java index 5c98bf48fe..147e8c6070 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/Jsr250MetadataSourceConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/Jsr250MetadataSourceConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import org.springframework.beans.factory.config.BeanDefinition; @@ -27,7 +28,8 @@ class Jsr250MetadataSourceConfiguration { @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - public Jsr250MethodSecurityMetadataSource jsr250MethodSecurityMetadataSource() { + Jsr250MethodSecurityMetadataSource jsr250MethodSecurityMetadataSource() { return new Jsr250MethodSecurityMetadataSource(); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityMetadataSourceAdvisorRegistrar.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityMetadataSourceAdvisorRegistrar.java index 8ef8f1e4af..84a33b00cd 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityMetadataSourceAdvisorRegistrar.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityMetadataSourceAdvisorRegistrar.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import org.springframework.beans.factory.config.BeanDefinition; @@ -24,39 +25,36 @@ import org.springframework.security.access.intercept.aopalliance.MethodSecurityM import org.springframework.util.MultiValueMap; /** - * Creates Spring Security's MethodSecurityMetadataSourceAdvisor only when - * using proxy based method security (i.e. do not do it when using ASPECTJ). - * The conditional logic is controlled through {@link GlobalMethodSecuritySelector}. + * Creates Spring Security's MethodSecurityMetadataSourceAdvisor only when using proxy + * based method security (i.e. do not do it when using ASPECTJ). The conditional logic is + * controlled through {@link GlobalMethodSecuritySelector}. * * @author Rob Winch * @since 4.0.2 * @see GlobalMethodSecuritySelector */ -class MethodSecurityMetadataSourceAdvisorRegistrar implements - ImportBeanDefinitionRegistrar { +class MethodSecurityMetadataSourceAdvisorRegistrar implements ImportBeanDefinitionRegistrar { /** * Register, escalate, and configure the AspectJ auto proxy creator based on the value * of the @{@link EnableGlobalMethodSecurity#proxyTargetClass()} attribute on the * importing {@code @Configuration} class. */ - public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, - BeanDefinitionRegistry registry) { - + @Override + public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { BeanDefinitionBuilder advisor = BeanDefinitionBuilder .rootBeanDefinition(MethodSecurityMetadataSourceAdvisor.class); advisor.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); advisor.addConstructorArgValue("methodSecurityInterceptor"); advisor.addConstructorArgReference("methodSecurityMetadataSource"); advisor.addConstructorArgValue("methodSecurityMetadataSource"); - - MultiValueMap attributes = importingClassMetadata.getAllAnnotationAttributes(EnableGlobalMethodSecurity.class.getName()); + MultiValueMap attributes = importingClassMetadata + .getAllAnnotationAttributes(EnableGlobalMethodSecurity.class.getName()); Integer order = (Integer) attributes.getFirst("order"); if (order != null) { advisor.addPropertyValue("order", order); } - - registry.registerBeanDefinition("metaDataSourceAdvisor", - advisor.getBeanDefinition()); + registry.registerBeanDefinition("metaDataSourceAdvisor", advisor.getBeanDefinition()); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfiguration.java index b1ba9ae5d8..e66dec4d6e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.method.configuration; +import java.util.Arrays; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.context.annotation.Bean; @@ -23,7 +25,11 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportAware; import org.springframework.context.annotation.Role; import org.springframework.core.type.AnnotationMetadata; -import org.springframework.security.access.expression.method.*; +import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler; +import org.springframework.security.access.expression.method.ExpressionBasedAnnotationAttributeFactory; +import org.springframework.security.access.expression.method.ExpressionBasedPostInvocationAdvice; +import org.springframework.security.access.expression.method.ExpressionBasedPreInvocationAdvice; +import org.springframework.security.access.expression.method.MethodSecurityExpressionHandler; import org.springframework.security.access.intercept.aopalliance.MethodSecurityMetadataSourceAdvisor; import org.springframework.security.access.method.AbstractMethodSecurityMetadataSource; import org.springframework.security.access.method.DelegatingMethodSecurityMetadataSource; @@ -31,8 +37,6 @@ import org.springframework.security.access.prepost.PrePostAdviceReactiveMethodIn import org.springframework.security.access.prepost.PrePostAnnotationSecurityMetadataSource; import org.springframework.security.config.core.GrantedAuthorityDefaults; -import java.util.Arrays; - /** * @author Rob Winch * @author Tadaya Tsuyukubo @@ -40,43 +44,43 @@ import java.util.Arrays; */ @Configuration(proxyBeanMethods = false) class ReactiveMethodSecurityConfiguration implements ImportAware { + private int advisorOrder; private GrantedAuthorityDefaults grantedAuthorityDefaults; @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - public MethodSecurityMetadataSourceAdvisor methodSecurityInterceptor(AbstractMethodSecurityMetadataSource source) { + MethodSecurityMetadataSourceAdvisor methodSecurityInterceptor(AbstractMethodSecurityMetadataSource source) { MethodSecurityMetadataSourceAdvisor advisor = new MethodSecurityMetadataSourceAdvisor( - "securityMethodInterceptor", source, "methodMetadataSource"); - advisor.setOrder(advisorOrder); + "securityMethodInterceptor", source, "methodMetadataSource"); + advisor.setOrder(this.advisorOrder); return advisor; } @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - public DelegatingMethodSecurityMetadataSource methodMetadataSource(MethodSecurityExpressionHandler methodSecurityExpressionHandler) { + DelegatingMethodSecurityMetadataSource methodMetadataSource( + MethodSecurityExpressionHandler methodSecurityExpressionHandler) { ExpressionBasedAnnotationAttributeFactory attributeFactory = new ExpressionBasedAnnotationAttributeFactory( methodSecurityExpressionHandler); PrePostAnnotationSecurityMetadataSource prePostSource = new PrePostAnnotationSecurityMetadataSource( - attributeFactory); + attributeFactory); return new DelegatingMethodSecurityMetadataSource(Arrays.asList(prePostSource)); } @Bean - public PrePostAdviceReactiveMethodInterceptor securityMethodInterceptor(AbstractMethodSecurityMetadataSource source, MethodSecurityExpressionHandler handler) { - - ExpressionBasedPostInvocationAdvice postAdvice = new ExpressionBasedPostInvocationAdvice( - handler); + PrePostAdviceReactiveMethodInterceptor securityMethodInterceptor(AbstractMethodSecurityMetadataSource source, + MethodSecurityExpressionHandler handler) { + ExpressionBasedPostInvocationAdvice postAdvice = new ExpressionBasedPostInvocationAdvice(handler); ExpressionBasedPreInvocationAdvice preAdvice = new ExpressionBasedPreInvocationAdvice(); preAdvice.setExpressionHandler(handler); - return new PrePostAdviceReactiveMethodInterceptor(source, preAdvice, postAdvice); } @Bean @Role(BeanDefinition.ROLE_INFRASTRUCTURE) - public DefaultMethodSecurityExpressionHandler methodSecurityExpressionHandler() { + DefaultMethodSecurityExpressionHandler methodSecurityExpressionHandler() { DefaultMethodSecurityExpressionHandler handler = new DefaultMethodSecurityExpressionHandler(); if (this.grantedAuthorityDefaults != null) { handler.setDefaultRolePrefix(this.grantedAuthorityDefaults.getRolePrefix()); @@ -86,7 +90,8 @@ class ReactiveMethodSecurityConfiguration implements ImportAware { @Override public void setImportMetadata(AnnotationMetadata importMetadata) { - this.advisorOrder = (int) importMetadata.getAnnotationAttributes(EnableReactiveMethodSecurity.class.getName()).get("order"); + this.advisorOrder = (int) importMetadata.getAnnotationAttributes(EnableReactiveMethodSecurity.class.getName()) + .get("order"); } @Autowired(required = false) diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecuritySelector.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecuritySelector.java index 612432215e..17e350e5f2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecuritySelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecuritySelector.java @@ -16,33 +16,32 @@ package org.springframework.security.config.annotation.method.configuration; +import java.util.ArrayList; +import java.util.List; + import org.springframework.context.annotation.AdviceMode; import org.springframework.context.annotation.AdviceModeImportSelector; import org.springframework.context.annotation.AutoProxyRegistrar; -import java.util.ArrayList; -import java.util.List; - /** * @author Rob Winch * @since 5.0 */ -class ReactiveMethodSecuritySelector extends - AdviceModeImportSelector { +class ReactiveMethodSecuritySelector extends AdviceModeImportSelector { @Override protected String[] selectImports(AdviceMode adviceMode) { - switch (adviceMode) { - case PROXY: - return getProxyImports(); - default: - throw new IllegalStateException("AdviceMode " + adviceMode + " is not supported"); + if (adviceMode == AdviceMode.PROXY) { + return getProxyImports(); } + throw new IllegalStateException("AdviceMode " + adviceMode + " is not supported"); } /** - * Return the imports to use if the {@link AdviceMode} is set to {@link AdviceMode#PROXY}. - *

    Take care of adding the necessary JSR-107 import if it is available. + * Return the imports to use if the {@link AdviceMode} is set to + * {@link AdviceMode#PROXY}. + *

    + * Take care of adding the necessary JSR-107 import if it is available. */ private String[] getProxyImports() { List result = new ArrayList<>(); @@ -50,4 +49,5 @@ class ReactiveMethodSecuritySelector extends result.add(ReactiveMethodSecurityConfiguration.class.getName()); return result.toArray(new String[0]); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java index 440186090f..29058c10e6 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/EnableRSocketSecurity.java @@ -16,14 +16,14 @@ package org.springframework.security.config.annotation.rsocket; -import org.springframework.context.annotation.Import; - import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.springframework.context.annotation.Import; + /** * Add this annotation to a {@code Configuration} class to have Spring Security * {@link RSocketSecurity} support added. @@ -36,4 +36,6 @@ import java.lang.annotation.Target; @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Import({ RSocketSecurityConfiguration.class, SecuritySocketAcceptorInterceptorConfiguration.class }) -public @interface EnableRSocketSecurity { } +public @interface EnableRSocketSecurity { + +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/PayloadInterceptorOrder.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/PayloadInterceptorOrder.java index eba69bd9c5..4577301714 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/PayloadInterceptorOrder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/PayloadInterceptorOrder.java @@ -21,14 +21,15 @@ import org.springframework.security.config.Customizer; import org.springframework.security.rsocket.api.PayloadInterceptor; /** - * The standard order for {@link PayloadInterceptor} to be - * sorted. The actual values might change, so users should use the {@link #getOrder()} method to - * calculate the position dynamically rather than copy values. + * The standard order for {@link PayloadInterceptor} to be sorted. The actual values might + * change, so users should use the {@link #getOrder()} method to calculate the position + * dynamically rather than copy values. * * @author Rob Winch * @since 5.2 */ public enum PayloadInterceptorOrder implements Ordered { + /** * Where basic authentication is placed. * @see RSocketSecurity#basicAuthentication(Customizer) @@ -62,7 +63,9 @@ public enum PayloadInterceptorOrder implements Ordered { this.order = ordinal() * INTERVAL; } + @Override public int getOrder() { return this.order; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java index 8ab04005e8..4beef6947b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurity.java @@ -16,6 +16,12 @@ package org.springframework.security.config.annotation.rsocket; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import reactor.core.publisher.Mono; + import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.core.ResolvableType; @@ -30,23 +36,18 @@ import org.springframework.security.config.Customizer; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager; import org.springframework.security.rsocket.api.PayloadInterceptor; -import org.springframework.security.rsocket.authentication.AuthenticationPayloadExchangeConverter; -import org.springframework.security.rsocket.core.PayloadSocketAcceptorInterceptor; import org.springframework.security.rsocket.authentication.AnonymousPayloadInterceptor; +import org.springframework.security.rsocket.authentication.AuthenticationPayloadExchangeConverter; import org.springframework.security.rsocket.authentication.AuthenticationPayloadInterceptor; import org.springframework.security.rsocket.authentication.BearerPayloadExchangeConverter; import org.springframework.security.rsocket.authorization.AuthorizationPayloadInterceptor; import org.springframework.security.rsocket.authorization.PayloadExchangeMatcherReactiveAuthorizationManager; +import org.springframework.security.rsocket.core.PayloadSocketAcceptorInterceptor; import org.springframework.security.rsocket.util.matcher.PayloadExchangeAuthorizationContext; import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcher; import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcherEntry; import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatchers; import org.springframework.security.rsocket.util.matcher.RoutePayloadExchangeMatcher; -import reactor.core.publisher.Mono; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; /** * Allows configuring RSocket based security. @@ -56,19 +57,16 @@ import java.util.List; *

      * @EnableRSocketSecurity
      * public class SecurityConfig {
    - *     // @formatter:off
      *     @Bean
      *     PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
      *         rsocket
    - *             .authorizePayload(authorize ->
    + *             .authorizePayload((authorize) ->
      *                 authorize
      *                     .anyRequest().authenticated()
      *             );
      *         return rsocket.build();
      *     }
    - *     // @formatter:on
      *
    - *     // @formatter:off
      *     @Bean
      *     public MapReactiveUserDetailsService userDetailsService() {
      *          UserDetails user = User.withDefaultPasswordEncoder()
    @@ -78,7 +76,6 @@ import java.util.List;
      *               .build();
      *          return new MapReactiveUserDetailsService(user);
      *     }
    - *     // @formatter:on
      * }
      * 
    * @@ -87,11 +84,10 @@ import java.util.List; *
      * @EnableRSocketSecurity
      * public class SecurityConfig {
    - *     // @formatter:off
      *     @Bean
      *     PayloadSocketAcceptorInterceptor rsocketInterceptor(RSocketSecurity rsocket) {
      *         rsocket
    - *             .authorizePayload(authorize ->
    + *             .authorizePayload((authorize) ->
      *                 authorize
      *                     // must have ROLE_SETUP to make connection
      *                     .setup().hasRole("SETUP")
    @@ -102,9 +98,9 @@ import java.util.List;
      *             );
      *         return rsocket.build();
      *     }
    - *     // @formatter:on
      * }
      * 
    + * * @author Rob Winch * @author Jesús Ascama Arias * @author Luis Felipe Vega @@ -129,12 +125,12 @@ public class RSocketSecurity { private ReactiveAuthenticationManager authenticationManager; /** - * Adds a {@link PayloadInterceptor} to be used. This is typically only used - * when using the DSL does not meet a users needs. In order to ensure the - * {@link PayloadInterceptor} is done in the proper order the {@link PayloadInterceptor} should - * either implement {@link org.springframework.core.Ordered} or be annotated with + * Adds a {@link PayloadInterceptor} to be used. This is typically only used when + * using the DSL does not meet a users needs. In order to ensure the + * {@link PayloadInterceptor} is done in the proper order the + * {@link PayloadInterceptor} should either implement + * {@link org.springframework.core.Ordered} or be annotated with * {@link org.springframework.core.annotation.Order}. - * * @param interceptor * @return the builder for additional customizations * @see PayloadInterceptorOrder @@ -150,8 +146,9 @@ public class RSocketSecurity { } /** - * Adds support for validating a username and password using - * Simple Authentication + * Adds support for validating a username and password using Simple + * Authentication * @param simple a customizer * @return RSocketSecurity for additional configuration * @since 5.3 @@ -164,12 +161,106 @@ public class RSocketSecurity { return this; } + /** + * Adds authentication with BasicAuthenticationPayloadExchangeConverter. + * @param basic + * @return this instance + * @deprecated Use {@link #simpleAuthentication(Customizer)} + */ + @Deprecated + public RSocketSecurity basicAuthentication(Customizer basic) { + if (this.basicAuthSpec == null) { + this.basicAuthSpec = new BasicAuthenticationSpec(); + } + basic.customize(this.basicAuthSpec); + return this; + } + + public RSocketSecurity jwt(Customizer jwt) { + if (this.jwtSpec == null) { + this.jwtSpec = new JwtSpec(); + } + jwt.customize(this.jwtSpec); + return this; + } + + public RSocketSecurity authorizePayload(Customizer authorize) { + if (this.authorizePayload == null) { + this.authorizePayload = new AuthorizePayloadsSpec(); + } + authorize.customize(this.authorizePayload); + return this; + } + + public PayloadSocketAcceptorInterceptor build() { + PayloadSocketAcceptorInterceptor interceptor = new PayloadSocketAcceptorInterceptor(payloadInterceptors()); + RSocketMessageHandler handler = getBean(RSocketMessageHandler.class); + interceptor.setDefaultDataMimeType(handler.getDefaultDataMimeType()); + interceptor.setDefaultMetadataMimeType(handler.getDefaultMetadataMimeType()); + return interceptor; + } + + private List payloadInterceptors() { + List result = new ArrayList<>(this.payloadInterceptors); + if (this.basicAuthSpec != null) { + result.add(this.basicAuthSpec.build()); + } + if (this.simpleAuthSpec != null) { + result.add(this.simpleAuthSpec.build()); + } + if (this.jwtSpec != null) { + result.addAll(this.jwtSpec.build()); + } + result.add(anonymous()); + if (this.authorizePayload != null) { + result.add(this.authorizePayload.build()); + } + AnnotationAwareOrderComparator.sort(result); + return result; + } + + private AnonymousPayloadInterceptor anonymous() { + AnonymousPayloadInterceptor result = new AnonymousPayloadInterceptor("anonymousUser"); + result.setOrder(PayloadInterceptorOrder.ANONYMOUS.getOrder()); + return result; + } + + private T getBean(Class beanClass) { + if (this.context == null) { + return null; + } + return this.context.getBean(beanClass); + } + + private T getBeanOrNull(Class beanClass) { + return getBeanOrNull(ResolvableType.forClass(beanClass)); + } + + private T getBeanOrNull(ResolvableType type) { + if (this.context == null) { + return null; + } + String[] names = this.context.getBeanNamesForType(type); + if (names.length == 1) { + return (T) this.context.getBean(names[0]); + } + return null; + } + + protected void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.context = applicationContext; + } + /** * @since 5.3 */ - public class SimpleAuthenticationSpec { + public final class SimpleAuthenticationSpec { + private ReactiveAuthenticationManager authenticationManager; + private SimpleAuthenticationSpec() { + } + public SimpleAuthenticationSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; return this; @@ -190,28 +281,15 @@ public class RSocketSecurity { return result; } - private SimpleAuthenticationSpec() {} } - /** - * Adds authentication with BasicAuthenticationPayloadExchangeConverter. - * - * @param basic - * @return - * @deprecated Use {@link #simpleAuthentication(Customizer)} - */ - @Deprecated - public RSocketSecurity basicAuthentication(Customizer basic) { - if (this.basicAuthSpec == null) { - this.basicAuthSpec = new BasicAuthenticationSpec(); - } - basic.customize(this.basicAuthSpec); - return this; - } + public final class BasicAuthenticationSpec { - public class BasicAuthenticationSpec { private ReactiveAuthenticationManager authenticationManager; + private BasicAuthenticationSpec() { + } + public BasicAuthenticationSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; return this; @@ -231,20 +309,15 @@ public class RSocketSecurity { return result; } - private BasicAuthenticationSpec() {} } - public RSocketSecurity jwt(Customizer jwt) { - if (this.jwtSpec == null) { - this.jwtSpec = new JwtSpec(); - } - jwt.customize(this.jwtSpec); - return this; - } + public final class JwtSpec { - public class JwtSpec { private ReactiveAuthenticationManager authenticationManager; + private JwtSpec() { + } + public JwtSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; return this; @@ -267,73 +340,27 @@ public class RSocketSecurity { AuthenticationPayloadInterceptor legacy = new AuthenticationPayloadInterceptor(manager); legacy.setAuthenticationConverter(new BearerPayloadExchangeConverter()); legacy.setOrder(PayloadInterceptorOrder.AUTHENTICATION.getOrder()); - AuthenticationPayloadInterceptor standard = new AuthenticationPayloadInterceptor(manager); standard.setAuthenticationConverter(new AuthenticationPayloadExchangeConverter()); standard.setOrder(PayloadInterceptorOrder.AUTHENTICATION.getOrder()); - return Arrays.asList(standard, legacy); } - private JwtSpec() {} - } - - public RSocketSecurity authorizePayload(Customizer authorize) { - if (this.authorizePayload == null) { - this.authorizePayload = new AuthorizePayloadsSpec(); - } - authorize.customize(this.authorizePayload); - return this; - } - - public PayloadSocketAcceptorInterceptor build() { - PayloadSocketAcceptorInterceptor interceptor = new PayloadSocketAcceptorInterceptor( - payloadInterceptors()); - RSocketMessageHandler handler = getBean(RSocketMessageHandler.class); - interceptor.setDefaultDataMimeType(handler.getDefaultDataMimeType()); - interceptor.setDefaultMetadataMimeType(handler.getDefaultMetadataMimeType()); - return interceptor; - } - - private List payloadInterceptors() { - List result = new ArrayList<>(this.payloadInterceptors); - - if (this.basicAuthSpec != null) { - result.add(this.basicAuthSpec.build()); - } - if (this.simpleAuthSpec != null) { - result.add(this.simpleAuthSpec.build()); - } - if (this.jwtSpec != null) { - result.addAll(this.jwtSpec.build()); - } - result.add(anonymous()); - - if (this.authorizePayload != null) { - result.add(this.authorizePayload.build()); - } - AnnotationAwareOrderComparator.sort(result); - return result; - } - - private AnonymousPayloadInterceptor anonymous() { - AnonymousPayloadInterceptor result = new AnonymousPayloadInterceptor("anonymousUser"); - result.setOrder(PayloadInterceptorOrder.ANONYMOUS.getOrder()); - return result; } public class AuthorizePayloadsSpec { - private PayloadExchangeMatcherReactiveAuthorizationManager.Builder authzBuilder = - PayloadExchangeMatcherReactiveAuthorizationManager.builder(); + private PayloadExchangeMatcherReactiveAuthorizationManager.Builder authzBuilder = PayloadExchangeMatcherReactiveAuthorizationManager + .builder(); public Access setup() { return matcher(PayloadExchangeMatchers.setup()); } /** - * Matches if {@link org.springframework.security.rsocket.api.PayloadExchangeType#isRequest()} is true, else - * not a match + * Matches if + * {@link org.springframework.security.rsocket.api.PayloadExchangeType#isRequest()} + * is true, else not a match * @return the Access to set up the authorization rule. */ public Access anyRequest() { @@ -356,10 +383,8 @@ public class RSocketSecurity { public Access route(String pattern) { RSocketMessageHandler handler = getBean(RSocketMessageHandler.class); - PayloadExchangeMatcher matcher = new RoutePayloadExchangeMatcher( - handler.getMetadataExtractor(), - handler.getRouteMatcher(), - pattern); + PayloadExchangeMatcher matcher = new RoutePayloadExchangeMatcher(handler.getMetadataExtractor(), + handler.getRouteMatcher(), pattern); return matcher(matcher); } @@ -367,7 +392,7 @@ public class RSocketSecurity { return new Access(matcher); } - public class Access { + public final class Access { private final PayloadExchangeMatcher matcher; @@ -392,8 +417,7 @@ public class RSocketSecurity { } public AuthorizePayloadsSpec permitAll() { - return access((a, ctx) -> Mono - .just(new AuthorizationDecision(true))); + return access((a, ctx) -> Mono.just(new AuthorizationDecision(true))); } public AuthorizePayloadsSpec hasAnyAuthority(String... authorities) { @@ -402,41 +426,17 @@ public class RSocketSecurity { public AuthorizePayloadsSpec access( ReactiveAuthorizationManager authorization) { - AuthorizePayloadsSpec.this.authzBuilder.add(new PayloadExchangeMatcherEntry<>(this.matcher, authorization)); + AuthorizePayloadsSpec.this.authzBuilder + .add(new PayloadExchangeMatcherEntry<>(this.matcher, authorization)); return AuthorizePayloadsSpec.this; } public AuthorizePayloadsSpec denyAll() { - return access((a, ctx) -> Mono - .just(new AuthorizationDecision(false))); + return access((a, ctx) -> Mono.just(new AuthorizationDecision(false))); } + } + } - private T getBean(Class beanClass) { - if (this.context == null) { - return null; - } - return this.context.getBean(beanClass); - } - - private T getBeanOrNull(Class beanClass) { - return getBeanOrNull(ResolvableType.forClass(beanClass)); - } - - private T getBeanOrNull(ResolvableType type) { - if (this.context == null) { - return null; - } - String[] names = this.context.getBeanNamesForType(type); - if (names.length == 1) { - return (T) this.context.getBean(names[0]); - } - return null; - } - - protected void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - this.context = applicationContext; - } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java index fdf9bd31bc..ea5c2f1a32 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java @@ -34,6 +34,7 @@ import org.springframework.security.crypto.password.PasswordEncoder; class RSocketSecurityConfiguration { private static final String BEAN_NAME_PREFIX = "org.springframework.security.config.annotation.rsocket.RSocketSecurityConfiguration."; + private static final String RSOCKET_SECURITY_BEAN_NAME = BEAN_NAME_PREFIX + "rsocketSecurity"; private ReactiveAuthenticationManager authenticationManager; @@ -43,8 +44,7 @@ class RSocketSecurityConfiguration { private PasswordEncoder passwordEncoder; @Autowired(required = false) - void setAuthenticationManager( - ReactiveAuthenticationManager authenticationManager) { + void setAuthenticationManager(ReactiveAuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; } @@ -60,9 +60,8 @@ class RSocketSecurityConfiguration { @Bean(name = RSOCKET_SECURITY_BEAN_NAME) @Scope("prototype") - public RSocketSecurity rsocketSecurity(ApplicationContext context) { - RSocketSecurity security = new RSocketSecurity() - .authenticationManager(authenticationManager()); + RSocketSecurity rsocketSecurity(ApplicationContext context) { + RSocketSecurity security = new RSocketSecurity().authenticationManager(authenticationManager()); security.setApplicationContext(context); return security; } @@ -72,8 +71,8 @@ class RSocketSecurityConfiguration { return this.authenticationManager; } if (this.reactiveUserDetailsService != null) { - UserDetailsRepositoryReactiveAuthenticationManager manager = - new UserDetailsRepositoryReactiveAuthenticationManager(this.reactiveUserDetailsService); + UserDetailsRepositoryReactiveAuthenticationManager manager = new UserDetailsRepositoryReactiveAuthenticationManager( + this.reactiveUserDetailsService); if (this.passwordEncoder != null) { manager.setPasswordEncoder(this.passwordEncoder); } @@ -81,4 +80,5 @@ class RSocketSecurityConfiguration { } return null; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/rsocket/SecuritySocketAcceptorInterceptorConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/rsocket/SecuritySocketAcceptorInterceptorConfiguration.java index cdd007d61e..00019ba782 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/rsocket/SecuritySocketAcceptorInterceptorConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/rsocket/SecuritySocketAcceptorInterceptorConfiguration.java @@ -31,29 +31,31 @@ import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcher. */ @Configuration(proxyBeanMethods = false) class SecuritySocketAcceptorInterceptorConfiguration { + @Bean SecuritySocketAcceptorInterceptor securitySocketAcceptorInterceptor( - ObjectProvider rsocketInterceptor, ObjectProvider rsocketSecurity) { + ObjectProvider rsocketInterceptor, + ObjectProvider rsocketSecurity) { PayloadSocketAcceptorInterceptor delegate = rsocketInterceptor .getIfAvailable(() -> defaultInterceptor(rsocketSecurity)); return new SecuritySocketAcceptorInterceptor(delegate); } - private PayloadSocketAcceptorInterceptor defaultInterceptor( - ObjectProvider rsocketSecurity) { + private PayloadSocketAcceptorInterceptor defaultInterceptor(ObjectProvider rsocketSecurity) { RSocketSecurity rsocket = rsocketSecurity.getIfAvailable(); if (rsocket == null) { throw new NoSuchBeanDefinitionException("No RSocketSecurity defined"); } - rsocket - .basicAuthentication(Customizer.withDefaults()) + // @formatter:off + rsocket.basicAuthentication(Customizer.withDefaults()) .simpleAuthentication(Customizer.withDefaults()) - .authorizePayload(authz -> - authz - .setup().authenticated() - .anyRequest().authenticated() - .matcher(e -> MatchResult.match()).permitAll() + .authorizePayload((authz) -> authz + .setup().authenticated() + .anyRequest().authenticated() + .matcher((e) -> MatchResult.match()).permitAll() ); + // @formatter:on return rsocket.build(); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index 268c0e4a54..e20d71054d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpMethod; @@ -28,22 +33,17 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - /** * A base class for registering {@link RequestMatcher}'s. For example, it might allow for * specifying which {@link RequestMatcher} require a certain level of authorization. * - * * @param The object that is returned or Chained after creating the RequestMatcher - * * @author Rob Winch * @author Ankur Pathak * @since 3.2 */ public abstract class AbstractRequestMatcherRegistry { + private static final String HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME = "mvcHandlerMappingIntrospector"; private static final RequestMatcher ANY_REQUEST = AnyRequestMatcher.INSTANCE; @@ -58,7 +58,6 @@ public abstract class AbstractRequestMatcherRegistry { /** * Gets the {@link ApplicationContext} - * * @return the {@link ApplicationContext} */ protected final ApplicationContext getApplicationContext() { @@ -67,7 +66,6 @@ public abstract class AbstractRequestMatcherRegistry { /** * Maps any request. - * * @return the object that is chained after creating the {@link RequestMatcher} */ public C anyRequest() { @@ -81,10 +79,7 @@ public abstract class AbstractRequestMatcherRegistry { * Maps a {@link List} of * {@link org.springframework.security.web.util.matcher.AntPathRequestMatcher} * instances. - * - * @param method the {@link HttpMethod} to use for any - * {@link HttpMethod}. - * + * @param method the {@link HttpMethod} to use for any {@link HttpMethod}. * @return the object that is chained after creating the {@link RequestMatcher} */ public C antMatchers(HttpMethod method) { @@ -95,12 +90,11 @@ public abstract class AbstractRequestMatcherRegistry { * Maps a {@link List} of * {@link org.springframework.security.web.util.matcher.AntPathRequestMatcher} * instances. - * * @param method the {@link HttpMethod} to use or {@code null} for any * {@link HttpMethod}. - * @param antPatterns the ant patterns to create. If {@code null} or empty, then matches on nothing. + * @param antPatterns the ant patterns to create. If {@code null} or empty, then + * matches on nothing. * {@link org.springframework.security.web.util.matcher.AntPathRequestMatcher} from - * * @return the object that is chained after creating the {@link RequestMatcher} */ public C antMatchers(HttpMethod method, String... antPatterns) { @@ -112,10 +106,8 @@ public abstract class AbstractRequestMatcherRegistry { * Maps a {@link List} of * {@link org.springframework.security.web.util.matcher.AntPathRequestMatcher} * instances that do not care which {@link HttpMethod} is used. - * * @param antPatterns the ant patterns to create * {@link org.springframework.security.web.util.matcher.AntPathRequestMatcher} from - * * @return the object that is chained after creating the {@link RequestMatcher} */ public C antMatchers(String... antPatterns) { @@ -134,7 +126,6 @@ public abstract class AbstractRequestMatcherRegistry { * If the current request will not be processed by Spring MVC, a reasonable default * using the pattern as a ant pattern will be used. *

    - * * @param mvcPatterns the patterns to match on. The rules for matching are defined by * Spring MVC * @return the object that is chained after creating the {@link RequestMatcher}. @@ -152,7 +143,6 @@ public abstract class AbstractRequestMatcherRegistry { * If the current request will not be processed by Spring MVC, a reasonable default * using the pattern as a ant pattern will be used. *

    - * * @param method the HTTP method to match on * @param mvcPatterns the patterns to match on. The rules for matching are defined by * Spring MVC @@ -162,27 +152,24 @@ public abstract class AbstractRequestMatcherRegistry { /** * Creates {@link MvcRequestMatcher} instances for the method and patterns passed in - * * @param method the HTTP method to use or null if any should be used * @param mvcPatterns the Spring MVC patterns to match on * @return a List of {@link MvcRequestMatcher} instances */ - protected final List createMvcMatchers(HttpMethod method, - String... mvcPatterns) { + protected final List createMvcMatchers(HttpMethod method, String... mvcPatterns) { Assert.state(!this.anyRequestConfigured, "Can't configure mvcMatchers after anyRequest"); ObjectPostProcessor opp = this.context.getBean(ObjectPostProcessor.class); if (!this.context.containsBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME)) { - throw new NoSuchBeanDefinitionException("A Bean named " + HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME +" of type " + HandlerMappingIntrospector.class.getName() - + " is required to use MvcRequestMatcher. Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext."); + throw new NoSuchBeanDefinitionException("A Bean named " + HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME + + " of type " + HandlerMappingIntrospector.class.getName() + + " is required to use MvcRequestMatcher. Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext."); } HandlerMappingIntrospector introspector = this.context.getBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, - HandlerMappingIntrospector.class); - List matchers = new ArrayList<>( - mvcPatterns.length); + HandlerMappingIntrospector.class); + List matchers = new ArrayList<>(mvcPatterns.length); for (String mvcPattern : mvcPatterns) { MvcRequestMatcher matcher = new MvcRequestMatcher(introspector, mvcPattern); opp.postProcess(matcher); - if (method != null) { matcher.setMethod(method); } @@ -195,12 +182,10 @@ public abstract class AbstractRequestMatcherRegistry { * Maps a {@link List} of * {@link org.springframework.security.web.util.matcher.RegexRequestMatcher} * instances. - * * @param method the {@link HttpMethod} to use or {@code null} for any * {@link HttpMethod}. * @param regexPatterns the regular expressions to create * {@link org.springframework.security.web.util.matcher.RegexRequestMatcher} from - * * @return the object that is chained after creating the {@link RequestMatcher} */ public C regexMatchers(HttpMethod method, String... regexPatterns) { @@ -212,10 +197,8 @@ public abstract class AbstractRequestMatcherRegistry { * Create a {@link List} of * {@link org.springframework.security.web.util.matcher.RegexRequestMatcher} instances * that do not specify an {@link HttpMethod}. - * * @param regexPatterns the regular expressions to create * {@link org.springframework.security.web.util.matcher.RegexRequestMatcher} from - * * @return the object that is chained after creating the {@link RequestMatcher} */ public C regexMatchers(String... regexPatterns) { @@ -226,9 +209,7 @@ public abstract class AbstractRequestMatcherRegistry { /** * Associates a list of {@link RequestMatcher} instances with the * {@link AbstractConfigAttributeRequestMatcherRegistry} - * * @param requestMatchers the {@link RequestMatcher} instances - * * @return the object that is chained after creating the {@link RequestMatcher} */ public C requestMatchers(RequestMatcher... requestMatchers) { @@ -239,7 +220,6 @@ public abstract class AbstractRequestMatcherRegistry { /** * Subclasses should implement this method for returning the object that is chained to * the creation of the {@link RequestMatcher} instances. - * * @param requestMatchers the {@link RequestMatcher} instances that were created * @return the chained Object for the subclass which allows association of something * else to the {@link RequestMatcher} @@ -254,19 +234,19 @@ public abstract class AbstractRequestMatcherRegistry { */ private static final class RequestMatchers { + private RequestMatchers() { + } + /** * Create a {@link List} of {@link AntPathRequestMatcher} instances. - * * @param httpMethod the {@link HttpMethod} to use or {@code null} for any * {@link HttpMethod}. * @param antPatterns the ant patterns to create {@link AntPathRequestMatcher} * from - * * @return a {@link List} of {@link AntPathRequestMatcher} instances */ - public static List antMatchers(HttpMethod httpMethod, - String... antPatterns) { - String method = httpMethod == null ? null : httpMethod.toString(); + static List antMatchers(HttpMethod httpMethod, String... antPatterns) { + String method = (httpMethod != null) ? httpMethod.toString() : null; List matchers = new ArrayList<>(); for (String pattern : antPatterns) { matchers.add(new AntPathRequestMatcher(pattern, method)); @@ -277,29 +257,24 @@ public abstract class AbstractRequestMatcherRegistry { /** * Create a {@link List} of {@link AntPathRequestMatcher} instances that do not * specify an {@link HttpMethod}. - * * @param antPatterns the ant patterns to create {@link AntPathRequestMatcher} * from - * * @return a {@link List} of {@link AntPathRequestMatcher} instances */ - public static List antMatchers(String... antPatterns) { + static List antMatchers(String... antPatterns) { return antMatchers(null, antPatterns); } /** * Create a {@link List} of {@link RegexRequestMatcher} instances. - * * @param httpMethod the {@link HttpMethod} to use or {@code null} for any * {@link HttpMethod}. * @param regexPatterns the regular expressions to create * {@link RegexRequestMatcher} from - * * @return a {@link List} of {@link RegexRequestMatcher} instances */ - public static List regexMatchers(HttpMethod httpMethod, - String... regexPatterns) { - String method = httpMethod == null ? null : httpMethod.toString(); + static List regexMatchers(HttpMethod httpMethod, String... regexPatterns) { + String method = (httpMethod != null) ? httpMethod.toString() : null; List matchers = new ArrayList<>(); for (String pattern : regexPatterns) { matchers.add(new RegexRequestMatcher(pattern, method)); @@ -310,18 +285,14 @@ public abstract class AbstractRequestMatcherRegistry { /** * Create a {@link List} of {@link RegexRequestMatcher} instances that do not * specify an {@link HttpMethod}. - * * @param regexPatterns the regular expressions to create * {@link RegexRequestMatcher} from - * * @return a {@link List} of {@link RegexRequestMatcher} instances */ - public static List regexMatchers(String... regexPatterns) { + static List regexMatchers(String... regexPatterns) { return regexMatchers(null, regexPatterns); } - private RequestMatchers() { - } } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/HttpSecurityBuilder.java b/config/src/main/java/org/springframework/security/config/annotation/web/HttpSecurityBuilder.java index 510f40cba6..475f2de2f5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/HttpSecurityBuilder.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/HttpSecurityBuilder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; import javax.servlet.Filter; @@ -44,36 +45,29 @@ import org.springframework.security.web.session.ConcurrentSessionFilter; import org.springframework.security.web.session.SessionManagementFilter; /** - * - * @author Rob Winch - * * @param + * @author Rob Winch */ -public interface HttpSecurityBuilder> extends - SecurityBuilder { +public interface HttpSecurityBuilder> + extends SecurityBuilder { /** * Gets the {@link SecurityConfigurer} by its class name or null if not * found. Note that object hierarchies are not considered. - * * @param clazz the Class of the {@link SecurityConfigurer} to attempt to get. */ - > C getConfigurer( - Class clazz); + > C getConfigurer(Class clazz); /** * Removes the {@link SecurityConfigurer} by its class name or null if * not found. Note that object hierarchies are not considered. - * * @param clazz the Class of the {@link SecurityConfigurer} to attempt to remove. * @return the {@link SecurityConfigurer} that was removed or null if not found */ - > C removeConfigurer( - Class clazz); + > C removeConfigurer(Class clazz); /** * Sets an object that is shared by multiple {@link SecurityConfigurer}. - * * @param sharedType the Class to key the shared object by. * @param object the Object to store */ @@ -81,7 +75,6 @@ public interface HttpSecurityBuilder> extends /** * Gets a shared Object. Note that object heirarchies are not considered. - * * @param sharedType the type of the shared Object * @return the shared Object or null if it is not found */ @@ -89,7 +82,6 @@ public interface HttpSecurityBuilder> extends /** * Allows adding an additional {@link AuthenticationProvider} to be used - * * @param authenticationProvider the {@link AuthenticationProvider} to be added * @return the {@link HttpSecurity} for further customizations */ @@ -97,7 +89,6 @@ public interface HttpSecurityBuilder> extends /** * Allows adding an additional {@link UserDetailsService} to be used - * * @param userDetailsService the {@link UserDetailsService} to be added * @return the {@link HttpSecurity} for further customizations */ @@ -108,7 +99,6 @@ public interface HttpSecurityBuilder> extends * known {@link Filter} instances are either a {@link Filter} listed in * {@link #addFilter(Filter)} or a {@link Filter} that has already been added using * {@link #addFilterAfter(Filter, Class)} or {@link #addFilterBefore(Filter, Class)}. - * * @param filter the {@link Filter} to register after the type {@code afterFilter} * @param afterFilter the Class of the known {@link Filter}. * @return the {@link HttpSecurity} for further customizations @@ -120,7 +110,6 @@ public interface HttpSecurityBuilder> extends * known {@link Filter} instances are either a {@link Filter} listed in * {@link #addFilter(Filter)} or a {@link Filter} that has already been added using * {@link #addFilterAfter(Filter, Class)} or {@link #addFilterBefore(Filter, Class)}. - * * @param filter the {@link Filter} to register before the type {@code beforeFilter} * @param beforeFilter the Class of the known {@link Filter}. * @return the {@link HttpSecurity} for further customizations @@ -140,7 +129,8 @@ public interface HttpSecurityBuilder> extends *
  • {@link LogoutFilter}
  • *
  • {@link X509AuthenticationFilter}
  • *
  • {@link AbstractPreAuthenticatedProcessingFilter}
  • - *
  • CasAuthenticationFilter
  • + *
  • CasAuthenticationFilter
  • *
  • {@link UsernamePasswordAuthenticationFilter}
  • *
  • {@link OpenIDAuthenticationFilter}
  • *
  • {@link org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter}
  • @@ -159,9 +149,9 @@ public interface HttpSecurityBuilder> extends *
  • {@link FilterSecurityInterceptor}
  • *
  • {@link SwitchUserFilter}
  • * - * * @param filter the {@link Filter} to add * @return the {@link HttpSecurity} for further customizations */ H addFilter(Filter filter); + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/WebSecurityConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/WebSecurityConfigurer.java index e97f2cf4da..c7bc0578d5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/WebSecurityConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/WebSecurityConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; import javax.servlet.Filter; @@ -28,17 +29,15 @@ import org.springframework.security.web.SecurityFilterChain; /** * Allows customization to the {@link WebSecurity}. In most instances users will use * {@link EnableWebSecurity} and either create a {@link Configuration} that extends - * {@link WebSecurityConfigurerAdapter} or expose a {@link SecurityFilterChain} bean. - * Both will automatically be applied to the {@link WebSecurity} by the + * {@link WebSecurityConfigurerAdapter} or expose a {@link SecurityFilterChain} bean. Both + * will automatically be applied to the {@link WebSecurity} by the * {@link EnableWebSecurity} annotation. * - * @see WebSecurityConfigurerAdapter - * @see SecurityFilterChain - * * @author Rob Winch * @since 3.2 + * @see WebSecurityConfigurerAdapter + * @see SecurityFilterChain */ -public interface WebSecurityConfigurer> extends - SecurityConfigurer { +public interface WebSecurityConfigurer> extends SecurityConfigurer { } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java index 4b96267a59..9c07581daa 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.builders; import java.io.Serializable; import java.util.Comparator; import java.util.HashMap; import java.util.Map; + import javax.servlet.Filter; import org.springframework.security.web.access.ExceptionTranslationFilter; @@ -44,6 +46,7 @@ import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.session.ConcurrentSessionFilter; import org.springframework.security.web.session.SessionManagementFilter; +import org.springframework.util.Assert; import org.springframework.web.filter.CorsFilter; /** @@ -56,8 +59,11 @@ import org.springframework.web.filter.CorsFilter; @SuppressWarnings("serial") final class FilterComparator implements Comparator, Serializable { + private static final int INITIAL_ORDER = 100; + private static final int ORDER_STEP = 100; + private final Map filterToOrder = new HashMap<>(); FilterComparator() { @@ -70,40 +76,37 @@ final class FilterComparator implements Comparator, Serializable { put(CorsFilter.class, order.next()); put(CsrfFilter.class, order.next()); put(LogoutFilter.class, order.next()); - filterToOrder.put( - "org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter", + this.filterToOrder.put( + "org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter", order.next()); - filterToOrder.put( + this.filterToOrder.put( "org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter", order.next()); put(X509AuthenticationFilter.class, order.next()); put(AbstractPreAuthenticatedProcessingFilter.class, order.next()); - filterToOrder.put("org.springframework.security.cas.web.CasAuthenticationFilter", + this.filterToOrder.put("org.springframework.security.cas.web.CasAuthenticationFilter", order.next()); + this.filterToOrder.put("org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter", order.next()); - filterToOrder.put( - "org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter", - order.next()); - filterToOrder.put( + this.filterToOrder.put( "org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter", order.next()); put(UsernamePasswordAuthenticationFilter.class, order.next()); order.next(); // gh-8105 - filterToOrder.put( - "org.springframework.security.openid.OpenIDAuthenticationFilter", order.next()); + this.filterToOrder.put("org.springframework.security.openid.OpenIDAuthenticationFilter", order.next()); put(DefaultLoginPageGeneratingFilter.class, order.next()); put(DefaultLogoutPageGeneratingFilter.class, order.next()); put(ConcurrentSessionFilter.class, order.next()); put(DigestAuthenticationFilter.class, order.next()); - filterToOrder.put( - "org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationFilter", order.next()); + this.filterToOrder.put( + "org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationFilter", + order.next()); put(BasicAuthenticationFilter.class, order.next()); put(RequestCacheAwareFilter.class, order.next()); put(SecurityContextHolderAwareRequestFilter.class, order.next()); put(JaasApiIntegrationFilter.class, order.next()); put(RememberMeAuthenticationFilter.class, order.next()); put(AnonymousAuthenticationFilter.class, order.next()); - filterToOrder.put( - "org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter", + this.filterToOrder.put("org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter", order.next()); put(SessionManagementFilter.class, order.next()); put(ExceptionTranslationFilter.class, order.next()); @@ -111,6 +114,7 @@ final class FilterComparator implements Comparator, Serializable { put(SwitchUserFilter.class, order.next()); } + @Override public int compare(Filter lhs, Filter rhs) { Integer left = getOrder(lhs.getClass()); Integer right = getOrder(rhs.getClass()); @@ -119,11 +123,10 @@ final class FilterComparator implements Comparator, Serializable { /** * Determines if a particular {@link Filter} is registered to be sorted - * * @param filter * @return */ - public boolean isRegistered(Class filter) { + boolean isRegistered(Class filter) { return getOrder(filter) != null; } @@ -134,14 +137,9 @@ final class FilterComparator implements Comparator, Serializable { * @param afterFilter the {@link Filter} that is already registered and that * {@code filter} should be placed after. */ - public void registerAfter(Class filter, - Class afterFilter) { + void registerAfter(Class filter, Class afterFilter) { Integer position = getOrder(afterFilter); - if (position == null) { - throw new IllegalArgumentException( - "Cannot register after unregistered Filter " + afterFilter); - } - + Assert.notNull(position, () -> "Cannot register after unregistered Filter " + afterFilter); put(filter, position + 1); } @@ -151,14 +149,9 @@ final class FilterComparator implements Comparator, Serializable { * @param atFilter the {@link Filter} that is already registered and that * {@code filter} should be placed at. */ - public void registerAt(Class filter, - Class atFilter) { + void registerAt(Class filter, Class atFilter) { Integer position = getOrder(atFilter); - if (position == null) { - throw new IllegalArgumentException( - "Cannot register after unregistered Filter " + atFilter); - } - + Assert.notNull(position, () -> "Cannot register after unregistered Filter " + atFilter); put(filter, position); } @@ -169,32 +162,26 @@ final class FilterComparator implements Comparator, Serializable { * @param beforeFilter the {@link Filter} that is already registered and that * {@code filter} should be placed before. */ - public void registerBefore(Class filter, - Class beforeFilter) { + void registerBefore(Class filter, Class beforeFilter) { Integer position = getOrder(beforeFilter); - if (position == null) { - throw new IllegalArgumentException( - "Cannot register after unregistered Filter " + beforeFilter); - } - + Assert.notNull(position, () -> "Cannot register after unregistered Filter " + beforeFilter); put(filter, position - 1); } private void put(Class filter, int position) { String className = filter.getName(); - filterToOrder.put(className, position); + this.filterToOrder.put(className, position); } /** * Gets the order of a particular {@link Filter} class taking into consideration * superclasses. - * * @param clazz the {@link Filter} class to determine the sort order * @return the sort order or null if not defined */ private Integer getOrder(Class clazz) { while (clazz != null) { - Integer result = filterToOrder.get(clazz.getName()); + Integer result = this.filterToOrder.get(clazz.getName()); if (result != null) { return result; } @@ -206,6 +193,7 @@ final class FilterComparator implements Comparator, Serializable { private static class Step { private int value; + private final int stepSize; Step(int initialValue, int stepSize) { diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java index b056bdc898..002676eef2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.builders; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import javax.servlet.Filter; +import javax.servlet.http.HttpServletRequest; + import org.springframework.context.ApplicationContext; import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AuthenticationManager; @@ -78,12 +86,6 @@ import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.filter.CorsFilter; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import javax.servlet.Filter; -import javax.servlet.http.HttpServletRequest; - /** * A {@link HttpSecurity} is similar to Spring Security's XML <http> element in the * namespace configuration. It allows configuring web based security for specific http @@ -120,13 +122,15 @@ import javax.servlet.http.HttpServletRequest; * @since 3.2 * @see EnableWebSecurity */ -public final class HttpSecurity extends - AbstractConfiguredSecurityBuilder - implements SecurityBuilder, - HttpSecurityBuilder { +public final class HttpSecurity extends AbstractConfiguredSecurityBuilder + implements SecurityBuilder, HttpSecurityBuilder { + private final RequestMatcherConfigurer requestMatcherConfigurer; + private List filters = new ArrayList<>(); + private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE; + private FilterComparator comparator = new FilterComparator(); /** @@ -139,17 +143,14 @@ public final class HttpSecurity extends */ @SuppressWarnings("unchecked") public HttpSecurity(ObjectPostProcessor objectPostProcessor, - AuthenticationManagerBuilder authenticationBuilder, - Map, Object> sharedObjects) { + AuthenticationManagerBuilder authenticationBuilder, Map, Object> sharedObjects) { super(objectPostProcessor); Assert.notNull(authenticationBuilder, "authenticationBuilder cannot be null"); setSharedObject(AuthenticationManagerBuilder.class, authenticationBuilder); - for (Map.Entry, Object> entry : sharedObjects - .entrySet()) { + for (Map.Entry, Object> entry : sharedObjects.entrySet()) { setSharedObject((Class) entry.getKey(), entry.getValue()); } - ApplicationContext context = (ApplicationContext) sharedObjects - .get(ApplicationContext.class); + ApplicationContext context = (ApplicationContext) sharedObjects.get(ApplicationContext.class); this.requestMatcherConfigurer = new RequestMatcherConfigurer(context); } @@ -231,14 +232,15 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link OpenIDLoginConfigurer} for further customizations. - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @throws Exception + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. * @see OpenIDLoginConfigurer */ + @Deprecated public OpenIDLoginConfigurer openidLogin() throws Exception { return getOrApply(new OpenIDLoginConfigurer<>()); } @@ -258,11 +260,11 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) - * .openidLogin(openidLogin -> + * .openidLogin((openidLogin) -> * openidLogin * .permitAll() * ); @@ -291,48 +293,48 @@ public final class HttpSecurity extends * * @Override * protected void configure(HttpSecurity http) throws Exception { - * http.authorizeRequests(authorizeRequests -> + * http.authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) - * .openidLogin(openidLogin -> + * .openidLogin((openidLogin) -> * openidLogin * .loginPage("/login") * .permitAll() * .authenticationUserDetailsService( * new AutoProvisioningUserDetailsService()) - * .attributeExchange(googleExchange -> + * .attributeExchange((googleExchange) -> * googleExchange * .identifierPattern("https://www.google.com/.*") - * .attribute(emailAttribute -> + * .attribute((emailAttribute) -> * emailAttribute * .name("email") * .type("https://axschema.org/contact/email") * .required(true) * ) - * .attribute(firstnameAttribute -> + * .attribute((firstnameAttribute) -> * firstnameAttribute * .name("firstname") * .type("https://axschema.org/namePerson/first") * .required(true) * ) - * .attribute(lastnameAttribute -> + * .attribute((lastnameAttribute) -> * lastnameAttribute * .name("lastname") * .type("https://axschema.org/namePerson/last") * .required(true) * ) * ) - * .attributeExchange(yahooExchange -> + * .attributeExchange((yahooExchange) -> * yahooExchange * .identifierPattern(".*yahoo.com.*") - * .attribute(emailAttribute -> + * .attribute((emailAttribute) -> * emailAttribute * .name("email") * .type("https://schema.openid.net/contact/email") * .required(true) * ) - * .attribute(fullnameAttribute -> + * .attribute((fullnameAttribute) -> * fullnameAttribute * .name("fullname") * .type("https://axschema.org/namePerson") @@ -352,26 +354,27 @@ public final class HttpSecurity extends * } * } * - * - * @see OpenIDLoginConfigurer - * - * @param openidLoginCustomizer the {@link Customizer} to provide more options for - * the {@link OpenIDLoginConfigurer} - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. + * @param openidLoginCustomizer the {@link Customizer} to provide more options for the + * {@link OpenIDLoginConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. + * @see OpenIDLoginConfigurer */ - public HttpSecurity openidLogin(Customizer> openidLoginCustomizer) throws Exception { + @Deprecated + public HttpSecurity openidLogin(Customizer> openidLoginCustomizer) + throws Exception { openidLoginCustomizer.customize(getOrApply(new OpenIDLoginConfigurer<>())); return HttpSecurity.this; } /** * Adds the Security headers to the response. This is activated by default when using - * {@link WebSecurityConfigurerAdapter}'s default constructor. Accepting the - * default provided by {@link WebSecurityConfigurerAdapter} or only invoking + * {@link WebSecurityConfigurerAdapter}'s default constructor. Accepting the default + * provided by {@link WebSecurityConfigurerAdapter} or only invoking * {@link #headers()} without invoking additional methods on it, is the equivalent of: * *
    @@ -415,9 +418,9 @@ public final class HttpSecurity extends
     	 * 
    * * You can enable only a few of the headers by first invoking - * {@link HeadersConfigurer#defaultsDisabled()} - * and then invoking the appropriate methods on the {@link #headers()} result. - * For example, the following will enable {@link HeadersConfigurer#cacheControl()} and + * {@link HeadersConfigurer#defaultsDisabled()} and then invoking the appropriate + * methods on the {@link #headers()} result. For example, the following will enable + * {@link HeadersConfigurer#cacheControl()} and * {@link HeadersConfigurer#frameOptions()} only. * *
    @@ -439,8 +442,8 @@ public final class HttpSecurity extends
     	 * }
     	 * 
    * - * You can also choose to keep the defaults but explicitly disable a subset of headers. - * For example, the following will enable all the default headers except + * You can also choose to keep the defaults but explicitly disable a subset of + * headers. For example, the following will enable all the default headers except * {@link HeadersConfigurer#frameOptions()}. * *
    @@ -459,7 +462,6 @@ public final class HttpSecurity extends
     	 *     }
     	 * }
     	 * 
    - * * @return the {@link HeadersConfigurer} for further customizations * @throws Exception * @see HeadersConfigurer @@ -474,8 +476,9 @@ public final class HttpSecurity extends * *

    Example Configurations

    * - * Accepting the default provided by {@link WebSecurityConfigurerAdapter} or only invoking - * {@link #headers()} without invoking additional methods on it, is the equivalent of: + * Accepting the default provided by {@link WebSecurityConfigurerAdapter} or only + * invoking {@link #headers()} without invoking additional methods on it, is the + * equivalent of: * *
     	 * @Configuration
    @@ -485,7 +488,7 @@ public final class HttpSecurity extends
     	 *	@Override
     	 *	protected void configure(HttpSecurity http) throws Exception {
     	 *		http
    -	 *			.headers(headers ->
    +	 *			.headers((headers) ->
     	 *				headers
     	 *					.contentTypeOptions(withDefaults())
     	 *					.xssProtection(withDefaults())
    @@ -507,15 +510,15 @@ public final class HttpSecurity extends
     	 *	@Override
     	 *	protected void configure(HttpSecurity http) throws Exception {
     	 * 		http
    -	 * 			.headers(headers -> headers.disable());
    +	 * 			.headers((headers) -> headers.disable());
     	 *	}
     	 * }
     	 * 
    * * You can enable only a few of the headers by first invoking - * {@link HeadersConfigurer#defaultsDisabled()} - * and then invoking the appropriate methods on the {@link #headers()} result. - * For example, the following will enable {@link HeadersConfigurer#cacheControl()} and + * {@link HeadersConfigurer#defaultsDisabled()} and then invoking the appropriate + * methods on the {@link #headers()} result. For example, the following will enable + * {@link HeadersConfigurer#cacheControl()} and * {@link HeadersConfigurer#frameOptions()} only. * *
    @@ -526,7 +529,7 @@ public final class HttpSecurity extends
     	 *	@Override
     	 *	protected void configure(HttpSecurity http) throws Exception {
     	 *		http
    -	 *			.headers(headers ->
    +	 *			.headers((headers) ->
     	 *				headers
     	 *			 		.defaultsDisabled()
     	 *			 		.cacheControl(withDefaults())
    @@ -536,8 +539,8 @@ public final class HttpSecurity extends
     	 * }
     	 * 
    * - * You can also choose to keep the defaults but explicitly disable a subset of headers. - * For example, the following will enable all the default headers except + * You can also choose to keep the defaults but explicitly disable a subset of + * headers. For example, the following will enable all the default headers except * {@link HeadersConfigurer#frameOptions()}. * *
    @@ -548,15 +551,14 @@ public final class HttpSecurity extends
     	 * 	@Override
     	 *  protected void configure(HttpSecurity http) throws Exception {
     	 *  	http
    -	 *  		.headers(headers ->
    +	 *  		.headers((headers) ->
     	 *  			headers
    -	 *  				.frameOptions(frameOptions -> frameOptions.disable())
    +	 *  				.frameOptions((frameOptions) -> frameOptions.disable())
     	 *  		);
     	 * }
     	 * 
    - * - * @param headersCustomizer the {@link Customizer} to provide more options for - * the {@link HeadersConfigurer} + * @param headersCustomizer the {@link Customizer} to provide more options for the + * {@link HeadersConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -570,7 +572,6 @@ public final class HttpSecurity extends * provided, that {@link CorsFilter} is used. Else if corsConfigurationSource is * defined, then that {@link CorsConfiguration} is used. Otherwise, if Spring MVC is * on the classpath a {@link HandlerMappingIntrospector} is used. - * * @return the {@link CorsConfigurer} for customizations * @throws Exception */ @@ -582,8 +583,8 @@ public final class HttpSecurity extends * Adds a {@link CorsFilter} to be used. If a bean by the name of corsFilter is * provided, that {@link CorsFilter} is used. Else if corsConfigurationSource is * defined, then that {@link CorsConfiguration} is used. Otherwise, if Spring MVC is - * on the classpath a {@link HandlerMappingIntrospector} is used. - * You can enable CORS using: + * on the classpath a {@link HandlerMappingIntrospector} is used. You can enable CORS + * using: * *
     	 * @Configuration
    @@ -597,9 +598,8 @@ public final class HttpSecurity extends
     	 *     }
     	 * }
     	 * 
    - * - * @param corsCustomizer the {@link Customizer} to provide more options for - * the {@link CorsConfigurer} + * @param corsCustomizer the {@link Customizer} to provide more options for the + * {@link CorsConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -652,7 +652,6 @@ public final class HttpSecurity extends * Alternatively, * {@link AbstractSecurityWebApplicationInitializer#enableHttpSessionEventPublisher()} * could return true. - * * @return the {@link SessionManagementConfigurer} for further customizations * @throws Exception */ @@ -678,17 +677,17 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .anyRequest().hasRole("USER") * ) - * .formLogin(formLogin -> + * .formLogin((formLogin) -> * formLogin * .permitAll() * ) - * .sessionManagement(sessionManagement -> + * .sessionManagement((sessionManagement) -> * sessionManagement - * .sessionConcurrency(sessionConcurrency -> + * .sessionConcurrency((sessionConcurrency) -> * sessionConcurrency * .maximumSessions(1) * .expiredUrl("/login?expired") @@ -713,13 +712,13 @@ public final class HttpSecurity extends * Alternatively, * {@link AbstractSecurityWebApplicationInitializer#enableHttpSessionEventPublisher()} * could return true. - * - * @param sessionManagementCustomizer the {@link Customizer} to provide more options for - * the {@link SessionManagementConfigurer} + * @param sessionManagementCustomizer the {@link Customizer} to provide more options + * for the {@link SessionManagementConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ - public HttpSecurity sessionManagement(Customizer> sessionManagementCustomizer) throws Exception { + public HttpSecurity sessionManagement( + Customizer> sessionManagementCustomizer) throws Exception { sessionManagementCustomizer.customize(getOrApply(new SessionManagementConfigurer<>())); return HttpSecurity.this; } @@ -758,7 +757,6 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link PortMapperConfigurer} for further customizations * @throws Exception * @see #requiresChannel() @@ -790,11 +788,11 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .requiresChannel(requiresChannel -> + * .requiresChannel((requiresChannel) -> * requiresChannel * .anyRequest().requiresSecure() * ) - * .portMapper(portMapper -> + * .portMapper((portMapper) -> * portMapper * .http(9090).mapsTo(9443) * .http(80).mapsTo(443) @@ -802,21 +800,21 @@ public final class HttpSecurity extends * } * } * - * - * @see #requiresChannel() - * @param portMapperCustomizer the {@link Customizer} to provide more options for - * the {@link PortMapperConfigurer} + * @param portMapperCustomizer the {@link Customizer} to provide more options for the + * {@link PortMapperConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @see #requiresChannel() */ - public HttpSecurity portMapper(Customizer> portMapperCustomizer) throws Exception { + public HttpSecurity portMapper(Customizer> portMapperCustomizer) + throws Exception { portMapperCustomizer.customize(getOrApply(new PortMapperConfigurer<>())); return HttpSecurity.this; } /** - * Configures container based pre authentication. In this case, authentication - * is managed by the Servlet Container. + * Configures container based pre authentication. In this case, authentication is + * managed by the Servlet Container. * *

    Example Configuration

    * @@ -878,7 +876,6 @@ public final class HttpSecurity extends * Last you will need to configure your container to contain the user with the correct * roles. This configuration is specific to the Servlet Container, so consult your * Servlet Container's documentation. - * * @return the {@link JeeConfigurer} for further customizations * @throws Exception */ @@ -887,8 +884,8 @@ public final class HttpSecurity extends } /** - * Configures container based pre authentication. In this case, authentication - * is managed by the Servlet Container. + * Configures container based pre authentication. In this case, authentication is + * managed by the Servlet Container. * *

    Example Configuration

    * @@ -904,11 +901,11 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) - * .jee(jee -> + * .jee((jee) -> * jee * .mappableRoles("USER", "ADMIN") * ); @@ -956,9 +953,8 @@ public final class HttpSecurity extends * Last you will need to configure your container to contain the user with the correct * roles. This configuration is specific to the Servlet Container, so consult your * Servlet Container's documentation. - * - * @param jeeCustomizer the {@link Customizer} to provide more options for - * the {@link JeeConfigurer} + * @param jeeCustomizer the {@link Customizer} to provide more options for the + * {@link JeeConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -989,7 +985,6 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link X509Configurer} for further customizations * @throws Exception */ @@ -1014,7 +1009,7 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -1022,9 +1017,8 @@ public final class HttpSecurity extends * } * } * - * - * @param x509Customizer the {@link Customizer} to provide more options for - * the {@link X509Configurer} + * @param x509Customizer the {@link Customizer} to provide more options for the + * {@link X509Configurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -1062,7 +1056,6 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link RememberMeConfigurer} for further customizations * @throws Exception */ @@ -1088,7 +1081,7 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -1097,13 +1090,13 @@ public final class HttpSecurity extends * } * } * - * - * @param rememberMeCustomizer the {@link Customizer} to provide more options for - * the {@link RememberMeConfigurer} + * @param rememberMeCustomizer the {@link Customizer} to provide more options for the + * {@link RememberMeConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ - public HttpSecurity rememberMe(Customizer> rememberMeCustomizer) throws Exception { + public HttpSecurity rememberMe(Customizer> rememberMeCustomizer) + throws Exception { rememberMeCustomizer.customize(getOrApply(new RememberMeConfigurer<>())); return HttpSecurity.this; } @@ -1167,17 +1160,14 @@ public final class HttpSecurity extends * http.authorizeRequests().antMatchers("/**").hasRole("USER").antMatchers("/admin/**") * .hasRole("ADMIN") * - * - * @see #requestMatcher(RequestMatcher) - * * @return the {@link ExpressionUrlAuthorizationConfigurer} for further customizations * @throws Exception + * @see #requestMatcher(RequestMatcher) */ public ExpressionUrlAuthorizationConfigurer.ExpressionInterceptUrlRegistry authorizeRequests() throws Exception { ApplicationContext context = getContext(); - return getOrApply(new ExpressionUrlAuthorizationConfigurer<>(context)) - .getRegistry(); + return getOrApply(new ExpressionUrlAuthorizationConfigurer<>(context)).getRegistry(); } /** @@ -1198,7 +1188,7 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -1219,7 +1209,7 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/admin/**").hasRole("ADMIN") * .antMatchers("/**").hasRole("USER") @@ -1241,7 +1231,7 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * .antMatchers("/admin/**").hasRole("ADMIN") @@ -1249,19 +1239,18 @@ public final class HttpSecurity extends * } * } * - * - * @see #requestMatcher(RequestMatcher) - * - * @param authorizeRequestsCustomizer the {@link Customizer} to provide more options for - * the {@link ExpressionUrlAuthorizationConfigurer.ExpressionInterceptUrlRegistry} + * @param authorizeRequestsCustomizer the {@link Customizer} to provide more options + * for the {@link ExpressionUrlAuthorizationConfigurer.ExpressionInterceptUrlRegistry} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @see #requestMatcher(RequestMatcher) */ - public HttpSecurity authorizeRequests(Customizer.ExpressionInterceptUrlRegistry> authorizeRequestsCustomizer) + public HttpSecurity authorizeRequests( + Customizer.ExpressionInterceptUrlRegistry> authorizeRequestsCustomizer) throws Exception { ApplicationContext context = getContext(); - authorizeRequestsCustomizer.customize(getOrApply(new ExpressionUrlAuthorizationConfigurer<>(context)) - .getRegistry()); + authorizeRequestsCustomizer + .customize(getOrApply(new ExpressionUrlAuthorizationConfigurer<>(context)).getRegistry()); return HttpSecurity.this; } @@ -1271,7 +1260,6 @@ public final class HttpSecurity extends * a login page. After authentication, Spring Security will redirect the user to the * originally requested protected page (/protected). This is automatically applied * when using {@link WebSecurityConfigurerAdapter}. - * * @return the {@link RequestCacheConfigurer} for further customizations * @throws Exception */ @@ -1298,17 +1286,16 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) - * .requestCache(requestCache -> + * .requestCache((requestCache) -> * requestCache.disable() * ); * } * } * - * * @param requestCacheCustomizer the {@link Customizer} to provide more options for * the {@link RequestCacheConfigurer} * @return the {@link HttpSecurity} for further customizations @@ -1323,7 +1310,6 @@ public final class HttpSecurity extends /** * Allows configuring exception handling. This is automatically applied when using * {@link WebSecurityConfigurerAdapter}. - * * @return the {@link ExceptionHandlingConfigurer} for further customizations * @throws Exception */ @@ -1337,8 +1323,8 @@ public final class HttpSecurity extends * *

    Example Custom Configuration

    * - * The following customization will ensure that users who are denied access are forwarded - * to the page "/errors/access-denied". + * The following customization will ensure that users who are denied access are + * forwarded to the page "/errors/access-denied". * *
     	 * @Configuration
    @@ -1348,25 +1334,25 @@ public final class HttpSecurity extends
     	 * 	@Override
     	 * 	protected void configure(HttpSecurity http) throws Exception {
     	 * 		http
    -	 * 			.authorizeRequests(authorizeRequests ->
    +	 * 			.authorizeRequests((authorizeRequests) ->
     	 * 				authorizeRequests
     	 * 					.antMatchers("/**").hasRole("USER")
     	 * 			)
     	 * 			// sample exception handling customization
    -	 * 			.exceptionHandling(exceptionHandling ->
    +	 * 			.exceptionHandling((exceptionHandling) ->
     	 * 				exceptionHandling
     	 * 					.accessDeniedPage("/errors/access-denied")
     	 * 			);
     	 * 	}
     	 * }
     	 * 
    - * - * @param exceptionHandlingCustomizer the {@link Customizer} to provide more options for - * the {@link ExceptionHandlingConfigurer} + * @param exceptionHandlingCustomizer the {@link Customizer} to provide more options + * for the {@link ExceptionHandlingConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ - public HttpSecurity exceptionHandling(Customizer> exceptionHandlingCustomizer) throws Exception { + public HttpSecurity exceptionHandling( + Customizer> exceptionHandlingCustomizer) throws Exception { exceptionHandlingCustomizer.customize(getOrApply(new ExceptionHandlingConfigurer<>())); return HttpSecurity.this; } @@ -1375,7 +1361,6 @@ public final class HttpSecurity extends * Sets up management of the {@link SecurityContext} on the * {@link SecurityContextHolder} between {@link HttpServletRequest}'s. This is * automatically applied when using {@link WebSecurityConfigurerAdapter}. - * * @return the {@link SecurityContextConfigurer} for further customizations * @throws Exception */ @@ -1398,20 +1383,20 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .securityContext(securityContext -> + * .securityContext((securityContext) -> * securityContext * .securityContextRepository(SCR) * ); * } * } * - * * @param securityContextCustomizer the {@link Customizer} to provide more options for * the {@link SecurityContextConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ - public HttpSecurity securityContext(Customizer> securityContextCustomizer) throws Exception { + public HttpSecurity securityContext(Customizer> securityContextCustomizer) + throws Exception { securityContextCustomizer.customize(getOrApply(new SecurityContextConfigurer<>())); return HttpSecurity.this; } @@ -1420,7 +1405,6 @@ public final class HttpSecurity extends * Integrates the {@link HttpServletRequest} methods with the values found on the * {@link SecurityContext}. This is automatically applied when using * {@link WebSecurityConfigurerAdapter}. - * * @return the {@link ServletApiConfigurer} for further customizations * @throws Exception */ @@ -1441,19 +1425,19 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .servletApi(servletApi -> + * .servletApi((servletApi) -> * servletApi.disable() * ); * } * } * - * - * @param servletApiCustomizer the {@link Customizer} to provide more options for - * the {@link ServletApiConfigurer} + * @param servletApiCustomizer the {@link Customizer} to provide more options for the + * {@link ServletApiConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ - public HttpSecurity servletApi(Customizer> servletApiCustomizer) throws Exception { + public HttpSecurity servletApi(Customizer> servletApiCustomizer) + throws Exception { servletApiCustomizer.customize(getOrApply(new ServletApiConfigurer<>())); return HttpSecurity.this; } @@ -1476,7 +1460,6 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link CsrfConfigurer} for further customizations * @throws Exception */ @@ -1498,13 +1481,12 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .csrf(csrf -> csrf.disable()); + * .csrf((csrf) -> csrf.disable()); * } * } * - * - * @param csrfCustomizer the {@link Customizer} to provide more options for - * the {@link CsrfConfigurer} + * @param csrfCustomizer the {@link Customizer} to provide more options for the + * {@link CsrfConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -1547,7 +1529,6 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link LogoutConfigurer} for further customizations * @throws Exception */ @@ -1576,13 +1557,13 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) * .formLogin(withDefaults()) * // sample logout customization - * .logout(logout -> + * .logout((logout) -> * logout.deleteCookies("remove") * .invalidateHttpSession(false) * .logoutUrl("/custom-logout") @@ -1591,9 +1572,8 @@ public final class HttpSecurity extends * } * } * - * - * @param logoutCustomizer the {@link Customizer} to provide more options for - * the {@link LogoutConfigurer} + * @param logoutCustomizer the {@link Customizer} to provide more options for the + * {@link LogoutConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -1665,7 +1645,6 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link AnonymousConfigurer} for further customizations * @throws Exception */ @@ -1693,13 +1672,13 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) * .formLogin(withDefaults()) * // sample anonymous customization - * .anonymous(anonymous -> + * .anonymous((anonymous) -> * anonymous * .authorities("ROLE_ANON") * ) @@ -1719,13 +1698,13 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) * .formLogin(withDefaults()) * // sample anonymous customization - * .anonymous(anonymous -> + * .anonymous((anonymous) -> * anonymous.disable() * ); * } @@ -1736,9 +1715,8 @@ public final class HttpSecurity extends * } * } * - * - * @param anonymousCustomizer the {@link Customizer} to provide more options for - * the {@link AnonymousConfigurer} + * @param anonymousCustomizer the {@link Customizer} to provide more options for the + * {@link AnonymousConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -1747,7 +1725,6 @@ public final class HttpSecurity extends return HttpSecurity.this; } - /** * Specifies to support form based authentication. If * {@link FormLoginConfigurer#loginPage(String)} is not specified a default login page @@ -1802,11 +1779,9 @@ public final class HttpSecurity extends * } * } * - * - * @see FormLoginConfigurer#loginPage(String) - * * @return the {@link FormLoginConfigurer} for further customizations * @throws Exception + * @see FormLoginConfigurer#loginPage(String) */ public FormLoginConfigurer formLogin() throws Exception { return getOrApply(new FormLoginConfigurer<>()); @@ -1832,7 +1807,7 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -1851,11 +1826,11 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) - * .formLogin(formLogin -> + * .formLogin((formLogin) -> * formLogin * .usernameParameter("username") * .passwordParameter("password") @@ -1866,13 +1841,11 @@ public final class HttpSecurity extends * } * } * - * - * @see FormLoginConfigurer#loginPage(String) - * - * @param formLoginCustomizer the {@link Customizer} to provide more options for - * the {@link FormLoginConfigurer} + * @param formLoginCustomizer the {@link Customizer} to provide more options for the + * {@link FormLoginConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @see FormLoginConfigurer#loginPage(String) */ public HttpSecurity formLogin(Customizer> formLoginCustomizer) throws Exception { formLoginCustomizer.customize(getOrApply(new FormLoginConfigurer<>())); @@ -1880,40 +1853,39 @@ public final class HttpSecurity extends } /** - * Configures authentication support using an SAML 2.0 Service Provider. - *
    + * Configures authentication support using an SAML 2.0 Service Provider.
    *
    * - * The "authentication flow" is implemented using the Web Browser SSO Profile, using POST and REDIRECT bindings, - * as documented in the SAML V2.0 Core,Profiles and Bindings - * specifications. - *
    + * The "authentication flow" is implemented using the Web Browser SSO + * Profile, using POST and REDIRECT bindings, as documented in the + * SAML V2.0 + * Core,Profiles and Bindings specifications.
    *
    * - * As a prerequisite to using this feature, is that you have a SAML v2.0 Identity Provider to provide an assertion. - * The representation of the Service Provider, the relying party, and the remote Identity Provider, the asserting party - * is contained within {@link RelyingPartyRegistration}. - *
    + * As a prerequisite to using this feature, is that you have a SAML v2.0 Identity + * Provider to provide an assertion. The representation of the Service Provider, the + * relying party, and the remote Identity Provider, the asserting party is contained + * within {@link RelyingPartyRegistration}.
    *
    * * {@link RelyingPartyRegistration}(s) are composed within a - * {@link RelyingPartyRegistrationRepository}, - * which is required and must be registered with the {@link ApplicationContext} or - * configured via saml2Login().relyingPartyRegistrationRepository(..). - *
    + * {@link RelyingPartyRegistrationRepository}, which is required and must be + * registered with the {@link ApplicationContext} or configured via + * saml2Login().relyingPartyRegistrationRepository(..).
    *
    * - * The default configuration provides an auto-generated login page at "/login" and - * redirects to "/login?error" when an authentication error occurs. - * The login page will display each of the identity providers with a link - * that is capable of initiating the "authentication flow". - *
    + * The default configuration provides an auto-generated login page at + * "/login" and redirects to + * "/login?error" when an authentication error occurs. The + * login page will display each of the identity providers with a link that is capable + * of initiating the "authentication flow".
    *
    * *

    *

    Example Configuration

    * - * The following example shows the minimal configuration required, using SimpleSamlPhp as the Authentication Provider. + * The following example shows the minimal configuration required, using SimpleSamlPhp + * as the Authentication Provider. * *
     	 * @Configuration
    @@ -1961,50 +1933,48 @@ public final class HttpSecurity extends
     	 * 
    * *

    - * - * @since 5.2 * @return the {@link Saml2LoginConfigurer} for further customizations * @throws Exception + * @since 5.2 */ public Saml2LoginConfigurer saml2Login() throws Exception { return getOrApply(new Saml2LoginConfigurer<>()); } /** - * Configures authentication support using an SAML 2.0 Service Provider. - *
    + * Configures authentication support using an SAML 2.0 Service Provider.
    *
    * - * The "authentication flow" is implemented using the Web Browser SSO Profile, using POST and REDIRECT bindings, - * as documented in the SAML V2.0 Core,Profiles and Bindings - * specifications. - *
    + * The "authentication flow" is implemented using the Web Browser SSO + * Profile, using POST and REDIRECT bindings, as documented in the + * SAML V2.0 + * Core,Profiles and Bindings specifications.
    *
    * - * As a prerequisite to using this feature, is that you have a SAML v2.0 Identity Provider to provide an assertion. - * The representation of the Service Provider, the relying party, and the remote Identity Provider, the asserting party - * is contained within {@link RelyingPartyRegistration}. - *
    + * As a prerequisite to using this feature, is that you have a SAML v2.0 Identity + * Provider to provide an assertion. The representation of the Service Provider, the + * relying party, and the remote Identity Provider, the asserting party is contained + * within {@link RelyingPartyRegistration}.
    *
    * * {@link RelyingPartyRegistration}(s) are composed within a - * {@link RelyingPartyRegistrationRepository}, - * which is required and must be registered with the {@link ApplicationContext} or - * configured via saml2Login().relyingPartyRegistrationRepository(..). - *
    + * {@link RelyingPartyRegistrationRepository}, which is required and must be + * registered with the {@link ApplicationContext} or configured via + * saml2Login().relyingPartyRegistrationRepository(..).
    *
    * - * The default configuration provides an auto-generated login page at "/login" and - * redirects to "/login?error" when an authentication error occurs. - * The login page will display each of the identity providers with a link - * that is capable of initiating the "authentication flow". - *
    + * The default configuration provides an auto-generated login page at + * "/login" and redirects to + * "/login?error" when an authentication error occurs. The + * login page will display each of the identity providers with a link that is capable + * of initiating the "authentication flow".
    *
    * *

    *

    Example Configuration

    * - * The following example shows the minimal configuration required, using SimpleSamlPhp as the Authentication Provider. + * The following example shows the minimal configuration required, using SimpleSamlPhp + * as the Authentication Provider. * *
     	 * @Configuration
    @@ -2052,55 +2022,58 @@ public final class HttpSecurity extends
     	 * 
    * *

    - * - * @since 5.2 - * @param saml2LoginCustomizer the {@link Customizer} to provide more options for - * the {@link Saml2LoginConfigurer} + * @param saml2LoginCustomizer the {@link Customizer} to provide more options for the + * {@link Saml2LoginConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @since 5.2 */ - public HttpSecurity saml2Login(Customizer> saml2LoginCustomizer) throws Exception { + public HttpSecurity saml2Login(Customizer> saml2LoginCustomizer) + throws Exception { saml2LoginCustomizer.customize(getOrApply(new Saml2LoginConfigurer<>())); return HttpSecurity.this; } /** - * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 Provider. - *
    + * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 + * Provider.
    *
    * - * The "authentication flow" is implemented using the Authorization Code Grant, as specified in the - * OAuth 2.0 Authorization Framework - * and OpenID Connect Core 1.0 - * specification. - *
    + * The "authentication flow" is implemented using the Authorization Code + * Grant, as specified in the + * OAuth 2.0 + * Authorization Framework and OpenID Connect + * Core 1.0 specification.
    *
    * - * As a prerequisite to using this feature, you must register a client with a provider. - * The client registration information may than be used for configuring - * a {@link org.springframework.security.oauth2.client.registration.ClientRegistration} using a + * As a prerequisite to using this feature, you must register a client with a + * provider. The client registration information may than be used for configuring a + * {@link org.springframework.security.oauth2.client.registration.ClientRegistration} + * using a * {@link org.springframework.security.oauth2.client.registration.ClientRegistration.Builder}. *
    *
    * - * {@link org.springframework.security.oauth2.client.registration.ClientRegistration}(s) are composed within a + * {@link org.springframework.security.oauth2.client.registration.ClientRegistration}(s) + * are composed within a * {@link org.springframework.security.oauth2.client.registration.ClientRegistrationRepository}, - * which is required and must be registered with the {@link ApplicationContext} or - * configured via oauth2Login().clientRegistrationRepository(..). - *
    + * which is required and must be registered with the {@link ApplicationContext} + * or configured via oauth2Login().clientRegistrationRepository(..).
    *
    * - * The default configuration provides an auto-generated login page at "/login" and - * redirects to "/login?error" when an authentication error occurs. - * The login page will display each of the clients with a link - * that is capable of initiating the "authentication flow". - *
    + * The default configuration provides an auto-generated login page at + * "/login" and redirects to + * "/login?error" when an authentication error occurs. The + * login page will display each of the clients with a link that is capable of + * initiating the "authentication flow".
    *
    * *

    *

    Example Configuration

    * - * The following example shows the minimal configuration required, using Google as the Authentication Provider. + * The following example shows the minimal configuration required, using Google as the + * Authentication Provider. * *
     	 * @Configuration
    @@ -2143,57 +2116,64 @@ public final class HttpSecurity extends
     	 * 
    * *

    - * For more advanced configuration, see {@link OAuth2LoginConfigurer} for available options to customize the defaults. - * - * @since 5.0 - * @see Section 4.1 Authorization Code Grant - * @see Section 3.1 Authorization Code Flow - * @see org.springframework.security.oauth2.client.registration.ClientRegistration - * @see org.springframework.security.oauth2.client.registration.ClientRegistrationRepository + * For more advanced configuration, see {@link OAuth2LoginConfigurer} for available + * options to customize the defaults. * @return the {@link OAuth2LoginConfigurer} for further customizations * @throws Exception + * @since 5.0 + * @see Section 4.1 Authorization Code + * Grant + * @see Section 3.1 + * Authorization Code Flow + * @see org.springframework.security.oauth2.client.registration.ClientRegistration + * @see org.springframework.security.oauth2.client.registration.ClientRegistrationRepository */ public OAuth2LoginConfigurer oauth2Login() throws Exception { return getOrApply(new OAuth2LoginConfigurer<>()); } /** - * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 Provider. - *
    + * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 + * Provider.
    *
    * - * The "authentication flow" is implemented using the Authorization Code Grant, as specified in the - * OAuth 2.0 Authorization Framework - * and OpenID Connect Core 1.0 - * specification. - *
    + * The "authentication flow" is implemented using the Authorization Code + * Grant, as specified in the + * OAuth 2.0 + * Authorization Framework and OpenID Connect + * Core 1.0 specification.
    *
    * - * As a prerequisite to using this feature, you must register a client with a provider. - * The client registration information may than be used for configuring - * a {@link org.springframework.security.oauth2.client.registration.ClientRegistration} using a + * As a prerequisite to using this feature, you must register a client with a + * provider. The client registration information may than be used for configuring a + * {@link org.springframework.security.oauth2.client.registration.ClientRegistration} + * using a * {@link org.springframework.security.oauth2.client.registration.ClientRegistration.Builder}. *
    *
    * - * {@link org.springframework.security.oauth2.client.registration.ClientRegistration}(s) are composed within a + * {@link org.springframework.security.oauth2.client.registration.ClientRegistration}(s) + * are composed within a * {@link org.springframework.security.oauth2.client.registration.ClientRegistrationRepository}, - * which is required and must be registered with the {@link ApplicationContext} or - * configured via oauth2Login().clientRegistrationRepository(..). - *
    + * which is required and must be registered with the {@link ApplicationContext} + * or configured via oauth2Login().clientRegistrationRepository(..).
    *
    * - * The default configuration provides an auto-generated login page at "/login" and - * redirects to "/login?error" when an authentication error occurs. - * The login page will display each of the clients with a link - * that is capable of initiating the "authentication flow". - *
    + * The default configuration provides an auto-generated login page at + * "/login" and redirects to + * "/login?error" when an authentication error occurs. The + * login page will display each of the clients with a link that is capable of + * initiating the "authentication flow".
    *
    * *

    *

    Example Configuration

    * - * The following example shows the minimal configuration required, using Google as the Authentication Provider. + * The following example shows the minimal configuration required, using Google as the + * Authentication Provider. * *
     	 * @Configuration
    @@ -2204,7 +2184,7 @@ public final class HttpSecurity extends
     	 * 		@Override
     	 * 		protected void configure(HttpSecurity http) throws Exception {
     	 * 			http
    -	 * 				.authorizeRequests(authorizeRequests ->
    +	 * 				.authorizeRequests((authorizeRequests) ->
     	 * 					authorizeRequests
     	 * 						.anyRequest().authenticated()
     	 * 				)
    @@ -2237,30 +2217,35 @@ public final class HttpSecurity extends
     	 * 
    * *

    - * For more advanced configuration, see {@link OAuth2LoginConfigurer} for available options to customize the defaults. - * - * @see Section 4.1 Authorization Code Grant - * @see Section 3.1 Authorization Code Flow - * @see org.springframework.security.oauth2.client.registration.ClientRegistration - * @see org.springframework.security.oauth2.client.registration.ClientRegistrationRepository - * - * @param oauth2LoginCustomizer the {@link Customizer} to provide more options for - * the {@link OAuth2LoginConfigurer} + * For more advanced configuration, see {@link OAuth2LoginConfigurer} for available + * options to customize the defaults. + * @param oauth2LoginCustomizer the {@link Customizer} to provide more options for the + * {@link OAuth2LoginConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @see Section 4.1 Authorization Code + * Grant + * @see Section 3.1 + * Authorization Code Flow + * @see org.springframework.security.oauth2.client.registration.ClientRegistration + * @see org.springframework.security.oauth2.client.registration.ClientRegistrationRepository */ - public HttpSecurity oauth2Login(Customizer> oauth2LoginCustomizer) throws Exception { + public HttpSecurity oauth2Login(Customizer> oauth2LoginCustomizer) + throws Exception { oauth2LoginCustomizer.customize(getOrApply(new OAuth2LoginConfigurer<>())); return HttpSecurity.this; } /** * Configures OAuth 2.0 Client support. - * - * @since 5.1 - * @see OAuth 2.0 Authorization Framework * @return the {@link OAuth2ClientConfigurer} for further customizations * @throws Exception + * @since 5.1 + * @see OAuth 2.0 Authorization + * Framework */ public OAuth2ClientConfigurer oauth2Client() throws Exception { OAuth2ClientConfigurer configurer = getOrApply(new OAuth2ClientConfigurer<>()); @@ -2273,7 +2258,8 @@ public final class HttpSecurity extends * *

    Example Configuration

    * - * The following example demonstrates how to enable OAuth 2.0 Client support for all endpoints. + * The following example demonstrates how to enable OAuth 2.0 Client support for all + * endpoints. * *
     	 * @Configuration
    @@ -2282,7 +2268,7 @@ public final class HttpSecurity extends
     	 * 	@Override
     	 * 	protected void configure(HttpSecurity http) throws Exception {
     	 * 		http
    -	 * 			.authorizeRequests(authorizeRequests ->
    +	 * 			.authorizeRequests((authorizeRequests) ->
     	 * 				authorizeRequests
     	 * 					.anyRequest().authenticated()
     	 * 			)
    @@ -2290,29 +2276,32 @@ public final class HttpSecurity extends
     	 *	}
     	 * }
     	 * 
    - * - * @see OAuth 2.0 Authorization Framework - * * @param oauth2ClientCustomizer the {@link Customizer} to provide more options for * the {@link OAuth2ClientConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @see OAuth 2.0 Authorization + * Framework */ - public HttpSecurity oauth2Client(Customizer> oauth2ClientCustomizer) throws Exception { + public HttpSecurity oauth2Client(Customizer> oauth2ClientCustomizer) + throws Exception { oauth2ClientCustomizer.customize(getOrApply(new OAuth2ClientConfigurer<>())); return HttpSecurity.this; } /** * Configures OAuth 2.0 Resource Server support. - * - * @since 5.1 - * @see OAuth 2.0 Authorization Framework * @return the {@link OAuth2ResourceServerConfigurer} for further customizations * @throws Exception + * @since 5.1 + * @see OAuth 2.0 Authorization + * Framework */ public OAuth2ResourceServerConfigurer oauth2ResourceServer() throws Exception { - OAuth2ResourceServerConfigurer configurer = getOrApply(new OAuth2ResourceServerConfigurer<>(getContext())); + OAuth2ResourceServerConfigurer configurer = getOrApply( + new OAuth2ResourceServerConfigurer<>(getContext())); this.postProcess(configurer); return configurer; } @@ -2322,7 +2311,8 @@ public final class HttpSecurity extends * *

    Example Configuration

    * - * The following example demonstrates how to configure a custom JWT authentication converter. + * The following example demonstrates how to configure a custom JWT authentication + * converter. * *
     	 * @Configuration
    @@ -2335,13 +2325,13 @@ public final class HttpSecurity extends
     	 * 	@Override
     	 * 	protected void configure(HttpSecurity http) throws Exception {
     	 * 		http
    -	 * 			.authorizeRequests(authorizeRequests ->
    +	 * 			.authorizeRequests((authorizeRequests) ->
     	 * 				authorizeRequests
     	 * 					.anyRequest().authenticated()
     	 * 			)
    -	 * 			.oauth2ResourceServer(oauth2ResourceServer ->
    +	 * 			.oauth2ResourceServer((oauth2ResourceServer) ->
     	 * 				oauth2ResourceServer
    -	 * 					.jwt(jwt ->
    +	 * 					.jwt((jwt) ->
     	 * 						jwt
     	 * 							.decoder(jwtDecoder())
     	 * 					)
    @@ -2354,17 +2344,18 @@ public final class HttpSecurity extends
     	 * 	}
     	 * }
     	 * 
    - * - * @see OAuth 2.0 Authorization Framework - * - * @param oauth2ResourceServerCustomizer the {@link Customizer} to provide more options for - * the {@link OAuth2ResourceServerConfigurer} + * @param oauth2ResourceServerCustomizer the {@link Customizer} to provide more + * options for the {@link OAuth2ResourceServerConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception + * @see OAuth 2.0 Authorization + * Framework */ - public HttpSecurity oauth2ResourceServer(Customizer> oauth2ResourceServerCustomizer) - throws Exception { - OAuth2ResourceServerConfigurer configurer = getOrApply(new OAuth2ResourceServerConfigurer<>(getContext())); + public HttpSecurity oauth2ResourceServer( + Customizer> oauth2ResourceServerCustomizer) throws Exception { + OAuth2ResourceServerConfigurer configurer = getOrApply( + new OAuth2ResourceServerConfigurer<>(getContext())); this.postProcess(configurer); oauth2ResourceServerCustomizer.customize(configurer); return HttpSecurity.this; @@ -2379,8 +2370,8 @@ public final class HttpSecurity extends * The example below demonstrates how to require HTTPs for every request. Only * requiring HTTPS for some requests is supported, but not recommended since an * application that allows for HTTP introduces many security vulnerabilities. For one - * such example, read about Firesheep. + * such example, read about + * Firesheep. * *
     	 * @Configuration
    @@ -2399,16 +2390,12 @@ public final class HttpSecurity extends
     	 * 	}
     	 * }
     	 * 
    - * - * * @return the {@link ChannelSecurityConfigurer} for further customizations * @throws Exception */ - public ChannelSecurityConfigurer.ChannelRequestMatcherRegistry requiresChannel() - throws Exception { + public ChannelSecurityConfigurer.ChannelRequestMatcherRegistry requiresChannel() throws Exception { ApplicationContext context = getContext(); - return getOrApply(new ChannelSecurityConfigurer<>(context)) - .getRegistry(); + return getOrApply(new ChannelSecurityConfigurer<>(context)).getRegistry(); } /** @@ -2420,8 +2407,8 @@ public final class HttpSecurity extends * The example below demonstrates how to require HTTPs for every request. Only * requiring HTTPS for some requests is supported, but not recommended since an * application that allows for HTTP introduces many security vulnerabilities. For one - * such example, read about Firesheep. + * such example, read about + * Firesheep. * *
     	 * @Configuration
    @@ -2431,29 +2418,28 @@ public final class HttpSecurity extends
     	 * 	@Override
     	 * 	protected void configure(HttpSecurity http) throws Exception {
     	 * 		http
    -	 * 			.authorizeRequests(authorizeRequests ->
    +	 * 			.authorizeRequests((authorizeRequests) ->
     	 * 				authorizeRequests
     	 * 					.antMatchers("/**").hasRole("USER")
     	 * 			)
     	 * 			.formLogin(withDefaults())
    -	 * 			.requiresChannel(requiresChannel ->
    +	 * 			.requiresChannel((requiresChannel) ->
     	 * 				requiresChannel
     	 * 					.anyRequest().requiresSecure()
     	 * 			);
     	 * 	}
     	 * }
     	 * 
    - * * @param requiresChannelCustomizer the {@link Customizer} to provide more options for * the {@link ChannelSecurityConfigurer.ChannelRequestMatcherRegistry} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ - public HttpSecurity requiresChannel(Customizer.ChannelRequestMatcherRegistry> requiresChannelCustomizer) + public HttpSecurity requiresChannel( + Customizer.ChannelRequestMatcherRegistry> requiresChannelCustomizer) throws Exception { ApplicationContext context = getContext(); - requiresChannelCustomizer.customize(getOrApply(new ChannelSecurityConfigurer<>(context)) - .getRegistry()); + requiresChannelCustomizer.customize(getOrApply(new ChannelSecurityConfigurer<>(context)).getRegistry()); return HttpSecurity.this; } @@ -2463,8 +2449,8 @@ public final class HttpSecurity extends *

    Example Configuration

    * * The example below demonstrates how to configure HTTP Basic authentication for an - * application. The default realm is "Realm", but can be - * customized using {@link HttpBasicConfigurer#realmName(String)}. + * application. The default realm is "Realm", but can be customized using + * {@link HttpBasicConfigurer#realmName(String)}. * *
     	 * @Configuration
    @@ -2482,7 +2468,6 @@ public final class HttpSecurity extends
     	 * 	}
     	 * }
     	 * 
    - * * @return the {@link HttpBasicConfigurer} for further customizations * @throws Exception */ @@ -2496,8 +2481,8 @@ public final class HttpSecurity extends *

    Example Configuration

    * * The example below demonstrates how to configure HTTP Basic authentication for an - * application. The default realm is "Realm", but can be - * customized using {@link HttpBasicConfigurer#realmName(String)}. + * application. The default realm is "Realm", but can be customized using + * {@link HttpBasicConfigurer#realmName(String)}. * *
     	 * @Configuration
    @@ -2507,7 +2492,7 @@ public final class HttpSecurity extends
     	 * 	@Override
     	 * 	protected void configure(HttpSecurity http) throws Exception {
     	 * 		http
    -	 * 			.authorizeRequests(authorizeRequests ->
    +	 * 			.authorizeRequests((authorizeRequests) ->
     	 * 				authorizeRequests
     	 * 					.antMatchers("/**").hasRole("USER")
     	 * 			)
    @@ -2515,9 +2500,8 @@ public final class HttpSecurity extends
     	 * 	}
     	 * }
     	 * 
    - * - * @param httpBasicCustomizer the {@link Customizer} to provide more options for - * the {@link HttpBasicConfigurer} + * @param httpBasicCustomizer the {@link Customizer} to provide more options for the + * {@link HttpBasicConfigurer} * @return the {@link HttpSecurity} for further customizations * @throws Exception */ @@ -2526,6 +2510,7 @@ public final class HttpSecurity extends return HttpSecurity.this; } + @Override public void setSharedObject(Class sharedType, C object) { super.setSharedObject(sharedType, object); } @@ -2537,32 +2522,18 @@ public final class HttpSecurity extends @Override protected DefaultSecurityFilterChain performBuild() { - filters.sort(comparator); - return new DefaultSecurityFilterChain(requestMatcher, filters); + this.filters.sort(this.comparator); + return new DefaultSecurityFilterChain(this.requestMatcher, this.filters); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.config.annotation.web.HttpSecurityBuilder#authenticationProvider - * (org.springframework.security.authentication.AuthenticationProvider) - */ - public HttpSecurity authenticationProvider( - AuthenticationProvider authenticationProvider) { + @Override + public HttpSecurity authenticationProvider(AuthenticationProvider authenticationProvider) { getAuthenticationRegistry().authenticationProvider(authenticationProvider); return this; } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.config.annotation.web.HttpSecurityBuilder#userDetailsService - * (org.springframework.security.core.userdetails.UserDetailsService) - */ - public HttpSecurity userDetailsService(UserDetailsService userDetailsService) - throws Exception { + @Override + public HttpSecurity userDetailsService(UserDetailsService userDetailsService) throws Exception { getAuthenticationRegistry().userDetailsService(userDetailsService); return this; } @@ -2571,45 +2542,24 @@ public final class HttpSecurity extends return getSharedObject(AuthenticationManagerBuilder.class); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.config.annotation.web.HttpSecurityBuilder#addFilterAfter(javax - * .servlet.Filter, java.lang.Class) - */ + @Override public HttpSecurity addFilterAfter(Filter filter, Class afterFilter) { - comparator.registerAfter(filter.getClass(), afterFilter); + this.comparator.registerAfter(filter.getClass(), afterFilter); return addFilter(filter); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.config.annotation.web.HttpSecurityBuilder#addFilterBefore( - * javax.servlet.Filter, java.lang.Class) - */ - public HttpSecurity addFilterBefore(Filter filter, - Class beforeFilter) { - comparator.registerBefore(filter.getClass(), beforeFilter); + @Override + public HttpSecurity addFilterBefore(Filter filter, Class beforeFilter) { + this.comparator.registerBefore(filter.getClass(), beforeFilter); return addFilter(filter); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.config.annotation.web.HttpSecurityBuilder#addFilter(javax. - * servlet.Filter) - */ + @Override public HttpSecurity addFilter(Filter filter) { Class filterClass = filter.getClass(); - if (!comparator.isRegistered(filterClass)) { - throw new IllegalArgumentException( - "The Filter class " - + filterClass.getName() - + " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead."); + if (!this.comparator.isRegistered(filterClass)) { + throw new IllegalArgumentException("The Filter class " + filterClass.getName() + + " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead."); } this.filters.add(filter); return this; @@ -2628,7 +2578,6 @@ public final class HttpSecurity extends * deterministic. More concretely, registering multiple Filters in the same location * does not override existing Filters. Instead, do not register Filters you do not * want to use. - * * @param filter the Filter to register * @param atFilter the location of another {@link Filter} that is already registered * (i.e. known) with Spring Security. @@ -2643,14 +2592,15 @@ public final class HttpSecurity extends * Allows specifying which {@link HttpServletRequest} instances this * {@link HttpSecurity} will be invoked on. This method allows for easily invoking the * {@link HttpSecurity} for multiple different {@link RequestMatcher} instances. If - * only a single {@link RequestMatcher} is necessary consider using {@link #mvcMatcher(String)}, - * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, or - * {@link #requestMatcher(RequestMatcher)}. + * only a single {@link RequestMatcher} is necessary consider using + * {@link #mvcMatcher(String)}, {@link #antMatcher(String)}, + * {@link #regexMatcher(String)}, or {@link #requestMatcher(RequestMatcher)}. * *

    - * Invoking {@link #requestMatchers()} will not override previous invocations of {@link #mvcMatcher(String)}}, - * {@link #requestMatchers()}, {@link #antMatcher(String)}, - * {@link #regexMatcher(String)}, and {@link #requestMatcher(RequestMatcher)}. + * Invoking {@link #requestMatchers()} will not override previous invocations of + * {@link #mvcMatcher(String)}}, {@link #requestMatchers()}, + * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, and + * {@link #requestMatcher(RequestMatcher)}. *

    * *

    Example Configurations

    @@ -2743,25 +2693,25 @@ public final class HttpSecurity extends * } * } * - * * @return the {@link RequestMatcherConfigurer} for further customizations */ public RequestMatcherConfigurer requestMatchers() { - return requestMatcherConfigurer; + return this.requestMatcherConfigurer; } /** * Allows specifying which {@link HttpServletRequest} instances this * {@link HttpSecurity} will be invoked on. This method allows for easily invoking the * {@link HttpSecurity} for multiple different {@link RequestMatcher} instances. If - * only a single {@link RequestMatcher} is necessary consider using {@link #mvcMatcher(String)}, - * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, or - * {@link #requestMatcher(RequestMatcher)}. + * only a single {@link RequestMatcher} is necessary consider using + * {@link #mvcMatcher(String)}, {@link #antMatcher(String)}, + * {@link #regexMatcher(String)}, or {@link #requestMatcher(RequestMatcher)}. * *

    - * Invoking {@link #requestMatchers()} will not override previous invocations of {@link #mvcMatcher(String)}}, - * {@link #requestMatchers()}, {@link #antMatcher(String)}, - * {@link #regexMatcher(String)}, and {@link #requestMatcher(RequestMatcher)}. + * Invoking {@link #requestMatchers()} will not override previous invocations of + * {@link #mvcMatcher(String)}}, {@link #requestMatchers()}, + * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, and + * {@link #requestMatcher(RequestMatcher)}. *

    * *

    Example Configurations

    @@ -2777,11 +2727,11 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .requestMatchers(requestMatchers -> + * .requestMatchers((requestMatchers) -> * requestMatchers * .antMatchers("/api/**", "/oauth/**") * ) - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -2800,12 +2750,12 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .requestMatchers(requestMatchers -> + * .requestMatchers((requestMatchers) -> * requestMatchers * .antMatchers("/api/**") * .antMatchers("/oauth/**") * ) - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -2824,15 +2774,15 @@ public final class HttpSecurity extends * @Override * protected void configure(HttpSecurity http) throws Exception { * http - * .requestMatchers(requestMatchers -> + * .requestMatchers((requestMatchers) -> * requestMatchers * .antMatchers("/api/**") * ) - * .requestMatchers(requestMatchers -> + * .requestMatchers((requestMatchers) -> * requestMatchers * .antMatchers("/oauth/**") * ) - * .authorizeRequests(authorizeRequests -> + * .authorizeRequests((authorizeRequests) -> * authorizeRequests * .antMatchers("/**").hasRole("USER") * ) @@ -2840,13 +2790,12 @@ public final class HttpSecurity extends * } * } * - * * @param requestMatcherCustomizer the {@link Customizer} to provide more options for * the {@link RequestMatcherConfigurer} * @return the {@link HttpSecurity} for further customizations */ public HttpSecurity requestMatchers(Customizer requestMatcherCustomizer) { - requestMatcherCustomizer.customize(requestMatcherConfigurer); + requestMatcherCustomizer.customize(this.requestMatcherConfigurer); return HttpSecurity.this; } @@ -2857,10 +2806,10 @@ public final class HttpSecurity extends * *

    * Invoking {@link #requestMatcher(RequestMatcher)} will override previous invocations - * of {@link #requestMatchers()}, {@link #mvcMatcher(String)}, {@link #antMatcher(String)}, - * {@link #regexMatcher(String)}, and {@link #requestMatcher(RequestMatcher)}. + * of {@link #requestMatchers()}, {@link #mvcMatcher(String)}, + * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, and + * {@link #requestMatcher(RequestMatcher)}. *

    - * * @param requestMatcher the {@link RequestMatcher} to use (i.e. new * AntPathRequestMatcher("/admin/**","GET") ) * @return the {@link HttpSecurity} for further customizations @@ -2879,11 +2828,11 @@ public final class HttpSecurity extends * {@link #requestMatchers()} or {@link #requestMatcher(RequestMatcher)}. * *

    - * Invoking {@link #antMatcher(String)} will override previous invocations of {@link #mvcMatcher(String)}}, - * {@link #requestMatchers()}, {@link #antMatcher(String)}, - * {@link #regexMatcher(String)}, and {@link #requestMatcher(RequestMatcher)}. + * Invoking {@link #antMatcher(String)} will override previous invocations of + * {@link #mvcMatcher(String)}}, {@link #requestMatchers()}, + * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, and + * {@link #requestMatcher(RequestMatcher)}. *

    - * * @param antPattern the Ant Pattern to match on (i.e. "/admin/**") * @return the {@link HttpSecurity} for further customizations * @see AntPathRequestMatcher @@ -2894,15 +2843,15 @@ public final class HttpSecurity extends /** * Allows configuring the {@link HttpSecurity} to only be invoked when matching the - * provided Spring MVC pattern. If more advanced configuration is necessary, consider using - * {@link #requestMatchers()} or {@link #requestMatcher(RequestMatcher)}. + * provided Spring MVC pattern. If more advanced configuration is necessary, consider + * using {@link #requestMatchers()} or {@link #requestMatcher(RequestMatcher)}. * *

    - * Invoking {@link #mvcMatcher(String)} will override previous invocations of {@link #mvcMatcher(String)}}, - * {@link #requestMatchers()}, {@link #antMatcher(String)}, - * {@link #regexMatcher(String)}, and {@link #requestMatcher(RequestMatcher)}. + * Invoking {@link #mvcMatcher(String)} will override previous invocations of + * {@link #mvcMatcher(String)}}, {@link #requestMatchers()}, + * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, and + * {@link #requestMatcher(RequestMatcher)}. *

    - * * @param mvcPattern the Spring MVC Pattern to match on (i.e. "/admin/**") * @return the {@link HttpSecurity} for further customizations * @see MvcRequestMatcher @@ -2918,11 +2867,11 @@ public final class HttpSecurity extends * {@link #requestMatchers()} or {@link #requestMatcher(RequestMatcher)}. * *

    - * Invoking {@link #regexMatcher(String)} will override previous invocations of {@link #mvcMatcher(String)}}, - * {@link #requestMatchers()}, {@link #antMatcher(String)}, - * {@link #regexMatcher(String)}, and {@link #requestMatcher(RequestMatcher)}. + * Invoking {@link #regexMatcher(String)} will override previous invocations of + * {@link #mvcMatcher(String)}}, {@link #requestMatchers()}, + * {@link #antMatcher(String)}, {@link #regexMatcher(String)}, and + * {@link #requestMatcher(RequestMatcher)}. *

    - * * @param pattern the Regular Expression to match on (i.e. "/admin/.+") * @return the {@link HttpSecurity} for further customizations * @see RegexRequestMatcher @@ -2931,6 +2880,24 @@ public final class HttpSecurity extends return requestMatcher(new RegexRequestMatcher(pattern, null)); } + /** + * If the {@link SecurityConfigurer} has already been specified get the original, + * otherwise apply the new {@link SecurityConfigurerAdapter}. + * @param configurer the {@link SecurityConfigurer} to apply if one is not found for + * this {@link SecurityConfigurer} class. + * @return the current {@link SecurityConfigurer} for the configurer passed in + * @throws Exception + */ + @SuppressWarnings("unchecked") + private > C getOrApply(C configurer) + throws Exception { + C existingConfig = (C) getConfigurer(configurer.getClass()); + if (existingConfig != null) { + return existingConfig; + } + return apply(configurer); + } + /** * An extension to {@link RequestMatcherConfigurer} that allows optionally configuring * the servlet path. @@ -2945,8 +2912,7 @@ public final class HttpSecurity extends * @param matchers the {@link MvcRequestMatcher} instances to set the servlet path * on if {@link #servletPath(String)} is set. */ - private MvcMatchersRequestMatcherConfigurer(ApplicationContext context, - List matchers) { + private MvcMatchersRequestMatcherConfigurer(ApplicationContext context, List matchers) { super(context); this.matchers = new ArrayList<>(matchers); } @@ -2966,21 +2932,16 @@ public final class HttpSecurity extends * @author Rob Winch * @since 3.2 */ - public class RequestMatcherConfigurer - extends AbstractRequestMatcherRegistry { + public class RequestMatcherConfigurer extends AbstractRequestMatcherRegistry { protected List matchers = new ArrayList<>(); - /** - * @param context - */ - private RequestMatcherConfigurer(ApplicationContext context) { + RequestMatcherConfigurer(ApplicationContext context) { setApplicationContext(context); } @Override - public MvcMatchersRequestMatcherConfigurer mvcMatchers(HttpMethod method, - String... mvcPatterns) { + public MvcMatchersRequestMatcherConfigurer mvcMatchers(HttpMethod method, String... mvcPatterns) { List mvcMatchers = createMvcMatchers(method, mvcPatterns); setMatchers(mvcMatchers); return new MvcMatchersRequestMatcherConfigurer(getContext(), mvcMatchers); @@ -2992,8 +2953,7 @@ public final class HttpSecurity extends } @Override - protected RequestMatcherConfigurer chainRequestMatchers( - List requestMatchers) { + protected RequestMatcherConfigurer chainRequestMatchers(List requestMatchers) { setMatchers(requestMatchers); return this; } @@ -3005,7 +2965,6 @@ public final class HttpSecurity extends /** * Return the {@link HttpSecurity} for further customizations - * * @return the {@link HttpSecurity} for further customizations */ public HttpSecurity and() { @@ -3014,22 +2973,4 @@ public final class HttpSecurity extends } - /** - * If the {@link SecurityConfigurer} has already been specified get the original, - * otherwise apply the new {@link SecurityConfigurerAdapter}. - * - * @param configurer the {@link SecurityConfigurer} to apply if one is not found for - * this {@link SecurityConfigurer} class. - * @return the current {@link SecurityConfigurer} for the configurer passed in - * @throws Exception - */ - @SuppressWarnings("unchecked") - private > C getOrApply( - C configurer) throws Exception { - C existingConfig = (C) getConfigurer(configurer.getClass()); - if (existingConfig != null) { - return existingConfig; - } - return apply(configurer); - } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java index 251556fcaf..78cbb2a4f3 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.builders; import java.util.ArrayList; @@ -72,16 +73,15 @@ import org.springframework.web.filter.DelegatingFilterProxy; * {@link WebSecurityConfigurerAdapter}. *

    * - * @see EnableWebSecurity - * @see WebSecurityConfiguration - * * @author Rob Winch * @author Evgeniy Cheban * @since 3.2 + * @see EnableWebSecurity + * @see WebSecurityConfiguration */ -public final class WebSecurity extends - AbstractConfiguredSecurityBuilder implements - SecurityBuilder, ApplicationContextAware { +public final class WebSecurity extends AbstractConfiguredSecurityBuilder + implements SecurityBuilder, ApplicationContextAware { + private final Log logger = LogFactory.getLog(getClass()); private final List ignoredRequests = new ArrayList<>(); @@ -102,7 +102,7 @@ public final class WebSecurity extends private DefaultWebSecurityExpressionHandler defaultWebSecurityExpressionHandler = new DefaultWebSecurityExpressionHandler(); - private SecurityExpressionHandler expressionHandler = defaultWebSecurityExpressionHandler; + private SecurityExpressionHandler expressionHandler = this.defaultWebSecurityExpressionHandler; private Runnable postBuildAction = () -> { }; @@ -118,12 +118,11 @@ public final class WebSecurity extends /** *

    - * Allows adding {@link RequestMatcher} instances that Spring Security - * should ignore. Web Security provided by Spring Security (including the - * {@link SecurityContext}) will not be available on {@link HttpServletRequest} that - * match. Typically the requests that are registered should be that of only static - * resources. For requests that are dynamic, consider mapping the request to allow all - * users instead. + * Allows adding {@link RequestMatcher} instances that Spring Security should ignore. + * Web Security provided by Spring Security (including the {@link SecurityContext}) + * will not be available on {@link HttpServletRequest} that match. Typically the + * requests that are registered should be that of only static resources. For requests + * that are dynamic, consider mapping the request to allow all users instead. *

    * * Example Usage: @@ -154,18 +153,16 @@ public final class WebSecurity extends * .antMatchers("/static/**"); * // now both URLs that start with /resources/ and /static/ will be ignored * - * * @return the {@link IgnoredRequestConfigurer} to use for registering request that * should be ignored */ public IgnoredRequestConfigurer ignoring() { - return ignoredRequestRegistry; + return this.ignoredRequestRegistry; } /** * Allows customizing the {@link HttpFirewall}. The default is * {@link StrictHttpFirewall}. - * * @param httpFirewall the custom {@link HttpFirewall} * @return the {@link WebSecurity} for further customizations */ @@ -176,10 +173,8 @@ public final class WebSecurity extends /** * Controls debugging support for Spring Security. - * * @param debugEnabled if true, enables debug support with Spring Security. Default is * false. - * * @return the {@link WebSecurity} for further customization. * @see EnableWebSecurity#debug() */ @@ -197,7 +192,6 @@ public final class WebSecurity extends * Typically this method is invoked automatically within the framework from * {@link WebSecurityConfigurerAdapter#init(WebSecurity)} *

    - * * @param securityFilterChainBuilder the builder to use to create the * {@link SecurityFilterChain} instances * @return the {@link WebSecurity} for further customizations @@ -209,15 +203,13 @@ public final class WebSecurity extends } /** - * Set the {@link WebInvocationPrivilegeEvaluator} to be used. If this is not specified, - * then a {@link DefaultWebInvocationPrivilegeEvaluator} will be created when - * {@link #securityInterceptor(FilterSecurityInterceptor)} is non null. - * + * Set the {@link WebInvocationPrivilegeEvaluator} to be used. If this is not + * specified, then a {@link DefaultWebInvocationPrivilegeEvaluator} will be created + * when {@link #securityInterceptor(FilterSecurityInterceptor)} is non null. * @param privilegeEvaluator the {@link WebInvocationPrivilegeEvaluator} to use * @return the {@link WebSecurity} for further customizations */ - public WebSecurity privilegeEvaluator( - WebInvocationPrivilegeEvaluator privilegeEvaluator) { + public WebSecurity privilegeEvaluator(WebInvocationPrivilegeEvaluator privilegeEvaluator) { this.privilegeEvaluator = privilegeEvaluator; return this; } @@ -225,12 +217,10 @@ public final class WebSecurity extends /** * Set the {@link SecurityExpressionHandler} to be used. If this is not specified, * then a {@link DefaultWebSecurityExpressionHandler} will be used. - * * @param expressionHandler the {@link SecurityExpressionHandler} to use * @return the {@link WebSecurity} for further customizations */ - public WebSecurity expressionHandler( - SecurityExpressionHandler expressionHandler) { + public WebSecurity expressionHandler(SecurityExpressionHandler expressionHandler) { Assert.notNull(expressionHandler, "expressionHandler cannot be null"); this.expressionHandler = expressionHandler; return this; @@ -241,7 +231,7 @@ public final class WebSecurity extends * @return the {@link SecurityExpressionHandler} for further customizations */ public SecurityExpressionHandler getExpressionHandler() { - return expressionHandler; + return this.expressionHandler; } /** @@ -249,11 +239,11 @@ public final class WebSecurity extends * @return the {@link WebInvocationPrivilegeEvaluator} for further customizations */ public WebInvocationPrivilegeEvaluator getPrivilegeEvaluator() { - if (privilegeEvaluator != null) { - return privilegeEvaluator; + if (this.privilegeEvaluator != null) { + return this.privilegeEvaluator; } - return filterSecurityInterceptor == null ? null - : new DefaultWebInvocationPrivilegeEvaluator(filterSecurityInterceptor); + return (this.filterSecurityInterceptor != null) + ? new DefaultWebInvocationPrivilegeEvaluator(this.filterSecurityInterceptor) : null; } /** @@ -269,7 +259,6 @@ public final class WebSecurity extends /** * Executes the Runnable immediately after the build takes place - * * @param postBuildAction * @return the {@link WebSecurity} for further customizations */ @@ -280,58 +269,80 @@ public final class WebSecurity extends @Override protected Filter performBuild() throws Exception { - Assert.state( - !securityFilterChainBuilders.isEmpty(), + Assert.state(!this.securityFilterChainBuilders.isEmpty(), () -> "At least one SecurityBuilder needs to be specified. " + "Typically this is done by exposing a SecurityFilterChain bean " + "or by adding a @Configuration that extends WebSecurityConfigurerAdapter. " - + "More advanced users can invoke " - + WebSecurity.class.getSimpleName() + + "More advanced users can invoke " + WebSecurity.class.getSimpleName() + ".addSecurityFilterChainBuilder directly"); - int chainSize = ignoredRequests.size() + securityFilterChainBuilders.size(); - List securityFilterChains = new ArrayList<>( - chainSize); - for (RequestMatcher ignoredRequest : ignoredRequests) { + int chainSize = this.ignoredRequests.size() + this.securityFilterChainBuilders.size(); + List securityFilterChains = new ArrayList<>(chainSize); + for (RequestMatcher ignoredRequest : this.ignoredRequests) { securityFilterChains.add(new DefaultSecurityFilterChain(ignoredRequest)); } - for (SecurityBuilder securityFilterChainBuilder : securityFilterChainBuilders) { + for (SecurityBuilder securityFilterChainBuilder : this.securityFilterChainBuilders) { securityFilterChains.add(securityFilterChainBuilder.build()); } FilterChainProxy filterChainProxy = new FilterChainProxy(securityFilterChains); - if (httpFirewall != null) { - filterChainProxy.setFirewall(httpFirewall); + if (this.httpFirewall != null) { + filterChainProxy.setFirewall(this.httpFirewall); } - if (requestRejectedHandler != null) { - filterChainProxy.setRequestRejectedHandler(requestRejectedHandler); + if (this.requestRejectedHandler != null) { + filterChainProxy.setRequestRejectedHandler(this.requestRejectedHandler); } filterChainProxy.afterPropertiesSet(); Filter result = filterChainProxy; - if (debugEnabled) { - logger.warn("\n\n" - + "********************************************************************\n" + if (this.debugEnabled) { + this.logger.warn("\n\n" + "********************************************************************\n" + "********** Security debugging is enabled. *************\n" + "********** This may include sensitive information. *************\n" + "********** Do not use in a production system! *************\n" + "********************************************************************\n\n"); result = new DebugFilter(filterChainProxy); } - postBuildAction.run(); + this.postBuildAction.run(); return result; } + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.defaultWebSecurityExpressionHandler.setApplicationContext(applicationContext); + try { + this.defaultWebSecurityExpressionHandler.setRoleHierarchy(applicationContext.getBean(RoleHierarchy.class)); + } + catch (NoSuchBeanDefinitionException ex) { + } + try { + this.defaultWebSecurityExpressionHandler + .setPermissionEvaluator(applicationContext.getBean(PermissionEvaluator.class)); + } + catch (NoSuchBeanDefinitionException ex) { + } + this.ignoredRequestRegistry = new IgnoredRequestConfigurer(applicationContext); + try { + this.httpFirewall = applicationContext.getBean(HttpFirewall.class); + } + catch (NoSuchBeanDefinitionException ex) { + } + try { + this.requestRejectedHandler = applicationContext.getBean(RequestRejectedHandler.class); + } + catch (NoSuchBeanDefinitionException ex) { + } + } + /** * An {@link IgnoredRequestConfigurer} that allows optionally configuring the * {@link MvcRequestMatcher#setMethod(HttpMethod)} * * @author Rob Winch */ - public final class MvcMatchersIgnoredRequestConfigurer - extends IgnoredRequestConfigurer { + public final class MvcMatchersIgnoredRequestConfigurer extends IgnoredRequestConfigurer { + private final List mvcMatchers; - private MvcMatchersIgnoredRequestConfigurer(ApplicationContext context, - List mvcMatchers) { + private MvcMatchersIgnoredRequestConfigurer(ApplicationContext context, List mvcMatchers) { super(context); this.mvcMatchers = mvcMatchers; } @@ -342,6 +353,7 @@ public final class WebSecurity extends } return this; } + } /** @@ -351,20 +363,17 @@ public final class WebSecurity extends * @author Rob Winch * @since 3.2 */ - public class IgnoredRequestConfigurer - extends AbstractRequestMatcherRegistry { + public class IgnoredRequestConfigurer extends AbstractRequestMatcherRegistry { - private IgnoredRequestConfigurer(ApplicationContext context) { + IgnoredRequestConfigurer(ApplicationContext context) { setApplicationContext(context); } @Override - public MvcMatchersIgnoredRequestConfigurer mvcMatchers(HttpMethod method, - String... mvcPatterns) { + public MvcMatchersIgnoredRequestConfigurer mvcMatchers(HttpMethod method, String... mvcPatterns) { List mvcMatchers = createMvcMatchers(method, mvcPatterns); WebSecurity.this.ignoredRequests.addAll(mvcMatchers); - return new MvcMatchersIgnoredRequestConfigurer(getApplicationContext(), - mvcMatchers); + return new MvcMatchersIgnoredRequestConfigurer(getApplicationContext(), mvcMatchers); } @Override @@ -373,8 +382,7 @@ public final class WebSecurity extends } @Override - protected IgnoredRequestConfigurer chainRequestMatchers( - List requestMatchers) { + protected IgnoredRequestConfigurer chainRequestMatchers(List requestMatchers) { WebSecurity.this.ignoredRequests.addAll(requestMatchers); return this; } @@ -385,29 +393,7 @@ public final class WebSecurity extends public WebSecurity and() { return WebSecurity.this; } + } - @Override - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - this.defaultWebSecurityExpressionHandler - .setApplicationContext(applicationContext); - - try { - this.defaultWebSecurityExpressionHandler.setRoleHierarchy(applicationContext.getBean(RoleHierarchy.class)); - } catch (NoSuchBeanDefinitionException e) {} - - try { - this.defaultWebSecurityExpressionHandler.setPermissionEvaluator(applicationContext.getBean( - PermissionEvaluator.class)); - } catch(NoSuchBeanDefinitionException e) {} - - this.ignoredRequestRegistry = new IgnoredRequestConfigurer(applicationContext); - try { - this.httpFirewall = applicationContext.getBean(HttpFirewall.class); - } catch(NoSuchBeanDefinitionException e) {} - try { - this.requestRejectedHandler = applicationContext.getBean(RequestRejectedHandler.class); - } catch(NoSuchBeanDefinitionException e) {} - } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/AutowiredWebSecurityConfigurersIgnoreParents.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/AutowiredWebSecurityConfigurersIgnoreParents.java index 9195060ef1..dec674d305 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/AutowiredWebSecurityConfigurersIgnoreParents.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/AutowiredWebSecurityConfigurersIgnoreParents.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import java.util.ArrayList; @@ -34,14 +35,12 @@ import org.springframework.util.Assert; * {@link ApplicationContext} but ignoring the parent. * * @author Rob Winch - * */ -final class AutowiredWebSecurityConfigurersIgnoreParents { +public final class AutowiredWebSecurityConfigurersIgnoreParents { private final ConfigurableListableBeanFactory beanFactory; - AutowiredWebSecurityConfigurersIgnoreParents( - ConfigurableListableBeanFactory beanFactory) { + AutowiredWebSecurityConfigurersIgnoreParents(ConfigurableListableBeanFactory beanFactory) { Assert.notNull(beanFactory, "beanFactory cannot be null"); this.beanFactory = beanFactory; } @@ -49,11 +48,11 @@ final class AutowiredWebSecurityConfigurersIgnoreParents { @SuppressWarnings({ "rawtypes", "unchecked" }) public List> getWebSecurityConfigurers() { List> webSecurityConfigurers = new ArrayList<>(); - Map beansOfType = beanFactory - .getBeansOfType(WebSecurityConfigurer.class); + Map beansOfType = this.beanFactory.getBeansOfType(WebSecurityConfigurer.class); for (Entry entry : beansOfType.entrySet()) { webSecurityConfigurers.add(entry.getValue()); } return webSecurityConfigurers; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java index 9b6ccdbfa5..aa95ed89a5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurity.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import org.springframework.context.annotation.Configuration; @@ -69,13 +72,11 @@ import org.springframework.security.config.annotation.web.WebSecurityConfigurer; * @author Rob Winch * @since 3.2 */ -@Retention(value = java.lang.annotation.RetentionPolicy.RUNTIME) -@Target(value = { java.lang.annotation.ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) @Documented -@Import({ WebSecurityConfiguration.class, - SpringWebMvcImportSelector.class, - OAuth2ImportSelector.class, - HttpSecurityConfiguration.class}) +@Import({ WebSecurityConfiguration.class, SpringWebMvcImportSelector.class, OAuth2ImportSelector.class, + HttpSecurityConfiguration.class }) @EnableGlobalAuthentication @Configuration public @interface EnableWebSecurity { @@ -85,4 +86,5 @@ public @interface EnableWebSecurity { * @return if true, enables debug support with Spring Security */ boolean debug() default false; + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java index 0b4ea90061..69acc39631 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java @@ -16,6 +16,9 @@ package org.springframework.security.config.annotation.web.configuration; +import java.util.HashMap; +import java.util.Map; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -29,9 +32,6 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.DefaultLoginPageConfigurer; import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter; -import java.util.HashMap; -import java.util.Map; - import static org.springframework.security.config.Customizer.withDefaults; /** @@ -42,7 +42,9 @@ import static org.springframework.security.config.Customizer.withDefaults; */ @Configuration(proxyBeanMethods = false) class HttpSecurityConfiguration { + private static final String BEAN_NAME_PREFIX = "org.springframework.security.config.annotation.web.configuration.HttpSecurityConfiguration."; + private static final String HTTPSECURITY_BEAN_NAME = BEAN_NAME_PREFIX + "httpSecurity"; private ObjectPostProcessor objectPostProcessor; @@ -54,7 +56,7 @@ class HttpSecurityConfiguration { private ApplicationContext context; @Autowired - public void setObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { + void setObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { this.objectPostProcessor = objectPostProcessor; } @@ -64,54 +66,50 @@ class HttpSecurityConfiguration { } @Autowired - public void setAuthenticationConfiguration( - AuthenticationConfiguration authenticationConfiguration) { + void setAuthenticationConfiguration(AuthenticationConfiguration authenticationConfiguration) { this.authenticationConfiguration = authenticationConfiguration; } @Autowired - public void setApplicationContext(ApplicationContext context) { + void setApplicationContext(ApplicationContext context) { this.context = context; } @Bean(HTTPSECURITY_BEAN_NAME) @Scope("prototype") - public HttpSecurity httpSecurity() throws Exception { - WebSecurityConfigurerAdapter.LazyPasswordEncoder passwordEncoder = - new WebSecurityConfigurerAdapter.LazyPasswordEncoder(this.context); - - AuthenticationManagerBuilder authenticationBuilder = - new WebSecurityConfigurerAdapter.DefaultPasswordEncoderAuthenticationManagerBuilder(this.objectPostProcessor, passwordEncoder); + HttpSecurity httpSecurity() throws Exception { + WebSecurityConfigurerAdapter.LazyPasswordEncoder passwordEncoder = new WebSecurityConfigurerAdapter.LazyPasswordEncoder( + this.context); + AuthenticationManagerBuilder authenticationBuilder = new WebSecurityConfigurerAdapter.DefaultPasswordEncoderAuthenticationManagerBuilder( + this.objectPostProcessor, passwordEncoder); authenticationBuilder.parentAuthenticationManager(authenticationManager()); - - HttpSecurity http = new HttpSecurity(objectPostProcessor, authenticationBuilder, createSharedObjects()); + HttpSecurity http = new HttpSecurity(this.objectPostProcessor, authenticationBuilder, createSharedObjects()); + // @formatter:off http - .csrf(withDefaults()) - .addFilter(new WebAsyncManagerIntegrationFilter()) - .exceptionHandling(withDefaults()) - .headers(withDefaults()) - .sessionManagement(withDefaults()) - .securityContext(withDefaults()) - .requestCache(withDefaults()) - .anonymous(withDefaults()) - .servletApi(withDefaults()) - .logout(withDefaults()) - .apply(new DefaultLoginPageConfigurer<>()); - + .csrf(withDefaults()) + .addFilter(new WebAsyncManagerIntegrationFilter()) + .exceptionHandling(withDefaults()) + .headers(withDefaults()) + .sessionManagement(withDefaults()) + .securityContext(withDefaults()) + .requestCache(withDefaults()) + .anonymous(withDefaults()) + .servletApi(withDefaults()) + .logout(withDefaults()) + .apply(new DefaultLoginPageConfigurer<>()); + // @formatter:on return http; } private AuthenticationManager authenticationManager() throws Exception { - if (this.authenticationManager != null) { - return this.authenticationManager; - } else { - return this.authenticationConfiguration.getAuthenticationManager(); - } + return (this.authenticationManager != null) ? this.authenticationManager + : this.authenticationConfiguration.getAuthenticationManager(); } private Map, Object> createSharedObjects() { Map, Object> sharedObjects = new HashMap<>(); - sharedObjects.put(ApplicationContext.class, context); + sharedObjects.put(ApplicationContext.class, this.context); return sharedObjects; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 892fb394f0..b48c565a51 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; @@ -33,8 +36,6 @@ import org.springframework.util.ClassUtils; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; -import java.util.List; - /** * {@link Configuration} for OAuth 2.0 Client support. * @@ -53,20 +54,25 @@ final class OAuth2ClientConfiguration { @Override public String[] selectImports(AnnotationMetadata importingClassMetadata) { - boolean webmvcPresent = ClassUtils.isPresent( - "org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader()); - - return webmvcPresent ? - new String[] { "org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" } : - new String[] {}; + if (!ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", + getClass().getClassLoader())) { + return new String[0]; + } + return new String[] { "org.springframework.security.config.annotation.web.configuration." + + "OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" }; } + } @Configuration(proxyBeanMethods = false) static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private OAuth2AuthorizedClientManager authorizedClientManager; @Override @@ -92,7 +98,8 @@ final class OAuth2ClientConfiguration { } @Autowired(required = false) - void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + void setAccessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { this.accessTokenResponseClient = accessTokenResponseClient; } @@ -107,29 +114,31 @@ final class OAuth2ClientConfiguration { if (this.authorizedClientManager != null) { return this.authorizedClientManager; } - OAuth2AuthorizedClientManager authorizedClientManager = null; if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) { if (this.accessTokenResponseClient != null) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials(configurer -> - configurer.accessTokenResponseClient(this.accessTokenResponseClient)) - .password() - .build(); - DefaultOAuth2AuthorizedClientManager defaultAuthorizedClientManager = - new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); + // @formatter:off + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder + .builder() + .authorizationCode() + .refreshToken() + .clientCredentials((configurer) -> configurer.accessTokenResponseClient(this.accessTokenResponseClient)) + .password() + .build(); + // @formatter:on + DefaultOAuth2AuthorizedClientManager defaultAuthorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); defaultAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); authorizedClientManager = defaultAuthorizedClientManager; - } else { + } + else { authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); } } return authorizedClientManager; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java index a999d5c7f8..79eeb478d5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import java.util.LinkedHashSet; @@ -21,15 +22,17 @@ import java.util.Set; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; /** * Used by {@link EnableWebSecurity} to conditionally import: * *
      - *
    • {@link OAuth2ClientConfiguration} when the {@code spring-security-oauth2-client} module is present on the classpath
    • - *
    • {@link SecurityReactorContextConfiguration} when either the {@code spring-security-oauth2-client} or - * {@code spring-security-oauth2-resource-server} module as well as the {@code spring-webflux} module - * are present on the classpath
    • + *
    • {@link OAuth2ClientConfiguration} when the {@code spring-security-oauth2-client} + * module is present on the classpath
    • + *
    • {@link SecurityReactorContextConfiguration} when either the + * {@code spring-security-oauth2-client} or {@code spring-security-oauth2-resource-server} + * module as well as the {@code spring-webflux} module are present on the classpath
    • *
    * * @author Joe Grandja @@ -43,25 +46,25 @@ final class OAuth2ImportSelector implements ImportSelector { @Override public String[] selectImports(AnnotationMetadata importingClassMetadata) { Set imports = new LinkedHashSet<>(); - - boolean oauth2ClientPresent = ClassUtils.isPresent( - "org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader()); + ClassLoader classLoader = getClass().getClassLoader(); + boolean oauth2ClientPresent = ClassUtils + .isPresent("org.springframework.security.oauth2.client.registration.ClientRegistration", classLoader); + boolean webfluxPresent = ClassUtils + .isPresent("org.springframework.web.reactive.function.client.ExchangeFilterFunction", classLoader); + boolean oauth2ResourceServerPresent = ClassUtils + .isPresent("org.springframework.security.oauth2.server.resource.BearerTokenError", classLoader); if (oauth2ClientPresent) { imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration"); } - - boolean webfluxPresent = ClassUtils.isPresent( - "org.springframework.web.reactive.function.client.ExchangeFilterFunction", getClass().getClassLoader()); if (webfluxPresent && oauth2ClientPresent) { - imports.add("org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration"); + imports.add( + "org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration"); } - - boolean oauth2ResourceServerPresent = ClassUtils.isPresent( - "org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader()); if (webfluxPresent && oauth2ResourceServerPresent) { - imports.add("org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration"); + imports.add( + "org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration"); } - - return imports.toArray(new String[0]); + return StringUtils.toStringArray(imports); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java index 8d76982c80..2783cb358b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java @@ -13,10 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.annotation.Bean; @@ -26,28 +40,15 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Hooks; -import reactor.core.publisher.Operators; -import reactor.util.context.Context; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES; /** - * {@link Configuration} that (potentially) adds a "decorating" {@code Publisher} - * for the last operator created in every {@code Mono} or {@code Flux}. + * {@link Configuration} that (potentially) adds a "decorating" {@code Publisher} for the + * last operator created in every {@code Mono} or {@code Flux}. * *

    - * The {@code Publisher} is solely responsible for adding - * the current {@code HttpServletRequest}, {@code HttpServletResponse} and {@code Authentication} - * to the Reactor {@code Context} so that it's accessible in every flow, if required. + * The {@code Publisher} is solely responsible for adding the current + * {@code HttpServletRequest}, {@code HttpServletResponse} and {@code Authentication} to + * the Reactor {@code Context} so that it's accessible in every flow, if required. * * @author Joe Grandja * @author Roman Matiushchenko @@ -63,14 +64,14 @@ class SecurityReactorContextConfiguration { } static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean { + private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR"; @Override public void afterPropertiesSet() throws Exception { - Function, ? extends Publisher> lifter = - Operators.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub)); - - Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, pub -> { + Function, ? extends Publisher> lifter = Operators + .liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub)); + Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, (pub) -> { if (!contextAttributesAvailable()) { // No need to decorate so return original Publisher return pub; @@ -85,7 +86,7 @@ class SecurityReactorContextConfiguration { } CoreSubscriber createSubscriberIfNecessary(CoreSubscriber delegate) { - if (delegate.currentContext().hasKey(SECURITY_CONTEXT_ATTRIBUTES)) { + if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) { // Already enriched. No need to create Subscriber so return original return delegate; } @@ -93,8 +94,8 @@ class SecurityReactorContextConfiguration { } private static boolean contextAttributesAvailable() { - return SecurityContextHolder.getContext().getAuthentication() != null || - RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes; + return SecurityContextHolder.getContext().getAuthentication() != null + || RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes; } private static Map getContextAttributes() { @@ -104,13 +105,12 @@ class SecurityReactorContextConfiguration { if (requestAttributes instanceof ServletRequestAttributes) { ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes; servletRequest = servletRequestAttributes.getRequest(); - servletResponse = servletRequestAttributes.getResponse(); // possible null + servletResponse = servletRequestAttributes.getResponse(); // possible null } Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (authentication == null && servletRequest == null) { return Collections.emptyMap(); } - Map contextAttributes = new HashMap<>(); if (servletRequest != null) { contextAttributes.put(HttpServletRequest.class, servletRequest); @@ -124,25 +124,30 @@ class SecurityReactorContextConfiguration { return contextAttributes; } + } static class SecurityReactorContextSubscriber implements CoreSubscriber { + static final String SECURITY_CONTEXT_ATTRIBUTES = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES"; + private final CoreSubscriber delegate; + private final Context context; SecurityReactorContextSubscriber(CoreSubscriber delegate, Map attributes) { this.delegate = delegate; - Context currentContext = this.delegate.currentContext(); - Context context; - if (currentContext.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) { - context = currentContext; - } else { - context = currentContext.put(SECURITY_CONTEXT_ATTRIBUTES, attributes); - } + Context context = getOrPutContext(attributes, this.delegate.currentContext()); this.context = context; } + private Context getOrPutContext(Map attributes, Context currentContext) { + if (currentContext.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) { + return currentContext; + } + return currentContext.put(SECURITY_CONTEXT_ATTRIBUTES, attributes); + } + @Override public Context currentContext() { return this.context; @@ -159,13 +164,15 @@ class SecurityReactorContextConfiguration { } @Override - public void onError(Throwable t) { - this.delegate.onError(t); + public void onError(Throwable ex) { + this.delegate.onError(ex); } @Override public void onComplete() { this.delegate.onComplete(); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SpringWebMvcImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SpringWebMvcImportSelector.java index 1516a3e3f0..fee0fd6162 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SpringWebMvcImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/SpringWebMvcImportSelector.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import org.springframework.context.annotation.ImportSelector; @@ -29,19 +30,13 @@ import org.springframework.util.ClassUtils; */ class SpringWebMvcImportSelector implements ImportSelector { - /* - * (non-Javadoc) - * - * @see org.springframework.context.annotation.ImportSelector#selectImports(org. - * springframework .core.type.AnnotationMetadata) - */ + @Override public String[] selectImports(AnnotationMetadata importingClassMetadata) { - boolean webmvcPresent = ClassUtils.isPresent( - "org.springframework.web.servlet.DispatcherServlet", - getClass().getClassLoader()); - return webmvcPresent - ? new String[] { - "org.springframework.security.config.annotation.web.configuration.WebMvcSecurityConfiguration" } - : new String[] {}; + if (!ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader())) { + return new String[0]; + } + return new String[] { + "org.springframework.security.config.annotation.web.configuration.WebMvcSecurityConfiguration" }; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java index 312a1b170e..db63fcf82b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java @@ -13,30 +13,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import java.util.List; + import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.context.expression.BeanFactoryResolver; import org.springframework.expression.BeanResolver; -import org.springframework.security.web.method.annotation.CurrentSecurityContextArgumentResolver; import org.springframework.security.web.method.annotation.AuthenticationPrincipalArgumentResolver; import org.springframework.security.web.method.annotation.CsrfTokenArgumentResolver; +import org.springframework.security.web.method.annotation.CurrentSecurityContextArgumentResolver; import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import org.springframework.web.servlet.support.RequestDataValueProcessor; -import java.util.List; - /** * Used to add a {@link RequestDataValueProcessor} for Spring MVC and Spring Security CSRF * integration. This configuration is added whenever {@link EnableWebMvc} is added by - * SpringWebMvcImportSelector and the DispatcherServlet is present on the - * classpath. It also adds the {@link AuthenticationPrincipalArgumentResolver} as a + * SpringWebMvcImportSelector + * and the DispatcherServlet is present on the classpath. It also adds the + * {@link AuthenticationPrincipalArgumentResolver} as a * {@link HandlerMethodArgumentResolver}. * * @author Rob Winch @@ -44,25 +47,25 @@ import java.util.List; * @since 3.2 */ class WebMvcSecurityConfiguration implements WebMvcConfigurer, ApplicationContextAware { + private BeanResolver beanResolver; @Override @SuppressWarnings("deprecation") public void addArgumentResolvers(List argumentResolvers) { AuthenticationPrincipalArgumentResolver authenticationPrincipalResolver = new AuthenticationPrincipalArgumentResolver(); - authenticationPrincipalResolver.setBeanResolver(beanResolver); + authenticationPrincipalResolver.setBeanResolver(this.beanResolver); argumentResolvers.add(authenticationPrincipalResolver); argumentResolvers .add(new org.springframework.security.web.bind.support.AuthenticationPrincipalArgumentResolver()); - CurrentSecurityContextArgumentResolver currentSecurityContextArgumentResolver = new CurrentSecurityContextArgumentResolver(); - currentSecurityContextArgumentResolver.setBeanResolver(beanResolver); + currentSecurityContextArgumentResolver.setBeanResolver(this.beanResolver); argumentResolvers.add(currentSecurityContextArgumentResolver); argumentResolvers.add(new CsrfTokenArgumentResolver()); } @Bean - public RequestDataValueProcessor requestDataValueProcessor() { + RequestDataValueProcessor requestDataValueProcessor() { return new CsrfRequestDataValueProcessor(); } @@ -70,4 +73,5 @@ class WebMvcSecurityConfiguration implements WebMvcConfigurer, ApplicationContex public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.beanResolver = new BeanFactoryResolver(applicationContext.getAutowireCapableBeanFactory()); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfiguration.java index 31e5e9f6ea..27fc4ba5e1 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfiguration.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import java.util.Collections; import java.util.List; import java.util.Map; + import javax.servlet.Filter; import org.springframework.beans.factory.BeanClassLoaderAware; @@ -48,7 +50,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; - +import org.springframework.util.Assert; /** * Uses a {@link WebSecurity} to create the {@link FilterChainProxy} that performs the web @@ -60,13 +62,13 @@ import org.springframework.security.web.context.AbstractSecurityWebApplicationIn * * @see EnableWebSecurity * @see WebSecurity - * * @author Rob Winch * @author Keesun Baik * @since 3.2 */ @Configuration(proxyBeanMethods = false) public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAware { + private WebSecurity webSecurity; private Boolean debugEnabled; @@ -88,7 +90,7 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa @Bean @DependsOn(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME) public SecurityExpressionHandler webSecurityExpressionHandler() { - return webSecurity.getExpressionHandler(); + return this.webSecurity.getExpressionHandler(); } /** @@ -98,30 +100,26 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa */ @Bean(name = AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME) public Filter springSecurityFilterChain() throws Exception { - boolean hasConfigurers = webSecurityConfigurers != null - && !webSecurityConfigurers.isEmpty(); - boolean hasFilterChain = !securityFilterChains.isEmpty(); - if (hasConfigurers && hasFilterChain) { - throw new IllegalStateException( - "Found WebSecurityConfigurerAdapter as well as SecurityFilterChain." + - "Please select just one."); - } + boolean hasConfigurers = this.webSecurityConfigurers != null && !this.webSecurityConfigurers.isEmpty(); + boolean hasFilterChain = !this.securityFilterChains.isEmpty(); + Assert.state(!(hasConfigurers && hasFilterChain), + "Found WebSecurityConfigurerAdapter as well as SecurityFilterChain. Please select just one."); if (!hasConfigurers && !hasFilterChain) { - WebSecurityConfigurerAdapter adapter = objectObjectPostProcessor + WebSecurityConfigurerAdapter adapter = this.objectObjectPostProcessor .postProcess(new WebSecurityConfigurerAdapter() { }); - webSecurity.apply(adapter); + this.webSecurity.apply(adapter); } - for (SecurityFilterChain securityFilterChain : securityFilterChains) { - webSecurity.addSecurityFilterChainBuilder(() -> securityFilterChain); + for (SecurityFilterChain securityFilterChain : this.securityFilterChains) { + this.webSecurity.addSecurityFilterChainBuilder(() -> securityFilterChain); for (Filter filter : securityFilterChain.getFilters()) { if (filter instanceof FilterSecurityInterceptor) { - webSecurity.securityInterceptor((FilterSecurityInterceptor) filter); + this.webSecurity.securityInterceptor((FilterSecurityInterceptor) filter); break; } } } - return webSecurity.build(); + return this.webSecurity.build(); } /** @@ -132,13 +130,12 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa @Bean @DependsOn(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME) public WebInvocationPrivilegeEvaluator privilegeEvaluator() { - return webSecurity.getPrivilegeEvaluator(); + return this.webSecurity.getPrivilegeEvaluator(); } /** * Sets the {@code } * instances used to create the web configuration. - * * @param objectPostProcessor the {@link ObjectPostProcessor} used to create a * {@link WebSecurity} instance * @param webSecurityConfigurers the @@ -147,33 +144,27 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa * @throws Exception */ @Autowired(required = false) - public void setFilterChainProxySecurityConfigurer( - ObjectPostProcessor objectPostProcessor, + public void setFilterChainProxySecurityConfigurer(ObjectPostProcessor objectPostProcessor, @Value("#{@autowiredWebSecurityConfigurersIgnoreParents.getWebSecurityConfigurers()}") List> webSecurityConfigurers) throws Exception { - webSecurity = objectPostProcessor - .postProcess(new WebSecurity(objectPostProcessor)); - if (debugEnabled != null) { - webSecurity.debug(debugEnabled); + this.webSecurity = objectPostProcessor.postProcess(new WebSecurity(objectPostProcessor)); + if (this.debugEnabled != null) { + this.webSecurity.debug(this.debugEnabled); } - webSecurityConfigurers.sort(AnnotationAwareOrderComparator.INSTANCE); - Integer previousOrder = null; Object previousConfig = null; for (SecurityConfigurer config : webSecurityConfigurers) { Integer order = AnnotationAwareOrderComparator.lookupOrder(config); if (previousOrder != null && previousOrder.equals(order)) { - throw new IllegalStateException( - "@Order on WebSecurityConfigurers must be unique. Order of " - + order + " was already used on " + previousConfig + ", so it cannot be used on " - + config + " too."); + throw new IllegalStateException("@Order on WebSecurityConfigurers must be unique. Order of " + order + + " was already used on " + previousConfig + ", so it cannot be used on " + config + " too."); } previousOrder = order; previousConfig = config; } for (SecurityConfigurer webSecurityConfigurer : webSecurityConfigurers) { - webSecurity.apply(webSecurityConfigurer); + this.webSecurity.apply(webSecurityConfigurer); } this.webSecurityConfigurers = webSecurityConfigurers; } @@ -195,6 +186,22 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa return new AutowiredWebSecurityConfigurersIgnoreParents(beanFactory); } + @Override + public void setImportMetadata(AnnotationMetadata importMetadata) { + Map enableWebSecurityAttrMap = importMetadata + .getAnnotationAttributes(EnableWebSecurity.class.getName()); + AnnotationAttributes enableWebSecurityAttrs = AnnotationAttributes.fromMap(enableWebSecurityAttrMap); + this.debugEnabled = enableWebSecurityAttrs.getBoolean("debug"); + if (this.webSecurity != null) { + this.webSecurity.debug(this.debugEnabled); + } + } + + @Override + public void setBeanClassLoader(ClassLoader classLoader) { + this.beanClassLoader = classLoader; + } + /** * A custom verision of the Spring provided AnnotationAwareOrderComparator that uses * {@link AnnotationUtils#findAnnotation(Class, Class)} to look on super class @@ -204,6 +211,7 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa * @since 3.2 */ private static class AnnotationAwareOrderComparator extends OrderComparator { + private static final AnnotationAwareOrderComparator INSTANCE = new AnnotationAwareOrderComparator(); @Override @@ -216,7 +224,7 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa return ((Ordered) obj).getOrder(); } if (obj != null) { - Class clazz = (obj instanceof Class ? (Class) obj : obj.getClass()); + Class clazz = ((obj instanceof Class) ? (Class) obj : obj.getClass()); Order order = AnnotationUtils.findAnnotation(clazz, Order.class); if (order != null) { return order.value(); @@ -224,33 +232,7 @@ public class WebSecurityConfiguration implements ImportAware, BeanClassLoaderAwa } return Ordered.LOWEST_PRECEDENCE; } + } - /* - * (non-Javadoc) - * - * @see org.springframework.context.annotation.ImportAware#setImportMetadata(org. - * springframework.core.type.AnnotationMetadata) - */ - public void setImportMetadata(AnnotationMetadata importMetadata) { - Map enableWebSecurityAttrMap = importMetadata - .getAnnotationAttributes(EnableWebSecurity.class.getName()); - AnnotationAttributes enableWebSecurityAttrs = AnnotationAttributes - .fromMap(enableWebSecurityAttrMap); - debugEnabled = enableWebSecurityAttrs.getBoolean("debug"); - if (webSecurity != null) { - webSecurity.debug(debugEnabled); - } - } - - /* - * (non-Javadoc) - * - * @see - * org.springframework.beans.factory.BeanClassLoaderAware#setBeanClassLoader(java. - * lang.ClassLoader) - */ - public void setBeanClassLoader(ClassLoader classLoader) { - this.beanClassLoader = classLoader; - } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurerAdapter.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurerAdapter.java index 4667967f99..61609ede2a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurerAdapter.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurerAdapter.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import java.lang.reflect.Field; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -69,31 +69,30 @@ import org.springframework.web.accept.ContentNegotiationStrategy; import org.springframework.web.accept.HeaderContentNegotiationStrategy; /** - * Provides a convenient base class for creating a {@link WebSecurityConfigurer} - * instance. The implementation allows customization by overriding methods. + * Provides a convenient base class for creating a {@link WebSecurityConfigurer} instance. + * The implementation allows customization by overriding methods. * *

    - * Will automatically apply the result of looking up - * {@link AbstractHttpConfigurer} from {@link SpringFactoriesLoader} to allow - * developers to extend the defaults. - * To do this, you must create a class that extends AbstractHttpConfigurer and then create a file in the classpath at "META-INF/spring.factories" that looks something like: + * Will automatically apply the result of looking up {@link AbstractHttpConfigurer} from + * {@link SpringFactoriesLoader} to allow developers to extend the defaults. To do this, + * you must create a class that extends AbstractHttpConfigurer and then create a file in + * the classpath at "META-INF/spring.factories" that looks something like: *

    *
      * org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer = sample.MyClassThatExtendsAbstractHttpConfigurer
    - * 
    - * If you have multiple classes that should be added you can use "," to separate the values. For example: + * If you have multiple classes that should be added you can use "," to separate + * the values. For example: * *
      * org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer = sample.MyClassThatExtendsAbstractHttpConfigurer, sample.OtherThatExtendsAbstractHttpConfigurer
      * 
    * * @see EnableWebSecurity - * * @author Rob Winch */ @Order(100) -public abstract class WebSecurityConfigurerAdapter implements - WebSecurityConfigurer { +public abstract class WebSecurityConfigurerAdapter implements WebSecurityConfigurer { + private final Log logger = LogFactory.getLog(WebSecurityConfigurerAdapter.class); private ApplicationContext context; @@ -101,21 +100,29 @@ public abstract class WebSecurityConfigurerAdapter implements private ContentNegotiationStrategy contentNegotiationStrategy = new HeaderContentNegotiationStrategy(); private ObjectPostProcessor objectPostProcessor = new ObjectPostProcessor() { + @Override public T postProcess(T object) { - throw new IllegalStateException( - ObjectPostProcessor.class.getName() - + " is a required bean. Ensure you have used @EnableWebSecurity and @Configuration"); + throw new IllegalStateException(ObjectPostProcessor.class.getName() + + " is a required bean. Ensure you have used @EnableWebSecurity and @Configuration"); } }; private AuthenticationConfiguration authenticationConfiguration; + private AuthenticationManagerBuilder authenticationBuilder; + private AuthenticationManagerBuilder localConfigureAuthenticationBldr; + private boolean disableLocalConfigureAuthenticationBldr; + private boolean authenticationManagerInitialized; + private AuthenticationManager authenticationManager; + private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); + private HttpSecurity http; + private boolean disableDefaults; /** @@ -129,7 +136,6 @@ public abstract class WebSecurityConfigurerAdapter implements * Creates an instance which allows specifying if the default configuration should be * enabled. Disabling the default configuration should be considered more advanced * usage as it requires more understanding of how the framework is implemented. - * * @param disableDefaults true if the default configuration should be disabled, else * false */ @@ -176,7 +182,6 @@ public abstract class WebSecurityConfigurerAdapter implements * } * * - * * @param auth the {@link AuthenticationManagerBuilder} to use * @throws Exception */ @@ -186,50 +191,45 @@ public abstract class WebSecurityConfigurerAdapter implements /** * Creates the {@link HttpSecurity} or returns the current instance - * * @return the {@link HttpSecurity} * @throws Exception */ @SuppressWarnings({ "rawtypes", "unchecked" }) protected final HttpSecurity getHttp() throws Exception { - if (http != null) { - return http; + if (this.http != null) { + return this.http; } - AuthenticationEventPublisher eventPublisher = getAuthenticationEventPublisher(); - localConfigureAuthenticationBldr.authenticationEventPublisher(eventPublisher); - + this.localConfigureAuthenticationBldr.authenticationEventPublisher(eventPublisher); AuthenticationManager authenticationManager = authenticationManager(); - authenticationBuilder.parentAuthenticationManager(authenticationManager); + this.authenticationBuilder.parentAuthenticationManager(authenticationManager); Map, Object> sharedObjects = createSharedObjects(); - - http = new HttpSecurity(objectPostProcessor, authenticationBuilder, - sharedObjects); - if (!disableDefaults) { - // @formatter:off - http - .csrf().and() - .addFilter(new WebAsyncManagerIntegrationFilter()) - .exceptionHandling().and() - .headers().and() - .sessionManagement().and() - .securityContext().and() - .requestCache().and() - .anonymous().and() - .servletApi().and() - .apply(new DefaultLoginPageConfigurer<>()).and() - .logout(); - // @formatter:on + this.http = new HttpSecurity(this.objectPostProcessor, this.authenticationBuilder, sharedObjects); + if (!this.disableDefaults) { + applyDefaultConfiguration(this.http); ClassLoader classLoader = this.context.getClassLoader(); - List defaultHttpConfigurers = - SpringFactoriesLoader.loadFactories(AbstractHttpConfigurer.class, classLoader); - + List defaultHttpConfigurers = SpringFactoriesLoader + .loadFactories(AbstractHttpConfigurer.class, classLoader); for (AbstractHttpConfigurer configurer : defaultHttpConfigurers) { - http.apply(configurer); + this.http.apply(configurer); } } - configure(http); - return http; + configure(this.http); + return this.http; + } + + private void applyDefaultConfiguration(HttpSecurity http) throws Exception { + http.csrf(); + http.addFilter(new WebAsyncManagerIntegrationFilter()); + http.exceptionHandling(); + http.headers(); + http.sessionManagement(); + http.securityContext(); + http.requestCache(); + http.anonymous(); + http.servletApi(); + http.apply(new DefaultLoginPageConfigurer<>()); + http.logout(); } /** @@ -244,12 +244,11 @@ public abstract class WebSecurityConfigurerAdapter implements * return super.authenticationManagerBean(); * } * - * * @return the {@link AuthenticationManager} * @throws Exception */ public AuthenticationManager authenticationManagerBean() throws Exception { - return new AuthenticationManagerDelegator(authenticationBuilder, context); + return new AuthenticationManagerDelegator(this.authenticationBuilder, this.context); } /** @@ -257,23 +256,21 @@ public abstract class WebSecurityConfigurerAdapter implements * {@link #configure(AuthenticationManagerBuilder)} method is overridden to use the * {@link AuthenticationManagerBuilder} that was passed in. Otherwise, autowire the * {@link AuthenticationManager} by type. - * * @return the {@link AuthenticationManager} to use * @throws Exception */ protected AuthenticationManager authenticationManager() throws Exception { - if (!authenticationManagerInitialized) { - configure(localConfigureAuthenticationBldr); - if (disableLocalConfigureAuthenticationBldr) { - authenticationManager = authenticationConfiguration - .getAuthenticationManager(); + if (!this.authenticationManagerInitialized) { + configure(this.localConfigureAuthenticationBldr); + if (this.disableLocalConfigureAuthenticationBldr) { + this.authenticationManager = this.authenticationConfiguration.getAuthenticationManager(); } else { - authenticationManager = localConfigureAuthenticationBldr.build(); + this.authenticationManager = this.localConfigureAuthenticationBldr.build(); } - authenticationManagerInitialized = true; + this.authenticationManagerInitialized = true; } - return authenticationManager; + return this.authenticationManager; } /** @@ -297,10 +294,8 @@ public abstract class WebSecurityConfigurerAdapter implements * @see #userDetailsService() */ public UserDetailsService userDetailsServiceBean() throws Exception { - AuthenticationManagerBuilder globalAuthBuilder = context - .getBean(AuthenticationManagerBuilder.class); - return new UserDetailsServiceDelegator(Arrays.asList( - localConfigureAuthenticationBldr, globalAuthBuilder)); + AuthenticationManagerBuilder globalAuthBuilder = this.context.getBean(AuthenticationManagerBuilder.class); + return new UserDetailsServiceDelegator(Arrays.asList(this.localConfigureAuthenticationBldr, globalAuthBuilder)); } /** @@ -308,21 +303,18 @@ public abstract class WebSecurityConfigurerAdapter implements * {@link #userDetailsServiceBean()} without interacting with the * {@link ApplicationContext}. Developers should override this method when changing * the instance of {@link #userDetailsServiceBean()}. - * * @return the {@link UserDetailsService} to use */ protected UserDetailsService userDetailsService() { - AuthenticationManagerBuilder globalAuthBuilder = context - .getBean(AuthenticationManagerBuilder.class); - return new UserDetailsServiceDelegator(Arrays.asList( - localConfigureAuthenticationBldr, globalAuthBuilder)); + AuthenticationManagerBuilder globalAuthBuilder = this.context.getBean(AuthenticationManagerBuilder.class); + return new UserDetailsServiceDelegator(Arrays.asList(this.localConfigureAuthenticationBldr, globalAuthBuilder)); } - public void init(final WebSecurity web) throws Exception { - final HttpSecurity http = getHttp(); + @Override + public void init(WebSecurity web) throws Exception { + HttpSecurity http = getHttp(); web.addSecurityFilterChainBuilder(http).postBuildAction(() -> { - FilterSecurityInterceptor securityInterceptor = http - .getSharedObject(FilterSecurityInterceptor.class); + FilterSecurityInterceptor securityInterceptor = http.getSharedObject(FilterSecurityInterceptor.class); web.securityInterceptor(securityInterceptor); }); } @@ -338,6 +330,7 @@ public abstract class WebSecurityConfigurerAdapter implements * {@link #configure(HttpSecurity)} and the {@link HttpSecurity#authorizeRequests} * configuration method. */ + @Override public void configure(WebSecurity web) throws Exception { } @@ -350,25 +343,19 @@ public abstract class WebSecurityConfigurerAdapter implements * http.authorizeRequests().anyRequest().authenticated().and().formLogin().and().httpBasic(); * * - * Any endpoint that requires defense against common vulnerabilities can be specified here, including public ones. - * See {@link HttpSecurity#authorizeRequests} and the `permitAll()` authorization rule - * for more details on public endpoints. - * + * Any endpoint that requires defense against common vulnerabilities can be specified + * here, including public ones. See {@link HttpSecurity#authorizeRequests} and the + * `permitAll()` authorization rule for more details on public endpoints. * @param http the {@link HttpSecurity} to modify * @throws Exception if an error occurs */ - // @formatter:off protected void configure(HttpSecurity http) throws Exception { - logger.debug("Using default configure(HttpSecurity). If subclassed this will potentially override subclass configure(HttpSecurity)."); - - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .formLogin().and() - .httpBasic(); + this.logger.debug("Using default configure(HttpSecurity). " + + "If subclassed this will potentially override subclass configure(HttpSecurity)."); + http.authorizeRequests((requests) -> requests.anyRequest().authenticated()); + http.formLogin(); + http.httpBasic(); } - // @formatter:on /** * Gets the ApplicationContext @@ -381,23 +368,26 @@ public abstract class WebSecurityConfigurerAdapter implements @Autowired public void setApplicationContext(ApplicationContext context) { this.context = context; - ObjectPostProcessor objectPostProcessor = context.getBean(ObjectPostProcessor.class); LazyPasswordEncoder passwordEncoder = new LazyPasswordEncoder(context); + this.authenticationBuilder = new DefaultPasswordEncoderAuthenticationManagerBuilder(objectPostProcessor, + passwordEncoder); + this.localConfigureAuthenticationBldr = new DefaultPasswordEncoderAuthenticationManagerBuilder( + objectPostProcessor, passwordEncoder) { - authenticationBuilder = new DefaultPasswordEncoderAuthenticationManagerBuilder(objectPostProcessor, passwordEncoder); - localConfigureAuthenticationBldr = new DefaultPasswordEncoderAuthenticationManagerBuilder(objectPostProcessor, passwordEncoder) { @Override public AuthenticationManagerBuilder eraseCredentials(boolean eraseCredentials) { - authenticationBuilder.eraseCredentials(eraseCredentials); + WebSecurityConfigurerAdapter.this.authenticationBuilder.eraseCredentials(eraseCredentials); return super.eraseCredentials(eraseCredentials); } @Override - public AuthenticationManagerBuilder authenticationEventPublisher(AuthenticationEventPublisher eventPublisher) { - authenticationBuilder.authenticationEventPublisher(eventPublisher); + public AuthenticationManagerBuilder authenticationEventPublisher( + AuthenticationEventPublisher eventPublisher) { + WebSecurityConfigurerAdapter.this.authenticationBuilder.authenticationEventPublisher(eventPublisher); return super.authenticationEventPublisher(eventPublisher); } + }; } @@ -407,8 +397,7 @@ public abstract class WebSecurityConfigurerAdapter implements } @Autowired(required = false) - public void setContentNegotationStrategy( - ContentNegotiationStrategy contentNegotiationStrategy) { + public void setContentNegotationStrategy(ContentNegotiationStrategy contentNegotiationStrategy) { this.contentNegotiationStrategy = contentNegotiationStrategy; } @@ -418,8 +407,7 @@ public abstract class WebSecurityConfigurerAdapter implements } @Autowired - public void setAuthenticationConfiguration( - AuthenticationConfiguration authenticationConfiguration) { + public void setAuthenticationConfiguration(AuthenticationConfiguration authenticationConfiguration) { this.authenticationConfiguration = authenticationConfiguration; } @@ -432,16 +420,15 @@ public abstract class WebSecurityConfigurerAdapter implements /** * Creates the shared objects - * * @return the shared Objects */ private Map, Object> createSharedObjects() { Map, Object> sharedObjects = new HashMap<>(); - sharedObjects.putAll(localConfigureAuthenticationBldr.getSharedObjects()); + sharedObjects.putAll(this.localConfigureAuthenticationBldr.getSharedObjects()); sharedObjects.put(UserDetailsService.class, userDetailsService()); - sharedObjects.put(ApplicationContext.class, context); - sharedObjects.put(ContentNegotiationStrategy.class, contentNegotiationStrategy); - sharedObjects.put(AuthenticationTrustResolver.class, trustResolver); + sharedObjects.put(ApplicationContext.class, this.context); + sharedObjects.put(ContentNegotiationStrategy.class, this.contentNegotiationStrategy); + sharedObjects.put(AuthenticationTrustResolver.class, this.trustResolver); return sharedObjects; } @@ -453,43 +440,41 @@ public abstract class WebSecurityConfigurerAdapter implements * @since 3.2 */ static final class UserDetailsServiceDelegator implements UserDetailsService { + private List delegateBuilders; + private UserDetailsService delegate; + private final Object delegateMonitor = new Object(); UserDetailsServiceDelegator(List delegateBuilders) { - if (delegateBuilders.contains(null)) { - throw new IllegalArgumentException( - "delegateBuilders cannot contain null values. Got " - + delegateBuilders); - } + Assert.isTrue(!delegateBuilders.contains(null), + () -> "delegateBuilders cannot contain null values. Got " + delegateBuilders); this.delegateBuilders = delegateBuilders; } - public UserDetails loadUserByUsername(String username) - throws UsernameNotFoundException { - if (delegate != null) { - return delegate.loadUserByUsername(username); + @Override + public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException { + if (this.delegate != null) { + return this.delegate.loadUserByUsername(username); } - - synchronized (delegateMonitor) { - if (delegate == null) { - for (AuthenticationManagerBuilder delegateBuilder : delegateBuilders) { - delegate = delegateBuilder.getDefaultUserDetailsService(); - if (delegate != null) { + synchronized (this.delegateMonitor) { + if (this.delegate == null) { + for (AuthenticationManagerBuilder delegateBuilder : this.delegateBuilders) { + this.delegate = delegateBuilder.getDefaultUserDetailsService(); + if (this.delegate != null) { break; } } - - if (delegate == null) { + if (this.delegate == null) { throw new IllegalStateException("UserDetailsService is required."); } this.delegateBuilders = null; } } - - return delegate.loadUserByUsername(username); + return this.delegate.loadUserByUsername(username); } + } /** @@ -500,104 +485,100 @@ public abstract class WebSecurityConfigurerAdapter implements * @since 3.2 */ static final class AuthenticationManagerDelegator implements AuthenticationManager { + private AuthenticationManagerBuilder delegateBuilder; + private AuthenticationManager delegate; + private final Object delegateMonitor = new Object(); + private Set beanNames; - AuthenticationManagerDelegator(AuthenticationManagerBuilder delegateBuilder, - ApplicationContext context) { + AuthenticationManagerDelegator(AuthenticationManagerBuilder delegateBuilder, ApplicationContext context) { Assert.notNull(delegateBuilder, "delegateBuilder cannot be null"); - Field parentAuthMgrField = ReflectionUtils.findField( - AuthenticationManagerBuilder.class, "parentAuthenticationManager"); + Field parentAuthMgrField = ReflectionUtils.findField(AuthenticationManagerBuilder.class, + "parentAuthenticationManager"); ReflectionUtils.makeAccessible(parentAuthMgrField); - beanNames = getAuthenticationManagerBeanNames(context); - validateBeanCycle( - ReflectionUtils.getField(parentAuthMgrField, delegateBuilder), - beanNames); + this.beanNames = getAuthenticationManagerBeanNames(context); + validateBeanCycle(ReflectionUtils.getField(parentAuthMgrField, delegateBuilder), this.beanNames); this.delegateBuilder = delegateBuilder; } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { - if (delegate != null) { - return delegate.authenticate(authentication); + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + if (this.delegate != null) { + return this.delegate.authenticate(authentication); } - - synchronized (delegateMonitor) { - if (delegate == null) { - delegate = this.delegateBuilder.getObject(); + synchronized (this.delegateMonitor) { + if (this.delegate == null) { + this.delegate = this.delegateBuilder.getObject(); this.delegateBuilder = null; } } - - return delegate.authenticate(authentication); + return this.delegate.authenticate(authentication); } - private static Set getAuthenticationManagerBeanNames( - ApplicationContext applicationContext) { - String[] beanNamesForType = BeanFactoryUtils - .beanNamesForTypeIncludingAncestors(applicationContext, - AuthenticationManager.class); + private static Set getAuthenticationManagerBeanNames(ApplicationContext applicationContext) { + String[] beanNamesForType = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(applicationContext, + AuthenticationManager.class); return new HashSet<>(Arrays.asList(beanNamesForType)); } private static void validateBeanCycle(Object auth, Set beanNames) { - if (auth != null && !beanNames.isEmpty()) { - if (auth instanceof Advised) { - Advised advised = (Advised) auth; - TargetSource targetSource = advised.getTargetSource(); - if (targetSource instanceof LazyInitTargetSource) { - LazyInitTargetSource lits = (LazyInitTargetSource) targetSource; - if (beanNames.contains(lits.getTargetBeanName())) { - throw new FatalBeanException( - "A dependency cycle was detected when trying to resolve the AuthenticationManager. Please ensure you have configured authentication."); - } - } - } - beanNames = Collections.emptySet(); + if (auth == null || beanNames.isEmpty() || !(auth instanceof Advised)) { + return; + } + TargetSource targetSource = ((Advised) auth).getTargetSource(); + if (!(targetSource instanceof LazyInitTargetSource)) { + return; + } + LazyInitTargetSource lits = (LazyInitTargetSource) targetSource; + if (beanNames.contains(lits.getTargetBeanName())) { + throw new FatalBeanException( + "A dependency cycle was detected when trying to resolve the AuthenticationManager. " + + "Please ensure you have configured authentication."); } } + } static class DefaultPasswordEncoderAuthenticationManagerBuilder extends AuthenticationManagerBuilder { + private PasswordEncoder defaultPasswordEncoder; /** * Creates a new instance - * * @param objectPostProcessor the {@link ObjectPostProcessor} instance to use. */ - DefaultPasswordEncoderAuthenticationManagerBuilder( - ObjectPostProcessor objectPostProcessor, PasswordEncoder defaultPasswordEncoder) { + DefaultPasswordEncoderAuthenticationManagerBuilder(ObjectPostProcessor objectPostProcessor, + PasswordEncoder defaultPasswordEncoder) { super(objectPostProcessor); this.defaultPasswordEncoder = defaultPasswordEncoder; } @Override public InMemoryUserDetailsManagerConfigurer inMemoryAuthentication() - throws Exception { - return super.inMemoryAuthentication() - .passwordEncoder(this.defaultPasswordEncoder); + throws Exception { + return super.inMemoryAuthentication().passwordEncoder(this.defaultPasswordEncoder); } @Override - public JdbcUserDetailsManagerConfigurer jdbcAuthentication() - throws Exception { - return super.jdbcAuthentication() - .passwordEncoder(this.defaultPasswordEncoder); + public JdbcUserDetailsManagerConfigurer jdbcAuthentication() throws Exception { + return super.jdbcAuthentication().passwordEncoder(this.defaultPasswordEncoder); } @Override public DaoAuthenticationConfigurer userDetailsService( - T userDetailsService) throws Exception { - return super.userDetailsService(userDetailsService) - .passwordEncoder(this.defaultPasswordEncoder); + T userDetailsService) throws Exception { + return super.userDetailsService(userDetailsService).passwordEncoder(this.defaultPasswordEncoder); } + } static class LazyPasswordEncoder implements PasswordEncoder { + private ApplicationContext applicationContext; + private PasswordEncoder passwordEncoder; LazyPasswordEncoder(ApplicationContext applicationContext) { @@ -610,8 +591,7 @@ public abstract class WebSecurityConfigurerAdapter implements } @Override - public boolean matches(CharSequence rawPassword, - String encodedPassword) { + public boolean matches(CharSequence rawPassword, String encodedPassword) { return getPasswordEncoder().matches(rawPassword, encodedPassword); } @@ -635,7 +615,8 @@ public abstract class WebSecurityConfigurerAdapter implements private T getBeanOrNull(Class type) { try { return this.applicationContext.getBean(type); - } catch(NoSuchBeanDefinitionException notFound) { + } + catch (NoSuchBeanDefinitionException ex) { return null; } } @@ -644,5 +625,7 @@ public abstract class WebSecurityConfigurerAdapter implements public String toString() { return getPasswordEncoder().toString(); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java index 0dfd232009..34cd345311 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; +import java.util.Arrays; +import java.util.Collections; + +import javax.servlet.http.HttpServletRequest; + import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; @@ -41,23 +47,17 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.web.accept.ContentNegotiationStrategy; import org.springframework.web.accept.HeaderContentNegotiationStrategy; -import javax.servlet.http.HttpServletRequest; -import java.util.Arrays; -import java.util.Collections; - /** * Base class for configuring {@link AbstractAuthenticationFilterConfigurer}. This is * intended for internal use only. * - * @see FormLoginConfigurer - * @see OpenIDLoginConfigurer - * * @param T refers to "this" for returning the current configurer * @param F refers to the {@link AbstractAuthenticationProcessingFilter} that is being * built - * * @author Rob Winch * @since 3.2 + * @see FormLoginConfigurer + * @see OpenIDLoginConfigurer */ public abstract class AbstractAuthenticationFilterConfigurer, T extends AbstractAuthenticationFilterConfigurer, F extends AbstractAuthenticationProcessingFilter> extends AbstractHttpConfigurer { @@ -67,12 +67,15 @@ public abstract class AbstractAuthenticationFilterConfigurer authenticationDetailsSource; private SavedRequestAwareAuthenticationSuccessHandler defaultSuccessHandler = new SavedRequestAwareAuthenticationSuccessHandler(); + private AuthenticationSuccessHandler successHandler = this.defaultSuccessHandler; private LoginUrlAuthenticationEntryPoint authenticationEntryPoint; private boolean customLoginPage; + private String loginPage; + private String loginProcessingUrl; private AuthenticationFailureHandler failureHandler; @@ -95,8 +98,7 @@ public abstract class AbstractAuthenticationFilterConfigurer exceptionHandling = http - .getConfigurer(ExceptionHandlingConfigurer.class); + ExceptionHandlingConfigurer exceptionHandling = http.getConfigurer(ExceptionHandlingConfigurer.class); if (exceptionHandling == null) { return; } - exceptionHandling.defaultAuthenticationEntryPointFor( - postProcess(authenticationEntryPoint), getAuthenticationEntryPointMatcher(http)); + exceptionHandling.defaultAuthenticationEntryPointFor(postProcess(authenticationEntryPoint), + getAuthenticationEntryPointMatcher(http)); } protected final RequestMatcher getAuthenticationEntryPointMatcher(B http) { - ContentNegotiationStrategy contentNegotiationStrategy = http - .getSharedObject(ContentNegotiationStrategy.class); + ContentNegotiationStrategy contentNegotiationStrategy = http.getSharedObject(ContentNegotiationStrategy.class); if (contentNegotiationStrategy == null) { contentNegotiationStrategy = new HeaderContentNegotiationStrategy(); } - - MediaTypeRequestMatcher mediaMatcher = new MediaTypeRequestMatcher( - contentNegotiationStrategy, MediaType.APPLICATION_XHTML_XML, - new MediaType("image", "*"), MediaType.TEXT_HTML, MediaType.TEXT_PLAIN); + MediaTypeRequestMatcher mediaMatcher = new MediaTypeRequestMatcher(contentNegotiationStrategy, + MediaType.APPLICATION_XHTML_XML, new MediaType("image", "*"), MediaType.TEXT_HTML, + MediaType.TEXT_PLAIN); mediaMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); - RequestMatcher notXRequestedWith = new NegatedRequestMatcher( new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest")); - return new AndRequestMatcher(Arrays.asList(notXRequestedWith, mediaMatcher)); } @@ -280,32 +266,28 @@ public abstract class AbstractAuthenticationFilterConfigurer logoutConfigurer = getBuilder().getConfigurer( - LogoutConfigurer.class); + LogoutConfigurer logoutConfigurer = getBuilder().getConfigurer(LogoutConfigurer.class); if (logoutConfigurer != null && !logoutConfigurer.isCustomLogoutSuccess()) { - logoutConfigurer.logoutSuccessUrl(loginPage + "?logout"); + logoutConfigurer.logoutSuccessUrl(this.loginPage + "?logout"); } } @@ -416,8 +388,8 @@ public abstract class AbstractAuthenticationFilterConfigurer The object that is returned or Chained after creating the RequestMatcher * @author Rob Winch * @since 3.2 - * - * @param The object that is returned or Chained after creating the RequestMatcher - * * @see ChannelSecurityConfigurer * @see UrlAuthorizationConfigurer * @see ExpressionUrlAuthorizationConfigurer */ -public abstract class AbstractConfigAttributeRequestMatcherRegistry extends - AbstractRequestMatcherRegistry { +public abstract class AbstractConfigAttributeRequestMatcherRegistry extends AbstractRequestMatcherRegistry { + private List urlMappings = new ArrayList<>(); + private List unmappedMatchers; /** * Gets the {@link UrlMapping} added by subclasses in * {@link #chainRequestMatchers(java.util.List)}. May be empty. - * * @return the {@link UrlMapping} added by subclasses in * {@link #chainRequestMatchers(java.util.List)} */ final List getUrlMappings() { - return urlMappings; + return this.urlMappings; } /** * Adds a {@link UrlMapping} added by subclasses in * {@link #chainRequestMatchers(java.util.List)} and resets the unmapped * {@link RequestMatcher}'s. - * * @param urlMapping {@link UrlMapping} the mapping to add */ final void addMapping(UrlMapping urlMapping) { @@ -68,11 +67,11 @@ public abstract class AbstractConfigAttributeRequestMatcherRegistry extends /** * Marks the {@link RequestMatcher}'s as unmapped and then calls * {@link #chainRequestMatchersInternal(List)}. - * * @param requestMatchers the {@link RequestMatcher} instances that were created * @return the chained Object for the subclass which allows association of something * else to the {@link RequestMatcher} */ + @Override protected final C chainRequestMatchers(List requestMatchers) { this.unmappedMatchers = requestMatchers; return chainRequestMatchersInternal(requestMatchers); @@ -81,7 +80,6 @@ public abstract class AbstractConfigAttributeRequestMatcherRegistry extends /** * Subclasses should implement this method for returning the object that is chained to * the creation of the {@link RequestMatcher} instances. - * * @param requestMatchers the {@link RequestMatcher} instances that were created * @return the chained Object for the subclass which allows association of something * else to the {@link RequestMatcher} @@ -91,7 +89,6 @@ public abstract class AbstractConfigAttributeRequestMatcherRegistry extends /** * Adds a {@link UrlMapping} added by subclasses in * {@link #chainRequestMatchers(java.util.List)} at a particular index. - * * @param index the index to add a {@link UrlMapping} * @param urlMapping {@link UrlMapping} the mapping to add */ @@ -102,18 +99,12 @@ public abstract class AbstractConfigAttributeRequestMatcherRegistry extends /** * Creates the mapping of {@link RequestMatcher} to {@link Collection} of * {@link ConfigAttribute} instances - * * @return the mapping of {@link RequestMatcher} to {@link Collection} of * {@link ConfigAttribute} instances. Cannot be null. */ final LinkedHashMap> createRequestMap() { - if (unmappedMatchers != null) { - throw new IllegalStateException( - "An incomplete mapping was found for " - + unmappedMatchers - + ". Try completing it with something like requestUrls()..hasRole('USER')"); - } - + Assert.state(this.unmappedMatchers == null, () -> "An incomplete mapping was found for " + this.unmappedMatchers + + ". Try completing it with something like requestUrls()..hasRole('USER')"); LinkedHashMap> requestMap = new LinkedHashMap<>(); for (UrlMapping mapping : getUrlMappings()) { RequestMatcher matcher = mapping.getRequestMatcher(); @@ -128,20 +119,24 @@ public abstract class AbstractConfigAttributeRequestMatcherRegistry extends * {@link ConfigAttribute} instances */ static final class UrlMapping { - private RequestMatcher requestMatcher; - private Collection configAttrs; + + private final RequestMatcher requestMatcher; + + private final Collection configAttrs; UrlMapping(RequestMatcher requestMatcher, Collection configAttrs) { this.requestMatcher = requestMatcher; this.configAttrs = configAttrs; } - public RequestMatcher getRequestMatcher() { - return requestMatcher; + RequestMatcher getRequestMatcher() { + return this.requestMatcher; } - public Collection getConfigAttrs() { - return configAttrs; + Collection getConfigAttrs() { + return this.configAttrs; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractHttpConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractHttpConfigurer.java index 18c26b4236..4144d5e840 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractHttpConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractHttpConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.springframework.security.config.annotation.ObjectPostProcessor; @@ -27,7 +28,6 @@ import org.springframework.security.web.DefaultSecurityFilterChain; * {@link HttpSecurity}. * * @author Rob Winch - * */ public abstract class AbstractHttpConfigurer, B extends HttpSecurityBuilder> extends SecurityConfigurerAdapter { @@ -35,7 +35,6 @@ public abstract class AbstractHttpConfigurer - *
  • - * {@link AuthenticationManager} - *
  • + *
  • {@link AuthenticationManager}
  • * * - * * @param the AbstractInterceptUrlConfigurer * @param the type of {@link HttpSecurityBuilder} that is being configured - * * @author Rob Winch * @since 3.2 * @see ExpressionUrlAuthorizationConfigurer * @see UrlAuthorizationConfigurer */ -abstract class AbstractInterceptUrlConfigurer, H extends HttpSecurityBuilder> +public abstract class AbstractInterceptUrlConfigurer, H extends HttpSecurityBuilder> extends AbstractHttpConfigurer { + private Boolean filterSecurityInterceptorOncePerRequest; private AccessDecisionManager accessDecisionManager; + AbstractInterceptUrlConfigurer() { + } + @Override public void configure(H http) throws Exception { FilterInvocationSecurityMetadataSource metadataSource = createMetadataSource(http); if (metadataSource == null) { return; } - FilterSecurityInterceptor securityInterceptor = createFilterSecurityInterceptor( - http, metadataSource, http.getSharedObject(AuthenticationManager.class)); - if (filterSecurityInterceptorOncePerRequest != null) { - securityInterceptor - .setObserveOncePerRequest(filterSecurityInterceptorOncePerRequest); + FilterSecurityInterceptor securityInterceptor = createFilterSecurityInterceptor(http, metadataSource, + http.getSharedObject(AuthenticationManager.class)); + if (this.filterSecurityInterceptorOncePerRequest != null) { + securityInterceptor.setObserveOncePerRequest(this.filterSecurityInterceptorOncePerRequest); } securityInterceptor = postProcess(securityInterceptor); http.addFilter(securityInterceptor); @@ -91,9 +91,7 @@ abstract class AbstractInterceptUrlConfigurer> getDecisionVoters(H http); - abstract class AbstractInterceptUrlRegistry, T> - extends AbstractConfigAttributeRequestMatcherRegistry { - - /** - * Allows setting the {@link AccessDecisionManager}. If none is provided, a - * default {@link AccessDecisionManager} is created. - * - * @param accessDecisionManager the {@link AccessDecisionManager} to use - * @return the {@link AbstractInterceptUrlConfigurer} for further customization - */ - public R accessDecisionManager(AccessDecisionManager accessDecisionManager) { - AbstractInterceptUrlConfigurer.this.accessDecisionManager = accessDecisionManager; - return getSelf(); - } - - /** - * Allows setting if the {@link FilterSecurityInterceptor} should be only applied - * once per request (i.e. if the filter intercepts on a forward, should it be - * applied again). - * - * @param filterSecurityInterceptorOncePerRequest if the - * {@link FilterSecurityInterceptor} should be only applied once per request - * @return the {@link AbstractInterceptUrlConfigurer} for further customization - */ - public R filterSecurityInterceptorOncePerRequest( - boolean filterSecurityInterceptorOncePerRequest) { - AbstractInterceptUrlConfigurer.this.filterSecurityInterceptorOncePerRequest = filterSecurityInterceptorOncePerRequest; - return getSelf(); - } - - /** - * Returns a reference to the current object with a single suppression of the type - * - * @return a reference to the current object - */ - @SuppressWarnings("unchecked") - private R getSelf() { - return (R) this; - } - } - /** * Creates the default {@code AccessDecisionManager} * @return the default {@code AccessDecisionManager} @@ -162,23 +117,20 @@ abstract class AbstractInterceptUrlConfigurer, T> + extends AbstractConfigAttributeRequestMatcherRegistry { + + AbstractInterceptUrlRegistry() { + } + + /** + * Allows setting the {@link AccessDecisionManager}. If none is provided, a + * default {@link AccessDecisionManager} is created. + * @param accessDecisionManager the {@link AccessDecisionManager} to use + * @return the {@link AbstractInterceptUrlConfigurer} for further customization + */ + public R accessDecisionManager(AccessDecisionManager accessDecisionManager) { + AbstractInterceptUrlConfigurer.this.accessDecisionManager = accessDecisionManager; + return getSelf(); + } + + /** + * Allows setting if the {@link FilterSecurityInterceptor} should be only applied + * once per request (i.e. if the filter intercepts on a forward, should it be + * applied again). + * @param filterSecurityInterceptorOncePerRequest if the + * {@link FilterSecurityInterceptor} should be only applied once per request + * @return the {@link AbstractInterceptUrlConfigurer} for further customization + */ + public R filterSecurityInterceptorOncePerRequest(boolean filterSecurityInterceptorOncePerRequest) { + AbstractInterceptUrlConfigurer.this.filterSecurityInterceptorOncePerRequest = filterSecurityInterceptorOncePerRequest; + return getSelf(); + } + + /** + * Returns a reference to the current object with a single suppression of the type + * @return a reference to the current object + */ + @SuppressWarnings("unchecked") + private R getSelf() { + return (R) this; + } + + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurer.java index d597618b0e..662ff2095b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.List; @@ -39,14 +40,18 @@ import org.springframework.security.web.authentication.AnonymousAuthenticationFi * @author Rob Winch * @since 3.2 */ -public final class AnonymousConfigurer> extends - AbstractHttpConfigurer, H> { +public final class AnonymousConfigurer> + extends AbstractHttpConfigurer, H> { + private String key; + private AuthenticationProvider authenticationProvider; + private AnonymousAuthenticationFilter authenticationFilter; + private Object principal = "anonymousUser"; - private List authorities = AuthorityUtils - .createAuthorityList("ROLE_ANONYMOUS"); + + private List authorities = AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"); /** * Creates a new instance @@ -58,7 +63,6 @@ public final class AnonymousConfigurer> extends /** * Sets the key to identify tokens created for anonymous authentication. Default is a * secure randomly generated key. - * * @param key the key to identify tokens created for anonymous authentication. Default * is a secure randomly generated key. * @return the {@link AnonymousConfigurer} for further customization of anonymous @@ -71,7 +75,6 @@ public final class AnonymousConfigurer> extends /** * Sets the principal for {@link Authentication} objects of anonymous users - * * @param principal used for the {@link Authentication} object of anonymous users * @return the {@link AnonymousConfigurer} for further customization of anonymous * authentication @@ -84,7 +87,6 @@ public final class AnonymousConfigurer> extends /** * Sets the {@link org.springframework.security.core.Authentication#getAuthorities()} * for anonymous users - * * @param authorities Sets the * {@link org.springframework.security.core.Authentication#getAuthorities()} for * anonymous users @@ -99,7 +101,6 @@ public final class AnonymousConfigurer> extends /** * Sets the {@link org.springframework.security.core.Authentication#getAuthorities()} * for anonymous users - * * @param authorities Sets the * {@link org.springframework.security.core.Authentication#getAuthorities()} for * anonymous users (i.e. "ROLE_ANONYMOUS") @@ -114,15 +115,12 @@ public final class AnonymousConfigurer> extends * Sets the {@link AuthenticationProvider} used to validate an anonymous user. If this * is set, no attributes on the {@link AnonymousConfigurer} will be set on the * {@link AuthenticationProvider}. - * * @param authenticationProvider the {@link AuthenticationProvider} used to validate * an anonymous user. Default is {@link AnonymousAuthenticationProvider} - * * @return the {@link AnonymousConfigurer} for further customization of anonymous * authentication */ - public AnonymousConfigurer authenticationProvider( - AuthenticationProvider authenticationProvider) { + public AnonymousConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) { this.authenticationProvider = authenticationProvider; return this; } @@ -131,42 +129,39 @@ public final class AnonymousConfigurer> extends * Sets the {@link AnonymousAuthenticationFilter} used to populate an anonymous user. * If this is set, no attributes on the {@link AnonymousConfigurer} will be set on the * {@link AnonymousAuthenticationFilter}. - * * @param authenticationFilter the {@link AnonymousAuthenticationFilter} used to * populate an anonymous user. - * * @return the {@link AnonymousConfigurer} for further customization of anonymous * authentication */ - public AnonymousConfigurer authenticationFilter( - AnonymousAuthenticationFilter authenticationFilter) { + public AnonymousConfigurer authenticationFilter(AnonymousAuthenticationFilter authenticationFilter) { this.authenticationFilter = authenticationFilter; return this; } @Override public void init(H http) { - if (authenticationProvider == null) { - authenticationProvider = new AnonymousAuthenticationProvider(getKey()); + if (this.authenticationProvider == null) { + this.authenticationProvider = new AnonymousAuthenticationProvider(getKey()); } - if (authenticationFilter == null) { - authenticationFilter = new AnonymousAuthenticationFilter(getKey(), principal, - authorities); + if (this.authenticationFilter == null) { + this.authenticationFilter = new AnonymousAuthenticationFilter(getKey(), this.principal, this.authorities); } - authenticationProvider = postProcess(authenticationProvider); - http.authenticationProvider(authenticationProvider); + this.authenticationProvider = postProcess(this.authenticationProvider); + http.authenticationProvider(this.authenticationProvider); } @Override public void configure(H http) { - authenticationFilter.afterPropertiesSet(); - http.addFilter(authenticationFilter); + this.authenticationFilter.afterPropertiesSet(); + http.addFilter(this.authenticationFilter); } private String getKey() { - if (key == null) { - key = UUID.randomUUID().toString(); + if (this.key == null) { + this.key = UUID.randomUUID().toString(); } - return key; + return this.key; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurer.java index 1d3b99b5ef..19fc8ae0ca 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.Arrays; @@ -73,14 +74,16 @@ import org.springframework.security.web.util.matcher.RequestMatcher; * * * @param the type of {@link HttpSecurityBuilder} that is being configured - * * @author Rob Winch * @since 3.2 */ -public final class ChannelSecurityConfigurer> extends - AbstractHttpConfigurer, H> { +public final class ChannelSecurityConfigurer> + extends AbstractHttpConfigurer, H> { + private ChannelProcessingFilter channelFilter = new ChannelProcessingFilter(); + private LinkedHashMap> requestMap = new LinkedHashMap<>(); + private List channelProcessors; private final ChannelRequestMatcherRegistry REGISTRY; @@ -94,7 +97,7 @@ public final class ChannelSecurityConfigurer> e } public ChannelRequestMatcherRegistry getRegistry() { - return REGISTRY; + return this.REGISTRY; } @Override @@ -102,49 +105,40 @@ public final class ChannelSecurityConfigurer> e ChannelDecisionManagerImpl channelDecisionManager = new ChannelDecisionManagerImpl(); channelDecisionManager.setChannelProcessors(getChannelProcessors(http)); channelDecisionManager = postProcess(channelDecisionManager); - - channelFilter.setChannelDecisionManager(channelDecisionManager); - + this.channelFilter.setChannelDecisionManager(channelDecisionManager); DefaultFilterInvocationSecurityMetadataSource filterInvocationSecurityMetadataSource = new DefaultFilterInvocationSecurityMetadataSource( - requestMap); - channelFilter.setSecurityMetadataSource(filterInvocationSecurityMetadataSource); - - channelFilter = postProcess(channelFilter); - http.addFilter(channelFilter); + this.requestMap); + this.channelFilter.setSecurityMetadataSource(filterInvocationSecurityMetadataSource); + this.channelFilter = postProcess(this.channelFilter); + http.addFilter(this.channelFilter); } private List getChannelProcessors(H http) { - if (channelProcessors != null) { - return channelProcessors; + if (this.channelProcessors != null) { + return this.channelProcessors; } - InsecureChannelProcessor insecureChannelProcessor = new InsecureChannelProcessor(); SecureChannelProcessor secureChannelProcessor = new SecureChannelProcessor(); - PortMapper portMapper = http.getSharedObject(PortMapper.class); if (portMapper != null) { RetryWithHttpEntryPoint httpEntryPoint = new RetryWithHttpEntryPoint(); httpEntryPoint.setPortMapper(portMapper); insecureChannelProcessor.setEntryPoint(httpEntryPoint); - RetryWithHttpsEntryPoint httpsEntryPoint = new RetryWithHttpsEntryPoint(); httpsEntryPoint.setPortMapper(portMapper); secureChannelProcessor.setEntryPoint(httpsEntryPoint); } insecureChannelProcessor = postProcess(insecureChannelProcessor); secureChannelProcessor = postProcess(secureChannelProcessor); - return Arrays. asList(insecureChannelProcessor, - secureChannelProcessor); + return Arrays.asList(insecureChannelProcessor, secureChannelProcessor); } - private ChannelRequestMatcherRegistry addAttribute(String attribute, - List matchers) { + private ChannelRequestMatcherRegistry addAttribute(String attribute, List matchers) { for (RequestMatcher matcher : matchers) { - Collection attrs = Arrays - . asList(new SecurityConfig(attribute)); - requestMap.put(matcher, attrs); + Collection attrs = Arrays.asList(new SecurityConfig(attribute)); + this.requestMap.put(matcher, attrs); } - return REGISTRY; + return this.REGISTRY; } public final class ChannelRequestMatcherRegistry @@ -155,8 +149,7 @@ public final class ChannelSecurityConfigurer> e } @Override - public MvcMatchersRequiresChannelUrl mvcMatchers(HttpMethod method, - String... mvcPatterns) { + public MvcMatchersRequiresChannelUrl mvcMatchers(HttpMethod method, String... mvcPatterns) { List mvcMatchers = createMvcMatchers(method, mvcPatterns); return new MvcMatchersRequiresChannelUrl(mvcMatchers); } @@ -167,19 +160,16 @@ public final class ChannelSecurityConfigurer> e } @Override - protected RequiresChannelUrl chainRequestMatchersInternal( - List requestMatchers) { + protected RequiresChannelUrl chainRequestMatchersInternal(List requestMatchers) { return new RequiresChannelUrl(requestMatchers); } /** * Adds an {@link ObjectPostProcessor} for this class. - * * @param objectPostProcessor * @return the {@link ChannelSecurityConfigurer} for further customizations */ - public ChannelRequestMatcherRegistry withObjectPostProcessor( - ObjectPostProcessor objectPostProcessor) { + public ChannelRequestMatcherRegistry withObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { addObjectPostProcessor(objectPostProcessor); return this; } @@ -190,8 +180,7 @@ public final class ChannelSecurityConfigurer> e * @param channelProcessors * @return the {@link ChannelSecurityConfigurer} for further customizations */ - public ChannelRequestMatcherRegistry channelProcessors( - List channelProcessors) { + public ChannelRequestMatcherRegistry channelProcessors(List channelProcessors) { ChannelSecurityConfigurer.this.channelProcessors = channelProcessors; return this; } @@ -199,12 +188,12 @@ public final class ChannelSecurityConfigurer> e /** * Return the {@link SecurityBuilder} when done using the * {@link SecurityConfigurer}. This is useful for method chaining. - * * @return the type of {@link HttpSecurityBuilder} that is being configured */ public H and() { return ChannelSecurityConfigurer.this.and(); } + } public final class MvcMatchersRequiresChannelUrl extends RequiresChannelUrl { @@ -219,12 +208,14 @@ public final class ChannelSecurityConfigurer> e } return this; } + } public class RequiresChannelUrl { + protected List requestMatchers; - private RequiresChannelUrl(List requestMatchers) { + RequiresChannelUrl(List requestMatchers) { this.requestMatchers = requestMatchers; } @@ -237,7 +228,9 @@ public final class ChannelSecurityConfigurer> e } public ChannelRequestMatcherRegistry requires(String attribute) { - return addAttribute(attribute, requestMatchers); + return addAttribute(attribute, this.requestMatchers); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurer.java index 449061d49f..f4002d799b 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurer.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; @@ -36,11 +38,12 @@ import org.springframework.web.servlet.handler.HandlerMappingIntrospector; * @author Rob Winch * @since 4.1.1 */ -public class CorsConfigurer> - extends AbstractHttpConfigurer, H> { +public class CorsConfigurer> extends AbstractHttpConfigurer, H> { private static final String HANDLER_MAPPING_INTROSPECTOR = "org.springframework.web.servlet.handler.HandlerMappingIntrospector"; + private static final String CORS_CONFIGURATION_SOURCE_BEAN_NAME = "corsConfigurationSource"; + private static final String CORS_FILTER_BEAN_NAME = "corsFilter"; private CorsConfigurationSource configurationSource; @@ -53,8 +56,7 @@ public class CorsConfigurer> public CorsConfigurer() { } - public CorsConfigurer configurationSource( - CorsConfigurationSource configurationSource) { + public CorsConfigurer configurationSource(CorsConfigurationSource configurationSource) { this.configurationSource = configurationSource; return this; } @@ -62,13 +64,9 @@ public class CorsConfigurer> @Override public void configure(H http) { ApplicationContext context = http.getSharedObject(ApplicationContext.class); - CorsFilter corsFilter = getCorsFilter(context); - if (corsFilter == null) { - throw new IllegalStateException( - "Please configure either a " + CORS_FILTER_BEAN_NAME + " bean or a " - + CORS_CONFIGURATION_SOURCE_BEAN_NAME + "bean."); - } + Assert.state(corsFilter != null, () -> "Please configure either a " + CORS_FILTER_BEAN_NAME + " bean or a " + + CORS_CONFIGURATION_SOURCE_BEAN_NAME + "bean."); http.addFilter(corsFilter); } @@ -76,32 +74,27 @@ public class CorsConfigurer> if (this.configurationSource != null) { return new CorsFilter(this.configurationSource); } - - boolean containsCorsFilter = context - .containsBeanDefinition(CORS_FILTER_BEAN_NAME); + boolean containsCorsFilter = context.containsBeanDefinition(CORS_FILTER_BEAN_NAME); if (containsCorsFilter) { return context.getBean(CORS_FILTER_BEAN_NAME, CorsFilter.class); } - - boolean containsCorsSource = context - .containsBean(CORS_CONFIGURATION_SOURCE_BEAN_NAME); + boolean containsCorsSource = context.containsBean(CORS_CONFIGURATION_SOURCE_BEAN_NAME); if (containsCorsSource) { - CorsConfigurationSource configurationSource = context.getBean( - CORS_CONFIGURATION_SOURCE_BEAN_NAME, CorsConfigurationSource.class); + CorsConfigurationSource configurationSource = context.getBean(CORS_CONFIGURATION_SOURCE_BEAN_NAME, + CorsConfigurationSource.class); return new CorsFilter(configurationSource); } - - boolean mvcPresent = ClassUtils.isPresent(HANDLER_MAPPING_INTROSPECTOR, - context.getClassLoader()); + boolean mvcPresent = ClassUtils.isPresent(HANDLER_MAPPING_INTROSPECTOR, context.getClassLoader()); if (mvcPresent) { return MvcCorsFilter.getMvcCorsFilter(context); } return null; } - static class MvcCorsFilter { + private static final String HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME = "mvcHandlerMappingIntrospector"; + /** * This needs to be isolated into a separate class as Spring MVC is an optional * dependency and will potentially cause ClassLoading issues @@ -110,11 +103,16 @@ public class CorsConfigurer> */ private static CorsFilter getMvcCorsFilter(ApplicationContext context) { if (!context.containsBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME)) { - throw new NoSuchBeanDefinitionException(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, "A Bean named " + HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME +" of type " + HandlerMappingIntrospector.class.getName() + throw new NoSuchBeanDefinitionException(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, "A Bean named " + + HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME + " of type " + + HandlerMappingIntrospector.class.getName() + " is required to use MvcRequestMatcher. Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext."); } - HandlerMappingIntrospector mappingIntrospector = context.getBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, HandlerMappingIntrospector.class); + HandlerMappingIntrospector mappingIntrospector = context.getBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, + HandlerMappingIntrospector.class); return new CorsFilter(mappingIntrospector); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java index 33b34bcef5..4e33f4bc7e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.ArrayList; @@ -74,15 +75,20 @@ import org.springframework.util.Assert; * * * @author Rob Winch + * @author Michael Vitz * @since 3.2 */ public final class CsrfConfigurer> extends AbstractHttpConfigurer, H> { - private CsrfTokenRepository csrfTokenRepository = new LazyCsrfTokenRepository( - new HttpSessionCsrfTokenRepository()); + + private CsrfTokenRepository csrfTokenRepository = new LazyCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); + private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER; + private List ignoredCsrfProtectionMatchers = new ArrayList<>(); + private SessionAuthenticationStrategy sessionAuthenticationStrategy; + private final ApplicationContext context; /** @@ -96,12 +102,10 @@ public final class CsrfConfigurer> /** * Specify the {@link CsrfTokenRepository} to use. The default is an * {@link HttpSessionCsrfTokenRepository} wrapped by {@link LazyCsrfTokenRepository}. - * * @param csrfTokenRepository the {@link CsrfTokenRepository} to use * @return the {@link CsrfConfigurer} for further customizations */ - public CsrfConfigurer csrfTokenRepository( - CsrfTokenRepository csrfTokenRepository) { + public CsrfConfigurer csrfTokenRepository(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); this.csrfTokenRepository = csrfTokenRepository; return this; @@ -111,14 +115,11 @@ public final class CsrfConfigurer> * Specify the {@link RequestMatcher} to use for determining when CSRF should be * applied. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other * requests. - * * @param requireCsrfProtectionMatcher the {@link RequestMatcher} to use * @return the {@link CsrfConfigurer} for further customizations */ - public CsrfConfigurer requireCsrfProtectionMatcher( - RequestMatcher requireCsrfProtectionMatcher) { - Assert.notNull(requireCsrfProtectionMatcher, - "requireCsrfProtectionMatcher cannot be null"); + public CsrfConfigurer requireCsrfProtectionMatcher(RequestMatcher requireCsrfProtectionMatcher) { + Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null"); this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher; return this; } @@ -148,8 +149,7 @@ public final class CsrfConfigurer> * @since 4.0 */ public CsrfConfigurer ignoringAntMatchers(String... antPatterns) { - return new IgnoreCsrfProtectionRegistry(this.context).antMatchers(antPatterns) - .and(); + return new IgnoreCsrfProtectionRegistry(this.context).antMatchers(antPatterns).and(); } /** @@ -163,13 +163,14 @@ public final class CsrfConfigurer> *

    *
      *
    • Any GET, HEAD, TRACE, OPTIONS (this is the default)
    • - *
    • We also explicitly state to ignore any request that has a "X-Requested-With: XMLHttpRequest" header
    • + *
    • We also explicitly state to ignore any request that has a "X-Requested-With: + * XMLHttpRequest" header
    • *
    * *
     	 * http
     	 *     .csrf()
    -	 *         .ignoringRequestMatchers(request -> "XMLHttpRequest".equals(request.getHeader("X-Requested-With")))
    +	 *         .ignoringRequestMatchers((request) -> "XMLHttpRequest".equals(request.getHeader("X-Requested-With")))
     	 *         .and()
     	 *     ...
     	 * 
    @@ -177,8 +178,7 @@ public final class CsrfConfigurer> * @since 5.1 */ public CsrfConfigurer ignoringRequestMatchers(RequestMatcher... requestMatchers) { - return new IgnoreCsrfProtectionRegistry(this.context).requestMatchers(requestMatchers) - .and(); + return new IgnoreCsrfProtectionRegistry(this.context).requestMatchers(requestMatchers).and(); } /** @@ -186,17 +186,14 @@ public final class CsrfConfigurer> * Specify the {@link SessionAuthenticationStrategy} to use. The default is a * {@link CsrfAuthenticationStrategy}. *

    - * - * @author Michael Vitz - * @since 5.2 - * - * @param sessionAuthenticationStrategy the {@link SessionAuthenticationStrategy} to use + * @param sessionAuthenticationStrategy the {@link SessionAuthenticationStrategy} to + * use * @return the {@link CsrfConfigurer} for further customizations + * @since 5.2 */ public CsrfConfigurer sessionAuthenticationStrategy( SessionAuthenticationStrategy sessionAuthenticationStrategy) { - Assert.notNull(sessionAuthenticationStrategy, - "sessionAuthenticationStrategy cannot be null"); + Assert.notNull(sessionAuthenticationStrategy, "sessionAuthenticationStrategy cannot be null"); this.sessionAuthenticationStrategy = sessionAuthenticationStrategy; return this; } @@ -215,14 +212,11 @@ public final class CsrfConfigurer> } LogoutConfigurer logoutConfigurer = http.getConfigurer(LogoutConfigurer.class); if (logoutConfigurer != null) { - logoutConfigurer - .addLogoutHandler(new CsrfLogoutHandler(this.csrfTokenRepository)); + logoutConfigurer.addLogoutHandler(new CsrfLogoutHandler(this.csrfTokenRepository)); } - SessionManagementConfigurer sessionConfigurer = http - .getConfigurer(SessionManagementConfigurer.class); + SessionManagementConfigurer sessionConfigurer = http.getConfigurer(SessionManagementConfigurer.class); if (sessionConfigurer != null) { - sessionConfigurer.addSessionAuthenticationStrategy( - getSessionAuthenticationStrategy()); + sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy()); } filter = postProcess(filter); http.addFilter(filter); @@ -231,7 +225,6 @@ public final class CsrfConfigurer> /** * Gets the final {@link RequestMatcher} to use by combining the * {@link #requireCsrfProtectionMatcher(RequestMatcher)} and any {@link #ignore()}. - * * @return the {@link RequestMatcher} to use */ private RequestMatcher getRequireCsrfProtectionMatcher() { @@ -239,22 +232,19 @@ public final class CsrfConfigurer> return this.requireCsrfProtectionMatcher; } return new AndRequestMatcher(this.requireCsrfProtectionMatcher, - new NegatedRequestMatcher( - new OrRequestMatcher(this.ignoredCsrfProtectionMatchers))); + new NegatedRequestMatcher(new OrRequestMatcher(this.ignoredCsrfProtectionMatchers))); } /** * Gets the default {@link AccessDeniedHandler} from the * {@link ExceptionHandlingConfigurer#getAccessDeniedHandler()} or create a * {@link AccessDeniedHandlerImpl} if not available. - * * @param http the {@link HttpSecurityBuilder} * @return the {@link AccessDeniedHandler} */ @SuppressWarnings("unchecked") private AccessDeniedHandler getDefaultAccessDeniedHandler(H http) { - ExceptionHandlingConfigurer exceptionConfig = http - .getConfigurer(ExceptionHandlingConfigurer.class); + ExceptionHandlingConfigurer exceptionConfig = http.getConfigurer(ExceptionHandlingConfigurer.class); AccessDeniedHandler handler = null; if (exceptionConfig != null) { handler = exceptionConfig.getAccessDeniedHandler(); @@ -269,14 +259,12 @@ public final class CsrfConfigurer> * Gets the default {@link InvalidSessionStrategy} from the * {@link SessionManagementConfigurer#getInvalidSessionStrategy()} or null if not * available. - * * @param http the {@link HttpSecurityBuilder} * @return the {@link InvalidSessionStrategy} */ @SuppressWarnings("unchecked") private InvalidSessionStrategy getInvalidSessionStrategy(H http) { - SessionManagementConfigurer sessionManagement = http - .getConfigurer(SessionManagementConfigurer.class); + SessionManagementConfigurer sessionManagement = http.getConfigurer(SessionManagementConfigurer.class); if (sessionManagement == null) { return null; } @@ -292,18 +280,15 @@ public final class CsrfConfigurer> * {@link InvalidSessionAccessDeniedHandler} and the * {@link #getDefaultAccessDeniedHandler(HttpSecurityBuilder)}. Otherwise, only * {@link #getDefaultAccessDeniedHandler(HttpSecurityBuilder)} is used. - * * @param http the {@link HttpSecurityBuilder} * @return the {@link AccessDeniedHandler} */ private AccessDeniedHandler createAccessDeniedHandler(H http) { InvalidSessionStrategy invalidSessionStrategy = getInvalidSessionStrategy(http); - AccessDeniedHandler defaultAccessDeniedHandler = getDefaultAccessDeniedHandler( - http); + AccessDeniedHandler defaultAccessDeniedHandler = getDefaultAccessDeniedHandler(http); if (invalidSessionStrategy == null) { return defaultAccessDeniedHandler; } - InvalidSessionAccessDeniedHandler invalidSessionDeniedHandler = new InvalidSessionAccessDeniedHandler( invalidSessionStrategy); LinkedHashMap, AccessDeniedHandler> handlers = new LinkedHashMap<>(); @@ -312,20 +297,16 @@ public final class CsrfConfigurer> } /** - * Gets the {@link SessionAuthenticationStrategy} to use. If none was set by the user a - * {@link CsrfAuthenticationStrategy} is created. - * - * @author Michael Vitz - * @since 5.2 - * + * Gets the {@link SessionAuthenticationStrategy} to use. If none was set by the user + * a {@link CsrfAuthenticationStrategy} is created. * @return the {@link SessionAuthenticationStrategy} + * @since 5.2 */ private SessionAuthenticationStrategy getSessionAuthenticationStrategy() { - if (sessionAuthenticationStrategy != null) { - return sessionAuthenticationStrategy; - } else { - return new CsrfAuthenticationStrategy(this.csrfTokenRepository); + if (this.sessionAuthenticationStrategy != null) { + return this.sessionAuthenticationStrategy; } + return new CsrfAuthenticationStrategy(this.csrfTokenRepository); } /** @@ -336,23 +317,17 @@ public final class CsrfConfigurer> * @author Rob Winch * @since 4.0 */ - private class IgnoreCsrfProtectionRegistry - extends AbstractRequestMatcherRegistry { + private class IgnoreCsrfProtectionRegistry extends AbstractRequestMatcherRegistry { - /** - * @param context - */ - private IgnoreCsrfProtectionRegistry(ApplicationContext context) { + IgnoreCsrfProtectionRegistry(ApplicationContext context) { setApplicationContext(context); } @Override - public MvcMatchersIgnoreCsrfProtectionRegistry mvcMatchers(HttpMethod method, - String... mvcPatterns) { + public MvcMatchersIgnoreCsrfProtectionRegistry mvcMatchers(HttpMethod method, String... mvcPatterns) { List mvcMatchers = createMvcMatchers(method, mvcPatterns); CsrfConfigurer.this.ignoredCsrfProtectionMatchers.addAll(mvcMatchers); - return new MvcMatchersIgnoreCsrfProtectionRegistry(getApplicationContext(), - mvcMatchers); + return new MvcMatchersIgnoreCsrfProtectionRegistry(getApplicationContext(), mvcMatchers); } @Override @@ -360,16 +335,16 @@ public final class CsrfConfigurer> return mvcMatchers(null, mvcPatterns); } - public CsrfConfigurer and() { + CsrfConfigurer and() { return CsrfConfigurer.this; } @Override - protected IgnoreCsrfProtectionRegistry chainRequestMatchers( - List requestMatchers) { + protected IgnoreCsrfProtectionRegistry chainRequestMatchers(List requestMatchers) { CsrfConfigurer.this.ignoredCsrfProtectionMatchers.addAll(requestMatchers); return this; } + } /** @@ -378,8 +353,8 @@ public final class CsrfConfigurer> * * @author Rob Winch */ - private final class MvcMatchersIgnoreCsrfProtectionRegistry - extends IgnoreCsrfProtectionRegistry { + private final class MvcMatchersIgnoreCsrfProtectionRegistry extends IgnoreCsrfProtectionRegistry { + private final List mvcMatchers; private MvcMatchersIgnoreCsrfProtectionRegistry(ApplicationContext context, @@ -388,11 +363,13 @@ public final class CsrfConfigurer> this.mvcMatchers = mvcMatchers; } - public IgnoreCsrfProtectionRegistry servletPath(String servletPath) { + IgnoreCsrfProtectionRegistry servletPath(String servletPath) { for (MvcRequestMatcher matcher : this.mvcMatchers) { matcher.setServletPath(servletPath); } return this; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java index 251c586f3e..bf144dc56c 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurer.java @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; +import java.util.Collections; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.web.AuthenticationEntryPoint; @@ -22,11 +28,6 @@ import org.springframework.security.web.authentication.ui.DefaultLoginPageGenera import org.springframework.security.web.authentication.ui.DefaultLogoutPageGeneratingFilter; import org.springframework.security.web.csrf.CsrfToken; -import javax.servlet.http.HttpServletRequest; -import java.util.Collections; -import java.util.Map; -import java.util.function.Function; - /** * Adds a Filter that will generate a login page if one is not specified otherwise when * using {@link WebSecurityConfigurerAdapter}. @@ -49,7 +50,8 @@ import java.util.function.Function; * *

    Shared Objects Created

    * - * No shared objects are created. isLogoutRequest

    Shared Objects Used

    + * No shared objects are created. isLogoutRequest + *

    Shared Objects Used

    * * The following shared objects are used: * @@ -60,13 +62,12 @@ import java.util.function.Function; * {@link DefaultLoginPageConfigurer} should be added and how to configure it. * * - * @see WebSecurityConfigurerAdapter - * * @author Rob Winch * @since 3.2 + * @see WebSecurityConfigurerAdapter */ -public final class DefaultLoginPageConfigurer> extends - AbstractHttpConfigurer, H> { +public final class DefaultLoginPageConfigurer> + extends AbstractHttpConfigurer, H> { private DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = new DefaultLoginPageGeneratingFilter(); @@ -74,32 +75,28 @@ public final class DefaultLoginPageConfigurer> @Override public void init(H http) { - Function> hiddenInputs = request -> { - CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); - if (token == null) { - return Collections.emptyMap(); - } - return Collections.singletonMap(token.getParameterName(), token.getToken()); - }; - this.loginPageGeneratingFilter.setResolveHiddenInputs(hiddenInputs); - this.logoutPageGeneratingFilter.setResolveHiddenInputs(hiddenInputs); - http.setSharedObject(DefaultLoginPageGeneratingFilter.class, - loginPageGeneratingFilter); + this.loginPageGeneratingFilter.setResolveHiddenInputs(DefaultLoginPageConfigurer.this::hiddenInputs); + this.logoutPageGeneratingFilter.setResolveHiddenInputs(DefaultLoginPageConfigurer.this::hiddenInputs); + http.setSharedObject(DefaultLoginPageGeneratingFilter.class, this.loginPageGeneratingFilter); + } + + private Map hiddenInputs(HttpServletRequest request) { + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + return (token != null) ? Collections.singletonMap(token.getParameterName(), token.getToken()) + : Collections.emptyMap(); } @Override @SuppressWarnings("unchecked") public void configure(H http) { AuthenticationEntryPoint authenticationEntryPoint = null; - ExceptionHandlingConfigurer exceptionConf = http - .getConfigurer(ExceptionHandlingConfigurer.class); + ExceptionHandlingConfigurer exceptionConf = http.getConfigurer(ExceptionHandlingConfigurer.class); if (exceptionConf != null) { authenticationEntryPoint = exceptionConf.getAuthenticationEntryPoint(); } - - if (loginPageGeneratingFilter.isEnabled() && authenticationEntryPoint == null) { - loginPageGeneratingFilter = postProcess(loginPageGeneratingFilter); - http.addFilter(loginPageGeneratingFilter); + if (this.loginPageGeneratingFilter.isEnabled() && authenticationEntryPoint == null) { + this.loginPageGeneratingFilter = postProcess(this.loginPageGeneratingFilter); + http.addFilter(this.loginPageGeneratingFilter); http.addFilter(this.logoutPageGeneratingFilter); } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurer.java index d1741cf898..f311358e2c 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.LinkedHashMap; @@ -62,8 +63,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher; * @author Rob Winch * @since 3.2 */ -public final class ExceptionHandlingConfigurer> extends - AbstractHttpConfigurer, H> { +public final class ExceptionHandlingConfigurer> + extends AbstractHttpConfigurer, H> { private AuthenticationEntryPoint authenticationEntryPoint; @@ -83,7 +84,6 @@ public final class ExceptionHandlingConfigurer> /** * Shortcut to specify the {@link AccessDeniedHandler} to be used is a specific error * page - * * @param accessDeniedUrl the URL to the access denied page (i.e. /errors/401) * @return the {@link ExceptionHandlingConfigurer} for further customization * @see AccessDeniedHandlerImpl @@ -97,32 +97,29 @@ public final class ExceptionHandlingConfigurer> /** * Specifies the {@link AccessDeniedHandler} to be used - * * @param accessDeniedHandler the {@link AccessDeniedHandler} to be used * @return the {@link ExceptionHandlingConfigurer} for further customization */ - public ExceptionHandlingConfigurer accessDeniedHandler( - AccessDeniedHandler accessDeniedHandler) { + public ExceptionHandlingConfigurer accessDeniedHandler(AccessDeniedHandler accessDeniedHandler) { this.accessDeniedHandler = accessDeniedHandler; return this; } /** - * Sets a default {@link AccessDeniedHandler} to be used which prefers being - * invoked for the provided {@link RequestMatcher}. If only a single default - * {@link AccessDeniedHandler} is specified, it will be what is used for the - * default {@link AccessDeniedHandler}. If multiple default - * {@link AccessDeniedHandler} instances are configured, then a + * Sets a default {@link AccessDeniedHandler} to be used which prefers being invoked + * for the provided {@link RequestMatcher}. If only a single default + * {@link AccessDeniedHandler} is specified, it will be what is used for the default + * {@link AccessDeniedHandler}. If multiple default {@link AccessDeniedHandler} + * instances are configured, then a * {@link RequestMatcherDelegatingAccessDeniedHandler} will be used. - * * @param deniedHandler the {@link AccessDeniedHandler} to use * @param preferredMatcher the {@link RequestMatcher} for this default * {@link AccessDeniedHandler} * @return the {@link ExceptionHandlingConfigurer} for further customizations * @since 5.1 */ - public ExceptionHandlingConfigurer defaultAccessDeniedHandlerFor( - AccessDeniedHandler deniedHandler, RequestMatcher preferredMatcher) { + public ExceptionHandlingConfigurer defaultAccessDeniedHandlerFor(AccessDeniedHandler deniedHandler, + RequestMatcher preferredMatcher) { this.defaultDeniedHandlerMappings.put(preferredMatcher, deniedHandler); return this; } @@ -141,12 +138,10 @@ public final class ExceptionHandlingConfigurer> *

    * If that is not provided defaults to {@link Http403ForbiddenEntryPoint}. *

    - * * @param authenticationEntryPoint the {@link AuthenticationEntryPoint} to use * @return the {@link ExceptionHandlingConfigurer} for further customizations */ - public ExceptionHandlingConfigurer authenticationEntryPoint( - AuthenticationEntryPoint authenticationEntryPoint) { + public ExceptionHandlingConfigurer authenticationEntryPoint(AuthenticationEntryPoint authenticationEntryPoint) { this.authenticationEntryPoint = authenticationEntryPoint; return this; } @@ -158,14 +153,13 @@ public final class ExceptionHandlingConfigurer> * default {@link AuthenticationEntryPoint}. If multiple default * {@link AuthenticationEntryPoint} instances are configured, then a * {@link DelegatingAuthenticationEntryPoint} will be used. - * * @param entryPoint the {@link AuthenticationEntryPoint} to use * @param preferredMatcher the {@link RequestMatcher} for this default * {@link AuthenticationEntryPoint} * @return the {@link ExceptionHandlingConfigurer} for further customizations */ - public ExceptionHandlingConfigurer defaultAuthenticationEntryPointFor( - AuthenticationEntryPoint entryPoint, RequestMatcher preferredMatcher) { + public ExceptionHandlingConfigurer defaultAuthenticationEntryPointFor(AuthenticationEntryPoint entryPoint, + RequestMatcher preferredMatcher) { this.defaultEntryPointMappings.put(preferredMatcher, entryPoint); return this; } @@ -180,7 +174,6 @@ public final class ExceptionHandlingConfigurer> /** * Gets the {@link AccessDeniedHandler} that is configured. - * * @return the {@link AccessDeniedHandler} */ AccessDeniedHandler getAccessDeniedHandler() { @@ -190,8 +183,8 @@ public final class ExceptionHandlingConfigurer> @Override public void configure(H http) { AuthenticationEntryPoint entryPoint = getAuthenticationEntryPoint(http); - ExceptionTranslationFilter exceptionTranslationFilter = new ExceptionTranslationFilter( - entryPoint, getRequestCache(http)); + ExceptionTranslationFilter exceptionTranslationFilter = new ExceptionTranslationFilter(entryPoint, + getRequestCache(http)); AccessDeniedHandler deniedHandler = getAccessDeniedHandler(http); exceptionTranslationFilter.setAccessDeniedHandler(deniedHandler); exceptionTranslationFilter = postProcess(exceptionTranslationFilter); @@ -235,8 +228,7 @@ public final class ExceptionHandlingConfigurer> if (this.defaultDeniedHandlerMappings.size() == 1) { return this.defaultDeniedHandlerMappings.values().iterator().next(); } - return new RequestMatcherDelegatingAccessDeniedHandler( - this.defaultDeniedHandlerMappings, + return new RequestMatcherDelegatingAccessDeniedHandler(this.defaultDeniedHandlerMappings, new AccessDeniedHandlerImpl()); } @@ -249,8 +241,7 @@ public final class ExceptionHandlingConfigurer> } DelegatingAuthenticationEntryPoint entryPoint = new DelegatingAuthenticationEntryPoint( this.defaultEntryPointMappings); - entryPoint.setDefaultEntryPoint(this.defaultEntryPointMappings.values().iterator() - .next()); + entryPoint.setDefaultEntryPoint(this.defaultEntryPointMappings.values().iterator().next()); return entryPoint; } @@ -259,7 +250,6 @@ public final class ExceptionHandlingConfigurer> * {@link #requestCache(org.springframework.security.web.savedrequest.RequestCache)}, * then it is used. Otherwise, an attempt to find a {@link RequestCache} shared object * is made. If that fails, an {@link HttpSessionRequestCache} is used - * * @param http the {@link HttpSecurity} to attempt to fined the shared object * @return the {@link RequestCache} to use */ @@ -270,4 +260,5 @@ public final class ExceptionHandlingConfigurer> } return new HttpSessionRequestCache(); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurer.java index d1df31e7f5..949b833da4 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.ArrayList; @@ -46,7 +47,8 @@ import org.springframework.util.StringUtils; * Adds URL based authorization based upon SpEL expressions to an application. At least * one {@link org.springframework.web.bind.annotation.RequestMapping} needs to be mapped * to {@link ConfigAttribute}'s for this {@link SecurityContextConfigurer} to have - * meaning.

    Security Filters

    + * meaning. + *

    Security Filters

    * * The following Filters are populated * @@ -73,19 +75,23 @@ import org.springframework.util.StringUtils; * * * @param the type of {@link HttpSecurityBuilder} that is being configured - * * @author Rob Winch * @since 3.2 * @see org.springframework.security.config.annotation.web.builders.HttpSecurity#authorizeRequests() */ public final class ExpressionUrlAuthorizationConfigurer> - extends - AbstractInterceptUrlConfigurer, H> { + extends AbstractInterceptUrlConfigurer, H> { + static final String permitAll = "permitAll"; + private static final String denyAll = "denyAll"; + private static final String anonymous = "anonymous"; + private static final String authenticated = "authenticated"; + private static final String fullyAuthenticated = "fullyAuthenticated"; + private static final String rememberMe = "rememberMe"; private final ExpressionInterceptUrlRegistry REGISTRY; @@ -101,73 +107,12 @@ public final class ExpressionUrlAuthorizationConfigurer.AbstractInterceptUrlRegistry { - - /** - * @param context - */ - private ExpressionInterceptUrlRegistry(ApplicationContext context) { - setApplicationContext(context); - } - - @Override - public MvcMatchersAuthorizedUrl mvcMatchers(HttpMethod method, String... mvcPatterns) { - return new MvcMatchersAuthorizedUrl(createMvcMatchers(method, mvcPatterns)); - } - - @Override - public MvcMatchersAuthorizedUrl mvcMatchers(String... patterns) { - return mvcMatchers(null, patterns); - } - - @Override - protected final AuthorizedUrl chainRequestMatchersInternal( - List requestMatchers) { - return new AuthorizedUrl(requestMatchers); - } - - /** - * Allows customization of the {@link SecurityExpressionHandler} to be used. The - * default is {@link DefaultWebSecurityExpressionHandler} - * - * @param expressionHandler the {@link SecurityExpressionHandler} to be used - * @return the {@link ExpressionUrlAuthorizationConfigurer} for further - * customization. - */ - public ExpressionInterceptUrlRegistry expressionHandler( - SecurityExpressionHandler expressionHandler) { - ExpressionUrlAuthorizationConfigurer.this.expressionHandler = expressionHandler; - return this; - } - - /** - * Adds an {@link ObjectPostProcessor} for this class. - * - * @param objectPostProcessor - * @return the {@link ExpressionUrlAuthorizationConfigurer} for further - * customizations - */ - public ExpressionInterceptUrlRegistry withObjectPostProcessor( - ObjectPostProcessor objectPostProcessor) { - addObjectPostProcessor(objectPostProcessor); - return this; - } - - public H and() { - return ExpressionUrlAuthorizationConfigurer.this.and(); - } - + return this.REGISTRY; } /** * Allows registering multiple {@link RequestMatcher} instances to a collection of * {@link ConfigAttribute} instances - * * @param requestMatchers the {@link RequestMatcher} instances to register to the * {@link ConfigAttribute} instances * @param configAttributes the {@link ConfigAttribute} to be mapped by the @@ -176,8 +121,8 @@ public final class ExpressionUrlAuthorizationConfigurer requestMatchers, Collection configAttributes) { for (RequestMatcher requestMatcher : requestMatchers) { - REGISTRY.addMapping(new AbstractConfigAttributeRequestMatcherRegistry.UrlMapping( - requestMatcher, configAttributes)); + this.REGISTRY.addMapping( + new AbstractConfigAttributeRequestMatcherRegistry.UrlMapping(requestMatcher, configAttributes)); } } @@ -192,63 +137,54 @@ public final class ExpressionUrlAuthorizationConfigurer> requestMap = REGISTRY - .createRequestMap(); - if (requestMap.isEmpty()) { - throw new IllegalStateException( - "At least one mapping is required (i.e. authorizeRequests().anyRequest().authenticated())"); - } - return new ExpressionBasedFilterInvocationSecurityMetadataSource(requestMap, - getExpressionHandler(http)); + ExpressionBasedFilterInvocationSecurityMetadataSource createMetadataSource(H http) { + LinkedHashMap> requestMap = this.REGISTRY.createRequestMap(); + Assert.state(!requestMap.isEmpty(), + "At least one mapping is required (i.e. authorizeRequests().anyRequest().authenticated())"); + return new ExpressionBasedFilterInvocationSecurityMetadataSource(requestMap, getExpressionHandler(http)); } private SecurityExpressionHandler getExpressionHandler(H http) { - if (expressionHandler == null) { - DefaultWebSecurityExpressionHandler defaultHandler = new DefaultWebSecurityExpressionHandler(); - AuthenticationTrustResolver trustResolver = http - .getSharedObject(AuthenticationTrustResolver.class); - if (trustResolver != null) { - defaultHandler.setTrustResolver(trustResolver); - } - ApplicationContext context = http.getSharedObject(ApplicationContext.class); - if (context != null) { - String[] roleHiearchyBeanNames = context.getBeanNamesForType(RoleHierarchy.class); - if (roleHiearchyBeanNames.length == 1) { - defaultHandler.setRoleHierarchy(context.getBean(roleHiearchyBeanNames[0], RoleHierarchy.class)); - } - String[] grantedAuthorityDefaultsBeanNames = context.getBeanNamesForType(GrantedAuthorityDefaults.class); - if (grantedAuthorityDefaultsBeanNames.length == 1) { - GrantedAuthorityDefaults grantedAuthorityDefaults = context.getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); - defaultHandler.setDefaultRolePrefix(grantedAuthorityDefaults.getRolePrefix()); - } - String[] permissionEvaluatorBeanNames = context.getBeanNamesForType(PermissionEvaluator.class); - if (permissionEvaluatorBeanNames.length == 1) { - PermissionEvaluator permissionEvaluator = context.getBean(permissionEvaluatorBeanNames[0], PermissionEvaluator.class); - defaultHandler.setPermissionEvaluator(permissionEvaluator); - } - } - - expressionHandler = postProcess(defaultHandler); + if (this.expressionHandler != null) { + return this.expressionHandler; } - - return expressionHandler; + DefaultWebSecurityExpressionHandler defaultHandler = new DefaultWebSecurityExpressionHandler(); + AuthenticationTrustResolver trustResolver = http.getSharedObject(AuthenticationTrustResolver.class); + if (trustResolver != null) { + defaultHandler.setTrustResolver(trustResolver); + } + ApplicationContext context = http.getSharedObject(ApplicationContext.class); + if (context != null) { + String[] roleHiearchyBeanNames = context.getBeanNamesForType(RoleHierarchy.class); + if (roleHiearchyBeanNames.length == 1) { + defaultHandler.setRoleHierarchy(context.getBean(roleHiearchyBeanNames[0], RoleHierarchy.class)); + } + String[] grantedAuthorityDefaultsBeanNames = context.getBeanNamesForType(GrantedAuthorityDefaults.class); + if (grantedAuthorityDefaultsBeanNames.length == 1) { + GrantedAuthorityDefaults grantedAuthorityDefaults = context + .getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); + defaultHandler.setDefaultRolePrefix(grantedAuthorityDefaults.getRolePrefix()); + } + String[] permissionEvaluatorBeanNames = context.getBeanNamesForType(PermissionEvaluator.class); + if (permissionEvaluatorBeanNames.length == 1) { + PermissionEvaluator permissionEvaluator = context.getBean(permissionEvaluatorBeanNames[0], + PermissionEvaluator.class); + defaultHandler.setPermissionEvaluator(permissionEvaluator); + } + } + this.expressionHandler = postProcess(defaultHandler); + return this.expressionHandler; } private static String hasAnyRole(String... authorities) { - String anyAuthorities = StringUtils.arrayToDelimitedString(authorities, - "','ROLE_"); + String anyAuthorities = StringUtils.arrayToDelimitedString(authorities, "','ROLE_"); return "hasAnyRole('ROLE_" + anyAuthorities + "')"; } private static String hasRole(String role) { Assert.notNull(role, "role cannot be null"); - if (role.startsWith("ROLE_")) { - throw new IllegalArgumentException( - "role should not start with 'ROLE_' since it is automatically inserted. Got '" - + role + "'"); - } + Assert.isTrue(!role.startsWith("ROLE_"), + () -> "role should not start with 'ROLE_' since it is automatically inserted. Got '" + role + "'"); return "hasRole('ROLE_" + role + "')"; } @@ -265,16 +201,68 @@ public final class ExpressionUrlAuthorizationConfigurer.AbstractInterceptUrlRegistry { + + private ExpressionInterceptUrlRegistry(ApplicationContext context) { + setApplicationContext(context); + } + + @Override + public MvcMatchersAuthorizedUrl mvcMatchers(HttpMethod method, String... mvcPatterns) { + return new MvcMatchersAuthorizedUrl(createMvcMatchers(method, mvcPatterns)); + } + + @Override + public MvcMatchersAuthorizedUrl mvcMatchers(String... patterns) { + return mvcMatchers(null, patterns); + } + + @Override + protected AuthorizedUrl chainRequestMatchersInternal(List requestMatchers) { + return new AuthorizedUrl(requestMatchers); + } + + /** + * Allows customization of the {@link SecurityExpressionHandler} to be used. The + * default is {@link DefaultWebSecurityExpressionHandler} + * @param expressionHandler the {@link SecurityExpressionHandler} to be used + * @return the {@link ExpressionUrlAuthorizationConfigurer} for further + * customization. + */ + public ExpressionInterceptUrlRegistry expressionHandler( + SecurityExpressionHandler expressionHandler) { + ExpressionUrlAuthorizationConfigurer.this.expressionHandler = expressionHandler; + return this; + } + + /** + * Adds an {@link ObjectPostProcessor} for this class. + * @param objectPostProcessor + * @return the {@link ExpressionUrlAuthorizationConfigurer} for further + * customizations + */ + public ExpressionInterceptUrlRegistry withObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { + addObjectPostProcessor(objectPostProcessor); + return this; + } + + public H and() { + return ExpressionUrlAuthorizationConfigurer.this.and(); + } + + } + /** * An {@link AuthorizedUrl} that allows optionally configuring the * {@link MvcRequestMatcher#setMethod(HttpMethod)} * * @author Rob Winch */ - public class MvcMatchersAuthorizedUrl extends AuthorizedUrl { + public final class MvcMatchersAuthorizedUrl extends AuthorizedUrl { + /** * Creates a new instance - * * @param requestMatchers the {@link RequestMatcher} instances to map */ private MvcMatchersAuthorizedUrl(List requestMatchers) { @@ -287,18 +275,20 @@ public final class ExpressionUrlAuthorizationConfigurer requestMatchers; + private boolean not; /** * Creates a new instance - * * @param requestMatchers the {@link RequestMatcher} instances to map */ - private AuthorizedUrl(List requestMatchers) { + AuthorizedUrl(List requestMatchers) { this.requestMatchers = requestMatchers; } @@ -308,7 +298,6 @@ public final class ExpressionUrlAuthorizationConfigurersubnet. - * * @param ipaddressExpression the ipaddress (i.e. 192.168.1.79) or local subnet * (i.e. 192.168.0/24) * @return the {@link ExpressionUrlAuthorizationConfigurer} for further * customization */ public ExpressionInterceptUrlRegistry hasIpAddress(String ipaddressExpression) { - return access(ExpressionUrlAuthorizationConfigurer - .hasIpAddress(ipaddressExpression)); + return access(ExpressionUrlAuthorizationConfigurer.hasIpAddress(ipaddressExpression)); } /** * Specify that URLs are allowed by anyone. - * * @return the {@link ExpressionUrlAuthorizationConfigurer} for further * customization */ @@ -396,7 +377,6 @@ public final class ExpressionUrlAuthorizationConfigurer> extends *
  • /authenticate?error GET - redirect here for failed authentication attempts
  • *
  • /authenticate?logout GET - redirect here after successfully logging out
  • * - * - * * @param loginPage the login page to redirect to if authentication is required (i.e. * "/login") * @return the {@link FormLoginConfigurer} for additional customization @@ -186,7 +185,6 @@ public final class FormLoginConfigurer> extends /** * The HTTP parameter to look for the username when performing authentication. Default * is "username". - * * @param usernameParameter the HTTP parameter to look for the username when * performing authentication * @return the {@link FormLoginConfigurer} for additional customization @@ -199,7 +197,6 @@ public final class FormLoginConfigurer> extends /** * The HTTP parameter to look for the password when performing authentication. Default * is "password". - * * @param passwordParameter the HTTP parameter to look for the password when * performing authentication * @return the {@link FormLoginConfigurer} for additional customization @@ -211,7 +208,6 @@ public final class FormLoginConfigurer> extends /** * Forward Authentication Failure Handler - * * @param forwardUrl the target URL in case of failure * @return the {@link FormLoginConfigurer} for additional customization */ @@ -222,7 +218,6 @@ public final class FormLoginConfigurer> extends /** * Forward Authentication Success Handler - * * @param forwardUrl the target URL in case of success * @return the {@link FormLoginConfigurer} for additional customization */ @@ -237,13 +232,6 @@ public final class FormLoginConfigurer> extends initDefaultLoginFilter(http); } - /* - * (non-Javadoc) - * - * @see org.springframework.security.config.annotation.web.configurers. - * AbstractAuthenticationFilterConfigurer - * #createLoginProcessingUrlMatcher(java.lang.String) - */ @Override protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingUrl) { return new AntPathRequestMatcher(loginProcessingUrl, "POST"); @@ -251,7 +239,6 @@ public final class FormLoginConfigurer> extends /** * Gets the HTTP parameter that is used to submit the username. - * * @return the HTTP parameter that is used to submit the username */ private String getUsernameParameter() { @@ -260,7 +247,6 @@ public final class FormLoginConfigurer> extends /** * Gets the HTTP parameter that is used to submit the password. - * * @return the HTTP parameter that is used to submit the password */ private String getPasswordParameter() { @@ -270,7 +256,6 @@ public final class FormLoginConfigurer> extends /** * If available, initializes the {@link DefaultLoginPageGeneratingFilter} shared * object. - * * @param http the {@link HttpSecurityBuilder} to use */ private void initDefaultLoginFilter(H http) { @@ -285,4 +270,5 @@ public final class FormLoginConfigurer> extends loginPageGeneratingFilter.setAuthenticationUrl(getLoginProcessingUrl()); } } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurer.java index adcffa0648..9d007ec7fa 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurer.java @@ -29,8 +29,15 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.web.header.HeaderWriter; import org.springframework.security.web.header.HeaderWriterFilter; -import org.springframework.security.web.header.writers.*; +import org.springframework.security.web.header.writers.CacheControlHeadersWriter; +import org.springframework.security.web.header.writers.ContentSecurityPolicyHeaderWriter; +import org.springframework.security.web.header.writers.FeaturePolicyHeaderWriter; +import org.springframework.security.web.header.writers.HpkpHeaderWriter; +import org.springframework.security.web.header.writers.HstsHeaderWriter; +import org.springframework.security.web.header.writers.ReferrerPolicyHeaderWriter; import org.springframework.security.web.header.writers.ReferrerPolicyHeaderWriter.ReferrerPolicy; +import org.springframework.security.web.header.writers.XContentTypeOptionsHeaderWriter; +import org.springframework.security.web.header.writers.XXssProtectionHeaderWriter; import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter; import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter.XFrameOptionsMode; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -63,11 +70,10 @@ import org.springframework.util.Assert; * @author Vedran Pavic * @since 3.2 */ -public class HeadersConfigurer> extends - AbstractHttpConfigurer, H> { - private List headerWriters = new ArrayList<>(); +public class HeadersConfigurer> + extends AbstractHttpConfigurer, H> { - // --- default header writers --- + private List headerWriters = new ArrayList<>(); private final ContentTypeOptionsConfig contentTypeOptions = new ContentTypeOptionsConfig(); @@ -97,7 +103,6 @@ public class HeadersConfigurer> extends /** * Adds a {@link HeaderWriter} instance - * * @param headerWriter the {@link HeaderWriter} instance to add * @return the {@link HeadersConfigurer} for additional customizations */ @@ -108,76 +113,36 @@ public class HeadersConfigurer> extends } /** - * Configures the {@link XContentTypeOptionsHeaderWriter} which inserts the X-Content-Type-Options: * *
     	 * X-Content-Type-Options: nosniff
     	 * 
    - * * @return the {@link ContentTypeOptionsConfig} for additional customizations */ public ContentTypeOptionsConfig contentTypeOptions() { - return contentTypeOptions.enable(); + return this.contentTypeOptions.enable(); } /** - * Configures the {@link XContentTypeOptionsHeaderWriter} which inserts the X-Content-Type-Options: * *
     	 * X-Content-Type-Options: nosniff
     	 * 
    - * - * @param contentTypeOptionsCustomizer the {@link Customizer} to provide more options for - * the {@link ContentTypeOptionsConfig} + * @param contentTypeOptionsCustomizer the {@link Customizer} to provide more options + * for the {@link ContentTypeOptionsConfig} * @return the {@link HeadersConfigurer} for additional customizations */ public HeadersConfigurer contentTypeOptions(Customizer contentTypeOptionsCustomizer) { - contentTypeOptionsCustomizer.customize(contentTypeOptions.enable()); + contentTypeOptionsCustomizer.customize(this.contentTypeOptions.enable()); return HeadersConfigurer.this; } - public final class ContentTypeOptionsConfig { - private XContentTypeOptionsHeaderWriter writer; - - private ContentTypeOptionsConfig() { - enable(); - } - - /** - * Removes the X-XSS-Protection header. - * - * @return {@link HeadersConfigurer} for additional customization. - */ - public HeadersConfigurer disable() { - writer = null; - return and(); - } - - /** - * Allows customizing the {@link HeadersConfigurer} - * @return the {@link HeadersConfigurer} for additional customization - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - /** - * Ensures that Content Type Options is enabled - * - * @return the {@link ContentTypeOptionsConfig} for additional customization - */ - private ContentTypeOptionsConfig enable() { - if (writer == null) { - writer = new XContentTypeOptionsHeaderWriter(); - } - return this; - } - } - /** * Note this is not comprehensive XSS protection! * @@ -186,11 +151,10 @@ public class HeadersConfigurer> extends * "https://blogs.msdn.com/b/ieinternals/archive/2011/01/31/controlling-the-internet-explorer-xss-filter-with-the-x-xss-protection-http-header.aspx" * >X-XSS-Protection header *

    - * * @return the {@link XXssConfig} for additional customizations */ public XXssConfig xssProtection() { - return xssProtection.enable(); + return this.xssProtection.enable(); } /** @@ -201,95 +165,15 @@ public class HeadersConfigurer> extends * "https://blogs.msdn.com/b/ieinternals/archive/2011/01/31/controlling-the-internet-explorer-xss-filter-with-the-x-xss-protection-http-header.aspx" * >X-XSS-Protection header *

    - * - * @param xssCustomizer the {@link Customizer} to provide more options for - * the {@link XXssConfig} + * @param xssCustomizer the {@link Customizer} to provide more options for the + * {@link XXssConfig} * @return the {@link HeadersConfigurer} for additional customizations */ public HeadersConfigurer xssProtection(Customizer xssCustomizer) { - xssCustomizer.customize(xssProtection.enable()); + xssCustomizer.customize(this.xssProtection.enable()); return HeadersConfigurer.this; } - public final class XXssConfig { - private XXssProtectionHeaderWriter writer; - - private XXssConfig() { - enable(); - } - - /** - * If false, will not specify the mode as blocked. In this instance, any content - * will be attempted to be fixed. If true, the content will be replaced with "#". - * - * @param enabled the new value - */ - public XXssConfig block(boolean enabled) { - writer.setBlock(enabled); - return this; - } - - /** - * If true, the header value will contain a value of 1. For example: - * - *
    -		 * X-XSS-Protection: 1
    -		 * 
    - * - * or if {@link XXssProtectionHeaderWriter#setBlock(boolean)} of the given {@link XXssProtectionHeaderWriter} is true - * - * - *
    -		 * X-XSS-Protection: 1; mode=block
    -		 * 
    - * - * If false, will explicitly disable specify that X-XSS-Protection is disabled. - * For example: - * - *
    -		 * X-XSS-Protection: 0
    -		 * 
    - * - * @param enabled the new value - */ - public XXssConfig xssProtectionEnabled(boolean enabled) { - writer.setEnabled(enabled); - return this; - } - - /** - * Disables X-XSS-Protection header (does not include it) - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer disable() { - writer = null; - return and(); - } - - /** - * Allows completing configuration of X-XSS-Protection and continuing - * configuration of headers. - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - /** - * Ensures the X-XSS-Protection header is enabled if it is not already. - * - * @return the {@link XXssConfig} for additional customization - */ - private XXssConfig enable() { - if (writer == null) { - writer = new XXssProtectionHeaderWriter(); - } - return this; - } - } - /** * Allows customizing the {@link CacheControlHeadersWriter}. Specifically it adds the * following headers: @@ -298,11 +182,10 @@ public class HeadersConfigurer> extends *
  • Pragma: no-cache
  • *
  • Expires: 0
  • * - * * @return the {@link CacheControlConfig} for additional customizations */ public CacheControlConfig cacheControl() { - return cacheControl.enable(); + return this.cacheControl.enable(); } /** @@ -313,571 +196,142 @@ public class HeadersConfigurer> extends *
  • Pragma: no-cache
  • *
  • Expires: 0
  • * - * * @param cacheControlCustomizer the {@link Customizer} to provide more options for * the {@link CacheControlConfig} * @return the {@link HeadersConfigurer} for additional customizations */ public HeadersConfigurer cacheControl(Customizer cacheControlCustomizer) { - cacheControlCustomizer.customize(cacheControl.enable()); + cacheControlCustomizer.customize(this.cacheControl.enable()); return HeadersConfigurer.this; } - public final class CacheControlConfig { - private CacheControlHeadersWriter writer; - - private CacheControlConfig() { - enable(); - } - - /** - * Disables Cache Control - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer disable() { - writer = null; - return HeadersConfigurer.this; - } - - /** - * Allows completing configuration of Cache Control and continuing - * configuration of headers. - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - /** - * Ensures the Cache Control headers are enabled if they are not already. - * - * @return the {@link CacheControlConfig} for additional customization - */ - private CacheControlConfig enable() { - if (writer == null) { - writer = new CacheControlHeadersWriter(); - } - return this; - } - } - /** - * Allows customizing the {@link HstsHeaderWriter} which provides support for HTTP Strict Transport Security + * Allows customizing the {@link HstsHeaderWriter} which provides support for + * HTTP Strict Transport Security * (HSTS). - * * @return the {@link HstsConfig} for additional customizations */ public HstsConfig httpStrictTransportSecurity() { - return hsts.enable(); + return this.hsts.enable(); } /** - * Allows customizing the {@link HstsHeaderWriter} which provides support for HTTP Strict Transport Security + * Allows customizing the {@link HstsHeaderWriter} which provides support for + * HTTP Strict Transport Security * (HSTS). - * - * @param hstsCustomizer the {@link Customizer} to provide more options for - * the {@link HstsConfig} + * @param hstsCustomizer the {@link Customizer} to provide more options for the + * {@link HstsConfig} * @return the {@link HeadersConfigurer} for additional customizations */ public HeadersConfigurer httpStrictTransportSecurity(Customizer hstsCustomizer) { - hstsCustomizer.customize(hsts.enable()); + hstsCustomizer.customize(this.hsts.enable()); return HeadersConfigurer.this; } - public final class HstsConfig { - private HstsHeaderWriter writer; - - private HstsConfig() { - enable(); - } - - /** - *

    - * Sets the value (in seconds) for the max-age directive of the - * Strict-Transport-Security header. The default is one year. - *

    - * - *

    - * This instructs browsers how long to remember to keep this domain as a known - * HSTS Host. See Section 6.1.1 for - * additional details. - *

    - * - * @param maxAgeInSeconds the maximum amount of time (in seconds) to consider this - * domain as a known HSTS Host. - * @throws IllegalArgumentException if maxAgeInSeconds is negative - */ - public HstsConfig maxAgeInSeconds(long maxAgeInSeconds) { - writer.setMaxAgeInSeconds(maxAgeInSeconds); - return this; - } - - /** - * Sets the {@link RequestMatcher} used to determine if the - * "Strict-Transport-Security" should be added. If true the header is added, else - * the header is not added. By default the header is added when - * {@link HttpServletRequest#isSecure()} returns true. - * - * @param requestMatcher the {@link RequestMatcher} to use. - * @throws IllegalArgumentException if {@link RequestMatcher} is null - */ - public HstsConfig requestMatcher(RequestMatcher requestMatcher) { - writer.setRequestMatcher(requestMatcher); - return this; - } - - /** - *

    - * If true, subdomains should be considered HSTS Hosts too. The default is true. - *

    - * - *

    - * See Section - * 6.1.2 for additional details. - *

    - * - * @param includeSubDomains true to include subdomains, else false - */ - public HstsConfig includeSubDomains(boolean includeSubDomains) { - writer.setIncludeSubDomains(includeSubDomains); - return this; - } - - /** - *

    - * If true, preload will be included in HSTS Header. The default is false. - *

    - * - *

    - * See Website hstspreload.org - * for additional details. - *

    - * - * @param preload true to include preload, else false - * @since 5.2.0 - * @author Ankur Pathak - */ - public HstsConfig preload(boolean preload) { - writer.setPreload(preload); - return this; - } - - /** - * Disables Strict Transport Security - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer disable() { - writer = null; - return HeadersConfigurer.this; - } - - /** - * Allows completing configuration of Strict Transport Security and continuing - * configuration of headers. - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - /** - * Ensures that Strict-Transport-Security is enabled if it is not already - * - * @return the {@link HstsConfig} for additional customization - */ - private HstsConfig enable() { - if (writer == null) { - writer = new HstsHeaderWriter(); - } - return this; - } - } - /** * Allows customizing the {@link XFrameOptionsHeaderWriter}. - * * @return the {@link FrameOptionsConfig} for additional customizations */ public FrameOptionsConfig frameOptions() { - return frameOptions.enable(); + return this.frameOptions.enable(); } /** * Allows customizing the {@link XFrameOptionsHeaderWriter}. - * * @param frameOptionsCustomizer the {@link Customizer} to provide more options for * the {@link FrameOptionsConfig} * @return the {@link HeadersConfigurer} for additional customizations */ public HeadersConfigurer frameOptions(Customizer frameOptionsCustomizer) { - frameOptionsCustomizer.customize(frameOptions.enable()); + frameOptionsCustomizer.customize(this.frameOptions.enable()); return HeadersConfigurer.this; } - public final class FrameOptionsConfig { - private XFrameOptionsHeaderWriter writer; - - private FrameOptionsConfig() { - enable(); - } - - /** - * Specify to DENY framing any content from this application. - * - * @return the {@link HeadersConfigurer} for additional customization. - */ - public HeadersConfigurer deny() { - writer = new XFrameOptionsHeaderWriter(XFrameOptionsMode.DENY); - return and(); - } - - /** - *

    - * Specify to allow any request that comes from the same origin to frame this - * application. For example, if the application was hosted on example.com, then - * example.com could frame the application, but evil.com could not frame the - * application. - *

    - * - * @return the {@link HeadersConfigurer} for additional customization. - */ - public HeadersConfigurer sameOrigin() { - writer = new XFrameOptionsHeaderWriter(XFrameOptionsMode.SAMEORIGIN); - return and(); - } - - /** - * Prevents the header from being added to the response. - * - * @return the {@link HeadersConfigurer} for additional configuration. - */ - public HeadersConfigurer disable() { - writer = null; - return and(); - } - - /** - * Allows continuing customizing the headers configuration. - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - /** - * Enables FrameOptionsConfig if it is not already enabled. - * - * @return the FrameOptionsConfig for additional customization. - */ - private FrameOptionsConfig enable() { - if (writer == null) { - writer = new XFrameOptionsHeaderWriter(XFrameOptionsMode.DENY); - } - return this; - } - } - /** - * Allows customizing the {@link HpkpHeaderWriter} which provides support for HTTP Public Key Pinning (HPKP). - * + * Allows customizing the {@link HpkpHeaderWriter} which provides support for + * HTTP Public Key Pinning (HPKP). * @return the {@link HpkpConfig} for additional customizations * * @since 4.1 */ public HpkpConfig httpPublicKeyPinning() { - return hpkp.enable(); + return this.hpkp.enable(); } /** - * Allows customizing the {@link HpkpHeaderWriter} which provides support for HTTP Public Key Pinning (HPKP). - * - * @param hpkpCustomizer the {@link Customizer} to provide more options for - * the {@link HpkpConfig} + * Allows customizing the {@link HpkpHeaderWriter} which provides support for + * HTTP Public Key Pinning (HPKP). + * @param hpkpCustomizer the {@link Customizer} to provide more options for the + * {@link HpkpConfig} * @return the {@link HeadersConfigurer} for additional customizations */ public HeadersConfigurer httpPublicKeyPinning(Customizer hpkpCustomizer) { - hpkpCustomizer.customize(hpkp.enable()); + hpkpCustomizer.customize(this.hpkp.enable()); return HeadersConfigurer.this; } - public final class HpkpConfig { - private HpkpHeaderWriter writer; - - private HpkpConfig() {} - - /** - *

    - * Sets the value for the pin- directive of the Public-Key-Pins header. - *

    - * - *

    - * The pin directive specifies a way for web host operators to indicate - * a cryptographic identity that should be bound to a given web host. - * See Section 2.1.1 for additional details. - *

    - * - * @param pins the map of base64-encoded SPKI fingerprint & cryptographic hash algorithm pairs. - * @throws IllegalArgumentException if pins is null - */ - public HpkpConfig withPins(Map pins) { - writer.setPins(pins); - return this; - } - - /** - *

    - * Adds a list of SHA256 hashed pins for the pin- directive of the Public-Key-Pins header. - *

    - * - *

    - * The pin directive specifies a way for web host operators to indicate - * a cryptographic identity that should be bound to a given web host. - * See Section 2.1.1 for additional details. - *

    - * - * @param pins a list of base64-encoded SPKI fingerprints. - * @throws IllegalArgumentException if a pin is null - */ - public HpkpConfig addSha256Pins(String ... pins) { - writer.addSha256Pins(pins); - return this; - } - - /** - *

    - * Sets the value (in seconds) for the max-age directive of the Public-Key-Pins header. - * The default is 60 days. - *

    - * - *

    - * This instructs browsers how long they should regard the host (from whom the message was received) - * as a known pinned host. See Section - * 2.1.2 for additional details. - *

    - * - * @param maxAgeInSeconds the maximum amount of time (in seconds) to regard the host - * as a known pinned host. - * @throws IllegalArgumentException if maxAgeInSeconds is negative - */ - public HpkpConfig maxAgeInSeconds(long maxAgeInSeconds) { - writer.setMaxAgeInSeconds(maxAgeInSeconds); - return this; - } - - /** - *

    - * If true, the pinning policy applies to this pinned host as well as any subdomains - * of the host's domain name. The default is false. - *

    - * - *

    - * See Section 2.1.3 - * for additional details. - *

    - * - * @param includeSubDomains true to include subdomains, else false - */ - public HpkpConfig includeSubDomains(boolean includeSubDomains) { - writer.setIncludeSubDomains(includeSubDomains); - return this; - } - - /** - *

    - * If true, the browser should not terminate the connection with the server. The default is true. - *

    - * - *

    - * See Section 2.1 - * for additional details. - *

    - * - * @param reportOnly true to report only, else false - */ - public HpkpConfig reportOnly(boolean reportOnly) { - writer.setReportOnly(reportOnly); - return this; - } - - /** - *

    - * Sets the URI to which the browser should report pin validation failures. - *

    - * - *

    - * See Section 2.1.4 - * for additional details. - *

    - * - * @param reportUri the URI where the browser should send the report to. - */ - public HpkpConfig reportUri(URI reportUri) { - writer.setReportUri(reportUri); - return this; - } - - /** - *

    - * Sets the URI to which the browser should report pin validation failures. - *

    - * - *

    - * See Section 2.1.4 - * for additional details. - *

    - * - * @param reportUri the URI where the browser should send the report to. - * @throws IllegalArgumentException if the reportUri is not a valid URI - */ - public HpkpConfig reportUri(String reportUri) { - writer.setReportUri(reportUri); - return this; - } - - /** - * Prevents the header from being added to the response. - * - * @return the {@link HeadersConfigurer} for additional configuration. - */ - public HeadersConfigurer disable() { - writer = null; - return and(); - } - - /** - * Allows completing configuration of Public Key Pinning and continuing - * configuration of headers. - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - /** - * Ensures that Public-Key-Pins or Public-Key-Pins-Report-Only is enabled if it is not already - * - * @return the {@link HstsConfig} for additional customization - */ - private HpkpConfig enable() { - if (writer == null) { - writer = new HpkpHeaderWriter(); - } - return this; - } - } - /** *

    - * Allows configuration for Content Security Policy (CSP) Level 2. + * Allows configuration for Content Security + * Policy (CSP) Level 2. *

    * *

    - * Calling this method automatically enables (includes) the Content-Security-Policy header in the response - * using the supplied security policy directive(s). + * Calling this method automatically enables (includes) the Content-Security-Policy + * header in the response using the supplied security policy directive(s). *

    * *

    - * Configuration is provided to the {@link ContentSecurityPolicyHeaderWriter} which supports the writing - * of the two headers as detailed in the W3C Candidate Recommendation: + * Configuration is provided to the {@link ContentSecurityPolicyHeaderWriter} which + * supports the writing of the two headers as detailed in the W3C Candidate + * Recommendation: *

    *
      - *
    • Content-Security-Policy
    • - *
    • Content-Security-Policy-Report-Only
    • + *
    • Content-Security-Policy
    • + *
    • Content-Security-Policy-Report-Only
    • *
    - * - * @see ContentSecurityPolicyHeaderWriter - * @since 4.1 * @return the {@link ContentSecurityPolicyConfig} for additional configuration * @throws IllegalArgumentException if policyDirectives is null or empty + * @since 4.1 + * @see ContentSecurityPolicyHeaderWriter */ public ContentSecurityPolicyConfig contentSecurityPolicy(String policyDirectives) { - this.contentSecurityPolicy.writer = - new ContentSecurityPolicyHeaderWriter(policyDirectives); - return contentSecurityPolicy; + this.contentSecurityPolicy.writer = new ContentSecurityPolicyHeaderWriter(policyDirectives); + return this.contentSecurityPolicy; } /** *

    - * Allows configuration for Content Security Policy (CSP) Level 2. + * Allows configuration for Content Security + * Policy (CSP) Level 2. *

    * *

    - * Calling this method automatically enables (includes) the Content-Security-Policy header in the response - * using the supplied security policy directive(s). + * Calling this method automatically enables (includes) the Content-Security-Policy + * header in the response using the supplied security policy directive(s). *

    * *

    - * Configuration is provided to the {@link ContentSecurityPolicyHeaderWriter} which supports the writing - * of the two headers as detailed in the W3C Candidate Recommendation: + * Configuration is provided to the {@link ContentSecurityPolicyHeaderWriter} which + * supports the writing of the two headers as detailed in the W3C Candidate + * Recommendation: *

    *
      - *
    • Content-Security-Policy
    • - *
    • Content-Security-Policy-Report-Only
    • + *
    • Content-Security-Policy
    • + *
    • Content-Security-Policy-Report-Only
    • *
    - * - * @see ContentSecurityPolicyHeaderWriter * @param contentSecurityCustomizer the {@link Customizer} to provide more options for * the {@link ContentSecurityPolicyConfig} * @return the {@link HeadersConfigurer} for additional customizations + * @see ContentSecurityPolicyHeaderWriter */ - public HeadersConfigurer contentSecurityPolicy(Customizer contentSecurityCustomizer) { + public HeadersConfigurer contentSecurityPolicy( + Customizer contentSecurityCustomizer) { this.contentSecurityPolicy.writer = new ContentSecurityPolicyHeaderWriter(); contentSecurityCustomizer.customize(this.contentSecurityPolicy); - return HeadersConfigurer.this; } - public final class ContentSecurityPolicyConfig { - private ContentSecurityPolicyHeaderWriter writer; - - private ContentSecurityPolicyConfig() { - } - - /** - * Sets the security policy directive(s) to be used in the response header. - * - * @param policyDirectives the security policy directive(s) - * @return the {@link ContentSecurityPolicyConfig} for additional configuration - * @throws IllegalArgumentException if policyDirectives is null or empty - */ - public ContentSecurityPolicyConfig policyDirectives(String policyDirectives) { - this.writer.setPolicyDirectives(policyDirectives); - return this; - } - - /** - * Enables (includes) the Content-Security-Policy-Report-Only header in the response. - * - * @return the {@link ContentSecurityPolicyConfig} for additional configuration - */ - public ContentSecurityPolicyConfig reportOnly() { - this.writer.setReportOnly(true); - return this; - } - - /** - * Allows completing configuration of Content Security Policy and continuing - * configuration of headers. - * - * @return the {@link HeadersConfigurer} for additional configuration - */ - public HeadersConfigurer and() { - return HeadersConfigurer.this; - } - - } - /** * Clears all of the default headers from the response. After doing so, one can add * headers back. For example, if you only want to use Spring Security's cache control @@ -886,15 +340,14 @@ public class HeadersConfigurer> extends *
     	 * http.headers().defaultsDisabled().cacheControl();
     	 * 
    - * * @return the {@link HeadersConfigurer} for additional customization */ public HeadersConfigurer defaultsDisabled() { - contentTypeOptions.disable(); - xssProtection.disable(); - cacheControl.disable(); - hsts.disable(); - frameOptions.disable(); + this.contentTypeOptions.disable(); + this.xssProtection.disable(); + this.cacheControl.disable(); + this.hsts.disable(); + this.frameOptions.disable(); return this; } @@ -906,7 +359,6 @@ public class HeadersConfigurer> extends /** * Creates the {@link HeaderWriter} - * * @return the {@link HeaderWriter} */ private HeaderWriterFilter createHeaderWriterFilter() { @@ -922,21 +374,20 @@ public class HeadersConfigurer> extends /** * Gets the {@link HeaderWriter} instances and possibly initializes with the defaults. - * * @return */ private List getHeaderWriters() { List writers = new ArrayList<>(); - addIfNotNull(writers, contentTypeOptions.writer); - addIfNotNull(writers, xssProtection.writer); - addIfNotNull(writers, cacheControl.writer); - addIfNotNull(writers, hsts.writer); - addIfNotNull(writers, frameOptions.writer); - addIfNotNull(writers, hpkp.writer); - addIfNotNull(writers, contentSecurityPolicy.writer); - addIfNotNull(writers, referrerPolicy.writer); - addIfNotNull(writers, featurePolicy.writer); - writers.addAll(headerWriters); + addIfNotNull(writers, this.contentTypeOptions.writer); + addIfNotNull(writers, this.xssProtection.writer); + addIfNotNull(writers, this.cacheControl.writer); + addIfNotNull(writers, this.hsts.writer); + addIfNotNull(writers, this.frameOptions.writer); + addIfNotNull(writers, this.hpkp.writer); + addIfNotNull(writers, this.contentSecurityPolicy.writer); + addIfNotNull(writers, this.referrerPolicy.writer); + addIfNotNull(writers, this.featurePolicy.writer); + writers.addAll(this.headerWriters); return writers; } @@ -948,26 +399,28 @@ public class HeadersConfigurer> extends /** *

    - * Allows configuration for Referrer Policy. + * Allows configuration for Referrer + * Policy. *

    * *

    - * Configuration is provided to the {@link ReferrerPolicyHeaderWriter} which support the writing - * of the header as detailed in the W3C Technical Report: + * Configuration is provided to the {@link ReferrerPolicyHeaderWriter} which support + * the writing of the header as detailed in the W3C Technical Report: *

    *
      - *
    • Referrer-Policy
    • + *
    • Referrer-Policy
    • *
    * - *

    Default value is:

    + *

    + * Default value is: + *

    * *
     	 * Referrer-Policy: no-referrer
     	 * 
    - * - * @see ReferrerPolicyHeaderWriter - * @since 4.2 * @return the {@link ReferrerPolicyConfig} for additional configuration + * @since 4.2 + * @see ReferrerPolicyHeaderWriter */ public ReferrerPolicyConfig referrerPolicy() { this.referrerPolicy.writer = new ReferrerPolicyHeaderWriter(); @@ -976,21 +429,21 @@ public class HeadersConfigurer> extends /** *

    - * Allows configuration for Referrer Policy. + * Allows configuration for Referrer + * Policy. *

    * *

    - * Configuration is provided to the {@link ReferrerPolicyHeaderWriter} which support the writing - * of the header as detailed in the W3C Technical Report: + * Configuration is provided to the {@link ReferrerPolicyHeaderWriter} which support + * the writing of the header as detailed in the W3C Technical Report: *

    *
      - *
    • Referrer-Policy
    • + *
    • Referrer-Policy
    • *
    - * - * @see ReferrerPolicyHeaderWriter - * @since 4.2 * @return the {@link ReferrerPolicyConfig} for additional configuration * @throws IllegalArgumentException if policy is null or empty + * @since 4.2 + * @see ReferrerPolicyHeaderWriter */ public ReferrerPolicyConfig referrerPolicy(ReferrerPolicy policy) { this.referrerPolicy.writer = new ReferrerPolicyHeaderWriter(policy); @@ -999,21 +452,21 @@ public class HeadersConfigurer> extends /** *

    - * Allows configuration for Referrer Policy. + * Allows configuration for Referrer + * Policy. *

    * *

    - * Configuration is provided to the {@link ReferrerPolicyHeaderWriter} which support the writing - * of the header as detailed in the W3C Technical Report: + * Configuration is provided to the {@link ReferrerPolicyHeaderWriter} which support + * the writing of the header as detailed in the W3C Technical Report: *

    *
      - *
    • Referrer-Policy
    • + *
    • Referrer-Policy
    • *
    - * - * @see ReferrerPolicyHeaderWriter * @param referrerPolicyCustomizer the {@link Customizer} to provide more options for * the {@link ReferrerPolicyConfig} * @return the {@link HeadersConfigurer} for additional customizations + * @see ReferrerPolicyHeaderWriter */ public HeadersConfigurer referrerPolicy(Customizer referrerPolicyCustomizer) { this.referrerPolicy.writer = new ReferrerPolicyHeaderWriter(); @@ -1021,6 +474,553 @@ public class HeadersConfigurer> extends return HeadersConfigurer.this; } + /** + * Allows configuration for Feature + * Policy. + *

    + * Calling this method automatically enables (includes) the {@code Feature-Policy} + * header in the response using the supplied policy directive(s). + *

    + * Configuration is provided to the {@link FeaturePolicyHeaderWriter} which is + * responsible for writing the header. + * @return the {@link FeaturePolicyConfig} for additional configuration + * @throws IllegalArgumentException if policyDirectives is {@code null} or empty + * @since 5.1 + * @see FeaturePolicyHeaderWriter + */ + public FeaturePolicyConfig featurePolicy(String policyDirectives) { + this.featurePolicy.writer = new FeaturePolicyHeaderWriter(policyDirectives); + return this.featurePolicy; + } + + public final class ContentTypeOptionsConfig { + + private XContentTypeOptionsHeaderWriter writer; + + private ContentTypeOptionsConfig() { + enable(); + } + + /** + * Removes the X-XSS-Protection header. + * @return {@link HeadersConfigurer} for additional customization. + */ + public HeadersConfigurer disable() { + this.writer = null; + return and(); + } + + /** + * Allows customizing the {@link HeadersConfigurer} + * @return the {@link HeadersConfigurer} for additional customization + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + /** + * Ensures that Content Type Options is enabled + * @return the {@link ContentTypeOptionsConfig} for additional customization + */ + private ContentTypeOptionsConfig enable() { + if (this.writer == null) { + this.writer = new XContentTypeOptionsHeaderWriter(); + } + return this; + } + + } + + public final class XXssConfig { + + private XXssProtectionHeaderWriter writer; + + private XXssConfig() { + enable(); + } + + /** + * If false, will not specify the mode as blocked. In this instance, any content + * will be attempted to be fixed. If true, the content will be replaced with "#". + * @param enabled the new value + */ + public XXssConfig block(boolean enabled) { + this.writer.setBlock(enabled); + return this; + } + + /** + * If true, the header value will contain a value of 1. For example: + * + *

    +		 * X-XSS-Protection: 1
    +		 * 
    + * + * or if {@link XXssProtectionHeaderWriter#setBlock(boolean)} of the given + * {@link XXssProtectionHeaderWriter} is true + * + * + *
    +		 * X-XSS-Protection: 1; mode=block
    +		 * 
    + * + * If false, will explicitly disable specify that X-XSS-Protection is disabled. + * For example: + * + *
    +		 * X-XSS-Protection: 0
    +		 * 
    + * @param enabled the new value + */ + public XXssConfig xssProtectionEnabled(boolean enabled) { + this.writer.setEnabled(enabled); + return this; + } + + /** + * Disables X-XSS-Protection header (does not include it) + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer disable() { + this.writer = null; + return and(); + } + + /** + * Allows completing configuration of X-XSS-Protection and continuing + * configuration of headers. + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + /** + * Ensures the X-XSS-Protection header is enabled if it is not already. + * @return the {@link XXssConfig} for additional customization + */ + private XXssConfig enable() { + if (this.writer == null) { + this.writer = new XXssProtectionHeaderWriter(); + } + return this; + } + + } + + public final class CacheControlConfig { + + private CacheControlHeadersWriter writer; + + private CacheControlConfig() { + enable(); + } + + /** + * Disables Cache Control + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer disable() { + this.writer = null; + return HeadersConfigurer.this; + } + + /** + * Allows completing configuration of Cache Control and continuing configuration + * of headers. + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + /** + * Ensures the Cache Control headers are enabled if they are not already. + * @return the {@link CacheControlConfig} for additional customization + */ + private CacheControlConfig enable() { + if (this.writer == null) { + this.writer = new CacheControlHeadersWriter(); + } + return this; + } + + } + + public final class HstsConfig { + + private HstsHeaderWriter writer; + + private HstsConfig() { + enable(); + } + + /** + *

    + * Sets the value (in seconds) for the max-age directive of the + * Strict-Transport-Security header. The default is one year. + *

    + * + *

    + * This instructs browsers how long to remember to keep this domain as a known + * HSTS Host. See + * Section 6.1.1 + * for additional details. + *

    + * @param maxAgeInSeconds the maximum amount of time (in seconds) to consider this + * domain as a known HSTS Host. + * @throws IllegalArgumentException if maxAgeInSeconds is negative + */ + public HstsConfig maxAgeInSeconds(long maxAgeInSeconds) { + this.writer.setMaxAgeInSeconds(maxAgeInSeconds); + return this; + } + + /** + * Sets the {@link RequestMatcher} used to determine if the + * "Strict-Transport-Security" should be added. If true the header is added, else + * the header is not added. By default the header is added when + * {@link HttpServletRequest#isSecure()} returns true. + * @param requestMatcher the {@link RequestMatcher} to use. + * @throws IllegalArgumentException if {@link RequestMatcher} is null + */ + public HstsConfig requestMatcher(RequestMatcher requestMatcher) { + this.writer.setRequestMatcher(requestMatcher); + return this; + } + + /** + *

    + * If true, subdomains should be considered HSTS Hosts too. The default is true. + *

    + * + *

    + * See Section + * 6.1.2 for additional details. + *

    + * @param includeSubDomains true to include subdomains, else false + */ + public HstsConfig includeSubDomains(boolean includeSubDomains) { + this.writer.setIncludeSubDomains(includeSubDomains); + return this; + } + + /** + *

    + * If true, preload will be included in HSTS Header. The default is false. + *

    + * + *

    + * See Website hstspreload.org for + * additional details. + *

    + * @param preload true to include preload, else false + * @since 5.2.0 + * @author Ankur Pathak + */ + public HstsConfig preload(boolean preload) { + this.writer.setPreload(preload); + return this; + } + + /** + * Disables Strict Transport Security + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer disable() { + this.writer = null; + return HeadersConfigurer.this; + } + + /** + * Allows completing configuration of Strict Transport Security and continuing + * configuration of headers. + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + /** + * Ensures that Strict-Transport-Security is enabled if it is not already + * @return the {@link HstsConfig} for additional customization + */ + private HstsConfig enable() { + if (this.writer == null) { + this.writer = new HstsHeaderWriter(); + } + return this; + } + + } + + public final class FrameOptionsConfig { + + private XFrameOptionsHeaderWriter writer; + + private FrameOptionsConfig() { + enable(); + } + + /** + * Specify to DENY framing any content from this application. + * @return the {@link HeadersConfigurer} for additional customization. + */ + public HeadersConfigurer deny() { + this.writer = new XFrameOptionsHeaderWriter(XFrameOptionsMode.DENY); + return and(); + } + + /** + *

    + * Specify to allow any request that comes from the same origin to frame this + * application. For example, if the application was hosted on example.com, then + * example.com could frame the application, but evil.com could not frame the + * application. + *

    + * @return the {@link HeadersConfigurer} for additional customization. + */ + public HeadersConfigurer sameOrigin() { + this.writer = new XFrameOptionsHeaderWriter(XFrameOptionsMode.SAMEORIGIN); + return and(); + } + + /** + * Prevents the header from being added to the response. + * @return the {@link HeadersConfigurer} for additional configuration. + */ + public HeadersConfigurer disable() { + this.writer = null; + return and(); + } + + /** + * Allows continuing customizing the headers configuration. + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + /** + * Enables FrameOptionsConfig if it is not already enabled. + * @return the FrameOptionsConfig for additional customization. + */ + private FrameOptionsConfig enable() { + if (this.writer == null) { + this.writer = new XFrameOptionsHeaderWriter(XFrameOptionsMode.DENY); + } + return this; + } + + } + + public final class HpkpConfig { + + private HpkpHeaderWriter writer; + + private HpkpConfig() { + } + + /** + *

    + * Sets the value for the pin- directive of the Public-Key-Pins header. + *

    + * + *

    + * The pin directive specifies a way for web host operators to indicate a + * cryptographic identity that should be bound to a given web host. See + * Section 2.1.1 + * for additional details. + *

    + * @param pins the map of base64-encoded SPKI fingerprint & cryptographic hash + * algorithm pairs. + * @throws IllegalArgumentException if pins is null + */ + public HpkpConfig withPins(Map pins) { + this.writer.setPins(pins); + return this; + } + + /** + *

    + * Adds a list of SHA256 hashed pins for the pin- directive of the Public-Key-Pins + * header. + *

    + * + *

    + * The pin directive specifies a way for web host operators to indicate a + * cryptographic identity that should be bound to a given web host. See + * Section 2.1.1 + * for additional details. + *

    + * @param pins a list of base64-encoded SPKI fingerprints. + * @throws IllegalArgumentException if a pin is null + */ + public HpkpConfig addSha256Pins(String... pins) { + this.writer.addSha256Pins(pins); + return this; + } + + /** + *

    + * Sets the value (in seconds) for the max-age directive of the Public-Key-Pins + * header. The default is 60 days. + *

    + * + *

    + * This instructs browsers how long they should regard the host (from whom the + * message was received) as a known pinned host. See + * Section 2.1.2 + * for additional details. + *

    + * @param maxAgeInSeconds the maximum amount of time (in seconds) to regard the + * host as a known pinned host. + * @throws IllegalArgumentException if maxAgeInSeconds is negative + */ + public HpkpConfig maxAgeInSeconds(long maxAgeInSeconds) { + this.writer.setMaxAgeInSeconds(maxAgeInSeconds); + return this; + } + + /** + *

    + * If true, the pinning policy applies to this pinned host as well as any + * subdomains of the host's domain name. The default is false. + *

    + * + *

    + * See Section + * 2.1.3 for additional details. + *

    + * @param includeSubDomains true to include subdomains, else false + */ + public HpkpConfig includeSubDomains(boolean includeSubDomains) { + this.writer.setIncludeSubDomains(includeSubDomains); + return this; + } + + /** + *

    + * If true, the browser should not terminate the connection with the server. The + * default is true. + *

    + * + *

    + * See Section 2.1 + * for additional details. + *

    + * @param reportOnly true to report only, else false + */ + public HpkpConfig reportOnly(boolean reportOnly) { + this.writer.setReportOnly(reportOnly); + return this; + } + + /** + *

    + * Sets the URI to which the browser should report pin validation failures. + *

    + * + *

    + * See Section + * 2.1.4 for additional details. + *

    + * @param reportUri the URI where the browser should send the report to. + */ + public HpkpConfig reportUri(URI reportUri) { + this.writer.setReportUri(reportUri); + return this; + } + + /** + *

    + * Sets the URI to which the browser should report pin validation failures. + *

    + * + *

    + * See Section + * 2.1.4 for additional details. + *

    + * @param reportUri the URI where the browser should send the report to. + * @throws IllegalArgumentException if the reportUri is not a valid URI + */ + public HpkpConfig reportUri(String reportUri) { + this.writer.setReportUri(reportUri); + return this; + } + + /** + * Prevents the header from being added to the response. + * @return the {@link HeadersConfigurer} for additional configuration. + */ + public HeadersConfigurer disable() { + this.writer = null; + return and(); + } + + /** + * Allows completing configuration of Public Key Pinning and continuing + * configuration of headers. + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + /** + * Ensures that Public-Key-Pins or Public-Key-Pins-Report-Only is enabled if it is + * not already + * @return the {@link HstsConfig} for additional customization + */ + private HpkpConfig enable() { + if (this.writer == null) { + this.writer = new HpkpHeaderWriter(); + } + return this; + } + + } + + public final class ContentSecurityPolicyConfig { + + private ContentSecurityPolicyHeaderWriter writer; + + private ContentSecurityPolicyConfig() { + } + + /** + * Sets the security policy directive(s) to be used in the response header. + * @param policyDirectives the security policy directive(s) + * @return the {@link ContentSecurityPolicyConfig} for additional configuration + * @throws IllegalArgumentException if policyDirectives is null or empty + */ + public ContentSecurityPolicyConfig policyDirectives(String policyDirectives) { + this.writer.setPolicyDirectives(policyDirectives); + return this; + } + + /** + * Enables (includes) the Content-Security-Policy-Report-Only header in the + * response. + * @return the {@link ContentSecurityPolicyConfig} for additional configuration + */ + public ContentSecurityPolicyConfig reportOnly() { + this.writer.setReportOnly(true); + return this; + } + + /** + * Allows completing configuration of Content Security Policy and continuing + * configuration of headers. + * @return the {@link HeadersConfigurer} for additional configuration + */ + public HeadersConfigurer and() { + return HeadersConfigurer.this; + } + + } + public final class ReferrerPolicyConfig { private ReferrerPolicyHeaderWriter writer; @@ -1030,7 +1030,6 @@ public class HeadersConfigurer> extends /** * Sets the policy to be used in the response header. - * * @param policy a referrer policy * @return the {@link ReferrerPolicyConfig} for additional configuration * @throws IllegalArgumentException if policy is null @@ -1046,26 +1045,6 @@ public class HeadersConfigurer> extends } - /** - * Allows configuration for Feature - * Policy. - *

    - * Calling this method automatically enables (includes) the {@code Feature-Policy} - * header in the response using the supplied policy directive(s). - *

    - * Configuration is provided to the {@link FeaturePolicyHeaderWriter} which is - * responsible for writing the header. - * - * @see FeaturePolicyHeaderWriter - * @since 5.1 - * @return the {@link FeaturePolicyConfig} for additional configuration - * @throws IllegalArgumentException if policyDirectives is {@code null} or empty - */ - public FeaturePolicyConfig featurePolicy(String policyDirectives) { - this.featurePolicy.writer = new FeaturePolicyHeaderWriter(policyDirectives); - return featurePolicy; - } - public final class FeaturePolicyConfig { private FeaturePolicyHeaderWriter writer; @@ -1076,7 +1055,6 @@ public class HeadersConfigurer> extends /** * Allows completing configuration of Feature Policy and continuing configuration * of headers. - * * @return the {@link HeadersConfigurer} for additional configuration */ public HeadersConfigurer and() { diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurer.java index e4defae601..e45bb31a06 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.Arrays; @@ -53,8 +54,7 @@ import org.springframework.web.accept.HeaderContentNegotiationStrategy; * The following Filters are populated * *

      - *
    • - * {@link BasicAuthenticationFilter}
    • + *
    • {@link BasicAuthenticationFilter}
    • *
    * *

    Shared Objects Created

    @@ -77,16 +77,18 @@ import org.springframework.web.accept.HeaderContentNegotiationStrategy; * @author Rob Winch * @since 3.2 */ -public final class HttpBasicConfigurer> extends - AbstractHttpConfigurer, B> { +public final class HttpBasicConfigurer> + extends AbstractHttpConfigurer, B> { - private static final RequestHeaderRequestMatcher X_REQUESTED_WITH = new RequestHeaderRequestMatcher("X-Requested-With", - "XMLHttpRequest"); + private static final RequestHeaderRequestMatcher X_REQUESTED_WITH = new RequestHeaderRequestMatcher( + "X-Requested-With", "XMLHttpRequest"); private static final String DEFAULT_REALM = "Realm"; private AuthenticationEntryPoint authenticationEntryPoint; + private AuthenticationDetailsSource authenticationDetailsSource; + private BasicAuthenticationEntryPoint basicAuthEntryPoint = new BasicAuthenticationEntryPoint(); /** @@ -95,12 +97,9 @@ public final class HttpBasicConfigurer> extends */ public HttpBasicConfigurer() { realmName(DEFAULT_REALM); - LinkedHashMap entryPoints = new LinkedHashMap<>(); entryPoints.put(X_REQUESTED_WITH, new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED)); - - DelegatingAuthenticationEntryPoint defaultEntryPoint = new DelegatingAuthenticationEntryPoint( - entryPoints); + DelegatingAuthenticationEntryPoint defaultEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints); defaultEntryPoint.setDefaultEntryPoint(this.basicAuthEntryPoint); this.authenticationEntryPoint = defaultEntryPoint; } @@ -109,7 +108,6 @@ public final class HttpBasicConfigurer> extends * Allows easily changing the realm, but leaving the remaining defaults in place. If * {@link #authenticationEntryPoint(AuthenticationEntryPoint)} has been invoked, * invoking this method will result in an error. - * * @param realmName the HTTP Basic realm to use * @return {@link HttpBasicConfigurer} for additional customization */ @@ -122,14 +120,11 @@ public final class HttpBasicConfigurer> extends /** * The {@link AuthenticationEntryPoint} to be populated on * {@link BasicAuthenticationFilter} in the event that authentication fails. The - * default to use {@link BasicAuthenticationEntryPoint} with the realm - * "Realm". - * + * default to use {@link BasicAuthenticationEntryPoint} with the realm "Realm". * @param authenticationEntryPoint the {@link AuthenticationEntryPoint} to use * @return {@link HttpBasicConfigurer} for additional customization */ - public HttpBasicConfigurer authenticationEntryPoint( - AuthenticationEntryPoint authenticationEntryPoint) { + public HttpBasicConfigurer authenticationEntryPoint(AuthenticationEntryPoint authenticationEntryPoint) { this.authenticationEntryPoint = authenticationEntryPoint; return this; } @@ -137,7 +132,6 @@ public final class HttpBasicConfigurer> extends /** * Specifies a custom {@link AuthenticationDetailsSource} to use for basic * authentication. The default is {@link WebAuthenticationDetailsSource}. - * * @param authenticationDetailsSource the custom {@link AuthenticationDetailsSource} * to use * @return {@link HttpBasicConfigurer} for additional customization @@ -154,47 +148,38 @@ public final class HttpBasicConfigurer> extends } private void registerDefaults(B http) { - ContentNegotiationStrategy contentNegotiationStrategy = http - .getSharedObject(ContentNegotiationStrategy.class); + ContentNegotiationStrategy contentNegotiationStrategy = http.getSharedObject(ContentNegotiationStrategy.class); if (contentNegotiationStrategy == null) { contentNegotiationStrategy = new HeaderContentNegotiationStrategy(); } - - MediaTypeRequestMatcher restMatcher = new MediaTypeRequestMatcher( - contentNegotiationStrategy, MediaType.APPLICATION_ATOM_XML, - MediaType.APPLICATION_FORM_URLENCODED, MediaType.APPLICATION_JSON, - MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_XML, - MediaType.MULTIPART_FORM_DATA, MediaType.TEXT_XML); + MediaTypeRequestMatcher restMatcher = new MediaTypeRequestMatcher(contentNegotiationStrategy, + MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_FORM_URLENCODED, MediaType.APPLICATION_JSON, + MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_XML, MediaType.MULTIPART_FORM_DATA, + MediaType.TEXT_XML); restMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); - MediaTypeRequestMatcher allMatcher = new MediaTypeRequestMatcher(contentNegotiationStrategy, MediaType.ALL); allMatcher.setUseEquals(true); - RequestMatcher notHtmlMatcher = new NegatedRequestMatcher( - new MediaTypeRequestMatcher(contentNegotiationStrategy, - MediaType.TEXT_HTML)); + new MediaTypeRequestMatcher(contentNegotiationStrategy, MediaType.TEXT_HTML)); RequestMatcher restNotHtmlMatcher = new AndRequestMatcher( Arrays.asList(notHtmlMatcher, restMatcher)); - - RequestMatcher preferredMatcher = new OrRequestMatcher(Arrays.asList(X_REQUESTED_WITH, restNotHtmlMatcher, allMatcher)); - + RequestMatcher preferredMatcher = new OrRequestMatcher( + Arrays.asList(X_REQUESTED_WITH, restNotHtmlMatcher, allMatcher)); registerDefaultEntryPoint(http, preferredMatcher); registerDefaultLogoutSuccessHandler(http, preferredMatcher); } private void registerDefaultEntryPoint(B http, RequestMatcher preferredMatcher) { - ExceptionHandlingConfigurer exceptionHandling = http - .getConfigurer(ExceptionHandlingConfigurer.class); + ExceptionHandlingConfigurer exceptionHandling = http.getConfigurer(ExceptionHandlingConfigurer.class); if (exceptionHandling == null) { return; } - exceptionHandling.defaultAuthenticationEntryPointFor( - postProcess(this.authenticationEntryPoint), preferredMatcher); + exceptionHandling.defaultAuthenticationEntryPointFor(postProcess(this.authenticationEntryPoint), + preferredMatcher); } private void registerDefaultLogoutSuccessHandler(B http, RequestMatcher preferredMatcher) { - LogoutConfigurer logout = http - .getConfigurer(LogoutConfigurer.class); + LogoutConfigurer logout = http.getConfigurer(LogoutConfigurer.class); if (logout == null) { return; } @@ -204,13 +189,11 @@ public final class HttpBasicConfigurer> extends @Override public void configure(B http) { - AuthenticationManager authenticationManager = http - .getSharedObject(AuthenticationManager.class); - BasicAuthenticationFilter basicAuthenticationFilter = new BasicAuthenticationFilter( - authenticationManager, this.authenticationEntryPoint); + AuthenticationManager authenticationManager = http.getSharedObject(AuthenticationManager.class); + BasicAuthenticationFilter basicAuthenticationFilter = new BasicAuthenticationFilter(authenticationManager, + this.authenticationEntryPoint); if (this.authenticationDetailsSource != null) { - basicAuthenticationFilter - .setAuthenticationDetailsSource(this.authenticationDetailsSource); + basicAuthenticationFilter.setAuthenticationDetailsSource(this.authenticationDetailsSource); } RememberMeServices rememberMeServices = http.getSharedObject(RememberMeServices.class); if (rememberMeServices != null) { @@ -219,4 +202,5 @@ public final class HttpBasicConfigurer> extends basicAuthenticationFilter = postProcess(basicAuthenticationFilter); http.addFilter(basicAuthenticationFilter); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurer.java index 45abc6b156..bbd91f045c 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.HashSet; @@ -42,15 +43,13 @@ import org.springframework.security.web.authentication.preauth.j2ee.J2eePreAuthe * The following Filters are populated * *
      - *
    • - * {@link J2eePreAuthenticatedProcessingFilter}
    • + *
    • {@link J2eePreAuthenticatedProcessingFilter}
    • *
    * *

    Shared Objects Created

    * *
      - *
    • - * {@link AuthenticationEntryPoint} is populated with an + *
    • {@link AuthenticationEntryPoint} is populated with an * {@link Http403ForbiddenEntryPoint}
    • *
    • A {@link PreAuthenticatedAuthenticationProvider} is populated into * {@link HttpSecurity#authenticationProvider(org.springframework.security.authentication.AuthenticationProvider)} @@ -68,10 +67,12 @@ import org.springframework.security.web.authentication.preauth.j2ee.J2eePreAuthe * @author Rob Winch * @since 3.2 */ -public final class JeeConfigurer> extends - AbstractHttpConfigurer, H> { +public final class JeeConfigurer> extends AbstractHttpConfigurer, H> { + private J2eePreAuthenticatedProcessingFilter j2eePreAuthenticatedProcessingFilter; + private AuthenticationUserDetailsService authenticationUserDetailsService; + private Set mappableRoles = new HashSet<>(); /** @@ -91,7 +92,6 @@ public final class JeeConfigurer> extends *

      * There are no default roles that are mapped. *

      - * * @param mappableRoles the roles to attempt to map to the {@link UserDetails} (i.e. * "ROLE_USER", "ROLE_ADMIN", etc). * @return the {@link JeeConfigurer} for further customizations @@ -117,7 +117,6 @@ public final class JeeConfigurer> extends *

      * There are no default roles that are mapped. *

      - * * @param mappableRoles the roles to attempt to map to the {@link UserDetails} (i.e. * "USER", "ADMIN", etc). * @return the {@link JeeConfigurer} for further customizations @@ -142,7 +141,6 @@ public final class JeeConfigurer> extends *

      * There are no default roles that are mapped. *

      - * * @param mappableRoles the roles to attempt to map to the {@link UserDetails}. * @return the {@link JeeConfigurer} for further customizations * @see SimpleMappableAttributesRetriever @@ -156,7 +154,6 @@ public final class JeeConfigurer> extends * Specifies the {@link AuthenticationUserDetailsService} that is used with the * {@link PreAuthenticatedAuthenticationProvider}. The default is a * {@link PreAuthenticatedGrantedAuthoritiesUserDetailsService}. - * * @param authenticatedUserDetailsService the {@link AuthenticationUserDetailsService} * to use. * @return the {@link JeeConfigurer} for further configuration @@ -172,7 +169,6 @@ public final class JeeConfigurer> extends * {@link J2eePreAuthenticatedProcessingFilter} is provided, all of its attributes * must also be configured manually (i.e. all attributes populated in the * {@link JeeConfigurer} are not used). - * * @param j2eePreAuthenticatedProcessingFilter the * {@link J2eePreAuthenticatedProcessingFilter} to use. * @return the {@link JeeConfigurer} for further configuration @@ -194,21 +190,15 @@ public final class JeeConfigurer> extends @Override public void init(H http) { PreAuthenticatedAuthenticationProvider authenticationProvider = new PreAuthenticatedAuthenticationProvider(); - authenticationProvider - .setPreAuthenticatedUserDetailsService(getUserDetailsService()); + authenticationProvider.setPreAuthenticatedUserDetailsService(getUserDetailsService()); authenticationProvider = postProcess(authenticationProvider); - - // @formatter:off - http - .authenticationProvider(authenticationProvider) - .setSharedObject(AuthenticationEntryPoint.class, new Http403ForbiddenEntryPoint()); - // @formatter:on + http.authenticationProvider(authenticationProvider).setSharedObject(AuthenticationEntryPoint.class, + new Http403ForbiddenEntryPoint()); } @Override public void configure(H http) { - J2eePreAuthenticatedProcessingFilter filter = getFilter(http - .getSharedObject(AuthenticationManager.class)); + J2eePreAuthenticatedProcessingFilter filter = getFilter(http.getSharedObject(AuthenticationManager.class)); http.addFilter(filter); } @@ -218,45 +208,41 @@ public final class JeeConfigurer> extends * @param authenticationManager the {@link AuthenticationManager} to use. * @return the {@link J2eePreAuthenticatedProcessingFilter} to use. */ - private J2eePreAuthenticatedProcessingFilter getFilter( - AuthenticationManager authenticationManager) { - if (j2eePreAuthenticatedProcessingFilter == null) { - j2eePreAuthenticatedProcessingFilter = new J2eePreAuthenticatedProcessingFilter(); - j2eePreAuthenticatedProcessingFilter - .setAuthenticationManager(authenticationManager); - j2eePreAuthenticatedProcessingFilter + private J2eePreAuthenticatedProcessingFilter getFilter(AuthenticationManager authenticationManager) { + if (this.j2eePreAuthenticatedProcessingFilter == null) { + this.j2eePreAuthenticatedProcessingFilter = new J2eePreAuthenticatedProcessingFilter(); + this.j2eePreAuthenticatedProcessingFilter.setAuthenticationManager(authenticationManager); + this.j2eePreAuthenticatedProcessingFilter .setAuthenticationDetailsSource(createWebAuthenticationDetailsSource()); - j2eePreAuthenticatedProcessingFilter = postProcess(j2eePreAuthenticatedProcessingFilter); + this.j2eePreAuthenticatedProcessingFilter = postProcess(this.j2eePreAuthenticatedProcessingFilter); } - return j2eePreAuthenticatedProcessingFilter; + return this.j2eePreAuthenticatedProcessingFilter; } /** * Gets the {@link AuthenticationUserDetailsService} that was specified or defaults to * {@link PreAuthenticatedGrantedAuthoritiesUserDetailsService}. - * * @return the {@link AuthenticationUserDetailsService} to use */ private AuthenticationUserDetailsService getUserDetailsService() { - return authenticationUserDetailsService == null ? new PreAuthenticatedGrantedAuthoritiesUserDetailsService() - : authenticationUserDetailsService; + return (this.authenticationUserDetailsService != null) ? this.authenticationUserDetailsService + : new PreAuthenticatedGrantedAuthoritiesUserDetailsService(); } /** * Creates the {@link J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource} to set * on the {@link J2eePreAuthenticatedProcessingFilter}. It is populated with a * {@link SimpleMappableAttributesRetriever}. - * * @return the {@link J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource} to use. */ private J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource createWebAuthenticationDetailsSource() { J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource detailsSource = new J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource(); SimpleMappableAttributesRetriever rolesRetriever = new SimpleMappableAttributesRetriever(); - rolesRetriever.setMappableAttributes(mappableRoles); + rolesRetriever.setMappableAttributes(this.mappableRoles); detailsSource.setMappableRolesRetriever(rolesRetriever); - detailsSource = postProcess(detailsSource); return detailsSource; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurer.java index 4f5d09833f..a86d8339dc 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.ArrayList; @@ -41,15 +42,15 @@ import org.springframework.util.Assert; /** * Adds logout support. Other {@link SecurityConfigurer} instances may invoke - * {@link #addLogoutHandler(LogoutHandler)} in the {@link #init(HttpSecurityBuilder)} phase. + * {@link #addLogoutHandler(LogoutHandler)} in the {@link #init(HttpSecurityBuilder)} + * phase. * *

      Security Filters

      * * The following Filters are populated * *
        - *
      • - * {@link LogoutFilter}
      • + *
      • {@link LogoutFilter}
      • *
      * *

      Shared Objects Created

      @@ -65,19 +66,26 @@ import org.springframework.util.Assert; * @since 3.2 * @see RememberMeConfigurer */ -public final class LogoutConfigurer> extends - AbstractHttpConfigurer, H> { +public final class LogoutConfigurer> + extends AbstractHttpConfigurer, H> { + private List logoutHandlers = new ArrayList<>(); + private SecurityContextLogoutHandler contextLogoutHandler = new SecurityContextLogoutHandler(); + private String logoutSuccessUrl = "/login?logout"; + private LogoutSuccessHandler logoutSuccessHandler; + private String logoutUrl = "/logout"; + private RequestMatcher logoutRequestMatcher; + private boolean permitAll; + private boolean customLogoutSuccess; - private LinkedHashMap defaultLogoutSuccessHandlerMappings = - new LinkedHashMap<>(); + private LinkedHashMap defaultLogoutSuccessHandlerMappings = new LinkedHashMap<>(); /** * Creates a new instance @@ -87,10 +95,9 @@ public final class LogoutConfigurer> extends } /** - * Adds a {@link LogoutHandler}. - * {@link SecurityContextLogoutHandler} and {@link LogoutSuccessEventPublishingLogoutHandler} are added as - * last {@link LogoutHandler} instances by default. - * + * Adds a {@link LogoutHandler}. {@link SecurityContextLogoutHandler} and + * {@link LogoutSuccessEventPublishingLogoutHandler} are added as last + * {@link LogoutHandler} instances by default. * @param logoutHandler the {@link LogoutHandler} to add * @return the {@link LogoutConfigurer} for further customization */ @@ -101,16 +108,17 @@ public final class LogoutConfigurer> extends } /** - * Specifies if {@link SecurityContextLogoutHandler} should clear the {@link Authentication} at the time of logout. - * @param clearAuthentication true {@link SecurityContextLogoutHandler} should clear the {@link Authentication} (default), or false otherwise. + * Specifies if {@link SecurityContextLogoutHandler} should clear the + * {@link Authentication} at the time of logout. + * @param clearAuthentication true {@link SecurityContextLogoutHandler} should clear + * the {@link Authentication} (default), or false otherwise. * @return the {@link LogoutConfigurer} for further customization */ public LogoutConfigurer clearAuthentication(boolean clearAuthentication) { - contextLogoutHandler.setClearAuthentication(clearAuthentication); + this.contextLogoutHandler.setClearAuthentication(clearAuthentication); return this; } - /** * Configures {@link SecurityContextLogoutHandler} to invalidate the * {@link HttpSession} at the time of logout. @@ -119,7 +127,7 @@ public final class LogoutConfigurer> extends * @return the {@link LogoutConfigurer} for further customization */ public LogoutConfigurer invalidateHttpSession(boolean invalidateHttpSession) { - contextLogoutHandler.setInvalidateHttpSession(invalidateHttpSession); + this.contextLogoutHandler.setInvalidateHttpSession(invalidateHttpSession); return this; } @@ -131,17 +139,15 @@ public final class LogoutConfigurer> extends * *

      * It is considered best practice to use an HTTP POST on any action that changes state - * (i.e. log out) to protect against CSRF attacks. If - * you really want to use an HTTP GET, you can use + * (i.e. log out) to protect against + * CSRF + * attacks. If you really want to use an HTTP GET, you can use * logoutRequestMatcher(new AntPathRequestMatcher(logoutUrl, "GET")); *

      - * - * @see #logoutRequestMatcher(RequestMatcher) - * @see HttpSecurity#csrf() - * * @param logoutUrl the URL that will invoke logout. * @return the {@link LogoutConfigurer} for further customization + * @see #logoutRequestMatcher(RequestMatcher) + * @see HttpSecurity#csrf() */ public LogoutConfigurer logoutUrl(String logoutUrl) { this.logoutRequestMatcher = null; @@ -152,12 +158,10 @@ public final class LogoutConfigurer> extends /** * The RequestMatcher that triggers log out to occur. In most circumstances users will * use {@link #logoutUrl(String)} which helps enforce good practices. - * - * @see #logoutUrl(String) - * * @param logoutRequestMatcher the RequestMatcher used to determine if logout should * occur. * @return the {@link LogoutConfigurer} for further customization + * @see #logoutUrl(String) */ public LogoutConfigurer logoutRequestMatcher(RequestMatcher logoutRequestMatcher) { this.logoutRequestMatcher = logoutRequestMatcher; @@ -168,7 +172,6 @@ public final class LogoutConfigurer> extends * The URL to redirect to after logout has occurred. The default is "/login?logout". * This is a shortcut for invoking {@link #logoutSuccessHandler(LogoutSuccessHandler)} * with a {@link SimpleUrlLogoutSuccessHandler}. - * * @param logoutSuccessUrl the URL to redirect to after logout occurred * @return the {@link LogoutConfigurer} for further customization */ @@ -190,7 +193,6 @@ public final class LogoutConfigurer> extends * Allows specifying the names of cookies to be removed on logout success. This is a * shortcut to easily invoke {@link #addLogoutHandler(LogoutHandler)} with a * {@link CookieClearingLogoutHandler}. - * * @param cookieNamesToClear the names of cookies to be removed on logout success. * @return the {@link LogoutConfigurer} for further customization */ @@ -201,13 +203,11 @@ public final class LogoutConfigurer> extends /** * Sets the {@link LogoutSuccessHandler} to use. If this is specified, * {@link #logoutSuccessUrl(String)} is ignored. - * * @param logoutSuccessHandler the {@link LogoutSuccessHandler} to use after a user * has been logged out. * @return the {@link LogoutConfigurer} for further customizations */ - public LogoutConfigurer logoutSuccessHandler( - LogoutSuccessHandler logoutSuccessHandler) { + public LogoutConfigurer logoutSuccessHandler(LogoutSuccessHandler logoutSuccessHandler) { this.logoutSuccessUrl = null; this.customLogoutSuccess = true; this.logoutSuccessHandler = logoutSuccessHandler; @@ -217,18 +217,17 @@ public final class LogoutConfigurer> extends /** * Sets a default {@link LogoutSuccessHandler} to be used which prefers being invoked * for the provided {@link RequestMatcher}. If no {@link LogoutSuccessHandler} is - * specified a {@link SimpleUrlLogoutSuccessHandler} will be used. - * If any default {@link LogoutSuccessHandler} instances are configured, then a + * specified a {@link SimpleUrlLogoutSuccessHandler} will be used. If any default + * {@link LogoutSuccessHandler} instances are configured, then a * {@link DelegatingLogoutSuccessHandler} will be used that defaults to a * {@link SimpleUrlLogoutSuccessHandler}. - * * @param handler the {@link LogoutSuccessHandler} to use * @param preferredMatcher the {@link RequestMatcher} for this default * {@link LogoutSuccessHandler} * @return the {@link LogoutConfigurer} for further customizations */ - public LogoutConfigurer defaultLogoutSuccessHandlerFor( - LogoutSuccessHandler handler, RequestMatcher preferredMatcher) { + public LogoutConfigurer defaultLogoutSuccessHandlerFor(LogoutSuccessHandler handler, + RequestMatcher preferredMatcher) { Assert.notNull(handler, "handler cannot be null"); Assert.notNull(preferredMatcher, "preferredMatcher cannot be null"); this.defaultLogoutSuccessHandlerMappings.put(preferredMatcher, handler); @@ -238,7 +237,6 @@ public final class LogoutConfigurer> extends /** * Grants access to the {@link #logoutSuccessUrl(String)} and the * {@link #logoutUrl(String)} for every user. - * * @param permitAll if true grants access, else nothing is done * @return the {@link LogoutConfigurer} for further customization. */ @@ -250,7 +248,6 @@ public final class LogoutConfigurer> extends /** * Gets the {@link LogoutSuccessHandler} if not null, otherwise creates a new * {@link SimpleUrlLogoutSuccessHandler} using the {@link #logoutSuccessUrl(String)}. - * * @return the {@link LogoutSuccessHandler} to use */ private LogoutSuccessHandler getLogoutSuccessHandler() { @@ -263,22 +260,22 @@ public final class LogoutConfigurer> extends private LogoutSuccessHandler createDefaultSuccessHandler() { SimpleUrlLogoutSuccessHandler urlLogoutHandler = new SimpleUrlLogoutSuccessHandler(); - urlLogoutHandler.setDefaultTargetUrl(logoutSuccessUrl); - if (defaultLogoutSuccessHandlerMappings.isEmpty()) { + urlLogoutHandler.setDefaultTargetUrl(this.logoutSuccessUrl); + if (this.defaultLogoutSuccessHandlerMappings.isEmpty()) { return urlLogoutHandler; } - DelegatingLogoutSuccessHandler successHandler = new DelegatingLogoutSuccessHandler(defaultLogoutSuccessHandlerMappings); + DelegatingLogoutSuccessHandler successHandler = new DelegatingLogoutSuccessHandler( + this.defaultLogoutSuccessHandlerMappings); successHandler.setDefaultLogoutSuccessHandler(urlLogoutHandler); return successHandler; } @Override public void init(H http) { - if (permitAll) { + if (this.permitAll) { PermitAllSupport.permitAll(http, this.logoutSuccessUrl); PermitAllSupport.permitAll(http, this.getLogoutRequestMatcher(http)); } - DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http .getSharedObject(DefaultLoginPageGeneratingFilter.class); if (loginPageGeneratingFilter != null && !isCustomLogoutSuccess()) { @@ -296,21 +293,19 @@ public final class LogoutConfigurer> extends * Returns true if the logout success has been customized via * {@link #logoutSuccessUrl(String)} or * {@link #logoutSuccessHandler(LogoutSuccessHandler)}. - * * @return true if logout success handling has been customized, else false */ boolean isCustomLogoutSuccess() { - return customLogoutSuccess; + return this.customLogoutSuccess; } /** * Gets the logoutSuccesUrl or null if a * {@link #logoutSuccessHandler(LogoutSuccessHandler)} was configured. - * * @return the logoutSuccessUrl */ private String getLogoutSuccessUrl() { - return logoutSuccessUrl; + return this.logoutSuccessUrl; } /** @@ -318,44 +313,48 @@ public final class LogoutConfigurer> extends * @return the {@link LogoutHandler} instances. Cannot be null. */ List getLogoutHandlers() { - return logoutHandlers; + return this.logoutHandlers; } /** * Creates the {@link LogoutFilter} using the {@link LogoutHandler} instances, the * {@link #logoutSuccessHandler(LogoutSuccessHandler)} and the * {@link #logoutUrl(String)}. - * * @param http the builder to use * @return the {@link LogoutFilter} to use. */ private LogoutFilter createLogoutFilter(H http) { - logoutHandlers.add(contextLogoutHandler); - logoutHandlers.add(postProcess(new LogoutSuccessEventPublishingLogoutHandler())); - LogoutHandler[] handlers = logoutHandlers - .toArray(new LogoutHandler[0]); + this.logoutHandlers.add(this.contextLogoutHandler); + this.logoutHandlers.add(postProcess(new LogoutSuccessEventPublishingLogoutHandler())); + LogoutHandler[] handlers = this.logoutHandlers.toArray(new LogoutHandler[0]); LogoutFilter result = new LogoutFilter(getLogoutSuccessHandler(), handlers); result.setLogoutRequestMatcher(getLogoutRequestMatcher(http)); result = postProcess(result); return result; } - @SuppressWarnings("unchecked") private RequestMatcher getLogoutRequestMatcher(H http) { - if (logoutRequestMatcher != null) { - return logoutRequestMatcher; - } - if (http.getConfigurer(CsrfConfigurer.class) != null) { - this.logoutRequestMatcher = new AntPathRequestMatcher(this.logoutUrl, "POST"); - } - else { - this.logoutRequestMatcher = new OrRequestMatcher( - new AntPathRequestMatcher(this.logoutUrl, "GET"), - new AntPathRequestMatcher(this.logoutUrl, "POST"), - new AntPathRequestMatcher(this.logoutUrl, "PUT"), - new AntPathRequestMatcher(this.logoutUrl, "DELETE") - ); + if (this.logoutRequestMatcher != null) { + return this.logoutRequestMatcher; } + this.logoutRequestMatcher = createLogoutRequestMatcher(http); return this.logoutRequestMatcher; } + + @SuppressWarnings("unchecked") + private RequestMatcher createLogoutRequestMatcher(H http) { + RequestMatcher post = createLogoutRequestMatcher("POST"); + if (http.getConfigurer(CsrfConfigurer.class) != null) { + return post; + } + RequestMatcher get = createLogoutRequestMatcher("GET"); + RequestMatcher put = createLogoutRequestMatcher("PUT"); + RequestMatcher delete = createLogoutRequestMatcher("DELETE"); + return new OrRequestMatcher(get, post, put, delete); + } + + private RequestMatcher createLogoutRequestMatcher(String httpMethod) { + return new AntPathRequestMatcher(this.logoutUrl, httpMethod); + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupport.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupport.java index c500e9e283..3af0eba172 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupport.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupport.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -21,16 +22,20 @@ import org.springframework.security.access.SecurityConfig; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.configurers.AbstractConfigAttributeRequestMatcherRegistry.UrlMapping; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; /** * Configures non-null URL's to grant access to every URL + * * @author Rob Winch * @since 3.2 */ final class PermitAllSupport { - public static void permitAll( - HttpSecurityBuilder> http, String... urls) { + private PermitAllSupport() { + } + + static void permitAll(HttpSecurityBuilder> http, String... urls) { for (String url : urls) { if (url != null) { permitAll(http, new ExactUrlRequestMatcher(url)); @@ -39,61 +44,47 @@ final class PermitAllSupport { } @SuppressWarnings("unchecked") - public static void permitAll( - HttpSecurityBuilder> http, + static void permitAll(HttpSecurityBuilder> http, RequestMatcher... requestMatchers) { ExpressionUrlAuthorizationConfigurer configurer = http .getConfigurer(ExpressionUrlAuthorizationConfigurer.class); - - if (configurer == null) { - throw new IllegalStateException( - "permitAll only works with HttpSecurity.authorizeRequests()"); - } - + Assert.state(configurer != null, "permitAll only works with HttpSecurity.authorizeRequests()"); for (RequestMatcher matcher : requestMatchers) { if (matcher != null) { - configurer - .getRegistry() - .addMapping( - 0, - new UrlMapping( - matcher, - SecurityConfig - .createList(ExpressionUrlAuthorizationConfigurer.permitAll))); + configurer.getRegistry().addMapping(0, new UrlMapping(matcher, + SecurityConfig.createList(ExpressionUrlAuthorizationConfigurer.permitAll))); } } } - private final static class ExactUrlRequestMatcher implements RequestMatcher { + private static final class ExactUrlRequestMatcher implements RequestMatcher { + private String processUrl; private ExactUrlRequestMatcher(String processUrl) { this.processUrl = processUrl; } + @Override public boolean matches(HttpServletRequest request) { String uri = request.getRequestURI(); String query = request.getQueryString(); - if (query != null) { uri += "?" + query; } - if ("".equals(request.getContextPath())) { - return uri.equals(processUrl); + return uri.equals(this.processUrl); } - - return uri.equals(request.getContextPath() + processUrl); + return uri.equals(request.getContextPath() + this.processUrl); } @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("ExactUrl [processUrl='").append(processUrl).append("']"); + sb.append("ExactUrl [processUrl='").append(this.processUrl).append("']"); return sb.toString(); } + } - private PermitAllSupport() { - } -} \ No newline at end of file +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurer.java index 74f77c1b57..60004d0c73 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.HashMap; @@ -31,9 +32,11 @@ import org.springframework.security.web.PortMapperImpl; * @author Rob Winch * @since 3.2 */ -public final class PortMapperConfigurer> extends - AbstractHttpConfigurer, H> { +public final class PortMapperConfigurer> + extends AbstractHttpConfigurer, H> { + private PortMapper portMapper; + private Map httpsPortMappings = new HashMap<>(); /** @@ -70,16 +73,15 @@ public final class PortMapperConfigurer> extend * Gets the {@link PortMapper} to use. If {@link #portMapper(PortMapper)} was not * invoked, builds a {@link PortMapperImpl} using the port mappings specified with * {@link #http(int)}. - * * @return the {@link PortMapper} to use */ private PortMapper getPortMapper() { - if (portMapper == null) { + if (this.portMapper == null) { PortMapperImpl portMapper = new PortMapperImpl(); - portMapper.setPortMappings(httpsPortMappings); + portMapper.setPortMappings(this.httpsPortMappings); this.portMapper = portMapper; } - return portMapper; + return this.portMapper; } /** @@ -90,6 +92,7 @@ public final class PortMapperConfigurer> extend * @since 3.2 */ public final class HttpPortMapping { + private final int httpPort; /** @@ -107,8 +110,10 @@ public final class PortMapperConfigurer> extend * @return the {@link PortMapperConfigurer} for further customization */ public PortMapperConfigurer mapsTo(int httpsPort) { - httpsPortMappings.put(String.valueOf(httpPort), String.valueOf(httpsPort)); + PortMapperConfigurer.this.httpsPortMappings.put(String.valueOf(this.httpPort), String.valueOf(httpsPort)); return PortMapperConfigurer.this; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java index 16e1cdfe51..d6d352995d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.UUID; @@ -34,6 +35,7 @@ import org.springframework.security.web.authentication.rememberme.PersistentToke import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter; import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.util.Assert; /** * Configures Remember Me authentication. This typically involves the user checking a box @@ -79,21 +81,34 @@ import org.springframework.security.web.authentication.ui.DefaultLoginPageGenera */ public final class RememberMeConfigurer> extends AbstractHttpConfigurer, H> { + /** * The default name for remember me parameter name and remember me cookie name */ private static final String DEFAULT_REMEMBER_ME_NAME = "remember-me"; + private AuthenticationSuccessHandler authenticationSuccessHandler; + private String key; + private RememberMeServices rememberMeServices; + private LogoutHandler logoutHandler; + private String rememberMeParameter = DEFAULT_REMEMBER_ME_NAME; + private String rememberMeCookieName = DEFAULT_REMEMBER_ME_NAME; + private String rememberMeCookieDomain; + private PersistentTokenRepository tokenRepository; + private UserDetailsService userDetailsService; + private Integer tokenValiditySeconds; + private Boolean useSecureCookie; + private Boolean alwaysRemember; /** @@ -104,7 +119,6 @@ public final class RememberMeConfigurer> /** * Allows specifying how long (in seconds) a token is valid for - * * @param tokenValiditySeconds * @return {@link RememberMeConfigurer} for further customization * @see AbstractRememberMeServices#setTokenValiditySeconds(int) @@ -122,7 +136,6 @@ public final class RememberMeConfigurer> * By default the cookie will be secure if the request is secure. If you only want to * use remember-me over HTTPS (recommended) you should set this property to * {@code true}. - * * @param useSecureCookie set to {@code true} to always user secure cookies, * {@code false} to disable their use. * @return the {@link RememberMeConfigurer} for further customization @@ -140,13 +153,11 @@ public final class RememberMeConfigurer> * {@link HttpSecurity#getSharedObject(Class)} which is set when using * {@link WebSecurityConfigurerAdapter#configure(AuthenticationManagerBuilder)}. * Alternatively, one can populate {@link #rememberMeServices(RememberMeServices)}. - * * @param userDetailsService the {@link UserDetailsService} to configure * @return the {@link RememberMeConfigurer} for further customization * @see AbstractRememberMeServices */ - public RememberMeConfigurer userDetailsService( - UserDetailsService userDetailsService) { + public RememberMeConfigurer userDetailsService(UserDetailsService userDetailsService) { this.userDetailsService = userDetailsService; return this; } @@ -154,23 +165,19 @@ public final class RememberMeConfigurer> /** * Specifies the {@link PersistentTokenRepository} to use. The default is to use * {@link TokenBasedRememberMeServices} instead. - * * @param tokenRepository the {@link PersistentTokenRepository} to use * @return the {@link RememberMeConfigurer} for further customization */ - public RememberMeConfigurer tokenRepository( - PersistentTokenRepository tokenRepository) { + public RememberMeConfigurer tokenRepository(PersistentTokenRepository tokenRepository) { this.tokenRepository = tokenRepository; return this; } /** * Sets the key to identify tokens created for remember me authentication. Default is - * a secure randomly generated key. - * If {@link #rememberMeServices(RememberMeServices)} is specified and is of type - * {@link AbstractRememberMeServices}, then the default is the key set in - * {@link AbstractRememberMeServices}. - * + * a secure randomly generated key. If {@link #rememberMeServices(RememberMeServices)} + * is specified and is of type {@link AbstractRememberMeServices}, then the default is + * the key set in {@link AbstractRememberMeServices}. * @param key the key to identify tokens created for remember me authentication * @return the {@link RememberMeConfigurer} for further customization */ @@ -181,7 +188,6 @@ public final class RememberMeConfigurer> /** * The HTTP parameter used to indicate to remember the user at time of login. - * * @param rememberMeParameter the HTTP parameter used to indicate to remember the user * @return the {@link RememberMeConfigurer} for further customization */ @@ -193,7 +199,6 @@ public final class RememberMeConfigurer> /** * The name of cookie which store the token for remember me authentication. Defaults * to 'remember-me'. - * * @param rememberMeCookieName the name of cookie which store the token for remember * me authentication * @return the {@link RememberMeConfigurer} for further customization @@ -206,7 +211,6 @@ public final class RememberMeConfigurer> /** * The domain name within which the remember me cookie is visible. - * * @param rememberMeCookieDomain the domain name within which the remember me cookie * is visible. * @return the {@link RememberMeConfigurer} for further customization @@ -224,7 +228,6 @@ public final class RememberMeConfigurer> * be invoked and the {@code doFilter()} method will return immediately, thus allowing * the application to redirect the user to a specific URL, regardless of what the * original request was for. - * * @param authenticationSuccessHandler the strategy to invoke immediately before * returning from {@code doFilter()}. * @return {@link RememberMeConfigurer} for further customization @@ -242,8 +245,7 @@ public final class RememberMeConfigurer> * @return the {@link RememberMeConfigurer} for further customizations * @see RememberMeServices */ - public RememberMeConfigurer rememberMeServices( - RememberMeServices rememberMeServices) { + public RememberMeConfigurer rememberMeServices(RememberMeServices rememberMeServices) { this.rememberMeServices = rememberMeServices; return this; } @@ -253,7 +255,6 @@ public final class RememberMeConfigurer> * not set. *

      * By default this will be set to {@code false}. - * * @param alwaysRemember set to {@code true} to always trigger remember me, * {@code false} to use the remember-me parameter. * @return the {@link RememberMeConfigurer} for further customization @@ -275,36 +276,30 @@ public final class RememberMeConfigurer> if (logoutConfigurer != null && this.logoutHandler != null) { logoutConfigurer.addLogoutHandler(this.logoutHandler); } - - RememberMeAuthenticationProvider authenticationProvider = new RememberMeAuthenticationProvider( - key); + RememberMeAuthenticationProvider authenticationProvider = new RememberMeAuthenticationProvider(key); authenticationProvider = postProcess(authenticationProvider); http.authenticationProvider(authenticationProvider); - initDefaultLoginFilter(http); } @Override public void configure(H http) { RememberMeAuthenticationFilter rememberMeFilter = new RememberMeAuthenticationFilter( - http.getSharedObject(AuthenticationManager.class), - this.rememberMeServices); + http.getSharedObject(AuthenticationManager.class), this.rememberMeServices); if (this.authenticationSuccessHandler != null) { - rememberMeFilter - .setAuthenticationSuccessHandler(this.authenticationSuccessHandler); + rememberMeFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler); } rememberMeFilter = postProcess(rememberMeFilter); http.addFilter(rememberMeFilter); } /** - * Validate rememberMeServices and rememberMeCookieName have not been set at - * the same time. + * Validate rememberMeServices and rememberMeCookieName have not been set at the same + * time. */ private void validateInput() { - if (this.rememberMeServices != null && this.rememberMeCookieName != DEFAULT_REMEMBER_ME_NAME) { - throw new IllegalArgumentException("Can not set rememberMeCookieName " + - "and custom rememberMeServices."); + if (this.rememberMeServices != null && !DEFAULT_REMEMBER_ME_NAME.equals(this.rememberMeCookieName)) { + throw new IllegalArgumentException("Can not set rememberMeCookieName and custom rememberMeServices."); } } @@ -319,7 +314,6 @@ public final class RememberMeConfigurer> /** * If available, initializes the {@link DefaultLoginPageGeneratingFilter} shared * object. - * * @param http the {@link HttpSecurityBuilder} to use */ private void initDefaultLoginFilter(H http) { @@ -337,17 +331,14 @@ public final class RememberMeConfigurer> * @return the {@link RememberMeServices} to use * @throws Exception */ - private RememberMeServices getRememberMeServices(H http, String key) - throws Exception { + private RememberMeServices getRememberMeServices(H http, String key) throws Exception { if (this.rememberMeServices != null) { - if (this.rememberMeServices instanceof LogoutHandler - && this.logoutHandler == null) { + if (this.rememberMeServices instanceof LogoutHandler && this.logoutHandler == null) { this.logoutHandler = (LogoutHandler) this.rememberMeServices; } return this.rememberMeServices; } - AbstractRememberMeServices tokenRememberMeServices = createRememberMeServices( - http, key); + AbstractRememberMeServices tokenRememberMeServices = createRememberMeServices(http, key); tokenRememberMeServices.setParameter(this.rememberMeParameter); tokenRememberMeServices.setCookieName(this.rememberMeCookieName); if (this.rememberMeCookieDomain != null) { @@ -372,49 +363,41 @@ public final class RememberMeConfigurer> * Creates the {@link RememberMeServices} to use when none is provided. The result is * either {@link PersistentTokenRepository} (if a {@link PersistentTokenRepository} is * specified, else {@link TokenBasedRememberMeServices}. - * * @param http the {@link HttpSecurity} to lookup shared objects * @param key the {@link #key(String)} * @return the {@link RememberMeServices} to use */ private AbstractRememberMeServices createRememberMeServices(H http, String key) { - return this.tokenRepository == null - ? createTokenBasedRememberMeServices(http, key) - : createPersistentRememberMeServices(http, key); + return (this.tokenRepository != null) ? createPersistentRememberMeServices(http, key) + : createTokenBasedRememberMeServices(http, key); } /** * Creates {@link TokenBasedRememberMeServices} - * * @param http the {@link HttpSecurity} to lookup shared objects * @param key the {@link #key(String)} * @return the {@link TokenBasedRememberMeServices} */ - private AbstractRememberMeServices createTokenBasedRememberMeServices(H http, - String key) { + private AbstractRememberMeServices createTokenBasedRememberMeServices(H http, String key) { UserDetailsService userDetailsService = getUserDetailsService(http); return new TokenBasedRememberMeServices(key, userDetailsService); } /** * Creates {@link PersistentTokenBasedRememberMeServices} - * * @param http the {@link HttpSecurity} to lookup shared objects * @param key the {@link #key(String)} * @return the {@link PersistentTokenBasedRememberMeServices} */ - private AbstractRememberMeServices createPersistentRememberMeServices(H http, - String key) { + private AbstractRememberMeServices createPersistentRememberMeServices(H http, String key) { UserDetailsService userDetailsService = getUserDetailsService(http); - return new PersistentTokenBasedRememberMeServices(key, userDetailsService, - this.tokenRepository); + return new PersistentTokenBasedRememberMeServices(key, userDetailsService, this.tokenRepository); } /** * Gets the {@link UserDetailsService} to use. Either the explicitly configure * {@link UserDetailsService} from {@link #userDetailsService(UserDetailsService)} or * a shared object from {@link HttpSecurity#getSharedObject(Class)}. - * * @param http {@link HttpSecurity} to get the shared {@link UserDetailsService} * @return the {@link UserDetailsService} to use */ @@ -422,32 +405,30 @@ public final class RememberMeConfigurer> if (this.userDetailsService == null) { this.userDetailsService = http.getSharedObject(UserDetailsService.class); } - if (this.userDetailsService == null) { - throw new IllegalStateException("userDetailsService cannot be null. Invoke " - + RememberMeConfigurer.class.getSimpleName() - + "#userDetailsService(UserDetailsService) or see its javadoc for alternative approaches."); - } + Assert.state(this.userDetailsService != null, + () -> "userDetailsService cannot be null. Invoke " + RememberMeConfigurer.class.getSimpleName() + + "#userDetailsService(UserDetailsService) or see its javadoc for alternative approaches."); return this.userDetailsService; } /** * Gets the key to use for validating remember me tokens. If a value was passed into - * {@link #key(String)}, then that is returned. - * Alternatively, if a key was specified in the - * {@link #rememberMeServices(RememberMeServices)}}, then that is returned. - * If no key was specified in either of those cases, then a secure random string is + * {@link #key(String)}, then that is returned. Alternatively, if a key was specified + * in the {@link #rememberMeServices(RememberMeServices)}}, then that is returned. If + * no key was specified in either of those cases, then a secure random string is * generated. - * * @return the remember me key to use */ private String getKey() { if (this.key == null) { if (this.rememberMeServices instanceof AbstractRememberMeServices) { - this.key = ((AbstractRememberMeServices) rememberMeServices).getKey(); - } else { + this.key = ((AbstractRememberMeServices) this.rememberMeServices).getKey(); + } + else { this.key = UUID.randomUUID().toString(); } } return this.key; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurer.java index 3ac5d73e89..e7eae3f283 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.ArrayList; @@ -69,8 +70,8 @@ import org.springframework.web.accept.HeaderContentNegotiationStrategy; * @since 3.2 * @see RequestCache */ -public final class RequestCacheConfigurer> extends - AbstractHttpConfigurer, H> { +public final class RequestCacheConfigurer> + extends AbstractHttpConfigurer, H> { public RequestCacheConfigurer() { } @@ -79,7 +80,6 @@ public final class RequestCacheConfigurer> exte * Allows explicit configuration of the {@link RequestCache} to be used. Defaults to * try finding a {@link RequestCache} as a shared object. Then falls back to a * {@link HttpSessionRequestCache}. - * * @param requestCache the explicit {@link RequestCache} to use * @return the {@link RequestCacheConfigurer} for further customization */ @@ -102,8 +102,7 @@ public final class RequestCacheConfigurer> exte @Override public void configure(H http) { RequestCache requestCache = getRequestCache(http); - RequestCacheAwareFilter requestCacheFilter = new RequestCacheAwareFilter( - requestCache); + RequestCacheAwareFilter requestCacheFilter = new RequestCacheAwareFilter(requestCache); requestCacheFilter = postProcess(requestCacheFilter); http.addFilter(requestCacheFilter); } @@ -113,7 +112,6 @@ public final class RequestCacheConfigurer> exte * {@link #requestCache(org.springframework.security.web.savedrequest.RequestCache)}, * then it is used. Otherwise, an attempt to find a {@link RequestCache} shared object * is made. If that fails, an {@link HttpSessionRequestCache} is used - * * @param http the {@link HttpSecurity} to attempt to fined the shared object * @return the {@link RequestCache} to use */ @@ -138,21 +136,18 @@ public final class RequestCacheConfigurer> exte } try { return context.getBean(type); - } catch (NoSuchBeanDefinitionException e) { + } + catch (NoSuchBeanDefinitionException ex) { return null; } } @SuppressWarnings("unchecked") private RequestMatcher createDefaultSavedRequestMatcher(H http) { - RequestMatcher notFavIcon = new NegatedRequestMatcher(new AntPathRequestMatcher( - "/**/favicon.*")); - + RequestMatcher notFavIcon = new NegatedRequestMatcher(new AntPathRequestMatcher("/**/favicon.*")); RequestMatcher notXRequestedWith = new NegatedRequestMatcher( new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest")); - boolean isCsrfEnabled = http.getConfigurer(CsrfConfigurer.class) != null; - List matchers = new ArrayList<>(); if (isCsrfEnabled) { RequestMatcher getRequests = new AntPathRequestMatcher("/**", "GET"); @@ -163,7 +158,6 @@ public final class RequestCacheConfigurer> exte matchers.add(notXRequestedWith); matchers.add(notMatchingMediaType(http, MediaType.MULTIPART_FORM_DATA)); matchers.add(notMatchingMediaType(http, MediaType.TEXT_EVENT_STREAM)); - return new AndRequestMatcher(matchers); } @@ -172,9 +166,9 @@ public final class RequestCacheConfigurer> exte if (contentNegotiationStrategy == null) { contentNegotiationStrategy = new HeaderContentNegotiationStrategy(); } - MediaTypeRequestMatcher mediaRequest = new MediaTypeRequestMatcher(contentNegotiationStrategy, mediaType); mediaRequest.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); return new NegatedRequestMatcher(mediaRequest); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurer.java index c514989fe9..2139961a00 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; @@ -58,8 +59,8 @@ import org.springframework.security.web.context.SecurityContextRepository; * @author Rob Winch * @since 3.2 */ -public final class SecurityContextConfigurer> extends - AbstractHttpConfigurer, H> { +public final class SecurityContextConfigurer> + extends AbstractHttpConfigurer, H> { /** * Creates a new instance @@ -73,32 +74,28 @@ public final class SecurityContextConfigurer> e * @param securityContextRepository the {@link SecurityContextRepository} to use * @return the {@link HttpSecurity} for further customizations */ - public SecurityContextConfigurer securityContextRepository( - SecurityContextRepository securityContextRepository) { - getBuilder().setSharedObject(SecurityContextRepository.class, - securityContextRepository); + public SecurityContextConfigurer securityContextRepository(SecurityContextRepository securityContextRepository) { + getBuilder().setSharedObject(SecurityContextRepository.class, securityContextRepository); return this; } @Override @SuppressWarnings("unchecked") public void configure(H http) { - - SecurityContextRepository securityContextRepository = http - .getSharedObject(SecurityContextRepository.class); + SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class); if (securityContextRepository == null) { securityContextRepository = new HttpSessionSecurityContextRepository(); } SecurityContextPersistenceFilter securityContextFilter = new SecurityContextPersistenceFilter( securityContextRepository); - SessionManagementConfigurer sessionManagement = http - .getConfigurer(SessionManagementConfigurer.class); - SessionCreationPolicy sessionCreationPolicy = sessionManagement == null ? null - : sessionManagement.getSessionCreationPolicy(); + SessionManagementConfigurer sessionManagement = http.getConfigurer(SessionManagementConfigurer.class); + SessionCreationPolicy sessionCreationPolicy = (sessionManagement != null) + ? sessionManagement.getSessionCreationPolicy() : null; if (SessionCreationPolicy.ALWAYS == sessionCreationPolicy) { securityContextFilter.setForceEagerSessionCreation(true); } securityContextFilter = postProcess(securityContextFilter); http.addFilter(securityContextFilter); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurer.java index 73dccbfea2..5959d9d08e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.List; @@ -57,8 +58,9 @@ import org.springframework.security.web.servletapi.SecurityContextHolderAwareReq * @author Rob Winch * @since 3.2 */ -public final class ServletApiConfigurer> extends - AbstractHttpConfigurer, H> { +public final class ServletApiConfigurer> + extends AbstractHttpConfigurer, H> { + private SecurityContextHolderAwareRequestFilter securityContextRequestFilter = new SecurityContextHolderAwareRequestFilter(); /** @@ -69,39 +71,36 @@ public final class ServletApiConfigurer> extend } public ServletApiConfigurer rolePrefix(String rolePrefix) { - securityContextRequestFilter.setRolePrefix(rolePrefix); + this.securityContextRequestFilter.setRolePrefix(rolePrefix); return this; } @Override @SuppressWarnings("unchecked") public void configure(H http) { - securityContextRequestFilter.setAuthenticationManager(http - .getSharedObject(AuthenticationManager.class)); - ExceptionHandlingConfigurer exceptionConf = http - .getConfigurer(ExceptionHandlingConfigurer.class); - AuthenticationEntryPoint authenticationEntryPoint = exceptionConf == null ? null - : exceptionConf.getAuthenticationEntryPoint(http); - securityContextRequestFilter - .setAuthenticationEntryPoint(authenticationEntryPoint); + this.securityContextRequestFilter.setAuthenticationManager(http.getSharedObject(AuthenticationManager.class)); + ExceptionHandlingConfigurer exceptionConf = http.getConfigurer(ExceptionHandlingConfigurer.class); + AuthenticationEntryPoint authenticationEntryPoint = (exceptionConf != null) + ? exceptionConf.getAuthenticationEntryPoint(http) : null; + this.securityContextRequestFilter.setAuthenticationEntryPoint(authenticationEntryPoint); LogoutConfigurer logoutConf = http.getConfigurer(LogoutConfigurer.class); - List logoutHandlers = logoutConf == null ? null : logoutConf - .getLogoutHandlers(); - securityContextRequestFilter.setLogoutHandlers(logoutHandlers); - AuthenticationTrustResolver trustResolver = http - .getSharedObject(AuthenticationTrustResolver.class); + List logoutHandlers = (logoutConf != null) ? logoutConf.getLogoutHandlers() : null; + this.securityContextRequestFilter.setLogoutHandlers(logoutHandlers); + AuthenticationTrustResolver trustResolver = http.getSharedObject(AuthenticationTrustResolver.class); if (trustResolver != null) { - securityContextRequestFilter.setTrustResolver(trustResolver); + this.securityContextRequestFilter.setTrustResolver(trustResolver); } ApplicationContext context = http.getSharedObject(ApplicationContext.class); if (context != null) { String[] grantedAuthorityDefaultsBeanNames = context.getBeanNamesForType(GrantedAuthorityDefaults.class); if (grantedAuthorityDefaultsBeanNames.length == 1) { - GrantedAuthorityDefaults grantedAuthorityDefaults = context.getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); - securityContextRequestFilter.setRolePrefix(grantedAuthorityDefaults.getRolePrefix()); + GrantedAuthorityDefaults grantedAuthorityDefaults = context + .getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); + this.securityContextRequestFilter.setRolePrefix(grantedAuthorityDefaults.getRolePrefix()); } } - securityContextRequestFilter = postProcess(securityContextRequestFilter); - http.addFilter(securityContextRequestFilter); + this.securityContextRequestFilter = postProcess(this.securityContextRequestFilter); + http.addFilter(this.securityContextRequestFilter); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java index 5d6d2c22f0..86b9cc0275 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurer.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; @@ -98,21 +100,37 @@ import org.springframework.util.CollectionUtils; */ public final class SessionManagementConfigurer> extends AbstractHttpConfigurer, H> { + private final SessionAuthenticationStrategy DEFAULT_SESSION_FIXATION_STRATEGY = createDefaultSessionFixationProtectionStrategy(); + private SessionAuthenticationStrategy sessionFixationAuthenticationStrategy = this.DEFAULT_SESSION_FIXATION_STRATEGY; + private SessionAuthenticationStrategy sessionAuthenticationStrategy; + private SessionAuthenticationStrategy providedSessionAuthenticationStrategy; + private InvalidSessionStrategy invalidSessionStrategy; + private SessionInformationExpiredStrategy expiredSessionStrategy; + private List sessionAuthenticationStrategies = new ArrayList<>(); + private SessionRegistry sessionRegistry; + private Integer maximumSessions; + private String expiredUrl; + private boolean maxSessionsPreventsLogin; + private SessionCreationPolicy sessionPolicy; + private boolean enableSessionUrlRewriting; + private String invalidSessionUrl; + private String sessionAuthenticationErrorUrl; + private AuthenticationFailureHandler sessionAuthenticationFailureHandler; /** @@ -127,7 +145,6 @@ public final class SessionManagementConfigurer> * {@link SimpleRedirectInvalidSessionStrategy} configured with the attribute value. * When an invalid session ID is submitted, the strategy will be invoked, redirecting * to the configured URL. - * * @param invalidSessionUrl the URL to redirect to when an invalid session is detected * @return the {@link SessionManagementConfigurer} for further customization */ @@ -144,8 +161,7 @@ public final class SessionManagementConfigurer> * submitted. * @return the {@link SessionManagementConfigurer} for further customization */ - public SessionManagementConfigurer invalidSessionStrategy( - InvalidSessionStrategy invalidSessionStrategy) { + public SessionManagementConfigurer invalidSessionStrategy(InvalidSessionStrategy invalidSessionStrategy) { Assert.notNull(invalidSessionStrategy, "invalidSessionStrategy"); this.invalidSessionStrategy = invalidSessionStrategy; return this; @@ -157,12 +173,10 @@ public final class SessionManagementConfigurer> * (402) error code will be returned to the client. Note that this attribute doesn't * apply if the error occurs during a form-based login, where the URL for * authentication failure will take precedence. - * * @param sessionAuthenticationErrorUrl the URL to redirect to * @return the {@link SessionManagementConfigurer} for further customization */ - public SessionManagementConfigurer sessionAuthenticationErrorUrl( - String sessionAuthenticationErrorUrl) { + public SessionManagementConfigurer sessionAuthenticationErrorUrl(String sessionAuthenticationErrorUrl) { this.sessionAuthenticationErrorUrl = sessionAuthenticationErrorUrl; return this; } @@ -173,7 +187,6 @@ public final class SessionManagementConfigurer> * (402) error code will be returned to the client. Note that this attribute doesn't * apply if the error occurs during a form-based login, where the URL for * authentication failure will take precedence. - * * @param sessionAuthenticationFailureHandler the handler to use * @return the {@link SessionManagementConfigurer} for further customization */ @@ -188,14 +201,12 @@ public final class SessionManagementConfigurer> * {@link HttpServletResponse#encodeRedirectURL(String)} or * {@link HttpServletResponse#encodeURL(String)}, otherwise disallows HTTP sessions to * be included in the URL. This prevents leaking information to external domains. - * * @param enableSessionUrlRewriting true if should allow the JSESSIONID to be * rewritten into the URLs, else false (default) * @return the {@link SessionManagementConfigurer} for further customization * @see HttpSessionSecurityContextRepository#setDisableUrlRewriting(boolean) */ - public SessionManagementConfigurer enableSessionUrlRewriting( - boolean enableSessionUrlRewriting) { + public SessionManagementConfigurer enableSessionUrlRewriting(boolean enableSessionUrlRewriting) { this.enableSessionUrlRewriting = enableSessionUrlRewriting; return this; } @@ -205,29 +216,27 @@ public final class SessionManagementConfigurer> * @param sessionCreationPolicy the {@link SessionCreationPolicy} to use. Cannot be * null. * @return the {@link SessionManagementConfigurer} for further customizations - * @see SessionCreationPolicy * @throws IllegalArgumentException if {@link SessionCreationPolicy} is null. + * @see SessionCreationPolicy */ - public SessionManagementConfigurer sessionCreationPolicy( - SessionCreationPolicy sessionCreationPolicy) { + public SessionManagementConfigurer sessionCreationPolicy(SessionCreationPolicy sessionCreationPolicy) { Assert.notNull(sessionCreationPolicy, "sessionCreationPolicy cannot be null"); this.sessionPolicy = sessionCreationPolicy; return this; } /** - * Allows explicitly specifying the {@link SessionAuthenticationStrategy}. - * The default is to use {@link ChangeSessionIdAuthenticationStrategy}. - * If restricting the maximum number of sessions is configured, then + * Allows explicitly specifying the {@link SessionAuthenticationStrategy}. The default + * is to use {@link ChangeSessionIdAuthenticationStrategy}. If restricting the maximum + * number of sessions is configured, then * {@link CompositeSessionAuthenticationStrategy} delegating to - * {@link ConcurrentSessionControlAuthenticationStrategy}, - * the default OR supplied {@code SessionAuthenticationStrategy} and + * {@link ConcurrentSessionControlAuthenticationStrategy}, the default OR supplied + * {@code SessionAuthenticationStrategy} and * {@link RegisterSessionAuthenticationStrategy}. * *

      * NOTE: Supplying a custom {@link SessionAuthenticationStrategy} will override the * default session fixation strategy. - * * @param sessionAuthenticationStrategy * @return the {@link SessionManagementConfigurer} for further customizations */ @@ -240,7 +249,6 @@ public final class SessionManagementConfigurer> /** * Adds an additional {@link SessionAuthenticationStrategy} to be used within the * {@link CompositeSessionAuthenticationStrategy}. - * * @param sessionAuthenticationStrategy * @return the {@link SessionManagementConfigurer} for further customizations */ @@ -252,7 +260,6 @@ public final class SessionManagementConfigurer> /** * Allows changing the default {@link SessionFixationProtectionStrategy}. - * * @return the {@link SessionFixationConfigurer} for further customizations */ public SessionFixationConfigurer sessionFixation() { @@ -261,12 +268,12 @@ public final class SessionManagementConfigurer> /** * Allows configuring session fixation protection. - * * @param sessionFixationCustomizer the {@link Customizer} to provide more options for * the {@link SessionFixationConfigurer} * @return the {@link SessionManagementConfigurer} for further customizations */ - public SessionManagementConfigurer sessionFixation(Customizer sessionFixationCustomizer) { + public SessionManagementConfigurer sessionFixation( + Customizer sessionFixationCustomizer) { sessionFixationCustomizer.customize(new SessionFixationConfigurer()); return this; } @@ -285,12 +292,12 @@ public final class SessionManagementConfigurer> /** * Controls the maximum number of sessions for a user. The default is to allow any * number of users. - * - * @param sessionConcurrencyCustomizer the {@link Customizer} to provide more options for - * the {@link ConcurrencyControlConfigurer} + * @param sessionConcurrencyCustomizer the {@link Customizer} to provide more options + * for the {@link ConcurrencyControlConfigurer} * @return the {@link SessionManagementConfigurer} for further customizations */ - public SessionManagementConfigurer sessionConcurrency(Customizer sessionConcurrencyCustomizer) { + public SessionManagementConfigurer sessionConcurrency( + Customizer sessionConcurrencyCustomizer) { sessionConcurrencyCustomizer.customize(new ConcurrencyControlConfigurer()); return this; } @@ -302,207 +309,46 @@ public final class SessionManagementConfigurer> */ private void setSessionFixationAuthenticationStrategy( SessionAuthenticationStrategy sessionFixationAuthenticationStrategy) { - this.sessionFixationAuthenticationStrategy = postProcess( - sessionFixationAuthenticationStrategy); - } - - /** - * Allows configuring SessionFixation protection - * - * @author Rob Winch - */ - public final class SessionFixationConfigurer { - /** - * Specifies that a new session should be created, but the session attributes from - * the original {@link HttpSession} should not be retained. - * - * @return the {@link SessionManagementConfigurer} for further customizations - */ - public SessionManagementConfigurer newSession() { - SessionFixationProtectionStrategy sessionFixationProtectionStrategy = new SessionFixationProtectionStrategy(); - sessionFixationProtectionStrategy.setMigrateSessionAttributes(false); - setSessionFixationAuthenticationStrategy(sessionFixationProtectionStrategy); - return SessionManagementConfigurer.this; - } - - /** - * Specifies that a new session should be created and the session attributes from - * the original {@link HttpSession} should be retained. - * - * @return the {@link SessionManagementConfigurer} for further customizations - */ - public SessionManagementConfigurer migrateSession() { - setSessionFixationAuthenticationStrategy( - new SessionFixationProtectionStrategy()); - return SessionManagementConfigurer.this; - } - - /** - * Specifies that the Servlet container-provided session fixation protection - * should be used. When a session authenticates, the Servlet method - * {@code HttpServletRequest#changeSessionId()} is called to change the session ID - * and retain all session attributes. - * - * @return the {@link SessionManagementConfigurer} for further customizations - */ - public SessionManagementConfigurer changeSessionId() { - setSessionFixationAuthenticationStrategy( - new ChangeSessionIdAuthenticationStrategy()); - return SessionManagementConfigurer.this; - } - - /** - * Specifies that no session fixation protection should be enabled. This may be - * useful when utilizing other mechanisms for protecting against session fixation. - * For example, if application container session fixation protection is already in - * use. Otherwise, this option is not recommended. - * - * @return the {@link SessionManagementConfigurer} for further customizations - */ - public SessionManagementConfigurer none() { - setSessionFixationAuthenticationStrategy( - new NullAuthenticatedSessionStrategy()); - return SessionManagementConfigurer.this; - } - } - - /** - * Allows configuring controlling of multiple sessions. - * - * @author Rob Winch - */ - public final class ConcurrencyControlConfigurer { - - /** - * Controls the maximum number of sessions for a user. The default is to allow any - * number of users. - * - * @param maximumSessions the maximum number of sessions for a user - * @return the {@link ConcurrencyControlConfigurer} for further customizations - */ - public ConcurrencyControlConfigurer maximumSessions(int maximumSessions) { - SessionManagementConfigurer.this.maximumSessions = maximumSessions; - return this; - } - - /** - * The URL to redirect to if a user tries to access a resource and their session - * has been expired due to too many sessions for the current user. The default is - * to write a simple error message to the response. - * - * @param expiredUrl the URL to redirect to - * @return the {@link ConcurrencyControlConfigurer} for further customizations - */ - public ConcurrencyControlConfigurer expiredUrl(String expiredUrl) { - SessionManagementConfigurer.this.expiredUrl = expiredUrl; - return this; - } - - /** - * Determines the behaviour when an expired session is detected. - * - * @param expiredSessionStrategy the {@link SessionInformationExpiredStrategy} to - * use when an expired session is detected. - * @return the {@link ConcurrencyControlConfigurer} for further customizations - */ - public ConcurrencyControlConfigurer expiredSessionStrategy( - SessionInformationExpiredStrategy expiredSessionStrategy) { - SessionManagementConfigurer.this.expiredSessionStrategy = expiredSessionStrategy; - return this; - } - - /** - * If true, prevents a user from authenticating when the - * {@link #maximumSessions(int)} has been reached. Otherwise (default), the user - * who authenticates is allowed access and an existing user's session is expired. - * The user's who's session is forcibly expired is sent to - * {@link #expiredUrl(String)}. The advantage of this approach is if a user - * accidentally does not log out, there is no need for an administrator to - * intervene or wait till their session expires. - * - * @param maxSessionsPreventsLogin true to have an error at time of - * authentication, else false (default) - * @return the {@link ConcurrencyControlConfigurer} for further customizations - */ - public ConcurrencyControlConfigurer maxSessionsPreventsLogin( - boolean maxSessionsPreventsLogin) { - SessionManagementConfigurer.this.maxSessionsPreventsLogin = maxSessionsPreventsLogin; - return this; - } - - /** - * Controls the {@link SessionRegistry} implementation used. The default is - * {@link SessionRegistryImpl} which is an in memory implementation. - * - * @param sessionRegistry the {@link SessionRegistry} to use - * @return the {@link ConcurrencyControlConfigurer} for further customizations - */ - public ConcurrencyControlConfigurer sessionRegistry( - SessionRegistry sessionRegistry) { - SessionManagementConfigurer.this.sessionRegistry = sessionRegistry; - return this; - } - - /** - * Used to chain back to the {@link SessionManagementConfigurer} - * - * @return the {@link SessionManagementConfigurer} for further customizations - */ - public SessionManagementConfigurer and() { - return SessionManagementConfigurer.this; - } - - private ConcurrencyControlConfigurer() { - } + this.sessionFixationAuthenticationStrategy = postProcess(sessionFixationAuthenticationStrategy); } @Override public void init(H http) { - SecurityContextRepository securityContextRepository = http - .getSharedObject(SecurityContextRepository.class); + SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class); boolean stateless = isStateless(); - if (securityContextRepository == null) { if (stateless) { - http.setSharedObject(SecurityContextRepository.class, - new NullSecurityContextRepository()); + http.setSharedObject(SecurityContextRepository.class, new NullSecurityContextRepository()); } else { HttpSessionSecurityContextRepository httpSecurityRepository = new HttpSessionSecurityContextRepository(); - httpSecurityRepository - .setDisableUrlRewriting(!this.enableSessionUrlRewriting); + httpSecurityRepository.setDisableUrlRewriting(!this.enableSessionUrlRewriting); httpSecurityRepository.setAllowSessionCreation(isAllowSessionCreation()); - AuthenticationTrustResolver trustResolver = http - .getSharedObject(AuthenticationTrustResolver.class); + AuthenticationTrustResolver trustResolver = http.getSharedObject(AuthenticationTrustResolver.class); if (trustResolver != null) { httpSecurityRepository.setTrustResolver(trustResolver); } - http.setSharedObject(SecurityContextRepository.class, - httpSecurityRepository); + http.setSharedObject(SecurityContextRepository.class, httpSecurityRepository); } } - RequestCache requestCache = http.getSharedObject(RequestCache.class); if (requestCache == null) { if (stateless) { http.setSharedObject(RequestCache.class, new NullRequestCache()); } } - http.setSharedObject(SessionAuthenticationStrategy.class, - getSessionAuthenticationStrategy(http)); + http.setSharedObject(SessionAuthenticationStrategy.class, getSessionAuthenticationStrategy(http)); http.setSharedObject(InvalidSessionStrategy.class, getInvalidSessionStrategy()); } @Override public void configure(H http) { - SecurityContextRepository securityContextRepository = http - .getSharedObject(SecurityContextRepository.class); - SessionManagementFilter sessionManagementFilter = new SessionManagementFilter( - securityContextRepository, getSessionAuthenticationStrategy(http)); + SecurityContextRepository securityContextRepository = http.getSharedObject(SecurityContextRepository.class); + SessionManagementFilter sessionManagementFilter = new SessionManagementFilter(securityContextRepository, + getSessionAuthenticationStrategy(http)); if (this.sessionAuthenticationErrorUrl != null) { sessionManagementFilter.setAuthenticationFailureHandler( - new SimpleUrlAuthenticationFailureHandler( - this.sessionAuthenticationErrorUrl)); + new SimpleUrlAuthenticationFailureHandler(this.sessionAuthenticationErrorUrl)); } InvalidSessionStrategy strategy = getInvalidSessionStrategy(); if (strategy != null) { @@ -512,13 +358,11 @@ public final class SessionManagementConfigurer> if (failureHandler != null) { sessionManagementFilter.setAuthenticationFailureHandler(failureHandler); } - AuthenticationTrustResolver trustResolver = http - .getSharedObject(AuthenticationTrustResolver.class); + AuthenticationTrustResolver trustResolver = http.getSharedObject(AuthenticationTrustResolver.class); if (trustResolver != null) { sessionManagementFilter.setTrustResolver(trustResolver); } sessionManagementFilter = postProcess(sessionManagementFilter); - http.addFilter(sessionManagementFilter); if (isConcurrentSessionControlEnabled()) { ConcurrentSessionFilter concurrentSessionFilter = createConcurrencyFilter(http); @@ -531,12 +375,9 @@ public final class SessionManagementConfigurer> private ConcurrentSessionFilter createConcurrencyFilter(H http) { SessionInformationExpiredStrategy expireStrategy = getExpiredSessionStrategy(); SessionRegistry sessionRegistry = getSessionRegistry(http); - ConcurrentSessionFilter concurrentSessionFilter; - if (expireStrategy == null) { - concurrentSessionFilter = new ConcurrentSessionFilter(sessionRegistry); - } else { - concurrentSessionFilter = new ConcurrentSessionFilter(sessionRegistry, expireStrategy); - } + ConcurrentSessionFilter concurrentSessionFilter = (expireStrategy != null) + ? new ConcurrentSessionFilter(sessionRegistry, expireStrategy) + : new ConcurrentSessionFilter(sessionRegistry); LogoutConfigurer logoutConfigurer = http.getConfigurer(LogoutConfigurer.class); if (logoutConfigurer != null) { List logoutHandlers = logoutConfigurer.getLogoutHandlers(); @@ -551,20 +392,16 @@ public final class SessionManagementConfigurer> * Gets the {@link InvalidSessionStrategy} to use. If null and * {@link #invalidSessionUrl} is not null defaults to * {@link SimpleRedirectInvalidSessionStrategy}. - * * @return the {@link InvalidSessionStrategy} to use */ InvalidSessionStrategy getInvalidSessionStrategy() { if (this.invalidSessionStrategy != null) { return this.invalidSessionStrategy; } - if (this.invalidSessionUrl == null) { return null; } - - this.invalidSessionStrategy = new SimpleRedirectInvalidSessionStrategy( - this.invalidSessionUrl); + this.invalidSessionStrategy = new SimpleRedirectInvalidSessionStrategy(this.invalidSessionUrl); return this.invalidSessionStrategy; } @@ -572,13 +409,10 @@ public final class SessionManagementConfigurer> if (this.expiredSessionStrategy != null) { return this.expiredSessionStrategy; } - if (this.expiredUrl == null) { return null; } - - this.expiredSessionStrategy = new SimpleRedirectSessionInformationExpiredStrategy( - this.expiredUrl); + this.expiredSessionStrategy = new SimpleRedirectSessionInformationExpiredStrategy(this.expiredUrl); return this.expiredSessionStrategy; } @@ -586,11 +420,9 @@ public final class SessionManagementConfigurer> if (this.sessionAuthenticationFailureHandler != null) { return this.sessionAuthenticationFailureHandler; } - if (this.sessionAuthenticationErrorUrl == null) { return null; } - this.sessionAuthenticationFailureHandler = new SimpleUrlAuthenticationFailureHandler( this.sessionAuthenticationErrorUrl); return this.sessionAuthenticationFailureHandler; @@ -604,11 +436,8 @@ public final class SessionManagementConfigurer> if (this.sessionPolicy != null) { return this.sessionPolicy; } - - SessionCreationPolicy sessionPolicy = - getBuilder().getSharedObject(SessionCreationPolicy.class); - return sessionPolicy == null ? - SessionCreationPolicy.IF_REQUIRED : sessionPolicy; + SessionCreationPolicy sessionPolicy = getBuilder().getSharedObject(SessionCreationPolicy.class); + return (sessionPolicy != null) ? sessionPolicy : SessionCreationPolicy.IF_REQUIRED; } /** @@ -618,8 +447,7 @@ public final class SessionManagementConfigurer> */ private boolean isAllowSessionCreation() { SessionCreationPolicy sessionPolicy = getSessionCreationPolicy(); - return SessionCreationPolicy.ALWAYS == sessionPolicy - || SessionCreationPolicy.IF_REQUIRED == sessionPolicy; + return SessionCreationPolicy.ALWAYS == sessionPolicy || SessionCreationPolicy.IF_REQUIRED == sessionPolicy; } /** @@ -635,7 +463,6 @@ public final class SessionManagementConfigurer> * Gets the customized {@link SessionAuthenticationStrategy} if * {@link #sessionAuthenticationStrategy(SessionAuthenticationStrategy)} was * specified. Otherwise creates a default {@link SessionAuthenticationStrategy}. - * * @return the {@link SessionAuthenticationStrategy} to use */ private SessionAuthenticationStrategy getSessionAuthenticationStrategy(H http) { @@ -647,8 +474,7 @@ public final class SessionManagementConfigurer> if (this.providedSessionAuthenticationStrategy == null) { // If the user did not provide a SessionAuthenticationStrategy // then default to sessionFixationAuthenticationStrategy - defaultSessionAuthenticationStrategy = postProcess( - this.sessionFixationAuthenticationStrategy); + defaultSessionAuthenticationStrategy = postProcess(this.sessionFixationAuthenticationStrategy); } else { defaultSessionAuthenticationStrategy = this.providedSessionAuthenticationStrategy; @@ -658,10 +484,8 @@ public final class SessionManagementConfigurer> ConcurrentSessionControlAuthenticationStrategy concurrentSessionControlStrategy = new ConcurrentSessionControlAuthenticationStrategy( sessionRegistry); concurrentSessionControlStrategy.setMaximumSessions(this.maximumSessions); - concurrentSessionControlStrategy - .setExceptionIfMaximumExceeded(this.maxSessionsPreventsLogin); - concurrentSessionControlStrategy = postProcess( - concurrentSessionControlStrategy); + concurrentSessionControlStrategy.setExceptionIfMaximumExceeded(this.maxSessionsPreventsLogin); + concurrentSessionControlStrategy = postProcess(concurrentSessionControlStrategy); RegisterSessionAuthenticationStrategy registerSessionStrategy = new RegisterSessionAuthenticationStrategy( sessionRegistry); @@ -690,14 +514,12 @@ public final class SessionManagementConfigurer> return this.sessionRegistry; } - private void registerDelegateApplicationListener(H http, - ApplicationListener delegate) { + private void registerDelegateApplicationListener(H http, ApplicationListener delegate) { DelegatingApplicationListener delegating = getBeanOrNull(DelegatingApplicationListener.class); if (delegating == null) { return; } - SmartApplicationListener smartListener = new GenericApplicationListenerAdapter( - delegate); + SmartApplicationListener smartListener = new GenericApplicationListenerAdapter(delegate); delegating.addListener(smartListener); } @@ -714,7 +536,7 @@ public final class SessionManagementConfigurer> * @return the default {@link SessionAuthenticationStrategy} for session fixation */ private static SessionAuthenticationStrategy createDefaultSessionFixationProtectionStrategy() { - return new ChangeSessionIdAuthenticationStrategy(); + return new ChangeSessionIdAuthenticationStrategy(); } private T getBeanOrNull(Class type) { @@ -725,8 +547,147 @@ public final class SessionManagementConfigurer> try { return context.getBean(type); } - catch (NoSuchBeanDefinitionException e) { + catch (NoSuchBeanDefinitionException ex) { return null; } } + + /** + * Allows configuring SessionFixation protection + * + * @author Rob Winch + */ + public final class SessionFixationConfigurer { + + /** + * Specifies that a new session should be created, but the session attributes from + * the original {@link HttpSession} should not be retained. + * @return the {@link SessionManagementConfigurer} for further customizations + */ + public SessionManagementConfigurer newSession() { + SessionFixationProtectionStrategy sessionFixationProtectionStrategy = new SessionFixationProtectionStrategy(); + sessionFixationProtectionStrategy.setMigrateSessionAttributes(false); + setSessionFixationAuthenticationStrategy(sessionFixationProtectionStrategy); + return SessionManagementConfigurer.this; + } + + /** + * Specifies that a new session should be created and the session attributes from + * the original {@link HttpSession} should be retained. + * @return the {@link SessionManagementConfigurer} for further customizations + */ + public SessionManagementConfigurer migrateSession() { + setSessionFixationAuthenticationStrategy(new SessionFixationProtectionStrategy()); + return SessionManagementConfigurer.this; + } + + /** + * Specifies that the Servlet container-provided session fixation protection + * should be used. When a session authenticates, the Servlet method + * {@code HttpServletRequest#changeSessionId()} is called to change the session ID + * and retain all session attributes. + * @return the {@link SessionManagementConfigurer} for further customizations + */ + public SessionManagementConfigurer changeSessionId() { + setSessionFixationAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy()); + return SessionManagementConfigurer.this; + } + + /** + * Specifies that no session fixation protection should be enabled. This may be + * useful when utilizing other mechanisms for protecting against session fixation. + * For example, if application container session fixation protection is already in + * use. Otherwise, this option is not recommended. + * @return the {@link SessionManagementConfigurer} for further customizations + */ + public SessionManagementConfigurer none() { + setSessionFixationAuthenticationStrategy(new NullAuthenticatedSessionStrategy()); + return SessionManagementConfigurer.this; + } + + } + + /** + * Allows configuring controlling of multiple sessions. + * + * @author Rob Winch + */ + public final class ConcurrencyControlConfigurer { + + private ConcurrencyControlConfigurer() { + } + + /** + * Controls the maximum number of sessions for a user. The default is to allow any + * number of users. + * @param maximumSessions the maximum number of sessions for a user + * @return the {@link ConcurrencyControlConfigurer} for further customizations + */ + public ConcurrencyControlConfigurer maximumSessions(int maximumSessions) { + SessionManagementConfigurer.this.maximumSessions = maximumSessions; + return this; + } + + /** + * The URL to redirect to if a user tries to access a resource and their session + * has been expired due to too many sessions for the current user. The default is + * to write a simple error message to the response. + * @param expiredUrl the URL to redirect to + * @return the {@link ConcurrencyControlConfigurer} for further customizations + */ + public ConcurrencyControlConfigurer expiredUrl(String expiredUrl) { + SessionManagementConfigurer.this.expiredUrl = expiredUrl; + return this; + } + + /** + * Determines the behaviour when an expired session is detected. + * @param expiredSessionStrategy the {@link SessionInformationExpiredStrategy} to + * use when an expired session is detected. + * @return the {@link ConcurrencyControlConfigurer} for further customizations + */ + public ConcurrencyControlConfigurer expiredSessionStrategy( + SessionInformationExpiredStrategy expiredSessionStrategy) { + SessionManagementConfigurer.this.expiredSessionStrategy = expiredSessionStrategy; + return this; + } + + /** + * If true, prevents a user from authenticating when the + * {@link #maximumSessions(int)} has been reached. Otherwise (default), the user + * who authenticates is allowed access and an existing user's session is expired. + * The user's who's session is forcibly expired is sent to + * {@link #expiredUrl(String)}. The advantage of this approach is if a user + * accidentally does not log out, there is no need for an administrator to + * intervene or wait till their session expires. + * @param maxSessionsPreventsLogin true to have an error at time of + * authentication, else false (default) + * @return the {@link ConcurrencyControlConfigurer} for further customizations + */ + public ConcurrencyControlConfigurer maxSessionsPreventsLogin(boolean maxSessionsPreventsLogin) { + SessionManagementConfigurer.this.maxSessionsPreventsLogin = maxSessionsPreventsLogin; + return this; + } + + /** + * Controls the {@link SessionRegistry} implementation used. The default is + * {@link SessionRegistryImpl} which is an in memory implementation. + * @param sessionRegistry the {@link SessionRegistry} to use + * @return the {@link ConcurrencyControlConfigurer} for further customizations + */ + public ConcurrencyControlConfigurer sessionRegistry(SessionRegistry sessionRegistry) { + SessionManagementConfigurer.this.sessionRegistry = sessionRegistry; + return this; + } + + /** + * Used to chain back to the {@link SessionManagementConfigurer} + * @return the {@link SessionManagementConfigurer} for further customizations + */ + public SessionManagementConfigurer and() { + return SessionManagementConfigurer.this; + } + + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurer.java index ec653cb5fb..4dd72aa947 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.ArrayList; @@ -77,97 +78,46 @@ import org.springframework.util.Assert; * The following shared objects are used: * *

        - *
      • - * AuthenticationManager
      • + *
      • AuthenticationManager
      • *
      * * @param the type of {@link HttpSecurityBuilder} that is being configured - * * @author Rob Winch * @since 3.2 * @see ExpressionUrlAuthorizationConfigurer */ -public final class UrlAuthorizationConfigurer> extends - AbstractInterceptUrlConfigurer, H> { - private final StandardInterceptUrlRegistry REGISTRY; +public final class UrlAuthorizationConfigurer> + extends AbstractInterceptUrlConfigurer, H> { + + private final StandardInterceptUrlRegistry registry; public UrlAuthorizationConfigurer(ApplicationContext context) { - this.REGISTRY = new StandardInterceptUrlRegistry(context); + this.registry = new StandardInterceptUrlRegistry(context); } /** * The StandardInterceptUrlRegistry is what users will interact with after applying * the {@link UrlAuthorizationConfigurer}. - * * @return the {@link ExpressionUrlAuthorizationConfigurer} for further customizations */ public StandardInterceptUrlRegistry getRegistry() { - return REGISTRY; + return this.registry; } /** * Adds an {@link ObjectPostProcessor} for this class. - * * @param objectPostProcessor * @return the {@link UrlAuthorizationConfigurer} for further customizations */ - public UrlAuthorizationConfigurer withObjectPostProcessor( - ObjectPostProcessor objectPostProcessor) { + @Override + public UrlAuthorizationConfigurer withObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { addObjectPostProcessor(objectPostProcessor); return this; } - public class StandardInterceptUrlRegistry - extends - ExpressionUrlAuthorizationConfigurer.AbstractInterceptUrlRegistry { - - /** - * @param context - */ - private StandardInterceptUrlRegistry(ApplicationContext context) { - setApplicationContext(context); - } - - @Override - public MvcMatchersAuthorizedUrl mvcMatchers(HttpMethod method, - String... mvcPatterns) { - return new MvcMatchersAuthorizedUrl(createMvcMatchers(method, mvcPatterns)); - } - - @Override - public MvcMatchersAuthorizedUrl mvcMatchers(String... patterns) { - return mvcMatchers(null, patterns); - } - - @Override - protected final AuthorizedUrl chainRequestMatchersInternal( - List requestMatchers) { - return new AuthorizedUrl(requestMatchers); - } - - /** - * Adds an {@link ObjectPostProcessor} for this class. - * - * @param objectPostProcessor - * @return the {@link ExpressionUrlAuthorizationConfigurer} for further - * customizations - */ - public StandardInterceptUrlRegistry withObjectPostProcessor( - ObjectPostProcessor objectPostProcessor) { - addObjectPostProcessor(objectPostProcessor); - return this; - } - - public H and() { - return UrlAuthorizationConfigurer.this.and(); - } - - } - /** * Creates the default {@link AccessDecisionVoter} instances used if an * {@link AccessDecisionManager} was not specified. - * * @param http the builder to use */ @Override @@ -182,13 +132,11 @@ public final class UrlAuthorizationConfigurer> /** * Creates the {@link FilterInvocationSecurityMetadataSource} to use. The * implementation is a {@link DefaultFilterInvocationSecurityMetadataSource}. - * * @param http the builder to use */ @Override FilterInvocationSecurityMetadataSource createMetadataSource(H http) { - return new DefaultFilterInvocationSecurityMetadataSource( - REGISTRY.createRequestMap()); + return new DefaultFilterInvocationSecurityMetadataSource(this.registry.createRequestMap()); } /** @@ -200,34 +148,29 @@ public final class UrlAuthorizationConfigurer> * by the {@link RequestMatcher} instances * @return the {@link ExpressionUrlAuthorizationConfigurer} for further customizations */ - private StandardInterceptUrlRegistry addMapping( - Iterable requestMatchers, + private StandardInterceptUrlRegistry addMapping(Iterable requestMatchers, Collection configAttributes) { for (RequestMatcher requestMatcher : requestMatchers) { - REGISTRY.addMapping(new AbstractConfigAttributeRequestMatcherRegistry.UrlMapping( - requestMatcher, configAttributes)); + this.registry.addMapping( + new AbstractConfigAttributeRequestMatcherRegistry.UrlMapping(requestMatcher, configAttributes)); } - return REGISTRY; + return this.registry; } /** * Creates a String for specifying a user requires a role. - * * @param role the role that should be required which is prepended with ROLE_ * automatically (i.e. USER, ADMIN, etc). It should not start with ROLE_ * @return the {@link ConfigAttribute} expressed as a String */ private static String hasRole(String role) { - Assert.isTrue( - !role.startsWith("ROLE_"), - () -> role - + " should not start with ROLE_ since ROLE_ is automatically prepended when using hasRole. Consider using hasAuthority or access instead."); + Assert.isTrue(!role.startsWith("ROLE_"), () -> role + + " should not start with ROLE_ since ROLE_ is automatically prepended when using hasRole. Consider using hasAuthority or access instead."); return "ROLE_" + role; } /** * Creates a String for specifying that a user requires one of many roles. - * * @param roles the roles that the user should have at least one of (i.e. ADMIN, USER, * etc). Each role should not start with ROLE_ since it is automatically prepended * already. @@ -250,6 +193,45 @@ public final class UrlAuthorizationConfigurer> return authorities; } + public final class StandardInterceptUrlRegistry extends + ExpressionUrlAuthorizationConfigurer.AbstractInterceptUrlRegistry { + + private StandardInterceptUrlRegistry(ApplicationContext context) { + setApplicationContext(context); + } + + @Override + public MvcMatchersAuthorizedUrl mvcMatchers(HttpMethod method, String... mvcPatterns) { + return new MvcMatchersAuthorizedUrl(createMvcMatchers(method, mvcPatterns)); + } + + @Override + public MvcMatchersAuthorizedUrl mvcMatchers(String... patterns) { + return mvcMatchers(null, patterns); + } + + @Override + protected AuthorizedUrl chainRequestMatchersInternal(List requestMatchers) { + return new AuthorizedUrl(requestMatchers); + } + + /** + * Adds an {@link ObjectPostProcessor} for this class. + * @param objectPostProcessor + * @return the {@link ExpressionUrlAuthorizationConfigurer} for further + * customizations + */ + public StandardInterceptUrlRegistry withObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { + addObjectPostProcessor(objectPostProcessor); + return this; + } + + public H and() { + return UrlAuthorizationConfigurer.this.and(); + } + + } + /** * An {@link AuthorizedUrl} that allows optionally configuring the * {@link MvcRequestMatcher#setMethod(HttpMethod)} @@ -257,9 +239,9 @@ public final class UrlAuthorizationConfigurer> * @author Rob Winch */ public final class MvcMatchersAuthorizedUrl extends AuthorizedUrl { + /** * Creates a new instance - * * @param requestMatchers the {@link RequestMatcher} instances to map */ private MvcMatchersAuthorizedUrl(List requestMatchers) { @@ -273,6 +255,7 @@ public final class UrlAuthorizationConfigurer> } return this; } + } /** @@ -283,6 +266,7 @@ public final class UrlAuthorizationConfigurer> * @since 3.2 */ public class AuthorizedUrl { + private final List requestMatchers; /** @@ -290,15 +274,13 @@ public final class UrlAuthorizationConfigurer> * @param requestMatchers the {@link RequestMatcher} instances to map to some * {@link ConfigAttribute} instances. */ - private AuthorizedUrl(List requestMatchers) { - Assert.notEmpty(requestMatchers, - "requestMatchers must contain at least one value"); + AuthorizedUrl(List requestMatchers) { + Assert.notEmpty(requestMatchers, "requestMatchers must contain at least one value"); this.requestMatchers = requestMatchers; } /** * Specifies a user requires a role. - * * @param role the role that should be required which is prepended with ROLE_ * automatically (i.e. USER, ADMIN, etc). It should not start with ROLE_ the * {@link UrlAuthorizationConfigurer} for further customization @@ -309,7 +291,6 @@ public final class UrlAuthorizationConfigurer> /** * Specifies that a user requires one of many roles. - * * @param roles the roles that the user should have at least one of (i.e. ADMIN, * USER, etc). Each role should not start with ROLE_ since it is automatically * prepended already. @@ -321,7 +302,6 @@ public final class UrlAuthorizationConfigurer> /** * Specifies a user requires an authority. - * * @param authority the authority that should be required * @return the {@link UrlAuthorizationConfigurer} for further customization */ @@ -353,12 +333,14 @@ public final class UrlAuthorizationConfigurer> * @return the {@link UrlAuthorizationConfigurer} for further customization */ public StandardInterceptUrlRegistry access(String... attributes) { - addMapping(requestMatchers, SecurityConfig.createList(attributes)); - return UrlAuthorizationConfigurer.this.REGISTRY; + addMapping(this.requestMatchers, SecurityConfig.createList(attributes)); + return UrlAuthorizationConfigurer.this.registry; } protected List getMatchers() { return this.requestMatchers; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/X509Configurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/X509Configurer.java index 7b35c367ad..93e1b09250 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/X509Configurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/X509Configurer.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; +import javax.servlet.http.HttpServletRequest; + import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; @@ -32,8 +35,6 @@ import org.springframework.security.web.authentication.preauth.x509.SubjectDnX50 import org.springframework.security.web.authentication.preauth.x509.X509AuthenticationFilter; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; -import javax.servlet.http.HttpServletRequest; - /** * Adds X509 based pre authentication to an application. Since validating the certificate * happens when the client connects, the requesting and validation of the client @@ -53,8 +54,7 @@ import javax.servlet.http.HttpServletRequest; * The following shared objects are created * *
        - *
      • - * {@link AuthenticationEntryPoint} is populated with an + *
      • {@link AuthenticationEntryPoint} is populated with an * {@link Http403ForbiddenEntryPoint}
      • *
      • A {@link PreAuthenticatedAuthenticationProvider} is populated into * {@link HttpSecurity#authenticationProvider(org.springframework.security.authentication.AuthenticationProvider)} @@ -73,11 +73,15 @@ import javax.servlet.http.HttpServletRequest; * @author Rob Winch * @since 3.2 */ -public final class X509Configurer> extends - AbstractHttpConfigurer, H> { +public final class X509Configurer> + extends AbstractHttpConfigurer, H> { + private X509AuthenticationFilter x509AuthenticationFilter; + private X509PrincipalExtractor x509PrincipalExtractor; + private AuthenticationUserDetailsService authenticationUserDetailsService; + private AuthenticationDetailsSource authenticationDetailsSource; /** @@ -92,19 +96,16 @@ public final class X509Configurer> extends * Allows specifying the entire {@link X509AuthenticationFilter}. If this is * specified, the properties on {@link X509Configurer} will not be populated on the * {@link X509AuthenticationFilter}. - * * @param x509AuthenticationFilter the {@link X509AuthenticationFilter} to use * @return the {@link X509Configurer} for further customizations */ - public X509Configurer x509AuthenticationFilter( - X509AuthenticationFilter x509AuthenticationFilter) { + public X509Configurer x509AuthenticationFilter(X509AuthenticationFilter x509AuthenticationFilter) { this.x509AuthenticationFilter = x509AuthenticationFilter; return this; } /** * Specifies the {@link X509PrincipalExtractor} - * * @param x509PrincipalExtractor the {@link X509PrincipalExtractor} to use * @return the {@link X509Configurer} to use */ @@ -115,7 +116,6 @@ public final class X509Configurer> extends /** * Specifies the {@link AuthenticationDetailsSource} - * * @param authenticationDetailsSource the {@link AuthenticationDetailsSource} to use * @return the {@link X509Configurer} to use */ @@ -129,7 +129,6 @@ public final class X509Configurer> extends * Shortcut for invoking * {@link #authenticationUserDetailsService(AuthenticationUserDetailsService)} with a * {@link UserDetailsByNameServiceWrapper}. - * * @param userDetailsService the {@link UserDetailsService} to use * @return the {@link X509Configurer} for further customizations */ @@ -143,8 +142,8 @@ public final class X509Configurer> extends * Specifies the {@link AuthenticationUserDetailsService} to use. If not specified, * the shared {@link UserDetailsService} will be used to create a * {@link UserDetailsByNameServiceWrapper}. - * - * @param authenticationUserDetailsService the {@link AuthenticationUserDetailsService} to use + * @param authenticationUserDetailsService the + * {@link AuthenticationUserDetailsService} to use * @return the {@link X509Configurer} for further customizations */ public X509Configurer authenticationUserDetailsService( @@ -157,9 +156,8 @@ public final class X509Configurer> extends * Specifies the regex to extract the principal from the certificate. If not * specified, the default expression from {@link SubjectDnX509PrincipalExtractor} is * used. - * * @param subjectPrincipalRegex the regex to extract the user principal from the - * certificate (i.e. "CN=(.*?)(?:,|$)"). + * certificate (i.e. "CN=(.*?)(?:,|$)"). * @return the {@link X509Configurer} for further customizations */ public X509Configurer subjectPrincipalRegex(String subjectPrincipalRegex) { @@ -169,48 +167,42 @@ public final class X509Configurer> extends return this; } - // @formatter:off @Override public void init(H http) { PreAuthenticatedAuthenticationProvider authenticationProvider = new PreAuthenticatedAuthenticationProvider(); authenticationProvider.setPreAuthenticatedUserDetailsService(getAuthenticationUserDetailsService(http)); - - http - .authenticationProvider(authenticationProvider) - .setSharedObject(AuthenticationEntryPoint.class, new Http403ForbiddenEntryPoint()); + http.authenticationProvider(authenticationProvider).setSharedObject(AuthenticationEntryPoint.class, + new Http403ForbiddenEntryPoint()); } - // @formatter:on @Override public void configure(H http) { - X509AuthenticationFilter filter = getFilter(http - .getSharedObject(AuthenticationManager.class)); + X509AuthenticationFilter filter = getFilter(http.getSharedObject(AuthenticationManager.class)); http.addFilter(filter); } private X509AuthenticationFilter getFilter(AuthenticationManager authenticationManager) { - if (x509AuthenticationFilter == null) { - x509AuthenticationFilter = new X509AuthenticationFilter(); - x509AuthenticationFilter.setAuthenticationManager(authenticationManager); - if (x509PrincipalExtractor != null) { - x509AuthenticationFilter.setPrincipalExtractor(x509PrincipalExtractor); + if (this.x509AuthenticationFilter == null) { + this.x509AuthenticationFilter = new X509AuthenticationFilter(); + this.x509AuthenticationFilter.setAuthenticationManager(authenticationManager); + if (this.x509PrincipalExtractor != null) { + this.x509AuthenticationFilter.setPrincipalExtractor(this.x509PrincipalExtractor); } - if (authenticationDetailsSource != null) { - x509AuthenticationFilter - .setAuthenticationDetailsSource(authenticationDetailsSource); + if (this.authenticationDetailsSource != null) { + this.x509AuthenticationFilter.setAuthenticationDetailsSource(this.authenticationDetailsSource); } - x509AuthenticationFilter = postProcess(x509AuthenticationFilter); + this.x509AuthenticationFilter = postProcess(this.x509AuthenticationFilter); } - return x509AuthenticationFilter; + return this.x509AuthenticationFilter; } private AuthenticationUserDetailsService getAuthenticationUserDetailsService( H http) { - if (authenticationUserDetailsService == null) { + if (this.authenticationUserDetailsService == null) { userDetailsService(http.getSharedObject(UserDetailsService.class)); } - return authenticationUserDetailsService; + return this.authenticationUserDetailsService; } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java index ccfff084db..51d151c9ed 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/ImplicitGrantConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.oauth2.client; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; @@ -48,26 +49,26 @@ import org.springframework.util.Assert; *
      • {@link ClientRegistrationRepository}
      • *
      * - * @deprecated It is not recommended to use the implicit flow - * due to the inherent risks of returning access tokens in an HTTP redirect - * without any confirmation that it has been received by the client. - * See reference OAuth 2.0 Implicit Grant. - * + * @deprecated It is not recommended to use the implicit flow due to the inherent risks of + * returning access tokens in an HTTP redirect without any confirmation that it has been + * received by the client. See reference + * OAuth 2.0 Implicit + * Grant. * @author Joe Grandja * @since 5.0 * @see OAuth2AuthorizationRequestRedirectFilter * @see ClientRegistrationRepository */ @Deprecated -public final class ImplicitGrantConfigurer> extends - AbstractHttpConfigurer, B> { +public final class ImplicitGrantConfigurer> + extends AbstractHttpConfigurer, B> { private String authorizationRequestBaseUri; /** * Sets the base {@code URI} used for authorization requests. - * - * @param authorizationRequestBaseUri the base {@code URI} used for authorization requests + * @param authorizationRequestBaseUri the base {@code URI} used for authorization + * requests * @return the {@link ImplicitGrantConfigurer} for further configuration */ public ImplicitGrantConfigurer authorizationRequestBaseUri(String authorizationRequestBaseUri) { @@ -78,11 +79,11 @@ public final class ImplicitGrantConfigurer> ext /** * Sets the repository of client registrations. - * * @param clientRegistrationRepository the repository of client registrations * @return the {@link ImplicitGrantConfigurer} for further configuration */ - public ImplicitGrantConfigurer clientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) { + public ImplicitGrantConfigurer clientRegistrationRepository( + ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); return this; @@ -91,13 +92,14 @@ public final class ImplicitGrantConfigurer> ext @Override public void configure(B http) { OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), this.getAuthorizationRequestBaseUri()); + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), + this.getAuthorizationRequestBaseUri()); http.addFilter(this.postProcess(authorizationRequestFilter)); } private String getAuthorizationRequestBaseUri() { - return this.authorizationRequestBaseUri != null ? - this.authorizationRequestBaseUri : - OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; + return (this.authorizationRequestBaseUri != null) ? this.authorizationRequestBaseUri + : OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index af2f56e0cd..a8447e7d14 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.oauth2.client; import org.springframework.security.authentication.AuthenticationManager; @@ -43,13 +44,15 @@ import org.springframework.util.Assert; * The following configuration options are available: * *
        - *
      • {@link #authorizationCodeGrant()} - support for the OAuth 2.0 Authorization Code Grant
      • + *
      • {@link #authorizationCodeGrant()} - support for the OAuth 2.0 Authorization Code + * Grant
      • *
      * *

      - * Defaults are provided for all configuration options with the only required configuration - * being {@link #clientRegistrationRepository(ClientRegistrationRepository)}. - * Alternatively, a {@link ClientRegistrationRepository} {@code @Bean} may be registered instead. + * Defaults are provided for all configuration options with the only required + * configuration being + * {@link #clientRegistrationRepository(ClientRegistrationRepository)}. Alternatively, a + * {@link ClientRegistrationRepository} {@code @Bean} may be registered instead. * *

      Security Filters

      * @@ -87,18 +90,18 @@ import org.springframework.util.Assert; * @see OAuth2AuthorizedClientRepository * @see AbstractHttpConfigurer */ -public final class OAuth2ClientConfigurer> extends - AbstractHttpConfigurer, B> { +public final class OAuth2ClientConfigurer> + extends AbstractHttpConfigurer, B> { private AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer = new AuthorizationCodeGrantConfigurer(); /** * Sets the repository of client registrations. - * * @param clientRegistrationRepository the repository of client registrations * @return the {@link OAuth2ClientConfigurer} for further configuration */ - public OAuth2ClientConfigurer clientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) { + public OAuth2ClientConfigurer clientRegistrationRepository( + ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); return this; @@ -106,11 +109,11 @@ public final class OAuth2ClientConfigurer> exte /** * Sets the repository for authorized client(s). - * * @param authorizedClientRepository the authorized client repository * @return the {@link OAuth2ClientConfigurer} for further configuration */ - public OAuth2ClientConfigurer authorizedClientRepository(OAuth2AuthorizedClientRepository authorizedClientRepository) { + public OAuth2ClientConfigurer authorizedClientRepository( + OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository); return this; @@ -118,19 +121,19 @@ public final class OAuth2ClientConfigurer> exte /** * Sets the service for authorized client(s). - * * @param authorizedClientService the authorized client service * @return the {@link OAuth2ClientConfigurer} for further configuration */ public OAuth2ClientConfigurer authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) { Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); - this.authorizedClientRepository(new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService)); + this.authorizedClientRepository( + new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService)); return this; } /** - * Returns the {@link AuthorizationCodeGrantConfigurer} for configuring the OAuth 2.0 Authorization Code Grant. - * + * Returns the {@link AuthorizationCodeGrantConfigurer} for configuring the OAuth 2.0 + * Authorization Code Grant. * @return the {@link AuthorizationCodeGrantConfigurer} */ public AuthorizationCodeGrantConfigurer authorizationCodeGrant() { @@ -139,22 +142,35 @@ public final class OAuth2ClientConfigurer> exte /** * Configures the OAuth 2.0 Authorization Code Grant. - * - * @param authorizationCodeGrantCustomizer the {@link Customizer} to provide more options for - * the {@link AuthorizationCodeGrantConfigurer} + * @param authorizationCodeGrantCustomizer the {@link Customizer} to provide more + * options for the {@link AuthorizationCodeGrantConfigurer} * @return the {@link OAuth2ClientConfigurer} for further customizations */ - public OAuth2ClientConfigurer authorizationCodeGrant(Customizer authorizationCodeGrantCustomizer) { + public OAuth2ClientConfigurer authorizationCodeGrant( + Customizer authorizationCodeGrantCustomizer) { authorizationCodeGrantCustomizer.customize(this.authorizationCodeGrantConfigurer); return this; } + @Override + public void init(B builder) { + this.authorizationCodeGrantConfigurer.init(builder); + } + + @Override + public void configure(B builder) { + this.authorizationCodeGrantConfigurer.configure(builder); + } + /** * Configuration options for the OAuth 2.0 Authorization Code Grant. */ - public class AuthorizationCodeGrantConfigurer { + public final class AuthorizationCodeGrantConfigurer { + private OAuth2AuthorizationRequestResolver authorizationRequestResolver; + private AuthorizationRequestRepository authorizationRequestRepository; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; private AuthorizationCodeGrantConfigurer() { @@ -162,11 +178,12 @@ public final class OAuth2ClientConfigurer> exte /** * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s + * @param authorizationRequestResolver the resolver used for resolving + * {@link OAuth2AuthorizationRequest}'s * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration */ - public AuthorizationCodeGrantConfigurer authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) { + public AuthorizationCodeGrantConfigurer authorizationRequestResolver( + OAuth2AuthorizationRequestResolver authorizationRequestResolver) { Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null"); this.authorizationRequestResolver = authorizationRequestResolver; return this; @@ -174,27 +191,26 @@ public final class OAuth2ClientConfigurer> exte /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s + * @param authorizationRequestRepository the repository used for storing + * {@link OAuth2AuthorizationRequest}'s * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration */ public AuthorizationCodeGrantConfigurer authorizationRequestRepository( AuthorizationRequestRepository authorizationRequestRepository) { - Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); this.authorizationRequestRepository = authorizationRequestRepository; return this; } /** - * Sets the client used for requesting the access token credential from the Token Endpoint. - * - * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint + * Sets the client used for requesting the access token credential from the Token + * Endpoint. + * @param accessTokenResponseClient the client used for requesting the access + * token credential from the Token Endpoint * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration */ public AuthorizationCodeGrantConfigurer accessTokenResponseClient( OAuth2AccessTokenResponseClient accessTokenResponseClient) { - Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; return this; @@ -202,7 +218,6 @@ public final class OAuth2ClientConfigurer> exte /** * Returns the {@link OAuth2ClientConfigurer} for further configuration. - * * @return the {@link OAuth2ClientConfigurer} */ public OAuth2ClientConfigurer and() { @@ -210,25 +225,27 @@ public final class OAuth2ClientConfigurer> exte } private void init(B builder) { - OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = - new OAuth2AuthorizationCodeAuthenticationProvider(getAccessTokenResponseClient()); + OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( + getAccessTokenResponseClient()); builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider)); } private void configure(B builder) { - OAuth2AuthorizationRequestRedirectFilter authorizationRequestRedirectFilter = createAuthorizationRequestRedirectFilter(builder); + OAuth2AuthorizationRequestRedirectFilter authorizationRequestRedirectFilter = createAuthorizationRequestRedirectFilter( + builder); builder.addFilter(postProcess(authorizationRequestRedirectFilter)); - OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = createAuthorizationCodeGrantFilter(builder); + OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = createAuthorizationCodeGrantFilter( + builder); builder.addFilter(postProcess(authorizationCodeGrantFilter)); } private OAuth2AuthorizationRequestRedirectFilter createAuthorizationRequestRedirectFilter(B builder) { OAuth2AuthorizationRequestResolver resolver = getAuthorizationRequestResolver(); - OAuth2AuthorizationRequestRedirectFilter authorizationRequestRedirectFilter = - new OAuth2AuthorizationRequestRedirectFilter(resolver); - + OAuth2AuthorizationRequestRedirectFilter authorizationRequestRedirectFilter = new OAuth2AuthorizationRequestRedirectFilter( + resolver); if (this.authorizationRequestRepository != null) { - authorizationRequestRedirectFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository); + authorizationRequestRedirectFilter + .setAuthorizationRequestRepository(this.authorizationRequestRepository); } RequestCache requestCache = builder.getSharedObject(RequestCache.class); if (requestCache != null) { @@ -251,9 +268,7 @@ public final class OAuth2ClientConfigurer> exte AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class); OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter( OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), - OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder), - authenticationManager); - + OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder), authenticationManager); if (this.authorizationRequestRepository != null) { authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } @@ -270,15 +285,7 @@ public final class OAuth2ClientConfigurer> exte } return new DefaultAuthorizationCodeTokenResponseClient(); } + } - @Override - public void init(B builder) { - this.authorizationCodeGrantConfigurer.init(builder); - } - - @Override - public void configure(B builder) { - this.authorizationCodeGrantConfigurer.configure(builder); - } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java index 046c607739..0f1dc7ab8f 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.oauth2.client; +import java.util.Map; + import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.context.ApplicationContext; @@ -27,8 +30,6 @@ import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAut import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.StringUtils; -import java.util.Map; - /** * Utility methods for the OAuth 2.0 Client {@link AbstractHttpConfigurer}'s. * @@ -41,7 +42,8 @@ final class OAuth2ClientConfigurerUtils { } static > ClientRegistrationRepository getClientRegistrationRepository(B builder) { - ClientRegistrationRepository clientRegistrationRepository = builder.getSharedObject(ClientRegistrationRepository.class); + ClientRegistrationRepository clientRegistrationRepository = builder + .getSharedObject(ClientRegistrationRepository.class); if (clientRegistrationRepository == null) { clientRegistrationRepository = getClientRegistrationRepositoryBean(builder); builder.setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); @@ -49,12 +51,15 @@ final class OAuth2ClientConfigurerUtils { return clientRegistrationRepository; } - private static > ClientRegistrationRepository getClientRegistrationRepositoryBean(B builder) { + private static > ClientRegistrationRepository getClientRegistrationRepositoryBean( + B builder) { return builder.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class); } - static > OAuth2AuthorizedClientRepository getAuthorizedClientRepository(B builder) { - OAuth2AuthorizedClientRepository authorizedClientRepository = builder.getSharedObject(OAuth2AuthorizedClientRepository.class); + static > OAuth2AuthorizedClientRepository getAuthorizedClientRepository( + B builder) { + OAuth2AuthorizedClientRepository authorizedClientRepository = builder + .getSharedObject(OAuth2AuthorizedClientRepository.class); if (authorizedClientRepository == null) { authorizedClientRepository = getAuthorizedClientRepositoryBean(builder); if (authorizedClientRepository == null) { @@ -66,34 +71,45 @@ final class OAuth2ClientConfigurerUtils { return authorizedClientRepository; } - private static > OAuth2AuthorizedClientRepository getAuthorizedClientRepositoryBean(B builder) { - Map authorizedClientRepositoryMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( - builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizedClientRepository.class); + private static > OAuth2AuthorizedClientRepository getAuthorizedClientRepositoryBean( + B builder) { + Map authorizedClientRepositoryMap = BeanFactoryUtils + .beansOfTypeIncludingAncestors(builder.getSharedObject(ApplicationContext.class), + OAuth2AuthorizedClientRepository.class); if (authorizedClientRepositoryMap.size() > 1) { - throw new NoUniqueBeanDefinitionException(OAuth2AuthorizedClientRepository.class, authorizedClientRepositoryMap.size(), - "Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() + "' but found " + - authorizedClientRepositoryMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizedClientRepositoryMap.keySet())); + throw new NoUniqueBeanDefinitionException(OAuth2AuthorizedClientRepository.class, + authorizedClientRepositoryMap.size(), + "Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() + + "' but found " + authorizedClientRepositoryMap.size() + ": " + + StringUtils.collectionToCommaDelimitedString(authorizedClientRepositoryMap.keySet())); } - return (!authorizedClientRepositoryMap.isEmpty() ? authorizedClientRepositoryMap.values().iterator().next() : null); + return (!authorizedClientRepositoryMap.isEmpty() ? authorizedClientRepositoryMap.values().iterator().next() + : null); } - - private static > OAuth2AuthorizedClientService getAuthorizedClientService(B builder) { + private static > OAuth2AuthorizedClientService getAuthorizedClientService( + B builder) { OAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientServiceBean(builder); if (authorizedClientService == null) { - authorizedClientService = new InMemoryOAuth2AuthorizedClientService(getClientRegistrationRepository(builder)); + authorizedClientService = new InMemoryOAuth2AuthorizedClientService( + getClientRegistrationRepository(builder)); } return authorizedClientService; } - private static > OAuth2AuthorizedClientService getAuthorizedClientServiceBean(B builder) { - Map authorizedClientServiceMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( - builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizedClientService.class); + private static > OAuth2AuthorizedClientService getAuthorizedClientServiceBean( + B builder) { + Map authorizedClientServiceMap = BeanFactoryUtils + .beansOfTypeIncludingAncestors(builder.getSharedObject(ApplicationContext.class), + OAuth2AuthorizedClientService.class); if (authorizedClientServiceMap.size() > 1) { - throw new NoUniqueBeanDefinitionException(OAuth2AuthorizedClientService.class, authorizedClientServiceMap.size(), - "Expected single matching bean of type '" + OAuth2AuthorizedClientService.class.getName() + "' but found " + - authorizedClientServiceMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizedClientServiceMap.keySet())); + throw new NoUniqueBeanDefinitionException(OAuth2AuthorizedClientService.class, + authorizedClientServiceMap.size(), + "Expected single matching bean of type '" + OAuth2AuthorizedClientService.class.getName() + + "' but found " + authorizedClientServiceMap.size() + ": " + + StringUtils.collectionToCommaDelimitedString(authorizedClientServiceMap.keySet())); } return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 91c8f7f128..a7e0cc46cb 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.oauth2.client; import java.util.ArrayList; @@ -79,17 +80,18 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; /** - * An {@link AbstractHttpConfigurer} for OAuth 2.0 Login, - * which leverages the OAuth 2.0 Authorization Code Grant Flow. + * An {@link AbstractHttpConfigurer} for OAuth 2.0 Login, which leverages the OAuth 2.0 + * Authorization Code Grant Flow. * *

      - * OAuth 2.0 Login provides an application with the capability to have users log in - * by using their existing account at an OAuth 2.0 or OpenID Connect 1.0 Provider. + * OAuth 2.0 Login provides an application with the capability to have users log in by + * using their existing account at an OAuth 2.0 or OpenID Connect 1.0 Provider. * *

      - * Defaults are provided for all configuration options with the only required configuration - * being {@link #clientRegistrationRepository(ClientRegistrationRepository)}. - * Alternatively, a {@link ClientRegistrationRepository} {@code @Bean} may be registered instead. + * Defaults are provided for all configuration options with the only required + * configuration being + * {@link #clientRegistrationRepository(ClientRegistrationRepository)}. Alternatively, a + * {@link ClientRegistrationRepository} {@code @Bean} may be registered instead. * *

      Security Filters

      * @@ -118,8 +120,9 @@ import org.springframework.util.ClassUtils; *
    • {@link ClientRegistrationRepository}
    • *
    • {@link OAuth2AuthorizedClientRepository}
    • *
    • {@link GrantedAuthoritiesMapper}
    • - *
    • {@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not configured - * and {@code DefaultLoginPageGeneratingFilter} is available, then a default login page will be made available
    • + *
    • {@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not + * configured and {@code DefaultLoginPageGeneratingFilter} is available, then a default + * login page will be made available
    • *
    * * @author Joe Grandja @@ -132,23 +135,28 @@ import org.springframework.util.ClassUtils; * @see OAuth2AuthorizedClientRepository * @see AbstractAuthenticationFilterConfigurer */ -public final class OAuth2LoginConfigurer> extends - AbstractAuthenticationFilterConfigurer, OAuth2LoginAuthenticationFilter> { +public final class OAuth2LoginConfigurer> + extends AbstractAuthenticationFilterConfigurer, OAuth2LoginAuthenticationFilter> { private final AuthorizationEndpointConfig authorizationEndpointConfig = new AuthorizationEndpointConfig(); + private final TokenEndpointConfig tokenEndpointConfig = new TokenEndpointConfig(); + private final RedirectionEndpointConfig redirectionEndpointConfig = new RedirectionEndpointConfig(); + private final UserInfoEndpointConfig userInfoEndpointConfig = new UserInfoEndpointConfig(); + private String loginPage; + private String loginProcessingUrl = OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; /** * Sets the repository of client registrations. - * * @param clientRegistrationRepository the repository of client registrations * @return the {@link OAuth2LoginConfigurer} for further configuration */ - public OAuth2LoginConfigurer clientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) { + public OAuth2LoginConfigurer clientRegistrationRepository( + ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); return this; @@ -156,12 +164,12 @@ public final class OAuth2LoginConfigurer> exten /** * Sets the repository for authorized client(s). - * - * @since 5.1 * @param authorizedClientRepository the authorized client repository * @return the {@link OAuth2LoginConfigurer} for further configuration + * @since 5.1 */ - public OAuth2LoginConfigurer authorizedClientRepository(OAuth2AuthorizedClientRepository authorizedClientRepository) { + public OAuth2LoginConfigurer authorizedClientRepository( + OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository); return this; @@ -169,13 +177,13 @@ public final class OAuth2LoginConfigurer> exten /** * Sets the service for authorized client(s). - * * @param authorizedClientService the authorized client service * @return the {@link OAuth2LoginConfigurer} for further configuration */ public OAuth2LoginConfigurer authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) { Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); - this.authorizedClientRepository(new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService)); + this.authorizedClientRepository( + new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService)); return this; } @@ -194,8 +202,8 @@ public final class OAuth2LoginConfigurer> exten } /** - * Returns the {@link AuthorizationEndpointConfig} for configuring the Authorization Server's Authorization Endpoint. - * + * Returns the {@link AuthorizationEndpointConfig} for configuring the Authorization + * Server's Authorization Endpoint. * @return the {@link AuthorizationEndpointConfig} */ public AuthorizationEndpointConfig authorizationEndpoint() { @@ -204,77 +212,19 @@ public final class OAuth2LoginConfigurer> exten /** * Configures the Authorization Server's Authorization Endpoint. - * - * @param authorizationEndpointCustomizer the {@link Customizer} to provide more options for - * the {@link AuthorizationEndpointConfig} + * @param authorizationEndpointCustomizer the {@link Customizer} to provide more + * options for the {@link AuthorizationEndpointConfig} * @return the {@link OAuth2LoginConfigurer} for further customizations */ - public OAuth2LoginConfigurer authorizationEndpoint(Customizer authorizationEndpointCustomizer) { + public OAuth2LoginConfigurer authorizationEndpoint( + Customizer authorizationEndpointCustomizer) { authorizationEndpointCustomizer.customize(this.authorizationEndpointConfig); return this; } /** - * Configuration options for the Authorization Server's Authorization Endpoint. - */ - public class AuthorizationEndpointConfig { - private String authorizationRequestBaseUri; - private OAuth2AuthorizationRequestResolver authorizationRequestResolver; - private AuthorizationRequestRepository authorizationRequestRepository; - - private AuthorizationEndpointConfig() { - } - - /** - * Sets the base {@code URI} used for authorization requests. - * - * @param authorizationRequestBaseUri the base {@code URI} used for authorization requests - * @return the {@link AuthorizationEndpointConfig} for further configuration - */ - public AuthorizationEndpointConfig baseUri(String authorizationRequestBaseUri) { - Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); - this.authorizationRequestBaseUri = authorizationRequestBaseUri; - return this; - } - - /** - * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s. - * - * @since 5.1 - * @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s - * @return the {@link AuthorizationEndpointConfig} for further configuration - */ - public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) { - Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null"); - this.authorizationRequestResolver = authorizationRequestResolver; - return this; - } - - /** - * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s - * @return the {@link AuthorizationEndpointConfig} for further configuration - */ - public AuthorizationEndpointConfig authorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) { - Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); - this.authorizationRequestRepository = authorizationRequestRepository; - return this; - } - - /** - * Returns the {@link OAuth2LoginConfigurer} for further configuration. - * - * @return the {@link OAuth2LoginConfigurer} - */ - public OAuth2LoginConfigurer and() { - return OAuth2LoginConfigurer.this; - } - } - - /** - * Returns the {@link TokenEndpointConfig} for configuring the Authorization Server's Token Endpoint. - * + * Returns the {@link TokenEndpointConfig} for configuring the Authorization Server's + * Token Endpoint. * @return the {@link TokenEndpointConfig} */ public TokenEndpointConfig tokenEndpoint() { @@ -283,7 +233,6 @@ public final class OAuth2LoginConfigurer> exten /** * Configures the Authorization Server's Token Endpoint. - * * @param tokenEndpointCustomizer the {@link Customizer} to provide more options for * the {@link TokenEndpointConfig} * @return the {@link OAuth2LoginConfigurer} for further customizations @@ -295,41 +244,8 @@ public final class OAuth2LoginConfigurer> exten } /** - * Configuration options for the Authorization Server's Token Endpoint. - */ - public class TokenEndpointConfig { - private OAuth2AccessTokenResponseClient accessTokenResponseClient; - - private TokenEndpointConfig() { - } - - /** - * Sets the client used for requesting the access token credential from the Token Endpoint. - * - * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint - * @return the {@link TokenEndpointConfig} for further configuration - */ - public TokenEndpointConfig accessTokenResponseClient( - OAuth2AccessTokenResponseClient accessTokenResponseClient) { - - Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); - this.accessTokenResponseClient = accessTokenResponseClient; - return this; - } - - /** - * Returns the {@link OAuth2LoginConfigurer} for further configuration. - * - * @return the {@link OAuth2LoginConfigurer} - */ - public OAuth2LoginConfigurer and() { - return OAuth2LoginConfigurer.this; - } - } - - /** - * Returns the {@link RedirectionEndpointConfig} for configuring the Client's Redirection Endpoint. - * + * Returns the {@link RedirectionEndpointConfig} for configuring the Client's + * Redirection Endpoint. * @return the {@link RedirectionEndpointConfig} */ public RedirectionEndpointConfig redirectionEndpoint() { @@ -338,50 +254,19 @@ public final class OAuth2LoginConfigurer> exten /** * Configures the Client's Redirection Endpoint. - * - * @param redirectionEndpointCustomizer the {@link Customizer} to provide more options for - * the {@link RedirectionEndpointConfig} + * @param redirectionEndpointCustomizer the {@link Customizer} to provide more options + * for the {@link RedirectionEndpointConfig} * @return the {@link OAuth2LoginConfigurer} for further customizations */ - public OAuth2LoginConfigurer redirectionEndpoint(Customizer redirectionEndpointCustomizer) { + public OAuth2LoginConfigurer redirectionEndpoint( + Customizer redirectionEndpointCustomizer) { redirectionEndpointCustomizer.customize(this.redirectionEndpointConfig); return this; } /** - * Configuration options for the Client's Redirection Endpoint. - */ - public class RedirectionEndpointConfig { - private String authorizationResponseBaseUri; - - private RedirectionEndpointConfig() { - } - - /** - * Sets the {@code URI} where the authorization response will be processed. - * - * @param authorizationResponseBaseUri the {@code URI} where the authorization response will be processed - * @return the {@link RedirectionEndpointConfig} for further configuration - */ - public RedirectionEndpointConfig baseUri(String authorizationResponseBaseUri) { - Assert.hasText(authorizationResponseBaseUri, "authorizationResponseBaseUri cannot be empty"); - this.authorizationResponseBaseUri = authorizationResponseBaseUri; - return this; - } - - /** - * Returns the {@link OAuth2LoginConfigurer} for further configuration. - * - * @return the {@link OAuth2LoginConfigurer} - */ - public OAuth2LoginConfigurer and() { - return OAuth2LoginConfigurer.this; - } - } - - /** - * Returns the {@link UserInfoEndpointConfig} for configuring the Authorization Server's UserInfo Endpoint. - * + * Returns the {@link UserInfoEndpointConfig} for configuring the Authorization + * Server's UserInfo Endpoint. * @return the {@link UserInfoEndpointConfig} */ public UserInfoEndpointConfig userInfoEndpoint() { @@ -390,9 +275,8 @@ public final class OAuth2LoginConfigurer> exten /** * Configures the Authorization Server's UserInfo Endpoint. - * - * @param userInfoEndpointCustomizer the {@link Customizer} to provide more options for - * the {@link UserInfoEndpointConfig} + * @param userInfoEndpointCustomizer the {@link Customizer} to provide more options + * for the {@link UserInfoEndpointConfig} * @return the {@link OAuth2LoginConfigurer} for further customizations */ public OAuth2LoginConfigurer userInfoEndpoint(Customizer userInfoEndpointCustomizer) { @@ -400,96 +284,19 @@ public final class OAuth2LoginConfigurer> exten return this; } - /** - * Configuration options for the Authorization Server's UserInfo Endpoint. - */ - public class UserInfoEndpointConfig { - private OAuth2UserService userService; - private OAuth2UserService oidcUserService; - private Map> customUserTypes = new HashMap<>(); - - private UserInfoEndpointConfig() { - } - - /** - * Sets the OAuth 2.0 service used for obtaining the user attributes of the End-User from the UserInfo Endpoint. - * - * @param userService the OAuth 2.0 service used for obtaining the user attributes of the End-User from the UserInfo Endpoint - * @return the {@link UserInfoEndpointConfig} for further configuration - */ - public UserInfoEndpointConfig userService(OAuth2UserService userService) { - Assert.notNull(userService, "userService cannot be null"); - this.userService = userService; - return this; - } - - /** - * Sets the OpenID Connect 1.0 service used for obtaining the user attributes of the End-User from the UserInfo Endpoint. - * - * @param oidcUserService the OpenID Connect 1.0 service used for obtaining the user attributes of the End-User from the UserInfo Endpoint - * @return the {@link UserInfoEndpointConfig} for further configuration - */ - public UserInfoEndpointConfig oidcUserService(OAuth2UserService oidcUserService) { - Assert.notNull(oidcUserService, "oidcUserService cannot be null"); - this.oidcUserService = oidcUserService; - return this; - } - - /** - * Sets a custom {@link OAuth2User} type and associates it to the provided - * client {@link ClientRegistration#getRegistrationId() registration identifier}. - * - * @deprecated See {@link CustomUserTypesOAuth2UserService} for alternative usage. - * - * @param customUserType a custom {@link OAuth2User} type - * @param clientRegistrationId the client registration identifier - * @return the {@link UserInfoEndpointConfig} for further configuration - */ - @Deprecated - public UserInfoEndpointConfig customUserType(Class customUserType, String clientRegistrationId) { - Assert.notNull(customUserType, "customUserType cannot be null"); - Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); - this.customUserTypes.put(clientRegistrationId, customUserType); - return this; - } - - /** - * Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OAuth2User#getAuthorities()}. - * - * @param userAuthoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities - * @return the {@link UserInfoEndpointConfig} for further configuration - */ - public UserInfoEndpointConfig userAuthoritiesMapper(GrantedAuthoritiesMapper userAuthoritiesMapper) { - Assert.notNull(userAuthoritiesMapper, "userAuthoritiesMapper cannot be null"); - OAuth2LoginConfigurer.this.getBuilder().setSharedObject(GrantedAuthoritiesMapper.class, userAuthoritiesMapper); - return this; - } - - /** - * Returns the {@link OAuth2LoginConfigurer} for further configuration. - * - * @return the {@link OAuth2LoginConfigurer} - */ - public OAuth2LoginConfigurer and() { - return OAuth2LoginConfigurer.this; - } - } - @Override public void init(B http) throws Exception { - OAuth2LoginAuthenticationFilter authenticationFilter = - new OAuth2LoginAuthenticationFilter( + OAuth2LoginAuthenticationFilter authenticationFilter = new OAuth2LoginAuthenticationFilter( OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), - OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()), - this.loginProcessingUrl); + OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()), this.loginProcessingUrl); this.setAuthenticationFilter(authenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); - if (this.loginPage != null) { // Set custom login page super.loginPage(this.loginPage); super.init(http); - } else { + } + else { Map loginUrlToClientName = this.getLoginLinks(); if (loginUrlToClientName.size() == 1) { // Setup auto-redirect to provider login page @@ -498,33 +305,29 @@ public final class OAuth2LoginConfigurer> exten this.updateAccessDefaults(http); String providerLoginPage = loginUrlToClientName.keySet().iterator().next(); this.registerAuthenticationEntryPoint(http, this.getLoginEntryPoint(http, providerLoginPage)); - } else { + } + else { super.init(http); } } - - OAuth2AccessTokenResponseClient accessTokenResponseClient = - this.tokenEndpointConfig.accessTokenResponseClient; + OAuth2AccessTokenResponseClient accessTokenResponseClient = this.tokenEndpointConfig.accessTokenResponseClient; if (accessTokenResponseClient == null) { accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); } - OAuth2UserService oauth2UserService = getOAuth2UserService(); - OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider = - new OAuth2LoginAuthenticationProvider(accessTokenResponseClient, oauth2UserService); + OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider = new OAuth2LoginAuthenticationProvider( + accessTokenResponseClient, oauth2UserService); GrantedAuthoritiesMapper userAuthoritiesMapper = this.getGrantedAuthoritiesMapper(); if (userAuthoritiesMapper != null) { oauth2LoginAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } http.authenticationProvider(this.postProcess(oauth2LoginAuthenticationProvider)); - - boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent( - "org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); - + boolean oidcAuthenticationProviderEnabled = ClassUtils + .isPresent("org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); if (oidcAuthenticationProviderEnabled) { OAuth2UserService oidcUserService = getOidcUserService(); - OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = - new OidcAuthorizationCodeAuthenticationProvider(accessTokenResponseClient, oidcUserService); + OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider = new OidcAuthorizationCodeAuthenticationProvider( + accessTokenResponseClient, oidcUserService); JwtDecoderFactory jwtDecoderFactory = this.getJwtDecoderFactoryBean(); if (jwtDecoderFactory != null) { oidcAuthorizationCodeAuthenticationProvider.setJwtDecoderFactory(jwtDecoderFactory); @@ -533,46 +336,45 @@ public final class OAuth2LoginConfigurer> exten oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper); } http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider)); - } else { + } + else { http.authenticationProvider(new OidcAuthenticationRequestChecker()); } - this.initDefaultLoginFilter(http); } @Override public void configure(B http) throws Exception { OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter; - if (this.authorizationEndpointConfig.authorizationRequestResolver != null) { authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( this.authorizationEndpointConfig.authorizationRequestResolver); - } else { + } + else { String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri; if (authorizationRequestBaseUri == null) { authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; } authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri); + OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), + authorizationRequestBaseUri); } - if (this.authorizationEndpointConfig.authorizationRequestRepository != null) { - authorizationRequestFilter.setAuthorizationRequestRepository( - this.authorizationEndpointConfig.authorizationRequestRepository); + authorizationRequestFilter + .setAuthorizationRequestRepository(this.authorizationEndpointConfig.authorizationRequestRepository); } RequestCache requestCache = http.getSharedObject(RequestCache.class); if (requestCache != null) { authorizationRequestFilter.setRequestCache(requestCache); } http.addFilter(this.postProcess(authorizationRequestFilter)); - OAuth2LoginAuthenticationFilter authenticationFilter = this.getAuthenticationFilter(); if (this.redirectionEndpointConfig.authorizationResponseBaseUri != null) { authenticationFilter.setFilterProcessesUrl(this.redirectionEndpointConfig.authorizationResponseBaseUri); } if (this.authorizationEndpointConfig.authorizationRequestRepository != null) { - authenticationFilter.setAuthorizationRequestRepository( - this.authorizationEndpointConfig.authorizationRequestRepository); + authenticationFilter + .setAuthorizationRequestRepository(this.authorizationEndpointConfig.authorizationRequestRepository); } super.configure(http); } @@ -590,14 +392,15 @@ public final class OAuth2LoginConfigurer> exten throw new NoUniqueBeanDefinitionException(type, names); } if (names.length == 1) { - return (JwtDecoderFactory) this.getBuilder().getSharedObject(ApplicationContext.class).getBean(names[0]); + return (JwtDecoderFactory) this.getBuilder().getSharedObject(ApplicationContext.class) + .getBean(names[0]); } return null; } private GrantedAuthoritiesMapper getGrantedAuthoritiesMapper() { - GrantedAuthoritiesMapper grantedAuthoritiesMapper = - this.getBuilder().getSharedObject(GrantedAuthoritiesMapper.class); + GrantedAuthoritiesMapper grantedAuthoritiesMapper = this.getBuilder() + .getSharedObject(GrantedAuthoritiesMapper.class); if (grantedAuthoritiesMapper == null) { grantedAuthoritiesMapper = this.getGrantedAuthoritiesMapperBean(); if (grantedAuthoritiesMapper != null) { @@ -608,9 +411,8 @@ public final class OAuth2LoginConfigurer> exten } private GrantedAuthoritiesMapper getGrantedAuthoritiesMapperBean() { - Map grantedAuthoritiesMapperMap = - BeanFactoryUtils.beansOfTypeIncludingAncestors( - this.getBuilder().getSharedObject(ApplicationContext.class), + Map grantedAuthoritiesMapperMap = BeanFactoryUtils + .beansOfTypeIncludingAncestors(this.getBuilder().getSharedObject(ApplicationContext.class), GrantedAuthoritiesMapper.class); return (!grantedAuthoritiesMapperMap.isEmpty() ? grantedAuthoritiesMapperMap.values().iterator().next() : null); } @@ -619,53 +421,48 @@ public final class OAuth2LoginConfigurer> exten if (this.userInfoEndpointConfig.oidcUserService != null) { return this.userInfoEndpointConfig.oidcUserService; } - ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OidcUserRequest.class, OidcUser.class); + ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OidcUserRequest.class, + OidcUser.class); OAuth2UserService bean = getBeanOrNull(type); - if (bean == null) { - return new OidcUserService(); - } - - return bean; + return (bean != null) ? bean : new OidcUserService(); } private OAuth2UserService getOAuth2UserService() { if (this.userInfoEndpointConfig.userService != null) { return this.userInfoEndpointConfig.userService; } - ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OAuth2UserRequest.class, OAuth2User.class); + ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OAuth2UserRequest.class, + OAuth2User.class); OAuth2UserService bean = getBeanOrNull(type); - if (bean == null) { - if (!this.userInfoEndpointConfig.customUserTypes.isEmpty()) { - List> userServices = new ArrayList<>(); - userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes)); - userServices.add(new DefaultOAuth2UserService()); - return new DelegatingOAuth2UserService<>(userServices); - } else { - return new DefaultOAuth2UserService(); - } + if (bean != null) { + return bean; } - - return bean; + if (this.userInfoEndpointConfig.customUserTypes.isEmpty()) { + return new DefaultOAuth2UserService(); + } + List> userServices = new ArrayList<>(); + userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes)); + userServices.add(new DefaultOAuth2UserService()); + return new DelegatingOAuth2UserService<>(userServices); } private T getBeanOrNull(ResolvableType type) { ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class); - if (context == null) { - return null; - } - String[] names = context.getBeanNamesForType(type); - if (names.length == 1) { - return (T) context.getBean(names[0]); + if (context != null) { + String[] names = context.getBeanNamesForType(type); + if (names.length == 1) { + return (T) context.getBean(names[0]); + } } return null; } private void initDefaultLoginFilter(B http) { - DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class); + DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http + .getSharedObject(DefaultLoginPageGeneratingFilter.class); if (loginPageGeneratingFilter == null || this.isCustomLoginPage()) { return; } - loginPageGeneratingFilter.setOauth2LoginEnabled(true); loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(this.getLoginLinks()); loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage()); @@ -675,8 +472,8 @@ public final class OAuth2LoginConfigurer> exten @SuppressWarnings("unchecked") private Map getLoginLinks() { Iterable clientRegistrations = null; - ClientRegistrationRepository clientRegistrationRepository = - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()); + ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils + .getClientRegistrationRepository(this.getBuilder()); ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class); if (type != ResolvableType.NONE && ClientRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) { clientRegistrations = (Iterable) clientRegistrationRepository; @@ -684,15 +481,12 @@ public final class OAuth2LoginConfigurer> exten if (clientRegistrations == null) { return Collections.emptyMap(); } - - String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri != null ? - this.authorizationEndpointConfig.authorizationRequestBaseUri : - OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; + String authorizationRequestBaseUri = (this.authorizationEndpointConfig.authorizationRequestBaseUri != null) + ? this.authorizationEndpointConfig.authorizationRequestBaseUri + : OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; Map loginUrlToClientName = new HashMap<>(); - clientRegistrations.forEach(registration -> loginUrlToClientName.put( - authorizationRequestBaseUri + "/" + registration.getRegistrationId(), - registration.getClientName())); - + clientRegistrations.forEach((registration) -> loginUrlToClientName.put( + authorizationRequestBaseUri + "/" + registration.getRegistrationId(), registration.getClientName())); return loginUrlToClientName; } @@ -702,41 +496,244 @@ public final class OAuth2LoginConfigurer> exten RequestMatcher defaultEntryPointMatcher = this.getAuthenticationEntryPointMatcher(http); RequestMatcher defaultLoginPageMatcher = new AndRequestMatcher( new OrRequestMatcher(loginPageMatcher, faviconMatcher), defaultEntryPointMatcher); - RequestMatcher notXRequestedWith = new NegatedRequestMatcher( new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest")); - LinkedHashMap entryPoints = new LinkedHashMap<>(); entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher)), new LoginUrlAuthenticationEntryPoint(providerLoginPage)); - DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints); loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint()); - return loginEntryPoint; } + /** + * Configuration options for the Authorization Server's Authorization Endpoint. + */ + public final class AuthorizationEndpointConfig { + + private String authorizationRequestBaseUri; + + private OAuth2AuthorizationRequestResolver authorizationRequestResolver; + + private AuthorizationRequestRepository authorizationRequestRepository; + + private AuthorizationEndpointConfig() { + } + + /** + * Sets the base {@code URI} used for authorization requests. + * @param authorizationRequestBaseUri the base {@code URI} used for authorization + * requests + * @return the {@link AuthorizationEndpointConfig} for further configuration + */ + public AuthorizationEndpointConfig baseUri(String authorizationRequestBaseUri) { + Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); + this.authorizationRequestBaseUri = authorizationRequestBaseUri; + return this; + } + + /** + * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s. + * @param authorizationRequestResolver the resolver used for resolving + * {@link OAuth2AuthorizationRequest}'s + * @return the {@link AuthorizationEndpointConfig} for further configuration + * @since 5.1 + */ + public AuthorizationEndpointConfig authorizationRequestResolver( + OAuth2AuthorizationRequestResolver authorizationRequestResolver) { + Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null"); + this.authorizationRequestResolver = authorizationRequestResolver; + return this; + } + + /** + * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. + * @param authorizationRequestRepository the repository used for storing + * {@link OAuth2AuthorizationRequest}'s + * @return the {@link AuthorizationEndpointConfig} for further configuration + */ + public AuthorizationEndpointConfig authorizationRequestRepository( + AuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + return this; + } + + /** + * Returns the {@link OAuth2LoginConfigurer} for further configuration. + * @return the {@link OAuth2LoginConfigurer} + */ + public OAuth2LoginConfigurer and() { + return OAuth2LoginConfigurer.this; + } + + } + + /** + * Configuration options for the Authorization Server's Token Endpoint. + */ + public final class TokenEndpointConfig { + + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + + private TokenEndpointConfig() { + } + + /** + * Sets the client used for requesting the access token credential from the Token + * Endpoint. + * @param accessTokenResponseClient the client used for requesting the access + * token credential from the Token Endpoint + * @return the {@link TokenEndpointConfig} for further configuration + */ + public TokenEndpointConfig accessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Returns the {@link OAuth2LoginConfigurer} for further configuration. + * @return the {@link OAuth2LoginConfigurer} + */ + public OAuth2LoginConfigurer and() { + return OAuth2LoginConfigurer.this; + } + + } + + /** + * Configuration options for the Client's Redirection Endpoint. + */ + public final class RedirectionEndpointConfig { + + private String authorizationResponseBaseUri; + + private RedirectionEndpointConfig() { + } + + /** + * Sets the {@code URI} where the authorization response will be processed. + * @param authorizationResponseBaseUri the {@code URI} where the authorization + * response will be processed + * @return the {@link RedirectionEndpointConfig} for further configuration + */ + public RedirectionEndpointConfig baseUri(String authorizationResponseBaseUri) { + Assert.hasText(authorizationResponseBaseUri, "authorizationResponseBaseUri cannot be empty"); + this.authorizationResponseBaseUri = authorizationResponseBaseUri; + return this; + } + + /** + * Returns the {@link OAuth2LoginConfigurer} for further configuration. + * @return the {@link OAuth2LoginConfigurer} + */ + public OAuth2LoginConfigurer and() { + return OAuth2LoginConfigurer.this; + } + + } + + /** + * Configuration options for the Authorization Server's UserInfo Endpoint. + */ + public final class UserInfoEndpointConfig { + + private OAuth2UserService userService; + + private OAuth2UserService oidcUserService; + + private Map> customUserTypes = new HashMap<>(); + + private UserInfoEndpointConfig() { + } + + /** + * Sets the OAuth 2.0 service used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint. + * @param userService the OAuth 2.0 service used for obtaining the user attributes + * of the End-User from the UserInfo Endpoint + * @return the {@link UserInfoEndpointConfig} for further configuration + */ + public UserInfoEndpointConfig userService(OAuth2UserService userService) { + Assert.notNull(userService, "userService cannot be null"); + this.userService = userService; + return this; + } + + /** + * Sets the OpenID Connect 1.0 service used for obtaining the user attributes of + * the End-User from the UserInfo Endpoint. + * @param oidcUserService the OpenID Connect 1.0 service used for obtaining the + * user attributes of the End-User from the UserInfo Endpoint + * @return the {@link UserInfoEndpointConfig} for further configuration + */ + public UserInfoEndpointConfig oidcUserService(OAuth2UserService oidcUserService) { + Assert.notNull(oidcUserService, "oidcUserService cannot be null"); + this.oidcUserService = oidcUserService; + return this; + } + + /** + * Sets a custom {@link OAuth2User} type and associates it to the provided client + * {@link ClientRegistration#getRegistrationId() registration identifier}. + * @deprecated See {@link CustomUserTypesOAuth2UserService} for alternative usage. + * @param customUserType a custom {@link OAuth2User} type + * @param clientRegistrationId the client registration identifier + * @return the {@link UserInfoEndpointConfig} for further configuration + */ + @Deprecated + public UserInfoEndpointConfig customUserType(Class customUserType, + String clientRegistrationId) { + Assert.notNull(customUserType, "customUserType cannot be null"); + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + this.customUserTypes.put(clientRegistrationId, customUserType); + return this; + } + + /** + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OAuth2User#getAuthorities()}. + * @param userAuthoritiesMapper the {@link GrantedAuthoritiesMapper} used for + * mapping the user's authorities + * @return the {@link UserInfoEndpointConfig} for further configuration + */ + public UserInfoEndpointConfig userAuthoritiesMapper(GrantedAuthoritiesMapper userAuthoritiesMapper) { + Assert.notNull(userAuthoritiesMapper, "userAuthoritiesMapper cannot be null"); + OAuth2LoginConfigurer.this.getBuilder().setSharedObject(GrantedAuthoritiesMapper.class, + userAuthoritiesMapper); + return this; + } + + /** + * Returns the {@link OAuth2LoginConfigurer} for further configuration. + * @return the {@link OAuth2LoginConfigurer} + */ + public OAuth2LoginConfigurer and() { + return OAuth2LoginConfigurer.this; + } + + } + private static class OidcAuthenticationRequestChecker implements AuthenticationProvider { @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2LoginAuthenticationToken authorizationCodeAuthentication = - (OAuth2LoginAuthenticationToken) authentication; - - // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope - // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (authorizationCodeAuthentication.getAuthorizationExchange() - .getAuthorizationRequest().getScopes().contains(OidcScopes.OPENID)) { - - OAuth2Error oauth2Error = new OAuth2Error( - "oidc_provider_not_configured", - "An OpenID Connect Authentication Provider has not been configured. " + - "Check to ensure you include the dependency 'spring-security-oauth2-jose'.", - null); + OAuth2LoginAuthenticationToken authorizationCodeAuthentication = (OAuth2LoginAuthenticationToken) authentication; + OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() + .getAuthorizationRequest(); + if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope + // value. + OAuth2Error oauth2Error = new OAuth2Error("oidc_provider_not_configured", + "An OpenID Connect Authentication Provider has not been configured. " + + "Check to ensure you include the dependency 'spring-security-oauth2-jose'.", + null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - return null; } @@ -744,5 +741,7 @@ public final class OAuth2LoginConfigurer> exten public boolean supports(Class authentication) { return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java index e28f27e2f7..d56ca5bb4c 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java @@ -17,6 +17,7 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.server.resource; import java.util.function.Supplier; + import javax.servlet.http.HttpServletRequest; import org.springframework.context.ApplicationContext; @@ -50,22 +51,23 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; - /** * * An {@link AbstractHttpConfigurer} for OAuth 2.0 Resource Server Support. * - * By default, this wires a {@link BearerTokenAuthenticationFilter}, which can be used to parse the request - * for bearer tokens and make an authentication attempt. + * By default, this wires a {@link BearerTokenAuthenticationFilter}, which can be used to + * parse the request for bearer tokens and make an authentication attempt. * *

    * The following configuration options are available: * *

      - *
    • {@link #accessDeniedHandler(AccessDeniedHandler)}
    • - customizes how access denied errors are handled - *
    • {@link #authenticationEntryPoint(AuthenticationEntryPoint)}
    • - customizes how authentication failures are handled - *
    • {@link #bearerTokenResolver(BearerTokenResolver)} - customizes how to resolve a bearer token from the request
    • + *
    • {@link #accessDeniedHandler(AccessDeniedHandler)}
    • - customizes how access + * denied errors are handled + *
    • {@link #authenticationEntryPoint(AuthenticationEntryPoint)}
    • - customizes how + * authentication failures are handled + *
    • {@link #bearerTokenResolver(BearerTokenResolver)} - customizes how to resolve a + * bearer token from the request
    • *
    • {@link #jwt(Customizer)} - enables Jwt-encoded bearer token support
    • *
    • {@link #opaqueToken(Customizer)} - enables opaque bearer token support
    • *
    @@ -74,33 +76,28 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSe * When using {@link #jwt(Customizer)}, either * *
      - *
    • - * supply a Jwk Set Uri via {@link JwtConfigurer#jwkSetUri}, or - *
    • - *
    • - * supply a {@link JwtDecoder} instance via {@link JwtConfigurer#decoder}, or - *
    • - *
    • - * expose a {@link JwtDecoder} bean - *
    • + *
    • supply a Jwk Set Uri via {@link JwtConfigurer#jwkSetUri}, or
    • + *
    • supply a {@link JwtDecoder} instance via {@link JwtConfigurer#decoder}, or
    • + *
    • expose a {@link JwtDecoder} bean
    • *
    * * Also with {@link #jwt(Customizer)} consider * *
      - *
    • - * customizing the conversion from a {@link Jwt} to an {@link org.springframework.security.core.Authentication} with - * {@link JwtConfigurer#jwtAuthenticationConverter(Converter)} - *
    • + *
    • customizing the conversion from a {@link Jwt} to an + * {@link org.springframework.security.core.Authentication} with + * {@link JwtConfigurer#jwtAuthenticationConverter(Converter)}
    • *
    * *

    - * When using {@link #opaqueToken(Customizer)}, supply an introspection endpoint and its authentication configuration + * When using {@link #opaqueToken(Customizer)}, supply an introspection endpoint and its + * authentication configuration *

    * *

    Security Filters

    * - * The following {@code Filter}s are populated when {@link #jwt(Customizer)} is configured: + * The following {@code Filter}s are populated when {@link #jwt(Customizer)} is + * configured: * *
      *
    • {@link BearerTokenAuthenticationFilter}
    • @@ -130,19 +127,23 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSe * @see NimbusJwtDecoder * @see AbstractHttpConfigurer */ -public final class OAuth2ResourceServerConfigurer> extends - AbstractHttpConfigurer, H> { +public final class OAuth2ResourceServerConfigurer> + extends AbstractHttpConfigurer, H> { private final ApplicationContext context; private AuthenticationManagerResolver authenticationManagerResolver; + private BearerTokenResolver bearerTokenResolver; private JwtConfigurer jwtConfigurer; + private OpaqueTokenConfigurer opaqueTokenConfigurer; private AccessDeniedHandler accessDeniedHandler = new BearerTokenAccessDeniedHandler(); + private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint(); + private BearerTokenRequestMatcher requestMatcher = new BearerTokenRequestMatcher(); public OAuth2ResourceServerConfigurer(ApplicationContext context) { @@ -162,8 +163,8 @@ public final class OAuth2ResourceServerConfigurer authenticationManagerResolver - (AuthenticationManagerResolver authenticationManagerResolver) { + public OAuth2ResourceServerConfigurer authenticationManagerResolver( + AuthenticationManagerResolver authenticationManagerResolver) { Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); this.authenticationManagerResolver = authenticationManagerResolver; return this; @@ -176,22 +177,20 @@ public final class OAuth2ResourceServerConfigurer jwt(Customizer jwtCustomizer) { - if ( this.jwtConfigurer == null ) { + if (this.jwtConfigurer == null) { this.jwtConfigurer = new JwtConfigurer(this.context); } jwtCustomizer.customize(this.jwtConfigurer); @@ -202,15 +201,13 @@ public final class OAuth2ResourceServerConfigurer opaqueToken(Customizer opaqueTokenCustomizer) { @@ -224,11 +221,9 @@ public final class OAuth2ResourceServerConfigurer authenticationManager; + resolver = (request) -> authenticationManager; } - BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(resolver); filter.setBearerTokenResolver(bearerTokenResolver); filter.setAuthenticationEntryPoint(this.authenticationEntryPoint); filter = postProcess(filter); - http.addFilter(filter); } private void validateConfiguration() { if (this.authenticationManagerResolver == null) { - if (this.jwtConfigurer == null && this.opaqueTokenConfigurer == null) { - throw new IllegalStateException("Jwt and Opaque Token are the only supported formats for bearer tokens " + - "in Spring Security and neither was found. Make sure to configure JWT " + - "via http.oauth2ResourceServer().jwt() or Opaque Tokens via " + - "http.oauth2ResourceServer().opaqueToken()."); - } - - if (this.jwtConfigurer != null && this.opaqueTokenConfigurer != null) { - throw new IllegalStateException("Spring Security only supports JWTs or Opaque Tokens, not both at the " + - "same time."); - } - } else { - if (this.jwtConfigurer != null || this.opaqueTokenConfigurer != null) { - throw new IllegalStateException("If an authenticationManagerResolver() is configured, then it takes " + - "precedence over any jwt() or opaqueToken() configuration."); - } + Assert.state(this.jwtConfigurer != null || this.opaqueTokenConfigurer != null, + "Jwt and Opaque Token are the only supported formats for bearer tokens " + + "in Spring Security and neither was found. Make sure to configure JWT " + + "via http.oauth2ResourceServer().jwt() or Opaque Tokens via " + + "http.oauth2ResourceServer().opaqueToken()."); + Assert.state(this.jwtConfigurer == null || this.opaqueTokenConfigurer == null, + "Spring Security only supports JWTs or Opaque Tokens, not both at the " + "same time."); + } + else { + Assert.state(this.jwtConfigurer == null && this.opaqueTokenConfigurer == null, + "If an authenticationManagerResolver() is configured, then it takes " + + "precedence over any jwt() or opaqueToken() configuration."); } } + private void registerDefaultAccessDeniedHandler(H http) { + ExceptionHandlingConfigurer exceptionHandling = http.getConfigurer(ExceptionHandlingConfigurer.class); + if (exceptionHandling != null) { + exceptionHandling.defaultAccessDeniedHandlerFor(this.accessDeniedHandler, this.requestMatcher); + } + } + + private void registerDefaultEntryPoint(H http) { + ExceptionHandlingConfigurer exceptionHandling = http.getConfigurer(ExceptionHandlingConfigurer.class); + if (exceptionHandling != null) { + exceptionHandling.defaultAuthenticationEntryPointFor(this.authenticationEntryPoint, this.requestMatcher); + } + } + + private void registerDefaultCsrfOverride(H http) { + CsrfConfigurer csrf = http.getConfigurer(CsrfConfigurer.class); + if (csrf != null) { + csrf.ignoringRequestMatchers(this.requestMatcher); + } + } + + AuthenticationProvider getAuthenticationProvider() { + if (this.jwtConfigurer != null) { + return this.jwtConfigurer.getAuthenticationProvider(); + } + if (this.opaqueTokenConfigurer != null) { + return this.opaqueTokenConfigurer.getAuthenticationProvider(); + } + return null; + } + + AuthenticationManager getAuthenticationManager(H http) { + if (this.jwtConfigurer != null) { + return this.jwtConfigurer.getAuthenticationManager(http); + } + if (this.opaqueTokenConfigurer != null) { + return this.opaqueTokenConfigurer.getAuthenticationManager(http); + } + return http.getSharedObject(AuthenticationManager.class); + } + + BearerTokenResolver getBearerTokenResolver() { + if (this.bearerTokenResolver == null) { + if (this.context.getBeanNamesForType(BearerTokenResolver.class).length > 0) { + this.bearerTokenResolver = this.context.getBean(BearerTokenResolver.class); + } + else { + this.bearerTokenResolver = new DefaultBearerTokenResolver(); + } + } + return this.bearerTokenResolver; + } + public class JwtConfigurer { + private final ApplicationContext context; private AuthenticationManager authenticationManager; + private JwtDecoder decoder; private Converter jwtAuthenticationConverter; @@ -299,13 +342,12 @@ public final class OAuth2ResourceServerConfigurer jwtAuthenticationConverter) { - + public JwtConfigurer jwtAuthenticationConverter( + Converter jwtAuthenticationConverter) { this.jwtAuthenticationConverter = jwtAuthenticationConverter; return this; } @@ -318,19 +360,18 @@ public final class OAuth2ResourceServerConfigurer 0) { this.jwtAuthenticationConverter = this.context.getBean(JwtAuthenticationConverter.class); - } else { + } + else { this.jwtAuthenticationConverter = new JwtAuthenticationConverter(); } } - return this.jwtAuthenticationConverter; } JwtDecoder getJwtDecoder() { - if ( this.decoder == null ) { + if (this.decoder == null) { return this.context.getBean(JwtDecoder.class); } - return this.decoder; } @@ -338,13 +379,9 @@ public final class OAuth2ResourceServerConfigurer jwtAuthenticationConverter = - getJwtAuthenticationConverter(); - - JwtAuthenticationProvider provider = - new JwtAuthenticationProvider(decoder); + Converter jwtAuthenticationConverter = getJwtAuthenticationConverter(); + JwtAuthenticationProvider provider = new JwtAuthenticationProvider(decoder); provider.setJwtAuthenticationConverter(jwtAuthenticationConverter); return postProcess(provider); } @@ -353,18 +390,23 @@ public final class OAuth2ResourceServerConfigurer introspector; OpaqueTokenConfigurer(ApplicationContext context) { @@ -380,8 +422,8 @@ public final class OAuth2ResourceServerConfigurer - new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, this.clientSecret); + this.introspector = () -> new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, + this.clientSecret); return this; } @@ -390,8 +432,8 @@ public final class OAuth2ResourceServerConfigurer - new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, this.clientSecret); + this.introspector = () -> new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, + this.clientSecret); return this; } @@ -420,96 +462,30 @@ public final class OAuth2ResourceServerConfigurer exceptionHandling = http - .getConfigurer(ExceptionHandlingConfigurer.class); - if (exceptionHandling == null) { - return; - } - - exceptionHandling.defaultAccessDeniedHandlerFor( - this.accessDeniedHandler, - this.requestMatcher); - } - - private void registerDefaultEntryPoint(H http) { - ExceptionHandlingConfigurer exceptionHandling = http - .getConfigurer(ExceptionHandlingConfigurer.class); - if (exceptionHandling == null) { - return; - } - - exceptionHandling.defaultAuthenticationEntryPointFor( - this.authenticationEntryPoint, - this.requestMatcher); - } - - private void registerDefaultCsrfOverride(H http) { - CsrfConfigurer csrf = http - .getConfigurer(CsrfConfigurer.class); - if (csrf == null) { - return; - } - - csrf.ignoringRequestMatchers(this.requestMatcher); - } - - AuthenticationProvider getAuthenticationProvider() { - if (this.jwtConfigurer != null) { - return this.jwtConfigurer.getAuthenticationProvider(); - } - - if (this.opaqueTokenConfigurer != null) { - return this.opaqueTokenConfigurer.getAuthenticationProvider(); - } - - return null; - } - - AuthenticationManager getAuthenticationManager(H http) { - if (this.jwtConfigurer != null) { - return this.jwtConfigurer.getAuthenticationManager(http); - } - - if (this.opaqueTokenConfigurer != null) { - return this.opaqueTokenConfigurer.getAuthenticationManager(http); - } - - return http.getSharedObject(AuthenticationManager.class); - } - - BearerTokenResolver getBearerTokenResolver() { - if ( this.bearerTokenResolver == null ) { - if ( this.context.getBeanNamesForType(BearerTokenResolver.class).length > 0 ) { - this.bearerTokenResolver = this.context.getBean(BearerTokenResolver.class); - } else { - this.bearerTokenResolver = new DefaultBearerTokenResolver(); - } - } - - return this.bearerTokenResolver; } private static final class BearerTokenRequestMatcher implements RequestMatcher { + private BearerTokenResolver bearerTokenResolver; @Override public boolean matches(HttpServletRequest request) { try { return this.bearerTokenResolver.resolve(request) != null; - } catch ( OAuth2AuthenticationException e ) { + } + catch (OAuth2AuthenticationException ex) { return false; } } - public void setBearerTokenResolver(BearerTokenResolver tokenResolver) { + void setBearerTokenResolver(BearerTokenResolver tokenResolver) { Assert.notNull(tokenResolver, "resolver cannot be null"); this.bearerTokenResolver = tokenResolver; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurer.java index 4fa74f0053..e54c2863a8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.openid; import java.util.ArrayList; @@ -118,16 +119,22 @@ import org.springframework.security.web.util.matcher.RequestMatcher; *
    * * @author Rob Winch - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @since 3.2 + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ -public final class OpenIDLoginConfigurer> extends - AbstractAuthenticationFilterConfigurer, OpenIDAuthenticationFilter> { +@Deprecated +public final class OpenIDLoginConfigurer> + extends AbstractAuthenticationFilterConfigurer, OpenIDAuthenticationFilter> { + private OpenIDConsumer openIDConsumer; + private ConsumerManager consumerManager; + private AuthenticationUserDetailsService authenticationUserDetailsService; + private List attributeExchangeConfigurers = new ArrayList<>(); /** @@ -139,29 +146,27 @@ public final class OpenIDLoginConfigurer> exten /** * Sets up OpenID attribute exchange for OpenID's matching the specified pattern. - * * @param identifierPattern the regular expression for matching on OpenID's (i.e. * "https://www.google.com/.*", ".*yahoo.com.*", etc) * @return a {@link AttributeExchangeConfigurer} for further customizations of the * attribute exchange */ public AttributeExchangeConfigurer attributeExchange(String identifierPattern) { - AttributeExchangeConfigurer attributeExchangeConfigurer = new AttributeExchangeConfigurer( - identifierPattern); + AttributeExchangeConfigurer attributeExchangeConfigurer = new AttributeExchangeConfigurer(identifierPattern); this.attributeExchangeConfigurers.add(attributeExchangeConfigurer); return attributeExchangeConfigurer; } /** - * Sets up OpenID attribute exchange for OpenIDs matching the specified pattern. - * The default pattern is ".*", it can be specified using + * Sets up OpenID attribute exchange for OpenIDs matching the specified pattern. The + * default pattern is ".*", it can be specified using * {@link AttributeExchangeConfigurer#identifierPattern(String)} - * - * @param attributeExchangeCustomizer the {@link Customizer} to provide more options for - * the {@link AttributeExchangeConfigurer} + * @param attributeExchangeCustomizer the {@link Customizer} to provide more options + * for the {@link AttributeExchangeConfigurer} * @return a {@link OpenIDLoginConfigurer} for further customizations */ - public OpenIDLoginConfigurer attributeExchange(Customizer attributeExchangeCustomizer) { + public OpenIDLoginConfigurer attributeExchange( + Customizer attributeExchangeCustomizer) { AttributeExchangeConfigurer attributeExchangeConfigurer = new AttributeExchangeConfigurer(".*"); attributeExchangeCustomizer.customize(attributeExchangeConfigurer); this.attributeExchangeConfigurers.add(attributeExchangeConfigurer); @@ -171,7 +176,6 @@ public final class OpenIDLoginConfigurer> exten /** * Allows specifying the {@link OpenIDConsumer} to be used. The default is using an * {@link OpenID4JavaConsumer}. - * * @param consumer the {@link OpenIDConsumer} to be used * @return the {@link OpenIDLoginConfigurer} for further customizations */ @@ -188,7 +192,6 @@ public final class OpenIDLoginConfigurer> exten * This is a shortcut for specifying the {@link OpenID4JavaConsumer} with a specific * {@link ConsumerManager} on {@link #consumer(OpenIDConsumer)}. *

    - * * @param consumerManager the {@link ConsumerManager} to use. Cannot be null. * @return the {@link OpenIDLoginConfigurer} for further customizations */ @@ -201,7 +204,6 @@ public final class OpenIDLoginConfigurer> exten * The {@link AuthenticationUserDetailsService} to use. By default a * {@link UserDetailsByNameServiceWrapper} is used with the {@link UserDetailsService} * shared object found with {@link HttpSecurity#getSharedObject(Class)}. - * * @param authenticationUserDetailsService the {@link AuthenticationDetailsSource} to * use * @return the {@link OpenIDLoginConfigurer} for further customizations @@ -216,7 +218,6 @@ public final class OpenIDLoginConfigurer> exten * Specifies the URL used to authenticate OpenID requests. If the * {@link HttpServletRequest} matches this URL the {@link OpenIDAuthenticationFilter} * will attempt to authenticate the request. The default is "/login/openid". - * * @param loginProcessingUrl the URL used to perform authentication * @return the {@link OpenIDLoginConfigurer} for additional customization */ @@ -267,7 +268,6 @@ public final class OpenIDLoginConfigurer> exten *
  • /authenticate?error GET - redirect here for failed authentication attempts
  • *
  • /authenticate?logout GET - redirect here after successfully logging out
  • * - * * @param loginPage the login page to redirect to if authentication is required (i.e. * "/login") * @return the {@link FormLoginConfigurer} for additional customization @@ -280,13 +280,10 @@ public final class OpenIDLoginConfigurer> exten @Override public void init(H http) throws Exception { super.init(http); - OpenIDAuthenticationProvider authenticationProvider = new OpenIDAuthenticationProvider(); - authenticationProvider.setAuthenticationUserDetailsService( - getAuthenticationUserDetailsService(http)); + authenticationProvider.setAuthenticationUserDetailsService(getAuthenticationUserDetailsService(http)); authenticationProvider = postProcess(authenticationProvider); http.authenticationProvider(authenticationProvider); - initDefaultLoginFilter(http); } @@ -296,13 +293,6 @@ public final class OpenIDLoginConfigurer> exten super.configure(http); } - /* - * (non-Javadoc) - * - * @see org.springframework.security.config.annotation.web.configurers. - * AbstractAuthenticationFilterConfigurer - * #createLoginProcessingUrlMatcher(java.lang.String) - */ @Override protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingUrl) { return new AntPathRequestMatcher(loginProcessingUrl); @@ -316,8 +306,7 @@ public final class OpenIDLoginConfigurer> exten */ private OpenIDConsumer getConsumer() throws ConsumerException { if (this.openIDConsumer == null) { - this.openIDConsumer = new OpenID4JavaConsumer(getConsumerManager(), - attributesToFetchFactory()); + this.openIDConsumer = new OpenID4JavaConsumer(getConsumerManager(), attributesToFetchFactory()); } return this.openIDConsumer; } @@ -337,7 +326,6 @@ public final class OpenIDLoginConfigurer> exten /** * Creates an {@link RegexBasedAxFetchListFactory} using the attributes populated by * {@link AttributeExchangeConfigurer} - * * @return the {@link AxFetchListFactory} to use */ private AxFetchListFactory attributesToFetchFactory() { @@ -352,23 +340,19 @@ public final class OpenIDLoginConfigurer> exten * Gets the {@link AuthenticationUserDetailsService} that was configured or defaults * to {@link UserDetailsByNameServiceWrapper} that uses a {@link UserDetailsService} * looked up using {@link HttpSecurity#getSharedObject(Class)} - * * @param http the current {@link HttpSecurity} * @return the {@link AuthenticationUserDetailsService}. */ - private AuthenticationUserDetailsService getAuthenticationUserDetailsService( - H http) { + private AuthenticationUserDetailsService getAuthenticationUserDetailsService(H http) { if (this.authenticationUserDetailsService != null) { return this.authenticationUserDetailsService; } - return new UserDetailsByNameServiceWrapper<>( - http.getSharedObject(UserDetailsService.class)); + return new UserDetailsByNameServiceWrapper<>(http.getSharedObject(UserDetailsService.class)); } /** * If available, initializes the {@link DefaultLoginPageGeneratingFilter} shared * object. - * * @param http the {@link HttpSecurityBuilder} to use */ private void initDefaultLoginFilter(H http) { @@ -382,8 +366,8 @@ public final class OpenIDLoginConfigurer> exten loginPageGeneratingFilter.setLoginPageUrl(getLoginPage()); loginPageGeneratingFilter.setFailureUrl(getFailureUrl()); } - loginPageGeneratingFilter.setOpenIDusernameParameter( - OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD); + loginPageGeneratingFilter + .setOpenIDusernameParameter(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD); } } @@ -393,8 +377,11 @@ public final class OpenIDLoginConfigurer> exten * @author Rob Winch */ public final class AttributeExchangeConfigurer { + private String identifier; + private List attributes = new ArrayList<>(); + private List attributeConfigurers = new ArrayList<>(); /** @@ -418,7 +405,6 @@ public final class OpenIDLoginConfigurer> exten /** * Sets the regular expression for matching on OpenID's (i.e. * "https://www.google.com/.*", ".*yahoo.com.*", etc) - * * @param identifierPattern the regular expression for matching on OpenID's * @return the {@link AttributeExchangeConfigurer} for further customization of * attribute exchange @@ -453,9 +439,8 @@ public final class OpenIDLoginConfigurer> exten } /** - * Adds an {@link OpenIDAttribute} named "default-attribute". - * The name can by updated using {@link AttributeConfigurer#name(String)}. - * + * Adds an {@link OpenIDAttribute} named "default-attribute". The name + * can by updated using {@link AttributeConfigurer#name(String)}. * @param attributeCustomizer the {@link Customizer} to provide more options for * the {@link AttributeConfigurer} * @return a {@link AttributeExchangeConfigurer} for further customizations @@ -486,14 +471,18 @@ public final class OpenIDLoginConfigurer> exten * @since 3.2 */ public final class AttributeConfigurer { + private String name; + private int count = 1; + private boolean required = false; + private String type; /** - * Creates a new instance named "default-attribute". - * The name can by updated using {@link #name(String)}. + * Creates a new instance named "default-attribute". The name can by updated + * using {@link #name(String)}. * * @see AttributeExchangeConfigurer#attribute(String) */ @@ -525,7 +514,6 @@ public final class OpenIDLoginConfigurer> exten * false. Note that as outlined in the OpenID specification, * required attributes are not validated by the OpenID Provider. Developers * should perform any validation in custom code. - * * @param required specifies the attribute is required * @return the {@link AttributeConfigurer} for further customization */ @@ -557,7 +545,6 @@ public final class OpenIDLoginConfigurer> exten /** * Gets the {@link AttributeExchangeConfigurer} for further customization of * the attributes - * * @return the {@link AttributeConfigurer} */ public AttributeExchangeConfigurer and() { @@ -574,6 +561,9 @@ public final class OpenIDLoginConfigurer> exten attribute.setRequired(this.required); return attribute; } + } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index ca7b6d2c08..4adad193d6 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.configurers.saml2; import java.util.LinkedHashMap; import java.util.Map; + import javax.servlet.Filter; import org.springframework.beans.factory.NoSuchBeanDefinitionException; @@ -46,21 +47,22 @@ import org.springframework.security.web.authentication.ui.DefaultLoginPageGenera import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; - -import static org.springframework.util.StringUtils.hasText; +import org.springframework.util.StringUtils; /** - * An {@link AbstractHttpConfigurer} for SAML 2.0 Login, - * which leverages the SAML 2.0 Web Browser Single Sign On (WebSSO) Flow. + * An {@link AbstractHttpConfigurer} for SAML 2.0 Login, which leverages the SAML 2.0 Web + * Browser Single Sign On (WebSSO) Flow. * *

    - * SAML 2.0 Login provides an application with the capability to have users log in - * by using their existing account at an SAML 2.0 Identity Provider. + * SAML 2.0 Login provides an application with the capability to have users log in by + * using their existing account at an SAML 2.0 Identity Provider. * *

    - * Defaults are provided for all configuration options with the only required configuration - * being {@link #relyingPartyRegistrationRepository(RelyingPartyRegistrationRepository)} . - * Alternatively, a {@link RelyingPartyRegistrationRepository} {@code @Bean} may be registered instead. + * Defaults are provided for all configuration options with the only required + * configuration being + * {@link #relyingPartyRegistrationRepository(RelyingPartyRegistrationRepository)} . + * Alternatively, a {@link RelyingPartyRegistrationRepository} {@code @Bean} may be + * registered instead. * *

    Security Filters

    * @@ -87,8 +89,9 @@ import static org.springframework.util.StringUtils.hasText; *
      *
    • {@link RelyingPartyRegistrationRepository} (required)
    • *
    • {@link Saml2AuthenticationRequestFactory} (optional)
    • - *
    • {@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not configured - * and {@code DefaultLoginPageGeneratingFilter} is available, than a default login page will be made available
    • + *
    • {@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not + * configured and {@code DefaultLoginPageGeneratingFilter} is available, than a default + * login page will be made available
    • *
    * * @since 5.2 @@ -98,8 +101,8 @@ import static org.springframework.util.StringUtils.hasText; * @see RelyingPartyRegistrationRepository * @see AbstractAuthenticationFilterConfigurer */ -public final class Saml2LoginConfigurer> extends - AbstractAuthenticationFilterConfigurer, Saml2WebSsoAuthenticationFilter> { +public final class Saml2LoginConfigurer> + extends AbstractAuthenticationFilterConfigurer, Saml2WebSsoAuthenticationFilter> { private String loginPage; @@ -110,14 +113,15 @@ public final class Saml2LoginConfigurer> extend private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; private AuthenticationConverter authenticationConverter; + private AuthenticationManager authenticationManager; private Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter; /** - * Use this {@link AuthenticationConverter} when converting incoming requests to an {@link Authentication}. - * By default the {@link Saml2AuthenticationTokenConverter} is used. - * + * Use this {@link AuthenticationConverter} when converting incoming requests to an + * {@link Authentication}. By default the {@link Saml2AuthenticationTokenConverter} is + * used. * @param authenticationConverter the {@link AuthenticationConverter} to use * @return the {@link Saml2LoginConfigurer} for further configuration * @since 5.4 @@ -129,12 +133,13 @@ public final class Saml2LoginConfigurer> extend } /** - * Allows a configuration of a {@link AuthenticationManager} to be used during SAML 2 authentication. - * If none is specified, the system will create one inject it into the {@link Saml2WebSsoAuthenticationFilter} + * Allows a configuration of a {@link AuthenticationManager} to be used during SAML 2 + * authentication. If none is specified, the system will create one inject it into the + * {@link Saml2WebSsoAuthenticationFilter} * @param authenticationManager the authentication manager to be used * @return the {@link Saml2LoginConfigurer} for further configuration - * @throws IllegalArgumentException if authenticationManager is null - * configure the default manager + * @throws IllegalArgumentException if authenticationManager is null configure the + * default manager * @since 5.3 */ public Saml2LoginConfigurer authenticationManager(AuthenticationManager authenticationManager) { @@ -144,8 +149,9 @@ public final class Saml2LoginConfigurer> extend } /** - * Sets the {@code RelyingPartyRegistrationRepository} of relying parties, each party representing a - * service provider, SP and this host, and identity provider, IDP pair that communicate with each other. + * Sets the {@code RelyingPartyRegistrationRepository} of relying parties, each party + * representing a service provider, SP and this host, and identity provider, IDP pair + * that communicate with each other. * @param repo the repository of relying parties * @return the {@link Saml2LoginConfigurer} for further configuration */ @@ -154,9 +160,6 @@ public final class Saml2LoginConfigurer> extend return this; } - /** - * {@inheritDoc} - */ @Override public Saml2LoginConfigurer loginPage(String loginPage) { Assert.hasText(loginPage, "loginPage cannot be empty"); @@ -164,9 +167,6 @@ public final class Saml2LoginConfigurer> extend return this; } - /** - * {@inheritDoc} - */ @Override public Saml2LoginConfigurer loginProcessingUrl(String loginProcessingUrl) { Assert.hasText(loginProcessingUrl, "loginProcessingUrl cannot be empty"); @@ -175,9 +175,6 @@ public final class Saml2LoginConfigurer> extend return this; } - /** - * {@inheritDoc} - */ @Override protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingUrl) { return new AntPathRequestMatcher(loginProcessingUrl); @@ -186,15 +183,14 @@ public final class Saml2LoginConfigurer> extend /** * {@inheritDoc} * - * Initializes this filter chain for SAML 2 Login. - * The following actions are taken: + * Initializes this filter chain for SAML 2 Login. The following actions are taken: *
      - *
    • The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}
    • - *
    • A {@link Saml2WebSsoAuthenticationFilter is configured}
    • - *
    • The {@code loginProcessingUrl} is set
    • - *
    • A custom login page is configured, or
    • - *
    • A default login page with all SAML 2.0 Identity Providers is configured
    • - *
    • An {@link OpenSamlAuthenticationProvider} is configured
    • + *
    • The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}
    • + *
    • A {@link Saml2WebSsoAuthenticationFilter is configured}
    • + *
    • The {@code loginProcessingUrl} is set
    • + *
    • A custom login page is configured, or
    • + *
    • A default login page with all SAML 2.0 Identity Providers is configured
    • + *
    • An {@link OpenSamlAuthenticationProvider} is configured
    • *
    */ @Override @@ -203,32 +199,24 @@ public final class Saml2LoginConfigurer> extend if (this.relyingPartyRegistrationRepository == null) { this.relyingPartyRegistrationRepository = getSharedOrBean(http, RelyingPartyRegistrationRepository.class); } - - saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter( - getAuthenticationConverter(http), - this.loginProcessingUrl - ); - setAuthenticationFilter(saml2WebSsoAuthenticationFilter); + this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http), + this.loginProcessingUrl); + setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); - - if (hasText(this.loginPage)) { + if (StringUtils.hasText(this.loginPage)) { // Set custom login page super.loginPage(this.loginPage); super.init(http); - } else { - final Map providerUrlMap = - getIdentityProviderUrlMap( - this.authenticationRequestEndpoint.filterProcessingUrl, - this.relyingPartyRegistrationRepository - ); - + } + else { + Map providerUrlMap = getIdentityProviderUrlMap( + this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository); boolean singleProvider = providerUrlMap.size() == 1; if (singleProvider) { // Setup auto-redirect to provider login page // when only 1 IDP is configured this.updateAuthenticationDefaults(); this.updateAccessDefaults(http); - String loginUrl = providerUrlMap.entrySet().iterator().next().getKey(); final LoginUrlAuthenticationEntryPoint entryPoint = new LoginUrlAuthenticationEntryPoint(loginUrl); registerAuthenticationEntryPoint(http, entryPoint); @@ -237,15 +225,15 @@ public final class Saml2LoginConfigurer> extend super.init(http); } } - this.initDefaultLoginFilter(http); } /** * {@inheritDoc} * - * During the {@code configure} phase, a {@link Saml2WebSsoAuthenticationRequestFilter} - * is added to handle SAML 2.0 AuthNRequest redirects + * During the {@code configure} phase, a + * {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0 + * AuthNRequest redirects */ @Override public void configure(B http) throws Exception { @@ -255,7 +243,7 @@ public final class Saml2LoginConfigurer> extend registerDefaultAuthenticationProvider(http); } else { - saml2WebSsoAuthenticationFilter.setAuthenticationManager(this.authenticationManager); + this.saml2WebSsoAuthenticationFilter.setAuthenticationManager(this.authenticationManager); } } @@ -277,44 +265,30 @@ public final class Saml2LoginConfigurer> extend if (csrf == null) { return; } - - csrf.ignoringRequestMatchers( - new AntPathRequestMatcher(loginProcessingUrl) - ); + csrf.ignoringRequestMatchers(new AntPathRequestMatcher(this.loginProcessingUrl)); } private void initDefaultLoginFilter(B http) { - DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class); + DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http + .getSharedObject(DefaultLoginPageGeneratingFilter.class); if (loginPageGeneratingFilter == null || this.isCustomLoginPage()) { return; } - loginPageGeneratingFilter.setSaml2LoginEnabled(true); - loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName( - this.getIdentityProviderUrlMap( - this.authenticationRequestEndpoint.filterProcessingUrl, - this.relyingPartyRegistrationRepository - ) - ); + loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap( + this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository)); loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage()); loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl()); } @SuppressWarnings("unchecked") - private Map getIdentityProviderUrlMap( - String authRequestPrefixUrl, - RelyingPartyRegistrationRepository idpRepo - ) { + private Map getIdentityProviderUrlMap(String authRequestPrefixUrl, + RelyingPartyRegistrationRepository idpRepo) { Map idps = new LinkedHashMap<>(); if (idpRepo instanceof Iterable) { Iterable repo = (Iterable) idpRepo; - repo.forEach( - p -> - idps.put( - authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), - p.getRegistrationId() - ) - ); + repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), + p.getRegistrationId())); } return idps; } @@ -334,8 +308,10 @@ public final class Saml2LoginConfigurer> extend } try { return context.getBean(clazz); - } catch (NoSuchBeanDefinitionException e) {} - return null; + } + catch (NoSuchBeanDefinitionException ex) { + return null; + } } private void setSharedObject(B http, Class clazz, C object) { @@ -345,6 +321,7 @@ public final class Saml2LoginConfigurer> extend } private final class AuthenticationRequestEndpointConfig { + private String filterProcessingUrl = "/saml2/authenticate/{registrationId}"; private AuthenticationRequestEndpointConfig() { @@ -353,28 +330,28 @@ public final class Saml2LoginConfigurer> extend private Filter build(B http) { Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http); - - return postProcess(new Saml2WebSsoAuthenticationRequestFilter( - contextResolver, authenticationRequestResolver)); + return postProcess( + new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver)); } private Saml2AuthenticationRequestFactory getResolver(B http) { Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class); - if (resolver == null ) { + if (resolver == null) { resolver = new OpenSamlAuthenticationRequestFactory(); } return resolver; } private Saml2AuthenticationRequestContextResolver getContextResolver(B http) { - Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class); + Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, + Saml2AuthenticationRequestContextResolver.class); if (resolver == null) { - return new DefaultSaml2AuthenticationRequestContextResolver( - new DefaultRelyingPartyRegistrationResolver( - Saml2LoginConfigurer.this.relyingPartyRegistrationRepository)); + return new DefaultSaml2AuthenticationRequestContextResolver(new DefaultRelyingPartyRegistrationResolver( + Saml2LoginConfigurer.this.relyingPartyRegistrationRepository)); } return resolver; } + } } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java index 79a476ede0..eee7e34f36 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.messaging; import java.util.ArrayList; @@ -40,15 +41,21 @@ import org.springframework.util.StringUtils; * Allows mapping security constraints using {@link MessageMatcher} to the security * expressions. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public class MessageSecurityMetadataSourceRegistry { + private static final String permitAll = "permitAll"; + private static final String denyAll = "denyAll"; + private static final String anonymous = "anonymous"; + private static final String authenticated = "authenticated"; + private static final String fullyAuthenticated = "fullyAuthenticated"; + private static final String rememberMe = "rememberMe"; private SecurityExpressionHandler> expressionHandler = new DefaultMessageSecurityExpressionHandler<>(); @@ -61,7 +68,6 @@ public class MessageSecurityMetadataSourceRegistry { /** * Maps any {@link Message} to a security expression. - * * @return the Expression to associate */ public Constraint anyMessage() { @@ -72,7 +78,6 @@ public class MessageSecurityMetadataSourceRegistry { * Maps any {@link Message} that has a null SimpMessageHeaderAccessor destination * header (i.e. CONNECT, CONNECT_ACK, HEARTBEAT, UNSUBSCRIBE, DISCONNECT, * DISCONNECT_ACK, OTHER) - * * @return the Expression to associate */ public Constraint nullDestMatcher() { @@ -81,7 +86,6 @@ public class MessageSecurityMetadataSourceRegistry { /** * Maps a {@link List} of {@link SimpDestinationMessageMatcher} instances. - * * @param typesToMatch the {@link SimpMessageType} instance to match on * @return the {@link Constraint} associated to the matchers. */ @@ -98,12 +102,10 @@ public class MessageSecurityMetadataSourceRegistry { * Maps a {@link List} of {@link SimpDestinationMessageMatcher} instances without * regard to the {@link SimpMessageType}. If no destination is found on the Message, * then the Matcher returns false. - * * @param patterns the patterns to create * {@link org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher} * from. Uses * {@link MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher)} . - * * @return the {@link Constraint} that is associated to the {@link MessageMatcher} * @see MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher) */ @@ -115,12 +117,10 @@ public class MessageSecurityMetadataSourceRegistry { * Maps a {@link List} of {@link SimpDestinationMessageMatcher} instances that match * on {@code SimpMessageType.MESSAGE}. If no destination is found on the Message, then * the Matcher returns false. - * * @param patterns the patterns to create * {@link org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher} * from. Uses * {@link MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher)}. - * * @return the {@link Constraint} that is associated to the {@link MessageMatcher} * @see MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher) */ @@ -132,12 +132,10 @@ public class MessageSecurityMetadataSourceRegistry { * Maps a {@link List} of {@link SimpDestinationMessageMatcher} instances that match * on {@code SimpMessageType.SUBSCRIBE}. If no destination is found on the Message, * then the Matcher returns false. - * * @param patterns the patterns to create * {@link org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher} * from. Uses * {@link MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher)}. - * * @return the {@link Constraint} that is associated to the {@link MessageMatcher} * @see MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher) */ @@ -148,16 +146,14 @@ public class MessageSecurityMetadataSourceRegistry { /** * Maps a {@link List} of {@link SimpDestinationMessageMatcher} instances. If no * destination is found on the Message, then the Matcher returns false. - * * @param type the {@link SimpMessageType} to match on. If null, the * {@link SimpMessageType} is not considered for matching. * @param patterns the patterns to create * {@link org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher} * from. Uses * {@link MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher)}. - * * @return the {@link Constraint} that is associated to the {@link MessageMatcher} - * @see {@link MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher)} + * @see MessageSecurityMetadataSourceRegistry#simpDestPathMatcher(PathMatcher) */ private Constraint simpDestMatchers(SimpMessageType type, String... patterns) { List matchers = new ArrayList<>(patterns.length); @@ -171,13 +167,11 @@ public class MessageSecurityMetadataSourceRegistry { * The {@link PathMatcher} to be used with the * {@link MessageSecurityMetadataSourceRegistry#simpDestMatchers(String...)}. The * default is to use the default constructor of {@link AntPathMatcher}. - * * @param pathMatcher the {@link PathMatcher} to use. Cannot be null. * @return the {@link MessageSecurityMetadataSourceRegistry} for further * customization. */ - public MessageSecurityMetadataSourceRegistry simpDestPathMatcher( - PathMatcher pathMatcher) { + public MessageSecurityMetadataSourceRegistry simpDestPathMatcher(PathMatcher pathMatcher) { Assert.notNull(pathMatcher, "pathMatcher cannot be null"); this.pathMatcher.setPathMatcher(pathMatcher); this.defaultPathMatcher = false; @@ -185,9 +179,10 @@ public class MessageSecurityMetadataSourceRegistry { } /** - * Determines if the {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set. - * - * @return true if {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set, else false. + * Determines if the {@link #simpDestPathMatcher(PathMatcher)} has been explicitly + * set. + * @return true if {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set, + * else false. */ protected boolean isSimpDestPathMatcherConfigured() { return !this.defaultPathMatcher; @@ -195,7 +190,6 @@ public class MessageSecurityMetadataSourceRegistry { /** * Maps a {@link List} of {@link MessageMatcher} instances to a security expression. - * * @param matchers the {@link MessageMatcher} instances to map. * @return The {@link Constraint} that is associated to the {@link MessageMatcher} * instances @@ -209,14 +203,15 @@ public class MessageSecurityMetadataSourceRegistry { } /** - * The {@link SecurityExpressionHandler} to be used. The - * default is to use {@link DefaultMessageSecurityExpressionHandler}. - * - * @param expressionHandler the {@link SecurityExpressionHandler} to use. Cannot be null. + * The {@link SecurityExpressionHandler} to be used. The default is to use + * {@link DefaultMessageSecurityExpressionHandler}. + * @param expressionHandler the {@link SecurityExpressionHandler} to use. Cannot be + * null. * @return the {@link MessageSecurityMetadataSourceRegistry} for further * customization. */ - public MessageSecurityMetadataSourceRegistry expressionHandler(SecurityExpressionHandler> expressionHandler) { + public MessageSecurityMetadataSourceRegistry expressionHandler( + SecurityExpressionHandler> expressionHandler) { Assert.notNull(expressionHandler, "expressionHandler cannot be null"); this.expressionHandler = expressionHandler; return this; @@ -229,17 +224,15 @@ public class MessageSecurityMetadataSourceRegistry { * This is not exposed so as not to confuse users of the API, which should never * invoke this method. *

    - * * @return the {@link MessageSecurityMetadataSource} to use */ protected MessageSecurityMetadataSource createMetadataSource() { LinkedHashMap, String> matcherToExpression = new LinkedHashMap<>(); - for (Map.Entry entry : this.matcherToExpression - .entrySet()) { + for (Map.Entry entry : this.matcherToExpression.entrySet()) { matcherToExpression.put(entry.getKey().build(), entry.getValue()); } return ExpressionBasedMessageSecurityMetadataSourceFactory - .createExpressionMessageMetadataSource(matcherToExpression, expressionHandler); + .createExpressionMessageMetadataSource(matcherToExpression, this.expressionHandler); } /** @@ -249,167 +242,14 @@ public class MessageSecurityMetadataSourceRegistry { * This is not exposed so as not to confuse users of the API, which should never need * to invoke this method. *

    - * * @return true if a mapping was added, else false */ protected boolean containsMapping() { return !this.matcherToExpression.isEmpty(); } - /** - * Represents the security constraint to be applied to the {@link MessageMatcher} - * instances. - */ - public class Constraint { - private final List messageMatchers; - - /** - * Creates a new instance - * - * @param messageMatchers the {@link MessageMatcher} instances to map to this - * constraint - */ - private Constraint(List messageMatchers) { - Assert.notEmpty(messageMatchers, "messageMatchers cannot be null or empty"); - this.messageMatchers = messageMatchers; - } - - /** - * Shortcut for specifying {@link Message} instances require a particular role. If - * you do not want to have "ROLE_" automatically inserted see - * {@link #hasAuthority(String)}. - * - * @param role the role to require (i.e. USER, ADMIN, etc). Note, it should not - * start with "ROLE_" as this is automatically inserted. - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry hasRole(String role) { - return access(MessageSecurityMetadataSourceRegistry.hasRole(role)); - } - - /** - * Shortcut for specifying {@link Message} instances require any of a number of - * roles. If you do not want to have "ROLE_" automatically inserted see - * {@link #hasAnyAuthority(String...)} - * - * @param roles the roles to require (i.e. USER, ADMIN, etc). Note, it should not - * start with "ROLE_" as this is automatically inserted. - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry hasAnyRole(String... roles) { - return access(MessageSecurityMetadataSourceRegistry.hasAnyRole(roles)); - } - - /** - * Specify that {@link Message} instances require a particular authority. - * - * @param authority the authority to require (i.e. ROLE_USER, ROLE_ADMIN, etc). - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry hasAuthority(String authority) { - return access(MessageSecurityMetadataSourceRegistry.hasAuthority(authority)); - } - - /** - * Specify that {@link Message} instances requires any of a number authorities. - * - * @param authorities the requests require at least one of the authorities (i.e. - * "ROLE_USER","ROLE_ADMIN" would mean either "ROLE_USER" or "ROLE_ADMIN" is - * required). - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry hasAnyAuthority( - String... authorities) { - return access(MessageSecurityMetadataSourceRegistry - .hasAnyAuthority(authorities)); - } - - /** - * Specify that Messages are allowed by anyone. - * - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry permitAll() { - return access(permitAll); - } - - /** - * Specify that Messages are allowed by anonymous users. - * - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry anonymous() { - return access(anonymous); - } - - /** - * Specify that Messages are allowed by users that have been remembered. - * - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - * @see RememberMeConfigurer - */ - public MessageSecurityMetadataSourceRegistry rememberMe() { - return access(rememberMe); - } - - /** - * Specify that Messages are not allowed by anyone. - * - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry denyAll() { - return access(denyAll); - } - - /** - * Specify that Messages are allowed by any authenticated user. - * - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry authenticated() { - return access(authenticated); - } - - /** - * Specify that Messages are allowed by users who have authenticated and were not - * "remembered". - * - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - * @see RememberMeConfigurer - */ - public MessageSecurityMetadataSourceRegistry fullyAuthenticated() { - return access(fullyAuthenticated); - } - - /** - * Allows specifying that Messages are secured by an arbitrary expression - * - * @param attribute the expression to secure the URLs (i.e. - * "hasRole('ROLE_USER') and hasRole('ROLE_SUPER')") - * @return the {@link MessageSecurityMetadataSourceRegistry} for further - * customization - */ - public MessageSecurityMetadataSourceRegistry access(String attribute) { - for (MatcherBuilder messageMatcher : messageMatchers) { - matcherToExpression.put(messageMatcher, attribute); - } - return MessageSecurityMetadataSourceRegistry.this; - } - } - private static String hasAnyRole(String... authorities) { - String anyAuthorities = StringUtils.arrayToDelimitedString(authorities, - "','ROLE_"); + String anyAuthorities = StringUtils.arrayToDelimitedString(authorities, "','ROLE_"); return "hasAnyRole('ROLE_" + anyAuthorities + "')"; } @@ -417,8 +257,7 @@ public class MessageSecurityMetadataSourceRegistry { Assert.notNull(role, "role cannot be null"); if (role.startsWith("ROLE_")) { throw new IllegalArgumentException( - "role should not start with 'ROLE_' since it is automatically inserted. Got '" - + role + "'"); + "role should not start with 'ROLE_' since it is automatically inserted. Got '" + role + "'"); } return "hasRole('ROLE_" + role + "')"; } @@ -432,20 +271,164 @@ public class MessageSecurityMetadataSourceRegistry { return "hasAnyAuthority('" + anyAuthorities + "')"; } - private static class PreBuiltMatcherBuilder implements MatcherBuilder { + /** + * Represents the security constraint to be applied to the {@link MessageMatcher} + * instances. + */ + public final class Constraint { + + private final List messageMatchers; + + /** + * Creates a new instance + * @param messageMatchers the {@link MessageMatcher} instances to map to this + * constraint + */ + private Constraint(List messageMatchers) { + Assert.notEmpty(messageMatchers, "messageMatchers cannot be null or empty"); + this.messageMatchers = messageMatchers; + } + + /** + * Shortcut for specifying {@link Message} instances require a particular role. If + * you do not want to have "ROLE_" automatically inserted see + * {@link #hasAuthority(String)}. + * @param role the role to require (i.e. USER, ADMIN, etc). Note, it should not + * start with "ROLE_" as this is automatically inserted. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry hasRole(String role) { + return access(MessageSecurityMetadataSourceRegistry.hasRole(role)); + } + + /** + * Shortcut for specifying {@link Message} instances require any of a number of + * roles. If you do not want to have "ROLE_" automatically inserted see + * {@link #hasAnyAuthority(String...)} + * @param roles the roles to require (i.e. USER, ADMIN, etc). Note, it should not + * start with "ROLE_" as this is automatically inserted. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry hasAnyRole(String... roles) { + return access(MessageSecurityMetadataSourceRegistry.hasAnyRole(roles)); + } + + /** + * Specify that {@link Message} instances require a particular authority. + * @param authority the authority to require (i.e. ROLE_USER, ROLE_ADMIN, etc). + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry hasAuthority(String authority) { + return access(MessageSecurityMetadataSourceRegistry.hasAuthority(authority)); + } + + /** + * Specify that {@link Message} instances requires any of a number authorities. + * @param authorities the requests require at least one of the authorities (i.e. + * "ROLE_USER","ROLE_ADMIN" would mean either "ROLE_USER" or "ROLE_ADMIN" is + * required). + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry hasAnyAuthority(String... authorities) { + return access(MessageSecurityMetadataSourceRegistry.hasAnyAuthority(authorities)); + } + + /** + * Specify that Messages are allowed by anyone. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry permitAll() { + return access(permitAll); + } + + /** + * Specify that Messages are allowed by anonymous users. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry anonymous() { + return access(anonymous); + } + + /** + * Specify that Messages are allowed by users that have been remembered. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + * @see RememberMeConfigurer + */ + public MessageSecurityMetadataSourceRegistry rememberMe() { + return access(rememberMe); + } + + /** + * Specify that Messages are not allowed by anyone. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry denyAll() { + return access(denyAll); + } + + /** + * Specify that Messages are allowed by any authenticated user. + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry authenticated() { + return access(authenticated); + } + + /** + * Specify that Messages are allowed by users who have authenticated and were not + * "remembered". + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + * @see RememberMeConfigurer + */ + public MessageSecurityMetadataSourceRegistry fullyAuthenticated() { + return access(fullyAuthenticated); + } + + /** + * Allows specifying that Messages are secured by an arbitrary expression + * @param attribute the expression to secure the URLs (i.e. "hasRole('ROLE_USER') + * and hasRole('ROLE_SUPER')") + * @return the {@link MessageSecurityMetadataSourceRegistry} for further + * customization + */ + public MessageSecurityMetadataSourceRegistry access(String attribute) { + for (MatcherBuilder messageMatcher : this.messageMatchers) { + MessageSecurityMetadataSourceRegistry.this.matcherToExpression.put(messageMatcher, attribute); + } + return MessageSecurityMetadataSourceRegistry.this; + } + + } + + private static final class PreBuiltMatcherBuilder implements MatcherBuilder { + private MessageMatcher matcher; private PreBuiltMatcherBuilder(MessageMatcher matcher) { this.matcher = matcher; } + @Override public MessageMatcher build() { - return matcher; + return this.matcher; } + } - private class PathMatcherMessageMatcherBuilder implements MatcherBuilder { + private final class PathMatcherMessageMatcherBuilder implements MatcherBuilder { + private final String pattern; + private final SimpMessageType type; private PathMatcherMessageMatcherBuilder(String pattern, SimpMessageType type) { @@ -453,62 +436,74 @@ public class MessageSecurityMetadataSourceRegistry { this.type = type; } + @Override public MessageMatcher build() { - if (type == null) { - return new SimpDestinationMessageMatcher(pattern, pathMatcher); + if (this.type == null) { + return new SimpDestinationMessageMatcher(this.pattern, + MessageSecurityMetadataSourceRegistry.this.pathMatcher); } - else if (SimpMessageType.MESSAGE == type) { - return SimpDestinationMessageMatcher.createMessageMatcher(pattern, - pathMatcher); + if (SimpMessageType.MESSAGE == this.type) { + return SimpDestinationMessageMatcher.createMessageMatcher(this.pattern, + MessageSecurityMetadataSourceRegistry.this.pathMatcher); } - else if (SimpMessageType.SUBSCRIBE == type) { - return SimpDestinationMessageMatcher.createSubscribeMatcher(pattern, - pathMatcher); + if (SimpMessageType.SUBSCRIBE == this.type) { + return SimpDestinationMessageMatcher.createSubscribeMatcher(this.pattern, + MessageSecurityMetadataSourceRegistry.this.pathMatcher); } - throw new IllegalStateException(type - + " is not supported since it does not have a destination"); + throw new IllegalStateException(this.type + " is not supported since it does not have a destination"); } + } private interface MatcherBuilder { - MessageMatcher build(); - } + MessageMatcher build(); + + } static class DelegatingPathMatcher implements PathMatcher { private PathMatcher delegate = new AntPathMatcher(); + @Override public boolean isPattern(String path) { - return delegate.isPattern(path); + return this.delegate.isPattern(path); } + @Override public boolean match(String pattern, String path) { - return delegate.match(pattern, path); + return this.delegate.match(pattern, path); } + @Override public boolean matchStart(String pattern, String path) { - return delegate.matchStart(pattern, path); + return this.delegate.matchStart(pattern, path); } + @Override public String extractPathWithinPattern(String pattern, String path) { - return delegate.extractPathWithinPattern(pattern, path); + return this.delegate.extractPathWithinPattern(pattern, path); } + @Override public Map extractUriTemplateVariables(String pattern, String path) { - return delegate.extractUriTemplateVariables(pattern, path); + return this.delegate.extractUriTemplateVariables(pattern, path); } + @Override public Comparator getPatternComparator(String path) { - return delegate.getPatternComparator(path); + return this.delegate.getPatternComparator(path); } + @Override public String combine(String pattern1, String pattern2) { - return delegate.combine(pattern1, pattern2); + return this.delegate.combine(pattern1, pattern2); } void setPathMatcher(PathMatcher pathMatcher) { this.delegate = pathMatcher; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurity.java index e3d812db50..be95d27f82 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurity.java @@ -16,16 +16,16 @@ package org.springframework.security.config.annotation.web.reactive; -import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Import; -import org.springframework.security.config.web.server.ServerHttpSecurity; - import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.security.config.web.server.ServerHttpSecurity; + /** * Add this annotation to a {@code Configuration} class to have Spring Security WebFlux * support added. User's can then create one or more {@link ServerHttpSecurity} @@ -47,6 +47,7 @@ import java.lang.annotation.Target; * return new MapReactiveUserDetailsService(user); * } * } + * * * Below is the same as our minimal configuration, but explicitly declaring the * {@code ServerHttpSecurity}. @@ -54,7 +55,6 @@ import java.lang.annotation.Target; *
      * @EnableWebFluxSecurity
      * public class MyExplicitSecurityConfiguration {
    - *     // @formatter:off
      *     @Bean
      *     public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
      *          http
    @@ -65,9 +65,7 @@ import java.lang.annotation.Target;
      *                    .formLogin();
      *          return http.build();
      *     }
    - *     // @formatter:on
      *
    - *     // @formatter:off
      *     @Bean
      *     public MapReactiveUserDetailsService userDetailsService() {
      *          UserDetails user = User.withDefaultPasswordEncoder()
    @@ -77,8 +75,8 @@ import java.lang.annotation.Target;
      *               .build();
      *          return new MapReactiveUserDetailsService(user);
      *     }
    - *     // @formatter:on
      * }
    + * 
    * * @author Rob Winch * @since 5.0 @@ -86,8 +84,9 @@ import java.lang.annotation.Target; @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) @Documented -@Import({ServerHttpSecurityConfiguration.class, WebFluxSecurityConfiguration.class, - ReactiveOAuth2ClientImportSelector.class}) +@Import({ ServerHttpSecurityConfiguration.class, WebFluxSecurityConfiguration.class, + ReactiveOAuth2ClientImportSelector.class }) @Configuration public @interface EnableWebFluxSecurity { + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java index 4cf69de47c..5a56f25650 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java @@ -16,6 +16,8 @@ package org.springframework.security.config.annotation.web.reactive; +import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportSelector; @@ -24,16 +26,14 @@ import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClient import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.ClassUtils; import org.springframework.web.reactive.config.WebFluxConfigurer; import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; -import java.util.List; - /** * {@link Configuration} for OAuth 2.0 Client support. * @@ -47,16 +47,17 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { @Override public String[] selectImports(AnnotationMetadata importingClassMetadata) { - boolean oauth2ClientPresent = ClassUtils.isPresent( - "org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader()); - - return oauth2ClientPresent ? - new String[] { "org.springframework.security.config.annotation.web.reactive.ReactiveOAuth2ClientImportSelector$OAuth2ClientWebFluxSecurityConfiguration" } : - new String[] {}; + if (!ClassUtils.isPresent("org.springframework.security.oauth2.client.registration.ClientRegistration", + getClass().getClassLoader())) { + return new String[0]; + } + return new String[] { "org.springframework.security.config.annotation.web.reactive." + + "ReactiveOAuth2ClientImportSelector$OAuth2ClientWebFluxSecurityConfiguration" }; } @Configuration(proxyBeanMethods = false) static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer { + private ReactiveClientRegistrationRepository clientRegistrationRepository; private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; @@ -66,13 +67,8 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { @Override public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) { - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build(); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder().authorizationCode().refreshToken().clientCredentials().password().build(); DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( this.clientRegistrationRepository, getAuthorizedClientRepository()); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); @@ -81,18 +77,17 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { } @Autowired(required = false) - public void setClientRegistrationRepository( - ReactiveClientRegistrationRepository clientRegistrationRepository) { + void setClientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) { this.clientRegistrationRepository = clientRegistrationRepository; } @Autowired(required = false) - public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { this.authorizedClientRepository = authorizedClientRepository; } @Autowired(required = false) - public void setAuthorizedClientService(List authorizedClientService) { + void setAuthorizedClientService(List authorizedClientService) { if (authorizedClientService.size() == 1) { this.authorizedClientService = authorizedClientService.get(0); } @@ -107,5 +102,7 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { } return null; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java index b7d4b3e3c7..676bfa5161 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java @@ -45,7 +45,9 @@ import org.springframework.web.reactive.result.method.annotation.ArgumentResolve */ @Configuration(proxyBeanMethods = false) class ServerHttpSecurityConfiguration { + private static final String BEAN_NAME_PREFIX = "org.springframework.security.config.annotation.web.reactive.HttpSecurityConfiguration."; + private static final String HTTPSECURITY_BEAN_NAME = BEAN_NAME_PREFIX + "httpSecurity"; private ReactiveAdapterRegistry adapterRegistry = new ReactiveAdapterRegistry(); @@ -87,20 +89,22 @@ class ServerHttpSecurityConfiguration { } @Bean - public WebFluxConfigurer authenticationPrincipalArgumentResolverConfigurer( + WebFluxConfigurer authenticationPrincipalArgumentResolverConfigurer( ObjectProvider authenticationPrincipalArgumentResolver) { return new WebFluxConfigurer() { + @Override public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { configurer.addCustomResolver(authenticationPrincipalArgumentResolver.getObject()); } + }; } @Bean - public AuthenticationPrincipalArgumentResolver authenticationPrincipalArgumentResolver() { + AuthenticationPrincipalArgumentResolver authenticationPrincipalArgumentResolver() { AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver( - this.adapterRegistry); + this.adapterRegistry); if (this.beanFactory != null) { resolver.setBeanResolver(new BeanFactoryResolver(this.beanFactory)); } @@ -108,7 +112,7 @@ class ServerHttpSecurityConfiguration { } @Bean - public CurrentSecurityContextArgumentResolver reactiveCurrentSecurityContextArgumentResolver() { + CurrentSecurityContextArgumentResolver reactiveCurrentSecurityContextArgumentResolver() { CurrentSecurityContextArgumentResolver resolver = new CurrentSecurityContextArgumentResolver( this.adapterRegistry); if (this.beanFactory != null) { @@ -119,12 +123,13 @@ class ServerHttpSecurityConfiguration { @Bean(HTTPSECURITY_BEAN_NAME) @Scope("prototype") - public ServerHttpSecurity httpSecurity() { + ServerHttpSecurity httpSecurity() { ContextAwareServerHttpSecurity http = new ContextAwareServerHttpSecurity(); - return http - .authenticationManager(authenticationManager()) + // @formatter:off + return http.authenticationManager(authenticationManager()) .headers().and() .logout().and(); + // @formatter:on } private ReactiveAuthenticationManager authenticationManager() { @@ -132,8 +137,8 @@ class ServerHttpSecurityConfiguration { return this.authenticationManager; } if (this.reactiveUserDetailsService != null) { - UserDetailsRepositoryReactiveAuthenticationManager manager = - new UserDetailsRepositoryReactiveAuthenticationManager(this.reactiveUserDetailsService); + UserDetailsRepositoryReactiveAuthenticationManager manager = new UserDetailsRepositoryReactiveAuthenticationManager( + this.reactiveUserDetailsService); if (this.passwordEncoder != null) { manager.setPasswordEncoder(this.passwordEncoder); } @@ -143,12 +148,13 @@ class ServerHttpSecurityConfiguration { return null; } - private static class ContextAwareServerHttpSecurity extends ServerHttpSecurity implements - ApplicationContextAware { + private static class ContextAwareServerHttpSecurity extends ServerHttpSecurity implements ApplicationContextAware { + @Override - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { super.setApplicationContext(applicationContext); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java index 7fa4b1c6dd..07119e8ee0 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java @@ -40,11 +40,13 @@ import org.springframework.web.reactive.result.view.AbstractView; */ @Configuration(proxyBeanMethods = false) class WebFluxSecurityConfiguration { + public static final int WEB_FILTER_CHAIN_FILTER_ORDER = 0 - 100; private static final String BEAN_NAME_PREFIX = "org.springframework.security.config.annotation.web.reactive.WebFluxSecurityConfiguration."; - private static final String SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME = BEAN_NAME_PREFIX + "WebFilterChainFilter"; + private static final String SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME = BEAN_NAME_PREFIX + + "WebFilterChainFilter"; public static final String REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME = "org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository"; @@ -62,18 +64,18 @@ class WebFluxSecurityConfiguration { } @Bean(SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME) - @Order(value = WEB_FILTER_CHAIN_FILTER_ORDER) - public WebFilterChainProxy springSecurityWebFilterChainFilter() { + @Order(WEB_FILTER_CHAIN_FILTER_ORDER) + WebFilterChainProxy springSecurityWebFilterChainFilter() { return new WebFilterChainProxy(getSecurityWebFilterChains()); } @Bean(name = AbstractView.REQUEST_DATA_VALUE_PROCESSOR_BEAN_NAME) - public CsrfRequestDataValueProcessor requestDataValueProcessor() { + CsrfRequestDataValueProcessor requestDataValueProcessor() { return new CsrfRequestDataValueProcessor(); } @Bean - public static BeanFactoryPostProcessor conversionServicePostProcessor() { + static BeanFactoryPostProcessor conversionServicePostProcessor() { return new RsaKeyConversionServicePostProcessor(); } @@ -96,33 +98,32 @@ class WebFluxSecurityConfiguration { * @return */ private SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { - http - .authorizeExchange() - .anyExchange().authenticated(); - + http.authorizeExchange().anyExchange().authenticated(); if (isOAuth2Present && OAuth2ClasspathGuard.shouldConfigure(this.context)) { OAuth2ClasspathGuard.configure(this.context, http); - } else { - http - .httpBasic().and() - .formLogin(); } - + else { + http.httpBasic(); + http.formLogin(); + } SecurityWebFilterChain result = http.build(); return result; } private static class OAuth2ClasspathGuard { + static void configure(ApplicationContext context, ServerHttpSecurity http) { - http - .oauth2Login().and() - .oauth2Client(); + http.oauth2Login(); + http.oauth2Client(); } static boolean shouldConfigure(ApplicationContext context) { ClassLoader loader = context.getClassLoader(); - Class reactiveClientRegistrationRepositoryClass = ClassUtils.resolveClassName(REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME, loader); + Class reactiveClientRegistrationRepositoryClass = ClassUtils + .resolveClassName(REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME, loader); return context.getBeanNamesForType(reactiveClientRegistrationRepositoryClass).length == 1; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/EnableWebMvcSecurity.java b/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/EnableWebMvcSecurity.java index 1afc2c710b..7b68cf5de7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/EnableWebMvcSecurity.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/EnableWebMvcSecurity.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.servlet.configuration; import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import org.springframework.context.annotation.Configuration; @@ -26,18 +29,18 @@ import org.springframework.security.config.annotation.authentication.configurati /** * Add this annotation to an {@code @Configuration} class to have the Spring Security * configuration integrate with Spring MVC. - * * @deprecated Use EnableWebSecurity instead which will automatically add the Spring MVC * related Security items. * @author Rob Winch * @since 3.2 */ -@Retention(value = java.lang.annotation.RetentionPolicy.RUNTIME) -@Target(value = { java.lang.annotation.ElementType.TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) @Documented @Import(WebMvcSecurityConfiguration.class) @EnableGlobalAuthentication @Configuration @Deprecated public @interface EnableWebMvcSecurity { + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/WebMvcSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/WebMvcSecurityConfiguration.java index ba1d7204fb..bbb4ea2980 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/WebMvcSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/servlet/configuration/WebMvcSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.servlet.configuration; import java.util.List; @@ -30,14 +31,17 @@ import org.springframework.web.servlet.support.RequestDataValueProcessor; /** * Used to add a {@link RequestDataValueProcessor} for Spring MVC and Spring Security CSRF * integration. This configuration is added whenever {@link EnableWebMvc} is added by - * SpringWebMvcImportSelector and the DispatcherServlet is present on the - * classpath. It also adds the {@link AuthenticationPrincipalArgumentResolver} as a + * SpringWebMvcImportSelector + * and the DispatcherServlet is present on the classpath. It also adds the + * {@link AuthenticationPrincipalArgumentResolver} as a * {@link HandlerMethodArgumentResolver}. * * @deprecated This is applied internally using SpringWebMvcImportSelector * @author Rob Winch * @since 3.2 */ +@Deprecated @Configuration(proxyBeanMethods = false) @EnableWebSecurity public class WebMvcSecurityConfiguration implements WebMvcConfigurer { @@ -54,4 +58,5 @@ public class WebMvcSecurityConfiguration implements WebMvcConfigurer { public RequestDataValueProcessor requestDataValueProcessor() { return new CsrfRequestDataValueProcessor(); } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java index 0bb45a614f..f201352f59 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.socket; import java.util.ArrayList; @@ -46,6 +47,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.util.AntPathMatcher; +import org.springframework.util.Assert; import org.springframework.util.PathMatcher; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; @@ -77,14 +79,14 @@ import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsSe * } * * - * * @since 4.0 * @author Rob Winch */ @Order(Ordered.HIGHEST_PRECEDENCE + 100) @Import(ObjectPostProcessorConfiguration.class) -public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends - AbstractWebSocketMessageBrokerConfigurer implements SmartInitializingSingleton { +public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer + implements SmartInitializingSingleton { + private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); private SecurityExpressionHandler> defaultExpressionHandler = new DefaultMessageSecurityExpressionHandler<>(); @@ -93,6 +95,7 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends private ApplicationContext context; + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { } @@ -103,12 +106,12 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends @Override public final void configureClientInboundChannel(ChannelRegistration registration) { - ChannelSecurityInterceptor inboundChannelSecurity = context.getBean(ChannelSecurityInterceptor.class); - registration.setInterceptors(context.getBean(SecurityContextChannelInterceptor.class)); + ChannelSecurityInterceptor inboundChannelSecurity = this.context.getBean(ChannelSecurityInterceptor.class); + registration.setInterceptors(this.context.getBean(SecurityContextChannelInterceptor.class)); if (!sameOriginDisabled()) { - registration.setInterceptors(context.getBean(CsrfChannelInterceptor.class)); + registration.setInterceptors(this.context.getBean(CsrfChannelInterceptor.class)); } - if (inboundRegistry.containsMapping()) { + if (this.inboundRegistry.containsMapping()) { registration.setInterceptors(inboundChannelSecurity); } customizeClientInboundChannel(registration); @@ -116,8 +119,9 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends private PathMatcher getDefaultPathMatcher() { try { - return context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher(); - } catch(NoSuchBeanDefinitionException e) { + return this.context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher(); + } + catch (NoSuchBeanDefinitionException ex) { return new AntPathMatcher(); } } @@ -131,7 +135,6 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends *

    * Subclasses can override this method to disable CSRF protection *

    - * * @return false if a CSRF token is required for connecting, else true */ protected boolean sameOriginDisabled() { @@ -141,7 +144,6 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends /** * Allows subclasses to customize the configuration of the {@link ChannelRegistration} * . - * * @param registration the {@link ChannelRegistration} to customize */ protected void customizeClientInboundChannel(ChannelRegistration registration) { @@ -153,15 +155,14 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends } @Bean - public ChannelSecurityInterceptor inboundChannelSecurity(MessageSecurityMetadataSource messageSecurityMetadataSource) { + public ChannelSecurityInterceptor inboundChannelSecurity( + MessageSecurityMetadataSource messageSecurityMetadataSource) { ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor( messageSecurityMetadataSource); MessageExpressionVoter voter = new MessageExpressionVoter<>(); voter.setExpressionHandler(getMessageExpressionHandler()); - List> voters = new ArrayList<>(); voters.add(voter); - AffirmativeBased manager = new AffirmativeBased(voters); channelSecurityInterceptor.setAccessDecisionManager(manager); return channelSecurityInterceptor; @@ -174,36 +175,17 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends @Bean public MessageSecurityMetadataSource inboundMessageSecurityMetadataSource() { - inboundRegistry.expressionHandler(getMessageExpressionHandler()); - configureInbound(inboundRegistry); - return inboundRegistry.createMetadataSource(); + this.inboundRegistry.expressionHandler(getMessageExpressionHandler()); + configureInbound(this.inboundRegistry); + return this.inboundRegistry.createMetadataSource(); } /** - * * @param messages */ protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { } - private static class WebSocketMessageSecurityMetadataSourceRegistry extends - MessageSecurityMetadataSourceRegistry { - @Override - public MessageSecurityMetadataSource createMetadataSource() { - return super.createMetadataSource(); - } - - @Override - protected boolean containsMapping() { - return super.containsMapping(); - } - - @Override - protected boolean isSimpDestPathMatcherConfigured() { - return super.isSimpDestPathMatcherConfigured(); - } - } - @Autowired public void setApplicationContext(ApplicationContext context) { this.context = context; @@ -223,69 +205,79 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends @Autowired(required = false) public void setObjectPostProcessor(ObjectPostProcessor objectPostProcessor) { - defaultExpressionHandler = objectPostProcessor.postProcess(defaultExpressionHandler); + this.defaultExpressionHandler = objectPostProcessor.postProcess(this.defaultExpressionHandler); } - private SecurityExpressionHandler> getMessageExpressionHandler() { - if (expressionHandler == null) { - return defaultExpressionHandler; + private SecurityExpressionHandler> getMessageExpressionHandler() { + if (this.expressionHandler == null) { + return this.defaultExpressionHandler; } - return expressionHandler; + return this.expressionHandler; } + @Override public void afterSingletonsInstantiated() { if (sameOriginDisabled()) { return; } - String beanName = "stompWebSocketHandlerMapping"; - SimpleUrlHandlerMapping mapping = context.getBean(beanName, - SimpleUrlHandlerMapping.class); + SimpleUrlHandlerMapping mapping = this.context.getBean(beanName, SimpleUrlHandlerMapping.class); Map mappings = mapping.getHandlerMap(); for (Object object : mappings.values()) { if (object instanceof SockJsHttpRequestHandler) { - SockJsHttpRequestHandler sockjsHandler = (SockJsHttpRequestHandler) object; - SockJsService sockJsService = sockjsHandler.getSockJsService(); - if (!(sockJsService instanceof TransportHandlingSockJsService)) { - throw new IllegalStateException( - "sockJsService must be instance of TransportHandlingSockJsService got " - + sockJsService); - } - - TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService) sockJsService; - List handshakeInterceptors = transportHandlingSockJsService - .getHandshakeInterceptors(); - List interceptorsToSet = new ArrayList<>( - handshakeInterceptors.size() + 1); - interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); - interceptorsToSet.addAll(handshakeInterceptors); - - transportHandlingSockJsService - .setHandshakeInterceptors(interceptorsToSet); + setHandshakeInterceptors((SockJsHttpRequestHandler) object); } else if (object instanceof WebSocketHttpRequestHandler) { - WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) object; - List handshakeInterceptors = handler - .getHandshakeInterceptors(); - List interceptorsToSet = new ArrayList<>( - handshakeInterceptors.size() + 1); - interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); - interceptorsToSet.addAll(handshakeInterceptors); - - handler.setHandshakeInterceptors(interceptorsToSet); + setHandshakeInterceptors((WebSocketHttpRequestHandler) object); } else { - throw new IllegalStateException( - "Bean " - + beanName - + " is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " - + object); + throw new IllegalStateException("Bean " + beanName + " is expected to contain mappings to either a " + + "SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object); } } - - if (inboundRegistry.containsMapping() && !inboundRegistry.isSimpDestPathMatcherConfigured()) { + if (this.inboundRegistry.containsMapping() && !this.inboundRegistry.isSimpDestPathMatcherConfigured()) { PathMatcher pathMatcher = getDefaultPathMatcher(); - inboundRegistry.simpDestPathMatcher(pathMatcher); + this.inboundRegistry.simpDestPathMatcher(pathMatcher); } } + + private void setHandshakeInterceptors(SockJsHttpRequestHandler handler) { + SockJsService sockJsService = handler.getSockJsService(); + Assert.state(sockJsService instanceof TransportHandlingSockJsService, + () -> "sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService); + TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService) sockJsService; + List handshakeInterceptors = transportHandlingSockJsService.getHandshakeInterceptors(); + List interceptorsToSet = new ArrayList<>(handshakeInterceptors.size() + 1); + interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); + interceptorsToSet.addAll(handshakeInterceptors); + transportHandlingSockJsService.setHandshakeInterceptors(interceptorsToSet); + } + + private void setHandshakeInterceptors(WebSocketHttpRequestHandler handler) { + List handshakeInterceptors = handler.getHandshakeInterceptors(); + List interceptorsToSet = new ArrayList<>(handshakeInterceptors.size() + 1); + interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); + interceptorsToSet.addAll(handshakeInterceptors); + handler.setHandshakeInterceptors(interceptorsToSet); + } + + private static class WebSocketMessageSecurityMetadataSourceRegistry extends MessageSecurityMetadataSourceRegistry { + + @Override + public MessageSecurityMetadataSource createMetadataSource() { + return super.createMetadataSource(); + } + + @Override + protected boolean containsMapping() { + return super.containsMapping(); + } + + @Override + protected boolean isSimpDestPathMatcherConfigured() { + return super.isSimpDestPathMatcherConfigured(); + } + + } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/AbstractUserDetailsServiceBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/authentication/AbstractUserDetailsServiceBeanDefinitionParser.java index ae1a6e9d21..883e98bc39 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/AbstractUserDetailsServiceBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/authentication/AbstractUserDetailsServiceBeanDefinitionParser.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; +import org.w3c.dom.Element; + import org.springframework.beans.factory.BeanDefinitionStoreException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -27,76 +30,61 @@ import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.authentication.CachingUserDetailsService; import org.springframework.security.config.BeanIds; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * @author Luke Taylor */ -public abstract class AbstractUserDetailsServiceBeanDefinitionParser implements - BeanDefinitionParser { +public abstract class AbstractUserDetailsServiceBeanDefinitionParser implements BeanDefinitionParser { + static final String CACHE_REF = "cache-ref"; + public static final String CACHING_SUFFIX = ".caching"; protected abstract String getBeanClassName(Element element); - protected abstract void doParse(Element element, ParserContext parserContext, - BeanDefinitionBuilder builder); + protected abstract void doParse(Element element, ParserContext parserContext, BeanDefinitionBuilder builder); + @Override public BeanDefinition parse(Element element, ParserContext parserContext) { - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(getBeanClassName(element)); - + BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(getBeanClassName(element)); doParse(element, parserContext, builder); - RootBeanDefinition userService = (RootBeanDefinition) builder.getBeanDefinition(); - final String beanId = resolveId(element, userService, parserContext); - - parserContext.registerBeanComponent(new BeanComponentDefinition(userService, - beanId)); - + String beanId = resolveId(element, userService, parserContext); + parserContext.registerBeanComponent(new BeanComponentDefinition(userService, beanId)); String cacheRef = element.getAttribute(CACHE_REF); - // Register a caching version of the user service if there's a cache-ref if (StringUtils.hasText(cacheRef)) { BeanDefinitionBuilder cachingUSBuilder = BeanDefinitionBuilder .rootBeanDefinition(CachingUserDetailsService.class); cachingUSBuilder.addConstructorArgReference(beanId); - - cachingUSBuilder.addPropertyValue("userCache", new RuntimeBeanReference( - cacheRef)); + cachingUSBuilder.addPropertyValue("userCache", new RuntimeBeanReference(cacheRef)); BeanDefinition cachingUserService = cachingUSBuilder.getBeanDefinition(); - parserContext.registerBeanComponent(new BeanComponentDefinition( - cachingUserService, beanId + CACHING_SUFFIX)); + parserContext + .registerBeanComponent(new BeanComponentDefinition(cachingUserService, beanId + CACHING_SUFFIX)); } - return null; } - private String resolveId(Element element, AbstractBeanDefinition definition, - ParserContext pc) throws BeanDefinitionStoreException { - + private String resolveId(Element element, AbstractBeanDefinition definition, ParserContext pc) + throws BeanDefinitionStoreException { String id = element.getAttribute("id"); - if (pc.isNested()) { // We're inside an element if (!StringUtils.hasText(id)) { id = pc.getReaderContext().generateBeanName(definition); } BeanDefinition container = pc.getContainingBeanDefinition(); - container.getPropertyValues().add("userDetailsService", - new RuntimeBeanReference(id)); + container.getPropertyValues().add("userDetailsService", new RuntimeBeanReference(id)); } - if (StringUtils.hasText(id)) { return id; } - // If top level, use the default name or throw an exception if already used if (pc.getRegistry().containsBeanDefinition(BeanIds.USER_DETAILS_SERVICE)) { - throw new BeanDefinitionStoreException("No id supplied and another " - + "bean is already registered as " + BeanIds.USER_DETAILS_SERVICE); + throw new BeanDefinitionStoreException( + "No id supplied and another bean is already registered as " + BeanIds.USER_DETAILS_SERVICE); } - return BeanIds.USER_DETAILS_SERVICE; } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParser.java index 57342c9c81..0439549e28 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParser.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; import java.util.List; +import org.w3c.dom.Element; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -36,9 +41,6 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; -import org.w3c.dom.Node; -import org.w3c.dom.NodeList; /** * Registers the central ProviderManager used by the namespace configuration, and allows @@ -48,123 +50,103 @@ import org.w3c.dom.NodeList; * @author Luke Taylor */ public class AuthenticationManagerBeanDefinitionParser implements BeanDefinitionParser { + private static final String ATT_ALIAS = "alias"; + private static final String ATT_REF = "ref"; + private static final String ATT_ERASE_CREDENTIALS = "erase-credentials"; + @Override public BeanDefinition parse(Element element, ParserContext pc) { String id = element.getAttribute("id"); - if (!StringUtils.hasText(id)) { if (pc.getRegistry().containsBeanDefinition(BeanIds.AUTHENTICATION_MANAGER)) { - pc.getReaderContext().warning( - "Overriding globally registered AuthenticationManager", + pc.getReaderContext().warning("Overriding globally registered AuthenticationManager", pc.extractSource(element)); } id = BeanIds.AUTHENTICATION_MANAGER; } - pc.pushContainingComponent(new CompositeComponentDefinition(element.getTagName(), - pc.extractSource(element))); - - BeanDefinitionBuilder providerManagerBldr = BeanDefinitionBuilder - .rootBeanDefinition(ProviderManager.class); - + pc.pushContainingComponent(new CompositeComponentDefinition(element.getTagName(), pc.extractSource(element))); + BeanDefinitionBuilder providerManagerBldr = BeanDefinitionBuilder.rootBeanDefinition(ProviderManager.class); String alias = element.getAttribute(ATT_ALIAS); - List providers = new ManagedList<>(); - NamespaceHandlerResolver resolver = pc.getReaderContext() - .getNamespaceHandlerResolver(); - + NamespaceHandlerResolver resolver = pc.getReaderContext().getNamespaceHandlerResolver(); NodeList children = element.getChildNodes(); - for (int i = 0; i < children.getLength(); i++) { Node node = children.item(i); if (node instanceof Element) { - Element providerElt = (Element) node; - if (StringUtils.hasText(providerElt.getAttribute(ATT_REF))) { - if (providerElt.getAttributes().getLength() > 1) { - pc.getReaderContext().error( - "authentication-provider element cannot be used with other attributes " - + "when using 'ref' attribute", - pc.extractSource(element)); - } - NodeList providerChildren = providerElt.getChildNodes(); - for (int j = 0; j < providerChildren.getLength(); j++) { - if (providerChildren.item(j) instanceof Element) { - pc.getReaderContext().error( - "authentication-provider element cannot have child elements when used " - + "with 'ref' attribute", - pc.extractSource(element)); - } - } - providers.add(new RuntimeBeanReference(providerElt - .getAttribute(ATT_REF))); - } - else { - BeanDefinition provider = resolver.resolve( - providerElt.getNamespaceURI()).parse(providerElt, pc); - Assert.notNull(provider, () -> "Parser for " + providerElt.getNodeName() - + " returned a null bean definition"); - String providerId = pc.getReaderContext().generateBeanName(provider); - pc.registerBeanComponent(new BeanComponentDefinition(provider, - providerId)); - providers.add(new RuntimeBeanReference(providerId)); - } + providers.add(extracted(element, pc, resolver, (Element) node)); } } - if (providers.isEmpty()) { providers.add(new RootBeanDefinition(NullAuthenticationProvider.class)); } - providerManagerBldr.addConstructorArgValue(providers); - if ("false".equals(element.getAttribute(ATT_ERASE_CREDENTIALS))) { - providerManagerBldr.addPropertyValue("eraseCredentialsAfterAuthentication", - false); + providerManagerBldr.addPropertyValue("eraseCredentialsAfterAuthentication", false); } - // Add the default event publisher - BeanDefinition publisher = new RootBeanDefinition( - DefaultAuthenticationEventPublisher.class); + BeanDefinition publisher = new RootBeanDefinition(DefaultAuthenticationEventPublisher.class); String pubId = pc.getReaderContext().generateBeanName(publisher); pc.registerBeanComponent(new BeanComponentDefinition(publisher, pubId)); providerManagerBldr.addPropertyReference("authenticationEventPublisher", pubId); - - pc.registerBeanComponent(new BeanComponentDefinition(providerManagerBldr - .getBeanDefinition(), id)); - + pc.registerBeanComponent(new BeanComponentDefinition(providerManagerBldr.getBeanDefinition(), id)); if (StringUtils.hasText(alias)) { pc.getRegistry().registerAlias(id, alias); - pc.getReaderContext().fireAliasRegistered(id, alias, - pc.extractSource(element)); + pc.getReaderContext().fireAliasRegistered(id, alias, pc.extractSource(element)); } if (!BeanIds.AUTHENTICATION_MANAGER.equals(id)) { pc.getRegistry().registerAlias(id, BeanIds.AUTHENTICATION_MANAGER); - pc.getReaderContext().fireAliasRegistered(id, BeanIds.AUTHENTICATION_MANAGER, - pc.extractSource(element)); + pc.getReaderContext().fireAliasRegistered(id, BeanIds.AUTHENTICATION_MANAGER, pc.extractSource(element)); } - pc.popAndRegisterContainingComponent(); - return null; } + private BeanMetadataElement extracted(Element element, ParserContext pc, NamespaceHandlerResolver resolver, + Element providerElement) { + String ref = providerElement.getAttribute(ATT_REF); + if (!StringUtils.hasText(ref)) { + BeanDefinition provider = resolver.resolve(providerElement.getNamespaceURI()).parse(providerElement, pc); + Assert.notNull(provider, + () -> "Parser for " + providerElement.getNodeName() + " returned a null bean definition"); + String providerId = pc.getReaderContext().generateBeanName(provider); + pc.registerBeanComponent(new BeanComponentDefinition(provider, providerId)); + return new RuntimeBeanReference(providerId); + } + if (providerElement.getAttributes().getLength() > 1) { + pc.getReaderContext().error("authentication-provider element cannot be used with other attributes " + + "when using 'ref' attribute", pc.extractSource(element)); + } + NodeList providerChildren = providerElement.getChildNodes(); + for (int i = 0; i < providerChildren.getLength(); i++) { + if (providerChildren.item(i) instanceof Element) { + pc.getReaderContext().error("authentication-provider element cannot have child elements when used " + + "with 'ref' attribute", pc.extractSource(element)); + } + } + return new RuntimeBeanReference(ref); + } + /** * Provider which doesn't provide any service. Only used to prevent a configuration * exception if the provider list is empty (usually because a child ProviderManager * from the <http> namespace, such as OpenID, is expected to handle the * request). */ - public static final class NullAuthenticationProvider implements - AuthenticationProvider { - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + public static final class NullAuthenticationProvider implements AuthenticationProvider { + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { return null; } + @Override public boolean supports(Class authentication) { return false; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerFactoryBean.java b/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerFactoryBean.java index 2340bd4fc9..ce199e8ea3 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/authentication/AuthenticationManagerFactoryBean.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; +import java.util.Arrays; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; @@ -28,8 +31,6 @@ import org.springframework.security.config.BeanIds; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.crypto.password.PasswordEncoder; -import java.util.Arrays; - /** * Factory bean for the namespace AuthenticationManager, which allows a more meaningful * error message to be reported in the NoSuchBeanDefinitionException, if the user @@ -38,28 +39,27 @@ import java.util.Arrays; * @author Luke Taylor * @since 3.0 */ -public class AuthenticationManagerFactoryBean implements - FactoryBean, BeanFactoryAware { +public class AuthenticationManagerFactoryBean implements FactoryBean, BeanFactoryAware { + private BeanFactory bf; + public static final String MISSING_BEAN_ERROR_MESSAGE = "Did you forget to add a global element " + "to your configuration (with child elements)? Alternatively you can use the " + "authentication-manager-ref attribute on your and elements."; + @Override public AuthenticationManager getObject() throws Exception { try { - return (AuthenticationManager) bf.getBean(BeanIds.AUTHENTICATION_MANAGER); + return (AuthenticationManager) this.bf.getBean(BeanIds.AUTHENTICATION_MANAGER); } - catch (NoSuchBeanDefinitionException e) { - if (!BeanIds.AUTHENTICATION_MANAGER.equals(e.getBeanName())) { - throw e; + catch (NoSuchBeanDefinitionException ex) { + if (!BeanIds.AUTHENTICATION_MANAGER.equals(ex.getBeanName())) { + throw ex; } - UserDetailsService uds = getBeanOrNull(UserDetailsService.class); if (uds == null) { - throw new NoSuchBeanDefinitionException(BeanIds.AUTHENTICATION_MANAGER, - MISSING_BEAN_ERROR_MESSAGE); + throw new NoSuchBeanDefinitionException(BeanIds.AUTHENTICATION_MANAGER, MISSING_BEAN_ERROR_MESSAGE); } - DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setUserDetailsService(uds); PasswordEncoder passwordEncoder = getBeanOrNull(PasswordEncoder.class); @@ -67,27 +67,32 @@ public class AuthenticationManagerFactoryBean implements provider.setPasswordEncoder(passwordEncoder); } provider.afterPropertiesSet(); - return new ProviderManager(Arrays. asList(provider)); + return new ProviderManager(Arrays.asList(provider)); } } + @Override public Class getObjectType() { return ProviderManager.class; } + @Override public boolean isSingleton() { return true; } + @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - bf = beanFactory; + this.bf = beanFactory; } private T getBeanOrNull(Class type) { try { return this.bf.getBean(type); - } catch (NoSuchBeanDefinitionException noUds) { + } + catch (NoSuchBeanDefinitionException noUds) { return null; } } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParser.java index 6b56bff9e5..973edde649 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParser.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -25,7 +28,6 @@ import org.springframework.security.authentication.dao.DaoAuthenticationProvider import org.springframework.security.config.Elements; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** * Wraps a UserDetailsService bean with a DaoAuthenticationProvider and registers the @@ -34,46 +36,35 @@ import org.w3c.dom.Element; * @author Luke Taylor */ public class AuthenticationProviderBeanDefinitionParser implements BeanDefinitionParser { + private static final String ATT_USER_DETAILS_REF = "user-service-ref"; + @Override public BeanDefinition parse(Element element, ParserContext pc) { RootBeanDefinition authProvider = new RootBeanDefinition(DaoAuthenticationProvider.class); authProvider.setSource(pc.extractSource(element)); - Element passwordEncoderElt = DomUtils.getChildElementByTagName(element, Elements.PASSWORD_ENCODER); - PasswordEncoderParser pep = new PasswordEncoderParser(passwordEncoderElt, pc); BeanMetadataElement passwordEncoder = pep.getPasswordEncoder(); if (passwordEncoder != null) { - authProvider.getPropertyValues() - .addPropertyValue("passwordEncoder", passwordEncoder); + authProvider.getPropertyValues().addPropertyValue("passwordEncoder", passwordEncoder); } - - Element userServiceElt = DomUtils.getChildElementByTagName(element, - Elements.USER_SERVICE); + Element userServiceElt = DomUtils.getChildElementByTagName(element, Elements.USER_SERVICE); if (userServiceElt == null) { - userServiceElt = DomUtils.getChildElementByTagName(element, - Elements.JDBC_USER_SERVICE); + userServiceElt = DomUtils.getChildElementByTagName(element, Elements.JDBC_USER_SERVICE); } if (userServiceElt == null) { - userServiceElt = DomUtils.getChildElementByTagName(element, - Elements.LDAP_USER_SERVICE); + userServiceElt = DomUtils.getChildElementByTagName(element, Elements.LDAP_USER_SERVICE); } - String ref = element.getAttribute(ATT_USER_DETAILS_REF); - if (StringUtils.hasText(ref)) { if (userServiceElt != null) { - pc.getReaderContext().error( - "The " + ATT_USER_DETAILS_REF - + " attribute cannot be used in combination with child" - + "elements '" + Elements.USER_SERVICE + "', '" - + Elements.JDBC_USER_SERVICE + "' or '" + pc.getReaderContext() + .error("The " + ATT_USER_DETAILS_REF + " attribute cannot be used in combination with child" + + "elements '" + Elements.USER_SERVICE + "', '" + Elements.JDBC_USER_SERVICE + "' or '" + Elements.LDAP_USER_SERVICE + "'", element); } - - authProvider.getPropertyValues().add("userDetailsService", - new RuntimeBeanReference(ref)); + authProvider.getPropertyValues().add("userDetailsService", new RuntimeBeanReference(ref)); } else { // Use the child elements to create the UserDetailsService @@ -83,17 +74,14 @@ public class AuthenticationProviderBeanDefinitionParser implements BeanDefinitio else { pc.getReaderContext().error("A user-service is required", element); } - // Pinch the cache-ref from the UserDetailService element, if set. - String cacheRef = userServiceElt - .getAttribute(AbstractUserDetailsServiceBeanDefinitionParser.CACHE_REF); + String cacheRef = userServiceElt.getAttribute(AbstractUserDetailsServiceBeanDefinitionParser.CACHE_REF); if (StringUtils.hasText(cacheRef)) { - authProvider.getPropertyValues().addPropertyValue("userCache", - new RuntimeBeanReference(cacheRef)); + authProvider.getPropertyValues().addPropertyValue("userCache", new RuntimeBeanReference(cacheRef)); } } - return authProvider; } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParser.java index 18e35d3286..d1e192a6e7 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParser.java @@ -13,64 +13,59 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; -import org.springframework.security.config.Elements; -import org.springframework.util.StringUtils; +import org.w3c.dom.Element; + import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.xml.ParserContext; - -import org.w3c.dom.Element; +import org.springframework.security.config.Elements; +import org.springframework.util.StringUtils; /** * @author Luke Taylor */ -public class JdbcUserServiceBeanDefinitionParser extends - AbstractUserDetailsServiceBeanDefinitionParser { +public class JdbcUserServiceBeanDefinitionParser extends AbstractUserDetailsServiceBeanDefinitionParser { + static final String ATT_DATA_SOURCE = "data-source-ref"; static final String ATT_USERS_BY_USERNAME_QUERY = "users-by-username-query"; static final String ATT_AUTHORITIES_BY_USERNAME_QUERY = "authorities-by-username-query"; static final String ATT_GROUP_AUTHORITIES_QUERY = "group-authorities-by-username-query"; static final String ATT_ROLE_PREFIX = "role-prefix"; + @Override protected String getBeanClassName(Element element) { return "org.springframework.security.provisioning.JdbcUserDetailsManager"; } - protected void doParse(Element element, ParserContext parserContext, - BeanDefinitionBuilder builder) { + @Override + protected void doParse(Element element, ParserContext parserContext, BeanDefinitionBuilder builder) { String dataSource = element.getAttribute(ATT_DATA_SOURCE); - if (dataSource != null) { builder.addPropertyReference("dataSource", dataSource); } else { - parserContext.getReaderContext().error( - ATT_DATA_SOURCE + " is required for " + Elements.JDBC_USER_SERVICE, + parserContext.getReaderContext().error(ATT_DATA_SOURCE + " is required for " + Elements.JDBC_USER_SERVICE, parserContext.extractSource(element)); } - String usersQuery = element.getAttribute(ATT_USERS_BY_USERNAME_QUERY); String authoritiesQuery = element.getAttribute(ATT_AUTHORITIES_BY_USERNAME_QUERY); String groupAuthoritiesQuery = element.getAttribute(ATT_GROUP_AUTHORITIES_QUERY); String rolePrefix = element.getAttribute(ATT_ROLE_PREFIX); - if (StringUtils.hasText(rolePrefix)) { builder.addPropertyValue("rolePrefix", rolePrefix); } - if (StringUtils.hasText(usersQuery)) { builder.addPropertyValue("usersByUsernameQuery", usersQuery); } - if (StringUtils.hasText(authoritiesQuery)) { builder.addPropertyValue("authoritiesByUsernameQuery", authoritiesQuery); } - if (StringUtils.hasText(groupAuthoritiesQuery)) { builder.addPropertyValue("enableGroups", Boolean.TRUE); - builder.addPropertyValue("groupAuthoritiesByUsernameQuery", - groupAuthoritiesQuery); + builder.addPropertyValue("groupAuthoritiesByUsernameQuery", groupAuthoritiesQuery); } } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/PasswordEncoderParser.java b/config/src/main/java/org/springframework/security/config/authentication/PasswordEncoderParser.java index 0e16f1f1c6..3fced6b18f 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/PasswordEncoderParser.java +++ b/config/src/main/java/org/springframework/security/config/authentication/PasswordEncoderParser.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; -import java.util.HashMap; + +import java.util.Collections; import java.util.Map; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -27,7 +29,6 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * Stateful parser for the <password-encoder> element. @@ -35,19 +36,17 @@ import org.w3c.dom.Element; * @author Luke Taylor */ public class PasswordEncoderParser { + static final String ATT_REF = "ref"; + public static final String ATT_HASH = "hash"; + static final String ATT_BASE_64 = "base64"; + static final String OPT_HASH_BCRYPT = "bcrypt"; - private static final Map> ENCODER_CLASSES; - - static { - ENCODER_CLASSES = new HashMap<>(); - ENCODER_CLASSES.put(OPT_HASH_BCRYPT, BCryptPasswordEncoder.class); - } - - private static final Log logger = LogFactory.getLog(PasswordEncoderParser.class); + private static final Map> ENCODER_CLASSES = Collections.singletonMap(OPT_HASH_BCRYPT, + BCryptPasswordEncoder.class); private BeanMetadataElement passwordEncoder; @@ -63,33 +62,26 @@ public class PasswordEncoderParser { return; } String hash = element.getAttribute(ATT_HASH); - boolean useBase64 = false; - - if (StringUtils.hasText(element.getAttribute(ATT_BASE_64))) { - useBase64 = Boolean.parseBoolean(element.getAttribute(ATT_BASE_64)); - } - + boolean useBase64 = StringUtils.hasText(element.getAttribute(ATT_BASE_64)) + && Boolean.parseBoolean(element.getAttribute(ATT_BASE_64)); String ref = element.getAttribute(ATT_REF); - if (StringUtils.hasText(ref)) { - passwordEncoder = new RuntimeBeanReference(ref); + this.passwordEncoder = new RuntimeBeanReference(ref); } else { - passwordEncoder = createPasswordEncoderBeanDefinition(hash, useBase64); - ((RootBeanDefinition) passwordEncoder).setSource(parserContext - .extractSource(element)); + this.passwordEncoder = createPasswordEncoderBeanDefinition(hash, useBase64); + ((RootBeanDefinition) this.passwordEncoder).setSource(parserContext.extractSource(element)); } } - public static BeanDefinition createPasswordEncoderBeanDefinition(String hash, - boolean useBase64) { + public static BeanDefinition createPasswordEncoderBeanDefinition(String hash, boolean useBase64) { Class beanClass = ENCODER_CLASSES.get(hash); - BeanDefinitionBuilder beanBldr = BeanDefinitionBuilder - .rootBeanDefinition(beanClass); + BeanDefinitionBuilder beanBldr = BeanDefinitionBuilder.rootBeanDefinition(beanClass); return beanBldr.getBeanDefinition(); } public BeanMetadataElement getPasswordEncoder() { - return passwordEncoder; + return this.passwordEncoder; } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParser.java index 9394b24460..eecd64bf28 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParser.java @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.util.List; +import org.w3c.dom.Element; + import org.springframework.beans.factory.BeanDefinitionStoreException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.PropertiesFactoryBean; @@ -32,14 +35,12 @@ import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** * @author Luke Taylor * @author Ben Alex */ -public class UserServiceBeanDefinitionParser extends - AbstractUserDetailsServiceBeanDefinitionParser { +public class UserServiceBeanDefinitionParser extends AbstractUserDetailsServiceBeanDefinitionParser { static final String ATT_PASSWORD = "password"; static final String ATT_NAME = "name"; @@ -51,57 +52,44 @@ public class UserServiceBeanDefinitionParser extends private SecureRandom random; + @Override protected String getBeanClassName(Element element) { return InMemoryUserDetailsManager.class.getName(); } + @Override @SuppressWarnings("unchecked") - protected void doParse(Element element, ParserContext parserContext, - BeanDefinitionBuilder builder) { + protected void doParse(Element element, ParserContext parserContext, BeanDefinitionBuilder builder) { String userProperties = element.getAttribute(ATT_PROPERTIES); List userElts = DomUtils.getChildElementsByTagName(element, ELT_USER); - if (StringUtils.hasText(userProperties)) { - if (!CollectionUtils.isEmpty(userElts)) { throw new BeanDefinitionStoreException( "Use of a properties file and user elements are mutually exclusive"); } - BeanDefinition bd = new RootBeanDefinition(PropertiesFactoryBean.class); bd.getPropertyValues().addPropertyValue("location", userProperties); builder.addConstructorArgValue(bd); - return; } - if (CollectionUtils.isEmpty(userElts)) { - throw new BeanDefinitionStoreException( - "You must supply user definitions, either with <" + ELT_USER - + "> child elements or a " + "properties file (using the '" - + ATT_PROPERTIES + "' attribute)"); + throw new BeanDefinitionStoreException("You must supply user definitions, either with <" + ELT_USER + + "> child elements or a " + "properties file (using the '" + ATT_PROPERTIES + "' attribute)"); } - ManagedList users = new ManagedList<>(); - for (Object elt : userElts) { Element userElt = (Element) elt; String userName = userElt.getAttribute(ATT_NAME); String password = userElt.getAttribute(ATT_PASSWORD); - if (!StringUtils.hasLength(password)) { password = generateRandomPassword(); } - boolean locked = "true".equals(userElt.getAttribute(ATT_LOCKED)); boolean disabled = "true".equals(userElt.getAttribute(ATT_DISABLED)); - BeanDefinitionBuilder authorities = BeanDefinitionBuilder - .rootBeanDefinition(AuthorityUtils.class); + BeanDefinitionBuilder authorities = BeanDefinitionBuilder.rootBeanDefinition(AuthorityUtils.class); authorities.addConstructorArgValue(userElt.getAttribute(ATT_AUTHORITIES)); authorities.setFactoryMethod("commaSeparatedStringToAuthorityList"); - - BeanDefinitionBuilder user = BeanDefinitionBuilder - .rootBeanDefinition(User.class); + BeanDefinitionBuilder user = BeanDefinitionBuilder.rootBeanDefinition(User.class); user.addConstructorArgValue(userName); user.addConstructorArgValue(password); user.addConstructorArgValue(!disabled); @@ -109,23 +97,22 @@ public class UserServiceBeanDefinitionParser extends user.addConstructorArgValue(true); user.addConstructorArgValue(!locked); user.addConstructorArgValue(authorities.getBeanDefinition()); - users.add(user.getBeanDefinition()); } - builder.addConstructorArgValue(users); } private String generateRandomPassword() { - if (random == null) { + if (this.random == null) { try { - random = SecureRandom.getInstance("SHA1PRNG"); + this.random = SecureRandom.getInstance("SHA1PRNG"); } - catch (NoSuchAlgorithmException e) { + catch (NoSuchAlgorithmException ex) { // Shouldn't happen... throw new RuntimeException("Failed find SHA1PRNG algorithm!"); } } - return Long.toString(random.nextLong()); + return Long.toString(this.random.nextLong()); } + } diff --git a/config/src/main/java/org/springframework/security/config/authentication/package-info.java b/config/src/main/java/org/springframework/security/config/authentication/package-info.java index 8c0422b51c..e28c505eb6 100644 --- a/config/src/main/java/org/springframework/security/config/authentication/package-info.java +++ b/config/src/main/java/org/springframework/security/config/authentication/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Parsing of <authentication-manager> and related elements. */ package org.springframework.security.config.authentication; - diff --git a/config/src/main/java/org/springframework/security/config/core/GrantedAuthorityDefaults.java b/config/src/main/java/org/springframework/security/config/core/GrantedAuthorityDefaults.java index 62d548262d..fe338054b9 100644 --- a/config/src/main/java/org/springframework/security/config/core/GrantedAuthorityDefaults.java +++ b/config/src/main/java/org/springframework/security/config/core/GrantedAuthorityDefaults.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.core; import org.springframework.security.core.GrantedAuthority; @@ -33,10 +34,10 @@ public final class GrantedAuthorityDefaults { /** * The default prefix used with role based authorization. Default is "ROLE_". - * * @return the default role prefix */ public String getRolePrefix() { return this.rolePrefix; } + } diff --git a/config/src/main/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBean.java b/config/src/main/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBean.java index 5f038883d6..34c3ab73f8 100644 --- a/config/src/main/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBean.java @@ -16,6 +16,8 @@ package org.springframework.security.config.core.userdetails; +import java.util.Collection; + import org.springframework.beans.factory.FactoryBean; import org.springframework.context.ResourceLoaderAware; import org.springframework.core.io.Resource; @@ -24,22 +26,22 @@ import org.springframework.security.core.userdetails.MapReactiveUserDetailsServi import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.util.InMemoryResource; -import java.util.Collection; - /** - * Constructs an {@link MapReactiveUserDetailsService} from a resource using {@link UserDetailsResourceFactoryBean}. + * Constructs an {@link MapReactiveUserDetailsService} from a resource using + * {@link UserDetailsResourceFactoryBean}. * * @author Rob Winch * @since 5.0 * @see UserDetailsResourceFactoryBean */ public class ReactiveUserDetailsServiceResourceFactoryBean - implements ResourceLoaderAware, FactoryBean { + implements ResourceLoaderAware, FactoryBean { + private UserDetailsResourceFactoryBean userDetails = new UserDetailsResourceFactoryBean(); @Override public MapReactiveUserDetailsService getObject() throws Exception { - Collection users = userDetails.getObject(); + Collection users = this.userDetails.getObject(); return new MapReactiveUserDetailsService(users); } @@ -50,21 +52,22 @@ public class ReactiveUserDetailsServiceResourceFactoryBean @Override public void setResourceLoader(ResourceLoader resourceLoader) { - userDetails.setResourceLoader(resourceLoader); + this.userDetails.setResourceLoader(resourceLoader); } /** - * Sets the location of a Resource that is a Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param resourceLocation the location of the properties file that contains the users (i.e. "classpath:users.properties") + * Sets the location of a Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. + * @param resourceLocation the location of the properties file that contains the users + * (i.e. "classpath:users.properties") */ public void setResourceLocation(String resourceLocation) { this.userDetails.setResourceLocation(resourceLocation); } /** - * Sets a Resource that is a Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. - * + * Sets a Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. * @param resource the Resource to use */ public void setResource(Resource resource) { @@ -72,10 +75,11 @@ public class ReactiveUserDetailsServiceResourceFactoryBean } /** - * Create a ReactiveUserDetailsServiceResourceFactoryBean with the location of a Resource that is a Properties file in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param resourceLocation the location of the properties file that contains the users (i.e. "classpath:users.properties") + * Create a ReactiveUserDetailsServiceResourceFactoryBean with the location of a + * Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. + * @param resourceLocation the location of the properties file that contains the users + * (i.e. "classpath:users.properties") * @return the ReactiveUserDetailsServiceResourceFactoryBean */ public static ReactiveUserDetailsServiceResourceFactoryBean fromResourceLocation(String resourceLocation) { @@ -85,10 +89,10 @@ public class ReactiveUserDetailsServiceResourceFactoryBean } /** - * Create a ReactiveUserDetailsServiceResourceFactoryBean with a Resource that is a Properties file in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param propertiesResource the Resource that is a properties file that contains the users + * Create a ReactiveUserDetailsServiceResourceFactoryBean with a Resource that is a + * Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. + * @param propertiesResource the Resource that is a properties file that contains the + * users * @return the ReactiveUserDetailsServiceResourceFactoryBean */ public static ReactiveUserDetailsServiceResourceFactoryBean fromResource(Resource propertiesResource) { @@ -100,8 +104,8 @@ public class ReactiveUserDetailsServiceResourceFactoryBean /** * Create a ReactiveUserDetailsServiceResourceFactoryBean with a String that is in the * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param users the users in the format defined in {@link UserDetailsResourceFactoryBean} + * @param users the users in the format defined in + * {@link UserDetailsResourceFactoryBean} * @return the ReactiveUserDetailsServiceResourceFactoryBean */ public static ReactiveUserDetailsServiceResourceFactoryBean fromString(String users) { @@ -109,4 +113,5 @@ public class ReactiveUserDetailsServiceResourceFactoryBean result.setResource(new InMemoryResource(users)); return result; } + } diff --git a/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsMapFactoryBean.java b/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsMapFactoryBean.java index 0048eb7e4e..7b36a5ee19 100644 --- a/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsMapFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsMapFactoryBean.java @@ -16,17 +16,19 @@ package org.springframework.security.config.core.userdetails; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + import org.springframework.beans.factory.FactoryBean; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.memory.UserAttribute; import org.springframework.security.core.userdetails.memory.UserAttributeEditor; import org.springframework.util.Assert; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Map; - /** * Creates a {@code Collection} from a @{code Map} in the format of *

    @@ -34,7 +36,8 @@ import java.util.Map; * username=password[,enabled|disabled],roles... * *

    - * The enabled and disabled properties are optional with enabled being the default. For example: + * The enabled and disabled properties are optional with enabled being the default. For + * example: *

    * * user=password,ROLE_USER @@ -46,6 +49,7 @@ import java.util.Map; * @since 5.0 */ public class UserDetailsMapFactoryBean implements FactoryBean> { + private final Map userProperties; public UserDetailsMapFactoryBean(Map userProperties) { @@ -56,28 +60,24 @@ public class UserDetailsMapFactoryBean implements FactoryBean getObject() { Collection users = new ArrayList<>(this.userProperties.size()); - UserAttributeEditor editor = new UserAttributeEditor(); - for (Map.Entry entry : this.userProperties.entrySet()) { - String name = entry.getKey(); - String property = entry.getValue(); + this.userProperties.forEach((name, property) -> { editor.setAsText(property); UserAttribute attr = (UserAttribute) editor.getValue(); - if (attr == null) { - throw new IllegalStateException("The entry with username '" + name - + "' and value '" + property + "' could not be converted to a UserDetails."); - } - UserDetails user = User.withUsername(name) - .password(attr.getPassword()) - .disabled(!attr.isEnabled()) - .authorities(attr.getAuthorities()) - .build(); - users.add(user); - } return users; + Assert.state(attr != null, () -> "The entry with username '" + name + "' and value '" + property + + "' could not be converted to a UserDetails."); + String password = attr.getPassword(); + boolean disabled = !attr.isEnabled(); + List authorities = attr.getAuthorities(); + users.add(User.withUsername(name).password(password).disabled(disabled).authorities(authorities).build()); + }); + return users; + } @Override public Class getObjectType() { return Collection.class; } + } diff --git a/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBean.java b/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBean.java index 850a72c10b..ed74783d5c 100644 --- a/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBean.java @@ -16,6 +16,11 @@ package org.springframework.security.config.core.userdetails; +import java.io.InputStream; +import java.util.Collection; +import java.util.Map; +import java.util.Properties; + import org.springframework.beans.factory.FactoryBean; import org.springframework.context.ResourceLoaderAware; import org.springframework.core.io.DefaultResourceLoader; @@ -25,11 +30,6 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.util.InMemoryResource; import org.springframework.util.Assert; -import java.io.InputStream; -import java.util.Collection; -import java.util.Map; -import java.util.Properties; - /** * Parses a Resource that is a Properties file in the format of: * @@ -37,7 +37,8 @@ import java.util.Properties; * username=password[,enabled|disabled],roles... * * - * The enabled and disabled properties are optional with enabled being the default. For example: + * The enabled and disabled properties are optional with enabled being the default. For + * example: * * * user=password,ROLE_USER @@ -49,6 +50,7 @@ import java.util.Properties; * @since 5.0 */ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, FactoryBean> { + private ResourceLoader resourceLoader = new DefaultResourceLoader(); private String resourceLocation; @@ -65,7 +67,7 @@ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, Fact public Collection getObject() throws Exception { Properties userProperties = new Properties(); Resource resource = getPropertiesResource(); - try(InputStream in = resource.getInputStream()){ + try (InputStream in = resource.getInputStream()) { userProperties.load(in); } return new UserDetailsMapFactoryBean((Map) userProperties).getObject(); @@ -77,17 +79,18 @@ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, Fact } /** - * Sets the location of a Resource that is a Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param resourceLocation the location of the properties file that contains the users (i.e. "classpath:users.properties") + * Sets the location of a Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. + * @param resourceLocation the location of the properties file that contains the users + * (i.e. "classpath:users.properties") */ public void setResourceLocation(String resourceLocation) { this.resourceLocation = resourceLocation; } /** - * Sets a Resource that is a Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. - * + * Sets a Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. * @param resource the Resource to use */ public void setResource(Resource resource) { @@ -95,19 +98,19 @@ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, Fact } private Resource getPropertiesResource() { - Resource result = resource; - if (result == null && resourceLocation != null) { - result = resourceLoader.getResource(resourceLocation); + Resource result = this.resource; + if (result == null && this.resourceLocation != null) { + result = this.resourceLoader.getResource(this.resourceLocation); } Assert.notNull(result, "resource cannot be null if resourceLocation is null"); return result; } /** - * Create a UserDetailsResourceFactoryBean with the location of a Resource that is a Properties file in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param resourceLocation the location of the properties file that contains the users (i.e. "classpath:users.properties") + * Create a UserDetailsResourceFactoryBean with the location of a Resource that is a + * Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. + * @param resourceLocation the location of the properties file that contains the users + * (i.e. "classpath:users.properties") * @return the UserDetailsResourceFactoryBean */ public static UserDetailsResourceFactoryBean fromResourceLocation(String resourceLocation) { @@ -117,10 +120,10 @@ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, Fact } /** - * Create a UserDetailsResourceFactoryBean with a Resource that is a Properties file in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param propertiesResource the Resource that is a properties file that contains the users + * Create a UserDetailsResourceFactoryBean with a Resource that is a Properties file + * in the format defined in {@link UserDetailsResourceFactoryBean}. + * @param propertiesResource the Resource that is a properties file that contains the + * users * @return the UserDetailsResourceFactoryBean */ public static UserDetailsResourceFactoryBean fromResource(Resource propertiesResource) { @@ -131,7 +134,6 @@ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, Fact /** * Creates a UserDetailsResourceFactoryBean with a resource from the provided String - * * @param users the string representing the users * @return the UserDetailsResourceFactoryBean */ @@ -139,4 +141,5 @@ public class UserDetailsResourceFactoryBean implements ResourceLoaderAware, Fact InMemoryResource resource = new InMemoryResource(users); return fromResource(resource); } + } diff --git a/config/src/main/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessor.java b/config/src/main/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessor.java index b3cd7df87c..5174a62bc5 100644 --- a/config/src/main/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessor.java @@ -16,6 +16,7 @@ package org.springframework.security.config.crypto; +import java.beans.PropertyEditor; import java.beans.PropertyEditorSupport; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -39,12 +40,14 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; /** - * Adds {@link RsaKeyConverters} to the configured {@link ConversionService} or {@link PropertyEditor}s + * Adds {@link RsaKeyConverters} to the configured {@link ConversionService} or + * {@link PropertyEditor}s * * @author Josh Cummings * @since 5.2 */ public class RsaKeyConversionServicePostProcessor implements BeanFactoryPostProcessor { + private static final String CONVERSION_SERVICE_BEAN_NAME = "conversionService"; private ResourceLoader resourceLoader = new DefaultResourceLoader(); @@ -54,25 +57,21 @@ public class RsaKeyConversionServicePostProcessor implements BeanFactoryPostProc this.resourceLoader = resourceLoader; } - /** - * {@inheritDoc} - */ @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { if (hasUserDefinedConversionService(beanFactory)) { return; } - Converter pkcs8 = pkcs8(); Converter x509 = x509(); - ConversionService service = beanFactory.getConversionService(); if (service instanceof ConverterRegistry) { ConverterRegistry registry = (ConverterRegistry) service; registry.addConverter(String.class, RSAPrivateKey.class, pkcs8); registry.addConverter(String.class, RSAPublicKey.class, x509); - } else { - beanFactory.addPropertyEditorRegistrar(registry -> { + } + else { + beanFactory.addPropertyEditorRegistrar((registry) -> { registry.registerCustomEditor(RSAPublicKey.class, new ConverterPropertyEditorAdapter<>(x509)); registry.registerCustomEditor(RSAPrivateKey.class, new ConverterPropertyEditorAdapter<>(pkcs8)); }); @@ -80,8 +79,8 @@ public class RsaKeyConversionServicePostProcessor implements BeanFactoryPostProc } private boolean hasUserDefinedConversionService(ConfigurableListableBeanFactory beanFactory) { - return beanFactory.containsBean(CONVERSION_SERVICE_BEAN_NAME) && - beanFactory.isTypeMatch(CONVERSION_SERVICE_BEAN_NAME, ConversionService.class); + return beanFactory.containsBean(CONVERSION_SERVICE_BEAN_NAME) + && beanFactory.isTypeMatch(CONVERSION_SERVICE_BEAN_NAME, ConversionService.class); } private Converter pkcs8() { @@ -97,8 +96,8 @@ public class RsaKeyConversionServicePostProcessor implements BeanFactoryPostProc } private Converter pemInputStreamConverter() { - return source -> source.startsWith("-----") ? - toInputStream(source) : toInputStream(this.resourceLoader.getResource(source)); + return (source) -> source.startsWith("-----") ? toInputStream(source) + : toInputStream(this.resourceLoader.getResource(source)); } private InputStream toInputStream(String raw) { @@ -108,29 +107,32 @@ public class RsaKeyConversionServicePostProcessor implements BeanFactoryPostProc private InputStream toInputStream(Resource resource) { try { return resource.getInputStream(); - } catch (IOException e) { - throw new UncheckedIOException(e); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); } } private Converter autoclose(Converter inputStreamKeyConverter) { - return inputStream -> { + return (inputStream) -> { try (InputStream is = inputStream) { return inputStreamKeyConverter.convert(is); - } catch (IOException e) { - throw new UncheckedIOException(e); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); } }; } private Converter pair(Converter one, Converter two) { - return source -> { + return (source) -> { I intermediary = one.convert(source); return two.convert(intermediary); }; } private static class ConverterPropertyEditorAdapter extends PropertyEditorSupport { + private final Converter converter; ConverterPropertyEditorAdapter(Converter converter) { @@ -146,9 +148,12 @@ public class RsaKeyConversionServicePostProcessor implements BeanFactoryPostProc public void setAsText(String text) throws IllegalArgumentException { if (StringUtils.hasText(text)) { setValue(this.converter.convert(text)); - } else { + } + else { setValue(null); } } + } + } diff --git a/config/src/main/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessor.java b/config/src/main/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessor.java index ab365c3f00..b7fe459ff3 100644 --- a/config/src/main/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessor.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.debug; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -30,14 +32,13 @@ import org.springframework.security.web.debug.DebugFilter; * @author Luke Taylor * @author Rob Winch */ -public class SecurityDebugBeanFactoryPostProcessor implements - BeanDefinitionRegistryPostProcessor { +public class SecurityDebugBeanFactoryPostProcessor implements BeanDefinitionRegistryPostProcessor { + private final Log logger = LogFactory.getLog(getClass()); - public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) - throws BeansException { - logger.warn("\n\n" - + "********************************************************************\n" + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + this.logger.warn("\n\n" + "********************************************************************\n" + "********** Security debugging is enabled. *************\n" + "********** This may include sensitive information. *************\n" + "********** Do not use in a production system! *************\n" @@ -45,21 +46,19 @@ public class SecurityDebugBeanFactoryPostProcessor implements // SPRING_SECURITY_FILTER_CHAIN does not exist yet since it is an alias that has // not been processed, so use FILTER_CHAIN_PROXY if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) { - BeanDefinition fcpBeanDef = registry - .getBeanDefinition(BeanIds.FILTER_CHAIN_PROXY); - BeanDefinitionBuilder debugFilterBldr = BeanDefinitionBuilder - .genericBeanDefinition(DebugFilter.class); + BeanDefinition fcpBeanDef = registry.getBeanDefinition(BeanIds.FILTER_CHAIN_PROXY); + BeanDefinitionBuilder debugFilterBldr = BeanDefinitionBuilder.genericBeanDefinition(DebugFilter.class); debugFilterBldr.addConstructorArgValue(fcpBeanDef); // Remove the alias to SPRING_SECURITY_FILTER_CHAIN, so that it does not // override the new // SPRING_SECURITY_FILTER_CHAIN definition registry.removeAlias(BeanIds.SPRING_SECURITY_FILTER_CHAIN); - registry.registerBeanDefinition(BeanIds.SPRING_SECURITY_FILTER_CHAIN, - debugFilterBldr.getBeanDefinition()); + registry.registerBeanDefinition(BeanIds.SPRING_SECURITY_FILTER_CHAIN, debugFilterBldr.getBeanDefinition()); } } - public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) - throws BeansException { + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { } -} \ No newline at end of file + +} diff --git a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java index 3e82001c40..d3c0ce32f4 100644 --- a/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java @@ -13,10 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import javax.servlet.http.HttpServletRequest; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; @@ -56,31 +68,6 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; - -import javax.servlet.http.HttpServletRequest; -import java.security.SecureRandom; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Function; - -import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER; -import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER; -import static org.springframework.security.config.http.SecurityFilters.BEARER_TOKEN_AUTH_FILTER; -import static org.springframework.security.config.http.SecurityFilters.EXCEPTION_TRANSLATION_FILTER; -import static org.springframework.security.config.http.SecurityFilters.FORM_LOGIN_FILTER; -import static org.springframework.security.config.http.SecurityFilters.LOGIN_PAGE_FILTER; -import static org.springframework.security.config.http.SecurityFilters.LOGOUT_FILTER; -import static org.springframework.security.config.http.SecurityFilters.LOGOUT_PAGE_FILTER; -import static org.springframework.security.config.http.SecurityFilters.OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER; -import static org.springframework.security.config.http.SecurityFilters.OAUTH2_AUTHORIZATION_REQUEST_FILTER; -import static org.springframework.security.config.http.SecurityFilters.OAUTH2_LOGIN_FILTER; -import static org.springframework.security.config.http.SecurityFilters.OPENID_FILTER; -import static org.springframework.security.config.http.SecurityFilters.PRE_AUTH_FILTER; -import static org.springframework.security.config.http.SecurityFilters.REMEMBER_ME_FILTER; -import static org.springframework.security.config.http.SecurityFilters.X509_FILTER; /** * Handles creation of authentication mechanism filters and related beans for <http> @@ -91,16 +78,23 @@ import static org.springframework.security.config.http.SecurityFilters.X509_FILT * @since 3.0 */ final class AuthenticationConfigBuilder { + private final Log logger = LogFactory.getLog(getClass()); private static final String ATT_REALM = "realm"; + private static final String DEF_REALM = "Realm"; static final String OPEN_ID_AUTHENTICATION_PROCESSING_FILTER_CLASS = "org.springframework.security.openid.OpenIDAuthenticationFilter"; + static final String OPEN_ID_AUTHENTICATION_PROVIDER_CLASS = "org.springframework.security.openid.OpenIDAuthenticationProvider"; + private static final String OPEN_ID_CONSUMER_CLASS = "org.springframework.security.openid.OpenID4JavaConsumer"; + static final String OPEN_ID_ATTRIBUTE_CLASS = "org.springframework.security.openid.OpenIDAttribute"; + private static final String OPEN_ID_ATTRIBUTE_FACTORY_CLASS = "org.springframework.security.openid.RegexBasedAxFetchListFactory"; + static final String AUTHENTICATION_PROCESSING_FILTER_CLASS = "org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter"; static final String ATT_AUTH_DETAILS_SOURCE_REF = "authentication-details-source-ref"; @@ -108,52 +102,84 @@ final class AuthenticationConfigBuilder { private static final String ATT_AUTO_CONFIG = "auto-config"; private static final String ATT_ACCESS_DENIED_ERROR_PAGE = "error-page"; + private static final String ATT_ENTRY_POINT_REF = "entry-point-ref"; private static final String ATT_USER_SERVICE_REF = "user-service-ref"; private static final String ATT_KEY = "key"; + private static final String ATT_MAPPABLE_ROLES = "mappable-roles"; + private final Element httpElt; + private final ParserContext pc; private final boolean autoConfig; + private final boolean allowSessionCreation; private RootBeanDefinition anonymousFilter; + private BeanReference anonymousProviderRef; + private BeanDefinition rememberMeFilter; + private String rememberMeServicesId; + private BeanReference rememberMeProviderRef; + private BeanDefinition basicFilter; + private RuntimeBeanReference basicEntryPoint; + private BeanDefinition formEntryPoint; + private BeanDefinition openIDEntryPoint; + private BeanReference openIDProviderRef; + private String formFilterId = null; + private String openIDFilterId = null; + private BeanDefinition x509Filter; + private BeanReference x509ProviderRef; + private BeanDefinition jeeFilter; + private BeanReference jeeProviderRef; + private RootBeanDefinition preAuthEntryPoint; + private BeanMetadataElement mainEntryPoint; + private BeanMetadataElement accessDeniedHandler; private BeanDefinition bearerTokenAuthenticationFilter; private BeanDefinition logoutFilter; + @SuppressWarnings("rawtypes") private ManagedList logoutHandlers; + private BeanDefinition loginPageGenerationFilter; + private BeanDefinition logoutPageGenerationFilter; + private BeanDefinition etf; + private final BeanReference requestCache; + private final BeanReference portMapper; + private final BeanReference portResolver; + private final BeanMetadataElement csrfLogoutHandler; private String loginProcessingUrl; + private String openidLoginProcessingUrl; private String formLoginPage; @@ -161,40 +187,50 @@ final class AuthenticationConfigBuilder { private String openIDLoginPage; private boolean oauth2LoginEnabled; + private boolean defaultAuthorizedClientRepositoryRegistered; + private String oauth2LoginFilterId; + private BeanDefinition oauth2AuthorizationRequestRedirectFilter; + private BeanDefinition oauth2LoginEntryPoint; + private BeanReference oauth2LoginAuthenticationProviderRef; + private BeanReference oauth2LoginOidcAuthenticationProviderRef; + private BeanDefinition oauth2LoginLinks; private boolean oauth2ClientEnabled; + private BeanDefinition authorizationRequestRedirectFilter; + private BeanDefinition authorizationCodeGrantFilter; + private BeanReference authorizationCodeAuthenticationProviderRef; private final List authenticationProviders = new ManagedList<>(); + private final Map defaultDeniedHandlerMappings = new ManagedMap<>(); + private final Map defaultEntryPointMappings = new ManagedMap<>(); + private final List csrfIgnoreRequestMatchers = new ManagedList<>(); - AuthenticationConfigBuilder(Element element, boolean forceAutoConfig, - ParserContext pc, SessionCreationPolicy sessionPolicy, - BeanReference requestCache, BeanReference authenticationManager, - BeanReference sessionStrategy, BeanReference portMapper, - BeanReference portResolver, BeanMetadataElement csrfLogoutHandler) { + AuthenticationConfigBuilder(Element element, boolean forceAutoConfig, ParserContext pc, + SessionCreationPolicy sessionPolicy, BeanReference requestCache, BeanReference authenticationManager, + BeanReference sessionStrategy, BeanReference portMapper, BeanReference portResolver, + BeanMetadataElement csrfLogoutHandler) { this.httpElt = element; this.pc = pc; this.requestCache = requestCache; - autoConfig = forceAutoConfig - | "true".equals(element.getAttribute(ATT_AUTO_CONFIG)); + this.autoConfig = forceAutoConfig | "true".equals(element.getAttribute(ATT_AUTO_CONFIG)); this.allowSessionCreation = sessionPolicy != SessionCreationPolicy.NEVER && sessionPolicy != SessionCreationPolicy.STATELESS; this.portMapper = portMapper; this.portResolver = portResolver; this.csrfLogoutHandler = csrfLogoutHandler; - createAnonymousFilter(); createRememberMeFilter(authenticationManager); createBasicFilter(authenticationManager); @@ -211,69 +247,51 @@ final class AuthenticationConfigBuilder { } void createRememberMeFilter(BeanReference authenticationManager) { - // Parse remember me before logout as RememberMeServices is also a LogoutHandler // implementation. - Element rememberMeElt = DomUtils.getChildElementByTagName(httpElt, - Elements.REMEMBER_ME); - + Element rememberMeElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.REMEMBER_ME); if (rememberMeElt != null) { String key = rememberMeElt.getAttribute(ATT_KEY); - if (!StringUtils.hasText(key)) { key = createKey(); } - - RememberMeBeanDefinitionParser rememberMeParser = new RememberMeBeanDefinitionParser( - key, authenticationManager); - rememberMeFilter = rememberMeParser.parse(rememberMeElt, pc); - rememberMeServicesId = rememberMeParser.getRememberMeServicesId(); + RememberMeBeanDefinitionParser rememberMeParser = new RememberMeBeanDefinitionParser(key, + authenticationManager); + this.rememberMeFilter = rememberMeParser.parse(rememberMeElt, this.pc); + this.rememberMeServicesId = rememberMeParser.getRememberMeServicesId(); createRememberMeProvider(key); } } private void createRememberMeProvider(String key) { - RootBeanDefinition provider = new RootBeanDefinition( - RememberMeAuthenticationProvider.class); - provider.setSource(rememberMeFilter.getSource()); - + RootBeanDefinition provider = new RootBeanDefinition(RememberMeAuthenticationProvider.class); + provider.setSource(this.rememberMeFilter.getSource()); provider.getConstructorArgumentValues().addGenericArgumentValue(key); - - String id = pc.getReaderContext().generateBeanName(provider); - pc.registerBeanComponent(new BeanComponentDefinition(provider, id)); - - rememberMeProviderRef = new RuntimeBeanReference(id); + String id = this.pc.getReaderContext().generateBeanName(provider); + this.pc.registerBeanComponent(new BeanComponentDefinition(provider, id)); + this.rememberMeProviderRef = new RuntimeBeanReference(id); } void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authManager) { - - Element formLoginElt = DomUtils.getChildElementByTagName(httpElt, - Elements.FORM_LOGIN); + Element formLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.FORM_LOGIN); RootBeanDefinition formFilter = null; - - if (formLoginElt != null || autoConfig) { - FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser( - "/login", "POST", AUTHENTICATION_PROCESSING_FILTER_CLASS, - requestCache, sessionStrategy, allowSessionCreation, portMapper, - portResolver); - - parser.parse(formLoginElt, pc); + if (formLoginElt != null || this.autoConfig) { + FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser("/login", "POST", + AUTHENTICATION_PROCESSING_FILTER_CLASS, this.requestCache, sessionStrategy, + this.allowSessionCreation, this.portMapper, this.portResolver); + parser.parse(formLoginElt, this.pc); formFilter = parser.getFilterBean(); - formEntryPoint = parser.getEntryPointBean(); - loginProcessingUrl = parser.getLoginProcessingUrl(); - formLoginPage = parser.getLoginPage(); + this.formEntryPoint = parser.getEntryPointBean(); + this.loginProcessingUrl = parser.getLoginProcessingUrl(); + this.formLoginPage = parser.getLoginPage(); } - if (formFilter != null) { - formFilter.getPropertyValues().addPropertyValue("allowSessionCreation", - allowSessionCreation); - formFilter.getPropertyValues().addPropertyValue("authenticationManager", - authManager); - + formFilter.getPropertyValues().addPropertyValue("allowSessionCreation", this.allowSessionCreation); + formFilter.getPropertyValues().addPropertyValue("authenticationManager", authManager); // Id is required by login page filter - formFilterId = pc.getReaderContext().generateBeanName(formFilter); - pc.registerBeanComponent(new BeanComponentDefinition(formFilter, formFilterId)); - injectRememberMeServicesRef(formFilter, rememberMeServicesId); + this.formFilterId = this.pc.getReaderContext().generateBeanName(formFilter); + this.pc.registerBeanComponent(new BeanComponentDefinition(formFilter, this.formFilterId)); + injectRememberMeServicesRef(formFilter, this.rememberMeServicesId); } } @@ -290,43 +308,39 @@ final class AuthenticationConfigBuilder { return; } this.oauth2LoginEnabled = true; - - OAuth2LoginBeanDefinitionParser parser = new OAuth2LoginBeanDefinitionParser(requestCache, portMapper, - portResolver, sessionStrategy, allowSessionCreation); + OAuth2LoginBeanDefinitionParser parser = new OAuth2LoginBeanDefinitionParser(this.requestCache, this.portMapper, + this.portResolver, sessionStrategy, this.allowSessionCreation); BeanDefinition oauth2LoginFilterBean = parser.parse(oauth2LoginElt, this.pc); - BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository(); registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository); - oauth2LoginFilterBean.getPropertyValues().addPropertyValue("authenticationManager", authManager); // retrieve the other bean result BeanDefinition oauth2LoginAuthProvider = parser.getOAuth2LoginAuthenticationProvider(); - oauth2AuthorizationRequestRedirectFilter = parser.getOAuth2AuthorizationRequestRedirectFilter(); - oauth2LoginEntryPoint = parser.getOAuth2LoginAuthenticationEntryPoint(); + this.oauth2AuthorizationRequestRedirectFilter = parser.getOAuth2AuthorizationRequestRedirectFilter(); + this.oauth2LoginEntryPoint = parser.getOAuth2LoginAuthenticationEntryPoint(); // generate bean name to be registered - String oauth2LoginAuthProviderId = pc.getReaderContext() - .generateBeanName(oauth2LoginAuthProvider); - oauth2LoginFilterId = pc.getReaderContext().generateBeanName(oauth2LoginFilterBean); - String oauth2AuthorizationRequestRedirectFilterId = pc.getReaderContext() - .generateBeanName(oauth2AuthorizationRequestRedirectFilter); - oauth2LoginLinks = parser.getOAuth2LoginLinks(); + String oauth2LoginAuthProviderId = this.pc.getReaderContext().generateBeanName(oauth2LoginAuthProvider); + this.oauth2LoginFilterId = this.pc.getReaderContext().generateBeanName(oauth2LoginFilterBean); + String oauth2AuthorizationRequestRedirectFilterId = this.pc.getReaderContext() + .generateBeanName(this.oauth2AuthorizationRequestRedirectFilter); + this.oauth2LoginLinks = parser.getOAuth2LoginLinks(); // register the component - pc.registerBeanComponent(new BeanComponentDefinition(oauth2LoginFilterBean, oauth2LoginFilterId)); - pc.registerBeanComponent(new BeanComponentDefinition( - oauth2AuthorizationRequestRedirectFilter, oauth2AuthorizationRequestRedirectFilterId)); - pc.registerBeanComponent(new BeanComponentDefinition(oauth2LoginAuthProvider, oauth2LoginAuthProviderId)); + this.pc.registerBeanComponent(new BeanComponentDefinition(oauth2LoginFilterBean, this.oauth2LoginFilterId)); + this.pc.registerBeanComponent(new BeanComponentDefinition(this.oauth2AuthorizationRequestRedirectFilter, + oauth2AuthorizationRequestRedirectFilterId)); + this.pc.registerBeanComponent(new BeanComponentDefinition(oauth2LoginAuthProvider, oauth2LoginAuthProviderId)); - oauth2LoginAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginAuthProviderId); + this.oauth2LoginAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginAuthProviderId); // oidc provider BeanDefinition oauth2LoginOidcAuthProvider = parser.getOAuth2LoginOidcAuthenticationProvider(); - String oauth2LoginOidcAuthProviderId = pc.getReaderContext().generateBeanName(oauth2LoginOidcAuthProvider); - pc.registerBeanComponent(new BeanComponentDefinition( - oauth2LoginOidcAuthProvider, oauth2LoginOidcAuthProviderId)); - oauth2LoginOidcAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginOidcAuthProviderId); + String oauth2LoginOidcAuthProviderId = this.pc.getReaderContext().generateBeanName(oauth2LoginOidcAuthProvider); + this.pc.registerBeanComponent( + new BeanComponentDefinition(oauth2LoginOidcAuthProvider, oauth2LoginOidcAuthProviderId)); + this.oauth2LoginOidcAuthenticationProviderRef = new RuntimeBeanReference(oauth2LoginOidcAuthProviderId); } void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenticationManager) { @@ -335,40 +349,36 @@ final class AuthenticationConfigBuilder { return; } this.oauth2ClientEnabled = true; - - OAuth2ClientBeanDefinitionParser parser = new OAuth2ClientBeanDefinitionParser( - requestCache, authenticationManager); + OAuth2ClientBeanDefinitionParser parser = new OAuth2ClientBeanDefinitionParser(requestCache, + authenticationManager); parser.parse(oauth2ClientElt, this.pc); - BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository(); registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository); - this.authorizationRequestRedirectFilter = parser.getAuthorizationRequestRedirectFilter(); - String authorizationRequestRedirectFilterId = pc.getReaderContext() + String authorizationRequestRedirectFilterId = this.pc.getReaderContext() .generateBeanName(this.authorizationRequestRedirectFilter); - this.pc.registerBeanComponent(new BeanComponentDefinition( - this.authorizationRequestRedirectFilter, authorizationRequestRedirectFilterId)); - + this.pc.registerBeanComponent(new BeanComponentDefinition(this.authorizationRequestRedirectFilter, + authorizationRequestRedirectFilterId)); this.authorizationCodeGrantFilter = parser.getAuthorizationCodeGrantFilter(); - String authorizationCodeGrantFilterId = pc.getReaderContext() + String authorizationCodeGrantFilterId = this.pc.getReaderContext() .generateBeanName(this.authorizationCodeGrantFilter); - this.pc.registerBeanComponent(new BeanComponentDefinition( - this.authorizationCodeGrantFilter, authorizationCodeGrantFilterId)); - + this.pc.registerBeanComponent( + new BeanComponentDefinition(this.authorizationCodeGrantFilter, authorizationCodeGrantFilterId)); BeanDefinition authorizationCodeAuthenticationProvider = parser.getAuthorizationCodeAuthenticationProvider(); - String authorizationCodeAuthenticationProviderId = pc.getReaderContext() + String authorizationCodeAuthenticationProviderId = this.pc.getReaderContext() .generateBeanName(authorizationCodeAuthenticationProvider); - this.pc.registerBeanComponent(new BeanComponentDefinition( - authorizationCodeAuthenticationProvider, authorizationCodeAuthenticationProviderId)); - this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference(authorizationCodeAuthenticationProviderId); + this.pc.registerBeanComponent(new BeanComponentDefinition(authorizationCodeAuthenticationProvider, + authorizationCodeAuthenticationProviderId)); + this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference( + authorizationCodeAuthenticationProviderId); } void registerDefaultAuthorizedClientRepositoryIfNecessary(BeanDefinition defaultAuthorizedClientRepository) { if (!this.defaultAuthorizedClientRepositoryRegistered && defaultAuthorizedClientRepository != null) { - String authorizedClientRepositoryId = pc.getReaderContext() + String authorizedClientRepositoryId = this.pc.getReaderContext() .generateBeanName(defaultAuthorizedClientRepository); - this.pc.registerBeanComponent(new BeanComponentDefinition( - defaultAuthorizedClientRepository, authorizedClientRepositoryId)); + this.pc.registerBeanComponent( + new BeanComponentDefinition(defaultAuthorizedClientRepository, authorizedClientRepositoryId)); this.defaultAuthorizedClientRepositoryRegistered = true; } } @@ -377,143 +387,115 @@ final class AuthenticationConfigBuilder { if (!this.oauth2LoginEnabled && !this.oauth2ClientEnabled) { return; } - - boolean webmvcPresent = ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader()); + boolean webmvcPresent = ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", + getClass().getClassLoader()); if (webmvcPresent) { - this.pc.getReaderContext().registerWithGeneratedName( - new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class)); + this.pc.getReaderContext() + .registerWithGeneratedName(new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class)); } } void createOpenIDLoginFilter(BeanReference sessionStrategy, BeanReference authManager) { - Element openIDLoginElt = DomUtils.getChildElementByTagName(httpElt, - Elements.OPENID_LOGIN); + Element openIDLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OPENID_LOGIN); RootBeanDefinition openIDFilter = null; - if (openIDLoginElt != null) { openIDFilter = parseOpenIDFilter(sessionStrategy, openIDLoginElt); } - if (openIDFilter != null) { - openIDFilter.getPropertyValues().addPropertyValue("allowSessionCreation", - allowSessionCreation); - openIDFilter.getPropertyValues().addPropertyValue("authenticationManager", - authManager); + openIDFilter.getPropertyValues().addPropertyValue("allowSessionCreation", this.allowSessionCreation); + openIDFilter.getPropertyValues().addPropertyValue("authenticationManager", authManager); // Required by login page filter - openIDFilterId = pc.getReaderContext().generateBeanName(openIDFilter); - pc.registerBeanComponent(new BeanComponentDefinition(openIDFilter, - openIDFilterId)); - injectRememberMeServicesRef(openIDFilter, rememberMeServicesId); - + this.openIDFilterId = this.pc.getReaderContext().generateBeanName(openIDFilter); + this.pc.registerBeanComponent(new BeanComponentDefinition(openIDFilter, this.openIDFilterId)); + injectRememberMeServicesRef(openIDFilter, this.rememberMeServicesId); createOpenIDProvider(); } } /** * Parses OpenID 1.0 and 2.0 - related parts of configuration xmls - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @param sessionStrategy sessionStrategy * @param openIDLoginElt the element from the xml file * @return the parsed filter as rootBeanDefinition + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ - private RootBeanDefinition parseOpenIDFilter( BeanReference sessionStrategy, Element openIDLoginElt ) { + @Deprecated + private RootBeanDefinition parseOpenIDFilter(BeanReference sessionStrategy, Element openIDLoginElt) { RootBeanDefinition openIDFilter; - FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser( - "/login/openid", null, - OPEN_ID_AUTHENTICATION_PROCESSING_FILTER_CLASS, requestCache, - sessionStrategy, allowSessionCreation, portMapper, portResolver); + FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser("/login/openid", null, + OPEN_ID_AUTHENTICATION_PROCESSING_FILTER_CLASS, this.requestCache, sessionStrategy, + this.allowSessionCreation, this.portMapper, this.portResolver); - parser.parse(openIDLoginElt, pc); + parser.parse(openIDLoginElt, this.pc); openIDFilter = parser.getFilterBean(); - openIDEntryPoint = parser.getEntryPointBean(); - openidLoginProcessingUrl = parser.getLoginProcessingUrl(); - openIDLoginPage = parser.getLoginPage(); - + this.openIDEntryPoint = parser.getEntryPointBean(); + this.openidLoginProcessingUrl = parser.getLoginProcessingUrl(); + this.openIDLoginPage = parser.getLoginPage(); List attrExElts = DomUtils.getChildElementsByTagName(openIDLoginElt, Elements.OPENID_ATTRIBUTE_EXCHANGE); - if (!attrExElts.isEmpty()) { // Set up the consumer with the required attribute list - BeanDefinitionBuilder consumerBldr = BeanDefinitionBuilder - .rootBeanDefinition(OPEN_ID_CONSUMER_CLASS); - BeanDefinitionBuilder axFactory = BeanDefinitionBuilder - .rootBeanDefinition(OPEN_ID_ATTRIBUTE_FACTORY_CLASS); + BeanDefinitionBuilder consumerBldr = BeanDefinitionBuilder.rootBeanDefinition(OPEN_ID_CONSUMER_CLASS); + BeanDefinitionBuilder axFactory = BeanDefinitionBuilder.rootBeanDefinition(OPEN_ID_ATTRIBUTE_FACTORY_CLASS); ManagedMap> axMap = new ManagedMap<>(); - for (Element attrExElt : attrExElts) { String identifierMatch = attrExElt.getAttribute("identifier-match"); - if (!StringUtils.hasText(identifierMatch)) { if (attrExElts.size() > 1) { - pc.getReaderContext().error( - "You must supply an identifier-match attribute if using more" - + " than one " - + Elements.OPENID_ATTRIBUTE_EXCHANGE - + " element", attrExElt); + this.pc.getReaderContext().error("You must supply an identifier-match attribute if using more" + + " than one " + Elements.OPENID_ATTRIBUTE_EXCHANGE + " element", attrExElt); } // Match anything identifierMatch = ".*"; } - axMap.put(identifierMatch, parseOpenIDAttributes(attrExElt)); } axFactory.addConstructorArgValue(axMap); - consumerBldr.addConstructorArgValue(axFactory.getBeanDefinition()); - openIDFilter.getPropertyValues().addPropertyValue("consumer", - consumerBldr.getBeanDefinition()); + openIDFilter.getPropertyValues().addPropertyValue("consumer", consumerBldr.getBeanDefinition()); } return openIDFilter; } private ManagedList parseOpenIDAttributes(Element attrExElt) { ManagedList attributes = new ManagedList<>(); - for (Element attElt : DomUtils.getChildElementsByTagName(attrExElt, - Elements.OPENID_ATTRIBUTE)) { + for (Element attElt : DomUtils.getChildElementsByTagName(attrExElt, Elements.OPENID_ATTRIBUTE)) { String name = attElt.getAttribute("name"); String type = attElt.getAttribute("type"); String required = attElt.getAttribute("required"); String count = attElt.getAttribute("count"); - BeanDefinitionBuilder attrBldr = BeanDefinitionBuilder - .rootBeanDefinition(OPEN_ID_ATTRIBUTE_CLASS); + BeanDefinitionBuilder attrBldr = BeanDefinitionBuilder.rootBeanDefinition(OPEN_ID_ATTRIBUTE_CLASS); attrBldr.addConstructorArgValue(name); attrBldr.addConstructorArgValue(type); if (StringUtils.hasLength(required)) { attrBldr.addPropertyValue("required", Boolean.valueOf(required)); } - if (StringUtils.hasLength(count)) { attrBldr.addPropertyValue("count", Integer.parseInt(count)); } attributes.add(attrBldr.getBeanDefinition()); } - return attributes; } private void createOpenIDProvider() { - Element openIDLoginElt = DomUtils.getChildElementByTagName(httpElt, - Elements.OPENID_LOGIN); + Element openIDLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OPENID_LOGIN); BeanDefinitionBuilder openIDProviderBuilder = BeanDefinitionBuilder .rootBeanDefinition(OPEN_ID_AUTHENTICATION_PROVIDER_CLASS); - RootBeanDefinition uds = new RootBeanDefinition(); uds.setFactoryBeanName(BeanIds.USER_DETAILS_SERVICE_FACTORY); uds.setFactoryMethodName("authenticationUserDetailsService"); - uds.getConstructorArgumentValues().addGenericArgumentValue( - openIDLoginElt.getAttribute(ATT_USER_SERVICE_REF)); - + uds.getConstructorArgumentValues().addGenericArgumentValue(openIDLoginElt.getAttribute(ATT_USER_SERVICE_REF)); openIDProviderBuilder.addPropertyValue("authenticationUserDetailsService", uds); - BeanDefinition openIDProvider = openIDProviderBuilder.getBeanDefinition(); - openIDProviderRef = new RuntimeBeanReference(pc.getReaderContext() - .registerWithGeneratedName(openIDProvider)); + this.openIDProviderRef = new RuntimeBeanReference( + this.pc.getReaderContext().registerWithGeneratedName(openIDProvider)); } - private void injectRememberMeServicesRef(RootBeanDefinition bean, - String rememberMeServicesId) { + private void injectRememberMeServicesRef(RootBeanDefinition bean, String rememberMeServicesId) { if (rememberMeServicesId != null) { bean.getPropertyValues().addPropertyValue("rememberMeServices", new RuntimeBeanReference(rememberMeServicesId)); @@ -521,207 +503,149 @@ final class AuthenticationConfigBuilder { } void createBasicFilter(BeanReference authManager) { - Element basicAuthElt = DomUtils.getChildElementByTagName(httpElt, - Elements.BASIC_AUTH); - - if (basicAuthElt == null && !autoConfig) { + Element basicAuthElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.BASIC_AUTH); + if (basicAuthElt == null && !this.autoConfig) { // No basic auth, do nothing return; } - - String realm = httpElt.getAttribute(ATT_REALM); + String realm = this.httpElt.getAttribute(ATT_REALM); if (!StringUtils.hasText(realm)) { realm = DEF_REALM; } - - BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder - .rootBeanDefinition(BasicAuthenticationFilter.class); - + BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder.rootBeanDefinition(BasicAuthenticationFilter.class); String entryPointId; - if (basicAuthElt != null) { if (StringUtils.hasText(basicAuthElt.getAttribute(ATT_ENTRY_POINT_REF))) { - basicEntryPoint = new RuntimeBeanReference( - basicAuthElt.getAttribute(ATT_ENTRY_POINT_REF)); + this.basicEntryPoint = new RuntimeBeanReference(basicAuthElt.getAttribute(ATT_ENTRY_POINT_REF)); } - injectAuthenticationDetailsSource(basicAuthElt, filterBuilder); - } - - if (basicEntryPoint == null) { - RootBeanDefinition entryPoint = new RootBeanDefinition( - BasicAuthenticationEntryPoint.class); - entryPoint.setSource(pc.extractSource(httpElt)); + if (this.basicEntryPoint == null) { + RootBeanDefinition entryPoint = new RootBeanDefinition(BasicAuthenticationEntryPoint.class); + entryPoint.setSource(this.pc.extractSource(this.httpElt)); entryPoint.getPropertyValues().addPropertyValue("realmName", realm); - entryPointId = pc.getReaderContext().generateBeanName(entryPoint); - pc.registerBeanComponent(new BeanComponentDefinition(entryPoint, entryPointId)); - basicEntryPoint = new RuntimeBeanReference(entryPointId); + entryPointId = this.pc.getReaderContext().generateBeanName(entryPoint); + this.pc.registerBeanComponent(new BeanComponentDefinition(entryPoint, entryPointId)); + this.basicEntryPoint = new RuntimeBeanReference(entryPointId); } - filterBuilder.addConstructorArgValue(authManager); - filterBuilder.addConstructorArgValue(basicEntryPoint); - basicFilter = filterBuilder.getBeanDefinition(); + filterBuilder.addConstructorArgValue(this.basicEntryPoint); + this.basicFilter = filterBuilder.getBeanDefinition(); } void createBearerTokenAuthenticationFilter(BeanReference authManager) { - Element resourceServerElt = DomUtils.getChildElementByTagName(httpElt, - Elements.OAUTH2_RESOURCE_SERVER); - + Element resourceServerElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OAUTH2_RESOURCE_SERVER); if (resourceServerElt == null) { // No resource server, do nothing return; } - - OAuth2ResourceServerBeanDefinitionParser resourceServerBuilder = - new OAuth2ResourceServerBeanDefinitionParser(authManager, authenticationProviders, - defaultEntryPointMappings, defaultDeniedHandlerMappings, csrfIgnoreRequestMatchers); - bearerTokenAuthenticationFilter = resourceServerBuilder.parse(resourceServerElt, pc); + OAuth2ResourceServerBeanDefinitionParser resourceServerBuilder = new OAuth2ResourceServerBeanDefinitionParser( + authManager, this.authenticationProviders, this.defaultEntryPointMappings, + this.defaultDeniedHandlerMappings, this.csrfIgnoreRequestMatchers); + this.bearerTokenAuthenticationFilter = resourceServerBuilder.parse(resourceServerElt, this.pc); } void createX509Filter(BeanReference authManager) { - Element x509Elt = DomUtils.getChildElementByTagName(httpElt, Elements.X509); + Element x509Elt = DomUtils.getChildElementByTagName(this.httpElt, Elements.X509); RootBeanDefinition filter = null; - if (x509Elt != null) { BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder .rootBeanDefinition(X509AuthenticationFilter.class); - filterBuilder.getRawBeanDefinition().setSource(pc.extractSource(x509Elt)); + filterBuilder.getRawBeanDefinition().setSource(this.pc.extractSource(x509Elt)); filterBuilder.addPropertyValue("authenticationManager", authManager); - String regex = x509Elt.getAttribute("subject-principal-regex"); - if (StringUtils.hasText(regex)) { BeanDefinitionBuilder extractor = BeanDefinitionBuilder .rootBeanDefinition(SubjectDnX509PrincipalExtractor.class); extractor.addPropertyValue("subjectDnRegex", regex); - - filterBuilder.addPropertyValue("principalExtractor", - extractor.getBeanDefinition()); + filterBuilder.addPropertyValue("principalExtractor", extractor.getBeanDefinition()); } - injectAuthenticationDetailsSource(x509Elt, filterBuilder); - filter = (RootBeanDefinition) filterBuilder.getBeanDefinition(); createPrauthEntryPoint(x509Elt); - createX509Provider(); } - - x509Filter = filter; + this.x509Filter = filter; } - private void injectAuthenticationDetailsSource(Element elt, - BeanDefinitionBuilder filterBuilder) { - String authDetailsSourceRef = elt - .getAttribute(AuthenticationConfigBuilder.ATT_AUTH_DETAILS_SOURCE_REF); - + private void injectAuthenticationDetailsSource(Element elt, BeanDefinitionBuilder filterBuilder) { + String authDetailsSourceRef = elt.getAttribute(AuthenticationConfigBuilder.ATT_AUTH_DETAILS_SOURCE_REF); if (StringUtils.hasText(authDetailsSourceRef)) { - filterBuilder.addPropertyReference("authenticationDetailsSource", - authDetailsSourceRef); + filterBuilder.addPropertyReference("authenticationDetailsSource", authDetailsSourceRef); } } private void createX509Provider() { - Element x509Elt = DomUtils.getChildElementByTagName(httpElt, Elements.X509); - BeanDefinition provider = new RootBeanDefinition( - PreAuthenticatedAuthenticationProvider.class); - + Element x509Elt = DomUtils.getChildElementByTagName(this.httpElt, Elements.X509); + BeanDefinition provider = new RootBeanDefinition(PreAuthenticatedAuthenticationProvider.class); RootBeanDefinition uds = new RootBeanDefinition(); uds.setFactoryBeanName(BeanIds.USER_DETAILS_SERVICE_FACTORY); uds.setFactoryMethodName("authenticationUserDetailsService"); - uds.getConstructorArgumentValues().addGenericArgumentValue( - x509Elt.getAttribute(ATT_USER_SERVICE_REF)); - - provider.getPropertyValues().addPropertyValue( - "preAuthenticatedUserDetailsService", uds); - - x509ProviderRef = new RuntimeBeanReference(pc.getReaderContext() - .registerWithGeneratedName(provider)); + uds.getConstructorArgumentValues().addGenericArgumentValue(x509Elt.getAttribute(ATT_USER_SERVICE_REF)); + provider.getPropertyValues().addPropertyValue("preAuthenticatedUserDetailsService", uds); + this.x509ProviderRef = new RuntimeBeanReference(this.pc.getReaderContext().registerWithGeneratedName(provider)); } private void createPrauthEntryPoint(Element source) { - if (preAuthEntryPoint == null) { - preAuthEntryPoint = new RootBeanDefinition(Http403ForbiddenEntryPoint.class); - preAuthEntryPoint.setSource(pc.extractSource(source)); + if (this.preAuthEntryPoint == null) { + this.preAuthEntryPoint = new RootBeanDefinition(Http403ForbiddenEntryPoint.class); + this.preAuthEntryPoint.setSource(this.pc.extractSource(source)); } } void createJeeFilter(BeanReference authManager) { - final String ATT_MAPPABLE_ROLES = "mappable-roles"; - - Element jeeElt = DomUtils.getChildElementByTagName(httpElt, Elements.JEE); + Element jeeElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.JEE); RootBeanDefinition filter = null; - if (jeeElt != null) { BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder .rootBeanDefinition(J2eePreAuthenticatedProcessingFilter.class); - filterBuilder.getRawBeanDefinition().setSource(pc.extractSource(jeeElt)); + filterBuilder.getRawBeanDefinition().setSource(this.pc.extractSource(jeeElt)); filterBuilder.addPropertyValue("authenticationManager", authManager); - BeanDefinitionBuilder adsBldr = BeanDefinitionBuilder .rootBeanDefinition(J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource.class); adsBldr.addPropertyValue("userRoles2GrantedAuthoritiesMapper", - new RootBeanDefinition( - SimpleAttributes2GrantedAuthoritiesMapper.class)); - + new RootBeanDefinition(SimpleAttributes2GrantedAuthoritiesMapper.class)); String roles = jeeElt.getAttribute(ATT_MAPPABLE_ROLES); Assert.hasLength(roles, "roles is expected to have length"); - BeanDefinitionBuilder rolesBuilder = BeanDefinitionBuilder - .rootBeanDefinition(StringUtils.class); + BeanDefinitionBuilder rolesBuilder = BeanDefinitionBuilder.rootBeanDefinition(StringUtils.class); rolesBuilder.addConstructorArgValue(roles); rolesBuilder.setFactoryMethod("commaDelimitedListToSet"); - - RootBeanDefinition mappableRolesRetriever = new RootBeanDefinition( - SimpleMappableAttributesRetriever.class); - mappableRolesRetriever.getPropertyValues().addPropertyValue( - "mappableAttributes", rolesBuilder.getBeanDefinition()); + RootBeanDefinition mappableRolesRetriever = new RootBeanDefinition(SimpleMappableAttributesRetriever.class); + mappableRolesRetriever.getPropertyValues().addPropertyValue("mappableAttributes", + rolesBuilder.getBeanDefinition()); adsBldr.addPropertyValue("mappableRolesRetriever", mappableRolesRetriever); - filterBuilder.addPropertyValue("authenticationDetailsSource", - adsBldr.getBeanDefinition()); - + filterBuilder.addPropertyValue("authenticationDetailsSource", adsBldr.getBeanDefinition()); filter = (RootBeanDefinition) filterBuilder.getBeanDefinition(); - createPrauthEntryPoint(jeeElt); createJeeProvider(); } - - jeeFilter = filter; + this.jeeFilter = filter; } private void createJeeProvider() { - Element jeeElt = DomUtils.getChildElementByTagName(httpElt, Elements.JEE); - BeanDefinition provider = new RootBeanDefinition( - PreAuthenticatedAuthenticationProvider.class); - + Element jeeElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.JEE); + BeanDefinition provider = new RootBeanDefinition(PreAuthenticatedAuthenticationProvider.class); RootBeanDefinition uds; if (StringUtils.hasText(jeeElt.getAttribute(ATT_USER_SERVICE_REF))) { uds = new RootBeanDefinition(); uds.setFactoryBeanName(BeanIds.USER_DETAILS_SERVICE_FACTORY); uds.setFactoryMethodName("authenticationUserDetailsService"); - uds.getConstructorArgumentValues().addGenericArgumentValue( - jeeElt.getAttribute(ATT_USER_SERVICE_REF)); + uds.getConstructorArgumentValues().addGenericArgumentValue(jeeElt.getAttribute(ATT_USER_SERVICE_REF)); } else { - uds = new RootBeanDefinition( - PreAuthenticatedGrantedAuthoritiesUserDetailsService.class); + uds = new RootBeanDefinition(PreAuthenticatedGrantedAuthoritiesUserDetailsService.class); } - - provider.getPropertyValues().addPropertyValue( - "preAuthenticatedUserDetailsService", uds); - - jeeProviderRef = new RuntimeBeanReference(pc.getReaderContext() - .registerWithGeneratedName(provider)); + provider.getPropertyValues().addPropertyValue("preAuthenticatedUserDetailsService", uds); + this.jeeProviderRef = new RuntimeBeanReference(this.pc.getReaderContext().registerWithGeneratedName(provider)); } void createLoginPageFilterIfNeeded() { - boolean needLoginPage = formFilterId != null || openIDFilterId != null || oauth2LoginFilterId != null; - + boolean needLoginPage = this.formFilterId != null || this.openIDFilterId != null + || this.oauth2LoginFilterId != null; // If no login page has been defined, add in the default page generator. - if (needLoginPage && formLoginPage == null && openIDLoginPage == null) { - logger.info("No login page configured. The default internal one will be used. Use the '" - + FormLoginBeanDefinitionParser.ATT_LOGIN_PAGE - + "' attribute to set the URL of the login page."); + if (needLoginPage && this.formLoginPage == null && this.openIDLoginPage == null) { + this.logger.info("No login page configured. The default internal one will be used. Use the '" + + FormLoginBeanDefinitionParser.ATT_LOGIN_PAGE + "' attribute to set the URL of the login page."); BeanDefinitionBuilder loginPageFilter = BeanDefinitionBuilder .rootBeanDefinition(DefaultLoginPageGeneratingFilter.class); loginPageFilter.addPropertyValue("resolveHiddenInputs", new CsrfTokenHiddenInputFunction()); @@ -729,121 +653,101 @@ final class AuthenticationConfigBuilder { BeanDefinitionBuilder logoutPageFilter = BeanDefinitionBuilder .rootBeanDefinition(DefaultLogoutPageGeneratingFilter.class); logoutPageFilter.addPropertyValue("resolveHiddenInputs", new CsrfTokenHiddenInputFunction()); - - if (formFilterId != null) { - loginPageFilter.addConstructorArgReference(formFilterId); - loginPageFilter.addPropertyValue("authenticationUrl", loginProcessingUrl); + if (this.formFilterId != null) { + loginPageFilter.addConstructorArgReference(this.formFilterId); + loginPageFilter.addPropertyValue("authenticationUrl", this.loginProcessingUrl); } - - if (openIDFilterId != null) { - loginPageFilter.addConstructorArgReference(openIDFilterId); - loginPageFilter.addPropertyValue("openIDauthenticationUrl", - openidLoginProcessingUrl); + if (this.openIDFilterId != null) { + loginPageFilter.addConstructorArgReference(this.openIDFilterId); + loginPageFilter.addPropertyValue("openIDauthenticationUrl", this.openidLoginProcessingUrl); } - - if (oauth2LoginFilterId != null) { - loginPageFilter.addConstructorArgReference(oauth2LoginFilterId); + if (this.oauth2LoginFilterId != null) { + loginPageFilter.addConstructorArgReference(this.oauth2LoginFilterId); loginPageFilter.addPropertyValue("Oauth2LoginEnabled", true); - loginPageFilter.addPropertyValue("Oauth2AuthenticationUrlToClientName", oauth2LoginLinks); + loginPageFilter.addPropertyValue("Oauth2AuthenticationUrlToClientName", this.oauth2LoginLinks); } - - loginPageGenerationFilter = loginPageFilter.getBeanDefinition(); + this.loginPageGenerationFilter = loginPageFilter.getBeanDefinition(); this.logoutPageGenerationFilter = logoutPageFilter.getBeanDefinition(); } } void createLogoutFilter() { - Element logoutElt = DomUtils.getChildElementByTagName(httpElt, Elements.LOGOUT); - if (logoutElt != null || autoConfig) { + Element logoutElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.LOGOUT); + if (logoutElt != null || this.autoConfig) { String formLoginPage = this.formLoginPage; if (formLoginPage == null) { formLoginPage = DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL; } - LogoutBeanDefinitionParser logoutParser = new LogoutBeanDefinitionParser( - formLoginPage, rememberMeServicesId, csrfLogoutHandler); - logoutFilter = logoutParser.parse(logoutElt, pc); - logoutHandlers = logoutParser.getLogoutHandlers(); + LogoutBeanDefinitionParser logoutParser = new LogoutBeanDefinitionParser(formLoginPage, + this.rememberMeServicesId, this.csrfLogoutHandler); + this.logoutFilter = logoutParser.parse(logoutElt, this.pc); + this.logoutHandlers = logoutParser.getLogoutHandlers(); } } @SuppressWarnings({ "rawtypes", "unchecked" }) ManagedList getLogoutHandlers() { - if (logoutHandlers == null && rememberMeProviderRef != null) { - logoutHandlers = new ManagedList(); - if (csrfLogoutHandler != null) { - logoutHandlers.add(csrfLogoutHandler); + if (this.logoutHandlers == null && this.rememberMeProviderRef != null) { + this.logoutHandlers = new ManagedList(); + if (this.csrfLogoutHandler != null) { + this.logoutHandlers.add(this.csrfLogoutHandler); } - logoutHandlers.add(new RuntimeBeanReference(rememberMeServicesId)); - logoutHandlers - .add(new RootBeanDefinition(SecurityContextLogoutHandler.class)); + this.logoutHandlers.add(new RuntimeBeanReference(this.rememberMeServicesId)); + this.logoutHandlers.add(new RootBeanDefinition(SecurityContextLogoutHandler.class)); } - return logoutHandlers; + return this.logoutHandlers; } BeanMetadataElement getEntryPointBean() { - return mainEntryPoint; + return this.mainEntryPoint; } BeanMetadataElement getAccessDeniedHandlerBean() { - return accessDeniedHandler; + return this.accessDeniedHandler; } List getCsrfIgnoreRequestMatchers() { - return csrfIgnoreRequestMatchers; + return this.csrfIgnoreRequestMatchers; } void createAnonymousFilter() { - Element anonymousElt = DomUtils.getChildElementByTagName(httpElt, - Elements.ANONYMOUS); - + Element anonymousElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.ANONYMOUS); if (anonymousElt != null && "false".equals(anonymousElt.getAttribute("enabled"))) { return; } - String grantedAuthority = null; String username = null; String key = null; - Object source = pc.extractSource(httpElt); - + Object source = this.pc.extractSource(this.httpElt); if (anonymousElt != null) { grantedAuthority = anonymousElt.getAttribute("granted-authority"); username = anonymousElt.getAttribute("username"); key = anonymousElt.getAttribute(ATT_KEY); - source = pc.extractSource(anonymousElt); + source = this.pc.extractSource(anonymousElt); } - if (!StringUtils.hasText(grantedAuthority)) { grantedAuthority = "ROLE_ANONYMOUS"; } - if (!StringUtils.hasText(username)) { username = "anonymousUser"; } - if (!StringUtils.hasText(key)) { // Generate a random key for the Anonymous provider key = createKey(); } - - anonymousFilter = new RootBeanDefinition(AnonymousAuthenticationFilter.class); - anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(0, key); - anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(1, - username); - anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(2, + this.anonymousFilter = new RootBeanDefinition(AnonymousAuthenticationFilter.class); + this.anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(0, key); + this.anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(1, username); + this.anonymousFilter.getConstructorArgumentValues().addIndexedArgumentValue(2, AuthorityUtils.commaSeparatedStringToAuthorityList(grantedAuthority)); - anonymousFilter.setSource(source); - - RootBeanDefinition anonymousProviderBean = new RootBeanDefinition( - AnonymousAuthenticationProvider.class); - anonymousProviderBean.getConstructorArgumentValues().addIndexedArgumentValue(0, - key); - anonymousProviderBean.setSource(anonymousFilter.getSource()); - String id = pc.getReaderContext().generateBeanName(anonymousProviderBean); - pc.registerBeanComponent(new BeanComponentDefinition(anonymousProviderBean, id)); - - anonymousProviderRef = new RuntimeBeanReference(id); - + this.anonymousFilter.setSource(source); + RootBeanDefinition anonymousProviderBean = new RootBeanDefinition(AnonymousAuthenticationProvider.class); + anonymousProviderBean.getConstructorArgumentValues().addIndexedArgumentValue(0, key); + anonymousProviderBean.setSource(this.anonymousFilter.getSource()); + String id = this.pc.getReaderContext().generateBeanName(anonymousProviderBean); + this.pc.registerBeanComponent(new BeanComponentDefinition(anonymousProviderBean, id)); + this.anonymousProviderRef = new RuntimeBeanReference(id); } private String createKey() { @@ -852,254 +756,208 @@ final class AuthenticationConfigBuilder { } void createExceptionTranslationFilter() { - BeanDefinitionBuilder etfBuilder = BeanDefinitionBuilder - .rootBeanDefinition(ExceptionTranslationFilter.class); - accessDeniedHandler = createAccessDeniedHandler(httpElt, pc); - etfBuilder.addPropertyValue("accessDeniedHandler", accessDeniedHandler); - assert requestCache != null; - mainEntryPoint = selectEntryPoint(); - etfBuilder.addConstructorArgValue(mainEntryPoint); - etfBuilder.addConstructorArgValue(requestCache); - - etf = etfBuilder.getBeanDefinition(); + BeanDefinitionBuilder etfBuilder = BeanDefinitionBuilder.rootBeanDefinition(ExceptionTranslationFilter.class); + this.accessDeniedHandler = createAccessDeniedHandler(this.httpElt, this.pc); + etfBuilder.addPropertyValue("accessDeniedHandler", this.accessDeniedHandler); + Assert.state(this.requestCache != null, "No request cache found"); + this.mainEntryPoint = selectEntryPoint(); + etfBuilder.addConstructorArgValue(this.mainEntryPoint); + etfBuilder.addConstructorArgValue(this.requestCache); + this.etf = etfBuilder.getBeanDefinition(); } - private BeanMetadataElement createAccessDeniedHandler(Element element, - ParserContext pc) { - Element accessDeniedElt = DomUtils.getChildElementByTagName(element, - Elements.ACCESS_DENIED_HANDLER); + private BeanMetadataElement createAccessDeniedHandler(Element element, ParserContext pc) { + Element accessDeniedElt = DomUtils.getChildElementByTagName(element, Elements.ACCESS_DENIED_HANDLER); BeanDefinitionBuilder accessDeniedHandler = BeanDefinitionBuilder .rootBeanDefinition(AccessDeniedHandlerImpl.class); - if (accessDeniedElt != null) { String errorPage = accessDeniedElt.getAttribute("error-page"); String ref = accessDeniedElt.getAttribute("ref"); - if (StringUtils.hasText(errorPage)) { if (StringUtils.hasText(ref)) { pc.getReaderContext() - .error("The attribute " - + ATT_ACCESS_DENIED_ERROR_PAGE + .error("The attribute " + ATT_ACCESS_DENIED_ERROR_PAGE + " cannot be used together with the 'ref' attribute within <" - + Elements.ACCESS_DENIED_HANDLER + ">", - pc.extractSource(accessDeniedElt)); + + Elements.ACCESS_DENIED_HANDLER + ">", pc.extractSource(accessDeniedElt)); } accessDeniedHandler.addPropertyValue("errorPage", errorPage); return accessDeniedHandler.getBeanDefinition(); } - else if (StringUtils.hasText(ref)) { + if (StringUtils.hasText(ref)) { return new RuntimeBeanReference(ref); } - } - if (this.defaultDeniedHandlerMappings.isEmpty()) { return accessDeniedHandler.getBeanDefinition(); } if (this.defaultDeniedHandlerMappings.size() == 1) { return this.defaultDeniedHandlerMappings.values().iterator().next(); } - accessDeniedHandler = BeanDefinitionBuilder .rootBeanDefinition(RequestMatcherDelegatingAccessDeniedHandler.class); accessDeniedHandler.addConstructorArgValue(this.defaultDeniedHandlerMappings); - accessDeniedHandler.addConstructorArgValue - (BeanDefinitionBuilder.rootBeanDefinition(AccessDeniedHandlerImpl.class)); - + accessDeniedHandler + .addConstructorArgValue(BeanDefinitionBuilder.rootBeanDefinition(AccessDeniedHandlerImpl.class)); return accessDeniedHandler.getBeanDefinition(); } private BeanMetadataElement selectEntryPoint() { // We need to establish the main entry point. // First check if a custom entry point bean is set - String customEntryPoint = httpElt.getAttribute(ATT_ENTRY_POINT_REF); - + String customEntryPoint = this.httpElt.getAttribute(ATT_ENTRY_POINT_REF); if (StringUtils.hasText(customEntryPoint)) { return new RuntimeBeanReference(customEntryPoint); } - - if (!defaultEntryPointMappings.isEmpty()) { - if (defaultEntryPointMappings.size() == 1) { - return defaultEntryPointMappings.values().iterator().next(); + if (!this.defaultEntryPointMappings.isEmpty()) { + if (this.defaultEntryPointMappings.size() == 1) { + return this.defaultEntryPointMappings.values().iterator().next(); } BeanDefinitionBuilder delegatingEntryPoint = BeanDefinitionBuilder .rootBeanDefinition(DelegatingAuthenticationEntryPoint.class); - delegatingEntryPoint.addConstructorArgValue(defaultEntryPointMappings); + delegatingEntryPoint.addConstructorArgValue(this.defaultEntryPointMappings); return delegatingEntryPoint.getBeanDefinition(); } - - Element basicAuthElt = DomUtils.getChildElementByTagName(httpElt, - Elements.BASIC_AUTH); - Element formLoginElt = DomUtils.getChildElementByTagName(httpElt, - Elements.FORM_LOGIN); - Element openIDLoginElt = DomUtils.getChildElementByTagName(httpElt, - Elements.OPENID_LOGIN); + Element basicAuthElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.BASIC_AUTH); + Element formLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.FORM_LOGIN); + Element openIDLoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OPENID_LOGIN); // Basic takes precedence if explicit element is used and no others are configured - if (basicAuthElt != null && formLoginElt == null && openIDLoginElt == null && oauth2LoginEntryPoint == null) { - return basicEntryPoint; + if (basicAuthElt != null && formLoginElt == null && openIDLoginElt == null + && this.oauth2LoginEntryPoint == null) { + return this.basicEntryPoint; } - // If formLogin has been enabled either through an element or auto-config, then it // is used if no openID login page // has been set. - - if (formLoginPage != null && openIDLoginPage != null) { - pc.getReaderContext().error( - "Only one login-page can be defined, either for OpenID or form-login, " - + "but not both.", pc.extractSource(openIDLoginElt)); + if (this.formLoginPage != null && this.openIDLoginPage != null) { + this.pc.getReaderContext().error( + "Only one login-page can be defined, either for OpenID or form-login, " + "but not both.", + this.pc.extractSource(openIDLoginElt)); } - - if (formFilterId != null && openIDLoginPage == null) { - // gh-6802 - // If form login was enabled through element and Oauth2 login was enabled from element then use form login - if (formLoginElt != null && oauth2LoginEntryPoint != null) { - return formEntryPoint; + if (this.formFilterId != null && this.openIDLoginPage == null) { + // If form login was enabled through element and Oauth2 login was enabled from + // element then use form login (gh-6802) + if (formLoginElt != null && this.oauth2LoginEntryPoint != null) { + return this.formEntryPoint; } - // If form login was enabled through auto-config, and Oauth2 login was not enabled then use form login - if (oauth2LoginEntryPoint == null) { - return formEntryPoint; + // If form login was enabled through auto-config, and Oauth2 login was not + // enabled then use form login + if (this.oauth2LoginEntryPoint == null) { + return this.formEntryPoint; } } - // Otherwise use OpenID if enabled - if (openIDFilterId != null) { - return openIDEntryPoint; + if (this.openIDFilterId != null) { + return this.openIDEntryPoint; } - // If X.509 or JEE have been enabled, use the preauth entry point. - if (preAuthEntryPoint != null) { - return preAuthEntryPoint; + if (this.preAuthEntryPoint != null) { + return this.preAuthEntryPoint; } - // OAuth2 entry point will not be null if only 1 client registration - if (oauth2LoginEntryPoint != null) { - return oauth2LoginEntryPoint; + if (this.oauth2LoginEntryPoint != null) { + return this.oauth2LoginEntryPoint; } - - pc.getReaderContext() - .error("No AuthenticationEntryPoint could be established. Please " - + "make sure you have a login mechanism configured through the namespace (such as form-login) or " - + "specify a custom AuthenticationEntryPoint with the '" - + ATT_ENTRY_POINT_REF + "' attribute ", pc.extractSource(httpElt)); + this.pc.getReaderContext().error("No AuthenticationEntryPoint could be established. Please " + + "make sure you have a login mechanism configured through the namespace (such as form-login) or " + + "specify a custom AuthenticationEntryPoint with the '" + ATT_ENTRY_POINT_REF + "' attribute ", + this.pc.extractSource(this.httpElt)); return null; } private void createUserDetailsServiceFactory() { - if (pc.getRegistry().containsBeanDefinition(BeanIds.USER_DETAILS_SERVICE_FACTORY)) { + if (this.pc.getRegistry().containsBeanDefinition(BeanIds.USER_DETAILS_SERVICE_FACTORY)) { // Multiple case return; } - RootBeanDefinition bean = new RootBeanDefinition( - UserDetailsServiceFactoryBean.class); + RootBeanDefinition bean = new RootBeanDefinition(UserDetailsServiceFactoryBean.class); bean.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); - pc.registerBeanComponent(new BeanComponentDefinition(bean, - BeanIds.USER_DETAILS_SERVICE_FACTORY)); + this.pc.registerBeanComponent(new BeanComponentDefinition(bean, BeanIds.USER_DETAILS_SERVICE_FACTORY)); } List getFilters() { List filters = new ArrayList<>(); - - if (anonymousFilter != null) { - filters.add(new OrderDecorator(anonymousFilter, ANONYMOUS_FILTER)); + if (this.anonymousFilter != null) { + filters.add(new OrderDecorator(this.anonymousFilter, SecurityFilters.ANONYMOUS_FILTER)); } - - if (rememberMeFilter != null) { - filters.add(new OrderDecorator(rememberMeFilter, REMEMBER_ME_FILTER)); + if (this.rememberMeFilter != null) { + filters.add(new OrderDecorator(this.rememberMeFilter, SecurityFilters.REMEMBER_ME_FILTER)); } - - if (logoutFilter != null) { - filters.add(new OrderDecorator(logoutFilter, LOGOUT_FILTER)); + if (this.logoutFilter != null) { + filters.add(new OrderDecorator(this.logoutFilter, SecurityFilters.LOGOUT_FILTER)); } - - if (x509Filter != null) { - filters.add(new OrderDecorator(x509Filter, X509_FILTER)); + if (this.x509Filter != null) { + filters.add(new OrderDecorator(this.x509Filter, SecurityFilters.X509_FILTER)); } - - if (jeeFilter != null) { - filters.add(new OrderDecorator(jeeFilter, PRE_AUTH_FILTER)); + if (this.jeeFilter != null) { + filters.add(new OrderDecorator(this.jeeFilter, SecurityFilters.PRE_AUTH_FILTER)); } - - if (formFilterId != null) { - filters.add(new OrderDecorator(new RuntimeBeanReference(formFilterId), - FORM_LOGIN_FILTER)); + if (this.formFilterId != null) { + filters.add( + new OrderDecorator(new RuntimeBeanReference(this.formFilterId), SecurityFilters.FORM_LOGIN_FILTER)); } - - if (oauth2LoginFilterId != null) { - filters.add(new OrderDecorator(new RuntimeBeanReference(oauth2LoginFilterId), OAUTH2_LOGIN_FILTER)); - filters.add(new OrderDecorator(oauth2AuthorizationRequestRedirectFilter, OAUTH2_AUTHORIZATION_REQUEST_FILTER)); + if (this.oauth2LoginFilterId != null) { + filters.add(new OrderDecorator(new RuntimeBeanReference(this.oauth2LoginFilterId), + SecurityFilters.OAUTH2_LOGIN_FILTER)); + filters.add(new OrderDecorator(this.oauth2AuthorizationRequestRedirectFilter, + SecurityFilters.OAUTH2_AUTHORIZATION_REQUEST_FILTER)); } - - if (openIDFilterId != null) { - filters.add(new OrderDecorator(new RuntimeBeanReference(openIDFilterId), - OPENID_FILTER)); + if (this.openIDFilterId != null) { + filters.add( + new OrderDecorator(new RuntimeBeanReference(this.openIDFilterId), SecurityFilters.OPENID_FILTER)); } - - if (loginPageGenerationFilter != null) { - filters.add(new OrderDecorator(loginPageGenerationFilter, LOGIN_PAGE_FILTER)); - filters.add(new OrderDecorator(this.logoutPageGenerationFilter, LOGOUT_PAGE_FILTER)); + if (this.loginPageGenerationFilter != null) { + filters.add(new OrderDecorator(this.loginPageGenerationFilter, SecurityFilters.LOGIN_PAGE_FILTER)); + filters.add(new OrderDecorator(this.logoutPageGenerationFilter, SecurityFilters.LOGOUT_PAGE_FILTER)); } - - if (basicFilter != null) { - filters.add(new OrderDecorator(basicFilter, BASIC_AUTH_FILTER)); + if (this.basicFilter != null) { + filters.add(new OrderDecorator(this.basicFilter, SecurityFilters.BASIC_AUTH_FILTER)); } - - if (bearerTokenAuthenticationFilter != null) { - filters.add(new OrderDecorator(bearerTokenAuthenticationFilter, BEARER_TOKEN_AUTH_FILTER)); + if (this.bearerTokenAuthenticationFilter != null) { + filters.add( + new OrderDecorator(this.bearerTokenAuthenticationFilter, SecurityFilters.BEARER_TOKEN_AUTH_FILTER)); } - - if (authorizationCodeGrantFilter != null) { - filters.add(new OrderDecorator(authorizationRequestRedirectFilter, OAUTH2_AUTHORIZATION_REQUEST_FILTER.getOrder() + 1)); - filters.add(new OrderDecorator(authorizationCodeGrantFilter, OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER)); + if (this.authorizationCodeGrantFilter != null) { + filters.add(new OrderDecorator(this.authorizationRequestRedirectFilter, + SecurityFilters.OAUTH2_AUTHORIZATION_REQUEST_FILTER.getOrder() + 1)); + filters.add(new OrderDecorator(this.authorizationCodeGrantFilter, + SecurityFilters.OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER)); } - - filters.add(new OrderDecorator(etf, EXCEPTION_TRANSLATION_FILTER)); - + filters.add(new OrderDecorator(this.etf, SecurityFilters.EXCEPTION_TRANSLATION_FILTER)); return filters; } List getProviders() { List providers = new ArrayList<>(); - - if (anonymousProviderRef != null) { - providers.add(anonymousProviderRef); + if (this.anonymousProviderRef != null) { + providers.add(this.anonymousProviderRef); } - - if (rememberMeProviderRef != null) { - providers.add(rememberMeProviderRef); + if (this.rememberMeProviderRef != null) { + providers.add(this.rememberMeProviderRef); } - - if (openIDProviderRef != null) { - providers.add(openIDProviderRef); + if (this.openIDProviderRef != null) { + providers.add(this.openIDProviderRef); } - - if (x509ProviderRef != null) { - providers.add(x509ProviderRef); + if (this.x509ProviderRef != null) { + providers.add(this.x509ProviderRef); } - - if (jeeProviderRef != null) { - providers.add(jeeProviderRef); + if (this.jeeProviderRef != null) { + providers.add(this.jeeProviderRef); } - - if (oauth2LoginAuthenticationProviderRef != null) { - providers.add(oauth2LoginAuthenticationProviderRef); + if (this.oauth2LoginAuthenticationProviderRef != null) { + providers.add(this.oauth2LoginAuthenticationProviderRef); } - - if (oauth2LoginOidcAuthenticationProviderRef != null) { - providers.add(oauth2LoginOidcAuthenticationProviderRef); + if (this.oauth2LoginOidcAuthenticationProviderRef != null) { + providers.add(this.oauth2LoginOidcAuthenticationProviderRef); } - - if (authorizationCodeAuthenticationProviderRef != null) { - providers.add(authorizationCodeAuthenticationProviderRef); + if (this.authorizationCodeAuthenticationProviderRef != null) { + providers.add(this.authorizationCodeAuthenticationProviderRef); } - providers.addAll(this.authenticationProviders); - return providers; } - private static class CsrfTokenHiddenInputFunction implements - Function> { + private static class CsrfTokenHiddenInputFunction implements Function> { @Override public Map apply(HttpServletRequest request) { @@ -1109,5 +967,7 @@ final class AuthenticationConfigBuilder { } return Collections.singletonMap(token.getParameterName(), token.getToken()); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/http/ChannelAttributeFactory.java b/config/src/main/java/org/springframework/security/config/http/ChannelAttributeFactory.java index c926b83465..6c00919429 100644 --- a/config/src/main/java/org/springframework/security/config/http/ChannelAttributeFactory.java +++ b/config/src/main/java/org/springframework/security/config/http/ChannelAttributeFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.List; @@ -29,14 +30,19 @@ import org.springframework.security.web.access.channel.ChannelDecisionManagerImp * @author Luke Taylor * @since 3.0 */ -public class ChannelAttributeFactory { +public final class ChannelAttributeFactory { + private static final String OPT_REQUIRES_HTTP = "http"; + private static final String OPT_REQUIRES_HTTPS = "https"; + private static final String OPT_ANY_CHANNEL = "any"; + private ChannelAttributeFactory() { + } + public static List createChannelAttributes(String requiredChannel) { String channelConfigAttribute; - if (requiredChannel.equals(OPT_REQUIRES_HTTPS)) { channelConfigAttribute = "REQUIRES_SECURE_CHANNEL"; } @@ -47,10 +53,9 @@ public class ChannelAttributeFactory { channelConfigAttribute = ChannelDecisionManagerImpl.ANY_CHANNEL; } else { - throw new BeanCreationException("Unknown channel attribute " - + requiredChannel); + throw new BeanCreationException("Unknown channel attribute " + requiredChannel); } - return SecurityConfig.createList(channelConfigAttribute); } + } diff --git a/config/src/main/java/org/springframework/security/config/http/CorsBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CorsBeanDefinitionParser.java index b8c67e6c54..8f7ccc1eab 100644 --- a/config/src/main/java/org/springframework/security/config/http/CorsBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CorsBeanDefinitionParser.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -24,7 +27,6 @@ import org.springframework.beans.factory.xml.ParserContext; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.web.filter.CorsFilter; -import org.w3c.dom.Element; /** * Parser for the {@code CorsFilter}. @@ -33,26 +35,25 @@ import org.w3c.dom.Element; * @since 4.1.1 */ public class CorsBeanDefinitionParser { + private static final String HANDLER_MAPPING_INTROSPECTOR = "org.springframework.web.servlet.handler.HandlerMappingIntrospector"; private static final String ATT_SOURCE = "configuration-source-ref"; + private static final String ATT_REF = "ref"; public BeanMetadataElement parse(Element element, ParserContext parserContext) { if (element == null) { return null; } - String filterRef = element.getAttribute(ATT_REF); if (StringUtils.hasText(filterRef)) { return new RuntimeBeanReference(filterRef); } - BeanMetadataElement configurationSource = getSource(element, parserContext); if (configurationSource == null) { throw new BeanCreationException("Could not create CorsFilter"); } - BeanDefinitionBuilder filterBldr = BeanDefinitionBuilder.rootBeanDefinition(CorsFilter.class); filterBldr.addConstructorArgValue(configurationSource); return filterBldr.getBeanDefinition(); @@ -63,13 +64,11 @@ public class CorsBeanDefinitionParser { if (StringUtils.hasText(configurationSourceRef)) { return new RuntimeBeanReference(configurationSourceRef); } - - boolean mvcPresent = ClassUtils.isPresent(HANDLER_MAPPING_INTROSPECTOR, - getClass().getClassLoader()); + boolean mvcPresent = ClassUtils.isPresent(HANDLER_MAPPING_INTROSPECTOR, getClass().getClassLoader()); if (!mvcPresent) { return null; } - return new RootBeanDefinition(HandlerMappingIntrospectorFactoryBean.class); } + } diff --git a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java index 9ef1fed832..58dcd468a8 100644 --- a/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.Arrays; import java.util.HashSet; import java.util.List; + import javax.servlet.http.HttpServletRequest; import org.w3c.dom.Element; @@ -62,76 +64,65 @@ import org.springframework.util.StringUtils; public class CsrfBeanDefinitionParser implements BeanDefinitionParser { private static final String REQUEST_DATA_VALUE_PROCESSOR = "requestDataValueProcessor"; + private static final String DISPATCHER_SERVLET_CLASS_NAME = "org.springframework.web.servlet.DispatcherServlet"; + private static final String ATT_MATCHER = "request-matcher-ref"; + private static final String ATT_REPOSITORY = "token-repository-ref"; private String csrfRepositoryRef; + private BeanDefinition csrfFilter; private String requestMatcherRef; @Override public BeanDefinition parse(Element element, ParserContext pc) { - boolean disabled = element != null - && "true".equals(element.getAttribute("disabled")); + boolean disabled = element != null && "true".equals(element.getAttribute("disabled")); if (disabled) { return null; } - boolean webmvcPresent = ClassUtils.isPresent(DISPATCHER_SERVLET_CLASS_NAME, - getClass().getClassLoader()); + boolean webmvcPresent = ClassUtils.isPresent(DISPATCHER_SERVLET_CLASS_NAME, getClass().getClassLoader()); if (webmvcPresent) { if (!pc.getRegistry().containsBeanDefinition(REQUEST_DATA_VALUE_PROCESSOR)) { - RootBeanDefinition beanDefinition = new RootBeanDefinition( - CsrfRequestDataValueProcessor.class); - BeanComponentDefinition componentDefinition = new BeanComponentDefinition( - beanDefinition, REQUEST_DATA_VALUE_PROCESSOR); + RootBeanDefinition beanDefinition = new RootBeanDefinition(CsrfRequestDataValueProcessor.class); + BeanComponentDefinition componentDefinition = new BeanComponentDefinition(beanDefinition, + REQUEST_DATA_VALUE_PROCESSOR); pc.registerBeanComponent(componentDefinition); } } - if (element != null) { this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY); this.requestMatcherRef = element.getAttribute(ATT_MATCHER); } - if (!StringUtils.hasText(this.csrfRepositoryRef)) { - - RootBeanDefinition csrfTokenRepository = new RootBeanDefinition( - HttpSessionCsrfTokenRepository.class); + RootBeanDefinition csrfTokenRepository = new RootBeanDefinition(HttpSessionCsrfTokenRepository.class); BeanDefinitionBuilder lazyTokenRepository = BeanDefinitionBuilder .rootBeanDefinition(LazyCsrfTokenRepository.class); lazyTokenRepository.addConstructorArgValue(csrfTokenRepository); - this.csrfRepositoryRef = pc.getReaderContext() - .generateBeanName(lazyTokenRepository.getBeanDefinition()); - pc.registerBeanComponent(new BeanComponentDefinition( - lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef)); + this.csrfRepositoryRef = pc.getReaderContext().generateBeanName(lazyTokenRepository.getBeanDefinition()); + pc.registerBeanComponent( + new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef)); } - - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(CsrfFilter.class); + BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class); builder.addConstructorArgReference(this.csrfRepositoryRef); - if (StringUtils.hasText(this.requestMatcherRef)) { builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef); } - this.csrfFilter = builder.getBeanDefinition(); return this.csrfFilter; } /** * Populate the AccessDeniedHandler on the {@link CsrfFilter} - * * @param invalidSessionStrategy the {@link InvalidSessionStrategy} to use * @param defaultDeniedHandler the {@link AccessDeniedHandler} to use */ - void initAccessDeniedHandler(BeanDefinition invalidSessionStrategy, - BeanMetadataElement defaultDeniedHandler) { - BeanMetadataElement accessDeniedHandler = createAccessDeniedHandler( - invalidSessionStrategy, defaultDeniedHandler); - this.csrfFilter.getPropertyValues().addPropertyValue("accessDeniedHandler", - accessDeniedHandler); + void initAccessDeniedHandler(BeanDefinition invalidSessionStrategy, BeanMetadataElement defaultDeniedHandler) { + BeanMetadataElement accessDeniedHandler = createAccessDeniedHandler(invalidSessionStrategy, + defaultDeniedHandler); + this.csrfFilter.getPropertyValues().addPropertyValue("accessDeniedHandler", accessDeniedHandler); } /** @@ -143,15 +134,12 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { * {@link InvalidSessionAccessDeniedHandler} and the * {@link #getDefaultAccessDeniedHandler(HttpSecurityBuilder)}. Otherwise, only * {@link #getDefaultAccessDeniedHandler(HttpSecurityBuilder)} is used. - * * @param invalidSessionStrategy the {@link InvalidSessionStrategy} to use * @param defaultDeniedHandler the {@link AccessDeniedHandler} to use - * * @return the {@link BeanMetadataElement} that is the {@link AccessDeniedHandler} to * populate on the {@link CsrfFilter} */ - private BeanMetadataElement createAccessDeniedHandler( - BeanDefinition invalidSessionStrategy, + private BeanMetadataElement createAccessDeniedHandler(BeanDefinition invalidSessionStrategy, BeanMetadataElement defaultDeniedHandler) { if (invalidSessionStrategy == null) { return defaultDeniedHandler; @@ -160,14 +148,11 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { BeanDefinitionBuilder invalidSessionHandlerBldr = BeanDefinitionBuilder .rootBeanDefinition(InvalidSessionAccessDeniedHandler.class); invalidSessionHandlerBldr.addConstructorArgValue(invalidSessionStrategy); - handlers.put(MissingCsrfTokenException.class, - invalidSessionHandlerBldr.getBeanDefinition()); - + handlers.put(MissingCsrfTokenException.class, invalidSessionHandlerBldr.getBeanDefinition()); BeanDefinitionBuilder deniedBldr = BeanDefinitionBuilder .rootBeanDefinition(DelegatingAccessDeniedHandler.class); deniedBldr.addConstructorArgValue(handlers); deniedBldr.addConstructorArgValue(defaultDeniedHandler); - return deniedBldr.getBeanDefinition(); } @@ -187,43 +172,31 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser { void setIgnoreCsrfRequestMatchers(List requestMatchers) { if (!requestMatchers.isEmpty()) { - BeanMetadataElement requestMatcher; - if (StringUtils.hasText(this.requestMatcherRef)) { - requestMatcher = new RuntimeBeanReference(this.requestMatcherRef); - } else { - requestMatcher = new RootBeanDefinition(DefaultRequiresCsrfMatcher.class); - } - BeanDefinitionBuilder and = BeanDefinitionBuilder - .rootBeanDefinition(AndRequestMatcher.class); - BeanDefinitionBuilder negated = BeanDefinitionBuilder - .rootBeanDefinition(NegatedRequestMatcher.class); - BeanDefinitionBuilder or = BeanDefinitionBuilder - .rootBeanDefinition(OrRequestMatcher.class); + BeanMetadataElement requestMatcher = (!StringUtils.hasText(this.requestMatcherRef)) + ? new RootBeanDefinition(DefaultRequiresCsrfMatcher.class) + : new RuntimeBeanReference(this.requestMatcherRef); + BeanDefinitionBuilder and = BeanDefinitionBuilder.rootBeanDefinition(AndRequestMatcher.class); + BeanDefinitionBuilder negated = BeanDefinitionBuilder.rootBeanDefinition(NegatedRequestMatcher.class); + BeanDefinitionBuilder or = BeanDefinitionBuilder.rootBeanDefinition(OrRequestMatcher.class); or.addConstructorArgValue(requestMatchers); negated.addConstructorArgValue(or.getBeanDefinition()); List ands = new ManagedList<>(); ands.add(requestMatcher); ands.add(negated.getBeanDefinition()); and.addConstructorArgValue(ands); - this.csrfFilter.getPropertyValues() - .add("requireCsrfProtectionMatcher", and.getBeanDefinition()); + this.csrfFilter.getPropertyValues().add("requireCsrfProtectionMatcher", and.getBeanDefinition()); } } private static final class DefaultRequiresCsrfMatcher implements RequestMatcher { - private final HashSet allowedMethods = new HashSet<>( - Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.web.util.matcher.RequestMatcher#matches(javax. - * servlet.http.HttpServletRequest) - */ + private final HashSet allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS")); + @Override public boolean matches(HttpServletRequest request) { return !this.allowedMethods.contains(request.getMethod()); } + } + } diff --git a/config/src/main/java/org/springframework/security/config/http/DefaultFilterChainValidator.java b/config/src/main/java/org/springframework/security/config/http/DefaultFilterChainValidator.java index ac4ec37ed3..423ae18de6 100644 --- a/config/src/main/java/org/springframework/security/config/http/DefaultFilterChainValidator.java +++ b/config/src/main/java/org/springframework/security/config/http/DefaultFilterChainValidator.java @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; import javax.servlet.Filter; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -40,54 +45,44 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter import org.springframework.security.web.jaasapi.JaasApiIntegrationFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.session.SessionManagementFilter; -import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.AnyRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; public class DefaultFilterChainValidator implements FilterChainProxy.FilterChainValidator { + private final Log logger = LogFactory.getLog(getClass()); + @Override public void validate(FilterChainProxy fcp) { for (SecurityFilterChain filterChain : fcp.getFilterChains()) { checkLoginPageIsntProtected(fcp, filterChain.getFilters()); checkFilterStack(filterChain.getFilters()); } - checkPathOrder(new ArrayList<>(fcp.getFilterChains())); - checkForDuplicateMatchers(new ArrayList<>( - fcp.getFilterChains())); + checkForDuplicateMatchers(new ArrayList<>(fcp.getFilterChains())); } private void checkPathOrder(List filterChains) { // Check that the universal pattern is listed at the end, if at all Iterator chains = filterChains.iterator(); - while (chains.hasNext()) { - RequestMatcher matcher = ((DefaultSecurityFilterChain) chains.next()) - .getRequestMatcher(); + RequestMatcher matcher = ((DefaultSecurityFilterChain) chains.next()).getRequestMatcher(); if (AnyRequestMatcher.INSTANCE.equals(matcher) && chains.hasNext()) { - throw new IllegalArgumentException( - "A universal match pattern ('/**') is defined " - + " before other patterns in the filter chain, causing them to be ignored. Please check the " - + "ordering in your namespace or FilterChainProxy bean configuration"); + throw new IllegalArgumentException("A universal match pattern ('/**') is defined " + + " before other patterns in the filter chain, causing them to be ignored. Please check the " + + "ordering in your namespace or FilterChainProxy bean configuration"); } } } private void checkForDuplicateMatchers(List chains) { - while (chains.size() > 1) { - DefaultSecurityFilterChain chain = (DefaultSecurityFilterChain) chains - .remove(0); - + DefaultSecurityFilterChain chain = (DefaultSecurityFilterChain) chains.remove(0); for (SecurityFilterChain test : chains) { - if (chain.getRequestMatcher().equals( - ((DefaultSecurityFilterChain) test).getRequestMatcher())) { - throw new IllegalArgumentException( - "The FilterChainProxy contains two filter chains using the" - + " matcher " - + chain.getRequestMatcher() - + ". If you are using multiple namespace " - + "elements, you must use a 'pattern' attribute to define the request patterns to which they apply."); + if (chain.getRequestMatcher().equals(((DefaultSecurityFilterChain) test).getRequestMatcher())) { + throw new IllegalArgumentException("The FilterChainProxy contains two filter chains using the" + + " matcher " + chain.getRequestMatcher() + ". If you are using multiple namespace " + + "elements, you must use a 'pattern' attribute to define the request patterns to which they apply."); } } } @@ -100,7 +95,6 @@ public class DefaultFilterChainValidator implements FilterChainProxy.FilterChain return (F) f; } } - return null; } @@ -126,8 +120,8 @@ public class DefaultFilterChainValidator implements FilterChainProxy.FilterChain for (int j = i + 1; j < filters.size(); j++) { Filter f2 = filters.get(j); if (clazz.isAssignableFrom(f2.getClass())) { - logger.warn("Possible error: Filters at position " + i + " and " - + j + " are both " + "instances of " + clazz.getName()); + this.logger.warn("Possible error: Filters at position " + i + " and " + j + " are both " + + "instances of " + clazz.getName()); return; } } @@ -139,84 +133,66 @@ public class DefaultFilterChainValidator implements FilterChainProxy.FilterChain * Checks for the common error of having a login page URL protected by the security * interceptor */ - private void checkLoginPageIsntProtected(FilterChainProxy fcp, - List filterStack) { - ExceptionTranslationFilter etf = getFilter(ExceptionTranslationFilter.class, - filterStack); - - if (etf == null - || !(etf.getAuthenticationEntryPoint() instanceof LoginUrlAuthenticationEntryPoint)) { + private void checkLoginPageIsntProtected(FilterChainProxy fcp, List filterStack) { + ExceptionTranslationFilter etf = getFilter(ExceptionTranslationFilter.class, filterStack); + if (etf == null || !(etf.getAuthenticationEntryPoint() instanceof LoginUrlAuthenticationEntryPoint)) { return; } - - String loginPage = ((LoginUrlAuthenticationEntryPoint) etf - .getAuthenticationEntryPoint()).getLoginFormUrl(); - logger.info("Checking whether login URL '" + loginPage - + "' is accessible with your configuration"); + String loginPage = ((LoginUrlAuthenticationEntryPoint) etf.getAuthenticationEntryPoint()).getLoginFormUrl(); + this.logger.info("Checking whether login URL '" + loginPage + "' is accessible with your configuration"); FilterInvocation loginRequest = new FilterInvocation(loginPage, "POST"); List filters = null; - try { filters = fcp.getFilters(loginPage); } - catch (Exception e) { + catch (Exception ex) { // May happen legitimately if a filter-chain request matcher requires more // request data than that provided // by the dummy request used when creating the filter invocation. - logger.info("Failed to obtain filter chain information for the login page. Unable to complete check."); + this.logger.info("Failed to obtain filter chain information for the login page. Unable to complete check."); } - if (filters == null || filters.isEmpty()) { - logger.debug("Filter chain is empty for the login page"); + this.logger.debug("Filter chain is empty for the login page"); return; } - if (getFilter(DefaultLoginPageGeneratingFilter.class, filters) != null) { - logger.debug("Default generated login page is in use"); + this.logger.debug("Default generated login page is in use"); return; } - - FilterSecurityInterceptor fsi = getFilter(FilterSecurityInterceptor.class, - filters); + FilterSecurityInterceptor fsi = getFilter(FilterSecurityInterceptor.class, filters); FilterInvocationSecurityMetadataSource fids = fsi.getSecurityMetadataSource(); - Collection attributes = fids.getAttributes(loginRequest); - if (attributes == null) { - logger.debug("No access attributes defined for login page URL"); + this.logger.debug("No access attributes defined for login page URL"); if (fsi.isRejectPublicInvocations()) { - logger.warn("FilterSecurityInterceptor is configured to reject public invocations." + this.logger.warn("FilterSecurityInterceptor is configured to reject public invocations." + " Your login page may not be accessible."); } return; } - - AnonymousAuthenticationFilter anonPF = getFilter( - AnonymousAuthenticationFilter.class, filters); + AnonymousAuthenticationFilter anonPF = getFilter(AnonymousAuthenticationFilter.class, filters); if (anonPF == null) { - logger.warn("The login page is being protected by the filter chain, but you don't appear to have" + this.logger.warn("The login page is being protected by the filter chain, but you don't appear to have" + " anonymous authentication enabled. This is almost certainly an error."); return; } - // Simulate an anonymous access with the supplied attributes. - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("key", - anonPF.getPrincipal(), anonPF.getAuthorities()); + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("key", anonPF.getPrincipal(), + anonPF.getAuthorities()); try { fsi.getAccessDecisionManager().decide(token, loginRequest, attributes); } - catch (AccessDeniedException e) { - logger.warn("Anonymous access to the login page doesn't appear to be enabled. This is almost certainly " - + "an error. Please check your configuration allows unauthenticated access to the configured " - + "login page. (Simulated access was rejected: " + e + ")"); + catch (AccessDeniedException ex) { + this.logger.warn("Anonymous access to the login page doesn't appear to be enabled. " + + "This is almost certainly an error. Please check your configuration allows unauthenticated " + + "access to the configured login page. (Simulated access was rejected: " + ex + ")"); } - catch (Exception e) { + catch (Exception ex) { // May happen legitimately if a filter-chain request matcher requires more // request data than that provided // by the dummy request used when creating the filter invocation. See SEC-1878 - logger.info( - "Unable to check access to the login page to determine if anonymous access is allowed. This might be an error, but can happen under normal circumstances.", - e); + this.logger.info("Unable to check access to the login page to determine if anonymous access is allowed. " + + "This might be an error, but can happen under normal circumstances.", ex); } } diff --git a/config/src/main/java/org/springframework/security/config/http/FilterChainBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/FilterChainBeanDefinitionParser.java index ae5298b56e..c362bcc0a1 100644 --- a/config/src/main/java/org/springframework/security/config/http/FilterChainBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/FilterChainBeanDefinitionParser.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.Collections; + +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; @@ -24,25 +29,21 @@ import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; - -import java.util.*; /** * @author Luke Taylor */ public class FilterChainBeanDefinitionParser implements BeanDefinitionParser { + private static final String ATT_REQUEST_MATCHER_REF = "request-matcher-ref"; + @Override public BeanDefinition parse(Element elt, ParserContext pc) { MatcherType matcherType = MatcherType.fromElement(elt); String path = elt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN); String requestMatcher = elt.getAttribute(ATT_REQUEST_MATCHER_REF); String filters = elt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS); - - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(DefaultSecurityFilterChain.class); - + BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(DefaultSecurityFilterChain.class); if (StringUtils.hasText(path)) { Assert.isTrue(!StringUtils.hasText(requestMatcher), ""); builder.addConstructorArgValue(matcherType.createMatcher(pc, path, null)); @@ -51,22 +52,18 @@ public class FilterChainBeanDefinitionParser implements BeanDefinitionParser { Assert.isTrue(StringUtils.hasText(requestMatcher), ""); builder.addConstructorArgReference(requestMatcher); } - if (filters.equals(HttpSecurityBeanDefinitionParser.OPT_FILTERS_NONE)) { builder.addConstructorArgValue(Collections.EMPTY_LIST); } else { String[] filterBeanNames = StringUtils.tokenizeToStringArray(filters, ","); - ManagedList filterChain = new ManagedList<>( - filterBeanNames.length); - + ManagedList filterChain = new ManagedList<>(filterBeanNames.length); for (String name : filterBeanNames) { filterChain.add(new RuntimeBeanReference(name)); } - builder.addConstructorArgValue(filterChain); } - return builder.getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/http/FilterChainMapBeanDefinitionDecorator.java b/config/src/main/java/org/springframework/security/config/http/FilterChainMapBeanDefinitionDecorator.java index 35a43022e4..842ac8799f 100644 --- a/config/src/main/java/org/springframework/security/config/http/FilterChainMapBeanDefinitionDecorator.java +++ b/config/src/main/java/org/springframework/security/config/http/FilterChainMapBeanDefinitionDecorator.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.List; +import org.w3c.dom.Element; +import org.w3c.dom.Node; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; @@ -29,8 +33,6 @@ import org.springframework.security.config.Elements; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; -import org.w3c.dom.Node; /** * Sets the filter chain Map for a FilterChainProxy bean declaration. @@ -39,69 +41,48 @@ import org.w3c.dom.Node; */ public class FilterChainMapBeanDefinitionDecorator implements BeanDefinitionDecorator { + @Override @SuppressWarnings("unchecked") - public BeanDefinitionHolder decorate(Node node, BeanDefinitionHolder holder, - ParserContext parserContext) { + public BeanDefinitionHolder decorate(Node node, BeanDefinitionHolder holder, ParserContext parserContext) { BeanDefinition filterChainProxy = holder.getBeanDefinition(); - ManagedList securityFilterChains = new ManagedList<>(); Element elt = (Element) node; - MatcherType matcherType = MatcherType.fromElement(elt); - - List filterChainElts = DomUtils.getChildElementsByTagName(elt, - Elements.FILTER_CHAIN); - + List filterChainElts = DomUtils.getChildElementsByTagName(elt, Elements.FILTER_CHAIN); for (Element chain : filterChainElts) { - String path = chain - .getAttribute(HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN); - String filters = chain - .getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS); - + String path = chain.getAttribute(HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN); + String filters = chain.getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS); if (!StringUtils.hasText(path)) { parserContext.getReaderContext().error( - "The attribute '" - + HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN - + "' must not be empty", elt); + "The attribute '" + HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN + "' must not be empty", + elt); } - if (!StringUtils.hasText(filters)) { parserContext.getReaderContext().error( - "The attribute '" + HttpSecurityBeanDefinitionParser.ATT_FILTERS - + "'must not be empty", elt); + "The attribute '" + HttpSecurityBeanDefinitionParser.ATT_FILTERS + "'must not be empty", elt); } - BeanDefinition matcher = matcherType.createMatcher(parserContext, path, null); - if (filters.equals(HttpSecurityBeanDefinitionParser.OPT_FILTERS_NONE)) { - securityFilterChains.add(createSecurityFilterChain(matcher, - new ManagedList(0))); + securityFilterChains.add(createSecurityFilterChain(matcher, new ManagedList(0))); } else { - String[] filterBeanNames = StringUtils - .tokenizeToStringArray(filters, ","); + String[] filterBeanNames = StringUtils.tokenizeToStringArray(filters, ","); ManagedList filterChain = new ManagedList(filterBeanNames.length); - for (String name : filterBeanNames) { filterChain.add(new RuntimeBeanReference(name)); } - securityFilterChains.add(createSecurityFilterChain(matcher, filterChain)); } } - - filterChainProxy.getConstructorArgumentValues().addGenericArgumentValue( - securityFilterChains); - + filterChainProxy.getConstructorArgumentValues().addGenericArgumentValue(securityFilterChains); return holder; } - private BeanDefinition createSecurityFilterChain(BeanDefinition matcher, - ManagedList filters) { - BeanDefinitionBuilder sfc = BeanDefinitionBuilder - .rootBeanDefinition(DefaultSecurityFilterChain.class); + private BeanDefinition createSecurityFilterChain(BeanDefinition matcher, ManagedList filters) { + BeanDefinitionBuilder sfc = BeanDefinitionBuilder.rootBeanDefinition(DefaultSecurityFilterChain.class); sfc.addConstructorArgValue(matcher); sfc.addConstructorArgValue(filters); return sfc.getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/http/FilterInvocationSecurityMetadataSourceParser.java b/config/src/main/java/org/springframework/security/config/http/FilterInvocationSecurityMetadataSourceParser.java index dde778d41d..abc7929438 100644 --- a/config/src/main/java/org/springframework/security/config/http/FilterInvocationSecurityMetadataSourceParser.java +++ b/config/src/main/java/org/springframework/security/config/http/FilterInvocationSecurityMetadataSourceParser.java @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.http; -import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF; +package org.springframework.security.config.http; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -48,101 +48,83 @@ import org.springframework.util.xml.DomUtils; * @author Luke Taylor */ public class FilterInvocationSecurityMetadataSourceParser implements BeanDefinitionParser { + private static final String ATT_USE_EXPRESSIONS = "use-expressions"; + private static final String ATT_HTTP_METHOD = "method"; + private static final String ATT_PATTERN = "pattern"; + private static final String ATT_ACCESS = "access"; + private static final String ATT_SERVLET_PATH = "servlet-path"; - private static final Log logger = LogFactory - .getLog(FilterInvocationSecurityMetadataSourceParser.class); + private static final Log logger = LogFactory.getLog(FilterInvocationSecurityMetadataSourceParser.class); + + @Override public BeanDefinition parse(Element element, ParserContext parserContext) { - List interceptUrls = DomUtils.getChildElementsByTagName(element, - Elements.INTERCEPT_URL); - + List interceptUrls = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_URL); // Check for attributes that aren't allowed in this context for (Element elt : interceptUrls) { - if (StringUtils.hasLength(elt - .getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL))) { - parserContext.getReaderContext().error( - "The attribute '" - + HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL - + "' isn't allowed here.", elt); + if (StringUtils.hasLength(elt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL))) { + parserContext.getReaderContext().error("The attribute '" + + HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL + "' isn't allowed here.", elt); } - - if (StringUtils.hasLength(elt - .getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS))) { + if (StringUtils.hasLength(elt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS))) { parserContext.getReaderContext().error( - "The attribute '" + HttpSecurityBeanDefinitionParser.ATT_FILTERS - + "' isn't allowed here.", elt); + "The attribute '" + HttpSecurityBeanDefinitionParser.ATT_FILTERS + "' isn't allowed here.", + elt); } - if (StringUtils.hasLength(elt.getAttribute(ATT_SERVLET_PATH))) { - parserContext.getReaderContext().error( - "The attribute '" + ATT_SERVLET_PATH - + "' isn't allowed here.", elt); + parserContext.getReaderContext().error("The attribute '" + ATT_SERVLET_PATH + "' isn't allowed here.", + elt); } } - - BeanDefinition mds = createSecurityMetadataSource(interceptUrls, false, element, - parserContext); - + BeanDefinition mds = createSecurityMetadataSource(interceptUrls, false, element, parserContext); String id = element.getAttribute(AbstractBeanDefinitionParser.ID_ATTRIBUTE); - if (StringUtils.hasText(id)) { parserContext.registerComponent(new BeanComponentDefinition(mds, id)); parserContext.getRegistry().registerBeanDefinition(id, mds); } - return mds; } - static RootBeanDefinition createSecurityMetadataSource(List interceptUrls, - boolean addAllAuth, Element httpElt, ParserContext pc) { + static RootBeanDefinition createSecurityMetadataSource(List interceptUrls, boolean addAllAuth, + Element httpElt, ParserContext pc) { MatcherType matcherType = MatcherType.fromElement(httpElt); boolean useExpressions = isUseExpressions(httpElt); - ManagedMap requestToAttributesMap = parseInterceptUrlsForFilterInvocationRequestMap( matcherType, interceptUrls, useExpressions, addAllAuth, pc); BeanDefinitionBuilder fidsBuilder; - if (useExpressions) { - Element expressionHandlerElt = DomUtils.getChildElementByTagName(httpElt, - Elements.EXPRESSION_HANDLER); - String expressionHandlerRef = expressionHandlerElt == null ? null - : expressionHandlerElt.getAttribute("ref"); - + Element expressionHandlerElt = DomUtils.getChildElementByTagName(httpElt, Elements.EXPRESSION_HANDLER); + String expressionHandlerRef = (expressionHandlerElt != null) ? expressionHandlerElt.getAttribute("ref") + : null; if (StringUtils.hasText(expressionHandlerRef)) { - logger.info("Using bean '" + expressionHandlerRef - + "' as web SecurityExpressionHandler implementation"); + logger.info("Using bean '" + expressionHandlerRef + "' as web " + + "SecurityExpressionHandler implementation"); } else { expressionHandlerRef = registerDefaultExpressionHandler(pc); } - fidsBuilder = BeanDefinitionBuilder .rootBeanDefinition(ExpressionBasedFilterInvocationSecurityMetadataSource.class); fidsBuilder.addConstructorArgValue(requestToAttributesMap); fidsBuilder.addConstructorArgReference(expressionHandlerRef); } else { - fidsBuilder = BeanDefinitionBuilder - .rootBeanDefinition(DefaultFilterInvocationSecurityMetadataSource.class); + fidsBuilder = BeanDefinitionBuilder.rootBeanDefinition(DefaultFilterInvocationSecurityMetadataSource.class); fidsBuilder.addConstructorArgValue(requestToAttributesMap); } - fidsBuilder.getRawBeanDefinition().setSource(pc.extractSource(httpElt)); - return (RootBeanDefinition) fidsBuilder.getBeanDefinition(); } static String registerDefaultExpressionHandler(ParserContext pc) { - BeanDefinition expressionHandler = GrantedAuthorityDefaultsParserUtils.registerWithDefaultRolePrefix(pc, DefaultWebSecurityExpressionHandlerBeanFactory.class); - String expressionHandlerRef = pc.getReaderContext().generateBeanName( - expressionHandler); - pc.registerBeanComponent(new BeanComponentDefinition(expressionHandler, - expressionHandlerRef)); - + BeanDefinition expressionHandler = GrantedAuthorityDefaultsParserUtils.registerWithDefaultRolePrefix(pc, + DefaultWebSecurityExpressionHandlerBeanFactory.class); + String expressionHandlerRef = pc.getReaderContext().generateBeanName(expressionHandler); + pc.registerBeanComponent(new BeanComponentDefinition(expressionHandler, expressionHandlerRef)); return expressionHandlerRef; } @@ -152,47 +134,38 @@ public class FilterInvocationSecurityMetadataSourceParser implements BeanDefinit } private static ManagedMap parseInterceptUrlsForFilterInvocationRequestMap( - MatcherType matcherType, List urlElts, boolean useExpressions, - boolean addAuthenticatedAll, ParserContext parserContext) { - + MatcherType matcherType, List urlElts, boolean useExpressions, boolean addAuthenticatedAll, + ParserContext parserContext) { ManagedMap filterInvocationDefinitionMap = new ManagedMap<>(); - for (Element urlElt : urlElts) { String access = urlElt.getAttribute(ATT_ACCESS); if (!StringUtils.hasText(access)) { continue; } - String path = urlElt.getAttribute(ATT_PATTERN); - String matcherRef = urlElt.getAttribute(ATT_REQUEST_MATCHER_REF); + String matcherRef = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF); boolean hasMatcherRef = StringUtils.hasText(matcherRef); - if (!hasMatcherRef && !StringUtils.hasText(path)) { - parserContext.getReaderContext().error( - "path attribute cannot be empty or null", urlElt); + parserContext.getReaderContext().error("path attribute cannot be empty or null", urlElt); } - String method = urlElt.getAttribute(ATT_HTTP_METHOD); if (!StringUtils.hasText(method)) { method = null; } - String servletPath = urlElt.getAttribute(ATT_SERVLET_PATH); if (!StringUtils.hasText(servletPath)) { servletPath = null; - } else if (!MatcherType.mvc.equals(matcherType)) { - parserContext.getReaderContext().error( - ATT_SERVLET_PATH + " is not applicable for request-matcher: '" + matcherType.name() + "'", urlElt); } - - BeanMetadataElement matcher = hasMatcherRef ? new RuntimeBeanReference(matcherRef) : matcherType.createMatcher(parserContext, path, - method, servletPath); - BeanDefinitionBuilder attributeBuilder = BeanDefinitionBuilder - .rootBeanDefinition(SecurityConfig.class); - + else if (!MatcherType.mvc.equals(matcherType)) { + parserContext.getReaderContext().error( + ATT_SERVLET_PATH + " is not applicable for request-matcher: '" + matcherType.name() + "'", + urlElt); + } + BeanMetadataElement matcher = hasMatcherRef ? new RuntimeBeanReference(matcherRef) + : matcherType.createMatcher(parserContext, path, method, servletPath); + BeanDefinitionBuilder attributeBuilder = BeanDefinitionBuilder.rootBeanDefinition(SecurityConfig.class); if (useExpressions) { - logger.info("Creating access control expression attribute '" + access - + "' for " + path); + logger.info("Creating access control expression attribute '" + access + "' for " + path); // The single expression will be parsed later by the // ExpressionBasedFilterInvocationSecurityMetadataSource attributeBuilder.addConstructorArgValue(new String[] { access }); @@ -203,37 +176,32 @@ public class FilterInvocationSecurityMetadataSourceParser implements BeanDefinit attributeBuilder.addConstructorArgValue(access); attributeBuilder.setFactoryMethod("createListFromCommaDelimitedString"); } - if (filterInvocationDefinitionMap.containsKey(matcher)) { - logger.warn("Duplicate URL defined: " + path - + ". The original attribute values will be overwritten"); + logger.warn("Duplicate URL defined: " + path + ". The original attribute values will be overwritten"); } - - filterInvocationDefinitionMap.put(matcher, - attributeBuilder.getBeanDefinition()); + filterInvocationDefinitionMap.put(matcher, attributeBuilder.getBeanDefinition()); } - if (addAuthenticatedAll && filterInvocationDefinitionMap.isEmpty()) { - - BeanDefinition matcher = matcherType.createMatcher(parserContext, "/**", - null); - BeanDefinitionBuilder attributeBuilder = BeanDefinitionBuilder - .rootBeanDefinition(SecurityConfig.class); + BeanDefinition matcher = matcherType.createMatcher(parserContext, "/**", null); + BeanDefinitionBuilder attributeBuilder = BeanDefinitionBuilder.rootBeanDefinition(SecurityConfig.class); attributeBuilder.addConstructorArgValue(new String[] { "authenticated" }); attributeBuilder.setFactoryMethod("createList"); - filterInvocationDefinitionMap.put(matcher, - attributeBuilder.getBeanDefinition()); + filterInvocationDefinitionMap.put(matcher, attributeBuilder.getBeanDefinition()); } - return filterInvocationDefinitionMap; } - static class DefaultWebSecurityExpressionHandlerBeanFactory extends GrantedAuthorityDefaultsParserUtils.AbstractGrantedAuthorityDefaultsBeanFactory { + static class DefaultWebSecurityExpressionHandlerBeanFactory + extends GrantedAuthorityDefaultsParserUtils.AbstractGrantedAuthorityDefaultsBeanFactory { + private DefaultWebSecurityExpressionHandler handler = new DefaultWebSecurityExpressionHandler(); + @Override public DefaultWebSecurityExpressionHandler getBean() { - handler.setDefaultRolePrefix(this.rolePrefix); - return handler; + this.handler.setDefaultRolePrefix(this.rolePrefix); + return this.handler; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/http/FormLoginBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/FormLoginBeanDefinitionParser.java index 9bd44c7f7c..e29bed8283 100644 --- a/config/src/main/java/org/springframework/security/config/http/FormLoginBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/FormLoginBeanDefinitionParser.java @@ -13,19 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.ParserContext; -import org.springframework.security.web.authentication.*; +import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler; +import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler; +import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; +import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; +import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * @author Luke Taylor @@ -35,45 +41,64 @@ import org.w3c.dom.Element; * @author Shazin Sadakath */ public class FormLoginBeanDefinitionParser { + protected final Log logger = LogFactory.getLog(getClass()); private static final String ATT_LOGIN_URL = "login-processing-url"; static final String ATT_LOGIN_PAGE = "login-page"; + private static final String DEF_LOGIN_PAGE = DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL; private static final String ATT_FORM_LOGIN_TARGET_URL = "default-target-url"; + private static final String ATT_ALWAYS_USE_DEFAULT_TARGET_URL = "always-use-default-target"; + private static final String DEF_FORM_LOGIN_TARGET_URL = "/"; + private static final String ATT_USERNAME_PARAMETER = "username-parameter"; + private static final String ATT_PASSWORD_PARAMETER = "password-parameter"; private static final String ATT_FORM_LOGIN_AUTHENTICATION_FAILURE_URL = "authentication-failure-url"; + private static final String DEF_FORM_LOGIN_AUTHENTICATION_FAILURE_URL = DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL + "?" + DefaultLoginPageGeneratingFilter.ERROR_PARAMETER_NAME; private static final String ATT_SUCCESS_HANDLER_REF = "authentication-success-handler-ref"; + private static final String ATT_FAILURE_HANDLER_REF = "authentication-failure-handler-ref"; + private static final String ATT_FORM_LOGIN_AUTHENTICATION_FAILURE_FORWARD_URL = "authentication-failure-forward-url"; + private static final String ATT_FORM_LOGIN_AUTHENTICATION_SUCCESS_FORWARD_URL = "authentication-success-forward-url"; private final String defaultLoginProcessingUrl; + private final String filterClassName; + private final BeanReference requestCache; + private final BeanReference sessionStrategy; + private final boolean allowSessionCreation; + private final BeanReference portMapper; + private final BeanReference portResolver; private RootBeanDefinition filterBean; + private RootBeanDefinition entryPointBean; + private String loginPage; + private String loginMethod; + private String loginProcessingUrl; - FormLoginBeanDefinitionParser(String defaultLoginProcessingUrl, String loginMethod, - String filterClassName, BeanReference requestCache, - BeanReference sessionStrategy, boolean allowSessionCreation, + FormLoginBeanDefinitionParser(String defaultLoginProcessingUrl, String loginMethod, String filterClassName, + BeanReference requestCache, BeanReference sessionStrategy, boolean allowSessionCreation, BeanReference portMapper, BeanReference portResolver) { this.defaultLoginProcessingUrl = defaultLoginProcessingUrl; this.loginMethod = loginMethod; @@ -98,130 +123,103 @@ public class FormLoginBeanDefinitionParser { String authDetailsSourceRef = null; String authenticationFailureForwardUrl = null; String authenticationSuccessForwardUrl = null; - Object source = null; - if (elt != null) { source = pc.extractSource(elt); loginUrl = elt.getAttribute(ATT_LOGIN_URL); WebConfigUtils.validateHttpRedirect(loginUrl, pc, source); defaultTargetUrl = elt.getAttribute(ATT_FORM_LOGIN_TARGET_URL); WebConfigUtils.validateHttpRedirect(defaultTargetUrl, pc, source); - authenticationFailureUrl = elt - .getAttribute(ATT_FORM_LOGIN_AUTHENTICATION_FAILURE_URL); + authenticationFailureUrl = elt.getAttribute(ATT_FORM_LOGIN_AUTHENTICATION_FAILURE_URL); WebConfigUtils.validateHttpRedirect(authenticationFailureUrl, pc, source); alwaysUseDefault = elt.getAttribute(ATT_ALWAYS_USE_DEFAULT_TARGET_URL); - loginPage = elt.getAttribute(ATT_LOGIN_PAGE); + this.loginPage = elt.getAttribute(ATT_LOGIN_PAGE); successHandlerRef = elt.getAttribute(ATT_SUCCESS_HANDLER_REF); failureHandlerRef = elt.getAttribute(ATT_FAILURE_HANDLER_REF); - authDetailsSourceRef = elt - .getAttribute(AuthenticationConfigBuilder.ATT_AUTH_DETAILS_SOURCE_REF); + authDetailsSourceRef = elt.getAttribute(AuthenticationConfigBuilder.ATT_AUTH_DETAILS_SOURCE_REF); authenticationFailureForwardUrl = elt.getAttribute(ATT_FORM_LOGIN_AUTHENTICATION_FAILURE_FORWARD_URL); WebConfigUtils.validateHttpRedirect(authenticationFailureForwardUrl, pc, source); authenticationSuccessForwardUrl = elt.getAttribute(ATT_FORM_LOGIN_AUTHENTICATION_SUCCESS_FORWARD_URL); WebConfigUtils.validateHttpRedirect(authenticationSuccessForwardUrl, pc, source); - - if (!StringUtils.hasText(loginPage)) { - loginPage = null; + if (!StringUtils.hasText(this.loginPage)) { + this.loginPage = null; } - WebConfigUtils.validateHttpRedirect(loginPage, pc, source); + WebConfigUtils.validateHttpRedirect(this.loginPage, pc, source); usernameParameter = elt.getAttribute(ATT_USERNAME_PARAMETER); passwordParameter = elt.getAttribute(ATT_PASSWORD_PARAMETER); } - - filterBean = createFilterBean(loginUrl, defaultTargetUrl, alwaysUseDefault, - loginPage, authenticationFailureUrl, successHandlerRef, - failureHandlerRef, authDetailsSourceRef, authenticationFailureForwardUrl, authenticationSuccessForwardUrl); - + this.filterBean = createFilterBean(loginUrl, defaultTargetUrl, alwaysUseDefault, this.loginPage, + authenticationFailureUrl, successHandlerRef, failureHandlerRef, authDetailsSourceRef, + authenticationFailureForwardUrl, authenticationSuccessForwardUrl); if (StringUtils.hasText(usernameParameter)) { - filterBean.getPropertyValues().addPropertyValue("usernameParameter", - usernameParameter); + this.filterBean.getPropertyValues().addPropertyValue("usernameParameter", usernameParameter); } if (StringUtils.hasText(passwordParameter)) { - filterBean.getPropertyValues().addPropertyValue("passwordParameter", - passwordParameter); + this.filterBean.getPropertyValues().addPropertyValue("passwordParameter", passwordParameter); } - - filterBean.setSource(source); - + this.filterBean.setSource(source); BeanDefinitionBuilder entryPointBuilder = BeanDefinitionBuilder .rootBeanDefinition(LoginUrlAuthenticationEntryPoint.class); entryPointBuilder.getRawBeanDefinition().setSource(source); - entryPointBuilder.addConstructorArgValue(loginPage != null ? loginPage - : DEF_LOGIN_PAGE); - entryPointBuilder.addPropertyValue("portMapper", portMapper); - entryPointBuilder.addPropertyValue("portResolver", portResolver); - entryPointBean = (RootBeanDefinition) entryPointBuilder.getBeanDefinition(); - + entryPointBuilder.addConstructorArgValue((this.loginPage != null) ? this.loginPage : DEF_LOGIN_PAGE); + entryPointBuilder.addPropertyValue("portMapper", this.portMapper); + entryPointBuilder.addPropertyValue("portResolver", this.portResolver); + this.entryPointBean = (RootBeanDefinition) entryPointBuilder.getBeanDefinition(); return null; } - private RootBeanDefinition createFilterBean(String loginUrl, String defaultTargetUrl, - String alwaysUseDefault, String loginPage, String authenticationFailureUrl, - String successHandlerRef, String failureHandlerRef, - String authDetailsSourceRef, String authenticationFailureForwardUrl, String authenticationSuccessForwardUrl) { - - BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder - .rootBeanDefinition(filterClassName); - + private RootBeanDefinition createFilterBean(String loginUrl, String defaultTargetUrl, String alwaysUseDefault, + String loginPage, String authenticationFailureUrl, String successHandlerRef, String failureHandlerRef, + String authDetailsSourceRef, String authenticationFailureForwardUrl, + String authenticationSuccessForwardUrl) { + BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder.rootBeanDefinition(this.filterClassName); if (!StringUtils.hasText(loginUrl)) { - loginUrl = defaultLoginProcessingUrl; + loginUrl = this.defaultLoginProcessingUrl; } - this.loginProcessingUrl = loginUrl; - BeanDefinitionBuilder matcherBuilder = BeanDefinitionBuilder .rootBeanDefinition("org.springframework.security.web.util.matcher.AntPathRequestMatcher"); matcherBuilder.addConstructorArgValue(loginUrl); - if (loginMethod != null) { + if (this.loginMethod != null) { matcherBuilder.addConstructorArgValue("POST"); } - - filterBuilder.addPropertyValue("requiresAuthenticationRequestMatcher", - matcherBuilder.getBeanDefinition()); - + filterBuilder.addPropertyValue("requiresAuthenticationRequestMatcher", matcherBuilder.getBeanDefinition()); if (StringUtils.hasText(successHandlerRef)) { - filterBuilder.addPropertyReference("authenticationSuccessHandler", - successHandlerRef); - } else if (StringUtils.hasText(authenticationSuccessForwardUrl)) { + filterBuilder.addPropertyReference("authenticationSuccessHandler", successHandlerRef); + } + else if (StringUtils.hasText(authenticationSuccessForwardUrl)) { BeanDefinitionBuilder forwardSuccessHandler = BeanDefinitionBuilder .rootBeanDefinition(ForwardAuthenticationSuccessHandler.class); forwardSuccessHandler.addConstructorArgValue(authenticationSuccessForwardUrl); filterBuilder.addPropertyValue("authenticationSuccessHandler", forwardSuccessHandler.getBeanDefinition()); - } else { + } + else { BeanDefinitionBuilder successHandler = BeanDefinitionBuilder .rootBeanDefinition(SavedRequestAwareAuthenticationSuccessHandler.class); if ("true".equals(alwaysUseDefault)) { - successHandler - .addPropertyValue("alwaysUseDefaultTargetUrl", Boolean.TRUE); + successHandler.addPropertyValue("alwaysUseDefaultTargetUrl", Boolean.TRUE); } - successHandler.addPropertyValue("requestCache", requestCache); - successHandler.addPropertyValue("defaultTargetUrl", StringUtils - .hasText(defaultTargetUrl) ? defaultTargetUrl - : DEF_FORM_LOGIN_TARGET_URL); - filterBuilder.addPropertyValue("authenticationSuccessHandler", - successHandler.getBeanDefinition()); + successHandler.addPropertyValue("requestCache", this.requestCache); + successHandler.addPropertyValue("defaultTargetUrl", + StringUtils.hasText(defaultTargetUrl) ? defaultTargetUrl : DEF_FORM_LOGIN_TARGET_URL); + filterBuilder.addPropertyValue("authenticationSuccessHandler", successHandler.getBeanDefinition()); } - if (StringUtils.hasText(authDetailsSourceRef)) { - filterBuilder.addPropertyReference("authenticationDetailsSource", - authDetailsSourceRef); + filterBuilder.addPropertyReference("authenticationDetailsSource", authDetailsSourceRef); } - - if (sessionStrategy != null) { - filterBuilder.addPropertyValue("sessionAuthenticationStrategy", - sessionStrategy); + if (this.sessionStrategy != null) { + filterBuilder.addPropertyValue("sessionAuthenticationStrategy", this.sessionStrategy); } - if (StringUtils.hasText(failureHandlerRef)) { - filterBuilder.addPropertyReference("authenticationFailureHandler", - failureHandlerRef); - } else if (StringUtils.hasText(authenticationFailureForwardUrl)) { + filterBuilder.addPropertyReference("authenticationFailureHandler", failureHandlerRef); + } + else if (StringUtils.hasText(authenticationFailureForwardUrl)) { BeanDefinitionBuilder forwardFailureHandler = BeanDefinitionBuilder .rootBeanDefinition(ForwardAuthenticationFailureHandler.class); forwardFailureHandler.addConstructorArgValue(authenticationFailureForwardUrl); filterBuilder.addPropertyValue("authenticationFailureHandler", forwardFailureHandler.getBeanDefinition()); - } else { + } + else { BeanDefinitionBuilder failureHandler = BeanDefinitionBuilder .rootBeanDefinition(SimpleUrlAuthenticationFailureHandler.class); if (!StringUtils.hasText(authenticationFailureUrl)) { @@ -233,29 +231,27 @@ public class FormLoginBeanDefinitionParser { authenticationFailureUrl = DEF_FORM_LOGIN_AUTHENTICATION_FAILURE_URL; } } - failureHandler - .addPropertyValue("defaultFailureUrl", authenticationFailureUrl); - failureHandler.addPropertyValue("allowSessionCreation", allowSessionCreation); - filterBuilder.addPropertyValue("authenticationFailureHandler", - failureHandler.getBeanDefinition()); + failureHandler.addPropertyValue("defaultFailureUrl", authenticationFailureUrl); + failureHandler.addPropertyValue("allowSessionCreation", this.allowSessionCreation); + filterBuilder.addPropertyValue("authenticationFailureHandler", failureHandler.getBeanDefinition()); } - return (RootBeanDefinition) filterBuilder.getBeanDefinition(); } RootBeanDefinition getFilterBean() { - return filterBean; + return this.filterBean; } RootBeanDefinition getEntryPointBean() { - return entryPointBean; + return this.entryPointBean; } String getLoginPage() { - return loginPage; + return this.loginPage; } String getLoginProcessingUrl() { - return loginProcessingUrl; + return this.loginProcessingUrl; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/GrantedAuthorityDefaultsParserUtils.java b/config/src/main/java/org/springframework/security/config/http/GrantedAuthorityDefaultsParserUtils.java index 39f1ae1e00..ce32607100 100644 --- a/config/src/main/java/org/springframework/security/config/http/GrantedAuthorityDefaultsParserUtils.java +++ b/config/src/main/java/org/springframework/security/config/http/GrantedAuthorityDefaultsParserUtils.java @@ -27,35 +27,39 @@ import org.springframework.security.config.core.GrantedAuthorityDefaults; * @author Rob Winch * @since 4.2 */ -class GrantedAuthorityDefaultsParserUtils { +final class GrantedAuthorityDefaultsParserUtils { + private GrantedAuthorityDefaultsParserUtils() { + } - static RootBeanDefinition registerWithDefaultRolePrefix(ParserContext pc, Class beanFactoryClass) { + static RootBeanDefinition registerWithDefaultRolePrefix(ParserContext pc, + Class beanFactoryClass) { RootBeanDefinition beanFactoryDefinition = new RootBeanDefinition(beanFactoryClass); String beanFactoryRef = pc.getReaderContext().generateBeanName(beanFactoryDefinition); pc.getRegistry().registerBeanDefinition(beanFactoryRef, beanFactoryDefinition); - RootBeanDefinition bean = new RootBeanDefinition(); bean.setFactoryBeanName(beanFactoryRef); bean.setFactoryMethodName("getBean"); return bean; } - static abstract class AbstractGrantedAuthorityDefaultsBeanFactory implements ApplicationContextAware { + abstract static class AbstractGrantedAuthorityDefaultsBeanFactory implements ApplicationContextAware { + protected String rolePrefix = "ROLE_"; @Override - public final void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - String[] grantedAuthorityDefaultsBeanNames = applicationContext.getBeanNamesForType(GrantedAuthorityDefaults.class); + public final void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + String[] grantedAuthorityDefaultsBeanNames = applicationContext + .getBeanNamesForType(GrantedAuthorityDefaults.class); if (grantedAuthorityDefaultsBeanNames.length == 1) { - GrantedAuthorityDefaults grantedAuthorityDefaults = applicationContext.getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); + GrantedAuthorityDefaults grantedAuthorityDefaults = applicationContext + .getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); this.rolePrefix = grantedAuthorityDefaults.getRolePrefix(); } } abstract Object getBean(); + } - private GrantedAuthorityDefaultsParserUtils() {} } diff --git a/config/src/main/java/org/springframework/security/config/http/HandlerMappingIntrospectorFactoryBean.java b/config/src/main/java/org/springframework/security/config/http/HandlerMappingIntrospectorFactoryBean.java index 3518bf8d0b..302e95834a 100644 --- a/config/src/main/java/org/springframework/security/config/http/HandlerMappingIntrospectorFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/http/HandlerMappingIntrospectorFactoryBean.java @@ -31,15 +31,21 @@ import org.springframework.web.servlet.handler.HandlerMappingIntrospector; * @author Rob Winch * @since 4.1.1 */ -class HandlerMappingIntrospectorFactoryBean implements FactoryBean, ApplicationContextAware { +class HandlerMappingIntrospectorFactoryBean + implements FactoryBean, ApplicationContextAware { + private static final String HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME = "mvcHandlerMappingIntrospector"; private ApplicationContext context; + @Override public HandlerMappingIntrospector getObject() { if (!this.context.containsBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME)) { - throw new NoSuchBeanDefinitionException(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, "A Bean named " + HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME +" of type " + HandlerMappingIntrospector.class.getName() - + " is required to use MvcRequestMatcher. Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext."); + throw new NoSuchBeanDefinitionException(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, + "A Bean named " + HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME + " of type " + + HandlerMappingIntrospector.class.getName() + + " is required to use MvcRequestMatcher. Please ensure Spring Security & Spring " + + "MVC are configured in a shared ApplicationContext."); } return this.context.getBean(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME, HandlerMappingIntrospector.class); } @@ -50,15 +56,8 @@ class HandlerMappingIntrospectorFactoryBean implements FactoryBean headerWriters; + @Override public BeanDefinition parse(Element element, ParserContext parserContext) { - - headerWriters = new ManagedList<>(); - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(HeaderWriterFilter.class); - - boolean disabled = element != null - && "true".equals(resolveAttribute(parserContext, element, "disabled")); + this.headerWriters = new ManagedList<>(); + BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(HeaderWriterFilter.class); + boolean disabled = element != null && "true".equals(resolveAttribute(parserContext, element, "disabled")); boolean defaultsDisabled = element != null && "true".equals(resolveAttribute(parserContext, element, "defaults-disabled")); - boolean addIfNotPresent = element == null || !disabled && !defaultsDisabled; - parseCacheControlElement(addIfNotPresent, element); parseHstsElement(addIfNotPresent, element, parserContext); parseXssElement(addIfNotPresent, element, parserContext); parseFrameOptionsElement(addIfNotPresent, element, parserContext); parseContentTypeOptionsElement(addIfNotPresent, element); - parseHpkpElement(element == null || !disabled, element, parserContext); - parseContentSecurityPolicyElement(disabled, element, parserContext); - parseReferrerPolicyElement(element, parserContext); - parseFeaturePolicyElement(element, parserContext); - parseHeaderElements(element); - - boolean noWriters = headerWriters.isEmpty(); + boolean noWriters = this.headerWriters.isEmpty(); if (disabled && !noWriters) { - parserContext - .getReaderContext() - .error("Cannot specify with child elements.", - element); - } else if (noWriters) { + parserContext.getReaderContext().error("Cannot specify with child elements.", + element); + } + else if (noWriters) { return null; } - - builder.addConstructorArgValue(headerWriters); + builder.addConstructorArgValue(this.headerWriters); return builder.getBeanDefinition(); } /** - * * Resolve the placeholder for a given attribute on a element. - * * @param pc * @param element * @param attributeName @@ -152,10 +165,9 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { } private void parseCacheControlElement(boolean addIfNotPresent, Element element) { - Element cacheControlElement = element == null ? null : DomUtils - .getChildElementByTagName(element, CACHE_CONTROL_ELEMENT); - boolean disabled = "true".equals(getAttribute(cacheControlElement, ATT_DISABLED, - "false")); + Element cacheControlElement = (element != null) + ? DomUtils.getChildElementByTagName(element, CACHE_CONTROL_ELEMENT) : null; + boolean disabled = "true".equals(getAttribute(cacheControlElement, ATT_DISABLED, "false")); if (disabled) { return; } @@ -167,69 +179,59 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { private void addCacheControl() { BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder .genericBeanDefinition(CacheControlHeadersWriter.class); - headerWriters.add(headersWriter.getBeanDefinition()); + this.headerWriters.add(headersWriter.getBeanDefinition()); } - private void parseHstsElement(boolean addIfNotPresent, Element element, - ParserContext context) { - Element hstsElement = element == null ? null : DomUtils.getChildElementByTagName( - element, HSTS_ELEMENT); + private void parseHstsElement(boolean addIfNotPresent, Element element, ParserContext context) { + Element hstsElement = (element != null) ? DomUtils.getChildElementByTagName(element, HSTS_ELEMENT) : null; if (addIfNotPresent || hstsElement != null) { addHsts(addIfNotPresent, hstsElement, context); } } - private void addHsts(boolean addIfNotPresent, Element hstsElement, - ParserContext context) { - BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder - .genericBeanDefinition(HstsHeaderWriter.class); + private void addHsts(boolean addIfNotPresent, Element hstsElement, ParserContext context) { + BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder.genericBeanDefinition(HstsHeaderWriter.class); if (hstsElement != null) { - boolean disabled = "true".equals(getAttribute(hstsElement, ATT_DISABLED, - "false")); + boolean disabled = "true".equals(getAttribute(hstsElement, ATT_DISABLED, "false")); String includeSubDomains = hstsElement.getAttribute(ATT_INCLUDE_SUBDOMAINS); if (StringUtils.hasText(includeSubDomains)) { if (disabled) { - attrNotAllowed(context, ATT_INCLUDE_SUBDOMAINS, ATT_DISABLED, - hstsElement); + attrNotAllowed(context, ATT_INCLUDE_SUBDOMAINS, ATT_DISABLED, hstsElement); } headersWriter.addPropertyValue("includeSubDomains", includeSubDomains); } String maxAgeSeconds = hstsElement.getAttribute(ATT_MAX_AGE_SECONDS); if (StringUtils.hasText(maxAgeSeconds)) { if (disabled) { - attrNotAllowed(context, ATT_MAX_AGE_SECONDS, ATT_DISABLED, - hstsElement); + attrNotAllowed(context, ATT_MAX_AGE_SECONDS, ATT_DISABLED, hstsElement); } headersWriter.addPropertyValue("maxAgeInSeconds", maxAgeSeconds); } String requestMatcherRef = hstsElement.getAttribute(ATT_REQUEST_MATCHER_REF); if (StringUtils.hasText(requestMatcherRef)) { if (disabled) { - attrNotAllowed(context, ATT_REQUEST_MATCHER_REF, ATT_DISABLED, - hstsElement); + attrNotAllowed(context, ATT_REQUEST_MATCHER_REF, ATT_DISABLED, hstsElement); } headersWriter.addPropertyReference("requestMatcher", requestMatcherRef); } String preload = hstsElement.getAttribute(ATT_PRELOAD); if (StringUtils.hasText(preload)) { if (disabled) { - attrNotAllowed(context, ATT_PRELOAD, ATT_DISABLED, - hstsElement); + attrNotAllowed(context, ATT_PRELOAD, ATT_DISABLED, hstsElement); } headersWriter.addPropertyValue("preload", preload); } - - if (disabled == true) { + if (disabled) { return; } } if (addIfNotPresent || hstsElement != null) { - headerWriters.add(headersWriter.getBeanDefinition()); + this.headerWriters.add(headersWriter.getBeanDefinition()); } } private void parseHpkpElement(boolean addIfNotPresent, Element element, ParserContext context) { - Element hpkpElement = element == null ? null : DomUtils.getChildElementByTagName(element, HPKP_ELEMENT); + Element hpkpElement = (element != null) ? DomUtils.getChildElementByTagName(element, HPKP_ELEMENT) : null; if (addIfNotPresent || hpkpElement != null) { addHpkp(addIfNotPresent, hpkpElement, context); } @@ -238,68 +240,54 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { private void addHpkp(boolean addIfNotPresent, Element hpkpElement, ParserContext context) { if (hpkpElement != null) { boolean disabled = "true".equals(getAttribute(hpkpElement, ATT_DISABLED, "false")); - if (disabled) { return; } - BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder.genericBeanDefinition(HpkpHeaderWriter.class); - Element pinsElement = DomUtils.getChildElementByTagName(hpkpElement, PINS_ELEMENT); if (pinsElement != null) { List pinElements = DomUtils.getChildElements(pinsElement); - Map pins = new LinkedHashMap<>(); - for (Element pinElement : pinElements) { String hash = pinElement.getAttribute(ATT_ALGORITHM); if (!StringUtils.hasText(hash)) { hash = "sha256"; } - Node pinValueNode = pinElement.getFirstChild(); if (pinValueNode == null) { context.getReaderContext().warning("Missing value for pin entry.", hpkpElement); continue; } - String fingerprint = pinElement.getFirstChild().getTextContent(); - pins.put(fingerprint, hash); } - headersWriter.addPropertyValue("pins", pins); } - String includeSubDomains = hpkpElement.getAttribute(ATT_INCLUDE_SUBDOMAINS); if (StringUtils.hasText(includeSubDomains)) { headersWriter.addPropertyValue("includeSubDomains", includeSubDomains); } - String maxAgeSeconds = hpkpElement.getAttribute(ATT_MAX_AGE_SECONDS); if (StringUtils.hasText(maxAgeSeconds)) { headersWriter.addPropertyValue("maxAgeInSeconds", maxAgeSeconds); } - String reportOnly = hpkpElement.getAttribute(ATT_REPORT_ONLY); if (StringUtils.hasText(reportOnly)) { headersWriter.addPropertyValue("reportOnly", reportOnly); } - String reportUri = hpkpElement.getAttribute(ATT_REPORT_URI); if (StringUtils.hasText(reportUri)) { headersWriter.addPropertyValue("reportUri", reportUri); } - if (addIfNotPresent) { - headerWriters.add(headersWriter.getBeanDefinition()); + this.headerWriters.add(headersWriter.getBeanDefinition()); } } } private void parseContentSecurityPolicyElement(boolean elementDisabled, Element element, ParserContext context) { - Element contentSecurityPolicyElement = (elementDisabled || element == null) ? null : DomUtils.getChildElementByTagName( - element, CONTENT_SECURITY_POLICY_ELEMENT); + Element contentSecurityPolicyElement = (elementDisabled || element == null) ? null + : DomUtils.getChildElementByTagName(element, CONTENT_SECURITY_POLICY_ELEMENT); if (contentSecurityPolicyElement != null) { addContentSecurityPolicy(contentSecurityPolicyElement, context); } @@ -308,43 +296,42 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { private void addContentSecurityPolicy(Element contentSecurityPolicyElement, ParserContext context) { BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder .genericBeanDefinition(ContentSecurityPolicyHeaderWriter.class); - String policyDirectives = contentSecurityPolicyElement.getAttribute(ATT_POLICY_DIRECTIVES); if (!StringUtils.hasText(policyDirectives)) { - context.getReaderContext().error( - ATT_POLICY_DIRECTIVES + " requires a 'value' to be set.", contentSecurityPolicyElement); - } else { + context.getReaderContext().error(ATT_POLICY_DIRECTIVES + " requires a 'value' to be set.", + contentSecurityPolicyElement); + } + else { headersWriter.addConstructorArgValue(policyDirectives); } - String reportOnly = contentSecurityPolicyElement.getAttribute(ATT_REPORT_ONLY); if (StringUtils.hasText(reportOnly)) { headersWriter.addPropertyValue("reportOnly", reportOnly); } - - headerWriters.add(headersWriter.getBeanDefinition()); + this.headerWriters.add(headersWriter.getBeanDefinition()); } private void parseReferrerPolicyElement(Element element, ParserContext context) { - Element referrerPolicyElement = (element == null) ? null : DomUtils.getChildElementByTagName(element, REFERRER_POLICY_ELEMENT); + Element referrerPolicyElement = (element != null) + ? DomUtils.getChildElementByTagName(element, REFERRER_POLICY_ELEMENT) : null; if (referrerPolicyElement != null) { addReferrerPolicy(referrerPolicyElement, context); } } private void addReferrerPolicy(Element referrerPolicyElement, ParserContext context) { - BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder.genericBeanDefinition(ReferrerPolicyHeaderWriter.class); - + BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder + .genericBeanDefinition(ReferrerPolicyHeaderWriter.class); String policy = referrerPolicyElement.getAttribute(ATT_POLICY); if (StringUtils.hasLength(policy)) { headersWriter.addConstructorArgValue(ReferrerPolicy.get(policy)); } - headerWriters.add(headersWriter.getBeanDefinition()); + this.headerWriters.add(headersWriter.getBeanDefinition()); } private void parseFeaturePolicyElement(Element element, ParserContext context) { - Element featurePolicyElement = (element == null) ? null - : DomUtils.getChildElementByTagName(element, FEATURE_POLICY_ELEMENT); + Element featurePolicyElement = (element != null) + ? DomUtils.getChildElementByTagName(element, FEATURE_POLICY_ELEMENT) : null; if (featurePolicyElement != null) { addFeaturePolicy(featurePolicyElement, context); } @@ -353,51 +340,43 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { private void addFeaturePolicy(Element featurePolicyElement, ParserContext context) { BeanDefinitionBuilder headersWriter = BeanDefinitionBuilder .genericBeanDefinition(FeaturePolicyHeaderWriter.class); - - String policyDirectives = featurePolicyElement - .getAttribute(ATT_POLICY_DIRECTIVES); + String policyDirectives = featurePolicyElement.getAttribute(ATT_POLICY_DIRECTIVES); if (!StringUtils.hasText(policyDirectives)) { - context.getReaderContext().error( - ATT_POLICY_DIRECTIVES + " requires a 'value' to be set.", + context.getReaderContext().error(ATT_POLICY_DIRECTIVES + " requires a 'value' to be set.", featurePolicyElement); } else { headersWriter.addConstructorArgValue(policyDirectives); } - - headerWriters.add(headersWriter.getBeanDefinition()); + this.headerWriters.add(headersWriter.getBeanDefinition()); } - private void attrNotAllowed(ParserContext context, String attrName, - String otherAttrName, Element element) { - context.getReaderContext().error( - "Only one of '" + attrName + "' or '" + otherAttrName + "' can be set.", + private void attrNotAllowed(ParserContext context, String attrName, String otherAttrName, Element element) { + context.getReaderContext().error("Only one of '" + attrName + "' or '" + otherAttrName + "' can be set.", element); } private void parseHeaderElements(Element element) { - List headerElts = element == null ? Collections. emptyList() - : DomUtils.getChildElementsByTagName(element, GENERIC_HEADER_ELEMENT); + List headerElts = (element != null) + ? DomUtils.getChildElementsByTagName(element, GENERIC_HEADER_ELEMENT) : Collections.emptyList(); for (Element headerElt : headerElts) { String headerFactoryRef = headerElt.getAttribute(ATT_REF); if (StringUtils.hasText(headerFactoryRef)) { - headerWriters.add(new RuntimeBeanReference(headerFactoryRef)); + this.headerWriters.add(new RuntimeBeanReference(headerFactoryRef)); } else { - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .genericBeanDefinition(StaticHeadersWriter.class); + BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class); builder.addConstructorArgValue(headerElt.getAttribute(ATT_NAME)); builder.addConstructorArgValue(headerElt.getAttribute(ATT_VALUE)); - headerWriters.add(builder.getBeanDefinition()); + this.headerWriters.add(builder.getBeanDefinition()); } } } private void parseContentTypeOptionsElement(boolean addIfNotPresent, Element element) { - Element contentTypeElt = element == null ? null : DomUtils - .getChildElementByTagName(element, CONTENT_TYPE_ELEMENT); - boolean disabled = "true".equals(getAttribute(contentTypeElt, ATT_DISABLED, - "false")); + Element contentTypeElt = (element != null) ? DomUtils.getChildElementByTagName(element, CONTENT_TYPE_ELEMENT) + : null; + boolean disabled = "true".equals(getAttribute(contentTypeElt, ATT_DISABLED, "false")); if (disabled) { return; } @@ -409,108 +388,93 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { private void addContentTypeOptions() { BeanDefinitionBuilder builder = BeanDefinitionBuilder .genericBeanDefinition(XContentTypeOptionsHeaderWriter.class); - headerWriters.add(builder.getBeanDefinition()); + this.headerWriters.add(builder.getBeanDefinition()); } - private void parseFrameOptionsElement(boolean addIfNotPresent, Element element, - ParserContext parserContext) { - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .genericBeanDefinition(XFrameOptionsHeaderWriter.class); - - Element frameElt = element == null ? null : DomUtils.getChildElementByTagName( - element, FRAME_OPTIONS_ELEMENT); - if (frameElt != null) { - String header = getAttribute(frameElt, ATT_POLICY, null); - boolean disabled = "true" - .equals(getAttribute(frameElt, ATT_DISABLED, "false")); - - if (disabled && header != null) { - this.attrNotAllowed(parserContext, ATT_DISABLED, ATT_POLICY, frameElt); - } - if (!StringUtils.hasText(header)) { - header = "DENY"; - } - - if (ALLOW_FROM.equals(header)) { - String strategyRef = getAttribute(frameElt, ATT_REF, null); - String strategy = getAttribute(frameElt, ATT_STRATEGY, null); - - if (StringUtils.hasText(strategy) && StringUtils.hasText(strategyRef)) { - parserContext.getReaderContext().error( - "Only one of 'strategy' or 'strategy-ref' can be set.", - frameElt); - } - else if (strategyRef != null) { - builder.addConstructorArgReference(strategyRef); - } - else if (strategy != null) { - String value = getAttribute(frameElt, ATT_VALUE, null); - if (!StringUtils.hasText(value)) { - parserContext.getReaderContext().error( - "Strategy requires a 'value' to be set.", frameElt); - } - // static, whitelist, regexp - if ("static".equals(strategy)) { - try { - builder.addConstructorArgValue(new StaticAllowFromStrategy( - new URI(value))); - } - catch (URISyntaxException e) { - parserContext.getReaderContext().error( - "'value' attribute doesn't represent a valid URI.", - frameElt, e); - } - } - else { - BeanDefinitionBuilder allowFromStrategy; - if ("whitelist".equals(strategy)) { - allowFromStrategy = BeanDefinitionBuilder - .rootBeanDefinition(WhiteListedAllowFromStrategy.class); - allowFromStrategy.addConstructorArgValue(StringUtils - .commaDelimitedListToSet(value)); - } - else { - allowFromStrategy = BeanDefinitionBuilder - .rootBeanDefinition(RegExpAllowFromStrategy.class); - allowFromStrategy.addConstructorArgValue(value); - } - String fromParameter = getAttribute(frameElt, ATT_FROM_PARAMETER, - "from"); - allowFromStrategy.addPropertyValue("allowFromParameterName", - fromParameter); - builder.addConstructorArgValue(allowFromStrategy - .getBeanDefinition()); - } - } - else { - parserContext.getReaderContext() - .error("One of 'strategy' and 'strategy-ref' must be set.", - frameElt); - } - } - else { - builder.addConstructorArgValue(header); - } - - if (disabled) { - return; + private void parseFrameOptionsElement(boolean addIfNotPresent, Element element, ParserContext parserContext) { + BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(XFrameOptionsHeaderWriter.class); + Element frameElement = (element != null) ? DomUtils.getChildElementByTagName(element, FRAME_OPTIONS_ELEMENT) + : null; + if (frameElement == null) { + if (addIfNotPresent) { + this.headerWriters.add(builder.getBeanDefinition()); } + return; } - - if (addIfNotPresent || frameElt != null) { - headerWriters.add(builder.getBeanDefinition()); + String header = getAttribute(frameElement, ATT_POLICY, null); + boolean disabled = "true".equals(getAttribute(frameElement, ATT_DISABLED, "false")); + if (disabled && header != null) { + this.attrNotAllowed(parserContext, ATT_DISABLED, ATT_POLICY, frameElement); + } + header = StringUtils.hasText(header) ? header : "DENY"; + if (ALLOW_FROM.equals(header)) { + parseAllowFromFrameOptionsElement(parserContext, builder, frameElement); + } + else { + builder.addConstructorArgValue(header); + } + if (!disabled) { + this.headerWriters.add(builder.getBeanDefinition()); } } - private void parseXssElement(boolean addIfNotPresent, Element element, - ParserContext parserContext) { - Element xssElt = element == null ? null : DomUtils.getChildElementByTagName( - element, XSS_ELEMENT); - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .genericBeanDefinition(XXssProtectionHeaderWriter.class); + private void parseAllowFromFrameOptionsElement(ParserContext parserContext, BeanDefinitionBuilder builder, + Element frameElement) { + String strategyRef = getAttribute(frameElement, ATT_REF, null); + String strategy = getAttribute(frameElement, ATT_STRATEGY, null); + if (StringUtils.hasText(strategy) && StringUtils.hasText(strategyRef)) { + parserContext.getReaderContext().error("Only one of 'strategy' or 'strategy-ref' can be set.", + frameElement); + return; + } + if (strategyRef != null) { + builder.addConstructorArgReference(strategyRef); + return; + } + if (strategy == null) { + parserContext.getReaderContext().error("One of 'strategy' and 'strategy-ref' must be set.", frameElement); + return; + } + String value = getAttribute(frameElement, ATT_VALUE, null); + if (!StringUtils.hasText(value)) { + parserContext.getReaderContext().error("Strategy requires a 'value' to be set.", frameElement); + return; + } + // static, whitelist, regexp + if ("static".equals(strategy)) { + try { + builder.addConstructorArgValue(new StaticAllowFromStrategy(new URI(value))); + } + catch (URISyntaxException ex) { + parserContext.getReaderContext().error("'value' attribute doesn't represent a valid URI.", frameElement, + ex); + } + return; + } + BeanDefinitionBuilder allowFromStrategy = getAllowFromStrategy(strategy, value); + String fromParameter = getAttribute(frameElement, ATT_FROM_PARAMETER, "from"); + allowFromStrategy.addPropertyValue("allowFromParameterName", fromParameter); + builder.addConstructorArgValue(allowFromStrategy.getBeanDefinition()); + } + + private BeanDefinitionBuilder getAllowFromStrategy(String strategy, String value) { + if ("whitelist".equals(strategy)) { + BeanDefinitionBuilder allowFromStrategy = BeanDefinitionBuilder + .rootBeanDefinition(WhiteListedAllowFromStrategy.class); + allowFromStrategy.addConstructorArgValue(StringUtils.commaDelimitedListToSet(value)); + return allowFromStrategy; + } + BeanDefinitionBuilder allowFromStrategy; + allowFromStrategy = BeanDefinitionBuilder.rootBeanDefinition(RegExpAllowFromStrategy.class); + allowFromStrategy.addConstructorArgValue(value); + return allowFromStrategy; + } + + private void parseXssElement(boolean addIfNotPresent, Element element, ParserContext parserContext) { + Element xssElt = (element != null) ? DomUtils.getChildElementByTagName(element, XSS_ELEMENT) : null; + BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(XXssProtectionHeaderWriter.class); if (xssElt != null) { boolean disabled = "true".equals(getAttribute(xssElt, ATT_DISABLED, "false")); - String enabled = xssElt.getAttribute(ATT_ENABLED); if (StringUtils.hasText(enabled)) { if (disabled) { @@ -518,7 +482,6 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { } builder.addPropertyValue("enabled", enabled); } - String block = xssElt.getAttribute(ATT_BLOCK); if (StringUtils.hasText(block)) { if (disabled) { @@ -526,13 +489,12 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { } builder.addPropertyValue("block", block); } - if (disabled) { return; } } if (addIfNotPresent || xssElt != null) { - headerWriters.add(builder.getBeanDefinition()); + this.headerWriters.add(builder.getBeanDefinition()); } } @@ -544,8 +506,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser { if (StringUtils.hasText(value)) { return value; } - else { - return defaultValue; - } + return defaultValue; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java b/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java index 2711ebdb9f..992d5c60a1 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.ArrayList; @@ -70,28 +71,11 @@ import org.springframework.security.web.session.SessionManagementFilter; import org.springframework.security.web.session.SimpleRedirectInvalidSessionStrategy; import org.springframework.security.web.session.SimpleRedirectSessionInformationExpiredStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_FILTERS; -import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_HTTP_METHOD; -import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN; -import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF; -import static org.springframework.security.config.http.HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL; -import static org.springframework.security.config.http.SecurityFilters.CHANNEL_FILTER; -import static org.springframework.security.config.http.SecurityFilters.CONCURRENT_SESSION_FILTER; -import static org.springframework.security.config.http.SecurityFilters.CORS_FILTER; -import static org.springframework.security.config.http.SecurityFilters.CSRF_FILTER; -import static org.springframework.security.config.http.SecurityFilters.FILTER_SECURITY_INTERCEPTOR; -import static org.springframework.security.config.http.SecurityFilters.HEADERS_FILTER; -import static org.springframework.security.config.http.SecurityFilters.JAAS_API_SUPPORT_FILTER; -import static org.springframework.security.config.http.SecurityFilters.REQUEST_CACHE_FILTER; -import static org.springframework.security.config.http.SecurityFilters.SECURITY_CONTEXT_FILTER; -import static org.springframework.security.config.http.SecurityFilters.SERVLET_API_SUPPORT_FILTER; -import static org.springframework.security.config.http.SecurityFilters.SESSION_MANAGEMENT_FILTER; -import static org.springframework.security.config.http.SecurityFilters.WEB_ASYNC_MANAGER_FILTER; - /** * Stateful class which helps HttpSecurityBDP to create the configuration for the * <http> element. @@ -101,87 +85,120 @@ import static org.springframework.security.config.http.SecurityFilters.WEB_ASYNC * @since 3.0 */ class HttpConfigurationBuilder { + private static final String ATT_CREATE_SESSION = "create-session"; private static final String ATT_SESSION_FIXATION_PROTECTION = "session-fixation-protection"; + private static final String OPT_SESSION_FIXATION_NO_PROTECTION = "none"; + private static final String OPT_SESSION_FIXATION_MIGRATE_SESSION = "migrateSession"; + private static final String OPT_CHANGE_SESSION_ID = "changeSessionId"; private static final String ATT_INVALID_SESSION_URL = "invalid-session-url"; + private static final String ATT_SESSION_AUTH_STRATEGY_REF = "session-authentication-strategy-ref"; + private static final String ATT_SESSION_AUTH_ERROR_URL = "session-authentication-error-url"; + private static final String ATT_SECURITY_CONTEXT_REPOSITORY = "security-context-repository-ref"; + private static final String ATT_INVALID_SESSION_STRATEGY_REF = "invalid-session-strategy-ref"; + private static final String ATT_DISABLE_URL_REWRITING = "disable-url-rewriting"; private static final String ATT_ACCESS_MGR = "access-decision-manager-ref"; + private static final String ATT_ONCE_PER_REQUEST = "once-per-request"; private static final String ATT_REF = "ref"; + private static final String ATT_EXPIRY_URL = "expired-url"; + + private static final String ATT_EXPIRED_SESSION_STRATEGY_REF = "expired-session-strategy-ref"; + + private static final String ATT_SESSION_REGISTRY_ALIAS = "session-registry-alias"; + + private static final String ATT_SESSION_REGISTRY_REF = "session-registry-ref"; + + private static final String ATT_SERVLET_API_PROVISION = "servlet-api-provision"; + + private static final String DEF_SERVLET_API_PROVISION = "true"; + + private static final String ATT_JAAS_API_PROVISION = "jaas-api-provision"; + + private static final String DEF_JAAS_API_PROVISION = "false"; + private final Element httpElt; + private final ParserContext pc; + private final SessionCreationPolicy sessionPolicy; + private final List interceptUrls; + private final MatcherType matcherType; private BeanDefinition cpf; + private BeanDefinition securityContextPersistenceFilter; + private BeanReference contextRepoRef; + private BeanReference sessionRegistryRef; + private BeanDefinition concurrentSessionFilter; + private BeanDefinition webAsyncManagerFilter; + private BeanDefinition requestCacheAwareFilter; + private BeanReference sessionStrategyRef; + private RootBeanDefinition sfpf; + private BeanDefinition servApiFilter; + private BeanDefinition jaasApiFilter; + private final BeanReference portMapper; + private final BeanReference portResolver; + private BeanReference fsi; + private BeanReference requestCache; + private BeanDefinition addHeadersFilter; + private BeanMetadataElement corsFilter; + private BeanDefinition csrfFilter; + private BeanMetadataElement csrfLogoutHandler; + private BeanMetadataElement csrfAuthStrategy; private CsrfBeanDefinitionParser csrfParser; private BeanDefinition invalidSession; + private boolean addAllAuth; - HttpConfigurationBuilder(Element element, boolean addAllAuth, - ParserContext pc, BeanReference portMapper, BeanReference portResolver, - BeanReference authenticationManager) { + HttpConfigurationBuilder(Element element, boolean addAllAuth, ParserContext pc, BeanReference portMapper, + BeanReference portResolver, BeanReference authenticationManager) { this.httpElt = element; this.addAllAuth = addAllAuth; this.pc = pc; this.portMapper = portMapper; this.portResolver = portResolver; this.matcherType = MatcherType.fromElement(element); - interceptUrls = DomUtils.getChildElementsByTagName(element, - Elements.INTERCEPT_URL); - - for (Element urlElt : interceptUrls) { - if (StringUtils.hasText(urlElt.getAttribute(ATT_FILTERS))) { - pc.getReaderContext() - .error("The use of \"filters='none'\" is no longer supported. Please define a" - + " separate element for the pattern you want to exclude and use the attribute" - + " \"security='none'\".", pc.extractSource(urlElt)); - } - } - + this.interceptUrls = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_URL); + validateInterceptUrls(pc); String createSession = element.getAttribute(ATT_CREATE_SESSION); - - if (StringUtils.hasText(createSession)) { - sessionPolicy = createPolicy(createSession); - } - else { - sessionPolicy = SessionCreationPolicy.IF_REQUIRED; - } - + this.sessionPolicy = !StringUtils.hasText(createSession) ? SessionCreationPolicy.IF_REQUIRED + : createPolicy(createSession); createCsrfFilter(); createSecurityContextPersistenceFilter(); createSessionManagementFilters(); @@ -195,52 +212,61 @@ class HttpConfigurationBuilder { createCorsFilter(); } + private void validateInterceptUrls(ParserContext pc) { + for (Element element : this.interceptUrls) { + if (StringUtils.hasText(element.getAttribute(HttpSecurityBeanDefinitionParser.ATT_FILTERS))) { + String message = "The use of \"filters='none'\" is no longer supported. Please define a" + + " separate element for the pattern you want to exclude and use the attribute" + + " \"security='none'\"."; + pc.getReaderContext().error(message, pc.extractSource(element)); + } + } + } + private SessionCreationPolicy createPolicy(String createSession) { if ("ifRequired".equals(createSession)) { return SessionCreationPolicy.IF_REQUIRED; } - else if ("always".equals(createSession)) { + if ("always".equals(createSession)) { return SessionCreationPolicy.ALWAYS; } - else if ("never".equals(createSession)) { + if ("never".equals(createSession)) { return SessionCreationPolicy.NEVER; } - else if ("stateless".equals(createSession)) { + if ("stateless".equals(createSession)) { return SessionCreationPolicy.STATELESS; } - - throw new IllegalStateException("Cannot convert " + createSession + " to " - + SessionCreationPolicy.class.getName()); + throw new IllegalStateException( + "Cannot convert " + createSession + " to " + SessionCreationPolicy.class.getName()); } @SuppressWarnings("rawtypes") void setLogoutHandlers(ManagedList logoutHandlers) { if (logoutHandlers != null) { - if (concurrentSessionFilter != null) { - concurrentSessionFilter.getPropertyValues().add("logoutHandlers", - logoutHandlers); + if (this.concurrentSessionFilter != null) { + this.concurrentSessionFilter.getPropertyValues().add("logoutHandlers", logoutHandlers); } - if (servApiFilter != null) { - servApiFilter.getPropertyValues().add("logoutHandlers", logoutHandlers); + if (this.servApiFilter != null) { + this.servApiFilter.getPropertyValues().add("logoutHandlers", logoutHandlers); } } } void setEntryPoint(BeanMetadataElement entryPoint) { - if (servApiFilter != null) { - servApiFilter.getPropertyValues().add("authenticationEntryPoint", entryPoint); + if (this.servApiFilter != null) { + this.servApiFilter.getPropertyValues().add("authenticationEntryPoint", entryPoint); } } void setAccessDeniedHandler(BeanMetadataElement accessDeniedHandler) { - if (csrfParser != null) { - csrfParser.initAccessDeniedHandler(this.invalidSession, accessDeniedHandler); + if (this.csrfParser != null) { + this.csrfParser.initAccessDeniedHandler(this.invalidSession, accessDeniedHandler); } } void setCsrfIgnoreRequestMatchers(List requestMatchers) { - if (csrfParser != null) { - csrfParser.setIgnoreCsrfRequestMatchers(requestMatchers); + if (this.csrfParser != null) { + this.csrfParser.setIgnoreCsrfRequestMatchers(requestMatchers); } } @@ -250,30 +276,25 @@ class HttpConfigurationBuilder { } private void createSecurityContextPersistenceFilter() { - BeanDefinitionBuilder scpf = BeanDefinitionBuilder - .rootBeanDefinition(SecurityContextPersistenceFilter.class); - - String repoRef = httpElt.getAttribute(ATT_SECURITY_CONTEXT_REPOSITORY); - String disableUrlRewriting = httpElt.getAttribute(ATT_DISABLE_URL_REWRITING); + BeanDefinitionBuilder scpf = BeanDefinitionBuilder.rootBeanDefinition(SecurityContextPersistenceFilter.class); + String repoRef = this.httpElt.getAttribute(ATT_SECURITY_CONTEXT_REPOSITORY); + String disableUrlRewriting = this.httpElt.getAttribute(ATT_DISABLE_URL_REWRITING); if (!StringUtils.hasText(disableUrlRewriting)) { disableUrlRewriting = "true"; } - if (StringUtils.hasText(repoRef)) { - if (sessionPolicy == SessionCreationPolicy.ALWAYS) { + if (this.sessionPolicy == SessionCreationPolicy.ALWAYS) { scpf.addPropertyValue("forceEagerSessionCreation", Boolean.TRUE); } } else { BeanDefinitionBuilder contextRepo; - if (sessionPolicy == SessionCreationPolicy.STATELESS) { - contextRepo = BeanDefinitionBuilder - .rootBeanDefinition(NullSecurityContextRepository.class); + if (this.sessionPolicy == SessionCreationPolicy.STATELESS) { + contextRepo = BeanDefinitionBuilder.rootBeanDefinition(NullSecurityContextRepository.class); } else { - contextRepo = BeanDefinitionBuilder - .rootBeanDefinition(HttpSessionSecurityContextRepository.class); - switch (sessionPolicy) { + contextRepo = BeanDefinitionBuilder.rootBeanDefinition(HttpSessionSecurityContextRepository.class); + switch (this.sessionPolicy) { case ALWAYS: contextRepo.addPropertyValue("allowSessionCreation", Boolean.TRUE); scpf.addPropertyValue("forceEagerSessionCreation", Boolean.TRUE); @@ -286,123 +307,97 @@ class HttpConfigurationBuilder { contextRepo.addPropertyValue("allowSessionCreation", Boolean.TRUE); scpf.addPropertyValue("forceEagerSessionCreation", Boolean.FALSE); } - if ("true".equals(disableUrlRewriting)) { contextRepo.addPropertyValue("disableUrlRewriting", Boolean.TRUE); } } - BeanDefinition repoBean = contextRepo.getBeanDefinition(); - repoRef = pc.getReaderContext().generateBeanName(repoBean); - pc.registerBeanComponent(new BeanComponentDefinition(repoBean, repoRef)); + repoRef = this.pc.getReaderContext().generateBeanName(repoBean); + this.pc.registerBeanComponent(new BeanComponentDefinition(repoBean, repoRef)); } - contextRepoRef = new RuntimeBeanReference(repoRef); - scpf.addConstructorArgValue(contextRepoRef); + this.contextRepoRef = new RuntimeBeanReference(repoRef); + scpf.addConstructorArgValue(this.contextRepoRef); - securityContextPersistenceFilter = scpf.getBeanDefinition(); + this.securityContextPersistenceFilter = scpf.getBeanDefinition(); } private void createSessionManagementFilters() { - Element sessionMgmtElt = DomUtils.getChildElementByTagName(httpElt, - Elements.SESSION_MANAGEMENT); + Element sessionMgmtElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.SESSION_MANAGEMENT); Element sessionCtrlElt = null; - String sessionFixationAttribute = null; String invalidSessionUrl = null; String invalidSessionStrategyRef = null; String sessionAuthStratRef = null; String errorUrl = null; - boolean sessionControlEnabled = false; if (sessionMgmtElt != null) { - if (sessionPolicy == SessionCreationPolicy.STATELESS) { - pc.getReaderContext().error( - Elements.SESSION_MANAGEMENT + " cannot be used" - + " in combination with " + ATT_CREATE_SESSION + "='" - + SessionCreationPolicy.STATELESS + "'", - pc.extractSource(sessionMgmtElt)); + if (this.sessionPolicy == SessionCreationPolicy.STATELESS) { + this.pc.getReaderContext() + .error(Elements.SESSION_MANAGEMENT + " cannot be used" + " in combination with " + + ATT_CREATE_SESSION + "='" + SessionCreationPolicy.STATELESS + "'", + this.pc.extractSource(sessionMgmtElt)); } - sessionFixationAttribute = sessionMgmtElt - .getAttribute(ATT_SESSION_FIXATION_PROTECTION); + sessionFixationAttribute = sessionMgmtElt.getAttribute(ATT_SESSION_FIXATION_PROTECTION); invalidSessionUrl = sessionMgmtElt.getAttribute(ATT_INVALID_SESSION_URL); invalidSessionStrategyRef = sessionMgmtElt.getAttribute(ATT_INVALID_SESSION_STRATEGY_REF); - - sessionAuthStratRef = sessionMgmtElt - .getAttribute(ATT_SESSION_AUTH_STRATEGY_REF); + sessionAuthStratRef = sessionMgmtElt.getAttribute(ATT_SESSION_AUTH_STRATEGY_REF); errorUrl = sessionMgmtElt.getAttribute(ATT_SESSION_AUTH_ERROR_URL); - sessionCtrlElt = DomUtils.getChildElementByTagName(sessionMgmtElt, - Elements.CONCURRENT_SESSIONS); + sessionCtrlElt = DomUtils.getChildElementByTagName(sessionMgmtElt, Elements.CONCURRENT_SESSIONS); sessionControlEnabled = sessionCtrlElt != null; - if (StringUtils.hasText(invalidSessionUrl) && StringUtils.hasText(invalidSessionStrategyRef)) { - pc.getReaderContext().error(ATT_INVALID_SESSION_URL + " attribute cannot be used in combination with" + - " the " + ATT_INVALID_SESSION_STRATEGY_REF + " attribute.", sessionMgmtElt); + this.pc.getReaderContext() + .error(ATT_INVALID_SESSION_URL + " attribute cannot be used in combination with" + " the " + + ATT_INVALID_SESSION_STRATEGY_REF + " attribute.", sessionMgmtElt); } - if (sessionControlEnabled) { if (StringUtils.hasText(sessionAuthStratRef)) { - pc.getReaderContext().error( - ATT_SESSION_AUTH_STRATEGY_REF + " attribute cannot be used" - + " in combination with <" - + Elements.CONCURRENT_SESSIONS + ">", - pc.extractSource(sessionCtrlElt)); + this.pc.getReaderContext() + .error(ATT_SESSION_AUTH_STRATEGY_REF + " attribute cannot be used" + + " in combination with <" + Elements.CONCURRENT_SESSIONS + ">", + this.pc.extractSource(sessionCtrlElt)); } createConcurrencyControlFilterAndSessionRegistry(sessionCtrlElt); } } if (!StringUtils.hasText(sessionFixationAttribute)) { - sessionFixationAttribute = OPT_CHANGE_SESSION_ID; + sessionFixationAttribute = OPT_CHANGE_SESSION_ID; } else if (StringUtils.hasText(sessionAuthStratRef)) { - pc.getReaderContext().error( - ATT_SESSION_FIXATION_PROTECTION + " attribute cannot be used" - + " in combination with " + ATT_SESSION_AUTH_STRATEGY_REF, - pc.extractSource(sessionMgmtElt)); + this.pc.getReaderContext().error(ATT_SESSION_FIXATION_PROTECTION + " attribute cannot be used" + + " in combination with " + ATT_SESSION_AUTH_STRATEGY_REF, this.pc.extractSource(sessionMgmtElt)); } - if (sessionPolicy == SessionCreationPolicy.STATELESS) { + if (this.sessionPolicy == SessionCreationPolicy.STATELESS) { // SEC-1424: do nothing return; } - boolean sessionFixationProtectionRequired = !sessionFixationAttribute .equals(OPT_SESSION_FIXATION_NO_PROTECTION); - ManagedList delegateSessionStrategies = new ManagedList<>(); BeanDefinitionBuilder concurrentSessionStrategy; BeanDefinitionBuilder sessionFixationStrategy = null; BeanDefinitionBuilder registerSessionStrategy; - - if (csrfAuthStrategy != null) { - delegateSessionStrategies.add(csrfAuthStrategy); + if (this.csrfAuthStrategy != null) { + delegateSessionStrategies.add(this.csrfAuthStrategy); } - if (sessionControlEnabled) { - assert sessionRegistryRef != null; + Assert.state(this.sessionRegistryRef != null, "No sessionRegistryRef found"); concurrentSessionStrategy = BeanDefinitionBuilder .rootBeanDefinition(ConcurrentSessionControlAuthenticationStrategy.class); - concurrentSessionStrategy.addConstructorArgValue(sessionRegistryRef); - + concurrentSessionStrategy.addConstructorArgValue(this.sessionRegistryRef); String maxSessions = sessionCtrlElt.getAttribute("max-sessions"); - if (StringUtils.hasText(maxSessions)) { - concurrentSessionStrategy - .addPropertyValue("maximumSessions", maxSessions); + concurrentSessionStrategy.addPropertyValue("maximumSessions", maxSessions); } - - String exceptionIfMaximumExceeded = sessionCtrlElt - .getAttribute("error-if-maximum-exceeded"); - + String exceptionIfMaximumExceeded = sessionCtrlElt.getAttribute("error-if-maximum-exceeded"); if (StringUtils.hasText(exceptionIfMaximumExceeded)) { - concurrentSessionStrategy.addPropertyValue("exceptionIfMaximumExceeded", - exceptionIfMaximumExceeded); + concurrentSessionStrategy.addPropertyValue("exceptionIfMaximumExceeded", exceptionIfMaximumExceeded); } delegateSessionStrategies.add(concurrentSessionStrategy.getBeanDefinition()); } - boolean useChangeSessionId = OPT_CHANGE_SESSION_ID - .equals(sessionFixationAttribute); + boolean useChangeSessionId = OPT_CHANGE_SESSION_ID.equals(sessionFixationAttribute); if (sessionFixationProtectionRequired || StringUtils.hasText(invalidSessionUrl)) { if (useChangeSessionId) { sessionFixationStrategy = BeanDefinitionBuilder @@ -414,222 +409,158 @@ class HttpConfigurationBuilder { } delegateSessionStrategies.add(sessionFixationStrategy.getBeanDefinition()); } - if (StringUtils.hasText(sessionAuthStratRef)) { delegateSessionStrategies.add(new RuntimeBeanReference(sessionAuthStratRef)); } - if (sessionControlEnabled) { registerSessionStrategy = BeanDefinitionBuilder .rootBeanDefinition(RegisterSessionAuthenticationStrategy.class); - registerSessionStrategy.addConstructorArgValue(sessionRegistryRef); + registerSessionStrategy.addConstructorArgValue(this.sessionRegistryRef); delegateSessionStrategies.add(registerSessionStrategy.getBeanDefinition()); } - if (delegateSessionStrategies.isEmpty()) { - sfpf = null; + this.sfpf = null; return; } - BeanDefinitionBuilder sessionMgmtFilter = BeanDefinitionBuilder .rootBeanDefinition(SessionManagementFilter.class); - RootBeanDefinition failureHandler = new RootBeanDefinition( - SimpleUrlAuthenticationFailureHandler.class); + RootBeanDefinition failureHandler = new RootBeanDefinition(SimpleUrlAuthenticationFailureHandler.class); if (StringUtils.hasText(errorUrl)) { - failureHandler.getPropertyValues().addPropertyValue("defaultFailureUrl", - errorUrl); + failureHandler.getPropertyValues().addPropertyValue("defaultFailureUrl", errorUrl); } - sessionMgmtFilter - .addPropertyValue("authenticationFailureHandler", failureHandler); - sessionMgmtFilter.addConstructorArgValue(contextRepoRef); - - if (!StringUtils.hasText(sessionAuthStratRef) && sessionFixationStrategy != null - && !useChangeSessionId) { - + sessionMgmtFilter.addPropertyValue("authenticationFailureHandler", failureHandler); + sessionMgmtFilter.addConstructorArgValue(this.contextRepoRef); + if (!StringUtils.hasText(sessionAuthStratRef) && sessionFixationStrategy != null && !useChangeSessionId) { if (sessionFixationProtectionRequired) { sessionFixationStrategy.addPropertyValue("migrateSessionAttributes", - sessionFixationAttribute - .equals(OPT_SESSION_FIXATION_MIGRATE_SESSION)); + sessionFixationAttribute.equals(OPT_SESSION_FIXATION_MIGRATE_SESSION)); } } - if (!delegateSessionStrategies.isEmpty()) { BeanDefinitionBuilder sessionStrategy = BeanDefinitionBuilder .rootBeanDefinition(CompositeSessionAuthenticationStrategy.class); BeanDefinition strategyBean = sessionStrategy.getBeanDefinition(); sessionStrategy.addConstructorArgValue(delegateSessionStrategies); - sessionAuthStratRef = pc.getReaderContext().generateBeanName(strategyBean); - pc.registerBeanComponent(new BeanComponentDefinition(strategyBean, - sessionAuthStratRef)); - + sessionAuthStratRef = this.pc.getReaderContext().generateBeanName(strategyBean); + this.pc.registerBeanComponent(new BeanComponentDefinition(strategyBean, sessionAuthStratRef)); } - - - if (StringUtils.hasText(invalidSessionUrl)) { BeanDefinitionBuilder invalidSessionBldr = BeanDefinitionBuilder .rootBeanDefinition(SimpleRedirectInvalidSessionStrategy.class); invalidSessionBldr.addConstructorArgValue(invalidSessionUrl); - invalidSession = invalidSessionBldr.getBeanDefinition(); - sessionMgmtFilter.addPropertyValue("invalidSessionStrategy", invalidSession); - } else if (StringUtils.hasText(invalidSessionStrategyRef)) { + this.invalidSession = invalidSessionBldr.getBeanDefinition(); + sessionMgmtFilter.addPropertyValue("invalidSessionStrategy", this.invalidSession); + } + else if (StringUtils.hasText(invalidSessionStrategyRef)) { sessionMgmtFilter.addPropertyReference("invalidSessionStrategy", invalidSessionStrategyRef); } - sessionMgmtFilter.addConstructorArgReference(sessionAuthStratRef); - - sfpf = (RootBeanDefinition) sessionMgmtFilter.getBeanDefinition(); - sessionStrategyRef = new RuntimeBeanReference(sessionAuthStratRef); + this.sfpf = (RootBeanDefinition) sessionMgmtFilter.getBeanDefinition(); + this.sessionStrategyRef = new RuntimeBeanReference(sessionAuthStratRef); } private void createConcurrencyControlFilterAndSessionRegistry(Element element) { - final String ATT_EXPIRY_URL = "expired-url"; - final String ATT_EXPIRED_SESSION_STRATEGY_REF = "expired-session-strategy-ref"; - final String ATT_SESSION_REGISTRY_ALIAS = "session-registry-alias"; - final String ATT_SESSION_REGISTRY_REF = "session-registry-ref"; - - CompositeComponentDefinition compositeDef = new CompositeComponentDefinition( - element.getTagName(), pc.extractSource(element)); - pc.pushContainingComponent(compositeDef); - - BeanDefinitionRegistry beanRegistry = pc.getRegistry(); - + CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), + this.pc.extractSource(element)); + this.pc.pushContainingComponent(compositeDef); + BeanDefinitionRegistry beanRegistry = this.pc.getRegistry(); String sessionRegistryId = element.getAttribute(ATT_SESSION_REGISTRY_REF); - if (!StringUtils.hasText(sessionRegistryId)) { // Register an internal SessionRegistryImpl if no external reference supplied. - RootBeanDefinition sessionRegistry = new RootBeanDefinition( - SessionRegistryImpl.class); - sessionRegistryId = pc.getReaderContext().registerWithGeneratedName( - sessionRegistry); - pc.registerComponent(new BeanComponentDefinition(sessionRegistry, - sessionRegistryId)); + RootBeanDefinition sessionRegistry = new RootBeanDefinition(SessionRegistryImpl.class); + sessionRegistryId = this.pc.getReaderContext().registerWithGeneratedName(sessionRegistry); + this.pc.registerComponent(new BeanComponentDefinition(sessionRegistry, sessionRegistryId)); } - String registryAlias = element.getAttribute(ATT_SESSION_REGISTRY_ALIAS); if (StringUtils.hasText(registryAlias)) { beanRegistry.registerAlias(sessionRegistryId, registryAlias); } - - BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder - .rootBeanDefinition(ConcurrentSessionFilter.class); + BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder.rootBeanDefinition(ConcurrentSessionFilter.class); filterBuilder.addConstructorArgReference(sessionRegistryId); - - Object source = pc.extractSource(element); + Object source = this.pc.extractSource(element); filterBuilder.getRawBeanDefinition().setSource(source); filterBuilder.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); - String expiryUrl = element.getAttribute(ATT_EXPIRY_URL); String expiredSessionStrategyRef = element.getAttribute(ATT_EXPIRED_SESSION_STRATEGY_REF); - if (StringUtils.hasText(expiryUrl) && StringUtils.hasText(expiredSessionStrategyRef)) { - pc.getReaderContext().error("Cannot use 'expired-url' attribute and 'expired-session-strategy-ref'" + - " attribute together.", source); + this.pc.getReaderContext().error( + "Cannot use 'expired-url' attribute and 'expired-session-strategy-ref'" + " attribute together.", + source); } - if (StringUtils.hasText(expiryUrl)) { BeanDefinitionBuilder expiredSessionBldr = BeanDefinitionBuilder .rootBeanDefinition(SimpleRedirectSessionInformationExpiredStrategy.class); expiredSessionBldr.addConstructorArgValue(expiryUrl); filterBuilder.addConstructorArgValue(expiredSessionBldr.getBeanDefinition()); - } else if (StringUtils.hasText(expiredSessionStrategyRef)) { + } + else if (StringUtils.hasText(expiredSessionStrategyRef)) { filterBuilder.addConstructorArgReference(expiredSessionStrategyRef); } - - pc.popAndRegisterContainingComponent(); - - concurrentSessionFilter = filterBuilder.getBeanDefinition(); - sessionRegistryRef = new RuntimeBeanReference(sessionRegistryId); + this.pc.popAndRegisterContainingComponent(); + this.concurrentSessionFilter = filterBuilder.getBeanDefinition(); + this.sessionRegistryRef = new RuntimeBeanReference(sessionRegistryId); } private void createWebAsyncManagerFilter() { boolean asyncSupported = ClassUtils.hasMethod(ServletRequest.class, "startAsync"); if (asyncSupported) { - webAsyncManagerFilter = new RootBeanDefinition( - WebAsyncManagerIntegrationFilter.class); + this.webAsyncManagerFilter = new RootBeanDefinition(WebAsyncManagerIntegrationFilter.class); } } // Adds the servlet-api integration filter if required private void createServletApiFilter(BeanReference authenticationManager) { - final String ATT_SERVLET_API_PROVISION = "servlet-api-provision"; - final String DEF_SERVLET_API_PROVISION = "true"; - - String provideServletApi = httpElt.getAttribute(ATT_SERVLET_API_PROVISION); + String provideServletApi = this.httpElt.getAttribute(ATT_SERVLET_API_PROVISION); if (!StringUtils.hasText(provideServletApi)) { provideServletApi = DEF_SERVLET_API_PROVISION; } - if ("true".equals(provideServletApi)) { - servApiFilter = GrantedAuthorityDefaultsParserUtils.registerWithDefaultRolePrefix(pc, SecurityContextHolderAwareRequestFilterBeanFactory.class); - servApiFilter.getPropertyValues().add("authenticationManager", - authenticationManager); + this.servApiFilter = GrantedAuthorityDefaultsParserUtils.registerWithDefaultRolePrefix(this.pc, + SecurityContextHolderAwareRequestFilterBeanFactory.class); + this.servApiFilter.getPropertyValues().add("authenticationManager", authenticationManager); } } // Adds the jaas-api integration filter if required private void createJaasApiFilter() { - final String ATT_JAAS_API_PROVISION = "jaas-api-provision"; - final String DEF_JAAS_API_PROVISION = "false"; - - String provideJaasApi = httpElt.getAttribute(ATT_JAAS_API_PROVISION); + String provideJaasApi = this.httpElt.getAttribute(ATT_JAAS_API_PROVISION); if (!StringUtils.hasText(provideJaasApi)) { provideJaasApi = DEF_JAAS_API_PROVISION; } - if ("true".equals(provideJaasApi)) { - jaasApiFilter = new RootBeanDefinition(JaasApiIntegrationFilter.class); + this.jaasApiFilter = new RootBeanDefinition(JaasApiIntegrationFilter.class); } } private void createChannelProcessingFilter() { ManagedMap channelRequestMap = parseInterceptUrlsForChannelSecurity(); - if (channelRequestMap.isEmpty()) { return; } - - RootBeanDefinition channelFilter = new RootBeanDefinition( - ChannelProcessingFilter.class); + RootBeanDefinition channelFilter = new RootBeanDefinition(ChannelProcessingFilter.class); BeanDefinitionBuilder metadataSourceBldr = BeanDefinitionBuilder .rootBeanDefinition(DefaultFilterInvocationSecurityMetadataSource.class); metadataSourceBldr.addConstructorArgValue(channelRequestMap); - // metadataSourceBldr.addPropertyValue("stripQueryStringFromUrls", matcher - // instanceof AntUrlPathMatcher); - channelFilter.getPropertyValues().addPropertyValue("securityMetadataSource", metadataSourceBldr.getBeanDefinition()); - RootBeanDefinition channelDecisionManager = new RootBeanDefinition( - ChannelDecisionManagerImpl.class); - ManagedList channelProcessors = new ManagedList<>( - 3); - RootBeanDefinition secureChannelProcessor = new RootBeanDefinition( - SecureChannelProcessor.class); - RootBeanDefinition retryWithHttp = new RootBeanDefinition( - RetryWithHttpEntryPoint.class); - RootBeanDefinition retryWithHttps = new RootBeanDefinition( - RetryWithHttpsEntryPoint.class); - - retryWithHttp.getPropertyValues().addPropertyValue("portMapper", portMapper); - retryWithHttp.getPropertyValues().addPropertyValue("portResolver", portResolver); - retryWithHttps.getPropertyValues().addPropertyValue("portMapper", portMapper); - retryWithHttps.getPropertyValues().addPropertyValue("portResolver", portResolver); - secureChannelProcessor.getPropertyValues().addPropertyValue("entryPoint", - retryWithHttps); - RootBeanDefinition inSecureChannelProcessor = new RootBeanDefinition( - InsecureChannelProcessor.class); - inSecureChannelProcessor.getPropertyValues().addPropertyValue("entryPoint", - retryWithHttp); + RootBeanDefinition channelDecisionManager = new RootBeanDefinition(ChannelDecisionManagerImpl.class); + ManagedList channelProcessors = new ManagedList<>(3); + RootBeanDefinition secureChannelProcessor = new RootBeanDefinition(SecureChannelProcessor.class); + RootBeanDefinition retryWithHttp = new RootBeanDefinition(RetryWithHttpEntryPoint.class); + RootBeanDefinition retryWithHttps = new RootBeanDefinition(RetryWithHttpsEntryPoint.class); + retryWithHttp.getPropertyValues().addPropertyValue("portMapper", this.portMapper); + retryWithHttp.getPropertyValues().addPropertyValue("portResolver", this.portResolver); + retryWithHttps.getPropertyValues().addPropertyValue("portMapper", this.portMapper); + retryWithHttps.getPropertyValues().addPropertyValue("portResolver", this.portResolver); + secureChannelProcessor.getPropertyValues().addPropertyValue("entryPoint", retryWithHttps); + RootBeanDefinition inSecureChannelProcessor = new RootBeanDefinition(InsecureChannelProcessor.class); + inSecureChannelProcessor.getPropertyValues().addPropertyValue("entryPoint", retryWithHttp); channelProcessors.add(secureChannelProcessor); channelProcessors.add(inSecureChannelProcessor); - channelDecisionManager.getPropertyValues().addPropertyValue("channelProcessors", - channelProcessors); - - String id = pc.getReaderContext().registerWithGeneratedName( - channelDecisionManager); - channelFilter.getPropertyValues().addPropertyValue("channelDecisionManager", - new RuntimeBeanReference(id)); - cpf = channelFilter; + channelDecisionManager.getPropertyValues().addPropertyValue("channelProcessors", channelProcessors); + String id = this.pc.getReaderContext().registerWithGeneratedName(channelDecisionManager); + channelFilter.getPropertyValues().addPropertyValue("channelDecisionManager", new RuntimeBeanReference(id)); + this.cpf = channelFilter; } /** @@ -638,149 +569,110 @@ class HttpConfigurationBuilder { * path. */ private ManagedMap parseInterceptUrlsForChannelSecurity() { - ManagedMap channelRequestMap = new ManagedMap<>(); - - for (Element urlElt : interceptUrls) { - String path = urlElt.getAttribute(ATT_PATH_PATTERN); - String method = urlElt.getAttribute(ATT_HTTP_METHOD); - String matcherRef = urlElt.getAttribute(ATT_REQUEST_MATCHER_REF); + for (Element urlElt : this.interceptUrls) { + String path = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_PATH_PATTERN); + String method = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_HTTP_METHOD); + String matcherRef = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUEST_MATCHER_REF); boolean hasMatcherRef = StringUtils.hasText(matcherRef); - if (!hasMatcherRef && !StringUtils.hasText(path)) { - pc.getReaderContext().error("pattern attribute cannot be empty or null", - urlElt); + this.pc.getReaderContext().error("pattern attribute cannot be empty or null", urlElt); } - - String requiredChannel = urlElt.getAttribute(ATT_REQUIRES_CHANNEL); - + String requiredChannel = urlElt.getAttribute(HttpSecurityBeanDefinitionParser.ATT_REQUIRES_CHANNEL); if (StringUtils.hasText(requiredChannel)) { - BeanMetadataElement matcher = hasMatcherRef ? new RuntimeBeanReference(matcherRef) : matcherType.createMatcher(pc, path, method); - - RootBeanDefinition channelAttributes = new RootBeanDefinition( - ChannelAttributeFactory.class); - channelAttributes.getConstructorArgumentValues().addGenericArgumentValue( - requiredChannel); + BeanMetadataElement matcher = hasMatcherRef ? new RuntimeBeanReference(matcherRef) + : this.matcherType.createMatcher(this.pc, path, method); + RootBeanDefinition channelAttributes = new RootBeanDefinition(ChannelAttributeFactory.class); + channelAttributes.getConstructorArgumentValues().addGenericArgumentValue(requiredChannel); channelAttributes.setFactoryMethodName("createChannelAttributes"); - channelRequestMap.put(matcher, channelAttributes); } } - return channelRequestMap; } private void createRequestCacheFilter() { - Element requestCacheElt = DomUtils.getChildElementByTagName(httpElt, - Elements.REQUEST_CACHE); - + Element requestCacheElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.REQUEST_CACHE); if (requestCacheElt != null) { - requestCache = new RuntimeBeanReference(requestCacheElt.getAttribute(ATT_REF)); + this.requestCache = new RuntimeBeanReference(requestCacheElt.getAttribute(ATT_REF)); } else { BeanDefinitionBuilder requestCacheBldr; - - if (sessionPolicy == SessionCreationPolicy.STATELESS) { - requestCacheBldr = BeanDefinitionBuilder - .rootBeanDefinition(NullRequestCache.class); + if (this.sessionPolicy == SessionCreationPolicy.STATELESS) { + requestCacheBldr = BeanDefinitionBuilder.rootBeanDefinition(NullRequestCache.class); } else { - requestCacheBldr = BeanDefinitionBuilder - .rootBeanDefinition(HttpSessionRequestCache.class); + requestCacheBldr = BeanDefinitionBuilder.rootBeanDefinition(HttpSessionRequestCache.class); requestCacheBldr.addPropertyValue("createSessionAllowed", - sessionPolicy == SessionCreationPolicy.IF_REQUIRED); - requestCacheBldr.addPropertyValue("portResolver", portResolver); - if (csrfFilter != null) { + this.sessionPolicy == SessionCreationPolicy.IF_REQUIRED); + requestCacheBldr.addPropertyValue("portResolver", this.portResolver); + if (this.csrfFilter != null) { BeanDefinitionBuilder requestCacheMatcherBldr = BeanDefinitionBuilder .rootBeanDefinition(AntPathRequestMatcher.class); requestCacheMatcherBldr.addConstructorArgValue("/**"); requestCacheMatcherBldr.addConstructorArgValue("GET"); - requestCacheBldr.addPropertyValue("requestMatcher", - requestCacheMatcherBldr.getBeanDefinition()); + requestCacheBldr.addPropertyValue("requestMatcher", requestCacheMatcherBldr.getBeanDefinition()); } } - BeanDefinition bean = requestCacheBldr.getBeanDefinition(); - String id = pc.getReaderContext().generateBeanName(bean); - pc.registerBeanComponent(new BeanComponentDefinition(bean, id)); - + String id = this.pc.getReaderContext().generateBeanName(bean); + this.pc.registerBeanComponent(new BeanComponentDefinition(bean, id)); this.requestCache = new RuntimeBeanReference(id); } - - requestCacheAwareFilter = new RootBeanDefinition(RequestCacheAwareFilter.class); - requestCacheAwareFilter.getConstructorArgumentValues().addGenericArgumentValue( - requestCache); + this.requestCacheAwareFilter = new RootBeanDefinition(RequestCacheAwareFilter.class); + this.requestCacheAwareFilter.getConstructorArgumentValues().addGenericArgumentValue(this.requestCache); } private void createFilterSecurityInterceptor(BeanReference authManager) { - boolean useExpressions = FilterInvocationSecurityMetadataSourceParser - .isUseExpressions(httpElt); + boolean useExpressions = FilterInvocationSecurityMetadataSourceParser.isUseExpressions(this.httpElt); RootBeanDefinition securityMds = FilterInvocationSecurityMetadataSourceParser - .createSecurityMetadataSource(interceptUrls, addAllAuth, httpElt, pc); - + .createSecurityMetadataSource(this.interceptUrls, this.addAllAuth, this.httpElt, this.pc); RootBeanDefinition accessDecisionMgr; ManagedList voters = new ManagedList<>(2); - if (useExpressions) { - BeanDefinitionBuilder expressionVoter = BeanDefinitionBuilder - .rootBeanDefinition(WebExpressionVoter.class); + BeanDefinitionBuilder expressionVoter = BeanDefinitionBuilder.rootBeanDefinition(WebExpressionVoter.class); // Read the expression handler from the FISMS - RuntimeBeanReference expressionHandler = (RuntimeBeanReference) securityMds - .getConstructorArgumentValues() + RuntimeBeanReference expressionHandler = (RuntimeBeanReference) securityMds.getConstructorArgumentValues() .getArgumentValue(1, RuntimeBeanReference.class).getValue(); - expressionVoter.addPropertyValue("expressionHandler", expressionHandler); - voters.add(expressionVoter.getBeanDefinition()); } else { - voters.add(GrantedAuthorityDefaultsParserUtils.registerWithDefaultRolePrefix(pc, RoleVoterBeanFactory.class)); + voters.add(GrantedAuthorityDefaultsParserUtils.registerWithDefaultRolePrefix(this.pc, + RoleVoterBeanFactory.class)); voters.add(new RootBeanDefinition(AuthenticatedVoter.class)); } accessDecisionMgr = new RootBeanDefinition(AffirmativeBased.class); accessDecisionMgr.getConstructorArgumentValues().addGenericArgumentValue(voters); - accessDecisionMgr.setSource(pc.extractSource(httpElt)); - + accessDecisionMgr.setSource(this.pc.extractSource(this.httpElt)); // Set up the access manager reference for http - String accessManagerId = httpElt.getAttribute(ATT_ACCESS_MGR); - + String accessManagerId = this.httpElt.getAttribute(ATT_ACCESS_MGR); if (!StringUtils.hasText(accessManagerId)) { - accessManagerId = pc.getReaderContext().generateBeanName(accessDecisionMgr); - pc.registerBeanComponent(new BeanComponentDefinition(accessDecisionMgr, - accessManagerId)); + accessManagerId = this.pc.getReaderContext().generateBeanName(accessDecisionMgr); + this.pc.registerBeanComponent(new BeanComponentDefinition(accessDecisionMgr, accessManagerId)); } - - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(FilterSecurityInterceptor.class); - + BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(FilterSecurityInterceptor.class); builder.addPropertyReference("accessDecisionManager", accessManagerId); builder.addPropertyValue("authenticationManager", authManager); - - if ("false".equals(httpElt.getAttribute(ATT_ONCE_PER_REQUEST))) { + if ("false".equals(this.httpElt.getAttribute(ATT_ONCE_PER_REQUEST))) { builder.addPropertyValue("observeOncePerRequest", Boolean.FALSE); } - builder.addPropertyValue("securityMetadataSource", securityMds); BeanDefinition fsiBean = builder.getBeanDefinition(); - String fsiId = pc.getReaderContext().generateBeanName(fsiBean); - pc.registerBeanComponent(new BeanComponentDefinition(fsiBean, fsiId)); - + String fsiId = this.pc.getReaderContext().generateBeanName(fsiBean); + this.pc.registerBeanComponent(new BeanComponentDefinition(fsiBean, fsiId)); // Create and register a DefaultWebInvocationPrivilegeEvaluator for use with // taglibs etc. - BeanDefinition wipe = new RootBeanDefinition( - DefaultWebInvocationPrivilegeEvaluator.class); - wipe.getConstructorArgumentValues().addGenericArgumentValue( - new RuntimeBeanReference(fsiId)); - - pc.registerBeanComponent(new BeanComponentDefinition(wipe, pc.getReaderContext() - .generateBeanName(wipe))); - + BeanDefinition wipe = new RootBeanDefinition(DefaultWebInvocationPrivilegeEvaluator.class); + wipe.getConstructorArgumentValues().addGenericArgumentValue(new RuntimeBeanReference(fsiId)); + this.pc.registerBeanComponent( + new BeanComponentDefinition(wipe, this.pc.getReaderContext().generateBeanName(wipe))); this.fsi = new RuntimeBeanReference(fsiId); } private void createAddHeadersFilter() { - Element elmt = DomUtils.getChildElementByTagName(httpElt, Elements.HEADERS); - this.addHeadersFilter = new HeadersBeanDefinitionParser().parse(elmt, pc); + Element elmt = DomUtils.getChildElementByTagName(this.httpElt, Elements.HEADERS); + this.addHeadersFilter = new HeadersBeanDefinitionParser().parse(elmt, this.pc); } private void createCorsFilter() { @@ -790,17 +682,15 @@ class HttpConfigurationBuilder { } private void createCsrfFilter() { - Element elmt = DomUtils.getChildElementByTagName(httpElt, Elements.CSRF); - csrfParser = new CsrfBeanDefinitionParser(); - csrfFilter = csrfParser.parse(elmt, pc); - - if (csrfFilter == null) { - csrfParser = null; + Element elmt = DomUtils.getChildElementByTagName(this.httpElt, Elements.CSRF); + this.csrfParser = new CsrfBeanDefinitionParser(); + this.csrfFilter = this.csrfParser.parse(elmt, this.pc); + if (this.csrfFilter == null) { + this.csrfParser = null; return; } - - this.csrfAuthStrategy = csrfParser.getCsrfAuthenticationStrategy(); - this.csrfLogoutHandler = csrfParser.getCsrfLogoutHandler(); + this.csrfAuthStrategy = this.csrfParser.getCsrfAuthenticationStrategy(); + this.csrfLogoutHandler = this.csrfParser.getCsrfLogoutHandler(); } BeanMetadataElement getCsrfLogoutHandler() { @@ -808,85 +698,77 @@ class HttpConfigurationBuilder { } BeanReference getSessionStrategy() { - return sessionStrategyRef; + return this.sessionStrategyRef; } SessionCreationPolicy getSessionCreationPolicy() { - return sessionPolicy; + return this.sessionPolicy; } BeanReference getRequestCache() { - return requestCache; + return this.requestCache; } List getFilters() { List filters = new ArrayList<>(); - - if (cpf != null) { - filters.add(new OrderDecorator(cpf, CHANNEL_FILTER)); + if (this.cpf != null) { + filters.add(new OrderDecorator(this.cpf, SecurityFilters.CHANNEL_FILTER)); } - - if (concurrentSessionFilter != null) { - filters.add(new OrderDecorator(concurrentSessionFilter, - CONCURRENT_SESSION_FILTER)); + if (this.concurrentSessionFilter != null) { + filters.add(new OrderDecorator(this.concurrentSessionFilter, SecurityFilters.CONCURRENT_SESSION_FILTER)); } - - if (webAsyncManagerFilter != null) { - filters.add(new OrderDecorator(webAsyncManagerFilter, - WEB_ASYNC_MANAGER_FILTER)); + if (this.webAsyncManagerFilter != null) { + filters.add(new OrderDecorator(this.webAsyncManagerFilter, SecurityFilters.WEB_ASYNC_MANAGER_FILTER)); } - - filters.add(new OrderDecorator(securityContextPersistenceFilter, - SECURITY_CONTEXT_FILTER)); - - if (servApiFilter != null) { - filters.add(new OrderDecorator(servApiFilter, SERVLET_API_SUPPORT_FILTER)); + filters.add(new OrderDecorator(this.securityContextPersistenceFilter, SecurityFilters.SECURITY_CONTEXT_FILTER)); + if (this.servApiFilter != null) { + filters.add(new OrderDecorator(this.servApiFilter, SecurityFilters.SERVLET_API_SUPPORT_FILTER)); } - - if (jaasApiFilter != null) { - filters.add(new OrderDecorator(jaasApiFilter, JAAS_API_SUPPORT_FILTER)); + if (this.jaasApiFilter != null) { + filters.add(new OrderDecorator(this.jaasApiFilter, SecurityFilters.JAAS_API_SUPPORT_FILTER)); } - - if (sfpf != null) { - filters.add(new OrderDecorator(sfpf, SESSION_MANAGEMENT_FILTER)); + if (this.sfpf != null) { + filters.add(new OrderDecorator(this.sfpf, SecurityFilters.SESSION_MANAGEMENT_FILTER)); } - - filters.add(new OrderDecorator(fsi, FILTER_SECURITY_INTERCEPTOR)); - - if (sessionPolicy != SessionCreationPolicy.STATELESS) { - filters.add(new OrderDecorator(requestCacheAwareFilter, REQUEST_CACHE_FILTER)); + filters.add(new OrderDecorator(this.fsi, SecurityFilters.FILTER_SECURITY_INTERCEPTOR)); + if (this.sessionPolicy != SessionCreationPolicy.STATELESS) { + filters.add(new OrderDecorator(this.requestCacheAwareFilter, SecurityFilters.REQUEST_CACHE_FILTER)); } - if (this.corsFilter != null) { - filters.add(new OrderDecorator(this.corsFilter, CORS_FILTER)); + filters.add(new OrderDecorator(this.corsFilter, SecurityFilters.CORS_FILTER)); } - - if (addHeadersFilter != null) { - filters.add(new OrderDecorator(addHeadersFilter, HEADERS_FILTER)); + if (this.addHeadersFilter != null) { + filters.add(new OrderDecorator(this.addHeadersFilter, SecurityFilters.HEADERS_FILTER)); } - - if (csrfFilter != null) { - filters.add(new OrderDecorator(csrfFilter, CSRF_FILTER)); + if (this.csrfFilter != null) { + filters.add(new OrderDecorator(this.csrfFilter, SecurityFilters.CSRF_FILTER)); } - return filters; } static class RoleVoterBeanFactory extends AbstractGrantedAuthorityDefaultsBeanFactory { + private RoleVoter voter = new RoleVoter(); + @Override public RoleVoter getBean() { - voter.setRolePrefix(this.rolePrefix); - return voter; + this.voter.setRolePrefix(this.rolePrefix); + return this.voter; } + } - static class SecurityContextHolderAwareRequestFilterBeanFactory extends GrantedAuthorityDefaultsParserUtils.AbstractGrantedAuthorityDefaultsBeanFactory { + static class SecurityContextHolderAwareRequestFilterBeanFactory + extends GrantedAuthorityDefaultsParserUtils.AbstractGrantedAuthorityDefaultsBeanFactory { + private SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter(); + @Override public SecurityContextHolderAwareRequestFilter getBean() { - filter.setRolePrefix(this.rolePrefix); - return filter; + this.filter.setRolePrefix(this.rolePrefix); + return this.filter; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/http/HttpFirewallBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/HttpFirewallBeanDefinitionParser.java index 6fd912d81a..2a166c662d 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpFirewallBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpFirewallBeanDefinitionParser.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.config.BeanIds; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * Injects the supplied {@code HttpFirewall} bean reference into the @@ -31,22 +33,17 @@ import org.w3c.dom.Element; */ public class HttpFirewallBeanDefinitionParser implements BeanDefinitionParser { + @Override public BeanDefinition parse(Element element, ParserContext pc) { String ref = element.getAttribute("ref"); - if (!StringUtils.hasText(ref)) { - pc.getReaderContext().error("ref attribute is required", - pc.extractSource(element)); + pc.getReaderContext().error("ref attribute is required", pc.extractSource(element)); } - // Ensure the FCP is registered. - HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, - pc.extractSource(element)); - BeanDefinition filterChainProxy = pc.getRegistry().getBeanDefinition( - BeanIds.FILTER_CHAIN_PROXY); - filterChainProxy.getPropertyValues().addPropertyValue("firewall", - new RuntimeBeanReference(ref)); - + HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, pc.extractSource(element)); + BeanDefinition filterChainProxy = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAIN_PROXY); + filterChainProxy.getPropertyValues().addPropertyValue("firewall", new RuntimeBeanReference(ref)); return null; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java index d9f4a74ee2..970245d134 100644 --- a/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.ArrayList; @@ -42,7 +43,6 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.core.OrderComparator; -import org.springframework.core.Ordered; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.DefaultAuthenticationEventPublisher; import org.springframework.security.authentication.ProviderManager; @@ -65,23 +65,35 @@ import org.springframework.util.xml.DomUtils; * @since 2.0 */ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { - private static final Log logger = LogFactory - .getLog(HttpSecurityBeanDefinitionParser.class); + + private static final Log logger = LogFactory.getLog(HttpSecurityBeanDefinitionParser.class); private static final String ATT_AUTHENTICATION_MANAGER_REF = "authentication-manager-ref"; + static final String ATT_REQUEST_MATCHER_REF = "request-matcher-ref"; + static final String ATT_PATH_PATTERN = "pattern"; + static final String ATT_HTTP_METHOD = "method"; static final String ATT_FILTERS = "filters"; + static final String OPT_FILTERS_NONE = "none"; static final String ATT_REQUIRES_CHANNEL = "requires-channel"; private static final String ATT_REF = "ref"; + private static final String ATT_SECURED = "security"; + private static final String OPT_SECURITY_NONE = "none"; + private static final String ATT_AFTER = "after"; + + private static final String ATT_BEFORE = "before"; + + private static final String ATT_POSITION = "position"; + public HttpSecurityBeanDefinitionParser() { } @@ -97,20 +109,15 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { @SuppressWarnings({ "unchecked" }) @Override public BeanDefinition parse(Element element, ParserContext pc) { - CompositeComponentDefinition compositeDef = new CompositeComponentDefinition( - element.getTagName(), pc.extractSource(element)); + CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), + pc.extractSource(element)); pc.pushContainingComponent(compositeDef); - registerFilterChainProxyIfNecessary(pc, pc.extractSource(element)); - // Obtain the filter chains and add the new chain to it - BeanDefinition listFactoryBean = pc.getRegistry().getBeanDefinition( - BeanIds.FILTER_CHAINS); - List filterChains = (List) listFactoryBean - .getPropertyValues().getPropertyValue("sourceList").getValue(); - + BeanDefinition listFactoryBean = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAINS); + List filterChains = (List) listFactoryBean.getPropertyValues() + .getPropertyValue("sourceList").getValue(); filterChains.add(createFilterChain(element, pc)); - pc.popAndRegisterContainingComponent(); return null; } @@ -120,109 +127,82 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { */ private BeanReference createFilterChain(Element element, ParserContext pc) { boolean secured = !OPT_SECURITY_NONE.equals(element.getAttribute(ATT_SECURED)); - if (!secured) { - if (!StringUtils.hasText(element.getAttribute(ATT_PATH_PATTERN)) - && !StringUtils.hasText(ATT_REQUEST_MATCHER_REF)) { - pc.getReaderContext().error( - "The '" + ATT_SECURED - + "' attribute must be used in combination with" - + " the '" + ATT_PATH_PATTERN + "' or '" - + ATT_REQUEST_MATCHER_REF + "' attributes.", - pc.extractSource(element)); - } - - for (int n = 0; n < element.getChildNodes().getLength(); n++) { - if (element.getChildNodes().item(n) instanceof Element) { - pc.getReaderContext().error( - "If you are using to define an unsecured pattern, " - + "it cannot contain child elements.", - pc.extractSource(element)); + validateSecuredFilterChainElement(element, pc); + for (int i = 0; i < element.getChildNodes().getLength(); i++) { + if (element.getChildNodes().item(i) instanceof Element) { + pc.getReaderContext().error("If you are using to define an unsecured pattern, " + + "it cannot contain child elements.", pc.extractSource(element)); } } - return createSecurityFilterChainBean(element, pc, Collections.emptyList()); } - - final BeanReference portMapper = createPortMapper(element, pc); - final BeanReference portResolver = createPortResolver(portMapper, pc); - + BeanReference portMapper = createPortMapper(element, pc); + BeanReference portResolver = createPortResolver(portMapper, pc); ManagedList authenticationProviders = new ManagedList<>(); - BeanReference authenticationManager = createAuthenticationManager(element, pc, - authenticationProviders); - + BeanReference authenticationManager = createAuthenticationManager(element, pc, authenticationProviders); boolean forceAutoConfig = isDefaultHttpConfig(element); - HttpConfigurationBuilder httpBldr = new HttpConfigurationBuilder(element, - forceAutoConfig, pc, portMapper, portResolver, authenticationManager); - - AuthenticationConfigBuilder authBldr = new AuthenticationConfigBuilder(element, - forceAutoConfig, pc, httpBldr.getSessionCreationPolicy(), - httpBldr.getRequestCache(), authenticationManager, - httpBldr.getSessionStrategy(), portMapper, portResolver, - httpBldr.getCsrfLogoutHandler()); - + HttpConfigurationBuilder httpBldr = new HttpConfigurationBuilder(element, forceAutoConfig, pc, portMapper, + portResolver, authenticationManager); + AuthenticationConfigBuilder authBldr = new AuthenticationConfigBuilder(element, forceAutoConfig, pc, + httpBldr.getSessionCreationPolicy(), httpBldr.getRequestCache(), authenticationManager, + httpBldr.getSessionStrategy(), portMapper, portResolver, httpBldr.getCsrfLogoutHandler()); httpBldr.setLogoutHandlers(authBldr.getLogoutHandlers()); httpBldr.setEntryPoint(authBldr.getEntryPointBean()); httpBldr.setAccessDeniedHandler(authBldr.getAccessDeniedHandlerBean()); httpBldr.setCsrfIgnoreRequestMatchers(authBldr.getCsrfIgnoreRequestMatchers()); - authenticationProviders.addAll(authBldr.getProviders()); - List unorderedFilterChain = new ArrayList<>(); - unorderedFilterChain.addAll(httpBldr.getFilters()); unorderedFilterChain.addAll(authBldr.getFilters()); unorderedFilterChain.addAll(buildCustomFilterList(element, pc)); - unorderedFilterChain.sort(new OrderComparator()); checkFilterChainOrder(unorderedFilterChain, pc, pc.extractSource(element)); - // The list of filter beans List filterChain = new ManagedList<>(); - for (OrderDecorator od : unorderedFilterChain) { filterChain.add(od.bean); } - return createSecurityFilterChainBean(element, pc, filterChain); } - private static boolean isDefaultHttpConfig(Element httpElt) { - return httpElt.getChildNodes().getLength() == 0 - && httpElt.getAttributes().getLength() == 0; + private void validateSecuredFilterChainElement(Element element, ParserContext pc) { + if (!StringUtils.hasText(element.getAttribute(ATT_PATH_PATTERN)) + && !StringUtils.hasText(ATT_REQUEST_MATCHER_REF)) { + String message = "The '" + ATT_SECURED + "' attribute must be used in combination with" + " the '" + + ATT_PATH_PATTERN + "' or '" + ATT_REQUEST_MATCHER_REF + "' attributes."; + pc.getReaderContext().error(message, pc.extractSource(element)); + } } - private BeanReference createSecurityFilterChainBean(Element element, - ParserContext pc, List filterChain) { - BeanMetadataElement filterChainMatcher; + private static boolean isDefaultHttpConfig(Element httpElt) { + return httpElt.getChildNodes().getLength() == 0 && httpElt.getAttributes().getLength() == 0; + } + private BeanReference createSecurityFilterChainBean(Element element, ParserContext pc, List filterChain) { + BeanMetadataElement filterChainMatcher; String requestMatcherRef = element.getAttribute(ATT_REQUEST_MATCHER_REF); String filterChainPattern = element.getAttribute(ATT_PATH_PATTERN); - if (StringUtils.hasText(requestMatcherRef)) { if (StringUtils.hasText(filterChainPattern)) { pc.getReaderContext().error( - "You can't define a pattern and a request-matcher-ref for the " - + "same filter chain", pc.extractSource(element)); + "You can't define a pattern and a request-matcher-ref for the " + "same filter chain", + pc.extractSource(element)); } filterChainMatcher = new RuntimeBeanReference(requestMatcherRef); } else if (StringUtils.hasText(filterChainPattern)) { - filterChainMatcher = MatcherType.fromElement(element).createMatcher(pc, - filterChainPattern, null); + filterChainMatcher = MatcherType.fromElement(element).createMatcher(pc, filterChainPattern, null); } else { filterChainMatcher = new RootBeanDefinition(AnyRequestMatcher.class); } - BeanDefinitionBuilder filterChainBldr = BeanDefinitionBuilder .rootBeanDefinition(DefaultSecurityFilterChain.class); filterChainBldr.addConstructorArgValue(filterChainMatcher); filterChainBldr.addConstructorArgValue(filterChain); - BeanDefinition filterChainBean = filterChainBldr.getBeanDefinition(); - String id = element.getAttribute("name"); if (!StringUtils.hasText(id)) { id = element.getAttribute("id"); @@ -230,30 +210,25 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { id = pc.getReaderContext().generateBeanName(filterChainBean); } } - pc.registerBeanComponent(new BeanComponentDefinition(filterChainBean, id)); - return new RuntimeBeanReference(id); } private BeanReference createPortMapper(Element elt, ParserContext pc) { // Register the portMapper. A default will always be created, even if no element // exists. - BeanDefinition portMapper = new PortMappingsBeanDefinitionParser().parse( - DomUtils.getChildElementByTagName(elt, Elements.PORT_MAPPINGS), pc); + BeanDefinition portMapper = new PortMappingsBeanDefinitionParser() + .parse(DomUtils.getChildElementByTagName(elt, Elements.PORT_MAPPINGS), pc); String portMapperName = pc.getReaderContext().generateBeanName(portMapper); pc.registerBeanComponent(new BeanComponentDefinition(portMapper, portMapperName)); - return new RuntimeBeanReference(portMapperName); } - private RuntimeBeanReference createPortResolver(BeanReference portMapper, - ParserContext pc) { + private RuntimeBeanReference createPortResolver(BeanReference portMapper, ParserContext pc) { RootBeanDefinition portResolver = new RootBeanDefinition(PortResolverImpl.class); portResolver.getPropertyValues().addPropertyValue("portMapper", portMapper); String portResolverName = pc.getReaderContext().generateBeanName(portResolver); - pc.registerBeanComponent(new BeanComponentDefinition(portResolver, - portResolverName)); + pc.registerBeanComponent(new BeanComponentDefinition(portResolver, portResolverName)); return new RuntimeBeanReference(portResolverName); } @@ -268,67 +243,49 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { private BeanReference createAuthenticationManager(Element element, ParserContext pc, ManagedList authenticationProviders) { String parentMgrRef = element.getAttribute(ATT_AUTHENTICATION_MANAGER_REF); - BeanDefinitionBuilder authManager = BeanDefinitionBuilder - .rootBeanDefinition(ProviderManager.class); + BeanDefinitionBuilder authManager = BeanDefinitionBuilder.rootBeanDefinition(ProviderManager.class); authManager.addConstructorArgValue(authenticationProviders); - if (StringUtils.hasText(parentMgrRef)) { - RuntimeBeanReference parentAuthManager = new RuntimeBeanReference( - parentMgrRef); + RuntimeBeanReference parentAuthManager = new RuntimeBeanReference(parentMgrRef); authManager.addConstructorArgValue(parentAuthManager); RootBeanDefinition clearCredentials = new RootBeanDefinition( ClearCredentialsMethodInvokingFactoryBean.class); - clearCredentials.getPropertyValues().addPropertyValue("targetObject", - parentAuthManager); + clearCredentials.getPropertyValues().addPropertyValue("targetObject", parentAuthManager); clearCredentials.getPropertyValues().addPropertyValue("targetMethod", "isEraseCredentialsAfterAuthentication"); - - authManager.addPropertyValue("eraseCredentialsAfterAuthentication", - clearCredentials); + authManager.addPropertyValue("eraseCredentialsAfterAuthentication", clearCredentials); } else { - RootBeanDefinition amfb = new RootBeanDefinition( - AuthenticationManagerFactoryBean.class); + RootBeanDefinition amfb = new RootBeanDefinition(AuthenticationManagerFactoryBean.class); amfb.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); String amfbId = pc.getReaderContext().generateBeanName(amfb); pc.registerBeanComponent(new BeanComponentDefinition(amfb, amfbId)); - RootBeanDefinition clearCredentials = new RootBeanDefinition( - MethodInvokingFactoryBean.class); - clearCredentials.getPropertyValues().addPropertyValue("targetObject", - new RuntimeBeanReference(amfbId)); + RootBeanDefinition clearCredentials = new RootBeanDefinition(MethodInvokingFactoryBean.class); + clearCredentials.getPropertyValues().addPropertyValue("targetObject", new RuntimeBeanReference(amfbId)); clearCredentials.getPropertyValues().addPropertyValue("targetMethod", "isEraseCredentialsAfterAuthentication"); - authManager.addConstructorArgValue(new RuntimeBeanReference(amfbId)); - authManager.addPropertyValue("eraseCredentialsAfterAuthentication", - clearCredentials); + authManager.addPropertyValue("eraseCredentialsAfterAuthentication", clearCredentials); } - // gh-6009 - authManager.addPropertyValue("authenticationEventPublisher", new RootBeanDefinition(DefaultAuthenticationEventPublisher.class)); + authManager.addPropertyValue("authenticationEventPublisher", + new RootBeanDefinition(DefaultAuthenticationEventPublisher.class)); authManager.getRawBeanDefinition().setSource(pc.extractSource(element)); BeanDefinition authMgrBean = authManager.getBeanDefinition(); String id = pc.getReaderContext().generateBeanName(authMgrBean); pc.registerBeanComponent(new BeanComponentDefinition(authMgrBean, id)); - return new RuntimeBeanReference(id); } - private void checkFilterChainOrder(List filters, ParserContext pc, - Object source) { + private void checkFilterChainOrder(List filters, ParserContext pc, Object source) { logger.info("Checking sorted filter chain: " + filters); - for (int i = 0; i < filters.size(); i++) { OrderDecorator filter = filters.get(i); - if (i > 0) { OrderDecorator previous = filters.get(i - 1); if (filter.getOrder() == previous.getOrder()) { pc.getReaderContext() - .error("Filter beans '" - + filter.bean - + "' and '" - + previous.bean + .error("Filter beans '" + filter.bean + "' and '" + previous.bean + "' have the same 'order' value. When using custom filters, " + "please make sure the positions do not conflict with default filters. " + "Alternatively you can disable the default filters by removing the corresponding " @@ -340,39 +297,23 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { } List buildCustomFilterList(Element element, ParserContext pc) { - List customFilterElts = DomUtils.getChildElementsByTagName(element, - Elements.CUSTOM_FILTER); + List customFilterElts = DomUtils.getChildElementsByTagName(element, Elements.CUSTOM_FILTER); List customFilters = new ArrayList<>(); - - final String ATT_AFTER = "after"; - final String ATT_BEFORE = "before"; - final String ATT_POSITION = "position"; - for (Element elt : customFilterElts) { String after = elt.getAttribute(ATT_AFTER); String before = elt.getAttribute(ATT_BEFORE); String position = elt.getAttribute(ATT_POSITION); - String ref = elt.getAttribute(ATT_REF); - if (!StringUtils.hasText(ref)) { - pc.getReaderContext().error( - "The '" + ATT_REF + "' attribute must be supplied", - pc.extractSource(elt)); + pc.getReaderContext().error("The '" + ATT_REF + "' attribute must be supplied", pc.extractSource(elt)); } - RuntimeBeanReference bean = new RuntimeBeanReference(ref); - if (WebConfigUtils.countNonEmpty(new String[] { after, before, position }) != 1) { - pc.getReaderContext().error( - "A single '" + ATT_AFTER + "', '" + ATT_BEFORE + "', or '" - + ATT_POSITION + "' attribute must be supplied", - pc.extractSource(elt)); + pc.getReaderContext().error("A single '" + ATT_AFTER + "', '" + ATT_BEFORE + "', or '" + ATT_POSITION + + "' attribute must be supplied", pc.extractSource(elt)); } - if (StringUtils.hasText(position)) { - customFilters.add(new OrderDecorator(bean, SecurityFilters - .valueOf(position))); + customFilters.add(new OrderDecorator(bean, SecurityFilters.valueOf(position))); } else if (StringUtils.hasText(after)) { SecurityFilters order = SecurityFilters.valueOf(after); @@ -393,7 +334,6 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { } } } - return customFilters; } @@ -406,22 +346,16 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { // FilterChainProxy BeanDefinition listFactoryBean = new RootBeanDefinition(ListFactoryBean.class); listFactoryBean.getPropertyValues().add("sourceList", new ManagedList()); - pc.registerBeanComponent(new BeanComponentDefinition(listFactoryBean, - BeanIds.FILTER_CHAINS)); - - BeanDefinitionBuilder fcpBldr = BeanDefinitionBuilder - .rootBeanDefinition(FilterChainProxy.class); + pc.registerBeanComponent(new BeanComponentDefinition(listFactoryBean, BeanIds.FILTER_CHAINS)); + BeanDefinitionBuilder fcpBldr = BeanDefinitionBuilder.rootBeanDefinition(FilterChainProxy.class); fcpBldr.getRawBeanDefinition().setSource(source); fcpBldr.addConstructorArgReference(BeanIds.FILTER_CHAINS); - fcpBldr.addPropertyValue("filterChainValidator", new RootBeanDefinition( - DefaultFilterChainValidator.class)); + fcpBldr.addPropertyValue("filterChainValidator", new RootBeanDefinition(DefaultFilterChainValidator.class)); BeanDefinition fcpBean = fcpBldr.getBeanDefinition(); - pc.registerBeanComponent(new BeanComponentDefinition(fcpBean, - BeanIds.FILTER_CHAIN_PROXY)); - registry.registerAlias(BeanIds.FILTER_CHAIN_PROXY, - BeanIds.SPRING_SECURITY_FILTER_CHAIN); - - BeanDefinitionBuilder requestRejected = BeanDefinitionBuilder.rootBeanDefinition(RequestRejectedHandlerPostProcessor.class); + pc.registerBeanComponent(new BeanComponentDefinition(fcpBean, BeanIds.FILTER_CHAIN_PROXY)); + registry.registerAlias(BeanIds.FILTER_CHAIN_PROXY, BeanIds.SPRING_SECURITY_FILTER_CHAIN); + BeanDefinitionBuilder requestRejected = BeanDefinitionBuilder + .rootBeanDefinition(RequestRejectedHandlerPostProcessor.class); requestRejected.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); requestRejected.addConstructorArgValue("requestRejectedHandler"); requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY); @@ -431,91 +365,70 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser { registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean); } -} + static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor { -class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor { - private final String beanName; + private final String beanName; - private final String targetBeanName; + private final String targetBeanName; - private final String targetPropertyName; + private final String targetPropertyName; - RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) { - this.beanName = beanName; - this.targetBeanName = targetBeanName; - this.targetPropertyName = targetPropertyName; - } - - @Override - public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { - if (registry.containsBeanDefinition(this.beanName)) { - BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName); - beanDefinition.getPropertyValues().add(this.targetPropertyName, new RuntimeBeanReference(this.beanName)); + RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) { + this.beanName = beanName; + this.targetBeanName = targetBeanName; + this.targetPropertyName = targetPropertyName; } - } - @Override - public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { - - } -} - -class OrderDecorator implements Ordered { - final BeanMetadataElement bean; - final int order; - - OrderDecorator(BeanMetadataElement bean, SecurityFilters filterOrder) { - this.bean = bean; - this.order = filterOrder.getOrder(); - } - - OrderDecorator(BeanMetadataElement bean, int order) { - this.bean = bean; - this.order = order; - } - - @Override - public int getOrder() { - return order; - } - - @Override - public String toString() { - return bean + ", order = " + order; - } -} - -/** - * Custom {@link MethodInvokingFactoryBean} that is specifically used for looking up the - * child {@link ProviderManager} value for - * {@link ProviderManager#setEraseCredentialsAfterAuthentication(boolean)} given the - * parent {@link AuthenticationManager}. This is necessary because the parent - * {@link AuthenticationManager} might not be a {@link ProviderManager}. - * - * @author Rob Winch - */ -final class ClearCredentialsMethodInvokingFactoryBean extends MethodInvokingFactoryBean { - @Override - public void afterPropertiesSet() throws Exception { - boolean isTargetProviderManager = getTargetObject() instanceof ProviderManager; - if (!isTargetProviderManager) { - setTargetObject(this); + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (registry.containsBeanDefinition(this.beanName)) { + BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName); + beanDefinition.getPropertyValues().add(this.targetPropertyName, + new RuntimeBeanReference(this.beanName)); + } } - super.afterPropertiesSet(); + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + + } + } /** - * The default value if the target object is not a ProviderManager is false. We use - * false because this feature is associated with {@link ProviderManager} not - * {@link AuthenticationManager}. If the user wants to leverage - * {@link ProviderManager#setEraseCredentialsAfterAuthentication(boolean)} their - * original {@link AuthenticationManager} must be a {@link ProviderManager} (we should - * not magically add this functionality to their implementation since we cannot - * determine if it should be on or off). + * Custom {@link MethodInvokingFactoryBean} that is specifically used for looking up + * the child {@link ProviderManager} value for + * {@link ProviderManager#setEraseCredentialsAfterAuthentication(boolean)} given the + * parent {@link AuthenticationManager}. This is necessary because the parent + * {@link AuthenticationManager} might not be a {@link ProviderManager}. * - * @return + * @author Rob Winch */ - public boolean isEraseCredentialsAfterAuthentication() { - return false; + static final class ClearCredentialsMethodInvokingFactoryBean extends MethodInvokingFactoryBean { + + @Override + public void afterPropertiesSet() throws Exception { + boolean isTargetProviderManager = getTargetObject() instanceof ProviderManager; + if (!isTargetProviderManager) { + setTargetObject(this); + } + super.afterPropertiesSet(); + } + + /** + * The default value if the target object is not a ProviderManager is false. We + * use false because this feature is associated with {@link ProviderManager} not + * {@link AuthenticationManager}. If the user wants to leverage + * {@link ProviderManager#setEraseCredentialsAfterAuthentication(boolean)} their + * original {@link AuthenticationManager} must be a {@link ProviderManager} (we + * should not magically add this functionality to their implementation since we + * cannot determine if it should be on or off). + * @return + */ + boolean isEraseCredentialsAfterAuthentication() { + return false; + } + } + } diff --git a/config/src/main/java/org/springframework/security/config/http/LogoutBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/LogoutBeanDefinitionParser.java index 8b8c4f8b6c..65c1b3b931 100644 --- a/config/src/main/java/org/springframework/security/config/http/LogoutBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/LogoutBeanDefinitionParser.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -28,7 +31,6 @@ import org.springframework.security.web.authentication.logout.LogoutFilter; import org.springframework.security.web.authentication.logout.LogoutSuccessEventPublishingLogoutHandler; import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * @author Luke Taylor @@ -36,40 +38,44 @@ import org.w3c.dom.Element; * @author Onur Kagan Ozcan */ class LogoutBeanDefinitionParser implements BeanDefinitionParser { + static final String ATT_LOGOUT_SUCCESS_URL = "logout-success-url"; static final String ATT_INVALIDATE_SESSION = "invalidate-session"; static final String ATT_LOGOUT_URL = "logout-url"; + static final String DEF_LOGOUT_URL = "/logout"; + static final String ATT_LOGOUT_HANDLER = "success-handler-ref"; + static final String ATT_DELETE_COOKIES = "delete-cookies"; final String rememberMeServices; + private final String defaultLogoutUrl; + private ManagedList logoutHandlers = new ManagedList<>(); + private boolean csrfEnabled; - LogoutBeanDefinitionParser(String loginPageUrl, String rememberMeServices, - BeanMetadataElement csrfLogoutHandler) { + LogoutBeanDefinitionParser(String loginPageUrl, String rememberMeServices, BeanMetadataElement csrfLogoutHandler) { this.defaultLogoutUrl = loginPageUrl + "?logout"; this.rememberMeServices = rememberMeServices; this.csrfEnabled = csrfLogoutHandler != null; if (this.csrfEnabled) { - logoutHandlers.add(csrfLogoutHandler); + this.logoutHandlers.add(csrfLogoutHandler); } } + @Override public BeanDefinition parse(Element element, ParserContext pc) { String logoutUrl = null; String successHandlerRef = null; String logoutSuccessUrl = null; String invalidateSession = null; String deleteCookies = null; - - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(LogoutFilter.class); - + BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(LogoutFilter.class); if (element != null) { Object source = pc.extractSource(element); builder.getRawBeanDefinition().setSource(source); @@ -81,51 +87,39 @@ class LogoutBeanDefinitionParser implements BeanDefinitionParser { invalidateSession = element.getAttribute(ATT_INVALIDATE_SESSION); deleteCookies = element.getAttribute(ATT_DELETE_COOKIES); } - if (!StringUtils.hasText(logoutUrl)) { logoutUrl = DEF_LOGOUT_URL; } - - builder.addPropertyValue("logoutRequestMatcher", - getLogoutRequestMatcher(logoutUrl)); - + builder.addPropertyValue("logoutRequestMatcher", getLogoutRequestMatcher(logoutUrl)); if (StringUtils.hasText(successHandlerRef)) { if (StringUtils.hasText(logoutSuccessUrl)) { pc.getReaderContext().error( - "Use " + ATT_LOGOUT_SUCCESS_URL + " or " + ATT_LOGOUT_HANDLER - + ", but not both", pc.extractSource(element)); + "Use " + ATT_LOGOUT_SUCCESS_URL + " or " + ATT_LOGOUT_HANDLER + ", but not both", + pc.extractSource(element)); } builder.addConstructorArgReference(successHandlerRef); } else { // Use the logout URL if no handler set if (!StringUtils.hasText(logoutSuccessUrl)) { - logoutSuccessUrl = defaultLogoutUrl; + logoutSuccessUrl = this.defaultLogoutUrl; } builder.addConstructorArgValue(logoutSuccessUrl); } - BeanDefinition sclh = new RootBeanDefinition(SecurityContextLogoutHandler.class); - sclh.getPropertyValues().addPropertyValue("invalidateHttpSession", - !"false".equals(invalidateSession)); - logoutHandlers.add(sclh); - - if (rememberMeServices != null) { - logoutHandlers.add(new RuntimeBeanReference(rememberMeServices)); + sclh.getPropertyValues().addPropertyValue("invalidateHttpSession", !"false".equals(invalidateSession)); + this.logoutHandlers.add(sclh); + if (this.rememberMeServices != null) { + this.logoutHandlers.add(new RuntimeBeanReference(this.rememberMeServices)); } - if (StringUtils.hasText(deleteCookies)) { - BeanDefinition cookieDeleter = new RootBeanDefinition( - CookieClearingLogoutHandler.class); + BeanDefinition cookieDeleter = new RootBeanDefinition(CookieClearingLogoutHandler.class); String[] names = StringUtils.tokenizeToStringArray(deleteCookies, ","); cookieDeleter.getConstructorArgumentValues().addGenericArgumentValue(names); - logoutHandlers.add(cookieDeleter); + this.logoutHandlers.add(cookieDeleter); } - - logoutHandlers.add(new RootBeanDefinition(LogoutSuccessEventPublishingLogoutHandler.class)); - - builder.addConstructorArgValue(logoutHandlers); - + this.logoutHandlers.add(new RootBeanDefinition(LogoutSuccessEventPublishingLogoutHandler.class)); + builder.addConstructorArgValue(this.logoutHandlers); return builder.getBeanDefinition(); } @@ -136,11 +130,11 @@ class LogoutBeanDefinitionParser implements BeanDefinitionParser { if (this.csrfEnabled) { matcherBuilder.addConstructorArgValue("POST"); } - return matcherBuilder.getBeanDefinition(); } ManagedList getLogoutHandlers() { - return logoutHandlers; + return this.logoutHandlers; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/MatcherType.java b/config/src/main/java/org/springframework/security/config/http/MatcherType.java index 58cd734aed..9d65f9ea63 100644 --- a/config/src/main/java/org/springframework/security/config/http/MatcherType.java +++ b/config/src/main/java/org/springframework/security/config/http/MatcherType.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.RootBeanDefinition; @@ -25,7 +28,6 @@ import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * Defines the {@link RequestMatcher} types supported by the namespace. @@ -34,8 +36,9 @@ import org.w3c.dom.Element; * @since 3.1 */ public enum MatcherType { - ant(AntPathRequestMatcher.class), regex(RegexRequestMatcher.class), ciRegex( - RegexRequestMatcher.class), mvc(MvcRequestMatcher.class); + + ant(AntPathRequestMatcher.class), regex(RegexRequestMatcher.class), ciRegex(RegexRequestMatcher.class), mvc( + MvcRequestMatcher.class); private static final String HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME = "mvcHandlerMappingIntrospector"; @@ -55,14 +58,10 @@ public enum MatcherType { if (("/**".equals(path) || "**".equals(path)) && method == null) { return new RootBeanDefinition(AnyRequestMatcher.class); } - - BeanDefinitionBuilder matcherBldr = BeanDefinitionBuilder - .rootBeanDefinition(type); - + BeanDefinitionBuilder matcherBldr = BeanDefinitionBuilder.rootBeanDefinition(this.type); if (this == mvc) { matcherBldr.addConstructorArgValue(new RootBeanDefinition(HandlerMappingIntrospectorFactoryBean.class)); } - matcherBldr.addConstructorArgValue(path); if (this == mvc) { matcherBldr.addPropertyValue("method", method); @@ -71,11 +70,9 @@ public enum MatcherType { else { matcherBldr.addConstructorArgValue(method); } - if (this == ciRegex) { matcherBldr.addConstructorArgValue(true); } - return matcherBldr.getBeanDefinition(); } @@ -86,4 +83,5 @@ public enum MatcherType { return ant; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java index ccc3219e46..3a72f62585 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; @@ -28,27 +31,31 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequest import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; - -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository; -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository; -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService; -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository; /** * @author Joe Grandja * @since 5.3 */ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { + private static final String ELT_AUTHORIZATION_CODE_GRANT = "authorization-code-grant"; + private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref"; + private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; + private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; + private final BeanReference requestCache; + private final BeanReference authenticationManager; + private BeanDefinition defaultAuthorizedClientRepository; + private BeanDefinition authorizationRequestRedirectFilter; + private BeanDefinition authorizationCodeGrantFilter; + private BeanDefinition authorizationCodeAuthenticationProvider; OAuth2ClientBeanDefinitionParser(BeanReference requestCache, BeanReference authenticationManager) { @@ -59,76 +66,66 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { @Override public BeanDefinition parse(Element element, ParserContext parserContext) { Element authorizationCodeGrantElt = DomUtils.getChildElementByTagName(element, ELT_AUTHORIZATION_CODE_GRANT); - - BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element); - BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element); + BeanMetadataElement clientRegistrationRepository = OAuth2ClientBeanDefinitionParserUtils + .getClientRegistrationRepository(element); + BeanMetadataElement authorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils + .getAuthorizedClientRepository(element); if (authorizedClientRepository == null) { - BeanMetadataElement authorizedClientService = getAuthorizedClientService(element); - this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository( - clientRegistrationRepository, authorizedClientService); + BeanMetadataElement authorizedClientService = OAuth2ClientBeanDefinitionParserUtils + .getAuthorizedClientService(element); + this.defaultAuthorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils + .createDefaultAuthorizedClientRepository(clientRegistrationRepository, authorizedClientService); authorizedClientRepository = new RuntimeBeanReference(OAuth2AuthorizedClientRepository.class); } BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository( authorizationCodeGrantElt); - BeanDefinitionBuilder authorizationRequestRedirectFilterBuilder = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationRequestRedirectFilter.class); - String authorizationRequestResolverRef = authorizationCodeGrantElt != null ? - authorizationCodeGrantElt.getAttribute(ATT_AUTHORIZATION_REQUEST_RESOLVER_REF) : null; + String authorizationRequestResolverRef = (authorizationCodeGrantElt != null) + ? authorizationCodeGrantElt.getAttribute(ATT_AUTHORIZATION_REQUEST_RESOLVER_REF) : null; if (!StringUtils.isEmpty(authorizationRequestResolverRef)) { authorizationRequestRedirectFilterBuilder.addConstructorArgReference(authorizationRequestResolverRef); - } else { + } + else { authorizationRequestRedirectFilterBuilder.addConstructorArgValue(clientRegistrationRepository); } this.authorizationRequestRedirectFilter = authorizationRequestRedirectFilterBuilder .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) - .addPropertyValue("requestCache", this.requestCache) - .getBeanDefinition(); - + .addPropertyValue("requestCache", this.requestCache).getBeanDefinition(); this.authorizationCodeGrantFilter = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationCodeGrantFilter.class) - .addConstructorArgValue(clientRegistrationRepository) - .addConstructorArgValue(authorizedClientRepository) + .addConstructorArgValue(clientRegistrationRepository).addConstructorArgValue(authorizedClientRepository) .addConstructorArgValue(this.authenticationManager) - .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) - .getBeanDefinition(); + .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository).getBeanDefinition(); BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(authorizationCodeGrantElt); - this.authorizationCodeAuthenticationProvider = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationCodeAuthenticationProvider.class) - .addConstructorArgValue(accessTokenResponseClient) - .getBeanDefinition(); + .addConstructorArgValue(accessTokenResponseClient).getBeanDefinition(); return null; } private BeanMetadataElement getAuthorizationRequestRepository(Element element) { - BeanMetadataElement authorizationRequestRepository; - String authorizationRequestRepositoryRef = element != null ? - element.getAttribute(ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF) : null; + String authorizationRequestRepositoryRef = (element != null) + ? element.getAttribute(ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF) : null; if (!StringUtils.isEmpty(authorizationRequestRepositoryRef)) { - authorizationRequestRepository = new RuntimeBeanReference(authorizationRequestRepositoryRef); - } else { - authorizationRequestRepository = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository") - .getBeanDefinition(); + return new RuntimeBeanReference(authorizationRequestRepositoryRef); } - return authorizationRequestRepository; + return BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository") + .getBeanDefinition(); } private BeanMetadataElement getAccessTokenResponseClient(Element element) { - BeanMetadataElement accessTokenResponseClient; - String accessTokenResponseClientRef = element != null ? - element.getAttribute(ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF) : null; + String accessTokenResponseClientRef = (element != null) + ? element.getAttribute(ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF) : null; if (!StringUtils.isEmpty(accessTokenResponseClientRef)) { - accessTokenResponseClient = new RuntimeBeanReference(accessTokenResponseClientRef); - } else { - accessTokenResponseClient = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient") - .getBeanDefinition(); + return new RuntimeBeanReference(accessTokenResponseClientRef); } - return accessTokenResponseClient; + return BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient") + .getBeanDefinition(); } BeanDefinition getDefaultAuthorizedClientRepository() { @@ -146,4 +143,5 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { BeanDefinition getAuthorizationCodeAuthenticationProvider() { return this.authorizationCodeAuthenticationProvider; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java index c0af4d7d44..8b8a333c7b 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java @@ -13,34 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * @author Joe Grandja * @since 5.4 */ final class OAuth2ClientBeanDefinitionParserUtils { + private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref"; + private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref"; + private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref"; + private OAuth2ClientBeanDefinitionParserUtils() { + } + static BeanMetadataElement getClientRegistrationRepository(Element element) { - BeanMetadataElement clientRegistrationRepository; String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF); if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) { - clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef); - } else { - clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class); + return new RuntimeBeanReference(clientRegistrationRepositoryRef); } - return clientRegistrationRepository; + return new RuntimeBeanReference(ClientRegistrationRepository.class); } static BeanMetadataElement getAuthorizedClientRepository(Element element) { @@ -59,17 +64,17 @@ final class OAuth2ClientBeanDefinitionParserUtils { return null; } - static BeanDefinition createDefaultAuthorizedClientRepository( - BeanMetadataElement clientRegistrationRepository, BeanMetadataElement authorizedClientService) { + static BeanDefinition createDefaultAuthorizedClientRepository(BeanMetadataElement clientRegistrationRepository, + BeanMetadataElement authorizedClientService) { if (authorizedClientService == null) { - authorizedClientService = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService") - .addConstructorArgValue(clientRegistrationRepository) - .getBeanDefinition(); + authorizedClientService = BeanDefinitionBuilder + .rootBeanDefinition( + "org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService") + .addConstructorArgValue(clientRegistrationRepository).getBeanDefinition(); } return BeanDefinitionBuilder.rootBeanDefinition( "org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository") - .addConstructorArgValue(authorizedClientService) - .getBeanDefinition(); + .addConstructorArgValue(authorizedClientService).getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java index a6535ea064..a62cd65055 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.springframework.beans.BeansException; @@ -38,7 +39,9 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandl * @since 5.4 */ final class OAuth2ClientWebMvcSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { + private static final String CUSTOM_ARGUMENT_RESOLVERS_PROPERTY = "customArgumentResolvers"; + private BeanFactory beanFactory; @Override @@ -47,29 +50,26 @@ final class OAuth2ClientWebMvcSecurityPostProcessor implements BeanDefinitionReg (ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, false, false); String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, false, false); - if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) { return; } - for (String beanName : registry.getBeanDefinitionNames()) { BeanDefinition beanDefinition = registry.getBeanDefinition(beanName); if (RequestMappingHandlerAdapter.class.getName().equals(beanDefinition.getBeanClassName())) { - PropertyValue currentArgumentResolvers = - beanDefinition.getPropertyValues().getPropertyValue(CUSTOM_ARGUMENT_RESOLVERS_PROPERTY); + PropertyValue currentArgumentResolvers = beanDefinition.getPropertyValues() + .getPropertyValue(CUSTOM_ARGUMENT_RESOLVERS_PROPERTY); ManagedList argumentResolvers = new ManagedList<>(); if (currentArgumentResolvers != null) { argumentResolvers.addAll((ManagedList) currentArgumentResolvers.getValue()); } - String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors( (ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, false, false); - - BeanDefinitionBuilder beanDefinitionBuilder = - BeanDefinitionBuilder.genericBeanDefinition(OAuth2AuthorizedClientArgumentResolver.class); + BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder + .genericBeanDefinition(OAuth2AuthorizedClientArgumentResolver.class); if (authorizedClientManagerBeanNames.length == 1) { beanDefinitionBuilder.addConstructorArgReference(authorizedClientManagerBeanNames[0]); - } else { + } + else { beanDefinitionBuilder.addConstructorArgReference(clientRegistrationRepositoryBeanNames[0]); beanDefinitionBuilder.addConstructorArgReference(authorizedClientRepositoryBeanNames[0]); } @@ -88,4 +88,5 @@ final class OAuth2ClientWebMvcSecurityPostProcessor implements BeanDefinitionReg public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.beanFactory = beanFactory; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java index 68ac6f2041..288b09072e 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanDefinition; @@ -58,19 +68,6 @@ import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.springframework.web.accept.ContentNegotiationStrategy; import org.springframework.web.accept.HeaderContentNegotiationStrategy; -import org.w3c.dom.Element; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository; -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository; -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService; -import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository; /** * @author Ruby Hartono @@ -79,26 +76,43 @@ import static org.springframework.security.config.http.OAuth2ClientBeanDefinitio final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { private static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization"; + private static final String DEFAULT_LOGIN_URI = DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL; private static final String ELT_CLIENT_REGISTRATION = "client-registration"; + private static final String ATT_REGISTRATION_ID = "registration-id"; + private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref"; + private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; + private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; + private static final String ATT_USER_AUTHORITIES_MAPPER_REF = "user-authorities-mapper-ref"; + private static final String ATT_USER_SERVICE_REF = "user-service-ref"; + private static final String ATT_OIDC_USER_SERVICE_REF = "oidc-user-service-ref"; + private static final String ATT_LOGIN_PROCESSING_URL = "login-processing-url"; + private static final String ATT_LOGIN_PAGE = "login-page"; + private static final String ATT_AUTHENTICATION_SUCCESS_HANDLER_REF = "authentication-success-handler-ref"; + private static final String ATT_AUTHENTICATION_FAILURE_HANDLER_REF = "authentication-failure-handler-ref"; + private static final String ATT_JWT_DECODER_FACTORY_REF = "jwt-decoder-factory-ref"; private final BeanReference requestCache; + private final BeanReference portMapper; + private final BeanReference portResolver; + private final BeanReference sessionStrategy; + private final boolean allowSessionCreation; private BeanDefinition defaultAuthorizedClientRepository; @@ -128,204 +142,174 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { BeanDefinition oauth2LoginBeanConfig = BeanDefinitionBuilder.rootBeanDefinition(OAuth2LoginBeanConfig.class) .getBeanDefinition(); String oauth2LoginBeanConfigId = parserContext.getReaderContext().generateBeanName(oauth2LoginBeanConfig); - parserContext.registerBeanComponent( - new BeanComponentDefinition(oauth2LoginBeanConfig, oauth2LoginBeanConfigId)); - + parserContext + .registerBeanComponent(new BeanComponentDefinition(oauth2LoginBeanConfig, oauth2LoginBeanConfigId)); // configure filter - BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element); - BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element); + BeanMetadataElement clientRegistrationRepository = OAuth2ClientBeanDefinitionParserUtils + .getClientRegistrationRepository(element); + BeanMetadataElement authorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils + .getAuthorizedClientRepository(element); if (authorizedClientRepository == null) { - BeanMetadataElement authorizedClientService = getAuthorizedClientService(element); - this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository( - clientRegistrationRepository, authorizedClientService); + BeanMetadataElement authorizedClientService = OAuth2ClientBeanDefinitionParserUtils + .getAuthorizedClientService(element); + this.defaultAuthorizedClientRepository = OAuth2ClientBeanDefinitionParserUtils + .createDefaultAuthorizedClientRepository(clientRegistrationRepository, authorizedClientService); authorizedClientRepository = new RuntimeBeanReference(OAuth2AuthorizedClientRepository.class); } BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(element); BeanMetadataElement oauth2UserService = getOAuth2UserService(element); BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository(element); - BeanDefinitionBuilder oauth2LoginAuthenticationFilterBuilder = BeanDefinitionBuilder .rootBeanDefinition(OAuth2LoginAuthenticationFilter.class) - .addConstructorArgValue(clientRegistrationRepository) - .addConstructorArgValue(authorizedClientRepository) + .addConstructorArgValue(clientRegistrationRepository).addConstructorArgValue(authorizedClientRepository) .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository); - - if (sessionStrategy != null) { - oauth2LoginAuthenticationFilterBuilder.addPropertyValue("sessionAuthenticationStrategy", sessionStrategy); + if (this.sessionStrategy != null) { + oauth2LoginAuthenticationFilterBuilder.addPropertyValue("sessionAuthenticationStrategy", + this.sessionStrategy); } - Object source = parserContext.extractSource(element); String loginProcessingUrl = element.getAttribute(ATT_LOGIN_PROCESSING_URL); if (!StringUtils.isEmpty(loginProcessingUrl)) { WebConfigUtils.validateHttpRedirect(loginProcessingUrl, parserContext, source); oauth2LoginAuthenticationFilterBuilder.addConstructorArgValue(loginProcessingUrl); - } else { + } + else { oauth2LoginAuthenticationFilterBuilder .addConstructorArgValue(OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI); } - BeanDefinitionBuilder oauth2LoginAuthenticationProviderBuilder = BeanDefinitionBuilder .rootBeanDefinition(OAuth2LoginAuthenticationProvider.class) - .addConstructorArgValue(accessTokenResponseClient) - .addConstructorArgValue(oauth2UserService); - + .addConstructorArgValue(accessTokenResponseClient).addConstructorArgValue(oauth2UserService); String userAuthoritiesMapperRef = element.getAttribute(ATT_USER_AUTHORITIES_MAPPER_REF); if (!StringUtils.isEmpty(userAuthoritiesMapperRef)) { - oauth2LoginAuthenticationProviderBuilder.addPropertyReference("authoritiesMapper", userAuthoritiesMapperRef); + oauth2LoginAuthenticationProviderBuilder.addPropertyReference("authoritiesMapper", + userAuthoritiesMapperRef); } - - oauth2LoginAuthenticationProvider = oauth2LoginAuthenticationProviderBuilder.getBeanDefinition(); - - oauth2LoginOidcAuthenticationProvider = getOidcAuthProvider( - element, accessTokenResponseClient, userAuthoritiesMapperRef); - + this.oauth2LoginAuthenticationProvider = oauth2LoginAuthenticationProviderBuilder.getBeanDefinition(); + this.oauth2LoginOidcAuthenticationProvider = getOidcAuthProvider(element, accessTokenResponseClient, + userAuthoritiesMapperRef); BeanDefinitionBuilder oauth2AuthorizationRequestRedirectFilterBuilder = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationRequestRedirectFilter.class); - String authorizationRequestResolverRef = element.getAttribute(ATT_AUTHORIZATION_REQUEST_RESOLVER_REF); if (!StringUtils.isEmpty(authorizationRequestResolverRef)) { - oauth2AuthorizationRequestRedirectFilterBuilder - .addConstructorArgReference(authorizationRequestResolverRef); - } else { + oauth2AuthorizationRequestRedirectFilterBuilder.addConstructorArgReference(authorizationRequestResolverRef); + } + else { oauth2AuthorizationRequestRedirectFilterBuilder.addConstructorArgValue(clientRegistrationRepository); } - oauth2AuthorizationRequestRedirectFilterBuilder .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) - .addPropertyValue("requestCache", requestCache); - oauth2AuthorizationRequestRedirectFilter = oauth2AuthorizationRequestRedirectFilterBuilder.getBeanDefinition(); - + .addPropertyValue("requestCache", this.requestCache); + this.oauth2AuthorizationRequestRedirectFilter = oauth2AuthorizationRequestRedirectFilterBuilder + .getBeanDefinition(); String authenticationSuccessHandlerRef = element.getAttribute(ATT_AUTHENTICATION_SUCCESS_HANDLER_REF); if (!StringUtils.isEmpty(authenticationSuccessHandlerRef)) { oauth2LoginAuthenticationFilterBuilder.addPropertyReference("authenticationSuccessHandler", authenticationSuccessHandlerRef); - } else { + } + else { BeanDefinitionBuilder successHandlerBuilder = BeanDefinitionBuilder.rootBeanDefinition( "org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler") - .addPropertyValue("requestCache", requestCache); + .addPropertyValue("requestCache", this.requestCache); oauth2LoginAuthenticationFilterBuilder.addPropertyValue("authenticationSuccessHandler", successHandlerBuilder.getBeanDefinition()); } - String loginPage = element.getAttribute(ATT_LOGIN_PAGE); if (!StringUtils.isEmpty(loginPage)) { WebConfigUtils.validateHttpRedirect(loginPage, parserContext, source); - oauth2LoginAuthenticationEntryPoint = BeanDefinitionBuilder - .rootBeanDefinition(LoginUrlAuthenticationEntryPoint.class) - .addConstructorArgValue(loginPage) - .addPropertyValue("portMapper", portMapper) - .addPropertyValue("portResolver", portResolver) + this.oauth2LoginAuthenticationEntryPoint = BeanDefinitionBuilder + .rootBeanDefinition(LoginUrlAuthenticationEntryPoint.class).addConstructorArgValue(loginPage) + .addPropertyValue("portMapper", this.portMapper).addPropertyValue("portResolver", this.portResolver) .getBeanDefinition(); - } else { + } + else { Map entryPoint = getLoginEntryPoint(element); if (entryPoint != null) { - oauth2LoginAuthenticationEntryPoint = BeanDefinitionBuilder - .rootBeanDefinition(DelegatingAuthenticationEntryPoint.class) - .addConstructorArgValue(entryPoint) + this.oauth2LoginAuthenticationEntryPoint = BeanDefinitionBuilder + .rootBeanDefinition(DelegatingAuthenticationEntryPoint.class).addConstructorArgValue(entryPoint) .addPropertyValue("defaultEntryPoint", new LoginUrlAuthenticationEntryPoint(DEFAULT_LOGIN_URI)) .getBeanDefinition(); } } - String authenticationFailureHandlerRef = element.getAttribute(ATT_AUTHENTICATION_FAILURE_HANDLER_REF); if (!StringUtils.isEmpty(authenticationFailureHandlerRef)) { oauth2LoginAuthenticationFilterBuilder.addPropertyReference("authenticationFailureHandler", authenticationFailureHandlerRef); - } else { + } + else { BeanDefinitionBuilder failureHandlerBuilder = BeanDefinitionBuilder.rootBeanDefinition( "org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler"); failureHandlerBuilder.addConstructorArgValue( DEFAULT_LOGIN_URI + "?" + DefaultLoginPageGeneratingFilter.ERROR_PARAMETER_NAME); - failureHandlerBuilder.addPropertyValue("allowSessionCreation", allowSessionCreation); + failureHandlerBuilder.addPropertyValue("allowSessionCreation", this.allowSessionCreation); oauth2LoginAuthenticationFilterBuilder.addPropertyValue("authenticationFailureHandler", failureHandlerBuilder.getBeanDefinition()); } - // prepare loginlinks - oauth2LoginLinks = BeanDefinitionBuilder.rootBeanDefinition(Map.class) + this.oauth2LoginLinks = BeanDefinitionBuilder.rootBeanDefinition(Map.class) .setFactoryMethodOnBean("getLoginLinks", oauth2LoginBeanConfigId).getBeanDefinition(); - return oauth2LoginAuthenticationFilterBuilder.getBeanDefinition(); } private BeanMetadataElement getAuthorizationRequestRepository(Element element) { - BeanMetadataElement authorizationRequestRepository; String authorizationRequestRepositoryRef = element.getAttribute(ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF); if (!StringUtils.isEmpty(authorizationRequestRepositoryRef)) { - authorizationRequestRepository = new RuntimeBeanReference(authorizationRequestRepositoryRef); - } else { - authorizationRequestRepository = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository") - .getBeanDefinition(); + return new RuntimeBeanReference(authorizationRequestRepositoryRef); } - return authorizationRequestRepository; + return BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository") + .getBeanDefinition(); } - private BeanDefinition getOidcAuthProvider(Element element, - BeanMetadataElement accessTokenResponseClient, String userAuthoritiesMapperRef) { - - boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent( - "org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); + private BeanDefinition getOidcAuthProvider(Element element, BeanMetadataElement accessTokenResponseClient, + String userAuthoritiesMapperRef) { + boolean oidcAuthenticationProviderEnabled = ClassUtils + .isPresent("org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); if (!oidcAuthenticationProviderEnabled) { return BeanDefinitionBuilder.rootBeanDefinition(OidcAuthenticationRequestChecker.class).getBeanDefinition(); } - BeanMetadataElement oidcUserService = getOidcUserService(element); - BeanDefinitionBuilder oidcAuthProviderBuilder = BeanDefinitionBuilder.rootBeanDefinition( "org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider") - .addConstructorArgValue(accessTokenResponseClient) - .addConstructorArgValue(oidcUserService); - + .addConstructorArgValue(accessTokenResponseClient).addConstructorArgValue(oidcUserService); if (!StringUtils.isEmpty(userAuthoritiesMapperRef)) { oidcAuthProviderBuilder.addPropertyReference("authoritiesMapper", userAuthoritiesMapperRef); } - String jwtDecoderFactoryRef = element.getAttribute(ATT_JWT_DECODER_FACTORY_REF); if (!StringUtils.isEmpty(jwtDecoderFactoryRef)) { oidcAuthProviderBuilder.addPropertyReference("jwtDecoderFactory", jwtDecoderFactoryRef); } - return oidcAuthProviderBuilder.getBeanDefinition(); } private BeanMetadataElement getOidcUserService(Element element) { - BeanMetadataElement oidcUserService; String oidcUserServiceRef = element.getAttribute(ATT_OIDC_USER_SERVICE_REF); if (!StringUtils.isEmpty(oidcUserServiceRef)) { - oidcUserService = new RuntimeBeanReference(oidcUserServiceRef); - } else { - oidcUserService = BeanDefinitionBuilder - .rootBeanDefinition("org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService") - .getBeanDefinition(); + return new RuntimeBeanReference(oidcUserServiceRef); } - return oidcUserService; + return BeanDefinitionBuilder + .rootBeanDefinition("org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService") + .getBeanDefinition(); } private BeanMetadataElement getOAuth2UserService(Element element) { - BeanMetadataElement oauth2UserService; String oauth2UserServiceRef = element.getAttribute(ATT_USER_SERVICE_REF); if (!StringUtils.isEmpty(oauth2UserServiceRef)) { - oauth2UserService = new RuntimeBeanReference(oauth2UserServiceRef); - } else { - oauth2UserService = BeanDefinitionBuilder - .rootBeanDefinition("org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService") - .getBeanDefinition(); + return new RuntimeBeanReference(oauth2UserServiceRef); } - return oauth2UserService; + return BeanDefinitionBuilder + .rootBeanDefinition("org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService") + .getBeanDefinition(); } private BeanMetadataElement getAccessTokenResponseClient(Element element) { - BeanMetadataElement accessTokenResponseClient; String accessTokenResponseClientRef = element.getAttribute(ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF); if (!StringUtils.isEmpty(accessTokenResponseClientRef)) { - accessTokenResponseClient = new RuntimeBeanReference(accessTokenResponseClientRef); - } else { - accessTokenResponseClient = BeanDefinitionBuilder.rootBeanDefinition( - "org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient") - .getBeanDefinition(); + return new RuntimeBeanReference(accessTokenResponseClientRef); } - return accessTokenResponseClient; + return BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient") + .getBeanDefinition(); } BeanDefinition getDefaultAuthorizedClientRepository() { @@ -333,23 +317,23 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { } BeanDefinition getOAuth2AuthorizationRequestRedirectFilter() { - return oauth2AuthorizationRequestRedirectFilter; + return this.oauth2AuthorizationRequestRedirectFilter; } BeanDefinition getOAuth2LoginAuthenticationEntryPoint() { - return oauth2LoginAuthenticationEntryPoint; + return this.oauth2LoginAuthenticationEntryPoint; } BeanDefinition getOAuth2LoginAuthenticationProvider() { - return oauth2LoginAuthenticationProvider; + return this.oauth2LoginAuthenticationProvider; } BeanDefinition getOAuth2LoginOidcAuthenticationProvider() { - return oauth2LoginOidcAuthenticationProvider; + return this.oauth2LoginOidcAuthenticationProvider; } BeanDefinition getOAuth2LoginLinks() { - return oauth2LoginLinks; + return this.oauth2LoginLinks; } private Map getLoginEntryPoint(Element element) { @@ -364,10 +348,8 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { RequestMatcher defaultEntryPointMatcher = this.getAuthenticationEntryPointMatcher(); RequestMatcher defaultLoginPageMatcher = new AndRequestMatcher( new OrRequestMatcher(loginPageMatcher, faviconMatcher), defaultEntryPointMatcher); - RequestMatcher notXRequestedWith = new NegatedRequestMatcher( new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest")); - Element clientRegElt = clientRegList.get(0); entryPoints = new LinkedHashMap<>(); entryPoints.put( @@ -395,28 +377,26 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2LoginAuthenticationToken authorizationCodeAuthentication = (OAuth2LoginAuthenticationToken) authentication; - - // Section 3.1.2.1 Authentication Request - - // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope - // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (authorizationCodeAuthentication.getAuthorizationExchange().getAuthorizationRequest().getScopes() + if (!authorizationCodeAuthentication.getAuthorizationExchange().getAuthorizationRequest().getScopes() .contains(OidcScopes.OPENID)) { - - OAuth2Error oauth2Error = new OAuth2Error("oidc_provider_not_configured", - "An OpenID Connect Authentication Provider has not been configured. " - + "Check to ensure you include the dependency 'spring-security-oauth2-jose'.", - null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + return null; } - - return null; + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope + // value. + OAuth2Error oauth2Error = new OAuth2Error("oidc_provider_not_configured", + "An OpenID Connect Authentication Provider has not been configured. " + + "Check to ensure you include the dependency 'spring-security-oauth2-jose'.", + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } @Override public boolean supports(Class authentication) { return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication); } + } /** @@ -432,9 +412,9 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { } @SuppressWarnings({ "unchecked", "unused" }) - public Map getLoginLinks() { + Map getLoginLinks() { Iterable clientRegistrations = null; - ClientRegistrationRepository clientRegistrationRepository = context + ClientRegistrationRepository clientRegistrationRepository = this.context .getBean(ClientRegistrationRepository.class); ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class); if (type != ResolvableType.NONE && ClientRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) { @@ -443,14 +423,14 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { if (clientRegistrations == null) { return Collections.emptyMap(); } - String authorizationRequestBaseUri = DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; Map loginUrlToClientName = new HashMap<>(); - clientRegistrations.forEach(registration -> loginUrlToClientName.put( + clientRegistrations.forEach((registration) -> loginUrlToClientName.put( authorizationRequestBaseUri + "/" + registration.getRegistrationId(), registration.getClientName())); - return loginUrlToClientName; } + } + } diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java index 5068a74b45..2be6d13796 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParser.java @@ -18,6 +18,7 @@ package org.springframework.security.config.http; import java.util.List; import java.util.Map; + import javax.servlet.http.HttpServletRequest; import org.w3c.dom.Element; @@ -52,35 +53,42 @@ import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; /** - * A {@link BeanDefinitionParser} for <http>'s <oauth2-resource-server> element. + * A {@link BeanDefinitionParser} for <http>'s <oauth2-resource-server> + * element. * - * @since 5.3 * @author Josh Cummings + * @since 5.3 */ final class OAuth2ResourceServerBeanDefinitionParser implements BeanDefinitionParser { + static final String AUTHENTICATION_MANAGER_RESOLVER_REF = "authentication-manager-resolver-ref"; + static final String BEARER_TOKEN_RESOLVER_REF = "bearer-token-resolver-ref"; + static final String ENTRY_POINT_REF = "entry-point-ref"; static final String BEARER_TOKEN_RESOLVER = "bearerTokenResolver"; + static final String AUTHENTICATION_ENTRY_POINT = "authenticationEntryPoint"; private final BeanReference authenticationManager; + private final List authenticationProviders; + private final Map entryPoints; + private final Map deniedHandlers; + private final List ignoreCsrfRequestMatchers; - private final BeanDefinition authenticationEntryPoint = - new RootBeanDefinition(BearerTokenAuthenticationEntryPoint.class); - private final BeanDefinition accessDeniedHandler = - new RootBeanDefinition(BearerTokenAccessDeniedHandler.class); + private final BeanDefinition authenticationEntryPoint = new RootBeanDefinition( + BearerTokenAuthenticationEntryPoint.class); + + private final BeanDefinition accessDeniedHandler = new RootBeanDefinition(BearerTokenAccessDeniedHandler.class); OAuth2ResourceServerBeanDefinitionParser(BeanReference authenticationManager, - List authenticationProviders, - Map entryPoints, - Map deniedHandlers, - List ignoreCsrfRequestMatchers) { + List authenticationProviders, Map entryPoints, + Map deniedHandlers, List ignoreCsrfRequestMatchers) { this.authenticationManager = authenticationManager; this.authenticationProviders = authenticationProviders; this.entryPoints = entryPoints; @@ -91,44 +99,36 @@ final class OAuth2ResourceServerBeanDefinitionParser implements BeanDefinitionPa /** * Parse a <oauth2-resource-server> element and return the corresponding * {@link BearerTokenAuthenticationFilter} - * * @param oauth2ResourceServer the <oauth2-resource-server> element. * @param pc the {@link ParserContext} - * @return a {@link BeanDefinition} representing a {@link BearerTokenAuthenticationFilter} definition + * @return a {@link BeanDefinition} representing a + * {@link BearerTokenAuthenticationFilter} definition */ @Override public BeanDefinition parse(Element oauth2ResourceServer, ParserContext pc) { Element jwt = DomUtils.getChildElementByTagName(oauth2ResourceServer, Elements.JWT); Element opaqueToken = DomUtils.getChildElementByTagName(oauth2ResourceServer, Elements.OPAQUE_TOKEN); - validateConfiguration(oauth2ResourceServer, jwt, opaqueToken, pc); - if (jwt != null) { - BeanDefinition jwtAuthenticationProvider = - new JwtBeanDefinitionParser().parse(jwt, pc); - this.authenticationProviders.add(new RuntimeBeanReference - (pc.getReaderContext().registerWithGeneratedName(jwtAuthenticationProvider))); + BeanDefinition jwtAuthenticationProvider = new JwtBeanDefinitionParser().parse(jwt, pc); + this.authenticationProviders.add(new RuntimeBeanReference( + pc.getReaderContext().registerWithGeneratedName(jwtAuthenticationProvider))); } - if (opaqueToken != null) { - BeanDefinition opaqueTokenAuthenticationProvider = - new OpaqueTokenBeanDefinitionParser().parse(opaqueToken, pc); - this.authenticationProviders.add(new RuntimeBeanReference - (pc.getReaderContext().registerWithGeneratedName(opaqueTokenAuthenticationProvider))); + BeanDefinition opaqueTokenAuthenticationProvider = new OpaqueTokenBeanDefinitionParser().parse(opaqueToken, + pc); + this.authenticationProviders.add(new RuntimeBeanReference( + pc.getReaderContext().registerWithGeneratedName(opaqueTokenAuthenticationProvider))); } - BeanMetadataElement bearerTokenResolver = getBearerTokenResolver(oauth2ResourceServer); BeanDefinitionBuilder requestMatcherBuilder = BeanDefinitionBuilder .rootBeanDefinition(BearerTokenRequestMatcher.class); requestMatcherBuilder.addConstructorArgValue(bearerTokenResolver); BeanDefinition requestMatcher = requestMatcherBuilder.getBeanDefinition(); - BeanMetadataElement authenticationEntryPoint = getEntryPoint(oauth2ResourceServer); - this.entryPoints.put(requestMatcher, authenticationEntryPoint); this.deniedHandlers.put(requestMatcher, this.accessDeniedHandler); this.ignoreCsrfRequestMatchers.add(requestMatcher); - BeanDefinitionBuilder filterBuilder = BeanDefinitionBuilder .rootBeanDefinition(BearerTokenAuthenticationFilter.class); BeanMetadataElement authenticationManagerResolver = getAuthenticationManagerResolver(oauth2ResourceServer); @@ -141,23 +141,20 @@ final class OAuth2ResourceServerBeanDefinitionParser implements BeanDefinitionPa void validateConfiguration(Element oauth2ResourceServer, Element jwt, Element opaqueToken, ParserContext pc) { if (!oauth2ResourceServer.hasAttribute(AUTHENTICATION_MANAGER_RESOLVER_REF)) { if (jwt == null && opaqueToken == null) { - pc.getReaderContext().error - ("Didn't find authentication-manager-resolver-ref, , or . " + - "Please select one.", oauth2ResourceServer); + pc.getReaderContext().error("Didn't find authentication-manager-resolver-ref, " + + ", or . " + "Please select one.", oauth2ResourceServer); } return; } - if (jwt != null) { - pc.getReaderContext().error - ("Found as well as authentication-manager-resolver-ref. " + - "Please select just one.", oauth2ResourceServer); + pc.getReaderContext().error( + "Found as well as authentication-manager-resolver-ref. Please select just one.", + oauth2ResourceServer); } - if (opaqueToken != null) { - pc.getReaderContext().error - ("Found as well as authentication-manager-resolver-ref. " + - "Please select just one.", oauth2ResourceServer); + pc.getReaderContext().error( + "Found as well as authentication-manager-resolver-ref. Please select just one.", + oauth2ResourceServer); } } @@ -176,183 +173,182 @@ final class OAuth2ResourceServerBeanDefinitionParser implements BeanDefinitionPa String bearerTokenResolverRef = element.getAttribute(BEARER_TOKEN_RESOLVER_REF); if (StringUtils.isEmpty(bearerTokenResolverRef)) { return new RootBeanDefinition(DefaultBearerTokenResolver.class); - } else { - return new RuntimeBeanReference(bearerTokenResolverRef); } + return new RuntimeBeanReference(bearerTokenResolverRef); } BeanMetadataElement getEntryPoint(Element element) { String entryPointRef = element.getAttribute(ENTRY_POINT_REF); if (StringUtils.isEmpty(entryPointRef)) { return this.authenticationEntryPoint; - } else { - return new RuntimeBeanReference(entryPointRef); } - } -} - -final class JwtBeanDefinitionParser implements BeanDefinitionParser { - static final String DECODER_REF = "decoder-ref"; - static final String JWK_SET_URI = "jwk-set-uri"; - static final String JWT_AUTHENTICATION_CONVERTER_REF = "jwt-authentication-converter-ref"; - static final String JWT_AUTHENTICATION_CONVERTER = "jwtAuthenticationConverter"; - - @Override - public BeanDefinition parse(Element element, ParserContext pc) { - validateConfiguration(element, pc); - - BeanDefinitionBuilder jwtProviderBuilder = - BeanDefinitionBuilder.rootBeanDefinition(JwtAuthenticationProvider.class); - jwtProviderBuilder.addConstructorArgValue(getDecoder(element)); - jwtProviderBuilder.addPropertyValue(JWT_AUTHENTICATION_CONVERTER, getJwtAuthenticationConverter(element)); - - return jwtProviderBuilder.getBeanDefinition(); + return new RuntimeBeanReference(entryPointRef); } - void validateConfiguration(Element element, ParserContext pc) { - boolean usesDecoder = element.hasAttribute(DECODER_REF); - boolean usesJwkSetUri = element.hasAttribute(JWK_SET_URI); + static final class JwtBeanDefinitionParser implements BeanDefinitionParser { - if (usesDecoder == usesJwkSetUri) { - pc.getReaderContext().error - ("Please specify either decoder-ref or jwk-set-uri.", element); - } - } + static final String DECODER_REF = "decoder-ref"; - Object getDecoder(Element element) { - String decoderRef = element.getAttribute(DECODER_REF); - if (!StringUtils.isEmpty(decoderRef)) { - return new RuntimeBeanReference(decoderRef); + static final String JWK_SET_URI = "jwk-set-uri"; + + static final String JWT_AUTHENTICATION_CONVERTER_REF = "jwt-authentication-converter-ref"; + + static final String JWT_AUTHENTICATION_CONVERTER = "jwtAuthenticationConverter"; + + JwtBeanDefinitionParser() { } - BeanDefinitionBuilder builder = BeanDefinitionBuilder - .rootBeanDefinition(NimbusJwtDecoderJwkSetUriFactoryBean.class); - builder.addConstructorArgValue(element.getAttribute(JWK_SET_URI)); - return builder.getBeanDefinition(); - } - - Object getJwtAuthenticationConverter(Element element) { - String jwtDecoderRef = element.getAttribute(JWT_AUTHENTICATION_CONVERTER_REF); - if (!StringUtils.isEmpty(jwtDecoderRef)) { - return new RuntimeBeanReference(jwtDecoderRef); + @Override + public BeanDefinition parse(Element element, ParserContext pc) { + validateConfiguration(element, pc); + BeanDefinitionBuilder jwtProviderBuilder = BeanDefinitionBuilder + .rootBeanDefinition(JwtAuthenticationProvider.class); + jwtProviderBuilder.addConstructorArgValue(getDecoder(element)); + jwtProviderBuilder.addPropertyValue(JWT_AUTHENTICATION_CONVERTER, getJwtAuthenticationConverter(element)); + return jwtProviderBuilder.getBeanDefinition(); } - return new JwtAuthenticationConverter(); - } - - JwtBeanDefinitionParser() {} -} - -final class OpaqueTokenBeanDefinitionParser implements BeanDefinitionParser { - static final String INTROSPECTOR_REF = "introspector-ref"; - static final String INTROSPECTION_URI = "introspection-uri"; - static final String CLIENT_ID = "client-id"; - static final String CLIENT_SECRET = "client-secret"; - - @Override - public BeanDefinition parse(Element element, ParserContext pc) { - validateConfiguration(element, pc); - - BeanMetadataElement introspector = getIntrospector(element); - BeanDefinitionBuilder opaqueTokenProviderBuilder = - BeanDefinitionBuilder.rootBeanDefinition(OpaqueTokenAuthenticationProvider.class); - opaqueTokenProviderBuilder.addConstructorArgValue(introspector); - - return opaqueTokenProviderBuilder.getBeanDefinition(); - } - - void validateConfiguration(Element element, ParserContext pc) { - boolean usesIntrospector = element.hasAttribute(INTROSPECTOR_REF); - boolean usesEndpoint = element.hasAttribute(INTROSPECTION_URI) || - element.hasAttribute(CLIENT_ID) || - element.hasAttribute(CLIENT_SECRET); - - if (usesIntrospector == usesEndpoint) { - pc.getReaderContext().error - ("Please specify either introspector-ref or all of " + - "introspection-uri, client-id, and client-secret.", element); - return; - } - - if (usesEndpoint) { - if (!(element.hasAttribute(INTROSPECTION_URI) && - element.hasAttribute(CLIENT_ID) && - element.hasAttribute(CLIENT_SECRET))) { - pc.getReaderContext().error - ("Please specify introspection-uri, client-id, and client-secret together", element); + void validateConfiguration(Element element, ParserContext pc) { + boolean usesDecoder = element.hasAttribute(DECODER_REF); + boolean usesJwkSetUri = element.hasAttribute(JWK_SET_URI); + if (usesDecoder == usesJwkSetUri) { + pc.getReaderContext().error("Please specify either decoder-ref or jwk-set-uri.", element); } } - } - BeanMetadataElement getIntrospector(Element element) { - String introspectorRef = element.getAttribute(INTROSPECTOR_REF); - if (!StringUtils.isEmpty(introspectorRef)) { - return new RuntimeBeanReference(introspectorRef); + Object getDecoder(Element element) { + String decoderRef = element.getAttribute(DECODER_REF); + if (!StringUtils.isEmpty(decoderRef)) { + return new RuntimeBeanReference(decoderRef); + } + BeanDefinitionBuilder builder = BeanDefinitionBuilder + .rootBeanDefinition(NimbusJwtDecoderJwkSetUriFactoryBean.class); + builder.addConstructorArgValue(element.getAttribute(JWK_SET_URI)); + return builder.getBeanDefinition(); } - String introspectionUri = element.getAttribute(INTROSPECTION_URI); - String clientId = element.getAttribute(CLIENT_ID); - String clientSecret = element.getAttribute(CLIENT_SECRET); - - BeanDefinitionBuilder introspectorBuilder = BeanDefinitionBuilder - .rootBeanDefinition(NimbusOpaqueTokenIntrospector.class); - introspectorBuilder.addConstructorArgValue(introspectionUri); - introspectorBuilder.addConstructorArgValue(clientId); - introspectorBuilder.addConstructorArgValue(clientSecret); - - return introspectorBuilder.getBeanDefinition(); - } - - OpaqueTokenBeanDefinitionParser() {} -} - -final class StaticAuthenticationManagerResolver implements - AuthenticationManagerResolver { - private final AuthenticationManager authenticationManager; - - StaticAuthenticationManagerResolver(AuthenticationManager authenticationManager) { - this.authenticationManager = authenticationManager; - } - - @Override - public AuthenticationManager resolve(HttpServletRequest context) { - return this.authenticationManager; - } -} - -final class NimbusJwtDecoderJwkSetUriFactoryBean implements FactoryBean { - private final String jwkSetUri; - - NimbusJwtDecoderJwkSetUriFactoryBean(String jwkSetUri) { - this.jwkSetUri = jwkSetUri; - } - - @Override - public JwtDecoder getObject() { - return NimbusJwtDecoder.withJwkSetUri(this.jwkSetUri).build(); - } - - @Override - public Class getObjectType() { - return JwtDecoder.class; - } -} - -final class BearerTokenRequestMatcher implements RequestMatcher { - private final BearerTokenResolver bearerTokenResolver; - - BearerTokenRequestMatcher(BearerTokenResolver bearerTokenResolver) { - Assert.notNull(bearerTokenResolver, "bearerTokenResolver cannot be null"); - this.bearerTokenResolver = bearerTokenResolver; - } - - @Override - public boolean matches(HttpServletRequest request) { - try { - return this.bearerTokenResolver.resolve(request) != null; - } catch (OAuth2AuthenticationException e) { - return false; + Object getJwtAuthenticationConverter(Element element) { + String jwtDecoderRef = element.getAttribute(JWT_AUTHENTICATION_CONVERTER_REF); + return (!StringUtils.isEmpty(jwtDecoderRef)) ? new RuntimeBeanReference(jwtDecoderRef) + : new JwtAuthenticationConverter(); } - } -} + } + + static final class OpaqueTokenBeanDefinitionParser implements BeanDefinitionParser { + + static final String INTROSPECTOR_REF = "introspector-ref"; + + static final String INTROSPECTION_URI = "introspection-uri"; + + static final String CLIENT_ID = "client-id"; + + static final String CLIENT_SECRET = "client-secret"; + + OpaqueTokenBeanDefinitionParser() { + } + + @Override + public BeanDefinition parse(Element element, ParserContext pc) { + validateConfiguration(element, pc); + BeanMetadataElement introspector = getIntrospector(element); + BeanDefinitionBuilder opaqueTokenProviderBuilder = BeanDefinitionBuilder + .rootBeanDefinition(OpaqueTokenAuthenticationProvider.class); + opaqueTokenProviderBuilder.addConstructorArgValue(introspector); + return opaqueTokenProviderBuilder.getBeanDefinition(); + } + + void validateConfiguration(Element element, ParserContext pc) { + boolean usesIntrospector = element.hasAttribute(INTROSPECTOR_REF); + boolean usesEndpoint = element.hasAttribute(INTROSPECTION_URI) || element.hasAttribute(CLIENT_ID) + || element.hasAttribute(CLIENT_SECRET); + if (usesIntrospector == usesEndpoint) { + pc.getReaderContext().error("Please specify either introspector-ref or all of " + + "introspection-uri, client-id, and client-secret.", element); + return; + } + if (usesEndpoint) { + if (!(element.hasAttribute(INTROSPECTION_URI) && element.hasAttribute(CLIENT_ID) + && element.hasAttribute(CLIENT_SECRET))) { + pc.getReaderContext() + .error("Please specify introspection-uri, client-id, and client-secret together", element); + } + } + } + + BeanMetadataElement getIntrospector(Element element) { + String introspectorRef = element.getAttribute(INTROSPECTOR_REF); + if (!StringUtils.isEmpty(introspectorRef)) { + return new RuntimeBeanReference(introspectorRef); + } + String introspectionUri = element.getAttribute(INTROSPECTION_URI); + String clientId = element.getAttribute(CLIENT_ID); + String clientSecret = element.getAttribute(CLIENT_SECRET); + BeanDefinitionBuilder introspectorBuilder = BeanDefinitionBuilder + .rootBeanDefinition(NimbusOpaqueTokenIntrospector.class); + introspectorBuilder.addConstructorArgValue(introspectionUri); + introspectorBuilder.addConstructorArgValue(clientId); + introspectorBuilder.addConstructorArgValue(clientSecret); + return introspectorBuilder.getBeanDefinition(); + } + + } + + static final class StaticAuthenticationManagerResolver + implements AuthenticationManagerResolver { + + private final AuthenticationManager authenticationManager; + + StaticAuthenticationManagerResolver(AuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + } + + @Override + public AuthenticationManager resolve(HttpServletRequest context) { + return this.authenticationManager; + } + + } + + static final class NimbusJwtDecoderJwkSetUriFactoryBean implements FactoryBean { + + private final String jwkSetUri; + + NimbusJwtDecoderJwkSetUriFactoryBean(String jwkSetUri) { + this.jwkSetUri = jwkSetUri; + } + + @Override + public JwtDecoder getObject() { + return NimbusJwtDecoder.withJwkSetUri(this.jwkSetUri).build(); + } + + @Override + public Class getObjectType() { + return JwtDecoder.class; + } + + } + + static final class BearerTokenRequestMatcher implements RequestMatcher { + + private final BearerTokenResolver bearerTokenResolver; + + BearerTokenRequestMatcher(BearerTokenResolver bearerTokenResolver) { + Assert.notNull(bearerTokenResolver, "bearerTokenResolver cannot be null"); + this.bearerTokenResolver = bearerTokenResolver; + } + + @Override + public boolean matches(HttpServletRequest request) { + try { + return this.bearerTokenResolver.resolve(request) != null; + } + catch (OAuth2AuthenticationException ex) { + return false; + } + } + + } + +} diff --git a/config/src/main/java/org/springframework/security/config/http/OrderDecorator.java b/config/src/main/java/org/springframework/security/config/http/OrderDecorator.java new file mode 100644 index 0000000000..3b14abaf8d --- /dev/null +++ b/config/src/main/java/org/springframework/security/config/http/OrderDecorator.java @@ -0,0 +1,53 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.http; + +import org.springframework.beans.BeanMetadataElement; +import org.springframework.core.Ordered; + +/** + * Wrapper to provide ordering to a {@link BeanMetadataElement}. + * + * @author Rob Winch + */ +class OrderDecorator implements Ordered { + + final BeanMetadataElement bean; + + final int order; + + OrderDecorator(BeanMetadataElement bean, SecurityFilters filterOrder) { + this.bean = bean; + this.order = filterOrder.getOrder(); + } + + OrderDecorator(BeanMetadataElement bean, int order) { + this.bean = bean; + this.order = order; + } + + @Override + public int getOrder() { + return this.order; + } + + @Override + public String toString() { + return this.bean + ", order = " + this.order; + } + +} diff --git a/config/src/main/java/org/springframework/security/config/http/PortMappingsBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/PortMappingsBeanDefinitionParser.java index 4dd13b82d9..dff5ebfa5b 100644 --- a/config/src/main/java/org/springframework/security/config/http/PortMappingsBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/PortMappingsBeanDefinitionParser.java @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.List; import java.util.Map; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.ManagedMap; import org.springframework.beans.factory.support.RootBeanDefinition; @@ -27,7 +30,6 @@ import org.springframework.security.config.Elements; import org.springframework.security.web.PortMapperImpl; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** * Parses a port-mappings element, producing a single @@ -36,44 +38,36 @@ import org.w3c.dom.Element; * @author Luke Taylor */ class PortMappingsBeanDefinitionParser implements BeanDefinitionParser { + public static final String ATT_HTTP_PORT = "http"; + public static final String ATT_HTTPS_PORT = "https"; + @Override @SuppressWarnings("unchecked") public BeanDefinition parse(Element element, ParserContext parserContext) { RootBeanDefinition portMapper = new RootBeanDefinition(PortMapperImpl.class); portMapper.setSource(parserContext.extractSource(element)); - if (element != null) { - List mappingElts = DomUtils.getChildElementsByTagName(element, - Elements.PORT_MAPPING); + List mappingElts = DomUtils.getChildElementsByTagName(element, Elements.PORT_MAPPING); if (mappingElts.isEmpty()) { - parserContext.getReaderContext().error( - "No port-mapping child elements specified", element); + parserContext.getReaderContext().error("No port-mapping child elements specified", element); } - Map mappings = new ManagedMap(); - for (Element elt : mappingElts) { String httpPort = elt.getAttribute(ATT_HTTP_PORT); String httpsPort = elt.getAttribute(ATT_HTTPS_PORT); - if (!StringUtils.hasText(httpPort)) { - parserContext.getReaderContext().error( - "No http port supplied in port mapping", elt); + parserContext.getReaderContext().error("No http port supplied in port mapping", elt); } - if (!StringUtils.hasText(httpsPort)) { - parserContext.getReaderContext().error( - "No https port supplied in port mapping", elt); + parserContext.getReaderContext().error("No https port supplied in port mapping", elt); } - mappings.put(httpPort, httpsPort); } - portMapper.getPropertyValues().addPropertyValue("portMappings", mappings); } - return portMapper; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java index 0fbb14c9bc..85ac360cf6 100644 --- a/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/RememberMeBeanDefinitionParser.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; import org.springframework.beans.factory.config.RuntimeBeanReference; @@ -32,7 +35,6 @@ import org.springframework.security.web.authentication.rememberme.PersistentToke import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter; import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; /** * @author Luke Taylor @@ -41,20 +43,33 @@ import org.w3c.dom.Element; * @author Oliver Becker */ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { + static final String ATT_DATA_SOURCE = "data-source-ref"; + static final String ATT_SERVICES_REF = "services-ref"; + static final String ATT_SERVICES_ALIAS = "services-alias"; + static final String ATT_TOKEN_REPOSITORY = "token-repository-ref"; + static final String ATT_USER_SERVICE_REF = "user-service-ref"; + static final String ATT_SUCCESS_HANDLER_REF = "authentication-success-handler-ref"; + static final String ATT_TOKEN_VALIDITY = "token-validity-seconds"; + static final String ATT_SECURE_COOKIE = "use-secure-cookie"; + static final String ATT_FORM_REMEMBERME_PARAMETER = "remember-me-parameter"; + static final String ATT_REMEMBERME_COOKIE = "remember-me-cookie"; protected final Log logger = LogFactory.getLog(getClass()); + private final String key; + private final BeanReference authenticationManager; + private String rememberMeServicesId; RememberMeBeanDefinitionParser(String key, BeanReference authenticationManager) { @@ -62,11 +77,11 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { this.authenticationManager = authenticationManager; } + @Override public BeanDefinition parse(Element element, ParserContext pc) { - CompositeComponentDefinition compositeDef = new CompositeComponentDefinition( - element.getTagName(), pc.extractSource(element)); + CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), + pc.extractSource(element)); pc.pushContainingComponent(compositeDef); - String tokenRepository = element.getAttribute(ATT_TOKEN_REPOSITORY); String dataSource = element.getAttribute(ATT_DATA_SOURCE); String userServiceRef = element.getAttribute(ATT_USER_SERVICE_REF); @@ -77,9 +92,7 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { String remembermeParameter = element.getAttribute(ATT_FORM_REMEMBERME_PARAMETER); String remembermeCookie = element.getAttribute(ATT_REMEMBERME_COOKIE); Object source = pc.extractSource(element); - RootBeanDefinition services = null; - boolean dataSourceSet = StringUtils.hasText(dataSource); boolean tokenRepoSet = StringUtils.hasText(tokenRepository); boolean servicesRefSet = StringUtils.hasText(rememberMeServicesRef); @@ -88,85 +101,62 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { boolean tokenValiditySet = StringUtils.hasText(tokenValiditySeconds); boolean remembermeParameterSet = StringUtils.hasText(remembermeParameter); boolean remembermeCookieSet = StringUtils.hasText(remembermeCookie); - - if (servicesRefSet - && (dataSourceSet || tokenRepoSet || userServiceSet || tokenValiditySet - || useSecureCookieSet || remembermeParameterSet || remembermeCookieSet)) { - pc.getReaderContext().error( - ATT_SERVICES_REF + " can't be used in combination with attributes " - + ATT_TOKEN_REPOSITORY + "," + ATT_DATA_SOURCE + ", " - + ATT_USER_SERVICE_REF + ", " + ATT_TOKEN_VALIDITY + ", " - + ATT_SECURE_COOKIE + ", " + ATT_FORM_REMEMBERME_PARAMETER - + " or " + ATT_REMEMBERME_COOKIE, source); + if (servicesRefSet && (dataSourceSet || tokenRepoSet || userServiceSet || tokenValiditySet || useSecureCookieSet + || remembermeParameterSet || remembermeCookieSet)) { + pc.getReaderContext() + .error(ATT_SERVICES_REF + " can't be used in combination with attributes " + ATT_TOKEN_REPOSITORY + + "," + ATT_DATA_SOURCE + ", " + ATT_USER_SERVICE_REF + ", " + ATT_TOKEN_VALIDITY + ", " + + ATT_SECURE_COOKIE + ", " + ATT_FORM_REMEMBERME_PARAMETER + " or " + ATT_REMEMBERME_COOKIE, + source); } - if (dataSourceSet && tokenRepoSet) { - pc.getReaderContext().error( - "Specify " + ATT_TOKEN_REPOSITORY + " or " + ATT_DATA_SOURCE - + " but not both", source); + pc.getReaderContext().error("Specify " + ATT_TOKEN_REPOSITORY + " or " + ATT_DATA_SOURCE + " but not both", + source); } - boolean isPersistent = dataSourceSet | tokenRepoSet; - if (isPersistent) { Object tokenRepo; - services = new RootBeanDefinition( - PersistentTokenBasedRememberMeServices.class); - + services = new RootBeanDefinition(PersistentTokenBasedRememberMeServices.class); if (tokenRepoSet) { tokenRepo = new RuntimeBeanReference(tokenRepository); } else { tokenRepo = new RootBeanDefinition(JdbcTokenRepositoryImpl.class); - ((BeanDefinition) tokenRepo).getPropertyValues().addPropertyValue( - "dataSource", new RuntimeBeanReference(dataSource)); + ((BeanDefinition) tokenRepo).getPropertyValues().addPropertyValue("dataSource", + new RuntimeBeanReference(dataSource)); } services.getConstructorArgumentValues().addIndexedArgumentValue(2, tokenRepo); } else if (!servicesRefSet) { services = new RootBeanDefinition(TokenBasedRememberMeServices.class); } - String servicesName; - if (services != null) { RootBeanDefinition uds = new RootBeanDefinition(); uds.setFactoryBeanName(BeanIds.USER_DETAILS_SERVICE_FACTORY); uds.setFactoryMethodName("cachingUserDetailsService"); uds.getConstructorArgumentValues().addGenericArgumentValue(userServiceRef); - - services.getConstructorArgumentValues().addGenericArgumentValue(key); + services.getConstructorArgumentValues().addGenericArgumentValue(this.key); services.getConstructorArgumentValues().addGenericArgumentValue(uds); // tokenRepo is already added if it is a // PersistentTokenBasedRememberMeServices - if (useSecureCookieSet) { - services.getPropertyValues().addPropertyValue("useSecureCookie", - Boolean.valueOf(useSecureCookie)); + services.getPropertyValues().addPropertyValue("useSecureCookie", Boolean.valueOf(useSecureCookie)); } - if (tokenValiditySet) { boolean isTokenValidityNegative = tokenValiditySeconds.startsWith("-"); if (isTokenValidityNegative && isPersistent) { - pc.getReaderContext().error( - ATT_TOKEN_VALIDITY + " cannot be negative if using" - + " a persistent remember-me token repository", - source); + pc.getReaderContext().error(ATT_TOKEN_VALIDITY + " cannot be negative if using" + + " a persistent remember-me token repository", source); } - services.getPropertyValues().addPropertyValue("tokenValiditySeconds", - tokenValiditySeconds); + services.getPropertyValues().addPropertyValue("tokenValiditySeconds", tokenValiditySeconds); } - if (remembermeParameterSet) { - services.getPropertyValues().addPropertyValue("parameter", - remembermeParameter); + services.getPropertyValues().addPropertyValue("parameter", remembermeParameter); } - if (remembermeCookieSet) { - services.getPropertyValues().addPropertyValue("cookieName", - remembermeCookie); + services.getPropertyValues().addPropertyValue("cookieName", remembermeCookie); } - services.setSource(source); servicesName = pc.getReaderContext().generateBeanName(services); pc.registerBeanComponent(new BeanComponentDefinition(services, servicesName)); @@ -174,31 +164,23 @@ class RememberMeBeanDefinitionParser implements BeanDefinitionParser { else { servicesName = rememberMeServicesRef; } - if (StringUtils.hasText(element.getAttribute(ATT_SERVICES_ALIAS))) { - pc.getRegistry().registerAlias(servicesName, - element.getAttribute(ATT_SERVICES_ALIAS)); + pc.getRegistry().registerAlias(servicesName, element.getAttribute(ATT_SERVICES_ALIAS)); } - this.rememberMeServicesId = servicesName; - - BeanDefinitionBuilder filter = BeanDefinitionBuilder - .rootBeanDefinition(RememberMeAuthenticationFilter.class); + BeanDefinitionBuilder filter = BeanDefinitionBuilder.rootBeanDefinition(RememberMeAuthenticationFilter.class); filter.getRawBeanDefinition().setSource(source); - if (StringUtils.hasText(successHandlerRef)) { filter.addPropertyReference("authenticationSuccessHandler", successHandlerRef); } - - filter.addConstructorArgValue(authenticationManager); + filter.addConstructorArgValue(this.authenticationManager); filter.addConstructorArgReference(servicesName); - pc.popAndRegisterContainingComponent(); - return filter.getBeanDefinition(); } String getRememberMeServicesId() { return this.rememberMeServicesId; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/SecurityFilters.java b/config/src/main/java/org/springframework/security/config/http/SecurityFilters.java index ee6f366d48..d2c34c9ac5 100644 --- a/config/src/main/java/org/springframework/security/config/http/SecurityFilters.java +++ b/config/src/main/java/org/springframework/security/config/http/SecurityFilters.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.http; -import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter; +package org.springframework.security.config.http; /** * Stores the default order numbers of all Spring Security filters for use in @@ -26,43 +25,77 @@ import org.springframework.security.web.context.request.async.WebAsyncManagerInt */ enum SecurityFilters { + FIRST(Integer.MIN_VALUE), + CHANNEL_FILTER, + SECURITY_CONTEXT_FILTER, + CONCURRENT_SESSION_FILTER, - WEB_ASYNC_MANAGER_FILTER /** {@link WebAsyncManagerIntegrationFilter} */, - HEADERS_FILTER, CORS_FILTER, + + WEB_ASYNC_MANAGER_FILTER, + + HEADERS_FILTER, + + CORS_FILTER, + CSRF_FILTER, + LOGOUT_FILTER, + OAUTH2_AUTHORIZATION_REQUEST_FILTER, + X509_FILTER, + PRE_AUTH_FILTER, + CAS_FILTER, + OAUTH2_LOGIN_FILTER, + FORM_LOGIN_FILTER, + OPENID_FILTER, + LOGIN_PAGE_FILTER, + LOGOUT_PAGE_FILTER, + DIGEST_AUTH_FILTER, + BEARER_TOKEN_AUTH_FILTER, + BASIC_AUTH_FILTER, + REQUEST_CACHE_FILTER, + SERVLET_API_SUPPORT_FILTER, + JAAS_API_SUPPORT_FILTER, + REMEMBER_ME_FILTER, + ANONYMOUS_FILTER, + OAUTH2_AUTHORIZATION_CODE_GRANT_FILTER, + SESSION_MANAGEMENT_FILTER, + EXCEPTION_TRANSLATION_FILTER, + FILTER_SECURITY_INTERCEPTOR, + SWITCH_USER_FILTER, + LAST(Integer.MAX_VALUE); private static final int INTERVAL = 100; + private final int order; SecurityFilters() { - order = ordinal() * INTERVAL; + this.order = ordinal() * INTERVAL; } SecurityFilters(int order) { @@ -70,6 +103,7 @@ enum SecurityFilters { } public int getOrder() { - return order; + return this.order; } + } diff --git a/config/src/main/java/org/springframework/security/config/http/SessionCreationPolicy.java b/config/src/main/java/org/springframework/security/config/http/SessionCreationPolicy.java index fc1f2636b2..74beef711d 100644 --- a/config/src/main/java/org/springframework/security/config/http/SessionCreationPolicy.java +++ b/config/src/main/java/org/springframework/security/config/http/SessionCreationPolicy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import javax.servlet.http.HttpSession; @@ -26,18 +27,27 @@ import org.springframework.security.core.context.SecurityContext; * @since 3.1 */ public enum SessionCreationPolicy { - /** Always create an {@link HttpSession} */ + + /** + * Always create an {@link HttpSession} + */ ALWAYS, + /** * Spring Security will never create an {@link HttpSession}, but will use the * {@link HttpSession} if it already exists */ NEVER, - /** Spring Security will only create an {@link HttpSession} if required */ + + /** + * Spring Security will only create an {@link HttpSession} if required + */ IF_REQUIRED, + /** * Spring Security will never create an {@link HttpSession} and it will never use it * to obtain the {@link SecurityContext} */ STATELESS + } diff --git a/config/src/main/java/org/springframework/security/config/http/UserDetailsServiceFactoryBean.java b/config/src/main/java/org/springframework/security/config/http/UserDetailsServiceFactoryBean.java index d7e18b7dae..851416c3f8 100644 --- a/config/src/main/java/org/springframework/security/config/http/UserDetailsServiceFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/http/UserDetailsServiceFactoryBean.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.Map; @@ -24,8 +25,8 @@ import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationContextException; -import org.springframework.security.config.authentication.AbstractUserDetailsServiceBeanDefinitionParser; import org.springframework.security.authentication.CachingUserDetailsService; +import org.springframework.security.config.authentication.AbstractUserDetailsServiceBeanDefinitionParser; import org.springframework.security.core.userdetails.AuthenticationUserDetailsService; import org.springframework.security.core.userdetails.UserDetailsByNameServiceWrapper; import org.springframework.security.core.userdetails.UserDetailsService; @@ -45,8 +46,7 @@ public class UserDetailsServiceFactoryBean implements ApplicationContextAware { if (!StringUtils.hasText(id)) { return getUserDetailsService(); } - - return (UserDetailsService) beanFactory.getBean(id); + return (UserDetailsService) this.beanFactory.getBean(id); } UserDetailsService cachingUserDetailsService(String id) { @@ -54,54 +54,43 @@ public class UserDetailsServiceFactoryBean implements ApplicationContextAware { return getUserDetailsService(); } // Overwrite with the caching version if available - String cachingId = id - + AbstractUserDetailsServiceBeanDefinitionParser.CACHING_SUFFIX; - - if (beanFactory.containsBeanDefinition(cachingId)) { - return (UserDetailsService) beanFactory.getBean(cachingId); + String cachingId = id + AbstractUserDetailsServiceBeanDefinitionParser.CACHING_SUFFIX; + if (this.beanFactory.containsBeanDefinition(cachingId)) { + return (UserDetailsService) this.beanFactory.getBean(cachingId); } - - return (UserDetailsService) beanFactory.getBean(id); + return (UserDetailsService) this.beanFactory.getBean(id); } @SuppressWarnings("unchecked") AuthenticationUserDetailsService authenticationUserDetailsService(String name) { UserDetailsService uds; - if (!StringUtils.hasText(name)) { Map beans = getBeansOfType(AuthenticationUserDetailsService.class); - if (!beans.isEmpty()) { if (beans.size() > 1) { - throw new ApplicationContextException( - "More than one AuthenticationUserDetailsService registered." - + " Please use a specific Id reference."); + throw new ApplicationContextException("More than one AuthenticationUserDetailsService registered." + + " Please use a specific Id reference."); } return (AuthenticationUserDetailsService) beans.values().toArray()[0]; } - uds = getUserDetailsService(); } else { - Object bean = beanFactory.getBean(name); - + Object bean = this.beanFactory.getBean(name); if (bean instanceof AuthenticationUserDetailsService) { return (AuthenticationUserDetailsService) bean; } else if (bean instanceof UserDetailsService) { uds = cachingUserDetailsService(name); - if (uds == null) { uds = (UserDetailsService) bean; } } else { - throw new ApplicationContextException("Bean '" + name - + "' must be a UserDetailsService or an" - + " AuthenticationUserDetailsService"); + throw new ApplicationContextException( + "Bean '" + name + "' must be a UserDetailsService or an" + " AuthenticationUserDetailsService"); } } - return new UserDetailsByNameServiceWrapper(uds); } @@ -112,35 +101,29 @@ public class UserDetailsServiceFactoryBean implements ApplicationContextAware { */ private UserDetailsService getUserDetailsService() { Map beans = getBeansOfType(CachingUserDetailsService.class); - if (beans.size() == 0) { beans = getBeansOfType(UserDetailsService.class); } - if (beans.size() == 0) { throw new ApplicationContextException("No UserDetailsService registered."); - } - else if (beans.size() > 1) { - throw new ApplicationContextException( - "More than one UserDetailsService registered. Please " - + "use a specific Id reference in or elements."); + if (beans.size() > 1) { + throw new ApplicationContextException("More than one UserDetailsService registered. Please " + + "use a specific Id reference in or elements."); } - return (UserDetailsService) beans.values().toArray()[0]; } - public void setApplicationContext(ApplicationContext beanFactory) - throws BeansException { + @Override + public void setApplicationContext(ApplicationContext beanFactory) throws BeansException { this.beanFactory = beanFactory; } private Map getBeansOfType(Class type) { - Map beans = beanFactory.getBeansOfType(type); - + Map beans = this.beanFactory.getBeansOfType(type); // Check ancestor bean factories if they exist and the current one has none of the // required type - BeanFactory parent = beanFactory.getParentBeanFactory(); + BeanFactory parent = this.beanFactory.getParentBeanFactory(); while (parent != null && beans.size() == 0) { if (parent instanceof ListableBeanFactory) { beans = ((ListableBeanFactory) parent).getBeansOfType(type); @@ -152,7 +135,6 @@ public class UserDetailsServiceFactoryBean implements ApplicationContextAware { break; } } - return beans; } diff --git a/config/src/main/java/org/springframework/security/config/http/WebConfigUtils.java b/config/src/main/java/org/springframework/security/config/http/WebConfigUtils.java index ea53d9dc82..64aa026c36 100644 --- a/config/src/main/java/org/springframework/security/config/http/WebConfigUtils.java +++ b/config/src/main/java/org/springframework/security/config/http/WebConfigUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.springframework.beans.factory.xml.ParserContext; @@ -26,17 +27,18 @@ import org.springframework.util.StringUtils; * @author Luke Taylor * @author Ben Alex */ -abstract class WebConfigUtils { +final class WebConfigUtils { - public static int countNonEmpty(String[] objects) { + private WebConfigUtils() { + } + + static int countNonEmpty(String[] objects) { int nonNulls = 0; - for (String object : objects) { if (StringUtils.hasText(object)) { nonNulls++; } } - return nonNulls; } @@ -46,13 +48,11 @@ abstract class WebConfigUtils { * SpEL), "/" or "http" it will raise an error. */ static void validateHttpRedirect(String url, ParserContext pc, Object source) { - if (!StringUtils.hasText(url) || UrlUtils.isValidRedirectUrl(url) - || url.startsWith("$") || url.startsWith("#")) { + if (!StringUtils.hasText(url) || UrlUtils.isValidRedirectUrl(url) || url.startsWith("$") + || url.startsWith("#")) { return; } - pc.getReaderContext().warning( - url + " is not a valid redirect URL (must start with '/' or http(s))", - source); + pc.getReaderContext().warning(url + " is not a valid redirect URL (must start with '/' or http(s))", source); } } diff --git a/config/src/main/java/org/springframework/security/config/http/package-info.java b/config/src/main/java/org/springframework/security/config/http/package-info.java index 654cd0d372..4a02a0ebda 100644 --- a/config/src/main/java/org/springframework/security/config/http/package-info.java +++ b/config/src/main/java/org/springframework/security/config/http/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Parsing of the <http> namespace element. */ package org.springframework.security.config.http; - diff --git a/config/src/main/java/org/springframework/security/config/ldap/ContextSourceSettingPostProcessor.java b/config/src/main/java/org/springframework/security/config/ldap/ContextSourceSettingPostProcessor.java index a55523d04f..1a7b5e72b7 100644 --- a/config/src/main/java/org/springframework/security/config/ldap/ContextSourceSettingPostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/ldap/ContextSourceSettingPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.ldap; import org.springframework.beans.BeansException; @@ -26,15 +27,16 @@ import org.springframework.util.ClassUtils; /** * Checks for the presence of a ContextSource instance. Also supplies the standard - * reference to any unconfigured or - * beans. This is necessary in cases where the user has given the server a specific Id, - * but hasn't used the server-ref attribute to link this to the other ldap definitions. - * See SEC-799. + * reference to any unconfigured <ldap-authentication-provider> or + * <ldap-user-service> beans. This is necessary in cases where the user has given + * the server a specific Id, but hasn't used the server-ref attribute to link this to the + * other ldap definitions. See SEC-799. * * @author Luke Taylor * @since 3.0 */ -class ContextSourceSettingPostProcessor implements BeanFactoryPostProcessor, Ordered { +public class ContextSourceSettingPostProcessor implements BeanFactoryPostProcessor, Ordered { + private static final String REQUIRED_CONTEXT_SOURCE_CLASS_NAME = "org.springframework.ldap.core.support.BaseLdapPathContextSource"; /** @@ -43,51 +45,46 @@ class ContextSourceSettingPostProcessor implements BeanFactoryPostProcessor, Ord */ private boolean defaultNameRequired; - public void postProcessBeanFactory(ConfigurableListableBeanFactory bf) - throws BeansException { - Class contextSourceClass; - - try { - contextSourceClass = ClassUtils.forName(REQUIRED_CONTEXT_SOURCE_CLASS_NAME, - ClassUtils.getDefaultClassLoader()); - } - catch (ClassNotFoundException e) { - throw new ApplicationContextException( - "Couldn't locate: " - + REQUIRED_CONTEXT_SOURCE_CLASS_NAME - + ". " - + " If you are using LDAP with Spring Security, please ensure that you include the spring-ldap " - + "jar file in your application", e); - } + ContextSourceSettingPostProcessor() { + } + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory bf) throws BeansException { + Class contextSourceClass = getContextSourceClass(); String[] sources = bf.getBeanNamesForType(contextSourceClass, false, false); - if (sources.length == 0) { - throw new ApplicationContextException( - "No BaseLdapPathContextSource instances found. Have you " - + "added an <" + Elements.LDAP_SERVER - + " /> element to your application context? If you have " - + "declared an explicit bean, do not use lazy-init"); + throw new ApplicationContextException("No BaseLdapPathContextSource instances found. Have you " + + "added an <" + Elements.LDAP_SERVER + " /> element to your application context? If you have " + + "declared an explicit bean, do not use lazy-init"); } - - if (!bf.containsBean(BeanIds.CONTEXT_SOURCE) && defaultNameRequired) { + if (!bf.containsBean(BeanIds.CONTEXT_SOURCE) && this.defaultNameRequired) { if (sources.length > 1) { - throw new ApplicationContextException( - "More than one BaseLdapPathContextSource instance found. " - + "Please specify a specific server id using the 'server-ref' attribute when configuring your <" - + Elements.LDAP_PROVIDER + "> " + "or <" - + Elements.LDAP_USER_SERVICE + ">."); + throw new ApplicationContextException("More than one BaseLdapPathContextSource instance found. " + + "Please specify a specific server id using the 'server-ref' attribute when configuring your <" + + Elements.LDAP_PROVIDER + "> " + "or <" + Elements.LDAP_USER_SERVICE + ">."); } - bf.registerAlias(sources[0], BeanIds.CONTEXT_SOURCE); } } + private Class getContextSourceClass() throws LinkageError { + try { + return ClassUtils.forName(REQUIRED_CONTEXT_SOURCE_CLASS_NAME, ClassUtils.getDefaultClassLoader()); + } + catch (ClassNotFoundException ex) { + throw new ApplicationContextException("Couldn't locate: " + REQUIRED_CONTEXT_SOURCE_CLASS_NAME + ". " + + " If you are using LDAP with Spring Security, please ensure that you include the spring-ldap " + + "jar file in your application", ex); + } + } + public void setDefaultNameRequired(boolean defaultNameRequired) { this.defaultNameRequired = defaultNameRequired; } + @Override public int getOrder() { return LOWEST_PRECEDENCE; } + } diff --git a/config/src/main/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParser.java index e2d86aab62..3e735458f4 100644 --- a/config/src/main/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/ldap/LdapProviderBeanDefinitionParser.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.ldap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; @@ -26,7 +29,6 @@ import org.springframework.security.config.Elements; import org.springframework.security.config.authentication.PasswordEncoderParser; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** * Ldap authentication provider namespace configuration. @@ -35,34 +37,36 @@ import org.w3c.dom.Element; * @since 2.0 */ public class LdapProviderBeanDefinitionParser implements BeanDefinitionParser { + private final Log logger = LogFactory.getLog(getClass()); private static final String ATT_USER_DN_PATTERN = "user-dn-pattern"; + private static final String ATT_USER_PASSWORD = "password-attribute"; + private static final String ATT_HASH = PasswordEncoderParser.ATT_HASH; private static final String DEF_USER_SEARCH_FILTER = "uid={0}"; static final String PROVIDER_CLASS = "org.springframework.security.ldap.authentication.LdapAuthenticationProvider"; + static final String BIND_AUTH_CLASS = "org.springframework.security.ldap.authentication.BindAuthenticator"; + static final String PASSWD_AUTH_CLASS = "org.springframework.security.ldap.authentication.PasswordComparisonAuthenticator"; + @Override public BeanDefinition parse(Element elt, ParserContext parserContext) { - RuntimeBeanReference contextSource = LdapUserServiceBeanDefinitionParser - .parseServerReference(elt, parserContext); - - BeanDefinition searchBean = LdapUserServiceBeanDefinitionParser.parseSearchBean( - elt, parserContext); + RuntimeBeanReference contextSource = LdapUserServiceBeanDefinitionParser.parseServerReference(elt, + parserContext); + BeanDefinition searchBean = LdapUserServiceBeanDefinitionParser.parseSearchBean(elt, parserContext); String userDnPattern = elt.getAttribute(ATT_USER_DN_PATTERN); - String[] userDnPatternArray = new String[0]; - if (StringUtils.hasText(userDnPattern)) { userDnPatternArray = new String[] { userDnPattern }; // TODO: Validate the pattern and make sure it is a valid DN. } else if (searchBean == null) { - logger.info("No search information or DN pattern specified. Using default search filter '" + this.logger.info("No search information or DN pattern specified. Using default search filter '" + DEF_USER_SEARCH_FILTER + "'"); BeanDefinitionBuilder searchBeanBuilder = BeanDefinitionBuilder .rootBeanDefinition(LdapUserServiceBeanDefinitionParser.LDAP_SEARCH_CLASS); @@ -72,61 +76,43 @@ public class LdapProviderBeanDefinitionParser implements BeanDefinitionParser { searchBeanBuilder.addConstructorArgValue(contextSource); searchBean = searchBeanBuilder.getBeanDefinition(); } - - BeanDefinitionBuilder authenticatorBuilder = BeanDefinitionBuilder - .rootBeanDefinition(BIND_AUTH_CLASS); - Element passwordCompareElt = DomUtils.getChildElementByTagName(elt, - Elements.LDAP_PASSWORD_COMPARE); - + BeanDefinitionBuilder authenticatorBuilder = BeanDefinitionBuilder.rootBeanDefinition(BIND_AUTH_CLASS); + Element passwordCompareElt = DomUtils.getChildElementByTagName(elt, Elements.LDAP_PASSWORD_COMPARE); if (passwordCompareElt != null) { - authenticatorBuilder = BeanDefinitionBuilder - .rootBeanDefinition(PASSWD_AUTH_CLASS); - + authenticatorBuilder = BeanDefinitionBuilder.rootBeanDefinition(PASSWD_AUTH_CLASS); String passwordAttribute = passwordCompareElt.getAttribute(ATT_USER_PASSWORD); if (StringUtils.hasText(passwordAttribute)) { - authenticatorBuilder.addPropertyValue("passwordAttributeName", - passwordAttribute); + authenticatorBuilder.addPropertyValue("passwordAttributeName", passwordAttribute); } - - Element passwordEncoderElement = DomUtils.getChildElementByTagName( - passwordCompareElt, Elements.PASSWORD_ENCODER); + Element passwordEncoderElement = DomUtils.getChildElementByTagName(passwordCompareElt, + Elements.PASSWORD_ENCODER); String hash = passwordCompareElt.getAttribute(ATT_HASH); - if (passwordEncoderElement != null) { if (StringUtils.hasText(hash)) { parserContext.getReaderContext().warning( - "Attribute 'hash' cannot be used with 'password-encoder' and " - + "will be ignored.", + "Attribute 'hash' cannot be used with 'password-encoder' and " + "will be ignored.", parserContext.extractSource(elt)); } - PasswordEncoderParser pep = new PasswordEncoderParser( - passwordEncoderElement, parserContext); - authenticatorBuilder.addPropertyValue("passwordEncoder", - pep.getPasswordEncoder()); + PasswordEncoderParser pep = new PasswordEncoderParser(passwordEncoderElement, parserContext); + authenticatorBuilder.addPropertyValue("passwordEncoder", pep.getPasswordEncoder()); } else if (StringUtils.hasText(hash)) { authenticatorBuilder.addPropertyValue("passwordEncoder", - PasswordEncoderParser.createPasswordEncoderBeanDefinition(hash, - false)); + PasswordEncoderParser.createPasswordEncoderBeanDefinition(hash, false)); } } - authenticatorBuilder.addConstructorArgValue(contextSource); authenticatorBuilder.addPropertyValue("userDnPatterns", userDnPatternArray); - if (searchBean != null) { authenticatorBuilder.addPropertyValue("userSearch", searchBean); } - - BeanDefinitionBuilder ldapProvider = BeanDefinitionBuilder - .rootBeanDefinition(PROVIDER_CLASS); + BeanDefinitionBuilder ldapProvider = BeanDefinitionBuilder.rootBeanDefinition(PROVIDER_CLASS); ldapProvider.addConstructorArgValue(authenticatorBuilder.getBeanDefinition()); - ldapProvider.addConstructorArgValue(LdapUserServiceBeanDefinitionParser - .parseAuthoritiesPopulator(elt, parserContext)); + ldapProvider.addConstructorArgValue( + LdapUserServiceBeanDefinitionParser.parseAuthoritiesPopulator(elt, parserContext)); ldapProvider.addPropertyValue("userDetailsContextMapper", - LdapUserServiceBeanDefinitionParser.parseUserDetailsClassOrUserMapperRef( - elt, parserContext)); - + LdapUserServiceBeanDefinitionParser.parseUserDetailsClassOrUserMapperRef(elt, parserContext)); return ldapProvider.getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParser.java index cbdda20de6..3da012c375 100644 --- a/config/src/main/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/ldap/LdapServerBeanDefinitionParser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.ldap; import java.io.IOException; @@ -42,6 +43,7 @@ import org.springframework.util.StringUtils; * @author Evgeniy Cheban */ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { + private static final String CONTEXT_SOURCE_CLASS = "org.springframework.security.ldap.DefaultSpringSecurityContextSource"; /** @@ -51,36 +53,43 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_URL = "url"; private static final String ATT_PRINCIPAL = "manager-dn"; + private static final String ATT_PASSWORD = "manager-password"; // Properties which apply to embedded server only - when no Url is set /** sets the configuration suffix (default is "dc=springframework,dc=org"). */ public static final String ATT_ROOT_SUFFIX = "root"; + private static final String OPT_DEFAULT_ROOT_SUFFIX = "dc=springframework,dc=org"; + /** * Optionally defines an ldif resource to be loaded. Otherwise an attempt will be made * to load all ldif files found on the classpath. */ public static final String ATT_LDIF_FILE = "ldif"; + private static final String OPT_DEFAULT_LDIF_FILE = "classpath*:*.ldif"; /** Defines the port the LDAP_PROVIDER server should run on */ public static final String ATT_PORT = "port"; + private static final String RANDOM_PORT = "0"; + private static final int DEFAULT_PORT = 33389; private static final String APACHEDS_CLASSNAME = "org.apache.directory.server.core.DefaultDirectoryService"; + private static final String UNBOUNID_CLASSNAME = "com.unboundid.ldap.listener.InMemoryDirectoryServer"; private static final String APACHEDS_CONTAINER_CLASSNAME = "org.springframework.security.ldap.server.ApacheDSContainer"; + private static final String UNBOUNDID_CONTAINER_CLASSNAME = "org.springframework.security.ldap.server.UnboundIdContainer"; + @Override public BeanDefinition parse(Element elt, ParserContext parserContext) { String url = elt.getAttribute(ATT_URL); - RootBeanDefinition contextSource; - if (!StringUtils.hasText(url)) { contextSource = createEmbeddedServer(elt, parserContext); } @@ -89,31 +98,20 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { contextSource.setBeanClassName(CONTEXT_SOURCE_CLASS); contextSource.getConstructorArgumentValues().addIndexedArgumentValue(0, url); } - contextSource.setSource(parserContext.extractSource(elt)); - String managerDn = elt.getAttribute(ATT_PRINCIPAL); String managerPassword = elt.getAttribute(ATT_PASSWORD); - if (StringUtils.hasText(managerDn)) { if (!StringUtils.hasText(managerPassword)) { - parserContext.getReaderContext().error( - "You must specify the " + ATT_PASSWORD + " if you supply a " - + managerDn, elt); + parserContext.getReaderContext() + .error("You must specify the " + ATT_PASSWORD + " if you supply a " + managerDn, elt); } - contextSource.getPropertyValues().addPropertyValue("userDn", managerDn); - contextSource.getPropertyValues().addPropertyValue("password", - managerPassword); + contextSource.getPropertyValues().addPropertyValue("password", managerPassword); } - String id = elt.getAttribute(AbstractBeanDefinitionParser.ID_ATTRIBUTE); - String contextSourceId = StringUtils.hasText(id) ? id : BeanIds.CONTEXT_SOURCE; - - parserContext.getRegistry() - .registerBeanDefinition(contextSourceId, contextSource); - + parserContext.getRegistry().registerBeanDefinition(contextSourceId, contextSource); return null; } @@ -121,64 +119,47 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { * Will be called if no url attribute is supplied. * * Registers beans to create an embedded apache directory server. - * * @return the BeanDefinition for the ContextSource for the embedded server. * * @see ApacheDSContainer * @see UnboundIdContainer */ - private RootBeanDefinition createEmbeddedServer(Element element, - ParserContext parserContext) { + private RootBeanDefinition createEmbeddedServer(Element element, ParserContext parserContext) { Object source = parserContext.extractSource(element); - String suffix = element.getAttribute(ATT_ROOT_SUFFIX); - if (!StringUtils.hasText(suffix)) { suffix = OPT_DEFAULT_ROOT_SUFFIX; } - - BeanDefinitionBuilder contextSource = BeanDefinitionBuilder - .rootBeanDefinition(CONTEXT_SOURCE_CLASS); + BeanDefinitionBuilder contextSource = BeanDefinitionBuilder.rootBeanDefinition(CONTEXT_SOURCE_CLASS); contextSource.addConstructorArgValue(suffix); contextSource.addPropertyValue("userDn", "uid=admin,ou=system"); contextSource.addPropertyValue("password", "secret"); - BeanDefinition embeddedLdapServerConfigBean = BeanDefinitionBuilder .rootBeanDefinition(EmbeddedLdapServerConfigBean.class).getBeanDefinition(); String embeddedLdapServerConfigBeanName = parserContext.getReaderContext() .generateBeanName(embeddedLdapServerConfigBean); - - parserContext.registerBeanComponent(new BeanComponentDefinition(embeddedLdapServerConfigBean, - embeddedLdapServerConfigBeanName)); - + parserContext.registerBeanComponent( + new BeanComponentDefinition(embeddedLdapServerConfigBean, embeddedLdapServerConfigBeanName)); contextSource.setFactoryMethodOnBean("createEmbeddedContextSource", embeddedLdapServerConfigBeanName); - String mode = element.getAttribute("mode"); RootBeanDefinition ldapContainer = getRootBeanDefinition(mode); ldapContainer.setSource(source); ldapContainer.getConstructorArgumentValues().addGenericArgumentValue(suffix); - String ldifs = element.getAttribute(ATT_LDIF_FILE); if (!StringUtils.hasText(ldifs)) { ldifs = OPT_DEFAULT_LDIF_FILE; } - ldapContainer.getConstructorArgumentValues().addGenericArgumentValue(ldifs); ldapContainer.getPropertyValues().addPropertyValue("port", getPort(element)); - - if (parserContext.getRegistry() - .containsBeanDefinition(BeanIds.EMBEDDED_APACHE_DS) || - parserContext.getRegistry().containsBeanDefinition(BeanIds.EMBEDDED_UNBOUNDID)) { - parserContext.getReaderContext().error( - "Only one embedded server bean is allowed per application context", + if (parserContext.getRegistry().containsBeanDefinition(BeanIds.EMBEDDED_APACHE_DS) + || parserContext.getRegistry().containsBeanDefinition(BeanIds.EMBEDDED_UNBOUNDID)) { + parserContext.getReaderContext().error("Only one embedded server bean is allowed per application context", element); } - String beanId = resolveBeanId(mode); if (beanId != null) { parserContext.getRegistry().registerBeanDefinition(beanId, ldapContainer); } - return (RootBeanDefinition) contextSource.getBeanDefinition(); } @@ -186,7 +167,7 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { if (isApacheDsEnabled(mode)) { return new RootBeanDefinition(APACHEDS_CONTAINER_CLASSNAME, null, null); } - else if (isUnboundidEnabled(mode)) { + if (isUnboundidEnabled(mode)) { return new RootBeanDefinition(UNBOUNDID_CONTAINER_CLASSNAME, null, null); } throw new IllegalStateException("Embedded LDAP server is not provided"); @@ -196,7 +177,7 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { if (isApacheDsEnabled(mode)) { return BeanIds.EMBEDDED_APACHE_DS; } - else if (isUnboundidEnabled(mode)) { + if (isUnboundidEnabled(mode)) { return BeanIds.EMBEDDED_UNBOUNDID; } return null; @@ -218,7 +199,8 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { private String getDefaultPort() { try (ServerSocket serverSocket = new ServerSocket(DEFAULT_PORT)) { return String.valueOf(serverSocket.getLocalPort()); - } catch (IOException e) { + } + catch (IOException ex) { return RANDOM_PORT; } } @@ -234,22 +216,23 @@ public class LdapServerBeanDefinitionParser implements BeanDefinitionParser { @SuppressWarnings("unused") private DefaultSpringSecurityContextSource createEmbeddedContextSource(String suffix) { - int port; - if (ClassUtils.isPresent(APACHEDS_CLASSNAME, getClass().getClassLoader())) { - ApacheDSContainer apacheDSContainer = this.applicationContext.getBean(ApacheDSContainer.class); - port = apacheDSContainer.getLocalPort(); - } - else if (ClassUtils.isPresent(UNBOUNID_CLASSNAME, getClass().getClassLoader())) { - UnboundIdContainer unboundIdContainer = this.applicationContext.getBean(UnboundIdContainer.class); - port = unboundIdContainer.getPort(); - } - else { - throw new IllegalStateException("Embedded LDAP server is not provided"); - } - + int port = getPort(); String providerUrl = "ldap://127.0.0.1:" + port + "/" + suffix; - return new DefaultSpringSecurityContextSource(providerUrl); } + + private int getPort() { + if (ClassUtils.isPresent(APACHEDS_CLASSNAME, getClass().getClassLoader())) { + ApacheDSContainer apacheDSContainer = this.applicationContext.getBean(ApacheDSContainer.class); + return apacheDSContainer.getLocalPort(); + } + if (ClassUtils.isPresent(UNBOUNID_CLASSNAME, getClass().getClassLoader())) { + UnboundIdContainer unboundIdContainer = this.applicationContext.getBean(UnboundIdContainer.class); + return unboundIdContainer.getPort(); + } + throw new IllegalStateException("Embedded LDAP server is not provided"); + } + } + } diff --git a/config/src/main/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParser.java index 219de13e33..dec9a6ac41 100644 --- a/config/src/main/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/ldap/LdapUserServiceBeanDefinitionParser.java @@ -13,203 +13,184 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.ldap; +import org.w3c.dom.Element; + import org.springframework.beans.BeanMetadataElement; -import org.springframework.beans.factory.xml.ParserContext; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.RootBeanDefinition; -import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.xml.ParserContext; import org.springframework.security.config.BeanIds; import org.springframework.security.config.authentication.AbstractUserDetailsServiceBeanDefinitionParser; import org.springframework.util.StringUtils; -import org.w3c.dom.Element; - /** * @author Luke Taylor * @since 2.0 */ -public class LdapUserServiceBeanDefinitionParser extends - AbstractUserDetailsServiceBeanDefinitionParser { +public class LdapUserServiceBeanDefinitionParser extends AbstractUserDetailsServiceBeanDefinitionParser { + public static final String ATT_SERVER = "server-ref"; + public static final String ATT_USER_SEARCH_FILTER = "user-search-filter"; + public static final String ATT_USER_SEARCH_BASE = "user-search-base"; + public static final String DEF_USER_SEARCH_BASE = ""; public static final String ATT_GROUP_SEARCH_FILTER = "group-search-filter"; + public static final String ATT_GROUP_SEARCH_BASE = "group-search-base"; + public static final String ATT_GROUP_ROLE_ATTRIBUTE = "group-role-attribute"; + public static final String DEF_GROUP_SEARCH_FILTER = "(uniqueMember={0})"; + public static final String DEF_GROUP_SEARCH_BASE = ""; static final String ATT_ROLE_PREFIX = "role-prefix"; + static final String ATT_USER_CLASS = "user-details-class"; + static final String ATT_USER_CONTEXT_MAPPER_REF = "user-context-mapper-ref"; + static final String OPT_PERSON = "person"; + static final String OPT_INETORGPERSON = "inetOrgPerson"; public static final String LDAP_SEARCH_CLASS = "org.springframework.security.ldap.search.FilterBasedLdapUserSearch"; + public static final String PERSON_MAPPER_CLASS = "org.springframework.security.ldap.userdetails.PersonContextMapper"; + public static final String INET_ORG_PERSON_MAPPER_CLASS = "org.springframework.security.ldap.userdetails.InetOrgPersonContextMapper"; + public static final String LDAP_USER_MAPPER_CLASS = "org.springframework.security.ldap.userdetails.LdapUserDetailsMapper"; + public static final String LDAP_AUTHORITIES_POPULATOR_CLASS = "org.springframework.security.ldap.userdetails.DefaultLdapAuthoritiesPopulator"; + @Override protected String getBeanClassName(Element element) { return "org.springframework.security.ldap.userdetails.LdapUserDetailsService"; } - protected void doParse(Element elt, ParserContext parserContext, - BeanDefinitionBuilder builder) { - + @Override + protected void doParse(Element elt, ParserContext parserContext, BeanDefinitionBuilder builder) { if (!StringUtils.hasText(elt.getAttribute(ATT_USER_SEARCH_FILTER))) { - parserContext.getReaderContext().error("User search filter must be supplied", - elt); + parserContext.getReaderContext().error("User search filter must be supplied", elt); } - builder.addConstructorArgValue(parseSearchBean(elt, parserContext)); builder.getRawBeanDefinition().setSource(parserContext.extractSource(elt)); builder.addConstructorArgValue(parseAuthoritiesPopulator(elt, parserContext)); - builder.addPropertyValue("userDetailsMapper", - parseUserDetailsClassOrUserMapperRef(elt, parserContext)); + builder.addPropertyValue("userDetailsMapper", parseUserDetailsClassOrUserMapperRef(elt, parserContext)); } static RootBeanDefinition parseSearchBean(Element elt, ParserContext parserContext) { String userSearchFilter = elt.getAttribute(ATT_USER_SEARCH_FILTER); String userSearchBase = elt.getAttribute(ATT_USER_SEARCH_BASE); Object source = parserContext.extractSource(elt); - if (StringUtils.hasText(userSearchBase)) { if (!StringUtils.hasText(userSearchFilter)) { - parserContext.getReaderContext().error( - ATT_USER_SEARCH_BASE + " cannot be used without a " - + ATT_USER_SEARCH_FILTER, source); + parserContext.getReaderContext() + .error(ATT_USER_SEARCH_BASE + " cannot be used without a " + ATT_USER_SEARCH_FILTER, source); } } else { userSearchBase = DEF_USER_SEARCH_BASE; } - if (!StringUtils.hasText(userSearchFilter)) { return null; } - - BeanDefinitionBuilder searchBuilder = BeanDefinitionBuilder - .rootBeanDefinition(LDAP_SEARCH_CLASS); + BeanDefinitionBuilder searchBuilder = BeanDefinitionBuilder.rootBeanDefinition(LDAP_SEARCH_CLASS); searchBuilder.getRawBeanDefinition().setSource(source); searchBuilder.addConstructorArgValue(userSearchBase); searchBuilder.addConstructorArgValue(userSearchFilter); searchBuilder.addConstructorArgValue(parseServerReference(elt, parserContext)); - return (RootBeanDefinition) searchBuilder.getBeanDefinition(); } - static RuntimeBeanReference parseServerReference(Element elt, - ParserContext parserContext) { + static RuntimeBeanReference parseServerReference(Element elt, ParserContext parserContext) { String server = elt.getAttribute(ATT_SERVER); boolean requiresDefaultName = false; - if (!StringUtils.hasText(server)) { server = BeanIds.CONTEXT_SOURCE; requiresDefaultName = true; } - RuntimeBeanReference contextSource = new RuntimeBeanReference(server); contextSource.setSource(parserContext.extractSource(elt)); registerPostProcessorIfNecessary(parserContext.getRegistry(), requiresDefaultName); - return contextSource; } - private static void registerPostProcessorIfNecessary(BeanDefinitionRegistry registry, - boolean defaultNameRequired) { - if (registry - .containsBeanDefinition(BeanIds.CONTEXT_SOURCE_SETTING_POST_PROCESSOR)) { + private static void registerPostProcessorIfNecessary(BeanDefinitionRegistry registry, boolean defaultNameRequired) { + if (registry.containsBeanDefinition(BeanIds.CONTEXT_SOURCE_SETTING_POST_PROCESSOR)) { if (defaultNameRequired) { - BeanDefinition bd = registry - .getBeanDefinition(BeanIds.CONTEXT_SOURCE_SETTING_POST_PROCESSOR); - bd.getPropertyValues().addPropertyValue("defaultNameRequired", - defaultNameRequired); + BeanDefinition bd = registry.getBeanDefinition(BeanIds.CONTEXT_SOURCE_SETTING_POST_PROCESSOR); + bd.getPropertyValues().addPropertyValue("defaultNameRequired", defaultNameRequired); } return; } - - BeanDefinitionBuilder bdb = BeanDefinitionBuilder - .rootBeanDefinition(ContextSourceSettingPostProcessor.class); + BeanDefinitionBuilder bdb = BeanDefinitionBuilder.rootBeanDefinition(ContextSourceSettingPostProcessor.class); bdb.addPropertyValue("defaultNameRequired", defaultNameRequired); - registry.registerBeanDefinition(BeanIds.CONTEXT_SOURCE_SETTING_POST_PROCESSOR, - bdb.getBeanDefinition()); + registry.registerBeanDefinition(BeanIds.CONTEXT_SOURCE_SETTING_POST_PROCESSOR, bdb.getBeanDefinition()); } - static BeanMetadataElement parseUserDetailsClassOrUserMapperRef(Element elt, - ParserContext parserContext) { + static BeanMetadataElement parseUserDetailsClassOrUserMapperRef(Element elt, ParserContext parserContext) { String userDetailsClass = elt.getAttribute(ATT_USER_CLASS); String userMapperRef = elt.getAttribute(ATT_USER_CONTEXT_MAPPER_REF); - if (StringUtils.hasText(userDetailsClass) && StringUtils.hasText(userMapperRef)) { - parserContext.getReaderContext().error( - "Attributes " + ATT_USER_CLASS + " and " - + ATT_USER_CONTEXT_MAPPER_REF + " cannot be used together.", - parserContext.extractSource(elt)); + parserContext.getReaderContext().error("Attributes " + ATT_USER_CLASS + " and " + + ATT_USER_CONTEXT_MAPPER_REF + " cannot be used together.", parserContext.extractSource(elt)); } - if (StringUtils.hasText(userMapperRef)) { return new RuntimeBeanReference(userMapperRef); } - - RootBeanDefinition mapper; - - if (OPT_PERSON.equals(userDetailsClass)) { - mapper = new RootBeanDefinition(PERSON_MAPPER_CLASS, null, null); - } - else if (OPT_INETORGPERSON.equals(userDetailsClass)) { - mapper = new RootBeanDefinition(INET_ORG_PERSON_MAPPER_CLASS, null, null); - } - else { - mapper = new RootBeanDefinition(LDAP_USER_MAPPER_CLASS, null, null); - } - + RootBeanDefinition mapper = getMapper(userDetailsClass); mapper.setSource(parserContext.extractSource(elt)); - return mapper; } - static RootBeanDefinition parseAuthoritiesPopulator(Element elt, - ParserContext parserContext) { + private static RootBeanDefinition getMapper(String userDetailsClass) { + if (OPT_PERSON.equals(userDetailsClass)) { + return new RootBeanDefinition(PERSON_MAPPER_CLASS, null, null); + } + if (OPT_INETORGPERSON.equals(userDetailsClass)) { + return new RootBeanDefinition(INET_ORG_PERSON_MAPPER_CLASS, null, null); + } + return new RootBeanDefinition(LDAP_USER_MAPPER_CLASS, null, null); + } + + static RootBeanDefinition parseAuthoritiesPopulator(Element elt, ParserContext parserContext) { String groupSearchFilter = elt.getAttribute(ATT_GROUP_SEARCH_FILTER); String groupSearchBase = elt.getAttribute(ATT_GROUP_SEARCH_BASE); String groupRoleAttribute = elt.getAttribute(ATT_GROUP_ROLE_ATTRIBUTE); String rolePrefix = elt.getAttribute(ATT_ROLE_PREFIX); - if (!StringUtils.hasText(groupSearchFilter)) { groupSearchFilter = DEF_GROUP_SEARCH_FILTER; } - if (!StringUtils.hasText(groupSearchBase)) { groupSearchBase = DEF_GROUP_SEARCH_BASE; } - - BeanDefinitionBuilder populator = BeanDefinitionBuilder - .rootBeanDefinition(LDAP_AUTHORITIES_POPULATOR_CLASS); + BeanDefinitionBuilder populator = BeanDefinitionBuilder.rootBeanDefinition(LDAP_AUTHORITIES_POPULATOR_CLASS); populator.getRawBeanDefinition().setSource(parserContext.extractSource(elt)); populator.addConstructorArgValue(parseServerReference(elt, parserContext)); populator.addConstructorArgValue(groupSearchBase); populator.addPropertyValue("groupSearchFilter", groupSearchFilter); populator.addPropertyValue("searchSubtree", Boolean.TRUE); - if (StringUtils.hasText(rolePrefix)) { if ("none".equals(rolePrefix)) { rolePrefix = ""; } populator.addPropertyValue("rolePrefix", rolePrefix); } - if (StringUtils.hasLength(groupRoleAttribute)) { populator.addPropertyValue("groupRoleAttribute", groupRoleAttribute); } - return (RootBeanDefinition) populator.getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/ldap/package-info.java b/config/src/main/java/org/springframework/security/config/ldap/package-info.java index 392fc2146b..54f428d6fa 100644 --- a/config/src/main/java/org/springframework/security/config/ldap/package-info.java +++ b/config/src/main/java/org/springframework/security/config/ldap/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Security namespace support for LDAP authentication. */ package org.springframework.security.config.ldap; - diff --git a/config/src/main/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParser.java index 29a58e040c..c87b32e169 100644 --- a/config/src/main/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParser.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.method; -import static org.springframework.security.config.Elements.*; +package org.springframework.security.config.method; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -24,6 +23,8 @@ import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.w3c.dom.Element; + import org.springframework.aop.config.AopNamespaceUtils; import org.springframework.aop.framework.ProxyFactoryBean; import org.springframework.aop.target.LazyInitTargetSource; @@ -47,6 +48,7 @@ import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.annotation.Jsr250MethodSecurityMetadataSource; @@ -79,7 +81,6 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** * Processes the top-level "global-method-security" element. @@ -94,56 +95,56 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP private final Log logger = LogFactory.getLog(getClass()); private static final String ATT_AUTHENTICATION_MANAGER_REF = "authentication-manager-ref"; + private static final String ATT_ACCESS = "access"; + private static final String ATT_EXPRESSION = "expression"; + private static final String ATT_ACCESS_MGR = "access-decision-manager-ref"; + private static final String ATT_RUN_AS_MGR = "run-as-manager-ref"; + private static final String ATT_USE_JSR250 = "jsr250-annotations"; + private static final String ATT_USE_SECURED = "secured-annotations"; + private static final String ATT_USE_PREPOST = "pre-post-annotations"; + private static final String ATT_REF = "ref"; + private static final String ATT_MODE = "mode"; + private static final String ATT_ADVICE_ORDER = "order"; + private static final String ATT_META_DATA_SOURCE_REF = "metadata-source-ref"; + @Override public BeanDefinition parse(Element element, ParserContext pc) { - CompositeComponentDefinition compositeDef = new CompositeComponentDefinition( - element.getTagName(), pc.extractSource(element)); + CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), + pc.extractSource(element)); pc.pushContainingComponent(compositeDef); - Object source = pc.extractSource(element); // The list of method metadata delegates ManagedList delegates = new ManagedList<>(); - boolean jsr250Enabled = "enabled".equals(element.getAttribute(ATT_USE_JSR250)); boolean useSecured = "enabled".equals(element.getAttribute(ATT_USE_SECURED)); - boolean prePostAnnotationsEnabled = "enabled".equals(element - .getAttribute(ATT_USE_PREPOST)); + boolean prePostAnnotationsEnabled = "enabled".equals(element.getAttribute(ATT_USE_PREPOST)); boolean useAspectJ = "aspectj".equals(element.getAttribute(ATT_MODE)); - BeanDefinition preInvocationVoter = null; ManagedList afterInvocationProviders = new ManagedList<>(); - // Check for an external SecurityMetadataSource, which takes priority over other // sources String metaDataSourceId = element.getAttribute(ATT_META_DATA_SOURCE_REF); - if (StringUtils.hasText(metaDataSourceId)) { delegates.add(new RuntimeBeanReference(metaDataSourceId)); } - if (prePostAnnotationsEnabled) { - Element prePostElt = DomUtils.getChildElementByTagName(element, - INVOCATION_HANDLING); - Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, - EXPRESSION_HANDLER); - + Element prePostElt = DomUtils.getChildElementByTagName(element, Elements.INVOCATION_HANDLING); + Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, Elements.EXPRESSION_HANDLER); if (prePostElt != null && expressionHandlerElt != null) { - pc.getReaderContext().error( - INVOCATION_HANDLING + " and " + EXPRESSION_HANDLER - + " cannot be used together ", source); + pc.getReaderContext().error(Elements.INVOCATION_HANDLING + " and " + Elements.EXPRESSION_HANDLER + + " cannot be used together ", source); } - BeanDefinitionBuilder preInvocationVoterBldr = BeanDefinitionBuilder .rootBeanDefinition(PreInvocationAuthorizationAdviceVoter.class); // After-invocation provider to handle post-invocation filtering and @@ -153,167 +154,123 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP // The metadata source for the security interceptor BeanDefinitionBuilder mds = BeanDefinitionBuilder .rootBeanDefinition(PrePostAnnotationSecurityMetadataSource.class); - if (prePostElt != null) { // Customized override of expression handling system - String attributeFactoryRef = DomUtils.getChildElementByTagName( - prePostElt, INVOCATION_ATTRIBUTE_FACTORY).getAttribute("ref"); - String preAdviceRef = DomUtils.getChildElementByTagName(prePostElt, - PRE_INVOCATION_ADVICE).getAttribute("ref"); - String postAdviceRef = DomUtils.getChildElementByTagName(prePostElt, - POST_INVOCATION_ADVICE).getAttribute("ref"); - + String attributeFactoryRef = DomUtils + .getChildElementByTagName(prePostElt, Elements.INVOCATION_ATTRIBUTE_FACTORY) + .getAttribute("ref"); + String preAdviceRef = DomUtils.getChildElementByTagName(prePostElt, Elements.PRE_INVOCATION_ADVICE) + .getAttribute("ref"); + String postAdviceRef = DomUtils.getChildElementByTagName(prePostElt, Elements.POST_INVOCATION_ADVICE) + .getAttribute("ref"); mds.addConstructorArgReference(attributeFactoryRef); preInvocationVoterBldr.addConstructorArgReference(preAdviceRef); afterInvocationBldr.addConstructorArgReference(postAdviceRef); } else { // The default expression-based system - String expressionHandlerRef = expressionHandlerElt == null ? null - : expressionHandlerElt.getAttribute("ref"); - + String expressionHandlerRef = (expressionHandlerElt != null) ? expressionHandlerElt.getAttribute("ref") + : null; if (StringUtils.hasText(expressionHandlerRef)) { - logger.info("Using bean '" + expressionHandlerRef - + "' as method ExpressionHandler implementation"); + this.logger.info(LogMessage.format("Using bean '%s' as method ExpressionHandler implementation", + expressionHandlerRef)); RootBeanDefinition lazyInitPP = new RootBeanDefinition( LazyInitBeanDefinitionRegistryPostProcessor.class); - lazyInitPP.getConstructorArgumentValues().addGenericArgumentValue( - expressionHandlerRef); + lazyInitPP.getConstructorArgumentValues().addGenericArgumentValue(expressionHandlerRef); pc.getReaderContext().registerWithGeneratedName(lazyInitPP); - BeanDefinitionBuilder lazyMethodSecurityExpressionHandlerBldr = BeanDefinitionBuilder .rootBeanDefinition(LazyInitTargetSource.class); - lazyMethodSecurityExpressionHandlerBldr.addPropertyValue( - "targetBeanName", expressionHandlerRef); - + lazyMethodSecurityExpressionHandlerBldr.addPropertyValue("targetBeanName", expressionHandlerRef); BeanDefinitionBuilder expressionHandlerProxyBldr = BeanDefinitionBuilder .rootBeanDefinition(ProxyFactoryBean.class); expressionHandlerProxyBldr.addPropertyValue("targetSource", lazyMethodSecurityExpressionHandlerBldr.getBeanDefinition()); expressionHandlerProxyBldr.addPropertyValue("proxyInterfaces", MethodSecurityExpressionHandler.class); - - expressionHandlerRef = pc.getReaderContext().generateBeanName( - expressionHandlerProxyBldr.getBeanDefinition()); - - pc.registerBeanComponent(new BeanComponentDefinition( - expressionHandlerProxyBldr.getBeanDefinition(), + expressionHandlerRef = pc.getReaderContext() + .generateBeanName(expressionHandlerProxyBldr.getBeanDefinition()); + pc.registerBeanComponent(new BeanComponentDefinition(expressionHandlerProxyBldr.getBeanDefinition(), expressionHandlerRef)); } else { - RootBeanDefinition expressionHandler = registerWithDefaultRolePrefix(pc, DefaultMethodSecurityExpressionHandlerBeanFactory.class); - - expressionHandlerRef = pc.getReaderContext().generateBeanName( - expressionHandler); - pc.registerBeanComponent(new BeanComponentDefinition( - expressionHandler, expressionHandlerRef)); - logger.info("Expressions were enabled for method security but no SecurityExpressionHandler was configured. " - + "All hasPermission() expressions will evaluate to false."); + RootBeanDefinition expressionHandler = registerWithDefaultRolePrefix(pc, + DefaultMethodSecurityExpressionHandlerBeanFactory.class); + expressionHandlerRef = pc.getReaderContext().generateBeanName(expressionHandler); + pc.registerBeanComponent(new BeanComponentDefinition(expressionHandler, expressionHandlerRef)); + this.logger.info("Expressions were enabled for method security but no SecurityExpressionHandler " + + "was configured. All hasPermission() expressions will evaluate to false."); } - BeanDefinitionBuilder expressionPreAdviceBldr = BeanDefinitionBuilder .rootBeanDefinition(ExpressionBasedPreInvocationAdvice.class); - expressionPreAdviceBldr.addPropertyReference("expressionHandler", - expressionHandlerRef); - preInvocationVoterBldr.addConstructorArgValue(expressionPreAdviceBldr - .getBeanDefinition()); - + expressionPreAdviceBldr.addPropertyReference("expressionHandler", expressionHandlerRef); + preInvocationVoterBldr.addConstructorArgValue(expressionPreAdviceBldr.getBeanDefinition()); BeanDefinitionBuilder expressionPostAdviceBldr = BeanDefinitionBuilder .rootBeanDefinition(ExpressionBasedPostInvocationAdvice.class); expressionPostAdviceBldr.addConstructorArgReference(expressionHandlerRef); - afterInvocationBldr.addConstructorArgValue(expressionPostAdviceBldr - .getBeanDefinition()); - + afterInvocationBldr.addConstructorArgValue(expressionPostAdviceBldr.getBeanDefinition()); BeanDefinitionBuilder annotationInvocationFactory = BeanDefinitionBuilder .rootBeanDefinition(ExpressionBasedAnnotationAttributeFactory.class); - annotationInvocationFactory - .addConstructorArgReference(expressionHandlerRef); - mds.addConstructorArgValue(annotationInvocationFactory - .getBeanDefinition()); + annotationInvocationFactory.addConstructorArgReference(expressionHandlerRef); + mds.addConstructorArgValue(annotationInvocationFactory.getBeanDefinition()); } - preInvocationVoter = preInvocationVoterBldr.getBeanDefinition(); afterInvocationProviders.add(afterInvocationBldr.getBeanDefinition()); delegates.add(mds.getBeanDefinition()); } - if (useSecured) { - delegates.add(BeanDefinitionBuilder.rootBeanDefinition( - SecuredAnnotationSecurityMetadataSource.class).getBeanDefinition()); + delegates.add(BeanDefinitionBuilder.rootBeanDefinition(SecuredAnnotationSecurityMetadataSource.class) + .getBeanDefinition()); } - if (jsr250Enabled) { - RootBeanDefinition jsrMetadataSource = registerWithDefaultRolePrefix(pc, Jsr250MethodSecurityMetadataSourceBeanFactory.class); + RootBeanDefinition jsrMetadataSource = registerWithDefaultRolePrefix(pc, + Jsr250MethodSecurityMetadataSourceBeanFactory.class); delegates.add(jsrMetadataSource); } - // Now create a Map for each // sub-element Map> pointcutMap = parseProtectPointcuts(pc, - DomUtils.getChildElementsByTagName(element, PROTECT_POINTCUT)); - + DomUtils.getChildElementsByTagName(element, Elements.PROTECT_POINTCUT)); if (pointcutMap.size() > 0) { if (useAspectJ) { - pc.getReaderContext().error( - "You can't use AspectJ mode with protect-pointcut definitions", - source); + pc.getReaderContext().error("You can't use AspectJ mode with protect-pointcut definitions", source); } // Only add it if there are actually any pointcuts defined. - BeanDefinition mapBasedMetadataSource = new RootBeanDefinition( - MapBasedMethodSecurityMetadataSource.class); - BeanReference ref = new RuntimeBeanReference(pc.getReaderContext() - .generateBeanName(mapBasedMetadataSource)); - + BeanDefinition mapBasedMetadataSource = new RootBeanDefinition(MapBasedMethodSecurityMetadataSource.class); + BeanReference ref = new RuntimeBeanReference( + pc.getReaderContext().generateBeanName(mapBasedMetadataSource)); delegates.add(ref); - pc.registerBeanComponent(new BeanComponentDefinition(mapBasedMetadataSource, - ref.getBeanName())); + pc.registerBeanComponent(new BeanComponentDefinition(mapBasedMetadataSource, ref.getBeanName())); registerProtectPointcutPostProcessor(pc, pointcutMap, ref, source); } - - BeanReference metadataSource = registerDelegatingMethodSecurityMetadataSource(pc, - delegates, source); - + BeanReference metadataSource = registerDelegatingMethodSecurityMetadataSource(pc, delegates, source); // Check for additional after-invocation-providers.. List afterInvocationElts = DomUtils.getChildElementsByTagName(element, Elements.AFTER_INVOCATION_PROVIDER); - for (Element elt : afterInvocationElts) { - afterInvocationProviders.add(new RuntimeBeanReference(elt - .getAttribute(ATT_REF))); + afterInvocationProviders.add(new RuntimeBeanReference(elt.getAttribute(ATT_REF))); } - String accessManagerId = element.getAttribute(ATT_ACCESS_MGR); - if (!StringUtils.hasText(accessManagerId)) { accessManagerId = registerAccessManager(pc, jsr250Enabled, preInvocationVoter); } - String authMgrRef = element.getAttribute(ATT_AUTHENTICATION_MANAGER_REF); - String runAsManagerId = element.getAttribute(ATT_RUN_AS_MGR); - BeanReference interceptor = registerMethodSecurityInterceptor(pc, authMgrRef, - accessManagerId, runAsManagerId, metadataSource, - afterInvocationProviders, source, useAspectJ); - + BeanReference interceptor = registerMethodSecurityInterceptor(pc, authMgrRef, accessManagerId, runAsManagerId, + metadataSource, afterInvocationProviders, source, useAspectJ); if (useAspectJ) { - BeanDefinitionBuilder aspect = BeanDefinitionBuilder - .rootBeanDefinition("org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"); + BeanDefinitionBuilder aspect = BeanDefinitionBuilder.rootBeanDefinition( + "org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"); aspect.setFactoryMethod("aspectOf"); aspect.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); aspect.addPropertyValue("securityInterceptor", interceptor); - String id = pc.getReaderContext().registerWithGeneratedName( - aspect.getBeanDefinition()); - pc.registerBeanComponent(new BeanComponentDefinition(aspect - .getBeanDefinition(), id)); + String id = pc.getReaderContext().registerWithGeneratedName(aspect.getBeanDefinition()); + pc.registerBeanComponent(new BeanComponentDefinition(aspect.getBeanDefinition(), id)); } else { - registerAdvisor(pc, interceptor, metadataSource, source, - element.getAttribute(ATT_ADVICE_ORDER)); + registerAdvisor(pc, interceptor, metadataSource, source, element.getAttribute(ATT_ADVICE_ORDER)); AopNamespaceUtils.registerAutoProxyCreatorIfNecessary(pc, element); } - pc.popAndRegisterContainingComponent(); - return null; } @@ -323,165 +280,123 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP * @return */ @SuppressWarnings({ "unchecked", "rawtypes" }) - private String registerAccessManager(ParserContext pc, boolean jsr250Enabled, - BeanDefinition expressionVoter) { - - BeanDefinitionBuilder accessMgrBuilder = BeanDefinitionBuilder - .rootBeanDefinition(AffirmativeBased.class); + private String registerAccessManager(ParserContext pc, boolean jsr250Enabled, BeanDefinition expressionVoter) { + BeanDefinitionBuilder accessMgrBuilder = BeanDefinitionBuilder.rootBeanDefinition(AffirmativeBased.class); ManagedList voters = new ManagedList(4); - if (expressionVoter != null) { voters.add(expressionVoter); } voters.add(new RootBeanDefinition(RoleVoter.class)); voters.add(new RootBeanDefinition(AuthenticatedVoter.class)); - if (jsr250Enabled) { voters.add(new RootBeanDefinition(Jsr250Voter.class)); } - accessMgrBuilder.addConstructorArgValue(voters); - BeanDefinition accessManager = accessMgrBuilder.getBeanDefinition(); String id = pc.getReaderContext().generateBeanName(accessManager); pc.registerBeanComponent(new BeanComponentDefinition(accessManager, id)); - return id; } @SuppressWarnings("rawtypes") - private BeanReference registerDelegatingMethodSecurityMetadataSource( - ParserContext pc, ManagedList delegates, Object source) { + private BeanReference registerDelegatingMethodSecurityMetadataSource(ParserContext pc, ManagedList delegates, + Object source) { RootBeanDefinition delegatingMethodSecurityMetadataSource = new RootBeanDefinition( DelegatingMethodSecurityMetadataSource.class); delegatingMethodSecurityMetadataSource.setSource(source); - delegatingMethodSecurityMetadataSource.getConstructorArgumentValues() - .addGenericArgumentValue(delegates); - - String id = pc.getReaderContext().generateBeanName( - delegatingMethodSecurityMetadataSource); - pc.registerBeanComponent(new BeanComponentDefinition( - delegatingMethodSecurityMetadataSource, id)); + delegatingMethodSecurityMetadataSource.getConstructorArgumentValues().addGenericArgumentValue(delegates); + String id = pc.getReaderContext().generateBeanName(delegatingMethodSecurityMetadataSource); + pc.registerBeanComponent(new BeanComponentDefinition(delegatingMethodSecurityMetadataSource, id)); return new RuntimeBeanReference(id); } private void registerProtectPointcutPostProcessor(ParserContext parserContext, - Map> pointcutMap, - BeanReference mapBasedMethodSecurityMetadataSource, Object source) { - RootBeanDefinition ppbp = new RootBeanDefinition( - ProtectPointcutPostProcessor.class); + Map> pointcutMap, BeanReference mapBasedMethodSecurityMetadataSource, + Object source) { + RootBeanDefinition ppbp = new RootBeanDefinition(ProtectPointcutPostProcessor.class); ppbp.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); ppbp.setSource(source); - ppbp.getConstructorArgumentValues().addGenericArgumentValue( - mapBasedMethodSecurityMetadataSource); + ppbp.getConstructorArgumentValues().addGenericArgumentValue(mapBasedMethodSecurityMetadataSource); ppbp.getPropertyValues().addPropertyValue("pointcutMap", pointcutMap); parserContext.getReaderContext().registerWithGeneratedName(ppbp); } - private Map> parseProtectPointcuts( - ParserContext parserContext, List protectPointcutElts) { + private Map> parseProtectPointcuts(ParserContext parserContext, + List protectPointcutElts) { Map> pointcutMap = new LinkedHashMap<>(); - for (Element childElt : protectPointcutElts) { String accessConfig = childElt.getAttribute(ATT_ACCESS); String expression = childElt.getAttribute(ATT_EXPRESSION); - if (!StringUtils.hasText(accessConfig)) { parserContext.getReaderContext().error("Access configuration required", parserContext.extractSource(childElt)); } - if (!StringUtils.hasText(expression)) { parserContext.getReaderContext().error("Pointcut expression required", parserContext.extractSource(childElt)); } - - String[] attributeTokens = StringUtils - .commaDelimitedListToStringArray(accessConfig); - List attributes = new ArrayList<>( - attributeTokens.length); - + String[] attributeTokens = StringUtils.commaDelimitedListToStringArray(accessConfig); + List attributes = new ArrayList<>(attributeTokens.length); for (String token : attributeTokens) { attributes.add(new SecurityConfig(token)); } - pointcutMap.put(expression, attributes); } - return pointcutMap; } - private BeanReference registerMethodSecurityInterceptor(ParserContext pc, - String authMgrRef, String accessManagerId, String runAsManagerId, - BeanReference metadataSource, - List afterInvocationProviders, Object source, - boolean useAspectJ) { - BeanDefinitionBuilder bldr = BeanDefinitionBuilder - .rootBeanDefinition(useAspectJ ? AspectJMethodSecurityInterceptor.class - : MethodSecurityInterceptor.class); + private BeanReference registerMethodSecurityInterceptor(ParserContext pc, String authMgrRef, String accessManagerId, + String runAsManagerId, BeanReference metadataSource, List afterInvocationProviders, + Object source, boolean useAspectJ) { + BeanDefinitionBuilder bldr = BeanDefinitionBuilder.rootBeanDefinition( + useAspectJ ? AspectJMethodSecurityInterceptor.class : MethodSecurityInterceptor.class); bldr.getRawBeanDefinition().setSource(source); bldr.addPropertyReference("accessDecisionManager", accessManagerId); - RootBeanDefinition authMgr = new RootBeanDefinition( - AuthenticationManagerDelegator.class); + RootBeanDefinition authMgr = new RootBeanDefinition(AuthenticationManagerDelegator.class); authMgr.getConstructorArgumentValues().addGenericArgumentValue(authMgrRef); bldr.addPropertyValue("authenticationManager", authMgr); bldr.addPropertyValue("securityMetadataSource", metadataSource); - if (StringUtils.hasText(runAsManagerId)) { bldr.addPropertyReference("runAsManager", runAsManagerId); } - if (!afterInvocationProviders.isEmpty()) { BeanDefinition afterInvocationManager; - afterInvocationManager = new RootBeanDefinition( - AfterInvocationProviderManager.class); - afterInvocationManager.getPropertyValues().addPropertyValue("providers", - afterInvocationProviders); + afterInvocationManager = new RootBeanDefinition(AfterInvocationProviderManager.class); + afterInvocationManager.getPropertyValues().addPropertyValue("providers", afterInvocationProviders); bldr.addPropertyValue("afterInvocationManager", afterInvocationManager); } - BeanDefinition bean = bldr.getBeanDefinition(); String id = pc.getReaderContext().generateBeanName(bean); pc.registerBeanComponent(new BeanComponentDefinition(bean, id)); - return new RuntimeBeanReference(id); } - private void registerAdvisor(ParserContext parserContext, BeanReference interceptor, - BeanReference metadataSource, Object source, String adviceOrder) { - if (parserContext.getRegistry().containsBeanDefinition( - BeanIds.METHOD_SECURITY_METADATA_SOURCE_ADVISOR)) { - parserContext.getReaderContext().error( - "Duplicate detected.", source); + private void registerAdvisor(ParserContext parserContext, BeanReference interceptor, BeanReference metadataSource, + Object source, String adviceOrder) { + if (parserContext.getRegistry().containsBeanDefinition(BeanIds.METHOD_SECURITY_METADATA_SOURCE_ADVISOR)) { + parserContext.getReaderContext().error("Duplicate detected.", source); } - RootBeanDefinition advisor = new RootBeanDefinition( - MethodSecurityMetadataSourceAdvisor.class); - + RootBeanDefinition advisor = new RootBeanDefinition(MethodSecurityMetadataSourceAdvisor.class); if (StringUtils.hasText(adviceOrder)) { advisor.getPropertyValues().addPropertyValue("order", adviceOrder); } - // advisor must be an infrastructure bean as Spring's // InfrastructureAdvisorAutoProxyCreator will ignore it // otherwise advisor.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); advisor.setSource(source); - advisor.getConstructorArgumentValues().addGenericArgumentValue( - interceptor.getBeanName()); + advisor.getConstructorArgumentValues().addGenericArgumentValue(interceptor.getBeanName()); advisor.getConstructorArgumentValues().addGenericArgumentValue(metadataSource); - advisor.getConstructorArgumentValues().addGenericArgumentValue( - metadataSource.getBeanName()); - - parserContext.getRegistry().registerBeanDefinition( - BeanIds.METHOD_SECURITY_METADATA_SOURCE_ADVISOR, advisor); + advisor.getConstructorArgumentValues().addGenericArgumentValue(metadataSource.getBeanName()); + parserContext.getRegistry().registerBeanDefinition(BeanIds.METHOD_SECURITY_METADATA_SOURCE_ADVISOR, advisor); } - private RootBeanDefinition registerWithDefaultRolePrefix(ParserContext pc, Class beanFactoryClass) { + private RootBeanDefinition registerWithDefaultRolePrefix(ParserContext pc, + Class beanFactoryClass) { RootBeanDefinition beanFactoryDefinition = new RootBeanDefinition(beanFactoryClass); String beanFactoryRef = pc.getReaderContext().generateBeanName(beanFactoryDefinition); pc.getRegistry().registerBeanDefinition(beanFactoryRef, beanFactoryDefinition); - RootBeanDefinition bean = new RootBeanDefinition(); bean.setFactoryBeanName(beanFactoryRef); bean.setFactoryMethodName("getBean"); @@ -495,77 +410,85 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP * @author Luke Taylor * @since 3.0 */ - static final class AuthenticationManagerDelegator implements AuthenticationManager, - BeanFactoryAware { + static final class AuthenticationManagerDelegator implements AuthenticationManager, BeanFactoryAware { + private AuthenticationManager delegate; + private final Object delegateMonitor = new Object(); + private BeanFactory beanFactory; + private final String authMgrBean; AuthenticationManagerDelegator(String authMgrBean) { - this.authMgrBean = StringUtils.hasText(authMgrBean) ? authMgrBean - : BeanIds.AUTHENTICATION_MANAGER; + this.authMgrBean = StringUtils.hasText(authMgrBean) ? authMgrBean : BeanIds.AUTHENTICATION_MANAGER; } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { - synchronized (delegateMonitor) { - if (delegate == null) { - Assert.state(beanFactory != null, - () -> "BeanFactory must be set to resolve " + authMgrBean); + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + synchronized (this.delegateMonitor) { + if (this.delegate == null) { + Assert.state(this.beanFactory != null, + () -> "BeanFactory must be set to resolve " + this.authMgrBean); try { - delegate = beanFactory.getBean(authMgrBean, - AuthenticationManager.class); + this.delegate = this.beanFactory.getBean(this.authMgrBean, AuthenticationManager.class); } - catch (NoSuchBeanDefinitionException e) { - if (BeanIds.AUTHENTICATION_MANAGER.equals(e.getBeanName())) { - throw new NoSuchBeanDefinitionException( - BeanIds.AUTHENTICATION_MANAGER, + catch (NoSuchBeanDefinitionException ex) { + if (BeanIds.AUTHENTICATION_MANAGER.equals(ex.getBeanName())) { + throw new NoSuchBeanDefinitionException(BeanIds.AUTHENTICATION_MANAGER, AuthenticationManagerFactoryBean.MISSING_BEAN_ERROR_MESSAGE); } - throw e; + throw ex; } } } - - return delegate.authenticate(authentication); + return this.delegate.authenticate(authentication); } + @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.beanFactory = beanFactory; } + } static class Jsr250MethodSecurityMetadataSourceBeanFactory extends AbstractGrantedAuthorityDefaultsBeanFactory { + private Jsr250MethodSecurityMetadataSource source = new Jsr250MethodSecurityMetadataSource(); - public Jsr250MethodSecurityMetadataSource getBean() { - source.setDefaultRolePrefix(this.rolePrefix); - return source; + Jsr250MethodSecurityMetadataSource getBean() { + this.source.setDefaultRolePrefix(this.rolePrefix); + return this.source; } + } static class DefaultMethodSecurityExpressionHandlerBeanFactory extends AbstractGrantedAuthorityDefaultsBeanFactory { + private DefaultMethodSecurityExpressionHandler handler = new DefaultMethodSecurityExpressionHandler(); - public DefaultMethodSecurityExpressionHandler getBean() { - handler.setDefaultRolePrefix(this.rolePrefix); - return handler; + DefaultMethodSecurityExpressionHandler getBean() { + this.handler.setDefaultRolePrefix(this.rolePrefix); + return this.handler; } + } - static abstract class AbstractGrantedAuthorityDefaultsBeanFactory implements ApplicationContextAware { + abstract static class AbstractGrantedAuthorityDefaultsBeanFactory implements ApplicationContextAware { + protected String rolePrefix = "ROLE_"; @Override - public final void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - String[] grantedAuthorityDefaultsBeanNames = applicationContext.getBeanNamesForType(GrantedAuthorityDefaults.class); + public final void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + String[] grantedAuthorityDefaultsBeanNames = applicationContext + .getBeanNamesForType(GrantedAuthorityDefaults.class); if (grantedAuthorityDefaultsBeanNames.length == 1) { - GrantedAuthorityDefaults grantedAuthorityDefaults = applicationContext.getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); + GrantedAuthorityDefaults grantedAuthorityDefaults = applicationContext + .getBean(grantedAuthorityDefaultsBeanNames[0], GrantedAuthorityDefaults.class); this.rolePrefix = grantedAuthorityDefaults.getRolePrefix(); } } + } /** @@ -575,25 +498,28 @@ public class GlobalMethodSecurityBeanDefinitionParser implements BeanDefinitionP * @author Rob Winch * @since 3.2 */ - private static final class LazyInitBeanDefinitionRegistryPostProcessor implements - BeanDefinitionRegistryPostProcessor { + private static final class LazyInitBeanDefinitionRegistryPostProcessor + implements BeanDefinitionRegistryPostProcessor { + private final String beanName; private LazyInitBeanDefinitionRegistryPostProcessor(String beanName) { this.beanName = beanName; } - public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) - throws BeansException { - if (!registry.containsBeanDefinition(beanName)) { + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (!registry.containsBeanDefinition(this.beanName)) { return; } - BeanDefinition beanDefinition = registry.getBeanDefinition(beanName); + BeanDefinition beanDefinition = registry.getBeanDefinition(this.beanName); beanDefinition.setLazyInit(true); } - public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) - throws BeansException { + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { } + } + } diff --git a/config/src/main/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecorator.java b/config/src/main/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecorator.java index f57eb1004b..ca4a87b477 100644 --- a/config/src/main/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecorator.java +++ b/config/src/main/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecorator.java @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; +import java.util.List; +import java.util.Map; + +import org.w3c.dom.Element; +import org.w3c.dom.Node; + import org.springframework.aop.config.AbstractInterceptorDrivenBeanDefinitionDecorator; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.ManagedMap; import org.springframework.beans.factory.support.RootBeanDefinition; @@ -31,10 +39,6 @@ import org.springframework.security.config.BeanIds; import org.springframework.security.config.Elements; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; -import org.w3c.dom.Node; - -import java.util.*; /** * @author Luke Taylor @@ -42,78 +46,66 @@ import java.util.*; * */ public class InterceptMethodsBeanDefinitionDecorator implements BeanDefinitionDecorator { + private final BeanDefinitionDecorator delegate = new InternalInterceptMethodsBeanDefinitionDecorator(); - public BeanDefinitionHolder decorate(Node node, BeanDefinitionHolder definition, - ParserContext parserContext) { + @Override + public BeanDefinitionHolder decorate(Node node, BeanDefinitionHolder definition, ParserContext parserContext) { MethodConfigUtils.registerDefaultMethodAccessManagerIfNecessary(parserContext); - - return delegate.decorate(node, definition, parserContext); + return this.delegate.decorate(node, definition, parserContext); } -} -/** - * This is the real class which does the work. We need access to the ParserContext in - * order to do bean registration. - */ -class InternalInterceptMethodsBeanDefinitionDecorator extends - AbstractInterceptorDrivenBeanDefinitionDecorator { - static final String ATT_METHOD = "method"; - static final String ATT_ACCESS = "access"; - private static final String ATT_ACCESS_MGR = "access-decision-manager-ref"; + /** + * This is the real class which does the work. We need access to the ParserContext in + * order to do bean registration. + */ + static class InternalInterceptMethodsBeanDefinitionDecorator + extends AbstractInterceptorDrivenBeanDefinitionDecorator { - protected BeanDefinition createInterceptorDefinition(Node node) { - Element interceptMethodsElt = (Element) node; - BeanDefinitionBuilder interceptor = BeanDefinitionBuilder - .rootBeanDefinition(MethodSecurityInterceptor.class); + static final String ATT_METHOD = "method"; - // Default to autowiring to pick up after invocation mgr - interceptor.setAutowireMode(RootBeanDefinition.AUTOWIRE_BY_TYPE); + static final String ATT_ACCESS = "access"; - String accessManagerId = interceptMethodsElt.getAttribute(ATT_ACCESS_MGR); + private static final String ATT_ACCESS_MGR = "access-decision-manager-ref"; - if (!StringUtils.hasText(accessManagerId)) { - accessManagerId = BeanIds.METHOD_ACCESS_MANAGER; - } - - interceptor.addPropertyValue("accessDecisionManager", new RuntimeBeanReference( - accessManagerId)); - interceptor.addPropertyValue("authenticationManager", new RuntimeBeanReference( - BeanIds.AUTHENTICATION_MANAGER)); - - // Lookup parent bean information - - String parentBeanClass = ((Element) node.getParentNode()).getAttribute("class"); - - // Parse the included methods - List methods = DomUtils.getChildElementsByTagName(interceptMethodsElt, - Elements.PROTECT); - Map mappings = new ManagedMap<>(); - - for (Element protectmethodElt : methods) { - BeanDefinitionBuilder attributeBuilder = BeanDefinitionBuilder - .rootBeanDefinition(SecurityConfig.class); - attributeBuilder.setFactoryMethod("createListFromCommaDelimitedString"); - attributeBuilder.addConstructorArgValue(protectmethodElt - .getAttribute(ATT_ACCESS)); - - // Support inference of class names - String methodName = protectmethodElt.getAttribute(ATT_METHOD); - - if (methodName.lastIndexOf(".") == -1) { - if (parentBeanClass != null && !"".equals(parentBeanClass)) { - methodName = parentBeanClass + "." + methodName; - } + @Override + protected BeanDefinition createInterceptorDefinition(Node node) { + Element interceptMethodsElt = (Element) node; + BeanDefinitionBuilder interceptor = BeanDefinitionBuilder + .rootBeanDefinition(MethodSecurityInterceptor.class); + // Default to autowiring to pick up after invocation mgr + interceptor.setAutowireMode(AbstractBeanDefinition.AUTOWIRE_BY_TYPE); + String accessManagerId = interceptMethodsElt.getAttribute(ATT_ACCESS_MGR); + if (!StringUtils.hasText(accessManagerId)) { + accessManagerId = BeanIds.METHOD_ACCESS_MANAGER; } - - mappings.put(methodName, attributeBuilder.getBeanDefinition()); + interceptor.addPropertyValue("accessDecisionManager", new RuntimeBeanReference(accessManagerId)); + interceptor.addPropertyValue("authenticationManager", + new RuntimeBeanReference(BeanIds.AUTHENTICATION_MANAGER)); + // Lookup parent bean information + String parentBeanClass = ((Element) node.getParentNode()).getAttribute("class"); + // Parse the included methods + List methods = DomUtils.getChildElementsByTagName(interceptMethodsElt, Elements.PROTECT); + Map mappings = new ManagedMap<>(); + for (Element protectmethodElt : methods) { + BeanDefinitionBuilder attributeBuilder = BeanDefinitionBuilder.rootBeanDefinition(SecurityConfig.class); + attributeBuilder.setFactoryMethod("createListFromCommaDelimitedString"); + attributeBuilder.addConstructorArgValue(protectmethodElt.getAttribute(ATT_ACCESS)); + // Support inference of class names + String methodName = protectmethodElt.getAttribute(ATT_METHOD); + if (methodName.lastIndexOf(".") == -1) { + if (parentBeanClass != null && !"".equals(parentBeanClass)) { + methodName = parentBeanClass + "." + methodName; + } + } + mappings.put(methodName, attributeBuilder.getBeanDefinition()); + } + BeanDefinition metadataSource = new RootBeanDefinition(MapBasedMethodSecurityMetadataSource.class); + metadataSource.getConstructorArgumentValues().addGenericArgumentValue(mappings); + interceptor.addPropertyValue("securityMetadataSource", metadataSource); + return interceptor.getBeanDefinition(); } - BeanDefinition metadataSource = new RootBeanDefinition( - MapBasedMethodSecurityMetadataSource.class); - metadataSource.getConstructorArgumentValues().addGenericArgumentValue(mappings); - interceptor.addPropertyValue("securityMetadataSource", metadataSource); - - return interceptor.getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/method/MethodConfigUtils.java b/config/src/main/java/org/springframework/security/config/method/MethodConfigUtils.java index 32ae929c32..f951d01446 100644 --- a/config/src/main/java/org/springframework/security/config/method/MethodConfigUtils.java +++ b/config/src/main/java/org/springframework/security/config/method/MethodConfigUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import org.springframework.beans.factory.support.BeanDefinitionBuilder; @@ -33,28 +34,24 @@ import org.springframework.security.config.BeanIds; * @author Rob Winch */ abstract class MethodConfigUtils { + @SuppressWarnings("unchecked") static void registerDefaultMethodAccessManagerIfNecessary(ParserContext parserContext) { - if (!parserContext.getRegistry().containsBeanDefinition( - BeanIds.METHOD_ACCESS_MANAGER)) { - parserContext.getRegistry().registerBeanDefinition( - BeanIds.METHOD_ACCESS_MANAGER, + if (!parserContext.getRegistry().containsBeanDefinition(BeanIds.METHOD_ACCESS_MANAGER)) { + parserContext.getRegistry().registerBeanDefinition(BeanIds.METHOD_ACCESS_MANAGER, createAccessManagerBean(RoleVoter.class, AuthenticatedVoter.class)); } } @SuppressWarnings("unchecked") - private static RootBeanDefinition createAccessManagerBean( - Class... voters) { + private static RootBeanDefinition createAccessManagerBean(Class... voters) { ManagedList defaultVoters = new ManagedList(voters.length); - for (Class voter : voters) { defaultVoters.add(new RootBeanDefinition(voter)); } - - BeanDefinitionBuilder accessMgrBuilder = BeanDefinitionBuilder - .rootBeanDefinition(AffirmativeBased.class); + BeanDefinitionBuilder accessMgrBuilder = BeanDefinitionBuilder.rootBeanDefinition(AffirmativeBased.class); accessMgrBuilder.addConstructorArgValue(defaultVoters); return (RootBeanDefinition) accessMgrBuilder.getBeanDefinition(); } + } diff --git a/config/src/main/java/org/springframework/security/config/method/MethodSecurityMetadataSourceBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/method/MethodSecurityMetadataSourceBeanDefinitionParser.java index 6d4f18b79b..6b0228de93 100644 --- a/config/src/main/java/org/springframework/security/config/method/MethodSecurityMetadataSourceBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/method/MethodSecurityMetadataSourceBeanDefinitionParser.java @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.w3c.dom.Element; + import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.AbstractBeanDefinitionParser; @@ -29,36 +32,29 @@ import org.springframework.security.access.method.MapBasedMethodSecurityMetadata import org.springframework.security.config.Elements; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** - * * @author Luke Taylor * @since 3.1 */ -public class MethodSecurityMetadataSourceBeanDefinitionParser extends - AbstractBeanDefinitionParser { +public class MethodSecurityMetadataSourceBeanDefinitionParser extends AbstractBeanDefinitionParser { + static final String ATT_METHOD = "method"; + static final String ATT_ACCESS = "access"; + @Override public AbstractBeanDefinition parseInternal(Element elt, ParserContext pc) { // Parse the included methods List methods = DomUtils.getChildElementsByTagName(elt, Elements.PROTECT); Map> mappings = new LinkedHashMap<>(); - for (Element protectmethodElt : methods) { - String[] tokens = StringUtils - .commaDelimitedListToStringArray(protectmethodElt - .getAttribute(ATT_ACCESS)); + String[] tokens = StringUtils.commaDelimitedListToStringArray(protectmethodElt.getAttribute(ATT_ACCESS)); String methodName = protectmethodElt.getAttribute(ATT_METHOD); - mappings.put(methodName, SecurityConfig.createList(tokens)); } - - RootBeanDefinition metadataSource = new RootBeanDefinition( - MapBasedMethodSecurityMetadataSource.class); + RootBeanDefinition metadataSource = new RootBeanDefinition(MapBasedMethodSecurityMetadataSource.class); metadataSource.getConstructorArgumentValues().addGenericArgumentValue(mappings); - return metadataSource; } diff --git a/config/src/main/java/org/springframework/security/config/method/ProtectPointcutPostProcessor.java b/config/src/main/java/org/springframework/security/config/method/ProtectPointcutPostProcessor.java index 1164a9682a..f35ecc9a88 100644 --- a/config/src/main/java/org/springframework/security/config/method/ProtectPointcutPostProcessor.java +++ b/config/src/main/java/org/springframework/security/config/method/ProtectPointcutPostProcessor.java @@ -13,16 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.lang.reflect.Method; -import java.util.*; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.aspectj.weaver.tools.PointcutExpression; import org.aspectj.weaver.tools.PointcutParser; import org.aspectj.weaver.tools.PointcutPrimitive; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.security.access.ConfigAttribute; @@ -59,21 +66,22 @@ import org.springframework.util.StringUtils; */ final class ProtectPointcutPostProcessor implements BeanPostProcessor { - private static final Log logger = LogFactory - .getLog(ProtectPointcutPostProcessor.class); + private static final Log logger = LogFactory.getLog(ProtectPointcutPostProcessor.class); private final Map> pointcutMap = new LinkedHashMap<>(); + private final MapBasedMethodSecurityMetadataSource mapBasedMethodSecurityMetadataSource; + private final Set pointCutExpressions = new LinkedHashSet<>(); + private final PointcutParser parser; + private final Set processedBeans = new HashSet<>(); - ProtectPointcutPostProcessor( - MapBasedMethodSecurityMetadataSource mapBasedMethodSecurityMetadataSource) { + ProtectPointcutPostProcessor(MapBasedMethodSecurityMetadataSource mapBasedMethodSecurityMetadataSource) { Assert.notNull(mapBasedMethodSecurityMetadataSource, "MapBasedMethodSecurityMetadataSource to populate is required"); this.mapBasedMethodSecurityMetadataSource = mapBasedMethodSecurityMetadataSource; - // Set up AspectJ pointcut expression parser Set supportedPrimitives = new HashSet<>(3); supportedPrimitives.add(PointcutPrimitive.EXECUTION); @@ -86,41 +94,33 @@ final class ProtectPointcutPostProcessor implements BeanPostProcessor { // supportedPrimitives.add(PointcutPrimitive.AT_WITHIN); // supportedPrimitives.add(PointcutPrimitive.AT_ARGS); // supportedPrimitives.add(PointcutPrimitive.AT_TARGET); - parser = PointcutParser - .getPointcutParserSupportingSpecifiedPrimitivesAndUsingContextClassloaderForResolution(supportedPrimitives); + this.parser = PointcutParser + .getPointcutParserSupportingSpecifiedPrimitivesAndUsingContextClassloaderForResolution( + supportedPrimitives); } - public Object postProcessAfterInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { return bean; } - public Object postProcessBeforeInitialization(Object bean, String beanName) - throws BeansException { - if (processedBeans.contains(beanName)) { + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + if (this.processedBeans.contains(beanName)) { // We already have the metadata for this bean return bean; } - - synchronized (processedBeans) { + synchronized (this.processedBeans) { // check again synchronized this time - if (processedBeans.contains(beanName)) { + if (this.processedBeans.contains(beanName)) { return bean; } - // Obtain methods for the present bean - Method[] methods; - try { - methods = bean.getClass().getMethods(); - } - catch (Exception e) { - throw new IllegalStateException(e.getMessage()); - } - + Method[] methods = getBeanMethods(bean); // Check to see if any of those methods are compatible with our pointcut // expressions for (Method method : methods) { - for (PointcutExpression expression : pointCutExpressions) { + for (PointcutExpression expression : this.pointCutExpressions) { // Try for the bean class directly if (attemptMatch(bean.getClass(), method, expression, beanName)) { // We've found the first expression that matches this method, so @@ -129,36 +129,34 @@ final class ProtectPointcutPostProcessor implements BeanPostProcessor { } } } - - processedBeans.add(beanName); + this.processedBeans.add(beanName); } - return bean; } - private boolean attemptMatch(Class targetClass, Method method, - PointcutExpression expression, String beanName) { + private Method[] getBeanMethods(Object bean) { + try { + return bean.getClass().getMethods(); + } + catch (Exception ex) { + throw new IllegalStateException(ex.getMessage()); + } + } + + private boolean attemptMatch(Class targetClass, Method method, PointcutExpression expression, String beanName) { // Determine if the presented AspectJ pointcut expression matches this method boolean matches = expression.matchesMethodExecution(method).alwaysMatches(); - // Handle accordingly if (matches) { - List attr = pointcutMap.get(expression - .getPointcutExpression()); - + List attr = this.pointcutMap.get(expression.getPointcutExpression()); if (logger.isDebugEnabled()) { - logger.debug("AspectJ pointcut expression '" - + expression.getPointcutExpression() + "' matches target class '" - + targetClass.getName() + "' (bean ID '" + beanName - + "') for method '" + method - + "'; registering security configuration attribute '" + attr + logger.debug("AspectJ pointcut expression '" + expression.getPointcutExpression() + + "' matches target class '" + targetClass.getName() + "' (bean ID '" + beanName + + "') for method '" + method + "'; registering security configuration attribute '" + attr + "'"); } - - mapBasedMethodSecurityMetadataSource.addSecureMethod(targetClass, method, - attr); + this.mapBasedMethodSecurityMetadataSource.addSecureMethod(targetClass, method, attr); } - return matches; } @@ -174,14 +172,12 @@ final class ProtectPointcutPostProcessor implements BeanPostProcessor { Assert.hasText(pointcutExpression, "An AspectJ pointcut expression is required"); Assert.notNull(definition, "A List of ConfigAttributes is required"); pointcutExpression = replaceBooleanOperators(pointcutExpression); - pointcutMap.put(pointcutExpression, definition); + this.pointcutMap.put(pointcutExpression, definition); // Parse the presented AspectJ pointcut expression and add it to the cache - pointCutExpressions.add(parser.parsePointcutExpression(pointcutExpression)); - + this.pointCutExpressions.add(this.parser.parsePointcutExpression(pointcutExpression)); if (logger.isDebugEnabled()) { logger.debug("AspectJ pointcut expression '" + pointcutExpression - + "' registered for security configuration attribute '" + definition - + "'"); + + "' registered for security configuration attribute '" + definition + "'"); } } diff --git a/config/src/main/java/org/springframework/security/config/method/package-info.java b/config/src/main/java/org/springframework/security/config/method/package-info.java index c985fc7e60..f393eefa82 100644 --- a/config/src/main/java/org/springframework/security/config/method/package-info.java +++ b/config/src/main/java/org/springframework/security/config/method/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Support for parsing of the <global-method-security> and <intercept-methods> elements. + * Support for parsing of the <global-method-security> and <intercept-methods> + * elements. */ package org.springframework.security.config.method; - diff --git a/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java index 976818d0ff..78fb3543a7 100644 --- a/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.oauth2.client; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.w3c.dom.Element; + import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.parsing.BeanComponentDefinition; import org.springframework.beans.factory.parsing.CompositeComponentDefinition; @@ -29,14 +39,6 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; - -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; /** * @author Ruby Hartono @@ -45,22 +47,39 @@ import java.util.Optional; public final class ClientRegistrationsBeanDefinitionParser implements BeanDefinitionParser { private static final String ELT_CLIENT_REGISTRATION = "client-registration"; + private static final String ELT_PROVIDER = "provider"; + private static final String ATT_REGISTRATION_ID = "registration-id"; + private static final String ATT_CLIENT_ID = "client-id"; + private static final String ATT_CLIENT_SECRET = "client-secret"; + private static final String ATT_CLIENT_AUTHENTICATION_METHOD = "client-authentication-method"; + private static final String ATT_AUTHORIZATION_GRANT_TYPE = "authorization-grant-type"; + private static final String ATT_REDIRECT_URI = "redirect-uri"; + private static final String ATT_SCOPE = "scope"; + private static final String ATT_CLIENT_NAME = "client-name"; + private static final String ATT_PROVIDER_ID = "provider-id"; + private static final String ATT_AUTHORIZATION_URI = "authorization-uri"; + private static final String ATT_TOKEN_URI = "token-uri"; + private static final String ATT_USER_INFO_URI = "user-info-uri"; + private static final String ATT_USER_INFO_AUTHENTICATION_METHOD = "user-info-authentication-method"; + private static final String ATT_USER_INFO_USER_NAME_ATTRIBUTE = "user-info-user-name-attribute"; + private static final String ATT_JWK_SET_URI = "jwk-set-uri"; + private static final String ATT_ISSUER_URI = "issuer-uri"; @Override @@ -68,19 +87,15 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), parserContext.extractSource(element)); parserContext.pushContainingComponent(compositeDef); - Map> providers = getProviders(element); List clientRegistrations = getClientRegistrations(element, parserContext, providers); - BeanDefinition clientRegistrationRepositoryBean = BeanDefinitionBuilder .rootBeanDefinition(InMemoryClientRegistrationRepository.class) - .addConstructorArgValue(clientRegistrations) - .getBeanDefinition(); - String clientRegistrationRepositoryId = parserContext.getReaderContext().generateBeanName( - clientRegistrationRepositoryBean); - parserContext.registerBeanComponent(new BeanComponentDefinition( - clientRegistrationRepositoryBean, clientRegistrationRepositoryId)); - + .addConstructorArgValue(clientRegistrations).getBeanDefinition(); + String clientRegistrationRepositoryId = parserContext.getReaderContext() + .generateBeanName(clientRegistrationRepositoryBean); + parserContext.registerBeanComponent( + new BeanComponentDefinition(clientRegistrationRepositoryBean, clientRegistrationRepositoryId)); parserContext.popAndRegisterContainingComponent(); return null; } @@ -89,7 +104,6 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini Map> providers) { List clientRegistrationElts = DomUtils.getChildElementsByTagName(element, ELT_CLIENT_REGISTRATION); List clientRegistrations = new ArrayList<>(); - for (Element clientRegistrationElt : clientRegistrationElts) { String registrationId = clientRegistrationElt.getAttribute(ATT_REGISTRATION_ID); String providerId = clientRegistrationElt.getAttribute(ATT_PROVIDER_ID); @@ -103,60 +117,51 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini continue; } } - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_ID)) - .ifPresent(builder::clientId); + getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_ID)).ifPresent(builder::clientId); getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET)) .ifPresent(builder::clientSecret); getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD)) - .map(ClientAuthenticationMethod::new) - .ifPresent(builder::clientAuthenticationMethod); + .map(ClientAuthenticationMethod::new).ifPresent(builder::clientAuthenticationMethod); getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE)) - .map(AuthorizationGrantType::new) - .ifPresent(builder::authorizationGrantType); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_REDIRECT_URI)) - .ifPresent(builder::redirectUri); + .map(AuthorizationGrantType::new).ifPresent(builder::authorizationGrantType); + getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_REDIRECT_URI)).ifPresent(builder::redirectUri); getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_SCOPE)) - .map(StringUtils::commaDelimitedListToSet) - .ifPresent(builder::scope); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_NAME)) - .ifPresent(builder::clientName); + .map(StringUtils::commaDelimitedListToSet).ifPresent(builder::scope); + getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_NAME)).ifPresent(builder::clientName); clientRegistrations.add(builder.build()); } - return clientRegistrations; } private Map> getProviders(Element element) { List providerElts = DomUtils.getChildElementsByTagName(element, ELT_PROVIDER); Map> providers = new HashMap<>(); - for (Element providerElt : providerElts) { Map provider = new HashMap<>(); String providerId = providerElt.getAttribute(ATT_PROVIDER_ID); provider.put(ATT_PROVIDER_ID, providerId); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_AUTHORIZATION_URI)) - .ifPresent(value -> provider.put(ATT_AUTHORIZATION_URI, value)); + .ifPresent((value) -> provider.put(ATT_AUTHORIZATION_URI, value)); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_TOKEN_URI)) - .ifPresent(value -> provider.put(ATT_TOKEN_URI, value)); + .ifPresent((value) -> provider.put(ATT_TOKEN_URI, value)); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_URI)) - .ifPresent(value -> provider.put(ATT_USER_INFO_URI, value)); + .ifPresent((value) -> provider.put(ATT_USER_INFO_URI, value)); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD)) - .ifPresent(value -> provider.put(ATT_USER_INFO_AUTHENTICATION_METHOD, value)); + .ifPresent((value) -> provider.put(ATT_USER_INFO_AUTHENTICATION_METHOD, value)); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE)) - .ifPresent(value -> provider.put(ATT_USER_INFO_USER_NAME_ATTRIBUTE, value)); + .ifPresent((value) -> provider.put(ATT_USER_INFO_USER_NAME_ATTRIBUTE, value)); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_JWK_SET_URI)) - .ifPresent(value -> provider.put(ATT_JWK_SET_URI, value)); + .ifPresent((value) -> provider.put(ATT_JWK_SET_URI, value)); getOptionalIfNotEmpty(providerElt.getAttribute(ATT_ISSUER_URI)) - .ifPresent(value -> provider.put(ATT_ISSUER_URI, value)); + .ifPresent((value) -> provider.put(ATT_ISSUER_URI, value)); providers.put(providerId, provider); } - return providers; } private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(String registrationId, String configuredProviderId, Map> providers) { - String providerId = configuredProviderId != null ? configuredProviderId : registrationId; + String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId; if (providers.containsKey(providerId)) { Map provider = providers.get(providerId); String issuer = provider.get(ATT_ISSUER_URI); @@ -176,7 +181,7 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini if (provider == null && !providers.containsKey(providerId)) { return null; } - ClientRegistration.Builder builder = provider != null ? provider.getBuilder(registrationId) + ClientRegistration.Builder builder = (provider != null) ? provider.getBuilder(registrationId) : ClientRegistration.withRegistrationId(registrationId); if (providers.containsKey(providerId)) { return getBuilder(builder, providers.get(providerId)); @@ -186,24 +191,19 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini private static ClientRegistration.Builder getBuilder(ClientRegistration.Builder builder, Map provider) { - getOptionalIfNotEmpty(provider.get(ATT_AUTHORIZATION_URI)) - .ifPresent(builder::authorizationUri); - getOptionalIfNotEmpty(provider.get(ATT_TOKEN_URI)) - .ifPresent(builder::tokenUri); - getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_URI)) - .ifPresent(builder::userInfoUri); - getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD)) - .map(AuthenticationMethod::new) + getOptionalIfNotEmpty(provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri); + getOptionalIfNotEmpty(provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri); + getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri); + getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD)).map(AuthenticationMethod::new) .ifPresent(builder::userInfoAuthenticationMethod); - getOptionalIfNotEmpty(provider.get(ATT_JWK_SET_URI)) - .ifPresent(builder::jwkSetUri); + getOptionalIfNotEmpty(provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri); getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE)) .ifPresent(builder::userNameAttributeName); return builder; } private static Optional getOptionalIfNotEmpty(String str) { - return Optional.ofNullable(str).filter(s -> !s.isEmpty()); + return Optional.ofNullable(str).filter((s) -> !s.isEmpty()); } private static CommonOAuth2Provider getCommonProvider(String providerId) { @@ -214,10 +214,12 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini } try { return CommonOAuth2Provider.valueOf(value); - } catch (Exception ex) { + } + catch (Exception ex) { return findEnum(value); } - } catch (Exception ex) { + } + catch (Exception ex) { return null; } } @@ -237,12 +239,13 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini private static String getCanonicalName(String name) { StringBuilder canonicalName = new StringBuilder(name.length()); name.chars().filter(Character::isLetterOrDigit).map(Character::toLowerCase) - .forEach(c -> canonicalName.append((char) c)); + .forEach((c) -> canonicalName.append((char) c)); return canonicalName.toString(); } private static String getErrorMessage(String configuredProviderId, String registrationId) { - return configuredProviderId != null ? "Unknown provider ID '" + configuredProviderId + "'" + return (configuredProviderId != null) ? "Unknown provider ID '" + configuredProviderId + "'" : "Provider ID must be specified for client registration '" + registrationId + "'"; } + } diff --git a/config/src/main/java/org/springframework/security/config/oauth2/client/CommonOAuth2Provider.java b/config/src/main/java/org/springframework/security/config/oauth2/client/CommonOAuth2Provider.java index 71539b8ac1..74cec2b2ca 100644 --- a/config/src/main/java/org/springframework/security/config/oauth2/client/CommonOAuth2Provider.java +++ b/config/src/main/java/org/springframework/security/config/oauth2/client/CommonOAuth2Provider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.oauth2.client; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -35,8 +36,8 @@ public enum CommonOAuth2Provider { @Override public Builder getBuilder(String registrationId) { - ClientRegistration.Builder builder = getBuilder(registrationId, - ClientAuthenticationMethod.BASIC, DEFAULT_REDIRECT_URL); + ClientRegistration.Builder builder = getBuilder(registrationId, ClientAuthenticationMethod.BASIC, + DEFAULT_REDIRECT_URL); builder.scope("openid", "profile", "email"); builder.authorizationUri("https://accounts.google.com/o/oauth2/v2/auth"); builder.tokenUri("https://www.googleapis.com/oauth2/v4/token"); @@ -47,14 +48,15 @@ public enum CommonOAuth2Provider { builder.clientName("Google"); return builder; } + }, GITHUB { @Override public Builder getBuilder(String registrationId) { - ClientRegistration.Builder builder = getBuilder(registrationId, - ClientAuthenticationMethod.BASIC, DEFAULT_REDIRECT_URL); + ClientRegistration.Builder builder = getBuilder(registrationId, ClientAuthenticationMethod.BASIC, + DEFAULT_REDIRECT_URL); builder.scope("read:user"); builder.authorizationUri("https://github.com/login/oauth/authorize"); builder.tokenUri("https://github.com/login/oauth/access_token"); @@ -63,14 +65,15 @@ public enum CommonOAuth2Provider { builder.clientName("GitHub"); return builder; } + }, FACEBOOK { @Override public Builder getBuilder(String registrationId) { - ClientRegistration.Builder builder = getBuilder(registrationId, - ClientAuthenticationMethod.POST, DEFAULT_REDIRECT_URL); + ClientRegistration.Builder builder = getBuilder(registrationId, ClientAuthenticationMethod.POST, + DEFAULT_REDIRECT_URL); builder.scope("public_profile", "email"); builder.authorizationUri("https://www.facebook.com/v2.8/dialog/oauth"); builder.tokenUri("https://graph.facebook.com/v2.8/oauth/access_token"); @@ -79,25 +82,27 @@ public enum CommonOAuth2Provider { builder.clientName("Facebook"); return builder; } + }, OKTA { @Override public Builder getBuilder(String registrationId) { - ClientRegistration.Builder builder = getBuilder(registrationId, - ClientAuthenticationMethod.BASIC, DEFAULT_REDIRECT_URL); + ClientRegistration.Builder builder = getBuilder(registrationId, ClientAuthenticationMethod.BASIC, + DEFAULT_REDIRECT_URL); builder.scope("openid", "profile", "email"); builder.userNameAttributeName(IdTokenClaimNames.SUB); builder.clientName("Okta"); return builder; } + }; private static final String DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"; - protected final ClientRegistration.Builder getBuilder(String registrationId, - ClientAuthenticationMethod method, String redirectUri) { + protected final ClientRegistration.Builder getBuilder(String registrationId, ClientAuthenticationMethod method, + String redirectUri) { ClientRegistration.Builder builder = ClientRegistration.withRegistrationId(registrationId); builder.clientAuthenticationMethod(method); builder.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE); diff --git a/config/src/main/java/org/springframework/security/config/package-info.java b/config/src/main/java/org/springframework/security/config/package-info.java index 9d9c373acf..10c9d64e61 100644 --- a/config/src/main/java/org/springframework/security/config/package-info.java +++ b/config/src/main/java/org/springframework/security/config/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Support classes for the Spring Security namespace. None of the code in these packages should be used directly - * in applications. + * Support classes for the Spring Security namespace. None of the code in these packages + * should be used directly in applications. */ package org.springframework.security.config; - diff --git a/config/src/main/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBean.java b/config/src/main/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBean.java index b6552168c2..a8c45f04d3 100644 --- a/config/src/main/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBean.java +++ b/config/src/main/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBean.java @@ -16,6 +16,8 @@ package org.springframework.security.config.provisioning; +import java.util.Collection; + import org.springframework.beans.factory.FactoryBean; import org.springframework.context.ResourceLoaderAware; import org.springframework.core.io.Resource; @@ -25,21 +27,22 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.util.InMemoryResource; -import java.util.Collection; - /** - * Constructs an {@link InMemoryUserDetailsManager} from a resource using {@link UserDetailsResourceFactoryBean}. + * Constructs an {@link InMemoryUserDetailsManager} from a resource using + * {@link UserDetailsResourceFactoryBean}. * * @author Rob Winch * @since 5.0 * @see UserDetailsResourceFactoryBean */ -public class UserDetailsManagerResourceFactoryBean implements ResourceLoaderAware, FactoryBean { +public class UserDetailsManagerResourceFactoryBean + implements ResourceLoaderAware, FactoryBean { + private UserDetailsResourceFactoryBean userDetails = new UserDetailsResourceFactoryBean(); @Override public InMemoryUserDetailsManager getObject() throws Exception { - Collection users = userDetails.getObject(); + Collection users = this.userDetails.getObject(); return new InMemoryUserDetailsManager(users); } @@ -50,21 +53,22 @@ public class UserDetailsManagerResourceFactoryBean implements ResourceLoaderAwar @Override public void setResourceLoader(ResourceLoader resourceLoader) { - userDetails.setResourceLoader(resourceLoader); + this.userDetails.setResourceLoader(resourceLoader); } /** - * Sets the location of a Resource that is a Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param resourceLocation the location of the properties file that contains the users (i.e. "classpath:users.properties") + * Sets the location of a Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. + * @param resourceLocation the location of the properties file that contains the users + * (i.e. "classpath:users.properties") */ public void setResourceLocation(String resourceLocation) { this.userDetails.setResourceLocation(resourceLocation); } /** - * Sets a Resource that is a Properties file in the format defined in {@link UserDetailsResourceFactoryBean}. - * + * Sets a Resource that is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. * @param resource the Resource to use */ public void setResource(Resource resource) { @@ -72,10 +76,11 @@ public class UserDetailsManagerResourceFactoryBean implements ResourceLoaderAwar } /** - * Create a UserDetailsManagerResourceFactoryBean with the location of a Resource that is a Properties file in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param resourceLocation the location of the properties file that contains the users (i.e. "classpath:users.properties") + * Create a UserDetailsManagerResourceFactoryBean with the location of a Resource that + * is a Properties file in the format defined in + * {@link UserDetailsResourceFactoryBean}. + * @param resourceLocation the location of the properties file that contains the users + * (i.e. "classpath:users.properties") * @return the UserDetailsManagerResourceFactoryBean */ public static UserDetailsManagerResourceFactoryBean fromResourceLocation(String resourceLocation) { @@ -85,9 +90,8 @@ public class UserDetailsManagerResourceFactoryBean implements ResourceLoaderAwar } /** - * Create a UserDetailsManagerResourceFactoryBean with a Resource that is a Properties file in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * + * Create a UserDetailsManagerResourceFactoryBean with a Resource that is a Properties + * file in the format defined in {@link UserDetailsResourceFactoryBean}. * @param resource the Resource that is a properties file that contains the users * @return the UserDetailsManagerResourceFactoryBean */ @@ -98,10 +102,10 @@ public class UserDetailsManagerResourceFactoryBean implements ResourceLoaderAwar } /** - * Create a UserDetailsManagerResourceFactoryBean with a String that is in the - * format defined in {@link UserDetailsResourceFactoryBean}. - * - * @param users the users in the format defined in {@link UserDetailsResourceFactoryBean} + * Create a UserDetailsManagerResourceFactoryBean with a String that is in the format + * defined in {@link UserDetailsResourceFactoryBean}. + * @param users the users in the format defined in + * {@link UserDetailsResourceFactoryBean} * @return the UserDetailsManagerResourceFactoryBean */ public static UserDetailsManagerResourceFactoryBean fromString(String users) { @@ -109,4 +113,5 @@ public class UserDetailsManagerResourceFactoryBean implements ResourceLoaderAwar result.setResource(new InMemoryResource(users)); return result; } + } diff --git a/config/src/main/java/org/springframework/security/config/web/server/AbstractServerWebExchangeMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/web/server/AbstractServerWebExchangeMatcherRegistry.java index fc2641b022..aa3382f1f0 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/AbstractServerWebExchangeMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/web/server/AbstractServerWebExchangeMatcherRegistry.java @@ -13,71 +13,65 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.web.server; +import java.util.List; + import org.springframework.http.HttpMethod; import org.springframework.security.web.server.util.matcher.OrServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; -import java.util.List; - - /** * @author Rob Winch * @since 5.0 */ -abstract class AbstractServerWebExchangeMatcherRegistry { +public abstract class AbstractServerWebExchangeMatcherRegistry { + + AbstractServerWebExchangeMatcherRegistry() { + } /** * Maps any request. - * - * @return the object that is chained after creating the {@link ServerWebExchangeMatcher} + * @return the object that is chained after creating the + * {@link ServerWebExchangeMatcher} */ public T anyExchange() { return matcher(ServerWebExchangeMatchers.anyExchange()); } /** - * Maps a {@link List} of - * {@link PathPatternParserServerWebExchangeMatcher} - * instances. - * - * @param method the {@link HttpMethod} to use for any - * {@link HttpMethod}. - * - * @return the object that is chained after creating the {@link ServerWebExchangeMatcher} + * Maps a {@link List} of {@link PathPatternParserServerWebExchangeMatcher} instances. + * @param method the {@link HttpMethod} to use for any {@link HttpMethod}. + * @return the object that is chained after creating the + * {@link ServerWebExchangeMatcher} */ public T pathMatchers(HttpMethod method) { return pathMatchers(method, new String[] { "/**" }); } /** - * Maps a {@link List} of - * {@link PathPatternParserServerWebExchangeMatcher} - * instances. - * + * Maps a {@link List} of {@link PathPatternParserServerWebExchangeMatcher} instances. * @param method the {@link HttpMethod} to use or {@code null} for any * {@link HttpMethod}. - * @param antPatterns the ant patterns to create. If {@code null} or empty, then matches on nothing. - * {@link PathPatternParserServerWebExchangeMatcher} from - * - * @return the object that is chained after creating the {@link ServerWebExchangeMatcher} + * @param antPatterns the ant patterns to create. If {@code null} or empty, then + * matches on nothing. {@link PathPatternParserServerWebExchangeMatcher} from + * @return the object that is chained after creating the + * {@link ServerWebExchangeMatcher} */ public T pathMatchers(HttpMethod method, String... antPatterns) { return matcher(ServerWebExchangeMatchers.pathMatchers(method, antPatterns)); } /** - * Maps a {@link List} of - * {@link PathPatternParserServerWebExchangeMatcher} - * instances that do not care which {@link HttpMethod} is used. - * + * Maps a {@link List} of {@link PathPatternParserServerWebExchangeMatcher} instances + * that do not care which {@link HttpMethod} is used. * @param antPatterns the ant patterns to create * {@link PathPatternParserServerWebExchangeMatcher} from - * - * @return the object that is chained after creating the {@link ServerWebExchangeMatcher} + * @return the object that is chained after creating the + * {@link ServerWebExchangeMatcher} */ public T pathMatchers(String... antPatterns) { return matcher(ServerWebExchangeMatchers.pathMatchers(antPatterns)); @@ -85,10 +79,9 @@ abstract class AbstractServerWebExchangeMatcherRegistry { /** * Associates a list of {@link ServerWebExchangeMatcher} instances - * * @param matchers the {@link ServerWebExchangeMatcher} instances - * - * @return the object that is chained after creating the {@link ServerWebExchangeMatcher} + * @return the object that is chained after creating the + * {@link ServerWebExchangeMatcher} */ public T matchers(ServerWebExchangeMatcher... matchers) { return registerMatcher(new OrServerWebExchangeMatcher(matchers)); @@ -97,7 +90,6 @@ abstract class AbstractServerWebExchangeMatcherRegistry { /** * Subclasses should implement this method for returning the object that is chained to * the creation of the {@link ServerWebExchangeMatcher} instances. - * * @param matcher the {@link ServerWebExchangeMatcher} instances that were created * @return the chained Object for the subclass which allows association of something * else to the {@link ServerWebExchangeMatcher} @@ -106,12 +98,12 @@ abstract class AbstractServerWebExchangeMatcherRegistry { /** * Associates a {@link ServerWebExchangeMatcher} instances - * * @param matcher the {@link ServerWebExchangeMatcher} instance - * - * @return the object that is chained after creating the {@link ServerWebExchangeMatcher} + * @return the object that is chained after creating the + * {@link ServerWebExchangeMatcher} */ private T matcher(ServerWebExchangeMatcher matcher) { return registerMatcher(matcher); } + } diff --git a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java index 4078ff5deb..101d5b15c5 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java +++ b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java @@ -13,59 +13,76 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.web.server; +package org.springframework.security.config.web.server; /** * @author Rob Winch * @since 5.0 */ public enum SecurityWebFiltersOrder { + FIRST(Integer.MIN_VALUE), + HTTP_HEADERS_WRITER, + /** * {@link org.springframework.security.web.server.transport.HttpsRedirectWebFilter} */ HTTPS_REDIRECT, + /** * {@link org.springframework.web.cors.reactive.CorsWebFilter} */ CORS, + /** * {@link org.springframework.security.web.server.csrf.CsrfWebFilter} */ CSRF, + /** * {@link org.springframework.security.web.server.context.ReactorContextWebFilter} */ REACTOR_CONTEXT, + /** * Instance of AuthenticationWebFilter */ HTTP_BASIC, + /** * Instance of AuthenticationWebFilter */ - FORM_LOGIN, - AUTHENTICATION, + FORM_LOGIN, AUTHENTICATION, + /** * Instance of AnonymousAuthenticationWebFilter */ ANONYMOUS_AUTHENTICATION, + OAUTH2_AUTHORIZATION_CODE, + LOGIN_PAGE_GENERATING, + LOGOUT_PAGE_GENERATING, + /** * {@link org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter} */ SECURITY_CONTEXT_SERVER_WEB_EXCHANGE, + /** * {@link org.springframework.security.web.server.savedrequest.ServerRequestCacheWebFilter} */ SERVER_REQUEST_CACHE, + LOGOUT, + EXCEPTION_TRANSLATION, + AUTHORIZATION, + LAST(Integer.MAX_VALUE); private static final int INTERVAL = 100; @@ -83,4 +100,5 @@ public enum SecurityWebFiltersOrder { public int getOrder() { return this.order; } + } diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 8334fce3eb..2cbbaeb231 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -31,9 +31,6 @@ import java.util.UUID; import java.util.function.Function; import java.util.function.Supplier; -import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import reactor.core.publisher.Mono; import reactor.util.context.Context; @@ -58,6 +55,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; @@ -83,6 +81,8 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; @@ -103,6 +103,7 @@ import org.springframework.security.web.PortMapper; import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint; +import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; import org.springframework.security.web.server.MatcherSecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; @@ -179,13 +180,11 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; -import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; - /** - * A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux. - * It allows configuring web based security for specific http requests. By default it will be applied - * to all requests, but can be restricted using {@link #securityMatcher(ServerWebExchangeMatcher)} or - * other similar methods. + * A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but + * for WebFlux. It allows configuring web based security for specific http requests. By + * default it will be applied to all requests, but can be restricted using + * {@link #securityMatcher(ServerWebExchangeMatcher)} or other similar methods. * * A minimal configuration can be found below: * @@ -195,14 +194,15 @@ import static org.springframework.security.web.server.DelegatingServerAuthentica * * @Bean * public MapReactiveUserDetailsService userDetailsService() { - * UserDetails user = User.withDefaultPasswordEncoder() - * .username("user") - * .password("password") - * .roles("USER") - * .build(); - * return new MapReactiveUserDetailsService(user); + * UserDetails user = User.withDefaultPasswordEncoder() + * .username("user") + * .password("password") + * .roles("USER") + * .build(); + * return new MapReactiveUserDetailsService(user); * } * } + * * * Below is the same as our minimal configuration, but explicitly declaring the * {@code ServerHttpSecurity}. @@ -210,27 +210,29 @@ import static org.springframework.security.web.server.DelegatingServerAuthentica *
      * @EnableWebFluxSecurity
      * public class MyExplicitSecurityConfiguration {
    + *
      *     @Bean
      *     public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
    - *          http
    - *               .authorizeExchange()
    - *                    .anyExchange().authenticated()
    - *                         .and()
    - *                    .httpBasic().and()
    - *                    .formLogin();
    - *          return http.build();
    + *         http
    + *             .authorizeExchange()
    + *               .anyExchange().authenticated()
    + *             .and()
    + *               .httpBasic().and()
    + *               .formLogin();
    + *             return http.build();
      *     }
      *
      *     @Bean
      *     public MapReactiveUserDetailsService userDetailsService() {
    - *          UserDetails user = User.withDefaultPasswordEncoder()
    - *               .username("user")
    - *               .password("password")
    - *               .roles("USER")
    - *               .build();
    - *          return new MapReactiveUserDetailsService(user);
    + *         UserDetails user = User.withDefaultPasswordEncoder()
    + *             .username("user")
    + *             .password("password")
    + *             .roles("USER")
    + *             .build();
    + *         return new MapReactiveUserDetailsService(user);
      *     }
      * }
    + * 
    * * @author Rob Winch * @author Vedran Pavic @@ -241,6 +243,7 @@ import static org.springframework.security.web.server.DelegatingServerAuthentica * @since 5.0 */ public class ServerHttpSecurity { + private ServerWebExchangeMatcher securityMatcher = ServerWebExchangeMatchers.anyExchange(); private AuthorizeExchangeSpec authorizeExchange; @@ -283,8 +286,7 @@ public class ServerHttpSecurity { private ServerAccessDeniedHandler accessDeniedHandler; - private List - defaultAccessDeniedHandlers = new ArrayList<>(); + private List defaultAccessDeniedHandlers = new ArrayList<>(); private List webFilters = new ArrayList<>(); @@ -294,11 +296,14 @@ public class ServerHttpSecurity { private AnonymousSpec anonymous; + protected ServerHttpSecurity() { + } + /** - * The ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. - * - * @param matcher the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. - * Default is all requests. + * The ServerExchangeMatcher that determines which requests apply to this HttpSecurity + * instance. + * @param matcher the ServerExchangeMatcher that determines which requests apply to + * this HttpSecurity instance. Default is all requests. * @return the {@link ServerHttpSecurity} to continue configuring */ public ServerHttpSecurity securityMatcher(ServerWebExchangeMatcher matcher) { @@ -345,16 +350,19 @@ public class ServerHttpSecurity { } /** - * Gets the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. - * @return the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance. + * Gets the ServerExchangeMatcher that determines which requests apply to this + * HttpSecurity instance. + * @return the ServerExchangeMatcher that determines which requests apply to this + * HttpSecurity instance. */ private ServerWebExchangeMatcher getSecurityMatcher() { return this.securityMatcher; } /** - * The strategy used with {@code ReactorContextWebFilter}. It does impact how the {@code SecurityContext} is - * saved which is configured on a per {@link AuthenticationWebFilter} basis. + * The strategy used with {@code ReactorContextWebFilter}. It does impact how the + * {@code SecurityContext} is saved which is configured on a per + * {@link AuthenticationWebFilter} basis. * @param securityContextRepository the repository to use * @return the {@link ServerHttpSecurity} to continue configuring */ @@ -379,7 +387,8 @@ public class ServerHttpSecurity { * * Then all non-HTTPS requests will be redirected to HTTPS. * - * Typically, all requests should be HTTPS; however, the focus for redirection can also be narrowed: + * Typically, all requests should be HTTPS; however, the focus for redirection can + * also be narrowed: * *
     	 *  @Bean
    @@ -387,12 +396,11 @@ public class ServerHttpSecurity {
     	 * 	    http
     	 * 	        // ...
     	 * 	        .redirectToHttps()
    -	 * 	            .httpsRedirectWhen(serverWebExchange ->
    +	 * 	            .httpsRedirectWhen((serverWebExchange) ->
     	 * 	            	serverWebExchange.getRequest().getHeaders().containsKey("X-Requires-Https"))
     	 * 	    return http.build();
     	 * 	}
     	 * 
    - * * @return the {@link HttpsRedirectSpec} to customize */ public HttpsRedirectSpec redirectToHttps() { @@ -415,35 +423,36 @@ public class ServerHttpSecurity { * * Then all non-HTTPS requests will be redirected to HTTPS. * - * Typically, all requests should be HTTPS; however, the focus for redirection can also be narrowed: + * Typically, all requests should be HTTPS; however, the focus for redirection can + * also be narrowed: * *
     	 *  @Bean
     	 * 	public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 * 	    http
     	 * 	        // ...
    -	 * 	        .redirectToHttps(redirectToHttps ->
    +	 * 	        .redirectToHttps((redirectToHttps) ->
     	 * 	        	redirectToHttps
    -	 * 	            	.httpsRedirectWhen(serverWebExchange ->
    +	 * 	            	.httpsRedirectWhen((serverWebExchange) ->
     	 * 	            		serverWebExchange.getRequest().getHeaders().containsKey("X-Requires-Https"))
     	 * 	            );
     	 * 	    return http.build();
     	 * 	}
     	 * 
    - * * @param httpsRedirectCustomizer the {@link Customizer} to provide more options for * the {@link HttpsRedirectSpec} * @return the {@link ServerHttpSecurity} to customize */ - public ServerHttpSecurity redirectToHttps(Customizer httpsRedirectCustomizer) { + public ServerHttpSecurity redirectToHttps(Customizer httpsRedirectCustomizer) { this.httpsRedirectSpec = new HttpsRedirectSpec(); httpsRedirectCustomizer.customize(this.httpsRedirectSpec); return this; } /** - * Configures CSRF Protection - * which is enabled by default. You can disable it using: + * Configures CSRF + * Protection which is enabled by default. You can disable it using: * *
     	 *  @Bean
    @@ -473,7 +482,6 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * * @return the {@link CsrfSpec} to customize */ public CsrfSpec csrf() { @@ -484,15 +492,16 @@ public class ServerHttpSecurity { } /** - * Configures CSRF Protection - * which is enabled by default. You can disable it using: + * Configures CSRF + * Protection which is enabled by default. You can disable it using: * *
     	 *  @Bean
     	 *  public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 *      http
     	 *          // ...
    -	 *          .csrf(csrf ->
    +	 *          .csrf((csrf) ->
     	 *              csrf.disabled()
     	 *          );
     	 *      return http.build();
    @@ -507,7 +516,7 @@ public class ServerHttpSecurity {
     	 *  public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 *      http
     	 *          // ...
    -	 *          .csrf(csrf ->
    +	 *          .csrf((csrf) ->
     	 *              csrf
     	 *                  // Handle CSRF failures
     	 *                  .accessDeniedHandler(accessDeniedHandler)
    @@ -519,9 +528,8 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * - * @param csrfCustomizer the {@link Customizer} to provide more options for - * the {@link CsrfSpec} + * @param csrfCustomizer the {@link Customizer} to provide more options for the + * {@link CsrfSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity csrf(Customizer csrfCustomizer) { @@ -533,9 +541,11 @@ public class ServerHttpSecurity { } /** - * Configures CORS headers. By default if a {@link CorsConfigurationSource} Bean is found, it will be used - * to create a {@link CorsWebFilter}. If {@link CorsSpec#configurationSource(CorsConfigurationSource)} is invoked - * it will be used instead. If neither has been configured, the Cors configuration will do nothing. + * Configures CORS headers. By default if a {@link CorsConfigurationSource} Bean is + * found, it will be used to create a {@link CorsWebFilter}. If + * {@link CorsSpec#configurationSource(CorsConfigurationSource)} is invoked it will be + * used instead. If neither has been configured, the Cors configuration will do + * nothing. * @return the {@link CorsSpec} to customize */ public CorsSpec cors() { @@ -546,12 +556,13 @@ public class ServerHttpSecurity { } /** - * Configures CORS headers. By default if a {@link CorsConfigurationSource} Bean is found, it will be used - * to create a {@link CorsWebFilter}. If {@link CorsSpec#configurationSource(CorsConfigurationSource)} is invoked - * it will be used instead. If neither has been configured, the Cors configuration will do nothing. - * - * @param corsCustomizer the {@link Customizer} to provide more options for - * the {@link CorsSpec} + * Configures CORS headers. By default if a {@link CorsConfigurationSource} Bean is + * found, it will be used to create a {@link CorsWebFilter}. If + * {@link CorsSpec#configurationSource(CorsConfigurationSource)} is invoked it will be + * used instead. If neither has been configured, the Cors configuration will do + * nothing. + * @param corsCustomizer the {@link Customizer} to provide more options for the + * {@link CorsSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity cors(Customizer corsCustomizer) { @@ -563,7 +574,8 @@ public class ServerHttpSecurity { } /** - * Enables and Configures anonymous authentication. Anonymous Authentication is disabled by default. + * Enables and Configures anonymous authentication. Anonymous Authentication is + * disabled by default. * *
     	 *  @Bean
    @@ -579,7 +591,7 @@ public class ServerHttpSecurity {
     	 * @since 5.2.0
     	 * @author Ankur Pathak
     	 */
    -	public AnonymousSpec anonymous(){
    +	public AnonymousSpec anonymous() {
     		if (this.anonymous == null) {
     			this.anonymous = new AnonymousSpec();
     		}
    @@ -587,14 +599,15 @@ public class ServerHttpSecurity {
     	}
     
     	/**
    -	 * Enables and Configures anonymous authentication. Anonymous Authentication is disabled by default.
    +	 * Enables and Configures anonymous authentication. Anonymous Authentication is
    +	 * disabled by default.
     	 *
     	 * 
     	 *  @Bean
     	 *  public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 *      http
     	 *          // ...
    -	 *          .anonymous(anonymous ->
    +	 *          .anonymous((anonymous) ->
     	 *              anonymous
     	 *                  .key("key")
     	 *                  .authorities("ROLE_ANONYMOUS")
    @@ -602,9 +615,8 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * - * @param anonymousCustomizer the {@link Customizer} to provide more options for - * the {@link AnonymousSpec} + * @param anonymousCustomizer the {@link Customizer} to provide more options for the + * {@link AnonymousSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity anonymous(Customizer anonymousCustomizer) { @@ -615,67 +627,6 @@ public class ServerHttpSecurity { return this; } - /** - * Configures CORS support within Spring Security. This ensures that the {@link CorsWebFilter} is place in the - * correct order. - */ - public class CorsSpec { - private CorsWebFilter corsFilter; - - /** - * Configures the {@link CorsConfigurationSource} to be used - * @param source the source to use - * @return the {@link CorsSpec} for additional configuration - */ - public CorsSpec configurationSource(CorsConfigurationSource source) { - this.corsFilter = new CorsWebFilter(source); - return this; - } - - /** - * Disables CORS support within Spring Security. - * @return the {@link ServerHttpSecurity} to continue configuring - */ - public ServerHttpSecurity disable() { - ServerHttpSecurity.this.cors = null; - return ServerHttpSecurity.this; - } - - /** - * Allows method chaining to continue configuring the {@link ServerHttpSecurity} - * @return the {@link ServerHttpSecurity} to continue configuring - */ - public ServerHttpSecurity and() { - return ServerHttpSecurity.this; - } - - protected void configure(ServerHttpSecurity http) { - CorsWebFilter corsFilter = getCorsFilter(); - if (corsFilter != null) { - http.addFilterAt(this.corsFilter, SecurityWebFiltersOrder.CORS); - } - } - - private CorsWebFilter getCorsFilter() { - if (this.corsFilter != null) { - return this.corsFilter; - } - - CorsConfigurationSource source = getBeanOrNull(CorsConfigurationSource.class); - if (source == null) { - return null; - } - CorsProcessor processor = getBeanOrNull(CorsProcessor.class); - if (processor == null) { - processor = new DefaultCorsProcessor(); - } - this.corsFilter = new CorsWebFilter(source, processor); - return this.corsFilter; - } - - private CorsSpec() {} - } - /** * Configures HTTP Basic authentication. An example configuration is provided below: * @@ -692,7 +643,6 @@ public class ServerHttpSecurity { * return http.build(); * } *
    - * * @return the {@link HttpBasicSpec} to customize */ public HttpBasicSpec httpBasic() { @@ -710,7 +660,7 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .httpBasic(httpBasic -> + * .httpBasic((httpBasic) -> * httpBasic * // used for authenticating the credentials * .authenticationManager(authenticationManager) @@ -720,9 +670,8 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * @param httpBasicCustomizer the {@link Customizer} to provide more options for - * the {@link HttpBasicSpec} + * @param httpBasicCustomizer the {@link Customizer} to provide more options for the + * {@link HttpBasicSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity httpBasic(Customizer httpBasicCustomizer) { @@ -753,7 +702,6 @@ public class ServerHttpSecurity { * return http.build(); * } * - * * @return the {@link FormLoginSpec} to customize */ public FormLoginSpec formLogin() { @@ -771,7 +719,7 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .formLogin(formLogin -> + * .formLogin((formLogin) -> * formLogin * // used for authenticating the credentials * .authenticationManager(authenticationManager) @@ -785,9 +733,8 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * @param formLoginCustomizer the {@link Customizer} to provide more options for - * the {@link FormLoginSpec} + * @param formLoginCustomizer the {@link Customizer} to provide more options for the + * {@link FormLoginSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity formLogin(Customizer formLoginCustomizer) { @@ -812,9 +759,9 @@ public class ServerHttpSecurity { * } * * - * Note that if extractor is not specified, {@link SubjectDnX509PrincipalExtractor} will be used. - * If authenticationManager is not specified, {@link ReactivePreAuthenticatedAuthenticationManager} will be used. - * + * Note that if extractor is not specified, {@link SubjectDnX509PrincipalExtractor} + * will be used. If authenticationManager is not specified, + * {@link ReactivePreAuthenticatedAuthenticationManager} will be used. * @return the {@link X509Spec} to customize * @author Alexey Nesterov * @since 5.2 @@ -834,7 +781,7 @@ public class ServerHttpSecurity { * @Bean * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http - * .x509(x509 -> + * .x509((x509) -> * x509 * .authenticationManager(authenticationManager) * .principalExtractor(principalExtractor) @@ -843,13 +790,13 @@ public class ServerHttpSecurity { * } * * - * Note that if extractor is not specified, {@link SubjectDnX509PrincipalExtractor} will be used. - * If authenticationManager is not specified, {@link ReactivePreAuthenticatedAuthenticationManager} will be used. - * - * @since 5.2 - * @param x509Customizer the {@link Customizer} to provide more options for - * the {@link X509Spec} + * Note that if extractor is not specified, {@link SubjectDnX509PrincipalExtractor} + * will be used. If authenticationManager is not specified, + * {@link ReactivePreAuthenticatedAuthenticationManager} will be used. + * @param x509Customizer the {@link Customizer} to provide more options for the + * {@link X509Spec} * @return the {@link ServerHttpSecurity} to customize + * @since 5.2 */ public ServerHttpSecurity x509(Customizer x509Customizer) { if (this.x509 == null) { @@ -860,65 +807,8 @@ public class ServerHttpSecurity { } /** - * Configures X509 authentication - * - * @author Alexey Nesterov - * @since 5.2 - * @see #x509() - */ - public class X509Spec { - - private X509PrincipalExtractor principalExtractor; - private ReactiveAuthenticationManager authenticationManager; - - public X509Spec principalExtractor(X509PrincipalExtractor principalExtractor) { - this.principalExtractor = principalExtractor; - return this; - } - - public X509Spec authenticationManager(ReactiveAuthenticationManager authenticationManager) { - this.authenticationManager = authenticationManager; - return this; - } - - public ServerHttpSecurity and() { - return ServerHttpSecurity.this; - } - - protected void configure(ServerHttpSecurity http) { - ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); - X509PrincipalExtractor principalExtractor = getPrincipalExtractor(); - - AuthenticationWebFilter filter = new AuthenticationWebFilter(authenticationManager); - filter.setServerAuthenticationConverter(new ServerX509AuthenticationConverter(principalExtractor)); - http.addFilterAt(filter, SecurityWebFiltersOrder.AUTHENTICATION); - } - - private X509PrincipalExtractor getPrincipalExtractor() { - if (this.principalExtractor != null) { - return this.principalExtractor; - } - - return new SubjectDnX509PrincipalExtractor(); - } - - private ReactiveAuthenticationManager getAuthenticationManager() { - if (this.authenticationManager != null) { - return this.authenticationManager; - } - - ReactiveUserDetailsService userDetailsService = getBean(ReactiveUserDetailsService.class); - ReactivePreAuthenticatedAuthenticationManager authenticationManager = new ReactivePreAuthenticatedAuthenticationManager(userDetailsService); - - return authenticationManager; - } - - private X509Spec() { - } - } - - /** - * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 Provider. + * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 + * Provider. * *
     	 *  @Bean
    @@ -931,8 +821,6 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * - * * @return the {@link OAuth2LoginSpec} to customize */ public OAuth2LoginSpec oauth2Login() { @@ -943,14 +831,15 @@ public class ServerHttpSecurity { } /** - * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 Provider. + * Configures authentication support using an OAuth 2.0 and/or OpenID Connect 1.0 + * Provider. * *
     	 *  @Bean
     	 *  public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 *      http
     	 *          // ...
    -	 *          .oauth2Login(oauth2Login ->
    +	 *          .oauth2Login((oauth2Login) ->
     	 *              oauth2Login
     	 *                  .authenticationConverter(authenticationConverter)
     	 *                  .authenticationManager(manager)
    @@ -958,9 +847,8 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * - * @param oauth2LoginCustomizer the {@link Customizer} to provide more options for - * the {@link OAuth2LoginSpec} + * @param oauth2LoginCustomizer the {@link Customizer} to provide more options for the + * {@link OAuth2LoginSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity oauth2Login(Customizer oauth2LoginCustomizer) { @@ -971,388 +859,6 @@ public class ServerHttpSecurity { return this; } - public class OAuth2LoginSpec { - private ReactiveClientRegistrationRepository clientRegistrationRepository; - - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - - private ServerAuthorizationRequestRepository authorizationRequestRepository; - - private ReactiveAuthenticationManager authenticationManager; - - private ServerSecurityContextRepository securityContextRepository; - - private ServerAuthenticationConverter authenticationConverter; - - private ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; - - private ServerWebExchangeMatcher authenticationMatcher; - - private ServerAuthenticationSuccessHandler authenticationSuccessHandler; - - private ServerAuthenticationFailureHandler authenticationFailureHandler; - - /** - * Configures the {@link ReactiveAuthenticationManager} to use. The default is - * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} - * @param authenticationManager the manager to use - * @return the {@link OAuth2LoginSpec} to customize - */ - public OAuth2LoginSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { - this.authenticationManager = authenticationManager; - return this; - } - - /** - * The {@link ServerSecurityContextRepository} used to save the {@code Authentication}. Defaults to - * {@link WebSessionServerSecurityContextRepository}. - * - * @since 5.2 - * @param securityContextRepository the repository to use - * @return the {@link OAuth2LoginSpec} to continue configuring - */ - public OAuth2LoginSpec securityContextRepository(ServerSecurityContextRepository securityContextRepository) { - this.securityContextRepository = securityContextRepository; - return this; - } - - /** - * The {@link ServerAuthenticationSuccessHandler} used after authentication success. Defaults to - * {@link RedirectServerAuthenticationSuccessHandler} redirecting to "/". - * - * @since 5.2 - * @param authenticationSuccessHandler the success handler to use - * @return the {@link OAuth2LoginSpec} to customize - */ - public OAuth2LoginSpec authenticationSuccessHandler(ServerAuthenticationSuccessHandler authenticationSuccessHandler) { - Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); - this.authenticationSuccessHandler = authenticationSuccessHandler; - return this; - } - - /** - * The {@link ServerAuthenticationFailureHandler} used after authentication failure. - * Defaults to {@link RedirectServerAuthenticationFailureHandler} redirecting to "/login?error". - * - * @since 5.2 - * @param authenticationFailureHandler the failure handler to use - * @return the {@link OAuth2LoginSpec} to customize - */ - public OAuth2LoginSpec authenticationFailureHandler(ServerAuthenticationFailureHandler authenticationFailureHandler) { - Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); - this.authenticationFailureHandler = authenticationFailureHandler; - return this; - } - - /** - * Gets the {@link ReactiveAuthenticationManager} to use. First tries an explicitly configured manager, and - * defaults to {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} - * - * @return the {@link ReactiveAuthenticationManager} to use - */ - private ReactiveAuthenticationManager getAuthenticationManager() { - if (this.authenticationManager == null) { - this.authenticationManager = createDefault(); - } - return this.authenticationManager; - } - - private ReactiveAuthenticationManager createDefault() { - ReactiveOAuth2AccessTokenResponseClient client = getAccessTokenResponseClient(); - OAuth2LoginReactiveAuthenticationManager oauth2Manager = new OAuth2LoginReactiveAuthenticationManager(client, getOauth2UserService()); - GrantedAuthoritiesMapper authoritiesMapper = getBeanOrNull(GrantedAuthoritiesMapper.class); - if (authoritiesMapper != null) { - oauth2Manager.setAuthoritiesMapper(authoritiesMapper); - } - boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent( - "org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); - if (oidcAuthenticationProviderEnabled) { - OidcAuthorizationCodeReactiveAuthenticationManager oidc = - new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService()); - ResolvableType type = ResolvableType.forClassWithGenerics( - ReactiveJwtDecoderFactory.class, ClientRegistration.class); - ReactiveJwtDecoderFactory jwtDecoderFactory = getBeanOrNull(type); - if (jwtDecoderFactory != null) { - oidc.setJwtDecoderFactory(jwtDecoderFactory); - } - if (authoritiesMapper != null) { - oidc.setAuthoritiesMapper(authoritiesMapper); - } - return new DelegatingReactiveAuthenticationManager(oidc, oauth2Manager); - } - return oauth2Manager; - } - - /** - * Sets the converter to use - * @param authenticationConverter the converter to use - * @return the {@link OAuth2LoginSpec} to customize - */ - public OAuth2LoginSpec authenticationConverter(ServerAuthenticationConverter authenticationConverter) { - this.authenticationConverter = authenticationConverter; - return this; - } - - private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) { - if (this.authenticationConverter == null) { - ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate = - new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); - delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); - ServerAuthenticationConverter authenticationConverter = exchange -> - delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class, - e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())); - this.authenticationConverter = authenticationConverter; - return authenticationConverter; - } - return this.authenticationConverter; - } - - public OAuth2LoginSpec clientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; - return this; - } - - public OAuth2LoginSpec authorizedClientService(ReactiveOAuth2AuthorizedClientService authorizedClientService) { - this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(authorizedClientService); - return this; - } - - public OAuth2LoginSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this.authorizedClientRepository = authorizedClientRepository; - return this; - } - - /** - * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s. - * - * @since 5.2 - * @param authorizationRequestRepository the repository to use for storing {@link OAuth2AuthorizationRequest}'s - * @return the {@link OAuth2LoginSpec} for further configuration - */ - public OAuth2LoginSpec authorizationRequestRepository( - ServerAuthorizationRequestRepository authorizationRequestRepository) { - this.authorizationRequestRepository = authorizationRequestRepository; - return this; - } - - /** - * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s. - * - * @since 5.2 - * @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s - * @return the {@link OAuth2LoginSpec} for further configuration - */ - public OAuth2LoginSpec authorizationRequestResolver(ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver) { - this.authorizationRequestResolver = authorizationRequestResolver; - return this; - } - - /** - * Sets the {@link ServerWebExchangeMatcher matcher} used for determining if the request is an authentication request. - * - * @since 5.2 - * @param authenticationMatcher the {@link ServerWebExchangeMatcher matcher} used for determining if the request is an authentication request - * @return the {@link OAuth2LoginSpec} for further configuration - */ - public OAuth2LoginSpec authenticationMatcher(ServerWebExchangeMatcher authenticationMatcher) { - this.authenticationMatcher = authenticationMatcher; - return this; - } - - private ServerWebExchangeMatcher getAuthenticationMatcher() { - if (this.authenticationMatcher == null) { - this.authenticationMatcher = createAttemptAuthenticationRequestMatcher(); - } - return this.authenticationMatcher; - } - - /** - * Allows method chaining to continue configuring the {@link ServerHttpSecurity} - * @return the {@link ServerHttpSecurity} to continue configuring - */ - public ServerHttpSecurity and() { - return ServerHttpSecurity.this; - } - - - protected void configure(ServerHttpSecurity http) { - ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); - OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter(); - ServerAuthorizationRequestRepository authorizationRequestRepository = - getAuthorizationRequestRepository(); - oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); - oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); - - ReactiveAuthenticationManager manager = getAuthenticationManager(); - - AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository); - authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher()); - authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository)); - - authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http)); - authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); - authenticationFilter.setSecurityContextRepository(this.securityContextRepository); - - setDefaultEntryPoints(http); - - http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); - http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); - } - - private void setDefaultEntryPoints(ServerHttpSecurity http) { - String defaultLoginPage = "/login"; - Map urlToText = http.oauth2Login.getLinks(); - String providerLoginPage = null; - if (urlToText.size() == 1) { - providerLoginPage = urlToText.keySet().iterator().next(); - } - - MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( - MediaType.APPLICATION_XHTML_XML, new MediaType("image", "*"), - MediaType.TEXT_HTML, MediaType.TEXT_PLAIN); - htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); - - ServerWebExchangeMatcher xhrMatcher = exchange -> { - if (exchange.getRequest().getHeaders().getOrEmpty("X-Requested-With").contains("XMLHttpRequest")) { - return ServerWebExchangeMatcher.MatchResult.match(); - } - return ServerWebExchangeMatcher.MatchResult.notMatch(); - }; - ServerWebExchangeMatcher notXhrMatcher = new NegatedServerWebExchangeMatcher(xhrMatcher); - - ServerWebExchangeMatcher defaultEntryPointMatcher = new AndServerWebExchangeMatcher( - notXhrMatcher, htmlMatcher); - - if (providerLoginPage != null) { - ServerWebExchangeMatcher loginPageMatcher = new PathPatternParserServerWebExchangeMatcher(defaultLoginPage); - ServerWebExchangeMatcher faviconMatcher = new PathPatternParserServerWebExchangeMatcher("/favicon.ico"); - ServerWebExchangeMatcher defaultLoginPageMatcher = new AndServerWebExchangeMatcher( - new OrServerWebExchangeMatcher(loginPageMatcher, faviconMatcher), defaultEntryPointMatcher); - - ServerWebExchangeMatcher matcher = new AndServerWebExchangeMatcher( - notXhrMatcher, new NegatedServerWebExchangeMatcher(defaultLoginPageMatcher)); - RedirectServerAuthenticationEntryPoint entryPoint = - new RedirectServerAuthenticationEntryPoint(providerLoginPage); - entryPoint.setRequestCache(http.requestCache.requestCache); - http.defaultEntryPoints.add(new DelegateEntry(matcher, entryPoint)); - } - - RedirectServerAuthenticationEntryPoint defaultEntryPoint = - new RedirectServerAuthenticationEntryPoint(defaultLoginPage); - defaultEntryPoint.setRequestCache(http.requestCache.requestCache); - http.defaultEntryPoints.add(new DelegateEntry(defaultEntryPointMatcher, defaultEntryPoint)); - } - - private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) { - if (this.authenticationSuccessHandler == null) { - RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler(); - handler.setRequestCache(http.requestCache.requestCache); - this.authenticationSuccessHandler = handler; - } - return this.authenticationSuccessHandler; - } - - private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() { - if (this.authenticationFailureHandler == null) { - this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error"); - } - return this.authenticationFailureHandler; - } - - private ServerWebExchangeMatcher createAttemptAuthenticationRequestMatcher() { - return new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"); - } - - private ReactiveOAuth2UserService getOidcUserService() { - ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2UserService.class, OidcUserRequest.class, OidcUser.class); - ReactiveOAuth2UserService bean = getBeanOrNull(type); - if (bean == null) { - return new OidcReactiveOAuth2UserService(); - } - - return bean; - } - - private ReactiveOAuth2UserService getOauth2UserService() { - ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2UserService.class, OAuth2UserRequest.class, OAuth2User.class); - ReactiveOAuth2UserService bean = getBeanOrNull(type); - if (bean == null) { - return new DefaultReactiveOAuth2UserService(); - } - - return bean; - } - - private Map getLinks() { - Iterable registrations = getBeanOrNull(ResolvableType.forClassWithGenerics(Iterable.class, ClientRegistration.class)); - if (registrations == null) { - return Collections.emptyMap(); - } - Map result = new HashMap<>(); - registrations.iterator().forEachRemaining(r -> result.put("/oauth2/authorization/" + r.getRegistrationId(), r.getClientName())); - return result; - } - - private ReactiveOAuth2AccessTokenResponseClient getAccessTokenResponseClient() { - ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, OAuth2AuthorizationCodeGrantRequest.class); - ReactiveOAuth2AccessTokenResponseClient bean = getBeanOrNull(type); - if (bean == null) { - return new WebClientReactiveAuthorizationCodeTokenResponseClient(); - } - return bean; - } - - private ReactiveClientRegistrationRepository getClientRegistrationRepository() { - if (this.clientRegistrationRepository == null) { - this.clientRegistrationRepository = getBeanOrNull(ReactiveClientRegistrationRepository.class); - } - return this.clientRegistrationRepository; - } - - private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() { - OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter; - if (this.authorizationRequestResolver == null) { - oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository()); - } else { - oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver); - } - return oauthRedirectFilter; - } - - private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { - ServerOAuth2AuthorizedClientRepository result = this.authorizedClientRepository; - if (result == null) { - result = getBeanOrNull(ServerOAuth2AuthorizedClientRepository.class); - } - if (result == null) { - ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); - if (authorizedClientService != null) { - result = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( - authorizedClientService); - } - } - return result; - } - - private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { - if (this.authorizationRequestRepository == null) { - this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); - } - return this.authorizationRequestRepository; - } - - private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { - ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); - if (service == null) { - service = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); - } - return service; - } - - private OAuth2LoginSpec() {} - } - /** * Configures the OAuth2 client. * @@ -1367,8 +873,6 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * * @return the {@link OAuth2ClientSpec} to customize */ public OAuth2ClientSpec oauth2Client() { @@ -1386,7 +890,7 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .oauth2Client(oauth2Client -> + * .oauth2Client((oauth2Client) -> * oauth2Client * .clientRegistrationRepository(clientRegistrationRepository) * .authorizedClientRepository(authorizedClientRepository) @@ -1394,7 +898,6 @@ public class ServerHttpSecurity { * return http.build(); * } * - * * @param oauth2ClientCustomizer the {@link Customizer} to provide more options for * the {@link OAuth2ClientSpec} * @return the {@link ServerHttpSecurity} to customize @@ -1407,165 +910,6 @@ public class ServerHttpSecurity { return this; } - public class OAuth2ClientSpec { - private ReactiveClientRegistrationRepository clientRegistrationRepository; - - private ServerAuthenticationConverter authenticationConverter; - - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - - private ReactiveAuthenticationManager authenticationManager; - - private ServerAuthorizationRequestRepository authorizationRequestRepository; - - /** - * Sets the converter to use - * @param authenticationConverter the converter to use - * @return the {@link OAuth2ClientSpec} to customize - */ - public OAuth2ClientSpec authenticationConverter(ServerAuthenticationConverter authenticationConverter) { - this.authenticationConverter = authenticationConverter; - return this; - } - - private ServerAuthenticationConverter getAuthenticationConverter() { - if (this.authenticationConverter == null) { - ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = - new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(getClientRegistrationRepository()); - authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); - this.authenticationConverter = authenticationConverter; - } - return this.authenticationConverter; - } - - /** - * Configures the {@link ReactiveAuthenticationManager} to use. The default is - * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} - * @param authenticationManager the manager to use - * @return the {@link OAuth2ClientSpec} to customize - */ - public OAuth2ClientSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { - this.authenticationManager = authenticationManager; - return this; - } - - /** - * Gets the {@link ReactiveAuthenticationManager} to use. First tries an explicitly configured manager, and - * defaults to {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} - * - * @return the {@link ReactiveAuthenticationManager} to use - */ - private ReactiveAuthenticationManager getAuthenticationManager() { - if (this.authenticationManager == null) { - this.authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(new WebClientReactiveAuthorizationCodeTokenResponseClient()); - } - return this.authenticationManager; - } - - /** - * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look the value up as a Bean. - * @param clientRegistrationRepository the repository to use - * @return the {@link OAuth2ClientSpec} to customize - */ - public OAuth2ClientSpec clientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; - return this; - } - - /** - * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look the value up as a Bean. - * @param authorizedClientRepository the repository to use - * @return the {@link OAuth2ClientSpec} to customize - */ - public OAuth2ClientSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this.authorizedClientRepository = authorizedClientRepository; - return this; - } - - /** - * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s. - * - * @since 5.2 - * @param authorizationRequestRepository the repository to use for storing {@link OAuth2AuthorizationRequest}'s - * @return the {@link OAuth2ClientSpec} to customize - */ - public OAuth2ClientSpec authorizationRequestRepository( - ServerAuthorizationRequestRepository authorizationRequestRepository) { - this.authorizationRequestRepository = authorizationRequestRepository; - return this; - } - - private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { - if (this.authorizationRequestRepository == null) { - this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); - } - return this.authorizationRequestRepository; - } - - /** - * Allows method chaining to continue configuring the {@link ServerHttpSecurity} - * @return the {@link ServerHttpSecurity} to continue configuring - */ - public ServerHttpSecurity and() { - return ServerHttpSecurity.this; - } - - protected void configure(ServerHttpSecurity http) { - ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); - ServerAuthenticationConverter authenticationConverter = getAuthenticationConverter(); - ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); - OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter( - authenticationManager, authenticationConverter, authorizedClientRepository); - codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); - if (http.requestCache != null) { - codeGrantWebFilter.setRequestCache(http.requestCache.requestCache); - } - - OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( - clientRegistrationRepository); - oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); - if (http.requestCache != null) { - oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); - } - - http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); - http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); - } - - private ReactiveClientRegistrationRepository getClientRegistrationRepository() { - if (this.clientRegistrationRepository != null) { - return this.clientRegistrationRepository; - } - return getBeanOrNull(ReactiveClientRegistrationRepository.class); - } - - private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { - if (this.authorizedClientRepository != null) { - return this.authorizedClientRepository; - } - ServerOAuth2AuthorizedClientRepository result = getBeanOrNull(ServerOAuth2AuthorizedClientRepository.class); - if (result == null) { - ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); - if (authorizedClientService != null) { - result = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( - authorizedClientService); - } - } - return result; - } - - private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { - ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); - if (service == null) { - service = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); - } - return service; - } - - private OAuth2ClientSpec() {} - } - /** * Configures OAuth 2.0 Resource Server support. * @@ -1580,7 +924,6 @@ public class ServerHttpSecurity { * return http.build(); * } * - * * @return the {@link OAuth2ResourceServerSpec} to customize */ public OAuth2ResourceServerSpec oauth2ResourceServer() { @@ -1598,9 +941,9 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .oauth2ResourceServer(oauth2ResourceServer -> + * .oauth2ResourceServer((oauth2ResourceServer) -> * oauth2ResourceServer - * .jwt(jwt -> + * .jwt((jwt) -> * jwt * .publicKey(publicKey()) * ) @@ -1608,12 +951,12 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * @param oauth2ResourceServerCustomizer the {@link Customizer} to provide more options for - * the {@link OAuth2ResourceServerSpec} + * @param oauth2ResourceServerCustomizer the {@link Customizer} to provide more + * options for the {@link OAuth2ResourceServerSpec} * @return the {@link ServerHttpSecurity} to customize */ - public ServerHttpSecurity oauth2ResourceServer(Customizer oauth2ResourceServerCustomizer) { + public ServerHttpSecurity oauth2ResourceServer( + Customizer oauth2ResourceServerCustomizer) { if (this.resourceServer == null) { this.resourceServer = new OAuth2ResourceServerSpec(); } @@ -1621,422 +964,6 @@ public class ServerHttpSecurity { return this; } - /** - * Configures OAuth2 Resource Server Support - */ - public class OAuth2ResourceServerSpec { - private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); - private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler(); - private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter(); - private AuthenticationConverterServerWebExchangeMatcher authenticationConverterServerWebExchangeMatcher; - - private JwtSpec jwt; - private OpaqueTokenSpec opaqueToken; - private ReactiveAuthenticationManagerResolver authenticationManagerResolver; - - /** - * Configures the {@link ServerAccessDeniedHandler} to use for requests authenticating with - * Bearer Tokens. - * requests. - * - * @param accessDeniedHandler the {@link ServerAccessDeniedHandler} to use - * @return the {@link OAuth2ResourceServerSpec} for additional configuration - * @since 5.2 - */ - public OAuth2ResourceServerSpec accessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { - Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null"); - this.accessDeniedHandler = accessDeniedHandler; - return this; - } - - /** - * Configures the {@link ServerAuthenticationEntryPoint} to use for requests authenticating with - * Bearer Tokens. - * - * @param entryPoint the {@link ServerAuthenticationEntryPoint} to use - * @return the {@link OAuth2ResourceServerSpec} for additional configuration - * @since 5.2 - */ - public OAuth2ResourceServerSpec authenticationEntryPoint(ServerAuthenticationEntryPoint entryPoint) { - Assert.notNull(entryPoint, "entryPoint cannot be null"); - this.entryPoint = entryPoint; - return this; - } - - /** - * Configures the {@link ServerAuthenticationConverter} to use for requests authenticating with - * Bearer Tokens. - * - * @param bearerTokenConverter The {@link ServerAuthenticationConverter} to use - * @return The {@link OAuth2ResourceServerSpec} for additional configuration - * @since 5.2 - */ - public OAuth2ResourceServerSpec bearerTokenConverter(ServerAuthenticationConverter bearerTokenConverter) { - Assert.notNull(bearerTokenConverter, "bearerTokenConverter cannot be null"); - this.bearerTokenConverter = bearerTokenConverter; - return this; - } - - /** - * Configures the {@link ReactiveAuthenticationManagerResolver} - * - * @param authenticationManagerResolver the {@link ReactiveAuthenticationManagerResolver} - * @return the {@link OAuth2ResourceServerSpec} for additional configuration - * @since 5.3 - */ - public OAuth2ResourceServerSpec authenticationManagerResolver( - ReactiveAuthenticationManagerResolver authenticationManagerResolver) { - Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); - this.authenticationManagerResolver = authenticationManagerResolver; - return this; - } - - /** - * Enables JWT Resource Server support. - * - * @return the {@link JwtSpec} for additional configuration - */ - public JwtSpec jwt() { - if (this.jwt == null) { - this.jwt = new JwtSpec(); - } - return this.jwt; - } - - /** - * Enables JWT Resource Server support. - * - * @param jwtCustomizer the {@link Customizer} to provide more options for - * the {@link JwtSpec} - * @return the {@link OAuth2ResourceServerSpec} to customize - */ - public OAuth2ResourceServerSpec jwt(Customizer jwtCustomizer) { - if (this.jwt == null) { - this.jwt = new JwtSpec(); - } - jwtCustomizer.customize(this.jwt); - return this; - } - - /** - * Enables Opaque Token Resource Server support. - * - * @return the {@link OpaqueTokenSpec} for additional configuration - */ - public OpaqueTokenSpec opaqueToken() { - if (this.opaqueToken == null) { - this.opaqueToken = new OpaqueTokenSpec(); - } - return this.opaqueToken; - } - - /** - * Enables Opaque Token Resource Server support. - * - * @param opaqueTokenCustomizer the {@link Customizer} to provide more options for - * the {@link OpaqueTokenSpec} - * @return the {@link OAuth2ResourceServerSpec} to customize - */ - public OAuth2ResourceServerSpec opaqueToken(Customizer opaqueTokenCustomizer) { - if (this.opaqueToken == null) { - this.opaqueToken = new OpaqueTokenSpec(); - } - opaqueTokenCustomizer.customize(this.opaqueToken); - return this; - } - - protected void configure(ServerHttpSecurity http) { - this.authenticationConverterServerWebExchangeMatcher = - new AuthenticationConverterServerWebExchangeMatcher(this.bearerTokenConverter); - - registerDefaultAccessDeniedHandler(http); - registerDefaultAuthenticationEntryPoint(http); - registerDefaultCsrfOverride(http); - - validateConfiguration(); - - if (this.authenticationManagerResolver != null) { - AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver); - oauth2.setServerAuthenticationConverter(bearerTokenConverter); - oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); - http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); - } else if (this.jwt != null) { - this.jwt.configure(http); - } else if (this.opaqueToken != null) { - this.opaqueToken.configure(http); - } - } - - private void validateConfiguration() { - if (this.authenticationManagerResolver == null) { - if (this.jwt == null && this.opaqueToken == null) { - throw new IllegalStateException("Jwt and Opaque Token are the only supported formats for bearer tokens " + - "in Spring Security and neither was found. Make sure to configure JWT " + - "via http.oauth2ResourceServer().jwt() or Opaque Tokens via " + - "http.oauth2ResourceServer().opaqueToken()."); - } - - if (this.jwt != null && this.opaqueToken != null) { - throw new IllegalStateException("Spring Security only supports JWTs or Opaque Tokens, not both at the " + - "same time."); - } - } else { - if (this.jwt != null || this.opaqueToken != null) { - throw new IllegalStateException("If an authenticationManagerResolver() is configured, then it takes " + - "precedence over any jwt() or opaqueToken() configuration."); - } - } - } - - private void registerDefaultAccessDeniedHandler(ServerHttpSecurity http) { - if ( http.exceptionHandling != null ) { - http.defaultAccessDeniedHandlers.add( - new ServerWebExchangeDelegatingServerAccessDeniedHandler.DelegateEntry( - this.authenticationConverterServerWebExchangeMatcher, - OAuth2ResourceServerSpec.this.accessDeniedHandler - ) - ); - } - } - - private void registerDefaultAuthenticationEntryPoint(ServerHttpSecurity http) { - if (http.exceptionHandling != null) { - http.defaultEntryPoints.add( - new DelegateEntry( - this.authenticationConverterServerWebExchangeMatcher, - OAuth2ResourceServerSpec.this.entryPoint - ) - ); - } - } - - private void registerDefaultCsrfOverride(ServerHttpSecurity http) { - if ( http.csrf != null && !http.csrf.specifiedRequireCsrfProtectionMatcher ) { - http - .csrf() - .requireCsrfProtectionMatcher( - new AndServerWebExchangeMatcher( - CsrfWebFilter.DEFAULT_CSRF_MATCHER, - new NegatedServerWebExchangeMatcher( - this.authenticationConverterServerWebExchangeMatcher))); - } - } - - private class BearerTokenAuthenticationWebFilter extends AuthenticationWebFilter { - private ServerAuthenticationFailureHandler authenticationFailureHandler; - - BearerTokenAuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) { - super(authenticationManager); - } - - @Override - public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { - WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); - return super.filter(exchange, chain) - .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler - .onAuthenticationFailure(webFilterExchange, e)); - } - - @Override - public void setAuthenticationFailureHandler(ServerAuthenticationFailureHandler authenticationFailureHandler) { - super.setAuthenticationFailureHandler(authenticationFailureHandler); - this.authenticationFailureHandler = authenticationFailureHandler; - } - } - - /** - * Configures JWT Resource Server Support - */ - public class JwtSpec { - private ReactiveAuthenticationManager authenticationManager; - private ReactiveJwtDecoder jwtDecoder; - private Converter> jwtAuthenticationConverter - = new ReactiveJwtAuthenticationConverterAdapter(new JwtAuthenticationConverter()); - - /** - * Configures the {@link ReactiveAuthenticationManager} to use - * @param authenticationManager the authentication manager to use - * @return the {@code JwtSpec} for additional configuration - */ - public JwtSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { - Assert.notNull(authenticationManager, "authenticationManager cannot be null"); - this.authenticationManager = authenticationManager; - return this; - } - - /** - * Configures the {@link Converter} to use for converting a {@link Jwt} into - * an {@link AbstractAuthenticationToken}. - * - * @param jwtAuthenticationConverter the converter to use - * @return the {@code JwtSpec} for additional configuration - * @since 5.1.1 - */ - public JwtSpec jwtAuthenticationConverter - (Converter> jwtAuthenticationConverter) { - Assert.notNull(jwtAuthenticationConverter, "jwtAuthenticationConverter cannot be null"); - this.jwtAuthenticationConverter = jwtAuthenticationConverter; - return this; - } - - /** - * Configures the {@link ReactiveJwtDecoder} to use - * @param jwtDecoder the decoder to use - * @return the {@code JwtSpec} for additional configuration - */ - public JwtSpec jwtDecoder(ReactiveJwtDecoder jwtDecoder) { - this.jwtDecoder = jwtDecoder; - return this; - } - - /** - * Configures a {@link ReactiveJwtDecoder} that leverages the provided {@link RSAPublicKey} - * - * @param publicKey the public key to use. - * @return the {@code JwtSpec} for additional configuration - */ - public JwtSpec publicKey(RSAPublicKey publicKey) { - this.jwtDecoder = new NimbusReactiveJwtDecoder(publicKey); - return this; - } - - /** - * Configures a {@link ReactiveJwtDecoder} using - * JSON Web Key (JWK) URL - * @param jwkSetUri the URL to use. - * @return the {@code JwtSpec} for additional configuration - */ - public JwtSpec jwkSetUri(String jwkSetUri) { - this.jwtDecoder = new NimbusReactiveJwtDecoder(jwkSetUri); - return this; - } - - public OAuth2ResourceServerSpec and() { - return OAuth2ResourceServerSpec.this; - } - - protected void configure(ServerHttpSecurity http) { - ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); - AuthenticationWebFilter oauth2 = new BearerTokenAuthenticationWebFilter(authenticationManager); - oauth2.setServerAuthenticationConverter(bearerTokenConverter); - oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); - http - .addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); - } - - protected ReactiveJwtDecoder getJwtDecoder() { - if (this.jwtDecoder == null) { - return getBean(ReactiveJwtDecoder.class); - } - return this.jwtDecoder; - } - - protected Converter> - getJwtAuthenticationConverter() { - - return this.jwtAuthenticationConverter; - } - - private ReactiveAuthenticationManager getAuthenticationManager() { - if (this.authenticationManager != null) { - return this.authenticationManager; - } - - ReactiveJwtDecoder jwtDecoder = getJwtDecoder(); - Converter> jwtAuthenticationConverter = - getJwtAuthenticationConverter(); - JwtReactiveAuthenticationManager authenticationManager = - new JwtReactiveAuthenticationManager(jwtDecoder); - authenticationManager.setJwtAuthenticationConverter(jwtAuthenticationConverter); - - return authenticationManager; - } - } - - /** - * Configures Opaque Token Resource Server support - * - * @author Josh Cummings - * @since 5.2 - */ - public class OpaqueTokenSpec { - private String introspectionUri; - private String clientId; - private String clientSecret; - private Supplier introspector; - - /** - * Configures the URI of the Introspection endpoint - * @param introspectionUri The URI of the Introspection endpoint - * @return the {@code OpaqueTokenSpec} for additional configuration - */ - public OpaqueTokenSpec introspectionUri(String introspectionUri) { - Assert.hasText(introspectionUri, "introspectionUri cannot be empty"); - this.introspectionUri = introspectionUri; - this.introspector = () -> - new NimbusReactiveOpaqueTokenIntrospector( - this.introspectionUri, this.clientId, this.clientSecret); - return this; - } - - /** - * Configures the credentials for Introspection endpoint - * @param clientId The clientId part of the credentials - * @param clientSecret The clientSecret part of the credentials - * @return the {@code OpaqueTokenSpec} for additional configuration - */ - public OpaqueTokenSpec introspectionClientCredentials(String clientId, String clientSecret) { - Assert.hasText(clientId, "clientId cannot be empty"); - Assert.notNull(clientSecret, "clientSecret cannot be null"); - this.clientId = clientId; - this.clientSecret = clientSecret; - this.introspector = () -> - new NimbusReactiveOpaqueTokenIntrospector( - this.introspectionUri, this.clientId, this.clientSecret); - return this; - } - - public OpaqueTokenSpec introspector(ReactiveOpaqueTokenIntrospector introspector) { - Assert.notNull(introspector, "introspector cannot be null"); - this.introspector = () -> introspector; - return this; - } - - /** - * Allows method chaining to continue configuring the {@link ServerHttpSecurity} - * @return the {@link ServerHttpSecurity} to continue configuring - */ - public OAuth2ResourceServerSpec and() { - return OAuth2ResourceServerSpec.this; - } - - protected ReactiveAuthenticationManager getAuthenticationManager() { - return new OpaqueTokenReactiveAuthenticationManager(getIntrospector()); - } - - protected ReactiveOpaqueTokenIntrospector getIntrospector() { - if (this.introspector != null) { - return this.introspector.get(); - } - return getBean(ReactiveOpaqueTokenIntrospector.class); - } - - protected void configure(ServerHttpSecurity http) { - ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); - AuthenticationWebFilter oauth2 = new BearerTokenAuthenticationWebFilter(authenticationManager); - oauth2.setServerAuthenticationConverter(bearerTokenConverter); - oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); - http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); - } - - private OpaqueTokenSpec() {} - } - - public ServerHttpSecurity and() { - return ServerHttpSecurity.this; - } - } - /** * Configures HTTP Response Headers. The default headers are: * @@ -2069,7 +996,6 @@ public class ServerHttpSecurity { * return http.build(); * } * - * * @return the {@link HeaderSpec} to customize */ public HeaderSpec headers() { @@ -2101,15 +1027,15 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .headers(headers -> + * .headers((headers) -> * headers * // customize frame options to be same origin - * .frameOptions(frameOptions -> + * .frameOptions((frameOptions) -> * frameOptions * .mode(XFrameOptionsServerHttpHeadersWriter.Mode.SAMEORIGIN) * ) * // disable cache control - * .cache(cache -> + * .cache((cache) -> * cache * .disable() * ) @@ -2117,9 +1043,8 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * @param headerCustomizer the {@link Customizer} to provide more options for - * the {@link HeaderSpec} + * @param headerCustomizer the {@link Customizer} to provide more options for the + * {@link HeaderSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity headers(Customizer headerCustomizer) { @@ -2131,8 +1056,8 @@ public class ServerHttpSecurity { } /** - * Configures exception handling (i.e. handles when authentication is requested). An example configuration can - * be found below: + * Configures exception handling (i.e. handles when authentication is requested). An + * example configuration can be found below: * *
     	 *  @Bean
    @@ -2145,7 +1070,6 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * * @return the {@link ExceptionHandlingSpec} to customize */ public ExceptionHandlingSpec exceptionHandling() { @@ -2156,15 +1080,15 @@ public class ServerHttpSecurity { } /** - * Configures exception handling (i.e. handles when authentication is requested). An example configuration can - * be found below: + * Configures exception handling (i.e. handles when authentication is requested). An + * example configuration can be found below: * *
     	 *  @Bean
     	 *  public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 *      http
     	 *          // ...
    -	 *          .exceptionHandling(exceptionHandling ->
    +	 *          .exceptionHandling((exceptionHandling) ->
     	 *              exceptionHandling
     	 *                  // customize how to request for authentication
     	 *                  .authenticationEntryPoint(entryPoint)
    @@ -2172,9 +1096,8 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * - * @param exceptionHandlingCustomizer the {@link Customizer} to provide more options for - * the {@link ExceptionHandlingSpec} + * @param exceptionHandlingCustomizer the {@link Customizer} to provide more options + * for the {@link ExceptionHandlingSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity exceptionHandling(Customizer exceptionHandlingCustomizer) { @@ -2203,7 +1126,7 @@ public class ServerHttpSecurity { * .pathMatchers("/users/{username}").access((authentication, context) -> * authentication * .map(Authentication::getName) - * .map(username -> username.equals(context.getVariables().get("username"))) + * .map((username) -> username.equals(context.getVariables().get("username"))) * .map(AuthorizationDecision::new) * ) * // allows providing a custom matching strategy that requires the role "ROLE_CUSTOM" @@ -2213,7 +1136,6 @@ public class ServerHttpSecurity { * return http.build(); * } * - * * @return the {@link AuthorizeExchangeSpec} to customize */ public AuthorizeExchangeSpec authorizeExchange() { @@ -2231,7 +1153,7 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .authorizeExchange(exchanges -> + * .authorizeExchange((exchanges) -> * exchanges * // any URL that starts with /admin/ requires the role "ROLE_ADMIN" * .pathMatchers("/admin/**").hasRole("ADMIN") @@ -2242,7 +1164,7 @@ public class ServerHttpSecurity { * .pathMatchers("/users/{username}").access((authentication, context) -> * authentication * .map(Authentication::getName) - * .map(username -> username.equals(context.getVariables().get("username"))) + * .map((username) -> username.equals(context.getVariables().get("username"))) * .map(AuthorizationDecision::new) * ) * // allows providing a custom matching strategy that requires the role "ROLE_CUSTOM" @@ -2253,9 +1175,8 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * @param authorizeExchangeCustomizer the {@link Customizer} to provide more options for - * the {@link AuthorizeExchangeSpec} + * @param authorizeExchangeCustomizer the {@link Customizer} to provide more options + * for the {@link AuthorizeExchangeSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity authorizeExchange(Customizer authorizeExchangeCustomizer) { @@ -2301,7 +1222,7 @@ public class ServerHttpSecurity { * public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { * http * // ... - * .logout(logout -> + * .logout((logout) -> * logout * // configures how log out is done * .logoutHandler(logoutHandler) @@ -2313,9 +1234,8 @@ public class ServerHttpSecurity { * return http.build(); * } * - * - * @param logoutCustomizer the {@link Customizer} to provide more options for - * the {@link LogoutSpec} + * @param logoutCustomizer the {@link Customizer} to provide more options for the + * {@link LogoutSpec} * @return the {@link ServerHttpSecurity} to customize */ public ServerHttpSecurity logout(Customizer logoutCustomizer) { @@ -2327,8 +1247,9 @@ public class ServerHttpSecurity { } /** - * Configures the request cache which is used when a flow is interrupted (i.e. due to requesting credentials) so - * that the request can be replayed after authentication. An example configuration can be found below: + * Configures the request cache which is used when a flow is interrupted (i.e. due to + * requesting credentials) so that the request can be replayed after authentication. + * An example configuration can be found below: * *
     	 *  @Bean
    @@ -2341,7 +1262,6 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * * @return the {@link RequestCacheSpec} to customize */ public RequestCacheSpec requestCache() { @@ -2349,15 +1269,16 @@ public class ServerHttpSecurity { } /** - * Configures the request cache which is used when a flow is interrupted (i.e. due to requesting credentials) so - * that the request can be replayed after authentication. An example configuration can be found below: + * Configures the request cache which is used when a flow is interrupted (i.e. due to + * requesting credentials) so that the request can be replayed after authentication. + * An example configuration can be found below: * *
     	 *  @Bean
     	 *  public SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
     	 *      http
     	 *          // ...
    -	 *          .requestCache(requestCache ->
    +	 *          .requestCache((requestCache) ->
     	 *              requestCache
     	 *                  // configures how the request is cached
     	 *                  .requestCache(customRequestCache)
    @@ -2365,7 +1286,6 @@ public class ServerHttpSecurity {
     	 *      return http.build();
     	 *  }
     	 * 
    - * * @param requestCacheCustomizer the {@link Customizer} to provide more options for * the {@link RequestCacheSpec} * @return the {@link ServerHttpSecurity} to customize @@ -2391,7 +1311,8 @@ public class ServerHttpSecurity { */ public SecurityWebFilterChain build() { if (this.built != null) { - throw new IllegalStateException("This has already been built with the following stacktrace. " + buildToString()); + throw new IllegalStateException( + "This has already been built with the following stacktrace. " + buildToString()); } this.built = new RuntimeException("First Build Invocation").fillInStackTrace(); if (this.headers != null) { @@ -2467,25 +1388,24 @@ public class ServerHttpSecurity { this.logout.configure(this); } this.requestCache.configure(this); - this.addFilterAt(new SecurityContextServerWebExchangeWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE); + this.addFilterAt(new SecurityContextServerWebExchangeWebFilter(), + SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE); if (this.authorizeExchange != null) { ServerAuthenticationEntryPoint authenticationEntryPoint = getAuthenticationEntryPoint(); ExceptionTranslationWebFilter exceptionTranslationWebFilter = new ExceptionTranslationWebFilter(); if (authenticationEntryPoint != null) { - exceptionTranslationWebFilter.setAuthenticationEntryPoint( - authenticationEntryPoint); + exceptionTranslationWebFilter.setAuthenticationEntryPoint(authenticationEntryPoint); } ServerAccessDeniedHandler accessDeniedHandler = getAccessDeniedHandler(); if (accessDeniedHandler != null) { - exceptionTranslationWebFilter.setAccessDeniedHandler( - accessDeniedHandler); + exceptionTranslationWebFilter.setAccessDeniedHandler(accessDeniedHandler); } this.addFilterAt(exceptionTranslationWebFilter, SecurityWebFiltersOrder.EXCEPTION_TRANSLATION); this.authorizeExchange.configure(this); } AnnotationAwareOrderComparator.sort(this.webFilters); List sortedWebFilters = new ArrayList<>(); - this.webFilters.forEach( f -> { + this.webFilters.forEach((f) -> { if (f instanceof OrderedWebFilter) { f = ((OrderedWebFilter) f).webFilter; } @@ -2496,8 +1416,8 @@ public class ServerHttpSecurity { } private String buildToString() { - try(StringWriter writer = new StringWriter()) { - try(PrintWriter printer = new PrintWriter(writer)) { + try (StringWriter writer = new StringWriter()) { + try (PrintWriter printer = new PrintWriter(writer)) { printer.println(); printer.println(); this.built.printStackTrace(printer); @@ -2505,8 +1425,9 @@ public class ServerHttpSecurity { printer.println(); return writer.toString(); } - } catch(IOException e) { - throw new RuntimeException(e); + } + catch (IOException ex) { + throw new RuntimeException(ex); } } @@ -2517,7 +1438,8 @@ public class ServerHttpSecurity { if (this.defaultEntryPoints.size() == 1) { return this.defaultEntryPoints.get(0).getEntryPoint(); } - DelegatingServerAuthenticationEntryPoint result = new DelegatingServerAuthenticationEntryPoint(this.defaultEntryPoints); + DelegatingServerAuthenticationEntryPoint result = new DelegatingServerAuthenticationEntryPoint( + this.defaultEntryPoints); result.setDefaultEntryPoint(this.defaultEntryPoints.get(this.defaultEntryPoints.size() - 1).getEntryPoint()); return result; } @@ -2529,8 +1451,8 @@ public class ServerHttpSecurity { if (this.defaultAccessDeniedHandlers.size() == 1) { return this.defaultAccessDeniedHandlers.get(0).getAccessDeniedHandler(); } - ServerWebExchangeDelegatingServerAccessDeniedHandler result = - new ServerWebExchangeDelegatingServerAccessDeniedHandler(this.defaultAccessDeniedHandlers); + ServerWebExchangeDelegatingServerAccessDeniedHandler result = new ServerWebExchangeDelegatingServerAccessDeniedHandler( + this.defaultAccessDeniedHandlers); result.setDefaultAccessDeniedHandler(this.defaultAccessDeniedHandlers .get(this.defaultAccessDeniedHandlers.size() - 1).getAccessDeniedHandler()); return result; @@ -2545,13 +1467,37 @@ public class ServerHttpSecurity { } private WebFilter securityContextRepositoryWebFilter() { - ServerSecurityContextRepository repository = this.securityContextRepository == null ? - new WebSessionServerSecurityContextRepository() : this.securityContextRepository; + ServerSecurityContextRepository repository = (this.securityContextRepository != null) + ? this.securityContextRepository : new WebSessionServerSecurityContextRepository(); WebFilter result = new ReactorContextWebFilter(repository); return new OrderedWebFilter(result, SecurityWebFiltersOrder.REACTOR_CONTEXT.getOrder()); } - protected ServerHttpSecurity() {} + private T getBean(Class beanClass) { + if (this.context == null) { + return null; + } + return this.context.getBean(beanClass); + } + + private T getBeanOrNull(Class beanClass) { + return getBeanOrNull(ResolvableType.forClass(beanClass)); + } + + private T getBeanOrNull(ResolvableType type) { + if (this.context == null) { + return null; + } + String[] names = this.context.getBeanNamesForType(type); + if (names.length == 1) { + return (T) this.context.getBean(names[0]); + } + return null; + } + + protected void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.context = applicationContext; + } /** * Configures authorization @@ -2560,10 +1506,13 @@ public class ServerHttpSecurity { * @since 5.0 * @see #authorizeExchange() */ - public class AuthorizeExchangeSpec - extends AbstractServerWebExchangeMatcherRegistry { - private DelegatingReactiveAuthorizationManager.Builder managerBldr = DelegatingReactiveAuthorizationManager.builder(); + public class AuthorizeExchangeSpec extends AbstractServerWebExchangeMatcherRegistry { + + private DelegatingReactiveAuthorizationManager.Builder managerBldr = DelegatingReactiveAuthorizationManager + .builder(); + private ServerWebExchangeMatcher matcher; + private boolean anyExchangeRegistered; /** @@ -2587,20 +1536,17 @@ public class ServerHttpSecurity { @Override protected Access registerMatcher(ServerWebExchangeMatcher matcher) { - if (this.anyExchangeRegistered) { - throw new IllegalStateException("Cannot register " + matcher + " which would be unreachable because anyExchange() has already been registered."); - } - if (this.matcher != null) { - throw new IllegalStateException("The matcher " + matcher + " does not have an access rule defined"); - } + Assert.state(!this.anyExchangeRegistered, () -> "Cannot register " + matcher + + " which would be unreachable because anyExchange() has already been registered."); + Assert.state(this.matcher == null, + () -> "The matcher " + matcher + " does not have an access rule defined"); this.matcher = matcher; return new Access(); } protected void configure(ServerHttpSecurity http) { - if (this.matcher != null) { - throw new IllegalStateException("The matcher " + this.matcher + " does not have an access rule defined"); - } + Assert.state(this.matcher == null, + () -> "The matcher " + this.matcher + " does not have an access rule defined"); AuthorizationWebFilter result = new AuthorizationWebFilter(this.managerBldr.build()); http.addFilterAt(result, SecurityWebFiltersOrder.AUTHORIZATION); } @@ -2615,7 +1561,7 @@ public class ServerHttpSecurity { * @return the {@link AuthorizeExchangeSpec} to configure */ public AuthorizeExchangeSpec permitAll() { - return access( (a, e) -> Mono.just(new AuthorizationDecision(true))); + return access((a, e) -> Mono.just(new AuthorizationDecision(true))); } /** @@ -2623,11 +1569,12 @@ public class ServerHttpSecurity { * @return the {@link AuthorizeExchangeSpec} to configure */ public AuthorizeExchangeSpec denyAll() { - return access( (a, e) -> Mono.just(new AuthorizationDecision(false))); + return access((a, e) -> Mono.just(new AuthorizationDecision(false))); } /** - * Require a specific role. This is a shorcut for {@link #hasAuthority(String)} + * Require a specific role. This is a shorcut for + * {@link #hasAuthority(String)} * @param role the role (i.e. "USER" would require "ROLE_USER") * @return the {@link AuthorizeExchangeSpec} to configure */ @@ -2636,7 +1583,8 @@ public class ServerHttpSecurity { } /** - * Require any specific role. This is a shortcut for {@link #hasAnyAuthority(String...)} + * Require any specific role. This is a shortcut for + * {@link #hasAnyAuthority(String...)} * @param roles the roles (i.e. "USER" would require "ROLE_USER") * @return the {@link AuthorizeExchangeSpec} to configure */ @@ -2646,7 +1594,8 @@ public class ServerHttpSecurity { /** * Require a specific authority. - * @param authority the authority to require (i.e. "USER" would require authority of "USER"). + * @param authority the authority to require (i.e. "USER" would require + * authority of "USER"). * @return the {@link AuthorizeExchangeSpec} to configure */ public AuthorizeExchangeSpec hasAuthority(String authority) { @@ -2655,7 +1604,8 @@ public class ServerHttpSecurity { /** * Require any authority - * @param authorities the authorities to require (i.e. "USER" would require authority of "USER"). + * @param authorities the authorities to require (i.e. "USER" would require + * authority of "USER"). * @return the {@link AuthorizeExchangeSpec} to configure */ public AuthorizeExchangeSpec hasAnyAuthority(String... authorities) { @@ -2677,12 +1627,13 @@ public class ServerHttpSecurity { */ public AuthorizeExchangeSpec access(ReactiveAuthorizationManager manager) { AuthorizeExchangeSpec.this.managerBldr - .add(new ServerWebExchangeMatcherEntry<>( - AuthorizeExchangeSpec.this.matcher, manager)); + .add(new ServerWebExchangeMatcherEntry<>(AuthorizeExchangeSpec.this.matcher, manager)); AuthorizeExchangeSpec.this.matcher = null; return AuthorizeExchangeSpec.this; } + } + } /** @@ -2693,15 +1644,17 @@ public class ServerHttpSecurity { * @see #redirectToHttps() */ public class HttpsRedirectSpec { + private ServerWebExchangeMatcher serverWebExchangeMatcher; + private PortMapper portMapper; /** * Configures when this filter should redirect to https * * By default, the filter will redirect whenever an exchange's scheme is not https - * - * @param matchers the list of conditions that, when any are met, the filter should redirect to https + * @param matchers the list of conditions that, when any are met, the filter + * should redirect to https * @return the {@link HttpsRedirectSpec} for additional configuration */ public HttpsRedirectSpec httpsRedirectWhen(ServerWebExchangeMatcher... matchers) { @@ -2713,21 +1666,17 @@ public class ServerHttpSecurity { * Configures when this filter should redirect to https * * By default, the filter will redirect whenever an exchange's scheme is not https - * * @param when determines when to redirect to https * @return the {@link HttpsRedirectSpec} for additional configuration */ - public HttpsRedirectSpec httpsRedirectWhen( - Function when) { - ServerWebExchangeMatcher matcher = e -> when.apply(e) ? - ServerWebExchangeMatcher.MatchResult.match() : - ServerWebExchangeMatcher.MatchResult.notMatch(); + public HttpsRedirectSpec httpsRedirectWhen(Function when) { + ServerWebExchangeMatcher matcher = (e) -> when.apply(e) ? ServerWebExchangeMatcher.MatchResult.match() + : ServerWebExchangeMatcher.MatchResult.notMatch(); return httpsRedirectWhen(matcher); } /** * Configures a custom HTTPS port to redirect to - * * @param portMapper the {@link PortMapper} to use * @return the {@link HttpsRedirectSpec} for additional configuration */ @@ -2754,66 +1703,70 @@ public class ServerHttpSecurity { public ServerHttpSecurity and() { return ServerHttpSecurity.this; } + } /** - * Configures CSRF Protection + * Configures CSRF + * Protection * * @author Rob Winch * @since 5.0 * @see #csrf() */ - public class CsrfSpec { + public final class CsrfSpec { + + private CsrfSpec() { + } + private CsrfWebFilter filter = new CsrfWebFilter(); + private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository(); private boolean specifiedRequireCsrfProtectionMatcher; /** - * Configures the {@link ServerAccessDeniedHandler} used when a CSRF token is invalid. Default is - * to send an {@link org.springframework.http.HttpStatus#FORBIDDEN}. - * + * Configures the {@link ServerAccessDeniedHandler} used when a CSRF token is + * invalid. Default is to send an + * {@link org.springframework.http.HttpStatus#FORBIDDEN}. * @param accessDeniedHandler the access denied handler. * @return the {@link CsrfSpec} for additional configuration */ - public CsrfSpec accessDeniedHandler( - ServerAccessDeniedHandler accessDeniedHandler) { + public CsrfSpec accessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { this.filter.setAccessDeniedHandler(accessDeniedHandler); return this; } /** - * Configures the {@link ServerCsrfTokenRepository} used to persist the CSRF Token. Default is + * Configures the {@link ServerCsrfTokenRepository} used to persist the CSRF + * Token. Default is * {@link org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository}. - * * @param csrfTokenRepository the repository to use * @return the {@link CsrfSpec} for additional configuration */ - public CsrfSpec csrfTokenRepository( - ServerCsrfTokenRepository csrfTokenRepository) { + public CsrfSpec csrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) { this.csrfTokenRepository = csrfTokenRepository; return this; } /** - * Configures the {@link ServerWebExchangeMatcher} used to determine when CSRF protection is enabled. Default is - * PUT, POST, DELETE requests. - * + * Configures the {@link ServerWebExchangeMatcher} used to determine when CSRF + * protection is enabled. Default is PUT, POST, DELETE requests. * @param requireCsrfProtectionMatcher the matcher to use * @return the {@link CsrfSpec} for additional configuration */ - public CsrfSpec requireCsrfProtectionMatcher( - ServerWebExchangeMatcher requireCsrfProtectionMatcher) { + public CsrfSpec requireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) { this.filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); this.specifiedRequireCsrfProtectionMatcher = true; return this; } /** - * Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart - * data requests. - * - * @param enabled true if should read from multipart form body, else false. Default is false + * Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token + * from the body of multipart data requests. + * @param enabled true if should read from multipart form body, else false. + * Default is false * @return the {@link CsrfSpec} for additional configuration */ public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) { @@ -2821,7 +1774,6 @@ public class ServerHttpSecurity { return this; } - /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -2831,8 +1783,8 @@ public class ServerHttpSecurity { } /** - * Disables CSRF Protection. Disabling CSRF Protection is only recommended when the application is never used - * within a browser. + * Disables CSRF Protection. Disabling CSRF Protection is only recommended when + * the application is never used within a browser. * @return the {@link ServerHttpSecurity} to continue configuring */ public ServerHttpSecurity disable() { @@ -2844,13 +1796,13 @@ public class ServerHttpSecurity { if (this.csrfTokenRepository != null) { this.filter.setCsrfTokenRepository(this.csrfTokenRepository); if (ServerHttpSecurity.this.logout != null) { - ServerHttpSecurity.this.logout.addLogoutHandler(new CsrfServerLogoutHandler(this.csrfTokenRepository)); + ServerHttpSecurity.this.logout + .addLogoutHandler(new CsrfServerLogoutHandler(this.csrfTokenRepository)); } } http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF); } - private CsrfSpec() {} } /** @@ -2860,7 +1812,10 @@ public class ServerHttpSecurity { * @since 5.0 * @see #exceptionHandling() */ - public class ExceptionHandlingSpec { + public final class ExceptionHandlingSpec { + + private ExceptionHandlingSpec() { + } /** * Configures what to do when the application request authentication @@ -2873,7 +1828,8 @@ public class ServerHttpSecurity { } /** - * Configures what to do when an authenticated user does not hold a required authority + * Configures what to do when an authenticated user does not hold a required + * authority * @param accessDeniedHandler the access denied handler to use * @return the {@link ExceptionHandlingSpec} to configure * @@ -2892,20 +1848,23 @@ public class ServerHttpSecurity { return ServerHttpSecurity.this; } - private ExceptionHandlingSpec() {} } /** - * Configures the request cache which is used when a flow is interrupted (i.e. due to requesting credentials) so - * that the request can be replayed after authentication. + * Configures the request cache which is used when a flow is interrupted (i.e. due to + * requesting credentials) so that the request can be replayed after authentication. * * @author Rob Winch * @since 5.0 * @see #requestCache() */ - public class RequestCacheSpec { + public final class RequestCacheSpec { + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); + private RequestCacheSpec() { + } + /** * Configures the cache used * @param requestCache the request cache @@ -2940,7 +1899,6 @@ public class ServerHttpSecurity { return and(); } - private RequestCacheSpec() {} } /** @@ -2950,17 +1908,20 @@ public class ServerHttpSecurity { * @since 5.0 * @see #httpBasic() */ - public class HttpBasicSpec { + public final class HttpBasicSpec { + private ReactiveAuthenticationManager authenticationManager; private ServerSecurityContextRepository securityContextRepository; private ServerAuthenticationEntryPoint entryPoint = new HttpBasicServerAuthenticationEntryPoint(); + private HttpBasicSpec() { + } + /** * The {@link ReactiveAuthenticationManager} used to authenticate. Defaults to * {@link ServerHttpSecurity#authenticationManager(ReactiveAuthenticationManager)}. - * * @param authenticationManager the authentication manager to use * @return the {@link HttpBasicSpec} to continue configuring */ @@ -2970,11 +1931,11 @@ public class ServerHttpSecurity { } /** - * The {@link ServerSecurityContextRepository} used to save the {@code Authentication}. Defaults to - * {@link NoOpServerSecurityContextRepository}. For the {@code SecurityContext} to be loaded on subsequent - * requests the {@link ReactorContextWebFilter} must be configured to be able to load the value (they are not - * implicitly linked). - * + * The {@link ServerSecurityContextRepository} used to save the + * {@code Authentication}. Defaults to + * {@link NoOpServerSecurityContextRepository}. For the {@code SecurityContext} to + * be loaded on subsequent requests the {@link ReactorContextWebFilter} must be + * configured to be able to load the value (they are not implicitly linked). * @param securityContextRepository the repository to use * @return the {@link HttpBasicSpec} to continue configuring */ @@ -2985,12 +1946,13 @@ public class ServerHttpSecurity { /** * Allows easily setting the entry point. - * @param authenticationEntryPoint the {@link ServerAuthenticationEntryPoint} to use + * @param authenticationEntryPoint the {@link ServerAuthenticationEntryPoint} to + * use * @return {@link HttpBasicSpec} for additional customization * @since 5.2.0 * @author Ankur Pathak */ - public HttpBasicSpec authenticationEntryPoint(ServerAuthenticationEntryPoint authenticationEntryPoint){ + public HttpBasicSpec authenticationEntryPoint(ServerAuthenticationEntryPoint authenticationEntryPoint) { Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null"); this.entryPoint = authenticationEntryPoint; return this; @@ -3015,21 +1977,19 @@ public class ServerHttpSecurity { protected void configure(ServerHttpSecurity http) { MediaTypeServerWebExchangeMatcher restMatcher = new MediaTypeServerWebExchangeMatcher( - MediaType.APPLICATION_ATOM_XML, - MediaType.APPLICATION_FORM_URLENCODED, MediaType.APPLICATION_JSON, - MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_XML, - MediaType.MULTIPART_FORM_DATA, MediaType.TEXT_XML); + MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_FORM_URLENCODED, MediaType.APPLICATION_JSON, + MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_XML, MediaType.MULTIPART_FORM_DATA, + MediaType.TEXT_XML); restMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); ServerHttpSecurity.this.defaultEntryPoints.add(new DelegateEntry(restMatcher, this.entryPoint)); - AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter( - this.authenticationManager); - authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); + AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager); + authenticationFilter + .setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC); } - private HttpBasicSpec() {} } /** @@ -3039,8 +1999,10 @@ public class ServerHttpSecurity { * @since 5.0 * @see #formLogin() */ - public class FormLoginSpec { - private final RedirectServerAuthenticationSuccessHandler defaultSuccessHandler = new RedirectServerAuthenticationSuccessHandler("/"); + public final class FormLoginSpec { + + private final RedirectServerAuthenticationSuccessHandler defaultSuccessHandler = new RedirectServerAuthenticationSuccessHandler( + "/"); private RedirectServerAuthenticationEntryPoint defaultEntryPoint; @@ -3058,10 +2020,12 @@ public class ServerHttpSecurity { private ServerAuthenticationSuccessHandler authenticationSuccessHandler = this.defaultSuccessHandler; + private FormLoginSpec() { + } + /** * The {@link ReactiveAuthenticationManager} used to authenticate. Defaults to * {@link ServerHttpSecurity#authenticationManager(ReactiveAuthenticationManager)}. - * * @param authenticationManager the authentication manager to use * @return the {@link FormLoginSpec} to continue configuring */ @@ -3071,29 +2035,32 @@ public class ServerHttpSecurity { } /** - * The {@link ServerAuthenticationSuccessHandler} used after authentication success. Defaults to - * {@link RedirectServerAuthenticationSuccessHandler}. + * The {@link ServerAuthenticationSuccessHandler} used after authentication + * success. Defaults to {@link RedirectServerAuthenticationSuccessHandler}. * @param authenticationSuccessHandler the success handler to use * @return the {@link FormLoginSpec} to continue configuring */ public FormLoginSpec authenticationSuccessHandler( - ServerAuthenticationSuccessHandler authenticationSuccessHandler) { + ServerAuthenticationSuccessHandler authenticationSuccessHandler) { Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); this.authenticationSuccessHandler = authenticationSuccessHandler; return this; } /** - * Configures the log in page to redirect to, the authentication failure page, and when authentication is - * performed. The default is that Spring Security will generate a log in page at "/login" and a log out page at - * "/logout". If this is customized: + * Configures the log in page to redirect to, the authentication failure page, and + * when authentication is performed. The default is that Spring Security will + * generate a log in page at "/login" and a log out page at "/logout". If this is + * customized: *
      *
    • The default log in & log out page are no longer provided
    • *
    • The application must render a log in page at the provided URL
    • - *
    • The application must render an authentication error page at the provided URL + "?error"
    • + *
    • The application must render an authentication error page at the provided + * URL + "?error"
    • *
    • Authentication will occur for POST to the provided URL
    • *
    - * @param loginPage the url to redirect to which provides a form to log in (i.e. "/login") + * @param loginPage the url to redirect to which provides a form to log in (i.e. + * "/login") * @return the {@link FormLoginSpec} to continue configuring * @see #authenticationEntryPoint(ServerAuthenticationEntryPoint) * @see #requiresAuthenticationMatcher(ServerWebExchangeMatcher) @@ -3106,7 +2073,8 @@ public class ServerHttpSecurity { this.requiresAuthenticationMatcher = ServerWebExchangeMatchers.pathMatchers(HttpMethod.POST, loginPage); } if (this.authenticationFailureHandler == null) { - this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler(loginPage + "?error"); + this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler( + loginPage + "?error"); } return this; } @@ -3135,22 +2103,25 @@ public class ServerHttpSecurity { } /** - * Configures how a failed authentication is handled. The default is to redirect to "/login?error". + * Configures how a failed authentication is handled. The default is to redirect + * to "/login?error". * @param authenticationFailureHandler the handler to use * @return the {@link FormLoginSpec} to continue configuring * @see #loginPage(String) */ - public FormLoginSpec authenticationFailureHandler(ServerAuthenticationFailureHandler authenticationFailureHandler) { + public FormLoginSpec authenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { this.authenticationFailureHandler = authenticationFailureHandler; return this; } /** - * The {@link ServerSecurityContextRepository} used to save the {@code Authentication}. Defaults to - * {@link WebSessionServerSecurityContextRepository}. For the {@code SecurityContext} to be loaded on subsequent - * requests the {@link ReactorContextWebFilter} must be configured to be able to load the value (they are not - * implicitly linked). - * + * The {@link ServerSecurityContextRepository} used to save the + * {@code Authentication}. Defaults to + * {@link WebSessionServerSecurityContextRepository}. For the + * {@code SecurityContext} to be loaded on subsequent requests the + * {@link ReactorContextWebFilter} must be configured to be able to load the value + * (they are not implicitly linked). * @param securityContextRepository the repository to use * @return the {@link FormLoginSpec} to continue configuring */ @@ -3180,7 +2151,8 @@ public class ServerHttpSecurity { if (this.authenticationEntryPoint == null) { this.isEntryPointExplicit = false; loginPage("/login"); - } else { + } + else { this.isEntryPointExplicit = true; } if (http.requestCache != null) { @@ -3190,12 +2162,11 @@ public class ServerHttpSecurity { this.defaultEntryPoint.setRequestCache(requestCache); } } - MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( - MediaType.TEXT_HTML); + MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(MediaType.TEXT_HTML); htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); - ServerHttpSecurity.this.defaultEntryPoints.add(0, new DelegateEntry(htmlMatcher, this.authenticationEntryPoint)); - AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter( - this.authenticationManager); + ServerHttpSecurity.this.defaultEntryPoints.add(0, + new DelegateEntry(htmlMatcher, this.authenticationEntryPoint)); + AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager); authenticationFilter.setRequiresAuthenticationMatcher(this.requiresAuthenticationMatcher); authenticationFilter.setAuthenticationFailureHandler(this.authenticationFailureHandler); authenticationFilter.setAuthenticationConverter(new ServerFormLoginAuthenticationConverter()); @@ -3204,11 +2175,13 @@ public class ServerHttpSecurity { http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.FORM_LOGIN); } - private FormLoginSpec() { - } } - private class LoginPageSpec { + private final class LoginPageSpec { + + private LoginPageSpec() { + } + protected void configure(ServerHttpSecurity http) { if (http.authenticationEntryPoint != null) { return; @@ -3234,7 +2207,6 @@ public class ServerHttpSecurity { } } - private LoginPageSpec() {} } /** @@ -3244,7 +2216,8 @@ public class ServerHttpSecurity { * @since 5.0 * @see #headers() */ - public class HeaderSpec { + public final class HeaderSpec { + private final List writers; private CacheControlServerHttpHeadersWriter cacheControl = new CacheControlServerHttpHeadersWriter(); @@ -3263,6 +2236,11 @@ public class ServerHttpSecurity { private ReferrerPolicyServerHttpHeadersWriter referrerPolicy = new ReferrerPolicyServerHttpHeadersWriter(); + private HeaderSpec() { + this.writers = new ArrayList<>(Arrays.asList(this.cacheControl, this.contentTypeOptions, this.hsts, + this.frameOptions, this.xss, this.featurePolicy, this.contentSecurityPolicy, this.referrerPolicy)); + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -3290,9 +2268,8 @@ public class ServerHttpSecurity { /** * Configures cache control headers - * - * @param cacheCustomizer the {@link Customizer} to provide more options for - * the {@link CacheSpec} + * @param cacheCustomizer the {@link Customizer} to provide more options for the + * {@link CacheSpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec cache(Customizer cacheCustomizer) { @@ -3310,9 +2287,8 @@ public class ServerHttpSecurity { /** * Configures content type response headers - * - * @param contentTypeOptionsCustomizer the {@link Customizer} to provide more options for - * the {@link ContentTypeOptionsSpec} + * @param contentTypeOptionsCustomizer the {@link Customizer} to provide more + * options for the {@link ContentTypeOptionsSpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec contentTypeOptions(Customizer contentTypeOptionsCustomizer) { @@ -3330,9 +2306,8 @@ public class ServerHttpSecurity { /** * Configures frame options response headers - * - * @param frameOptionsCustomizer the {@link Customizer} to provide more options for - * the {@link FrameOptionsSpec} + * @param frameOptionsCustomizer the {@link Customizer} to provide more options + * for the {@link FrameOptionsSpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec frameOptions(Customizer frameOptionsCustomizer) { @@ -3342,8 +2317,8 @@ public class ServerHttpSecurity { /** * Configures custom headers writer - * - * @param serverHttpHeadersWriter the {@link ServerHttpHeadersWriter} to provide custom headers writer + * @param serverHttpHeadersWriter the {@link ServerHttpHeadersWriter} to provide + * custom headers writer * @return the {@link HeaderSpec} to customize * @since 5.3.0 * @author Ankur Pathak @@ -3364,9 +2339,8 @@ public class ServerHttpSecurity { /** * Configures the Strict Transport Security response headers - * - * @param hstsCustomizer the {@link Customizer} to provide more options for - * the {@link HstsSpec} + * @param hstsCustomizer the {@link Customizer} to provide more options for the + * {@link HstsSpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec hsts(Customizer hstsCustomizer) { @@ -3390,9 +2364,8 @@ public class ServerHttpSecurity { /** * Configures x-xss-protection response header. - * - * @param xssProtectionCustomizer the {@link Customizer} to provide more options for - * the {@link XssProtectionSpec} + * @param xssProtectionCustomizer the {@link Customizer} to provide more options + * for the {@link XssProtectionSpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec xssProtection(Customizer xssProtectionCustomizer) { @@ -3411,9 +2384,8 @@ public class ServerHttpSecurity { /** * Configures {@code Content-Security-Policy} response header. - * - * @param contentSecurityPolicyCustomizer the {@link Customizer} to provide more options for - * the {@link ContentSecurityPolicySpec} + * @param contentSecurityPolicyCustomizer the {@link Customizer} to provide more + * options for the {@link ContentSecurityPolicySpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec contentSecurityPolicy(Customizer contentSecurityPolicyCustomizer) { @@ -3449,9 +2421,8 @@ public class ServerHttpSecurity { /** * Configures {@code Referrer-Policy} response header. - * - * @param referrerPolicyCustomizer the {@link Customizer} to provide more options for - * the {@link ReferrerPolicySpec} + * @param referrerPolicyCustomizer the {@link Customizer} to provide more options + * for the {@link ReferrerPolicySpec} * @return the {@link HeaderSpec} to customize */ public HeaderSpec referrerPolicy(Customizer referrerPolicyCustomizer) { @@ -3461,9 +2432,14 @@ public class ServerHttpSecurity { /** * Configures cache control headers + * * @see #cache() */ - public class CacheSpec { + public final class CacheSpec { + + private CacheSpec() { + } + /** * Disables cache control response headers * @return the {@link HeaderSpec} to configure @@ -3473,14 +2449,18 @@ public class ServerHttpSecurity { return HeaderSpec.this; } - private CacheSpec() {} } /** * The content type headers + * * @see #contentTypeOptions() */ - public class ContentTypeOptionsSpec { + public final class ContentTypeOptionsSpec { + + private ContentTypeOptionsSpec() { + } + /** * Disables the content type options response header * @return the {@link HeaderSpec} to configure @@ -3490,14 +2470,18 @@ public class ServerHttpSecurity { return HeaderSpec.this; } - private ContentTypeOptionsSpec() {} } /** * Configures frame options response header + * * @see #frameOptions() */ - public class FrameOptionsSpec { + public final class FrameOptionsSpec { + + private FrameOptionsSpec() { + } + /** * The mode to configure. Default is * {@link org.springframework.security.web.server.header.XFrameOptionsServerHttpHeadersWriter.Mode#DENY} @@ -3510,7 +2494,8 @@ public class ServerHttpSecurity { } /** - * Allows method chaining to continue configuring the {@link ServerHttpSecurity} + * Allows method chaining to continue configuring the + * {@link ServerHttpSecurity} * @return the {@link HeaderSpec} to continue configuring */ private HeaderSpec and() { @@ -3526,14 +2511,18 @@ public class ServerHttpSecurity { return and(); } - private FrameOptionsSpec() {} } /** * Configures Strict Transport Security response header + * * @see #hsts() */ - public class HstsSpec { + public final class HstsSpec { + + private HstsSpec() { + } + /** * Configures the max age. Default is one year. * @param maxAge the max age @@ -3560,10 +2549,9 @@ public class ServerHttpSecurity { *

    * *

    - * See Website hstspreload.org - * for additional details. + * See Website hstspreload.org for + * additional details. *

    - * * @param preload if subdomains should be included * @return the {@link HstsSpec} to continue configuring * @since 5.2.0 @@ -3575,7 +2563,8 @@ public class ServerHttpSecurity { } /** - * Allows method chaining to continue configuring the {@link ServerHttpSecurity} + * Allows method chaining to continue configuring the + * {@link ServerHttpSecurity} * @return the {@link HeaderSpec} to continue configuring */ public HeaderSpec and() { @@ -3591,14 +2580,18 @@ public class ServerHttpSecurity { return HeaderSpec.this; } - private HstsSpec() {} } /** * Configures x-xss-protection response header + * * @see #xssProtection() */ - public class XssProtectionSpec { + public final class XssProtectionSpec { + + private XssProtectionSpec() { + } + /** * Disables the x-xss-protection response header * @return the {@link HeaderSpec} to continue configuring @@ -3608,21 +2601,26 @@ public class ServerHttpSecurity { return HeaderSpec.this; } - private XssProtectionSpec() {} } /** * Configures {@code Content-Security-Policy} response header. * - * @see #contentSecurityPolicy(String) * @since 5.1 + * @see #contentSecurityPolicy(String) */ - public class ContentSecurityPolicySpec { + public final class ContentSecurityPolicySpec { + private static final String DEFAULT_SRC_SELF_POLICY = "default-src 'self'"; + private ContentSecurityPolicySpec() { + HeaderSpec.this.contentSecurityPolicy.setPolicyDirectives(DEFAULT_SRC_SELF_POLICY); + } + /** - * Whether to include the {@code Content-Security-Policy-Report-Only} header in - * the response. Otherwise, defaults to the {@code Content-Security-Policy} header. + * Whether to include the {@code Content-Security-Policy-Report-Only} header + * in the response. Otherwise, defaults to the {@code Content-Security-Policy} + * header. * @param reportOnly whether to only report policy violations * @return the {@link HeaderSpec} to continue configuring */ @@ -3633,7 +2631,6 @@ public class ServerHttpSecurity { /** * Sets the security policy directive(s) to be used in the response header. - * * @param policyDirectives the security policy directive(s) * @return the {@link HeaderSpec} to continue configuring */ @@ -3655,18 +2652,19 @@ public class ServerHttpSecurity { HeaderSpec.this.contentSecurityPolicy.setPolicyDirectives(policyDirectives); } - private ContentSecurityPolicySpec() { - HeaderSpec.this.contentSecurityPolicy.setPolicyDirectives(DEFAULT_SRC_SELF_POLICY); - } } /** * Configures {@code Feature-Policy} response header. * - * @see #featurePolicy(String) * @since 5.1 + * @see #featurePolicy(String) */ - public class FeaturePolicySpec { + public final class FeaturePolicySpec { + + private FeaturePolicySpec(String policyDirectives) { + HeaderSpec.this.featurePolicy.setPolicyDirectives(policyDirectives); + } /** * Allows method chaining to continue configuring the @@ -3677,24 +2675,26 @@ public class ServerHttpSecurity { return HeaderSpec.this; } - private FeaturePolicySpec(String policyDirectives) { - HeaderSpec.this.featurePolicy.setPolicyDirectives(policyDirectives); - } - } /** * Configures {@code Referrer-Policy} response header. * + * @since 5.1 * @see #referrerPolicy() * @see #referrerPolicy(ReferrerPolicy) - * @since 5.1 */ - public class ReferrerPolicySpec { + public final class ReferrerPolicySpec { + + private ReferrerPolicySpec() { + } + + private ReferrerPolicySpec(ReferrerPolicy referrerPolicy) { + HeaderSpec.this.referrerPolicy.setPolicy(referrerPolicy); + } /** * Sets the policy to be used in the response header. - * * @param referrerPolicy a referrer policy * @return the {@link ReferrerPolicySpec} to continue configuring */ @@ -3712,37 +2712,31 @@ public class ServerHttpSecurity { return HeaderSpec.this; } - private ReferrerPolicySpec() { - } - - private ReferrerPolicySpec(ReferrerPolicy referrerPolicy) { - HeaderSpec.this.referrerPolicy.setPolicy(referrerPolicy); - } - - } - - private HeaderSpec() { - this.writers = new ArrayList<>( - Arrays.asList(this.cacheControl, this.contentTypeOptions, this.hsts, - this.frameOptions, this.xss, this.featurePolicy, this.contentSecurityPolicy, - this.referrerPolicy)); } } /** * Configures log out + * * @author Shazin Sadakath * @since 5.0 * @see #logout() */ public final class LogoutSpec { + private LogoutWebFilter logoutWebFilter = new LogoutWebFilter(); + private final SecurityContextServerLogoutHandler DEFAULT_LOGOUT_HANDLER = new SecurityContextServerLogoutHandler(); + private List logoutHandlers = new ArrayList<>(Arrays.asList(this.DEFAULT_LOGOUT_HANDLER)); + private LogoutSpec() { + } + /** - * Configures the logout handler. Default is {@code SecurityContextServerLogoutHandler} + * Configures the logout handler. Default is + * {@code SecurityContextServerLogoutHandler} * @param logoutHandler * @return the {@link LogoutSpec} to configure */ @@ -3760,13 +2754,14 @@ public class ServerHttpSecurity { /** * Configures what URL a POST to will trigger a log out. - * @param logoutUrl the url to trigger a log out (i.e. "/signout" would mean a POST to "/signout" would trigger - * log out) + * @param logoutUrl the url to trigger a log out (i.e. "/signout" would mean a + * POST to "/signout" would trigger log out) * @return the {@link LogoutSpec} to configure */ public LogoutSpec logoutUrl(String logoutUrl) { Assert.notNull(logoutUrl, "logoutUrl must not be null"); - ServerWebExchangeMatcher requiresLogout = ServerWebExchangeMatchers.pathMatchers(HttpMethod.POST, logoutUrl); + ServerWebExchangeMatcher requiresLogout = ServerWebExchangeMatchers.pathMatchers(HttpMethod.POST, + logoutUrl); return requiresLogout(requiresLogout); } @@ -3809,11 +2804,11 @@ public class ServerHttpSecurity { } if (this.logoutHandlers.isEmpty()) { return null; - } else if (this.logoutHandlers.size() == 1) { - return this.logoutHandlers.get(0); - } else { - return new DelegatingServerLogoutHandler(this.logoutHandlers); } + if (this.logoutHandlers.size() == 1) { + return this.logoutHandlers.get(0); + } + return new DelegatingServerLogoutHandler(this.logoutHandlers); } protected void configure(ServerHttpSecurity http) { @@ -3824,39 +2819,12 @@ public class ServerHttpSecurity { http.addFilterAt(this.logoutWebFilter, SecurityWebFiltersOrder.LOGOUT); } - private LogoutSpec() {} - } - - private T getBean(Class beanClass) { - if (this.context == null) { - return null; - } - return this.context.getBean(beanClass); - } - - private T getBeanOrNull(Class beanClass) { - return getBeanOrNull(ResolvableType.forClass(beanClass)); - } - - - private T getBeanOrNull(ResolvableType type) { - if (this.context == null) { - return null; - } - String[] names = this.context.getBeanNamesForType(type); - if (names.length == 1) { - return (T) this.context.getBean(names[0]); - } - return null; - } - - protected void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - this.context = applicationContext; } private static class OrderedWebFilter implements WebFilter, Ordered { + private final WebFilter webFilter; + private final int order; OrderedWebFilter(WebFilter webFilter, int order) { @@ -3865,8 +2833,7 @@ public class ServerHttpSecurity { } @Override - public Mono filter(ServerWebExchange exchange, - WebFilterChain chain) { + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return this.webFilter.filter(exchange, chain); } @@ -3877,39 +2844,1124 @@ public class ServerHttpSecurity { @Override public String toString() { - return "OrderedWebFilter{" + "webFilter=" + this.webFilter + ", order=" + this.order - + '}'; + return "OrderedWebFilter{" + "webFilter=" + this.webFilter + ", order=" + this.order + '}'; } + } /** * Workaround https://jira.spring.io/projects/SPR/issues/SPR-17213 */ static class ServerWebExchangeReactorContextWebFilter implements WebFilter { + @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { - return chain.filter(exchange) - .subscriberContext(Context.of(ServerWebExchange.class, exchange)); + return chain.filter(exchange).subscriberContext(Context.of(ServerWebExchange.class, exchange)); } + + } + + /** + * Configures CORS support within Spring Security. This ensures that the + * {@link CorsWebFilter} is place in the correct order. + */ + public final class CorsSpec { + + private CorsWebFilter corsFilter; + + private CorsSpec() { + } + + /** + * Configures the {@link CorsConfigurationSource} to be used + * @param source the source to use + * @return the {@link CorsSpec} for additional configuration + */ + public CorsSpec configurationSource(CorsConfigurationSource source) { + this.corsFilter = new CorsWebFilter(source); + return this; + } + + /** + * Disables CORS support within Spring Security. + * @return the {@link ServerHttpSecurity} to continue configuring + */ + public ServerHttpSecurity disable() { + ServerHttpSecurity.this.cors = null; + return ServerHttpSecurity.this; + } + + /** + * Allows method chaining to continue configuring the {@link ServerHttpSecurity} + * @return the {@link ServerHttpSecurity} to continue configuring + */ + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + protected void configure(ServerHttpSecurity http) { + CorsWebFilter corsFilter = getCorsFilter(); + if (corsFilter != null) { + http.addFilterAt(this.corsFilter, SecurityWebFiltersOrder.CORS); + } + } + + private CorsWebFilter getCorsFilter() { + if (this.corsFilter != null) { + return this.corsFilter; + } + CorsConfigurationSource source = getBeanOrNull(CorsConfigurationSource.class); + if (source == null) { + return null; + } + CorsProcessor processor = getBeanOrNull(CorsProcessor.class); + if (processor == null) { + processor = new DefaultCorsProcessor(); + } + this.corsFilter = new CorsWebFilter(source, processor); + return this.corsFilter; + } + + } + + /** + * Configures X509 authentication + * + * @author Alexey Nesterov + * @since 5.2 + * @see #x509() + */ + public final class X509Spec { + + private X509PrincipalExtractor principalExtractor; + + private ReactiveAuthenticationManager authenticationManager; + + private X509Spec() { + } + + public X509Spec principalExtractor(X509PrincipalExtractor principalExtractor) { + this.principalExtractor = principalExtractor; + return this; + } + + public X509Spec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + return this; + } + + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + protected void configure(ServerHttpSecurity http) { + ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); + X509PrincipalExtractor principalExtractor = getPrincipalExtractor(); + AuthenticationWebFilter filter = new AuthenticationWebFilter(authenticationManager); + filter.setServerAuthenticationConverter(new ServerX509AuthenticationConverter(principalExtractor)); + http.addFilterAt(filter, SecurityWebFiltersOrder.AUTHENTICATION); + } + + private X509PrincipalExtractor getPrincipalExtractor() { + if (this.principalExtractor != null) { + return this.principalExtractor; + } + return new SubjectDnX509PrincipalExtractor(); + } + + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager != null) { + return this.authenticationManager; + } + ReactiveUserDetailsService userDetailsService = getBean(ReactiveUserDetailsService.class); + return new ReactivePreAuthenticatedAuthenticationManager(userDetailsService); + } + + } + + public final class OAuth2LoginSpec { + + private ReactiveClientRegistrationRepository clientRegistrationRepository; + + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private ServerAuthorizationRequestRepository authorizationRequestRepository; + + private ReactiveAuthenticationManager authenticationManager; + + private ServerSecurityContextRepository securityContextRepository; + + private ServerAuthenticationConverter authenticationConverter; + + private ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; + + private ServerWebExchangeMatcher authenticationMatcher; + + private ServerAuthenticationSuccessHandler authenticationSuccessHandler; + + private ServerAuthenticationFailureHandler authenticationFailureHandler; + + private OAuth2LoginSpec() { + } + + /** + * Configures the {@link ReactiveAuthenticationManager} to use. The default is + * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} + * @param authenticationManager the manager to use + * @return the {@link OAuth2LoginSpec} to customize + */ + public OAuth2LoginSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + return this; + } + + /** + * The {@link ServerSecurityContextRepository} used to save the + * {@code Authentication}. Defaults to + * {@link WebSessionServerSecurityContextRepository}. + * @param securityContextRepository the repository to use + * @return the {@link OAuth2LoginSpec} to continue configuring + * @since 5.2 + */ + public OAuth2LoginSpec securityContextRepository(ServerSecurityContextRepository securityContextRepository) { + this.securityContextRepository = securityContextRepository; + return this; + } + + /** + * The {@link ServerAuthenticationSuccessHandler} used after authentication + * success. Defaults to {@link RedirectServerAuthenticationSuccessHandler} + * redirecting to "/". + * @param authenticationSuccessHandler the success handler to use + * @return the {@link OAuth2LoginSpec} to customize + * @since 5.2 + */ + public OAuth2LoginSpec authenticationSuccessHandler( + ServerAuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + return this; + } + + /** + * The {@link ServerAuthenticationFailureHandler} used after authentication + * failure. Defaults to {@link RedirectServerAuthenticationFailureHandler} + * redirecting to "/login?error". + * @param authenticationFailureHandler the failure handler to use + * @return the {@link OAuth2LoginSpec} to customize + * @since 5.2 + */ + public OAuth2LoginSpec authenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + return this; + } + + /** + * Gets the {@link ReactiveAuthenticationManager} to use. First tries an + * explicitly configured manager, and defaults to + * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} + * @return the {@link ReactiveAuthenticationManager} to use + */ + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager == null) { + this.authenticationManager = createDefault(); + } + return this.authenticationManager; + } + + private ReactiveAuthenticationManager createDefault() { + ReactiveOAuth2AccessTokenResponseClient client = getAccessTokenResponseClient(); + OAuth2LoginReactiveAuthenticationManager oauth2Manager = new OAuth2LoginReactiveAuthenticationManager( + client, getOauth2UserService()); + GrantedAuthoritiesMapper authoritiesMapper = getBeanOrNull(GrantedAuthoritiesMapper.class); + if (authoritiesMapper != null) { + oauth2Manager.setAuthoritiesMapper(authoritiesMapper); + } + boolean oidcAuthenticationProviderEnabled = ClassUtils + .isPresent("org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); + if (!oidcAuthenticationProviderEnabled) { + return oauth2Manager; + } + OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager( + client, getOidcUserService()); + ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveJwtDecoderFactory.class, + ClientRegistration.class); + ReactiveJwtDecoderFactory jwtDecoderFactory = getBeanOrNull(type); + if (jwtDecoderFactory != null) { + oidc.setJwtDecoderFactory(jwtDecoderFactory); + } + if (authoritiesMapper != null) { + oidc.setAuthoritiesMapper(authoritiesMapper); + } + return new DelegatingReactiveAuthenticationManager(oidc, oauth2Manager); + } + + /** + * Sets the converter to use + * @param authenticationConverter the converter to use + * @return the {@link OAuth2LoginSpec} to customize + */ + public OAuth2LoginSpec authenticationConverter(ServerAuthenticationConverter authenticationConverter) { + this.authenticationConverter = authenticationConverter; + return this; + } + + private ServerAuthenticationConverter getAuthenticationConverter( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + if (this.authenticationConverter != null) { + return this.authenticationConverter; + } + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter( + clientRegistrationRepository); + delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + ServerAuthenticationConverter authenticationConverter = (exchange) -> delegate.convert(exchange).onErrorMap( + OAuth2AuthorizationException.class, + (e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())); + this.authenticationConverter = authenticationConverter; + return authenticationConverter; + } + + public OAuth2LoginSpec clientRegistrationRepository( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + this.clientRegistrationRepository = clientRegistrationRepository; + return this; + } + + public OAuth2LoginSpec authorizedClientService(ReactiveOAuth2AuthorizedClientService authorizedClientService) { + this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( + authorizedClientService); + return this; + } + + public OAuth2LoginSpec authorizedClientRepository( + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this.authorizedClientRepository = authorizedClientRepository; + return this; + } + + /** + * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s. + * @param authorizationRequestRepository the repository to use for storing + * {@link OAuth2AuthorizationRequest}'s + * @return the {@link OAuth2LoginSpec} for further configuration + * @since 5.2 + */ + public OAuth2LoginSpec authorizationRequestRepository( + ServerAuthorizationRequestRepository authorizationRequestRepository) { + this.authorizationRequestRepository = authorizationRequestRepository; + return this; + } + + /** + * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s. + * @param authorizationRequestResolver the resolver used for resolving + * {@link OAuth2AuthorizationRequest}'s + * @return the {@link OAuth2LoginSpec} for further configuration + * @since 5.2 + */ + public OAuth2LoginSpec authorizationRequestResolver( + ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver) { + this.authorizationRequestResolver = authorizationRequestResolver; + return this; + } + + /** + * Sets the {@link ServerWebExchangeMatcher matcher} used for determining if the + * request is an authentication request. + * @param authenticationMatcher the {@link ServerWebExchangeMatcher matcher} used + * for determining if the request is an authentication request + * @return the {@link OAuth2LoginSpec} for further configuration + * @since 5.2 + */ + public OAuth2LoginSpec authenticationMatcher(ServerWebExchangeMatcher authenticationMatcher) { + this.authenticationMatcher = authenticationMatcher; + return this; + } + + private ServerWebExchangeMatcher getAuthenticationMatcher() { + if (this.authenticationMatcher == null) { + this.authenticationMatcher = createAttemptAuthenticationRequestMatcher(); + } + return this.authenticationMatcher; + } + + /** + * Allows method chaining to continue configuring the {@link ServerHttpSecurity} + * @return the {@link ServerHttpSecurity} to continue configuring + */ + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + protected void configure(ServerHttpSecurity http) { + ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); + OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter(); + ServerAuthorizationRequestRepository authorizationRequestRepository = getAuthorizationRequestRepository(); + oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); + oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); + ReactiveAuthenticationManager manager = getAuthenticationManager(); + AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, + authorizedClientRepository); + authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher()); + authenticationFilter + .setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository)); + authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http)); + authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); + authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + setDefaultEntryPoints(http); + http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); + http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); + } + + private void setDefaultEntryPoints(ServerHttpSecurity http) { + String defaultLoginPage = "/login"; + Map urlToText = http.oauth2Login.getLinks(); + String providerLoginPage = null; + if (urlToText.size() == 1) { + providerLoginPage = urlToText.keySet().iterator().next(); + } + MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( + MediaType.APPLICATION_XHTML_XML, new MediaType("image", "*"), MediaType.TEXT_HTML, + MediaType.TEXT_PLAIN); + htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); + ServerWebExchangeMatcher xhrMatcher = (exchange) -> { + if (exchange.getRequest().getHeaders().getOrEmpty("X-Requested-With").contains("XMLHttpRequest")) { + return ServerWebExchangeMatcher.MatchResult.match(); + } + return ServerWebExchangeMatcher.MatchResult.notMatch(); + }; + ServerWebExchangeMatcher notXhrMatcher = new NegatedServerWebExchangeMatcher(xhrMatcher); + ServerWebExchangeMatcher defaultEntryPointMatcher = new AndServerWebExchangeMatcher(notXhrMatcher, + htmlMatcher); + if (providerLoginPage != null) { + ServerWebExchangeMatcher loginPageMatcher = new PathPatternParserServerWebExchangeMatcher( + defaultLoginPage); + ServerWebExchangeMatcher faviconMatcher = new PathPatternParserServerWebExchangeMatcher("/favicon.ico"); + ServerWebExchangeMatcher defaultLoginPageMatcher = new AndServerWebExchangeMatcher( + new OrServerWebExchangeMatcher(loginPageMatcher, faviconMatcher), defaultEntryPointMatcher); + + ServerWebExchangeMatcher matcher = new AndServerWebExchangeMatcher(notXhrMatcher, + new NegatedServerWebExchangeMatcher(defaultLoginPageMatcher)); + RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint( + providerLoginPage); + entryPoint.setRequestCache(http.requestCache.requestCache); + http.defaultEntryPoints.add(new DelegateEntry(matcher, entryPoint)); + } + RedirectServerAuthenticationEntryPoint defaultEntryPoint = new RedirectServerAuthenticationEntryPoint( + defaultLoginPage); + defaultEntryPoint.setRequestCache(http.requestCache.requestCache); + http.defaultEntryPoints.add(new DelegateEntry(defaultEntryPointMatcher, defaultEntryPoint)); + } + + private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) { + if (this.authenticationSuccessHandler == null) { + RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler(); + handler.setRequestCache(http.requestCache.requestCache); + this.authenticationSuccessHandler = handler; + } + return this.authenticationSuccessHandler; + } + + private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() { + if (this.authenticationFailureHandler == null) { + this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error"); + } + return this.authenticationFailureHandler; + } + + private ServerWebExchangeMatcher createAttemptAuthenticationRequestMatcher() { + return new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"); + } + + private ReactiveOAuth2UserService getOidcUserService() { + ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2UserService.class, + OidcUserRequest.class, OidcUser.class); + ReactiveOAuth2UserService bean = getBeanOrNull(type); + if (bean != null) { + return bean; + } + return new OidcReactiveOAuth2UserService(); + } + + private ReactiveOAuth2UserService getOauth2UserService() { + ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2UserService.class, + OAuth2UserRequest.class, OAuth2User.class); + ReactiveOAuth2UserService bean = getBeanOrNull(type); + if (bean != null) { + return bean; + } + return new DefaultReactiveOAuth2UserService(); + } + + private Map getLinks() { + Iterable registrations = getBeanOrNull( + ResolvableType.forClassWithGenerics(Iterable.class, ClientRegistration.class)); + if (registrations == null) { + return Collections.emptyMap(); + } + Map result = new HashMap<>(); + registrations.iterator().forEachRemaining( + (r) -> result.put("/oauth2/authorization/" + r.getRegistrationId(), r.getClientName())); + return result; + } + + private ReactiveOAuth2AccessTokenResponseClient getAccessTokenResponseClient() { + ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, + OAuth2AuthorizationCodeGrantRequest.class); + ReactiveOAuth2AccessTokenResponseClient bean = getBeanOrNull(type); + if (bean != null) { + return bean; + } + return new WebClientReactiveAuthorizationCodeTokenResponseClient(); + } + + private ReactiveClientRegistrationRepository getClientRegistrationRepository() { + if (this.clientRegistrationRepository == null) { + this.clientRegistrationRepository = getBeanOrNull(ReactiveClientRegistrationRepository.class); + } + return this.clientRegistrationRepository; + } + + private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() { + OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter; + if (this.authorizationRequestResolver != null) { + return new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver); + } + return new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository()); + } + + private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { + ServerOAuth2AuthorizedClientRepository result = this.authorizedClientRepository; + if (result == null) { + result = getBeanOrNull(ServerOAuth2AuthorizedClientRepository.class); + } + if (result == null) { + ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); + if (authorizedClientService != null) { + result = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(authorizedClientService); + } + } + return result; + } + + private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { + if (this.authorizationRequestRepository == null) { + this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + } + return this.authorizationRequestRepository; + } + + private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { + ReactiveOAuth2AuthorizedClientService bean = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); + if (bean != null) { + return bean; + } + return new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); + } + + } + + public final class OAuth2ClientSpec { + + private ReactiveClientRegistrationRepository clientRegistrationRepository; + + private ServerAuthenticationConverter authenticationConverter; + + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private ReactiveAuthenticationManager authenticationManager; + + private ServerAuthorizationRequestRepository authorizationRequestRepository; + + private OAuth2ClientSpec() { + } + + /** + * Sets the converter to use + * @param authenticationConverter the converter to use + * @return the {@link OAuth2ClientSpec} to customize + */ + public OAuth2ClientSpec authenticationConverter(ServerAuthenticationConverter authenticationConverter) { + this.authenticationConverter = authenticationConverter; + return this; + } + + private ServerAuthenticationConverter getAuthenticationConverter() { + if (this.authenticationConverter == null) { + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter( + getClientRegistrationRepository()); + authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + this.authenticationConverter = authenticationConverter; + } + return this.authenticationConverter; + } + + /** + * Configures the {@link ReactiveAuthenticationManager} to use. The default is + * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} + * @param authenticationManager the manager to use + * @return the {@link OAuth2ClientSpec} to customize + */ + public OAuth2ClientSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + this.authenticationManager = authenticationManager; + return this; + } + + /** + * Gets the {@link ReactiveAuthenticationManager} to use. First tries an + * explicitly configured manager, and defaults to + * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager} + * @return the {@link ReactiveAuthenticationManager} to use + */ + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager == null) { + this.authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager( + new WebClientReactiveAuthorizationCodeTokenResponseClient()); + } + return this.authenticationManager; + } + + /** + * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look + * the value up as a Bean. + * @param clientRegistrationRepository the repository to use + * @return the {@link OAuth2ClientSpec} to customize + */ + public OAuth2ClientSpec clientRegistrationRepository( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + this.clientRegistrationRepository = clientRegistrationRepository; + return this; + } + + /** + * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look + * the value up as a Bean. + * @param authorizedClientRepository the repository to use + * @return the {@link OAuth2ClientSpec} to customize + */ + public OAuth2ClientSpec authorizedClientRepository( + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this.authorizedClientRepository = authorizedClientRepository; + return this; + } + + /** + * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s. + * @param authorizationRequestRepository the repository to use for storing + * {@link OAuth2AuthorizationRequest}'s + * @return the {@link OAuth2ClientSpec} to customize + * @since 5.2 + */ + public OAuth2ClientSpec authorizationRequestRepository( + ServerAuthorizationRequestRepository authorizationRequestRepository) { + this.authorizationRequestRepository = authorizationRequestRepository; + return this; + } + + private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { + if (this.authorizationRequestRepository == null) { + this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + } + return this.authorizationRequestRepository; + } + + /** + * Allows method chaining to continue configuring the {@link ServerHttpSecurity} + * @return the {@link ServerHttpSecurity} to continue configuring + */ + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + protected void configure(ServerHttpSecurity http) { + ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); + ServerAuthenticationConverter authenticationConverter = getAuthenticationConverter(); + ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); + OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter( + authenticationManager, authenticationConverter, authorizedClientRepository); + codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + if (http.requestCache != null) { + codeGrantWebFilter.setRequestCache(http.requestCache.requestCache); + } + OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( + clientRegistrationRepository); + oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + if (http.requestCache != null) { + oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); + } + http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); + http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); + } + + private ReactiveClientRegistrationRepository getClientRegistrationRepository() { + if (this.clientRegistrationRepository != null) { + return this.clientRegistrationRepository; + } + return getBeanOrNull(ReactiveClientRegistrationRepository.class); + } + + private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { + if (this.authorizedClientRepository != null) { + return this.authorizedClientRepository; + } + ServerOAuth2AuthorizedClientRepository result = getBeanOrNull(ServerOAuth2AuthorizedClientRepository.class); + if (result != null) { + return result; + } + ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); + if (authorizedClientService != null) { + return new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(authorizedClientService); + } + return null; + } + + private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { + ReactiveOAuth2AuthorizedClientService bean = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); + if (bean != null) { + return bean; + } + return new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); + } + + } + + /** + * Configures OAuth2 Resource Server Support + */ + public class OAuth2ResourceServerSpec { + + private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); + + private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler(); + + private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter(); + + private AuthenticationConverterServerWebExchangeMatcher authenticationConverterServerWebExchangeMatcher; + + private JwtSpec jwt; + + private OpaqueTokenSpec opaqueToken; + + private ReactiveAuthenticationManagerResolver authenticationManagerResolver; + + /** + * Configures the {@link ServerAccessDeniedHandler} to use for requests + * authenticating with + * Bearer Tokens. requests. + * @param accessDeniedHandler the {@link ServerAccessDeniedHandler} to use + * @return the {@link OAuth2ResourceServerSpec} for additional configuration + * @since 5.2 + */ + public OAuth2ResourceServerSpec accessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { + Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null"); + this.accessDeniedHandler = accessDeniedHandler; + return this; + } + + /** + * Configures the {@link ServerAuthenticationEntryPoint} to use for requests + * authenticating with + * Bearer Tokens. + * @param entryPoint the {@link ServerAuthenticationEntryPoint} to use + * @return the {@link OAuth2ResourceServerSpec} for additional configuration + * @since 5.2 + */ + public OAuth2ResourceServerSpec authenticationEntryPoint(ServerAuthenticationEntryPoint entryPoint) { + Assert.notNull(entryPoint, "entryPoint cannot be null"); + this.entryPoint = entryPoint; + return this; + } + + /** + * Configures the {@link ServerAuthenticationConverter} to use for requests + * authenticating with + * Bearer Tokens. + * @param bearerTokenConverter The {@link ServerAuthenticationConverter} to use + * @return The {@link OAuth2ResourceServerSpec} for additional configuration + * @since 5.2 + */ + public OAuth2ResourceServerSpec bearerTokenConverter(ServerAuthenticationConverter bearerTokenConverter) { + Assert.notNull(bearerTokenConverter, "bearerTokenConverter cannot be null"); + this.bearerTokenConverter = bearerTokenConverter; + return this; + } + + /** + * Configures the {@link ReactiveAuthenticationManagerResolver} + * @param authenticationManagerResolver the + * {@link ReactiveAuthenticationManagerResolver} + * @return the {@link OAuth2ResourceServerSpec} for additional configuration + * @since 5.3 + */ + public OAuth2ResourceServerSpec authenticationManagerResolver( + ReactiveAuthenticationManagerResolver authenticationManagerResolver) { + Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); + this.authenticationManagerResolver = authenticationManagerResolver; + return this; + } + + /** + * Enables JWT Resource Server support. + * @return the {@link JwtSpec} for additional configuration + */ + public JwtSpec jwt() { + if (this.jwt == null) { + this.jwt = new JwtSpec(); + } + return this.jwt; + } + + /** + * Enables JWT Resource Server support. + * @param jwtCustomizer the {@link Customizer} to provide more options for the + * {@link JwtSpec} + * @return the {@link OAuth2ResourceServerSpec} to customize + */ + public OAuth2ResourceServerSpec jwt(Customizer jwtCustomizer) { + if (this.jwt == null) { + this.jwt = new JwtSpec(); + } + jwtCustomizer.customize(this.jwt); + return this; + } + + /** + * Enables Opaque Token Resource Server support. + * @return the {@link OpaqueTokenSpec} for additional configuration + */ + public OpaqueTokenSpec opaqueToken() { + if (this.opaqueToken == null) { + this.opaqueToken = new OpaqueTokenSpec(); + } + return this.opaqueToken; + } + + /** + * Enables Opaque Token Resource Server support. + * @param opaqueTokenCustomizer the {@link Customizer} to provide more options for + * the {@link OpaqueTokenSpec} + * @return the {@link OAuth2ResourceServerSpec} to customize + */ + public OAuth2ResourceServerSpec opaqueToken(Customizer opaqueTokenCustomizer) { + if (this.opaqueToken == null) { + this.opaqueToken = new OpaqueTokenSpec(); + } + opaqueTokenCustomizer.customize(this.opaqueToken); + return this; + } + + protected void configure(ServerHttpSecurity http) { + this.authenticationConverterServerWebExchangeMatcher = new AuthenticationConverterServerWebExchangeMatcher( + this.bearerTokenConverter); + registerDefaultAccessDeniedHandler(http); + registerDefaultAuthenticationEntryPoint(http); + registerDefaultCsrfOverride(http); + validateConfiguration(); + if (this.authenticationManagerResolver != null) { + AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver); + oauth2.setServerAuthenticationConverter(this.bearerTokenConverter); + oauth2.setAuthenticationFailureHandler( + new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); + http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); + } + else if (this.jwt != null) { + this.jwt.configure(http); + } + else if (this.opaqueToken != null) { + this.opaqueToken.configure(http); + } + } + + private void validateConfiguration() { + if (this.authenticationManagerResolver == null) { + Assert.state(this.jwt != null || this.opaqueToken != null, + "Jwt and Opaque Token are the only supported formats for bearer tokens " + + "in Spring Security and neither was found. Make sure to configure JWT " + + "via http.oauth2ResourceServer().jwt() or Opaque Tokens via " + + "http.oauth2ResourceServer().opaqueToken()."); + Assert.state(this.jwt == null || this.opaqueToken == null, + "Spring Security only supports JWTs or Opaque Tokens, not both at the " + "same time."); + } + else { + Assert.state(this.jwt == null && this.opaqueToken == null, + "If an authenticationManagerResolver() is configured, then it takes " + + "precedence over any jwt() or opaqueToken() configuration."); + } + } + + private void registerDefaultAccessDeniedHandler(ServerHttpSecurity http) { + if (http.exceptionHandling != null) { + http.defaultAccessDeniedHandlers + .add(new ServerWebExchangeDelegatingServerAccessDeniedHandler.DelegateEntry( + this.authenticationConverterServerWebExchangeMatcher, + OAuth2ResourceServerSpec.this.accessDeniedHandler)); + } + } + + private void registerDefaultAuthenticationEntryPoint(ServerHttpSecurity http) { + if (http.exceptionHandling != null) { + http.defaultEntryPoints.add(new DelegateEntry(this.authenticationConverterServerWebExchangeMatcher, + OAuth2ResourceServerSpec.this.entryPoint)); + } + } + + private void registerDefaultCsrfOverride(ServerHttpSecurity http) { + if (http.csrf != null && !http.csrf.specifiedRequireCsrfProtectionMatcher) { + AndServerWebExchangeMatcher matcher = new AndServerWebExchangeMatcher( + CsrfWebFilter.DEFAULT_CSRF_MATCHER, + new NegatedServerWebExchangeMatcher(this.authenticationConverterServerWebExchangeMatcher)); + http.csrf().requireCsrfProtectionMatcher(matcher); + } + } + + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + private class BearerTokenAuthenticationWebFilter extends AuthenticationWebFilter { + + private ServerAuthenticationFailureHandler authenticationFailureHandler; + + BearerTokenAuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) { + super(authenticationManager); + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); + return super.filter(exchange, chain).onErrorResume(AuthenticationException.class, + (e) -> this.authenticationFailureHandler.onAuthenticationFailure(webFilterExchange, e)); + } + + @Override + public void setAuthenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { + super.setAuthenticationFailureHandler(authenticationFailureHandler); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + } + + /** + * Configures JWT Resource Server Support + */ + public class JwtSpec { + + private ReactiveAuthenticationManager authenticationManager; + + private ReactiveJwtDecoder jwtDecoder; + + private Converter> jwtAuthenticationConverter = new ReactiveJwtAuthenticationConverterAdapter( + new JwtAuthenticationConverter()); + + /** + * Configures the {@link ReactiveAuthenticationManager} to use + * @param authenticationManager the authentication manager to use + * @return the {@code JwtSpec} for additional configuration + */ + public JwtSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + this.authenticationManager = authenticationManager; + return this; + } + + /** + * Configures the {@link Converter} to use for converting a {@link Jwt} into + * an {@link AbstractAuthenticationToken}. + * @param jwtAuthenticationConverter the converter to use + * @return the {@code JwtSpec} for additional configuration + * @since 5.1.1 + */ + public JwtSpec jwtAuthenticationConverter( + Converter> jwtAuthenticationConverter) { + Assert.notNull(jwtAuthenticationConverter, "jwtAuthenticationConverter cannot be null"); + this.jwtAuthenticationConverter = jwtAuthenticationConverter; + return this; + } + + /** + * Configures the {@link ReactiveJwtDecoder} to use + * @param jwtDecoder the decoder to use + * @return the {@code JwtSpec} for additional configuration + */ + public JwtSpec jwtDecoder(ReactiveJwtDecoder jwtDecoder) { + this.jwtDecoder = jwtDecoder; + return this; + } + + /** + * Configures a {@link ReactiveJwtDecoder} that leverages the provided + * {@link RSAPublicKey} + * @param publicKey the public key to use. + * @return the {@code JwtSpec} for additional configuration + */ + public JwtSpec publicKey(RSAPublicKey publicKey) { + this.jwtDecoder = new NimbusReactiveJwtDecoder(publicKey); + return this; + } + + /** + * Configures a {@link ReactiveJwtDecoder} using + * JSON Web Key + * (JWK) URL + * @param jwkSetUri the URL to use. + * @return the {@code JwtSpec} for additional configuration + */ + public JwtSpec jwkSetUri(String jwkSetUri) { + this.jwtDecoder = new NimbusReactiveJwtDecoder(jwkSetUri); + return this; + } + + public OAuth2ResourceServerSpec and() { + return OAuth2ResourceServerSpec.this; + } + + protected void configure(ServerHttpSecurity http) { + ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); + AuthenticationWebFilter oauth2 = new BearerTokenAuthenticationWebFilter(authenticationManager); + oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter); + oauth2.setAuthenticationFailureHandler( + new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint)); + http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); + } + + protected ReactiveJwtDecoder getJwtDecoder() { + return (this.jwtDecoder != null) ? this.jwtDecoder : getBean(ReactiveJwtDecoder.class); + } + + protected Converter> getJwtAuthenticationConverter() { + return this.jwtAuthenticationConverter; + } + + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager != null) { + return this.authenticationManager; + } + ReactiveJwtDecoder jwtDecoder = getJwtDecoder(); + Converter> jwtAuthenticationConverter = getJwtAuthenticationConverter(); + JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager( + jwtDecoder); + authenticationManager.setJwtAuthenticationConverter(jwtAuthenticationConverter); + return authenticationManager; + } + + } + + /** + * Configures Opaque Token Resource Server support + * + * @author Josh Cummings + * @since 5.2 + */ + public final class OpaqueTokenSpec { + + private String introspectionUri; + + private String clientId; + + private String clientSecret; + + private Supplier introspector; + + private OpaqueTokenSpec() { + } + + /** + * Configures the URI of the Introspection endpoint + * @param introspectionUri The URI of the Introspection endpoint + * @return the {@code OpaqueTokenSpec} for additional configuration + */ + public OpaqueTokenSpec introspectionUri(String introspectionUri) { + Assert.hasText(introspectionUri, "introspectionUri cannot be empty"); + this.introspectionUri = introspectionUri; + this.introspector = () -> new NimbusReactiveOpaqueTokenIntrospector(this.introspectionUri, + this.clientId, this.clientSecret); + return this; + } + + /** + * Configures the credentials for Introspection endpoint + * @param clientId The clientId part of the credentials + * @param clientSecret The clientSecret part of the credentials + * @return the {@code OpaqueTokenSpec} for additional configuration + */ + public OpaqueTokenSpec introspectionClientCredentials(String clientId, String clientSecret) { + Assert.hasText(clientId, "clientId cannot be empty"); + Assert.notNull(clientSecret, "clientSecret cannot be null"); + this.clientId = clientId; + this.clientSecret = clientSecret; + this.introspector = () -> new NimbusReactiveOpaqueTokenIntrospector(this.introspectionUri, + this.clientId, this.clientSecret); + return this; + } + + public OpaqueTokenSpec introspector(ReactiveOpaqueTokenIntrospector introspector) { + Assert.notNull(introspector, "introspector cannot be null"); + this.introspector = () -> introspector; + return this; + } + + /** + * Allows method chaining to continue configuring the + * {@link ServerHttpSecurity} + * @return the {@link ServerHttpSecurity} to continue configuring + */ + public OAuth2ResourceServerSpec and() { + return OAuth2ResourceServerSpec.this; + } + + protected ReactiveAuthenticationManager getAuthenticationManager() { + return new OpaqueTokenReactiveAuthenticationManager(getIntrospector()); + } + + protected ReactiveOpaqueTokenIntrospector getIntrospector() { + if (this.introspector != null) { + return this.introspector.get(); + } + return getBean(ReactiveOpaqueTokenIntrospector.class); + } + + protected void configure(ServerHttpSecurity http) { + ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); + AuthenticationWebFilter oauth2 = new BearerTokenAuthenticationWebFilter(authenticationManager); + oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter); + oauth2.setAuthenticationFailureHandler( + new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint)); + http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); + } + + } + } /** * Configures anonymous authentication + * * @author Ankur Pathak * @since 5.2.0 */ public final class AnonymousSpec { + private String key; + private AnonymousAuthenticationWebFilter authenticationFilter; + private Object principal = "anonymousUser"; + private List authorities = AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"); /** - * Sets the key to identify tokens created for anonymous authentication. Default is a - * secure randomly generated key. - * - * @param key the key to identify tokens created for anonymous authentication. Default + * Sets the key to identify tokens created for anonymous authentication. Default * is a secure randomly generated key. + * @param key the key to identify tokens created for anonymous authentication. + * Default is a secure randomly generated key. * @return the {@link AnonymousSpec} for further customization of anonymous * authentication */ @@ -3920,7 +3972,6 @@ public class ServerHttpSecurity { /** * Sets the principal for {@link Authentication} objects of anonymous users - * * @param principal used for the {@link Authentication} object of anonymous users * @return the {@link AnonymousSpec} for further customization of anonymous * authentication @@ -3931,9 +3982,9 @@ public class ServerHttpSecurity { } /** - * Sets the {@link org.springframework.security.core.Authentication#getAuthorities()} - * for anonymous users - * + * Sets the + * {@link org.springframework.security.core.Authentication#getAuthorities()} for + * anonymous users * @param authorities Sets the * {@link org.springframework.security.core.Authentication#getAuthorities()} for * anonymous users @@ -3946,9 +3997,9 @@ public class ServerHttpSecurity { } /** - * Sets the {@link org.springframework.security.core.Authentication#getAuthorities()} - * for anonymous users - * + * Sets the + * {@link org.springframework.security.core.Authentication#getAuthorities()} for + * anonymous users * @param authorities Sets the * {@link org.springframework.security.core.Authentication#getAuthorities()} for * anonymous users (i.e. "ROLE_ANONYMOUS") @@ -3960,18 +4011,15 @@ public class ServerHttpSecurity { } /** - * Sets the {@link AnonymousAuthenticationWebFilter} used to populate an anonymous user. - * If this is set, no attributes on the {@link AnonymousSpec} will be set on the - * {@link AnonymousAuthenticationWebFilter}. - * - * @param authenticationFilter the {@link AnonymousAuthenticationWebFilter} used to - * populate an anonymous user. - * + * Sets the {@link AnonymousAuthenticationWebFilter} used to populate an anonymous + * user. If this is set, no attributes on the {@link AnonymousSpec} will be set on + * the {@link AnonymousAuthenticationWebFilter}. + * @param authenticationFilter the {@link AnonymousAuthenticationWebFilter} used + * to populate an anonymous user. * @return the {@link AnonymousSpec} for further customization of anonymous * authentication */ - public AnonymousSpec authenticationFilter( - AnonymousAuthenticationWebFilter authenticationFilter) { + public AnonymousSpec authenticationFilter(AnonymousAuthenticationWebFilter authenticationFilter) { this.authenticationFilter = authenticationFilter; return this; } @@ -3994,22 +4042,22 @@ public class ServerHttpSecurity { } protected void configure(ServerHttpSecurity http) { - if (authenticationFilter == null) { - authenticationFilter = new AnonymousAuthenticationWebFilter(getKey(), principal, - authorities); + if (this.authenticationFilter == null) { + this.authenticationFilter = new AnonymousAuthenticationWebFilter(getKey(), this.principal, + this.authorities); } - http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.ANONYMOUS_AUTHENTICATION); + http.addFilterAt(this.authenticationFilter, SecurityWebFiltersOrder.ANONYMOUS_AUTHENTICATION); } private String getKey() { - if (key == null) { - key = UUID.randomUUID().toString(); + if (this.key == null) { + this.key = UUID.randomUUID().toString(); } - return key; + return this.key; } - - private AnonymousSpec() {} + private AnonymousSpec() { + } } diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index 24fe2003dc..ba8d3ac4e0 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.websocket; -import static org.springframework.security.config.Elements.*; +package org.springframework.security.config.websocket; import java.util.Comparator; import java.util.List; import java.util.Map; +import org.w3c.dom.Element; + import org.springframework.beans.BeansException; import org.springframework.beans.PropertyValue; import org.springframework.beans.factory.config.BeanDefinition; @@ -53,7 +54,6 @@ import org.springframework.util.AntPathMatcher; import org.springframework.util.PathMatcher; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; -import org.w3c.dom.Element; /** * Parses Spring Security's websocket namespace support. A simple example is: @@ -93,8 +93,8 @@ import org.w3c.dom.Element; * @author Rob Winch * @since 4.0 */ -public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements - BeanDefinitionParser { +public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements BeanDefinitionParser { + private static final String ID_ATTR = "id"; private static final String DISABLED_ATTR = "same-origin-disabled"; @@ -112,33 +112,24 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements * @param parserContext * @return the {@link BeanDefinition} */ + @Override public BeanDefinition parse(Element element, ParserContext parserContext) { BeanDefinitionRegistry registry = parserContext.getRegistry(); XmlReaderContext context = parserContext.getReaderContext(); - ManagedMap matcherToExpression = new ManagedMap<>(); - String id = element.getAttribute(ID_ATTR); - Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, - EXPRESSION_HANDLER); - String expressionHandlerRef = expressionHandlerElt == null ? null : expressionHandlerElt.getAttribute("ref"); + Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, Elements.EXPRESSION_HANDLER); + String expressionHandlerRef = (expressionHandlerElt != null) ? expressionHandlerElt.getAttribute("ref") : null; boolean expressionHandlerDefined = StringUtils.hasText(expressionHandlerRef); - - boolean sameOriginDisabled = Boolean.parseBoolean(element - .getAttribute(DISABLED_ATTR)); - - List interceptMessages = DomUtils.getChildElementsByTagName(element, - Elements.INTERCEPT_MESSAGE); + boolean sameOriginDisabled = Boolean.parseBoolean(element.getAttribute(DISABLED_ATTR)); + List interceptMessages = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_MESSAGE); for (Element interceptMessage : interceptMessages) { String matcherPattern = interceptMessage.getAttribute(PATTERN_ATTR); String accessExpression = interceptMessage.getAttribute(ACCESS_ATTR); String messageType = interceptMessage.getAttribute(TYPE_ATTR); - - BeanDefinition matcher = createMatcher(matcherPattern, messageType, - parserContext, interceptMessage); + BeanDefinition matcher = createMatcher(matcherPattern, messageType, parserContext, interceptMessage); matcherToExpression.put(matcher, accessExpression); } - BeanDefinitionBuilder mds = BeanDefinitionBuilder .rootBeanDefinition(ExpressionBasedMessageSecurityMetadataSourceFactory.class); mds.setFactoryMethod("createExpressionMessageMetadataSource"); @@ -146,58 +137,46 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements if (expressionHandlerDefined) { mds.addConstructorArgReference(expressionHandlerRef); } - String mdsId = context.registerWithGeneratedName(mds.getBeanDefinition()); - ManagedList voters = new ManagedList<>(); - BeanDefinitionBuilder messageExpressionVoterBldr = BeanDefinitionBuilder.rootBeanDefinition(MessageExpressionVoter.class); + BeanDefinitionBuilder messageExpressionVoterBldr = BeanDefinitionBuilder + .rootBeanDefinition(MessageExpressionVoter.class); if (expressionHandlerDefined) { messageExpressionVoterBldr.addPropertyReference("expressionHandler", expressionHandlerRef); } voters.add(messageExpressionVoterBldr.getBeanDefinition()); - BeanDefinitionBuilder adm = BeanDefinitionBuilder - .rootBeanDefinition(ConsensusBased.class); + BeanDefinitionBuilder adm = BeanDefinitionBuilder.rootBeanDefinition(ConsensusBased.class); adm.addConstructorArgValue(voters); - BeanDefinitionBuilder inboundChannelSecurityInterceptor = BeanDefinitionBuilder .rootBeanDefinition(ChannelSecurityInterceptor.class); - inboundChannelSecurityInterceptor.addConstructorArgValue(registry - .getBeanDefinition(mdsId)); - inboundChannelSecurityInterceptor.addPropertyValue("accessDecisionManager", - adm.getBeanDefinition()); + inboundChannelSecurityInterceptor.addConstructorArgValue(registry.getBeanDefinition(mdsId)); + inboundChannelSecurityInterceptor.addPropertyValue("accessDecisionManager", adm.getBeanDefinition()); String inSecurityInterceptorName = context - .registerWithGeneratedName(inboundChannelSecurityInterceptor - .getBeanDefinition()); - + .registerWithGeneratedName(inboundChannelSecurityInterceptor.getBeanDefinition()); if (StringUtils.hasText(id)) { registry.registerAlias(inSecurityInterceptorName, id); - if (!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) { registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, new RootBeanDefinition(AntPathMatcher.class)); } } else { - BeanDefinitionBuilder mspp = BeanDefinitionBuilder - .rootBeanDefinition(MessageSecurityPostProcessor.class); + BeanDefinitionBuilder mspp = BeanDefinitionBuilder.rootBeanDefinition(MessageSecurityPostProcessor.class); mspp.addConstructorArgValue(inSecurityInterceptorName); mspp.addConstructorArgValue(sameOriginDisabled); context.registerWithGeneratedName(mspp.getBeanDefinition()); } - return null; } - private BeanDefinition createMatcher(String matcherPattern, String messageType, - ParserContext parserContext, Element interceptMessage) { + private BeanDefinition createMatcher(String matcherPattern, String messageType, ParserContext parserContext, + Element interceptMessage) { boolean hasPattern = StringUtils.hasText(matcherPattern); boolean hasMessageType = StringUtils.hasText(messageType); if (!hasPattern) { - BeanDefinitionBuilder matcher = BeanDefinitionBuilder - .rootBeanDefinition(SimpMessageTypeMatcher.class); + BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpMessageTypeMatcher.class); matcher.addConstructorArgValue(messageType); return matcher.getBeanDefinition(); } - String factoryName = null; if (hasPattern && hasMessageType) { SimpMessageType type = SimpMessageType.valueOf(messageType); @@ -208,25 +187,18 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements factoryName = "createSubscribeMatcher"; } else { - parserContext - .getReaderContext() - .error("Cannot use intercept-websocket@message-type=" - + messageType - + " with a pattern because the type does not have a destination.", - interceptMessage); + parserContext.getReaderContext().error("Cannot use intercept-websocket@message-type=" + messageType + + " with a pattern because the type does not have a destination.", interceptMessage); } } - - BeanDefinitionBuilder matcher = BeanDefinitionBuilder - .rootBeanDefinition(SimpDestinationMessageMatcher.class); + BeanDefinitionBuilder matcher = BeanDefinitionBuilder.rootBeanDefinition(SimpDestinationMessageMatcher.class); matcher.setFactoryMethod(factoryName); matcher.addConstructorArgValue(matcherPattern); matcher.addConstructorArgValue(new RuntimeBeanReference("springSecurityMessagePathMatcher")); return matcher.getBeanDefinition(); } - static class MessageSecurityPostProcessor implements - BeanDefinitionRegistryPostProcessor { + static class MessageSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor { /** * This is not available prior to Spring 4.2 @@ -248,27 +220,24 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements this.sameOriginDisabled = sameOriginDisabled; } - public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) - throws BeansException { + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { String[] beanNames = registry.getBeanDefinitionNames(); for (String beanName : beanNames) { BeanDefinition bd = registry.getBeanDefinition(beanName); String beanClassName = bd.getBeanClassName(); - if (SimpAnnotationMethodMessageHandler.class.getName().equals(beanClassName) || - WEB_SOCKET_AMMH_CLASS_NAME.equals(beanClassName)) { - PropertyValue current = bd.getPropertyValues().getPropertyValue( - CUSTOM_ARG_RESOLVERS_PROP); + if (SimpAnnotationMethodMessageHandler.class.getName().equals(beanClassName) + || WEB_SOCKET_AMMH_CLASS_NAME.equals(beanClassName)) { + PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP); ManagedList argResolvers = new ManagedList<>(); if (current != null) { argResolvers.addAll((ManagedList) current.getValue()); } - argResolvers.add(new RootBeanDefinition( - AuthenticationPrincipalArgumentResolver.class)); + argResolvers.add(new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class)); bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers); - if (!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) { PropertyValue pathMatcherProp = bd.getPropertyValues().getPropertyValue("pathMatcher"); - Object pathMatcher = pathMatcherProp == null ? null : pathMatcherProp.getValue(); + Object pathMatcher = (pathMatcherProp != null) ? pathMatcherProp.getValue() : null; if (pathMatcher instanceof BeanReference) { registry.registerAlias(((BeanReference) pathMatcher).getBeanName(), PATH_MATCHER_BEAN_NAME); } @@ -287,87 +256,89 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements addCsrfTokenHandshakeInterceptor(bd); } } - if (!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) { return; } ManagedList interceptors = new ManagedList(); - interceptors.add(new RootBeanDefinition( - SecurityContextChannelInterceptor.class)); - if (!sameOriginDisabled) { + interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); + if (!this.sameOriginDisabled) { interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); } - interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId)); - - BeanDefinition inboundChannel = registry - .getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); + interceptors.add(registry.getBeanDefinition(this.inboundSecurityInterceptorId)); + BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); PropertyValue currentInterceptorsPv = inboundChannel.getPropertyValues() .getPropertyValue(INTERCEPTORS_PROP); if (currentInterceptorsPv != null) { - ManagedList currentInterceptors = (ManagedList) currentInterceptorsPv - .getValue(); + ManagedList currentInterceptors = (ManagedList) currentInterceptorsPv.getValue(); interceptors.addAll(currentInterceptors); } - inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors); - if (!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) { registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, new RootBeanDefinition(AntPathMatcher.class)); } } private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) { - if (sameOriginDisabled) { + if (this.sameOriginDisabled) { return; } String interceptorPropertyName = "handshakeInterceptors"; ManagedList interceptors = new ManagedList<>(); interceptors.add(new RootBeanDefinition(CsrfTokenHandshakeInterceptor.class)); - interceptors.addAll((ManagedList) bd.getPropertyValues().get( - interceptorPropertyName)); + interceptors.addAll((ManagedList) bd.getPropertyValues().get(interceptorPropertyName)); bd.getPropertyValues().add(interceptorPropertyName, interceptors); } - public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) - throws BeansException { + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { } + } static class DelegatingPathMatcher implements PathMatcher { private PathMatcher delegate = new AntPathMatcher(); + @Override public boolean isPattern(String path) { - return delegate.isPattern(path); + return this.delegate.isPattern(path); } + @Override public boolean match(String pattern, String path) { - return delegate.match(pattern, path); + return this.delegate.match(pattern, path); } + @Override public boolean matchStart(String pattern, String path) { - return delegate.matchStart(pattern, path); + return this.delegate.matchStart(pattern, path); } + @Override public String extractPathWithinPattern(String pattern, String path) { - return delegate.extractPathWithinPattern(pattern, path); + return this.delegate.extractPathWithinPattern(pattern, path); } + @Override public Map extractUriTemplateVariables(String pattern, String path) { - return delegate.extractUriTemplateVariables(pattern, path); + return this.delegate.extractUriTemplateVariables(pattern, path); } + @Override public Comparator getPatternComparator(String path) { - return delegate.getPatternComparator(path); + return this.delegate.getPatternComparator(path); } + @Override public String combine(String pattern1, String pattern2) { - return delegate.combine(pattern1, pattern2); + return this.delegate.combine(pattern1, pattern2); } void setPathMatcher(PathMatcher pathMatcher) { this.delegate = pathMatcher; } + } + } diff --git a/config/src/test/java/org/springframework/security/BeanNameCollectingPostProcessor.java b/config/src/test/java/org/springframework/security/BeanNameCollectingPostProcessor.java index ff39888179..e91b07ba4d 100644 --- a/config/src/test/java/org/springframework/security/BeanNameCollectingPostProcessor.java +++ b/config/src/test/java/org/springframework/security/BeanNameCollectingPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security; import java.util.HashSet; @@ -25,30 +26,33 @@ import org.springframework.beans.factory.config.BeanPostProcessor; * @author Luke Taylor */ public class BeanNameCollectingPostProcessor implements BeanPostProcessor { + Set beforeInitPostProcessedBeans = new HashSet<>(); + Set afterInitPostProcessedBeans = new HashSet<>(); - public Object postProcessBeforeInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { if (beanName != null) { - beforeInitPostProcessedBeans.add(beanName); + this.beforeInitPostProcessedBeans.add(beanName); } return bean; } - public Object postProcessAfterInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { if (beanName != null) { - afterInitPostProcessedBeans.add(beanName); + this.afterInitPostProcessedBeans.add(beanName); } return bean; } public Set getBeforeInitPostProcessedBeans() { - return beforeInitPostProcessedBeans; + return this.beforeInitPostProcessedBeans; } public Set getAfterInitPostProcessedBeans() { - return afterInitPostProcessedBeans; + return this.afterInitPostProcessedBeans; } + } diff --git a/config/src/test/java/org/springframework/security/CollectingAppListener.java b/config/src/test/java/org/springframework/security/CollectingAppListener.java index 1dd1c5c21a..5861750fe2 100644 --- a/config/src/test/java/org/springframework/security/CollectingAppListener.java +++ b/config/src/test/java/org/springframework/security/CollectingAppListener.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security; -import java.util.*; +import java.util.HashSet; +import java.util.Set; import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationListener; @@ -30,39 +32,45 @@ import org.springframework.security.authentication.event.AbstractAuthenticationF * @since 3.1 */ public class CollectingAppListener implements ApplicationListener { + Set events = new HashSet<>(); + Set authenticationEvents = new HashSet<>(); + Set authenticationFailureEvents = new HashSet<>(); + Set authorizationEvents = new HashSet<>(); + @Override public void onApplicationEvent(ApplicationEvent event) { if (event instanceof AbstractAuthenticationEvent) { - events.add(event); - authenticationEvents.add((AbstractAuthenticationEvent) event); + this.events.add(event); + this.authenticationEvents.add((AbstractAuthenticationEvent) event); } if (event instanceof AbstractAuthenticationFailureEvent) { - events.add(event); - authenticationFailureEvents.add((AbstractAuthenticationFailureEvent) event); + this.events.add(event); + this.authenticationFailureEvents.add((AbstractAuthenticationFailureEvent) event); } if (event instanceof AbstractAuthorizationEvent) { - events.add(event); - authorizationEvents.add((AbstractAuthorizationEvent) event); + this.events.add(event); + this.authorizationEvents.add((AbstractAuthorizationEvent) event); } } public Set getEvents() { - return events; + return this.events; } public Set getAuthenticationEvents() { - return authenticationEvents; + return this.authenticationEvents; } public Set getAuthenticationFailureEvents() { - return authenticationFailureEvents; + return this.authenticationFailureEvents; } public Set getAuthorizationEvents() { - return authorizationEvents; + return this.authorizationEvents; } + } diff --git a/config/src/test/java/org/springframework/security/config/ConfigTestUtils.java b/config/src/test/java/org/springframework/security/config/ConfigTestUtils.java index a95ffe3c4e..92075d6649 100644 --- a/config/src/test/java/org/springframework/security/config/ConfigTestUtils.java +++ b/config/src/test/java/org/springframework/security/config/ConfigTestUtils.java @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; public abstract class ConfigTestUtils { + + // @formatter:off public static final String AUTH_PROVIDER_XML = "" + " " + " " @@ -26,4 +29,6 @@ public abstract class ConfigTestUtils { + " " + " " + ""; + // @formatter:on + } diff --git a/config/src/test/java/org/springframework/security/config/DataSourcePopulator.java b/config/src/test/java/org/springframework/security/config/DataSourcePopulator.java index 212e2ad6b6..103addaebd 100644 --- a/config/src/test/java/org/springframework/security/config/DataSourcePopulator.java +++ b/config/src/test/java/org/springframework/security/config/DataSourcePopulator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import javax.sql.DataSource; @@ -27,18 +28,17 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class DataSourcePopulator implements InitializingBean { - // ~ Instance fields - // ================================================================================================ JdbcTemplate template; + @Override public void afterPropertiesSet() { - Assert.notNull(template, "dataSource required"); - - template.execute("CREATE TABLE USERS(USERNAME VARCHAR_IGNORECASE(50) NOT NULL PRIMARY KEY,PASSWORD VARCHAR_IGNORECASE(500) NOT NULL,ENABLED BOOLEAN NOT NULL);"); - template.execute("CREATE TABLE AUTHORITIES(USERNAME VARCHAR_IGNORECASE(50) NOT NULL,AUTHORITY VARCHAR_IGNORECASE(50) NOT NULL,CONSTRAINT FK_AUTHORITIES_USERS FOREIGN KEY(USERNAME) REFERENCES USERS(USERNAME));"); - template.execute("CREATE UNIQUE INDEX IX_AUTH_USERNAME ON AUTHORITIES(USERNAME,AUTHORITY);"); - + Assert.notNull(this.template, "dataSource required"); + this.template.execute( + "CREATE TABLE USERS(USERNAME VARCHAR_IGNORECASE(50) NOT NULL PRIMARY KEY,PASSWORD VARCHAR_IGNORECASE(500) NOT NULL,ENABLED BOOLEAN NOT NULL);"); + this.template.execute( + "CREATE TABLE AUTHORITIES(USERNAME VARCHAR_IGNORECASE(50) NOT NULL,AUTHORITY VARCHAR_IGNORECASE(50) NOT NULL,CONSTRAINT FK_AUTHORITIES_USERS FOREIGN KEY(USERNAME) REFERENCES USERS(USERNAME));"); + this.template.execute("CREATE UNIQUE INDEX IX_AUTH_USERNAME ON AUTHORITIES(USERNAME,AUTHORITY);"); /* * Passwords encoded using MD5, NOT in Base64 format, with null as salt Encoded * password for rod is "koala" Encoded password for dianne is "emu" Encoded @@ -46,24 +46,25 @@ public class DataSourcePopulator implements InitializingBean { * is disabled) Encoded password for bill is "wombat" Encoded password for bob is * "wombat" Encoded password for jane is "wombat" */ - template.execute("INSERT INTO USERS VALUES('rod','{noop}koala',TRUE);"); - template.execute("INSERT INTO USERS VALUES('dianne','{MD5}65d15fe9156f9c4bbffd98085992a44e',TRUE);"); - template.execute("INSERT INTO USERS VALUES('scott','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); - template.execute("INSERT INTO USERS VALUES('peter','{MD5}22b5c9accc6e1ba628cedc63a72d57f8',FALSE);"); - template.execute("INSERT INTO USERS VALUES('bill','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); - template.execute("INSERT INTO USERS VALUES('bob','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); - template.execute("INSERT INTO USERS VALUES('jane','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); - template.execute("INSERT INTO AUTHORITIES VALUES('rod','ROLE_USER');"); - template.execute("INSERT INTO AUTHORITIES VALUES('rod','ROLE_SUPERVISOR');"); - template.execute("INSERT INTO AUTHORITIES VALUES('dianne','ROLE_USER');"); - template.execute("INSERT INTO AUTHORITIES VALUES('scott','ROLE_USER');"); - template.execute("INSERT INTO AUTHORITIES VALUES('peter','ROLE_USER');"); - template.execute("INSERT INTO AUTHORITIES VALUES('bill','ROLE_USER');"); - template.execute("INSERT INTO AUTHORITIES VALUES('bob','ROLE_USER');"); - template.execute("INSERT INTO AUTHORITIES VALUES('jane','ROLE_USER');"); + this.template.execute("INSERT INTO USERS VALUES('rod','{noop}koala',TRUE);"); + this.template.execute("INSERT INTO USERS VALUES('dianne','{MD5}65d15fe9156f9c4bbffd98085992a44e',TRUE);"); + this.template.execute("INSERT INTO USERS VALUES('scott','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); + this.template.execute("INSERT INTO USERS VALUES('peter','{MD5}22b5c9accc6e1ba628cedc63a72d57f8',FALSE);"); + this.template.execute("INSERT INTO USERS VALUES('bill','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); + this.template.execute("INSERT INTO USERS VALUES('bob','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); + this.template.execute("INSERT INTO USERS VALUES('jane','{MD5}2b58af6dddbd072ed27ffc86725d7d3a',TRUE);"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('rod','ROLE_USER');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('rod','ROLE_SUPERVISOR');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('dianne','ROLE_USER');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('scott','ROLE_USER');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('peter','ROLE_USER');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('bill','ROLE_USER');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('bob','ROLE_USER');"); + this.template.execute("INSERT INTO AUTHORITIES VALUES('jane','ROLE_USER');"); } public void setDataSource(DataSource dataSource) { this.template = new JdbcTemplate(dataSource); } + } diff --git a/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java b/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java index 42a6eba6d5..d570897460 100644 --- a/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/FilterChainProxyConfigTests.java @@ -16,9 +16,6 @@ package org.springframework.security.config; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import java.util.List; import javax.servlet.Filter; @@ -43,6 +40,11 @@ import org.springframework.security.web.servletapi.SecurityContextHolderAwareReq import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AnyRequestMatcher; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + /** * Tests {@link FilterChainProxy}. * @@ -50,37 +52,32 @@ import org.springframework.security.web.util.matcher.AnyRequestMatcher; * @author Ben Alex */ public class FilterChainProxyConfigTests { - private ClassPathXmlApplicationContext appCtx; - // ~ Methods - // ======================================================================================================== + private ClassPathXmlApplicationContext appCtx; @Before public void loadContext() { System.setProperty("sec1235.pattern1", "/login"); System.setProperty("sec1235.pattern2", "/logout"); - appCtx = new ClassPathXmlApplicationContext( - "org/springframework/security/util/filtertest-valid.xml"); + this.appCtx = new ClassPathXmlApplicationContext("org/springframework/security/util/filtertest-valid.xml"); } @After public void closeContext() { - if (appCtx != null) { - appCtx.close(); + if (this.appCtx != null) { + this.appCtx.close(); } } @Test public void normalOperation() throws Exception { - FilterChainProxy filterChainProxy = appCtx.getBean("filterChain", - FilterChainProxy.class); + FilterChainProxy filterChainProxy = this.appCtx.getBean("filterChain", FilterChainProxy.class); doNormalOperation(filterChainProxy); } @Test public void normalOperationWithNewConfig() throws Exception { - FilterChainProxy filterChainProxy = appCtx.getBean("newFilterChainProxy", - FilterChainProxy.class); + FilterChainProxy filterChainProxy = this.appCtx.getBean("newFilterChainProxy", FilterChainProxy.class); filterChainProxy.setFirewall(new DefaultHttpFirewall()); checkPathAndFilterOrder(filterChainProxy); doNormalOperation(filterChainProxy); @@ -88,8 +85,7 @@ public class FilterChainProxyConfigTests { @Test public void normalOperationWithNewConfigRegex() throws Exception { - FilterChainProxy filterChainProxy = appCtx.getBean("newFilterChainProxyRegex", - FilterChainProxy.class); + FilterChainProxy filterChainProxy = this.appCtx.getBean("newFilterChainProxyRegex", FilterChainProxy.class); filterChainProxy.setFirewall(new DefaultHttpFirewall()); checkPathAndFilterOrder(filterChainProxy); doNormalOperation(filterChainProxy); @@ -97,8 +93,8 @@ public class FilterChainProxyConfigTests { @Test public void normalOperationWithNewConfigNonNamespace() throws Exception { - FilterChainProxy filterChainProxy = appCtx.getBean( - "newFilterChainProxyNonNamespace", FilterChainProxy.class); + FilterChainProxy filterChainProxy = this.appCtx.getBean("newFilterChainProxyNonNamespace", + FilterChainProxy.class); filterChainProxy.setFirewall(new DefaultHttpFirewall()); checkPathAndFilterOrder(filterChainProxy); doNormalOperation(filterChainProxy); @@ -106,43 +102,38 @@ public class FilterChainProxyConfigTests { @Test public void pathWithNoMatchHasNoFilters() { - FilterChainProxy filterChainProxy = appCtx.getBean( - "newFilterChainProxyNoDefaultPath", FilterChainProxy.class); + FilterChainProxy filterChainProxy = this.appCtx.getBean("newFilterChainProxyNoDefaultPath", + FilterChainProxy.class); assertThat(filterChainProxy.getFilters("/nomatch")).isNull(); } // SEC-1235 @Test public void mixingPatternsAndPlaceholdersDoesntCauseOrderingIssues() { - FilterChainProxy fcp = appCtx.getBean("sec1235FilterChainProxy", - FilterChainProxy.class); - + FilterChainProxy fcp = this.appCtx.getBean("sec1235FilterChainProxy", FilterChainProxy.class); List chains = fcp.getFilterChains(); assertThat(getPattern(chains.get(0))).isEqualTo("/login*"); assertThat(getPattern(chains.get(1))).isEqualTo("/logout"); - assertThat(((DefaultSecurityFilterChain) chains.get(2)).getRequestMatcher() instanceof AnyRequestMatcher).isTrue(); + assertThat(((DefaultSecurityFilterChain) chains.get(2)).getRequestMatcher() instanceof AnyRequestMatcher) + .isTrue(); } private String getPattern(SecurityFilterChain chain) { - return ((AntPathRequestMatcher) ((DefaultSecurityFilterChain) chain) - .getRequestMatcher()).getPattern(); + return ((AntPathRequestMatcher) ((DefaultSecurityFilterChain) chain).getRequestMatcher()).getPattern(); } private void checkPathAndFilterOrder(FilterChainProxy filterChainProxy) { List filters = filterChainProxy.getFilters("/foo/blah;x=1"); assertThat(filters).hasSize(1); assertThat(filters.get(0) instanceof SecurityContextHolderAwareRequestFilter).isTrue(); - filters = filterChainProxy.getFilters("/some;x=2,y=3/other/path;z=4/blah"); assertThat(filters).isNotNull(); assertThat(filters).hasSize(3); assertThat(filters.get(0) instanceof SecurityContextPersistenceFilter).isTrue(); assertThat(filters.get(1) instanceof SecurityContextHolderAwareRequestFilter).isTrue(); assertThat(filters.get(2) instanceof SecurityContextHolderAwareRequestFilter).isTrue(); - filters = filterChainProxy.getFilters("/do/not/filter;x=7"); assertThat(filters).isEmpty(); - filters = filterChainProxy.getFilters("/another/nonspecificmatch"); assertThat(filters).hasSize(3); assertThat(filters.get(0) instanceof SecurityContextPersistenceFilter).isTrue(); @@ -153,18 +144,14 @@ public class FilterChainProxyConfigTests { private void doNormalOperation(FilterChainProxy filterChainProxy) throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); request.setServletPath("/foo/secure/super/somefile.html"); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - filterChainProxy.doFilter(request, response, chain); - verify(chain).doFilter(any(HttpServletRequest.class), - any(HttpServletResponse.class)); - + verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); request.setServletPath("/a/path/which/doesnt/match/any/filter.html"); chain = mock(FilterChain.class); filterChainProxy.doFilter(request, response, chain); - verify(chain).doFilter(any(HttpServletRequest.class), - any(HttpServletResponse.class)); + verify(chain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } + } diff --git a/config/src/test/java/org/springframework/security/config/InvalidConfigurationTests.java b/config/src/test/java/org/springframework/security/config/InvalidConfigurationTests.java index 25ca973ef3..7d6134f7f3 100644 --- a/config/src/test/java/org/springframework/security/config/InvalidConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/InvalidConfigurationTests.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config; -import static org.assertj.core.api.Assertions.*; -import static org.junit.Assert.fail; +package org.springframework.security.config; import org.junit.After; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.xml.XmlBeanDefinitionStoreException; import org.springframework.security.config.authentication.AuthenticationManagerFactoryBean; import org.springframework.security.config.util.InMemoryXmlApplicationContext; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.fail; + /** * Tests which make sure invalid configurations are rejected by the namespace. In * particular invalid top-level elements. These are likely to fail after the namespace has @@ -34,12 +36,13 @@ import org.springframework.security.config.util.InMemoryXmlApplicationContext; * @author Luke Taylor */ public class InvalidConfigurationTests { + private InMemoryXmlApplicationContext appContext; @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); + if (this.appContext != null) { + this.appContext.close(); } } @@ -60,24 +63,24 @@ public class InvalidConfigurationTests { setContext(""); fail(); } - catch (BeanCreationException e) { - Throwable cause = ultimateCause(e); + catch (BeanCreationException ex) { + Throwable cause = ultimateCause(ex); assertThat(cause instanceof NoSuchBeanDefinitionException).isTrue(); NoSuchBeanDefinitionException nsbe = (NoSuchBeanDefinitionException) cause; assertThat(nsbe.getBeanName()).isEqualTo(BeanIds.AUTHENTICATION_MANAGER); - assertThat(nsbe.getMessage()).endsWith( - AuthenticationManagerFactoryBean.MISSING_BEAN_ERROR_MESSAGE); + assertThat(nsbe.getMessage()).endsWith(AuthenticationManagerFactoryBean.MISSING_BEAN_ERROR_MESSAGE); } } - private Throwable ultimateCause(Throwable e) { - if (e.getCause() == null) { - return e; + private Throwable ultimateCause(Throwable ex) { + if (ex.getCause() == null) { + return ex; } - return ultimateCause(e.getCause()); + return ultimateCause(ex.getCause()); } private void setContext(String context) { - appContext = new InMemoryXmlApplicationContext(context); + this.appContext = new InMemoryXmlApplicationContext(context); } + } diff --git a/config/src/test/java/org/springframework/security/config/MockAfterInvocationProvider.java b/config/src/test/java/org/springframework/security/config/MockAfterInvocationProvider.java index 6fc6ae1b3c..1891bf0178 100644 --- a/config/src/test/java/org/springframework/security/config/MockAfterInvocationProvider.java +++ b/config/src/test/java/org/springframework/security/config/MockAfterInvocationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import java.util.Collection; @@ -24,16 +25,18 @@ import org.springframework.security.core.Authentication; public class MockAfterInvocationProvider implements AfterInvocationProvider { - public Object decide(Authentication authentication, Object object, - Collection config, Object returnedObject) - throws AccessDeniedException { + @Override + public Object decide(Authentication authentication, Object object, Collection config, + Object returnedObject) throws AccessDeniedException { return returnedObject; } + @Override public boolean supports(ConfigAttribute attribute) { return true; } + @Override public boolean supports(Class clazz) { return true; } diff --git a/config/src/test/java/org/springframework/security/config/MockEventListener.java b/config/src/test/java/org/springframework/security/config/MockEventListener.java index f92f7bcc8d..05a5d5b4f9 100644 --- a/config/src/test/java/org/springframework/security/config/MockEventListener.java +++ b/config/src/test/java/org/springframework/security/config/MockEventListener.java @@ -16,20 +16,21 @@ package org.springframework.security.config; -import org.springframework.context.ApplicationEvent; -import org.springframework.context.ApplicationListener; - import java.util.ArrayList; import java.util.List; +import org.springframework.context.ApplicationEvent; +import org.springframework.context.ApplicationListener; + /** * @author Rob Winch * @since 5.0.2 */ -public class MockEventListener - implements ApplicationListener { +public class MockEventListener implements ApplicationListener { + private List events = new ArrayList<>(); + @Override public void onApplicationEvent(T event) { this.events.add(event); } @@ -37,4 +38,5 @@ public class MockEventListener public List getEvents() { return this.events; } + } diff --git a/config/src/test/java/org/springframework/security/config/MockTransactionManager.java b/config/src/test/java/org/springframework/security/config/MockTransactionManager.java index fb4a207c96..ebd49d751d 100644 --- a/config/src/test/java/org/springframework/security/config/MockTransactionManager.java +++ b/config/src/test/java/org/springframework/security/config/MockTransactionManager.java @@ -13,27 +13,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config; -import static org.mockito.Mockito.mock; +package org.springframework.security.config; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.TransactionDefinition; import org.springframework.transaction.TransactionException; import org.springframework.transaction.TransactionStatus; +import static org.mockito.Mockito.mock; + /** * @author Luke Taylor */ public class MockTransactionManager implements PlatformTransactionManager { - public TransactionStatus getTransaction(TransactionDefinition definition) - throws TransactionException { + + @Override + public TransactionStatus getTransaction(TransactionDefinition definition) throws TransactionException { return mock(TransactionStatus.class); } + @Override public void commit(TransactionStatus status) throws TransactionException { } + @Override public void rollback(TransactionStatus status) throws TransactionException { } + } diff --git a/config/src/test/java/org/springframework/security/config/MockUserServiceBeanPostProcessor.java b/config/src/test/java/org/springframework/security/config/MockUserServiceBeanPostProcessor.java index eb7acd044f..fdec646a62 100644 --- a/config/src/test/java/org/springframework/security/config/MockUserServiceBeanPostProcessor.java +++ b/config/src/test/java/org/springframework/security/config/MockUserServiceBeanPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import org.springframework.beans.BeansException; @@ -26,18 +27,17 @@ import org.springframework.beans.factory.config.BeanPostProcessor; */ public class MockUserServiceBeanPostProcessor implements BeanPostProcessor { - public Object postProcessAfterInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { return bean; } - public Object postProcessBeforeInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { if (bean instanceof PostProcessedMockUserDetailsService) { - ((PostProcessedMockUserDetailsService) bean) - .setPostProcessorWasHere("Hello from the post processor!"); + ((PostProcessedMockUserDetailsService) bean).setPostProcessorWasHere("Hello from the post processor!"); } - return bean; } + } diff --git a/config/src/test/java/org/springframework/security/config/PostProcessedMockUserDetailsService.java b/config/src/test/java/org/springframework/security/config/PostProcessedMockUserDetailsService.java index 2e85d78f7b..8cba5084f0 100644 --- a/config/src/test/java/org/springframework/security/config/PostProcessedMockUserDetailsService.java +++ b/config/src/test/java/org/springframework/security/config/PostProcessedMockUserDetailsService.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; public class PostProcessedMockUserDetailsService implements UserDetailsService { + private String postProcessorWasHere; public PostProcessedMockUserDetailsService() { @@ -26,14 +28,16 @@ public class PostProcessedMockUserDetailsService implements UserDetailsService { } public String getPostProcessorWasHere() { - return postProcessorWasHere; + return this.postProcessorWasHere; } public void setPostProcessorWasHere(String postProcessorWasHere) { this.postProcessorWasHere = postProcessorWasHere; } + @Override public UserDetails loadUserByUsername(String username) { throw new UnsupportedOperationException("Not for actual use"); } + } diff --git a/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java b/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java index a2d17aa5dc..c80bcc308c 100644 --- a/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java +++ b/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import org.apache.commons.logging.Log; @@ -20,6 +21,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -34,14 +36,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.powermock.api.mockito.PowerMockito.doThrow; -import static org.powermock.api.mockito.PowerMockito.mock; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.verifyStatic; -import static org.powermock.api.mockito.PowerMockito.verifyZeroInteractions; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyZeroInteractions; /** - * * @author Luke Taylor * @author Rob Winch * @since 3.0 @@ -50,15 +48,22 @@ import static org.powermock.api.mockito.PowerMockito.verifyZeroInteractions; @PrepareForTest({ ClassUtils.class }) @PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" }) public class SecurityNamespaceHandlerTests { + @Rule public ExpectedException thrown = ExpectedException.none(); + // @formatter:off private static final String XML_AUTHENTICATION_MANAGER = "" - + " " + " " + + " " + + " " + " " - + " " + " " + + " " + + " " + ""; + // @formatter:on + private static final String XML_HTTP_BLOCK = ""; + private static final String FILTER_CHAIN_PROXY_CLASSNAME = "org.springframework.security.web.FilterChainProxy"; @Test @@ -74,15 +79,13 @@ public class SecurityNamespaceHandlerTests { @Test public void pre32SchemaAreNotSupported() { try { - new InMemoryXmlApplicationContext( - "" - + " " - + "", "3.0.3", null); + new InMemoryXmlApplicationContext("" + + " " + "", "3.0.3", + null); fail("Expected BeanDefinitionParsingException"); } catch (BeanDefinitionParsingException expected) { - assertThat(expected.getMessage().contains( - "You cannot use a spring-security-2.0.xsd")); + assertThat(expected.getMessage().contains("You cannot use a spring-security-2.0.xsd")); } } @@ -90,17 +93,14 @@ public class SecurityNamespaceHandlerTests { @Test public void initDoesNotLogErrorWhenFilterChainProxyFailsToLoad() throws Exception { String className = "javax.servlet.Filter"; - spy(ClassUtils.class); - doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", + PowerMockito.spy(ClassUtils.class); + PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); - Log logger = mock(Log.class); SecurityNamespaceHandler handler = new SecurityNamespaceHandler(); ReflectionTestUtils.setField(handler, "logger", logger); - handler.init(); - - verifyStatic(ClassUtils.class); + PowerMockito.verifyStatic(ClassUtils.class); ClassUtils.forName(eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); verifyZeroInteractions(logger); } @@ -108,10 +108,10 @@ public class SecurityNamespaceHandlerTests { @Test public void filterNoClassDefFoundError() throws Exception { String className = "javax.servlet.Filter"; - thrown.expect(BeanDefinitionParsingException.class); - thrown.expectMessage("NoClassDefFoundError: " + className); - spy(ClassUtils.class); - doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", + this.thrown.expect(BeanDefinitionParsingException.class); + this.thrown.expectMessage("NoClassDefFoundError: " + className); + PowerMockito.spy(ClassUtils.class); + PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); } @@ -119,8 +119,8 @@ public class SecurityNamespaceHandlerTests { @Test public void filterNoClassDefFoundErrorNoHttpBlock() throws Exception { String className = "javax.servlet.Filter"; - spy(ClassUtils.class); - doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", + PowerMockito.spy(ClassUtils.class); + PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER); // should load just fine since no http block @@ -129,10 +129,10 @@ public class SecurityNamespaceHandlerTests { @Test public void filterChainProxyClassNotFoundException() throws Exception { String className = FILTER_CHAIN_PROXY_CLASSNAME; - thrown.expect(BeanDefinitionParsingException.class); - thrown.expectMessage("ClassNotFoundException: " + className); - spy(ClassUtils.class); - doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", + this.thrown.expect(BeanDefinitionParsingException.class); + this.thrown.expectMessage("ClassNotFoundException: " + className); + PowerMockito.spy(ClassUtils.class); + PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); } @@ -140,8 +140,8 @@ public class SecurityNamespaceHandlerTests { @Test public void filterChainProxyClassNotFoundExceptionNoHttpBlock() throws Exception { String className = FILTER_CHAIN_PROXY_CLASSNAME; - spy(ClassUtils.class); - doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", + PowerMockito.spy(ClassUtils.class); + PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER); // should load just fine since no http block @@ -150,10 +150,11 @@ public class SecurityNamespaceHandlerTests { @Test public void websocketNotFoundExceptionNoMessageBlock() throws Exception { String className = FILTER_CHAIN_PROXY_CLASSNAME; - spy(ClassUtils.class); - doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", + PowerMockito.spy(ClassUtils.class); + PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", eq(Message.class.getName()), any(ClassLoader.class)); new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER); // should load just fine since no websocket block } + } diff --git a/config/src/test/java/org/springframework/security/config/TestBusinessBean.java b/config/src/test/java/org/springframework/security/config/TestBusinessBean.java index 8a27748dcf..d066c7e30e 100644 --- a/config/src/test/java/org/springframework/security/config/TestBusinessBean.java +++ b/config/src/test/java/org/springframework/security/config/TestBusinessBean.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; /** @@ -29,4 +30,5 @@ public interface TestBusinessBean { void doSomething(); void unprotected(); + } diff --git a/config/src/test/java/org/springframework/security/config/TestBusinessBeanImpl.java b/config/src/test/java/org/springframework/security/config/TestBusinessBeanImpl.java index 18e5dbbe7f..ae11a823c1 100644 --- a/config/src/test/java/org/springframework/security/config/TestBusinessBeanImpl.java +++ b/config/src/test/java/org/springframework/security/config/TestBusinessBeanImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import org.springframework.context.ApplicationListener; @@ -21,15 +22,18 @@ import org.springframework.security.core.session.SessionCreationEvent; /** * @author Luke Taylor */ -public class TestBusinessBeanImpl implements TestBusinessBean, - ApplicationListener { +public class TestBusinessBeanImpl implements TestBusinessBean, ApplicationListener { + + @Override public void setInteger(int i) { } + @Override public int getInteger() { return 1314; } + @Override public void setString(String s) { } @@ -37,13 +41,17 @@ public class TestBusinessBeanImpl implements TestBusinessBean, return "A string."; } + @Override public void doSomething() { } + @Override public void unprotected() { } + @Override public void onApplicationEvent(SessionCreationEvent event) { System.out.println(event); } + } diff --git a/config/src/test/java/org/springframework/security/config/TransactionalTestBusinessBean.java b/config/src/test/java/org/springframework/security/config/TransactionalTestBusinessBean.java index f45648d261..447d8ae88b 100644 --- a/config/src/test/java/org/springframework/security/config/TransactionalTestBusinessBean.java +++ b/config/src/test/java/org/springframework/security/config/TransactionalTestBusinessBean.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config; import org.springframework.transaction.annotation.Transactional; @@ -21,20 +22,27 @@ import org.springframework.transaction.annotation.Transactional; * @author Luke Taylor */ public class TransactionalTestBusinessBean implements TestBusinessBean { + + @Override public void setInteger(int i) { } + @Override public int getInteger() { return 0; } + @Override public void setString(String s) { } + @Override @Transactional public void doSomething() { } + @Override public void unprotected() { } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/ConcereteSecurityConfigurerAdapter.java b/config/src/test/java/org/springframework/security/config/annotation/ConcereteSecurityConfigurerAdapter.java index 78c12afeea..05c6425ff2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/ConcereteSecurityConfigurerAdapter.java +++ b/config/src/test/java/org/springframework/security/config/annotation/ConcereteSecurityConfigurerAdapter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; import java.util.ArrayList; @@ -22,17 +23,18 @@ import java.util.List; * @author Rob Winch * */ -class ConcereteSecurityConfigurerAdapter extends - SecurityConfigurerAdapter> { +class ConcereteSecurityConfigurerAdapter extends SecurityConfigurerAdapter> { + private List list = new ArrayList<>(); @Override public void configure(SecurityBuilder builder) { - list = postProcess(list); + this.list = postProcess(this.list); } - public ConcereteSecurityConfigurerAdapter list(List l) { + ConcereteSecurityConfigurerAdapter list(List l) { this.list = l; return this; } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/ObjectPostProcessorTests.java b/config/src/test/java/org/springframework/security/config/annotation/ObjectPostProcessorTests.java index ad93dc340e..03da3c1814 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/ObjectPostProcessorTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/ObjectPostProcessorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation; import java.util.ArrayList; @@ -21,7 +22,7 @@ import java.util.List; import org.junit.Test; -import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatObject; /** * @author Rob Winch @@ -32,22 +33,24 @@ public class ObjectPostProcessorTests { @Test public void convertTypes() { - assertThat((Object) PerformConversion.perform(new ArrayList<>())) - .isInstanceOf(LinkedList.class); + assertThatObject(PerformConversion.perform(new ArrayList<>())).isInstanceOf(LinkedList.class); } static class ListToLinkedListObjectPostProcessor implements ObjectPostProcessor> { + @Override public > O postProcess(O l) { return (O) new LinkedList(l); } + } static class PerformConversion { - public static List perform(ArrayList l) { + + static List perform(ArrayList l) { return new ListToLinkedListObjectPostProcessor().postProcess(l); } + } + } - - diff --git a/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterClosureTests.java b/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterClosureTests.java index 9565546229..b98c51df78 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterClosureTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterClosureTests.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation; +package org.springframework.security.config.annotation; import java.util.ArrayList; import java.util.List; @@ -30,6 +30,7 @@ import static org.mockito.Mockito.mock; * */ public class SecurityConfigurerAdapterClosureTests { + ConcereteSecurityConfigurerAdapter conf = new ConcereteSecurityConfigurerAdapter(); @Test @@ -42,25 +43,25 @@ public class SecurityConfigurerAdapterClosureTests { return l; } }); - this.conf.init(builder); this.conf.configure(builder); - assertThat(this.conf.list).contains("a"); } - static class ConcereteSecurityConfigurerAdapter extends - SecurityConfigurerAdapter> { - private List list = new ArrayList(); + static class ConcereteSecurityConfigurerAdapter extends SecurityConfigurerAdapter> { + + private List list = new ArrayList<>(); @Override public void configure(SecurityBuilder builder) throws Exception { this.list = postProcess(this.list); } - public ConcereteSecurityConfigurerAdapter list(List l) { + ConcereteSecurityConfigurerAdapter list(List l) { this.list = l; return this; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterTests.java b/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterTests.java index 2b01f0137f..9e6aeba683 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/SecurityConfigurerAdapterTests.java @@ -13,45 +13,52 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.config.annotation; import org.junit.Before; import org.junit.Test; + import org.springframework.core.Ordered; +import static org.assertj.core.api.Assertions.assertThat; + public class SecurityConfigurerAdapterTests { + ConcereteSecurityConfigurerAdapter adapter; @Before public void setup() { - adapter = new ConcereteSecurityConfigurerAdapter(); + this.adapter = new ConcereteSecurityConfigurerAdapter(); } @Test public void postProcessObjectPostProcessorsAreSorted() { - adapter.addObjectPostProcessor(new OrderedObjectPostProcessor(Ordered.LOWEST_PRECEDENCE)); - adapter.addObjectPostProcessor(new OrderedObjectPostProcessor(Ordered.HIGHEST_PRECEDENCE)); - - assertThat(adapter.postProcess("hi")) + this.adapter.addObjectPostProcessor(new OrderedObjectPostProcessor(Ordered.LOWEST_PRECEDENCE)); + this.adapter.addObjectPostProcessor(new OrderedObjectPostProcessor(Ordered.HIGHEST_PRECEDENCE)); + assertThat(this.adapter.postProcess("hi")) .isEqualTo("hi " + Ordered.HIGHEST_PRECEDENCE + " " + Ordered.LOWEST_PRECEDENCE); } static class OrderedObjectPostProcessor implements ObjectPostProcessor, Ordered { + private final int order; OrderedObjectPostProcessor(int order) { this.order = order; } + @Override public int getOrder() { - return order; + return this.order; } + @Override @SuppressWarnings("unchecked") public String postProcess(String object) { - return object + " " + order; + return object + " " + this.order; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/AuthenticationManagerBuilderTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/AuthenticationManagerBuilderTests.java index 6b02c0cb3f..80aeefbd7e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/AuthenticationManagerBuilderTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/AuthenticationManagerBuilderTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication; +import java.util.Arrays; +import java.util.Properties; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; @@ -45,17 +50,15 @@ import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.crypto.password.NoOpPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.provisioning.InMemoryUserDetailsManager; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.test.web.servlet.MockMvc; -import java.util.Arrays; -import java.util.Properties; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; @@ -63,18 +66,20 @@ import static org.springframework.security.test.web.servlet.response.SecurityMoc * @author Rob Winch */ public class AuthenticationManagerBuilderTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); + @Autowired(required = false) + MockMvc mockMvc; + @Test public void buildWhenAddAuthenticationProviderThenDoesNotPerformRegistration() throws Exception { ObjectPostProcessor opp = mock(ObjectPostProcessor.class); AuthenticationProvider provider = mock(AuthenticationProvider.class); - AuthenticationManagerBuilder builder = new AuthenticationManagerBuilder(opp); builder.authenticationProvider(provider); builder.build(); - verify(opp, never()).postProcess(provider); } @@ -83,109 +88,55 @@ public class AuthenticationManagerBuilderTests { public void customAuthenticationEventPublisherWithWeb() throws Exception { ObjectPostProcessor opp = mock(ObjectPostProcessor.class); AuthenticationEventPublisher aep = mock(AuthenticationEventPublisher.class); - when(opp.postProcess(any())).thenAnswer(a -> a.getArgument(0)); - AuthenticationManager am = new AuthenticationManagerBuilder(opp) - .authenticationEventPublisher(aep) - .inMemoryAuthentication() - .and() - .build(); - + given(opp.postProcess(any())).willAnswer((a) -> a.getArgument(0)); + AuthenticationManager am = new AuthenticationManagerBuilder(opp).authenticationEventPublisher(aep) + .inMemoryAuthentication().and().build(); try { am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - } catch (AuthenticationException success) {} - + } + catch (AuthenticationException success) { + } verify(aep).publishAuthenticationFailure(any(), any()); } @Test public void getAuthenticationManagerWhenGlobalPasswordEncoderBeanThenUsed() throws Exception { this.spring.register(PasswordEncoderGlobalConfig.class).autowire(); - AuthenticationManager manager = this.spring.getContext() - .getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - + AuthenticationManager manager = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); Authentication auth = manager.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - assertThat(auth.getName()).isEqualTo("user"); assertThat(auth.getAuthorities()).extracting(GrantedAuthority::getAuthority).containsOnly("ROLE_USER"); } - @EnableWebSecurity - static class PasswordEncoderGlobalConfig extends WebSecurityConfigurerAdapter { - @Autowired - void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("user").password("password").roles("USER"); - } - - @Bean - PasswordEncoder passwordEncoder() { - return NoOpPasswordEncoder.getInstance(); - } - } - @Test public void getAuthenticationManagerWhenProtectedPasswordEncoderBeanThenUsed() throws Exception { this.spring.register(PasswordEncoderGlobalConfig.class).autowire(); - AuthenticationManager manager = this.spring.getContext() - .getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - + AuthenticationManager manager = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); Authentication auth = manager.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - assertThat(auth.getName()).isEqualTo("user"); assertThat(auth.getAuthorities()).extracting(GrantedAuthority::getAuthority).containsOnly("ROLE_USER"); } - @EnableWebSecurity - static class PasswordEncoderConfig extends WebSecurityConfigurerAdapter { - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("user").password("password").roles("USER"); - } - - @Bean - PasswordEncoder passwordEncoder() { - return NoOpPasswordEncoder.getInstance(); - } - } - - @Autowired(required = false) - MockMvc mockMvc; - @Test public void authenticationManagerWhenMultipleProvidersThenWorks() throws Exception { this.spring.register(MultiAuthenticationProvidersConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withUsername("user").withRoles("USER")); - - this.mockMvc.perform(formLogin().user("admin")) - .andExpect(authenticated().withUsername("admin").withRoles("USER", "ADMIN")); - } - - @EnableWebSecurity - static class MultiAuthenticationProvidersConfig - extends WebSecurityConfigurerAdapter { - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth.inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()) - .and() - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.admin()); - } - + SecurityMockMvcResultMatchers.AuthenticatedMatcher user = authenticated().withUsername("user") + .withRoles("USER"); + this.mockMvc.perform(formLogin()).andExpect(user); + SecurityMockMvcResultMatchers.AuthenticatedMatcher admin = authenticated().withUsername("admin") + .withRoles("USER", "ADMIN"); + this.mockMvc.perform(formLogin().user("admin")).andExpect(admin); } @Test public void buildWhenAuthenticationProviderThenIsConfigured() throws Exception { ObjectPostProcessor opp = mock(ObjectPostProcessor.class); AuthenticationProvider provider = mock(AuthenticationProvider.class); - AuthenticationManagerBuilder builder = new AuthenticationManagerBuilder(opp); builder.authenticationProvider(provider); builder.build(); - assertThat(builder.isConfigured()).isTrue(); } @@ -193,29 +144,79 @@ public class AuthenticationManagerBuilderTests { public void buildWhenParentThenIsConfigured() throws Exception { ObjectPostProcessor opp = mock(ObjectPostProcessor.class); AuthenticationManager parent = mock(AuthenticationManager.class); - AuthenticationManagerBuilder builder = new AuthenticationManagerBuilder(opp); builder.parentAuthenticationManager(parent); builder.build(); - assertThat(builder.isConfigured()).isTrue(); } @Test public void buildWhenNotConfiguredThenIsConfiguredFalse() throws Exception { ObjectPostProcessor opp = mock(ObjectPostProcessor.class); - AuthenticationManagerBuilder builder = new AuthenticationManagerBuilder(opp); builder.build(); - assertThat(builder.isConfigured()).isFalse(); } public void buildWhenUserFromProperties() throws Exception { this.spring.register(UserFromPropertiesConfig.class).autowire(); - this.mockMvc.perform(formLogin().user("joe", "joespassword")) - .andExpect(authenticated().withUsername("joe").withRoles("USER")); + .andExpect(authenticated().withUsername("joe").withRoles("USER")); + } + + @EnableWebSecurity + static class MultiAuthenticationProvidersConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()) + .and() + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.admin()); + // @formatter:on + } + + } + + @EnableWebSecurity + static class PasswordEncoderGlobalConfig extends WebSecurityConfigurerAdapter { + + @Autowired + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("user").password("password").roles("USER"); + // @formatter:on + } + + @Bean + PasswordEncoder passwordEncoder() { + return NoOpPasswordEncoder.getInstance(); + } + + } + + @EnableWebSecurity + static class PasswordEncoderConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("user").password("password").roles("USER"); + // @formatter:on + } + + @Bean + PasswordEncoder passwordEncoder() { + return NoOpPasswordEncoder.getInstance(); + } + } @Configuration @@ -227,23 +228,24 @@ public class AuthenticationManagerBuilderTests { Resource users; @Bean - public AuthenticationManager authenticationManager() throws Exception { + AuthenticationManager authenticationManager() throws Exception { return new ProviderManager(Arrays.asList(authenticationProvider())); } @Bean - public AuthenticationProvider authenticationProvider() throws Exception { + AuthenticationProvider authenticationProvider() throws Exception { DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setUserDetailsService(userDetailsService()); return provider; } @Bean - public UserDetailsService userDetailsService() throws Exception { + UserDetailsService userDetailsService() throws Exception { Properties properties = new Properties(); properties.load(this.users.getInputStream()); return new InMemoryUserDetailsManager(properties); } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/BaseAuthenticationConfig.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/BaseAuthenticationConfig.java index 172d463d46..0f9ffa5107 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/BaseAuthenticationConfig.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/BaseAuthenticationConfig.java @@ -13,24 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; - /** - * * @author Rob Winch */ @Configuration public class BaseAuthenticationConfig { + @Autowired protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationManagerTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationManagerTests.java index 31c48eca29..f1d828d7e2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationManagerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationManagerTests.java @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.userdetails.PasswordEncodedUser; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.test.web.servlet.MockMvc; import static org.assertj.core.api.Assertions.assertThat; @@ -33,6 +36,7 @@ import static org.springframework.security.test.web.servlet.response.SecurityMoc * @author Rob Winch */ public class NamespaceAuthenticationManagerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -42,65 +46,74 @@ public class NamespaceAuthenticationManagerTests { @Test public void authenticationMangerWhenDefaultThenEraseCredentialsIsTrue() throws Exception { this.spring.register(EraseCredentialsTrueDefaultConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthentication(a-> assertThat(a.getCredentials()).isNull())); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthentication(a-> assertThat(a.getCredentials()).isNull())); + SecurityMockMvcResultMatchers.AuthenticatedMatcher nullCredentials = authenticated() + .withAuthentication((a) -> assertThat(a.getCredentials()).isNull()); + this.mockMvc.perform(formLogin()).andExpect(nullCredentials); + this.mockMvc.perform(formLogin()).andExpect(nullCredentials); // no exception due to username being cleared out } - @EnableWebSecurity - static class EraseCredentialsTrueDefaultConfig extends WebSecurityConfigurerAdapter { - @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - } - @Test public void authenticationMangerWhenEraseCredentialsIsFalseThenCredentialsNotNull() throws Exception { this.spring.register(EraseCredentialsFalseConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthentication(a-> assertThat(a.getCredentials()).isNotNull())); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthentication(a-> assertThat(a.getCredentials()).isNotNull())); + SecurityMockMvcResultMatchers.AuthenticatedMatcher notNullCredentials = authenticated() + .withAuthentication((a) -> assertThat(a.getCredentials()).isNotNull()); + this.mockMvc.perform(formLogin()).andExpect(notNullCredentials); + this.mockMvc.perform(formLogin()).andExpect(notNullCredentials); // no exception due to username being cleared out } - @EnableWebSecurity - static class EraseCredentialsFalseConfig extends WebSecurityConfigurerAdapter { - @Override - public void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .eraseCredentials(false) - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - } - @Test // SEC-2533 public void authenticationManagerWhenGlobalAndEraseCredentialsIsFalseThenCredentialsNotNull() throws Exception { this.spring.register(GlobalEraseCredentialsFalseConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthentication(a-> assertThat(a.getCredentials()).isNotNull())); + SecurityMockMvcResultMatchers.AuthenticatedMatcher notNullCredentials = authenticated() + .withAuthentication((a) -> assertThat(a.getCredentials()).isNotNull()); + this.mockMvc.perform(formLogin()).andExpect(notNullCredentials); } @EnableWebSecurity - static class GlobalEraseCredentialsFalseConfig extends WebSecurityConfigurerAdapter { + static class EraseCredentialsTrueDefaultConfig extends WebSecurityConfigurerAdapter { + @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + + @EnableWebSecurity + static class EraseCredentialsFalseConfig extends WebSecurityConfigurerAdapter { + + @Override + public void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .eraseCredentials(false) .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } + } + + @EnableWebSecurity + static class GlobalEraseCredentialsFalseConfig extends WebSecurityConfigurerAdapter { + + @Autowired + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .eraseCredentials(false) + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationProviderTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationProviderTests.java index b26b9fc2c6..f29882cea9 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationProviderTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceAuthenticationProviderTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.authentication; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.authentication.dao.DaoAuthenticationProvider; @@ -48,45 +49,53 @@ public class NamespaceAuthenticationProviderTests { // authentication-provider@ref public void authenticationProviderRef() throws Exception { this.spring.register(AuthenticationProviderRefConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withUsername("user")); - } - - @EnableWebSecurity - static class AuthenticationProviderRefConfig extends WebSecurityConfigurerAdapter { - protected void configure(AuthenticationManagerBuilder auth) { - auth - .authenticationProvider(authenticationProvider()); - } - - @Bean - public DaoAuthenticationProvider authenticationProvider() { - DaoAuthenticationProvider result = new DaoAuthenticationProvider(); - result.setUserDetailsService(new InMemoryUserDetailsManager(PasswordEncodedUser.user())); - return result; - } + this.mockMvc.perform(formLogin()).andExpect(authenticated().withUsername("user")); } @Test // authentication-provider@user-service-ref public void authenticationProviderUserServiceRef() throws Exception { this.spring.register(AuthenticationProviderRefConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(authenticated().withUsername("user")); + } + + @EnableWebSecurity + static class AuthenticationProviderRefConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) { + // @formatter:off + auth + .authenticationProvider(authenticationProvider()); + // @formatter:on + } + + @Bean + DaoAuthenticationProvider authenticationProvider() { + DaoAuthenticationProvider result = new DaoAuthenticationProvider(); + result.setUserDetailsService(new InMemoryUserDetailsManager(PasswordEncodedUser.user())); + return result; + } - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withUsername("user")); } @EnableWebSecurity static class UserServiceRefConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .userDetailsService(userDetailsService()); + // @formatter:on } + @Override @Bean public UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager(PasswordEncodedUser.user()); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceJdbcUserServiceTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceJdbcUserServiceTests.java index b1242b5a3f..86e28e1003 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceJdbcUserServiceTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespaceJdbcUserServiceTests.java @@ -16,8 +16,11 @@ package org.springframework.security.config.annotation.authentication; +import javax.sql.DataSource; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -30,11 +33,10 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.UserCache; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.security.provisioning.JdbcUserDetailsManager; +import org.springframework.security.core.userdetails.jdbc.JdbcDaoImpl; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.test.web.servlet.MockMvc; -import javax.sql.DataSource; - import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; @@ -52,48 +54,60 @@ public class NamespaceJdbcUserServiceTests { @Test public void jdbcUserService() throws Exception { this.spring.register(DataSourceConfig.class, JdbcUserServiceConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withUsername("user")); - } - - @EnableWebSecurity - static class JdbcUserServiceConfig extends WebSecurityConfigurerAdapter { - @Autowired - private DataSource dataSource; - - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .jdbcAuthentication() - .withDefaultSchema() - .withUser(PasswordEncodedUser.user()) - .dataSource(this.dataSource); // jdbc-user-service@data-source-ref - } - } - - @Configuration - static class DataSourceConfig { - @Bean - public DataSource dataSource() { - EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); - return builder.setType(EmbeddedDatabaseType.HSQL).build(); - } + SecurityMockMvcResultMatchers.AuthenticatedMatcher user = authenticated().withUsername("user"); + this.mockMvc.perform(formLogin()).andExpect(user); } @Test public void jdbcUserServiceCustom() throws Exception { this.spring.register(CustomDataSourceConfig.class, CustomJdbcUserServiceSampleConfig.class).autowire(); + // @formatter:off + SecurityMockMvcResultMatchers.AuthenticatedMatcher dba = authenticated() + .withUsername("user") + .withRoles("DBA", "USER"); + // @formatter:on + this.mockMvc.perform(formLogin()).andExpect(dba); + } + + @EnableWebSecurity + static class JdbcUserServiceConfig extends WebSecurityConfigurerAdapter { + + @Autowired + private DataSource dataSource; + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .jdbcAuthentication() + .withDefaultSchema() + .withUser(PasswordEncodedUser.user()) + .dataSource(this.dataSource); // jdbc-user-service@data-source-ref + // @formatter:on + } + + } + + @Configuration + static class DataSourceConfig { + + @Bean + DataSource dataSource() { + EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); + return builder.setType(EmbeddedDatabaseType.HSQL).build(); + } - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withUsername("user").withRoles("DBA", "USER")); } @EnableWebSecurity static class CustomJdbcUserServiceSampleConfig extends WebSecurityConfigurerAdapter { + @Autowired private DataSource dataSource; + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .jdbcAuthentication() // jdbc-user-service@dataSource @@ -105,10 +119,10 @@ public class NamespaceJdbcUserServiceTests { // jdbc-user-service@authorities-by-username-query .authoritiesByUsernameQuery("select principal,role from roles where principal = ?") // jdbc-user-service@group-authorities-by-username-query - .groupAuthoritiesByUsername(JdbcUserDetailsManager.DEF_GROUP_AUTHORITIES_BY_USERNAME_QUERY) + .groupAuthoritiesByUsername(JdbcDaoImpl.DEF_GROUP_AUTHORITIES_BY_USERNAME_QUERY) // jdbc-user-service@role-prefix .rolePrefix("ROLE_"); - + // @formatter:on } static class CustomUserCache implements UserCache { @@ -125,16 +139,22 @@ public class NamespaceJdbcUserServiceTests { @Override public void removeUserFromCache(String username) { } + } + } + @Configuration static class CustomDataSourceConfig { + @Bean - public DataSource dataSource() { + DataSource dataSource() { EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder() - // simulate that the DB already has the schema loaded and users in it - .addScript("CustomJdbcUserServiceSampleConfig.sql"); + // simulate that the DB already has the schema loaded and users in it + .addScript("CustomJdbcUserServiceSampleConfig.sql"); return builder.setType(EmbeddedDatabaseType.HSQL).build(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespacePasswordEncoderTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespacePasswordEncoderTests.java index d8c2ce0ccc..91759cd703 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespacePasswordEncoderTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/NamespacePasswordEncoderTests.java @@ -16,8 +16,11 @@ package org.springframework.security.config.annotation.authentication; +import javax.sql.DataSource; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; @@ -32,8 +35,6 @@ import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.test.web.servlet.MockMvc; -import javax.sql.DataSource; - import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; @@ -51,80 +52,88 @@ public class NamespacePasswordEncoderTests { @Test public void passwordEncoderRefWithInMemory() throws Exception { this.spring.register(PasswordEncoderWithInMemoryConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated()); - } - - @EnableWebSecurity - static class PasswordEncoderWithInMemoryConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(); - auth - .inMemoryAuthentication() - .withUser("user").password(encoder.encode("password")).roles("USER").and() - .passwordEncoder(encoder); - } + this.mockMvc.perform(formLogin()).andExpect(authenticated()); } @Test public void passwordEncoderRefWithJdbc() throws Exception { this.spring.register(PasswordEncoderWithJdbcConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(authenticated()); + } + + @Test + public void passwordEncoderRefWithUserDetailsService() throws Exception { + this.spring.register(PasswordEncoderWithUserDetailsServiceConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(authenticated()); + } + + @EnableWebSecurity + static class PasswordEncoderWithInMemoryConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(); + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("user").password(encoder.encode("password")).roles("USER").and() + .passwordEncoder(encoder); + // @formatter:on + } - this.mockMvc.perform(formLogin()) - .andExpect(authenticated()); } @EnableWebSecurity static class PasswordEncoderWithJdbcConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { - BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(); + // @formatter:off auth .jdbcAuthentication() .withDefaultSchema() .dataSource(dataSource()) .withUser("user").password(encoder.encode("password")).roles("USER").and() .passwordEncoder(encoder); + // @formatter:on } @Bean - public DataSource dataSource() { + DataSource dataSource() { EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); return builder.setType(EmbeddedDatabaseType.HSQL).build(); } - } - @Test - public void passwordEncoderRefWithUserDetailsService() throws Exception { - this.spring.register(PasswordEncoderWithUserDetailsServiceConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated()); } @EnableWebSecurity static class PasswordEncoderWithUserDetailsServiceConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(); + // @formatter:off UserDetails user = User.withUsername("user") .passwordEncoder(encoder::encode) .password("password") .roles("USER") .build(); + // @formatter:on InMemoryUserDetailsManager uds = new InMemoryUserDetailsManager(user); + // @formatter:off auth .userDetailsService(uds) .passwordEncoder(encoder); + // @formatter:on } @Bean - public DataSource dataSource() { + DataSource dataSource() { EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); return builder.setType(EmbeddedDatabaseType.HSQL).build(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/PasswordEncoderConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/PasswordEncoderConfigurerTests.java index e145d411c5..456efd48f6 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/PasswordEncoderConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/PasswordEncoderConfigurerTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.authentication; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; @@ -47,52 +48,56 @@ public class PasswordEncoderConfigurerTests { this.spring.register(PasswordEncoderConfig.class).autowire(); } + @Test + public void passwordEncoderRefWhenAuthenticationManagerBuilderThenAuthenticationSuccess() throws Exception { + this.spring.register(PasswordEncoderNoAuthManagerLoadsConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(authenticated()); + } + @EnableWebSecurity static class PasswordEncoderConfig extends WebSecurityConfigurerAdapter { - // @formatter:off + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { BCryptPasswordEncoder encoder = passwordEncoder(); + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password(encoder.encode("password")).roles("USER").and() .passwordEncoder(encoder); + // @formatter:on } - // @formatter:on @Override protected void configure(HttpSecurity http) { } @Bean - public BCryptPasswordEncoder passwordEncoder() { + BCryptPasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); } - } - @Test - public void passwordEncoderRefWhenAuthenticationManagerBuilderThenAuthenticationSuccess() throws Exception { - this.spring.register(PasswordEncoderNoAuthManagerLoadsConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated()); } @EnableWebSecurity - static class PasswordEncoderNoAuthManagerLoadsConfig extends - WebSecurityConfigurerAdapter { - // @formatter:off + static class PasswordEncoderNoAuthManagerLoadsConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { BCryptPasswordEncoder encoder = passwordEncoder(); + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password(encoder.encode("password")).roles("USER").and() .passwordEncoder(encoder); + // @formatter:on } - // @formatter:on @Bean - public BCryptPasswordEncoder passwordEncoder() { + BCryptPasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationPublishTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationPublishTests.java index d797632e5b..3c6ec39bf5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationPublishTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationPublishTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.authentication.configurat import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -30,13 +31,14 @@ import org.springframework.security.config.MockEventListener; import org.springframework.security.config.users.AuthenticationTestConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch */ @RunWith(SpringJUnit4ClassRunner.class) public class AuthenticationConfigurationPublishTests { + @Autowired MockEventListener listener; @@ -46,7 +48,6 @@ public class AuthenticationConfigurationPublishTests { @Test public void authenticationEventPublisherBeanUsedByDefault() { this.authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - assertThat(this.listener.getEvents()).hasSize(1); } @@ -58,6 +59,7 @@ public class AuthenticationConfigurationPublishTests { @EnableGlobalAuthentication @Import(AuthenticationTestConfiguration.class) static class Config { + @Bean AuthenticationEventPublisher publisher() { return new DefaultAuthenticationEventPublisher(); @@ -65,8 +67,10 @@ public class AuthenticationConfigurationPublishTests { @Bean MockEventListener eventListener() { - return new MockEventListener(){}; + return new MockEventListener() { + }; } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationTests.java index 3ba459170f..d410060767 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/AuthenticationConfigurationTests.java @@ -13,11 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configuration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.junit.After; import org.junit.Rule; import org.junit.Test; + import org.springframework.aop.framework.ProxyFactoryBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -50,19 +56,20 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.provisioning.InMemoryUserDetailsManager; -import org.springframework.security.core.userdetails.UserDetailsPasswordService; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.startsWith; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; public class AuthenticationConfigurationTests { @@ -79,153 +86,87 @@ public class AuthenticationConfigurationTests { @Test public void orderingAutowiredOnEnableGlobalMethodSecurity() { - this.spring.register(AuthenticationTestConfiguration.class, GlobalMethodSecurityAutowiredConfig.class, ServicesConfig.class).autowire(); - - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); + this.spring.register(AuthenticationTestConfiguration.class, GlobalMethodSecurityAutowiredConfig.class, + ServicesConfig.class).autowire(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); this.service.run(); } - @EnableGlobalMethodSecurity(securedEnabled = true) - static class GlobalMethodSecurityAutowiredConfig { - } - @Test public void orderingAutowiredOnEnableWebSecurity() { - this.spring.register(AuthenticationTestConfiguration.class, WebSecurityConfig.class, GlobalMethodSecurityAutowiredConfig.class, ServicesConfig.class).autowire(); - - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); + this.spring.register(AuthenticationTestConfiguration.class, WebSecurityConfig.class, + GlobalMethodSecurityAutowiredConfig.class, ServicesConfig.class).autowire(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); this.service.run(); } - @EnableWebSecurity - static class WebSecurityConfig {} - - @Test public void orderingAutowiredOnEnableWebMvcSecurity() { - this.spring.register(AuthenticationTestConfiguration.class, WebMvcSecurityConfig.class, GlobalMethodSecurityAutowiredConfig.class, ServicesConfig.class).autowire(); - - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); + this.spring.register(AuthenticationTestConfiguration.class, WebMvcSecurityConfig.class, + GlobalMethodSecurityAutowiredConfig.class, ServicesConfig.class).autowire(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); this.service.run(); } - @EnableWebMvcSecurity - static class WebMvcSecurityConfig {} - @Test public void getAuthenticationManagerWhenNoAuthenticationThenNull() throws Exception { this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class).autowire(); - - assertThat(this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager()).isNull(); + assertThat(this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager()) + .isNull(); } - @Test public void getAuthenticationManagerWhenNoOpGlobalAuthenticationConfigurerAdapterThenNull() throws Exception { - this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, NoOpGlobalAuthenticationConfigurerAdapter.class).autowire(); - - assertThat(this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager()).isNull(); + this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, + NoOpGlobalAuthenticationConfigurerAdapter.class).autowire(); + assertThat(this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager()) + .isNull(); } - @Configuration - static class NoOpGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter {} - @Test public void getAuthenticationWhenGlobalAuthenticationConfigurerAdapterThenAuthenticates() throws Exception { UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password"); - this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, UserGlobalAuthenticationConfigurerAdapter.class).autowire(); - - AuthenticationManager authentication = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - + this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, + UserGlobalAuthenticationConfigurerAdapter.class).autowire(); + AuthenticationManager authentication = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); assertThat(authentication.authenticate(token).getName()).isEqualTo(token.getName()); } - @Configuration - static class UserGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { - public void init(AuthenticationManagerBuilder auth) throws Exception { - auth.inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - } - @Test public void getAuthenticationWhenAuthenticationManagerBeanThenAuthenticates() throws Exception { UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password"); - this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, AuthenticationManagerBeanConfig.class).autowire(); - - AuthenticationManager authentication = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - when(authentication.authenticate(token)).thenReturn(TestAuthentication.authenticatedUser()); - + this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, + AuthenticationManagerBeanConfig.class).autowire(); + AuthenticationManager authentication = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); + given(authentication.authenticate(token)).willReturn(TestAuthentication.authenticatedUser()); assertThat(authentication.authenticate(token).getName()).isEqualTo(token.getName()); } - @Configuration - static class AuthenticationManagerBeanConfig { - AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - - @Bean - public AuthenticationManager authenticationManager() { - return this.authenticationManager; - } + @Test + public void getAuthenticationWhenMultipleThenOrdered() throws Exception { + this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, + AuthenticationManagerBeanConfig.class).autowire(); + AuthenticationConfiguration config = this.spring.getContext().getBean(AuthenticationConfiguration.class); + config.setGlobalAuthenticationConfigurers(Arrays.asList(new LowestOrderGlobalAuthenticationConfigurerAdapter(), + new HighestOrderGlobalAuthenticationConfigurerAdapter(), + new DefaultOrderGlobalAuthenticationConfigurerAdapter())); } - // - // // - // - @Configuration - static class ServicesConfig { - @Bean - public Service service() { - return new ServiceImpl(); - } - } - - interface Service { - void run(); - } - - static class ServiceImpl implements Service { - @Secured("ROLE_USER") - public void run() {} - } - - @Test - public void getAuthenticationWhenMultipleThenOrdered() throws Exception { - this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class, AuthenticationManagerBeanConfig.class).autowire(); - AuthenticationConfiguration config = this.spring.getContext().getBean(AuthenticationConfiguration.class); - config.setGlobalAuthenticationConfigurers(Arrays.asList(new LowestOrderGlobalAuthenticationConfigurerAdapter(), new HighestOrderGlobalAuthenticationConfigurerAdapter(), new DefaultOrderGlobalAuthenticationConfigurerAdapter())); - } - - static class DefaultOrderGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { - static List> inits = new ArrayList<>(); - static List> configs = new ArrayList<>(); - - public void init(AuthenticationManagerBuilder auth) throws Exception { - inits.add(getClass()); - } - - public void configure(AuthenticationManagerBuilder auth) { - configs.add(getClass()); - } - } - - @Order(Ordered.LOWEST_PRECEDENCE) - static class LowestOrderGlobalAuthenticationConfigurerAdapter extends DefaultOrderGlobalAuthenticationConfigurerAdapter {} - - @Order(Ordered.HIGHEST_PRECEDENCE) - static class HighestOrderGlobalAuthenticationConfigurerAdapter extends DefaultOrderGlobalAuthenticationConfigurerAdapter {} @Test public void getAuthenticationWhenConfiguredThenBootNotTrigger() throws Exception { this.spring.register(AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class).autowire(); AuthenticationConfiguration config = this.spring.getContext().getBean(AuthenticationConfiguration.class); - config.setGlobalAuthenticationConfigurers(Arrays.asList(new ConfiguresInMemoryConfigurerAdapter(), new BootGlobalAuthenticationConfigurerAdapter())); + config.setGlobalAuthenticationConfigurers(Arrays.asList(new ConfiguresInMemoryConfigurerAdapter(), + new BootGlobalAuthenticationConfigurerAdapter())); AuthenticationManager authenticationManager = config.getAuthenticationManager(); - authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - - assertThatThrownBy(() -> authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("boot", "password"))) - .isInstanceOf(AuthenticationException.class); - + assertThatExceptionOfType(AuthenticationException.class).isThrownBy( + () -> authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("boot", "password"))); } @Test @@ -234,154 +175,332 @@ public class AuthenticationConfigurationTests { AuthenticationConfiguration config = this.spring.getContext().getBean(AuthenticationConfiguration.class); config.setGlobalAuthenticationConfigurers(Arrays.asList(new BootGlobalAuthenticationConfigurerAdapter())); AuthenticationManager authenticationManager = config.getAuthenticationManager(); - authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("boot", "password")); } - static class ConfiguresInMemoryConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { - - public void init(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - } - - @Order(Ordered.LOWEST_PRECEDENCE) - static class BootGlobalAuthenticationConfigurerAdapter extends DefaultOrderGlobalAuthenticationConfigurerAdapter { - public void init(AuthenticationManagerBuilder auth) throws Exception { - auth.apply(new DefaultBootGlobalAuthenticationConfigurerAdapter()); - } - } - - static class DefaultBootGlobalAuthenticationConfigurerAdapter extends DefaultOrderGlobalAuthenticationConfigurerAdapter { - @Override - public void configure(AuthenticationManagerBuilder auth) { - if (auth.isConfigured()) { - return; - } - - UserDetails user = User.withUserDetails(PasswordEncodedUser.user()).username("boot").build(); - - List users = Arrays.asList(user); - InMemoryUserDetailsManager inMemory = new InMemoryUserDetailsManager(users); - - DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); - provider.setUserDetailsService(inMemory); - - auth.authenticationProvider(provider); - } - } - // gh-2531 @Test public void getAuthenticationManagerWhenPostProcessThenUsesBeanClassLoaderOnProxyFactoryBean() throws Exception { this.spring.register(Sec2531Config.class).autowire(); ObjectPostProcessor opp = this.spring.getContext().getBean(ObjectPostProcessor.class); - when(opp.postProcess(any())).thenAnswer(a -> a.getArgument(0)); - + given(opp.postProcess(any())).willAnswer((a) -> a.getArgument(0)); AuthenticationConfiguration config = this.spring.getContext().getBean(AuthenticationConfiguration.class); config.getAuthenticationManager(); - verify(opp).postProcess(any(ProxyFactoryBean.class)); } + @Test + public void getAuthenticationManagerWhenSec2822ThenCannotForceAuthenticationAlreadyBuilt() throws Exception { + this.spring.register(Sec2822WebSecurity.class, Sec2822UseAuth.class, Sec2822Config.class).autowire(); + this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); + // no exception + } + + // sec-2868 + @Test + public void getAuthenticationWhenUserDetailsServiceBeanThenAuthenticationManagerUsesUserDetailsServiceBean() + throws Exception { + this.spring.register(UserDetailsServiceBeanConfig.class).autowire(); + UserDetailsService uds = this.spring.getContext().getBean(UserDetailsService.class); + AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); + given(uds.loadUserByUsername("user")).willReturn(PasswordEncodedUser.user(), PasswordEncodedUser.user()); + am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); + assertThatExceptionOfType(AuthenticationException.class) + .isThrownBy(() -> am.authenticate(new UsernamePasswordAuthenticationToken("user", "invalid"))); + } + + @Test + public void getAuthenticationWhenUserDetailsServiceAndPasswordEncoderBeanThenEncoderUsed() throws Exception { + UserDetails user = new User("user", "$2a$10$FBAKClV1zBIOOC9XMXf3AO8RoGXYVYsfvUdoLxGkd/BnXEn4tqT3u", + AuthorityUtils.createAuthorityList("ROLE_USER")); + this.spring.register(UserDetailsServiceBeanWithPasswordEncoderConfig.class).autowire(); + UserDetailsService uds = this.spring.getContext().getBean(UserDetailsService.class); + AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); + given(uds.loadUserByUsername("user")).willReturn(User.withUserDetails(user).build(), + User.withUserDetails(user).build()); + am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); + assertThatExceptionOfType(AuthenticationException.class) + .isThrownBy(() -> am.authenticate(new UsernamePasswordAuthenticationToken("user", "invalid"))); + } + + @Test + public void getAuthenticationWhenUserDetailsServiceAndPasswordManagerThenManagerUsed() throws Exception { + UserDetails user = new User("user", "{noop}password", AuthorityUtils.createAuthorityList("ROLE_USER")); + this.spring.register(UserDetailsPasswordManagerBeanConfig.class).autowire(); + UserDetailsPasswordManagerBeanConfig.Manager manager = this.spring.getContext() + .getBean(UserDetailsPasswordManagerBeanConfig.Manager.class); + AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); + given(manager.loadUserByUsername("user")).willReturn(User.withUserDetails(user).build(), + User.withUserDetails(user).build()); + given(manager.updatePassword(any(), any())).willReturn(user); + am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); + verify(manager).updatePassword(eq(user), startsWith("{bcrypt}")); + } + + @Test + public void getAuthenticationWhenAuthenticationProviderAndUserDetailsBeanThenAuthenticationProviderUsed() + throws Exception { + this.spring.register(AuthenticationProviderBeanAndUserDetailsServiceConfig.class).autowire(); + AuthenticationProvider ap = this.spring.getContext().getBean(AuthenticationProvider.class); + AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); + given(ap.supports(any())).willReturn(true); + given(ap.authenticate(any())).willReturn(TestAuthentication.authenticatedUser()); + am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); + } + + // gh-3091 + @Test + public void getAuthenticationWhenAuthenticationProviderBeanThenUsed() throws Exception { + this.spring.register(AuthenticationProviderBeanConfig.class).autowire(); + AuthenticationProvider ap = this.spring.getContext().getBean(AuthenticationProvider.class); + AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class) + .getAuthenticationManager(); + given(ap.supports(any())).willReturn(true); + given(ap.authenticate(any())).willReturn(TestAuthentication.authenticatedUser()); + am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); + } + + @Test + public void enableGlobalMethodSecurityWhenPreAuthorizeThenNoException() { + this.spring.register(UsesPreAuthorizeMethodSecurityConfig.class, AuthenticationManagerBeanConfig.class) + .autowire(); + // no exception + } + + @Test + public void enableGlobalMethodSecurityWhenPreAuthorizeThenUsesMethodSecurityService() { + this.spring.register(ServicesConfig.class, UsesPreAuthorizeMethodSecurityConfig.class, + AuthenticationManagerBeanConfig.class).autowire(); + // no exception + } + + @Test + public void getAuthenticationManagerBeanWhenMultipleDefinedAndOnePrimaryThenNoException() throws Exception { + this.spring.register(MultipleAuthenticationManagerBeanConfig.class).autowire(); + this.spring.getContext().getBeanFactory().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); + } + + @Test + public void getAuthenticationManagerWhenAuthenticationConfigurationSubclassedThenBuildsUsingBean() + throws Exception { + this.spring.register(AuthenticationConfigurationSubclass.class).autowire(); + AuthenticationManagerBuilder ap = this.spring.getContext().getBean(AuthenticationManagerBuilder.class); + this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); + assertThatExceptionOfType(AlreadyBuiltException.class).isThrownBy(ap::build); + } + + @EnableGlobalMethodSecurity(securedEnabled = true) + static class GlobalMethodSecurityAutowiredConfig { + + } + + @EnableWebSecurity + static class WebSecurityConfig { + + } + + @EnableWebMvcSecurity + static class WebMvcSecurityConfig { + + } + + @Configuration + static class NoOpGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { + + } + + @Configuration + static class UserGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { + + @Override + public void init(AuthenticationManagerBuilder auth) throws Exception { + auth.inMemoryAuthentication().withUser(PasswordEncodedUser.user()); + } + + } + + @Configuration + static class AuthenticationManagerBeanConfig { + + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + + @Bean + AuthenticationManager authenticationManager() { + return this.authenticationManager; + } + + } + + @Configuration + static class ServicesConfig { + + @Bean + Service service() { + return new ServiceImpl(); + } + + } + + interface Service { + + void run(); + + } + + static class ServiceImpl implements Service { + + @Override + @Secured("ROLE_USER") + public void run() { + } + + } + + static class DefaultOrderGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { + + static List> inits = new ArrayList<>(); + static List> configs = new ArrayList<>(); + + @Override + public void init(AuthenticationManagerBuilder auth) throws Exception { + inits.add(getClass()); + } + + @Override + public void configure(AuthenticationManagerBuilder auth) { + configs.add(getClass()); + } + + } + + @Order(Ordered.LOWEST_PRECEDENCE) + static class LowestOrderGlobalAuthenticationConfigurerAdapter + extends DefaultOrderGlobalAuthenticationConfigurerAdapter { + + } + + @Order(Ordered.HIGHEST_PRECEDENCE) + static class HighestOrderGlobalAuthenticationConfigurerAdapter + extends DefaultOrderGlobalAuthenticationConfigurerAdapter { + + } + + static class ConfiguresInMemoryConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { + + @Override + public void init(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + + @Order(Ordered.LOWEST_PRECEDENCE) + static class BootGlobalAuthenticationConfigurerAdapter extends DefaultOrderGlobalAuthenticationConfigurerAdapter { + + @Override + public void init(AuthenticationManagerBuilder auth) throws Exception { + auth.apply(new DefaultBootGlobalAuthenticationConfigurerAdapter()); + } + + } + + static class DefaultBootGlobalAuthenticationConfigurerAdapter + extends DefaultOrderGlobalAuthenticationConfigurerAdapter { + + @Override + public void configure(AuthenticationManagerBuilder auth) { + if (auth.isConfigured()) { + return; + } + UserDetails user = User.withUserDetails(PasswordEncodedUser.user()).username("boot").build(); + List users = Arrays.asList(user); + InMemoryUserDetailsManager inMemory = new InMemoryUserDetailsManager(users); + DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); + provider.setUserDetailsService(inMemory); + auth.authenticationProvider(provider); + } + + } + @Configuration @Import(AuthenticationConfiguration.class) static class Sec2531Config { @Bean - public ObjectPostProcessor objectPostProcessor() { + ObjectPostProcessor objectPostProcessor() { return mock(ObjectPostProcessor.class); } @Bean - public AuthenticationManager manager() { + AuthenticationManager manager() { return null; } - } - @Test - public void getAuthenticationManagerWhenSec2822ThenCannotForceAuthenticationAlreadyBuilt() throws Exception { - this.spring.register(Sec2822WebSecurity.class, Sec2822UseAuth.class, Sec2822Config.class).autowire(); - - this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - // no exception } @Configuration @Import(AuthenticationConfiguration.class) - static class Sec2822Config {} + static class Sec2822Config { + + } @Configuration @EnableWebSecurity static class Sec2822WebSecurity extends WebSecurityConfigurerAdapter { + @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { auth.inMemoryAuthentication(); } + } @Configuration static class Sec2822UseAuth { + @Autowired - public void useAuthenticationManager(AuthenticationConfiguration auth) throws Exception { + void useAuthenticationManager(AuthenticationConfiguration auth) throws Exception { auth.getAuthenticationManager(); } // Ensures that Sec2822UseAuth is initialized before Sec2822WebSecurity // must have additional GlobalAuthenticationConfigurerAdapter to trigger SEC-2822 @Bean - public static GlobalAuthenticationConfigurerAdapter bootGlobalAuthenticationConfigurerAdapter() { + static GlobalAuthenticationConfigurerAdapter bootGlobalAuthenticationConfigurerAdapter() { return new BootGlobalAuthenticationConfigurerAdapter(); } - static class BootGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { } - } + static class BootGlobalAuthenticationConfigurerAdapter extends GlobalAuthenticationConfigurerAdapter { - // sec-2868 - @Test - public void getAuthenticationWhenUserDetailsServiceBeanThenAuthenticationManagerUsesUserDetailsServiceBean() throws Exception { - this.spring.register(UserDetailsServiceBeanConfig.class).autowire(); - UserDetailsService uds = this.spring.getContext().getBean(UserDetailsService.class); - AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - when(uds.loadUserByUsername("user")).thenReturn(PasswordEncodedUser.user(), PasswordEncodedUser.user()); + } - am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - - assertThatThrownBy(() -> am.authenticate(new UsernamePasswordAuthenticationToken("user", "invalid"))) - .isInstanceOf(AuthenticationException.class); } @Configuration - @Import({AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class}) + @Import({ AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class }) static class UserDetailsServiceBeanConfig { + UserDetailsService uds = mock(UserDetailsService.class); @Bean UserDetailsService userDetailsService() { return this.uds; } - } - @Test - public void getAuthenticationWhenUserDetailsServiceAndPasswordEncoderBeanThenEncoderUsed() throws Exception { - UserDetails user = new User("user", "$2a$10$FBAKClV1zBIOOC9XMXf3AO8RoGXYVYsfvUdoLxGkd/BnXEn4tqT3u", - AuthorityUtils.createAuthorityList("ROLE_USER")); - this.spring.register(UserDetailsServiceBeanWithPasswordEncoderConfig.class).autowire(); - UserDetailsService uds = this.spring.getContext().getBean(UserDetailsService.class); - AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - when(uds.loadUserByUsername("user")).thenReturn(User.withUserDetails(user).build(), User.withUserDetails(user).build()); - - am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - - assertThatThrownBy(() -> am.authenticate(new UsernamePasswordAuthenticationToken("user", "invalid"))) - .isInstanceOf(AuthenticationException.class); } @Configuration - @Import({AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class}) + @Import({ AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class }) static class UserDetailsServiceBeanWithPasswordEncoderConfig { + UserDetailsService uds = mock(UserDetailsService.class); @Bean @@ -393,26 +512,13 @@ public class AuthenticationConfigurationTests { PasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); } - } - @Test - public void getAuthenticationWhenUserDetailsServiceAndPasswordManagerThenManagerUsed() throws Exception { - UserDetails user = new User("user", "{noop}password", - AuthorityUtils.createAuthorityList("ROLE_USER")); - this.spring.register(UserDetailsPasswordManagerBeanConfig.class).autowire(); - UserDetailsPasswordManagerBeanConfig.Manager manager = this.spring.getContext().getBean(UserDetailsPasswordManagerBeanConfig.Manager.class); - AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - when(manager.loadUserByUsername("user")).thenReturn(User.withUserDetails(user).build(), User.withUserDetails(user).build()); - when(manager.updatePassword(any(), any())).thenReturn(user); - - am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); - - verify(manager).updatePassword(eq(user), startsWith("{bcrypt}")); } @Configuration - @Import({AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class}) + @Import({ AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class }) static class UserDetailsPasswordManagerBeanConfig { + Manager manager = mock(Manager.class); @Bean @@ -421,46 +527,28 @@ public class AuthenticationConfigurationTests { } interface Manager extends UserDetailsService, UserDetailsPasswordService { + } - } - //gh-3091 - @Test - public void getAuthenticationWhenAuthenticationProviderBeanThenUsed() throws Exception { - this.spring.register(AuthenticationProviderBeanConfig.class).autowire(); - AuthenticationProvider ap = this.spring.getContext().getBean(AuthenticationProvider.class); - AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - when(ap.supports(any())).thenReturn(true); - when(ap.authenticate(any())).thenReturn(TestAuthentication.authenticatedUser()); - - am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); } @Configuration - @Import({AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class}) + @Import({ AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class }) static class AuthenticationProviderBeanConfig { + AuthenticationProvider provider = mock(AuthenticationProvider.class); @Bean AuthenticationProvider authenticationProvider() { return this.provider; } - } - @Test - public void getAuthenticationWhenAuthenticationProviderAndUserDetailsBeanThenAuthenticationProviderUsed() throws Exception { - this.spring.register(AuthenticationProviderBeanAndUserDetailsServiceConfig.class).autowire(); - AuthenticationProvider ap = this.spring.getContext().getBean(AuthenticationProvider.class); - AuthenticationManager am = this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - when(ap.supports(any())).thenReturn(true); - when(ap.authenticate(any())).thenReturn(TestAuthentication.authenticatedUser()); - - am.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); } @Configuration - @Import({AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class}) + @Import({ AuthenticationConfiguration.class, ObjectPostProcessorConfiguration.class }) static class AuthenticationProviderBeanAndUserDetailsServiceConfig { + AuthenticationProvider provider = mock(AuthenticationProvider.class); UserDetailsService uds = mock(UserDetailsService.class); @@ -474,40 +562,26 @@ public class AuthenticationConfigurationTests { AuthenticationProvider authenticationProvider() { return this.provider; } - } - @Test - public void enableGlobalMethodSecurityWhenPreAuthorizeThenNoException() { - this.spring.register(UsesPreAuthorizeMethodSecurityConfig.class, AuthenticationManagerBeanConfig.class).autowire(); - - // no exception } @Configuration @EnableGlobalMethodSecurity(prePostEnabled = true) static class UsesPreAuthorizeMethodSecurityConfig { + @PreAuthorize("denyAll") - void run() {} - } + void run() { + } - @Test - public void enableGlobalMethodSecurityWhenPreAuthorizeThenUsesMethodSecurityService() { - this.spring.register(ServicesConfig.class, UsesPreAuthorizeMethodSecurityConfig.class, AuthenticationManagerBeanConfig.class).autowire(); - - // no exception } @Configuration @EnableGlobalMethodSecurity(securedEnabled = true) static class UsesServiceMethodSecurityConfig { + @Autowired Service service; - } - @Test - public void getAuthenticationManagerBeanWhenMultipleDefinedAndOnePrimaryThenNoException() throws Exception { - this.spring.register(MultipleAuthenticationManagerBeanConfig.class).autowire(); - this.spring.getContext().getBeanFactory().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); } @Configuration @@ -516,30 +590,20 @@ public class AuthenticationConfigurationTests { @Bean @Primary - public AuthenticationManager manager1() { + AuthenticationManager manager1() { return mock(AuthenticationManager.class); } @Bean - public AuthenticationManager manager2() { + AuthenticationManager manager2() { return mock(AuthenticationManager.class); } } - @Test - public void getAuthenticationManagerWhenAuthenticationConfigurationSubclassedThenBuildsUsingBean() - throws Exception { - this.spring.register(AuthenticationConfigurationSubclass.class).autowire(); - AuthenticationManagerBuilder ap = this.spring.getContext().getBean(AuthenticationManagerBuilder.class); - - this.spring.getContext().getBean(AuthenticationConfiguration.class).getAuthenticationManager(); - - assertThatThrownBy(ap::build) - .isInstanceOf(AlreadyBuiltException.class); - } - @Configuration static class AuthenticationConfigurationSubclass extends AuthenticationConfiguration { + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthenticationTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthenticationTests.java index 0670a82769..8a5e0db601 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthenticationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/configuration/EnableGlobalAuthenticationTests.java @@ -13,24 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.authentication.configuration; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.config.annotation.authentication.configuration; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.test.SpringTestRule; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Rob Winch * */ public class EnableGlobalAuthenticationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -38,83 +40,87 @@ public class EnableGlobalAuthenticationTests { @Test public void authenticationConfigurationWhenGetAuthenticationManagerThenNotNull() throws Exception { this.spring.register(Config.class).autowire(); - - AuthenticationConfiguration auth = spring.getContext().getBean(AuthenticationConfiguration.class); - + AuthenticationConfiguration auth = this.spring.getContext().getBean(AuthenticationConfiguration.class); assertThat(auth.getAuthenticationManager()).isNotNull(); } + @Test + public void enableGlobalAuthenticationWhenNoConfigurationAnnotationThenBeanProxyingEnabled() { + this.spring.register(BeanProxyEnabledByDefaultConfig.class).autowire(); + Child childBean = this.spring.getContext().getBean(Child.class); + Parent parentBean = this.spring.getContext().getBean(Parent.class); + assertThat(parentBean.getChild()).isSameAs(childBean); + } + + @Test + public void enableGlobalAuthenticationWhenProxyBeanMethodsFalseThenBeanProxyingDisabled() { + this.spring.register(BeanProxyDisabledConfig.class).autowire(); + Child childBean = this.spring.getContext().getBean(Child.class); + Parent parentBean = this.spring.getContext().getBean(Parent.class); + assertThat(parentBean.getChild()).isNotSameAs(childBean); + } + @Configuration @EnableGlobalAuthentication static class Config { @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { auth.inMemoryAuthentication().withUser("user").password("password").roles("USER"); } - } - @Test - public void enableGlobalAuthenticationWhenNoConfigurationAnnotationThenBeanProxyingEnabled() { - this.spring.register(BeanProxyEnabledByDefaultConfig.class).autowire(); - - Child childBean = this.spring.getContext().getBean(Child.class); - Parent parentBean = this.spring.getContext().getBean(Parent.class); - - assertThat(parentBean.getChild()).isSameAs(childBean); } @EnableGlobalAuthentication static class BeanProxyEnabledByDefaultConfig { + @Bean - public Child child() { + Child child() { return new Child(); } @Bean - public Parent parent() { + Parent parent() { return new Parent(child()); } - } - @Test - public void enableGlobalAuthenticationWhenProxyBeanMethodsFalseThenBeanProxyingDisabled() { - this.spring.register(BeanProxyDisabledConfig.class).autowire(); - - Child childBean = this.spring.getContext().getBean(Child.class); - Parent parentBean = this.spring.getContext().getBean(Parent.class); - - assertThat(parentBean.getChild()).isNotSameAs(childBean); } @Configuration(proxyBeanMethods = false) @EnableGlobalAuthentication static class BeanProxyDisabledConfig { + @Bean - public Child child() { + Child child() { return new Child(); } @Bean - public Parent parent() { + Parent parent() { return new Parent(child()); } + } static class Parent { + private Child child; Parent(Child child) { this.child = child; } - public Child getChild() { - return child; + Child getChild() { + return this.child; } + } static class Child { + Child() { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurerTest.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurerTests.java similarity index 77% rename from config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurerTest.java rename to config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurerTests.java index f3f33a91d9..40e9f8e149 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurerTest.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/ldap/LdapAuthenticationProviderConfigurerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.authentication.configurers.ldap; import org.junit.Before; @@ -24,21 +25,21 @@ import org.springframework.security.core.authority.mapping.SimpleAuthorityMapper import static org.assertj.core.api.Assertions.assertThat; -public class LdapAuthenticationProviderConfigurerTest { +public class LdapAuthenticationProviderConfigurerTests { private LdapAuthenticationProviderConfigurer configurer; @Before public void setUp() { - configurer = new LdapAuthenticationProviderConfigurer<>(); + this.configurer = new LdapAuthenticationProviderConfigurer<>(); } // SEC-2557 @Test public void getAuthoritiesMapper() throws Exception { - assertThat(configurer.getAuthoritiesMapper()).isInstanceOf(SimpleAuthorityMapper.class); - configurer.authoritiesMapper(new NullAuthoritiesMapper()); - assertThat(configurer.getAuthoritiesMapper()).isInstanceOf(NullAuthoritiesMapper.class); - + assertThat(this.configurer.getAuthoritiesMapper()).isInstanceOf(SimpleAuthorityMapper.class); + this.configurer.authoritiesMapper(new NullAuthoritiesMapper()); + assertThat(this.configurer.getAuthoritiesMapper()).isInstanceOf(NullAuthoritiesMapper.class); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurerTests.java index 501c6cfe79..ba43d04f0c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/authentication/configurers/provisioning/UserDetailsManagerConfigurerTests.java @@ -16,44 +16,44 @@ package org.springframework.security.config.annotation.authentication.configurers.provisioning; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Arrays; import org.junit.Before; import org.junit.Test; + import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.provisioning.InMemoryUserDetailsManager; +import static org.assertj.core.api.Assertions.assertThat; + /** -* -* @author Rob Winch -* @author Adolfo Eloy -*/ + * @author Rob Winch + * @author Adolfo Eloy + */ public class UserDetailsManagerConfigurerTests { private InMemoryUserDetailsManager userDetailsManager; @Before public void setup() { - userDetailsManager = new InMemoryUserDetailsManager(); + this.userDetailsManager = new InMemoryUserDetailsManager(); } @Test public void allAttributesSupported() { - UserDetails userDetails = new UserDetailsManagerConfigurer>(userDetailsManager) - .withUser("user") - .password("password") - .roles("USER") - .disabled(true) - .accountExpired(true) - .accountLocked(true) - .credentialsExpired(true) - .build(); - + // @formatter:off + UserDetails userDetails = configurer() + .withUser("user") + .password("password") + .roles("USER") + .disabled(true) + .accountExpired(true) + .accountLocked(true) + .credentialsExpired(true) + .build(); + // @formatter:on assertThat(userDetails.getUsername()).isEqualTo("user"); assertThat(userDetails.getPassword()).isEqualTo("password"); assertThat(userDetails.getAuthorities().stream().findFirst().get().getAuthority()).isEqualTo("ROLE_USER"); @@ -66,42 +66,45 @@ public class UserDetailsManagerConfigurerTests { @Test public void authoritiesWithGrantedAuthorityWorks() { SimpleGrantedAuthority authority = new SimpleGrantedAuthority("ROLE_USER"); - - UserDetails userDetails = new UserDetailsManagerConfigurer>(userDetailsManager) - .withUser("user") + // @formatter:off + UserDetails userDetails = configurer() + .withUser("user") .password("password") .authorities(authority) .build(); - + // @formatter:on assertThat(userDetails.getAuthorities().stream().findFirst().get()).isEqualTo(authority); } @Test public void authoritiesWithStringAuthorityWorks() { String authority = "ROLE_USER"; - - UserDetails userDetails = new UserDetailsManagerConfigurer>(userDetailsManager) - .withUser("user") + // @formatter:off + UserDetails userDetails = configurer() + .withUser("user") .password("password") .authorities(authority) .build(); - + // @formatter:on assertThat(userDetails.getAuthorities().stream().findFirst().get().getAuthority()).isEqualTo(authority); } @Test public void authoritiesWithAListOfGrantedAuthorityWorks() { SimpleGrantedAuthority authority = new SimpleGrantedAuthority("ROLE_USER"); - - UserDetails userDetails = new UserDetailsManagerConfigurer>(userDetailsManager) - .withUser("user") + // @formatter:off + UserDetails userDetails = configurer() + .withUser("user") .password("password") .authorities(Arrays.asList(authority)) .build(); - + // @formatter:on assertThat(userDetails.getAuthorities().stream().findFirst().get()).isEqualTo(authority); } + + private UserDetailsManagerConfigurer> configurer() { + return new UserDetailsManagerConfigurer>( + this.userDetailsManager); + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/configuration/AroundMethodInterceptor.java b/config/src/test/java/org/springframework/security/config/annotation/configuration/AroundMethodInterceptor.java index be4f491a96..4dff4bfca5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/configuration/AroundMethodInterceptor.java +++ b/config/src/test/java/org/springframework/security/config/annotation/configuration/AroundMethodInterceptor.java @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.configuration; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; public class AroundMethodInterceptor implements MethodInterceptor { + + @Override public Object invoke(MethodInvocation methodInvocation) throws Throwable { return String.valueOf(methodInvocation.proceed()); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessorTests.java b/config/src/test/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessorTests.java index 9f51ff9038..c0efa4c5e7 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessorTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/configuration/AutowireBeanFactoryObjectPostProcessorTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.configuration; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanClassLoaderAware; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.DisposableBean; @@ -33,16 +35,16 @@ import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.test.SpringTestRule; import org.springframework.web.context.ServletContextAware; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.isNotNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; /** - * * @author Rob Winch */ public class AutowireBeanFactoryObjectPostProcessorTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -52,7 +54,6 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenApplicationContextAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - ApplicationContextAware toPostProcess = mock(ApplicationContextAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setApplicationContext(isNotNull()); @@ -61,17 +62,14 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenApplicationEventPublisherAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - ApplicationEventPublisherAware toPostProcess = mock(ApplicationEventPublisherAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setApplicationEventPublisher(isNotNull()); - } @Test public void postProcessWhenBeanClassLoaderAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - BeanClassLoaderAware toPostProcess = mock(BeanClassLoaderAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setBeanClassLoader(isNotNull()); @@ -80,7 +78,6 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenBeanFactoryAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - BeanFactoryAware toPostProcess = mock(BeanFactoryAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setBeanFactory(isNotNull()); @@ -89,7 +86,6 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenEnvironmentAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - EnvironmentAware toPostProcess = mock(EnvironmentAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setEnvironment(isNotNull()); @@ -98,7 +94,6 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenMessageSourceAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - MessageSourceAware toPostProcess = mock(MessageSourceAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setMessageSource(isNotNull()); @@ -107,7 +102,6 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenServletContextAwareThenAwareInvoked() { this.spring.register(Config.class).autowire(); - ServletContextAware toPostProcess = mock(ServletContextAware.class); this.objectObjectPostProcessor.postProcess(toPostProcess); verify(toPostProcess).setServletContext(isNotNull()); @@ -116,62 +110,62 @@ public class AutowireBeanFactoryObjectPostProcessorTests { @Test public void postProcessWhenDisposableBeanThenAwareInvoked() throws Exception { this.spring.register(Config.class).autowire(); - DisposableBean toPostProcess = mock(DisposableBean.class); this.objectObjectPostProcessor.postProcess(toPostProcess); - this.spring.getContext().close(); - verify(toPostProcess).destroy(); } - @Configuration - static class Config { - @Bean - public ObjectPostProcessor objectPostProcessor(AutowireCapableBeanFactory beanFactory) { - return new AutowireBeanFactoryObjectPostProcessor(beanFactory); - } - } - @Test public void postProcessWhenSmartInitializingSingletonThenAwareInvoked() { this.spring.register(Config.class, SmartConfig.class).autowire(); - SmartConfig config = this.spring.getContext().getBean(SmartConfig.class); - verify(config.toTest).afterSingletonsInstantiated(); } - @Configuration - static class SmartConfig { - SmartInitializingSingleton toTest = mock(SmartInitializingSingleton.class); - - @Autowired - public void configure(ObjectPostProcessor p) { - p.postProcess(this.toTest); - } - } - @Test // SEC-2382 public void autowireBeanFactoryWhenBeanNameAutoProxyCreatorThenWorks() { this.spring.testConfigLocations("AutowireBeanFactoryObjectPostProcessorTests-aopconfig.xml").autowire(); - MyAdvisedBean bean = this.spring.getContext().getBean(MyAdvisedBean.class); - assertThat(bean.doStuff()).isEqualTo("null"); } @Configuration - static class WithBeanNameAutoProxyCreatorConfig { + static class Config { + @Bean - public ObjectPostProcessor objectPostProcessor(AutowireCapableBeanFactory beanFactory) { + ObjectPostProcessor objectPostProcessor(AutowireCapableBeanFactory beanFactory) { + return new AutowireBeanFactoryObjectPostProcessor(beanFactory); + } + + } + + @Configuration + static class SmartConfig { + + SmartInitializingSingleton toTest = mock(SmartInitializingSingleton.class); + + @Autowired + void configure(ObjectPostProcessor p) { + p.postProcess(this.toTest); + } + + } + + @Configuration + static class WithBeanNameAutoProxyCreatorConfig { + + @Bean + ObjectPostProcessor objectPostProcessor(AutowireCapableBeanFactory beanFactory) { return new AutowireBeanFactoryObjectPostProcessor(beanFactory); } @Autowired - public void configure(ObjectPostProcessor p) { + void configure(ObjectPostProcessor p) { p.postProcess(new Object()); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/configuration/MyAdvisedBean.java b/config/src/test/java/org/springframework/security/config/annotation/configuration/MyAdvisedBean.java index d68c075ee9..ababb41c1c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/configuration/MyAdvisedBean.java +++ b/config/src/test/java/org/springframework/security/config/annotation/configuration/MyAdvisedBean.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.configuration; public class MyAdvisedBean { @@ -20,4 +21,5 @@ public class MyAdvisedBean { public Object doStuff() { return null; } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/issue50/ApplicationConfig.java b/config/src/test/java/org/springframework/security/config/annotation/issue50/ApplicationConfig.java index 168d33d687..47535d1ce5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/issue50/ApplicationConfig.java +++ b/config/src/test/java/org/springframework/security/config/annotation/issue50/ApplicationConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.issue50; import javax.sql.DataSource; @@ -38,6 +39,7 @@ import org.springframework.transaction.annotation.EnableTransactionManagement; @EnableJpaRepositories("org.springframework.security.config.annotation.issue50.repo") @EnableTransactionManagement public class ApplicationConfig { + @Bean public DataSource dataSource() { EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder(); @@ -50,12 +52,10 @@ public class ApplicationConfig { vendorAdapter.setDatabase(Database.HSQL); vendorAdapter.setGenerateDdl(true); vendorAdapter.setShowSql(true); - LocalContainerEntityManagerFactoryBean factory = new LocalContainerEntityManagerFactoryBean(); factory.setJpaVendorAdapter(vendorAdapter); factory.setPackagesToScan(User.class.getPackage().getName()); factory.setDataSource(dataSource()); - return factory; } @@ -65,4 +65,5 @@ public class ApplicationConfig { txManager.setEntityManagerFactory(entityManagerFactory().getObject()); return txManager; } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/issue50/Issue50Tests.java b/config/src/test/java/org/springframework/security/config/annotation/issue50/Issue50Tests.java index 3d6f172921..647ca8f834 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/issue50/Issue50Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/issue50/Issue50Tests.java @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.issue50; +import javax.transaction.Transactional; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationManager; @@ -33,8 +37,6 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; -import javax.transaction.Transactional; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -43,8 +45,9 @@ import static org.assertj.core.api.Assertions.assertThat; */ @Transactional @RunWith(SpringRunner.class) -@ContextConfiguration(classes = {ApplicationConfig.class, SecurityConfig.class}) +@ContextConfiguration(classes = { ApplicationConfig.class, SecurityConfig.class }) public class Issue50Tests { + @Autowired private AuthenticationManager authenticationManager; @@ -53,7 +56,8 @@ public class Issue50Tests { @Before public void setup() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("test", null, "ROLE_ADMIN")); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("test", null, "ROLE_ADMIN")); } @After @@ -82,7 +86,7 @@ public class Issue50Tests { public void authenticateWhenValidUserThenAuthenticates() { this.userRepo.save(User.withUsernameAndPassword("test", "password")); Authentication result = this.authenticationManager - .authenticate(new UsernamePasswordAuthenticationToken("test", "password")); + .authenticate(new UsernamePasswordAuthenticationToken("test", "password")); assertThat(result.getName()).isEqualTo("test"); } @@ -91,6 +95,7 @@ public class Issue50Tests { SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("test", null, "ROLE_USER")); this.userRepo.save(User.withUsernameAndPassword("denied", "password")); Authentication result = this.authenticationManager - .authenticate(new UsernamePasswordAuthenticationToken("test", "password")); + .authenticate(new UsernamePasswordAuthenticationToken("test", "password")); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/issue50/SecurityConfig.java b/config/src/test/java/org/springframework/security/config/annotation/issue50/SecurityConfig.java index 6c9e56e5c6..de54b6d5ad 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/issue50/SecurityConfig.java +++ b/config/src/test/java/org/springframework/security/config/annotation/issue50/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.issue50; import org.springframework.beans.factory.annotation.Autowired; @@ -42,25 +43,26 @@ import org.springframework.util.Assert; @EnableGlobalMethodSecurity(prePostEnabled = true) @Configuration public class SecurityConfig extends WebSecurityConfigurerAdapter { + @Autowired private UserRepository myUserRepository; - // @formatter:off @Override protected void configure(AuthenticationManagerBuilder auth) { + // @formatter:off auth .authenticationProvider(authenticationProvider()); + // @formatter:on } - // @formatter:on - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/*").permitAll(); + // @formatter:on } - // @formatter:on @Bean @Override @@ -70,20 +72,20 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Bean public AuthenticationProvider authenticationProvider() { - Assert.notNull(myUserRepository); + Assert.notNull(this.myUserRepository); return new AuthenticationProvider() { + @Override public boolean supports(Class authentication) { return true; } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { Object principal = authentication.getPrincipal(); String username = String.valueOf(principal); - User user = myUserRepository.findByUsername(username); + User user = SecurityConfig.this.myUserRepository.findByUsername(username); if (user == null) { - throw new UsernameNotFoundException("No user for principal " - + principal); + throw new UsernameNotFoundException("No user for principal " + principal); } if (!authentication.getCredentials().equals(user.getPassword())) { throw new BadCredentialsException("Invalid password"); @@ -92,4 +94,5 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { } }; } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/issue50/domain/User.java b/config/src/test/java/org/springframework/security/config/annotation/issue50/domain/User.java index 7d4a7dabf2..d30ada88c4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/issue50/domain/User.java +++ b/config/src/test/java/org/springframework/security/config/annotation/issue50/domain/User.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.issue50.domain; import javax.persistence.Entity; @@ -36,7 +37,7 @@ public class User { private String password; public Long getId() { - return id; + return this.id; } public void setId(Long id) { @@ -44,7 +45,7 @@ public class User { } public String getUsername() { - return username; + return this.username; } public void setUsername(String username) { @@ -52,7 +53,7 @@ public class User { } public String getPassword() { - return password; + return this.password; } public void setPassword(String password) { @@ -63,6 +64,7 @@ public class User { User user = new User(); user.setUsername(username); user.setPassword(password); - return user; + return user; } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/issue50/repo/UserRepository.java b/config/src/test/java/org/springframework/security/config/annotation/issue50/repo/UserRepository.java index 08a6f05176..984be0a815 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/issue50/repo/UserRepository.java +++ b/config/src/test/java/org/springframework/security/config/annotation/issue50/repo/UserRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.issue50.repo; import org.springframework.data.repository.CrudRepository; @@ -27,4 +28,5 @@ public interface UserRepository extends CrudRepository { @PreAuthorize("hasRole('ROLE_ADMIN')") User findByUsername(String username); -} \ No newline at end of file + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/Authz.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/Authz.java index d4be78f42a..c206ef37f2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/Authz.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/Authz.java @@ -35,7 +35,7 @@ public class Authz { } public boolean check(Authentication authentication, String message) { - return message != null && - message.contains(authentication.getName()); + return message != null && message.contains(authentication.getName()); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/DelegatingReactiveMessageService.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/DelegatingReactiveMessageService.java index 11e9e42a1a..23d1df8069 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/DelegatingReactiveMessageService.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/DelegatingReactiveMessageService.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import org.reactivestreams.Publisher; -import org.springframework.security.access.prepost.PostAuthorize; -import org.springframework.security.access.prepost.PreAuthorize; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.security.access.prepost.PostAuthorize; +import org.springframework.security.access.prepost.PreAuthorize; + public class DelegatingReactiveMessageService implements ReactiveMessageService { + private final ReactiveMessageService delegate; public DelegatingReactiveMessageService(ReactiveMessageService delegate) { @@ -36,100 +39,89 @@ public class DelegatingReactiveMessageService implements ReactiveMessageService @Override public Mono monoFindById(long id) { - return delegate.monoFindById(id); + return this.delegate.monoFindById(id); } @Override @PreAuthorize("hasRole('ADMIN')") - public Mono monoPreAuthorizeHasRoleFindById( - long id) { - return delegate.monoPreAuthorizeHasRoleFindById(id); + public Mono monoPreAuthorizeHasRoleFindById(long id) { + return this.delegate.monoPreAuthorizeHasRoleFindById(id); } @Override @PostAuthorize("returnObject?.contains(authentication?.name)") - public Mono monoPostAuthorizeFindById( - long id) { - return delegate.monoPostAuthorizeFindById(id); + public Mono monoPostAuthorizeFindById(long id) { + return this.delegate.monoPostAuthorizeFindById(id); } @Override @PreAuthorize("@authz.check(#id)") - public Mono monoPreAuthorizeBeanFindById( - long id) { - return delegate.monoPreAuthorizeBeanFindById(id); + public Mono monoPreAuthorizeBeanFindById(long id) { + return this.delegate.monoPreAuthorizeBeanFindById(id); } @Override @PostAuthorize("@authz.check(authentication, returnObject)") - public Mono monoPostAuthorizeBeanFindById( - long id) { - return delegate.monoPostAuthorizeBeanFindById(id); + public Mono monoPostAuthorizeBeanFindById(long id) { + return this.delegate.monoPostAuthorizeBeanFindById(id); } @Override public Flux fluxFindById(long id) { - return delegate.fluxFindById(id); + return this.delegate.fluxFindById(id); } @Override @PreAuthorize("hasRole('ADMIN')") - public Flux fluxPreAuthorizeHasRoleFindById( - long id) { - return delegate.fluxPreAuthorizeHasRoleFindById(id); + public Flux fluxPreAuthorizeHasRoleFindById(long id) { + return this.delegate.fluxPreAuthorizeHasRoleFindById(id); } @Override @PostAuthorize("returnObject?.contains(authentication?.name)") - public Flux fluxPostAuthorizeFindById( - long id) { - return delegate.fluxPostAuthorizeFindById(id); + public Flux fluxPostAuthorizeFindById(long id) { + return this.delegate.fluxPostAuthorizeFindById(id); } @Override @PreAuthorize("@authz.check(#id)") - public Flux fluxPreAuthorizeBeanFindById( - long id) { - return delegate.fluxPreAuthorizeBeanFindById(id); + public Flux fluxPreAuthorizeBeanFindById(long id) { + return this.delegate.fluxPreAuthorizeBeanFindById(id); } @Override @PostAuthorize("@authz.check(authentication, returnObject)") - public Flux fluxPostAuthorizeBeanFindById( - long id) { - return delegate.fluxPostAuthorizeBeanFindById(id); + public Flux fluxPostAuthorizeBeanFindById(long id) { + return this.delegate.fluxPostAuthorizeBeanFindById(id); } @Override public Publisher publisherFindById(long id) { - return delegate.publisherFindById(id); + return this.delegate.publisherFindById(id); } @Override @PreAuthorize("hasRole('ADMIN')") - public Publisher publisherPreAuthorizeHasRoleFindById( - long id) { - return delegate.publisherPreAuthorizeHasRoleFindById(id); + public Publisher publisherPreAuthorizeHasRoleFindById(long id) { + return this.delegate.publisherPreAuthorizeHasRoleFindById(id); } @Override @PostAuthorize("returnObject?.contains(authentication?.name)") - public Publisher publisherPostAuthorizeFindById( - long id) { - return delegate.publisherPostAuthorizeFindById(id); + public Publisher publisherPostAuthorizeFindById(long id) { + return this.delegate.publisherPostAuthorizeFindById(id); } @Override @PreAuthorize("@authz.check(#id)") - public Publisher publisherPreAuthorizeBeanFindById( - long id) { - return delegate.publisherPreAuthorizeBeanFindById(id); + public Publisher publisherPreAuthorizeBeanFindById(long id) { + return this.delegate.publisherPreAuthorizeBeanFindById(id); } @Override @PostAuthorize("@authz.check(authentication, returnObject)") - public Publisher publisherPostAuthorizeBeanFindById( - long id) { - return delegate.publisherPostAuthorizeBeanFindById(id); + public Publisher publisherPostAuthorizeBeanFindById(long id) { + return this.delegate.publisherPostAuthorizeBeanFindById(id); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurityTests.java index 9f87b52d4a..83ff811ce7 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurityTests.java @@ -16,11 +16,16 @@ package org.springframework.security.config.annotation.method.configuration; -import org.assertj.core.api.AssertionsForClassTypes; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.util.context.Context; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.access.AccessDeniedException; @@ -28,14 +33,12 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import reactor.test.publisher.TestPublisher; -import reactor.util.context.Context; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; /** * @author Rob Winch @@ -44,16 +47,23 @@ import static org.mockito.Mockito.*; @RunWith(SpringRunner.class) @ContextConfiguration public class EnableReactiveMethodSecurityTests { - @Autowired ReactiveMessageService messageService; + + @Autowired + ReactiveMessageService messageService; + ReactiveMessageService delegate; + TestPublisher result = TestPublisher.create(); - Context withAdmin = ReactiveSecurityContextHolder.withAuthentication(new TestingAuthenticationToken("admin", "password", "ROLE_USER", "ROLE_ADMIN")); - Context withUser = ReactiveSecurityContextHolder.withAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); + Context withAdmin = ReactiveSecurityContextHolder + .withAuthentication(new TestingAuthenticationToken("admin", "password", "ROLE_USER", "ROLE_ADMIN")); + + Context withUser = ReactiveSecurityContextHolder + .withAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); @After public void cleanup() { - reset(delegate); + reset(this.delegate); } @Autowired @@ -63,516 +73,332 @@ public class EnableReactiveMethodSecurityTests { @Test public void notPublisherPreAuthorizeFindByIdThenThrowsIllegalStateException() { - assertThatThrownBy(() -> this.messageService.notPublisherPreAuthorizeFindById(1L)) - .isInstanceOf(IllegalStateException.class) - .extracting(Throwable::getMessage) - .isEqualTo("The returnType class java.lang.String on public abstract java.lang.String org.springframework.security.config.annotation.method.configuration.ReactiveMessageService.notPublisherPreAuthorizeFindById(long) must return an instance of org.reactivestreams.Publisher (i.e. Mono / Flux) in order to support Reactor Context"); + assertThatIllegalStateException().isThrownBy(() -> this.messageService.notPublisherPreAuthorizeFindById(1L)) + .withMessage("The returnType class java.lang.String on public abstract java.lang.String " + + "org.springframework.security.config.annotation.method.configuration.ReactiveMessageService" + + ".notPublisherPreAuthorizeFindById(long) must return an instance of org.reactivestreams" + + ".Publisher (i.e. Mono / Flux) in order to support Reactor Context"); } @Test public void monoWhenPermitAllThenAopDoesNotSubscribe() { - when(this.delegate.monoFindById(1L)).thenReturn(Mono.from(result)); - + given(this.delegate.monoFindById(1L)).willReturn(Mono.from(this.result)); this.delegate.monoFindById(1L); - - result.assertNoSubscribers(); + this.result.assertNoSubscribers(); } @Test public void monoWhenPermitAllThenSuccess() { - when(this.delegate.monoFindById(1L)).thenReturn(Mono.just("success")); - - StepVerifier.create(this.delegate.monoFindById(1L)) - .expectNext("success") - .verifyComplete(); + given(this.delegate.monoFindById(1L)).willReturn(Mono.just("success")); + StepVerifier.create(this.delegate.monoFindById(1L)).expectNext("success").verifyComplete(); } @Test public void monoPreAuthorizeHasRoleWhenGrantedThenSuccess() { - when(this.delegate.monoPreAuthorizeHasRoleFindById(1L)).thenReturn(Mono.just("result")); - + given(this.delegate.monoPreAuthorizeHasRoleFindById(1L)).willReturn(Mono.just("result")); Mono findById = this.messageService.monoPreAuthorizeHasRoleFindById(1L) - .subscriberContext(withAdmin); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + .subscriberContext(this.withAdmin); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void monoPreAuthorizeHasRoleWhenNoAuthenticationThenDenied() { - when(this.delegate.monoPreAuthorizeHasRoleFindById(1L)).thenReturn(Mono.from(result)); - + given(this.delegate.monoPreAuthorizeHasRoleFindById(1L)).willReturn(Mono.from(this.result)); Mono findById = this.messageService.monoPreAuthorizeHasRoleFindById(1L); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void monoPreAuthorizeHasRoleWhenNotAuthorizedThenDenied() { - when(this.delegate.monoPreAuthorizeHasRoleFindById(1L)).thenReturn(Mono.from(result)); - + given(this.delegate.monoPreAuthorizeHasRoleFindById(1L)).willReturn(Mono.from(this.result)); Mono findById = this.messageService.monoPreAuthorizeHasRoleFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void monoPreAuthorizeBeanWhenGrantedThenSuccess() { - when(this.delegate.monoPreAuthorizeBeanFindById(2L)).thenReturn(Mono.just("result")); - - Mono findById = this.messageService.monoPreAuthorizeBeanFindById(2L) - .subscriberContext(withAdmin); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + given(this.delegate.monoPreAuthorizeBeanFindById(2L)).willReturn(Mono.just("result")); + Mono findById = this.messageService.monoPreAuthorizeBeanFindById(2L).subscriberContext(this.withAdmin); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void monoPreAuthorizeBeanWhenNotAuthenticatedAndGrantedThenSuccess() { - when(this.delegate.monoPreAuthorizeBeanFindById(2L)).thenReturn(Mono.just("result")); - + given(this.delegate.monoPreAuthorizeBeanFindById(2L)).willReturn(Mono.just("result")); Mono findById = this.messageService.monoPreAuthorizeBeanFindById(2L); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void monoPreAuthorizeBeanWhenNoAuthenticationThenDenied() { - when(this.delegate.monoPreAuthorizeBeanFindById(1L)).thenReturn(Mono.from(result)); - + given(this.delegate.monoPreAuthorizeBeanFindById(1L)).willReturn(Mono.from(this.result)); Mono findById = this.messageService.monoPreAuthorizeBeanFindById(1L); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void monoPreAuthorizeBeanWhenNotAuthorizedThenDenied() { - when(this.delegate.monoPreAuthorizeBeanFindById(1L)).thenReturn(Mono.from(result)); - - Mono findById = this.messageService.monoPreAuthorizeBeanFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + given(this.delegate.monoPreAuthorizeBeanFindById(1L)).willReturn(Mono.from(this.result)); + Mono findById = this.messageService.monoPreAuthorizeBeanFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void monoPostAuthorizeWhenAuthorizedThenSuccess() { - when(this.delegate.monoPostAuthorizeFindById(1L)).thenReturn(Mono.just("user")); - - Mono findById = this.messageService.monoPostAuthorizeFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectNext("user") - .verifyComplete(); + given(this.delegate.monoPostAuthorizeFindById(1L)).willReturn(Mono.just("user")); + Mono findById = this.messageService.monoPostAuthorizeFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectNext("user").verifyComplete(); } @Test public void monoPostAuthorizeWhenNotAuthorizedThenDenied() { - when(this.delegate.monoPostAuthorizeBeanFindById(1L)).thenReturn(Mono.just("not-authorized")); - - Mono findById = this.messageService.monoPostAuthorizeBeanFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); + given(this.delegate.monoPostAuthorizeBeanFindById(1L)).willReturn(Mono.just("not-authorized")); + Mono findById = this.messageService.monoPostAuthorizeBeanFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); } @Test public void monoPostAuthorizeWhenBeanAndAuthorizedThenSuccess() { - when(this.delegate.monoPostAuthorizeBeanFindById(2L)).thenReturn(Mono.just("user")); - - Mono findById = this.messageService.monoPostAuthorizeBeanFindById(2L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectNext("user") - .verifyComplete(); + given(this.delegate.monoPostAuthorizeBeanFindById(2L)).willReturn(Mono.just("user")); + Mono findById = this.messageService.monoPostAuthorizeBeanFindById(2L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectNext("user").verifyComplete(); } @Test public void monoPostAuthorizeWhenBeanAndNotAuthenticatedAndAuthorizedThenSuccess() { - when(this.delegate.monoPostAuthorizeBeanFindById(2L)).thenReturn(Mono.just("anonymous")); - + given(this.delegate.monoPostAuthorizeBeanFindById(2L)).willReturn(Mono.just("anonymous")); Mono findById = this.messageService.monoPostAuthorizeBeanFindById(2L); - StepVerifier - .create(findById) - .expectNext("anonymous") - .verifyComplete(); + StepVerifier.create(findById).expectNext("anonymous").verifyComplete(); } @Test public void monoPostAuthorizeWhenBeanAndNotAuthorizedThenDenied() { - when(this.delegate.monoPostAuthorizeBeanFindById(1L)).thenReturn(Mono.just("not-authorized")); - - Mono findById = this.messageService.monoPostAuthorizeBeanFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); + given(this.delegate.monoPostAuthorizeBeanFindById(1L)).willReturn(Mono.just("not-authorized")); + Mono findById = this.messageService.monoPostAuthorizeBeanFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); } // Flux tests - @Test public void fluxWhenPermitAllThenAopDoesNotSubscribe() { - when(this.delegate.fluxFindById(1L)).thenReturn(Flux.from(result)); - + given(this.delegate.fluxFindById(1L)).willReturn(Flux.from(this.result)); this.delegate.fluxFindById(1L); - - result.assertNoSubscribers(); + this.result.assertNoSubscribers(); } @Test public void fluxWhenPermitAllThenSuccess() { - when(this.delegate.fluxFindById(1L)).thenReturn(Flux.just("success")); - - StepVerifier.create(this.delegate.fluxFindById(1L)) - .expectNext("success") - .verifyComplete(); + given(this.delegate.fluxFindById(1L)).willReturn(Flux.just("success")); + StepVerifier.create(this.delegate.fluxFindById(1L)).expectNext("success").verifyComplete(); } @Test public void fluxPreAuthorizeHasRoleWhenGrantedThenSuccess() { - when(this.delegate.fluxPreAuthorizeHasRoleFindById(1L)).thenReturn(Flux.just("result")); - + given(this.delegate.fluxPreAuthorizeHasRoleFindById(1L)).willReturn(Flux.just("result")); Flux findById = this.messageService.fluxPreAuthorizeHasRoleFindById(1L) - .subscriberContext(withAdmin); - StepVerifier - .create(findById) - .consumeNextWith( s -> AssertionsForClassTypes.assertThat(s).isEqualTo("result")) - .verifyComplete(); + .subscriberContext(this.withAdmin); + StepVerifier.create(findById).consumeNextWith((s) -> assertThat(s).isEqualTo("result")).verifyComplete(); } @Test public void fluxPreAuthorizeHasRoleWhenNoAuthenticationThenDenied() { - when(this.delegate.fluxPreAuthorizeHasRoleFindById(1L)).thenReturn(Flux.from(result)); - + given(this.delegate.fluxPreAuthorizeHasRoleFindById(1L)).willReturn(Flux.from(this.result)); Flux findById = this.messageService.fluxPreAuthorizeHasRoleFindById(1L); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void fluxPreAuthorizeHasRoleWhenNotAuthorizedThenDenied() { - when(this.delegate.fluxPreAuthorizeHasRoleFindById(1L)).thenReturn(Flux.from(result)); - + given(this.delegate.fluxPreAuthorizeHasRoleFindById(1L)).willReturn(Flux.from(this.result)); Flux findById = this.messageService.fluxPreAuthorizeHasRoleFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void fluxPreAuthorizeBeanWhenGrantedThenSuccess() { - when(this.delegate.fluxPreAuthorizeBeanFindById(2L)).thenReturn(Flux.just("result")); - - Flux findById = this.messageService.fluxPreAuthorizeBeanFindById(2L) - .subscriberContext(withAdmin); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + given(this.delegate.fluxPreAuthorizeBeanFindById(2L)).willReturn(Flux.just("result")); + Flux findById = this.messageService.fluxPreAuthorizeBeanFindById(2L).subscriberContext(this.withAdmin); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void fluxPreAuthorizeBeanWhenNotAuthenticatedAndGrantedThenSuccess() { - when(this.delegate.fluxPreAuthorizeBeanFindById(2L)).thenReturn(Flux.just("result")); - + given(this.delegate.fluxPreAuthorizeBeanFindById(2L)).willReturn(Flux.just("result")); Flux findById = this.messageService.fluxPreAuthorizeBeanFindById(2L); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void fluxPreAuthorizeBeanWhenNoAuthenticationThenDenied() { - when(this.delegate.fluxPreAuthorizeBeanFindById(1L)).thenReturn(Flux.from(result)); - + given(this.delegate.fluxPreAuthorizeBeanFindById(1L)).willReturn(Flux.from(this.result)); Flux findById = this.messageService.fluxPreAuthorizeBeanFindById(1L); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void fluxPreAuthorizeBeanWhenNotAuthorizedThenDenied() { - when(this.delegate.fluxPreAuthorizeBeanFindById(1L)).thenReturn(Flux.from(result)); - - Flux findById = this.messageService.fluxPreAuthorizeBeanFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + given(this.delegate.fluxPreAuthorizeBeanFindById(1L)).willReturn(Flux.from(this.result)); + Flux findById = this.messageService.fluxPreAuthorizeBeanFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void fluxPostAuthorizeWhenAuthorizedThenSuccess() { - when(this.delegate.fluxPostAuthorizeFindById(1L)).thenReturn(Flux.just("user")); - - Flux findById = this.messageService.fluxPostAuthorizeFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectNext("user") - .verifyComplete(); + given(this.delegate.fluxPostAuthorizeFindById(1L)).willReturn(Flux.just("user")); + Flux findById = this.messageService.fluxPostAuthorizeFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectNext("user").verifyComplete(); } @Test public void fluxPostAuthorizeWhenNotAuthorizedThenDenied() { - when(this.delegate.fluxPostAuthorizeBeanFindById(1L)).thenReturn(Flux.just("not-authorized")); - - Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); + given(this.delegate.fluxPostAuthorizeBeanFindById(1L)).willReturn(Flux.just("not-authorized")); + Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); } @Test public void fluxPostAuthorizeWhenBeanAndAuthorizedThenSuccess() { - when(this.delegate.fluxPostAuthorizeBeanFindById(2L)).thenReturn(Flux.just("user")); - - Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(2L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectNext("user") - .verifyComplete(); + given(this.delegate.fluxPostAuthorizeBeanFindById(2L)).willReturn(Flux.just("user")); + Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(2L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectNext("user").verifyComplete(); } @Test public void fluxPostAuthorizeWhenBeanAndNotAuthenticatedAndAuthorizedThenSuccess() { - when(this.delegate.fluxPostAuthorizeBeanFindById(2L)).thenReturn(Flux.just("anonymous")); - + given(this.delegate.fluxPostAuthorizeBeanFindById(2L)).willReturn(Flux.just("anonymous")); Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(2L); - StepVerifier - .create(findById) - .expectNext("anonymous") - .verifyComplete(); + StepVerifier.create(findById).expectNext("anonymous").verifyComplete(); } @Test public void fluxPostAuthorizeWhenBeanAndNotAuthorizedThenDenied() { - when(this.delegate.fluxPostAuthorizeBeanFindById(1L)).thenReturn(Flux.just("not-authorized")); - - Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(1L) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); + given(this.delegate.fluxPostAuthorizeBeanFindById(1L)).willReturn(Flux.just("not-authorized")); + Flux findById = this.messageService.fluxPostAuthorizeBeanFindById(1L).subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); } // Publisher tests - @Test public void publisherWhenPermitAllThenAopDoesNotSubscribe() { - when(this.delegate.publisherFindById(1L)).thenReturn(result); - + given(this.delegate.publisherFindById(1L)).willReturn(this.result); this.delegate.publisherFindById(1L); - - result.assertNoSubscribers(); + this.result.assertNoSubscribers(); } @Test public void publisherWhenPermitAllThenSuccess() { - when(this.delegate.publisherFindById(1L)).thenReturn(publisherJust("success")); - - StepVerifier.create(this.delegate.publisherFindById(1L)) - .expectNext("success") - .verifyComplete(); + given(this.delegate.publisherFindById(1L)).willReturn(publisherJust("success")); + StepVerifier.create(this.delegate.publisherFindById(1L)).expectNext("success").verifyComplete(); } @Test public void publisherPreAuthorizeHasRoleWhenGrantedThenSuccess() { - when(this.delegate.publisherPreAuthorizeHasRoleFindById(1L)).thenReturn(publisherJust("result")); - + given(this.delegate.publisherPreAuthorizeHasRoleFindById(1L)).willReturn(publisherJust("result")); Publisher findById = Flux.from(this.messageService.publisherPreAuthorizeHasRoleFindById(1L)) - .subscriberContext(withAdmin); - StepVerifier - .create(findById) - .consumeNextWith( s -> AssertionsForClassTypes.assertThat(s).isEqualTo("result")) - .verifyComplete(); + .subscriberContext(this.withAdmin); + StepVerifier.create(findById).consumeNextWith((s) -> assertThat(s).isEqualTo("result")).verifyComplete(); } @Test public void publisherPreAuthorizeHasRoleWhenNoAuthenticationThenDenied() { - when(this.delegate.publisherPreAuthorizeHasRoleFindById(1L)).thenReturn(result); - + given(this.delegate.publisherPreAuthorizeHasRoleFindById(1L)).willReturn(this.result); Publisher findById = this.messageService.publisherPreAuthorizeHasRoleFindById(1L); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void publisherPreAuthorizeHasRoleWhenNotAuthorizedThenDenied() { - when(this.delegate.publisherPreAuthorizeHasRoleFindById(1L)).thenReturn(result); - + given(this.delegate.publisherPreAuthorizeHasRoleFindById(1L)).willReturn(this.result); Publisher findById = Flux.from(this.messageService.publisherPreAuthorizeHasRoleFindById(1L)) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void publisherPreAuthorizeBeanWhenGrantedThenSuccess() { - when(this.delegate.publisherPreAuthorizeBeanFindById(2L)).thenReturn(publisherJust("result")); - + given(this.delegate.publisherPreAuthorizeBeanFindById(2L)).willReturn(publisherJust("result")); Publisher findById = Flux.from(this.messageService.publisherPreAuthorizeBeanFindById(2L)) - .subscriberContext(withAdmin); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + .subscriberContext(this.withAdmin); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void publisherPreAuthorizeBeanWhenNotAuthenticatedAndGrantedThenSuccess() { - when(this.delegate.publisherPreAuthorizeBeanFindById(2L)).thenReturn(publisherJust("result")); - + given(this.delegate.publisherPreAuthorizeBeanFindById(2L)).willReturn(publisherJust("result")); Publisher findById = this.messageService.publisherPreAuthorizeBeanFindById(2L); - StepVerifier - .create(findById) - .expectNext("result") - .verifyComplete(); + StepVerifier.create(findById).expectNext("result").verifyComplete(); } @Test public void publisherPreAuthorizeBeanWhenNoAuthenticationThenDenied() { - when(this.delegate.publisherPreAuthorizeBeanFindById(1L)).thenReturn(result); - + given(this.delegate.publisherPreAuthorizeBeanFindById(1L)).willReturn(this.result); Publisher findById = this.messageService.publisherPreAuthorizeBeanFindById(1L); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void publisherPreAuthorizeBeanWhenNotAuthorizedThenDenied() { - when(this.delegate.publisherPreAuthorizeBeanFindById(1L)).thenReturn(result); - + given(this.delegate.publisherPreAuthorizeBeanFindById(1L)).willReturn(this.result); Publisher findById = Flux.from(this.messageService.publisherPreAuthorizeBeanFindById(1L)) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); - - result.assertNoSubscribers(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); + this.result.assertNoSubscribers(); } @Test public void publisherPostAuthorizeWhenAuthorizedThenSuccess() { - when(this.delegate.publisherPostAuthorizeFindById(1L)).thenReturn(publisherJust("user")); - + given(this.delegate.publisherPostAuthorizeFindById(1L)).willReturn(publisherJust("user")); Publisher findById = Flux.from(this.messageService.publisherPostAuthorizeFindById(1L)) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectNext("user") - .verifyComplete(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectNext("user").verifyComplete(); } @Test public void publisherPostAuthorizeWhenNotAuthorizedThenDenied() { - when(this.delegate.publisherPostAuthorizeBeanFindById(1L)).thenReturn(publisherJust("not-authorized")); - + given(this.delegate.publisherPostAuthorizeBeanFindById(1L)).willReturn(publisherJust("not-authorized")); Publisher findById = Flux.from(this.messageService.publisherPostAuthorizeBeanFindById(1L)) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); } @Test public void publisherPostAuthorizeWhenBeanAndAuthorizedThenSuccess() { - when(this.delegate.publisherPostAuthorizeBeanFindById(2L)).thenReturn(publisherJust("user")); - + given(this.delegate.publisherPostAuthorizeBeanFindById(2L)).willReturn(publisherJust("user")); Publisher findById = Flux.from(this.messageService.publisherPostAuthorizeBeanFindById(2L)) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectNext("user") - .verifyComplete(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectNext("user").verifyComplete(); } @Test public void publisherPostAuthorizeWhenBeanAndNotAuthenticatedAndAuthorizedThenSuccess() { - when(this.delegate.publisherPostAuthorizeBeanFindById(2L)).thenReturn(publisherJust("anonymous")); - + given(this.delegate.publisherPostAuthorizeBeanFindById(2L)).willReturn(publisherJust("anonymous")); Publisher findById = this.messageService.publisherPostAuthorizeBeanFindById(2L); - StepVerifier - .create(findById) - .expectNext("anonymous") - .verifyComplete(); + StepVerifier.create(findById).expectNext("anonymous").verifyComplete(); } @Test public void publisherPostAuthorizeWhenBeanAndNotAuthorizedThenDenied() { - when(this.delegate.publisherPostAuthorizeBeanFindById(1L)).thenReturn(publisherJust("not-authorized")); - + given(this.delegate.publisherPostAuthorizeBeanFindById(1L)).willReturn(publisherJust("not-authorized")); Publisher findById = Flux.from(this.messageService.publisherPostAuthorizeBeanFindById(1L)) - .subscriberContext(withUser); - StepVerifier - .create(findById) - .expectError(AccessDeniedException.class) - .verify(); + .subscriberContext(this.withUser); + StepVerifier.create(findById).expectError(AccessDeniedException.class).verify(); } static Publisher publisher(Flux flux) { - return subscriber -> flux.subscribe(subscriber); + return (subscriber) -> flux.subscribe(subscriber); } static Publisher publisherJust(T... data) { @@ -581,16 +407,19 @@ public class EnableReactiveMethodSecurityTests { @EnableReactiveMethodSecurity static class Config { + ReactiveMessageService delegate = mock(ReactiveMessageService.class); @Bean - public DelegatingReactiveMessageService defaultMessageService() { - return new DelegatingReactiveMessageService(delegate); + DelegatingReactiveMessageService defaultMessageService() { + return new DelegatingReactiveMessageService(this.delegate); } @Bean - public Authz authz() { + Authz authz() { return new Authz(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java index 221a63d24c..5b4ce0106e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import java.lang.reflect.Proxy; import java.util.HashMap; import java.util.Map; + import javax.sql.DataSource; import org.aopalliance.intercept.MethodInterceptor; @@ -59,22 +61,22 @@ import org.springframework.transaction.annotation.EnableTransactionManagement; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** - * * @author Rob Winch * @author Artsiom Yudovin */ @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class GlobalMethodSecurityConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -100,197 +102,78 @@ public class GlobalMethodSecurityConfigurationTests { this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire(); } - @EnableGlobalMethodSecurity - public static class IllegalStateGlobalMethodSecurityConfig extends GlobalMethodSecurityConfiguration { - - } - @Test public void configureWhenGlobalMethodSecurityHasCustomMetadataSourceThenNoEnablingAttributeIsNeeded() { this.spring.register(CustomMetadataSourceConfig.class).autowire(); } - @EnableGlobalMethodSecurity - public static class CustomMetadataSourceConfig extends GlobalMethodSecurityConfiguration { - @Bean - @Override - protected MethodSecurityMetadataSource customMethodSecurityMetadataSource() { - return mock(MethodSecurityMetadataSource.class); - } - } - @Test public void methodSecurityAuthenticationManagerPublishesEvent() { this.spring.register(InMemoryAuthWithGlobalMethodSecurityConfig.class).autowire(); - try { this.authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("foo", "bar")); - } catch(AuthenticationException e) {} - - assertThat(this.events.getEvents()).extracting(Object::getClass).containsOnly((Class) AuthenticationFailureBadCredentialsEvent.class); - } - - @EnableGlobalMethodSecurity(prePostEnabled = true) - public static class InMemoryAuthWithGlobalMethodSecurityConfig extends GlobalMethodSecurityConfiguration { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication(); } - - @Bean - public MockEventListener listener() { - return new MockEventListener() {}; + catch (AuthenticationException ex) { } + assertThat(this.events.getEvents()).extracting(Object::getClass) + .containsOnly((Class) AuthenticationFailureBadCredentialsEvent.class); } @Test @WithMockUser public void methodSecurityWhenAuthenticationTrustResolverIsBeanThenAutowires() { this.spring.register(CustomTrustResolverConfig.class).autowire(); - AuthenticationTrustResolver trustResolver = this.spring.getContext().getBean(AuthenticationTrustResolver.class); - when(trustResolver.isAnonymous(any())).thenReturn(true, false); - - assertThatThrownBy(() -> this.service.preAuthorizeNotAnonymous()) - .isInstanceOf(AccessDeniedException.class); - + given(trustResolver.isAnonymous(any())).willReturn(true, false); + assertThatExceptionOfType(AccessDeniedException.class) + .isThrownBy(() -> this.service.preAuthorizeNotAnonymous()); this.service.preAuthorizeNotAnonymous(); - verify(trustResolver, atLeastOnce()).isAnonymous(any()); } - @EnableGlobalMethodSecurity(prePostEnabled = true) - static class CustomTrustResolverConfig { - - @Bean - public AuthenticationTrustResolver trustResolver() { - return mock(AuthenticationTrustResolver.class); - } - - @Bean - public MethodSecurityServiceImpl service() { - return new MethodSecurityServiceImpl(); - } - } - // SEC-2301 @Test @WithMockUser public void defaultWebSecurityExpressionHandlerHasBeanResolverSet() { this.spring.register(ExpressionHandlerHasBeanResolverSetConfig.class).autowire(); Authz authz = this.spring.getContext().getBean(Authz.class); - - assertThatThrownBy(() -> this.service.preAuthorizeBean(false)) - .isInstanceOf(AccessDeniedException.class); - + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorizeBean(false)); this.service.preAuthorizeBean(true); } - @EnableGlobalMethodSecurity(prePostEnabled = true, proxyTargetClass = true) - static class ExpressionHandlerHasBeanResolverSetConfig { - - @Bean - public MethodSecurityServiceImpl service() { - return new MethodSecurityServiceImpl(); - } - - @Bean - public Authz authz() { - return new Authz(); - } - } - @Test @WithMockUser public void methodSecuritySupportsAnnotaitonsOnInterfaceParamerNames() { this.spring.register(MethodSecurityServiceConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.postAnnotation("deny")) - .isInstanceOf(AccessDeniedException.class); - + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.postAnnotation("deny")); this.service.postAnnotation("grant"); // no exception } - @EnableGlobalMethodSecurity(prePostEnabled = true) - static class MethodSecurityServiceConfig { - - @Bean - public MethodSecurityService service() { - return new MethodSecurityServiceImpl(); - } - } - @Test @WithMockUser public void globalMethodSecurityConfigurationAutowiresPermissionEvaluator() { this.spring.register(AutowirePermissionEvaluatorConfig.class).autowire(); PermissionEvaluator permission = this.spring.getContext().getBean(PermissionEvaluator.class); - when(permission.hasPermission(any(), eq("something"), eq("read"))).thenReturn(true, false); - + given(permission.hasPermission(any(), eq("something"), eq("read"))).willReturn(true, false); this.service.hasPermission("something"); // no exception - - assertThatThrownBy(() -> this.service.hasPermission("something")) - .isInstanceOf(AccessDeniedException.class); - } - - @EnableGlobalMethodSecurity(prePostEnabled = true) - public static class AutowirePermissionEvaluatorConfig { - - @Bean - public PermissionEvaluator permissionEvaluator() { - return mock(PermissionEvaluator.class); - } - - @Bean - public MethodSecurityService service() { - return new MethodSecurityServiceImpl(); - } + assertThatExceptionOfType(AccessDeniedException.class) + .isThrownBy(() -> this.service.hasPermission("something")); } @Test public void multiPermissionEvaluatorConfig() { this.spring.register(MultiPermissionEvaluatorConfig.class).autowire(); - // no exception } - @EnableGlobalMethodSecurity(prePostEnabled = true) - public static class MultiPermissionEvaluatorConfig { - - @Bean - public PermissionEvaluator permissionEvaluator() { - return mock(PermissionEvaluator.class); - } - - @Bean - public PermissionEvaluator permissionEvaluator2() { - return mock(PermissionEvaluator.class); - } - } - // SEC-2425 @Test @WithMockUser public void enableGlobalMethodSecurityWorksOnSuperclass() { this.spring.register(ChildConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); - } - - @Configuration - static class ChildConfig extends ParentConfig {} - - @EnableGlobalMethodSecurity(prePostEnabled = true) - static class ParentConfig { - - @Bean - public MethodSecurityService service() { - return new MethodSecurityServiceImpl(); - } + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); } // SEC-2479 @@ -305,58 +188,254 @@ public class GlobalMethodSecurityConfigurationTests { child.register(Sec2479ChildConfig.class); child.refresh(); this.spring.context(child).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); } } } - @Configuration - static class Sec2479ParentConfig { - @Bean - public AuthenticationManager am() { - return mock(AuthenticationManager.class); - } - } - - @EnableGlobalMethodSecurity(prePostEnabled = true) - static class Sec2479ChildConfig { - @Bean - public MethodSecurityService service() { - return new MethodSecurityServiceImpl(); - } - } - @Test public void enableGlobalMethodSecurityDoesNotTriggerEagerInitializationOfBeansInGlobalAuthenticationConfigurer() { this.spring.register(Sec2815Config.class).autowire(); - MockBeanPostProcessor pp = this.spring.getContext().getBean(MockBeanPostProcessor.class); - assertThat(pp.beforeInit).containsKeys("dataSource"); assertThat(pp.afterInit).containsKeys("dataSource"); } - @EnableGlobalMethodSecurity(prePostEnabled = true) - static class Sec2815Config { + // SEC-3045 + @Test + public void globalSecurityProxiesSecurity() { + this.spring.register(Sec3005Config.class).autowire(); + assertThat(this.service.getClass()).matches((c) -> !Proxy.isProxyClass(c), "is not proxy class"); + } + + // + // // gh-3797 + // def preAuthorizeBeanSpel() { + // setup: + // SecurityContextHolder.getContext().setAuthentication( + // new TestingAuthenticationToken("user", "password","ROLE_USER")) + // context = new AnnotationConfigApplicationContext(PreAuthorizeBeanSpelConfig) + // BeanSpelService service = context.getBean(BeanSpelService) + // when: + // service.run(true) + // then: + // noExceptionThrown() + // when: + // service.run(false) + // then: + // thrown(AccessDeniedException) + // } + // + @Test + @WithMockUser + public void preAuthorizeBeanSpel() { + this.spring.register(PreAuthorizeBeanSpelConfig.class).autowire(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorizeBean(false)); + this.service.preAuthorizeBean(true); + } + + // gh-3394 + @Test + @WithMockUser + public void roleHierarchy() { + this.spring.register(RoleHierarchyConfig.class).autowire(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + this.service.preAuthorizeAdmin(); + } + + @Test + @WithMockUser(authorities = "ROLE:USER") + public void grantedAuthorityDefaultsAutowires() { + this.spring.register(CustomGrantedAuthorityConfig.class).autowire(); + CustomGrantedAuthorityConfig.CustomAuthorityService customService = this.spring.getContext() + .getBean(CustomGrantedAuthorityConfig.CustomAuthorityService.class); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + customService.customPrefixRoleUser(); + // no exception + } + + @Test + @WithMockUser(authorities = "USER") + public void grantedAuthorityDefaultsWithEmptyRolePrefix() { + this.spring.register(EmptyRolePrefixGrantedAuthorityConfig.class).autowire(); + EmptyRolePrefixGrantedAuthorityConfig.CustomAuthorityService customService = this.spring.getContext() + .getBean(EmptyRolePrefixGrantedAuthorityConfig.CustomAuthorityService.class); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.securedUser()); + customService.emptyPrefixRoleUser(); + // no exception + } + + @Test + public void methodSecurityInterceptorUsesMetadataSourceBeanWhenProxyingDisabled() { + this.spring.register(CustomMetadataSourceBeanProxyEnabledConfig.class).autowire(); + MethodSecurityInterceptor methodInterceptor = (MethodSecurityInterceptor) this.spring.getContext() + .getBean(MethodInterceptor.class); + MethodSecurityMetadataSource methodSecurityMetadataSource = this.spring.getContext() + .getBean(MethodSecurityMetadataSource.class); + assertThat(methodInterceptor.getSecurityMetadataSource()).isSameAs(methodSecurityMetadataSource); + } + + @EnableGlobalMethodSecurity + public static class IllegalStateGlobalMethodSecurityConfig extends GlobalMethodSecurityConfiguration { + + } + + @EnableGlobalMethodSecurity + public static class CustomMetadataSourceConfig extends GlobalMethodSecurityConfiguration { + @Bean - public MethodSecurityService service() { + @Override + protected MethodSecurityMetadataSource customMethodSecurityMetadataSource() { + return mock(MethodSecurityMetadataSource.class); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + public static class InMemoryAuthWithGlobalMethodSecurityConfig extends GlobalMethodSecurityConfiguration { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication(); + // @formatter:on + } + + @Bean + public MockEventListener listener() { + return new MockEventListener() { + }; + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + static class CustomTrustResolverConfig { + + @Bean + AuthenticationTrustResolver trustResolver() { + return mock(AuthenticationTrustResolver.class); + } + + @Bean + MethodSecurityServiceImpl service() { + return new MethodSecurityServiceImpl(); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true, proxyTargetClass = true) + static class ExpressionHandlerHasBeanResolverSetConfig { + + @Bean + MethodSecurityServiceImpl service() { return new MethodSecurityServiceImpl(); } @Bean - public MockBeanPostProcessor mockBeanPostProcessor() { + Authz authz() { + return new Authz(); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + static class MethodSecurityServiceConfig { + + @Bean + MethodSecurityService service() { + return new MethodSecurityServiceImpl(); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + public static class AutowirePermissionEvaluatorConfig { + + @Bean + PermissionEvaluator permissionEvaluator() { + return mock(PermissionEvaluator.class); + } + + @Bean + MethodSecurityService service() { + return new MethodSecurityServiceImpl(); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + public static class MultiPermissionEvaluatorConfig { + + @Bean + PermissionEvaluator permissionEvaluator() { + return mock(PermissionEvaluator.class); + } + + @Bean + PermissionEvaluator permissionEvaluator2() { + return mock(PermissionEvaluator.class); + } + + } + + @Configuration + static class ChildConfig extends ParentConfig { + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + static class ParentConfig { + + @Bean + MethodSecurityService service() { + return new MethodSecurityServiceImpl(); + } + + } + + @Configuration + static class Sec2479ParentConfig { + + @Bean + AuthenticationManager am() { + return mock(AuthenticationManager.class); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + static class Sec2479ChildConfig { + + @Bean + MethodSecurityService service() { + return new MethodSecurityServiceImpl(); + } + + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) + static class Sec2815Config { + + @Bean + MethodSecurityService service() { + return new MethodSecurityServiceImpl(); + } + + @Bean + MockBeanPostProcessor mockBeanPostProcessor() { return new MockBeanPostProcessor(); } @Bean - public DataSource dataSource() { + DataSource dataSource() { return mock(DataSource.class); } @Configuration static class AuthConfig extends GlobalAuthenticationConfigurerAdapter { + @Autowired DataSource dataSource; @@ -364,16 +443,19 @@ public class GlobalMethodSecurityConfigurationTests { public void init(AuthenticationManagerBuilder auth) throws Exception { auth.inMemoryAuthentication(); } + } + } static class MockBeanPostProcessor implements BeanPostProcessor { + Map beforeInit = new HashMap<>(); + Map afterInit = new HashMap<>(); @Override - public Object postProcessBeforeInitialization(Object bean, String beanName) throws - BeansException { + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { this.beforeInit.put(beanName, bean); return bean; } @@ -383,63 +465,29 @@ public class GlobalMethodSecurityConfigurationTests { this.afterInit.put(beanName, bean); return bean; } + } - // SEC-3045 - @Test - public void globalSecurityProxiesSecurity() { - this.spring.register(Sec3005Config.class).autowire(); - - assertThat(this.service.getClass()).matches(c-> !Proxy.isProxyClass(c), "is not proxy class"); - } - - @EnableGlobalMethodSecurity(prePostEnabled = true, mode= AdviceMode.ASPECTJ) + @EnableGlobalMethodSecurity(prePostEnabled = true, mode = AdviceMode.ASPECTJ) @EnableTransactionManagement static class Sec3005Config { + @Bean - public MethodSecurityService service() { + MethodSecurityService service() { return new MethodSecurityServiceImpl(); } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { auth.inMemoryAuthentication(); } - } - // - // // gh-3797 - // def preAuthorizeBeanSpel() { - // setup: - // SecurityContextHolder.getContext().setAuthentication( - // new TestingAuthenticationToken("user", "password","ROLE_USER")) - // context = new AnnotationConfigApplicationContext(PreAuthorizeBeanSpelConfig) - // BeanSpelService service = context.getBean(BeanSpelService) - // when: - // service.run(true) - // then: - // noExceptionThrown() - // when: - // service.run(false) - // then: - // thrown(AccessDeniedException) - // } - // - - @Test - @WithMockUser - public void preAuthorizeBeanSpel() { - this.spring.register(PreAuthorizeBeanSpelConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorizeBean(false)) - .isInstanceOf(AccessDeniedException.class); - - this.service.preAuthorizeBean(true); } @Configuration @EnableGlobalMethodSecurity(prePostEnabled = true) public static class PreAuthorizeBeanSpelConfig { + @Bean MethodSecurityService service() { return new MethodSecurityServiceImpl(); @@ -449,22 +497,13 @@ public class GlobalMethodSecurityConfigurationTests { Authz authz() { return new Authz(); } - } - // gh-3394 - @Test - @WithMockUser - public void roleHierarchy() { - this.spring.register(RoleHierarchyConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); - this.service.preAuthorizeAdmin(); } @EnableGlobalMethodSecurity(prePostEnabled = true) @Configuration public static class RoleHierarchyConfig { + @Bean MethodSecurityService service() { return new MethodSecurityServiceImpl(); @@ -476,98 +515,69 @@ public class GlobalMethodSecurityConfigurationTests { result.setHierarchy("ROLE_USER > ROLE_ADMIN"); return result; } - } - @Test - @WithMockUser(authorities = "ROLE:USER") - public void grantedAuthorityDefaultsAutowires() { - this.spring.register(CustomGrantedAuthorityConfig.class).autowire(); - - CustomGrantedAuthorityConfig.CustomAuthorityService customService = this.spring.getContext().getBean( - CustomGrantedAuthorityConfig.CustomAuthorityService.class); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); - - customService.customPrefixRoleUser(); - // no exception } @EnableGlobalMethodSecurity(prePostEnabled = true) static class CustomGrantedAuthorityConfig { @Bean - public GrantedAuthorityDefaults ga() { + GrantedAuthorityDefaults ga() { return new GrantedAuthorityDefaults("ROLE:"); } @Bean - public CustomAuthorityService service() { + CustomAuthorityService service() { return new CustomAuthorityService(); } @Bean - public MethodSecurityServiceImpl methodSecurityService() { + MethodSecurityServiceImpl methodSecurityService() { return new MethodSecurityServiceImpl(); } static class CustomAuthorityService { + @PreAuthorize("hasRole('ROLE:USER')") - public void customPrefixRoleUser() {} + void customPrefixRoleUser() { + } + } - } - @Test - @WithMockUser(authorities = "USER") - public void grantedAuthorityDefaultsWithEmptyRolePrefix() { - this.spring.register(EmptyRolePrefixGrantedAuthorityConfig.class).autowire(); - - EmptyRolePrefixGrantedAuthorityConfig.CustomAuthorityService customService = this.spring.getContext() - .getBean(EmptyRolePrefixGrantedAuthorityConfig.CustomAuthorityService.class); - - assertThatThrownBy(() -> this.service.securedUser()) - .isInstanceOf(AccessDeniedException.class); - - customService.emptyPrefixRoleUser(); - // no exception } @EnableGlobalMethodSecurity(securedEnabled = true) static class EmptyRolePrefixGrantedAuthorityConfig { + @Bean - public GrantedAuthorityDefaults ga() { + GrantedAuthorityDefaults ga() { return new GrantedAuthorityDefaults(""); } @Bean - public CustomAuthorityService service() { + CustomAuthorityService service() { return new CustomAuthorityService(); } @Bean - public MethodSecurityServiceImpl methodSecurityService() { + MethodSecurityServiceImpl methodSecurityService() { return new MethodSecurityServiceImpl(); } static class CustomAuthorityService { + @Secured("USER") - public void emptyPrefixRoleUser() {} + void emptyPrefixRoleUser() { + } + } - } - @Test - public void methodSecurityInterceptorUsesMetadataSourceBeanWhenProxyingDisabled() { - this.spring.register(CustomMetadataSourceBeanProxyEnabledConfig.class).autowire(); - MethodSecurityInterceptor methodInterceptor = - (MethodSecurityInterceptor) this.spring.getContext().getBean(MethodInterceptor.class); - MethodSecurityMetadataSource methodSecurityMetadataSource = - this.spring.getContext().getBean(MethodSecurityMetadataSource.class); - - assertThat(methodInterceptor.getSecurityMetadataSource()).isSameAs(methodSecurityMetadataSource); } @EnableGlobalMethodSecurity(prePostEnabled = true) @Configuration public static class CustomMetadataSourceBeanProxyEnabledConfig extends GlobalMethodSecurityConfiguration { + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityService.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityService.java index b558c9911d..525ce2a477 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityService.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityService.java @@ -16,19 +16,20 @@ package org.springframework.security.config.annotation.method.configuration; +import javax.annotation.security.DenyAll; +import javax.annotation.security.PermitAll; + import org.springframework.security.access.annotation.Secured; import org.springframework.security.access.prepost.PostAuthorize; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.security.core.Authentication; import org.springframework.security.core.parameters.P; -import javax.annotation.security.DenyAll; -import javax.annotation.security.PermitAll; - /** * @author Rob Winch */ public interface MethodSecurityService { + @PreAuthorize("denyAll") String preAuthorize(); @@ -67,4 +68,5 @@ public interface MethodSecurityService { @PostAuthorize("#o?.contains('grant')") String postAnnotation(@P("o") String object); + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceConfig.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceConfig.java index 0a5fd815eb..ee664f5a45 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceConfig.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import org.springframework.context.annotation.Bean; @@ -21,8 +22,10 @@ import org.springframework.context.annotation.Bean; * @author Josh Cummings */ public class MethodSecurityServiceConfig { + @Bean MethodSecurityService service() { return new MethodSecurityServiceImpl(); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceImpl.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceImpl.java index 1ce3a60f3a..94a05216bc 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceImpl.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/MethodSecurityServiceImpl.java @@ -23,6 +23,7 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Rob Winch */ public class MethodSecurityServiceImpl implements MethodSecurityService { + @Override public String preAuthorize() { return null; diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityExpressionHandlerTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityExpressionHandlerTests.java index 74cf91ab87..b9bf7080b8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityExpressionHandlerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityExpressionHandlerTests.java @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; +import java.io.Serializable; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.PermissionEvaluator; @@ -29,13 +33,10 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import java.io.Serializable; - -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** - * * @author Rob Winch * @author Josh Cummings */ @@ -53,43 +54,41 @@ public class NamespaceGlobalMethodSecurityExpressionHandlerTests { @WithMockUser public void methodSecurityWhenUsingCustomPermissionEvaluatorThenPreAuthorizesAccordingly() { this.spring.register(CustomAccessDecisionManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatCode(() -> this.service.hasPermission("granted")) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.service.hasPermission("denied")) - .isInstanceOf(AccessDeniedException.class); + assertThat(this.service.hasPermission("granted")).isNull(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.hasPermission("denied")); } @Test @WithMockUser public void methodSecurityWhenUsingCustomPermissionEvaluatorThenPostAuthorizesAccordingly() { this.spring.register(CustomAccessDecisionManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatCode(() -> this.service.postHasPermission("granted")) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.service.postHasPermission("denied")) - .isInstanceOf(AccessDeniedException.class); + assertThat(this.service.postHasPermission("granted")).isNull(); + assertThatExceptionOfType(AccessDeniedException.class) + .isThrownBy(() -> this.service.postHasPermission("denied")); } @EnableGlobalMethodSecurity(prePostEnabled = true) public static class CustomAccessDecisionManagerConfig extends GlobalMethodSecurityConfiguration { + @Override protected MethodSecurityExpressionHandler createExpressionHandler() { DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler(); - expressionHandler.setPermissionEvaluator(new PermissionEvaluator() { - public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Object targetDomainObject, + Object permission) { return "granted".equals(targetDomainObject); } - public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, + Object permission) { throw new UnsupportedOperationException(); } }); - return expressionHandler; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityTests.java index 1174a77d34..fd46256b4b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/NamespaceGlobalMethodSecurityTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import java.lang.reflect.Method; @@ -56,11 +57,9 @@ import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** - * * @author Rob Winch * @author Josh Cummings */ @@ -74,19 +73,166 @@ public class NamespaceGlobalMethodSecurityTests { @Autowired(required = false) private MethodSecurityService service; - // --- access-decision-manager-ref --- - @Test @WithMockUser public void methodSecurityWhenCustomAccessDecisionManagerThenAuthorizes() { this.spring.register(CustomAccessDecisionManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.secured()); + } - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); + @Test + @WithMockUser + public void methodSecurityWhenCustomAfterInvocationManagerThenAuthorizes() { + this.spring.register(CustomAfterInvocationManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorizePermitAll()); + } - assertThatThrownBy(() -> this.service.secured()) - .isInstanceOf(AccessDeniedException.class); + @Test + @WithMockUser + public void methodSecurityWhenCustomAuthenticationManagerThenAuthorizes() { + this.spring.register(CustomAuthenticationConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(() -> this.service.preAuthorize()); + } + @Test + @WithMockUser + public void methodSecurityWhenJsr250EnabledThenAuthorizes() { + this.spring.register(Jsr250Config.class, MethodSecurityServiceConfig.class).autowire(); + this.service.preAuthorize(); + this.service.secured(); + this.service.jsr250PermitAll(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.jsr250()); + } + + @Test + @WithMockUser + public void methodSecurityWhenCustomMethodSecurityMetadataSourceThenAuthorizes() { + this.spring.register(CustomMethodSecurityMetadataSourceConfig.class, MethodSecurityServiceConfig.class) + .autowire(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.secured()); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.jsr250()); + } + + @Test + @WithMockUser + public void contextRefreshWhenUsingAspectJThenAutowire() throws Exception { + this.spring.register(AspectJModeConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThat(this.spring.getContext().getBean( + Class.forName("org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"))) + .isNotNull(); + assertThat(this.spring.getContext().getBean(AspectJMethodSecurityInterceptor.class)).isNotNull(); + // TODO diagnose why aspectj isn't weaving method security advice around + // MethodSecurityServiceImpl + } + + @Test + public void contextRefreshWhenUsingAspectJAndCustomGlobalMethodSecurityConfigurationThenAutowire() + throws Exception { + this.spring.register(AspectJModeExtendsGMSCConfig.class).autowire(); + assertThat(this.spring.getContext().getBean( + Class.forName("org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"))) + .isNotNull(); + assertThat(this.spring.getContext().getBean(AspectJMethodSecurityInterceptor.class)).isNotNull(); + } + + @Test + @WithMockUser + public void methodSecurityWhenOrderSpecifiedThenConfigured() { + this.spring.register(CustomOrderConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThat(this.spring.getContext().getBean("metaDataSourceAdvisor", MethodSecurityMetadataSourceAdvisor.class) + .getOrder()).isEqualTo(-135); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.jsr250()); + } + + @Test + @WithMockUser + public void methodSecurityWhenOrderUnspecifiedThenConfiguredToLowestPrecedence() { + this.spring.register(DefaultOrderConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThat(this.spring.getContext().getBean("metaDataSourceAdvisor", MethodSecurityMetadataSourceAdvisor.class) + .getOrder()).isEqualTo(Ordered.LOWEST_PRECEDENCE); + assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(() -> this.service.jsr250()); + } + + @Test + @WithMockUser + public void methodSecurityWhenOrderUnspecifiedAndCustomGlobalMethodSecurityConfigurationThenConfiguredToLowestPrecedence() { + this.spring.register(DefaultOrderExtendsMethodSecurityConfig.class, MethodSecurityServiceConfig.class) + .autowire(); + assertThat(this.spring.getContext().getBean("metaDataSourceAdvisor", MethodSecurityMetadataSourceAdvisor.class) + .getOrder()).isEqualTo(Ordered.LOWEST_PRECEDENCE); + assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(() -> this.service.jsr250()); + } + + @Test + @WithMockUser + public void methodSecurityWhenPrePostEnabledThenPreAuthorizes() { + this.spring.register(PreAuthorizeConfig.class, MethodSecurityServiceConfig.class).autowire(); + this.service.secured(); + this.service.jsr250(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + } + + @Test + @WithMockUser + public void methodSecurityWhenPrePostEnabledAndCustomGlobalMethodSecurityConfigurationThenPreAuthorizes() { + this.spring.register(PreAuthorizeExtendsGMSCConfig.class, MethodSecurityServiceConfig.class).autowire(); + this.service.secured(); + this.service.jsr250(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + } + + @Test + @WithMockUser + public void methodSecurityWhenProxyTargetClassThenDoesNotWireToInterface() { + this.spring.register(ProxyTargetClassConfig.class, MethodSecurityServiceConfig.class).autowire(); + // make sure service was actually proxied + assertThat(this.service.getClass().getInterfaces()).doesNotContain(MethodSecurityService.class); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + } + + @Test + @WithMockUser + public void methodSecurityWhenDefaultProxyThenWiresToInterface() { + this.spring.register(DefaultProxyConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThat(this.service.getClass().getInterfaces()).contains(MethodSecurityService.class); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); + } + + @Test + @WithMockUser + public void methodSecurityWhenCustomRunAsManagerThenRunAsWrapsAuthentication() { + this.spring.register(CustomRunAsManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThat(this.service.runAs().getAuthorities()) + .anyMatch((authority) -> "ROLE_RUN_AS_SUPER".equals(authority.getAuthority())); + } + + @Test + @WithMockUser + public void methodSecurityWhenSecuredEnabledThenSecures() { + this.spring.register(SecuredConfig.class, MethodSecurityServiceConfig.class).autowire(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.secured()); + this.service.securedUser(); + this.service.preAuthorize(); + this.service.jsr250(); + } + + @Test + @WithMockUser + public void methodSecurityWhenMissingEnableAnnotationThenShowsHelpfulError() { + assertThatExceptionOfType(Exception.class) + .isThrownBy(() -> this.spring.register(ExtendsNoEnableAnntotationConfig.class).autowire()) + .withStackTraceContaining(EnableGlobalMethodSecurity.class.getName() + " is required"); + } + + @Test + @WithMockUser + public void methodSecurityWhenImportingGlobalMethodSecurityConfigurationSubclassThenAuthorizes() { + this.spring.register(ImportSubclassGMSCConfig.class, MethodSecurityServiceConfig.class).autowire(); + this.service.secured(); + this.service.jsr250(); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> this.service.preAuthorize()); } @EnableGlobalMethodSecurity(prePostEnabled = true, securedEnabled = true) @@ -98,32 +244,29 @@ public class NamespaceGlobalMethodSecurityTests { } public static class DenyAllAccessDecisionManager implements AccessDecisionManager { - public void decide(Authentication authentication, Object object, Collection configAttributes) { + + @Override + public void decide(Authentication authentication, Object object, + Collection configAttributes) { throw new AccessDeniedException("Always Denied"); } + + @Override public boolean supports(ConfigAttribute attribute) { return true; } + + @Override public boolean supports(Class clazz) { return true; } + } - } - // --- after-invocation-provider - - @Test - @WithMockUser - public void methodSecurityWhenCustomAfterInvocationManagerThenAuthorizes() { - this.spring.register(CustomAfterInvocationManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorizePermitAll()) - .isInstanceOf(AccessDeniedException.class); } @EnableGlobalMethodSecurity(prePostEnabled = true) - public static class CustomAfterInvocationManagerConfig - extends GlobalMethodSecurityConfiguration { + public static class CustomAfterInvocationManagerConfig extends GlobalMethodSecurityConfiguration { @Override protected AfterInvocationManager afterInvocationManager() { @@ -131,32 +274,25 @@ public class NamespaceGlobalMethodSecurityTests { } public static class AfterInvocationManagerStub implements AfterInvocationManager { - public Object decide(Authentication authentication, - Object object, - Collection attributes, - Object returnedObject) throws AccessDeniedException { + @Override + public Object decide(Authentication authentication, Object object, Collection attributes, + Object returnedObject) throws AccessDeniedException { throw new AccessDeniedException("custom AfterInvocationManager"); } + @Override public boolean supports(ConfigAttribute attribute) { return true; } + + @Override public boolean supports(Class clazz) { return true; } + } - } - // --- authentication-manager-ref --- - - @Test - @WithMockUser - public void methodSecurityWhenCustomAuthenticationManagerThenAuthorizes() { - this.spring.register(CustomAuthenticationConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(UnsupportedOperationException.class); } @EnableGlobalMethodSecurity(prePostEnabled = true) @@ -175,26 +311,6 @@ public class NamespaceGlobalMethodSecurityTests { throw new UnsupportedOperationException(); }; } - } - - // --- jsr250-annotations --- - - @Test - @WithMockUser - public void methodSecurityWhenJsr250EnabledThenAuthorizes() { - this.spring.register(Jsr250Config.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatCode(() -> this.service.preAuthorize()) - .doesNotThrowAnyException(); - - assertThatCode(() -> this.service.secured()) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.service.jsr250()) - .isInstanceOf(AccessDeniedException.class); - - assertThatCode(() -> this.service.jsr250PermitAll()) - .doesNotThrowAnyException(); } @@ -204,53 +320,27 @@ public class NamespaceGlobalMethodSecurityTests { } - // --- metadata-source-ref --- - - @Test - @WithMockUser - public void methodSecurityWhenCustomMethodSecurityMetadataSourceThenAuthorizes() { - this.spring.register(CustomMethodSecurityMetadataSourceConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); - - assertThatThrownBy(() -> this.service.secured()) - .isInstanceOf(AccessDeniedException.class); - - assertThatThrownBy(() -> this.service.jsr250()) - .isInstanceOf(AccessDeniedException.class); - } - @EnableGlobalMethodSecurity public static class CustomMethodSecurityMetadataSourceConfig extends GlobalMethodSecurityConfiguration { @Override protected MethodSecurityMetadataSource customMethodSecurityMetadataSource() { return new AbstractMethodSecurityMetadataSource() { + @Override public Collection getAttributes(Method method, Class targetClass) { - // require ROLE_NOBODY for any method on MethodSecurityService interface - return MethodSecurityService.class.isAssignableFrom(targetClass) ? - Arrays.asList(new SecurityConfig("ROLE_NOBODY")) : - Collections.emptyList(); + // require ROLE_NOBODY for any method on MethodSecurityService + // interface + return MethodSecurityService.class.isAssignableFrom(targetClass) + ? Arrays.asList(new SecurityConfig("ROLE_NOBODY")) : Collections.emptyList(); } + + @Override public Collection getAllConfigAttributes() { return null; } }; } - } - // --- mode --- - - @Test - @WithMockUser - public void contextRefreshWhenUsingAspectJThenAutowire() throws Exception { - this.spring.register(AspectJModeConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThat(this.spring.getContext().getBean(Class.forName("org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"))).isNotNull(); - assertThat(this.spring.getContext().getBean(AspectJMethodSecurityInterceptor.class)).isNotNull(); - - //TODO diagnose why aspectj isn't weaving method security advice around MethodSecurityServiceImpl } @EnableGlobalMethodSecurity(mode = AdviceMode.ASPECTJ, securedEnabled = true) @@ -258,64 +348,37 @@ public class NamespaceGlobalMethodSecurityTests { } - @Test - public void contextRefreshWhenUsingAspectJAndCustomGlobalMethodSecurityConfigurationThenAutowire() - throws Exception { - - this.spring.register(AspectJModeExtendsGMSCConfig.class).autowire(); - - assertThat(this.spring.getContext().getBean(Class.forName("org.springframework.security.access.intercept.aspectj.aspect.AnnotationSecurityAspect"))).isNotNull(); - assertThat(this.spring.getContext().getBean(AspectJMethodSecurityInterceptor.class)).isNotNull(); - - } - @EnableGlobalMethodSecurity(mode = AdviceMode.ASPECTJ, securedEnabled = true) public static class AspectJModeExtendsGMSCConfig extends GlobalMethodSecurityConfiguration { + } - // --- order --- - - private static class AdvisorOrderConfig - implements ImportBeanDefinitionRegistrar { - - private static class ExceptingInterceptor implements MethodInterceptor { - @Override - public Object invoke(MethodInvocation invocation) { - throw new UnsupportedOperationException("Deny All"); - } - } + private static class AdvisorOrderConfig implements ImportBeanDefinitionRegistrar { @Override - public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) { - BeanDefinitionBuilder advice = BeanDefinitionBuilder - .rootBeanDefinition(ExceptingInterceptor.class); - registry.registerBeanDefinition("exceptingInterceptor", - advice.getBeanDefinition()); - + public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, + BeanDefinitionRegistry registry) { + BeanDefinitionBuilder advice = BeanDefinitionBuilder.rootBeanDefinition(ExceptingInterceptor.class); + registry.registerBeanDefinition("exceptingInterceptor", advice.getBeanDefinition()); BeanDefinitionBuilder advisor = BeanDefinitionBuilder - .rootBeanDefinition(MethodSecurityMetadataSourceAdvisor.class); + .rootBeanDefinition(MethodSecurityMetadataSourceAdvisor.class); advisor.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); advisor.addConstructorArgValue("exceptingInterceptor"); advisor.addConstructorArgReference("methodSecurityMetadataSource"); advisor.addConstructorArgValue("methodSecurityMetadataSource"); advisor.addPropertyValue("order", 0); - registry.registerBeanDefinition("exceptingAdvisor", - advisor.getBeanDefinition()); + registry.registerBeanDefinition("exceptingAdvisor", advisor.getBeanDefinition()); } - } - @Test - @WithMockUser - public void methodSecurityWhenOrderSpecifiedThenConfigured() { - this.spring.register(CustomOrderConfig.class, MethodSecurityServiceConfig.class).autowire(); + private static class ExceptingInterceptor implements MethodInterceptor { - assertThat(this.spring.getContext() - .getBean("metaDataSourceAdvisor", MethodSecurityMetadataSourceAdvisor.class) - .getOrder()) - .isEqualTo(-135); + @Override + public Object invoke(MethodInvocation invocation) { + throw new UnsupportedOperationException("Deny All"); + } + + } - assertThatThrownBy(() -> this.service.jsr250()) - .isInstanceOf(AccessDeniedException.class); } @EnableGlobalMethodSecurity(order = -135, jsr250Enabled = true) @@ -324,128 +387,36 @@ public class NamespaceGlobalMethodSecurityTests { } - @Test - @WithMockUser - public void methodSecurityWhenOrderUnspecifiedThenConfiguredToLowestPrecedence() { - this.spring.register(DefaultOrderConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThat(this.spring.getContext() - .getBean("metaDataSourceAdvisor", MethodSecurityMetadataSourceAdvisor.class) - .getOrder()) - .isEqualTo(Ordered.LOWEST_PRECEDENCE); - - assertThatThrownBy(() -> this.service.jsr250()) - .isInstanceOf(UnsupportedOperationException.class); - } - @EnableGlobalMethodSecurity(jsr250Enabled = true) @Import(AdvisorOrderConfig.class) public static class DefaultOrderConfig { - } - @Test - @WithMockUser - public void methodSecurityWhenOrderUnspecifiedAndCustomGlobalMethodSecurityConfigurationThenConfiguredToLowestPrecedence() { - this.spring.register(DefaultOrderExtendsMethodSecurityConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThat(this.spring.getContext() - .getBean("metaDataSourceAdvisor", MethodSecurityMetadataSourceAdvisor.class) - .getOrder()) - .isEqualTo(Ordered.LOWEST_PRECEDENCE); - - assertThatThrownBy(() -> this.service.jsr250()) - .isInstanceOf(UnsupportedOperationException.class); } @EnableGlobalMethodSecurity(jsr250Enabled = true) @Import(AdvisorOrderConfig.class) public static class DefaultOrderExtendsMethodSecurityConfig extends GlobalMethodSecurityConfiguration { - } - // --- pre-post-annotations --- - - @Test - @WithMockUser - public void methodSecurityWhenPrePostEnabledThenPreAuthorizes() { - this.spring.register(PreAuthorizeConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatCode(() -> this.service.secured()) - .doesNotThrowAnyException(); - - assertThatCode(() -> this.service.jsr250()) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); } @EnableGlobalMethodSecurity(prePostEnabled = true) public static class PreAuthorizeConfig { - } - @Test - @WithMockUser - public void methodSecurityWhenPrePostEnabledAndCustomGlobalMethodSecurityConfigurationThenPreAuthorizes() { - this.spring.register(PreAuthorizeExtendsGMSCConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatCode(() -> this.service.secured()) - .doesNotThrowAnyException(); - - assertThatCode(() -> this.service.jsr250()) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); } @EnableGlobalMethodSecurity(prePostEnabled = true) public static class PreAuthorizeExtendsGMSCConfig extends GlobalMethodSecurityConfiguration { - } - // --- proxy-target-class --- - - @Test - @WithMockUser - public void methodSecurityWhenProxyTargetClassThenDoesNotWireToInterface() { - this.spring.register(ProxyTargetClassConfig.class, MethodSecurityServiceConfig.class).autowire(); - - // make sure service was actually proxied - assertThat(this.service.getClass().getInterfaces()) - .doesNotContain(MethodSecurityService.class); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); } @EnableGlobalMethodSecurity(proxyTargetClass = true, prePostEnabled = true) public static class ProxyTargetClassConfig { - } - @Test - @WithMockUser - public void methodSecurityWhenDefaultProxyThenWiresToInterface() { - this.spring.register(DefaultProxyConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThat(this.service.getClass().getInterfaces()) - .contains(MethodSecurityService.class); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); } @EnableGlobalMethodSecurity(prePostEnabled = true) public static class DefaultProxyConfig { - } - // --- run-as-manager-ref --- - - @Test - @WithMockUser - public void methodSecurityWhenCustomRunAsManagerThenRunAsWrapsAuthentication() { - this.spring.register(CustomRunAsManagerConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThat(service.runAs().getAuthorities()) - .anyMatch(authority -> "ROLE_RUN_AS_SUPER".equals(authority.getAuthority())); } @EnableGlobalMethodSecurity(securedEnabled = true) @@ -457,64 +428,23 @@ public class NamespaceGlobalMethodSecurityTests { runAsManager.setKey("some key"); return runAsManager; } - } - // --- secured-annotation --- - - @Test - @WithMockUser - public void methodSecurityWhenSecuredEnabledThenSecures() { - this.spring.register(SecuredConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatThrownBy(() -> this.service.secured()) - .isInstanceOf(AccessDeniedException.class); - - assertThatCode(() -> this.service.securedUser()) - .doesNotThrowAnyException(); - - assertThatCode(() -> this.service.preAuthorize()) - .doesNotThrowAnyException(); - - assertThatCode(() -> this.service.jsr250()) - .doesNotThrowAnyException(); } @EnableGlobalMethodSecurity(securedEnabled = true) public static class SecuredConfig { - } - // --- unsorted --- - - @Test - @WithMockUser - public void methodSecurityWhenMissingEnableAnnotationThenShowsHelpfulError() { - assertThatThrownBy(() -> - this.spring.register(ExtendsNoEnableAnntotationConfig.class).autowire()) - .hasStackTraceContaining(EnableGlobalMethodSecurity.class.getName() + " is required"); } @Configuration - public static class ExtendsNoEnableAnntotationConfig - extends GlobalMethodSecurityConfiguration { - } + public static class ExtendsNoEnableAnntotationConfig extends GlobalMethodSecurityConfiguration { - @Test - @WithMockUser - public void methodSecurityWhenImportingGlobalMethodSecurityConfigurationSubclassThenAuthorizes() { - this.spring.register(ImportSubclassGMSCConfig.class, MethodSecurityServiceConfig.class).autowire(); - - assertThatCode(() -> this.service.secured()) - .doesNotThrowAnyException(); - - assertThatCode(() -> this.service.jsr250()) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.service.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); } @Configuration @Import(PreAuthorizeExtendsGMSCConfig.class) public static class ImportSubclassGMSCConfig { + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMessageService.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMessageService.java index d4bc9d45c9..908014c65c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMessageService.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMessageService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; import org.reactivestreams.Publisher; @@ -20,23 +21,37 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public interface ReactiveMessageService { + String notPublisherPreAuthorizeFindById(long id); Mono monoFindById(long id); + Mono monoPreAuthorizeHasRoleFindById(long id); + Mono monoPostAuthorizeFindById(long id); + Mono monoPreAuthorizeBeanFindById(long id); + Mono monoPostAuthorizeBeanFindById(long id); Flux fluxFindById(long id); + Flux fluxPreAuthorizeHasRoleFindById(long id); + Flux fluxPostAuthorizeFindById(long id); + Flux fluxPreAuthorizeBeanFindById(long id); + Flux fluxPostAuthorizeBeanFindById(long id); Publisher publisherFindById(long id); + Publisher publisherPreAuthorizeHasRoleFindById(long id); + Publisher publisherPostAuthorizeFindById(long id); + Publisher publisherPreAuthorizeBeanFindById(long id); + Publisher publisherPostAuthorizeBeanFindById(long id); + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfigurationTests.java index 9df29cd18f..f4c7d66b2f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/ReactiveMethodSecurityConfigurationTests.java @@ -16,10 +16,9 @@ package org.springframework.security.config.annotation.method.configuration; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -31,6 +30,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.core.GrantedAuthorityDefaults; import org.springframework.security.config.test.SpringTestRule; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Tadaya Tsuyukubo */ @@ -45,16 +46,12 @@ public class ReactiveMethodSecurityConfigurationTests { @Test public void rolePrefixWithGrantedAuthorityDefaults() throws NoSuchMethodException { this.spring.register(WithRolePrefixConfiguration.class).autowire(); - - TestingAuthenticationToken authentication = new TestingAuthenticationToken( - "principal", "credential", "CUSTOM_ABC"); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("principal", "credential", + "CUSTOM_ABC"); MockMethodInvocation methodInvocation = new MockMethodInvocation(new Foo(), Foo.class, "bar", String.class); - - EvaluationContext context = this.methodSecurityExpressionHandler - .createEvaluationContext(authentication, methodInvocation); - SecurityExpressionRoot root = (SecurityExpressionRoot) context.getRootObject() - .getValue(); - + EvaluationContext context = this.methodSecurityExpressionHandler.createEvaluationContext(authentication, + methodInvocation); + SecurityExpressionRoot root = (SecurityExpressionRoot) context.getRootObject().getValue(); assertThat(root.hasRole("ROLE_ABC")).isFalse(); assertThat(root.hasRole("ROLE_CUSTOM_ABC")).isFalse(); assertThat(root.hasRole("CUSTOM_ABC")).isTrue(); @@ -64,16 +61,25 @@ public class ReactiveMethodSecurityConfigurationTests { @Test public void rolePrefixWithDefaultConfig() throws NoSuchMethodException { this.spring.register(ReactiveMethodSecurityConfiguration.class).autowire(); - - TestingAuthenticationToken authentication = new TestingAuthenticationToken( - "principal", "credential", "ROLE_ABC"); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("principal", "credential", + "ROLE_ABC"); MockMethodInvocation methodInvocation = new MockMethodInvocation(new Foo(), Foo.class, "bar", String.class); + EvaluationContext context = this.methodSecurityExpressionHandler.createEvaluationContext(authentication, + methodInvocation); + SecurityExpressionRoot root = (SecurityExpressionRoot) context.getRootObject().getValue(); + assertThat(root.hasRole("ROLE_ABC")).isTrue(); + assertThat(root.hasRole("ABC")).isTrue(); + } - EvaluationContext context = this.methodSecurityExpressionHandler - .createEvaluationContext(authentication, methodInvocation); - SecurityExpressionRoot root = (SecurityExpressionRoot) context.getRootObject() - .getValue(); - + @Test + public void rolePrefixWithGrantedAuthorityDefaultsAndSubclassWithProxyingEnabled() throws NoSuchMethodException { + this.spring.register(SubclassConfig.class).autowire(); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("principal", "credential", + "ROLE_ABC"); + MockMethodInvocation methodInvocation = new MockMethodInvocation(new Foo(), Foo.class, "bar", String.class); + EvaluationContext context = this.methodSecurityExpressionHandler.createEvaluationContext(authentication, + methodInvocation); + SecurityExpressionRoot root = (SecurityExpressionRoot) context.getRootObject().getValue(); assertThat(root.hasRole("ROLE_ABC")).isTrue(); assertThat(root.hasRole("ABC")).isTrue(); } @@ -81,35 +87,24 @@ public class ReactiveMethodSecurityConfigurationTests { @Configuration @EnableReactiveMethodSecurity // this imports ReactiveMethodSecurityConfiguration static class WithRolePrefixConfiguration { + @Bean GrantedAuthorityDefaults grantedAuthorityDefaults() { return new GrantedAuthorityDefaults("CUSTOM_"); } - } - @Test - public void rolePrefixWithGrantedAuthorityDefaultsAndSubclassWithProxyingEnabled() throws NoSuchMethodException { - this.spring.register(SubclassConfig.class).autowire(); - - TestingAuthenticationToken authentication = new TestingAuthenticationToken( - "principal", "credential", "ROLE_ABC"); - MockMethodInvocation methodInvocation = new MockMethodInvocation(new Foo(), Foo.class, "bar", String.class); - - EvaluationContext context = this.methodSecurityExpressionHandler - .createEvaluationContext(authentication, methodInvocation); - SecurityExpressionRoot root = (SecurityExpressionRoot) context.getRootObject() - .getValue(); - - assertThat(root.hasRole("ROLE_ABC")).isTrue(); - assertThat(root.hasRole("ABC")).isTrue(); } @Configuration static class SubclassConfig extends ReactiveMethodSecurityConfiguration { + } - private static class Foo { - public void bar(String param){ + static class Foo { + + public void bar(String param) { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/SampleEnableGlobalMethodSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/SampleEnableGlobalMethodSecurityTests.java index ea690af732..bf2bf1ae6c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/SampleEnableGlobalMethodSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/SampleEnableGlobalMethodSecurityTests.java @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.method.configuration; +import java.io.Serializable; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.access.AccessDeniedException; @@ -30,10 +34,8 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import java.io.Serializable; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Demonstrate the samples @@ -42,6 +44,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * */ public class SampleEnableGlobalMethodSecurityTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -50,51 +53,50 @@ public class SampleEnableGlobalMethodSecurityTests { @Before public void setup() { - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("user", "password", "ROLE_USER")); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); } @Test public void preAuthorize() { this.spring.register(SampleWebSecurityConfig.class).autowire(); - assertThat(this.methodSecurityService.secured()).isNull(); assertThat(this.methodSecurityService.jsr250()).isNull(); - - assertThatThrownBy(() -> this.methodSecurityService.preAuthorize()) - .isInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(AccessDeniedException.class) + .isThrownBy(() -> this.methodSecurityService.preAuthorize()); } - @EnableGlobalMethodSecurity(prePostEnabled=true) + @Test + public void customPermissionHandler() { + this.spring.register(CustomPermissionEvaluatorWebSecurityConfig.class).autowire(); + assertThat(this.methodSecurityService.hasPermission("allowed")).isNull(); + assertThatExceptionOfType(AccessDeniedException.class) + .isThrownBy(() -> this.methodSecurityService.hasPermission("denied")); + } + + @EnableGlobalMethodSecurity(prePostEnabled = true) static class SampleWebSecurityConfig { + @Bean - public MethodSecurityService methodSecurityService() { + MethodSecurityService methodSecurityService() { return new MethodSecurityServiceImpl(); } @Autowired protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } + } - - @Test - public void customPermissionHandler() { - this.spring.register(CustomPermissionEvaluatorWebSecurityConfig.class).autowire(); - - assertThat(this.methodSecurityService.hasPermission("allowed")).isNull(); - - assertThatThrownBy(() -> this.methodSecurityService.hasPermission("denied")) - .isInstanceOf(AccessDeniedException.class); - } - - - @EnableGlobalMethodSecurity(prePostEnabled=true) + @EnableGlobalMethodSecurity(prePostEnabled = true) public static class CustomPermissionEvaluatorWebSecurityConfig extends GlobalMethodSecurityConfiguration { + @Bean public MethodSecurityService methodSecurityService() { return new MethodSecurityServiceImpl(); @@ -109,23 +111,29 @@ public class SampleEnableGlobalMethodSecurityTests { @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } + } static class CustomPermissionEvaluator implements PermissionEvaluator { - public boolean hasPermission(Authentication authentication, - Object targetDomainObject, Object permission) { + + @Override + public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { return !"denied".equals(targetDomainObject); } - public boolean hasPermission(Authentication authentication, - Serializable targetId, String targetType, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, + Object permission) { return !"denied".equals(targetId); } } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/sec2758/Sec2758Tests.java b/config/src/test/java/org/springframework/security/config/annotation/sec2758/Sec2758Tests.java index 9cf8126e3e..bd94e55f45 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/sec2758/Sec2758Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/sec2758/Sec2758Tests.java @@ -13,15 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.sec2758; +import javax.annotation.security.RolesAllowed; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.context.annotation.Bean; +import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.security.access.annotation.Jsr250MethodSecurityMetadataSource; import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler; @@ -39,9 +44,6 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import javax.annotation.security.RolesAllowed; - -import static org.assertj.core.api.Assertions.assertThatCode; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -65,44 +67,33 @@ public class Sec2758Tests { @WithMockUser(authorities = "CUSTOM") @Test public void requestWhenNullifyingRolePrefixThenPassivityRestored() throws Exception { - this.spring.register(SecurityConfig.class).autowire(); - this.mvc.perform(get("/")).andExpect(status().isOk()); } @WithMockUser(authorities = "CUSTOM") @Test public void methodSecurityWhenNullifyingRolePrefixThenPassivityRestored() { - this.spring.register(SecurityConfig.class).autowire(); - - assertThatCode(() -> service.doJsr250()) - .doesNotThrowAnyException(); - - assertThatCode(() -> service.doPreAuthorize()) - .doesNotThrowAnyException(); + this.service.doJsr250(); + this.service.doPreAuthorize(); } @EnableWebSecurity - @EnableGlobalMethodSecurity(prePostEnabled=true, jsr250Enabled = true) + @EnableGlobalMethodSecurity(prePostEnabled = true, jsr250Enabled = true) static class SecurityConfig extends WebSecurityConfigurerAdapter { - @RestController - static class RootController { - @GetMapping("/") - public String ok() { return "ok"; } - } - @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .authorizeRequests() - .anyRequest().access("hasAnyRole('CUSTOM')"); + .authorizeRequests() + .anyRequest().access("hasAnyRole('CUSTOM')"); + // @formatter:on } @Bean - public Service service() { + Service service() { return new Service(); } @@ -111,14 +102,28 @@ public class Sec2758Tests { return new DefaultRolesPrefixPostProcessor(); } + @RestController + static class RootController { + + @GetMapping("/") + String ok() { + return "ok"; + } + + } + } static class Service { + @PreAuthorize("hasRole('CUSTOM')") - public void doPreAuthorize() {} + void doPreAuthorize() { + } @RolesAllowed("CUSTOM") - public void doJsr250() {} + void doJsr250() { + } + } static class DefaultRolesPrefixPostProcessor implements BeanPostProcessor, PriorityOrdered { @@ -144,7 +149,9 @@ public class Sec2758Tests { @Override public int getOrder() { - return PriorityOrdered.HIGHEST_PRECEDENCE; + return Ordered.HIGHEST_PRECEDENCE; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractConfiguredSecurityBuilderTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractConfiguredSecurityBuilderTests.java index 4f6fa4a613..4c9eb0d3e3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractConfiguredSecurityBuilderTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractConfiguredSecurityBuilderTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; +import java.util.List; + import org.junit.Before; import org.junit.Test; + import org.springframework.security.config.annotation.AbstractConfiguredSecurityBuilder; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.SecurityConfigurer; import org.springframework.security.config.annotation.SecurityConfigurerAdapter; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -34,6 +36,7 @@ import static org.mockito.Mockito.verify; * @author Joe Grandja */ public class AbstractConfiguredSecurityBuilderTests { + private TestConfiguredSecurityBuilder builder; @Before @@ -80,7 +83,8 @@ public class AbstractConfiguredSecurityBuilderTests { @Test(expected = IllegalStateException.class) public void getConfigurerWhenMultipleConfigurersThenThrowIllegalStateException() throws Exception { - TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), true); + TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), + true); builder.apply(new DelegateSecurityConfigurer()); builder.apply(new DelegateSecurityConfigurer()); builder.getConfigurer(DelegateSecurityConfigurer.class); @@ -88,7 +92,8 @@ public class AbstractConfiguredSecurityBuilderTests { @Test(expected = IllegalStateException.class) public void removeConfigurerWhenMultipleConfigurersThenThrowIllegalStateException() throws Exception { - TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), true); + TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), + true); builder.apply(new DelegateSecurityConfigurer()); builder.apply(new DelegateSecurityConfigurer()); builder.removeConfigurer(DelegateSecurityConfigurer.class); @@ -98,10 +103,12 @@ public class AbstractConfiguredSecurityBuilderTests { public void removeConfigurersWhenMultipleConfigurersThenConfigurersRemoved() throws Exception { DelegateSecurityConfigurer configurer1 = new DelegateSecurityConfigurer(); DelegateSecurityConfigurer configurer2 = new DelegateSecurityConfigurer(); - TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), true); + TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), + true); builder.apply(configurer1); builder.apply(configurer2); - List removedConfigurers = builder.removeConfigurers(DelegateSecurityConfigurer.class); + List removedConfigurers = builder + .removeConfigurers(DelegateSecurityConfigurer.class); assertThat(removedConfigurers).hasSize(2); assertThat(removedConfigurers).containsExactly(configurer1, configurer2); assertThat(builder.getConfigurers(DelegateSecurityConfigurer.class)).isEmpty(); @@ -111,7 +118,8 @@ public class AbstractConfiguredSecurityBuilderTests { public void getConfigurersWhenMultipleConfigurersThenConfigurersReturned() throws Exception { DelegateSecurityConfigurer configurer1 = new DelegateSecurityConfigurer(); DelegateSecurityConfigurer configurer2 = new DelegateSecurityConfigurer(); - TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), true); + TestConfiguredSecurityBuilder builder = new TestConfiguredSecurityBuilder(mock(ObjectPostProcessor.class), + true); builder.apply(configurer1); builder.apply(configurer2); List configurers = builder.getConfigurers(DelegateSecurityConfigurer.class); @@ -120,29 +128,40 @@ public class AbstractConfiguredSecurityBuilderTests { assertThat(builder.getConfigurers(DelegateSecurityConfigurer.class)).hasSize(2); } - private static class DelegateSecurityConfigurer extends SecurityConfigurerAdapter { + private static class DelegateSecurityConfigurer + extends SecurityConfigurerAdapter { + private static SecurityConfigurer CONFIGURER; @Override public void init(TestConfiguredSecurityBuilder builder) throws Exception { builder.apply(CONFIGURER); } + } - private static class TestSecurityConfigurer extends SecurityConfigurerAdapter { } + private static class TestSecurityConfigurer + extends SecurityConfigurerAdapter { - private static class TestConfiguredSecurityBuilder extends AbstractConfiguredSecurityBuilder { + } + + private static final class TestConfiguredSecurityBuilder + extends AbstractConfiguredSecurityBuilder { private TestConfiguredSecurityBuilder(ObjectPostProcessor objectPostProcessor) { super(objectPostProcessor); } - private TestConfiguredSecurityBuilder(ObjectPostProcessor objectPostProcessor, boolean allowConfigurersOfSameType) { + private TestConfiguredSecurityBuilder(ObjectPostProcessor objectPostProcessor, + boolean allowConfigurersOfSameType) { super(objectPostProcessor, allowConfigurersOfSameType); } + @Override public Object performBuild() { return "success"; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryAnyMatcherTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryAnyMatcherTests.java index e0124a570f..98232d5a6d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryAnyMatcherTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryAnyMatcherTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; import org.junit.Test; @@ -30,83 +31,28 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon * * @author Ankur Pathak */ -public class AbstractRequestMatcherRegistryAnyMatcherTests{ - - @EnableWebSecurity - static class AntMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated() - .antMatchers("/demo/**").permitAll(); - - } - } +public class AbstractRequestMatcherRegistryAnyMatcherTests { @Test(expected = BeanCreationException.class) - public void antMatchersCanNotWorkAfterAnyRequest(){ + public void antMatchersCanNotWorkAfterAnyRequest() { loadConfig(AntMatchersAfterAnyRequestConfig.class); } - @EnableWebSecurity - static class MvcMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated() - .mvcMatchers("/demo/**").permitAll(); - - } - } - @Test(expected = BeanCreationException.class) public void mvcMatchersCanNotWorkAfterAnyRequest() { loadConfig(MvcMatchersAfterAnyRequestConfig.class); } - @EnableWebSecurity - static class RegexMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated() - .regexMatchers(".*").permitAll(); - - } - } - @Test(expected = BeanCreationException.class) public void regexMatchersCanNotWorkAfterAnyRequest() { loadConfig(RegexMatchersAfterAnyRequestConfig.class); } - @EnableWebSecurity - static class AnyRequestAfterItselfConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated() - .anyRequest().permitAll(); - - } - } - @Test(expected = BeanCreationException.class) public void anyRequestCanNotWorkAfterItself() { loadConfig(AnyRequestAfterItselfConfig.class); } - @EnableWebSecurity - static class RequestMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated() - .requestMatchers(new AntPathRequestMatcher("/**")).permitAll(); - - } - } - @Test(expected = BeanCreationException.class) public void requestMatchersCanNotWorkAfterAnyRequest() { loadConfig(RequestMatchersAfterAnyRequestConfig.class); @@ -119,4 +65,80 @@ public class AbstractRequestMatcherRegistryAnyMatcherTests{ context.setServletContext(new MockServletContext()); context.refresh(); } + + @EnableWebSecurity + static class AntMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .antMatchers("/demo/**").permitAll(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class MvcMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .mvcMatchers("/demo/**").permitAll(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class RegexMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .regexMatchers(".*").permitAll(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class AnyRequestAfterItselfConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .anyRequest().permitAll(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class RequestMatchersAfterAnyRequestConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .requestMatchers(new AntPathRequestMatcher("/**")).permitAll(); + // @formatter:on + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index e0070a80c9..ef2b793024 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; +import java.util.List; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpMethod; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -32,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class AbstractRequestMatcherRegistryTests { + private TestRequestMatcherRegistry matcherRegistry; @Before @@ -87,5 +90,7 @@ public class AbstractRequestMatcherRegistryTests { protected List chainRequestMatchers(List requestMatchers) { return requestMatchers; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java index 9a7a66ec64..5cb55181c3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java @@ -13,17 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.config.annotation.web; import javax.servlet.Filter; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; @@ -40,8 +38,11 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + /** - * * @author Rob Winch * */ @@ -49,8 +50,10 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @ContextConfiguration @WebAppConfiguration public class HttpSecurityHeadersTests { + @Autowired WebApplicationContext wac; + @Autowired Filter springSecurityFilterChain; @@ -58,44 +61,54 @@ public class HttpSecurityHeadersTests { @Before public void setup() { - mockMvc = MockMvcBuilders - .webAppContextSetup(wac) - .addFilters(springSecurityFilterChain) - .build(); + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).addFilters(this.springSecurityFilterChain).build(); } // gh-2953 // gh-3975 @Test public void headerWhenSpringMvcResourceThenCacheRelatedHeadersReset() throws Exception { - mockMvc.perform(get("/resources/file.js")) - .andExpect(status().isOk()) - .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "max-age=12345")) - .andExpect(header().doesNotExist(HttpHeaders.PRAGMA)) - .andExpect(header().doesNotExist(HttpHeaders.EXPIRES)); + // @formatter:off + this.mockMvc.perform(get("/resources/file.js")) + .andExpect(status().isOk()) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "max-age=12345")) + .andExpect(header().doesNotExist(HttpHeaders.PRAGMA)) + .andExpect(header().doesNotExist(HttpHeaders.EXPIRES)); + // @formatter:on } @Test public void headerWhenNotSpringResourceThenCacheRelatedHeadersSet() throws Exception { - mockMvc.perform(get("/notresource")) - .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) - .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) - .andExpect(header().string(HttpHeaders.EXPIRES, "0")); + // @formatter:off + this.mockMvc.perform(get("/notresource")) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) + .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) + .andExpect(header().string(HttpHeaders.EXPIRES, "0")); + // @formatter:on } @EnableWebSecurity static class WebSecurityConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { } + } @EnableWebMvc @Configuration static class WebMvcConfig implements WebMvcConfigurer { + @Override public void addResourceHandlers(ResourceHandlerRegistry registry) { - registry.addResourceHandler("/resources/**").addResourceLocations("classpath:/resources/").setCachePeriod(12345); + // @formatter:off + registry.addResourceHandler("/resources/**") + .addResourceLocations("classpath:/resources/") + .setCachePeriod(12345); + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/SampleWebSecurityConfigurerAdapterTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/SampleWebSecurityConfigurerAdapterTests.java index b8304f9e13..8250a819ad 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/SampleWebSecurityConfigurerAdapterTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/SampleWebSecurityConfigurerAdapterTests.java @@ -13,11 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; +import java.util.Base64; + +import javax.servlet.http.HttpServletResponse; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.core.annotation.Order; @@ -36,9 +42,6 @@ import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; -import javax.servlet.http.HttpServletResponse; -import java.util.Base64; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -48,6 +51,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class SampleWebSecurityConfigurerAdapterTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -55,7 +59,9 @@ public class SampleWebSecurityConfigurerAdapterTests { private FilterChainProxy springSecurityFilterChain; private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockFilterChain chain; @Before @@ -63,7 +69,6 @@ public class SampleWebSecurityConfigurerAdapterTests { this.request = new MockHttpServletRequest("GET", ""); this.response = new MockHttpServletResponse(); this.chain = new MockFilterChain(); - CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "CSRF-TOKEN-TEST"); new HttpSessionCsrfTokenRepository().saveToken(csrfToken, this.request, this.response); this.request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); @@ -72,138 +77,187 @@ public class SampleWebSecurityConfigurerAdapterTests { @Test public void helloWorldSampleWhenRequestSecureResourceThenRedirectToLogin() throws Exception { this.spring.register(HelloWorldWebSecurityConfigurerAdapter.class).autowire(); - this.request.addHeader("Accept", "text/html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getRedirectedUrl()).isEqualTo("http://localhost/login"); } @Test public void helloWorldSampleWhenRequestLoginWithoutCredentialsThenRedirectToLogin() throws Exception { this.spring.register(HelloWorldWebSecurityConfigurerAdapter.class).autowire(); - this.request.setServletPath("/login"); this.request.setMethod("POST"); this.request.addHeader("Accept", "text/html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getRedirectedUrl()).isEqualTo("/login?error"); } @Test public void helloWorldSampleWhenRequestLoginWithValidCredentialsThenRedirectToIndex() throws Exception { this.spring.register(HelloWorldWebSecurityConfigurerAdapter.class).autowire(); - this.request.setServletPath("/login"); this.request.setMethod("POST"); this.request.addHeader("Accept", "text/html"); this.request.addParameter("username", "user"); this.request.addParameter("password", "password"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getRedirectedUrl()).isEqualTo("/"); } - /** - * - * - * - * - * - * login-processing-url="/login" - * password-parameter="password" - * username-parameter="username" - * /> - * - * - * - * - * - * - * - * - * - * @author Rob Winch - */ - @EnableWebSecurity - public static class HelloWorldWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - } - - @Test public void readmeSampleWhenRequestSecureResourceThenRedirectToLogin() throws Exception { this.spring.register(SampleWebSecurityConfigurerAdapter.class).autowire(); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getRedirectedUrl()).isEqualTo("http://localhost/login"); } @Test public void readmeSampleWhenRequestLoginWithoutCredentialsThenRedirectToLogin() throws Exception { this.spring.register(SampleWebSecurityConfigurerAdapter.class).autowire(); - this.request.setServletPath("/login"); this.request.setMethod("POST"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getRedirectedUrl()).isEqualTo("/login?error"); } @Test public void readmeSampleWhenRequestLoginWithValidCredentialsThenRedirectToIndex() throws Exception { this.spring.register(SampleWebSecurityConfigurerAdapter.class).autowire(); - this.request.setServletPath("/login"); this.request.setMethod("POST"); this.request.addParameter("username", "user"); this.request.addParameter("password", "password"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getRedirectedUrl()).isEqualTo("/"); } + @Test + public void multiHttpSampleWhenRequestSecureResourceThenRedirectToLogin() throws Exception { + this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getRedirectedUrl()).isEqualTo("http://localhost/login"); + } + + @Test + public void multiHttpSampleWhenRequestLoginWithoutCredentialsThenRedirectToLogin() throws Exception { + this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); + this.request.setServletPath("/login"); + this.request.setMethod("POST"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getRedirectedUrl()).isEqualTo("/login?error"); + } + + @Test + public void multiHttpSampleWhenRequestLoginWithValidCredentialsThenRedirectToIndex() throws Exception { + this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); + this.request.setServletPath("/login"); + this.request.setMethod("POST"); + this.request.addParameter("username", "user"); + this.request.addParameter("password", "password"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getRedirectedUrl()).isEqualTo("/"); + } + + @Test + public void multiHttpSampleWhenRequestProtectedResourceThenStatusUnauthorized() throws Exception { + this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); + this.request.setServletPath("/api/admin/test"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void multiHttpSampleWhenRequestAdminResourceWithRegularUserThenStatusForbidden() throws Exception { + this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); + this.request.setServletPath("/api/admin/test"); + this.request.addHeader("Authorization", + "Basic " + Base64.getEncoder().encodeToString("user:password".getBytes())); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + } + + @Test + public void multiHttpSampleWhenRequestAdminResourceWithAdminUserThenStatusOk() throws Exception { + this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); + this.request.setServletPath("/api/admin/test"); + this.request.addHeader("Authorization", + "Basic " + Base64.getEncoder().encodeToString("admin:password".getBytes())); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + /** - * - * - * - * - * - * - * - * - * + * <http> + * <intercept-url pattern="/resources/**" access="permitAll"/> + * <intercept-url pattern="/**" access="authenticated"/> + * <logout * logout-success-url="/login?logout" * logout-url="/logout" - * + * login-page="/login" <!-- Except Spring Security renders the login page --> + * login-processing-url="/login" <!-- but only POST --> * password-parameter="password" * username-parameter="username" - * /> - * - * - * - * - * - * - * - * - * - * + * /> + * </http> + * <authentication-manager> + * <authentication-provider> + * <user-service> + * <user username="user" password="password" authorities="ROLE_USER"/> + * </user-service> + * </authentication-provider> + * </authentication-manager> + * + * + * @author Rob Winch + */ + @EnableWebSecurity + public static class HelloWorldWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + + /** + *
    +	 *   <http security="none" pattern="/resources/**"/>
    +	 *   <http>
    +	 *     <intercept-url pattern="/logout" access="permitAll"/>
    +	 *     <intercept-url pattern="/login" access="permitAll"/>
    +	 *     <intercept-url pattern="/signup" access="permitAll"/>
    +	 *     <intercept-url pattern="/about" access="permitAll"/>
    +	 *     <intercept-url pattern="/**" access="hasRole('ROLE_USER')"/>
    +	 *     <logout
    +	 *         logout-success-url="/login?logout"
    +	 *         logout-url="/logout"
    +	 *     <form-login
    +	 *         authentication-failure-url="/login?error"
    +	 *         login-page="/login"
    +	 *         login-processing-url="/login" <!-- but only POST -->
    +	 *         password-parameter="password"
    +	 *         username-parameter="username"
    +	 *     />
    +	 *   </http>
    +	 *   <authentication-manager>
    +	 *     <authentication-provider>
    +	 *       <user-service>
    +	 *         <user username="user" password="password" authorities="ROLE_USER"/>
    +	 *         <user username="admin" password="password" authorities=
    +	"ROLE_USER,ROLE_ADMIN"/>
    +	 *       </user-service>
    +	 *     </authentication-provider>
    +	 *   </authentication-manager>
    +	 * 
    + * * @author Rob Winch */ @EnableWebSecurity @@ -211,13 +265,12 @@ public class SampleWebSecurityConfigurerAdapterTests { @Override public void configure(WebSecurity web) { - web - .ignoring() - .antMatchers("/resources/**"); + web.ignoring().antMatchers("/resources/**"); } @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/signup", "/about").permitAll() @@ -227,133 +280,79 @@ public class SampleWebSecurityConfigurerAdapterTests { .loginPage("/login") // set permitAll for all URLs associated with Form Login .permitAll(); + // @formatter:on } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()) .withUser(PasswordEncodedUser.admin()); + // @formatter:on } - } - - @Test - public void multiHttpSampleWhenRequestSecureResourceThenRedirectToLogin() throws Exception { - this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("http://localhost/login"); - } - - @Test - public void multiHttpSampleWhenRequestLoginWithoutCredentialsThenRedirectToLogin() throws Exception { - this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); - - this.request.setServletPath("/login"); - this.request.setMethod("POST"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("/login?error"); - } - - @Test - public void multiHttpSampleWhenRequestLoginWithValidCredentialsThenRedirectToIndex() throws Exception { - this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); - - this.request.setServletPath("/login"); - this.request.setMethod("POST"); - this.request.addParameter("username", "user"); - this.request.addParameter("password", "password"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("/"); - } - - @Test - public void multiHttpSampleWhenRequestProtectedResourceThenStatusUnauthorized() throws Exception { - this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); - - this.request.setServletPath("/api/admin/test"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - } - - @Test - public void multiHttpSampleWhenRequestAdminResourceWithRegularUserThenStatusForbidden() throws Exception { - this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); - - this.request.setServletPath("/api/admin/test"); - this.request.addHeader("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".getBytes())); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); - } - - @Test - public void multiHttpSampleWhenRequestAdminResourceWithAdminUserThenStatusOk() throws Exception { - this.spring.register(SampleMultiHttpSecurityConfig.class).autowire(); - - this.request.setServletPath("/api/admin/test"); - this.request.addHeader("Authorization", "Basic " + Base64.getEncoder().encodeToString("admin:password".getBytes())); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } /** * - * - * - * - * - * - * - * - * - * - * - * - * - * + * login-processing-url="/login" <!-- but only POST --> * password-parameter="password" * username-parameter="username" - * /> - * - * - * - * - * - * - * - * - * + * /> + * </http> + * <authentication-manager> + * <authentication-provider> + * <user-service> + * <user username="user" password="password" authorities="ROLE_USER"/> + * <user username="admin" password="password" authorities= + "ROLE_USER,ROLE_ADMIN"/> + * </user-service> + * </authentication-provider> + * </authentication-manager> * + * * @author Rob Winch */ @EnableWebSecurity public static class SampleMultiHttpSecurityConfig { + @Autowired protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()) .withUser(PasswordEncodedUser.admin()); + // @formatter:on } @Configuration @Order(1) public static class ApiWebSecurityConfigurationAdapter extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .antMatcher("/api/**") .authorizeRequests() @@ -361,20 +360,22 @@ public class SampleWebSecurityConfigurerAdapterTests { .antMatchers("/api/**").hasRole("USER") .and() .httpBasic(); + // @formatter:on } + } @Configuration public static class FormLoginWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { + @Override public void configure(WebSecurity web) { - web - .ignoring() - .antMatchers("/resources/**"); + web.ignoring().antMatchers("/resources/**"); } @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/signup", "/about").permitAll() @@ -383,7 +384,11 @@ public class SampleWebSecurityConfigurerAdapterTests { .formLogin() .loginPage("/login") .permitAll(); + // @formatter:on } + } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterPowermockTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterPowermockTests.java index c35cdb39fe..cccc804823 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterPowermockTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterPowermockTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; import java.util.Arrays; @@ -22,6 +23,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -48,19 +50,18 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** - * * @author Rob Winch * */ @RunWith(PowerMockRunner.class) @PrepareForTest({ SpringFactoriesLoader.class, WebAsyncManager.class }) -@PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*", "javax.xml.transform.*" }) +@PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*", + "javax.xml.transform.*" }) public class WebSecurityConfigurerAdapterPowermockTests { + ConfigurableWebApplicationContext context; @Rule @@ -71,25 +72,39 @@ public class WebSecurityConfigurerAdapterPowermockTests { @After public void close() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @Test public void loadConfigWhenDefaultConfigurerAsSpringFactoryhenDefaultConfigurerApplied() { - spy(SpringFactoriesLoader.class); + PowerMockito.spy(SpringFactoriesLoader.class); DefaultConfigurer configurer = new DefaultConfigurer(); - when(SpringFactoriesLoader - .loadFactories(AbstractHttpConfigurer.class, getClass().getClassLoader())) - .thenReturn(Arrays.asList(configurer)); - + PowerMockito + .when(SpringFactoriesLoader.loadFactories(AbstractHttpConfigurer.class, getClass().getClassLoader())) + .thenReturn(Arrays.asList(configurer)); loadConfig(Config.class); - assertThat(configurer.init).isTrue(); assertThat(configurer.configure).isTrue(); } + @Test + public void loadConfigWhenDefaultConfigThenWebAsyncManagerIntegrationFilterAdded() throws Exception { + this.spring.register(WebAsyncPopulatedByDefaultConfig.class).autowire(); + WebAsyncManager webAsyncManager = mock(WebAsyncManager.class); + this.mockMvc.perform(get("/").requestAttr(WebAsyncUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, webAsyncManager)); + ArgumentCaptor callableProcessingInterceptorArgCaptor = ArgumentCaptor + .forClass(CallableProcessingInterceptor.class); + verify(webAsyncManager, atLeastOnce()).registerCallableInterceptor(any(), + callableProcessingInterceptorArgCaptor.capture()); + CallableProcessingInterceptor callableProcessingInterceptor = callableProcessingInterceptorArgCaptor + .getAllValues().stream() + .filter((e) -> SecurityContextCallableProcessingInterceptor.class.isAssignableFrom(e.getClass())) + .findFirst().orElse(null); + assertThat(callableProcessingInterceptor).isNotNull(); + } + private void loadConfig(Class... classes) { AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext(); context.setClassLoader(getClass().getClassLoader()); @@ -100,13 +115,17 @@ public class WebSecurityConfigurerAdapterPowermockTests { @EnableWebSecurity static class Config extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { } + } static class DefaultConfigurer extends AbstractHttpConfigurer { + boolean init; + boolean configure; @Override @@ -118,27 +137,7 @@ public class WebSecurityConfigurerAdapterPowermockTests { public void configure(HttpSecurity builder) { this.configure = true; } - } - @Test - public void loadConfigWhenDefaultConfigThenWebAsyncManagerIntegrationFilterAdded() throws Exception { - this.spring.register(WebAsyncPopulatedByDefaultConfig.class).autowire(); - - WebAsyncManager webAsyncManager = mock(WebAsyncManager.class); - - this.mockMvc.perform(get("/").requestAttr(WebAsyncUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, webAsyncManager)); - - ArgumentCaptor callableProcessingInterceptorArgCaptor = - ArgumentCaptor.forClass(CallableProcessingInterceptor.class); - verify(webAsyncManager, atLeastOnce()).registerCallableInterceptor(any(), callableProcessingInterceptorArgCaptor.capture()); - - CallableProcessingInterceptor callableProcessingInterceptor = - callableProcessingInterceptorArgCaptor.getAllValues().stream() - .filter(e -> SecurityContextCallableProcessingInterceptor.class.isAssignableFrom(e.getClass())) - .findFirst() - .orElse(null); - - assertThat(callableProcessingInterceptor).isNotNull(); } @EnableWebSecurity @@ -146,13 +145,17 @@ public class WebSecurityConfigurerAdapterPowermockTests { @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } @Override protected void configure(HttpSecurity http) { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.java index 8f784aa513..f23d730292 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web; import java.io.IOException; import java.util.ArrayList; import java.util.List; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -49,12 +51,13 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.accept.ContentNegotiationStrategy; import org.springframework.web.accept.HeaderContentNegotiationStrategy; import org.springframework.web.filter.OncePerRequestFilter; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.ThrowableAssert.catchThrowable; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -71,6 +74,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Joe Grandja */ public class WebSecurityConfigurerAdapterTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -80,15 +84,119 @@ public class WebSecurityConfigurerAdapterTests { @Test public void loadConfigWhenRequestSecureThenDefaultSecurityHeadersReturned() throws Exception { this.spring.register(HeadersArePopulatedByDefaultConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(get("/").secure(true)) - .andExpect(header().string("X-Content-Type-Options", "nosniff")) - .andExpect(header().string("X-Frame-Options", "DENY")) - .andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains")) - .andExpect(header().string("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")) - .andExpect(header().string("Pragma", "no-cache")) - .andExpect(header().string("Expires", "0")) - .andExpect(header().string("X-XSS-Protection", "1; mode=block")); + .andExpect(header().string("X-Content-Type-Options", "nosniff")) + .andExpect(header().string("X-Frame-Options", "DENY")) + .andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains")) + .andExpect(header().string("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")) + .andExpect(header().string("Pragma", "no-cache")).andExpect(header().string("Expires", "0")) + .andExpect(header().string("X-XSS-Protection", "1; mode=block")); + // @formatter:on + } + + @Test + public void loadConfigWhenRequestAuthenticateThenAuthenticationEventPublished() throws Exception { + this.spring.register(InMemoryAuthWithWebSecurityConfigurerAdapter.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(status().is3xxRedirection()); + assertThat(InMemoryAuthWithWebSecurityConfigurerAdapter.EVENTS).isNotEmpty(); + assertThat(InMemoryAuthWithWebSecurityConfigurerAdapter.EVENTS).hasSize(1); + } + + @Test + public void loadConfigWhenInMemoryConfigureProtectedThenPasswordUpgraded() throws Exception { + this.spring.register(InMemoryConfigureProtectedConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(status().is3xxRedirection()); + UserDetailsService uds = this.spring.getContext().getBean(UserDetailsService.class); + assertThat(uds.loadUserByUsername("user").getPassword()).startsWith("{bcrypt}"); + } + + @Test + public void loadConfigWhenInMemoryConfigureGlobalThenPasswordUpgraded() throws Exception { + this.spring.register(InMemoryConfigureGlobalConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(status().is3xxRedirection()); + UserDetailsService uds = this.spring.getContext().getBean(UserDetailsService.class); + assertThat(uds.loadUserByUsername("user").getPassword()).startsWith("{bcrypt}"); + } + + @Test + public void loadConfigWhenCustomContentNegotiationStrategyBeanThenOverridesDefault() { + OverrideContentNegotiationStrategySharedObjectConfig.CONTENT_NEGOTIATION_STRATEGY_BEAN = mock( + ContentNegotiationStrategy.class); + this.spring.register(OverrideContentNegotiationStrategySharedObjectConfig.class).autowire(); + OverrideContentNegotiationStrategySharedObjectConfig securityConfig = this.spring.getContext() + .getBean(OverrideContentNegotiationStrategySharedObjectConfig.class); + assertThat(securityConfig.contentNegotiationStrategySharedObject).isNotNull(); + assertThat(securityConfig.contentNegotiationStrategySharedObject) + .isSameAs(OverrideContentNegotiationStrategySharedObjectConfig.CONTENT_NEGOTIATION_STRATEGY_BEAN); + } + + @Test + public void loadConfigWhenDefaultContentNegotiationStrategyThenHeaderContentNegotiationStrategy() { + this.spring.register(ContentNegotiationStrategyDefaultSharedObjectConfig.class).autowire(); + ContentNegotiationStrategyDefaultSharedObjectConfig securityConfig = this.spring.getContext() + .getBean(ContentNegotiationStrategyDefaultSharedObjectConfig.class); + assertThat(securityConfig.contentNegotiationStrategySharedObject).isNotNull(); + assertThat(securityConfig.contentNegotiationStrategySharedObject) + .isInstanceOf(HeaderContentNegotiationStrategy.class); + } + + @Test + public void loadConfigWhenUserDetailsServiceHasCircularReferenceThenStillLoads() { + this.spring.register(RequiresUserDetailsServiceConfig.class, UserDetailsServiceConfig.class).autowire(); + MyFilter myFilter = this.spring.getContext().getBean(MyFilter.class); + myFilter.userDetailsService.loadUserByUsername("user"); + assertThatExceptionOfType(UsernameNotFoundException.class) + .isThrownBy(() -> myFilter.userDetailsService.loadUserByUsername("admin")); + } + + // SEC-2274: WebSecurityConfigurer adds ApplicationContext as a shared object + @Test + public void loadConfigWhenSharedObjectsCreatedThenApplicationContextAdded() { + this.spring.register(ApplicationContextSharedObjectConfig.class).autowire(); + ApplicationContextSharedObjectConfig securityConfig = this.spring.getContext() + .getBean(ApplicationContextSharedObjectConfig.class); + assertThat(securityConfig.applicationContextSharedObject).isNotNull(); + assertThat(securityConfig.applicationContextSharedObject).isSameAs(this.spring.getContext()); + } + + @Test + public void loadConfigWhenCustomAuthenticationTrustResolverBeanThenOverridesDefault() { + CustomTrustResolverConfig.AUTHENTICATION_TRUST_RESOLVER_BEAN = mock(AuthenticationTrustResolver.class); + this.spring.register(CustomTrustResolverConfig.class).autowire(); + CustomTrustResolverConfig securityConfig = this.spring.getContext().getBean(CustomTrustResolverConfig.class); + assertThat(securityConfig.authenticationTrustResolverSharedObject).isNotNull(); + assertThat(securityConfig.authenticationTrustResolverSharedObject) + .isSameAs(CustomTrustResolverConfig.AUTHENTICATION_TRUST_RESOLVER_BEAN); + } + + @Test + public void compareOrderWebSecurityConfigurerAdapterWhenLowestOrderToDefaultOrderThenGreaterThanZero() { + AnnotationAwareOrderComparator comparator = new AnnotationAwareOrderComparator(); + assertThat(comparator.compare(new LowestPriorityWebSecurityConfig(), new DefaultOrderWebSecurityConfig())) + .isGreaterThan(0); + } + + // gh-7515 + @Test + public void performWhenUsingAuthenticationEventPublisherBeanThenUses() throws Exception { + this.spring.register(CustomAuthenticationEventPublisherBean.class).autowire(); + AuthenticationEventPublisher authenticationEventPublisher = this.spring.getContext() + .getBean(AuthenticationEventPublisher.class); + this.mockMvc.perform(get("/").with(httpBasic("user", "password"))); + verify(authenticationEventPublisher).publishAuthenticationSuccess(any(Authentication.class)); + } + + // gh-4400 + @Test + public void performWhenUsingAuthenticationEventPublisherInDslThenUses() throws Exception { + this.spring.register(CustomAuthenticationEventPublisherDsl.class).autowire(); + AuthenticationEventPublisher authenticationEventPublisher = CustomAuthenticationEventPublisherDsl.EVENT_PUBLISHER; + MockHttpServletRequestBuilder userRequest = get("/").with(httpBasic("user", "password")); + // fails since no providers configured + this.mockMvc.perform(userRequest); + verify(authenticationEventPublisher).publishAuthenticationFailure(any(AuthenticationException.class), + any(Authentication.class)); } @EnableWebSecurity @@ -96,65 +204,51 @@ public class WebSecurityConfigurerAdapterTests { @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } @Override protected void configure(HttpSecurity http) { } - } - @Test - public void loadConfigWhenRequestAuthenticateThenAuthenticationEventPublished() throws Exception { - this.spring.register(InMemoryAuthWithWebSecurityConfigurerAdapter.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(status().is3xxRedirection()); - - assertThat(InMemoryAuthWithWebSecurityConfigurerAdapter.EVENTS).isNotEmpty(); - assertThat(InMemoryAuthWithWebSecurityConfigurerAdapter.EVENTS).hasSize(1); } @EnableWebSecurity static class InMemoryAuthWithWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter - implements ApplicationListener { + implements ApplicationListener { static List EVENTS = new ArrayList<>(); @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } @Override public void onApplicationEvent(AuthenticationSuccessEvent event) { EVENTS.add(event); } - } - @Test - public void loadConfigWhenInMemoryConfigureProtectedThenPasswordUpgraded() throws Exception { - this.spring.register(InMemoryConfigureProtectedConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(status().is3xxRedirection()); - - UserDetailsService uds = this.spring.getContext() - .getBean(UserDetailsService.class); - assertThat(uds.loadUserByUsername("user").getPassword()).startsWith("{bcrypt}"); } @EnableWebSecurity static class InMemoryConfigureProtectedConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } @Override @@ -162,27 +256,19 @@ public class WebSecurityConfigurerAdapterTests { public UserDetailsService userDetailsServiceBean() throws Exception { return super.userDetailsServiceBean(); } - } - @Test - public void loadConfigWhenInMemoryConfigureGlobalThenPasswordUpgraded() throws Exception { - this.spring.register(InMemoryConfigureGlobalConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(status().is3xxRedirection()); - - UserDetailsService uds = this.spring.getContext() - .getBean(UserDetailsService.class); - assertThat(uds.loadUserByUsername("user").getPassword()).startsWith("{bcrypt}"); } @EnableWebSecurity static class InMemoryConfigureGlobalConfig extends WebSecurityConfigurerAdapter { + @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } @Override @@ -190,28 +276,18 @@ public class WebSecurityConfigurerAdapterTests { public UserDetailsService userDetailsServiceBean() throws Exception { return super.userDetailsServiceBean(); } - } - @Test - public void loadConfigWhenCustomContentNegotiationStrategyBeanThenOverridesDefault() { - OverrideContentNegotiationStrategySharedObjectConfig.CONTENT_NEGOTIATION_STRATEGY_BEAN = mock(ContentNegotiationStrategy.class); - this.spring.register(OverrideContentNegotiationStrategySharedObjectConfig.class).autowire(); - - OverrideContentNegotiationStrategySharedObjectConfig securityConfig = - this.spring.getContext().getBean(OverrideContentNegotiationStrategySharedObjectConfig.class); - - assertThat(securityConfig.contentNegotiationStrategySharedObject).isNotNull(); - assertThat(securityConfig.contentNegotiationStrategySharedObject) - .isSameAs(OverrideContentNegotiationStrategySharedObjectConfig.CONTENT_NEGOTIATION_STRATEGY_BEAN); } @EnableWebSecurity static class OverrideContentNegotiationStrategySharedObjectConfig extends WebSecurityConfigurerAdapter { + static ContentNegotiationStrategy CONTENT_NEGOTIATION_STRATEGY_BEAN; + private ContentNegotiationStrategy contentNegotiationStrategySharedObject; @Bean - public ContentNegotiationStrategy contentNegotiationStrategy() { + ContentNegotiationStrategy contentNegotiationStrategy() { return CONTENT_NEGOTIATION_STRATEGY_BEAN; } @@ -220,21 +296,12 @@ public class WebSecurityConfigurerAdapterTests { this.contentNegotiationStrategySharedObject = http.getSharedObject(ContentNegotiationStrategy.class); super.configure(http); } - } - @Test - public void loadConfigWhenDefaultContentNegotiationStrategyThenHeaderContentNegotiationStrategy() { - this.spring.register(ContentNegotiationStrategyDefaultSharedObjectConfig.class).autowire(); - - ContentNegotiationStrategyDefaultSharedObjectConfig securityConfig = - this.spring.getContext().getBean(ContentNegotiationStrategyDefaultSharedObjectConfig.class); - - assertThat(securityConfig.contentNegotiationStrategySharedObject).isNotNull(); - assertThat(securityConfig.contentNegotiationStrategySharedObject).isInstanceOf(HeaderContentNegotiationStrategy.class); } @EnableWebSecurity static class ContentNegotiationStrategyDefaultSharedObjectConfig extends WebSecurityConfigurerAdapter { + private ContentNegotiationStrategy contentNegotiationStrategySharedObject; @Override @@ -242,31 +309,22 @@ public class WebSecurityConfigurerAdapterTests { this.contentNegotiationStrategySharedObject = http.getSharedObject(ContentNegotiationStrategy.class); super.configure(http); } - } - @Test - public void loadConfigWhenUserDetailsServiceHasCircularReferenceThenStillLoads() { - this.spring.register(RequiresUserDetailsServiceConfig.class, UserDetailsServiceConfig.class).autowire(); - - MyFilter myFilter = this.spring.getContext().getBean(MyFilter.class); - - Throwable thrown = catchThrowable(() -> myFilter.userDetailsService.loadUserByUsername("user") ); - assertThat(thrown).isNull(); - - thrown = catchThrowable(() -> myFilter.userDetailsService.loadUserByUsername("admin") ); - assertThat(thrown).isInstanceOf(UsernameNotFoundException.class); } @Configuration static class RequiresUserDetailsServiceConfig { + @Bean - public MyFilter myFilter(UserDetailsService userDetailsService) { + MyFilter myFilter(UserDetailsService userDetailsService) { return new MyFilter(userDetailsService); } + } @EnableWebSecurity static class UserDetailsServiceConfig extends WebSecurityConfigurerAdapter { + @Autowired private MyFilter myFilter; @@ -283,13 +341,17 @@ public class WebSecurityConfigurerAdapterTests { @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } + } static class MyFilter extends OncePerRequestFilter { + private UserDetailsService userDetailsService; MyFilter(UserDetailsService userDetailsService) { @@ -297,27 +359,16 @@ public class WebSecurityConfigurerAdapterTests { } @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, - FilterChain filterChain) throws ServletException, IOException { + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { filterChain.doFilter(request, response); } - } - // SEC-2274: WebSecurityConfigurer adds ApplicationContext as a shared object - @Test - public void loadConfigWhenSharedObjectsCreatedThenApplicationContextAdded() { - this.spring.register(ApplicationContextSharedObjectConfig.class).autowire(); - - ApplicationContextSharedObjectConfig securityConfig = - this.spring.getContext().getBean(ApplicationContextSharedObjectConfig.class); - - assertThat(securityConfig.applicationContextSharedObject).isNotNull(); - assertThat(securityConfig.applicationContextSharedObject).isSameAs(this.spring.getContext()); } @EnableWebSecurity static class ApplicationContextSharedObjectConfig extends WebSecurityConfigurerAdapter { + private ApplicationContext applicationContextSharedObject; @Override @@ -325,28 +376,18 @@ public class WebSecurityConfigurerAdapterTests { this.applicationContextSharedObject = http.getSharedObject(ApplicationContext.class); super.configure(http); } - } - @Test - public void loadConfigWhenCustomAuthenticationTrustResolverBeanThenOverridesDefault() { - CustomTrustResolverConfig.AUTHENTICATION_TRUST_RESOLVER_BEAN = mock(AuthenticationTrustResolver.class); - this.spring.register(CustomTrustResolverConfig.class).autowire(); - - CustomTrustResolverConfig securityConfig = - this.spring.getContext().getBean(CustomTrustResolverConfig.class); - - assertThat(securityConfig.authenticationTrustResolverSharedObject).isNotNull(); - assertThat(securityConfig.authenticationTrustResolverSharedObject) - .isSameAs(CustomTrustResolverConfig.AUTHENTICATION_TRUST_RESOLVER_BEAN); } @EnableWebSecurity static class CustomTrustResolverConfig extends WebSecurityConfigurerAdapter { + static AuthenticationTrustResolver AUTHENTICATION_TRUST_RESOLVER_BEAN; + private AuthenticationTrustResolver authenticationTrustResolverSharedObject; @Bean - public AuthenticationTrustResolver authenticationTrustResolver() { + AuthenticationTrustResolver authenticationTrustResolver() { return AUTHENTICATION_TRUST_RESOLVER_BEAN; } @@ -355,39 +396,21 @@ public class WebSecurityConfigurerAdapterTests { this.authenticationTrustResolverSharedObject = http.getSharedObject(AuthenticationTrustResolver.class); super.configure(http); } - } - @Test - public void compareOrderWebSecurityConfigurerAdapterWhenLowestOrderToDefaultOrderThenGreaterThanZero() { - AnnotationAwareOrderComparator comparator = new AnnotationAwareOrderComparator(); - assertThat(comparator.compare( - new LowestPriorityWebSecurityConfig(), - new DefaultOrderWebSecurityConfig())).isGreaterThan(0); } static class DefaultOrderWebSecurityConfig extends WebSecurityConfigurerAdapter { + } @Order static class LowestPriorityWebSecurityConfig extends WebSecurityConfigurerAdapter { - } - // gh-7515 - @Test - public void performWhenUsingAuthenticationEventPublisherBeanThenUses() throws Exception { - this.spring.register(CustomAuthenticationEventPublisherBean.class).autowire(); - - AuthenticationEventPublisher authenticationEventPublisher = - this.spring.getContext().getBean(AuthenticationEventPublisher.class); - - this.mockMvc.perform(get("/") - .with(httpBasic("user", "password"))); - - verify(authenticationEventPublisher).publishAuthenticationSuccess(any(Authentication.class)); } @EnableWebSecurity static class CustomAuthenticationEventPublisherBean extends WebSecurityConfigurerAdapter { + @Bean @Override public UserDetailsService userDetailsService() { @@ -395,34 +418,22 @@ public class WebSecurityConfigurerAdapterTests { } @Bean - public AuthenticationEventPublisher authenticationEventPublisher() { + AuthenticationEventPublisher authenticationEventPublisher() { return mock(AuthenticationEventPublisher.class); } - } - // gh-4400 - @Test - public void performWhenUsingAuthenticationEventPublisherInDslThenUses() throws Exception { - this.spring.register(CustomAuthenticationEventPublisherDsl.class).autowire(); - - AuthenticationEventPublisher authenticationEventPublisher = - CustomAuthenticationEventPublisherDsl.EVENT_PUBLISHER; - - this.mockMvc.perform(get("/") - .with(httpBasic("user", "password"))); // fails since no providers configured - - verify(authenticationEventPublisher).publishAuthenticationFailure( - any(AuthenticationException.class), - any(Authentication.class)); } @EnableWebSecurity static class CustomAuthenticationEventPublisherDsl extends WebSecurityConfigurerAdapter { + static AuthenticationEventPublisher EVENT_PUBLISHER = mock(AuthenticationEventPublisher.class); @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth.authenticationEventPublisher(EVENT_PUBLISHER); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java index dbeda6bdd2..5fa3a802fa 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpConfigurationTests.java @@ -13,10 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.builders; +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.cas.web.CasAuthenticationFilter; @@ -28,17 +39,10 @@ import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.test.web.servlet.MockMvc; import org.springframework.web.filter.OncePerRequestFilter; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.ThrowableAssert.catchThrowable; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -49,6 +53,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Joe Grandja */ public class HttpConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -57,35 +62,11 @@ public class HttpConfigurationTests { @Test public void configureWhenAddFilterUnregisteredThenThrowsBeanCreationException() { - Throwable thrown = catchThrowable(() -> this.spring.register(UnregisteredFilterConfig.class).autowire() ); - assertThat(thrown).isInstanceOf(BeanCreationException.class); - assertThat(thrown.getMessage()).contains("The Filter class " + UnregisteredFilter.class.getName() + - " does not have a registered order and cannot be added without a specified order." + - " Consider using addFilterBefore or addFilterAfter instead."); - } - - @EnableWebSecurity - static class UnregisteredFilterConfig extends WebSecurityConfigurerAdapter { - - protected void configure(HttpSecurity http) { - http - .addFilter(new UnregisteredFilter()); - } - - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - } - - static class UnregisteredFilter extends OncePerRequestFilter { - @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, - FilterChain filterChain) throws ServletException, IOException { - filterChain.doFilter(request, response); - } + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(UnregisteredFilterConfig.class).autowire()) + .withMessageContaining("The Filter class " + UnregisteredFilter.class.getName() + + " does not have a registered order and cannot be added without a specified order." + + " Consider using addFilterBefore or addFilterAfter instead."); } // https://github.com/spring-projects/spring-security-javaconfig/issues/104 @@ -93,37 +74,73 @@ public class HttpConfigurationTests { public void configureWhenAddFilterCasAuthenticationFilterThenFilterAdded() throws Exception { CasAuthenticationFilterConfig.CAS_AUTHENTICATION_FILTER = spy(new CasAuthenticationFilter()); this.spring.register(CasAuthenticationFilterConfig.class).autowire(); - this.mockMvc.perform(get("/")); - - verify(CasAuthenticationFilterConfig.CAS_AUTHENTICATION_FILTER).doFilter( - any(ServletRequest.class), any(ServletResponse.class), any(FilterChain.class)); - } - - @EnableWebSecurity - static class CasAuthenticationFilterConfig extends WebSecurityConfigurerAdapter { - static CasAuthenticationFilter CAS_AUTHENTICATION_FILTER; - - protected void configure(HttpSecurity http) { - http - .addFilter(CAS_AUTHENTICATION_FILTER); - } + verify(CasAuthenticationFilterConfig.CAS_AUTHENTICATION_FILTER).doFilter(any(ServletRequest.class), + any(ServletResponse.class), any(FilterChain.class)); } @Test public void configureWhenConfigIsRequestMatchersJavadocThenAuthorizationApplied() throws Exception { this.spring.register(RequestMatcherRegistryConfigs.class).autowire(); - this.mockMvc.perform(get("/oauth/a")).andExpect(status().isUnauthorized()); this.mockMvc.perform(get("/oauth/b")).andExpect(status().isUnauthorized()); this.mockMvc.perform(get("/api/a")).andExpect(status().isUnauthorized()); this.mockMvc.perform(get("/api/b")).andExpect(status().isUnauthorized()); } + @EnableWebSecurity + static class UnregisteredFilterConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) { + // @formatter:off + http + .addFilter(new UnregisteredFilter()); + // @formatter:on + } + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + + static class UnregisteredFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + filterChain.doFilter(request, response); + } + + } + + @EnableWebSecurity + static class CasAuthenticationFilterConfig extends WebSecurityConfigurerAdapter { + + static CasAuthenticationFilter CAS_AUTHENTICATION_FILTER; + + @Override + protected void configure(HttpSecurity http) { + // @formatter:off + http + .addFilter(CAS_AUTHENTICATION_FILTER); + // @formatter:on + } + + } + @EnableWebSecurity static class RequestMatcherRegistryConfigs extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .requestMatchers() .antMatchers("/api/**") @@ -133,6 +150,9 @@ public class HttpConfigurationTests { .antMatchers("/**").hasRole("USER") .and() .httpBasic(); + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/builders/NamespaceHttpTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/builders/NamespaceHttpTests.java index d219a50dbe..45478c3a56 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/builders/NamespaceHttpTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/builders/NamespaceHttpTests.java @@ -13,10 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.builders; +import javax.security.auth.Subject; +import javax.security.auth.login.LoginContext; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpSession; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.access.ConfigAttribute; @@ -49,71 +56,247 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.web.bind.annotation.GetMapping; -import javax.security.auth.Subject; -import javax.security.auth.login.LoginContext; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpSession; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrlPattern; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes are present in Java Config. + * Tests to verify that all the functionality of <http> attributes are present in + * Java Config. * * @author Rob Winch * @author Joe Grandja */ public class NamespaceHttpTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @Autowired private MockMvc mockMvc; - @Test // http@access-decision-manager-ref + @Test // http@access-decision-manager-ref public void configureWhenAccessDecisionManagerSetThenVerifyUse() throws Exception { AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER = mock(AccessDecisionManager.class); - when(AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER.supports(FilterInvocation.class)).thenReturn(true); - when(AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER.supports(any(ConfigAttribute.class))).thenReturn(true); - + given(AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER.supports(FilterInvocation.class)).willReturn(true); + given(AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER.supports(any(ConfigAttribute.class))) + .willReturn(true); this.spring.register(AccessDecisionManagerRefConfig.class).autowire(); - this.mockMvc.perform(get("/")); + verify(AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER, times(1)).decide(any(Authentication.class), + any(), anyCollection()); + } - verify(AccessDecisionManagerRefConfig.ACCESS_DECISION_MANAGER, times(1)).decide(any(Authentication.class), any(), anyCollection()); + @Test // http@access-denied-page + public void configureWhenAccessDeniedPageSetAndRequestForbiddenThenForwardedToAccessDeniedPage() throws Exception { + this.spring.register(AccessDeniedPageConfig.class).autowire(); + this.mockMvc.perform(get("/admin").with(user(PasswordEncodedUser.user()))).andExpect(status().isForbidden()) + .andExpect(forwardedUrl("/AccessDeniedPage")); + } + + @Test // http@authentication-manager-ref + public void configureWhenAuthenticationManagerProvidedThenVerifyUse() throws Exception { + AuthenticationManagerRefConfig.AUTHENTICATION_MANAGER = mock(AuthenticationManager.class); + this.spring.register(AuthenticationManagerRefConfig.class).autowire(); + this.mockMvc.perform(formLogin()); + verify(AuthenticationManagerRefConfig.AUTHENTICATION_MANAGER, times(1)).authenticate(any(Authentication.class)); + } + + @Test // http@create-session=always + public void configureWhenSessionCreationPolicyAlwaysThenSessionCreatedOnRequest() throws Exception { + this.spring.register(CreateSessionAlwaysConfig.class).autowire(); + MvcResult mvcResult = this.mockMvc.perform(get("/")).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNotNull(); + assertThat(session.isNew()).isTrue(); + } + + @Test // http@create-session=stateless + public void configureWhenSessionCreationPolicyStatelessThenSessionNotCreatedOnRequest() throws Exception { + this.spring.register(CreateSessionStatelessConfig.class).autowire(); + MvcResult mvcResult = this.mockMvc.perform(get("/")).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + } + + @Test // http@create-session=ifRequired + public void configureWhenSessionCreationPolicyIfRequiredThenSessionCreatedWhenRequiredOnRequest() throws Exception { + this.spring.register(IfRequiredConfig.class).autowire(); + MvcResult mvcResult = this.mockMvc.perform(get("/unsecure")).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + mvcResult = this.mockMvc.perform(formLogin()).andReturn(); + session = mvcResult.getRequest().getSession(false); + assertThat(session).isNotNull(); + assertThat(session.isNew()).isTrue(); + } + + @Test // http@create-session=never + public void configureWhenSessionCreationPolicyNeverThenSessionNotCreatedOnRequest() throws Exception { + this.spring.register(CreateSessionNeverConfig.class).autowire(); + MvcResult mvcResult = this.mockMvc.perform(get("/")).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + } + + @Test // http@entry-point-ref + public void configureWhenAuthenticationEntryPointSetAndRequestUnauthorizedThenRedirectedToAuthenticationEntryPoint() + throws Exception { + this.spring.register(EntryPointRefConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/")) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrlPattern("**/entry-point")); + // @formatter:on + } + + @Test // http@jaas-api-provision + public void configureWhenJaasApiIntegrationFilterAddedThenJaasSubjectObtained() throws Exception { + LoginContext loginContext = mock(LoginContext.class); + given(loginContext.getSubject()).willReturn(new Subject()); + JaasAuthenticationToken authenticationToken = mock(JaasAuthenticationToken.class); + given(authenticationToken.isAuthenticated()).willReturn(true); + given(authenticationToken.getLoginContext()).willReturn(loginContext); + this.spring.register(JaasApiProvisionConfig.class).autowire(); + this.mockMvc.perform(get("/").with(authentication(authenticationToken))); + verify(loginContext, times(1)).getSubject(); + } + + @Test // http@realm + public void configureWhenHttpBasicAndRequestUnauthorizedThenReturnWWWAuthenticateWithRealm() throws Exception { + this.spring.register(RealmConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/")) + .andExpect(status().isUnauthorized()) + .andExpect(header().string("WWW-Authenticate", "Basic realm=\"RealmConfig\"")); + // @formatter:on + } + + @Test // http@request-matcher-ref ant + public void configureWhenAntPatternMatchingThenAntPathRequestMatcherUsed() { + this.spring.register(RequestMatcherAntConfig.class).autowire(); + FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); + assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); + DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains() + .get(0); + assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(AntPathRequestMatcher.class); + } + + @Test // http@request-matcher-ref regex + public void configureWhenRegexPatternMatchingThenRegexRequestMatcherUsed() { + this.spring.register(RequestMatcherRegexConfig.class).autowire(); + FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); + assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); + DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains() + .get(0); + assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(RegexRequestMatcher.class); + } + + @Test // http@request-matcher-ref + public void configureWhenRequestMatcherProvidedThenRequestMatcherUsed() { + this.spring.register(RequestMatcherRefConfig.class).autowire(); + FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); + assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); + DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains() + .get(0); + assertThat(securityFilterChain.getRequestMatcher()) + .isInstanceOf(RequestMatcherRefConfig.MyRequestMatcher.class); + } + + @Test // http@security=none + public void configureWhenIgnoredAntPatternsThenAntPathRequestMatcherUsedWithNoFilters() { + this.spring.register(SecurityNoneConfig.class).autowire(); + FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); + assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); + DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains() + .get(0); + assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(AntPathRequestMatcher.class); + assertThat(((AntPathRequestMatcher) securityFilterChain.getRequestMatcher()).getPattern()) + .isEqualTo("/resources/**"); + assertThat(securityFilterChain.getFilters()).isEmpty(); + assertThat(filterChainProxy.getFilterChains().get(1)).isInstanceOf(DefaultSecurityFilterChain.class); + securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains().get(1); + assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(AntPathRequestMatcher.class); + assertThat(((AntPathRequestMatcher) securityFilterChain.getRequestMatcher()).getPattern()) + .isEqualTo("/public/**"); + assertThat(securityFilterChain.getFilters()).isEmpty(); + } + + @Test // http@security-context-repository-ref + public void configureWhenNullSecurityContextRepositoryThenSecurityContextNotSavedInSession() throws Exception { + this.spring.register(SecurityContextRepoConfig.class).autowire(); + MvcResult mvcResult = this.mockMvc.perform(formLogin()).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + } + + @Test // http@servlet-api-provision=false + public void configureWhenServletApiDisabledThenRequestNotServletApiWrapper() throws Exception { + this.spring.register(ServletApiProvisionConfig.class, MainController.class).autowire(); + this.mockMvc.perform(get("/")); + assertThat(MainController.HTTP_SERVLET_REQUEST_TYPE) + .isNotInstanceOf(SecurityContextHolderAwareRequestWrapper.class); + } + + @Test // http@servlet-api-provision defaults to true + public void configureWhenServletApiDefaultThenRequestIsServletApiWrapper() throws Exception { + this.spring.register(ServletApiProvisionDefaultsConfig.class, MainController.class).autowire(); + this.mockMvc.perform(get("/")); + assertThat(SecurityContextHolderAwareRequestWrapper.class) + .isAssignableFrom(MainController.HTTP_SERVLET_REQUEST_TYPE); + } + + @Test // http@use-expressions=true + public void configureWhenUseExpressionsEnabledThenExpressionBasedSecurityMetadataSource() { + this.spring.register(UseExpressionsConfig.class).autowire(); + UseExpressionsConfig config = this.spring.getContext().getBean(UseExpressionsConfig.class); + assertThat(ExpressionBasedFilterInvocationSecurityMetadataSource.class) + .isAssignableFrom(config.filterInvocationSecurityMetadataSourceType); + } + + @Test // http@use-expressions=false + public void configureWhenUseExpressionsDisabledThenDefaultSecurityMetadataSource() { + this.spring.register(DisableUseExpressionsConfig.class).autowire(); + DisableUseExpressionsConfig config = this.spring.getContext().getBean(DisableUseExpressionsConfig.class); + assertThat(DefaultFilterInvocationSecurityMetadataSource.class) + .isAssignableFrom(config.filterInvocationSecurityMetadataSourceType); } @EnableWebSecurity static class AccessDecisionManagerRefConfig extends WebSecurityConfigurerAdapter { + static AccessDecisionManager ACCESS_DECISION_MANAGER; @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().permitAll() .accessDecisionManager(ACCESS_DECISION_MANAGER); + // @formatter:on } - } - @Test // http@access-denied-page - public void configureWhenAccessDeniedPageSetAndRequestForbiddenThenForwardedToAccessDeniedPage() throws Exception { - this.spring.register(AccessDeniedPageConfig.class).autowire(); - - this.mockMvc.perform(get("/admin").with(user(PasswordEncodedUser.user()))) - .andExpect(status().isForbidden()) - .andExpect(forwardedUrl("/AccessDeniedPage")); } @EnableWebSecurity static class AccessDeniedPageConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin").hasRole("ADMIN") @@ -121,21 +304,14 @@ public class NamespaceHttpTests { .and() .exceptionHandling() .accessDeniedPage("/AccessDeniedPage"); + // @formatter:on } - } - @Test // http@authentication-manager-ref - public void configureWhenAuthenticationManagerProvidedThenVerifyUse() throws Exception { - AuthenticationManagerRefConfig.AUTHENTICATION_MANAGER = mock(AuthenticationManager.class); - this.spring.register(AuthenticationManagerRefConfig.class).autowire(); - - this.mockMvc.perform(formLogin()); - - verify(AuthenticationManagerRefConfig.AUTHENTICATION_MANAGER, times(1)).authenticate(any(Authentication.class)); } @EnableWebSecurity static class AuthenticationManagerRefConfig extends WebSecurityConfigurerAdapter { + static AuthenticationManager AUTHENTICATION_MANAGER; @Override @@ -145,81 +321,57 @@ public class NamespaceHttpTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() .and() .formLogin(); + // @formatter:on } - } - @Test // http@create-session=always - public void configureWhenSessionCreationPolicyAlwaysThenSessionCreatedOnRequest() throws Exception { - this.spring.register(CreateSessionAlwaysConfig.class).autowire(); - - MvcResult mvcResult = this.mockMvc.perform(get("/")).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNotNull(); - assertThat(session.isNew()).isTrue(); } @EnableWebSecurity static class CreateSessionAlwaysConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().permitAll() .and() .sessionManagement() .sessionCreationPolicy(SessionCreationPolicy.ALWAYS); + // @formatter:on } - } - @Test // http@create-session=stateless - public void configureWhenSessionCreationPolicyStatelessThenSessionNotCreatedOnRequest() throws Exception { - this.spring.register(CreateSessionStatelessConfig.class).autowire(); - - MvcResult mvcResult = this.mockMvc.perform(get("/")).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNull(); } @EnableWebSecurity static class CreateSessionStatelessConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().permitAll() .and() .sessionManagement() .sessionCreationPolicy(SessionCreationPolicy.STATELESS); + // @formatter:on } - } - @Test // http@create-session=ifRequired - public void configureWhenSessionCreationPolicyIfRequiredThenSessionCreatedWhenRequiredOnRequest() throws Exception { - this.spring.register(IfRequiredConfig.class).autowire(); - - MvcResult mvcResult = this.mockMvc.perform(get("/unsecure")).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNull(); - - mvcResult = this.mockMvc.perform(formLogin()).andReturn(); - session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNotNull(); - assertThat(session.isNew()).isTrue(); } @EnableWebSecurity static class IfRequiredConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/unsecure").permitAll() @@ -229,45 +381,34 @@ public class NamespaceHttpTests { .sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED) .and() .formLogin(); + // @formatter:on } - } - @Test // http@create-session=never - public void configureWhenSessionCreationPolicyNeverThenSessionNotCreatedOnRequest() throws Exception { - this.spring.register(CreateSessionNeverConfig.class).autowire(); - - MvcResult mvcResult = this.mockMvc.perform(get("/")).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNull(); } @EnableWebSecurity static class CreateSessionNeverConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().anonymous() .and() .sessionManagement() .sessionCreationPolicy(SessionCreationPolicy.NEVER); + // @formatter:on } - } - @Test // http@entry-point-ref - public void configureWhenAuthenticationEntryPointSetAndRequestUnauthorizedThenRedirectedToAuthenticationEntryPoint() throws Exception { - this.spring.register(EntryPointRefConfig.class).autowire(); - - this.mockMvc.perform(get("/")) - .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrlPattern("**/entry-point")); } @EnableWebSecurity static class EntryPointRefConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -276,139 +417,87 @@ public class NamespaceHttpTests { .authenticationEntryPoint(new LoginUrlAuthenticationEntryPoint("/entry-point")) .and() .formLogin(); + // @formatter:on } - } - @Test // http@jaas-api-provision - public void configureWhenJaasApiIntegrationFilterAddedThenJaasSubjectObtained() throws Exception { - LoginContext loginContext = mock(LoginContext.class); - when(loginContext.getSubject()).thenReturn(new Subject()); - - JaasAuthenticationToken authenticationToken = mock(JaasAuthenticationToken.class); - when(authenticationToken.isAuthenticated()).thenReturn(true); - when(authenticationToken.getLoginContext()).thenReturn(loginContext); - - this.spring.register(JaasApiProvisionConfig.class).autowire(); - - this.mockMvc.perform(get("/").with(authentication(authenticationToken))); - - verify(loginContext, times(1)).getSubject(); } @EnableWebSecurity static class JaasApiProvisionConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { + // @formatter:off http .addFilter(new JaasApiIntegrationFilter()); + // @formatter:on } - } - @Test // http@realm - public void configureWhenHttpBasicAndRequestUnauthorizedThenReturnWWWAuthenticateWithRealm() throws Exception { - this.spring.register(RealmConfig.class).autowire(); - - this.mockMvc.perform(get("/")) - .andExpect(status().isUnauthorized()) - .andExpect(header().string("WWW-Authenticate", "Basic realm=\"RealmConfig\"")); } @EnableWebSecurity static class RealmConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() .and() .httpBasic() .realmName("RealmConfig"); + // @formatter:on } - } - @Test // http@request-matcher-ref ant - public void configureWhenAntPatternMatchingThenAntPathRequestMatcherUsed() { - this.spring.register(RequestMatcherAntConfig.class).autowire(); - - FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); - - assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); - DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains().get(0); - assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(AntPathRequestMatcher.class); } @EnableWebSecurity static class RequestMatcherAntConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { + // @formatter:off http .antMatcher("/api/**"); + // @formatter:on } - } - @Test // http@request-matcher-ref regex - public void configureWhenRegexPatternMatchingThenRegexRequestMatcherUsed() { - this.spring.register(RequestMatcherRegexConfig.class).autowire(); - - FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); - - assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); - DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains().get(0); - assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(RegexRequestMatcher.class); } @EnableWebSecurity static class RequestMatcherRegexConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { + // @formatter:off http .regexMatcher("/regex/.*"); + // @formatter:on } - } - @Test // http@request-matcher-ref - public void configureWhenRequestMatcherProvidedThenRequestMatcherUsed() { - this.spring.register(RequestMatcherRefConfig.class).autowire(); - - FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); - - assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); - DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains().get(0); - assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(RequestMatcherRefConfig.MyRequestMatcher.class); } @EnableWebSecurity static class RequestMatcherRefConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { + // @formatter:off http .requestMatcher(new MyRequestMatcher()); + // @formatter:on } static class MyRequestMatcher implements RequestMatcher { + + @Override public boolean matches(HttpServletRequest request) { return true; } + } - } - @Test // http@security=none - public void configureWhenIgnoredAntPatternsThenAntPathRequestMatcherUsedWithNoFilters() { - this.spring.register(SecurityNoneConfig.class).autowire(); - - FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); - - assertThat(filterChainProxy.getFilterChains().get(0)).isInstanceOf(DefaultSecurityFilterChain.class); - DefaultSecurityFilterChain securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains().get(0); - assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(AntPathRequestMatcher.class); - assertThat(((AntPathRequestMatcher) securityFilterChain.getRequestMatcher()).getPattern()).isEqualTo("/resources/**"); - assertThat(securityFilterChain.getFilters()).isEmpty(); - - assertThat(filterChainProxy.getFilterChains().get(1)).isInstanceOf(DefaultSecurityFilterChain.class); - securityFilterChain = (DefaultSecurityFilterChain) filterChainProxy.getFilterChains().get(1); - assertThat(securityFilterChain.getRequestMatcher()).isInstanceOf(AntPathRequestMatcher.class); - assertThat(((AntPathRequestMatcher) securityFilterChain.getRequestMatcher()).getPattern()).isEqualTo("/public/**"); - assertThat(securityFilterChain.getFilters()).isEmpty(); } @EnableWebSecurity @@ -416,29 +505,21 @@ public class NamespaceHttpTests { @Override public void configure(WebSecurity web) { - web - .ignoring() - .antMatchers("/resources/**", "/public/**"); + web.ignoring().antMatchers("/resources/**", "/public/**"); } @Override protected void configure(HttpSecurity http) { } - } - @Test // http@security-context-repository-ref - public void configureWhenNullSecurityContextRepositoryThenSecurityContextNotSavedInSession() throws Exception { - this.spring.register(SecurityContextRepoConfig.class).autowire(); - - MvcResult mvcResult = this.mockMvc.perform(formLogin()).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - assertThat(session).isNull(); } @EnableWebSecurity static class SecurityContextRepoConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -447,89 +528,78 @@ public class NamespaceHttpTests { .securityContextRepository(new NullSecurityContextRepository()) .and() .formLogin(); + // @formatter:on } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } - } - @Test // http@servlet-api-provision=false - public void configureWhenServletApiDisabledThenRequestNotServletApiWrapper() throws Exception { - this.spring.register(ServletApiProvisionConfig.class, MainController.class).autowire(); - - this.mockMvc.perform(get("/")); - - assertThat(MainController.HTTP_SERVLET_REQUEST_TYPE).isNotInstanceOf(SecurityContextHolderAwareRequestWrapper.class); } @EnableWebSecurity static class ServletApiProvisionConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().permitAll() .and() .servletApi() .disable(); + // @formatter:on } - } - @Test // http@servlet-api-provision defaults to true - public void configureWhenServletApiDefaultThenRequestIsServletApiWrapper() throws Exception { - this.spring.register(ServletApiProvisionDefaultsConfig.class, MainController.class).autowire(); - - this.mockMvc.perform(get("/")); - - assertThat(SecurityContextHolderAwareRequestWrapper.class).isAssignableFrom(MainController.HTTP_SERVLET_REQUEST_TYPE); } @EnableWebSecurity static class ServletApiProvisionDefaultsConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().permitAll(); + // @formatter:on } + } @Controller static class MainController { + static Class HTTP_SERVLET_REQUEST_TYPE; @GetMapping("/") - public String index(HttpServletRequest request) { + String index(HttpServletRequest request) { HTTP_SERVLET_REQUEST_TYPE = request.getClass(); return "index"; } - } - @Test // http@use-expressions=true - public void configureWhenUseExpressionsEnabledThenExpressionBasedSecurityMetadataSource() { - this.spring.register(UseExpressionsConfig.class).autowire(); - - UseExpressionsConfig config = this.spring.getContext().getBean(UseExpressionsConfig.class); - - assertThat(ExpressionBasedFilterInvocationSecurityMetadataSource.class) - .isAssignableFrom(config.filterInvocationSecurityMetadataSourceType); } @EnableWebSecurity static class UseExpressionsConfig extends WebSecurityConfigurerAdapter { + private Class filterInvocationSecurityMetadataSourceType; @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/users**", "/sessions/**").hasRole("USER") .antMatchers("/signup").permitAll() .anyRequest().hasRole("USER"); + // @formatter:on } @Override @@ -538,33 +608,27 @@ public class NamespaceHttpTests { final HttpSecurity http = this.getHttp(); web.postBuildAction(() -> { FilterSecurityInterceptor securityInterceptor = http.getSharedObject(FilterSecurityInterceptor.class); - UseExpressionsConfig.this.filterInvocationSecurityMetadataSourceType = - securityInterceptor.getSecurityMetadataSource().getClass(); + UseExpressionsConfig.this.filterInvocationSecurityMetadataSourceType = securityInterceptor + .getSecurityMetadataSource().getClass(); }); } - } - @Test // http@use-expressions=false - public void configureWhenUseExpressionsDisabledThenDefaultSecurityMetadataSource() { - this.spring.register(DisableUseExpressionsConfig.class).autowire(); - - DisableUseExpressionsConfig config = this.spring.getContext().getBean(DisableUseExpressionsConfig.class); - - assertThat(DefaultFilterInvocationSecurityMetadataSource.class) - .isAssignableFrom(config.filterInvocationSecurityMetadataSourceType); } @EnableWebSecurity static class DisableUseExpressionsConfig extends WebSecurityConfigurerAdapter { + private Class filterInvocationSecurityMetadataSourceType; @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .apply(new UrlAuthorizationConfigurer<>(getApplicationContext())).getRegistry() .antMatchers("/users**", "/sessions/**").hasRole("USER") .antMatchers("/signup").hasRole("ANONYMOUS") .anyRequest().hasRole("USER"); + // @formatter:on } @Override @@ -573,9 +637,11 @@ public class NamespaceHttpTests { final HttpSecurity http = this.getHttp(); web.postBuildAction(() -> { FilterSecurityInterceptor securityInterceptor = http.getSharedObject(FilterSecurityInterceptor.class); - DisableUseExpressionsConfig.this.filterInvocationSecurityMetadataSourceType = - securityInterceptor.getSecurityMetadataSource().getClass(); + DisableUseExpressionsConfig.this.filterInvocationSecurityMetadataSourceType = securityInterceptor + .getSecurityMetadataSource().getClass(); }); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/builders/WebSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/builders/WebSecurityTests.java index a096e39ec6..8e58db3ec9 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/builders/WebSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/builders/WebSecurityTests.java @@ -45,10 +45,13 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Rob Winch */ public class WebSecurityTests { + AnnotationConfigWebApplicationContext context; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Autowired @@ -72,39 +75,60 @@ public class WebSecurityTests { @Test public void ignoringMvcMatcher() throws Exception { loadConfig(MvcMatcherConfig.class, LegacyMvcMatchingConfig.class); - this.request.setRequestURI("/path"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - setup(); - this.request.setRequestURI("/path.html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - setup(); - this.request.setRequestURI("/path/"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - setup(); - this.request.setRequestURI("/other"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + @Test + public void ignoringMvcMatcherServletPath() throws Exception { + loadConfig(MvcMatcherServletPathConfig.class, LegacyMvcMatchingConfig.class); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/other"); + this.request.setRequestURI("/other/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + public void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.setServletContext(new MockServletContext()); + this.context.refresh(); + this.context.getAutowireCapableBeanFactory().autowireBean(this); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(WebSecurity web) { // @formatter:off @@ -134,53 +158,21 @@ public class WebSecurityTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void ignoringMvcMatcherServletPath() throws Exception { - loadConfig(MvcMatcherServletPathConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherServletPathConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(WebSecurity web) { // @formatter:off @@ -211,28 +203,24 @@ public class WebSecurityTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } + } @Configuration static class LegacyMvcMatchingConfig implements WebMvcConfigurer { + @Override public void configurePathMatch(PathMatchConfigurer configurer) { configurer.setUseSuffixPatternMatch(true); } - } - public void loadConfig(Class... configs) { - this.context = new AnnotationConfigWebApplicationContext(); - this.context.register(configs); - this.context.setServletContext(new MockServletContext()); - this.context.refresh(); - - this.context.getAutowireCapableBeanFactory().autowireBean(this); } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/AuthenticationPrincipalArgumentResolverTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/AuthenticationPrincipalArgumentResolverTests.java index 9392335bb8..a03eb8650b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/AuthenticationPrincipalArgumentResolverTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/AuthenticationPrincipalArgumentResolverTests.java @@ -13,15 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configuration; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.config.annotation.web.configuration; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -41,8 +39,11 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + /** - * * @author Rob Winch * */ @@ -50,6 +51,7 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc; @ContextConfiguration @WebAppConfiguration public class AuthenticationPrincipalArgumentResolverTests { + @Autowired WebApplicationContext wac; @@ -62,32 +64,32 @@ public class AuthenticationPrincipalArgumentResolverTests { public void authenticationPrincipalExpressionWhenBeanExpressionSuppliedThenBeanUsed() throws Exception { User user = new User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContext context = SecurityContextHolder.createEmptyContext(); - context.setAuthentication(new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities())); + context.setAuthentication( + new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities())); SecurityContextHolder.setContext(context); - - MockMvc mockMvc = MockMvcBuilders - .webAppContextSetup(wac) - .build(); - + MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).build(); + // @formatter:off mockMvc.perform(get("/users/self")) - .andExpect(status().isOk()) - .andExpect(content().string("extracted-user")); + .andExpect(status().isOk()) + .andExpect(content().string("extracted-user")); + // @formatter:on } @EnableWebSecurity @EnableWebMvc static class Config { + @Autowired public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication(); + // @formatter:off } - @Bean public UsernameExtractor usernameExtractor() { return new UsernameExtractor(); } - @RestController static class UserController { @GetMapping("/users/self") @@ -96,7 +98,6 @@ public class AuthenticationPrincipalArgumentResolverTests { } } } - static class UsernameExtractor { public String extract(User u) { return "extracted-" + u.getUsername(); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurityTests.java index 180dcac62b..96b058bd50 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/EnableWebSecurityTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -45,6 +47,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Joe Grandja */ public class EnableWebSecurityTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -54,19 +57,58 @@ public class EnableWebSecurityTests { @Test public void configureWhenOverrideAuthenticationManagerBeanThenAuthenticationManagerBeanRegistered() { this.spring.register(SecurityConfig.class).autowire(); - AuthenticationManager authenticationManager = this.spring.getContext().getBean(AuthenticationManager.class); - Authentication authentication = authenticationManager.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); + Authentication authentication = authenticationManager + .authenticate(new UsernamePasswordAuthenticationToken("user", "password")); assertThat(authentication.isAuthenticated()).isTrue(); } + @Test + public void loadConfigWhenChildConfigExtendsSecurityConfigThenSecurityConfigInherited() { + this.spring.register(ChildSecurityConfig.class).autowire(); + this.spring.getContext().getBean("springSecurityFilterChain", DebugFilter.class); + } + + @Test + public void configureWhenEnableWebMvcThenAuthenticationPrincipalResolvable() throws Exception { + this.spring.register(AuthenticationPrincipalConfig.class).autowire(); + this.mockMvc.perform(get("/").with(authentication(new TestingAuthenticationToken("user1", "password")))) + .andExpect(content().string("user1")); + } + + @Test + public void securityFilterChainWhenEnableWebMvcThenAuthenticationPrincipalResolvable() throws Exception { + this.spring.register(SecurityFilterChainAuthenticationPrincipalConfig.class).autowire(); + this.mockMvc.perform(get("/").with(authentication(new TestingAuthenticationToken("user1", "password")))) + .andExpect(content().string("user1")); + } + + @Test + public void enableWebSecurityWhenNoConfigurationAnnotationThenBeanProxyingEnabled() { + this.spring.register(BeanProxyEnabledByDefaultConfig.class).autowire(); + Child childBean = this.spring.getContext().getBean(Child.class); + Parent parentBean = this.spring.getContext().getBean(Parent.class); + assertThat(parentBean.getChild()).isSameAs(childBean); + } + + @Test + public void enableWebSecurityWhenProxyBeanMethodsFalseThenBeanProxyingDisabled() { + this.spring.register(BeanProxyDisabledConfig.class).autowire(); + Child childBean = this.spring.getContext().getBean(Child.class); + Parent parentBean = this.spring.getContext().getBean(Parent.class); + assertThat(parentBean.getChild()).isNotSameAs(childBean); + } + @EnableWebSecurity static class SecurityConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } @Bean @@ -77,39 +119,31 @@ public class EnableWebSecurityTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/*").hasRole("USER") .and() .formLogin(); + // @formatter:on } - } - @Test - public void loadConfigWhenChildConfigExtendsSecurityConfigThenSecurityConfigInherited() { - this.spring.register(ChildSecurityConfig.class).autowire(); - this.spring.getContext().getBean("springSecurityFilterChain", DebugFilter.class); } @Configuration static class ChildSecurityConfig extends DebugSecurityConfig { + } - @EnableWebSecurity(debug=true) + @EnableWebSecurity(debug = true) static class DebugSecurityConfig extends WebSecurityConfigurerAdapter { - } - @Test - public void configureWhenEnableWebMvcThenAuthenticationPrincipalResolvable() throws Exception { - this.spring.register(AuthenticationPrincipalConfig.class).autowire(); - - this.mockMvc.perform(get("/").with(authentication(new TestingAuthenticationToken("user1", "password")))) - .andExpect(content().string("user1")); } @EnableWebSecurity @EnableWebMvc static class AuthenticationPrincipalConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { } @@ -121,20 +155,15 @@ public class EnableWebSecurityTests { String principal(@AuthenticationPrincipal String principal) { return principal; } + } - } - @Test - public void securityFilterChainWhenEnableWebMvcThenAuthenticationPrincipalResolvable() throws Exception { - this.spring.register(SecurityFilterChainAuthenticationPrincipalConfig.class).autowire(); - - this.mockMvc.perform(get("/").with(authentication(new TestingAuthenticationToken("user1", "password")))) - .andExpect(content().string("user1")); } @EnableWebSecurity @EnableWebMvc static class SecurityFilterChainAuthenticationPrincipalConfig { + @Bean SecurityFilterChain filterChain(HttpSecurity http) throws Exception { return http.build(); @@ -147,70 +176,61 @@ public class EnableWebSecurityTests { String principal(@AuthenticationPrincipal String principal) { return principal; } + } - } - @Test - public void enableWebSecurityWhenNoConfigurationAnnotationThenBeanProxyingEnabled() { - this.spring.register(BeanProxyEnabledByDefaultConfig.class).autowire(); - - Child childBean = this.spring.getContext().getBean(Child.class); - Parent parentBean = this.spring.getContext().getBean(Parent.class); - - assertThat(parentBean.getChild()).isSameAs(childBean); } @EnableWebSecurity static class BeanProxyEnabledByDefaultConfig extends WebSecurityConfigurerAdapter { + @Bean - public Child child() { + Child child() { return new Child(); } @Bean - public Parent parent() { + Parent parent() { return new Parent(child()); } - } - @Test - public void enableWebSecurityWhenProxyBeanMethodsFalseThenBeanProxyingDisabled() { - this.spring.register(BeanProxyDisabledConfig.class).autowire(); - - Child childBean = this.spring.getContext().getBean(Child.class); - Parent parentBean = this.spring.getContext().getBean(Parent.class); - - assertThat(parentBean.getChild()).isNotSameAs(childBean); } @Configuration(proxyBeanMethods = false) @EnableWebSecurity static class BeanProxyDisabledConfig extends WebSecurityConfigurerAdapter { + @Bean - public Child child() { + Child child() { return new Child(); } @Bean - public Parent parent() { + Parent parent() { return new Parent(child()); } + } static class Parent { + private Child child; Parent(Child child) { this.child = child; } - public Child getChild() { - return child; + Child getChild() { + return this.child; } + } static class Child { + Child() { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java index 5297578a65..56a0909c47 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java @@ -16,9 +16,14 @@ package org.springframework.security.config.annotation.web.configuration; +import java.util.concurrent.Callable; + +import javax.servlet.http.HttpServletRequest; + import com.google.common.net.HttpHeaders; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -36,12 +41,10 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import javax.servlet.http.HttpServletRequest; -import java.util.concurrent.Callable; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; @@ -62,6 +65,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Eleftheria Stein */ public class HttpSecurityConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -70,26 +74,27 @@ public class HttpSecurityConfigurationTests { @Test public void postWhenDefaultFilterChainBeanThenRespondsWithForbidden() throws Exception { - this.spring.register(DefaultWithFilterChainConfig.class) - .autowire(); + this.spring.register(DefaultWithFilterChainConfig.class).autowire(); - this.mockMvc.perform(post("/")) - .andExpect(status().isForbidden()); + this.mockMvc.perform(post("/")).andExpect(status().isForbidden()); } @Test public void getWhenDefaultFilterChainBeanThenDefaultHeadersInResponse() throws Exception { this.spring.register(DefaultWithFilterChainConfig.class).autowire(); - + // @formatter:off MvcResult mvcResult = this.mockMvc.perform(get("/").secure(true)) .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")) - .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsHeaderWriter.XFrameOptionsMode.DENY.name())) - .andExpect(header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) + .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, + XFrameOptionsHeaderWriter.XFrameOptionsMode.DENY.name())) + .andExpect( + header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) .andExpect(header().string(HttpHeaders.EXPIRES, "0")) .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")) .andReturn(); + // @formatter:on assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder( HttpHeaders.X_CONTENT_TYPE_OPTIONS, HttpHeaders.X_FRAME_OPTIONS, HttpHeaders.STRICT_TRANSPORT_SECURITY, HttpHeaders.CACHE_CONTROL, HttpHeaders.EXPIRES, HttpHeaders.PRAGMA, HttpHeaders.X_XSS_PROTECTION); @@ -98,151 +103,179 @@ public class HttpSecurityConfigurationTests { @Test public void logoutWhenDefaultFilterChainBeanThenCreatesDefaultLogoutEndpoint() throws Exception { this.spring.register(DefaultWithFilterChainConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(post("/logout").with(csrf())) .andExpect(redirectedUrl("/login?logout")); + // @formatter:on } @Test public void loadConfigWhenDefaultConfigThenWebAsyncManagerIntegrationFilterAdded() throws Exception { this.spring.register(DefaultWithFilterChainConfig.class, NameController.class).autowire(); - - MvcResult mvcResult = this.mockMvc.perform(get("/name").with(user("Bob"))) + // @formatter:off + MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob")); + MvcResult mvcResult = this.mockMvc.perform(requestWithBob) .andExpect(request().asyncStarted()) .andReturn(); - this.mockMvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().string("Bob")); - } - - @RestController - static class NameController { - @GetMapping("/name") - public Callable name() { - return () -> SecurityContextHolder.getContext().getAuthentication().getName(); - } - } - - @EnableWebSecurity - static class DefaultWithFilterChainConfig { - @Bean - public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { - return http.build(); - } + // @formatter:on } @Test public void getWhenDefaultFilterChainBeanThenAnonymousPermitted() throws Exception { this.spring.register(AuthorizeRequestsConfig.class, UserDetailsConfig.class, BaseController.class).autowire(); - + // @formatter:off this.mockMvc.perform(get("/")) .andExpect(status().isOk()); - } - - @EnableWebSecurity - static class AuthorizeRequestsConfig { - @Bean - public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { - return http - .authorizeRequests(authorize -> authorize - .anyRequest().permitAll() - ) - .build(); - } + // @formatter:on } @Test public void authenticateWhenDefaultFilterChainBeanThenSessionIdChanges() throws Exception { this.spring.register(SecurityEnabledConfig.class, UserDetailsConfig.class).autowire(); - MockHttpSession session = new MockHttpSession(); String sessionId = session.getId(); - - MvcResult result = - this.mockMvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session(session) - .with(csrf())) - .andReturn(); - + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .session(session) + .with(csrf()); + // @formatter:on + MvcResult result = this.mockMvc.perform(loginRequest).andReturn(); assertThat(result.getRequest().getSession(false).getId()).isNotEqualTo(sessionId); } @Test public void authenticateWhenDefaultFilterChainBeanThenRedirectsToSavedRequest() throws Exception { this.spring.register(SecurityEnabledConfig.class, UserDetailsConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mockMvc.perform(get("/messages")) - .andReturn().getRequest().getSession(); - - this.mockMvc.perform(post("/login") + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mockMvc.perform(get("/messages")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") .param("password", "password") .session(session) - .with(csrf())) + .with(csrf()); + // @formatter:on + // @formatter:off + this.mockMvc.perform(loginRequest) .andExpect(redirectedUrl("http://localhost/messages")); + // @formatter:on } @Test public void authenticateWhenDefaultFilterChainBeanThenRolePrefixIsSet() throws Exception { this.spring.register(SecurityEnabledConfig.class, UserDetailsConfig.class, UserController.class).autowire(); - - this.mockMvc.perform(get("/user") - .with(authentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")))) + TestingAuthenticationToken user = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + // @formatter:off + this.mockMvc + .perform(get("/user").with(authentication(user))) .andExpect(status().isOk()); + // @formatter:on } @Test public void loginWhenUsingDefaultsThenDefaultLoginPageGenerated() throws Exception { this.spring.register(SecurityEnabledConfig.class).autowire(); + this.mockMvc.perform(get("/login")).andExpect(status().isOk()); + } + + @RestController + static class NameController { + + @GetMapping("/name") + Callable name() { + return () -> SecurityContextHolder.getContext().getAuthentication().getName(); + } + + } + + @EnableWebSecurity + static class DefaultWithFilterChainConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + return http.build(); + } + + } + + @EnableWebSecurity + static class AuthorizeRequestsConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + return http + .authorizeRequests((authorize) -> authorize + .anyRequest().permitAll() + ) + .build(); + // @formatter:on + } - this.mockMvc.perform(get("/login")) - .andExpect(status().isOk()); } @EnableWebSecurity static class SecurityEnabledConfig { + @Bean - public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off return http - .authorizeRequests(authorize -> authorize - .anyRequest().authenticated() + .authorizeRequests((authorize) -> authorize + .anyRequest().authenticated() ) .formLogin(withDefaults()) .build(); + // @formatter:on } + } @Configuration static class UserDetailsConfig { + @Bean - public UserDetailsService userDetailsService() { + UserDetailsService userDetailsService() { + // @formatter:off UserDetails user = User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build(); + // @formatter:on return new InMemoryUserDetailsManager(user); } + } @RestController static class BaseController { + @GetMapping("/") - public void index() { + void index() { } + } @RestController static class UserController { + @GetMapping("/user") - public void user(HttpServletRequest request) { + void user(HttpServletRequest request) { if (!request.isUserInRole("USER")) { throw new AccessDeniedException("This resource is only available to users"); } } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index 48eb3fd45d..122ee95fea 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import javax.servlet.http.HttpServletRequest; + import org.junit.Rule; import org.junit.Test; + +import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; @@ -31,28 +36,26 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResp import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import javax.servlet.http.HttpServletRequest; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; @@ -64,6 +67,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Joe Grandja */ public class OAuth2ClientConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -75,68 +79,143 @@ public class OAuth2ClientConfigurationTests { String clientRegistrationId = "client1"; String principalName = "user1"; TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); - ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); - ClientRegistration clientRegistration = clientRegistration().registrationId(clientRegistrationId).build(); - when(clientRegistrationRepository.findByRegistrationId(eq(clientRegistrationId))).thenReturn(clientRegistration); - + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .registrationId(clientRegistrationId).build(); + given(clientRegistrationRepository.findByRegistrationId(eq(clientRegistrationId))) + .willReturn(clientRegistration); OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); - when(authorizedClient.getClientRegistration()).thenReturn(clientRegistration); - when(authorizedClientRepository.loadAuthorizedClient( - eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))) - .thenReturn(authorizedClient); - + given(authorizedClient.getClientRegistration()).willReturn(clientRegistration); + given(authorizedClientRepository.loadAuthorizedClient(eq(clientRegistrationId), eq(authentication), + any(HttpServletRequest.class))).willReturn(authorizedClient); OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); - when(authorizedClient.getAccessToken()).thenReturn(accessToken); - + given(authorizedClient.getAccessToken()).willReturn(accessToken); OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); - OAuth2AuthorizedClientArgumentResolverConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository; OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository; OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient; this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))) .andExpect(status().isOk()) .andExpect(content().string("resolved")); + // @formatter:on verifyZeroInteractions(accessTokenResponseClient); } @Test - public void requestWhenAuthorizedClientNotFoundAndClientCredentialsThenTokenResponseClientIsUsed() throws Exception { + public void requestWhenAuthorizedClientNotFoundAndClientCredentialsThenTokenResponseClientIsUsed() + throws Exception { String clientRegistrationId = "client1"; String principalName = "user1"; TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); - ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); - - ClientRegistration clientRegistration = clientCredentials().registrationId(clientRegistrationId).build(); - when(clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(clientRegistration); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() + .registrationId(clientRegistrationId).build(); + given(clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).willReturn(clientRegistration); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(300) .build(); - when(accessTokenResponseClient.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class))) - .thenReturn(accessTokenResponse); - + // @formatter:on + given(accessTokenResponseClient.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class))) + .willReturn(accessTokenResponse); OAuth2AuthorizedClientArgumentResolverConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository; OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository; OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient; this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire(); - - this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))) + MockHttpServletRequestBuilder authenticatedRequest = get("/authorized-client") + .with(authentication(authentication)); + // @formatter:off + this.mockMvc.perform(authenticatedRequest) .andExpect(status().isOk()) .andExpect(content().string("resolved")); + // @formatter:on verify(accessTokenResponseClient, times(1)).getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class)); } + // gh-5321 + @Test + public void loadContextWhenOAuth2AuthorizedClientRepositoryRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { + assertThatExceptionOfType(BeanCreationException.class).isThrownBy( + () -> this.spring.register(OAuth2AuthorizedClientRepositoryRegisteredTwiceConfig.class).autowire()) + .withRootCauseInstanceOf(NoUniqueBeanDefinitionException.class).withMessageContaining( + "Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() + + "' but found 2: authorizedClientRepository1,authorizedClientRepository2"); + } + + @Test + public void loadContextWhenClientRegistrationRepositoryNotRegisteredThenThrowNoSuchBeanDefinitionException() { + assertThatExceptionOfType(Exception.class) + .isThrownBy( + () -> this.spring.register(ClientRegistrationRepositoryNotRegisteredConfig.class).autowire()) + .withRootCauseInstanceOf(NoSuchBeanDefinitionException.class).withMessageContaining( + "No qualifying bean of type '" + ClientRegistrationRepository.class.getName() + "' available"); + } + + @Test + public void loadContextWhenClientRegistrationRepositoryRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { + // @formatter:off + assertThatExceptionOfType(Exception.class) + .isThrownBy( + () -> this.spring.register(ClientRegistrationRepositoryRegisteredTwiceConfig.class).autowire()) + .withMessageContaining( + "expected single matching bean but found 2: clientRegistrationRepository1,clientRegistrationRepository2") + .withRootCauseInstanceOf(NoUniqueBeanDefinitionException.class); + // @formatter:on + } + + @Test + public void loadContextWhenAccessTokenResponseClientRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { + // @formatter:off + assertThatExceptionOfType(Exception.class) + .isThrownBy(() -> this.spring.register(AccessTokenResponseClientRegisteredTwiceConfig.class).autowire()) + .withRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) + .withMessageContaining( + "expected single matching bean but found 2: accessTokenResponseClient1,accessTokenResponseClient2"); + // @formatter:on + } + + // gh-8700 + @Test + public void requestWhenAuthorizedClientManagerConfiguredThenUsed() throws Exception { + String clientRegistrationId = "client1"; + String principalName = "user1"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + OAuth2AuthorizedClientManager authorizedClientManager = mock(OAuth2AuthorizedClientManager.class); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .registrationId(clientRegistrationId).build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principalName, + TestOAuth2AccessTokens.noScopes()); + given(authorizedClientManager.authorize(any())).willReturn(authorizedClient); + OAuth2AuthorizedClientManagerRegisteredConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository; + OAuth2AuthorizedClientManagerRegisteredConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository; + OAuth2AuthorizedClientManagerRegisteredConfig.AUTHORIZED_CLIENT_MANAGER = authorizedClientManager; + this.spring.register(OAuth2AuthorizedClientManagerRegisteredConfig.class).autowire(); + MockHttpServletRequestBuilder authenticatedRequest = get("/authorized-client") + .with(authentication(authentication)); + // @formatter:off + this.mockMvc + .perform(authenticatedRequest) + .andExpect(status().isOk()) + .andExpect(content().string("resolved")); + // @formatter:on + verify(authorizedClientManager).authorize(any()); + verifyNoInteractions(clientRegistrationRepository); + verifyNoInteractions(authorizedClientRepository); + } + @EnableWebMvc @EnableWebSecurity static class OAuth2AuthorizedClientArgumentResolverConfig extends WebSecurityConfigurerAdapter { + static ClientRegistrationRepository CLIENT_REGISTRATION_REPOSITORY; static OAuth2AuthorizedClientRepository AUTHORIZED_CLIENT_REPOSITORY; static OAuth2AccessTokenResponseClient ACCESS_TOKEN_RESPONSE_CLIENT; @@ -145,38 +224,32 @@ public class OAuth2ClientConfigurationTests { protected void configure(HttpSecurity http) { } - @RestController - public class Controller { - - @GetMapping("/authorized-client") - public String authorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { - return authorizedClient != null ? "resolved" : "not-resolved"; - } - } - @Bean - public ClientRegistrationRepository clientRegistrationRepository() { + ClientRegistrationRepository clientRegistrationRepository() { return CLIENT_REGISTRATION_REPOSITORY; } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository() { + OAuth2AuthorizedClientRepository authorizedClientRepository() { return AUTHORIZED_CLIENT_REPOSITORY; } @Bean - public OAuth2AccessTokenResponseClient accessTokenResponseClient() { + OAuth2AccessTokenResponseClient accessTokenResponseClient() { return ACCESS_TOKEN_RESPONSE_CLIENT; } - } - // gh-5321 - @Test - public void loadContextWhenOAuth2AuthorizedClientRepositoryRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { - assertThatThrownBy(() -> this.spring.register(OAuth2AuthorizedClientRepositoryRegisteredTwiceConfig.class).autowire()) - .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) - .hasMessageContaining("Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() + - "' but found 2: authorizedClientRepository1,authorizedClientRepository2"); + @RestController + class Controller { + + @GetMapping("/authorized-client") + String authorizedClient( + @RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { + return (authorizedClient != null) ? "resolved" : "not-resolved"; + } + + } + } @EnableWebMvc @@ -195,31 +268,25 @@ public class OAuth2ClientConfigurationTests { } @Bean - public ClientRegistrationRepository clientRegistrationRepository() { + ClientRegistrationRepository clientRegistrationRepository() { return mock(ClientRegistrationRepository.class); } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository1() { + OAuth2AuthorizedClientRepository authorizedClientRepository1() { return mock(OAuth2AuthorizedClientRepository.class); } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository2() { + OAuth2AuthorizedClientRepository authorizedClientRepository2() { return mock(OAuth2AuthorizedClientRepository.class); } @Bean - public OAuth2AccessTokenResponseClient accessTokenResponseClient() { + OAuth2AccessTokenResponseClient accessTokenResponseClient() { return mock(OAuth2AccessTokenResponseClient.class); } - } - @Test - public void loadContextWhenClientRegistrationRepositoryNotRegisteredThenThrowNoSuchBeanDefinitionException() { - assertThatThrownBy(() -> this.spring.register(ClientRegistrationRepositoryNotRegisteredConfig.class).autowire()) - .hasRootCauseInstanceOf(NoSuchBeanDefinitionException.class) - .hasMessageContaining("No qualifying bean of type '" + ClientRegistrationRepository.class.getName() + "' available"); } @EnableWebMvc @@ -236,13 +303,7 @@ public class OAuth2ClientConfigurationTests { .oauth2Login(); // @formatter:on } - } - @Test - public void loadContextWhenClientRegistrationRepositoryRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { - assertThatThrownBy(() -> this.spring.register(ClientRegistrationRepositoryRegisteredTwiceConfig.class).autowire()) - .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) - .hasMessageContaining("expected single matching bean but found 2: clientRegistrationRepository1,clientRegistrationRepository2"); } @EnableWebMvc @@ -261,31 +322,25 @@ public class OAuth2ClientConfigurationTests { } @Bean - public ClientRegistrationRepository clientRegistrationRepository1() { + ClientRegistrationRepository clientRegistrationRepository1() { return mock(ClientRegistrationRepository.class); } @Bean - public ClientRegistrationRepository clientRegistrationRepository2() { + ClientRegistrationRepository clientRegistrationRepository2() { return mock(ClientRegistrationRepository.class); } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository() { + OAuth2AuthorizedClientRepository authorizedClientRepository() { return mock(OAuth2AuthorizedClientRepository.class); } @Bean - public OAuth2AccessTokenResponseClient accessTokenResponseClient() { + OAuth2AccessTokenResponseClient accessTokenResponseClient() { return mock(OAuth2AccessTokenResponseClient.class); } - } - @Test - public void loadContextWhenAccessTokenResponseClientRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() { - assertThatThrownBy(() -> this.spring.register(AccessTokenResponseClientRegisteredTwiceConfig.class).autowire()) - .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) - .hasMessageContaining("expected single matching bean but found 2: accessTokenResponseClient1,accessTokenResponseClient2"); } @EnableWebMvc @@ -304,60 +359,31 @@ public class OAuth2ClientConfigurationTests { } @Bean - public ClientRegistrationRepository clientRegistrationRepository() { + ClientRegistrationRepository clientRegistrationRepository() { return mock(ClientRegistrationRepository.class); } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository() { + OAuth2AuthorizedClientRepository authorizedClientRepository() { return mock(OAuth2AuthorizedClientRepository.class); } @Bean - public OAuth2AccessTokenResponseClient accessTokenResponseClient1() { + OAuth2AccessTokenResponseClient accessTokenResponseClient1() { return mock(OAuth2AccessTokenResponseClient.class); } @Bean - public OAuth2AccessTokenResponseClient accessTokenResponseClient2() { + OAuth2AccessTokenResponseClient accessTokenResponseClient2() { return mock(OAuth2AccessTokenResponseClient.class); } - } - // gh-8700 - @Test - public void requestWhenAuthorizedClientManagerConfiguredThenUsed() throws Exception { - String clientRegistrationId = "client1"; - String principalName = "user1"; - TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); - - ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); - OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - OAuth2AuthorizedClientManager authorizedClientManager = mock(OAuth2AuthorizedClientManager.class); - - ClientRegistration clientRegistration = clientRegistration().registrationId(clientRegistrationId).build(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, principalName, TestOAuth2AccessTokens.noScopes()); - - when(authorizedClientManager.authorize(any())).thenReturn(authorizedClient); - - OAuth2AuthorizedClientManagerRegisteredConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository; - OAuth2AuthorizedClientManagerRegisteredConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository; - OAuth2AuthorizedClientManagerRegisteredConfig.AUTHORIZED_CLIENT_MANAGER = authorizedClientManager; - this.spring.register(OAuth2AuthorizedClientManagerRegisteredConfig.class).autowire(); - - this.mockMvc.perform(get("/authorized-client").with(authentication(authentication))) - .andExpect(status().isOk()) - .andExpect(content().string("resolved")); - - verify(authorizedClientManager).authorize(any()); - verifyNoInteractions(clientRegistrationRepository); - verifyNoInteractions(authorizedClientRepository); } @EnableWebMvc @EnableWebSecurity static class OAuth2AuthorizedClientManagerRegisteredConfig extends WebSecurityConfigurerAdapter { + static ClientRegistrationRepository CLIENT_REGISTRATION_REPOSITORY; static OAuth2AuthorizedClientRepository AUTHORIZED_CLIENT_REPOSITORY; static OAuth2AuthorizedClientManager AUTHORIZED_CLIENT_MANAGER; @@ -366,28 +392,32 @@ public class OAuth2ClientConfigurationTests { protected void configure(HttpSecurity http) { } - @RestController - public class Controller { - - @GetMapping("/authorized-client") - public String authorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { - return authorizedClient != null ? "resolved" : "not-resolved"; - } - } - @Bean - public ClientRegistrationRepository clientRegistrationRepository() { + ClientRegistrationRepository clientRegistrationRepository() { return CLIENT_REGISTRATION_REPOSITORY; } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository() { + OAuth2AuthorizedClientRepository authorizedClientRepository() { return AUTHORIZED_CLIENT_REPOSITORY; } @Bean - public OAuth2AuthorizedClientManager authorizedClientManager() { + OAuth2AuthorizedClientManager authorizedClientManager() { return AUTHORIZED_CLIENT_MANAGER; } + + @RestController + class Controller { + + @GetMapping("/authorized-client") + String authorizedClient( + @RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { + return (authorizedClient != null) ? "resolved" : "not-resolved"; + } + + } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/Sec2515Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/Sec2515Tests.java index 845488117c..77ee64ea66 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/Sec2515Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/Sec2515Tests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import java.net.URL; +import java.net.URLClassLoader; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.FatalBeanException; import org.springframework.context.annotation.Bean; import org.springframework.security.authentication.AuthenticationManager; @@ -24,9 +29,6 @@ import org.springframework.security.config.annotation.authentication.builders.Au import org.springframework.security.config.test.SpringTestRule; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import java.net.URL; -import java.net.URLClassLoader; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; @@ -34,6 +36,7 @@ import static org.mockito.Mockito.mock; * @author Joe Grandja */ public class Sec2515Tests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -43,6 +46,28 @@ public class Sec2515Tests { this.spring.register(StackOverflowSecurityConfig.class).autowire(); } + @Test(expected = FatalBeanException.class) + public void loadConfigWhenAuthenticationManagerNotConfiguredAndRegisterBeanCustomNameThenThrowFatalBeanException() { + this.spring.register(CustomBeanNameStackOverflowSecurityConfig.class).autowire(); + } + + // SEC-2549 + @Test + public void loadConfigWhenChildClassLoaderSetThenContextLoads() { + CanLoadWithChildConfig.AUTHENTICATION_MANAGER = mock(AuthenticationManager.class); + this.spring.register(CanLoadWithChildConfig.class); + AnnotationConfigWebApplicationContext context = (AnnotationConfigWebApplicationContext) this.spring + .getContext(); + context.setClassLoader(new URLClassLoader(new URL[0], context.getClassLoader())); + this.spring.autowire(); + assertThat(this.spring.getContext().getBean(AuthenticationManager.class)).isNotNull(); + } // SEC-2515 + + @Test + public void loadConfigWhenAuthenticationManagerConfiguredAndRegisterBeanThenContextLoads() { + this.spring.register(SecurityConfig.class).autowire(); + } + @EnableWebSecurity static class StackOverflowSecurityConfig extends WebSecurityConfigurerAdapter { @@ -51,49 +76,31 @@ public class Sec2515Tests { public AuthenticationManager authenticationManagerBean() throws Exception { return super.authenticationManagerBean(); } - } - @Test(expected = FatalBeanException.class) - public void loadConfigWhenAuthenticationManagerNotConfiguredAndRegisterBeanCustomNameThenThrowFatalBeanException() { - this.spring.register(CustomBeanNameStackOverflowSecurityConfig.class).autowire(); } @EnableWebSecurity static class CustomBeanNameStackOverflowSecurityConfig extends WebSecurityConfigurerAdapter { @Override - @Bean(name="custom") + @Bean(name = "custom") public AuthenticationManager authenticationManagerBean() throws Exception { return super.authenticationManagerBean(); } - } - // SEC-2549 - @Test - public void loadConfigWhenChildClassLoaderSetThenContextLoads() { - CanLoadWithChildConfig.AUTHENTICATION_MANAGER = mock(AuthenticationManager.class); - this.spring.register(CanLoadWithChildConfig.class); - AnnotationConfigWebApplicationContext context = (AnnotationConfigWebApplicationContext) this.spring.getContext(); - context.setClassLoader(new URLClassLoader(new URL[0], context.getClassLoader())); - this.spring.autowire(); - - assertThat(this.spring.getContext().getBean(AuthenticationManager.class)).isNotNull(); } @EnableWebSecurity static class CanLoadWithChildConfig extends WebSecurityConfigurerAdapter { + static AuthenticationManager AUTHENTICATION_MANAGER; + @Override @Bean public AuthenticationManager authenticationManager() { return AUTHENTICATION_MANAGER; } - } - // SEC-2515 - @Test - public void loadConfigWhenAuthenticationManagerConfiguredAndRegisterBeanThenContextLoads() { - this.spring.register(SecurityConfig.class).autowire(); } @EnableWebSecurity @@ -109,5 +116,7 @@ public class Sec2515Tests { protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth.inMemoryAuthentication(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java index 23e84c88b9..0069e02c01 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; import javax.annotation.PreDestroy; @@ -31,24 +32,27 @@ import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; +import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications; import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.function.client.WebClient; -import static org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications.bearer; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests for applications of {@link SecurityReactorContextConfiguration} in resource servers. + * Tests for applications of {@link SecurityReactorContextConfiguration} in resource + * servers. * * @author Josh Cummings */ public class SecurityReactorContextConfigurationResourceServerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -58,45 +62,47 @@ public class SecurityReactorContextConfigurationResourceServerTests { // gh-7418 @Test public void requestWhenUsingFilterThenBearerTokenPropagated() throws Exception { - BearerTokenAuthentication authentication = bearer(); + BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer(); this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class).autowire(); - - this.mockMvc.perform(get("/token") - .with(authentication(authentication))) + MockHttpServletRequestBuilder authenticatedRequest = get("/token").with(authentication(authentication)); + // @formatter:off + this.mockMvc.perform(authenticatedRequest) .andExpect(status().isOk()) .andExpect(content().string("Bearer token")); + // @formatter:on } // gh-7418 @Test public void requestWhenNotUsingFilterThenBearerTokenNotPropagated() throws Exception { - BearerTokenAuthentication authentication = bearer(); + BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer(); this.spring.register(BearerFilterlessConfig.class, WebServerConfig.class, Controller.class).autowire(); - - this.mockMvc.perform(get("/token") - .with(authentication(authentication))) + MockHttpServletRequestBuilder authenticatedRequest = get("/token").with(authentication(authentication)); + // @formatter:off + this.mockMvc.perform(authenticatedRequest) .andExpect(status().isOk()) .andExpect(content().string("")); + // @formatter:on } - @EnableWebSecurity static class BearerFilterConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { } @Bean WebClient rest() { - ServletBearerExchangeFilterFunction bearer = - new ServletBearerExchangeFilterFunction(); - return WebClient.builder() - .filter(bearer).build(); + ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction(); + return WebClient.builder().filter(bearer).build(); } + } @EnableWebSecurity static class BearerFilterlessConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { } @@ -105,35 +111,43 @@ public class SecurityReactorContextConfigurationResourceServerTests { WebClient rest() { return WebClient.create(); } + } @RestController static class Controller { + private final WebClient rest; + private final String uri; @Autowired - Controller(MockWebServer server, WebClient rest) { + Controller(MockWebServer server, WebClient rest) { this.uri = server.url("/").toString(); this.rest = rest; } @GetMapping("/token") - public String token() { + String token() { + // @formatter:off return this.rest.get() .uri(this.uri) .retrieve() .bodyToMono(String.class) - .flatMap(result -> this.rest.get() + .flatMap((result) -> this.rest.get() .uri(this.uri) .retrieve() - .bodyToMono(String.class)) + .bodyToMono(String.class) + ) .block(); + // @formatter:on } + } @Configuration static class WebServerConfig { + private final MockWebServer server = new MockWebServer(); @Bean @@ -147,6 +161,7 @@ public class SecurityReactorContextConfigurationResourceServerTests { void shutdown() throws Exception { this.server.shutdown(); } + } static class AuthorizationHeaderDispatcher extends Dispatcher { @@ -157,9 +172,10 @@ public class SecurityReactorContextConfigurationResourceServerTests { String header = request.getHeader("Authorization"); if (StringUtils.isBlank(header)) { return response; - } return response.setBody(header); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java index 11d8317dde..f7cfa89a12 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java @@ -13,17 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.util.context.Context; + +import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -34,23 +51,9 @@ import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.BaseSubscriber; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Operators; -import reactor.test.StepVerifier; -import reactor.util.context.Context; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.net.URI; -import java.util.HashMap; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.entry; -import static org.springframework.http.HttpMethod.GET; -import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES; /** * Tests for {@link SecurityReactorContextConfiguration}. @@ -59,11 +62,14 @@ import static org.springframework.security.config.annotation.web.configuration.S * @since 5.2 */ public class SecurityReactorContextConfigurationTests { + private MockHttpServletRequest servletRequest; + private MockHttpServletResponse servletResponse; + private Authentication authentication; - private SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar subscriberRegistrar = - new SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar(); + + private SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar subscriberRegistrar = new SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar(); @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -83,26 +89,25 @@ public class SecurityReactorContextConfigurationTests { @Test public void createSubscriberIfNecessaryWhenSubscriberContextContainsSecurityContextAttributesThenReturnOriginalSubscriber() { - Context context = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>()); + Context context = Context.of(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>()); BaseSubscriber originalSubscriber = new BaseSubscriber() { @Override public Context currentContext() { return context; } }; - CoreSubscriber resultSubscriber = this.subscriberRegistrar.createSubscriberIfNecessary(originalSubscriber); + CoreSubscriber resultSubscriber = this.subscriberRegistrar + .createSubscriberIfNecessary(originalSubscriber); assertThat(resultSubscriber).isSameAs(originalSubscriber); } @Test public void createSubscriberIfNecessaryWhenWebSecurityContextAvailableThenCreateWithParentContext() { - RequestContextHolder.setRequestAttributes( - new ServletRequestAttributes(this.servletRequest, this.servletResponse)); + RequestContextHolder + .setRequestAttributes(new ServletRequestAttributes(this.servletRequest, this.servletResponse)); SecurityContextHolder.getContext().setAuthentication(this.authentication); - String testKey = "test_key"; String testValue = "test_value"; - BaseSubscriber parent = new BaseSubscriber() { @Override public Context currentContext() { @@ -110,25 +115,23 @@ public class SecurityReactorContextConfigurationTests { } }; CoreSubscriber subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent); - Context resultContext = subscriber.currentContext(); - assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue); - Map securityContextAttributes = resultContext.getOrDefault(SECURITY_CONTEXT_ATTRIBUTES, null); + Map securityContextAttributes = resultContext + .getOrDefault(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, null); assertThat(securityContextAttributes).hasSize(3); - assertThat(securityContextAttributes).contains( - entry(HttpServletRequest.class, this.servletRequest), + assertThat(securityContextAttributes).contains(entry(HttpServletRequest.class, this.servletRequest), entry(HttpServletResponse.class, this.servletResponse), entry(Authentication.class, this.authentication)); } @Test public void createSubscriberIfNecessaryWhenParentContextContainsSecurityContextAttributesThenUseParentContext() { - RequestContextHolder.setRequestAttributes( - new ServletRequestAttributes(this.servletRequest, this.servletResponse)); + RequestContextHolder + .setRequestAttributes(new ServletRequestAttributes(this.servletRequest, this.servletResponse)); SecurityContextHolder.getContext().setAuthentication(this.authentication); - - Context parentContext = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>()); + Context parentContext = Context.of(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, + new HashMap<>()); BaseSubscriber parent = new BaseSubscriber() { @Override public Context currentContext() { @@ -136,101 +139,97 @@ public class SecurityReactorContextConfigurationTests { } }; CoreSubscriber subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent); - Context resultContext = subscriber.currentContext(); assertThat(resultContext).isSameAs(parentContext); } @Test public void createSubscriberIfNecessaryWhenNotServletRequestAttributesThenStillCreate() { - RequestContextHolder.setRequestAttributes( - new RequestAttributes() { - @Override - public Object getAttribute(String name, int scope) { - return null; - } + RequestContextHolder.setRequestAttributes(new RequestAttributes() { + @Override + public Object getAttribute(String name, int scope) { + return null; + } - @Override - public void setAttribute(String name, Object value, int scope) { - } + @Override + public void setAttribute(String name, Object value, int scope) { + } - @Override - public void removeAttribute(String name, int scope) { - } + @Override + public void removeAttribute(String name, int scope) { + } - @Override - public String[] getAttributeNames(int scope) { - return new String[0]; - } + @Override + public String[] getAttributeNames(int scope) { + return new String[0]; + } - @Override - public void registerDestructionCallback(String name, Runnable callback, int scope) { - } + @Override + public void registerDestructionCallback(String name, Runnable callback, int scope) { + } - @Override - public Object resolveReference(String key) { - return null; - } + @Override + public Object resolveReference(String key) { + return null; + } - @Override - public String getSessionId() { - return null; - } + @Override + public String getSessionId() { + return null; + } - @Override - public Object getSessionMutex() { - return null; - } - }); - - CoreSubscriber subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(Operators.emptySubscriber()); + @Override + public Object getSessionMutex() { + return null; + } + }); + CoreSubscriber subscriber = this.subscriberRegistrar + .createSubscriberIfNecessary(Operators.emptySubscriber()); assertThat(subscriber).isInstanceOf(SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.class); } @Test public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() { - // Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector + // Trigger the importing of SecurityReactorContextConfiguration via + // OAuth2ImportSelector this.spring.register(SecurityConfig.class).autowire(); - // Setup for SecurityReactorContextSubscriberRegistrar - RequestContextHolder.setRequestAttributes( - new ServletRequestAttributes(this.servletRequest, this.servletResponse)); + RequestContextHolder + .setRequestAttributes(new ServletRequestAttributes(this.servletRequest, this.servletResponse)); SecurityContextHolder.getContext().setAuthentication(this.authentication); - ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build(); - - ExchangeFilterFunction filter = (req, next) -> - Mono.subscriberContext() - .filter(ctx -> ctx.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) - .map(ctx -> ctx.get(SECURITY_CONTEXT_ATTRIBUTES)) - .cast(Map.class) - .map(attributes -> { - if (attributes.containsKey(HttpServletRequest.class) && - attributes.containsKey(HttpServletResponse.class) && - attributes.containsKey(Authentication.class)) { - return clientResponseOk; - } else { - return ClientResponse.create(HttpStatus.NOT_FOUND).build(); - } - }); - - ClientRequest clientRequest = ClientRequest.create(GET, URI.create("https://example.com")).build(); + // @formatter:off + ExchangeFilterFunction filter = (req, next) -> Mono.subscriberContext() + .filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) + .map((ctx) -> ctx.get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) + .cast(Map.class) + .map((attributes) -> { + if (attributes.containsKey(HttpServletRequest.class) + && attributes.containsKey(HttpServletResponse.class) + && attributes.containsKey(Authentication.class)) { + return clientResponseOk; + } + else { + return ClientResponse.create(HttpStatus.NOT_FOUND).build(); + } + }); + // @formatter:on + ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); MockExchangeFunction exchange = new MockExchangeFunction(); - Map expectedContextAttributes = new HashMap<>(); expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest); expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse); expectedContextAttributes.put(Authentication.class, this.authentication); - Mono clientResponseMono = filter.filter(clientRequest, exchange) - .flatMap(response -> filter.filter(clientRequest, exchange)); - + .flatMap((response) -> filter.filter(clientRequest, exchange)); + // @formatter:off StepVerifier.create(clientResponseMono) .expectAccessibleContext() - .contains(SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes) + .contains(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes) .then() .expectNext(clientResponseOk) .verifyComplete(); + // @formatter:on } @EnableWebSecurity @@ -239,5 +238,7 @@ public class SecurityReactorContextConfigurationTests { @Override protected void configure(HttpSecurity http) throws Exception { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java index 9d714433e9..09d5c6d44b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java @@ -13,15 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configuration; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +package org.springframework.security.config.annotation.web.configuration; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -45,6 +44,10 @@ import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.view; + /** * @author Rob Winch */ @@ -62,10 +65,10 @@ public class WebMvcSecurityConfigurationTests { @Before public void setup() { - mockMvc = MockMvcBuilders.webAppContextSetup(context).build(); - authentication = new TestingAuthenticationToken("user", "password", + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.context).build(); + this.authentication = new TestingAuthenticationToken("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER")); - SecurityContextHolder.getContext().setAuthentication(authentication); + SecurityContextHolder.getContext().setAuthentication(this.authentication); } @After @@ -75,25 +78,23 @@ public class WebMvcSecurityConfigurationTests { @Test public void authenticationPrincipalResolved() throws Exception { - mockMvc.perform(get("/authentication-principal")) - .andExpect(assertResult(authentication.getPrincipal())) + this.mockMvc.perform(get("/authentication-principal")) + .andExpect(assertResult(this.authentication.getPrincipal())) .andExpect(view().name("authentication-principal-view")); } @Test public void deprecatedAuthenticationPrincipalResolved() throws Exception { - mockMvc.perform(get("/deprecated-authentication-principal")) - .andExpect(assertResult(authentication.getPrincipal())) + this.mockMvc.perform(get("/deprecated-authentication-principal")) + .andExpect(assertResult(this.authentication.getPrincipal())) .andExpect(view().name("deprecated-authentication-principal-view")); } @Test public void csrfToken() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "token"); - MockHttpServletRequestBuilder request = get("/csrf").requestAttr( - CsrfToken.class.getName(), csrfToken); - - mockMvc.perform(request).andExpect(assertResult(csrfToken)); + MockHttpServletRequestBuilder request = get("/csrf").requestAttr(CsrfToken.class.getName(), csrfToken); + this.mockMvc.perform(request).andExpect(assertResult(csrfToken)); } private ResultMatcher assertResult(Object expected) { @@ -104,32 +105,33 @@ public class WebMvcSecurityConfigurationTests { static class TestController { @RequestMapping("/authentication-principal") - public ModelAndView authenticationPrincipal( - @AuthenticationPrincipal String principal) { + ModelAndView authenticationPrincipal(@AuthenticationPrincipal String principal) { return new ModelAndView("authentication-principal-view", "result", principal); } @RequestMapping("/deprecated-authentication-principal") - public ModelAndView deprecatedAuthenticationPrincipal( + ModelAndView deprecatedAuthenticationPrincipal( @org.springframework.security.web.bind.annotation.AuthenticationPrincipal String principal) { - return new ModelAndView("deprecated-authentication-principal-view", "result", - principal); + return new ModelAndView("deprecated-authentication-principal-view", "result", principal); } @RequestMapping("/csrf") - public ModelAndView csrf(CsrfToken token) { + ModelAndView csrf(CsrfToken token) { return new ModelAndView("view", "result", token); } + } @Configuration @EnableWebMvc @EnableWebSecurity static class Config { + @Bean - public TestController testController() { + TestController testController() { return new TestController(); } + } -} \ No newline at end of file +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurationTests.java index 29dab5ac05..9f8527e3a4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebSecurityConfigurationTests.java @@ -13,10 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration; +import java.io.Serializable; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.List; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -53,15 +60,10 @@ import org.springframework.util.ClassUtils; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import java.io.Serializable; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.catchThrowable; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -74,293 +76,95 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Evgeniy Cheban */ public class WebSecurityConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); + @Rule + public SpringTestRule child = new SpringTestRule(); + @Autowired private MockMvc mockMvc; @Test public void loadConfigWhenWebSecurityConfigurersHaveOrderThenFilterChainsOrdered() { this.spring.register(SortedWebSecurityConfigurerAdaptersConfig.class).autowire(); - FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); List filterChains = filterChainProxy.getFilterChains(); assertThat(filterChains).hasSize(6); - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setServletPath("/ignore1"); assertThat(filterChains.get(0).matches(request)).isTrue(); assertThat(filterChains.get(0).getFilters()).isEmpty(); - request.setServletPath("/ignore2"); assertThat(filterChains.get(1).matches(request)).isTrue(); assertThat(filterChains.get(1).getFilters()).isEmpty(); - request.setServletPath("/role1/**"); assertThat(filterChains.get(2).matches(request)).isTrue(); - request.setServletPath("/role2/**"); assertThat(filterChains.get(3).matches(request)).isTrue(); - request.setServletPath("/role3/**"); assertThat(filterChains.get(4).matches(request)).isTrue(); - request.setServletPath("/**"); assertThat(filterChains.get(5).matches(request)).isTrue(); } - @EnableWebSecurity - @Import(AuthenticationTestConfiguration.class) - static class SortedWebSecurityConfigurerAdaptersConfig { - - @Configuration - @Order(1) - static class WebConfigurer1 extends WebSecurityConfigurerAdapter { - @Override - public void configure(WebSecurity web) { - web - .ignoring() - .antMatchers("/ignore1", "/ignore2"); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .antMatcher("/role1/**") - .authorizeRequests() - .anyRequest().hasRole("1"); - } - } - - @Configuration - @Order(2) - static class WebConfigurer2 extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .antMatcher("/role2/**") - .authorizeRequests() - .anyRequest().hasRole("2"); - } - } - - @Configuration - @Order(3) - static class WebConfigurer3 extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .antMatcher("/role3/**") - .authorizeRequests() - .anyRequest().hasRole("3"); - } - } - - @Configuration - static class WebConfigurer4 extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("4"); - } - } - } - @Test public void loadConfigWhenSecurityFilterChainsHaveOrderThenFilterChainsOrdered() { this.spring.register(SortedSecurityFilterChainConfig.class).autowire(); - FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); List filterChains = filterChainProxy.getFilterChains(); assertThat(filterChains).hasSize(4); - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setServletPath("/role1/**"); assertThat(filterChains.get(0).matches(request)).isTrue(); - request.setServletPath("/role2/**"); assertThat(filterChains.get(1).matches(request)).isTrue(); - request.setServletPath("/role3/**"); assertThat(filterChains.get(2).matches(request)).isTrue(); - request.setServletPath("/**"); assertThat(filterChains.get(3).matches(request)).isTrue(); } - @EnableWebSecurity - @Import(AuthenticationTestConfiguration.class) - static class SortedSecurityFilterChainConfig { - - @Order(1) - @Bean - SecurityFilterChain filterChain1(HttpSecurity http) throws Exception { - return http - .antMatcher("/role1/**") - .authorizeRequests(authorize -> authorize - .anyRequest().hasRole("1") - ) - .build(); - } - - @Order(2) - @Bean - SecurityFilterChain filterChain2(HttpSecurity http) throws Exception { - return http - .antMatcher("/role2/**") - .authorizeRequests(authorize -> authorize - .anyRequest().hasRole("2") - ) - .build(); - } - - @Order(3) - @Bean - SecurityFilterChain filterChain3(HttpSecurity http) throws Exception { - return http - .antMatcher("/role3/**") - .authorizeRequests(authorize -> authorize - .anyRequest().hasRole("3") - ) - .build(); - } - - @Bean - SecurityFilterChain filterChain4(HttpSecurity http) throws Exception { - return http - .authorizeRequests(authorize -> authorize - .anyRequest().hasRole("4") - ) - .build(); - } - } - @Test public void loadConfigWhenWebSecurityConfigurersHaveSameOrderThenThrowBeanCreationException() { - Throwable thrown = catchThrowable(() -> this.spring.register(DuplicateOrderConfig.class).autowire()); - - assertThat(thrown).isInstanceOf(BeanCreationException.class) - .hasMessageContaining("@Order on WebSecurityConfigurers must be unique") - .hasMessageContaining(DuplicateOrderConfig.WebConfigurer1.class.getName()) - .hasMessageContaining(DuplicateOrderConfig.WebConfigurer2.class.getName()); - } - - @EnableWebSecurity - @Import(AuthenticationTestConfiguration.class) - static class DuplicateOrderConfig { - - @Configuration - static class WebConfigurer1 extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .antMatcher("/role1/**") - .authorizeRequests() - .anyRequest().hasRole("1"); - } - } - - @Configuration - static class WebConfigurer2 extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .antMatcher("/role2/**") - .authorizeRequests() - .anyRequest().hasRole("2"); - } - } + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(DuplicateOrderConfig.class).autowire()) + .withMessageContaining("@Order on WebSecurityConfigurers must be unique") + .withMessageContaining(DuplicateOrderConfig.WebConfigurer1.class.getName()) + .withMessageContaining(DuplicateOrderConfig.WebConfigurer2.class.getName()); } @Test public void loadConfigWhenWebInvocationPrivilegeEvaluatorSetThenIsRegistered() { PrivilegeEvaluatorConfigurerAdapterConfig.PRIVILEGE_EVALUATOR = mock(WebInvocationPrivilegeEvaluator.class); - this.spring.register(PrivilegeEvaluatorConfigurerAdapterConfig.class).autowire(); - assertThat(this.spring.getContext().getBean(WebInvocationPrivilegeEvaluator.class)) - .isSameAs(PrivilegeEvaluatorConfigurerAdapterConfig.PRIVILEGE_EVALUATOR); - } - - @EnableWebSecurity - static class PrivilegeEvaluatorConfigurerAdapterConfig extends WebSecurityConfigurerAdapter { - static WebInvocationPrivilegeEvaluator PRIVILEGE_EVALUATOR; - - @Override - public void configure(WebSecurity web) { - web.privilegeEvaluator(PRIVILEGE_EVALUATOR); - } + .isSameAs(PrivilegeEvaluatorConfigurerAdapterConfig.PRIVILEGE_EVALUATOR); } @Test public void loadConfigWhenSecurityExpressionHandlerSetThenIsRegistered() { WebSecurityExpressionHandlerConfig.EXPRESSION_HANDLER = mock(SecurityExpressionHandler.class); - when(WebSecurityExpressionHandlerConfig.EXPRESSION_HANDLER.getExpressionParser()).thenReturn(mock(ExpressionParser.class)); - + given(WebSecurityExpressionHandlerConfig.EXPRESSION_HANDLER.getExpressionParser()) + .willReturn(mock(ExpressionParser.class)); this.spring.register(WebSecurityExpressionHandlerConfig.class).autowire(); - assertThat(this.spring.getContext().getBean(SecurityExpressionHandler.class)) - .isSameAs(WebSecurityExpressionHandlerConfig.EXPRESSION_HANDLER); - } - - @EnableWebSecurity - static class WebSecurityExpressionHandlerConfig extends WebSecurityConfigurerAdapter { - static SecurityExpressionHandler EXPRESSION_HANDLER; - - @Override - public void configure(WebSecurity web) { - web.expressionHandler(EXPRESSION_HANDLER); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated() - .expressionHandler(EXPRESSION_HANDLER); - } + .isSameAs(WebSecurityExpressionHandlerConfig.EXPRESSION_HANDLER); } @Test public void loadConfigWhenSecurityExpressionHandlerIsNullThenException() { - Throwable thrown = catchThrowable(() -> - this.spring.register(NullWebSecurityExpressionHandlerConfig.class).autowire() - ); - - assertThat(thrown).isInstanceOf(BeanCreationException.class); - assertThat(thrown).hasRootCauseExactlyInstanceOf(IllegalArgumentException.class); - } - - @EnableWebSecurity - static class NullWebSecurityExpressionHandlerConfig extends WebSecurityConfigurerAdapter { - - @Override - public void configure(WebSecurity web) { - web.expressionHandler(null); - } + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullWebSecurityExpressionHandlerConfig.class).autowire()) + .havingRootCause().isExactlyInstanceOf(IllegalArgumentException.class); } @Test public void loadConfigWhenDefaultSecurityExpressionHandlerThenDefaultIsRegistered() { this.spring.register(WebSecurityExpressionHandlerDefaultsConfig.class).autowire(); - assertThat(this.spring.getContext().getBean(SecurityExpressionHandler.class)) - .isInstanceOf(DefaultWebSecurityExpressionHandler.class); - } - - @EnableWebSecurity - static class WebSecurityExpressionHandlerDefaultsConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated(); - } + .isInstanceOf(DefaultWebSecurityExpressionHandler.class); } @Test @@ -369,77 +173,33 @@ public class WebSecurityConfigurationTests { TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "notused", "ROLE_ADMIN"); FilterInvocation invocation = new FilterInvocation(new MockHttpServletRequest("GET", ""), new MockHttpServletResponse(), new MockFilterChain()); - - AbstractSecurityExpressionHandler handler = this.spring.getContext().getBean(AbstractSecurityExpressionHandler.class); + AbstractSecurityExpressionHandler handler = this.spring.getContext() + .getBean(AbstractSecurityExpressionHandler.class); EvaluationContext evaluationContext = handler.createEvaluationContext(authentication, invocation); - Expression expression = handler.getExpressionParser() - .parseExpression("hasRole('ROLE_USER')"); + Expression expression = handler.getExpressionParser().parseExpression("hasRole('ROLE_USER')"); boolean granted = expression.getValue(evaluationContext, Boolean.class); assertThat(granted).isTrue(); } - @EnableWebSecurity - static class WebSecurityExpressionHandlerRoleHierarchyBeanConfig extends WebSecurityConfigurerAdapter { - @Bean - RoleHierarchy roleHierarchy() { - RoleHierarchyImpl roleHierarchy = new RoleHierarchyImpl(); - roleHierarchy.setHierarchy("ROLE_ADMIN > ROLE_USER"); - return roleHierarchy; - } - } - @Test public void securityExpressionHandlerWhenPermissionEvaluatorBeanThenPermissionEvaluatorUsed() { this.spring.register(WebSecurityExpressionHandlerPermissionEvaluatorBeanConfig.class).autowire(); TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "notused"); - FilterInvocation invocation = new FilterInvocation(new MockHttpServletRequest("GET", ""), new MockHttpServletResponse(), new MockFilterChain()); - - AbstractSecurityExpressionHandler handler = this.spring.getContext().getBean(AbstractSecurityExpressionHandler.class); + FilterInvocation invocation = new FilterInvocation(new MockHttpServletRequest("GET", ""), + new MockHttpServletResponse(), new MockFilterChain()); + AbstractSecurityExpressionHandler handler = this.spring.getContext() + .getBean(AbstractSecurityExpressionHandler.class); EvaluationContext evaluationContext = handler.createEvaluationContext(authentication, invocation); - Expression expression = handler.getExpressionParser() - .parseExpression("hasPermission(#study,'DELETE')"); + Expression expression = handler.getExpressionParser().parseExpression("hasPermission(#study,'DELETE')"); boolean granted = expression.getValue(evaluationContext, Boolean.class); assertThat(granted).isTrue(); } - @EnableWebSecurity - static class WebSecurityExpressionHandlerPermissionEvaluatorBeanConfig extends WebSecurityConfigurerAdapter { - static final PermissionEvaluator PERMIT_ALL_PERMISSION_EVALUATOR = new PermissionEvaluator() { - @Override - public boolean hasPermission(Authentication authentication, - Object targetDomainObject, Object permission) { - return true; - } - - @Override - public boolean hasPermission(Authentication authentication, - Serializable targetId, String targetType, Object permission) { - return true; - } - }; - - @Bean - public PermissionEvaluator permissionEvaluator() { - return PERMIT_ALL_PERMISSION_EVALUATOR; - } - } - @Test public void loadConfigWhenDefaultWebInvocationPrivilegeEvaluatorThenDefaultIsRegistered() { this.spring.register(WebInvocationPrivilegeEvaluatorDefaultsConfig.class).autowire(); - assertThat(this.spring.getContext().getBean(WebInvocationPrivilegeEvaluator.class)) - .isInstanceOf(DefaultWebInvocationPrivilegeEvaluator.class); - } - - @EnableWebSecurity - static class WebInvocationPrivilegeEvaluatorDefaultsConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().authenticated(); - } + .isInstanceOf(DefaultWebInvocationPrivilegeEvaluator.class); } @Test @@ -447,94 +207,30 @@ public class WebSecurityConfigurationTests { this.spring.register(AuthorizeRequestsFilterChainConfig.class).autowire(); assertThat(this.spring.getContext().getBean(WebInvocationPrivilegeEvaluator.class)) - .isInstanceOf(DefaultWebInvocationPrivilegeEvaluator.class); - } - - @EnableWebSecurity - static class AuthorizeRequestsFilterChainConfig { - @Bean - public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { - return http - .authorizeRequests(authorize -> authorize - .anyRequest().authenticated() - ) - .build(); - } + .isInstanceOf(DefaultWebInvocationPrivilegeEvaluator.class); } // SEC-2303 @Test public void loadConfigWhenDefaultSecurityExpressionHandlerThenBeanResolverSet() throws Exception { this.spring.register(DefaultExpressionHandlerSetsBeanResolverConfig.class).autowire(); - this.mockMvc.perform(get("/")).andExpect(status().isOk()); this.mockMvc.perform(post("/")).andExpect(status().isForbidden()); } - @EnableWebSecurity - static class DefaultExpressionHandlerSetsBeanResolverConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().access("request.method == 'GET' ? @b.grant() : @b.deny()"); - } - - @RestController - public class HomeController { - @GetMapping("/") - public String home() { - return "home"; - } - } - - @Bean - public MyBean b() { - return new MyBean(); - } - - static class MyBean { - public boolean deny() { - return false; - } - - public boolean grant() { - return true; - } - } - } - - @Rule - public SpringTestRule child = new SpringTestRule(); - // SEC-2461 @Test public void loadConfigWhenMultipleWebSecurityConfigurationThenContextLoads() { this.spring.register(ParentConfig.class).autowire(); - this.child.register(ChildConfig.class); this.child.getContext().setParent(this.spring.getContext()); this.child.autowire(); - assertThat(this.spring.getContext().getBean("springSecurityFilterChain")).isNotNull(); assertThat(this.child.getContext().getBean("springSecurityFilterChain")).isNotNull(); - assertThat(this.spring.getContext().containsBean("springSecurityFilterChain")).isTrue(); assertThat(this.child.getContext().containsBean("springSecurityFilterChain")).isTrue(); } - @EnableWebSecurity - static class ParentConfig extends WebSecurityConfigurerAdapter { - @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { - auth.inMemoryAuthentication(); - } - } - - @EnableWebSecurity - static class ChildConfig extends WebSecurityConfigurerAdapter { - } - // SEC-2773 @Test public void getMethodDelegatingApplicationListenerWhenWebSecurityConfigurationThenIsStatic() { @@ -544,38 +240,168 @@ public class WebSecurityConfigurationTests { @Test public void loadConfigWhenBeanProxyingEnabledAndSubclassThenFilterChainsCreated() { - this.spring.register(GlobalAuthenticationWebSecurityConfigurerAdaptersConfig.class, SubclassConfig.class).autowire(); - + this.spring.register(GlobalAuthenticationWebSecurityConfigurerAdaptersConfig.class, SubclassConfig.class) + .autowire(); FilterChainProxy filterChainProxy = this.spring.getContext().getBean(FilterChainProxy.class); List filterChains = filterChainProxy.getFilterChains(); - assertThat(filterChains).hasSize(4); } - @Configuration - static class SubclassConfig extends WebSecurityConfiguration { + @Test + public void loadConfigWhenBothAdapterAndFilterChainConfiguredThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(AdapterAndFilterChainConfig.class).autowire()) + .withRootCauseExactlyInstanceOf(IllegalStateException.class) + .withMessageContaining("Found WebSecurityConfigurerAdapter as well as SecurityFilterChain."); + } + @EnableWebSecurity @Import(AuthenticationTestConfiguration.class) - @EnableGlobalAuthentication - static class GlobalAuthenticationWebSecurityConfigurerAdaptersConfig { + static class SortedWebSecurityConfigurerAdaptersConfig { + @Configuration @Order(1) static class WebConfigurer1 extends WebSecurityConfigurerAdapter { + @Override public void configure(WebSecurity web) { - web - .ignoring() - .antMatchers("/ignore1", "/ignore2"); + web.ignoring().antMatchers("/ignore1", "/ignore2"); } @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .antMatcher("/anonymous/**") - .authorizeRequests() - .anyRequest().anonymous(); + .antMatcher("/role1/**") + .authorizeRequests() + .anyRequest().hasRole("1"); + // @formatter:on } + + } + + @Configuration + @Order(2) + static class WebConfigurer2 extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .antMatcher("/role2/**") + .authorizeRequests() + .anyRequest().hasRole("2"); + // @formatter:on + } + + } + + @Configuration + @Order(3) + static class WebConfigurer3 extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .antMatcher("/role3/**") + .authorizeRequests() + .anyRequest().hasRole("3"); + // @formatter:on + } + + } + + @Configuration + static class WebConfigurer4 extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("4"); + // @formatter:on + } + + } + + } + + @EnableWebSecurity + @Import(AuthenticationTestConfiguration.class) + static class SortedSecurityFilterChainConfig { + + @Order(1) + @Bean + SecurityFilterChain filterChain1(HttpSecurity http) throws Exception { + // @formatter:off + return http + .antMatcher("/role1/**") + .authorizeRequests((authorize) -> authorize + .anyRequest().hasRole("1") + ) + .build(); + // @formatter:on + } + + @Order(2) + @Bean + SecurityFilterChain filterChain2(HttpSecurity http) throws Exception { + // @formatter:off + return http + .antMatcher("/role2/**") + .authorizeRequests((authorize) -> authorize + .anyRequest().hasRole("2") + ) + .build(); + // @formatter:on + } + + @Order(3) + @Bean + SecurityFilterChain filterChain3(HttpSecurity http) throws Exception { + // @formatter:off + return http + .antMatcher("/role3/**") + .authorizeRequests((authorize) -> authorize + .anyRequest().hasRole("3") + ) + .build(); + // @formatter:on + } + + @Bean + SecurityFilterChain filterChain4(HttpSecurity http) throws Exception { + // @formatter:off + return http + .authorizeRequests((authorize) -> authorize + .anyRequest().hasRole("4") + ) + .build(); + // @formatter:on + } + + } + + @EnableWebSecurity + @Import(AuthenticationTestConfiguration.class) + static class DuplicateOrderConfig { + + @Configuration + static class WebConfigurer1 extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .antMatcher("/role1/**") + .authorizeRequests() + .anyRequest().hasRole("1"); + // @formatter:on + } + } @Configuration @@ -583,48 +409,277 @@ public class WebSecurityConfigurationTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .antMatcher("/role2/**") + .authorizeRequests() + .anyRequest().hasRole("2"); + // @formatter:on + } + + } + + } + + @EnableWebSecurity + static class PrivilegeEvaluatorConfigurerAdapterConfig extends WebSecurityConfigurerAdapter { + + static WebInvocationPrivilegeEvaluator PRIVILEGE_EVALUATOR; + + @Override + public void configure(WebSecurity web) { + web.privilegeEvaluator(PRIVILEGE_EVALUATOR); + } + + } + + @EnableWebSecurity + static class WebSecurityExpressionHandlerConfig extends WebSecurityConfigurerAdapter { + + static SecurityExpressionHandler EXPRESSION_HANDLER; + + @Override + public void configure(WebSecurity web) { + web.expressionHandler(EXPRESSION_HANDLER); + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .expressionHandler(EXPRESSION_HANDLER); + // @formatter:on + } + + } + + @EnableWebSecurity + static class NullWebSecurityExpressionHandlerConfig extends WebSecurityConfigurerAdapter { + + @Override + public void configure(WebSecurity web) { + web.expressionHandler(null); + } + + } + + @EnableWebSecurity + static class WebSecurityExpressionHandlerDefaultsConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class WebSecurityExpressionHandlerRoleHierarchyBeanConfig extends WebSecurityConfigurerAdapter { + + @Bean + RoleHierarchy roleHierarchy() { + RoleHierarchyImpl roleHierarchy = new RoleHierarchyImpl(); + roleHierarchy.setHierarchy("ROLE_ADMIN > ROLE_USER"); + return roleHierarchy; + } + + } + + @EnableWebSecurity + static class WebSecurityExpressionHandlerPermissionEvaluatorBeanConfig extends WebSecurityConfigurerAdapter { + + static final PermissionEvaluator PERMIT_ALL_PERMISSION_EVALUATOR = new PermissionEvaluator() { + @Override + public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { + return true; + } + + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, + Object permission) { + return true; + } + }; + + @Bean + PermissionEvaluator permissionEvaluator() { + return PERMIT_ALL_PERMISSION_EVALUATOR; + } + + } + + @EnableWebSecurity + static class WebInvocationPrivilegeEvaluatorDefaultsConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class AuthorizeRequestsFilterChainConfig { + + @Bean + public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + return http + .authorizeRequests((authorize) -> authorize + .anyRequest().authenticated() + ) + .build(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class DefaultExpressionHandlerSetsBeanResolverConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().access("request.method == 'GET' ? @b.grant() : @b.deny()"); + // @formatter:on + } + + @Bean + public MyBean b() { + return new MyBean(); + } + + @RestController + class HomeController { + + @GetMapping("/") + String home() { + return "home"; + } + + } + + static class MyBean { + + public boolean deny() { + return false; + } + + public boolean grant() { + return true; + } + + } + + } + + @EnableWebSecurity + static class ParentConfig extends WebSecurityConfigurerAdapter { + + @Autowired + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + auth.inMemoryAuthentication(); + } + + } + + @EnableWebSecurity + static class ChildConfig extends WebSecurityConfigurerAdapter { + + } + + @Configuration + static class SubclassConfig extends WebSecurityConfiguration { + + } + + @Import(AuthenticationTestConfiguration.class) + @EnableGlobalAuthentication + static class GlobalAuthenticationWebSecurityConfigurerAdaptersConfig { + + @Configuration + @Order(1) + static class WebConfigurer1 extends WebSecurityConfigurerAdapter { + + @Override + public void configure(WebSecurity web) { + web.ignoring().antMatchers("/ignore1", "/ignore2"); + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .antMatcher("/anonymous/**") + .authorizeRequests() + .anyRequest().anonymous(); + // @formatter:on + } + + } + + @Configuration + static class WebConfigurer2 extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated(); + // @formatter:on } + } - } - - @Test - public void loadConfigWhenBothAdapterAndFilterChainConfiguredThenException() { - Throwable thrown = catchThrowable(() -> this.spring.register(AdapterAndFilterChainConfig.class).autowire()); - - assertThat(thrown).isInstanceOf(BeanCreationException.class) - .hasRootCauseExactlyInstanceOf(IllegalStateException.class) - .hasMessageContaining("Found WebSecurityConfigurerAdapter as well as SecurityFilterChain."); } @EnableWebSecurity @Import(AuthenticationTestConfiguration.class) static class AdapterAndFilterChainConfig { - @Order(1) - @Configuration - static class WebConfigurer extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .antMatcher("/config/**") - .authorizeRequests(authorize -> authorize - .anyRequest().permitAll() - ); - } - } @Order(2) @Bean SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off return http .antMatcher("/filter/**") - .authorizeRequests(authorize -> authorize + .authorizeRequests((authorize) -> authorize .anyRequest().authenticated() ) .build(); + // @formatter:on } + + @Order(1) + @Configuration + static class WebConfigurer extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .antMatcher("/config/**") + .authorizeRequests((authorize) -> authorize + .anyRequest().permitAll() + ); + // @formatter:on + } + + } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/Sec2377Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/Sec2377Tests.java index 59631f73b6..729c041d08 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/Sec2377Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/Sec2377Tests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration.sec2377; import org.junit.Rule; import org.junit.Test; + import org.springframework.context.ConfigurableApplicationContext; import org.springframework.security.config.annotation.web.configuration.sec2377.a.Sec2377AConfig; import org.springframework.security.config.annotation.web.configuration.sec2377.b.Sec2377BConfig; @@ -37,11 +39,9 @@ public class Sec2377Tests { @Test public void refreshContextWhenParentAndChildRegisteredThenNoException() { this.parent.register(Sec2377AConfig.class).autowire(); - - ConfigurableApplicationContext context = - this.child.register(Sec2377BConfig.class).getContext(); + ConfigurableApplicationContext context = this.child.register(Sec2377BConfig.class).getContext(); context.setParent(this.parent.getContext()); - this.child.autowire(); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/a/Sec2377AConfig.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/a/Sec2377AConfig.java index 23ba45407d..ed8857c73a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/a/Sec2377AConfig.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/a/Sec2377AConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration.sec2377.a; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/b/Sec2377BConfig.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/b/Sec2377BConfig.java index 444b3c0d79..9a7cd92134 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/b/Sec2377BConfig.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/sec2377/b/Sec2377BConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configuration.sec2377.b; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AbstractConfigAttributeRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AbstractConfigAttributeRequestMatcherRegistryTests.java index f5f3e9f89b..39a50575ff 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AbstractConfigAttributeRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AbstractConfigAttributeRequestMatcherRegistryTests.java @@ -16,8 +16,11 @@ package org.springframework.security.config.annotation.web.configurers; +import java.util.List; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpMethod; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -26,67 +29,69 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import static org.assertj.core.api.Assertions.assertThat; -import java.util.List; - public class AbstractConfigAttributeRequestMatcherRegistryTests { + private ConcreteAbstractRequestMatcherMappingConfigurer registry; @Before public void setup() { - registry = new ConcreteAbstractRequestMatcherMappingConfigurer(); + this.registry = new ConcreteAbstractRequestMatcherMappingConfigurer(); } @Test - public void testGetRequestMatcherIsTypeRegexMatcher(){ - List requestMatchers = registry.regexMatchers(HttpMethod.GET, "/a.*"); - + public void testGetRequestMatcherIsTypeRegexMatcher() { + List requestMatchers = this.registry.regexMatchers(HttpMethod.GET, "/a.*"); for (RequestMatcher requestMatcher : requestMatchers) { assertThat(requestMatcher).isInstanceOf(RegexRequestMatcher.class); } } @Test - public void testRequestMatcherIsTypeRegexMatcher(){ - List requestMatchers = registry.regexMatchers( "/a.*"); - + public void testRequestMatcherIsTypeRegexMatcher() { + List requestMatchers = this.registry.regexMatchers("/a.*"); for (RequestMatcher requestMatcher : requestMatchers) { assertThat(requestMatcher).isInstanceOf(RegexRequestMatcher.class); } } @Test - public void testGetRequestMatcherIsTypeAntPathRequestMatcher(){ - List requestMatchers = registry.antMatchers(HttpMethod.GET, "/a.*"); - + public void testGetRequestMatcherIsTypeAntPathRequestMatcher() { + List requestMatchers = this.registry.antMatchers(HttpMethod.GET, "/a.*"); for (RequestMatcher requestMatcher : requestMatchers) { assertThat(requestMatcher).isInstanceOf(AntPathRequestMatcher.class); } } @Test - public void testRequestMatcherIsTypeAntPathRequestMatcher(){ - List requestMatchers = registry.antMatchers("/a.*"); - + public void testRequestMatcherIsTypeAntPathRequestMatcher() { + List requestMatchers = this.registry.antMatchers("/a.*"); for (RequestMatcher requestMatcher : requestMatchers) { assertThat(requestMatcher).isInstanceOf(AntPathRequestMatcher.class); } } - static class ConcreteAbstractRequestMatcherMappingConfigurer extends AbstractConfigAttributeRequestMatcherRegistry> { + static class ConcreteAbstractRequestMatcherMappingConfigurer + extends AbstractConfigAttributeRequestMatcherRegistry> { + List decisionVoters() { return null; } + @Override protected List chainRequestMatchersInternal(List requestMatchers) { return requestMatchers; } + @Override public List mvcMatchers(String... mvcPatterns) { return null; } + @Override public List mvcMatchers(HttpMethod method, String... mvcPatterns) { return null; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurerTests.java index 11d80e02d8..c25de5ebaf 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AnonymousConfigurerTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -40,6 +42,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Josh Cummings */ public class AnonymousConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -49,9 +52,25 @@ public class AnonymousConfigurerTests { @Test public void requestWhenAnonymousTwiceInvokedThenDoesNotOverride() throws Exception { this.spring.register(InvokeTwiceDoesNotOverride.class, PrincipalController.class).autowire(); + this.mockMvc.perform(get("/")).andExpect(content().string("principal")); + } - this.mockMvc.perform(get("/")) - .andExpect(content().string("principal")); + @Test + public void requestWhenAnonymousPrincipalInLambdaThenPrincipalUsed() throws Exception { + this.spring.register(AnonymousPrincipalInLambdaConfig.class, PrincipalController.class).autowire(); + this.mockMvc.perform(get("/")).andExpect(content().string("principal")); + } + + @Test + public void requestWhenAnonymousDisabledInLambdaThenRespondsWithForbidden() throws Exception { + this.spring.register(AnonymousDisabledInLambdaConfig.class, PrincipalController.class).autowire(); + this.mockMvc.perform(get("/")).andExpect(status().isForbidden()); + } + + @Test + public void requestWhenAnonymousWithDefaultsInLambdaThenRespondsWithOk() throws Exception { + this.spring.register(AnonymousWithDefaultsInLambdaConfig.class, PrincipalController.class).autowire(); + this.mockMvc.perform(get("/")).andExpect(status().isOk()); } @EnableWebSecurity @@ -60,21 +79,16 @@ public class AnonymousConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .anonymous() .key("key") .principal("principal") .and() .anonymous(); + // @formatter:on } - } - @Test - public void requestWhenAnonymousPrincipalInLambdaThenPrincipalUsed() throws Exception { - this.spring.register(AnonymousPrincipalInLambdaConfig.class, PrincipalController.class).autowire(); - - this.mockMvc.perform(get("/")) - .andExpect(content().string("principal")); } @EnableWebSecurity @@ -85,29 +99,23 @@ public class AnonymousConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .anonymous(anonymous -> + .anonymous((anonymous) -> anonymous .principal("principal") ); // @formatter:on } - } - @Test - public void requestWhenAnonymousDisabledInLambdaThenRespondsWithForbidden() throws Exception { - this.spring.register(AnonymousDisabledInLambdaConfig.class, PrincipalController.class).autowire(); - - this.mockMvc.perform(get("/")) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class AnonymousDisabledInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().permitAll() ) @@ -115,6 +123,7 @@ public class AnonymousConfigurerTests { // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth @@ -122,23 +131,17 @@ public class AnonymousConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void requestWhenAnonymousWithDefaultsInLambdaThenRespondsWithOk() throws Exception { - this.spring.register(AnonymousWithDefaultsInLambdaConfig.class, PrincipalController.class).autowire(); - - this.mockMvc.perform(get("/")) - .andExpect(status().isOk()); } @EnableWebSecurity static class AnonymousWithDefaultsInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().permitAll() ) @@ -146,6 +149,7 @@ public class AnonymousConfigurerTests { // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth @@ -153,13 +157,17 @@ public class AnonymousConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } + } @RestController static class PrincipalController { + @GetMapping("/") String principal(@AuthenticationPrincipal String principal) { return principal; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java index d10ea89dce..22792f926e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletResponse; @@ -58,11 +59,15 @@ import static org.springframework.security.config.Customizer.withDefaults; * */ public class AuthorizeRequestsTests { + AnnotationConfigWebApplicationContext context; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; + MockServletContext servletContext; @Autowired @@ -89,15 +94,196 @@ public class AuthorizeRequestsTests { public void antMatchersMethodAndNoPatterns() throws Exception { loadConfig(AntMatchersNoPatternsConfig.class); this.request.setMethod("POST"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } + @Test + public void postWhenPostDenyAllInLambdaThenRespondsWithForbidden() throws Exception { + loadConfig(AntMatchersNoPatternsInLambdaConfig.class); + this.request.setMethod("POST"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + } + + // SEC-2256 + @Test + public void antMatchersPathVariables() throws Exception { + loadConfig(AntPatchersPathVariables.class); + this.request.setServletPath("/user/user"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.setup(); + this.request.setServletPath("/user/deny"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + } + + // SEC-2256 + @Test + public void antMatchersPathVariablesCaseInsensitive() throws Exception { + loadConfig(AntPatchersPathVariables.class); + this.request.setServletPath("/USER/user"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.setup(); + this.request.setServletPath("/USER/deny"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + } + + // gh-3786 + @Test + public void antMatchersPathVariablesCaseInsensitiveCamelCaseVariables() throws Exception { + loadConfig(AntMatchersPathVariablesCamelCaseVariables.class); + this.request.setServletPath("/USER/user"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.setup(); + this.request.setServletPath("/USER/deny"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + } + + // gh-3394 + @Test + public void roleHiearchy() throws Exception { + loadConfig(RoleHiearchyConfig.class); + SecurityContext securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new UsernamePasswordAuthenticationToken("test", "notused", + AuthorityUtils.createAuthorityList("ROLE_USER"))); + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + securityContext); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + @Test + public void mvcMatcher() throws Exception { + loadConfig(MvcMatcherConfig.class, LegacyMvcMatchingConfig.class); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setRequestURI("/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void requestWhenMvcMatcherDenyAllThenRespondsWithUnauthorized() throws Exception { + loadConfig(MvcMatcherInLambdaConfig.class, LegacyMvcMatchingConfig.class); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setRequestURI("/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void requestWhenMvcMatcherServletPathDenyAllThenMatchesOnServletPath() throws Exception { + loadConfig(MvcMatcherServletPathInLambdaConfig.class, LegacyMvcMatchingConfig.class); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/foo"); + this.request.setRequestURI("/foo/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/"); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + @Test + public void mvcMatcherPathVariables() throws Exception { + loadConfig(MvcMatcherPathVariablesConfig.class); + this.request.setRequestURI("/user/user"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.setup(); + this.request.setRequestURI("/user/deny"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void requestWhenMvcMatcherPathVariablesThenMatchesOnPathVariables() throws Exception { + loadConfig(MvcMatcherPathVariablesInLambdaConfig.class); + this.request.setRequestURI("/user/user"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.setup(); + this.request.setRequestURI("/user/deny"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void mvcMatcherServletPath() throws Exception { + loadConfig(MvcMatcherServletPathConfig.class, LegacyMvcMatchingConfig.class); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/foo"); + this.request.setRequestURI("/foo/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/"); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + public void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.setServletContext(this.servletContext); + this.context.refresh(); + this.context.getAutowireCapableBeanFactory().autowireBean(this); + } + @EnableWebSecurity @Configuration static class AntMatchersNoPatternsConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -114,26 +300,18 @@ public class AuthorizeRequestsTests { .inMemoryAuthentication(); // @formatter:on } - } - @Test - public void postWhenPostDenyAllInLambdaThenRespondsWithForbidden() throws Exception { - loadConfig(AntMatchersNoPatternsInLambdaConfig.class); - this.request.setMethod("POST"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } @EnableWebSecurity @Configuration static class AntMatchersNoPatternsInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers(HttpMethod.POST).denyAll() ); @@ -147,49 +325,13 @@ public class AuthorizeRequestsTests { .inMemoryAuthentication(); // @formatter:on } - } - // SEC-2256 - @Test - public void antMatchersPathVariables() throws Exception { - loadConfig(AntPatchersPathVariables.class); - - this.request.setServletPath("/user/user"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - this.setup(); - this.request.setServletPath("/user/deny"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); - } - - // SEC-2256 - @Test - public void antMatchersPathVariablesCaseInsensitive() throws Exception { - loadConfig(AntPatchersPathVariables.class); - - this.request.setServletPath("/USER/user"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - this.setup(); - this.request.setServletPath("/USER/deny"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } @EnableWebSecurity @Configuration static class AntPatchersPathVariables extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -207,30 +349,13 @@ public class AuthorizeRequestsTests { .inMemoryAuthentication(); // @formatter:on } - } - // gh-3786 - @Test - public void antMatchersPathVariablesCaseInsensitiveCamelCaseVariables() throws Exception { - loadConfig(AntMatchersPathVariablesCamelCaseVariables.class); - - this.request.setServletPath("/USER/user"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - this.setup(); - this.request.setServletPath("/USER/deny"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } @EnableWebSecurity @Configuration static class AntMatchersPathVariablesCamelCaseVariables extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -248,25 +373,13 @@ public class AuthorizeRequestsTests { .inMemoryAuthentication(); // @formatter:on } - } - // gh-3394 - @Test - public void roleHiearchy() throws Exception { - loadConfig(RoleHiearchyConfig.class); - - SecurityContext securityContext = new SecurityContextImpl(); - securityContext.setAuthentication(new UsernamePasswordAuthenticationToken("test", "notused", AuthorityUtils.createAuthorityList("ROLE_USER"))); - this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, securityContext); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration static class RoleHiearchyConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -285,44 +398,19 @@ public class AuthorizeRequestsTests { } @Bean - public RoleHierarchy roleHiearchy() { + RoleHierarchy roleHiearchy() { RoleHierarchyImpl result = new RoleHierarchyImpl(); result.setHierarchy("ROLE_USER > ROLE_ADMIN"); return result; } - } - @Test - public void mvcMatcher() throws Exception { - loadConfig(MvcMatcherConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setRequestURI("/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -343,50 +431,27 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestWhenMvcMatcherDenyAllThenRespondsWithUnauthorized() throws Exception { - loadConfig(MvcMatcherInLambdaConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setRequestURI("/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http .httpBasic(withDefaults()) - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .mvcMatchers("/path").denyAll() ); @@ -403,63 +468,21 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void mvcMatcherServletPath() throws Exception { - loadConfig(MvcMatcherServletPathConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/foo"); - this.request.setRequestURI("/foo/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/"); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherServletPathConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -480,69 +503,27 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestWhenMvcMatcherServletPathDenyAllThenMatchesOnServletPath() throws Exception { - loadConfig(MvcMatcherServletPathInLambdaConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/foo"); - this.request.setRequestURI("/foo/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/"); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherServletPathInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http .httpBasic(withDefaults()) - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .mvcMatchers("/path").servletPath("/spring").denyAll() ); @@ -559,36 +540,21 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void mvcMatcherPathVariables() throws Exception { - loadConfig(MvcMatcherPathVariablesConfig.class); - - this.request.setRequestURI("/user/user"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - this.setup(); - this.request.setRequestURI("/user/deny"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherPathVariablesConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -609,42 +575,27 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestWhenMvcMatcherPathVariablesThenMatchesOnPathVariables() throws Exception { - loadConfig(MvcMatcherPathVariablesInLambdaConfig.class); - - this.request.setRequestURI("/user/user"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - this.setup(); - this.request.setRequestURI("/user/deny"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherPathVariablesInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http .httpBasic(withDefaults()) - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .mvcMatchers("/user/{userName}").access("#userName == 'user'") ); @@ -661,17 +612,21 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } + } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherPathServletPathRequiredConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -692,27 +647,24 @@ public class AuthorizeRequestsTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } + } @Configuration static class LegacyMvcMatchingConfig implements WebMvcConfigurer { + @Override public void configurePathMatch(PathMatchConfigurer configurer) { configurer.setUseSuffixPatternMatch(true); } + } - public void loadConfig(Class... configs) { - this.context = new AnnotationConfigWebApplicationContext(); - this.context.register(configs); - this.context.setServletContext(this.servletContext); - this.context.refresh(); - - this.context.getAutowireCapableBeanFactory().autowireBean(this); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurerTests.java index c33682e716..491c1960ea 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ChannelSecurityConfigurerTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.ObjectPostProcessor; @@ -55,40 +56,45 @@ public class ChannelSecurityConfigurerTests { public void configureWhenRegisteringObjectPostProcessorThenInvokedOnInsecureChannelProcessor() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(InsecureChannelProcessor.class)); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(InsecureChannelProcessor.class)); } @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnSecureChannelProcessor() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(SecureChannelProcessor.class)); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(SecureChannelProcessor.class)); } @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnChannelDecisionManagerImpl() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ChannelDecisionManagerImpl.class)); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ChannelDecisionManagerImpl.class)); } @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnChannelProcessingFilter() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ChannelProcessingFilter.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ChannelProcessingFilter.class)); + @Test + public void requiresChannelWhenInvokesTwiceThenUsesOriginalRequiresSecure() throws Exception { + this.spring.register(DuplicateInvocationsDoesNotOverrideConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(redirectedUrl("https://localhost/")); + } + + @Test + public void requestWhenRequiresChannelConfiguredInLambdaThenRedirectsToHttps() throws Exception { + this.spring.register(RequiresChannelInLambdaConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(redirectedUrl("https://localhost/")); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor; @Override @@ -104,21 +110,16 @@ public class ChannelSecurityConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void requiresChannelWhenInvokesTwiceThenUsesOriginalRequiresSecure() throws Exception { - this.spring.register(DuplicateInvocationsDoesNotOverrideConfig.class).autowire(); - - mvc.perform(get("/")) - .andExpect(redirectedUrl("https://localhost/")); } @EnableWebSecurity @@ -134,14 +135,7 @@ public class ChannelSecurityConfigurerTests { .requiresChannel(); // @formatter:on } - } - @Test - public void requestWhenRequiresChannelConfiguredInLambdaThenRedirectsToHttps() throws Exception { - this.spring.register(RequiresChannelInLambdaConfig.class).autowire(); - - mvc.perform(get("/")) - .andExpect(redirectedUrl("https://localhost/")); } @EnableWebSecurity @@ -151,11 +145,13 @@ public class ChannelSecurityConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .requiresChannel(requiresChannel -> + .requiresChannel((requiresChannel) -> requiresChannel .anyRequest().requiresSecure() ); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurerTests.java index b880b3ef5b..a13bb2af4f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CorsConfigurerTests.java @@ -16,9 +16,13 @@ package org.springframework.security.config.annotation.web.configurers; +import java.util.Arrays; +import java.util.Collections; + import com.google.common.net.HttpHeaders; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -38,10 +42,7 @@ import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import org.springframework.web.filter.CorsFilter; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import java.util.Arrays; -import java.util.Collections; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.options; @@ -55,6 +56,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Eleftheria Stein */ public class CorsConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -63,9 +65,119 @@ public class CorsConfigurerTests { @Test public void configureWhenNoMvcThenException() { - assertThatThrownBy(() -> this.spring.register(DefaultCorsConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(DefaultCorsConfig.class).autowire()).withMessageContaining( + "Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext"); + } + + @Test + public void getWhenCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(MvcCorsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ORIGIN, "https://example.com")) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void optionsWhenCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(MvcCorsConfig.class).autowire(); + this.mvc.perform(options("/") + .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) + .header(HttpHeaders.ORIGIN, "https://example.com")).andExpect(status().isOk()) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void getWhenDefaultsInLambdaAndCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(MvcCorsInLambdaConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ORIGIN, "https://example.com")) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void optionsWhenDefaultsInLambdaAndCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(MvcCorsInLambdaConfig.class).autowire(); + this.mvc.perform(options("/") + .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) + .header(HttpHeaders.ORIGIN, "https://example.com")).andExpect(status().isOk()) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void getWhenCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(ConfigSourceConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ORIGIN, "https://example.com")) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void optionsWhenCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(ConfigSourceConfig.class).autowire(); + this.mvc.perform(options("/") + .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) + .header(HttpHeaders.ORIGIN, "https://example.com")).andExpect(status().isOk()) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void getWhenMvcCorsInLambdaConfigAndCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() + throws Exception { + this.spring.register(ConfigSourceInLambdaConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ORIGIN, "https://example.com")) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void optionsWhenMvcCorsInLambdaConfigAndCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() + throws Exception { + this.spring.register(ConfigSourceInLambdaConfig.class).autowire(); + this.mvc.perform(options("/") + .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) + .header(HttpHeaders.ORIGIN, "https://example.com")).andExpect(status().isOk()) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void getWhenCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(CorsFilterConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ORIGIN, "https://example.com")) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void optionsWhenCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(CorsFilterConfig.class).autowire(); + this.mvc.perform(options("/") + .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) + .header(HttpHeaders.ORIGIN, "https://example.com")).andExpect(status().isOk()) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void getWhenConfigSourceInLambdaConfigAndCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(CorsFilterInLambdaConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ORIGIN, "https://example.com")) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); + } + + @Test + public void optionsWhenConfigSourceInLambdaConfigAndCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { + this.spring.register(CorsFilterInLambdaConfig.class).autowire(); + this.mvc.perform(options("/") + .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) + .header(HttpHeaders.ORIGIN, "https://example.com")).andExpect(status().isOk()) + .andExpect(header().exists("Access-Control-Allow-Origin")) + .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebSecurity @@ -81,28 +193,7 @@ public class CorsConfigurerTests { .cors(); // @formatter:on } - } - @Test - public void getWhenCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(MvcCorsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); - } - - @Test - public void optionsWhenCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(MvcCorsConfig.class).autowire(); - - this.mvc.perform(options("/") - .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(status().isOk()) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebMvc @@ -121,37 +212,16 @@ public class CorsConfigurerTests { } @RestController - @CrossOrigin(methods = { - RequestMethod.GET, RequestMethod.POST - }) + @CrossOrigin(methods = { RequestMethod.GET, RequestMethod.POST }) static class CorsController { + @RequestMapping("/") String hello() { return "Hello"; } + } - } - @Test - public void getWhenDefaultsInLambdaAndCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(MvcCorsInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); - } - - @Test - public void optionsWhenDefaultsInLambdaAndCrossOriginAnnotationThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(MvcCorsInLambdaConfig.class).autowire(); - - this.mvc.perform(options("/") - .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(status().isOk()) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebMvc @@ -162,7 +232,7 @@ public class CorsConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -171,37 +241,16 @@ public class CorsConfigurerTests { } @RestController - @CrossOrigin(methods = { - RequestMethod.GET, RequestMethod.POST - }) + @CrossOrigin(methods = { RequestMethod.GET, RequestMethod.POST }) static class CorsController { + @RequestMapping("/") String hello() { return "Hello"; } + } - } - @Test - public void getWhenCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(ConfigSourceConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); - } - - @Test - public void optionsWhenCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(ConfigSourceConfig.class).autowire(); - - this.mvc.perform(options("/") - .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(status().isOk()) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebSecurity @@ -223,36 +272,11 @@ public class CorsConfigurerTests { UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); CorsConfiguration corsConfiguration = new CorsConfiguration(); corsConfiguration.setAllowedOrigins(Collections.singletonList("*")); - corsConfiguration.setAllowedMethods(Arrays.asList( - RequestMethod.GET.name(), - RequestMethod.POST.name())); + corsConfiguration.setAllowedMethods(Arrays.asList(RequestMethod.GET.name(), RequestMethod.POST.name())); source.registerCorsConfiguration("/**", corsConfiguration); return source; } - } - @Test - public void getWhenMvcCorsInLambdaConfigAndCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() - throws Exception { - this.spring.register(ConfigSourceInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); - } - - @Test - public void optionsWhenMvcCorsInLambdaConfigAndCorsConfigurationSourceBeanThenRespondsWithCorsHeaders() - throws Exception { - this.spring.register(ConfigSourceInLambdaConfig.class).autowire(); - - this.mvc.perform(options("/") - .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(status().isOk()) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebSecurity @@ -262,7 +286,7 @@ public class CorsConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -275,34 +299,11 @@ public class CorsConfigurerTests { UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); CorsConfiguration corsConfiguration = new CorsConfiguration(); corsConfiguration.setAllowedOrigins(Collections.singletonList("*")); - corsConfiguration.setAllowedMethods(Arrays.asList( - RequestMethod.GET.name(), - RequestMethod.POST.name())); + corsConfiguration.setAllowedMethods(Arrays.asList(RequestMethod.GET.name(), RequestMethod.POST.name())); source.registerCorsConfiguration("/**", corsConfiguration); return source; } - } - @Test - public void getWhenCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(CorsFilterConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); - } - - @Test - public void optionsWhenCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(CorsFilterConfig.class).autowire(); - - this.mvc.perform(options("/") - .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(status().isOk()) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebSecurity @@ -324,34 +325,11 @@ public class CorsConfigurerTests { UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); CorsConfiguration corsConfiguration = new CorsConfiguration(); corsConfiguration.setAllowedOrigins(Collections.singletonList("*")); - corsConfiguration.setAllowedMethods(Arrays.asList( - RequestMethod.GET.name(), - RequestMethod.POST.name())); + corsConfiguration.setAllowedMethods(Arrays.asList(RequestMethod.GET.name(), RequestMethod.POST.name())); source.registerCorsConfiguration("/**", corsConfiguration); return new CorsFilter(source); } - } - @Test - public void getWhenConfigSourceInLambdaConfigAndCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(CorsFilterInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); - } - - @Test - public void optionsWhenConfigSourceInLambdaConfigAndCorsFilterBeanThenRespondsWithCorsHeaders() throws Exception { - this.spring.register(CorsFilterInLambdaConfig.class).autowire(); - - this.mvc.perform(options("/") - .header(org.springframework.http.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()) - .header(HttpHeaders.ORIGIN, "https://example.com")) - .andExpect(status().isOk()) - .andExpect(header().exists("Access-Control-Allow-Origin")) - .andExpect(header().exists("X-Content-Type-Options")); } @EnableWebSecurity @@ -361,7 +339,7 @@ public class CorsConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -374,11 +352,11 @@ public class CorsConfigurerTests { UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); CorsConfiguration corsConfiguration = new CorsConfiguration(); corsConfiguration.setAllowedOrigins(Collections.singletonList("*")); - corsConfiguration.setAllowedMethods(Arrays.asList( - RequestMethod.GET.name(), - RequestMethod.POST.name())); + corsConfiguration.setAllowedMethods(Arrays.asList(RequestMethod.GET.name(), RequestMethod.POST.name())); source.registerCorsConfiguration("/**", corsConfiguration); return new CorsFilter(source); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerIgnoringRequestMatchersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerIgnoringRequestMatchersTests.java index 56d99f0610..7ae673cef3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerIgnoringRequestMatchersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerIgnoringRequestMatchersTests.java @@ -48,21 +48,40 @@ public class CsrfConfigurerIgnoringRequestMatchersTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void requestWhenIgnoringRequestMatchersThenAugmentedByConfiguredRequestMatcher() - throws Exception { + public void requestWhenIgnoringRequestMatchersThenAugmentedByConfiguredRequestMatcher() throws Exception { this.spring.register(IgnoringRequestMatchers.class, BasicController.class).autowire(); + this.mvc.perform(get("/path")).andExpect(status().isForbidden()); + this.mvc.perform(post("/path")).andExpect(status().isOk()); + } - this.mvc.perform(get("/path")) - .andExpect(status().isForbidden()); + @Test + public void requestWhenIgnoringRequestMatchersInLambdaThenAugmentedByConfiguredRequestMatcher() throws Exception { + this.spring.register(IgnoringRequestInLambdaMatchers.class, BasicController.class).autowire(); + this.mvc.perform(get("/path")).andExpect(status().isForbidden()); + this.mvc.perform(post("/path")).andExpect(status().isOk()); + } - this.mvc.perform(post("/path")) - .andExpect(status().isOk()); + @Test + public void requestWhenIgnoringRequestMatcherThenUnionsWithConfiguredIgnoringAntMatchers() throws Exception { + this.spring.register(IgnoringPathsAndMatchers.class, BasicController.class).autowire(); + this.mvc.perform(put("/csrf")).andExpect(status().isForbidden()); + this.mvc.perform(post("/csrf")).andExpect(status().isOk()); + this.mvc.perform(put("/no-csrf")).andExpect(status().isOk()); + } + + @Test + public void requestWhenIgnoringRequestMatcherInLambdaThenUnionsWithConfiguredIgnoringAntMatchers() + throws Exception { + this.spring.register(IgnoringPathsAndMatchersInLambdaConfig.class, BasicController.class).autowire(); + this.mvc.perform(put("/csrf")).andExpect(status().isForbidden()); + this.mvc.perform(post("/csrf")).andExpect(status().isOk()); + this.mvc.perform(put("/no-csrf")).andExpect(status().isOk()); } @EnableWebSecurity static class IgnoringRequestMatchers extends WebSecurityConfigurerAdapter { - RequestMatcher requestMatcher = - request -> HttpMethod.POST.name().equals(request.getMethod()); + + RequestMatcher requestMatcher = (request) -> HttpMethod.POST.name().equals(request.getMethod()); @Override protected void configure(HttpSecurity http) throws Exception { @@ -73,58 +92,32 @@ public class CsrfConfigurerIgnoringRequestMatchersTests { .ignoringRequestMatchers(this.requestMatcher); // @formatter:on } - } - @Test - public void requestWhenIgnoringRequestMatchersInLambdaThenAugmentedByConfiguredRequestMatcher() - throws Exception { - this.spring.register(IgnoringRequestInLambdaMatchers.class, BasicController.class).autowire(); - - this.mvc.perform(get("/path")) - .andExpect(status().isForbidden()); - - this.mvc.perform(post("/path")) - .andExpect(status().isOk()); } @EnableWebSecurity static class IgnoringRequestInLambdaMatchers extends WebSecurityConfigurerAdapter { - RequestMatcher requestMatcher = - request -> HttpMethod.POST.name().equals(request.getMethod()); + + RequestMatcher requestMatcher = (request) -> HttpMethod.POST.name().equals(request.getMethod()); @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .csrf(csrf -> + .csrf((csrf) -> csrf .requireCsrfProtectionMatcher(new AntPathRequestMatcher("/path")) .ignoringRequestMatchers(this.requestMatcher) ); // @formatter:on } - } - @Test - public void requestWhenIgnoringRequestMatcherThenUnionsWithConfiguredIgnoringAntMatchers() - throws Exception { - - this.spring.register(IgnoringPathsAndMatchers.class, BasicController.class).autowire(); - - this.mvc.perform(put("/csrf")) - .andExpect(status().isForbidden()); - - this.mvc.perform(post("/csrf")) - .andExpect(status().isOk()); - - this.mvc.perform(put("/no-csrf")) - .andExpect(status().isOk()); } @EnableWebSecurity static class IgnoringPathsAndMatchers extends WebSecurityConfigurerAdapter { - RequestMatcher requestMatcher = - request -> HttpMethod.POST.name().equals(request.getMethod()); + + RequestMatcher requestMatcher = (request) -> HttpMethod.POST.name().equals(request.getMethod()); @Override protected void configure(HttpSecurity http) throws Exception { @@ -135,44 +128,31 @@ public class CsrfConfigurerIgnoringRequestMatchersTests { .ignoringRequestMatchers(this.requestMatcher); // @formatter:on } - } - @Test - public void requestWhenIgnoringRequestMatcherInLambdaThenUnionsWithConfiguredIgnoringAntMatchers() - throws Exception { - - this.spring.register(IgnoringPathsAndMatchersInLambdaConfig.class, BasicController.class).autowire(); - - this.mvc.perform(put("/csrf")) - .andExpect(status().isForbidden()); - - this.mvc.perform(post("/csrf")) - .andExpect(status().isOk()); - - this.mvc.perform(put("/no-csrf")) - .andExpect(status().isOk()); } @EnableWebSecurity static class IgnoringPathsAndMatchersInLambdaConfig extends WebSecurityConfigurerAdapter { - RequestMatcher requestMatcher = - request -> HttpMethod.POST.name().equals(request.getMethod()); + + RequestMatcher requestMatcher = (request) -> HttpMethod.POST.name().equals(request.getMethod()); @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .csrf(csrf -> + .csrf((csrf) -> csrf .ignoringAntMatchers("/no-csrf") .ignoringRequestMatchers(this.requestMatcher) ); // @formatter:on } + } @RestController public static class BasicController { + @RequestMapping("/path") public String path() { return "path"; @@ -187,5 +167,7 @@ public class CsrfConfigurerIgnoringRequestMatchersTests { public String noCsrf() { return "no-csrf"; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerNoWebMvcTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerNoWebMvcTests.java index 61df857d62..108f04abb2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerNoWebMvcTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerNoWebMvcTests.java @@ -13,13 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +package org.springframework.security.config.annotation.web.configurers; import org.junit.After; import org.junit.Test; + import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; @@ -27,75 +26,80 @@ import org.springframework.context.annotation.Primary; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; - import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor; import org.springframework.web.servlet.support.RequestDataValueProcessor; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + /** * @author Rob Winch * */ public class CsrfConfigurerNoWebMvcTests { + ConfigurableApplicationContext context; @After public void teardown() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @Test public void missingDispatcherServletPreventsCsrfRequestDataValueProcessor() { loadContext(EnableWebConfig.class); - - assertThat(context.containsBeanDefinition("requestDataValueProcessor")).isTrue(); + assertThat(this.context.containsBeanDefinition("requestDataValueProcessor")).isTrue(); } @Test public void findDispatcherServletPreventsCsrfRequestDataValueProcessor() { loadContext(EnableWebMvcConfig.class); - - assertThat(context.containsBeanDefinition("requestDataValueProcessor")).isTrue(); + assertThat(this.context.containsBeanDefinition("requestDataValueProcessor")).isTrue(); } @Test public void overrideCsrfRequestDataValueProcessor() { loadContext(EnableWebOverrideRequestDataConfig.class); - - assertThat(context.getBean(RequestDataValueProcessor.class).getClass()) + assertThat(this.context.getBean(RequestDataValueProcessor.class).getClass()) .isNotEqualTo(CsrfRequestDataValueProcessor.class); } - @EnableWebSecurity - static class EnableWebConfig extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) { - } - } - - @EnableWebSecurity - static class EnableWebOverrideRequestDataConfig { - @Bean - @Primary - public RequestDataValueProcessor requestDataValueProcessor() { - return mock(RequestDataValueProcessor.class); - } - } - - @EnableWebSecurity - static class EnableWebMvcConfig extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) { - } - } - private void loadContext(Class configs) { AnnotationConfigApplicationContext annotationConfigApplicationContext = new AnnotationConfigApplicationContext(); annotationConfigApplicationContext.register(configs); annotationConfigApplicationContext.refresh(); this.context = annotationConfigApplicationContext; } + + @EnableWebSecurity + static class EnableWebConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) { + } + + } + + @EnableWebSecurity + static class EnableWebOverrideRequestDataConfig { + + @Bean + @Primary + RequestDataValueProcessor requestDataValueProcessor() { + return mock(RequestDataValueProcessor.class); + } + + } + + @EnableWebSecurity + static class EnableWebMvcConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) { + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index 7165b06422..a533dd62d8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -16,8 +16,14 @@ package org.springframework.security.config.annotation.web.configurers; +import java.net.URI; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -40,29 +46,33 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.support.RequestDataValueProcessor; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.net.URI; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.head; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.options; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -75,6 +85,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Sam Simmons */ public class CsrfConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -83,100 +94,317 @@ public class CsrfConfigurerTests { @Test public void postWhenWebSecurityEnabledThenRespondsWithForbidden() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(post("/")) - .andExpect(status().isForbidden()); + this.mvc.perform(post("/")).andExpect(status().isForbidden()); } @Test public void putWhenWebSecurityEnabledThenRespondsWithForbidden() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(put("/")) - .andExpect(status().isForbidden()); + this.mvc.perform(put("/")).andExpect(status().isForbidden()); } @Test public void patchWhenWebSecurityEnabledThenRespondsWithForbidden() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(patch("/")) - .andExpect(status().isForbidden()); + this.mvc.perform(patch("/")).andExpect(status().isForbidden()); } @Test public void deleteWhenWebSecurityEnabledThenRespondsWithForbidden() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(delete("/")) - .andExpect(status().isForbidden()); + this.mvc.perform(delete("/")).andExpect(status().isForbidden()); } @Test public void invalidWhenWebSecurityEnabledThenRespondsWithForbidden() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(request("INVALID", URI.create("/"))) - .andExpect(status().isForbidden()); + this.mvc.perform(request("INVALID", URI.create("/"))).andExpect(status().isForbidden()); } @Test public void getWhenWebSecurityEnabledThenRespondsWithOk() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); + this.mvc.perform(get("/")).andExpect(status().isOk()); } @Test public void headWhenWebSecurityEnabledThenRespondsWithOk() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(head("/")) - .andExpect(status().isOk()); + this.mvc.perform(head("/")).andExpect(status().isOk()); } @Test public void traceWhenWebSecurityEnabledThenRespondsWithOk() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(request(HttpMethod.TRACE, "/")) - .andExpect(status().isOk()); + this.mvc.perform(request(HttpMethod.TRACE, "/")).andExpect(status().isOk()); } @Test public void optionsWhenWebSecurityEnabledThenRespondsWithOk() throws Exception { - this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) + this.spring + .register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class, BasicController.class) .autowire(); - - this.mvc.perform(options("/")) - .andExpect(status().isOk()); + this.mvc.perform(options("/")).andExpect(status().isOk()); } @Test public void enableWebSecurityWhenDefaultConfigurationThenCreatesRequestDataValueProcessor() { this.spring.register(CsrfAppliedDefaultConfig.class, AllowHttpMethodsFirewallConfig.class).autowire(); - assertThat(this.spring.getContext().getBean(RequestDataValueProcessor.class)).isNotNull(); } + @Test + public void postWhenCsrfDisabledThenRespondsWithOk() throws Exception { + this.spring.register(DisableCsrfConfig.class, BasicController.class).autowire(); + this.mvc.perform(post("/")).andExpect(status().isOk()); + } + + @Test + public void postWhenCsrfDisabledInLambdaThenRespondsWithOk() throws Exception { + this.spring.register(DisableCsrfInLambdaConfig.class, BasicController.class).autowire(); + this.mvc.perform(post("/")).andExpect(status().isOk()); + } + + // SEC-2498 + @Test + public void loginWhenCsrfDisabledThenRedirectsToPreviousPostRequest() throws Exception { + this.spring.register(DisableCsrfEnablesRequestCacheConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(post("/to-save")).andReturn(); + this.mvc.perform(post("/login").param("username", "user").param("password", "password") + .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/to-save")); + } + + @Test + public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception { + CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); + DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); + given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken); + given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); + this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn(); + this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) + .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) + .andExpect(redirectedUrl("/")); + verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) + .loadToken(any(HttpServletRequest.class)); + } + + @Test + public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception { + CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); + DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); + given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken); + given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken); + this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn(); + this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf()) + .session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/some-url")); + verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()) + .loadToken(any(HttpServletRequest.class)); + } + + // SEC-2422 + @Test + public void postWhenCsrfEnabledAndSessionIsExpiredThenRespondsWithForbidden() throws Exception { + this.spring.register(InvalidSessionUrlConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(post("/").param("_csrf", "abc")).andExpect(status().isFound()) + .andExpect(redirectedUrl("/error/sessionError")).andReturn(); + this.mvc.perform(post("/").session((MockHttpSession) mvcResult.getRequest().getSession())) + .andExpect(status().isForbidden()); + } + + @Test + public void requireCsrfProtectionMatcherWhenRequestDoesNotMatchThenRespondsWithOk() throws Exception { + this.spring.register(RequireCsrfProtectionMatcherConfig.class, BasicController.class).autowire(); + given(RequireCsrfProtectionMatcherConfig.MATCHER.matches(any())).willReturn(false); + this.mvc.perform(get("/")).andExpect(status().isOk()); + } + + @Test + public void requireCsrfProtectionMatcherWhenRequestMatchesThenRespondsWithForbidden() throws Exception { + RequireCsrfProtectionMatcherConfig.MATCHER = mock(RequestMatcher.class); + given(RequireCsrfProtectionMatcherConfig.MATCHER.matches(any())).willReturn(true); + this.spring.register(RequireCsrfProtectionMatcherConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isForbidden()); + } + + @Test + public void requireCsrfProtectionMatcherInLambdaWhenRequestDoesNotMatchThenRespondsWithOk() throws Exception { + RequireCsrfProtectionMatcherInLambdaConfig.MATCHER = mock(RequestMatcher.class); + this.spring.register(RequireCsrfProtectionMatcherInLambdaConfig.class, BasicController.class).autowire(); + given(RequireCsrfProtectionMatcherInLambdaConfig.MATCHER.matches(any())).willReturn(false); + this.mvc.perform(get("/")).andExpect(status().isOk()); + } + + @Test + public void requireCsrfProtectionMatcherInLambdaWhenRequestMatchesThenRespondsWithForbidden() throws Exception { + RequireCsrfProtectionMatcherInLambdaConfig.MATCHER = mock(RequestMatcher.class); + given(RequireCsrfProtectionMatcherInLambdaConfig.MATCHER.matches(any())).willReturn(true); + this.spring.register(RequireCsrfProtectionMatcherInLambdaConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isForbidden()); + } + + @Test + public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception { + CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); + given(CsrfTokenRepositoryConfig.REPO.loadToken(any())) + .willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); + this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isOk()); + verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class)); + } + + @Test + public void logoutWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { + CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); + this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); + this.mvc.perform(post("/logout").with(csrf()).with(user("user"))); + verify(CsrfTokenRepositoryConfig.REPO).saveToken(isNull(), any(HttpServletRequest.class), + any(HttpServletResponse.class)); + } + + @Test + public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { + CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); + DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); + given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken); + given(CsrfTokenRepositoryConfig.REPO.generateToken(any())).willReturn(csrfToken); + this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + // @formatter:on + this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + verify(CsrfTokenRepositoryConfig.REPO).saveToken(isNull(), any(HttpServletRequest.class), + any(HttpServletResponse.class)); + } + + @Test + public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception { + CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class); + given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any())) + .willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); + this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isOk()); + verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class)); + } + + @Test + public void getWhenCustomAccessDeniedHandlerThenHandlerIsUsed() throws Exception { + AccessDeniedHandlerConfig.DENIED_HANDLER = mock(AccessDeniedHandler.class); + this.spring.register(AccessDeniedHandlerConfig.class, BasicController.class).autowire(); + this.mvc.perform(post("/")).andExpect(status().isOk()); + verify(AccessDeniedHandlerConfig.DENIED_HANDLER).handle(any(HttpServletRequest.class), + any(HttpServletResponse.class), any()); + } + + @Test + public void loginWhenNoCsrfTokenThenRespondsWithForbidden() throws Exception { + this.spring.register(FormLoginConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password"); + this.mvc.perform(loginRequest) + .andExpect(status().isForbidden()) + .andExpect(unauthenticated()); + // @formatter:on + } + + @Test + public void logoutWhenNoCsrfTokenThenRespondsWithForbidden() throws Exception { + this.spring.register(FormLoginConfig.class).autowire(); + MockHttpServletRequestBuilder logoutRequest = post("/logout").with(user("username")); + // @formatter:off + this.mvc.perform(logoutRequest) + .andExpect(status().isForbidden()) + .andExpect(authenticated()); + // @formatter:on + } + + // SEC-2543 + @Test + public void logoutWhenCsrfEnabledAndGetRequestThenDoesNotLogout() throws Exception { + this.spring.register(FormLoginConfig.class).autowire(); + MockHttpServletRequestBuilder logoutRequest = get("/logout").with(user("username")); + this.mvc.perform(logoutRequest).andExpect(authenticated()); + } + + @Test + public void logoutWhenGetRequestAndGetEnabledForLogoutThenLogsOut() throws Exception { + this.spring.register(LogoutAllowsGetConfig.class).autowire(); + MockHttpServletRequestBuilder logoutRequest = get("/logout").with(user("username")); + this.mvc.perform(logoutRequest).andExpect(unauthenticated()); + } + + // SEC-2749 + @Test + public void configureWhenRequireCsrfProtectionMatcherNullThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullRequireCsrfProtectionMatcherConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getWhenDefaultCsrfTokenRepositoryThenDoesNotCreateSession() throws Exception { + this.spring.register(DefaultDoesNotCreateSession.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")).andReturn(); + assertThat(mvcResult.getRequest().getSession(false)).isNull(); + } + + @Test + public void getWhenNullAuthenticationStrategyThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullAuthenticationStrategy.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void csrfAuthenticationStrategyConfiguredThenStrategyUsed() throws Exception { + CsrfAuthenticationStrategyConfig.STRATEGY = mock(SessionAuthenticationStrategy.class); + this.spring.register(CsrfAuthenticationStrategyConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + // @formatter:on + this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + verify(CsrfAuthenticationStrategyConfig.STRATEGY, atLeastOnce()).onAuthentication(any(Authentication.class), + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + @Configuration static class AllowHttpMethodsFirewallConfig { + @Bean StrictHttpFirewall strictHttpFirewall() { StrictHttpFirewall result = new StrictHttpFirewall(); result.setUnsafeAllowAnyHttpMethod(true); return result; } + } @EnableWebSecurity @@ -185,14 +413,7 @@ public class CsrfConfigurerTests { @Override protected void configure(HttpSecurity http) { } - } - @Test - public void postWhenCsrfDisabledThenRespondsWithOk() throws Exception { - this.spring.register(DisableCsrfConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/")) - .andExpect(status().isOk()); } @EnableWebSecurity @@ -206,14 +427,7 @@ public class CsrfConfigurerTests { .disable(); // @formatter:on } - } - @Test - public void postWhenCsrfDisabledInLambdaThenRespondsWithOk() throws Exception { - this.spring.register(DisableCsrfInLambdaConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/")) - .andExpect(status().isOk()); } @EnableWebSecurity @@ -226,21 +440,7 @@ public class CsrfConfigurerTests { .csrf(AbstractHttpConfigurer::disable); // @formatter:on } - } - // SEC-2498 - @Test - public void loginWhenCsrfDisabledThenRedirectsToPreviousPostRequest() throws Exception { - this.spring.register(DisableCsrfEnablesRequestCacheConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/to-save")).andReturn(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session((MockHttpSession) mvcResult.getRequest().getSession())) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/to-save")); } @EnableWebSecurity @@ -268,52 +468,12 @@ public class CsrfConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception { - CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); - DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - when(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).thenReturn(csrfToken); - when(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).thenReturn(csrfToken); - this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/some-url")) - .andReturn(); - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .with(csrf()) - .session((MockHttpSession) mvcResult.getRequest().getSession())) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")); - - verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()).loadToken(any(HttpServletRequest.class)); - } - - @Test - public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception { - CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class); - DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - when(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).thenReturn(csrfToken); - when(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).thenReturn(csrfToken); - this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/some-url")) - .andReturn(); - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .with(csrf()) - .session((MockHttpSession) mvcResult.getRequest().getSession())) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/some-url")); - - verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce()).loadToken(any(HttpServletRequest.class)); } @EnableWebSecurity static class CsrfDisablesPostRequestFromRequestCacheConfig extends WebSecurityConfigurerAdapter { + static CsrfTokenRepository REPO; @Override @@ -338,26 +498,12 @@ public class CsrfConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - // SEC-2422 - @Test - public void postWhenCsrfEnabledAndSessionIsExpiredThenRespondsWithForbidden() throws Exception { - this.spring.register(InvalidSessionUrlConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/") - .param("_csrf", "abc")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/error/sessionError")) - .andReturn(); - - this.mvc.perform(post("/") - .session((MockHttpSession) mvcResult.getRequest().getSession())) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class InvalidSessionUrlConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -368,30 +514,12 @@ public class CsrfConfigurerTests { .invalidSessionUrl("/error/sessionError"); // @formatter:on } - } - @Test - public void requireCsrfProtectionMatcherWhenRequestDoesNotMatchThenRespondsWithOk() throws Exception { - this.spring.register(RequireCsrfProtectionMatcherConfig.class, BasicController.class).autowire(); - when(RequireCsrfProtectionMatcherConfig.MATCHER.matches(any())) - .thenReturn(false); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - } - - @Test - public void requireCsrfProtectionMatcherWhenRequestMatchesThenRespondsWithForbidden() throws Exception { - RequireCsrfProtectionMatcherConfig.MATCHER = mock(RequestMatcher.class); - when(RequireCsrfProtectionMatcherConfig.MATCHER.matches(any())).thenReturn(true); - this.spring.register(RequireCsrfProtectionMatcherConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RequireCsrfProtectionMatcherConfig extends WebSecurityConfigurerAdapter { + static RequestMatcher MATCHER; @Override @@ -402,87 +530,27 @@ public class CsrfConfigurerTests { .requireCsrfProtectionMatcher(MATCHER); // @formatter:on } - } - @Test - public void requireCsrfProtectionMatcherInLambdaWhenRequestDoesNotMatchThenRespondsWithOk() throws Exception { - RequireCsrfProtectionMatcherInLambdaConfig.MATCHER = mock(RequestMatcher.class); - this.spring.register(RequireCsrfProtectionMatcherInLambdaConfig.class, BasicController.class).autowire(); - when(RequireCsrfProtectionMatcherInLambdaConfig.MATCHER.matches(any())) - .thenReturn(false); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - } - - @Test - public void requireCsrfProtectionMatcherInLambdaWhenRequestMatchesThenRespondsWithForbidden() throws Exception { - RequireCsrfProtectionMatcherInLambdaConfig.MATCHER = mock(RequestMatcher.class); - when(RequireCsrfProtectionMatcherInLambdaConfig.MATCHER.matches(any())).thenReturn(true); - this.spring.register(RequireCsrfProtectionMatcherInLambdaConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RequireCsrfProtectionMatcherInLambdaConfig extends WebSecurityConfigurerAdapter { + static RequestMatcher MATCHER; @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .csrf(csrf -> csrf.requireCsrfProtectionMatcher(MATCHER)); + .csrf((csrf) -> csrf.requireCsrfProtectionMatcher(MATCHER)); // @formatter:on } - } - @Test - public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception { - CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); - when(CsrfTokenRepositoryConfig.REPO.loadToken(any())) - .thenReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); - this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class)); - } - - @Test - public void logoutWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { - CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); - this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf()) - .with(user("user"))); - - verify(CsrfTokenRepositoryConfig.REPO) - .saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - - @Test - public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { - CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); - DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); - when(CsrfTokenRepositoryConfig.REPO.loadToken(any())).thenReturn(csrfToken); - when(CsrfTokenRepositoryConfig.REPO.generateToken(any())).thenReturn(csrfToken); - this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")) - .andExpect(redirectedUrl("/")); - - verify(CsrfTokenRepositoryConfig.REPO) - .saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @EnableWebSecurity static class CsrfTokenRepositoryConfig extends WebSecurityConfigurerAdapter { + static CsrfTokenRepository REPO; @Override @@ -504,22 +572,12 @@ public class CsrfConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception { - CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class); - when(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any())) - .thenReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")); - this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class)); } @EnableWebSecurity static class CsrfTokenRepositoryInLambdaConfig extends WebSecurityConfigurerAdapter { + static CsrfTokenRepository REPO; @Override @@ -527,25 +585,15 @@ public class CsrfConfigurerTests { // @formatter:off http .formLogin(withDefaults()) - .csrf(csrf -> csrf.csrfTokenRepository(REPO)); + .csrf((csrf) -> csrf.csrfTokenRepository(REPO)); // @formatter:on } - } - @Test - public void getWhenCustomAccessDeniedHandlerThenHandlerIsUsed() throws Exception { - AccessDeniedHandlerConfig.DENIED_HANDLER = mock(AccessDeniedHandler.class); - this.spring.register(AccessDeniedHandlerConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/")) - .andExpect(status().isOk()); - - verify(AccessDeniedHandlerConfig.DENIED_HANDLER) - .handle(any(HttpServletRequest.class), any(HttpServletResponse.class), any()); } @EnableWebSecurity static class AccessDeniedHandlerConfig extends WebSecurityConfigurerAdapter { + static AccessDeniedHandler DENIED_HANDLER; @Override @@ -556,37 +604,7 @@ public class CsrfConfigurerTests { .accessDeniedHandler(DENIED_HANDLER); // @formatter:on } - } - @Test - public void loginWhenNoCsrfTokenThenRespondsWithForbidden() throws Exception { - this.spring.register(FormLoginConfig.class).autowire(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password")) - .andExpect(status().isForbidden()) - .andExpect(unauthenticated()); - } - - @Test - public void logoutWhenNoCsrfTokenThenRespondsWithForbidden() throws Exception { - this.spring.register(FormLoginConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(user("username"))) - .andExpect(status().isForbidden()) - .andExpect(authenticated()); - } - - // SEC-2543 - @Test - public void logoutWhenCsrfEnabledAndGetRequestThenDoesNotLogout() throws Exception { - this.spring.register(FormLoginConfig.class).autowire(); - - this.mvc.perform(get("/logout") - .with(user("username"))) - .andExpect(authenticated()); } @EnableWebSecurity @@ -599,15 +617,7 @@ public class CsrfConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void logoutWhenGetRequestAndGetEnabledForLogoutThenLogsOut() throws Exception { - this.spring.register(LogoutAllowsGetConfig.class).autowire(); - - this.mvc.perform(get("/logout") - .with(user("username"))) - .andExpect(unauthenticated()); } @EnableWebSecurity @@ -623,18 +633,12 @@ public class CsrfConfigurerTests { .logoutRequestMatcher(new AntPathRequestMatcher("/logout")); // @formatter:on } - } - // SEC-2749 - @Test - public void configureWhenRequireCsrfProtectionMatcherNullThenException() { - assertThatThrownBy(() -> this.spring.register(NullRequireCsrfProtectionMatcherConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class NullRequireCsrfProtectionMatcherConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -643,16 +647,7 @@ public class CsrfConfigurerTests { .requireCsrfProtectionMatcher(null); // @formatter:on } - } - @Test - public void getWhenDefaultCsrfTokenRepositoryThenDoesNotCreateSession() throws Exception { - this.spring.register(DefaultDoesNotCreateSession.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andReturn(); - - assertThat(mvcResult.getRequest().getSession(false)).isNull(); } @EnableWebSecurity @@ -679,10 +674,12 @@ public class CsrfConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } + } @EnableWebSecurity static class NullAuthenticationStrategy extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -691,17 +688,12 @@ public class CsrfConfigurerTests { .sessionAuthenticationStrategy(null); // @formatter:on } - } - @Test - public void getWhenNullAuthenticationStrategyThenException() { - assertThatThrownBy(() -> this.spring.register(NullAuthenticationStrategy.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class CsrfAuthenticationStrategyConfig extends WebSecurityConfigurerAdapter { + static SessionAuthenticationStrategy STRATEGY; @Override @@ -723,32 +715,20 @@ public class CsrfConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void csrfAuthenticationStrategyConfiguredThenStrategyUsed() throws Exception { - CsrfAuthenticationStrategyConfig.STRATEGY = mock(SessionAuthenticationStrategy.class); - - this.spring.register(CsrfAuthenticationStrategyConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")) - .andExpect(redirectedUrl("/")); - - verify(CsrfAuthenticationStrategyConfig.STRATEGY, atLeastOnce()) - .onAuthentication(any(Authentication.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @RestController static class BasicController { + @GetMapping("/") - public void rootGet() { + void rootGet() { } @PostMapping("/") - public void rootPost() { + void rootPost() { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java index 3802618cf3..f16d989226 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.io.IOException; import java.util.List; import java.util.stream.Collectors; + import javax.servlet.Filter; import javax.servlet.ServletException; @@ -76,61 +78,30 @@ public class DefaultFiltersTests { assertThat(this.spring.getContext().getBean(FilterChainProxy.class)).isNotNull(); } - @EnableWebSecurity - static class FilterChainProxyBuilderMissingConfig { - @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { - // @formatter:off - auth - .inMemoryAuthentication() - .withUser("user").password("password").roles("USER"); - // @formatter:on - } - } - @Test public void nullWebInvocationPrivilegeEvaluator() { this.spring.register(NullWebInvocationPrivilegeEvaluatorConfig.class, UserDetailsServiceConfig.class); - List filterChains = this.spring.getContext().getBean(FilterChainProxy.class).getFilterChains(); + List filterChains = this.spring.getContext().getBean(FilterChainProxy.class) + .getFilterChains(); assertThat(filterChains.size()).isEqualTo(1); DefaultSecurityFilterChain filterChain = (DefaultSecurityFilterChain) filterChains.get(0); assertThat(filterChain.getRequestMatcher()).isInstanceOf(AnyRequestMatcher.class); assertThat(filterChain.getFilters().size()).isEqualTo(1); long filter = filterChain.getFilters().stream() - .filter(it -> it instanceof UsernamePasswordAuthenticationFilter).count(); + .filter((it) -> it instanceof UsernamePasswordAuthenticationFilter).count(); assertThat(filter).isEqualTo(1); } - @Configuration - static class UserDetailsServiceConfig { - @Bean - public UserDetailsService userDetailsService() { - return new InMemoryUserDetailsManager(PasswordEncodedUser.user(), PasswordEncodedUser.admin()); - } - } - - @EnableWebSecurity - static class NullWebInvocationPrivilegeEvaluatorConfig extends WebSecurityConfigurerAdapter { - NullWebInvocationPrivilegeEvaluatorConfig() { - super(true); - } - - protected void configure(HttpSecurity http) throws Exception { - http.formLogin(); - } - } - @Test public void filterChainProxyBuilderIgnoringResources() { this.spring.register(FilterChainProxyBuilderIgnoringConfig.class, UserDetailsServiceConfig.class); - List filterChains = this.spring.getContext().getBean(FilterChainProxy.class).getFilterChains(); + List filterChains = this.spring.getContext().getBean(FilterChainProxy.class) + .getFilterChains(); assertThat(filterChains.size()).isEqualTo(2); DefaultSecurityFilterChain firstFilter = (DefaultSecurityFilterChain) filterChains.get(0); DefaultSecurityFilterChain secondFilter = (DefaultSecurityFilterChain) filterChains.get(1); - assertThat(firstFilter.getFilters().isEmpty()).isEqualTo(true); assertThat(secondFilter.getRequestMatcher()).isInstanceOf(AnyRequestMatcher.class); - List> classes = secondFilter.getFilters().stream().map(Filter::getClass) .collect(Collectors.toList()); assertThat(classes.contains(WebAsyncManagerIntegrationFilter.class)).isTrue(); @@ -146,8 +117,61 @@ public class DefaultFiltersTests { assertThat(classes.contains(FilterSecurityInterceptor.class)).isTrue(); } + @Test + public void defaultFiltersPermitAll() throws IOException, ServletException { + this.spring.register(DefaultFiltersConfigPermitAll.class, UserDetailsServiceConfig.class); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockHttpServletRequest request = new MockHttpServletRequest("POST", ""); + request.setServletPath("/logout"); + CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); + new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request, response); + request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); + this.spring.getContext().getBean("springSecurityFilterChain", Filter.class).doFilter(request, response, + new MockFilterChain()); + assertThat(response.getRedirectedUrl()).isEqualTo("/login?logout"); + } + + @EnableWebSecurity + static class FilterChainProxyBuilderMissingConfig { + + @Autowired + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("user").password("password").roles("USER"); + // @formatter:on + } + + } + + @Configuration + static class UserDetailsServiceConfig { + + @Bean + UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager(PasswordEncodedUser.user(), PasswordEncodedUser.admin()); + } + + } + + @EnableWebSecurity + static class NullWebInvocationPrivilegeEvaluatorConfig extends WebSecurityConfigurerAdapter { + + NullWebInvocationPrivilegeEvaluatorConfig() { + super(true); + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + http.formLogin(); + } + + } + @EnableWebSecurity static class FilterChainProxyBuilderIgnoringConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(WebSecurity web) { // @formatter:off @@ -165,27 +189,16 @@ public class DefaultFiltersTests { .anyRequest().hasRole("USER"); // @formatter:on } - } - @Test - public void defaultFiltersPermitAll() throws IOException, ServletException { - this.spring.register(DefaultFiltersConfigPermitAll.class, UserDetailsServiceConfig.class); - MockHttpServletResponse response = new MockHttpServletResponse(); - MockHttpServletRequest request = new MockHttpServletRequest("POST", ""); - request.setServletPath("/logout"); - - CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); - new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request, response); - request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); - - this.spring.getContext().getBean("springSecurityFilterChain", Filter.class).doFilter(request, response, - new MockFilterChain()); - assertThat(response.getRedirectedUrl()).isEqualTo("/login?logout"); } @EnableWebSecurity static class DefaultFiltersConfigPermitAll extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java index 0bd6e0417f..5b18272caa 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpSession; @@ -39,6 +40,7 @@ import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -67,9 +69,7 @@ public class DefaultLoginPageConfigurerTests { @Test public void getWhenFormLoginEnabledThenRedirectsToLoginPage() throws Exception { this.spring.register(DefaultLoginPageConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost/login")); + this.mvc.perform(get("/")).andExpect(redirectedUrl("http://localhost/login")); } @Test @@ -77,48 +77,43 @@ public class DefaultLoginPageConfigurerTests { this.spring.register(DefaultLoginPageConfig.class).autowire(); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); - - this.mvc.perform(get("/login") - .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string( - "\n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Please sign in\n" + - " \n" + - " \n" + - " \n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - "" - )); + // @formatter:off + this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) + .andExpect(content().string("\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + "")); + // @formatter:on } @Test public void loginWhenNoCredentialsThenRedirectedToLoginPageWithError() throws Exception { this.spring.register(DefaultLoginPageConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf())) - .andExpect(redirectedUrl("/login?error")); + this.mvc.perform(post("/login").with(csrf())).andExpect(redirectedUrl("/login?error")); } @Test @@ -126,55 +121,50 @@ public class DefaultLoginPageConfigurerTests { this.spring.register(DefaultLoginPageConfig.class).autowire(); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); - - MvcResult mvcResult = this.mvc.perform(post("/login") - .with(csrf())) - .andReturn(); - - this.mvc.perform(get("/login?error") - .session((MockHttpSession) mvcResult.getRequest().getSession()) + MvcResult mvcResult = this.mvc.perform(post("/login").with(csrf())).andReturn(); + // @formatter:off + this.mvc.perform(get("/login?error").session((MockHttpSession) mvcResult.getRequest().getSession()) .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string( - "\n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Please sign in\n" + - " \n" + - " \n" + - " \n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "
    Bad credentials

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - "" - )); + .andExpect(content().string("\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "
    Bad credentials

    \n" + + " \n" + + " \n" + + "

    \n" + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + "")); + // @formatter:on } @Test public void loginWhenValidCredentialsThenRedirectsToDefaultSuccessPage() throws Exception { this.spring.register(DefaultLoginPageConfig.class).autowire(); - - this.mvc.perform(post("/login") + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .with(csrf()) .param("username", "user") - .param("password", "password")) - .andExpect(redirectedUrl("/")); + .param("password", "password"); + // @formatter:on + this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); } @Test @@ -182,43 +172,212 @@ public class DefaultLoginPageConfigurerTests { this.spring.register(DefaultLoginPageConfig.class).autowire(); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); + // @formatter:off + this.mvc.perform(get("/login?logout").sessionAttr(csrfAttributeName, csrfToken)) + .andExpect(content().string("\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "
    You have been signed out

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + "")); + // @formatter:on + } - this.mvc.perform(get("/login?logout") - .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string( - "\n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Please sign in\n" + - " \n" + - " \n" + - " \n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "
    You have been signed out

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - "" - )); + @Test + public void loginPageWhenLoggedOutAndCustomLogoutSuccessHandlerThenDoesNotRenderLoginPage() throws Exception { + this.spring.register(DefaultLoginPageCustomLogoutSuccessHandlerConfig.class).autowire(); + this.mvc.perform(get("/login?logout")).andExpect(content().string("")); + } + + @Test + public void loginPageWhenLoggedOutAndCustomLogoutSuccessUrlThenDoesNotRenderLoginPage() throws Exception { + this.spring.register(DefaultLoginPageCustomLogoutSuccessUrlConfig.class).autowire(); + this.mvc.perform(get("/login?logout")).andExpect(content().string("")); + } + + @Test + public void loginPageWhenRememberConfigureThenDefaultLoginPageWithRememberMeCheckbox() throws Exception { + this.spring.register(DefaultLoginPageWithRememberMeConfig.class).autowire(); + CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); + String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); + // @formatter:off + this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) + .andExpect(content().string("\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    Remember me on this computer.

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + "")); + // @formatter:on + } + + @Test + public void loginPageWhenOpenIdLoginConfiguredThenOpedIdLoginPage() throws Exception { + this.spring.register(DefaultLoginPageWithOpenIDConfig.class).autowire(); + CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); + String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); + // @formatter:off + this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) + .andExpect(content().string("\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + "")); + // @formatter:on + } + + @Test + public void loginPageWhenOpenIdLoginAndFormLoginAndRememberMeConfiguredThenOpedIdLoginPage() throws Exception { + this.spring.register(DefaultLoginPageWithFormLoginOpenIDRememberMeConfig.class).autowire(); + CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); + String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); + // @formatter:off + this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) + .andExpect(content().string("\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    Remember me on this computer.

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + " \n" + + " \n" + + "

    \n" + + "

    Remember me on this computer.

    \n" + + "\n" + + " \n" + + "
    \n" + + "
    \n" + + "")); + // @formatter:on + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnDefaultLoginPageGeneratingFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(DefaultLoginPageGeneratingFilter.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnUsernamePasswordAuthenticationFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor) + .postProcess(any(UsernamePasswordAuthenticationFilter.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnLoginUrlAuthenticationEntryPoint() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(LoginUrlAuthenticationEntryPoint.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnExceptionTranslationFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ExceptionTranslationFilter.class)); + } + + @Test + public void configureWhenAuthenticationEntryPointThenNoDefaultLoginPageGeneratingFilter() { + this.spring.register(DefaultLoginWithCustomAuthenticationEntryPointConfig.class).autowire(); + FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); + assertThat(filterChain.getFilterChains().get(0).getFilters().stream() + .filter((filter) -> filter.getClass().isAssignableFrom(DefaultLoginPageGeneratingFilter.class)).count()) + .isZero(); } @EnableWebSecurity static class DefaultLoginPageConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -238,19 +397,12 @@ public class DefaultLoginPageConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } + } - @Test - public void loginPageWhenLoggedOutAndCustomLogoutSuccessHandlerThenDoesNotRenderLoginPage() throws Exception { - this.spring.register(DefaultLoginPageCustomLogoutSuccessHandlerConfig.class).autowire(); - - this.mvc.perform(get("/login?logout")) - .andExpect(content().string("")); - } - - @EnableWebSecurity static class DefaultLoginPageCustomLogoutSuccessHandlerConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -264,18 +416,12 @@ public class DefaultLoginPageConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void loginPageWhenLoggedOutAndCustomLogoutSuccessUrlThenDoesNotRenderLoginPage() throws Exception { - this.spring.register(DefaultLoginPageCustomLogoutSuccessUrlConfig.class).autowire(); - - this.mvc.perform(get("/login?logout")) - .andExpect(content().string("")); } @EnableWebSecurity static class DefaultLoginPageCustomLogoutSuccessUrlConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -289,51 +435,12 @@ public class DefaultLoginPageConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void loginPageWhenRememberConfigureThenDefaultLoginPageWithRememberMeCheckbox() throws Exception { - this.spring.register(DefaultLoginPageWithRememberMeConfig.class).autowire(); - CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); - String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); - - this.mvc.perform(get("/login") - .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string( - "\n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Please sign in\n" + - " \n" + - " \n" + - " \n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    Remember me on this computer.

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - "" - )); } @EnableWebSecurity static class DefaultLoginPageWithRememberMeConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -346,47 +453,12 @@ public class DefaultLoginPageConfigurerTests { .rememberMe(); // @formatter:on } - } - @Test - public void loginPageWhenOpenIdLoginConfiguredThenOpedIdLoginPage() throws Exception { - this.spring.register(DefaultLoginPageWithOpenIDConfig.class).autowire(); - - CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); - String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); - - this.mvc.perform(get("/login") - .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string( - "\n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Please sign in\n" + - " \n" + - " \n" + - " \n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - "" - )); } @EnableWebSecurity static class DefaultLoginPageWithOpenIDConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -397,61 +469,12 @@ public class DefaultLoginPageConfigurerTests { .openidLogin(); // @formatter:on } - } - @Test - public void loginPageWhenOpenIdLoginAndFormLoginAndRememberMeConfiguredThenOpedIdLoginPage() throws Exception { - this.spring.register(DefaultLoginPageWithFormLoginOpenIDRememberMeConfig.class).autowire(); - CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); - String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); - - this.mvc.perform(get("/login") - .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string( - "\n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Please sign in\n" + - " \n" + - " \n" + - " \n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    Remember me on this computer.

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - " \n" + - "

    \n" + - " \n" + - " \n" + - "

    \n" + - "

    Remember me on this computer.

    \n" + - "\n" + - " \n" + - "
    \n" + - "
    \n" + - "" - )); } @EnableWebSecurity static class DefaultLoginPageWithFormLoginOpenIDRememberMeConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -466,21 +489,12 @@ public class DefaultLoginPageConfigurerTests { .openidLogin(); // @formatter:on } - } - @Test - public void configureWhenAuthenticationEntryPointThenNoDefaultLoginPageGeneratingFilter() { - this.spring.register(DefaultLoginWithCustomAuthenticationEntryPointConfig.class).autowire(); - - FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); - assertThat(filterChain.getFilterChains().get(0).getFilters().stream() - .filter(filter -> filter.getClass().isAssignableFrom(DefaultLoginPageGeneratingFilter.class)) - .count()) - .isZero(); } @EnableWebSecurity static class DefaultLoginWithCustomAuthenticationEntryPointConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -494,46 +508,12 @@ public class DefaultLoginPageConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnDefaultLoginPageGeneratingFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(DefaultLoginPageGeneratingFilter.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnUsernamePasswordAuthenticationFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(UsernamePasswordAuthenticationFilter.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnLoginUrlAuthenticationEntryPoint() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(LoginUrlAuthenticationEntryPoint.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnExceptionTranslationFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ExceptionTranslationFilter.class)); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor; @Override @@ -550,12 +530,16 @@ public class DefaultLoginPageConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerAccessDeniedHandlerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerAccessDeniedHandlerTests.java index ef1084cdbd..1022b268e5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerAccessDeniedHandlerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerAccessDeniedHandlerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; @@ -43,6 +44,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class ExceptionHandlingConfigurerAccessDeniedHandlerTests { + @Autowired MockMvc mvc; @@ -51,22 +53,33 @@ public class ExceptionHandlingConfigurerAccessDeniedHandlerTests { @Test @WithMockUser(roles = "ANYTHING") - public void getWhenAccessDeniedOverriddenThenCustomizesResponseByRequest() - throws Exception { + public void getWhenAccessDeniedOverriddenThenCustomizesResponseByRequest() throws Exception { this.spring.register(RequestMatcherBasedAccessDeniedHandlerConfig.class).autowire(); + this.mvc.perform(get("/hello")).andExpect(status().isIAmATeapot()); + this.mvc.perform(get("/goodbye")).andExpect(status().isForbidden()); + } - this.mvc.perform(get("/hello")) - .andExpect(status().isIAmATeapot()); + @Test + @WithMockUser(roles = "ANYTHING") + public void getWhenAccessDeniedOverriddenInLambdaThenCustomizesResponseByRequest() throws Exception { + this.spring.register(RequestMatcherBasedAccessDeniedHandlerInLambdaConfig.class).autowire(); + this.mvc.perform(get("/hello")).andExpect(status().isIAmATeapot()); + this.mvc.perform(get("/goodbye")).andExpect(status().isForbidden()); + } - this.mvc.perform(get("/goodbye")) - .andExpect(status().isForbidden()); + @Test + @WithMockUser(roles = "ANYTHING") + public void getWhenAccessDeniedOverriddenByOnlyOneHandlerThenAllRequestsUseThatHandler() throws Exception { + this.spring.register(SingleRequestMatcherAccessDeniedHandlerConfig.class).autowire(); + this.mvc.perform(get("/hello")).andExpect(status().isIAmATeapot()); + this.mvc.perform(get("/goodbye")).andExpect(status().isIAmATeapot()); } @EnableWebSecurity static class RequestMatcherBasedAccessDeniedHandlerConfig extends WebSecurityConfigurerAdapter { - AccessDeniedHandler teapotDeniedHandler = - (request, response, exception) -> - response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); + + AccessDeniedHandler teapotDeniedHandler = (request, response, exception) -> response + .setStatus(HttpStatus.I_AM_A_TEAPOT.value()); @Override protected void configure(HttpSecurity http) throws Exception { @@ -84,36 +97,24 @@ public class ExceptionHandlingConfigurerAccessDeniedHandlerTests { AnyRequestMatcher.INSTANCE); // @formatter:on } - } - @Test - @WithMockUser(roles = "ANYTHING") - public void getWhenAccessDeniedOverriddenInLambdaThenCustomizesResponseByRequest() - throws Exception { - this.spring.register(RequestMatcherBasedAccessDeniedHandlerInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/hello")) - .andExpect(status().isIAmATeapot()); - - this.mvc.perform(get("/goodbye")) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RequestMatcherBasedAccessDeniedHandlerInLambdaConfig extends WebSecurityConfigurerAdapter { - AccessDeniedHandler teapotDeniedHandler = - (request, response, exception) -> - response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); + + AccessDeniedHandler teapotDeniedHandler = (request, response, exception) -> response + .setStatus(HttpStatus.I_AM_A_TEAPOT.value()); @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().denyAll() ) - .exceptionHandling(exceptionHandling -> + .exceptionHandling((exceptionHandling) -> exceptionHandling .defaultAccessDeniedHandlerFor( this.teapotDeniedHandler, @@ -126,26 +127,14 @@ public class ExceptionHandlingConfigurerAccessDeniedHandlerTests { ); // @formatter:on } - } - @Test - @WithMockUser(roles = "ANYTHING") - public void getWhenAccessDeniedOverriddenByOnlyOneHandlerThenAllRequestsUseThatHandler() - throws Exception { - this.spring.register(SingleRequestMatcherAccessDeniedHandlerConfig.class).autowire(); - - this.mvc.perform(get("/hello")) - .andExpect(status().isIAmATeapot()); - - this.mvc.perform(get("/goodbye")) - .andExpect(status().isIAmATeapot()); } @EnableWebSecurity static class SingleRequestMatcherAccessDeniedHandlerConfig extends WebSecurityConfigurerAdapter { - AccessDeniedHandler teapotDeniedHandler = - (request, response, exception) -> - response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); + + AccessDeniedHandler teapotDeniedHandler = (request, response, exception) -> response + .setStatus(HttpStatus.I_AM_A_TEAPOT.value()); @Override protected void configure(HttpSecurity http) throws Exception { @@ -160,5 +149,7 @@ public class ExceptionHandlingConfigurerAccessDeniedHandlerTests { new AntPathRequestMatcher("/hello/**")); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerTests.java index bd452d0090..8837354084 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExceptionHandlingConfigurerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -66,13 +67,154 @@ public class ExceptionHandlingConfigurerTests { @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnExceptionTranslationFilter() { this.spring.register(ObjectPostProcessorConfig.class, DefaultSecurityConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ExceptionTranslationFilter.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ExceptionTranslationFilter.class)); + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsApplicationXhtmlXmlThenRespondsWith302() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.APPLICATION_XHTML_XML)) + .andExpect(status().isFound()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsImageGifThenRespondsWith302() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.IMAGE_GIF)).andExpect(status().isFound()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsImageJpgThenRespondsWith302() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.IMAGE_JPEG)).andExpect(status().isFound()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsImagePngThenRespondsWith302() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.IMAGE_PNG)).andExpect(status().isFound()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsTextHtmlThenRespondsWith302() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML)).andExpect(status().isFound()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsTextPlainThenRespondsWith302() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.TEXT_PLAIN)).andExpect(status().isFound()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsApplicationAtomXmlThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.APPLICATION_ATOM_XML)) + .andExpect(status().isUnauthorized()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsApplicationFormUrlEncodedThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.APPLICATION_FORM_URLENCODED)) + .andExpect(status().isUnauthorized()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsApplicationJsonThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON)) + .andExpect(status().isUnauthorized()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsApplicationOctetStreamThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.APPLICATION_OCTET_STREAM)) + .andExpect(status().isUnauthorized()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsMultipartFormDataThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.MULTIPART_FORM_DATA)) + .andExpect(status().isUnauthorized()); + } + + // SEC-2199 + @Test + public void getWhenAcceptHeaderIsTextXmlThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.TEXT_XML)).andExpect(status().isUnauthorized()); + } + + // gh-4831 + @Test + public void getWhenAcceptIsAnyThenRespondsWith401() throws Exception { + this.spring.register(DefaultSecurityConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, MediaType.ALL)).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenAcceptIsChromeThenRespondsWith302() throws Exception { + this.spring.register(DefaultSecurityConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8")) + .andExpect(status().isFound()); + } + + @Test + public void getWhenAcceptIsTextPlainAndXRequestedWithIsXHRThenRespondsWith401() throws Exception { + this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); + this.mvc.perform(get("/").header("Accept", MediaType.TEXT_PLAIN).header("X-Requested-With", "XMLHttpRequest")) + .andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenCustomContentNegotiationStrategyThenStrategyIsUsed() throws Exception { + this.spring.register(OverrideContentNegotiationStrategySharedObjectConfig.class, DefaultSecurityConfig.class) + .autowire(); + this.mvc.perform(get("/")); + verify(OverrideContentNegotiationStrategySharedObjectConfig.CNS, atLeastOnce()) + .resolveMediaTypes(any(NativeWebRequest.class)); + } + + @Test + public void getWhenUsingDefaultsAndUnauthenticatedThenRedirectsToLogin() throws Exception { + this.spring.register(DefaultHttpConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, "bogus/type")) + .andExpect(redirectedUrl("http://localhost/login")); + } + + @Test + public void getWhenDeclaringHttpBasicBeforeFormLoginThenRespondsWith401() throws Exception { + this.spring.register(BasicAuthenticationEntryPointBeforeFormLoginConfig.class).autowire(); + this.mvc.perform(get("/").header(HttpHeaders.ACCEPT, "bogus/type")).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenInvokingExceptionHandlingTwiceThenOriginalEntryPointUsed() throws Exception { + this.spring.register(InvokeTwiceDoesNotOverrideConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(InvokeTwiceDoesNotOverrideConfig.AEP).commence(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(AuthenticationException.class)); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -87,189 +229,41 @@ public class ExceptionHandlingConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsApplicationXhtmlXmlThenRespondsWith302() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_XHTML_XML)) - .andExpect(status().isFound()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsImageGifThenRespondsWith302() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.IMAGE_GIF)) - .andExpect(status().isFound()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsImageJpgThenRespondsWith302() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.IMAGE_JPEG)) - .andExpect(status().isFound()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsImagePngThenRespondsWith302() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.IMAGE_PNG)) - .andExpect(status().isFound()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsTextHtmlThenRespondsWith302() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML)) - .andExpect(status().isFound()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsTextPlainThenRespondsWith302() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.TEXT_PLAIN)) - .andExpect(status().isFound()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsApplicationAtomXmlThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_ATOM_XML)) - .andExpect(status().isUnauthorized()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsApplicationFormUrlEncodedThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_FORM_URLENCODED)) - .andExpect(status().isUnauthorized()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsApplicationJsonThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON)) - .andExpect(status().isUnauthorized()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsApplicationOctetStreamThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_OCTET_STREAM)) - .andExpect(status().isUnauthorized()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsMultipartFormDataThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.MULTIPART_FORM_DATA)) - .andExpect(status().isUnauthorized()); - } - - // SEC-2199 - @Test - public void getWhenAcceptHeaderIsTextXmlThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.TEXT_XML)) - .andExpect(status().isUnauthorized()); - } - - // gh-4831 - @Test - public void getWhenAcceptIsAnyThenRespondsWith401() throws Exception { - this.spring.register(DefaultSecurityConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, MediaType.ALL)) - .andExpect(status().isUnauthorized()); - } - - @Test - public void getWhenAcceptIsChromeThenRespondsWith302() throws Exception { - this.spring.register(DefaultSecurityConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, - "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8")) - .andExpect(status().isFound()); - } - - @Test - public void getWhenAcceptIsTextPlainAndXRequestedWithIsXHRThenRespondsWith401() throws Exception { - this.spring.register(HttpBasicAndFormLoginEntryPointsConfig.class).autowire(); - - this.mvc.perform(get("/") - .header("Accept", MediaType.TEXT_PLAIN) - .header("X-Requested-With", "XMLHttpRequest")) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class DefaultSecurityConfig { @Bean - public InMemoryUserDetailsManager userDetailsManager() { + InMemoryUserDetailsManager userDetailsManager() { + // @formatter:off return new InMemoryUserDetailsManager(User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build() ); + // @formatter:off } } - @EnableWebSecurity static class HttpBasicAndFormLoginEntryPointsConfig extends WebSecurityConfigurerAdapter { - @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); } - @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -282,53 +276,29 @@ public class ExceptionHandlingConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void getWhenCustomContentNegotiationStrategyThenStrategyIsUsed() throws Exception { - this.spring.register(OverrideContentNegotiationStrategySharedObjectConfig.class, - DefaultSecurityConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(OverrideContentNegotiationStrategySharedObjectConfig.CNS, atLeastOnce()) - .resolveMediaTypes(any(NativeWebRequest.class)); } @EnableWebSecurity static class OverrideContentNegotiationStrategySharedObjectConfig extends WebSecurityConfigurerAdapter { + static ContentNegotiationStrategy CNS = mock(ContentNegotiationStrategy.class); @Bean - public static ContentNegotiationStrategy cns() { + static ContentNegotiationStrategy cns() { return CNS; } - } - @Test - public void getWhenUsingDefaultsAndUnauthenticatedThenRedirectsToLogin() throws Exception { - this.spring.register(DefaultHttpConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, "bogus/type")) - .andExpect(redirectedUrl("http://localhost/login")); } @EnableWebSecurity static class DefaultHttpConfig extends WebSecurityConfigurerAdapter { - } - @Test - public void getWhenDeclaringHttpBasicBeforeFormLoginThenRespondsWith401() throws Exception { - this.spring.register(BasicAuthenticationEntryPointBeforeFormLoginConfig.class).autowire(); - - this.mvc.perform(get("/") - .header(HttpHeaders.ACCEPT, "bogus/type")) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class BasicAuthenticationEntryPointBeforeFormLoginConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -341,35 +311,27 @@ public class ExceptionHandlingConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void getWhenInvokingExceptionHandlingTwiceThenOriginalEntryPointUsed() throws Exception { - this.spring.register(InvokeTwiceDoesNotOverrideConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(InvokeTwiceDoesNotOverrideConfig.AEP) - .commence(any(HttpServletRequest.class), - any(HttpServletResponse.class), any(AuthenticationException.class)); } @EnableWebSecurity static class InvokeTwiceDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + static AuthenticationEntryPoint AEP = mock(AuthenticationEntryPoint.class); @Override protected void configure(HttpSecurity http) throws Exception { - // @formatter:on + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() .and() .exceptionHandling() - .authenticationEntryPoint(AEP) - .and() + .authenticationEntryPoint(AEP).and() .exceptionHandling(); - // @formatter:off + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurerTests.java index 4f7844ca69..6ee24188f8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ExpressionUrlAuthorizationConfigurerTests.java @@ -16,8 +16,12 @@ package org.springframework.security.config.annotation.web.configurers; +import java.io.Serializable; +import java.util.Collections; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; @@ -47,15 +51,13 @@ import org.springframework.security.web.access.expression.WebExpressionVoter; import org.springframework.security.web.access.expression.WebSecurityExpressionRoot; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RestController; -import java.io.Serializable; -import java.util.Collections; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -83,14 +85,426 @@ public class ExpressionUrlAuthorizationConfigurerTests { @Test public void configureWhenHasRoleStartingWithStringRoleThenException() { - assertThatThrownBy(() -> this.spring.register(HasRoleStartingWithRoleConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("role should not start with 'ROLE_' since it is automatically inserted. Got 'ROLE_USER'"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(HasRoleStartingWithRoleConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class).withMessageContaining( + "role should not start with 'ROLE_' since it is automatically inserted. Got 'ROLE_USER'"); + } + + @Test + public void configureWhenNoCustomAccessDecisionManagerThenUsesAffirmativeBased() { + this.spring.register(NoSpecificAccessDecisionManagerConfig.class).autowire(); + verify(NoSpecificAccessDecisionManagerConfig.objectPostProcessor).postProcess(any(AffirmativeBased.class)); + } + + @Test + public void configureWhenAuthorizedRequestsAndNoRequestsThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NoRequestsConfig.class).autowire()).withMessageContaining( + "At least one mapping is required (i.e. authorizeRequests().anyRequest().authenticated())"); + } + + @Test + public void configureWhenAnyRequestIncompleteMappingThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(IncompleteMappingConfig.class).autowire()) + .withMessageContaining("An incomplete mapping was found for "); + } + + @Test + public void getWhenHasAnyAuthorityRoleUserConfiguredAndAuthorityIsRoleUserThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserAnyAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_USER"))); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenHasAnyAuthorityRoleUserConfiguredAndAuthorityIsRoleAdminThenRespondsWithForbidden() + throws Exception { + this.spring.register(RoleUserAnyAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithAdmin = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_ADMIN"))); + // @formatter:on + this.mvc.perform(requestWithAdmin).andExpect(status().isForbidden()); + } + + @Test + public void getWhenHasAnyAuthorityRoleUserConfiguredAndNoAuthorityThenRespondsWithUnauthorized() throws Exception { + this.spring.register(RoleUserAnyAuthorityConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenHasAuthorityRoleUserConfiguredAndAuthorityIsRoleUserThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_USER"))); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenHasAuthorityRoleUserConfiguredAndAuthorityIsRoleAdminThenRespondsWithForbidden() + throws Exception { + this.spring.register(RoleUserAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithAdmin = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_ADMIN"))); + // @formatter:on + this.mvc.perform(requestWithAdmin).andExpect(status().isForbidden()); + } + + @Test + public void getWhenHasAuthorityRoleUserConfiguredAndNoAuthorityThenRespondsWithUnauthorized() throws Exception { + this.spring.register(RoleUserAuthorityConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenAuthorityRoleUserOrAdminRequiredAndAuthorityIsRoleUserThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_USER"))); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenAuthorityRoleUserOrAdminRequiredAndAuthorityIsRoleAdminThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_ADMIN"))); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenAuthorityRoleUserOrAdminRequiredAndAuthorityIsRoleOtherThenRespondsWithForbidden() + throws Exception { + this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .authorities(new SimpleGrantedAuthority("ROLE_OTHER"))); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + @Test + public void getWhenAuthorityRoleUserOrAdminAuthRequiredAndNoUserThenRespondsWithUnauthorized() throws Exception { + this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenHasAnyRoleUserConfiguredAndRoleIsUserThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .roles("USER")); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenHasAnyRoleUserConfiguredAndRoleIsAdminThenRespondsWithForbidden() throws Exception { + this.spring.register(RoleUserConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithAdmin = get("/") + .with(user("user") + .roles("ADMIN")); + // @formatter:on + this.mvc.perform(requestWithAdmin).andExpect(status().isForbidden()); + } + + @Test + public void getWhenRoleUserOrAdminConfiguredAndRoleIsUserThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserOrAdminConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = get("/") + .with(user("user") + .roles("USER")); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenRoleUserOrAdminConfiguredAndRoleIsAdminThenRespondsWithOk() throws Exception { + this.spring.register(RoleUserOrAdminConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithAdmin = get("/") + .with(user("user") + .roles("ADMIN")); + // @formatter:on + this.mvc.perform(requestWithAdmin).andExpect(status().isOk()); + } + + @Test + public void getWhenRoleUserOrAdminConfiguredAndRoleIsOtherThenRespondsWithForbidden() throws Exception { + this.spring.register(RoleUserOrAdminConfig.class, BasicController.class).autowire(); + // + MockHttpServletRequestBuilder requestWithRoleOther = get("/").with(user("user").roles("OTHER")); + // + this.mvc.perform(requestWithRoleOther).andExpect(status().isForbidden()); + } + + @Test + public void getWhenHasIpAddressConfiguredAndIpAddressMatchesThenRespondsWithOk() throws Exception { + this.spring.register(HasIpAddressConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/").with((request) -> { + request.setRemoteAddr("192.168.1.0"); + return request; + })).andExpect(status().isOk()); + } + + @Test + public void getWhenHasIpAddressConfiguredAndIpAddressDoesNotMatchThenRespondsWithUnauthorized() throws Exception { + this.spring.register(HasIpAddressConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/").with((request) -> { + request.setRemoteAddr("192.168.1.1"); + return request; + })).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenAnonymousConfiguredAndAnonymousUserThenRespondsWithOk() throws Exception { + this.spring.register(AnonymousConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isOk()); + } + + @Test + public void getWhenAnonymousConfiguredAndLoggedInUserThenRespondsWithForbidden() throws Exception { + this.spring.register(AnonymousConfig.class, BasicController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/").with(user("user")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + @Test + public void getWhenRememberMeConfiguredAndNoUserThenRespondsWithUnauthorized() throws Exception { + this.spring.register(RememberMeConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenRememberMeConfiguredAndRememberMeTokenThenRespondsWithOk() throws Exception { + this.spring.register(RememberMeConfig.class, BasicController.class).autowire(); + RememberMeAuthenticationToken rememberme = new RememberMeAuthenticationToken("key", "user", + AuthorityUtils.createAuthorityList("ROLE_USER")); + MockHttpServletRequestBuilder requestWithRememberme = get("/").with(authentication(rememberme)); + this.mvc.perform(requestWithRememberme).andExpect(status().isOk()); + } + + @Test + public void getWhenDenyAllConfiguredAndNoUserThenRespondsWithUnauthorized() throws Exception { + this.spring.register(DenyAllConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + } + + @Test + public void getWheDenyAllConfiguredAndUserLoggedInThenRespondsWithForbidden() throws Exception { + this.spring.register(DenyAllConfig.class, BasicController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + @Test + public void getWhenNotDenyAllConfiguredAndNoUserThenRespondsWithOk() throws Exception { + this.spring.register(NotDenyAllConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isOk()); + } + + @Test + public void getWhenNotDenyAllConfiguredAndRememberMeTokenThenRespondsWithOk() throws Exception { + this.spring.register(NotDenyAllConfig.class, BasicController.class).autowire(); + RememberMeAuthenticationToken rememberme = new RememberMeAuthenticationToken("key", "user", + AuthorityUtils.createAuthorityList("ROLE_USER")); + MockHttpServletRequestBuilder requestWithRememberme = get("/").with(authentication(rememberme)); + this.mvc.perform(requestWithRememberme).andExpect(status().isOk()); + } + + @Test + public void getWhenFullyAuthenticatedConfiguredAndRememberMeTokenThenRespondsWithUnauthorized() throws Exception { + this.spring.register(FullyAuthenticatedConfig.class, BasicController.class).autowire(); + RememberMeAuthenticationToken rememberme = new RememberMeAuthenticationToken("key", "user", + AuthorityUtils.createAuthorityList("ROLE_USER")); + MockHttpServletRequestBuilder requestWithRememberme = get("/").with(authentication(rememberme)); + this.mvc.perform(requestWithRememberme).andExpect(status().isUnauthorized()); + } + + @Test + public void getWhenFullyAuthenticatedConfiguredAndUserThenRespondsWithOk() throws Exception { + this.spring.register(FullyAuthenticatedConfig.class, BasicController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenAccessRoleUserOrGetRequestConfiguredThenRespondsWithOk() throws Exception { + this.spring.register(AccessConfig.class, BasicController.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().isOk()); + } + + @Test + public void postWhenAccessRoleUserOrGetRequestConfiguredAndRoleUserThenRespondsWithOk() throws Exception { + this.spring.register(AccessConfig.class, BasicController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithUser = post("/") + .with(csrf()) + .with(user("user").roles("USER")); + // @formatter:on + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void postWhenAccessRoleUserOrGetRequestConfiguredThenRespondsWithUnauthorized() throws Exception { + this.spring.register(AccessConfig.class, BasicController.class).autowire(); + MockHttpServletRequestBuilder requestWithCsrf = post("/").with(csrf()); + this.mvc.perform(requestWithCsrf).andExpect(status().isUnauthorized()); + } + + @Test + public void authorizeRequestsWhenInvokedTwiceThenUsesOriginalConfiguration() throws Exception { + this.spring.register(InvokeTwiceDoesNotResetConfig.class, BasicController.class).autowire(); + MockHttpServletRequestBuilder requestWithCsrf = post("/").with(csrf()); + this.mvc.perform(requestWithCsrf).andExpect(status().isUnauthorized()); + } + + @Test + public void configureWhenUsingAllAuthorizeRequestPropertiesThenCompiles() { + this.spring.register(AllPropertiesWorkConfig.class).autowire(); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenApplicationListenerInvokedOnAuthorizedEvent() + throws Exception { + this.spring.register(AuthorizedRequestsWithPostProcessorConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(AuthorizedRequestsWithPostProcessorConfig.AL).onApplicationEvent(any(AuthorizedEvent.class)); + } + + @Test + public void getWhenPermissionCheckAndRoleDoesNotMatchThenRespondsWithForbidden() throws Exception { + this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/admin").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + @Test + public void getWhenPermissionCheckAndRoleMatchesThenRespondsWithOk() throws Exception { + this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/user").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenPermissionCheckAndAuthenticationNameMatchesThenRespondsWithOk() throws Exception { + this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/allow").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenPermissionCheckAndAuthenticationNameDoesNotMatchThenRespondsWithForbidden() throws Exception { + this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/deny").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + @Test + public void getWhenCustomExpressionHandlerAndRoleDoesNotMatchThenRespondsWithForbidden() throws Exception { + this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/admin").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + @Test + public void getWhenCustomExpressionHandlerAndRoleMatchesThenRespondsWithOk() throws Exception { + this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/user").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenCustomExpressionHandlerAndAuthenticationNameMatchesThenRespondsWithOk() throws Exception { + this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/allow").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenCustomExpressionHandlerAndAuthenticationNameDoesNotMatchThenRespondsWithForbidden() + throws Exception { + this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/deny").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); + } + + // SEC-3011 + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnAccessDecisionManager() { + this.spring.register(Sec3011Config.class).autowire(); + verify(Sec3011Config.objectPostProcessor).postProcess(any(AccessDecisionManager.class)); + } + + @Test + public void getWhenRegisteringPermissionEvaluatorAndPermissionWithIdAndTypeMatchesThenRespondsWithOk() + throws Exception { + this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); + this.mvc.perform(get("/allow")).andExpect(status().isOk()); + } + + @Test + public void getWhenRegisteringPermissionEvaluatorAndPermissionWithIdAndTypeDoesNotMatchThenRespondsWithForbidden() + throws Exception { + this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); + this.mvc.perform(get("/deny")).andExpect(status().isForbidden()); + } + + @Test + public void getWhenRegisteringPermissionEvaluatorAndPermissionWithObjectMatchesThenRespondsWithOk() + throws Exception { + this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); + this.mvc.perform(get("/allowObject")).andExpect(status().isOk()); + } + + @Test + public void getWhenRegisteringPermissionEvaluatorAndPermissionWithObjectDoesNotMatchThenRespondsWithForbidden() + throws Exception { + this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); + this.mvc.perform(get("/denyObject")).andExpect(status().isForbidden()); + } + + @Test + public void getWhenRegisteringRoleHierarchyAndRelatedRoleAllowedThenRespondsWithOk() throws Exception { + this.spring.register(RoleHierarchyConfig.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/allow").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isOk()); + } + + @Test + public void getWhenRegisteringRoleHierarchyAndNoRelatedRolesAllowedThenRespondsWithForbidden() throws Exception { + this.spring.register(RoleHierarchyConfig.class, WildcardController.class).autowire(); + MockHttpServletRequestBuilder requestWithUser = get("/deny").with(user("user").roles("USER")); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); } @EnableWebSecurity static class HasRoleStartingWithRoleConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -99,18 +513,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasRole("ROLE_USER"); // @formatter:on } - } - @Test - public void configureWhenNoCustomAccessDecisionManagerThenUsesAffirmativeBased() { - this.spring.register(NoSpecificAccessDecisionManagerConfig.class).autowire(); - - verify(NoSpecificAccessDecisionManagerConfig.objectPostProcessor) - .postProcess(any(AffirmativeBased.class)); } @EnableWebSecurity static class NoSpecificAccessDecisionManagerConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -126,34 +534,25 @@ public class ExpressionUrlAuthorizationConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } - } - @Test - public void configureWhenAuthorizedRequestsAndNoRequestsThenException() { - assertThatThrownBy(() -> this.spring.register(NoRequestsConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("At least one mapping is required (i.e. authorizeRequests().anyRequest().authenticated())"); } @EnableWebSecurity static class NoRequestsConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http .authorizeRequests(); // @formatter:on } - } - @Test - public void configureWhenAnyRequestIncompleteMappingThenException() { - assertThatThrownBy(() -> this.spring.register(IncompleteMappingConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("An incomplete mapping was found for "); } @EnableWebSecurity static class IncompleteMappingConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -163,39 +562,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest(); // @formatter:on } - } - @Test - public void getWhenHasAnyAuthorityRoleUserConfiguredAndAuthorityIsRoleUserThenRespondsWithOk() - throws Exception { - this.spring.register(RoleUserAnyAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_USER")))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenHasAnyAuthorityRoleUserConfiguredAndAuthorityIsRoleAdminThenRespondsWithForbidden() - throws Exception { - this.spring.register(RoleUserAnyAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_ADMIN")))) - .andExpect(status().isForbidden()); - } - - @Test - public void getWhenHasAnyAuthorityRoleUserConfiguredAndNoAuthorityThenRespondsWithUnauthorized() - throws Exception { - this.spring.register(RoleUserAnyAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class RoleUserAnyAuthorityConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -206,39 +578,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasAnyAuthority("ROLE_USER"); // @formatter:on } - } - @Test - public void getWhenHasAuthorityRoleUserConfiguredAndAuthorityIsRoleUserThenRespondsWithOk() - throws Exception { - this.spring.register(RoleUserAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_USER")))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenHasAuthorityRoleUserConfiguredAndAuthorityIsRoleAdminThenRespondsWithForbidden() - throws Exception { - this.spring.register(RoleUserAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_ADMIN")))) - .andExpect(status().isForbidden()); - } - - @Test - public void getWhenHasAuthorityRoleUserConfiguredAndNoAuthorityThenRespondsWithUnauthorized() - throws Exception { - this.spring.register(RoleUserAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class RoleUserAuthorityConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -249,49 +594,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasAuthority("ROLE_USER"); // @formatter:on } - } - @Test - public void getWhenAuthorityRoleUserOrAdminRequiredAndAuthorityIsRoleUserThenRespondsWithOk() - throws Exception { - this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_USER")))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenAuthorityRoleUserOrAdminRequiredAndAuthorityIsRoleAdminThenRespondsWithOk() - throws Exception { - this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_ADMIN")))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenAuthorityRoleUserOrAdminRequiredAndAuthorityIsRoleOtherThenRespondsWithForbidden() - throws Exception { - this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").authorities(new SimpleGrantedAuthority("ROLE_OTHER")))) - .andExpect(status().isForbidden()); - } - - @Test - public void getWhenAuthorityRoleUserOrAdminAuthRequiredAndNoUserThenRespondsWithUnauthorized() - throws Exception { - this.spring.register(RoleUserOrRoleAdminAuthorityConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class RoleUserOrRoleAdminAuthorityConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -302,28 +610,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasAnyAuthority("ROLE_USER", "ROLE_ADMIN"); // @formatter:on } - } - @Test - public void getWhenHasAnyRoleUserConfiguredAndRoleIsUserThenRespondsWithOk() throws Exception { - this.spring.register(RoleUserConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenHasAnyRoleUserConfiguredAndRoleIsAdminThenRespondsWithForbidden() throws Exception { - this.spring.register(RoleUserConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("ADMIN"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RoleUserConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -332,37 +624,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasAnyRole("USER"); // @formatter:on } - } - @Test - public void getWhenRoleUserOrAdminConfiguredAndRoleIsUserThenRespondsWithOk() throws Exception { - this.spring.register(RoleUserOrAdminConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenRoleUserOrAdminConfiguredAndRoleIsAdminThenRespondsWithOk() throws Exception { - this.spring.register(RoleUserOrAdminConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("ADMIN"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenRoleUserOrAdminConfiguredAndRoleIsOtherThenRespondsWithForbidden() throws Exception { - this.spring.register(RoleUserOrAdminConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("OTHER"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RoleUserOrAdminConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -371,34 +638,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasAnyRole("USER", "ADMIN"); // @formatter:on } - } - @Test - public void getWhenHasIpAddressConfiguredAndIpAddressMatchesThenRespondsWithOk() throws Exception { - this.spring.register(HasIpAddressConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(request -> { - request.setRemoteAddr("192.168.1.0"); - return request; - })) - .andExpect(status().isOk()); - } - - @Test - public void getWhenHasIpAddressConfiguredAndIpAddressDoesNotMatchThenRespondsWithUnauthorized() throws Exception { - this.spring.register(HasIpAddressConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(request -> { - request.setRemoteAddr("192.168.1.1"); - return request; - })) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class HasIpAddressConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -409,27 +654,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().hasIpAddress("192.168.1.0"); // @formatter:on } - } - @Test - public void getWhenAnonymousConfiguredAndAnonymousUserThenRespondsWithOk() throws Exception { - this.spring.register(AnonymousConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - } - - @Test - public void getWhenAnonymousConfiguredAndLoggedInUserThenRespondsWithForbidden() throws Exception { - this.spring.register(AnonymousConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class AnonymousConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -440,27 +670,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().anonymous(); // @formatter:on } - } - @Test - public void getWhenRememberMeConfiguredAndNoUserThenRespondsWithUnauthorized() throws Exception { - this.spring.register(RememberMeConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); - } - - @Test - public void getWhenRememberMeConfiguredAndRememberMeTokenThenRespondsWithOk() throws Exception { - this.spring.register(RememberMeConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(authentication(new RememberMeAuthenticationToken("key", "user", AuthorityUtils.createAuthorityList("ROLE_USER"))))) - .andExpect(status().isOk()); } @EnableWebSecurity static class RememberMeConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -482,27 +697,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .withUser("user").password("password").roles("USER"); // @formatter:on } - } - @Test - public void getWhenDenyAllConfiguredAndNoUserThenRespondsWithUnauthorized() throws Exception { - this.spring.register(DenyAllConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); - } - - @Test - public void getWheDenyAllConfiguredAndUserLoggedInThenRespondsWithForbidden() throws Exception { - this.spring.register(DenyAllConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("USER"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class DenyAllConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -513,27 +713,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().denyAll(); // @formatter:on } - } - @Test - public void getWhenNotDenyAllConfiguredAndNoUserThenRespondsWithOk() throws Exception { - this.spring.register(NotDenyAllConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - } - - @Test - public void getWhenNotDenyAllConfiguredAndRememberMeTokenThenRespondsWithOk() throws Exception { - this.spring.register(NotDenyAllConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(authentication(new RememberMeAuthenticationToken("key", "user", AuthorityUtils.createAuthorityList("ROLE_USER"))))) - .andExpect(status().isOk()); } @EnableWebSecurity static class NotDenyAllConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -544,28 +729,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().not().denyAll(); // @formatter:on } - } - @Test - public void getWhenFullyAuthenticatedConfiguredAndRememberMeTokenThenRespondsWithUnauthorized() throws Exception { - this.spring.register(FullyAuthenticatedConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(authentication(new RememberMeAuthenticationToken("key", "user", AuthorityUtils.createAuthorityList("ROLE_USER"))))) - .andExpect(status().isUnauthorized()); - } - - @Test - public void getWhenFullyAuthenticatedConfiguredAndUserThenRespondsWithOk() throws Exception { - this.spring.register(FullyAuthenticatedConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); } @EnableWebSecurity static class FullyAuthenticatedConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -578,37 +747,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().fullyAuthenticated(); // @formatter:on } - } - @Test - public void getWhenAccessRoleUserOrGetRequestConfiguredThenRespondsWithOk() throws Exception { - this.spring.register(AccessConfig.class, BasicController.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isOk()); - } - - @Test - public void postWhenAccessRoleUserOrGetRequestConfiguredAndRoleUserThenRespondsWithOk() throws Exception { - this.spring.register(AccessConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/") - .with(csrf()) - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void postWhenAccessRoleUserOrGetRequestConfiguredThenRespondsWithUnauthorized() throws Exception { - this.spring.register(AccessConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/") - .with(csrf())) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class AccessConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -621,19 +765,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .anyRequest().access("hasRole('ROLE_USER') or request.method == 'GET'"); // @formatter:on } - } - @Test - public void authorizeRequestsWhenInvokedTwiceThenUsesOriginalConfiguration() throws Exception { - this.spring.register(InvokeTwiceDoesNotResetConfig.class, BasicController.class).autowire(); - - this.mvc.perform(post("/") - .with(csrf())) - .andExpect(status().isUnauthorized()); } @EnableWebSecurity static class InvokeTwiceDoesNotResetConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -646,15 +783,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .authorizeRequests(); // @formatter:on } - } - @Test - public void configureWhenUsingAllAuthorizeRequestPropertiesThenCompiles() { - this.spring.register(AllPropertiesWorkConfig.class).autowire(); } @EnableWebSecurity static class AllPropertiesWorkConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { SecurityExpressionHandler handler = new DefaultWebSecurityExpressionHandler(); @@ -672,20 +806,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { .formLogin(); // @formatter:on } - } - @Test - public void configureWhenRegisteringObjectPostProcessorThenApplicationListenerInvokedOnAuthorizedEvent() - throws Exception { - this.spring.register(AuthorizedRequestsWithPostProcessorConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(AuthorizedRequestsWithPostProcessorConfig.AL).onApplicationEvent(any(AuthorizedEvent.class)); } @EnableWebSecurity static class AuthorizedRequestsWithPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ApplicationListener AL = mock(ApplicationListener.class); @Override @@ -695,6 +821,7 @@ public class ExpressionUrlAuthorizationConfigurerTests { .authorizeRequests() .anyRequest().permitAll() .withObjectPostProcessor(new ObjectPostProcessor() { + @Override public O postProcess( O fsi) { fsi.setPublishAuthorizationSuccess(true); @@ -705,49 +832,15 @@ public class ExpressionUrlAuthorizationConfigurerTests { } @Bean - public ApplicationListener applicationListener() { + ApplicationListener applicationListener() { return AL; } - } - @Test - public void getWhenPermissionCheckAndRoleDoesNotMatchThenRespondsWithForbidden() throws Exception { - this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/admin") - .with(user("user").roles("USER"))) - .andExpect(status().isForbidden()); - } - - @Test - public void getWhenPermissionCheckAndRoleMatchesThenRespondsWithOk() throws Exception { - this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/user") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenPermissionCheckAndAuthenticationNameMatchesThenRespondsWithOk() throws Exception { - this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/allow") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenPermissionCheckAndAuthenticationNameDoesNotMatchThenRespondsWithForbidden() throws Exception { - this.spring.register(UseBeansInExpressions.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/deny") - .with(user("user").roles("USER"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class UseBeansInExpressions extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -761,59 +854,23 @@ public class ExpressionUrlAuthorizationConfigurerTests { } @Bean - public Checker permission() { + Checker permission() { return new Checker(); } static class Checker { + public boolean check(Authentication authentication, String customArg) { return authentication.getName().contains(customArg); } + } - } - @Test - public void getWhenCustomExpressionHandlerAndRoleDoesNotMatchThenRespondsWithForbidden() - throws Exception { - this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/admin") - .with(user("user").roles("USER"))) - .andExpect(status().isForbidden()); - } - - @Test - public void getWhenCustomExpressionHandlerAndRoleMatchesThenRespondsWithOk() - throws Exception { - this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/user") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenCustomExpressionHandlerAndAuthenticationNameMatchesThenRespondsWithOk() - throws Exception { - this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/allow") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenCustomExpressionHandlerAndAuthenticationNameDoesNotMatchThenRespondsWithForbidden() - throws Exception { - this.spring.register(CustomExpressionRootConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/deny") - .with(user("user").roles("USER"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class CustomExpressionRootConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -828,21 +885,22 @@ public class ExpressionUrlAuthorizationConfigurerTests { } @Bean - public CustomExpressionHandler expressionHandler() { + CustomExpressionHandler expressionHandler() { return new CustomExpressionHandler(); } static class CustomExpressionHandler extends DefaultWebSecurityExpressionHandler { @Override - protected SecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, FilterInvocation fi) { + protected SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, + FilterInvocation fi) { WebSecurityExpressionRoot root = new CustomExpressionRoot(authentication, fi); root.setPermissionEvaluator(getPermissionEvaluator()); root.setTrustResolver(new AuthenticationTrustResolverImpl()); root.setRoleHierarchy(getRoleHierarchy()); return root; } + } static class CustomExpressionRoot extends WebSecurityExpressionRoot { @@ -855,20 +913,14 @@ public class ExpressionUrlAuthorizationConfigurerTests { Authentication auth = this.getAuthentication(); return auth.getName().contains(customArg); } + } - } - //SEC-3011 - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnAccessDecisionManager() { - this.spring.register(Sec3011Config.class).autowire(); - - verify(Sec3011Config.objectPostProcessor) - .postProcess(any(AccessDecisionManager.class)); } @EnableWebSecurity static class Sec3011Config extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -892,46 +944,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } - } - @Test - public void getWhenRegisteringPermissionEvaluatorAndPermissionWithIdAndTypeMatchesThenRespondsWithOk() - throws Exception { - this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/allow")) - .andExpect(status().isOk()); - } - - @Test - public void getWhenRegisteringPermissionEvaluatorAndPermissionWithIdAndTypeDoesNotMatchThenRespondsWithForbidden() - throws Exception { - this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/deny")) - .andExpect(status().isForbidden()); - } - - @Test - public void getWhenRegisteringPermissionEvaluatorAndPermissionWithObjectMatchesThenRespondsWithOk() - throws Exception { - this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/allowObject")) - .andExpect(status().isOk()); - } - - @Test - public void getWhenRegisteringPermissionEvaluatorAndPermissionWithObjectDoesNotMatchThenRespondsWithForbidden() - throws Exception { - this.spring.register(PermissionEvaluatorConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/denyObject")) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class PermissionEvaluatorConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -946,10 +964,11 @@ public class ExpressionUrlAuthorizationConfigurerTests { } @Bean - public PermissionEvaluator permissionEvaluator() { + PermissionEvaluator permissionEvaluator() { return new PermissionEvaluator() { @Override - public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { + public boolean hasPermission(Authentication authentication, Object targetDomainObject, + Object permission) { return "TESTOBJ".equals(targetDomainObject) && "PERMISSION".equals(permission); } @@ -960,28 +979,12 @@ public class ExpressionUrlAuthorizationConfigurerTests { } }; } - } - @Test - public void getWhenRegisteringRoleHierarchyAndRelatedRoleAllowedThenRespondsWithOk() throws Exception { - this.spring.register(RoleHierarchyConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/allow") - .with(user("user").roles("USER"))) - .andExpect(status().isOk()); - } - - @Test - public void getWhenRegisteringRoleHierarchyAndNoRelatedRolesAllowedThenRespondsWithForbidden() throws Exception { - this.spring.register(RoleHierarchyConfig.class, WildcardController.class).autowire(); - - this.mvc.perform(get("/deny") - .with(user("user").roles("USER"))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RoleHierarchyConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -994,35 +997,43 @@ public class ExpressionUrlAuthorizationConfigurerTests { } @Bean - public RoleHierarchy roleHierarchy() { + RoleHierarchy roleHierarchy() { RoleHierarchyImpl roleHierarchy = new RoleHierarchyImpl(); roleHierarchy.setHierarchy("ROLE_USER > ROLE_MEMBER"); return roleHierarchy; } + } @RestController static class BasicController { + @GetMapping("/") - public void rootGet() { + void rootGet() { } @PostMapping("/") - public void rootPost() { + void rootPost() { } + } @RestController static class WildcardController { + @GetMapping("/{path}") - public void wildcard(@PathVariable String path) { + void wildcard(@PathVariable String path) { } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java index 991e598581..80261676f0 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.ObjectPostProcessor; @@ -29,6 +30,7 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.users.AuthenticationTestConfiguration; import org.springframework.security.core.userdetails.PasswordEncodedUser; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders; import org.springframework.security.web.PortMapper; import org.springframework.security.web.access.ExceptionTranslationFilter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; @@ -38,10 +40,10 @@ import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.test.web.servlet.MockMvc; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout; @@ -58,6 +60,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @since 5.1 */ public class FormLoginConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -66,105 +69,317 @@ public class FormLoginConfigurerTests { @Test public void requestCache() throws Exception { - this.spring.register(RequestCacheConfig.class, - AuthenticationTestConfiguration.class).autowire(); - + this.spring.register(RequestCacheConfig.class, AuthenticationTestConfiguration.class).autowire(); RequestCacheConfig config = this.spring.getContext().getBean(RequestCacheConfig.class); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated()); - + this.mockMvc.perform(formLogin()).andExpect(authenticated()); verify(config.requestCache).getRequest(any(), any()); } - @EnableWebSecurity - static class RequestCacheConfig extends WebSecurityConfigurerAdapter { - private RequestCache requestCache = mock(RequestCache.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .formLogin().and() - .requestCache() - .requestCache(this.requestCache); - } - } - @Test public void requestCacheAsBean() throws Exception { - this.spring.register(RequestCacheBeanConfig.class, - AuthenticationTestConfiguration.class).autowire(); - + this.spring.register(RequestCacheBeanConfig.class, AuthenticationTestConfiguration.class).autowire(); RequestCache requestCache = this.spring.getContext().getBean(RequestCache.class); - - this.mockMvc.perform(formLogin()) - .andExpect(authenticated()); - + this.mockMvc.perform(formLogin()).andExpect(authenticated()); verify(requestCache).getRequest(any(), any()); } - @EnableWebSecurity - static class RequestCacheBeanConfig { - @Bean - RequestCache requestCache() { - return mock(RequestCache.class); - } - } - @Test public void loginWhenFormLoginConfiguredThenHasDefaultUsernameAndPasswordParameterNames() throws Exception { this.spring.register(FormLoginConfig.class).autowire(); - - this.mockMvc.perform(formLogin().user("username", "user").password("password", "password")) + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder loginRequest = formLogin() + .user("username", "user") + .password("password", "password"); + this.mockMvc.perform(loginRequest) .andExpect(status().isFound()) .andExpect(redirectedUrl("/")); + // @formatter:on } @Test public void loginWhenFormLoginConfiguredThenHasDefaultFailureUrl() throws Exception { this.spring.register(FormLoginConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(formLogin().user("invalid")) .andExpect(status().isFound()) .andExpect(redirectedUrl("/login?error")); + // @formatter:on } @Test public void loginWhenFormLoginConfiguredThenHasDefaultSuccessUrl() throws Exception { this.spring.register(FormLoginConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(formLogin()) .andExpect(status().isFound()) .andExpect(redirectedUrl("/")); + // @formatter:on } @Test public void getLoginPageWhenFormLoginConfiguredThenNotSecured() throws Exception { this.spring.register(FormLoginConfig.class).autowire(); - - this.mockMvc.perform(get("/login")) - .andExpect(status().isFound()); + this.mockMvc.perform(get("/login")).andExpect(status().isFound()); } @Test public void loginWhenFormLoginConfiguredThenSecured() throws Exception { this.spring.register(FormLoginConfig.class).autowire(); - - this.mockMvc.perform(post("/login")) - .andExpect(status().isForbidden()); + this.mockMvc.perform(post("/login")).andExpect(status().isForbidden()); } @Test public void requestProtectedWhenFormLoginConfiguredThenRedirectsToLogin() throws Exception { this.spring.register(FormLoginConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(get("/private")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on + } + + @Test + public void loginWhenFormLoginDefaultsInLambdaThenHasDefaultUsernameAndPasswordParameterNames() throws Exception { + this.spring.register(FormLoginInLambdaConfig.class).autowire(); + // @formatter:off + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder loginRequest = formLogin() + .user("username", "user") + .password("password", "password"); + this.mockMvc.perform(loginRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/")); + // @formatter:on + } + + @Test + public void loginWhenFormLoginDefaultsInLambdaThenHasDefaultFailureUrl() throws Exception { + this.spring.register(FormLoginInLambdaConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(formLogin().user("invalid")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?error")); + // @formatter:on + } + + @Test + public void loginWhenFormLoginDefaultsInLambdaThenHasDefaultSuccessUrl() throws Exception { + this.spring.register(FormLoginInLambdaConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(formLogin()) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/")); + // @formatter:on + } + + @Test + public void getLoginPageWhenFormLoginDefaultsInLambdaThenNotSecured() throws Exception { + this.spring.register(FormLoginInLambdaConfig.class).autowire(); + this.mockMvc.perform(get("/login")).andExpect(status().isOk()); + } + + @Test + public void loginWhenFormLoginDefaultsInLambdaThenSecured() throws Exception { + this.spring.register(FormLoginInLambdaConfig.class).autowire(); + this.mockMvc.perform(post("/login")).andExpect(status().isForbidden()); + } + + @Test + public void requestProtectedWhenFormLoginDefaultsInLambdaThenRedirectsToLogin() throws Exception { + this.spring.register(FormLoginInLambdaConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/private")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on + } + + @Test + public void getLoginPageWhenFormLoginPermitAllThenPermittedAndNoRedirect() throws Exception { + this.spring.register(FormLoginConfigPermitAll.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/login")) + .andExpect(status().isOk()) + .andExpect(redirectedUrl(null)); + // @formatter:on + } + + @Test + public void getLoginPageWithErrorQueryWhenFormLoginPermitAllThenPermittedAndNoRedirect() throws Exception { + this.spring.register(FormLoginConfigPermitAll.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/login?error")) + .andExpect(status().isOk()) + .andExpect(redirectedUrl(null)); + // @formatter:on + } + + @Test + public void loginWhenFormLoginPermitAllAndInvalidUserThenRedirectsToLoginPageWithError() throws Exception { + this.spring.register(FormLoginConfigPermitAll.class).autowire(); + // @formatter:off + this.mockMvc.perform(formLogin().user("invalid")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?error")); + // @formatter:on + } + + @Test + public void getLoginPageWhenCustomLoginPageThenPermittedAndNoRedirect() throws Exception { + this.spring.register(FormLoginDefaultsConfig.class).autowire(); + this.mockMvc.perform(get("/authenticate")).andExpect(redirectedUrl(null)); + } + + @Test + public void getLoginPageWithErrorQueryWhenCustomLoginPageThenPermittedAndNoRedirect() throws Exception { + this.spring.register(FormLoginDefaultsConfig.class).autowire(); + this.mockMvc.perform(get("/authenticate?error")).andExpect(redirectedUrl(null)); + } + + @Test + public void loginWhenCustomLoginPageAndInvalidUserThenRedirectsToCustomLoginPageWithError() throws Exception { + this.spring.register(FormLoginDefaultsConfig.class).autowire(); + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder request = formLogin("/authenticate").user("invalid"); + // @formatter:off + this.mockMvc.perform(request) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/authenticate?error")); + // @formatter:on + } + + @Test + public void logoutWhenCustomLoginPageThenRedirectsToCustomLoginPage() throws Exception { + this.spring.register(FormLoginDefaultsConfig.class).autowire(); + this.mockMvc.perform(logout()).andExpect(redirectedUrl("/authenticate?logout")); + } + + @Test + public void getLoginPageWithLogoutQueryWhenCustomLoginPageThenPermittedAndNoRedirect() throws Exception { + this.spring.register(FormLoginDefaultsConfig.class).autowire(); + this.mockMvc.perform(get("/authenticate?logout")).andExpect(redirectedUrl(null)); + } + + @Test + public void getLoginPageWhenCustomLoginPageInLambdaThenPermittedAndNoRedirect() throws Exception { + this.spring.register(FormLoginDefaultsInLambdaConfig.class).autowire(); + this.mockMvc.perform(get("/authenticate")).andExpect(redirectedUrl(null)); + } + + @Test + public void loginWhenCustomLoginProcessingUrlThenRedirectsToHome() throws Exception { + this.spring.register(FormLoginLoginProcessingUrlConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(formLogin("/loginCheck")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/")); + // @formatter:on + } + + @Test + public void loginWhenCustomLoginProcessingUrlInLambdaThenRedirectsToHome() throws Exception { + this.spring.register(FormLoginLoginProcessingUrlInLambdaConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(formLogin("/loginCheck")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/")); + // @formatter:on + } + + @Test + public void requestWhenCustomPortMapperThenPortMapperUsed() throws Exception { + FormLoginUsesPortMapperConfig.PORT_MAPPER = mock(PortMapper.class); + given(FormLoginUsesPortMapperConfig.PORT_MAPPER.lookupHttpsPort(any())).willReturn(9443); + this.spring.register(FormLoginUsesPortMapperConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("http://localhost:9090")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("https://localhost:9443/login")); + // @formatter:on + verify(FormLoginUsesPortMapperConfig.PORT_MAPPER).lookupHttpsPort(any()); + } + + @Test + public void failureUrlWhenPermitAllAndFailureHandlerThenSecured() throws Exception { + this.spring.register(PermitAllIgnoresFailureHandlerConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/login?error")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on + } + + @Test + public void formLoginWhenInvokedTwiceThenUsesOriginalUsernameParameter() throws Exception { + this.spring.register(DuplicateInvocationsDoesNotOverrideConfig.class).autowire(); + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder loginRequest = formLogin().user("custom-username", + "user"); + this.mockMvc.perform(loginRequest).andExpect(authenticated()); + } + + @Test + public void loginWhenInvalidLoginAndFailureForwardUrlThenForwardsToFailureForwardUrl() throws Exception { + this.spring.register(FormLoginUserForwardAuthenticationSuccessAndFailureConfig.class).autowire(); + SecurityMockMvcRequestBuilders.FormLoginRequestBuilder loginRequest = formLogin().user("invalid"); + this.mockMvc.perform(loginRequest).andExpect(forwardedUrl("/failure_forward_url")); + } + + @Test + public void loginWhenSuccessForwardUrlThenForwardsToSuccessForwardUrl() throws Exception { + this.spring.register(FormLoginUserForwardAuthenticationSuccessAndFailureConfig.class).autowire(); + this.mockMvc.perform(formLogin()).andExpect(forwardedUrl("/success_forward_url")); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnUsernamePasswordAuthenticationFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor) + .postProcess(any(UsernamePasswordAuthenticationFilter.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnLoginUrlAuthenticationEntryPoint() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(LoginUrlAuthenticationEntryPoint.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnExceptionTranslationFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ExceptionTranslationFilter.class)); + } + + @EnableWebSecurity + static class RequestCacheConfig extends WebSecurityConfigurerAdapter { + + private RequestCache requestCache = mock(RequestCache.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .formLogin().and() + .requestCache() + .requestCache(this.requestCache); + // @formatter:on + } + + } + + @EnableWebSecurity + static class RequestCacheBeanConfig { + + @Bean + RequestCache requestCache() { + return mock(RequestCache.class); + } + } @EnableWebSecurity static class FormLoginConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(WebSecurity web) { // @formatter:off @@ -194,67 +409,17 @@ public class FormLoginConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenFormLoginDefaultsInLambdaThenHasDefaultUsernameAndPasswordParameterNames() throws Exception { - this.spring.register(FormLoginInLambdaConfig.class).autowire(); - - this.mockMvc.perform(formLogin().user("username", "user").password("password", "password")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")); - } - - @Test - public void loginWhenFormLoginDefaultsInLambdaThenHasDefaultFailureUrl() throws Exception { - this.spring.register(FormLoginInLambdaConfig.class).autowire(); - - this.mockMvc.perform(formLogin().user("invalid")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?error")); - } - - @Test - public void loginWhenFormLoginDefaultsInLambdaThenHasDefaultSuccessUrl() throws Exception { - this.spring.register(FormLoginInLambdaConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")); - } - - @Test - public void getLoginPageWhenFormLoginDefaultsInLambdaThenNotSecured() throws Exception { - this.spring.register(FormLoginInLambdaConfig.class).autowire(); - - this.mockMvc.perform(get("/login")) - .andExpect(status().isOk()); - } - - @Test - public void loginWhenFormLoginDefaultsInLambdaThenSecured() throws Exception { - this.spring.register(FormLoginInLambdaConfig.class).autowire(); - - this.mockMvc.perform(post("/login")) - .andExpect(status().isForbidden()); - } - - @Test - public void requestProtectedWhenFormLoginDefaultsInLambdaThenRedirectsToLogin() throws Exception { - this.spring.register(FormLoginInLambdaConfig.class).autowire(); - - this.mockMvc.perform(get("/private")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/login")); } @EnableWebSecurity static class FormLoginInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) @@ -270,37 +435,12 @@ public class FormLoginConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void getLoginPageWhenFormLoginPermitAllThenPermittedAndNoRedirect() throws Exception { - this.spring.register(FormLoginConfigPermitAll.class).autowire(); - - this.mockMvc.perform(get("/login")) - .andExpect(status().isOk()) - .andExpect(redirectedUrl(null)); - } - - @Test - public void getLoginPageWithErrorQueryWhenFormLoginPermitAllThenPermittedAndNoRedirect() throws Exception { - this.spring.register(FormLoginConfigPermitAll.class).autowire(); - - this.mockMvc.perform(get("/login?error")) - .andExpect(status().isOk()) - .andExpect(redirectedUrl(null)); - } - - @Test - public void loginWhenFormLoginPermitAllAndInvalidUserThenRedirectsToLoginPageWithError() throws Exception { - this.spring.register(FormLoginConfigPermitAll.class).autowire(); - - this.mockMvc.perform(formLogin().user("invalid")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?error")); } @EnableWebSecurity static class FormLoginConfigPermitAll extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -312,51 +452,12 @@ public class FormLoginConfigurerTests { .permitAll(); // @formatter:on } - } - @Test - public void getLoginPageWhenCustomLoginPageThenPermittedAndNoRedirect() throws Exception { - this.spring.register(FormLoginDefaultsConfig.class).autowire(); - - this.mockMvc.perform(get("/authenticate")) - .andExpect(redirectedUrl(null)); - } - - @Test - public void getLoginPageWithErrorQueryWhenCustomLoginPageThenPermittedAndNoRedirect() throws Exception { - this.spring.register(FormLoginDefaultsConfig.class).autowire(); - - this.mockMvc.perform(get("/authenticate?error")) - .andExpect(redirectedUrl(null)); - } - - @Test - public void loginWhenCustomLoginPageAndInvalidUserThenRedirectsToCustomLoginPageWithError() throws Exception { - this.spring.register(FormLoginDefaultsConfig.class).autowire(); - - this.mockMvc.perform(formLogin("/authenticate").user("invalid")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/authenticate?error")); - } - - @Test - public void logoutWhenCustomLoginPageThenRedirectsToCustomLoginPage() throws Exception { - this.spring.register(FormLoginDefaultsConfig.class).autowire(); - - this.mockMvc.perform(logout()) - .andExpect(redirectedUrl("/authenticate?logout")); - } - - @Test - public void getLoginPageWithLogoutQueryWhenCustomLoginPageThenPermittedAndNoRedirect() throws Exception { - this.spring.register(FormLoginDefaultsConfig.class).autowire(); - - this.mockMvc.perform(get("/authenticate?logout")) - .andExpect(redirectedUrl(null)); } @EnableWebSecurity static class FormLoginDefaultsConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -372,27 +473,21 @@ public class FormLoginConfigurerTests { .permitAll(); // @formatter:on } - } - @Test - public void getLoginPageWhenCustomLoginPageInLambdaThenPermittedAndNoRedirect() throws Exception { - this.spring.register(FormLoginDefaultsInLambdaConfig.class).autowire(); - - this.mockMvc.perform(get("/authenticate")) - .andExpect(redirectedUrl(null)); } @EnableWebSecurity static class FormLoginDefaultsInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) - .formLogin(formLogin -> + .formLogin((formLogin) -> formLogin .loginPage("/authenticate") .permitAll() @@ -400,19 +495,12 @@ public class FormLoginConfigurerTests { .logout(LogoutConfigurer::permitAll); // @formatter:on } - } - @Test - public void loginWhenCustomLoginProcessingUrlThenRedirectsToHome() throws Exception { - this.spring.register(FormLoginLoginProcessingUrlConfig.class).autowire(); - - this.mockMvc.perform(formLogin("/loginCheck")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")); } @EnableWebSecurity static class FormLoginLoginProcessingUrlConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -443,35 +531,28 @@ public class FormLoginConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenCustomLoginProcessingUrlInLambdaThenRedirectsToHome() throws Exception { - this.spring.register(FormLoginLoginProcessingUrlInLambdaConfig.class).autowire(); - - this.mockMvc.perform(formLogin("/loginCheck")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")); } @EnableWebSecurity static class FormLoginLoginProcessingUrlInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) - .formLogin(formLogin -> + .formLogin((formLogin) -> formLogin .loginProcessingUrl("/loginCheck") .loginPage("/login") .defaultSuccessUrl("/", true) .permitAll() ) - .logout(logout -> + .logout((logout) -> logout .logoutSuccessUrl("/login") .logoutUrl("/logout") @@ -488,23 +569,12 @@ public class FormLoginConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void requestWhenCustomPortMapperThenPortMapperUsed() throws Exception { - FormLoginUsesPortMapperConfig.PORT_MAPPER = mock(PortMapper.class); - when(FormLoginUsesPortMapperConfig.PORT_MAPPER.lookupHttpsPort(any())).thenReturn(9443); - this.spring.register(FormLoginUsesPortMapperConfig.class).autowire(); - - this.mockMvc.perform(get("http://localhost:9090")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("https://localhost:9443/login")); - - verify(FormLoginUsesPortMapperConfig.PORT_MAPPER).lookupHttpsPort(any()); } @EnableWebSecurity static class FormLoginUsesPortMapperConfig extends WebSecurityConfigurerAdapter { + static PortMapper PORT_MAPPER; @Override @@ -520,23 +590,16 @@ public class FormLoginConfigurerTests { .portMapper() .portMapper(PORT_MAPPER); // @formatter:on - LoginUrlAuthenticationEntryPoint authenticationEntryPoint = - (LoginUrlAuthenticationEntryPoint) http.getConfigurer(FormLoginConfigurer.class).getAuthenticationEntryPoint(); + LoginUrlAuthenticationEntryPoint authenticationEntryPoint = (LoginUrlAuthenticationEntryPoint) http + .getConfigurer(FormLoginConfigurer.class).getAuthenticationEntryPoint(); authenticationEntryPoint.setForceHttps(true); } - } - @Test - public void failureUrlWhenPermitAllAndFailureHandlerThenSecured() throws Exception { - this.spring.register(PermitAllIgnoresFailureHandlerConfig.class).autowire(); - - this.mockMvc.perform(get("/login?error")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/login")); } @EnableWebSecurity static class PermitAllIgnoresFailureHandlerConfig extends WebSecurityConfigurerAdapter { + static AuthenticationFailureHandler FAILURE_HANDLER = mock(AuthenticationFailureHandler.class); @Override @@ -551,18 +614,12 @@ public class FormLoginConfigurerTests { .permitAll(); // @formatter:on } - } - @Test - public void formLoginWhenInvokedTwiceThenUsesOriginalUsernameParameter() throws Exception { - this.spring.register(DuplicateInvocationsDoesNotOverrideConfig.class).autowire(); - - this.mockMvc.perform(formLogin().user("custom-username", "user")) - .andExpect(authenticated()); } @EnableWebSecurity static class DuplicateInvocationsDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -582,22 +639,7 @@ public class FormLoginConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenInvalidLoginAndFailureForwardUrlThenForwardsToFailureForwardUrl() throws Exception { - this.spring.register(FormLoginUserForwardAuthenticationSuccessAndFailureConfig.class).autowire(); - - this.mockMvc.perform(formLogin().user("invalid")) - .andExpect(forwardedUrl("/failure_forward_url")); - } - - @Test - public void loginWhenSuccessForwardUrlThenForwardsToSuccessForwardUrl() throws Exception { - this.spring.register(FormLoginUserForwardAuthenticationSuccessAndFailureConfig.class).autowire(); - - this.mockMvc.perform(formLogin()) - .andExpect(forwardedUrl("/success_forward_url")); } @EnableWebSecurity @@ -627,37 +669,12 @@ public class FormLoginConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnUsernamePasswordAuthenticationFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(UsernamePasswordAuthenticationFilter.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnLoginUrlAuthenticationEntryPoint() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(LoginUrlAuthenticationEntryPoint.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnExceptionTranslationFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ExceptionTranslationFilter.class)); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor; @Override @@ -674,12 +691,16 @@ public class FormLoginConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerEagerHeadersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerEagerHeadersTests.java index f3a21064ff..86163e04c8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerEagerHeadersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerEagerHeadersTests.java @@ -18,7 +18,9 @@ package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -27,7 +29,6 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.web.header.HeaderWriterFilter; import org.springframework.test.web.servlet.MockMvc; -import static org.springframework.http.HttpHeaders.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; @@ -44,11 +45,24 @@ public class HeadersConfigurerEagerHeadersTests { @Autowired MockMvc mvc; + @Test + public void requestWhenHeadersEagerlyConfiguredThenHeadersAreWritten() throws Exception { + this.spring.register(HeadersAtTheBeginningOfRequestConfig.class).autowire(); + this.mvc.perform(get("/").secure(true)).andExpect(header().string("X-Content-Type-Options", "nosniff")) + .andExpect(header().string("X-Frame-Options", "DENY")) + .andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains")) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) + .andExpect(header().string(HttpHeaders.EXPIRES, "0")) + .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) + .andExpect(header().string("X-XSS-Protection", "1; mode=block")); + } + @EnableWebSecurity public static class HeadersAtTheBeginningOfRequestConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { - //@ formatter:off + // @formatter:off http .headers() .addObjectPostProcessor(new ObjectPostProcessor() { @@ -58,21 +72,9 @@ public class HeadersConfigurerEagerHeadersTests { return filter; } }); - //@ formatter:on + // @formatter:on } + } - @Test - public void requestWhenHeadersEagerlyConfiguredThenHeadersAreWritten() throws Exception { - this.spring.register(HeadersAtTheBeginningOfRequestConfig.class).autowire(); - - this.mvc.perform(get("/").secure(true)) - .andExpect(header().string("X-Content-Type-Options", "nosniff")) - .andExpect(header().string("X-Frame-Options", "DENY")) - .andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains")) - .andExpect(header().string(CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) - .andExpect(header().string(EXPIRES, "0")) - .andExpect(header().string(PRAGMA, "no-cache")) - .andExpect(header().string("X-XSS-Protection", "1; mode=block")); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerTests.java index 390925e8d5..497fe083aa 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HeadersConfigurerTests.java @@ -16,9 +16,14 @@ package org.springframework.security.config.annotation.web.configurers; +import java.net.URI; +import java.util.LinkedHashMap; +import java.util.Map; + import com.google.common.net.HttpHeaders; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -29,13 +34,10 @@ import org.springframework.security.web.header.writers.ReferrerPolicyHeaderWrite import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter.XFrameOptionsMode; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; - -import java.net.URI; -import java.util.LinkedHashMap; -import java.util.Map; +import org.springframework.test.web.servlet.ResultMatcher; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; @@ -61,21 +63,418 @@ public class HeadersConfigurerTests { @Test public void getWhenHeadersConfiguredThenDefaultHeadersInResponse() throws Exception { this.spring.register(HeadersConfig.class).autowire(); - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")) .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.DENY.name())) - .andExpect(header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) + .andExpect( + header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) .andExpect(header().string(HttpHeaders.EXPIRES, "0")) .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) - .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")) - .andReturn(); + .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")).andReturn(); assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder( HttpHeaders.X_CONTENT_TYPE_OPTIONS, HttpHeaders.X_FRAME_OPTIONS, HttpHeaders.STRICT_TRANSPORT_SECURITY, HttpHeaders.CACHE_CONTROL, HttpHeaders.EXPIRES, HttpHeaders.PRAGMA, HttpHeaders.X_XSS_PROTECTION); } + @Test + public void getWhenHeadersConfiguredInLambdaThenDefaultHeadersInResponse() throws Exception { + this.spring.register(HeadersInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")) + .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.DENY.name())) + .andExpect( + header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) + .andExpect(header().string(HttpHeaders.EXPIRES, "0")) + .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) + .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder( + HttpHeaders.X_CONTENT_TYPE_OPTIONS, HttpHeaders.X_FRAME_OPTIONS, HttpHeaders.STRICT_TRANSPORT_SECURITY, + HttpHeaders.CACHE_CONTROL, HttpHeaders.EXPIRES, HttpHeaders.PRAGMA, HttpHeaders.X_XSS_PROTECTION); + } + + @Test + public void getWhenHeaderDefaultsDisabledAndContentTypeConfiguredThenOnlyContentTypeHeaderInResponse() + throws Exception { + this.spring.register(ContentTypeOptionsConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")) + .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_CONTENT_TYPE_OPTIONS); + } + + @Test + public void getWhenOnlyContentTypeConfiguredInLambdaThenOnlyContentTypeHeaderInResponse() throws Exception { + this.spring.register(ContentTypeOptionsInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")) + .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_CONTENT_TYPE_OPTIONS); + } + + @Test + public void getWhenHeaderDefaultsDisabledAndFrameOptionsConfiguredThenOnlyFrameOptionsHeaderInResponse() + throws Exception { + this.spring.register(FrameOptionsConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")) + .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.DENY.name())).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_FRAME_OPTIONS); + } + + @Test + public void getWhenHeaderDefaultsDisabledAndHstsConfiguredThenOnlyStrictTransportSecurityHeaderInResponse() + throws Exception { + this.spring.register(HstsConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect( + header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) + .andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.STRICT_TRANSPORT_SECURITY); + } + + @Test + public void getWhenHeaderDefaultsDisabledAndCacheControlConfiguredThenCacheControlAndExpiresAndPragmaHeadersInResponse() + throws Exception { + this.spring.register(CacheControlConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) + .andExpect(header().string(HttpHeaders.EXPIRES, "0")) + .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder(HttpHeaders.CACHE_CONTROL, + HttpHeaders.EXPIRES, HttpHeaders.PRAGMA); + } + + @Test + public void getWhenOnlyCacheControlConfiguredInLambdaThenCacheControlAndExpiresAndPragmaHeadersInResponse() + throws Exception { + this.spring.register(CacheControlInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) + .andExpect(header().string(HttpHeaders.EXPIRES, "0")) + .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder(HttpHeaders.CACHE_CONTROL, + HttpHeaders.EXPIRES, HttpHeaders.PRAGMA); + } + + @Test + public void getWhenHeaderDefaultsDisabledAndXssProtectionConfiguredThenOnlyXssProtectionHeaderInResponse() + throws Exception { + this.spring.register(XssProtectionConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_XSS_PROTECTION); + } + + @Test + public void getWhenOnlyXssProtectionConfiguredInLambdaThenOnlyXssProtectionHeaderInResponse() throws Exception { + this.spring.register(XssProtectionInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_XSS_PROTECTION); + } + + @Test + public void getWhenFrameOptionsSameOriginConfiguredThenFrameOptionsHeaderHasValueSameOrigin() throws Exception { + this.spring.register(HeadersCustomSameOriginConfig.class).autowire(); + this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.SAMEORIGIN.name())) + .andReturn(); + } + + @Test + public void getWhenFrameOptionsSameOriginConfiguredInLambdaThenFrameOptionsHeaderHasValueSameOrigin() + throws Exception { + this.spring.register(HeadersCustomSameOriginInLambdaConfig.class).autowire(); + this.mvc.perform(get("/").secure(true)) + .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.SAMEORIGIN.name())) + .andReturn(); + } + + @Test + public void getWhenHeaderDefaultsDisabledAndPublicHpkpWithNoPinThenNoHeadersInResponse() throws Exception { + this.spring.register(HpkpConfigNoPins.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).isEmpty(); + } + + @Test + public void getWhenSecureRequestAndHpkpWithPinThenPublicKeyPinsReportOnlyHeaderInResponse() throws Exception { + this.spring.register(HpkpConfig.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenInsecureRequestHeaderDefaultsDisabledAndHpkpWithPinThenNoHeadersInResponse() throws Exception { + this.spring.register(HpkpConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")).andReturn(); + assertThat(mvcResult.getResponse().getHeaderNames()).isEmpty(); + } + + @Test + public void getWhenHpkpWithMultiplePinsThenPublicKeyPinsReportOnlyHeaderWithMultiplePinsInResponse() + throws Exception { + this.spring.register(HpkpConfigWithPins.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenHpkpWithCustomAgeThenPublicKeyPinsReportOnlyHeaderWithCustomAgeInResponse() throws Exception { + this.spring.register(HpkpConfigCustomAge.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=604800 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenHpkpWithReportOnlyFalseThenPublicKeyPinsHeaderInResponse() throws Exception { + this.spring.register(HpkpConfigTerminateConnection.class).autowire(); + ResultMatcher pins = header().string(HttpHeaders.PUBLIC_KEY_PINS, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pins) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS); + } + + @Test + public void getWhenHpkpIncludeSubdomainThenPublicKeyPinsReportOnlyHeaderWithIncludeSubDomainsInResponse() + throws Exception { + this.spring.register(HpkpConfigIncludeSubDomains.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; includeSubDomains"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenHpkpWithReportUriThenPublicKeyPinsReportOnlyHeaderWithReportUriInResponse() throws Exception { + this.spring.register(HpkpConfigWithReportURI.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenHpkpWithReportUriAsStringThenPublicKeyPinsReportOnlyHeaderWithReportUriInResponse() + throws Exception { + this.spring.register(HpkpConfigWithReportURIAsString.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenHpkpWithReportUriInLambdaThenPublicKeyPinsReportOnlyHeaderWithReportUriInResponse() + throws Exception { + this.spring.register(HpkpWithReportUriInLambdaConfig.class).autowire(); + ResultMatcher pinsReportOnly = header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\""); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(pinsReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); + } + + @Test + public void getWhenContentSecurityPolicyConfiguredThenContentSecurityPolicyHeaderInResponse() throws Exception { + this.spring.register(ContentSecurityPolicyDefaultConfig.class).autowire(); + ResultMatcher csp = header().string(HttpHeaders.CONTENT_SECURITY_POLICY, "default-src 'self'"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(csp) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY); + } + + @Test + public void getWhenContentSecurityPolicyWithReportOnlyThenContentSecurityPolicyReportOnlyHeaderInResponse() + throws Exception { + this.spring.register(ContentSecurityPolicyReportOnlyConfig.class).autowire(); + ResultMatcher cspReportOnly = header().string(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY, + "default-src 'self'; script-src trustedscripts.example.com"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(cspReportOnly) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()) + .containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY); + } + + @Test + public void getWhenContentSecurityPolicyWithReportOnlyInLambdaThenContentSecurityPolicyReportOnlyHeaderInResponse() + throws Exception { + this.spring.register(ContentSecurityPolicyReportOnlyInLambdaConfig.class).autowire(); + ResultMatcher csp = header().string(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY, + "default-src 'self'; script-src trustedscripts.example.com"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(csp) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()) + .containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY); + } + + @Test + public void configureWhenContentSecurityPolicyEmptyThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(ContentSecurityPolicyInvalidConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenContentSecurityPolicyEmptyInLambdaThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(ContentSecurityPolicyInvalidInLambdaConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenContentSecurityPolicyNoPolicyDirectivesInLambdaThenDefaultHeaderValue() throws Exception { + this.spring.register(ContentSecurityPolicyNoDirectivesInLambdaConfig.class).autowire(); + ResultMatcher csp = header().string(HttpHeaders.CONTENT_SECURITY_POLICY, "default-src 'self'"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(csp) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY); + } + + @Test + public void getWhenReferrerPolicyConfiguredThenReferrerPolicyHeaderInResponse() throws Exception { + this.spring.register(ReferrerPolicyDefaultConfig.class).autowire(); + ResultMatcher referrerPolicy = header().string("Referrer-Policy", ReferrerPolicy.NO_REFERRER.getPolicy()); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(referrerPolicy) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); + } + + @Test + public void getWhenReferrerPolicyInLambdaThenReferrerPolicyHeaderInResponse() throws Exception { + this.spring.register(ReferrerPolicyDefaultInLambdaConfig.class).autowire(); + ResultMatcher referrerPolicy = header().string("Referrer-Policy", ReferrerPolicy.NO_REFERRER.getPolicy()); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(referrerPolicy) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); + } + + @Test + public void getWhenReferrerPolicyConfiguredWithCustomValueThenReferrerPolicyHeaderWithCustomValueInResponse() + throws Exception { + this.spring.register(ReferrerPolicyCustomConfig.class).autowire(); + ResultMatcher referrerPolicy = header().string("Referrer-Policy", ReferrerPolicy.SAME_ORIGIN.getPolicy()); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(referrerPolicy) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); + } + + @Test + public void getWhenReferrerPolicyConfiguredWithCustomValueInLambdaThenCustomValueInResponse() throws Exception { + this.spring.register(ReferrerPolicyCustomInLambdaConfig.class).autowire(); + ResultMatcher referrerPolicy = header().string("Referrer-Policy", ReferrerPolicy.SAME_ORIGIN.getPolicy()); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(referrerPolicy) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); + } + + @Test + public void getWhenFeaturePolicyConfiguredThenFeaturePolicyHeaderInResponse() throws Exception { + this.spring.register(FeaturePolicyConfig.class).autowire(); + ResultMatcher featurePolicy = header().string("Feature-Policy", "geolocation 'self'"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(featurePolicy) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Feature-Policy"); + } + + @Test + public void configureWhenFeaturePolicyEmptyThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(FeaturePolicyInvalidConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getWhenHstsConfiguredWithPreloadThenStrictTransportSecurityHeaderWithPreloadInResponse() + throws Exception { + this.spring.register(HstsWithPreloadConfig.class).autowire(); + ResultMatcher hsts = header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, + "max-age=31536000 ; includeSubDomains ; preload"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(hsts) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.STRICT_TRANSPORT_SECURITY); + } + + @Test + public void getWhenHstsConfiguredWithPreloadInLambdaThenStrictTransportSecurityHeaderWithPreloadInResponse() + throws Exception { + this.spring.register(HstsWithPreloadInLambdaConfig.class).autowire(); + ResultMatcher hsts = header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, + "max-age=31536000 ; includeSubDomains ; preload"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) + .andExpect(hsts) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.STRICT_TRANSPORT_SECURITY); + } + @EnableWebSecurity static class HeadersConfig extends WebSecurityConfigurerAdapter { @@ -86,24 +485,7 @@ public class HeadersConfigurerTests { .headers(); // @formatter:on } - } - @Test - public void getWhenHeadersConfiguredInLambdaThenDefaultHeadersInResponse() throws Exception { - this.spring.register(HeadersInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")) - .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.DENY.name())) - .andExpect(header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) - .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) - .andExpect(header().string(HttpHeaders.EXPIRES, "0")) - .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) - .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder( - HttpHeaders.X_CONTENT_TYPE_OPTIONS, HttpHeaders.X_FRAME_OPTIONS, HttpHeaders.STRICT_TRANSPORT_SECURITY, - HttpHeaders.CACHE_CONTROL, HttpHeaders.EXPIRES, HttpHeaders.PRAGMA, HttpHeaders.X_XSS_PROTECTION); } @EnableWebSecurity @@ -116,17 +498,7 @@ public class HeadersConfigurerTests { .headers(withDefaults()); // @formatter:on } - } - @Test - public void getWhenHeaderDefaultsDisabledAndContentTypeConfiguredThenOnlyContentTypeHeaderInResponse() - throws Exception { - this.spring.register(ContentTypeOptionsConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_CONTENT_TYPE_OPTIONS); } @EnableWebSecurity @@ -141,17 +513,7 @@ public class HeadersConfigurerTests { .contentTypeOptions(); // @formatter:on } - } - @Test - public void getWhenOnlyContentTypeConfiguredInLambdaThenOnlyContentTypeHeaderInResponse() - throws Exception { - this.spring.register(ContentTypeOptionsInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andExpect(header().string(HttpHeaders.X_CONTENT_TYPE_OPTIONS, "nosniff")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_CONTENT_TYPE_OPTIONS); } @EnableWebSecurity @@ -161,24 +523,14 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() .contentTypeOptions(withDefaults()) ); // @formatter:on } - } - @Test - public void getWhenHeaderDefaultsDisabledAndFrameOptionsConfiguredThenOnlyFrameOptionsHeaderInResponse() - throws Exception { - this.spring.register(FrameOptionsConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.DENY.name())) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_FRAME_OPTIONS); } @EnableWebSecurity @@ -193,17 +545,7 @@ public class HeadersConfigurerTests { .frameOptions(); // @formatter:on } - } - @Test - public void getWhenHeaderDefaultsDisabledAndHstsConfiguredThenOnlyStrictTransportSecurityHeaderInResponse() - throws Exception { - this.spring.register(HstsConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.STRICT_TRANSPORT_SECURITY); } @EnableWebSecurity @@ -218,20 +560,7 @@ public class HeadersConfigurerTests { .httpStrictTransportSecurity(); // @formatter:on } - } - @Test - public void getWhenHeaderDefaultsDisabledAndCacheControlConfiguredThenCacheControlAndExpiresAndPragmaHeadersInResponse() - throws Exception { - this.spring.register(CacheControlConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) - .andExpect(header().string(HttpHeaders.EXPIRES, "0")) - .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder(HttpHeaders.CACHE_CONTROL, - HttpHeaders.EXPIRES, HttpHeaders.PRAGMA); } @EnableWebSecurity @@ -246,20 +575,7 @@ public class HeadersConfigurerTests { .cacheControl(); // @formatter:on } - } - @Test - public void getWhenOnlyCacheControlConfiguredInLambdaThenCacheControlAndExpiresAndPragmaHeadersInResponse() - throws Exception { - this.spring.register(CacheControlInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")) - .andExpect(header().string(HttpHeaders.EXPIRES, "0")) - .andExpect(header().string(HttpHeaders.PRAGMA, "no-cache")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactlyInAnyOrder(HttpHeaders.CACHE_CONTROL, - HttpHeaders.EXPIRES, HttpHeaders.PRAGMA); } @EnableWebSecurity @@ -269,24 +585,14 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() .cacheControl(withDefaults()) ); // @formatter:on } - } - @Test - public void getWhenHeaderDefaultsDisabledAndXssProtectionConfiguredThenOnlyXssProtectionHeaderInResponse() - throws Exception { - this.spring.register(XssProtectionConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_XSS_PROTECTION); } @EnableWebSecurity @@ -301,17 +607,7 @@ public class HeadersConfigurerTests { .xssProtection(); // @formatter:on } - } - @Test - public void getWhenOnlyXssProtectionConfiguredInLambdaThenOnlyXssProtectionHeaderInResponse() - throws Exception { - this.spring.register(XssProtectionInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.X_XSS_PROTECTION, "1; mode=block")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.X_XSS_PROTECTION); } @EnableWebSecurity @@ -321,22 +617,14 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() .xssProtection(withDefaults()) ); // @formatter:on } - } - @Test - public void getWhenFrameOptionsSameOriginConfiguredThenFrameOptionsHeaderHasValueSameOrigin() throws Exception { - this.spring.register(HeadersCustomSameOriginConfig.class).autowire(); - - this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.SAMEORIGIN.name())) - .andReturn(); } @EnableWebSecurity @@ -350,16 +638,7 @@ public class HeadersConfigurerTests { .frameOptions().sameOrigin(); // @formatter:on } - } - @Test - public void getWhenFrameOptionsSameOriginConfiguredInLambdaThenFrameOptionsHeaderHasValueSameOrigin() - throws Exception { - this.spring.register(HeadersCustomSameOriginInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.X_FRAME_OPTIONS, XFrameOptionsMode.SAMEORIGIN.name())) - .andReturn(); } @EnableWebSecurity @@ -369,21 +648,13 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers - .frameOptions(frameOptionsConfig -> frameOptionsConfig.sameOrigin()) + .frameOptions((frameOptionsConfig) -> frameOptionsConfig.sameOrigin()) ); // @formatter:on } - } - @Test - public void getWhenHeaderDefaultsDisabledAndPublicHpkpWithNoPinThenNoHeadersInResponse() throws Exception { - this.spring.register(HpkpConfigNoPins.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).isEmpty(); } @EnableWebSecurity @@ -398,28 +669,7 @@ public class HeadersConfigurerTests { .httpPublicKeyPinning(); // @formatter:on } - } - @Test - public void getWhenSecureRequestAndHpkpWithPinThenPublicKeyPinsReportOnlyHeaderInResponse() - throws Exception { - this.spring.register(HpkpConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); - } - - @Test - public void getWhenInsecureRequestHeaderDefaultsDisabledAndHpkpWithPinThenNoHeadersInResponse() - throws Exception { - this.spring.register(HpkpConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).isEmpty(); } @EnableWebSecurity @@ -435,18 +685,7 @@ public class HeadersConfigurerTests { .addSha256Pins("d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="); // @formatter:on } - } - @Test - public void getWhenHpkpWithMultiplePinsThenPublicKeyPinsReportOnlyHeaderWithMultiplePinsInResponse() - throws Exception { - this.spring.register(HpkpConfigWithPins.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); } @EnableWebSecurity @@ -457,7 +696,6 @@ public class HeadersConfigurerTests { Map pins = new LinkedHashMap<>(); pins.put("d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=", "sha256"); pins.put("E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=", "sha256"); - // @formatter:off http .headers() @@ -466,17 +704,7 @@ public class HeadersConfigurerTests { .withPins(pins); // @formatter:on } - } - @Test - public void getWhenHpkpWithCustomAgeThenPublicKeyPinsReportOnlyHeaderWithCustomAgeInResponse() throws Exception { - this.spring.register(HpkpConfigCustomAge.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=604800 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); } @EnableWebSecurity @@ -493,17 +721,7 @@ public class HeadersConfigurerTests { .maxAgeInSeconds(604800); // @formatter:on } - } - @Test - public void getWhenHpkpWithReportOnlyFalseThenPublicKeyPinsHeaderInResponse() throws Exception { - this.spring.register(HpkpConfigTerminateConnection.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS); } @EnableWebSecurity @@ -520,18 +738,7 @@ public class HeadersConfigurerTests { .reportOnly(false); // @formatter:on } - } - @Test - public void getWhenHpkpIncludeSubdomainThenPublicKeyPinsReportOnlyHeaderWithIncludeSubDomainsInResponse() - throws Exception { - this.spring.register(HpkpConfigIncludeSubDomains.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; includeSubDomains")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); } @EnableWebSecurity @@ -548,17 +755,7 @@ public class HeadersConfigurerTests { .includeSubDomains(true); // @formatter:on } - } - @Test - public void getWhenHpkpWithReportUriThenPublicKeyPinsReportOnlyHeaderWithReportUriInResponse() throws Exception { - this.spring.register(HpkpConfigWithReportURI.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); } @EnableWebSecurity @@ -575,18 +772,7 @@ public class HeadersConfigurerTests { .reportUri(new URI("https://example.net/pkp-report")); // @formatter:on } - } - @Test - public void getWhenHpkpWithReportUriAsStringThenPublicKeyPinsReportOnlyHeaderWithReportUriInResponse() - throws Exception { - this.spring.register(HpkpConfigWithReportURIAsString.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); } @EnableWebSecurity @@ -603,18 +789,7 @@ public class HeadersConfigurerTests { .reportUri("https://example.net/pkp-report"); // @formatter:on } - } - @Test - public void getWhenHpkpWithReportUriInLambdaThenPublicKeyPinsReportOnlyHeaderWithReportUriInResponse() - throws Exception { - this.spring.register(HpkpWithReportUriInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY, - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\"")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.PUBLIC_KEY_PINS_REPORT_ONLY); } @EnableWebSecurity @@ -624,10 +799,10 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() - .httpPublicKeyPinning(hpkp -> + .httpPublicKeyPinning((hpkp) -> hpkp .addSha256Pins("d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=") .reportUri("https://example.net/pkp-report") @@ -635,16 +810,7 @@ public class HeadersConfigurerTests { ); // @formatter:on } - } - @Test - public void getWhenContentSecurityPolicyConfiguredThenContentSecurityPolicyHeaderInResponse() throws Exception { - this.spring.register(ContentSecurityPolicyDefaultConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.CONTENT_SECURITY_POLICY, "default-src 'self'")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY); } @EnableWebSecurity @@ -659,17 +825,7 @@ public class HeadersConfigurerTests { .contentSecurityPolicy("default-src 'self'"); // @formatter:on } - } - @Test - public void getWhenContentSecurityPolicyWithReportOnlyThenContentSecurityPolicyReportOnlyHeaderInResponse() throws Exception { - this.spring.register(ContentSecurityPolicyReportOnlyConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY, - "default-src 'self'; script-src trustedscripts.example.com")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY); } @EnableWebSecurity @@ -685,18 +841,7 @@ public class HeadersConfigurerTests { .reportOnly(); // @formatter:on } - } - @Test - public void getWhenContentSecurityPolicyWithReportOnlyInLambdaThenContentSecurityPolicyReportOnlyHeaderInResponse() - throws Exception { - this.spring.register(ContentSecurityPolicyReportOnlyInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY, - "default-src 'self'; script-src trustedscripts.example.com")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY_REPORT_ONLY); } @EnableWebSecurity @@ -706,10 +851,10 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() - .contentSecurityPolicy(csp -> + .contentSecurityPolicy((csp) -> csp .policyDirectives("default-src 'self'; script-src trustedscripts.example.com") .reportOnly() @@ -717,13 +862,7 @@ public class HeadersConfigurerTests { ); // @formatter:on } - } - @Test - public void configureWhenContentSecurityPolicyEmptyThenException() { - assertThatThrownBy(() -> this.spring.register(ContentSecurityPolicyInvalidConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity @@ -738,13 +877,7 @@ public class HeadersConfigurerTests { .contentSecurityPolicy(""); // @formatter:on } - } - @Test - public void configureWhenContentSecurityPolicyEmptyInLambdaThenException() { - assertThatThrownBy(() -> this.spring.register(ContentSecurityPolicyInvalidInLambdaConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity @@ -754,26 +887,16 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() - .contentSecurityPolicy(csp -> + .contentSecurityPolicy((csp) -> csp.policyDirectives("") ) ); // @formatter:on } - } - @Test - public void configureWhenContentSecurityPolicyNoPolicyDirectivesInLambdaThenDefaultHeaderValue() throws Exception { - this.spring.register(ContentSecurityPolicyNoDirectivesInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.CONTENT_SECURITY_POLICY, - "default-src 'self'")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.CONTENT_SECURITY_POLICY); } @EnableWebSecurity @@ -783,23 +906,14 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() .contentSecurityPolicy(withDefaults()) ); // @formatter:on } - } - @Test - public void getWhenReferrerPolicyConfiguredThenReferrerPolicyHeaderInResponse() throws Exception { - this.spring.register(ReferrerPolicyDefaultConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string("Referrer-Policy", ReferrerPolicy.NO_REFERRER.getPolicy())) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); } @EnableWebSecurity @@ -814,16 +928,7 @@ public class HeadersConfigurerTests { .referrerPolicy(); // @formatter:on } - } - @Test - public void getWhenReferrerPolicyInLambdaThenReferrerPolicyHeaderInResponse() throws Exception { - this.spring.register(ReferrerPolicyDefaultInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string("Referrer-Policy", ReferrerPolicy.NO_REFERRER.getPolicy())) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); } @EnableWebSecurity @@ -833,24 +938,14 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() .referrerPolicy() ); // @formatter:on } - } - @Test - public void getWhenReferrerPolicyConfiguredWithCustomValueThenReferrerPolicyHeaderWithCustomValueInResponse() - throws Exception { - this.spring.register(ReferrerPolicyCustomConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string("Referrer-Policy", ReferrerPolicy.SAME_ORIGIN.getPolicy())) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); } @EnableWebSecurity @@ -865,16 +960,7 @@ public class HeadersConfigurerTests { .referrerPolicy(ReferrerPolicy.SAME_ORIGIN); // @formatter:on } - } - @Test - public void getWhenReferrerPolicyConfiguredWithCustomValueInLambdaThenCustomValueInResponse() throws Exception { - this.spring.register(ReferrerPolicyCustomInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string("Referrer-Policy", ReferrerPolicy.SAME_ORIGIN.getPolicy())) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Referrer-Policy"); } @EnableWebSecurity @@ -884,25 +970,16 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() - .referrerPolicy(referrerPolicy -> + .referrerPolicy((referrerPolicy) -> referrerPolicy.policy(ReferrerPolicy.SAME_ORIGIN) ) ); // @formatter:on } - } - @Test - public void getWhenFeaturePolicyConfiguredThenFeaturePolicyHeaderInResponse() throws Exception { - this.spring.register(FeaturePolicyConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string("Feature-Policy", "geolocation 'self'")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly("Feature-Policy"); } @EnableWebSecurity @@ -917,13 +994,7 @@ public class HeadersConfigurerTests { .featurePolicy("geolocation 'self'"); // @formatter:on } - } - @Test - public void configureWhenFeaturePolicyEmptyThenException() { - assertThatThrownBy(() -> this.spring.register(FeaturePolicyInvalidConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity @@ -938,18 +1009,7 @@ public class HeadersConfigurerTests { .featurePolicy(""); // @formatter:on } - } - @Test - public void getWhenHstsConfiguredWithPreloadThenStrictTransportSecurityHeaderWithPreloadInResponse() - throws Exception { - this.spring.register(HstsWithPreloadConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, - "max-age=31536000 ; includeSubDomains ; preload")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.STRICT_TRANSPORT_SECURITY); } @EnableWebSecurity @@ -965,18 +1025,7 @@ public class HeadersConfigurerTests { .preload(true); // @formatter:on } - } - @Test - public void getWhenHstsConfiguredWithPreloadInLambdaThenStrictTransportSecurityHeaderWithPreloadInResponse() - throws Exception { - this.spring.register(HstsWithPreloadInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/").secure(true)) - .andExpect(header().string(HttpHeaders.STRICT_TRANSPORT_SECURITY, - "max-age=31536000 ; includeSubDomains ; preload")) - .andReturn(); - assertThat(mvcResult.getResponse().getHeaderNames()).containsExactly(HttpHeaders.STRICT_TRANSPORT_SECURITY); } @EnableWebSecurity @@ -986,12 +1035,14 @@ public class HeadersConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .headers(headers -> + .headers((headers) -> headers .defaultsDisabled() - .httpStrictTransportSecurity(hstsConfig -> hstsConfig.preload(true)) + .httpStrictTransportSecurity((hstsConfig) -> hstsConfig.preload(true)) ); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurerTests.java index 4a33b440e5..8d8f686510 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpBasicConfigurerTests.java @@ -16,8 +16,12 @@ package org.springframework.security.config.annotation.web.configurers; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -34,15 +38,18 @@ import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.www.BasicAuthenticationFilter; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** * Tests for {@link HttpBasicConfigurer} @@ -61,13 +68,58 @@ public class HttpBasicConfigurerTests { @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnBasicAuthenticationFilter() { this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(BasicAuthenticationFilter.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(BasicAuthenticationFilter.class)); + @Test + public void httpBasicWhenUsingDefaultsInLambdaThenResponseIncludesBasicChallenge() throws Exception { + this.spring.register(DefaultsLambdaEntryPointConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/")) + .andExpect(status().isUnauthorized()) + .andExpect(header().string("WWW-Authenticate", "Basic realm=\"Realm\"")); + // @formatter:on + } + + // SEC-2198 + @Test + public void httpBasicWhenUsingDefaultsThenResponseIncludesBasicChallenge() throws Exception { + this.spring.register(DefaultsEntryPointConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/")) + .andExpect(status().isUnauthorized()) + .andExpect(header().string("WWW-Authenticate", "Basic realm=\"Realm\"")); + // @formatter:on + } + + @Test + public void httpBasicWhenUsingCustomAuthenticationEntryPointThenResponseIncludesBasicChallenge() throws Exception { + this.spring.register(CustomAuthenticationEntryPointConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(CustomAuthenticationEntryPointConfig.ENTRY_POINT).commence(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(AuthenticationException.class)); + } + + @Test + public void httpBasicWhenInvokedTwiceThenUsesOriginalEntryPoint() throws Exception { + this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(DuplicateDoesNotOverrideConfig.ENTRY_POINT).commence(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(AuthenticationException.class)); + } + + // SEC-3019 + @Test + public void httpBasicWhenRememberMeConfiguredThenSetsRememberMeCookie() throws Exception { + this.spring.register(BasicUsesRememberMeConfig.class).autowire(); + MockHttpServletRequestBuilder rememberMeRequest = get("/").with(httpBasic("user", "password")) + .param("remember-me", "true"); + this.mvc.perform(rememberMeRequest).andExpect(cookie().exists("remember-me")); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -82,31 +134,26 @@ public class HttpBasicConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void httpBasicWhenUsingDefaultsInLambdaThenResponseIncludesBasicChallenge() throws Exception { - this.spring.register(DefaultsLambdaEntryPointConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()) - .andExpect(header().string("WWW-Authenticate", "Basic realm=\"Realm\"")); } @EnableWebSecurity static class DefaultsLambdaEntryPointConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -121,20 +168,12 @@ public class HttpBasicConfigurerTests { .inMemoryAuthentication(); // @formatter:on } - } - //SEC-2198 - @Test - public void httpBasicWhenUsingDefaultsThenResponseIncludesBasicChallenge() throws Exception { - this.spring.register(DefaultsEntryPointConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()) - .andExpect(header().string("WWW-Authenticate", "Basic realm=\"Realm\"")); } @EnableWebSecurity static class DefaultsEntryPointConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -153,22 +192,12 @@ public class HttpBasicConfigurerTests { .inMemoryAuthentication(); // @formatter:on } - } - @Test - public void httpBasicWhenUsingCustomAuthenticationEntryPointThenResponseIncludesBasicChallenge() throws Exception { - this.spring.register(CustomAuthenticationEntryPointConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(CustomAuthenticationEntryPointConfig.ENTRY_POINT) - .commence(any(HttpServletRequest.class), - any(HttpServletResponse.class), - any(AuthenticationException.class)); } @EnableWebSecurity static class CustomAuthenticationEntryPointConfig extends WebSecurityConfigurerAdapter { + static AuthenticationEntryPoint ENTRY_POINT = mock(AuthenticationEntryPoint.class); @Override @@ -190,22 +219,12 @@ public class HttpBasicConfigurerTests { .inMemoryAuthentication(); // @formatter:on } - } - @Test - public void httpBasicWhenInvokedTwiceThenUsesOriginalEntryPoint() throws Exception { - this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(DuplicateDoesNotOverrideConfig.ENTRY_POINT) - .commence(any(HttpServletRequest.class), - any(HttpServletResponse.class), - any(AuthenticationException.class)); } @EnableWebSecurity static class DuplicateDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + static AuthenticationEntryPoint ENTRY_POINT = mock(AuthenticationEntryPoint.class); @Override @@ -229,17 +248,7 @@ public class HttpBasicConfigurerTests { .inMemoryAuthentication(); // @formatter:on } - } - //SEC-3019 - @Test - public void httpBasicWhenRememberMeConfiguredThenSetsRememberMeCookie() throws Exception { - this.spring.register(BasicUsesRememberMeConfig.class).autowire(); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password")) - .param("remember-me", "true")) - .andExpect(cookie().exists("remember-me")); } @EnableWebSecurity @@ -256,15 +265,20 @@ public class HttpBasicConfigurerTests { // @formatter:on } + @Override @Bean public UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager( + // @formatter:off User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build() + // @formatter:on ); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityAntMatchersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityAntMatchersTests.java index 1fc9ec3dbf..9690aed1dd 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityAntMatchersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityAntMatchersTests.java @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletResponse; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpMethod; @@ -35,15 +35,20 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur import org.springframework.security.web.FilterChainProxy; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Rob Winch * */ public class HttpSecurityAntMatchersTests { + AnnotationConfigWebApplicationContext context; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Autowired @@ -51,15 +56,15 @@ public class HttpSecurityAntMatchersTests { @Before public void setup() { - request = new MockHttpServletRequest("GET", ""); - response = new MockHttpServletResponse(); - chain = new MockFilterChain(); + this.request = new MockHttpServletRequest("GET", ""); + this.response = new MockHttpServletResponse(); + this.chain = new MockFilterChain(); } @After public void cleanup() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @@ -67,47 +72,60 @@ public class HttpSecurityAntMatchersTests { @Test public void antMatchersMethodAndNoPatterns() throws Exception { loadConfig(AntMatchersNoPatternsConfig.class); - request.setMethod("POST"); - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); - } - - @EnableWebSecurity - @Configuration - static class AntMatchersNoPatternsConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .requestMatchers() - .antMatchers(HttpMethod.POST) - .and() - .authorizeRequests() - .anyRequest().denyAll(); - } - - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication(); - } + this.request.setMethod("POST"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } // SEC-3135 @Test public void antMatchersMethodAndEmptyPatterns() throws Exception { loadConfig(AntMatchersEmptyPatternsConfig.class); - request.setMethod("POST"); + this.request.setMethod("POST"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } - springSecurityFilterChain.doFilter(request, response, chain); + public void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.refresh(); + this.context.getAutowireCapableBeanFactory().autowireBean(this); + } + + @EnableWebSecurity + @Configuration + static class AntMatchersNoPatternsConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .requestMatchers() + .antMatchers(HttpMethod.POST) + .and() + .authorizeRequests() + .anyRequest().denyAll(); + // @formatter:on + } + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication(); + // @formatter:on + } - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration static class AntMatchersEmptyPatternsConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .requestMatchers() .antMatchers("/never/") @@ -115,22 +133,17 @@ public class HttpSecurityAntMatchersTests { .and() .authorizeRequests() .anyRequest().denyAll(); + // @formatter:on } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication(); + // @formatter:on } + } - public void loadConfig(Class... configs) { - context = new AnnotationConfigWebApplicationContext(); - context.register(configs); - context.refresh(); - - context.getAutowireCapableBeanFactory().autowireBean(this); - } - - } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java index 315a85a74b..da18813b49 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityLogoutTests.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.config.annotation.web.configurers; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.mock.web.MockFilterChain; @@ -36,15 +36,20 @@ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Rob Winch * */ public class HttpSecurityLogoutTests { + AnnotationConfigWebApplicationContext context; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Autowired @@ -52,15 +57,15 @@ public class HttpSecurityLogoutTests { @Before public void setup() { - request = new MockHttpServletRequest("GET", ""); - response = new MockHttpServletResponse(); - chain = new MockFilterChain(); + this.request = new MockHttpServletRequest("GET", ""); + this.response = new MockHttpServletResponse(); + this.chain = new MockFilterChain(); } @After public void cleanup() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @@ -68,43 +73,45 @@ public class HttpSecurityLogoutTests { @Test public void clearAuthenticationFalse() throws Exception { loadConfig(ClearAuthenticationFalseConfig.class); - SecurityContext currentContext = SecurityContextHolder.createEmptyContext(); currentContext.setAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); - - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, currentContext); - request.setMethod("POST"); - request.setServletPath("/logout"); - - springSecurityFilterChain.doFilter(request, response, chain); - + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + currentContext); + this.request.setMethod("POST"); + this.request.setServletPath("/logout"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); assertThat(currentContext.getAuthentication()).isNotNull(); } + public void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.refresh(); + this.context.getAutowireCapableBeanFactory().autowireBean(this); + } + @EnableWebSecurity @Configuration static class ClearAuthenticationFalseConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .csrf().disable() .logout() .clearAuthentication(false); + // @formatter:on } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication(); + // @formatter:on } + } - public void loadConfig(Class... configs) { - context = new AnnotationConfigWebApplicationContext(); - context.register(configs); - context.refresh(); - - context.getAutowireCapableBeanFactory().autowireBean(this); - } - - } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java index fec73febcc..ca4ae5ff5b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecurityRequestMatchersTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletResponse; @@ -47,10 +48,13 @@ import static org.springframework.security.config.Customizer.withDefaults; * */ public class HttpSecurityRequestMatchersTests { + AnnotationConfigWebApplicationContext context; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Autowired @@ -74,41 +78,108 @@ public class HttpSecurityRequestMatchersTests { @Test public void mvcMatcher() throws Exception { loadConfig(MvcMatcherConfig.class, LegacyMvcMatchingConfig.class); - this.request.setServletPath("/path"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath("/path.html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath("/path/"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @Test public void mvcMatcherGetFiltersNoUnsupportedMethodExceptionFromDummyRequest() { loadConfig(MvcMatcherConfig.class); + assertThat(this.springSecurityFilterChain.getFilters("/path")).isNotEmpty(); + } - assertThat(springSecurityFilterChain.getFilters("/path")).isNotEmpty(); + @Test + public void requestMatchersMvcMatcher() throws Exception { + loadConfig(RequestMatchersMvcMatcherConfig.class, LegacyMvcMatchingConfig.class); + this.request.setServletPath("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void requestMatchersWhenMvcMatcherInLambdaThenPathIsSecured() throws Exception { + loadConfig(RequestMatchersMvcMatcherInLambdaConfig.class, LegacyMvcMatchingConfig.class); + this.request.setServletPath("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } + + @Test + public void requestMatchersMvcMatcherServletPath() throws Exception { + loadConfig(RequestMatchersMvcMatcherServeltPathConfig.class); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath(""); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/other"); + this.request.setRequestURI("/other/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + @Test + public void requestMatcherWhensMvcMatcherServletPathInLambdaThenPathIsSecured() throws Exception { + loadConfig(RequestMatchersMvcMatcherServletPathInLambdaConfig.class); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath(""); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/other"); + this.request.setRequestURI("/other/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + public void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.setServletContext(new MockServletContext()); + this.context.refresh(); + this.context.getAutowireCapableBeanFactory().autowireBean(this); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -130,44 +201,21 @@ public class HttpSecurityRequestMatchersTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestMatchersMvcMatcher() throws Exception { - loadConfig(RequestMatchersMvcMatcherConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setServletPath("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class RequestMatchersMvcMatcherConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -191,54 +239,31 @@ public class HttpSecurityRequestMatchersTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestMatchersWhenMvcMatcherInLambdaThenPathIsSecured() throws Exception { - loadConfig(RequestMatchersMvcMatcherInLambdaConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setServletPath("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); } @EnableWebSecurity @Configuration @EnableWebMvc static class RequestMatchersMvcMatcherInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .requestMatchers(requestMatchers -> + .requestMatchers((requestMatchers) -> requestMatchers .mvcMatchers("/path") ) .httpBasic(withDefaults()) - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().denyAll() ); @@ -247,47 +272,21 @@ public class HttpSecurityRequestMatchersTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestMatchersMvcMatcherServletPath() throws Exception { - loadConfig(RequestMatchersMvcMatcherServeltPathConfig.class); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath(""); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration @EnableWebMvc - static class RequestMatchersMvcMatcherServeltPathConfig - extends WebSecurityConfigurerAdapter { + static class RequestMatchersMvcMatcherServeltPathConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -312,58 +311,32 @@ public class HttpSecurityRequestMatchersTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void requestMatcherWhensMvcMatcherServletPathInLambdaThenPathIsSecured() throws Exception { - loadConfig(RequestMatchersMvcMatcherServletPathInLambdaConfig.class); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath(""); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/other"); - this.request.setRequestURI("/other/path"); - - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration @EnableWebMvc - static class RequestMatchersMvcMatcherServletPathInLambdaConfig - extends WebSecurityConfigurerAdapter { + static class RequestMatchersMvcMatcherServletPathInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .requestMatchers(requestMatchers -> + .requestMatchers((requestMatchers) -> requestMatchers .mvcMatchers("/path").servletPath("/spring") .mvcMatchers("/never-match") ) .httpBasic(withDefaults()) - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().denyAll() ); @@ -372,27 +345,24 @@ public class HttpSecurityRequestMatchersTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } + } @Configuration static class LegacyMvcMatchingConfig implements WebMvcConfigurer { + @Override public void configurePathMatch(PathMatchConfigurer configurer) { configurer.setUseSuffixPatternMatch(true); } + } - public void loadConfig(Class... configs) { - this.context = new AnnotationConfigWebApplicationContext(); - this.context.register(configs); - this.context.setServletContext(new MockServletContext()); - this.context.refresh(); - - this.context.getAutowireCapableBeanFactory().autowireBean(this); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/Issue55Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/Issue55Tests.java index 62aeedce86..bb9e239a48 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/Issue55Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/Issue55Tests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.lang.reflect.InvocationTargetException; import java.util.List; + import javax.servlet.Filter; import org.junit.Rule; @@ -38,7 +40,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; import org.springframework.stereotype.Component; -import static org.assertj.core.api.Java6Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -54,91 +56,22 @@ public class Issue55Tests { TestingAuthenticationToken token = new TestingAuthenticationToken("test", "this"); this.spring.register(WebSecurityConfigurerAdapterDefaultsAuthManagerConfig.class); this.spring.getContext().getBean(FilterChainProxy.class); - FilterSecurityInterceptor filter = (FilterSecurityInterceptor) findFilter(FilterSecurityInterceptor.class, 0); assertThat(filter.getAuthenticationManager().authenticate(token)).isEqualTo(CustomAuthenticationManager.RESULT); } - - @EnableWebSecurity - static class WebSecurityConfigurerAdapterDefaultsAuthManagerConfig { - @Component - public static class WebSecurityAdapter extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().hasRole("USER"); - // @formatter:on - } - } - - @Configuration - public static class AuthenticationManagerConfiguration { - @Bean - public AuthenticationManager authenticationManager() throws Exception { - return new CustomAuthenticationManager(); - } - } - } - @Test - public void multiHttpWebSecurityConfigurerAdapterDefaultsToAutowired() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + public void multiHttpWebSecurityConfigurerAdapterDefaultsToAutowired() + throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { TestingAuthenticationToken token = new TestingAuthenticationToken("test", "this"); this.spring.register(MultiWebSecurityConfigurerAdapterDefaultsAuthManagerConfig.class); this.spring.getContext().getBean(FilterChainProxy.class); - FilterSecurityInterceptor filter = (FilterSecurityInterceptor) findFilter(FilterSecurityInterceptor.class, 0); assertThat(filter.getAuthenticationManager().authenticate(token)).isEqualTo(CustomAuthenticationManager.RESULT); - - FilterSecurityInterceptor secondFilter = (FilterSecurityInterceptor) findFilter(FilterSecurityInterceptor.class, 1); - assertThat(secondFilter.getAuthenticationManager().authenticate(token)).isEqualTo(CustomAuthenticationManager.RESULT); - } - - @EnableWebSecurity - static class MultiWebSecurityConfigurerAdapterDefaultsAuthManagerConfig { - @Component - @Order(1) - public static class ApiWebSecurityAdapter extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http.antMatcher("/api/**") - .authorizeRequests() - .anyRequest().hasRole("USER"); - // @formatter:on - } - } - - @Component - public static class WebSecurityAdapter extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().hasRole("USER"); - // @formatter:on - } - } - - @Configuration - public static class AuthenticationManagerConfiguration { - @Bean - public AuthenticationManager authenticationManager() throws Exception { - return new CustomAuthenticationManager(); - } - } - } - - static class CustomAuthenticationManager implements AuthenticationManager { - static Authentication RESULT = new TestingAuthenticationToken("test", "this", "ROLE_USER"); - - public Authentication authenticate(Authentication authentication) throws AuthenticationException { - return RESULT; - } + FilterSecurityInterceptor secondFilter = (FilterSecurityInterceptor) findFilter(FilterSecurityInterceptor.class, + 1); + assertThat(secondFilter.getAuthenticationManager().authenticate(token)) + .isEqualTo(CustomAuthenticationManager.RESULT); } Filter findFilter(Class filter, int index) { @@ -154,4 +87,89 @@ public class Issue55Tests { SecurityFilterChain filterChain(int index) { return this.spring.getContext().getBean(FilterChainProxy.class).getFilterChains().get(index); } + + @EnableWebSecurity + static class WebSecurityConfigurerAdapterDefaultsAuthManagerConfig { + + @Component + public static class WebSecurityAdapter extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER"); + // @formatter:on + } + + } + + @Configuration + public static class AuthenticationManagerConfiguration { + + @Bean + public AuthenticationManager authenticationManager() throws Exception { + return new CustomAuthenticationManager(); + } + + } + + } + + @EnableWebSecurity + static class MultiWebSecurityConfigurerAdapterDefaultsAuthManagerConfig { + + @Component + @Order(1) + public static class ApiWebSecurityAdapter extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http.antMatcher("/api/**") + .authorizeRequests() + .anyRequest().hasRole("USER"); + // @formatter:on + } + + } + + @Component + public static class WebSecurityAdapter extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER"); + // @formatter:on + } + + } + + @Configuration + public static class AuthenticationManagerConfiguration { + + @Bean + public AuthenticationManager authenticationManager() throws Exception { + return new CustomAuthenticationManager(); + } + + } + + } + + static class CustomAuthenticationManager implements AuthenticationManager { + + static Authentication RESULT = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + return RESULT; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurerTests.java index 57de13f087..2562ede8a8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/JeeConfigurerTests.java @@ -16,8 +16,11 @@ package org.springframework.security.config.annotation.web.configurers; +import java.security.Principal; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.ObjectPostProcessor; @@ -28,17 +31,17 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.AuthenticationUserDetailsService; import org.springframework.security.core.userdetails.User; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.security.web.authentication.preauth.j2ee.J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource; import org.springframework.security.web.authentication.preauth.j2ee.J2eePreAuthenticatedProcessingFilter; import org.springframework.test.web.servlet.MockMvc; - -import java.security.Principal; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -60,7 +63,6 @@ public class JeeConfigurerTests { public void configureWhenRegisteringObjectPostProcessorThenInvokedOnJ2eePreAuthenticatedProcessingFilter() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); - verify(ObjectPostProcessorConfig.objectPostProcessor) .postProcess(any(J2eePreAuthenticatedProcessingFilter.class)); } @@ -69,13 +71,88 @@ public class JeeConfigurerTests { public void configureWhenRegisteringObjectPostProcessorThenInvokedOnJ2eeBasedPreAuthenticatedWebAuthenticationDetailsSource() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); - verify(ObjectPostProcessorConfig.objectPostProcessor) .postProcess(any(J2eeBasedPreAuthenticatedWebAuthenticationDetailsSource.class)); } + @Test + public void jeeWhenInvokedTwiceThenUsesOriginalMappableRoles() throws Exception { + this.spring.register(InvokeTwiceDoesNotOverride.class).autowire(); + Principal user = mock(Principal.class); + given(user.getName()).willReturn("user"); + // @formatter:off + MockHttpServletRequestBuilder authRequest = get("/") + .principal(user) + .with((request) -> { + request.addUserRole("ROLE_ADMIN"); + request.addUserRole("ROLE_USER"); + return request; + }); + // @formatter:on + this.mvc.perform(authRequest).andExpect(authenticated().withRoles("USER")); + } + + @Test + public void requestWhenJeeMappableRolesInLambdaThenAuthenticatedWithMappableRoles() throws Exception { + this.spring.register(JeeMappableRolesConfig.class).autowire(); + Principal user = mock(Principal.class); + given(user.getName()).willReturn("user"); + // @formatter:off + MockHttpServletRequestBuilder authRequest = get("/") + .principal(user) + .with((request) -> { + request.addUserRole("ROLE_ADMIN"); + request.addUserRole("ROLE_USER"); + return request; + }); + // @formatter:on + this.mvc.perform(authRequest).andExpect(authenticated().withRoles("USER")); + } + + @Test + public void requestWhenJeeMappableAuthoritiesInLambdaThenAuthenticatedWithMappableAuthorities() throws Exception { + this.spring.register(JeeMappableAuthoritiesConfig.class).autowire(); + Principal user = mock(Principal.class); + given(user.getName()).willReturn("user"); + // @formatter:off + MockHttpServletRequestBuilder authRequest = get("/") + .principal(user) + .with((request) -> { + request.addUserRole("ROLE_ADMIN"); + request.addUserRole("ROLE_USER"); + return request; + }); + // @formatter:on + SecurityMockMvcResultMatchers.AuthenticatedMatcher authenticatedAsUser = authenticated() + .withAuthorities(AuthorityUtils.createAuthorityList("ROLE_USER")); + this.mvc.perform(authRequest).andExpect(authenticatedAsUser); + } + + @Test + public void requestWhenCustomAuthenticatedUserDetailsServiceInLambdaThenCustomAuthenticatedUserDetailsServiceUsed() + throws Exception { + this.spring.register(JeeCustomAuthenticatedUserDetailsServiceConfig.class).autowire(); + Principal user = mock(Principal.class); + User userDetails = new User("user", "N/A", true, true, true, true, + AuthorityUtils.createAuthorityList("ROLE_USER")); + given(user.getName()).willReturn("user"); + given(JeeCustomAuthenticatedUserDetailsServiceConfig.authenticationUserDetailsService.loadUserDetails(any())) + .willReturn(userDetails); + // @formatter:off + MockHttpServletRequestBuilder authRequest = get("/") + .principal(user) + .with((request) -> { + request.addUserRole("ROLE_ADMIN"); + request.addUserRole("ROLE_USER"); + return request; + }); + // @formatter:on + this.mvc.perform(authRequest).andExpect(authenticated().withRoles("USER")); + } + @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor; @Override @@ -90,33 +167,21 @@ public class JeeConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void jeeWhenInvokedTwiceThenUsesOriginalMappableRoles() throws Exception { - this.spring.register(InvokeTwiceDoesNotOverride.class).autowire(); - Principal user = mock(Principal.class); - when(user.getName()).thenReturn("user"); - - this.mvc.perform(get("/") - .principal(user) - .with(request -> { - request.addUserRole("ROLE_ADMIN"); - request.addUserRole("ROLE_USER"); - return request; - })) - .andExpect(authenticated().withRoles("USER")); } @EnableWebSecurity static class InvokeTwiceDoesNotOverride extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -127,56 +192,27 @@ public class JeeConfigurerTests { .jee(); // @formatter:on } - } - @Test - public void requestWhenJeeMappableRolesInLambdaThenAuthenticatedWithMappableRoles() throws Exception { - this.spring.register(JeeMappableRolesConfig.class).autowire(); - Principal user = mock(Principal.class); - when(user.getName()).thenReturn("user"); - - this.mvc.perform(get("/") - .principal(user) - .with(request -> { - request.addUserRole("ROLE_ADMIN"); - request.addUserRole("ROLE_USER"); - return request; - })) - .andExpect(authenticated().withRoles("USER")); } @EnableWebSecurity public static class JeeMappableRolesConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) - .jee(jee -> + .jee((jee) -> jee .mappableRoles("USER") ); // @formatter:on } - } - @Test - public void requestWhenJeeMappableAuthoritiesInLambdaThenAuthenticatedWithMappableAuthorities() throws Exception { - this.spring.register(JeeMappableAuthoritiesConfig.class).autowire(); - Principal user = mock(Principal.class); - when(user.getName()).thenReturn("user"); - - this.mvc.perform(get("/") - .principal(user) - .with(request -> { - request.addUserRole("ROLE_ADMIN"); - request.addUserRole("ROLE_USER"); - return request; - })) - .andExpect(authenticated().withAuthorities(AuthorityUtils.createAuthorityList("ROLE_USER"))); } @EnableWebSecurity @@ -186,57 +222,40 @@ public class JeeConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) - .jee(jee -> + .jee((jee) -> jee .mappableAuthorities("ROLE_USER") ); // @formatter:on } - } - @Test - public void requestWhenCustomAuthenticatedUserDetailsServiceInLambdaThenCustomAuthenticatedUserDetailsServiceUsed() - throws Exception { - this.spring.register(JeeCustomAuthenticatedUserDetailsServiceConfig.class).autowire(); - Principal user = mock(Principal.class); - User userDetails = new User("user", "N/A", true, true, true, true, - AuthorityUtils.createAuthorityList("ROLE_USER")); - when(user.getName()).thenReturn("user"); - when(JeeCustomAuthenticatedUserDetailsServiceConfig.authenticationUserDetailsService.loadUserDetails(any())) - .thenReturn(userDetails); - - this.mvc.perform(get("/") - .principal(user) - .with(request -> { - request.addUserRole("ROLE_ADMIN"); - request.addUserRole("ROLE_USER"); - return request; - })) - .andExpect(authenticated().withRoles("USER")); } @EnableWebSecurity public static class JeeCustomAuthenticatedUserDetailsServiceConfig extends WebSecurityConfigurerAdapter { - static AuthenticationUserDetailsService authenticationUserDetailsService = - mock(AuthenticationUserDetailsService.class); + + static AuthenticationUserDetailsService authenticationUserDetailsService = mock( + AuthenticationUserDetailsService.class); @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) - .jee(jee -> + .jee((jee) -> jee .authenticatedUserDetailsService(authenticationUserDetailsService) ); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerClearSiteDataTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerClearSiteDataTests.java index a46fdaba66..0587442d9f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerClearSiteDataTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerClearSiteDataTests.java @@ -29,22 +29,20 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.authentication.logout.HeaderWriterLogoutHandler; import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter; +import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; -import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.CACHE; -import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.COOKIES; -import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.EXECUTION_CONTEXTS; -import static org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive.STORAGE; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; /** * - * Tests for {@link HeaderWriterLogoutHandler} that passing {@link ClearSiteDataHeaderWriter} - * implementation. + * Tests for {@link HeaderWriterLogoutHandler} that passing + * {@link ClearSiteDataHeaderWriter} implementation. * * @author Rafiullah Hamedy * @@ -55,8 +53,8 @@ public class LogoutConfigurerClearSiteDataTests { private static final String CLEAR_SITE_DATA_HEADER = "Clear-Site-Data"; - private static final ClearSiteDataHeaderWriter.Directive[] SOURCE = - { CACHE, COOKIES, STORAGE, EXECUTION_CONTEXTS }; + private static final Directive[] SOURCE = { Directive.CACHE, Directive.COOKIES, Directive.STORAGE, + Directive.EXECUTION_CONTEXTS }; private static final String HEADER_VALUE = "\"cache\", \"cookies\", \"storage\", \"executionContexts\""; @@ -70,36 +68,38 @@ public class LogoutConfigurerClearSiteDataTests { @WithMockUser public void logoutWhenRequestTypeGetThenHeaderNotPresentt() throws Exception { this.spring.register(HttpLogoutConfig.class).autowire(); - - this.mvc.perform(get("/logout").secure(true).with(csrf())) - .andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER)); + MockHttpServletRequestBuilder logoutRequest = get("/logout").secure(true).with(csrf()); + this.mvc.perform(logoutRequest).andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER)); } @Test @WithMockUser public void logoutWhenRequestTypePostAndNotSecureThenHeaderNotPresent() throws Exception { this.spring.register(HttpLogoutConfig.class).autowire(); - - this.mvc.perform(post("/logout").with(csrf())) - .andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER)); + MockHttpServletRequestBuilder logoutRequest = post("/logout").with(csrf()); + this.mvc.perform(logoutRequest).andExpect(header().doesNotExist(CLEAR_SITE_DATA_HEADER)); } @Test @WithMockUser public void logoutWhenRequestTypePostAndSecureThenHeaderIsPresent() throws Exception { this.spring.register(HttpLogoutConfig.class).autowire(); - - this.mvc.perform(post("/logout").secure(true).with(csrf())) - .andExpect(header().stringValues(CLEAR_SITE_DATA_HEADER, HEADER_VALUE)); + MockHttpServletRequestBuilder logoutRequest = post("/logout").secure(true).with(csrf()); + this.mvc.perform(logoutRequest).andExpect(header().stringValues(CLEAR_SITE_DATA_HEADER, HEADER_VALUE)); } @EnableWebSecurity static class HttpLogoutConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .logout() .addLogoutHandler(new HeaderWriterLogoutHandler(new ClearSiteDataHeaderWriter(SOURCE))); + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerTests.java index bb24708390..610005145d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/LogoutConfigurerTests.java @@ -19,6 +19,7 @@ package org.springframework.security.config.annotation.web.configurers; import org.apache.http.HttpHeaders; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -34,8 +35,9 @@ import org.springframework.security.web.authentication.logout.LogoutFilter; import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -65,13 +67,241 @@ public class LogoutConfigurerTests { @Test public void configureWhenDefaultLogoutSuccessHandlerForHasNullLogoutHandlerThenException() { - assertThatThrownBy(() -> this.spring.register(NullLogoutSuccessHandlerConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullLogoutSuccessHandlerConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenDefaultLogoutSuccessHandlerForHasNullLogoutHandlerInLambdaThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullLogoutSuccessHandlerInLambdaConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenDefaultLogoutSuccessHandlerForHasNullMatcherThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullMatcherConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenDefaultLogoutSuccessHandlerForHasNullMatcherInLambdaThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullMatcherInLambdaConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnLogoutFilter() { + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(LogoutFilter.class)); + } + + @Test + public void logoutWhenInvokedTwiceThenUsesOriginalLogoutUrl() throws Exception { + this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); + MockHttpServletRequestBuilder logoutRequest = post("/custom/logout").with(csrf()); + // @formatter:off + this.mvc.perform(logoutRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + // SEC-2311 + @Test + public void logoutWhenGetRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenPostRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledConfig.class).autowire(); + // @formatter:off + this.mvc.perform(post("/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenPutRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledConfig.class).autowire(); + // @formatter:off + this.mvc.perform(put("/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenDeleteRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledConfig.class).autowire(); + // @formatter:off + this.mvc.perform(delete("/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenGetRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/custom/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenPostRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); + // @formatter:off + this.mvc.perform(post("/custom/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenPutRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); + // @formatter:off + this.mvc.perform(put("/custom/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenDeleteRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); + // @formatter:off + this.mvc.perform(delete("/custom/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + @Test + public void logoutWhenCustomLogoutUrlInLambdaThenRedirectsToLogin() throws Exception { + this.spring.register(CsrfDisabledAndCustomLogoutInLambdaConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/custom/logout")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + // SEC-3170 + @Test + public void configureWhenLogoutHandlerNullThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullLogoutHandlerConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test + public void configureWhenLogoutHandlerNullInLambdaThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NullLogoutHandlerInLambdaConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + // SEC-3170 + @Test + public void rememberMeWhenRememberMeServicesNotLogoutHandlerThenRedirectsToLogin() throws Exception { + this.spring.register(RememberMeNoLogoutHandler.class).autowire(); + this.mvc.perform(post("/logout").with(csrf())).andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + } + + @Test + public void logoutWhenAcceptTextHtmlThenRedirectsToLogin() throws Exception { + this.spring.register(BasicSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder logoutRequest = post("/logout") + .with(csrf()) + .with(user("user")) + .header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML_VALUE); + this.mvc.perform(logoutRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + // gh-3282 + @Test + public void logoutWhenAcceptApplicationJsonThenReturnsStatusNoContent() throws Exception { + this.spring.register(BasicSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = post("/logout") + .with(csrf()) + .with(user("user")) + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE); + // @formatter:on + this.mvc.perform(request).andExpect(status().isNoContent()); + } + + // gh-4831 + @Test + public void logoutWhenAcceptAllThenReturnsStatusNoContent() throws Exception { + this.spring.register(BasicSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder logoutRequest = post("/logout") + .with(csrf()) + .with(user("user")) + .header(HttpHeaders.ACCEPT, MediaType.ALL_VALUE); + // @formatter:on + this.mvc.perform(logoutRequest).andExpect(status().isNoContent()); + } + + // gh-3902 + @Test + public void logoutWhenAcceptFromChromeThenRedirectsToLogin() throws Exception { + this.spring.register(BasicSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = post("/logout") + .with(csrf()) + .with(user("user")) + .header(HttpHeaders.ACCEPT, "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8"); + this.mvc.perform(request) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?logout")); + // @formatter:on + } + + // gh-3997 + @Test + public void logoutWhenXMLHttpRequestThenReturnsStatusNoContent() throws Exception { + this.spring.register(BasicSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = post("/logout") + .with(csrf()) + .with(user("user")) + .header(HttpHeaders.ACCEPT, "text/html,application/json") + .header("X-Requested-With", "XMLHttpRequest"); + // @formatter:on + this.mvc.perform(request).andExpect(status().isNoContent()); + } + + @Test + public void logoutWhenDisabledThenLogoutUrlNotFound() throws Exception { + this.spring.register(LogoutDisabledConfig.class).autowire(); + this.mvc.perform(post("/logout").with(csrf())).andExpect(status().isNotFound()); } @EnableWebSecurity static class NullLogoutSuccessHandlerConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -80,37 +310,27 @@ public class LogoutConfigurerTests { .defaultLogoutSuccessHandlerFor(null, mock(RequestMatcher.class)); // @formatter:on } - } - @Test - public void configureWhenDefaultLogoutSuccessHandlerForHasNullLogoutHandlerInLambdaThenException() { - assertThatThrownBy(() -> this.spring.register(NullLogoutSuccessHandlerInLambdaConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class NullLogoutSuccessHandlerInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .logout(logout -> + .logout((logout) -> logout.defaultLogoutSuccessHandlerFor(null, mock(RequestMatcher.class)) ); // @formatter:on } - } - @Test - public void configureWhenDefaultLogoutSuccessHandlerForHasNullMatcherThenException() { - assertThatThrownBy(() -> this.spring.register(NullMatcherConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class NullMatcherConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -119,38 +339,27 @@ public class LogoutConfigurerTests { .defaultLogoutSuccessHandlerFor(mock(LogoutSuccessHandler.class), null); // @formatter:on } - } - @Test - public void configureWhenDefaultLogoutSuccessHandlerForHasNullMatcherInLambdaThenException() { - assertThatThrownBy(() -> this.spring.register(NullMatcherInLambdaConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class NullMatcherInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .logout(logout -> + .logout((logout) -> logout.defaultLogoutSuccessHandlerFor(mock(LogoutSuccessHandler.class), null) ); // @formatter:on } - } - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnLogoutFilter() { - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(LogoutFilter.class)); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -165,27 +374,21 @@ public class LogoutConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void logoutWhenInvokedTwiceThenUsesOriginalLogoutUrl() throws Exception { - this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); - - this.mvc.perform(post("/custom/logout") - .with(csrf())) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); } @EnableWebSecurity static class DuplicateDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -204,43 +407,7 @@ public class LogoutConfigurerTests { .inMemoryAuthentication(); // @formatter:on } - } - // SEC-2311 - @Test - public void logoutWhenGetRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledConfig.class).autowire(); - - this.mvc.perform(get("/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - @Test - public void logoutWhenPostRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledConfig.class).autowire(); - - this.mvc.perform(post("/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - @Test - public void logoutWhenPutRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledConfig.class).autowire(); - - this.mvc.perform(put("/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - @Test - public void logoutWhenDeleteRequestAndCsrfDisabledThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledConfig.class).autowire(); - - this.mvc.perform(delete("/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); } @EnableWebSecurity @@ -255,42 +422,7 @@ public class LogoutConfigurerTests { .logout(); // @formatter:on } - } - @Test - public void logoutWhenGetRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); - - this.mvc.perform(get("/custom/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - @Test - public void logoutWhenPostRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); - - this.mvc.perform(post("/custom/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - @Test - public void logoutWhenPutRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); - - this.mvc.perform(put("/custom/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - @Test - public void logoutWhenDeleteRequestAndCsrfDisabledAndCustomLogoutUrlThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledAndCustomLogoutConfig.class).autowire(); - - this.mvc.perform(delete("/custom/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); } @EnableWebSecurity @@ -306,15 +438,7 @@ public class LogoutConfigurerTests { .logoutUrl("/custom/logout"); // @formatter:on } - } - @Test - public void logoutWhenCustomLogoutUrlInLambdaThenRedirectsToLogin() throws Exception { - this.spring.register(CsrfDisabledAndCustomLogoutInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/custom/logout")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); } @EnableWebSecurity @@ -326,21 +450,15 @@ public class LogoutConfigurerTests { http .csrf() .disable() - .logout(logout -> logout.logoutUrl("/custom/logout")); + .logout((logout) -> logout.logoutUrl("/custom/logout")); // @formatter:on } - } - // SEC-3170 - @Test - public void configureWhenLogoutHandlerNullThenException() { - assertThatThrownBy(() -> this.spring.register(NullLogoutHandlerConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class NullLogoutHandlerConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -349,39 +467,25 @@ public class LogoutConfigurerTests { .addLogoutHandler(null); // @formatter:on } - } - @Test - public void configureWhenLogoutHandlerNullInLambdaThenException() { - assertThatThrownBy(() -> this.spring.register(NullLogoutHandlerInLambdaConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity static class NullLogoutHandlerInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .logout(logout -> logout.addLogoutHandler(null)); + .logout((logout) -> logout.addLogoutHandler(null)); // @formatter:on } - } - // SEC-3170 - @Test - public void rememberMeWhenRememberMeServicesNotLogoutHandlerThenRedirectsToLogin() throws Exception { - this.spring.register(RememberMeNoLogoutHandler.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf())) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); } @EnableWebSecurity static class RememberMeNoLogoutHandler extends WebSecurityConfigurerAdapter { + static RememberMeServices REMEMBER_ME = mock(RememberMeServices.class); @Override @@ -392,84 +496,17 @@ public class LogoutConfigurerTests { .rememberMeServices(REMEMBER_ME); // @formatter:on } - } - @Test - public void logoutWhenAcceptTextHtmlThenRedirectsToLogin() throws Exception { - this.spring.register(BasicSecurityConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf()) - .with(user("user")) - .header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML_VALUE)) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - // gh-3282 - @Test - public void logoutWhenAcceptApplicationJsonThenReturnsStatusNoContent() throws Exception { - this.spring.register(BasicSecurityConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf()) - .with(user("user")) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)) - .andExpect(status().isNoContent()); - } - - // gh-4831 - @Test - public void logoutWhenAcceptAllThenReturnsStatusNoContent() throws Exception { - this.spring.register(BasicSecurityConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf()) - .with(user("user")) - .header(HttpHeaders.ACCEPT, MediaType.ALL_VALUE)) - .andExpect(status().isNoContent()); - } - - // gh-3902 - @Test - public void logoutWhenAcceptFromChromeThenRedirectsToLogin() throws Exception { - this.spring.register(BasicSecurityConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf()).with(user("user")) - .header(HttpHeaders.ACCEPT, "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?logout")); - } - - // gh-3997 - @Test - public void logoutWhenXMLHttpRequestThenReturnsStatusNoContent() throws Exception { - this.spring.register(BasicSecurityConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf()) - .with(user("user")) - .header(HttpHeaders.ACCEPT, "text/html,application/json") - .header("X-Requested-With", "XMLHttpRequest")) - .andExpect(status().isNoContent()); } @EnableWebSecurity static class BasicSecurityConfig extends WebSecurityConfigurerAdapter { - } - @Test - public void logoutWhenDisabledThenLogoutUrlNotFound() throws Exception { - this.spring.register(LogoutDisabledConfig.class).autowire(); - - this.mvc.perform(post("/logout") - .with(csrf())) - .andExpect(status().isNotFound()); } @EnableWebSecurity static class LogoutDisabledConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -478,5 +515,7 @@ public class LogoutConfigurerTests { .disable(); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceDebugTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceDebugTests.java index 13fdc95f49..4401d80f0b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceDebugTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceDebugTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import ch.qos.logback.classic.Level; @@ -45,6 +46,7 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder * @author Josh Cummings */ public class NamespaceDebugTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -60,10 +62,6 @@ public class NamespaceDebugTests { verify(appender, atLeastOnce()).doAppend(any(ILoggingEvent.class)); } - @EnableWebSecurity(debug=true) - static class DebugWebSecurity extends WebSecurityConfigurerAdapter { - } - @Test public void requestWhenDebugSetToFalseThenDoesNotLogDebugInformation() throws Exception { Appender appender = mockAppenderFor("Spring Security Debugger"); @@ -73,10 +71,6 @@ public class NamespaceDebugTests { verify(appender, never()).doAppend(any(ILoggingEvent.class)); } - @EnableWebSecurity - static class NoDebugWebSecurity extends WebSecurityConfigurerAdapter { - } - private Appender mockAppenderFor(String name) { Appender appender = mock(Appender.class); Logger logger = (Logger) LoggerFactory.getLogger(name); @@ -88,4 +82,15 @@ public class NamespaceDebugTests { private Class filterChainClass() { return this.spring.getContext().getBean("springSecurityFilterChain").getClass(); } + + @EnableWebSecurity(debug = true) + static class DebugWebSecurity extends WebSecurityConfigurerAdapter { + + } + + @EnableWebSecurity + static class NoDebugWebSecurity extends WebSecurityConfigurerAdapter { + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAnonymousTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAnonymousTests.java index f6bd08bf47..6ebf4a23dd 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAnonymousTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpAnonymousTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.Optional; @@ -39,7 +40,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <anonymous> attributes is present * * @author Rob Winch * @author Josh Cummings @@ -57,12 +58,36 @@ public class NamespaceHttpAnonymousTests { public void anonymousRequestWhenUsingDefaultAnonymousConfigurationThenUsesAnonymousAuthentication() throws Exception { this.spring.register(AnonymousConfig.class, AnonymousController.class).autowire(); - this.mvc.perform(get("/type")) - .andExpect(content().string(AnonymousAuthenticationToken.class.getSimpleName())); + this.mvc.perform(get("/type")).andExpect(content().string(AnonymousAuthenticationToken.class.getSimpleName())); + } + + @Test + public void anonymousRequestWhenDisablingAnonymousThenDenies() throws Exception { + this.spring.register(AnonymousDisabledConfig.class, AnonymousController.class).autowire(); + this.mvc.perform(get("/type")).andExpect(status().isForbidden()); + } + + @Test + public void requestWhenAnonymousThenSendsAnonymousConfiguredAuthorities() throws Exception { + this.spring.register(AnonymousGrantedAuthorityConfig.class, AnonymousController.class).autowire(); + this.mvc.perform(get("/type")).andExpect(content().string(AnonymousAuthenticationToken.class.getSimpleName())); + } + + @Test + public void anonymousRequestWhenAnonymousKeyConfiguredThenKeyIsUsed() throws Exception { + this.spring.register(AnonymousKeyConfig.class, AnonymousController.class).autowire(); + this.mvc.perform(get("/key")).andExpect(content().string(String.valueOf("AnonymousKeyConfig".hashCode()))); + } + + @Test + public void anonymousRequestWhenAnonymousUsernameConfiguredThenUsernameIsUsed() throws Exception { + this.spring.register(AnonymousUsernameConfig.class, AnonymousController.class).autowire(); + this.mvc.perform(get("/principal")).andExpect(content().string("AnonymousUsernameConfig")); } @EnableWebSecurity static class AnonymousConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -72,18 +97,12 @@ public class NamespaceHttpAnonymousTests { .anyRequest().denyAll(); // @formatter:on } - } - @Test - public void anonymousRequestWhenDisablingAnonymousThenDenies() - throws Exception { - this.spring.register(AnonymousDisabledConfig.class, AnonymousController.class).autowire(); - this.mvc.perform(get("/type")) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class AnonymousDisabledConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -95,6 +114,7 @@ public class NamespaceHttpAnonymousTests { // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth @@ -103,18 +123,12 @@ public class NamespaceHttpAnonymousTests { .withUser(PasswordEncodedUser.admin()); // @formatter:on } - } - @Test - public void requestWhenAnonymousThenSendsAnonymousConfiguredAuthorities() - throws Exception { - this.spring.register(AnonymousGrantedAuthorityConfig.class, AnonymousController.class).autowire(); - this.mvc.perform(get("/type")) - .andExpect(content().string(AnonymousAuthenticationToken.class.getSimpleName())); } @EnableWebSecurity static class AnonymousGrantedAuthorityConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -127,17 +141,12 @@ public class NamespaceHttpAnonymousTests { .authorities("ROLE_ANON"); // @formatter:on } - } - @Test - public void anonymousRequestWhenAnonymousKeyConfiguredThenKeyIsUsed() throws Exception { - this.spring.register(AnonymousKeyConfig.class, AnonymousController.class).autowire(); - this.mvc.perform(get("/key")) - .andExpect(content().string(String.valueOf("AnonymousKeyConfig".hashCode()))); } @EnableWebSecurity static class AnonymousKeyConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -149,17 +158,12 @@ public class NamespaceHttpAnonymousTests { .anonymous().key("AnonymousKeyConfig"); // @formatter:on } - } - @Test - public void anonymousRequestWhenAnonymousUsernameConfiguredThenUsernameIsUsed() throws Exception { - this.spring.register(AnonymousUsernameConfig.class, AnonymousController.class).autowire(); - this.mvc.perform(get("/principal")) - .andExpect(content().string("AnonymousUsernameConfig")); } @EnableWebSecurity static class AnonymousUsernameConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -171,39 +175,33 @@ public class NamespaceHttpAnonymousTests { .anonymous().principal("AnonymousUsernameConfig"); // @formatter:on } + } @RestController static class AnonymousController { + @GetMapping("/type") String type() { - return anonymousToken() - .map(AnonymousAuthenticationToken::getClass) - .map(Class::getSimpleName) - .orElse(null); + return anonymousToken().map(AnonymousAuthenticationToken::getClass).map(Class::getSimpleName).orElse(null); } @GetMapping("/key") String key() { - return anonymousToken() - .map(AnonymousAuthenticationToken::getKeyHash) - .map(String::valueOf) - .orElse(null); + return anonymousToken().map(AnonymousAuthenticationToken::getKeyHash).map(String::valueOf).orElse(null); } @GetMapping("/principal") String principal() { - return anonymousToken() - .map(AnonymousAuthenticationToken::getName) - .orElse(null); + return anonymousToken().map(AnonymousAuthenticationToken::getName).orElse(null); } Optional anonymousToken() { - return Optional.of(SecurityContextHolder.getContext()) - .map(SecurityContext::getAuthentication) - .filter(a -> a instanceof AnonymousAuthenticationToken) + return Optional.of(SecurityContextHolder.getContext()).map(SecurityContext::getAuthentication) + .filter((a) -> a instanceof AnonymousAuthenticationToken) .map(AnonymousAuthenticationToken.class::cast); } + } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpBasicTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpBasicTests.java index 9a66e356fd..ac85f2b260 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpBasicTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpBasicTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -34,6 +35,7 @@ import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -45,7 +47,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <http-basic> attributes is present * * @author Rob Winch * @author Josh Cummings @@ -64,74 +66,29 @@ public class NamespaceHttpBasicTests { @Test public void basicAuthenticationWhenUsingDefaultsThenMatchesNamespace() throws Exception { this.spring.register(HttpBasicConfig.class, UserConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); - - this.mvc.perform(get("/") - .with(httpBasic("user", "invalid"))) + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + MockHttpServletRequestBuilder requestWithInvalidPassword = get("/").with(httpBasic("user", "invalid")); + // @formatter:off + this.mvc.perform(requestWithInvalidPassword) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Basic realm=\"Realm\"")); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()); - } - - @Configuration - static class UserConfig { - @Bean - public UserDetailsService userDetailsService() { - return new InMemoryUserDetailsManager( - User.withDefaultPasswordEncoder() - .username("user") - .password("password") - .roles("USER") - .build() - ); - } - } - - @EnableWebSecurity - static class HttpBasicConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .httpBasic(); - } + // @formatter:on + MockHttpServletRequestBuilder requestWithValidPassword = get("/").with(httpBasic("user", "password")); + this.mvc.perform(requestWithValidPassword).andExpect(status().isNotFound()); } @Test public void basicAuthenticationWhenUsingDefaultsInLambdaThenMatchesNamespace() throws Exception { this.spring.register(HttpBasicLambdaConfig.class, UserConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()); - - this.mvc.perform(get("/") - .with(httpBasic("user", "invalid"))) + this.mvc.perform(get("/")).andExpect(status().isUnauthorized()); + MockHttpServletRequestBuilder requestWithInvalidPassword = get("/").with(httpBasic("user", "invalid")); + // @formatter:off + this.mvc.perform(requestWithInvalidPassword) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Basic realm=\"Realm\"")); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()); - } - - @EnableWebSecurity - static class HttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .anyRequest().hasRole("USER") - ) - .httpBasic(withDefaults()); - // @formatter:on - } + // @formatter:on + MockHttpServletRequestBuilder requestWithValidPassword = get("/").with(httpBasic("user", "password")); + this.mvc.perform(requestWithValidPassword).andExpect(status().isNotFound()); } /** @@ -140,48 +97,23 @@ public class NamespaceHttpBasicTests { @Test public void basicAuthenticationWhenUsingCustomRealmThenMatchesNamespace() throws Exception { this.spring.register(CustomHttpBasicConfig.class, UserConfig.class).autowire(); - - this.mvc.perform(get("/") - .with(httpBasic("user", "invalid"))) + MockHttpServletRequestBuilder requestWithInvalidPassword = get("/").with(httpBasic("user", "invalid")); + // @formatter:off + this.mvc.perform(requestWithInvalidPassword) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Basic realm=\"Custom Realm\"")); - } - - @EnableWebSecurity - static class CustomHttpBasicConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .httpBasic().realmName("Custom Realm"); - } + // @formatter:on } @Test public void basicAuthenticationWhenUsingCustomRealmInLambdaThenMatchesNamespace() throws Exception { this.spring.register(CustomHttpBasicLambdaConfig.class, UserConfig.class).autowire(); - - this.mvc.perform(get("/") - .with(httpBasic("user", "invalid"))) + MockHttpServletRequestBuilder requestWithInvalidPassword = get("/").with(httpBasic("user", "invalid")); + // @formatter:off + this.mvc.perform(requestWithInvalidPassword) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Basic realm=\"Custom Realm\"")); - } - - @EnableWebSecurity - static class CustomHttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .anyRequest().hasRole("USER") - ) - .httpBasic(httpBasicConfig -> httpBasicConfig.realmName("Custom Realm")); - // @formatter:on - } + // @formatter:on } /** @@ -190,135 +122,206 @@ public class NamespaceHttpBasicTests { @Test public void basicAuthenticationWhenUsingAuthenticationDetailsSourceRefThenMatchesNamespace() throws Exception { this.spring.register(AuthenticationDetailsSourceHttpBasicConfig.class, UserConfig.class).autowire(); - - AuthenticationDetailsSource source = - this.spring.getContext().getBean(AuthenticationDetailsSource.class); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))); - + AuthenticationDetailsSource source = this.spring.getContext() + .getBean(AuthenticationDetailsSource.class); + this.mvc.perform(get("/").with(httpBasic("user", "password"))); verify(source).buildDetails(any(HttpServletRequest.class)); } - @EnableWebSecurity - static class AuthenticationDetailsSourceHttpBasicConfig extends WebSecurityConfigurerAdapter { - AuthenticationDetailsSource authenticationDetailsSource = - mock(AuthenticationDetailsSource.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .httpBasic() - .authenticationDetailsSource(this.authenticationDetailsSource); - } - - @Bean - AuthenticationDetailsSource authenticationDetailsSource() { - return this.authenticationDetailsSource; - } - } - @Test public void basicAuthenticationWhenUsingAuthenticationDetailsSourceRefInLambdaThenMatchesNamespace() throws Exception { this.spring.register(AuthenticationDetailsSourceHttpBasicLambdaConfig.class, UserConfig.class).autowire(); - - AuthenticationDetailsSource source = - this.spring.getContext().getBean(AuthenticationDetailsSource.class); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))); - + AuthenticationDetailsSource source = this.spring.getContext() + .getBean(AuthenticationDetailsSource.class); + this.mvc.perform(get("/").with(httpBasic("user", "password"))); verify(source).buildDetails(any(HttpServletRequest.class)); } - @EnableWebSecurity - static class AuthenticationDetailsSourceHttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { - AuthenticationDetailsSource authenticationDetailsSource = - mock(AuthenticationDetailsSource.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .httpBasic(httpBasicConfig -> - httpBasicConfig.authenticationDetailsSource(this.authenticationDetailsSource)); - // @formatter:on - } - - @Bean - AuthenticationDetailsSource authenticationDetailsSource() { - return this.authenticationDetailsSource; - } - } - /** * http/http-basic@entry-point-ref */ @Test public void basicAuthenticationWhenUsingEntryPointRefThenMatchesNamespace() throws Exception { this.spring.register(EntryPointRefHttpBasicConfig.class, UserConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().is(999)); + this.mvc.perform(get("/").with(httpBasic("user", "invalid"))).andExpect(status().is(999)); + this.mvc.perform(get("/").with(httpBasic("user", "password"))).andExpect(status().isNotFound()); + } - this.mvc.perform(get("/")) - .andExpect(status().is(999)); + @Test + public void basicAuthenticationWhenUsingEntryPointRefInLambdaThenMatchesNamespace() throws Exception { + this.spring.register(EntryPointRefHttpBasicLambdaConfig.class, UserConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(status().is(999)); + this.mvc.perform(get("/").with(httpBasic("user", "invalid"))).andExpect(status().is(999)); + this.mvc.perform(get("/").with(httpBasic("user", "password"))).andExpect(status().isNotFound()); + } - this.mvc.perform(get("/") - .with(httpBasic("user", "invalid"))) - .andExpect(status().is(999)); + @Configuration + static class UserConfig { + + @Bean + UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager( + // @formatter:off + User.withDefaultPasswordEncoder() + .username("user") + .password("password") + .roles("USER") + .build() + // @formatter:on + ); + } + + } + + @EnableWebSecurity + static class HttpBasicConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .httpBasic(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class HttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .anyRequest().hasRole("USER") + ) + .httpBasic(withDefaults()); + // @formatter:on + } + + } + + @EnableWebSecurity + static class CustomHttpBasicConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .httpBasic().realmName("Custom Realm"); + // @formatter:on + } + + } + + @EnableWebSecurity + static class CustomHttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .anyRequest().hasRole("USER") + ) + .httpBasic((httpBasicConfig) -> httpBasicConfig.realmName("Custom Realm")); + // @formatter:on + } + + } + + @EnableWebSecurity + static class AuthenticationDetailsSourceHttpBasicConfig extends WebSecurityConfigurerAdapter { + + AuthenticationDetailsSource authenticationDetailsSource = mock( + AuthenticationDetailsSource.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .httpBasic() + .authenticationDetailsSource(this.authenticationDetailsSource); + // @formatter:on + } + + @Bean + AuthenticationDetailsSource authenticationDetailsSource() { + return this.authenticationDetailsSource; + } + + } + + @EnableWebSecurity + static class AuthenticationDetailsSourceHttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { + + AuthenticationDetailsSource authenticationDetailsSource = mock( + AuthenticationDetailsSource.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .httpBasic((httpBasicConfig) -> + httpBasicConfig.authenticationDetailsSource(this.authenticationDetailsSource)); + // @formatter:on + } + + @Bean + AuthenticationDetailsSource authenticationDetailsSource() { + return this.authenticationDetailsSource; + } - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()); } @EnableWebSecurity static class EntryPointRefHttpBasicConfig extends WebSecurityConfigurerAdapter { - AuthenticationEntryPoint authenticationEntryPoint = - (request, response, ex) -> response.setStatus(999); + + AuthenticationEntryPoint authenticationEntryPoint = (request, response, ex) -> response.setStatus(999); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") .and() .httpBasic() .authenticationEntryPoint(this.authenticationEntryPoint); + // @formatter:on } - } - @Test - public void basicAuthenticationWhenUsingEntryPointRefInLambdaThenMatchesNamespace() throws Exception { - this.spring.register(EntryPointRefHttpBasicLambdaConfig.class, UserConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().is(999)); - - this.mvc.perform(get("/") - .with(httpBasic("user", "invalid"))) - .andExpect(status().is(999)); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()); } @EnableWebSecurity static class EntryPointRefHttpBasicLambdaConfig extends WebSecurityConfigurerAdapter { - AuthenticationEntryPoint authenticationEntryPoint = - (request, response, ex) -> response.setStatus(999); + + AuthenticationEntryPoint authenticationEntryPoint = (request, response, ex) -> response.setStatus(999); @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) - .httpBasic(httpBasicConfig -> + .httpBasic((httpBasicConfig) -> httpBasicConfig.authenticationEntryPoint(this.authenticationEntryPoint)); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpCustomFilterTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpCustomFilterTests.java index b4d07cc1ab..f3623fd1a0 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpCustomFilterTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpCustomFilterTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; +package org.springframework.security.config.annotation.web.configurers; import java.io.IOException; import java.util.List; import java.util.stream.Collectors; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -38,6 +39,7 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.User; +import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.FilterChainProxy; @@ -47,7 +49,8 @@ import org.springframework.web.filter.OncePerRequestFilter; import static org.assertj.core.api.Assertions.assertThat; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <custom-filter> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -64,130 +67,166 @@ public class NamespaceHttpCustomFilterTests { assertThatFilters().containsSubsequence(CustomFilter.class, UsernamePasswordAuthenticationFilter.class); } - @EnableWebSecurity - static class CustomFilterBeforeConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .addFilterBefore(new CustomFilter(), UsernamePasswordAuthenticationFilter.class) - .formLogin(); - } - } - @Test public void getFiltersWhenFilterAddedAfterThenBehaviorMatchesNamespace() { this.spring.register(CustomFilterAfterConfig.class, UserDetailsServiceConfig.class).autowire(); assertThatFilters().containsSubsequence(UsernamePasswordAuthenticationFilter.class, CustomFilter.class); } - @EnableWebSecurity - static class CustomFilterAfterConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .addFilterAfter(new CustomFilter(), UsernamePasswordAuthenticationFilter.class) - .formLogin(); - } - } - @Test public void getFiltersWhenFilterAddedThenBehaviorMatchesNamespace() { this.spring.register(CustomFilterPositionConfig.class, UserDetailsServiceConfig.class).autowire(); assertThatFilters().containsExactly(CustomFilter.class); } - @EnableWebSecurity - static class CustomFilterPositionConfig extends WebSecurityConfigurerAdapter { - CustomFilterPositionConfig() { - // do not add the default filters to make testing easier - super(true); - } - - protected void configure(HttpSecurity http) { - http - // this works so long as the CustomFilter extends one of the standard filters - // if not, use addFilterBefore or addFilterAfter - .addFilter(new CustomFilter()); - } - } - - @Test public void getFiltersWhenFilterAddedAtPositionThenBehaviorMatchesNamespace() { this.spring.register(CustomFilterPositionAtConfig.class, UserDetailsServiceConfig.class).autowire(); assertThatFilters().containsExactly(OtherCustomFilter.class); } - @EnableWebSecurity - static class CustomFilterPositionAtConfig extends WebSecurityConfigurerAdapter { - CustomFilterPositionAtConfig() { - // do not add the default filters to make testing easier - super(true); - } - - protected void configure(HttpSecurity http) { - http - .addFilterAt(new OtherCustomFilter(), UsernamePasswordAuthenticationFilter.class); - } - } - @Test public void getFiltersWhenCustomAuthenticationManagerThenBehaviorMatchesNamespace() { this.spring.register(NoAuthenticationManagerInHttpConfigurationConfig.class).autowire(); assertThatFilters().startsWith(CustomFilter.class); } + private ListAssert> assertThatFilters() { + FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); + List> filters = filterChain.getFilters("/").stream().map(Object::getClass) + .collect(Collectors.toList()); + return assertThat(filters); + } + + @EnableWebSecurity + static class CustomFilterBeforeConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterBefore(new CustomFilter(), UsernamePasswordAuthenticationFilter.class) + .formLogin(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class CustomFilterAfterConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .addFilterAfter(new CustomFilter(), UsernamePasswordAuthenticationFilter.class) + .formLogin(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class CustomFilterPositionConfig extends WebSecurityConfigurerAdapter { + + CustomFilterPositionConfig() { + // do not add the default filters to make testing easier + super(true); + } + + @Override + protected void configure(HttpSecurity http) { + // @formatter:off + http + // this works so long as the CustomFilter extends one of the standard filters + // if not, use addFilterBefore or addFilterAfter + .addFilter(new CustomFilter()); + // @formatter:on + } + + } + + @EnableWebSecurity + static class CustomFilterPositionAtConfig extends WebSecurityConfigurerAdapter { + + CustomFilterPositionAtConfig() { + // do not add the default filters to make testing easier + super(true); + } + + @Override + protected void configure(HttpSecurity http) { + // @formatter:off + http + .addFilterAt(new OtherCustomFilter(), UsernamePasswordAuthenticationFilter.class); + // @formatter:on + } + + } + @EnableWebSecurity static class NoAuthenticationManagerInHttpConfigurationConfig extends WebSecurityConfigurerAdapter { + NoAuthenticationManagerInHttpConfigurationConfig() { super(true); } + @Override protected AuthenticationManager authenticationManager() { return new CustomAuthenticationManager(); } @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") .and() .addFilterBefore(new CustomFilter(), UsernamePasswordAuthenticationFilter.class); + // @formatter:on } + } @Configuration static class UserDetailsServiceConfig { + @Bean - public UserDetailsService userDetailsService() { - return new InMemoryUserDetailsManager( - User.withDefaultPasswordEncoder() - .username("user") - .password("password") - .roles("USER") - .build()); + UserDetailsService userDetailsService() { + // @formatter:off + UserDetails user = User.withDefaultPasswordEncoder() + .username("user") + .password("password") + .roles("USER") + .build(); + // @formatter:on + return new InMemoryUserDetailsManager(user); } + + } + + static class CustomFilter extends UsernamePasswordAuthenticationFilter { + } - static class CustomFilter extends UsernamePasswordAuthenticationFilter {} static class OtherCustomFilter extends OncePerRequestFilter { - protected void doFilterInternal( - HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { filterChain.doFilter(request, response); } + } static class CustomAuthenticationManager implements AuthenticationManager { - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { return null; } + } - private ListAssert> assertThatFilters() { - FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); - List> filters = filterChain.getFilters("/").stream() - .map(Object::getClass).collect(Collectors.toList()); - return assertThat(filters); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpExpressionHandlerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpExpressionHandlerTests.java index c031809554..8255188f71 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpExpressionHandlerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpExpressionHandlerTests.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; +package org.springframework.security.config.annotation.web.configurers; import java.security.Principal; @@ -46,7 +46,8 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <expression-handler> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -70,44 +71,53 @@ public class NamespaceHttpExpressionHandlerTests { verifyBean("expressionParser", ExpressionParser.class).parseExpression("hasRole('USER')"); } + private T verifyBean(String beanName, Class beanClass) { + return verify(this.spring.getContext().getBean(beanName, beanClass)); + } + @EnableWebMvc @EnableWebSecurity private static class ExpressionHandlerConfig extends WebSecurityConfigurerAdapter { - ExpressionHandlerConfig() {} + + ExpressionHandlerConfig() { + } @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("rod").password("password").roles("USER", "ADMIN"); + // @formatter:on } @Override protected void configure(HttpSecurity http) throws Exception { DefaultWebSecurityExpressionHandler handler = new DefaultWebSecurityExpressionHandler(); handler.setExpressionParser(expressionParser()); - + // @formatter:off http .authorizeRequests() .expressionHandler(handler) .anyRequest().access("hasRole('USER')"); + // @formatter:on } @Bean ExpressionParser expressionParser() { return spy(new SpelExpressionParser()); } + } @RestController private static class ExpressionHandlerController { + @GetMapping("/whoami") String whoami(Principal user) { return user.getName(); } + } - private T verifyBean(String beanName, Class beanClass) { - return verify(this.spring.getContext().getBean(beanName, beanClass)); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFirewallTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFirewallTests.java index 63876be7f9..61b7b91598 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFirewallTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFirewallTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -32,11 +33,12 @@ import org.springframework.security.web.firewall.HttpFirewall; import org.springframework.security.web.firewall.RequestRejectedException; import org.springframework.test.web.servlet.MockMvc; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <http-firewall> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -52,53 +54,59 @@ public class NamespaceHttpFirewallTests { @Test public void requestWhenPathContainsDoubleDotsThenBehaviorMatchesNamespace() { this.rule.register(HttpFirewallConfig.class).autowire(); - assertThatCode(() -> this.mvc.perform(get("/public/../private/"))) - .isInstanceOf(RequestRejectedException.class); + assertThatExceptionOfType(RequestRejectedException.class) + .isThrownBy(() -> this.mvc.perform(get("/public/../private/"))); } - @EnableWebSecurity - static class HttpFirewallConfig {} - @Test public void requestWithCustomFirewallThenBehaviorMatchesNamespace() { this.rule.register(CustomHttpFirewallConfig.class).autowire(); - assertThatCode(() -> this.mvc.perform(get("/").param("deny", "true"))) - .isInstanceOf(RequestRejectedException.class); - } - - @EnableWebSecurity - static class CustomHttpFirewallConfig extends WebSecurityConfigurerAdapter { - @Override - public void configure(WebSecurity web) { - web - .httpFirewall(new CustomHttpFirewall()); - } + assertThatExceptionOfType(RequestRejectedException.class) + .isThrownBy(() -> this.mvc.perform(get("/").param("deny", "true"))); } @Test public void requestWithCustomFirewallBeanThenBehaviorMatchesNamespace() { this.rule.register(CustomHttpFirewallBeanConfig.class).autowire(); - assertThatCode(() -> this.mvc.perform(get("/").param("deny", "true"))) - .isInstanceOf(RequestRejectedException.class); + assertThatExceptionOfType(RequestRejectedException.class) + .isThrownBy(() -> this.mvc.perform(get("/").param("deny", "true"))); + } + + @EnableWebSecurity + static class HttpFirewallConfig { + + } + + @EnableWebSecurity + static class CustomHttpFirewallConfig extends WebSecurityConfigurerAdapter { + + @Override + public void configure(WebSecurity web) { + web.httpFirewall(new CustomHttpFirewall()); + } + } @EnableWebSecurity static class CustomHttpFirewallBeanConfig { + @Bean HttpFirewall firewall() { return new CustomHttpFirewall(); } + } static class CustomHttpFirewall extends DefaultHttpFirewall { @Override - public FirewalledRequest getFirewalledRequest(HttpServletRequest request) - throws RequestRejectedException { + public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException { if (request.getParameter("deny") != null) { throw new RequestRejectedException("custom rejection"); } return super.getFirewalledRequest(request); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFormLoginTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFormLoginTests.java index f33623f1ce..af53890e84 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpFormLoginTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -35,6 +36,7 @@ import org.springframework.security.web.authentication.SavedRequestAwareAuthenti import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; @@ -45,7 +47,7 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <form-login> attributes is present * * @author Rob Winch * @author Josh Cummings @@ -59,23 +61,52 @@ public class NamespaceHttpFormLoginTests { @Autowired MockMvc mvc; - @Test public void formLoginWhenDefaultConfigurationThenMatchesNamespace() throws Exception { this.spring.register(FormLoginConfig.class, UserDetailsServiceConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost/login")); - - this.mvc.perform(post("/login") - .with(csrf())) - .andExpect(redirectedUrl("/login?error")); - - this.mvc.perform(post("/login") + this.mvc.perform(get("/")).andExpect(redirectedUrl("http://localhost/login")); + this.mvc.perform(post("/login").with(csrf())).andExpect(redirectedUrl("/login?error")); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") .param("password", "password") - .with(csrf())) - .andExpect(redirectedUrl("/")); + .with(csrf()); + // @formatter:on + this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + } + + @Test + public void formLoginWithCustomEndpointsThenBehaviorMatchesNamespace() throws Exception { + this.spring.register(FormLoginCustomConfig.class, UserDetailsServiceConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(redirectedUrl("http://localhost/authentication/login")); + this.mvc.perform(post("/authentication/login/process").with(csrf())) + .andExpect(redirectedUrl("/authentication/login?failed")); + // @formatter:off + MockHttpServletRequestBuilder request = post("/authentication/login/process") + .param("username", "user") + .param("password", "password") + .with(csrf()); + // @formatter:on + this.mvc.perform(request).andExpect(redirectedUrl("/default")); + } + + @Test + public void formLoginWithCustomHandlersThenBehaviorMatchesNamespace() throws Exception { + this.spring.register(FormLoginCustomRefsConfig.class, UserDetailsServiceConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(redirectedUrl("http://localhost/login")); + this.mvc.perform(post("/login").with(csrf())).andExpect(redirectedUrl("/custom/failure")); + verifyBean(WebAuthenticationDetailsSource.class).buildDetails(any(HttpServletRequest.class)); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .with(csrf()); + // @formatter:on + this.mvc.perform(loginRequest).andExpect(redirectedUrl("/custom/targetUrl")); + } + + private T verifyBean(Class beanClass) { + return verify(this.spring.getContext().getBean(beanClass)); } @EnableWebSecurity @@ -83,43 +114,29 @@ public class NamespaceHttpFormLoginTests { @Override public void configure(WebSecurity web) { - web - .ignoring() - .antMatchers("/resources/**"); + web.ignoring().antMatchers("/resources/**"); } @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") .and() .formLogin(); + // @formatter:on } - } - @Test - public void formLoginWithCustomEndpointsThenBehaviorMatchesNamespace() throws Exception { - this.spring.register(FormLoginCustomConfig.class, UserDetailsServiceConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost/authentication/login")); - - this.mvc.perform(post("/authentication/login/process") - .with(csrf())) - .andExpect(redirectedUrl("/authentication/login?failed")); - - this.mvc.perform(post("/authentication/login/process") - .param("username", "user") - .param("password", "password") - .with(csrf())) - .andExpect(redirectedUrl("/default")); } @EnableWebSecurity static class FormLoginCustomConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { boolean alwaysUseDefaultSuccess = true; + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") @@ -131,35 +148,19 @@ public class NamespaceHttpFormLoginTests { .failureUrl("/authentication/login?failed") // form-login@authentication-failure-url .loginProcessingUrl("/authentication/login/process") // form-login@login-processing-url .defaultSuccessUrl("/default", alwaysUseDefaultSuccess); // form-login@default-target-url / form-login@always-use-default-target + // @formatter:on } - } - @Test - public void formLoginWithCustomHandlersThenBehaviorMatchesNamespace() throws Exception { - this.spring.register(FormLoginCustomRefsConfig.class, UserDetailsServiceConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost/login")); - - this.mvc.perform(post("/login") - .with(csrf())) - .andExpect(redirectedUrl("/custom/failure")); - verifyBean(WebAuthenticationDetailsSource.class).buildDetails(any(HttpServletRequest.class)); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .with(csrf())) - .andExpect(redirectedUrl("/custom/targetUrl")); } @EnableWebSecurity static class FormLoginCustomRefsConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - SavedRequestAwareAuthenticationSuccessHandler successHandler = - new SavedRequestAwareAuthenticationSuccessHandler(); - successHandler.setDefaultTargetUrl("/custom/targetUrl"); + @Override + protected void configure(HttpSecurity http) throws Exception { + SavedRequestAwareAuthenticationSuccessHandler successHandler = new SavedRequestAwareAuthenticationSuccessHandler(); + successHandler.setDefaultTargetUrl("/custom/targetUrl"); + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") @@ -170,28 +171,31 @@ public class NamespaceHttpFormLoginTests { .successHandler(successHandler) // form-login@authentication-success-handler-ref .authenticationDetailsSource(authenticationDetailsSource()) // form-login@authentication-details-source-ref .and(); + // @formatter:on } @Bean WebAuthenticationDetailsSource authenticationDetailsSource() { return spy(WebAuthenticationDetailsSource.class); } + } @Configuration static class UserDetailsServiceConfig { + @Bean - public UserDetailsService userDetailsService() { + UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager( + // @formatter:off User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build()); + // @formatter:on } + } - private T verifyBean(Class beanClass) { - return verify(this.spring.getContext().getBean(beanClass)); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.java index 6496668b7d..c9f11e767f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.net.URI; @@ -40,15 +41,15 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <headers> attributes is present * * @author Rob Winch * @author Josh Cummings * */ public class NamespaceHttpHeadersTests { - static final Map defaultHeaders = new LinkedHashMap<>(); + static final Map defaultHeaders = new LinkedHashMap<>(); static { defaultHeaders.put("X-Content-Type-Options", "nosniff"); defaultHeaders.put("X-Frame-Options", "DENY"); @@ -58,7 +59,6 @@ public class NamespaceHttpHeadersTests { defaultHeaders.put("Pragma", "no-cache"); defaultHeaders.put("X-XSS-Protection", "1; mode=block"); } - @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -68,208 +68,66 @@ public class NamespaceHttpHeadersTests { @Test public void secureRequestWhenDefaultConfigThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HeadersDefaultConfig.class).autowire(); - - this.mvc.perform(get("/").secure(true)) - .andExpect(includesDefaults()); - } - - @EnableWebSecurity - static class HeadersDefaultConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers(); - } + this.mvc.perform(get("/").secure(true)).andExpect(includesDefaults()); } @Test public void secureRequestWhenCacheControlOnlyThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HeadersCacheControlConfig.class).autowire(); - - this.mvc.perform(get("/").secure(true)) - .andExpect(includes("Cache-Control", "Expires", "Pragma")); - } - - @EnableWebSecurity - static class HeadersCacheControlConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - .defaultsDisabled() - .cacheControl(); - } + this.mvc.perform(get("/").secure(true)).andExpect(includes("Cache-Control", "Expires", "Pragma")); } @Test public void secureRequestWhenHstsOnlyThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HstsConfig.class).autowire(); - - this.mvc.perform(get("/").secure(true)) - .andExpect(includes("Strict-Transport-Security")); - } - - @EnableWebSecurity - static class HstsConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - .defaultsDisabled() - .httpStrictTransportSecurity(); - } + this.mvc.perform(get("/").secure(true)).andExpect(includes("Strict-Transport-Security")); } @Test public void requestWhenHstsCustomThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HstsCustomConfig.class).autowire(); - this.mvc.perform(get("/")) .andExpect(includes(Collections.singletonMap("Strict-Transport-Security", "max-age=15768000"))); } - @EnableWebSecurity - static class HstsCustomConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - // hsts@request-matcher-ref, hsts@max-age-seconds, hsts@include-subdomains - .defaultsDisabled() - .httpStrictTransportSecurity() - .requestMatcher(AnyRequestMatcher.INSTANCE) - .maxAgeInSeconds(15768000) - .includeSubDomains(false); - } - } - @Test public void requestWhenFrameOptionsSameOriginThenBehaviorMatchesNamespace() throws Exception { this.spring.register(FrameOptionsSameOriginConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(includes(Collections.singletonMap("X-Frame-Options", "SAMEORIGIN"))); + this.mvc.perform(get("/")).andExpect(includes(Collections.singletonMap("X-Frame-Options", "SAMEORIGIN"))); } - @EnableWebSecurity - static class FrameOptionsSameOriginConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - // frame-options@policy=SAMEORIGIN - .defaultsDisabled() - .frameOptions() - .sameOrigin(); - } - } - - // frame-options@strategy, frame-options@value, frame-options@parameter are not provided instead use frame-options@ref - @Test public void requestWhenFrameOptionsAllowFromThenBehaviorMatchesNamespace() throws Exception { this.spring.register(FrameOptionsAllowFromConfig.class).autowire(); - this.mvc.perform(get("/")) .andExpect(includes(Collections.singletonMap("X-Frame-Options", "ALLOW-FROM https://example.com"))); } - @EnableWebSecurity - static class FrameOptionsAllowFromConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - // frame-options@ref - .defaultsDisabled() - .addHeaderWriter(new XFrameOptionsHeaderWriter( - new StaticAllowFromStrategy(URI.create("https://example.com")))); - } - } - @Test public void requestWhenXssOnlyThenBehaviorMatchesNamespace() throws Exception { this.spring.register(XssProtectionConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(includes("X-XSS-Protection")); - } - - @EnableWebSecurity - static class XssProtectionConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - // xss-protection - .defaultsDisabled() - .xssProtection(); - } + this.mvc.perform(get("/")).andExpect(includes("X-XSS-Protection")); } @Test public void requestWhenXssCustomThenBehaviorMatchesNamespace() throws Exception { this.spring.register(XssProtectionCustomConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(includes(Collections.singletonMap("X-XSS-Protection", "1"))); - } - - @EnableWebSecurity - static class XssProtectionCustomConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - // xss-protection@enabled and xss-protection@block - .defaultsDisabled() - .xssProtection() - .xssProtectionEnabled(true) - .block(false); - } + this.mvc.perform(get("/")).andExpect(includes(Collections.singletonMap("X-XSS-Protection", "1"))); } @Test public void requestWhenXContentTypeOptionsOnlyThenBehaviorMatchesNamespace() throws Exception { this.spring.register(ContentTypeOptionsConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(includes("X-Content-Type-Options")); + this.mvc.perform(get("/")).andExpect(includes("X-Content-Type-Options")); } - @EnableWebSecurity - static class ContentTypeOptionsConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - // content-type-options - .defaultsDisabled() - .contentTypeOptions(); - } - } - - // header@name / header@value are not provided instead use header@ref - @Test public void requestWhenCustomHeaderOnlyThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HeaderRefConfig.class).autowire(); - this.mvc.perform(get("/")) .andExpect(includes(Collections.singletonMap("customHeaderName", "customHeaderValue"))); } - @EnableWebSecurity - static class HeaderRefConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .headers() - .defaultsDisabled() - .addHeaderWriter(new StaticHeadersWriter("customHeaderName", "customHeaderValue")); - } - } - private static ResultMatcher includesDefaults() { return includes(defaultHeaders); } @@ -283,11 +141,173 @@ public class NamespaceHttpHeadersTests { } private static ResultMatcher includes(Map headers, String... headerNames) { - return result -> { + return (result) -> { assertThat(result.getResponse().getHeaderNames()).hasSameSizeAs(headerNames); for (String headerName : headerNames) { header().string(headerName, headers.get(headerName)).match(result); } }; } + + @EnableWebSecurity + static class HeadersDefaultConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class HeadersCacheControlConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + .defaultsDisabled() + .cacheControl(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class HstsConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + .defaultsDisabled() + .httpStrictTransportSecurity(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class HstsCustomConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + // hsts@request-matcher-ref, hsts@max-age-seconds, hsts@include-subdomains + .defaultsDisabled() + .httpStrictTransportSecurity() + .requestMatcher(AnyRequestMatcher.INSTANCE) + .maxAgeInSeconds(15768000) + .includeSubDomains(false); + // @formatter:on + } + + } + + @EnableWebSecurity + static class FrameOptionsSameOriginConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + // frame-options@policy=SAMEORIGIN + .defaultsDisabled() + .frameOptions() + .sameOrigin(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class FrameOptionsAllowFromConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + // frame-options@ref + .defaultsDisabled() + .addHeaderWriter(new XFrameOptionsHeaderWriter( + new StaticAllowFromStrategy(URI.create("https://example.com")))); + // @formatter:on + } + + } + + @EnableWebSecurity + static class XssProtectionConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + // xss-protection + .defaultsDisabled() + .xssProtection(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class XssProtectionCustomConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + // xss-protection@enabled and xss-protection@block + .defaultsDisabled() + .xssProtection() + .xssProtectionEnabled(true) + .block(false); + // @formatter:on + } + + } + + @EnableWebSecurity + static class ContentTypeOptionsConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + // content-type-options + .defaultsDisabled() + .contentTypeOptions(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class HeaderRefConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .headers() + .defaultsDisabled() + .addHeaderWriter(new StaticHeadersWriter("customHeaderName", "customHeaderValue")); + // @formatter:on + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpInterceptUrlTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpInterceptUrlTests.java index 26fcdfcac6..b52eee3ec4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpInterceptUrlTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpInterceptUrlTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; @@ -29,6 +30,7 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -41,7 +43,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <intercept-url> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -58,61 +61,45 @@ public class NamespaceHttpInterceptUrlTests { @Test public void unauthenticatedRequestWhenUrlRequiresAuthenticationThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HttpInterceptUrlConfig.class).autowire(); - - this.mvc.perform(get("/users")) - .andExpect(status().isForbidden()); + this.mvc.perform(get("/users")).andExpect(status().isForbidden()); } @Test public void authenticatedRequestWhenUrlRequiresElevatedPrivilegesThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HttpInterceptUrlConfig.class).autowire(); - - - this.mvc.perform(get("/users") - .with(authentication(user("ROLE_USER")))) - .andExpect(status().isForbidden()); + MockHttpServletRequestBuilder requestWithUser = get("/users").with(authentication(user("ROLE_USER"))); + this.mvc.perform(requestWithUser).andExpect(status().isForbidden()); } @Test public void authenticatedRequestWhenAuthorizedThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HttpInterceptUrlConfig.class, BaseController.class).autowire(); - - this.mvc.perform(get("/users") - .with(authentication(user("ROLE_ADMIN")))) - .andExpect(status().isOk()) - .andReturn(); + MockHttpServletRequestBuilder requestWithAdmin = get("/users").with(authentication(user("ROLE_ADMIN"))); + this.mvc.perform(requestWithAdmin).andExpect(status().isOk()).andReturn(); } @Test public void requestWhenMappedByPostInterceptUrlThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HttpInterceptUrlConfig.class, BaseController.class).autowire(); - - this.mvc.perform(get("/admin/post") - .with(authentication(user("ROLE_USER")))) - .andExpect(status().isOk()); - - this.mvc.perform(post("/admin/post") - .with(authentication(user("ROLE_USER")))) - .andExpect(status().isForbidden()); - - this.mvc.perform(post("/admin/post") - .with(csrf()) - .with(authentication(user("ROLE_ADMIN")))) - .andExpect(status().isOk()); + MockHttpServletRequestBuilder getWithUser = get("/admin/post").with(authentication(user("ROLE_USER"))); + this.mvc.perform(getWithUser).andExpect(status().isOk()); + MockHttpServletRequestBuilder postWithUser = post("/admin/post").with(authentication(user("ROLE_USER"))); + this.mvc.perform(postWithUser).andExpect(status().isForbidden()); + MockHttpServletRequestBuilder requestWithAdmin = post("/admin/post").with(csrf()) + .with(authentication(user("ROLE_ADMIN"))); + this.mvc.perform(requestWithAdmin).andExpect(status().isOk()); } @Test public void requestWhenRequiresChannelThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HttpInterceptUrlConfig.class).autowire(); + this.mvc.perform(get("/login")).andExpect(redirectedUrl("https://localhost/login")); + this.mvc.perform(get("/secured/a")).andExpect(redirectedUrl("https://localhost/secured/a")); + this.mvc.perform(get("https://localhost/user")).andExpect(redirectedUrl("http://localhost/user")); + } - this.mvc.perform(get("/login")) - .andExpect(redirectedUrl("https://localhost/login")); - - this.mvc.perform(get("/secured/a")) - .andExpect(redirectedUrl("https://localhost/secured/a")); - - this.mvc.perform(get("https://localhost/user")) - .andExpect(redirectedUrl("http://localhost/user")); + private static Authentication user(String role) { + return new UsernamePasswordAuthenticationToken("user", null, AuthorityUtils.createAuthorityList(role)); } @EnableWebSecurity @@ -120,6 +107,7 @@ public class NamespaceHttpInterceptUrlTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() // the line below is similar to intercept-url@pattern: @@ -142,46 +130,49 @@ public class NamespaceHttpInterceptUrlTests { // the line below is similar to intercept-url@requires-channel="http": // .anyRequest().requiresInsecure(); + // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } + } @RestController static class BaseController { + @GetMapping("/users") - public String users() { + String users() { return "ok"; } @GetMapping("/sessions") - public String sessions() { + String sessions() { return "sessions"; } @RequestMapping("/admin/post") - public String adminPost() { + String adminPost() { return "adminPost"; } @GetMapping("/admin/another-post") - public String adminAnotherPost() { + String adminAnotherPost() { return "adminAnotherPost"; } @GetMapping("/signup") - public String signup() { + String signup() { return "signup"; } - } - private static Authentication user(String role) { - return new UsernamePasswordAuthenticationToken("user", null, AuthorityUtils.createAuthorityList(role)); } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpJeeTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpJeeTests.java index ed06b41183..f83e15ba4a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpJeeTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpJeeTests.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; +package org.springframework.security.config.annotation.web.configurers; import java.security.Principal; import java.util.stream.Collectors; @@ -37,15 +37,15 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <jee> attributes is present * * @author Rob Winch * @author Josh Cummings @@ -62,93 +62,29 @@ public class NamespaceHttpJeeTests { @Test public void requestWhenJeeUserThenBehaviorDiffersFromNamespaceForRoleNames() throws Exception { this.spring.register(JeeMappableRolesConfig.class, BaseController.class).autowire(); - Principal user = mock(Principal.class); - when(user.getName()).thenReturn("joe"); - - this.mvc.perform(get("/roles") - .principal(user) - .with(request -> { - request.addUserRole("ROLE_admin"); - request.addUserRole("ROLE_user"); - request.addUserRole("ROLE_unmapped"); - return request; - })) - .andExpect(status().isOk()) - .andExpect(content().string("ROLE_admin,ROLE_user")); - } - - @EnableWebSecurity - public static class JeeMappableRolesConfig extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("user") - .and() - .jee() - .mappableRoles("user", "admin"); - } + given(user.getName()).willReturn("joe"); + this.mvc.perform(get("/roles").principal(user).with((request) -> { + request.addUserRole("ROLE_admin"); + request.addUserRole("ROLE_user"); + request.addUserRole("ROLE_unmapped"); + return request; + })).andExpect(status().isOk()).andExpect(content().string("ROLE_admin,ROLE_user")); } @Test public void requestWhenCustomAuthenticatedUserDetailsServiceThenBehaviorMatchesNamespace() throws Exception { this.spring.register(JeeUserServiceRefConfig.class, BaseController.class).autowire(); - Principal user = mock(Principal.class); - when(user.getName()).thenReturn("joe"); - + given(user.getName()).willReturn("joe"); User result = new User(user.getName(), "N/A", true, true, true, true, AuthorityUtils.createAuthorityList("ROLE_user")); - - when(bean(AuthenticationUserDetailsService.class).loadUserDetails(any())) - .thenReturn(result); - - this.mvc.perform(get("/roles") - .principal(user)) - .andExpect(status().isOk()) + given(bean(AuthenticationUserDetailsService.class).loadUserDetails(any())).willReturn(result); + this.mvc.perform(get("/roles").principal(user)).andExpect(status().isOk()) .andExpect(content().string("ROLE_user")); - verifyBean(AuthenticationUserDetailsService.class).loadUserDetails(any()); } - @EnableWebSecurity - public static class JeeUserServiceRefConfig extends WebSecurityConfigurerAdapter { - private final AuthenticationUserDetailsService authenticationUserDetailsService = - mock(AuthenticationUserDetailsService.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("user") - .and() - .jee() - .mappableAuthorities("ROLE_user", "ROLE_admin") - .authenticatedUserDetailsService(this.authenticationUserDetailsService); - } - - @Bean - public AuthenticationUserDetailsService authenticationUserDetailsService() { - return this.authenticationUserDetailsService; - } - } - - @RestController - static class BaseController { - @GetMapping("/authenticated") - public String authenticated(Authentication authentication) { - return authentication.getName(); - } - - @GetMapping("/roles") - public String roles(Authentication authentication) { - return authentication.getAuthorities().stream() - .map(Object::toString).collect(Collectors.joining(",")); - } - } - private T bean(Class beanClass) { return this.spring.getContext().getBean(beanClass); } @@ -156,4 +92,63 @@ public class NamespaceHttpJeeTests { private T verifyBean(Class beanClass) { return verify(this.spring.getContext().getBean(beanClass)); } + + @EnableWebSecurity + public static class JeeMappableRolesConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("user") + .and() + .jee() + .mappableRoles("user", "admin"); + // @formatter:on + } + + } + + @EnableWebSecurity + public static class JeeUserServiceRefConfig extends WebSecurityConfigurerAdapter { + + private final AuthenticationUserDetailsService authenticationUserDetailsService = mock( + AuthenticationUserDetailsService.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("user") + .and() + .jee() + .mappableAuthorities("ROLE_user", "ROLE_admin") + .authenticatedUserDetailsService(this.authenticationUserDetailsService); + // @formatter:on + } + + @Bean + public AuthenticationUserDetailsService authenticationUserDetailsService() { + return this.authenticationUserDetailsService; + } + + } + + @RestController + static class BaseController { + + @GetMapping("/authenticated") + String authenticated(Authentication authentication) { + return authentication.getName(); + } + + @GetMapping("/roles") + String roles(Authentication authentication) { + return authentication.getAuthorities().stream().map(Object::toString).collect(Collectors.joining(",")); + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpLogoutTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpLogoutTests.java index f0edd6301c..5b02de5f44 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpLogoutTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpLogoutTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.Objects; import java.util.Optional; import java.util.function.Predicate; + import javax.servlet.http.HttpSession; import org.assertj.core.api.Condition; @@ -38,6 +40,7 @@ import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuc import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -48,7 +51,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <logout> attributes is present * * @author Rob Winch * @author Josh Cummings @@ -70,36 +73,21 @@ public class NamespaceHttpLogoutTests { @WithMockUser public void logoutWhenUsingDefaultsThenMatchesNamespace() throws Exception { this.spring.register(HttpLogoutConfig.class).autowire(); - + // @formatter:off this.mvc.perform(post("/logout").with(csrf())) .andExpect(authenticated(false)) .andExpect(redirectedUrl("/login?logout")) .andExpect(noCookies()) .andExpect(session(Objects::isNull)); - } - - @EnableWebSecurity - static class HttpLogoutConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) { - } + // @formatter:on } @Test @WithMockUser public void logoutWhenDisabledInLambdaThenRespondsWithNotFound() throws Exception { this.spring.register(HttpLogoutDisabledInLambdaConfig.class).autowire(); - - this.mvc.perform(post("/logout").with(csrf()).with(user("user"))) - .andExpect(status().isNotFound()); - } - - @EnableWebSecurity - static class HttpLogoutDisabledInLambdaConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http.logout(AbstractHttpConfigurer::disable); - } + MockHttpServletRequestBuilder logoutRequest = post("/logout").with(csrf()).with(user("user")); + this.mvc.perform(logoutRequest).andExpect(status().isNotFound()); } /** @@ -109,55 +97,28 @@ public class NamespaceHttpLogoutTests { @WithMockUser public void logoutWhenUsingVariousCustomizationsMatchesNamespace() throws Exception { this.spring.register(CustomHttpLogoutConfig.class).autowire(); - + // @formatter:off this.mvc.perform(post("/custom-logout").with(csrf())) .andExpect(authenticated(false)) .andExpect(redirectedUrl("/logout-success")) - .andExpect(result -> assertThat(result.getResponse().getCookies()).hasSize(1)) + .andExpect((result) -> assertThat(result.getResponse().getCookies()).hasSize(1)) .andExpect(cookie().maxAge("remove", 0)) .andExpect(session(Objects::nonNull)); - } - - @EnableWebSecurity - static class CustomHttpLogoutConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .logout() - .deleteCookies("remove") // logout@delete-cookies - .invalidateHttpSession(false) // logout@invalidate-session=false (default is true) - .logoutUrl("/custom-logout") // logout@logout-url (default is /logout) - .logoutSuccessUrl("/logout-success"); // logout@success-url (default is /login?logout) - } + // @formatter:on } @Test @WithMockUser public void logoutWhenUsingVariousCustomizationsInLambdaThenMatchesNamespace() throws Exception { this.spring.register(CustomHttpLogoutInLambdaConfig.class).autowire(); - + // @formatter:off this.mvc.perform(post("/custom-logout").with(csrf())) .andExpect(authenticated(false)) .andExpect(redirectedUrl("/logout-success")) - .andExpect(result -> assertThat(result.getResponse().getCookies()).hasSize(1)) + .andExpect((result) -> assertThat(result.getResponse().getCookies()).hasSize(1)) .andExpect(cookie().maxAge("remove", 0)) .andExpect(session(Objects::nonNull)); - } - - @EnableWebSecurity - static class CustomHttpLogoutInLambdaConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .logout(logout -> - logout.deleteCookies("remove") - .invalidateHttpSession(false) - .logoutUrl("/custom-logout") - .logoutSuccessUrl("/logout-success") - ); - // @formatter:on - } + // @formatter:on } /** @@ -167,67 +128,125 @@ public class NamespaceHttpLogoutTests { @WithMockUser public void logoutWhenUsingSuccessHandlerRefThenMatchesNamespace() throws Exception { this.spring.register(SuccessHandlerRefHttpLogoutConfig.class).autowire(); - + // @formatter:off this.mvc.perform(post("/logout").with(csrf())) .andExpect(authenticated(false)) .andExpect(redirectedUrl("/SuccessHandlerRefHttpLogoutConfig")) .andExpect(noCookies()) .andExpect(session(Objects::isNull)); - } - - @EnableWebSecurity - static class SuccessHandlerRefHttpLogoutConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - SimpleUrlLogoutSuccessHandler logoutSuccessHandler = - new SimpleUrlLogoutSuccessHandler(); - logoutSuccessHandler.setDefaultTargetUrl("/SuccessHandlerRefHttpLogoutConfig"); - - http - .logout() - .logoutSuccessHandler(logoutSuccessHandler); - } + // @formatter:on } @Test @WithMockUser public void logoutWhenUsingSuccessHandlerRefInLambdaThenMatchesNamespace() throws Exception { this.spring.register(SuccessHandlerRefHttpLogoutInLambdaConfig.class).autowire(); - + // @formatter:off this.mvc.perform(post("/logout").with(csrf())) .andExpect(authenticated(false)) .andExpect(redirectedUrl("/SuccessHandlerRefHttpLogoutConfig")) .andExpect(noCookies()) .andExpect(session(Objects::isNull)); + // @formatter:on + } + + ResultMatcher authenticated(boolean authenticated) { + return (result) -> assertThat(Optional.ofNullable(SecurityContextHolder.getContext().getAuthentication()) + .map(Authentication::isAuthenticated).orElse(false)).isEqualTo(authenticated); + } + + ResultMatcher noCookies() { + return (result) -> assertThat(result.getResponse().getCookies()).isEmpty(); + } + + ResultMatcher session(Predicate sessionPredicate) { + return (result) -> assertThat(result.getRequest().getSession(false)) + .is(new Condition<>(sessionPredicate, "sessionPredicate failed")); } @EnableWebSecurity - static class SuccessHandlerRefHttpLogoutInLambdaConfig extends WebSecurityConfigurerAdapter { + static class HttpLogoutConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) { + } + + } + + @EnableWebSecurity + static class HttpLogoutDisabledInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + http.logout(AbstractHttpConfigurer::disable); + } + + } + + @EnableWebSecurity + static class CustomHttpLogoutConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .logout() + .deleteCookies("remove") // logout@delete-cookies + .invalidateHttpSession(false) // logout@invalidate-session=false (default is true) + .logoutUrl("/custom-logout") // logout@logout-url (default is /logout) + .logoutSuccessUrl("/logout-success"); // logout@success-url (default is /login?logout) + // @formatter:on + } + + } + + @EnableWebSecurity + static class CustomHttpLogoutInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .logout((logout) -> + logout.deleteCookies("remove") + .invalidateHttpSession(false) + .logoutUrl("/custom-logout") + .logoutSuccessUrl("/logout-success") + ); + // @formatter:on + } + + } + + @EnableWebSecurity + static class SuccessHandlerRefHttpLogoutConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { SimpleUrlLogoutSuccessHandler logoutSuccessHandler = new SimpleUrlLogoutSuccessHandler(); logoutSuccessHandler.setDefaultTargetUrl("/SuccessHandlerRefHttpLogoutConfig"); - // @formatter:off http - .logout(logout -> logout.logoutSuccessHandler(logoutSuccessHandler)); + .logout() + .logoutSuccessHandler(logoutSuccessHandler); // @formatter:on } + } - ResultMatcher authenticated(boolean authenticated) { - return result -> assertThat( - Optional.ofNullable(SecurityContextHolder.getContext().getAuthentication()) - .map(Authentication::isAuthenticated) - .orElse(false)).isEqualTo(authenticated); + @EnableWebSecurity + static class SuccessHandlerRefHttpLogoutInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + SimpleUrlLogoutSuccessHandler logoutSuccessHandler = new SimpleUrlLogoutSuccessHandler(); + logoutSuccessHandler.setDefaultTargetUrl("/SuccessHandlerRefHttpLogoutConfig"); + // @formatter:off + http + .logout((logout) -> logout.logoutSuccessHandler(logoutSuccessHandler)); + // @formatter:on + } + } - ResultMatcher noCookies() { - return result -> assertThat(result.getResponse().getCookies()).isEmpty(); - } - - ResultMatcher session(Predicate sessionPredicate) { - return result -> assertThat(result.getRequest().getSession(false)) - .is(new Condition(sessionPredicate, "sessionPredicate failed")); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpOpenIDLoginTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpOpenIDLoginTests.java index c2be226617..273a528aef 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpOpenIDLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpOpenIDLoginTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.Arrays; import java.util.List; + import javax.servlet.http.HttpServletRequest; import okhttp3.mockwebserver.MockResponse; @@ -25,6 +27,7 @@ import org.junit.Rule; import org.junit.Test; import org.openid4java.consumer.ConsumerManager; import org.openid4java.discovery.DiscoveryInformation; +import org.openid4java.discovery.yadis.YadisResolver; import org.openid4java.message.AuthRequest; import org.springframework.beans.factory.annotation.Autowired; @@ -53,16 +56,16 @@ import org.springframework.security.web.authentication.SimpleUrlAuthenticationFa import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.openid4java.discovery.yadis.YadisResolver.YADIS_XRDS_LOCATION; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -70,7 +73,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <openid-login> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -86,24 +90,8 @@ public class NamespaceHttpOpenIDLoginTests { @Test public void openidLoginWhenUsingDefaultsThenMatchesNamespace() throws Exception { this.spring.register(OpenIDLoginConfig.class).autowire(); - this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost/login")); - this.mvc.perform(post("/login/openid").with(csrf())) - .andExpect(redirectedUrl("/login?error")); - } - - @Configuration - @EnableWebSecurity - static class OpenIDLoginConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .openidLogin() - .permitAll(); - } + this.mvc.perform(get("/")).andExpect(redirectedUrl("http://localhost/login")); + this.mvc.perform(post("/login/openid").with(csrf())).andExpect(redirectedUrl("/login?error")); } @Test @@ -111,54 +99,96 @@ public class NamespaceHttpOpenIDLoginTests { OpenIDLoginAttributeExchangeConfig.CONSUMER_MANAGER = mock(ConsumerManager.class); AuthRequest mockAuthRequest = mock(AuthRequest.class); DiscoveryInformation mockDiscoveryInformation = mock(DiscoveryInformation.class); - when(mockAuthRequest.getDestinationUrl(anyBoolean())).thenReturn("mockUrl"); - when(OpenIDLoginAttributeExchangeConfig.CONSUMER_MANAGER.associate(any())) - .thenReturn(mockDiscoveryInformation); - when(OpenIDLoginAttributeExchangeConfig.CONSUMER_MANAGER.authenticate(any(DiscoveryInformation.class), any(), any())) - .thenReturn(mockAuthRequest); + given(mockAuthRequest.getDestinationUrl(anyBoolean())).willReturn("mockUrl"); + given(OpenIDLoginAttributeExchangeConfig.CONSUMER_MANAGER.associate(any())) + .willReturn(mockDiscoveryInformation); + given(OpenIDLoginAttributeExchangeConfig.CONSUMER_MANAGER.authenticate(any(DiscoveryInformation.class), any(), + any())).willReturn(mockAuthRequest); this.spring.register(OpenIDLoginAttributeExchangeConfig.class).autowire(); - try (MockWebServer server = new MockWebServer()) { String endpoint = server.url("/").toString(); - - server.enqueue(new MockResponse() - .addHeader(YADIS_XRDS_LOCATION, endpoint)); + server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint)); server.enqueue(new MockResponse() .setBody(String.format("%s", endpoint))); - MvcResult mvcResult = this.mvc.perform(get("/login/openid") .param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, "https://www.google.com/1")) - .andExpect(status().isFound()) - .andReturn(); - - Object attributeObject = mvcResult.getRequest().getSession().getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"); + .andExpect(status().isFound()).andReturn(); + Object attributeObject = mvcResult.getRequest().getSession() + .getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"); assertThat(attributeObject).isInstanceOf(List.class); List attributeList = (List) attributeObject; - assertThat(attributeList.stream().anyMatch(attribute -> - "firstname".equals(attribute.getName()) - && "https://axschema.org/namePerson/first".equals(attribute.getType()) - && attribute.isRequired())) - .isTrue(); - assertThat(attributeList.stream().anyMatch(attribute -> - "lastname".equals(attribute.getName()) - && "https://axschema.org/namePerson/last".equals(attribute.getType()) - && attribute.isRequired())) - .isTrue(); - assertThat(attributeList.stream().anyMatch(attribute -> - "email".equals(attribute.getName()) - && "https://axschema.org/contact/email".equals(attribute.getType()) - && attribute.isRequired())) - .isTrue(); + assertThat(attributeList.stream().anyMatch((attribute) -> "firstname".equals(attribute.getName()) + && "https://axschema.org/namePerson/first".equals(attribute.getType()) && attribute.isRequired())) + .isTrue(); + assertThat(attributeList.stream().anyMatch((attribute) -> "lastname".equals(attribute.getName()) + && "https://axschema.org/namePerson/last".equals(attribute.getType()) && attribute.isRequired())) + .isTrue(); + assertThat(attributeList.stream().anyMatch((attribute) -> "email".equals(attribute.getName()) + && "https://axschema.org/contact/email".equals(attribute.getType()) && attribute.isRequired())) + .isTrue(); } } + @Test + public void openidLoginWhenUsingCustomEndpointsThenMatchesNamespace() throws Exception { + this.spring.register(OpenIDLoginCustomConfig.class).autowire(); + this.mvc.perform(get("/")).andExpect(redirectedUrl("http://localhost/authentication/login")); + this.mvc.perform(post("/authentication/login/process").with(csrf())) + .andExpect(redirectedUrl("/authentication/login?failed")); + } + + @Test + public void openidLoginWithCustomHandlersThenBehaviorMatchesNamespace() throws Exception { + OpenIDAuthenticationToken token = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.SUCCESS, + "identityUrl", "message", Arrays.asList(new OpenIDAttribute("name", "type"))); + OpenIDLoginCustomRefsConfig.AUDS = mock(AuthenticationUserDetailsService.class); + User user = new User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER")); + given(OpenIDLoginCustomRefsConfig.AUDS.loadUserDetails(any(Authentication.class))).willReturn(user); + OpenIDLoginCustomRefsConfig.ADS = spy(new WebAuthenticationDetailsSource()); + OpenIDLoginCustomRefsConfig.CONSUMER = mock(OpenIDConsumer.class); + this.spring.register(OpenIDLoginCustomRefsConfig.class, UserDetailsServiceConfig.class).autowire(); + given(OpenIDLoginCustomRefsConfig.CONSUMER.endConsumption(any(HttpServletRequest.class))) + .willThrow(new AuthenticationServiceException("boom")); + // @formatter:off + MockHttpServletRequestBuilder login = post("/login/openid") + .with(csrf()) + .param("openid.identity", "identity"); + // @formatter:on + this.mvc.perform(login).andExpect(redirectedUrl("/custom/failure")); + reset(OpenIDLoginCustomRefsConfig.CONSUMER); + given(OpenIDLoginCustomRefsConfig.CONSUMER.endConsumption(any(HttpServletRequest.class))).willReturn(token); + this.mvc.perform(login).andExpect(redirectedUrl("/custom/targetUrl")); + verify(OpenIDLoginCustomRefsConfig.AUDS).loadUserDetails(any(Authentication.class)); + verify(OpenIDLoginCustomRefsConfig.ADS).buildDetails(any(Object.class)); + } + + @Configuration + @EnableWebSecurity + static class OpenIDLoginConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .openidLogin() + .permitAll(); + // @formatter:on + } + + } + @Configuration @EnableWebSecurity static class OpenIDLoginAttributeExchangeConfig extends WebSecurityConfigurerAdapter { + static ConsumerManager CONSUMER_MANAGER; @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") @@ -191,24 +221,19 @@ public class NamespaceHttpOpenIDLoginTests { .and() .and() .permitAll(); + // @formatter:on } - } - @Test - public void openidLoginWhenUsingCustomEndpointsThenMatchesNamespace() throws Exception { - this.spring.register(OpenIDLoginCustomConfig.class).autowire(); - this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost/authentication/login")); - this.mvc.perform(post("/authentication/login/process").with(csrf())) - .andExpect(redirectedUrl("/authentication/login?failed")); } @Configuration @EnableWebSecurity static class OpenIDLoginCustomConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { boolean alwaysUseDefaultSuccess = true; + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") @@ -219,55 +244,24 @@ public class NamespaceHttpOpenIDLoginTests { .failureUrl("/authentication/login?failed") // openid-login@authentication-failure-url .loginProcessingUrl("/authentication/login/process") // openid-login@login-processing-url .defaultSuccessUrl("/default", alwaysUseDefaultSuccess); // openid-login@default-target-url / openid-login@always-use-default-target + // @formatter:on } - } - @Test - public void openidLoginWithCustomHandlersThenBehaviorMatchesNamespace() throws Exception { - OpenIDAuthenticationToken token = new OpenIDAuthenticationToken( - OpenIDAuthenticationStatus.SUCCESS, - "identityUrl", - "message", - Arrays.asList(new OpenIDAttribute("name", "type"))); - - OpenIDLoginCustomRefsConfig.AUDS = mock(AuthenticationUserDetailsService.class); - when(OpenIDLoginCustomRefsConfig.AUDS.loadUserDetails(any(Authentication.class))) - .thenReturn(new User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))); - OpenIDLoginCustomRefsConfig.ADS = spy(new WebAuthenticationDetailsSource()); - OpenIDLoginCustomRefsConfig.CONSUMER = mock(OpenIDConsumer.class); - - this.spring.register(OpenIDLoginCustomRefsConfig.class, UserDetailsServiceConfig.class).autowire(); - - when(OpenIDLoginCustomRefsConfig.CONSUMER.endConsumption(any(HttpServletRequest.class))) - .thenThrow(new AuthenticationServiceException("boom")); - this.mvc.perform(post("/login/openid").with(csrf()) - .param("openid.identity", "identity")) - .andExpect(redirectedUrl("/custom/failure")); - reset(OpenIDLoginCustomRefsConfig.CONSUMER); - - when(OpenIDLoginCustomRefsConfig.CONSUMER.endConsumption(any(HttpServletRequest.class))) - .thenReturn(token); - this.mvc.perform(post("/login/openid").with(csrf()) - .param("openid.identity", "identity")) - .andExpect(redirectedUrl("/custom/targetUrl")); - - verify(OpenIDLoginCustomRefsConfig.AUDS).loadUserDetails(any(Authentication.class)); - verify(OpenIDLoginCustomRefsConfig.ADS).buildDetails(any(Object.class)); } @Configuration @EnableWebSecurity static class OpenIDLoginCustomRefsConfig extends WebSecurityConfigurerAdapter { + static AuthenticationUserDetailsService AUDS; static AuthenticationDetailsSource ADS; static OpenIDConsumer CONSUMER; @Override protected void configure(HttpSecurity http) throws Exception { - SavedRequestAwareAuthenticationSuccessHandler handler = - new SavedRequestAwareAuthenticationSuccessHandler(); + SavedRequestAwareAuthenticationSuccessHandler handler = new SavedRequestAwareAuthenticationSuccessHandler(); handler.setDefaultTargetUrl("/custom/targetUrl"); - + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") @@ -285,20 +279,20 @@ public class NamespaceHttpOpenIDLoginTests { return filter; } }); - + // @formatter:on } + } @Configuration static class UserDetailsServiceConfig { + @Bean - public UserDetailsService userDetailsService() { + UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager( - User.withDefaultPasswordEncoder() - .username("user") - .password("password") - .roles("USER") - .build()); + User.withDefaultPasswordEncoder().username("user").password("password").roles("USER").build()); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpPortMappingsTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpPortMappingsTests.java index cd246b87f7..011409d9fb 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpPortMappingsTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpPortMappingsTests.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; +package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; @@ -31,7 +31,8 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <port-mappings> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -48,15 +49,10 @@ public class NamespaceHttpPortMappingsTests { @Test public void portMappingWhenRequestRequiresChannelThenBehaviorMatchesNamespace() throws Exception { this.spring.register(HttpInterceptUrlWithPortMapperConfig.class).autowire(); - - this.mvc.perform(get("http://localhost:9080/login")) - .andExpect(redirectedUrl("https://localhost:9443/login")); - + this.mvc.perform(get("http://localhost:9080/login")).andExpect(redirectedUrl("https://localhost:9443/login")); this.mvc.perform(get("http://localhost:9080/secured/a")) .andExpect(redirectedUrl("https://localhost:9443/secured/a")); - - this.mvc.perform(get("https://localhost:9443/user")) - .andExpect(redirectedUrl("http://localhost:9080/user")); + this.mvc.perform(get("https://localhost:9443/user")).andExpect(redirectedUrl("http://localhost:9080/user")); } @EnableWebSecurity @@ -64,6 +60,7 @@ public class NamespaceHttpPortMappingsTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().hasRole("USER") @@ -74,13 +71,19 @@ public class NamespaceHttpPortMappingsTests { .requiresChannel() .antMatchers("/login", "/secured/**").requiresSecure() .anyRequest().requiresInsecure(); + // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpRequestCacheTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpRequestCacheTests.java index 9d1cf59ef4..be604a4b8a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpRequestCacheTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpRequestCacheTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -42,7 +43,8 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <request-cache> attributes is + * present * * @author Rob Winch * @author Josh Cummings @@ -59,65 +61,77 @@ public class NamespaceHttpRequestCacheTests { @Test public void requestWhenCustomRequestCacheThenBehaviorMatchesNamespace() throws Exception { this.spring.register(RequestCacheRefConfig.class).autowire(); - this.mvc.perform(get("/")) - .andExpect(status().isForbidden()); + this.mvc.perform(get("/")).andExpect(status().isForbidden()); verifyBean(RequestCache.class).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } + @Test + public void requestWhenDefaultConfigurationThenUsesHttpSessionRequestCache() throws Exception { + this.spring.register(DefaultRequestCacheRefConfig.class).autowire(); + MvcResult result = this.mvc.perform(get("/")).andExpect(status().isForbidden()).andReturn(); + HttpSession session = result.getRequest().getSession(false); + assertThat(session).isNotNull(); + assertThat(session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")).isNotNull(); + } + + private T verifyBean(Class beanClass) { + return verify(this.spring.getContext().getBean(beanClass)); + } + @EnableWebSecurity static class RequestCacheRefConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() .and() .requestCache() .requestCache(requestCache()); + // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()) .withUser(PasswordEncodedUser.admin()); + // @formatter:on } @Bean - public RequestCache requestCache() { + RequestCache requestCache() { return mock(RequestCache.class); } - } - @Test - public void requestWhenDefaultConfigurationThenUsesHttpSessionRequestCache() throws Exception { - this.spring.register(DefaultRequestCacheRefConfig.class).autowire(); - - MvcResult result = this.mvc.perform(get("/")) - .andExpect(status().isForbidden()) - .andReturn(); - - HttpSession session = result.getRequest().getSession(false); - assertThat(session).isNotNull(); - assertThat(session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")).isNotNull(); } @EnableWebSecurity static class DefaultRequestCacheRefConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated(); + // @formatter:on } + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()) .withUser(PasswordEncodedUser.admin()); + // @formatter:on } + } - private T verifyBean(Class beanClass) { - return verify(this.spring.getContext().getBean(beanClass)); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpServerAccessDeniedHandlerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpServerAccessDeniedHandlerTests.java index 0aea788b3d..844058f8f4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpServerAccessDeniedHandlerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpServerAccessDeniedHandlerTests.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; +package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -44,7 +44,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <access-denied-handler> attributes + * is present * * @author Rob Winch * @author Josh Cummings @@ -61,103 +62,119 @@ public class NamespaceHttpServerAccessDeniedHandlerTests { @Test public void requestWhenCustomAccessDeniedPageThenBehaviorMatchesNamespace() throws Exception { this.spring.register(AccessDeniedPageConfig.class).autowire(); - this.mvc.perform(get("/") - .with(authentication(user()))) + // @formatter:off + this.mvc.perform(get("/").with(authentication(user()))) .andExpect(status().isForbidden()) .andExpect(forwardedUrl("/AccessDeniedPageConfig")); + // @formatter:on } - @EnableWebSecurity - static class AccessDeniedPageConfig extends WebSecurityConfigurerAdapter { - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().denyAll() - .and() - .exceptionHandling() - .accessDeniedPage("/AccessDeniedPageConfig"); - } + @Test + public void requestWhenCustomAccessDeniedPageInLambdaThenForwardedToCustomPage() throws Exception { + this.spring.register(AccessDeniedPageInLambdaConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/").with(authentication(user()))) + .andExpect(status().isForbidden()) + .andExpect(forwardedUrl("/AccessDeniedPageConfig")); + // @formatter:on + } + + @Test + public void requestWhenCustomAccessDeniedHandlerThenBehaviorMatchesNamespace() throws Exception { + this.spring.register(AccessDeniedHandlerRefConfig.class).autowire(); + this.mvc.perform(get("/").with(authentication(user()))); + verifyBean(AccessDeniedHandler.class).handle(any(HttpServletRequest.class), any(HttpServletResponse.class), + any(AccessDeniedException.class)); + } + + @Test + public void requestWhenCustomAccessDeniedHandlerInLambdaThenBehaviorMatchesNamespace() throws Exception { + this.spring.register(AccessDeniedHandlerRefInLambdaConfig.class).autowire(); + this.mvc.perform(get("/").with(authentication(user()))); + verify(AccessDeniedHandlerRefInLambdaConfig.accessDeniedHandler).handle(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(AccessDeniedException.class)); } private static Authentication user() { return new UsernamePasswordAuthenticationToken("user", null, AuthorityUtils.NO_AUTHORITIES); } - @Test - public void requestWhenCustomAccessDeniedPageInLambdaThenForwardedToCustomPage() throws Exception { - this.spring.register(AccessDeniedPageInLambdaConfig.class).autowire(); + private T verifyBean(Class beanClass) { + return verify(this.spring.getContext().getBean(beanClass)); + } + + @EnableWebSecurity + static class AccessDeniedPageConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().denyAll() + .and() + .exceptionHandling() + .accessDeniedPage("/AccessDeniedPageConfig"); + // @formatter:on + } - this.mvc.perform(get("/") - .with(authentication(user()))) - .andExpect(status().isForbidden()) - .andExpect(forwardedUrl("/AccessDeniedPageConfig")); } @EnableWebSecurity static class AccessDeniedPageInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().denyAll() ) - .exceptionHandling(exceptionHandling -> + .exceptionHandling((exceptionHandling) -> exceptionHandling.accessDeniedPage("/AccessDeniedPageConfig") ); // @formatter:on } - } - @Test - public void requestWhenCustomAccessDeniedHandlerThenBehaviorMatchesNamespace() throws Exception { - this.spring.register(AccessDeniedHandlerRefConfig.class).autowire(); - this.mvc.perform(get("/") - .with(authentication(user()))); - verifyBean(AccessDeniedHandler.class) - .handle(any(HttpServletRequest.class), any(HttpServletResponse.class), any(AccessDeniedException.class)); } @EnableWebSecurity static class AccessDeniedHandlerRefConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().denyAll() .and() .exceptionHandling() .accessDeniedHandler(accessDeniedHandler()); + // @formatter:on } @Bean AccessDeniedHandler accessDeniedHandler() { return mock(AccessDeniedHandler.class); } - } - @Test - public void requestWhenCustomAccessDeniedHandlerInLambdaThenBehaviorMatchesNamespace() throws Exception { - this.spring.register(AccessDeniedHandlerRefInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/") - .with(authentication(user()))); - - verify(AccessDeniedHandlerRefInLambdaConfig.accessDeniedHandler) - .handle(any(HttpServletRequest.class), any(HttpServletResponse.class), any(AccessDeniedException.class)); } @EnableWebSecurity static class AccessDeniedHandlerRefInLambdaConfig extends WebSecurityConfigurerAdapter { + static AccessDeniedHandler accessDeniedHandler = mock(AccessDeniedHandler.class); + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().denyAll() ) - .exceptionHandling(exceptionHandling -> + .exceptionHandling((exceptionHandling) -> exceptionHandling.accessDeniedHandler(accessDeniedHandler()) ); // @formatter:on @@ -167,9 +184,7 @@ public class NamespaceHttpServerAccessDeniedHandlerTests { AccessDeniedHandler accessDeniedHandler() { return accessDeniedHandler; } + } - private T verifyBean(Class beanClass) { - return verify(this.spring.getContext().getBean(beanClass)); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java index beaeb1f73a..06df9f2e0f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.io.InputStream; import java.security.cert.Certificate; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; + import javax.servlet.http.HttpServletRequest; import org.junit.Rule; @@ -51,7 +53,8 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; /** - * Tests to verify that all the functionality of attributes is present in Java config + * Tests to verify that all the functionality of <x509> attributes is present in + * Java config * * @author Rob Winch * @author Josh Cummings @@ -59,8 +62,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. */ public class NamespaceHttpX509Tests { - private static final User USER = - new User("customuser", "password", AuthorityUtils.createAuthorityList("ROLE_USER")); + private static final User USER = new User("customuser", "password", + AuthorityUtils.createAuthorityList("ROLE_USER")); @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -72,209 +75,242 @@ public class NamespaceHttpX509Tests { public void x509AuthenticationWhenUsingX509DefaultConfigurationThenMatchesNamespace() throws Exception { this.spring.register(X509Config.class, X509Controller.class).autowire(); X509Certificate certificate = loadCert("rod.cer"); - this.mvc.perform(get("/whoami").with(x509(certificate))) - .andExpect(content().string("rod")); - } - - @EnableWebSecurity - @EnableWebMvc - public static class X509Config extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("rod").password("password").roles("USER", "ADMIN"); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .x509(); - } + this.mvc.perform(get("/whoami").with(x509(certificate))).andExpect(content().string("rod")); } @Test public void x509AuthenticationWhenHasCustomAuthenticationDetailsSourceThenMatchesNamespace() throws Exception { this.spring.register(AuthenticationDetailsSourceRefConfig.class, X509Controller.class).autowire(); - X509Certificate certificate = loadCert("rod.cer"); - this.mvc.perform(get("/whoami").with(x509(certificate))) - .andExpect(content().string("rod")); - + this.mvc.perform(get("/whoami").with(x509(certificate))).andExpect(content().string("rod")); verifyBean(AuthenticationDetailsSource.class).buildDetails(any()); } - @EnableWebSecurity - @EnableWebMvc - static class AuthenticationDetailsSourceRefConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("rod").password("password").roles("USER", "ADMIN"); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .x509() - .authenticationDetailsSource(authenticationDetailsSource()); - } - - @Bean - AuthenticationDetailsSource - authenticationDetailsSource() { - - return mock(AuthenticationDetailsSource.class); - } - } - @Test public void x509AuthenticationWhenHasSubjectPrincipalRegexThenMatchesNamespace() throws Exception { this.spring.register(SubjectPrincipalRegexConfig.class, X509Controller.class).autowire(); X509Certificate certificate = loadCert("rodatexampledotcom.cer"); - this.mvc.perform(get("/whoami").with(x509(certificate))) - .andExpect(content().string("rod")); - } - - @EnableWebMvc - @EnableWebSecurity - public static class SubjectPrincipalRegexConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("rod").password("password").roles("USER", "ADMIN"); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .x509() - .subjectPrincipalRegex("CN=(.*?)@example.com(?:,|$)"); - } + this.mvc.perform(get("/whoami").with(x509(certificate))).andExpect(content().string("rod")); } @Test public void x509AuthenticationWhenHasCustomPrincipalExtractorThenMatchesNamespace() throws Exception { this.spring.register(CustomPrincipalExtractorConfig.class, X509Controller.class).autowire(); X509Certificate certificate = loadCert("rodatexampledotcom.cer"); - this.mvc.perform(get("/whoami").with(x509(certificate))) - .andExpect(content().string("rod@example.com")); - } - - @EnableWebMvc - @EnableWebSecurity - public static class CustomPrincipalExtractorConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("rod@example.com").password("password").roles("USER", "ADMIN"); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .x509() - .x509PrincipalExtractor(this::extractCommonName); - } - - private String extractCommonName(X509Certificate certificate) { - try { - return ((X500Name) certificate.getSubjectDN()).getCommonName(); - } catch (Exception e) { - throw new IllegalArgumentException(e); - } - } + this.mvc.perform(get("/whoami").with(x509(certificate))).andExpect(content().string("rod@example.com")); } @Test public void x509AuthenticationWhenHasCustomUserDetailsServiceThenMatchesNamespace() throws Exception { this.spring.register(UserDetailsServiceRefConfig.class, X509Controller.class).autowire(); X509Certificate certificate = loadCert("rodatexampledotcom.cer"); - this.mvc.perform(get("/whoami").with(x509(certificate))) - .andExpect(content().string("customuser")); - } - - @EnableWebMvc - @EnableWebSecurity - public static class UserDetailsServiceRefConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser("rod").password("password").roles("USER", "ADMIN"); - } - - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .x509() - .userDetailsService(username -> USER); - } + this.mvc.perform(get("/whoami").with(x509(certificate))).andExpect(content().string("customuser")); } @Test public void x509AuthenticationWhenHasCustomAuthenticationUserDetailsServiceThenMatchesNamespace() throws Exception { this.spring.register(AuthenticationUserDetailsServiceConfig.class, X509Controller.class).autowire(); X509Certificate certificate = loadCert("rodatexampledotcom.cer"); - this.mvc.perform(get("/whoami").with(x509(certificate))) - .andExpect(content().string("customuser")); - } - - @EnableWebMvc - @EnableWebSecurity - public static class AuthenticationUserDetailsServiceConfig extends WebSecurityConfigurerAdapter { - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth. - inMemoryAuthentication() - .withUser("rod").password("password").roles("USER", "ADMIN"); - } - - protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests() - .anyRequest().hasRole("USER") - .and() - .x509() - .authenticationUserDetailsService(authentication -> USER); - } - } - - @RestController - public static class X509Controller { - @GetMapping("/whoami") - public String whoami(@AuthenticationPrincipal(expression="username") String name) { - return name; - } + this.mvc.perform(get("/whoami").with(x509(certificate))).andExpect(content().string("customuser")); } T loadCert(String location) { try (InputStream is = new ClassPathResource(location).getInputStream()) { CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); return (T) certFactory.generateCertificate(is); - } catch (Exception e) { - throw new IllegalArgumentException(e); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); } } T verifyBean(Class beanClass) { return verify(this.spring.getContext().getBean(beanClass)); } + + @EnableWebSecurity + @EnableWebMvc + public static class X509Config extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("rod").password("password").roles("USER", "ADMIN"); + // @formatter:on + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .x509(); + // @formatter:on + } + + } + + @EnableWebSecurity + @EnableWebMvc + static class AuthenticationDetailsSourceRefConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("rod").password("password").roles("USER", "ADMIN"); + // @formatter:on + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .x509() + .authenticationDetailsSource(authenticationDetailsSource()); + // @formatter:on + } + + @Bean + AuthenticationDetailsSource authenticationDetailsSource() { + return mock(AuthenticationDetailsSource.class); + } + + } + + @EnableWebMvc + @EnableWebSecurity + public static class SubjectPrincipalRegexConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("rod").password("password").roles("USER", "ADMIN"); + // @formatter:on + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .x509() + .subjectPrincipalRegex("CN=(.*?)@example.com(?:,|$)"); + // @formatter:on + } + + } + + @EnableWebMvc + @EnableWebSecurity + public static class CustomPrincipalExtractorConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("rod@example.com").password("password").roles("USER", "ADMIN"); + // @formatter:on + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .x509() + .x509PrincipalExtractor(this::extractCommonName); + // @formatter:on + } + + private String extractCommonName(X509Certificate certificate) { + try { + return ((X500Name) certificate.getSubjectDN()).getCommonName(); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); + } + } + + } + + @EnableWebMvc + @EnableWebSecurity + public static class UserDetailsServiceRefConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("rod").password("password").roles("USER", "ADMIN"); + // @formatter:on + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .x509() + .userDetailsService((username) -> USER); + // @formatter:on + } + + } + + @EnableWebMvc + @EnableWebSecurity + public static class AuthenticationUserDetailsServiceConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser("rod").password("password").roles("USER", "ADMIN"); + // @formatter:on + } + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().hasRole("USER") + .and() + .x509() + .authenticationUserDetailsService((authentication) -> USER); + // @formatter:on + } + + } + + @RestController + public static class X509Controller { + + @GetMapping("/whoami") + public String whoami(@AuthenticationPrincipal(expression = "username") String name) { + return name; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceRememberMeTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceRememberMeTests.java index b1e03fe1d7..b3598dfca4 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceRememberMeTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceRememberMeTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.Cookie; @@ -45,16 +46,17 @@ import org.springframework.security.web.authentication.rememberme.PersistentReme import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -63,7 +65,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests to verify that all the functionality of attributes is present + * Tests to verify that all the functionality of <anonymous> attributes is present * * @author Rob Winch * @author Josh Cummings @@ -80,30 +82,202 @@ public class NamespaceRememberMeTests { @Test public void rememberMeLoginWhenUsingDefaultsThenMatchesNamespace() throws Exception { this.spring.register(RememberMeConfig.class, SecurityController.class).autowire(); - MvcResult result = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn(); - + MvcResult result = this.mvc.perform(post("/login").with(rememberMeLogin())).andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); Cookie rememberMe = result.getResponse().getCookie("remember-me"); assertThat(rememberMe).isNotNull(); - this.mvc.perform(get("/authentication-class") - .cookie(rememberMe)) + this.mvc.perform(get("/authentication-class").cookie(rememberMe)) .andExpect(content().string(RememberMeAuthenticationToken.class.getName())); - - result = this.mvc.perform(post("/logout").with(csrf()) + // @formatter:off + MockHttpServletRequestBuilder logoutRequest = post("/logout") + .with(csrf()) .session(session) - .cookie(rememberMe)) + .cookie(rememberMe); + result = this.mvc.perform(logoutRequest) .andExpect(redirectedUrl("/login?logout")) .andReturn(); - + // @formatter:on rememberMe = result.getResponse().getCookie("remember-me"); assertThat(rememberMe).isNotNull().extracting(Cookie::getMaxAge).isEqualTo(0); - - this.mvc.perform(post("/authentication-class").with(csrf()) - .cookie(rememberMe)) + // @formatter:off + MockHttpServletRequestBuilder authenticationClassRequest = post("/authentication-class") + .with(csrf()) + .cookie(rememberMe); + this.mvc.perform(authenticationClassRequest) .andExpect(redirectedUrl("http://localhost/login")) .andReturn(); + // @formatter:on + } + + // SEC-3170 - RememberMeService implementations should not have to also implement + // LogoutHandler + @Test + public void logoutWhenCustomRememberMeServicesDeclaredThenUses() throws Exception { + RememberMeServicesRefConfig.REMEMBER_ME_SERVICES = mock(RememberMeServicesWithoutLogoutHandler.class); + this.spring.register(RememberMeServicesRefConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).autoLogin(any(HttpServletRequest.class), + any(HttpServletResponse.class)); + this.mvc.perform(post("/login").with(csrf())); + verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).loginFail(any(HttpServletRequest.class), + any(HttpServletResponse.class)); + } + + @Test + public void rememberMeLoginWhenAuthenticationSuccessHandlerDeclaredThenUses() throws Exception { + AuthSuccessConfig.SUCCESS_HANDLER = mock(AuthenticationSuccessHandler.class); + this.spring.register(AuthSuccessConfig.class).autowire(); + MvcResult result = this.mvc.perform(post("/login").with(rememberMeLogin())).andReturn(); + verifyZeroInteractions(AuthSuccessConfig.SUCCESS_HANDLER); + Cookie rememberMe = result.getResponse().getCookie("remember-me"); + assertThat(rememberMe).isNotNull(); + this.mvc.perform(get("/somewhere").cookie(rememberMe)); + verify(AuthSuccessConfig.SUCCESS_HANDLER).onAuthenticationSuccess(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(Authentication.class)); + } + + @Test + public void rememberMeLoginWhenKeyDeclaredThenMatchesNamespace() throws Exception { + this.spring.register(WithoutKeyConfig.class, KeyConfig.class, SecurityController.class).autowire(); + MockHttpServletRequestBuilder requestWithRememberme = post("/without-key/login").with(rememberMeLogin()); + // @formatter:off + Cookie withoutKey = this.mvc.perform(requestWithRememberme) + .andExpect(redirectedUrl("/")) + .andReturn() + .getResponse() + .getCookie("remember-me"); + // @formatter:on + MockHttpServletRequestBuilder somewhereRequest = get("/somewhere").cookie(withoutKey); + // @formatter:off + this.mvc.perform(somewhereRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/login")); + MockHttpServletRequestBuilder loginWithRememberme = post("/login").with(rememberMeLogin()); + Cookie withKey = this.mvc.perform(loginWithRememberme) + .andReturn() + .getResponse() + .getCookie("remember-me"); + this.mvc.perform(get("/somewhere").cookie(withKey)) + .andExpect(status().isNotFound()); + // @formatter:on + } + + // http/remember-me@services-alias is not supported use standard aliasing instead + // (i.e. @Bean("alias")) + // http/remember-me@data-source-ref is not supported directly. Instead use + // http/remember-me@token-repository-ref example + @Test + public void rememberMeLoginWhenDeclaredTokenRepositoryThenMatchesNamespace() throws Exception { + TokenRepositoryRefConfig.TOKEN_REPOSITORY = mock(PersistentTokenRepository.class); + this.spring.register(TokenRepositoryRefConfig.class).autowire(); + this.mvc.perform(post("/login").with(rememberMeLogin())); + verify(TokenRepositoryRefConfig.TOKEN_REPOSITORY).createNewToken(any(PersistentRememberMeToken.class)); + } + + @Test + public void rememberMeLoginWhenTokenValidityDeclaredThenMatchesNamespace() throws Exception { + this.spring.register(TokenValiditySecondsConfig.class).autowire(); + // @formatter:off + Cookie expiredRememberMe = this.mvc.perform(post("/login").with(rememberMeLogin())) + .andReturn() + .getResponse() + .getCookie("remember-me"); + // @formatter:on + assertThat(expiredRememberMe).extracting(Cookie::getMaxAge).isEqualTo(314); + } + + @Test + public void rememberMeLoginWhenUsingDefaultsThenCookieMaxAgeMatchesNamespace() throws Exception { + this.spring.register(RememberMeConfig.class).autowire(); + // @formatter:off + Cookie expiredRememberMe = this.mvc.perform(post("/login").with(rememberMeLogin())) + .andReturn() + .getResponse() + .getCookie("remember-me"); + // @formatter:on + assertThat(expiredRememberMe).extracting(Cookie::getMaxAge).isEqualTo(AbstractRememberMeServices.TWO_WEEKS_S); + } + + @Test + public void rememberMeLoginWhenUsingSecureCookieThenMatchesNamespace() throws Exception { + this.spring.register(UseSecureCookieConfig.class).autowire(); + // @formatter:off + Cookie secureCookie = this.mvc.perform(post("/login").with(rememberMeLogin())) + .andReturn() + .getResponse() + .getCookie("remember-me"); + // @formatter:on + assertThat(secureCookie).extracting(Cookie::getSecure).isEqualTo(true); + } + + @Test + public void rememberMeLoginWhenUsingDefaultsThenCookieSecurityMatchesNamespace() throws Exception { + this.spring.register(RememberMeConfig.class).autowire(); + // @formatter:off + Cookie secureCookie = this.mvc.perform(post("/login").with(rememberMeLogin()).secure(true)) + .andReturn() + .getResponse() + .getCookie("remember-me"); + // @formatter:on + assertThat(secureCookie).extracting(Cookie::getSecure).isEqualTo(true); + } + + @Test + public void rememberMeLoginWhenParameterSpecifiedThenMatchesNamespace() throws Exception { + this.spring.register(RememberMeParameterConfig.class).autowire(); + MockHttpServletRequestBuilder loginWithRememberme = post("/login").with(rememberMeLogin("rememberMe", true)); + // @formatter:off + Cookie rememberMe = this.mvc.perform(loginWithRememberme) + .andReturn() + .getResponse() + .getCookie("remember-me"); + // @formatter:on + assertThat(rememberMe).isNotNull(); + } + + // SEC-2880 + @Test + public void rememberMeLoginWhenCookieNameDeclaredThenMatchesNamespace() throws Exception { + this.spring.register(RememberMeCookieNameConfig.class).autowire(); + // @formatter:off + Cookie rememberMe = this.mvc.perform(post("/login").with(rememberMeLogin())) + .andReturn() + .getResponse() + .getCookie("rememberMe"); + // @formatter:on + assertThat(rememberMe).isNotNull(); + } + + @Test + public void rememberMeLoginWhenGlobalUserDetailsServiceDeclaredThenMatchesNamespace() throws Exception { + DefaultsUserDetailsServiceWithDaoConfig.USERDETAILS_SERVICE = mock(UserDetailsService.class); + this.spring.register(DefaultsUserDetailsServiceWithDaoConfig.class).autowire(); + this.mvc.perform(post("/login").with(rememberMeLogin())); + verify(DefaultsUserDetailsServiceWithDaoConfig.USERDETAILS_SERVICE).loadUserByUsername("user"); + } + + @Test + public void rememberMeLoginWhenUserDetailsServiceDeclaredThenMatchesNamespace() throws Exception { + UserServiceRefConfig.USERDETAILS_SERVICE = mock(UserDetailsService.class); + this.spring.register(UserServiceRefConfig.class).autowire(); + User user = new User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER")); + given(UserServiceRefConfig.USERDETAILS_SERVICE.loadUserByUsername("user")).willReturn(user); + this.mvc.perform(post("/login").with(rememberMeLogin())); + verify(UserServiceRefConfig.USERDETAILS_SERVICE).loadUserByUsername("user"); + } + + static RequestPostProcessor rememberMeLogin() { + return rememberMeLogin("remember-me", true); + } + + static RequestPostProcessor rememberMeLogin(String parameterName, boolean parameterValue) { + return (request) -> { + csrf().postProcessRequest(request); + request.setParameter("username", "user"); + request.setParameter("password", "password"); + request.setParameter(parameterName, String.valueOf(parameterValue)); + return request; + }; } @Configuration @@ -122,28 +296,17 @@ public class NamespaceRememberMeTests { .rememberMe(); // @formatter:on } + } - // SEC-3170 - RememberMeService implementations should not have to also implement LogoutHandler - @Test - public void logoutWhenCustomRememberMeServicesDeclaredThenUses() throws Exception { - RememberMeServicesRefConfig.REMEMBER_ME_SERVICES = mock(RememberMeServicesWithoutLogoutHandler.class); - this.spring.register(RememberMeServicesRefConfig.class).autowire(); + interface RememberMeServicesWithoutLogoutHandler extends RememberMeServices { - this.mvc.perform(get("/")); - verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES) - .autoLogin(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - this.mvc.perform(post("/login").with(csrf())); - verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES) - .loginFail(any(HttpServletRequest.class), any(HttpServletResponse.class)); } - interface RememberMeServicesWithoutLogoutHandler extends RememberMeServices {} - @Configuration @EnableWebSecurity static class RememberMeServicesRefConfig extends WebSecurityConfigurerAdapter { + static RememberMeServices REMEMBER_ME_SERVICES; @Override @@ -156,31 +319,13 @@ public class NamespaceRememberMeTests { .rememberMeServices(REMEMBER_ME_SERVICES); // @formatter:on } - } - @Test - public void rememberMeLoginWhenAuthenticationSuccessHandlerDeclaredThenUses() throws Exception { - AuthSuccessConfig.SUCCESS_HANDLER = mock(AuthenticationSuccessHandler.class); - this.spring.register(AuthSuccessConfig.class).autowire(); - - MvcResult result = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn(); - - verifyZeroInteractions(AuthSuccessConfig.SUCCESS_HANDLER); - - Cookie rememberMe = result.getResponse().getCookie("remember-me"); - assertThat(rememberMe).isNotNull(); - this.mvc.perform(get("/somewhere") - .cookie(rememberMe)); - - verify(AuthSuccessConfig.SUCCESS_HANDLER).onAuthenticationSuccess - (any(HttpServletRequest.class), any(HttpServletResponse.class), any(Authentication.class)); } @Configuration @EnableWebSecurity static class AuthSuccessConfig extends UsersConfig { + static AuthenticationSuccessHandler SUCCESS_HANDLER; @Override @@ -193,33 +338,14 @@ public class NamespaceRememberMeTests { .authenticationSuccessHandler(SUCCESS_HANDLER); // @formatter:on } - } - @Test - public void rememberMeLoginWhenKeyDeclaredThenMatchesNamespace() throws Exception { - this.spring.register(WithoutKeyConfig.class, KeyConfig.class, SecurityController.class).autowire(); - Cookie withoutKey = this.mvc.perform(post("/without-key/login") - .with(rememberMeLogin())) - .andExpect(redirectedUrl("/")) - .andReturn().getResponse().getCookie("remember-me"); - - this.mvc.perform(get("/somewhere") - .cookie(withoutKey)) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/login")); - - Cookie withKey = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn().getResponse().getCookie("remember-me"); - this.mvc.perform(get("/somewhere") - .cookie(withKey)) - .andExpect(status().isNotFound()); } @Configuration @EnableWebSecurity @Order(0) static class WithoutKeyConfig extends UsersConfig { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -231,11 +357,13 @@ public class NamespaceRememberMeTests { .rememberMe(); // @formatter:on } + } @Configuration @EnableWebSecurity static class KeyConfig extends UsersConfig { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -249,32 +377,19 @@ public class NamespaceRememberMeTests { .key("KeyConfig"); // @formatter:on } - } - // http/remember-me@services-alias is not supported use standard aliasing instead (i.e. @Bean("alias")) - - // http/remember-me@data-source-ref is not supported directly. Instead use http/remember-me@token-repository-ref example - @Test - public void rememberMeLoginWhenDeclaredTokenRepositoryThenMatchesNamespace() throws Exception { - TokenRepositoryRefConfig.TOKEN_REPOSITORY = mock(PersistentTokenRepository.class); - this.spring.register(TokenRepositoryRefConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(rememberMeLogin())); - - verify(TokenRepositoryRefConfig.TOKEN_REPOSITORY).createNewToken(any(PersistentRememberMeToken.class)); } @Configuration @EnableWebSecurity static class TokenRepositoryRefConfig extends UsersConfig { + static PersistentTokenRepository TOKEN_REPOSITORY; @Override protected void configure(HttpSecurity http) throws Exception { // JdbcTokenRepositoryImpl tokenRepository = new JdbcTokenRepositoryImpl() // tokenRepository.setDataSource(dataSource); - // @formatter:off http .formLogin() @@ -283,21 +398,13 @@ public class NamespaceRememberMeTests { .tokenRepository(TOKEN_REPOSITORY); // @formatter:on } - } - @Test - public void rememberMeLoginWhenTokenValidityDeclaredThenMatchesNamespace() throws Exception { - this.spring.register(TokenValiditySecondsConfig.class).autowire(); - Cookie expiredRememberMe = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn().getResponse().getCookie("remember-me"); - - assertThat(expiredRememberMe).extracting(Cookie::getMaxAge).isEqualTo(314); } @Configuration @EnableWebSecurity static class TokenValiditySecondsConfig extends UsersConfig { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -311,32 +418,13 @@ public class NamespaceRememberMeTests { .tokenValiditySeconds(314); // @formatter:on } - } - @Test - public void rememberMeLoginWhenUsingDefaultsThenCookieMaxAgeMatchesNamespace() throws Exception { - this.spring.register(RememberMeConfig.class).autowire(); - Cookie expiredRememberMe = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn().getResponse().getCookie("remember-me"); - - assertThat(expiredRememberMe).extracting(Cookie::getMaxAge) - .isEqualTo(AbstractRememberMeServices.TWO_WEEKS_S); - } - - @Test - public void rememberMeLoginWhenUsingSecureCookieThenMatchesNamespace() throws Exception { - this.spring.register(UseSecureCookieConfig.class).autowire(); - Cookie secureCookie = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn().getResponse().getCookie("remember-me"); - - assertThat(secureCookie).extracting(Cookie::getSecure).isEqualTo(true); } @Configuration @EnableWebSecurity static class UseSecureCookieConfig extends UsersConfig { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -347,32 +435,13 @@ public class NamespaceRememberMeTests { .useSecureCookie(true); // @formatter:on } - } - @Test - public void rememberMeLoginWhenUsingDefaultsThenCookieSecurityMatchesNamespace() throws Exception { - this.spring.register(RememberMeConfig.class).autowire(); - Cookie secureCookie = this.mvc.perform(post("/login") - .with(rememberMeLogin()) - .secure(true)) - .andReturn().getResponse().getCookie("remember-me"); - - assertThat(secureCookie).extracting(Cookie::getSecure).isEqualTo(true); - } - - @Test - public void rememberMeLoginWhenParameterSpecifiedThenMatchesNamespace() throws Exception { - this.spring.register(RememberMeParameterConfig.class).autowire(); - Cookie rememberMe = this.mvc.perform(post("/login") - .with(rememberMeLogin("rememberMe", true))) - .andReturn().getResponse().getCookie("remember-me"); - - assertThat(rememberMe).isNotNull(); } @Configuration @EnableWebSecurity static class RememberMeParameterConfig extends UsersConfig { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -383,23 +452,13 @@ public class NamespaceRememberMeTests { .rememberMeParameter("rememberMe"); // @formatter:on } - } - // SEC-2880 - - @Test - public void rememberMeLoginWhenCookieNameDeclaredThenMatchesNamespace() throws Exception { - this.spring.register(RememberMeCookieNameConfig.class).autowire(); - Cookie rememberMe = this.mvc.perform(post("/login") - .with(rememberMeLogin())) - .andReturn().getResponse().getCookie("rememberMe"); - - assertThat(rememberMe).isNotNull(); } @Configuration @EnableWebSecurity static class RememberMeCookieNameConfig extends UsersConfig { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -410,23 +469,13 @@ public class NamespaceRememberMeTests { .rememberMeCookieName("rememberMe"); // @formatter:on } - } - @Test - public void rememberMeLoginWhenGlobalUserDetailsServiceDeclaredThenMatchesNamespace() throws Exception { - DefaultsUserDetailsServiceWithDaoConfig.USERDETAILS_SERVICE = mock(UserDetailsService.class); - this.spring.register(DefaultsUserDetailsServiceWithDaoConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(rememberMeLogin())); - - verify(DefaultsUserDetailsServiceWithDaoConfig.USERDETAILS_SERVICE) - .loadUserByUsername("user"); } @EnableWebSecurity @Configuration static class DefaultsUserDetailsServiceWithDaoConfig extends WebSecurityConfigurerAdapter { + static UserDetailsService USERDETAILS_SERVICE; @Override @@ -441,29 +490,18 @@ public class NamespaceRememberMeTests { @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .userDetailsService(USERDETAILS_SERVICE); + // @formatter:on } - } - @Test - public void rememberMeLoginWhenUserDetailsServiceDeclaredThenMatchesNamespace() throws Exception { - UserServiceRefConfig.USERDETAILS_SERVICE = mock(UserDetailsService.class); - this.spring.register(UserServiceRefConfig.class).autowire(); - - when(UserServiceRefConfig.USERDETAILS_SERVICE.loadUserByUsername("user")) - .thenReturn(new User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))); - - this.mvc.perform(post("/login") - .with(rememberMeLogin())); - - verify(UserServiceRefConfig.USERDETAILS_SERVICE) - .loadUserByUsername("user"); } @Configuration @EnableWebSecurity static class UserServiceRefConfig extends UsersConfig { + static UserDetailsService USERDETAILS_SERVICE; @Override @@ -476,40 +514,34 @@ public class NamespaceRememberMeTests { .userDetailsService(USERDETAILS_SERVICE); // @formatter:on } - } - static RequestPostProcessor rememberMeLogin() { - return rememberMeLogin("remember-me", true); - } - - static RequestPostProcessor rememberMeLogin(String parameterName, boolean parameterValue) { - return request -> { - csrf().postProcessRequest(request); - request.setParameter("username", "user"); - request.setParameter("password", "password"); - request.setParameter(parameterName, String.valueOf(parameterValue)); - return request; - }; } static class UsersConfig extends WebSecurityConfigurerAdapter { + @Override @Bean public UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager( + // @formatter:off User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build()); + // @formatter:on } + } @RestController static class SecurityController { + @GetMapping("/authentication-class") String authenticationClass(Authentication authentication) { return authentication.getClass().getName(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceSessionManagementTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceSessionManagementTests.java index e3a1f6bde0..8c43518461 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceSessionManagementTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceSessionManagementTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.security.Principal; import java.util.ArrayList; import java.util.Date; import java.util.List; + import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -50,22 +52,22 @@ import org.springframework.security.web.session.InvalidSessionStrategy; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Rob Winch * @author Josh Cummings */ @@ -79,101 +81,192 @@ public class NamespaceSessionManagementTests { @Test public void authenticateWhenDefaultSessionManagementThenMatchesNamespace() throws Exception { - this.spring.register - (SessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - + this.spring.register(SessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class) + .autowire(); MockHttpSession session = new MockHttpSession(); String sessionId = session.getId(); - - MvcResult result = - this.mvc.perform(get("/auth") - .session(session) - .with(httpBasic("user", "password"))) - .andExpect(session()) - .andReturn(); - + MockHttpServletRequestBuilder request = get("/auth").session(session).with(httpBasic("user", "password")); + // @formatter:off + MvcResult result = this.mvc.perform(request) + .andExpect(session()) + .andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false).getId()).isNotEqualTo(sessionId); } - @EnableWebSecurity - static class SessionManagementConfig extends WebSecurityConfigurerAdapter { - } - @Test public void authenticateWhenUsingInvalidSessionUrlThenMatchesNamespace() throws Exception { this.spring.register(CustomSessionManagementConfig.class).autowire(); - - this.mvc.perform(get("/auth") - .with(request -> { - request.setRequestedSessionIdValid(false); - request.setRequestedSessionId("id"); - return request; - })) - .andExpect(redirectedUrl("/invalid-session")); + MockHttpServletRequestBuilder authRequest = get("/auth").with((request) -> { + request.setRequestedSessionIdValid(false); + request.setRequestedSessionId("id"); + return request; + }); + this.mvc.perform(authRequest).andExpect(redirectedUrl("/invalid-session")); } - @Test public void authenticateWhenUsingExpiredUrlThenMatchesNamespace() throws Exception { this.spring.register(CustomSessionManagementConfig.class).autowire(); - MockHttpSession session = new MockHttpSession(); SessionInformation sessionInformation = new SessionInformation(new Object(), session.getId(), new Date(0)); sessionInformation.expireNow(); SessionRegistry sessionRegistry = this.spring.getContext().getBean(SessionRegistry.class); - when(sessionRegistry.getSessionInformation(session.getId())).thenReturn(sessionInformation); - - this.mvc.perform(get("/auth").session(session)) - .andExpect(redirectedUrl("/expired-session")); + given(sessionRegistry.getSessionInformation(session.getId())).willReturn(sessionInformation); + this.mvc.perform(get("/auth").session(session)).andExpect(redirectedUrl("/expired-session")); } @Test public void authenticateWhenUsingMaxSessionsThenMatchesNamespace() throws Exception { - this.spring.register(CustomSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + this.spring.register(CustomSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class) + .autowire(); + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))).andExpect(status().isOk()); + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(redirectedUrl("/session-auth-error")); } @Test public void authenticateWhenUsingFailureUrlThenMatchesNamespace() throws Exception { - this.spring.register(CustomSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - + this.spring.register(CustomSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class) + .autowire(); MockHttpServletRequest mock = spy(MockHttpServletRequest.class); mock.setSession(new MockHttpSession()); - when(mock.changeSessionId()).thenThrow(SessionAuthenticationException.class); + given(mock.changeSessionId()).willThrow(SessionAuthenticationException.class); mock.setMethod("GET"); - - this.mvc.perform(get("/auth") - .with(request -> mock) - .with(httpBasic("user", "password"))) - .andExpect(redirectedUrl("/session-auth-error")); + // @formatter:off + MockHttpServletRequestBuilder authRequest = get("/auth") + .with((request) -> mock) + .with(httpBasic("user", "password")); + // @formatter:on + this.mvc.perform(authRequest).andExpect(redirectedUrl("/session-auth-error")); } @Test public void authenticateWhenUsingSessionRegistryThenMatchesNamespace() throws Exception { - this.spring.register(CustomSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - + this.spring.register(CustomSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class) + .autowire(); SessionRegistry sessionRegistry = this.spring.getContext().getBean(SessionRegistry.class); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()); - + MockHttpServletRequestBuilder request = get("/auth").with(httpBasic("user", "password")); + this.mvc.perform(request).andExpect(status().isOk()); verify(sessionRegistry).registerNewSession(any(String.class), any(Object.class)); } + // gh-3371 + @Test + public void authenticateWhenUsingCustomInvalidSessionStrategyThenMatchesNamespace() throws Exception { + this.spring.register(InvalidSessionStrategyConfig.class).autowire(); + MockHttpServletRequestBuilder authRequest = get("/auth").with((request) -> { + request.setRequestedSessionIdValid(false); + request.setRequestedSessionId("id"); + return request; + }); + this.mvc.perform(authRequest).andExpect(status().isOk()); + verifyBean(InvalidSessionStrategy.class).onInvalidSessionDetected(any(HttpServletRequest.class), + any(HttpServletResponse.class)); + } + + @Test + public void authenticateWhenUsingCustomSessionAuthenticationStrategyThenMatchesNamespace() throws Exception { + this.spring.register(RefsSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class) + .autowire(); + MockHttpServletRequestBuilder request = get("/auth").with(httpBasic("user", "password")); + this.mvc.perform(request).andExpect(status().isOk()); + verifyBean(SessionAuthenticationStrategy.class).onAuthentication(any(Authentication.class), + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void authenticateWhenNoSessionFixationProtectionThenMatchesNamespace() throws Exception { + this.spring + .register(SFPNoneSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class) + .autowire(); + MockHttpSession givenSession = new MockHttpSession(); + String givenSessionId = givenSession.getId(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(givenSession) + .with(httpBasic("user", "password")); + MockHttpSession resultingSession = (MockHttpSession) this.mvc.perform(request) + .andExpect(status().isOk()) + .andReturn() + .getRequest() + .getSession(false); + // @formatter:on + assertThat(givenSessionId).isEqualTo(resultingSession.getId()); + } + + @Test + public void authenticateWhenMigrateSessionFixationProtectionThenMatchesNamespace() throws Exception { + this.spring.register(SFPMigrateSessionManagementConfig.class, BasicController.class, + UserDetailsServiceConfig.class).autowire(); + MockHttpSession givenSession = new MockHttpSession(); + String givenSessionId = givenSession.getId(); + givenSession.setAttribute("name", "value"); + // @formatter:off + MockHttpSession resultingSession = (MockHttpSession) this.mvc.perform(get("/auth") + .session(givenSession) + .with(httpBasic("user", "password"))) + .andExpect(status().isOk()) + .andReturn() + .getRequest() + .getSession(false); + // @formatter:on + assertThat(givenSessionId).isNotEqualTo(resultingSession.getId()); + assertThat(resultingSession.getAttribute("name")).isEqualTo("value"); + } + + // SEC-2913 + @Test + public void authenticateWhenUsingSessionFixationProtectionThenUsesNonNullEventPublisher() throws Exception { + this.spring.register(SFPPostProcessedConfig.class, UserDetailsServiceConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(new MockHttpSession()) + .with(httpBasic("user", "password")); + // @formatter:on + this.mvc.perform(request).andExpect(status().isNotFound()); + verifyBean(MockEventListener.class).onApplicationEvent(any(SessionFixationProtectionEvent.class)); + } + + @Test + public void authenticateWhenNewSessionFixationProtectionThenMatchesNamespace() throws Exception { + this.spring.register(SFPNewSessionSessionManagementConfig.class, UserDetailsServiceConfig.class).autowire(); + MockHttpSession givenSession = new MockHttpSession(); + String givenSessionId = givenSession.getId(); + givenSession.setAttribute("name", "value"); + MockHttpServletRequestBuilder request = get("/auth").session(givenSession).with(httpBasic("user", "password")); + // @formatter:off + MockHttpSession resultingSession = (MockHttpSession) this.mvc.perform(request) + .andExpect(status().isNotFound()) + .andReturn() + .getRequest() + .getSession(false); + // @formatter:on + assertThat(givenSessionId).isNotEqualTo(resultingSession.getId()); + assertThat(resultingSession.getAttribute("name")).isNull(); + } + + private T verifyBean(Class clazz) { + return verify(this.spring.getContext().getBean(clazz)); + } + + private static SessionResultMatcher session() { + return new SessionResultMatcher(); + } + + @EnableWebSecurity + static class SessionManagementConfig extends WebSecurityConfigurerAdapter { + + } + @EnableWebSecurity static class CustomSessionManagementConfig extends WebSecurityConfigurerAdapter { + SessionRegistry sessionRegistry = spy(SessionRegistryImpl.class); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -187,259 +280,189 @@ public class NamespaceSessionManagementTests { .maxSessionsPreventsLogin(true) // session-management/concurrency-control@error-if-maximum-exceeded .expiredUrl("/expired-session") // session-management/concurrency-control@expired-url .sessionRegistry(sessionRegistry()); // session-management/concurrency-control@session-registry-ref + // @formatter:on } @Bean SessionRegistry sessionRegistry() { return this.sessionRegistry; } - } - - // gh-3371 - @Test - public void authenticateWhenUsingCustomInvalidSessionStrategyThenMatchesNamespace() throws Exception { - this.spring.register(InvalidSessionStrategyConfig.class).autowire(); - - this.mvc.perform(get("/auth") - .with(request -> { - request.setRequestedSessionIdValid(false); - request.setRequestedSessionId("id"); - return request; - })) - .andExpect(status().isOk()); - - verifyBean(InvalidSessionStrategy.class) - .onInvalidSessionDetected(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @EnableWebSecurity static class InvalidSessionStrategyConfig extends WebSecurityConfigurerAdapter { + InvalidSessionStrategy invalidSessionStrategy = mock(InvalidSessionStrategy.class); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement() .invalidSessionStrategy(invalidSessionStrategy()); + // @formatter:on } @Bean InvalidSessionStrategy invalidSessionStrategy() { return this.invalidSessionStrategy; } - } - @Test - public void authenticateWhenUsingCustomSessionAuthenticationStrategyThenMatchesNamespace() throws Exception { - this.spring.register(RefsSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()); - - verifyBean(SessionAuthenticationStrategy.class) - .onAuthentication(any(Authentication.class), - any(HttpServletRequest.class), any(HttpServletResponse.class)); } @EnableWebSecurity static class RefsSessionManagementConfig extends WebSecurityConfigurerAdapter { - SessionAuthenticationStrategy sessionAuthenticationStrategy = - mock(SessionAuthenticationStrategy.class); + + SessionAuthenticationStrategy sessionAuthenticationStrategy = mock(SessionAuthenticationStrategy.class); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement() .sessionAuthenticationStrategy(sessionAuthenticationStrategy()) // session-management@session-authentication-strategy-ref .and() .httpBasic(); + // @formatter:on } @Bean SessionAuthenticationStrategy sessionAuthenticationStrategy() { return this.sessionAuthenticationStrategy; } - } - @Test - public void authenticateWhenNoSessionFixationProtectionThenMatchesNamespace() throws Exception { - this.spring.register(SFPNoneSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - - MockHttpSession givenSession = new MockHttpSession(); - String givenSessionId = givenSession.getId(); - MockHttpSession resultingSession = (MockHttpSession) - this.mvc.perform(get("/auth") - .session(givenSession) - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()) - .andReturn().getRequest().getSession(false); - - assertThat(givenSessionId).isEqualTo(resultingSession.getId()); } @EnableWebSecurity static class SFPNoneSessionManagementConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement() .sessionAuthenticationStrategy(new NullAuthenticatedSessionStrategy()) .and() .httpBasic(); + // @formatter:on } - } - @Test - public void authenticateWhenMigrateSessionFixationProtectionThenMatchesNamespace() throws Exception { - this.spring.register(SFPMigrateSessionManagementConfig.class, BasicController.class, UserDetailsServiceConfig.class).autowire(); - - MockHttpSession givenSession = new MockHttpSession(); - String givenSessionId = givenSession.getId(); - givenSession.setAttribute("name", "value"); - - MockHttpSession resultingSession = (MockHttpSession) - this.mvc.perform(get("/auth") - .session(givenSession) - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()) - .andReturn().getRequest().getSession(false); - - assertThat(givenSessionId).isNotEqualTo(resultingSession.getId()); - assertThat(resultingSession.getAttribute("name")).isEqualTo("value"); } @EnableWebSecurity static class SFPMigrateSessionManagementConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement() .and() .httpBasic(); + // @formatter:on } - } - // SEC-2913 - @Test - public void authenticateWhenUsingSessionFixationProtectionThenUsesNonNullEventPublisher() throws Exception { - this.spring.register(SFPPostProcessedConfig.class, UserDetailsServiceConfig.class).autowire(); - - this.mvc.perform(get("/auth") - .session(new MockHttpSession()) - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()); - - verifyBean(MockEventListener.class).onApplicationEvent(any(SessionFixationProtectionEvent.class)); } @EnableWebSecurity static class SFPPostProcessedConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement() .and() .httpBasic(); + // @formatter:on } @Bean - public MockEventListener eventListener() { + MockEventListener eventListener() { return spy(new MockEventListener()); } - } - @Test - public void authenticateWhenNewSessionFixationProtectionThenMatchesNamespace() throws Exception { - this.spring.register(SFPNewSessionSessionManagementConfig.class, UserDetailsServiceConfig.class).autowire(); - - MockHttpSession givenSession = new MockHttpSession(); - String givenSessionId = givenSession.getId(); - givenSession.setAttribute("name", "value"); - - MockHttpSession resultingSession = (MockHttpSession) - this.mvc.perform(get("/auth") - .session(givenSession) - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()) - .andReturn().getRequest().getSession(false); - - assertThat(givenSessionId).isNotEqualTo(resultingSession.getId()); - assertThat(resultingSession.getAttribute("name")).isNull(); } @EnableWebSecurity static class SFPNewSessionSessionManagementConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement() .sessionFixation().newSession() .and() .httpBasic(); + // @formatter:on } - } - - private T verifyBean(Class clazz) { - return verify(this.spring.getContext().getBean(clazz)); } static class MockEventListener implements ApplicationListener { + List events = new ArrayList<>(); + @Override public void onApplicationEvent(SessionFixationProtectionEvent event) { this.events.add(event); } + } @Configuration static class UserDetailsServiceConfig { + @Bean UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager( + // @formatter:off User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build()); + // @formatter:on } + } @RestController static class BasicController { + @GetMapping("/") - public String ok() { + String ok() { return "ok"; } @GetMapping("/auth") - public String auth(Principal principal) { + String auth(Principal principal) { return principal.getName(); } - } - private static SessionResultMatcher session() { - return new SessionResultMatcher(); } private static class SessionResultMatcher implements ResultMatcher { + private String id; + private Boolean valid; + private Boolean exists = true; - public ResultMatcher exists(boolean exists) { + ResultMatcher exists(boolean exists) { this.exists = exists; return this; } - public ResultMatcher valid(boolean valid) { + ResultMatcher valid(boolean valid) { this.valid = valid; return this.exists(true); } - public ResultMatcher id(String id) { + ResultMatcher id(String id) { this.id = id; return this.exists(true); } @@ -450,22 +473,21 @@ public class NamespaceSessionManagementTests { assertThat(result.getRequest().getSession(false)).isNull(); return; } - assertThat(result.getRequest().getSession(false)).isNotNull(); - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); - if (this.valid != null) { if (this.valid) { assertThat(session.isInvalid()).isFalse(); - } else { + } + else { assertThat(session.isInvalid()).isTrue(); } } - if (this.id != null) { assertThat(session.getId()).isEqualTo(this.id); } } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupportTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupportTests.java index eda13cb8b2..35d1573329 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupportTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PermitAllSupportTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; @@ -25,8 +26,9 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.test.SpringTestRule; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -48,21 +50,29 @@ public class PermitAllSupportTests { @Test public void performWhenUsingPermitAllExactUrlRequestMatcherThenMatchesExactUrl() throws Exception { this.spring.register(PermitAllConfig.class).autowire(); + MockHttpServletRequestBuilder request = get("/app/xyz").contextPath("/app"); + this.mvc.perform(request).andExpect(status().isNotFound()); + MockHttpServletRequestBuilder getWithQuery = get("/app/xyz?def").contextPath("/app"); + this.mvc.perform(getWithQuery).andExpect(status().isFound()); + MockHttpServletRequestBuilder postWithQueryAndCsrf = post("/app/abc?def").with(csrf()).contextPath("/app"); + this.mvc.perform(postWithQueryAndCsrf).andExpect(status().isNotFound()); + MockHttpServletRequestBuilder getWithCsrf = get("/app/abc").with(csrf()).contextPath("/app"); + this.mvc.perform(getWithCsrf).andExpect(status().isFound()); + } - this.mvc.perform(get("/app/xyz").contextPath("/app")) - .andExpect(status().isNotFound()); - this.mvc.perform(get("/app/xyz?def").contextPath("/app")) - .andExpect(status().isFound()); - this.mvc.perform(post("/app/abc?def").with(csrf()).contextPath("/app")) - .andExpect(status().isNotFound()); - this.mvc.perform(get("/app/abc").with(csrf()).contextPath("/app")) - .andExpect(status().isFound()); + @Test + public void configureWhenNotAuthorizeRequestsThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(NoAuthorizedUrlsConfig.class).autowire()) + .withMessageContaining("permitAll only works with HttpSecurity.authorizeRequests"); } @EnableWebSecurity static class PermitAllConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -70,14 +80,9 @@ public class PermitAllSupportTests { .formLogin() .loginPage("/xyz").permitAll() .loginProcessingUrl("/abc?def").permitAll(); + // @formatter:on } - } - @Test - public void configureWhenNotAuthorizeRequestsThenException() { - assertThatCode(() -> this.spring.register(NoAuthorizedUrlsConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("permitAll only works with HttpSecurity.authorizeRequests"); } @EnableWebSecurity @@ -85,9 +90,13 @@ public class PermitAllSupportTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .formLogin() .permitAll(); + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurerTests.java index bd68d2ecec..bf21abbd7c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/PortMapperConfigurerTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; +import java.util.Collections; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -25,8 +29,6 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.web.PortMapperImpl; import org.springframework.test.web.servlet.MockMvc; -import java.util.Collections; - import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -35,6 +37,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Josh Cummings */ public class PortMapperConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -44,9 +47,19 @@ public class PortMapperConfigurerTests { @Test public void requestWhenPortMapperTwiceInvokedThenDoesNotOverride() throws Exception { this.spring.register(InvokeTwiceDoesNotOverride.class).autowire(); + this.mockMvc.perform(get("http://localhost:543")).andExpect(redirectedUrl("https://localhost:123")); + } - this.mockMvc.perform(get("http://localhost:543")) - .andExpect(redirectedUrl("https://localhost:123")); + @Test + public void requestWhenPortMapperHttpMapsToInLambdaThenRedirectsToHttpsPort() throws Exception { + this.spring.register(HttpMapsToInLambdaConfig.class).autowire(); + this.mockMvc.perform(get("http://localhost:543")).andExpect(redirectedUrl("https://localhost:123")); + } + + @Test + public void requestWhenCustomPortMapperInLambdaThenRedirectsToHttpsPort() throws Exception { + this.spring.register(CustomPortMapperInLambdaConfig.class).autowire(); + this.mockMvc.perform(get("http://localhost:543")).andExpect(redirectedUrl("https://localhost:123")); } @EnableWebSecurity @@ -54,6 +67,7 @@ public class PortMapperConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .requiresChannel() .anyRequest().requiresSecure() @@ -62,60 +76,51 @@ public class PortMapperConfigurerTests { .http(543).mapsTo(123) .and() .portMapper(); + // @formatter:on } - } - @Test - public void requestWhenPortMapperHttpMapsToInLambdaThenRedirectsToHttpsPort() throws Exception { - this.spring.register(HttpMapsToInLambdaConfig.class).autowire(); - - this.mockMvc.perform(get("http://localhost:543")) - .andExpect(redirectedUrl("https://localhost:123")); } @EnableWebSecurity static class HttpMapsToInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .requiresChannel(requiresChannel -> + .requiresChannel((requiresChannel) -> requiresChannel .anyRequest().requiresSecure() ) - .portMapper(portMapper -> + .portMapper((portMapper) -> portMapper .http(543).mapsTo(123) ); // @formatter:on } - } - @Test - public void requestWhenCustomPortMapperInLambdaThenRedirectsToHttpsPort() throws Exception { - this.spring.register(CustomPortMapperInLambdaConfig.class).autowire(); - - this.mockMvc.perform(get("http://localhost:543")) - .andExpect(redirectedUrl("https://localhost:123")); } @EnableWebSecurity static class CustomPortMapperInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { PortMapperImpl customPortMapper = new PortMapperImpl(); customPortMapper.setPortMappings(Collections.singletonMap("543", "123")); // @formatter:off http - .requiresChannel(requiresChannel -> + .requiresChannel((requiresChannel) -> requiresChannel .anyRequest().requiresSecure() ) - .portMapper(portMapper -> + .portMapper((portMapper) -> portMapper .portMapper(customPortMapper) ); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java index 1cba7745fe..93946ab099 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java @@ -16,8 +16,14 @@ package org.springframework.security.config.annotation.web.configurers; +import java.util.Collections; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpSession; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -34,24 +40,23 @@ import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers; import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter; import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; - -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpSession; -import java.util.Collections; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; @@ -79,18 +84,194 @@ public class RememberMeConfigurerTests { @Test public void postWhenNoUserDetailsServiceThenException() { this.spring.register(NullUserDetailsConfig.class).autowire(); + assertThatIllegalStateException().isThrownBy(() -> { + // @formatter:off + MockHttpServletRequestBuilder request = post("/login") + .param("username", "user") + .param("password", "password") + .param("remember-me", "true") + .with(csrf()); + // @formatter:on + this.mvc.perform(request); + }).withMessageContaining("UserDetailsService is required"); + } - assertThatThrownBy(() -> - mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .param("remember-me", "true") - .with(csrf()))) - .hasMessageContaining("UserDetailsService is required"); + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnRememberMeAuthenticationFilter() { + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(RememberMeAuthenticationFilter.class)); + } + + @Test + public void rememberMeWhenInvokedTwiceThenUsesOriginalUserDetailsService() throws Exception { + given(DuplicateDoesNotOverrideConfig.userDetailsService.loadUserByUsername(anyString())) + .willReturn(new User("user", "password", Collections.emptyList())); + this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/") + .with(httpBasic("user", "password")) + .param("remember-me", "true"); + // @formatter:on + this.mvc.perform(request); + verify(DuplicateDoesNotOverrideConfig.userDetailsService).loadUserByUsername("user"); + } + + @Test + public void loginWhenRememberMeTrueThenRespondsWithRememberMeCookie() throws Exception { + this.spring.register(RememberMeConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + // @formatter:on + this.mvc.perform(request).andExpect(cookie().exists("remember-me")); + } + + @Test + public void getWhenRememberMeCookieThenAuthenticationIsRememberMeAuthenticationToken() throws Exception { + this.spring.register(RememberMeConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(post("/login").with(csrf()).param("username", "user") + .param("password", "password").param("remember-me", "true")).andReturn(); + Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); + // @formatter:off + MockHttpServletRequestBuilder request = get("/abc").cookie(rememberMeCookie); + SecurityMockMvcResultMatchers.AuthenticatedMatcher remembermeAuthentication = authenticated() + .withAuthentication((auth) -> assertThat(auth).isInstanceOf(RememberMeAuthenticationToken.class)); + // @formatter:on + this.mvc.perform(request).andExpect(remembermeAuthentication); + } + + @Test + public void logoutWhenRememberMeCookieThenAuthenticationIsRememberMeCookieExpired() throws Exception { + this.spring.register(RememberMeConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + // @formatter:on + MvcResult mvcResult = this.mvc.perform(loginRequest).andReturn(); + Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); + HttpSession session = mvcResult.getRequest().getSession(); + // @formatter:off + MockHttpServletRequestBuilder logoutRequest = post("/logout") + .with(csrf()) + .cookie(rememberMeCookie) + .session((MockHttpSession) session); + this.mvc.perform(logoutRequest) + .andExpect(redirectedUrl("/login?logout")) + .andExpect(cookie().maxAge("remember-me", 0)); + // @formatter:on + } + + @Test + public void getWhenRememberMeCookieAndLoggedOutThenRedirectsToLogin() throws Exception { + this.spring.register(RememberMeConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + // @formatter:on + MvcResult loginMvcResult = this.mvc.perform(loginRequest).andReturn(); + Cookie rememberMeCookie = loginMvcResult.getResponse().getCookie("remember-me"); + HttpSession session = loginMvcResult.getRequest().getSession(); + // @formatter:off + MockHttpServletRequestBuilder logoutRequest = post("/logout") + .with(csrf()) + .cookie(rememberMeCookie) + .session((MockHttpSession) session); + // @formatter:on + MvcResult logoutMvcResult = this.mvc.perform(logoutRequest).andReturn(); + Cookie expiredRememberMeCookie = logoutMvcResult.getResponse().getCookie("remember-me"); + // @formatter:off + MockHttpServletRequestBuilder expiredRequest = get("/abc") + .with(csrf()) + .cookie(expiredRememberMeCookie); + // @formatter:on + this.mvc.perform(expiredRequest).andExpect(redirectedUrl("http://localhost/login")); + } + + @Test + public void loginWhenRememberMeConfiguredInLambdaThenRespondsWithRememberMeCookie() throws Exception { + this.spring.register(RememberMeInLambdaConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + // @formatter:on + this.mvc.perform(request).andExpect(cookie().exists("remember-me")); + } + + @Test + public void loginWhenRememberMeTrueAndCookieDomainThenRememberMeCookieHasDomain() throws Exception { + this.spring.register(RememberMeCookieDomainConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + this.mvc.perform(request). + andExpect(cookie().exists("remember-me")) + .andExpect(cookie().domain("remember-me", "spring.io")); + // @formatter:on + } + + @Test + public void loginWhenRememberMeTrueAndCookieDomainInLambdaThenRememberMeCookieHasDomain() throws Exception { + this.spring.register(RememberMeCookieDomainInLambdaConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + this.mvc.perform(loginRequest) + .andExpect(cookie().exists("remember-me")) + .andExpect(cookie().domain("remember-me", "spring.io")); + // @formatter:on + } + + @Test + public void configureWhenRememberMeCookieNameAndRememberMeServicesThenException() { + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy( + () -> this.spring.register(RememberMeCookieNameAndRememberMeServicesConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class) + .withMessageContaining("Can not set rememberMeCookieName and custom rememberMeServices."); + } + + @Test + public void getWhenRememberMeCookieAndNoKeyConfiguredThenKeyFromRememberMeServicesIsUsed() throws Exception { + this.spring.register(FallbackRememberMeKeyConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + // @formatter:on + MvcResult mvcResult = this.mvc.perform(loginRequest).andReturn(); + Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); + MockHttpServletRequestBuilder requestWithRememberme = get("/abc").cookie(rememberMeCookie); + // @formatter:off + SecurityMockMvcResultMatchers.AuthenticatedMatcher remembermeAuthentication = authenticated() + .withAuthentication((auth) -> assertThat(auth).isInstanceOf(RememberMeAuthenticationToken.class)); + // @formatter:on + this.mvc.perform(requestWithRememberme).andExpect(remembermeAuthentication); } @EnableWebSecurity static class NullUserDetailsConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http @@ -113,18 +294,12 @@ public class RememberMeConfigurerTests { .authenticationProvider(provider); // @formatter:on } - } - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnRememberMeAuthenticationFilter() { - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(RememberMeAuthenticationFilter.class)); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -148,30 +323,21 @@ public class RememberMeConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void rememberMeWhenInvokedTwiceThenUsesOriginalUserDetailsService() throws Exception { - when(DuplicateDoesNotOverrideConfig.userDetailsService.loadUserByUsername(anyString())) - .thenReturn(new User("user", "password", Collections.emptyList())); - this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password")) - .param("remember-me", "true")); - - verify(DuplicateDoesNotOverrideConfig.userDetailsService).loadUserByUsername("user"); } @EnableWebSecurity static class DuplicateDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + static UserDetailsService userDetailsService = mock(UserDetailsService.class); @Override @@ -187,92 +353,20 @@ public class RememberMeConfigurerTests { // @formatter:on } + @Override @Bean public UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager( + // @formatter:off User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build() + // @formatter:on ); } - } - @Test - public void loginWhenRememberMeTrueThenRespondsWithRememberMeCookie() throws Exception { - this.spring.register(RememberMeConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andExpect(cookie().exists("remember-me")); - } - - @Test - public void getWhenRememberMeCookieThenAuthenticationIsRememberMeAuthenticationToken() throws Exception { - this.spring.register(RememberMeConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andReturn(); - Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); - - this.mvc.perform(get("/abc") - .cookie(rememberMeCookie)) - .andExpect(authenticated().withAuthentication(auth -> - assertThat(auth).isInstanceOf(RememberMeAuthenticationToken.class))); - } - - @Test - public void logoutWhenRememberMeCookieThenAuthenticationIsRememberMeCookieExpired() throws Exception { - this.spring.register(RememberMeConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andReturn(); - Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); - HttpSession session = mvcResult.getRequest().getSession(); - - this.mvc.perform(post("/logout") - .with(csrf()) - .cookie(rememberMeCookie) - .session((MockHttpSession) session)) - .andExpect(redirectedUrl("/login?logout")) - .andExpect(cookie().maxAge("remember-me", 0)); - } - - @Test - public void getWhenRememberMeCookieAndLoggedOutThenRedirectsToLogin() throws Exception { - this.spring.register(RememberMeConfig.class).autowire(); - - MvcResult loginMvcResult = this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andReturn(); - Cookie rememberMeCookie = loginMvcResult.getResponse().getCookie("remember-me"); - HttpSession session = loginMvcResult.getRequest().getSession(); - MvcResult logoutMvcResult = this.mvc.perform(post("/logout") - .with(csrf()) - .cookie(rememberMeCookie) - .session((MockHttpSession) session)) - .andReturn(); - Cookie expiredRememberMeCookie = logoutMvcResult.getResponse().getCookie("remember-me"); - - this.mvc.perform(get("/abc") - .with(csrf()) - .cookie(expiredRememberMeCookie)) - .andExpect(redirectedUrl("http://localhost/login")); } @EnableWebSecurity @@ -292,26 +386,14 @@ public class RememberMeConfigurerTests { } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - - @Test - public void loginWhenRememberMeConfiguredInLambdaThenRespondsWithRememberMeCookie() throws Exception { - this.spring.register(RememberMeInLambdaConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andExpect(cookie().exists("remember-me")); } @EnableWebSecurity @@ -321,7 +403,7 @@ public class RememberMeConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) @@ -331,30 +413,20 @@ public class RememberMeConfigurerTests { } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenRememberMeTrueAndCookieDomainThenRememberMeCookieHasDomain() throws Exception { - this.spring.register(RememberMeCookieDomainConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andExpect(cookie().exists("remember-me")) - .andExpect(cookie().domain("remember-me", "spring.io")); } @EnableWebSecurity static class RememberMeCookieDomainConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http @@ -369,39 +441,29 @@ public class RememberMeConfigurerTests { } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenRememberMeTrueAndCookieDomainInLambdaThenRememberMeCookieHasDomain() throws Exception { - this.spring.register(RememberMeCookieDomainInLambdaConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andExpect(cookie().exists("remember-me")) - .andExpect(cookie().domain("remember-me", "spring.io")); } @EnableWebSecurity static class RememberMeCookieDomainInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().hasRole("USER") ) .formLogin(withDefaults()) - .rememberMe(rememberMe -> + .rememberMe((rememberMe) -> rememberMe .rememberMeCookieDomain("spring.io") ); @@ -409,27 +471,22 @@ public class RememberMeConfigurerTests { } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void configureWhenRememberMeCookieNameAndRememberMeServicesThenException() { - assertThatThrownBy(() -> this.spring.register(RememberMeCookieNameAndRememberMeServicesConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Can not set rememberMeCookieName and custom rememberMeServices."); } @EnableWebSecurity static class RememberMeCookieNameAndRememberMeServicesConfig extends WebSecurityConfigurerAdapter { + static RememberMeServices REMEMBER_ME = mock(RememberMeServices.class); + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http @@ -446,32 +503,14 @@ public class RememberMeConfigurerTests { } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void getWhenRememberMeCookieAndNoKeyConfiguredThenKeyFromRememberMeServicesIsUsed() - throws Exception { - this.spring.register(FallbackRememberMeKeyConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password") - .param("remember-me", "true")) - .andReturn(); - Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); - - this.mvc.perform(get("/abc") - .cookie(rememberMeCookie)) - .andExpect(authenticated().withAuthentication(auth -> - assertThat(auth).isInstanceOf(RememberMeAuthenticationToken.class))); } @EnableWebSecurity @@ -485,5 +524,7 @@ public class RememberMeConfigurerTests { .rememberMeServices(new TokenBasedRememberMeServices("key", userDetailsService())); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurerTests.java index c2ec1211f5..effda94173 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestCacheConfigurerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletRequest; @@ -39,6 +40,8 @@ import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.request.MockMultipartHttpServletRequestBuilder; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -47,8 +50,8 @@ import static org.mockito.Mockito.verify; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; /** @@ -68,13 +71,237 @@ public class RequestCacheConfigurerTests { @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnExceptionTranslationFilter() { this.spring.register(ObjectPostProcessorConfig.class, DefaultSecurityConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(RequestCacheAwareFilter.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(RequestCacheAwareFilter.class)); + @Test + public void getWhenInvokingExceptionHandlingTwiceThenOriginalEntryPointUsed() throws Exception { + this.spring.register(InvokeTwiceDoesNotOverrideConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(InvokeTwiceDoesNotOverrideConfig.requestCache).getMatchingRequest(any(HttpServletRequest.class), + any(HttpServletResponse.class)); + } + + @Test + public void getWhenBookmarkedUrlIsFaviconIcoThenPostAuthenticationRedirectsToRoot() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("/favicon.ico")) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + // ignores favicon.ico + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + @Test + public void getWhenBookmarkedUrlIsFaviconPngThenPostAuthenticationRedirectsToRoot() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("/favicon.png")) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + // ignores favicon.png + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + // SEC-2321 + @Test + public void getWhenBookmarkedRequestIsApplicationJsonThenPostAuthenticationRedirectsToRoot() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + MockHttpServletRequestBuilder request = get("/messages").header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + // ignores application/json + // This is desirable since JSON requests are typically not invoked directly from + // the browser and we don't want the browser to replay them + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + // SEC-2321 + @Test + public void getWhenBookmarkedRequestIsXRequestedWithThenPostAuthenticationRedirectsToRoot() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder xRequestedWith = get("/messages") + .header("X-Requested-With", "XMLHttpRequest"); + MockHttpSession session = (MockHttpSession) this.mvc + .perform(xRequestedWith) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + // This is desirable since XHR requests are typically not invoked directly from + // the browser and we don't want the browser to replay them + } + + @Test + public void getWhenBookmarkedRequestIsTextEventStreamThenPostAuthenticationRedirectsToRoot() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + MockHttpServletRequestBuilder request = get("/messages").header(HttpHeaders.ACCEPT, + MediaType.TEXT_EVENT_STREAM); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + // ignores text/event-stream + // This is desirable since event-stream requests are typically not invoked + // directly from the browser and we don't want the browser to replay them + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + @Test + public void getWhenBookmarkedRequestIsAllMediaTypeThenPostAuthenticationRemembers() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + MockHttpServletRequestBuilder request = get("/messages").header(HttpHeaders.ACCEPT, MediaType.ALL); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); + } + + @Test + public void getWhenBookmarkedRequestIsTextHtmlThenPostAuthenticationRemembers() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + MockHttpServletRequestBuilder request = get("/messages").header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); + } + + @Test + public void getWhenBookmarkedRequestIsChromeThenPostAuthenticationRemembers() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/messages") + .header(HttpHeaders.ACCEPT, "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8"); + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); + } + + @Test + public void getWhenBookmarkedRequestIsRequestedWithAndroidThenPostAuthenticationRemembers() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/messages") + .header("X-Requested-With", "com.android"); + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/messages")); + } + + // gh-6102 + @Test + public void getWhenRequestCacheIsDisabledThenExceptionTranslationFilterDoesNotStoreRequest() throws Exception { + this.spring.register(RequestCacheDisabledConfig.class, + ExceptionHandlingConfigurerTests.DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("/bob")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + // SEC-7060 + @Test + public void postWhenRequestIsMultipartThenPostAuthenticationRedirectsToRoot() throws Exception { + this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); + MockMultipartFile aFile = new MockMultipartFile("aFile", "A_FILE".getBytes()); + MockMultipartHttpServletRequestBuilder request = multipart("/upload").file(aFile); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(request) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + @Test + public void getWhenRequestCacheIsDisabledInLambdaThenExceptionTranslationFilterDoesNotStoreRequest() + throws Exception { + this.spring.register(RequestCacheDisabledInLambdaConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("/bob")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + @Test + public void getWhenRequestCacheInLambdaThenRedirectedToCachedPage() throws Exception { + this.spring.register(RequestCacheInLambdaConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("/bob")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("http://localhost/bob")); + } + + @Test + public void getWhenCustomRequestCacheInLambdaThenCustomRequestCacheUsed() throws Exception { + this.spring.register(CustomRequestCacheInLambdaConfig.class, DefaultSecurityConfig.class).autowire(); + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("/bob")) + .andReturn() + .getRequest() + .getSession(); + // @formatter:on + this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); + } + + private static RequestBuilder formLogin(MockHttpSession session) { + // @formatter:off + return post("/login") + .param("username", "user") + .param("password", "password") + .session(session) + .with(csrf()); + // @formatter:on } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -89,27 +316,21 @@ public class RequestCacheConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void getWhenInvokingExceptionHandlingTwiceThenOriginalEntryPointUsed() throws Exception { - this.spring.register(InvokeTwiceDoesNotOverrideConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(InvokeTwiceDoesNotOverrideConfig.requestCache) - .getMatchingRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @EnableWebSecurity static class InvokeTwiceDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + static RequestCache requestCache = mock(RequestCache.class); @Override @@ -122,137 +343,7 @@ public class RequestCacheConfigurerTests { .requestCache(); // @formatter:on } - } - @Test - public void getWhenBookmarkedUrlIsFaviconIcoThenPostAuthenticationRedirectsToRoot() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/favicon.ico")) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); // ignores favicon.ico - } - - @Test - public void getWhenBookmarkedUrlIsFaviconPngThenPostAuthenticationRedirectsToRoot() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/favicon.png")) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); // ignores favicon.png - } - - // SEC-2321 - @Test - public void getWhenBookmarkedRequestIsApplicationJsonThenPostAuthenticationRedirectsToRoot() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON)) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); // ignores application/json - - // This is desirable since JSON requests are typically not invoked directly from the browser and we don't want the browser to replay them - } - - // SEC-2321 - @Test - public void getWhenBookmarkedRequestIsXRequestedWithThenPostAuthenticationRedirectsToRoot() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header("X-Requested-With", "XMLHttpRequest")) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); - - // This is desirable since XHR requests are typically not invoked directly from the browser and we don't want the browser to replay them - } - @Test - public void getWhenBookmarkedRequestIsTextEventStreamThenPostAuthenticationRedirectsToRoot() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header(HttpHeaders.ACCEPT, MediaType.TEXT_EVENT_STREAM)) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); // ignores text/event-stream - - // This is desirable since event-stream requests are typically not invoked directly from the browser and we don't want the browser to replay them - } - - @Test - public void getWhenBookmarkedRequestIsAllMediaTypeThenPostAuthenticationRemembers() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header(HttpHeaders.ACCEPT, MediaType.ALL)) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("http://localhost/messages")); - } - - @Test - public void getWhenBookmarkedRequestIsTextHtmlThenPostAuthenticationRemembers() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header(HttpHeaders.ACCEPT, MediaType.TEXT_HTML)) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("http://localhost/messages")); - } - - @Test - public void getWhenBookmarkedRequestIsChromeThenPostAuthenticationRemembers() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header(HttpHeaders.ACCEPT, "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("http://localhost/messages")); - } - - @Test - public void getWhenBookmarkedRequestIsRequestedWithAndroidThenPostAuthenticationRemembers() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/messages") - .header("X-Requested-With", "com.android")) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("http://localhost/messages")); } @EnableWebSecurity @@ -260,70 +351,36 @@ public class RequestCacheConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() .and() .formLogin(); + // @formatter:on } - } - // gh-6102 - @Test - public void getWhenRequestCacheIsDisabledThenExceptionTranslationFilterDoesNotStoreRequest() throws Exception { - this.spring.register(RequestCacheDisabledConfig.class, ExceptionHandlingConfigurerTests.DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/bob")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); - } - - // SEC-7060 - @Test - public void postWhenRequestIsMultipartThenPostAuthenticationRedirectsToRoot() throws Exception { - this.spring.register(RequestCacheDefaultsConfig.class, DefaultSecurityConfig.class).autowire(); - - MockMultipartFile aFile = new MockMultipartFile("aFile", "A_FILE".getBytes()); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(multipart("/upload") - .file(aFile)) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)).andExpect(redirectedUrl("/")); } @EnableWebSecurity static class RequestCacheDisabledConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { super.configure(http); http.requestCache().disable(); } - } - @Test - public void getWhenRequestCacheIsDisabledInLambdaThenExceptionTranslationFilterDoesNotStoreRequest() throws Exception { - this.spring.register(RequestCacheDisabledInLambdaConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/bob")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); } @EnableWebSecurity static class RequestCacheDisabledInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -331,27 +388,17 @@ public class RequestCacheConfigurerTests { .requestCache(RequestCacheConfigurer::disable); // @formatter:on } - } - @Test - public void getWhenRequestCacheInLambdaThenRedirectedToCachedPage() throws Exception { - this.spring.register(RequestCacheInLambdaConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/bob")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("http://localhost/bob")); } @EnableWebSecurity static class RequestCacheInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -359,58 +406,45 @@ public class RequestCacheConfigurerTests { .requestCache(withDefaults()); // @formatter:on } - } - @Test - public void getWhenCustomRequestCacheInLambdaThenCustomRequestCacheUsed() throws Exception { - this.spring.register(CustomRequestCacheInLambdaConfig.class, DefaultSecurityConfig.class).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("/bob")) - .andReturn().getRequest().getSession(); - - this.mvc.perform(formLogin(session)) - .andExpect(redirectedUrl("/")); } @EnableWebSecurity static class CustomRequestCacheInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) .formLogin(withDefaults()) - .requestCache(requestCache -> + .requestCache((requestCache) -> requestCache .requestCache(new NullRequestCache()) ); // @formatter:on } + } @EnableWebSecurity static class DefaultSecurityConfig { @Bean - public InMemoryUserDetailsManager userDetailsManager() { + InMemoryUserDetailsManager userDetailsManager() { + // @formatter:off return new InMemoryUserDetailsManager(User.withDefaultPasswordEncoder() .username("user") .password("password") .roles("USER") .build() ); + // @formatter:on } + } - private static RequestBuilder formLogin(MockHttpSession session) { - return post("/login") - .param("username", "user") - .param("password", "password") - .session(session) - .with(csrf()); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestMatcherConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestMatcherConfigurerTests.java index 4bc131761b..c5c6196869 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestMatcherConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RequestMatcherConfigurerTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -46,11 +47,23 @@ public class RequestMatcherConfigurerTests { @Test public void authorizeRequestsWhenInvokedMultipleTimesThenChainsPaths() throws Exception { this.spring.register(Sec2908Config.class).autowire(); - + // @formatter:off this.mvc.perform(get("/oauth/abc")) .andExpect(status().isForbidden()); this.mvc.perform(get("/api/abc")) .andExpect(status().isForbidden()); + // @formatter:on + } + + @Test + public void authorizeRequestsWhenInvokedMultipleTimesInLambdaThenChainsPaths() throws Exception { + this.spring.register(AuthorizeRequestInLambdaConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/oauth/abc")) + .andExpect(status().isForbidden()); + this.mvc.perform(get("/api/abc")) + .andExpect(status().isForbidden()); + // @formatter:on } @EnableWebSecurity @@ -70,16 +83,7 @@ public class RequestMatcherConfigurerTests { .anyRequest().denyAll(); // @formatter:on } - } - @Test - public void authorizeRequestsWhenInvokedMultipleTimesInLambdaThenChainsPaths() throws Exception { - this.spring.register(AuthorizeRequestInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/oauth/abc")) - .andExpect(status().isForbidden()); - this.mvc.perform(get("/api/abc")) - .andExpect(status().isForbidden()); } @EnableWebSecurity @@ -89,19 +93,21 @@ public class RequestMatcherConfigurerTests { protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .requestMatchers(requestMatchers -> + .requestMatchers((requestMatchers) -> requestMatchers .antMatchers("/api/**") ) - .requestMatchers(requestMatchers -> + .requestMatchers((requestMatchers) -> requestMatchers .antMatchers("/oauth/**") ) - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().denyAll() ); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java index 6d2b21c65e..fc99c2772d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java @@ -16,8 +16,11 @@ package org.springframework.security.config.annotation.web.configurers; +import javax.servlet.http.HttpSession; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -37,10 +40,12 @@ import org.springframework.security.web.context.request.async.WebAsyncManagerInt import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; -import javax.servlet.http.HttpSession; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -52,6 +57,7 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder * @author Eleftheria Stein */ public class SecurityContextConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -61,13 +67,51 @@ public class SecurityContextConfigurerTests { @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnSecurityContextPersistenceFilter() { this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(SecurityContextPersistenceFilter.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(SecurityContextPersistenceFilter.class)); + @Test + public void securityContextWhenInvokedTwiceThenUsesOriginalSecurityContextRepository() throws Exception { + this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); + given(DuplicateDoesNotOverrideConfig.SCR.loadContext(any())).willReturn(mock(SecurityContext.class)); + this.mvc.perform(get("/")); + verify(DuplicateDoesNotOverrideConfig.SCR).loadContext(any(HttpRequestResponseHolder.class)); + } + + // SEC-2932 + @Test + public void securityContextWhenSecurityContextRepositoryNotConfiguredThenDoesNotThrowException() throws Exception { + this.spring.register(SecurityContextRepositoryDefaultsSecurityContextRepositoryConfig.class).autowire(); + this.mvc.perform(get("/")); + } + + @Test + public void requestWhenSecurityContextWithDefaultsInLambdaThenSessionIsCreated() throws Exception { + this.spring.register(SecurityContextWithDefaultsInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(formLogin()).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNotNull(); + } + + @Test + public void requestWhenSecurityContextDisabledInLambdaThenContextNotSavedInSession() throws Exception { + this.spring.register(SecurityContextDisabledInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(formLogin()).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + } + + @Test + public void requestWhenNullSecurityContextRepositoryInLambdaThenContextNotSavedInSession() throws Exception { + this.spring.register(NullSecurityContextRepositoryInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(formLogin()).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -82,28 +126,21 @@ public class SecurityContextConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void securityContextWhenInvokedTwiceThenUsesOriginalSecurityContextRepository() throws Exception { - this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); - when(DuplicateDoesNotOverrideConfig.SCR.loadContext(any())).thenReturn(mock(SecurityContext.class)); - - this.mvc.perform(get("/")); - - verify(DuplicateDoesNotOverrideConfig.SCR) - .loadContext(any(HttpRequestResponseHolder.class)); } @EnableWebSecurity static class DuplicateDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + static SecurityContextRepository SCR = mock(SecurityContextRepository.class); @Override @@ -116,19 +153,13 @@ public class SecurityContextConfigurerTests { .securityContext(); // @formatter:on } - } - //SEC-2932 - @Test - public void securityContextWhenSecurityContextRepositoryNotConfiguredThenDoesNotThrowException() throws Exception { - this.spring.register(SecurityContextRepositoryDefaultsSecurityContextRepositoryConfig.class).autowire(); - - this.mvc.perform(get("/")); } @Configuration @EnableWebSecurity static class SecurityContextRepositoryDefaultsSecurityContextRepositoryConfig extends WebSecurityConfigurerAdapter { + SecurityContextRepositoryDefaultsSecurityContextRepositoryConfig() { super(true); } @@ -157,19 +188,12 @@ public class SecurityContextConfigurerTests { .withUser("user").password("password").roles("USER"); // @formatter:on } - } - @Test - public void requestWhenSecurityContextWithDefaultsInLambdaThenSessionIsCreated() throws Exception { - this.spring.register(SecurityContextWithDefaultsInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(formLogin()).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - assertThat(session).isNotNull(); } @EnableWebSecurity static class SecurityContextWithDefaultsInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -187,19 +211,12 @@ public class SecurityContextConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void requestWhenSecurityContextDisabledInLambdaThenContextNotSavedInSession() throws Exception { - this.spring.register(SecurityContextDisabledInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(formLogin()).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - assertThat(session).isNull(); } @EnableWebSecurity static class SecurityContextDisabledInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -217,25 +234,18 @@ public class SecurityContextConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void requestWhenNullSecurityContextRepositoryInLambdaThenContextNotSavedInSession() throws Exception { - this.spring.register(NullSecurityContextRepositoryInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(formLogin()).andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - assertThat(session).isNull(); } @EnableWebSecurity static class NullSecurityContextRepositoryInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http .formLogin(withDefaults()) - .securityContext(securityContext -> + .securityContext((securityContext) -> securityContext .securityContextRepository(new NullSecurityContextRepository()) ); @@ -250,5 +260,7 @@ public class SecurityContextConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurerTests.java index 7a9e5b1f94..8b511b841a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletApiConfigurerTests.java @@ -18,8 +18,14 @@ package org.springframework.security.config.annotation.web.configurers; import java.util.List; +import javax.servlet.Filter; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -36,6 +42,7 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.PasswordEncodedUser; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors; import org.springframework.security.util.FieldUtils; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.FilterChainProxy; @@ -46,16 +53,12 @@ import org.springframework.security.web.authentication.logout.LogoutSuccessEvent import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.ConfigurableWebApplicationContext; -import javax.servlet.Filter; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; @@ -77,6 +80,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Onur Kagan Ozcan */ public class ServletApiConfigurerTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -86,13 +90,130 @@ public class ServletApiConfigurerTests { @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnSecurityContextHolderAwareRequestFilter() { this.spring.register(ObjectPostProcessorConfig.class).autowire(); - verify(ObjectPostProcessorConfig.objectPostProcessor) .postProcess(any(SecurityContextHolderAwareRequestFilter.class)); } + // SEC-2215 + @Test + public void configureWhenUsingDefaultsThenAuthenticationManagerIsNotNull() { + this.spring.register(ServletApiConfig.class).autowire(); + assertThat(this.spring.getContext().getBean("customAuthenticationManager")).isNotNull(); + } + + @Test + public void configureWhenUsingDefaultsThenAuthenticationEntryPointIsLogin() throws Exception { + this.spring.register(ServletApiConfig.class).autowire(); + this.mvc.perform(formLogin()).andExpect(status().isFound()); + } + + // SEC-2926 + @Test + public void configureWhenUsingDefaultsThenRolePrefixIsSet() throws Exception { + this.spring.register(ServletApiConfig.class, AdminController.class).autowire(); + TestingAuthenticationToken user = new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN"); + MockHttpServletRequestBuilder request = get("/admin").with(authentication(user)); + this.mvc.perform(request).andExpect(status().isOk()); + } + + @Test + public void requestWhenCustomAuthenticationEntryPointThenEntryPointUsed() throws Exception { + this.spring.register(CustomEntryPointConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(CustomEntryPointConfig.ENTRYPOINT).commence(any(HttpServletRequest.class), + any(HttpServletResponse.class), any(AuthenticationException.class)); + } + + @Test + public void servletApiWhenInvokedTwiceThenUsesOriginalRole() throws Exception { + this.spring.register(DuplicateInvocationsDoesNotOverrideConfig.class, AdminController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/admin") + .with(user("user").authorities(AuthorityUtils.createAuthorityList("PERMISSION_ADMIN"))); + this.mvc.perform(request) + .andExpect(status().isOk()); + SecurityMockMvcRequestPostProcessors.UserRequestPostProcessor userWithRoleAdmin = user("user") + .authorities(AuthorityUtils.createAuthorityList("ROLE_ADMIN")); + MockHttpServletRequestBuilder requestWithRoleAdmin = get("/admin") + .with(userWithRoleAdmin); + this.mvc.perform(requestWithRoleAdmin) + .andExpect(status().isForbidden()); + // @formatter:on + } + + @Test + public void configureWhenSharedObjectTrustResolverThenTrustResolverUsed() throws Exception { + this.spring.register(SharedTrustResolverConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(SharedTrustResolverConfig.TR, atLeastOnce()).isAnonymous(any()); + } + + @Test + public void requestWhenServletApiWithDefaultsInLambdaThenUsesDefaultRolePrefix() throws Exception { + this.spring.register(ServletApiWithDefaultsInLambdaConfig.class, AdminController.class).autowire(); + MockHttpServletRequestBuilder request = get("/admin") + .with(user("user").authorities(AuthorityUtils.createAuthorityList("ROLE_ADMIN"))); + this.mvc.perform(request).andExpect(status().isOk()); + } + + @Test + public void requestWhenRolePrefixInLambdaThenUsesCustomRolePrefix() throws Exception { + this.spring.register(RolePrefixInLambdaConfig.class, AdminController.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder requestWithAdminPermission = get("/admin") + .with(user("user").authorities(AuthorityUtils.createAuthorityList("PERMISSION_ADMIN"))); + this.mvc.perform(requestWithAdminPermission) + .andExpect(status().isOk()); + MockHttpServletRequestBuilder requestWithAdminRole = get("/admin") + .with(user("user").authorities(AuthorityUtils.createAuthorityList("ROLE_ADMIN"))); + this.mvc.perform(requestWithAdminRole) + .andExpect(status().isForbidden()); + // @formatter:on + } + + @Test + public void checkSecurityContextAwareAndLogoutFilterHasSameSizeAndHasLogoutSuccessEventPublishingLogoutHandler() { + this.spring.register(ServletApiWithLogoutConfig.class); + SecurityContextHolderAwareRequestFilter scaFilter = getFilter(SecurityContextHolderAwareRequestFilter.class); + LogoutFilter logoutFilter = getFilter(LogoutFilter.class); + LogoutHandler lfLogoutHandler = getFieldValue(logoutFilter, "handler"); + assertThat(lfLogoutHandler).isInstanceOf(CompositeLogoutHandler.class); + List scaLogoutHandlers = getFieldValue(scaFilter, "logoutHandlers"); + List lfLogoutHandlers = getFieldValue(lfLogoutHandler, "logoutHandlers"); + assertThat(scaLogoutHandlers).hasSameSizeAs(lfLogoutHandlers); + assertThat(scaLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); + assertThat(lfLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); + } + + @Test + public void logoutServletApiWhenCsrfDisabled() throws Exception { + ConfigurableWebApplicationContext context = this.spring.register(CsrfDisabledConfig.class).getContext(); + MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + MvcResult mvcResult = mockMvc.perform(get("/")).andReturn(); + assertThat(mvcResult.getRequest().getSession(false)).isNull(); + } + + private T getFilter(Class filterClass) { + return (T) getFilters().stream().filter(filterClass::isInstance).findFirst().orElse(null); + } + + private List getFilters() { + FilterChainProxy proxy = this.spring.getContext().getBean(FilterChainProxy.class); + return proxy.getFilters("/"); + } + + private T getFieldValue(Object target, String fieldName) { + try { + return (T) FieldUtils.getFieldValue(target, fieldName); + } + catch (Exception ex) { + throw new RuntimeException(ex); + } + } + @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -107,43 +228,21 @@ public class ServletApiConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - // SEC-2215 - @Test - public void configureWhenUsingDefaultsThenAuthenticationManagerIsNotNull() { - this.spring.register(ServletApiConfig.class).autowire(); - - assertThat(this.spring.getContext().getBean("customAuthenticationManager")).isNotNull(); - } - - @Test - public void configureWhenUsingDefaultsThenAuthenticationEntryPointIsLogin() throws Exception { - this.spring.register(ServletApiConfig.class).autowire(); - - this.mvc.perform(formLogin()) - .andExpect(status().isFound()); - } - - // SEC-2926 - @Test - public void configureWhenUsingDefaultsThenRolePrefixIsSet() throws Exception { - this.spring.register(ServletApiConfig.class, AdminController.class).autowire(); - - this.mvc.perform(get("/admin") - .with(authentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")))) - .andExpect(status().isOk()); } @EnableWebSecurity static class ServletApiConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { // @formatter:off @@ -154,24 +253,15 @@ public class ServletApiConfigurerTests { } @Bean - public AuthenticationManager customAuthenticationManager() throws Exception { + AuthenticationManager customAuthenticationManager() throws Exception { return super.authenticationManagerBean(); } - } - @Test - public void requestWhenCustomAuthenticationEntryPointThenEntryPointUsed() throws Exception { - this.spring.register(CustomEntryPointConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(CustomEntryPointConfig.ENTRYPOINT) - .commence(any(HttpServletRequest.class), - any(HttpServletResponse.class), any(AuthenticationException.class)); } @EnableWebSecurity static class CustomEntryPointConfig extends WebSecurityConfigurerAdapter { + static AuthenticationEntryPoint ENTRYPOINT = spy(AuthenticationEntryPoint.class); @Override @@ -196,23 +286,12 @@ public class ServletApiConfigurerTests { .withUser("user").password("password").roles("USER"); // @formatter:on } - } - @Test - public void servletApiWhenInvokedTwiceThenUsesOriginalRole() throws Exception { - this.spring.register(DuplicateInvocationsDoesNotOverrideConfig.class, AdminController.class).autowire(); - - this.mvc.perform(get("/admin") - .with(user("user").authorities(AuthorityUtils.createAuthorityList("PERMISSION_ADMIN")))) - .andExpect(status().isOk()); - - this.mvc.perform(get("/admin") - .with(user("user").authorities(AuthorityUtils.createAuthorityList("ROLE_ADMIN")))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class DuplicateInvocationsDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -223,19 +302,12 @@ public class ServletApiConfigurerTests { .servletApi(); // @formatter:on } - } - @Test - public void configureWhenSharedObjectTrustResolverThenTrustResolverUsed() throws Exception { - this.spring.register(SharedTrustResolverConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(SharedTrustResolverConfig.TR, atLeastOnce()).isAnonymous(any()); } @EnableWebSecurity static class SharedTrustResolverConfig extends WebSecurityConfigurerAdapter { + static AuthenticationTrustResolver TR = spy(AuthenticationTrustResolver.class); @Override @@ -245,19 +317,12 @@ public class ServletApiConfigurerTests { .setSharedObject(AuthenticationTrustResolver.class, TR); // @formatter:on } - } - @Test - public void requestWhenServletApiWithDefaultsInLambdaThenUsesDefaultRolePrefix() throws Exception { - this.spring.register(ServletApiWithDefaultsInLambdaConfig.class, AdminController.class).autowire(); - - this.mvc.perform(get("/admin") - .with(user("user").authorities(AuthorityUtils.createAuthorityList("ROLE_ADMIN")))) - .andExpect(status().isOk()); } @EnableWebSecurity static class ServletApiWithDefaultsInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -265,66 +330,40 @@ public class ServletApiConfigurerTests { .servletApi(withDefaults()); // @formatter:on } - } - @Test - public void requestWhenRolePrefixInLambdaThenUsesCustomRolePrefix() throws Exception { - this.spring.register(RolePrefixInLambdaConfig.class, AdminController.class).autowire(); - - this.mvc.perform(get("/admin") - .with(user("user").authorities(AuthorityUtils.createAuthorityList("PERMISSION_ADMIN")))) - .andExpect(status().isOk()); - - this.mvc.perform(get("/admin") - .with(user("user").authorities(AuthorityUtils.createAuthorityList("ROLE_ADMIN")))) - .andExpect(status().isForbidden()); } @EnableWebSecurity static class RolePrefixInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .servletApi(servletApi -> + .servletApi((servletApi) -> servletApi .rolePrefix("PERMISSION_") ); // @formatter:on } + } @RestController static class AdminController { + @GetMapping("/admin") - public void admin(HttpServletRequest request) { + void admin(HttpServletRequest request) { if (!request.isUserInRole("ADMIN")) { throw new AccessDeniedException("This resource is only available to admins"); } } - } - @Test - public void checkSecurityContextAwareAndLogoutFilterHasSameSizeAndHasLogoutSuccessEventPublishingLogoutHandler() { - this.spring.register(ServletApiWithLogoutConfig.class); - - SecurityContextHolderAwareRequestFilter scaFilter = getFilter(SecurityContextHolderAwareRequestFilter.class); - LogoutFilter logoutFilter = getFilter(LogoutFilter.class); - - LogoutHandler lfLogoutHandler = getFieldValue(logoutFilter, "handler"); - assertThat(lfLogoutHandler).isInstanceOf(CompositeLogoutHandler.class); - - List scaLogoutHandlers = getFieldValue(scaFilter, "logoutHandlers"); - List lfLogoutHandlers = getFieldValue(lfLogoutHandler, "logoutHandlers"); - - assertThat(scaLogoutHandlers).hasSameSizeAs(lfLogoutHandlers); - - assertThat(scaLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); - assertThat(lfLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); } @EnableWebSecurity static class ServletApiWithLogoutConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -333,22 +372,13 @@ public class ServletApiConfigurerTests { .logout(); // @formatter:on } - } - @Test - public void logoutServletApiWhenCsrfDisabled() throws Exception { - ConfigurableWebApplicationContext context = this.spring.register(CsrfDisabledConfig.class).getContext(); - MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(context) - .apply(springSecurity()) - .build(); - MvcResult mvcResult = mockMvc.perform(get("/")) - .andReturn(); - assertThat(mvcResult.getRequest().getSession(false)).isNull(); } @Configuration @EnableWebSecurity static class CsrfDisabledConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -359,33 +389,16 @@ public class ServletApiConfigurerTests { @RestController static class LogoutController { + @GetMapping("/") String logout(HttpServletRequest request) throws ServletException { request.getSession().setAttribute("foo", "bar"); request.logout(); return "logout"; } + } - } - private T getFilter(Class filterClass) { - return (T) getFilters().stream() - .filter(filterClass::isInstance) - .findFirst() - .orElse(null); - } - - private List getFilters() { - FilterChainProxy proxy = this.spring.getContext().getBean(FilterChainProxy.class); - return proxy.getFilters("/"); - } - - private T getFieldValue(Object target, String fieldName) { - try { - return (T) FieldUtils.getFieldValue(target, fieldName); - } catch (Exception e) { - throw new RuntimeException(e); - } } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java index 3c5d5725db..53f7b0ec8a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.lang.reflect.Method; + import javax.servlet.Filter; import org.junit.After; @@ -43,20 +45,22 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** - * * @author Rob Winch */ @RunWith(PowerMockRunner.class) @PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" }) public class SessionManagementConfigurerServlet31Tests { + @Mock Method method; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; ConfigurableApplicationContext context; @@ -65,22 +69,22 @@ public class SessionManagementConfigurerServlet31Tests { @Before public void setup() { - request = new MockHttpServletRequest("GET", ""); - response = new MockHttpServletResponse(); - chain = new MockFilterChain(); + this.request = new MockHttpServletRequest("GET", ""); + this.response = new MockHttpServletResponse(); + this.chain = new MockFilterChain(); } @After public void teardown() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @Test public void changeSessionIdThenPreserveParameters() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - String id = request.getSession().getId(); + String id = request.getSession().getId(); request.getSession(); request.setServletPath("/login"); request.setMethod("POST"); @@ -88,59 +92,54 @@ public class SessionManagementConfigurerServlet31Tests { request.setParameter("password", "password"); HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); CsrfToken token = repository.generateToken(request); - repository.saveToken(token, request, response); + repository.saveToken(token, request, this.response); request.setParameter(token.getParameterName(), token.getToken()); request.getSession().setAttribute("attribute1", "value1"); - loadConfig(SessionManagementDefaultSessionFixationServlet31Config.class); - - springSecurityFilterChain.doFilter(request, response, chain); - + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(request.getSession().getId()).isNotEqualTo(id); assertThat(request.getSession().getAttribute("attribute1")).isEqualTo("value1"); } - @EnableWebSecurity - static class SessionManagementDefaultSessionFixationServlet31Config extends - WebSecurityConfigurerAdapter { - // @formatter:off - @Override - protected void configure(HttpSecurity http) throws Exception { - http - .formLogin() - .and() - .sessionManagement(); - } - // @formatter:on - - // @formatter:off - @Override - protected void configure(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication() - .withUser(PasswordEncodedUser.user()); - } - // @formatter:on - } - private void loadConfig(Class... classes) { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); context.register(classes); context.refresh(); this.context = context; - this.springSecurityFilterChain = this.context.getBean( - "springSecurityFilterChain", Filter.class); + this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } private void login(Authentication auth) { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder( - request, response); + HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response); repo.loadContext(requestResponseHolder); - SecurityContextImpl securityContextImpl = new SecurityContextImpl(); securityContextImpl.setAuthentication(auth); - repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), - requestResponseHolder.getResponse()); + repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse()); } + + @EnableWebSecurity + static class SessionManagementDefaultSessionFixationServlet31Config extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .formLogin() + .and() + .sessionManagement(); + // @formatter:on + } + + @Override + protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionAuthenticationStrategyTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionAuthenticationStrategyTests.java index 519b4a9703..8985070277 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionAuthenticationStrategyTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionAuthenticationStrategyTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -28,9 +33,6 @@ import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.test.web.servlet.MockMvc; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -52,32 +54,36 @@ public class SessionManagementConfigurerSessionAuthenticationStrategyTests { public void requestWhenCustomSessionAuthenticationStrategyProvidedThenCalled() throws Exception { this.spring.register(CustomSessionAuthenticationStrategyConfig.class).autowire(); this.mvc.perform(formLogin().user("user").password("password")); - verify(CustomSessionAuthenticationStrategyConfig.customSessionAuthenticationStrategy) - .onAuthentication(any(Authentication.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(CustomSessionAuthenticationStrategyConfig.customSessionAuthenticationStrategy).onAuthentication( + any(Authentication.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @EnableWebSecurity static class CustomSessionAuthenticationStrategyConfig extends WebSecurityConfigurerAdapter { - static SessionAuthenticationStrategy customSessionAuthenticationStrategy = mock(SessionAuthenticationStrategy.class); - // @formatter:off + static SessionAuthenticationStrategy customSessionAuthenticationStrategy = mock( + SessionAuthenticationStrategy.class); + @Override public void configure(HttpSecurity http) throws Exception { + // @formatter:off http .formLogin() .and() .sessionManagement() .sessionAuthenticationStrategy(customSessionAuthenticationStrategy); + // @formatter:on } - // @formatter:on - // @formatter:off @Override protected void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser(PasswordEncodedUser.user()); + // @formatter:on } - // @formatter:on + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionCreationPolicyTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionCreationPolicyTests.java index aaea0b4343..562e344784 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionCreationPolicyTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerSessionCreationPolicyTests.java @@ -38,6 +38,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Josh Cummings */ public class SessionManagementConfigurerSessionCreationPolicyTests { + @Autowired MockMvc mvc; @@ -45,71 +46,69 @@ public class SessionManagementConfigurerSessionCreationPolicyTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void getWhenSharedObjectSessionCreationPolicyConfigurationThenOverrides() - throws Exception { - + public void getWhenSharedObjectSessionCreationPolicyConfigurationThenOverrides() throws Exception { this.spring.register(StatelessCreateSessionSharedObjectConfig.class).autowire(); - MvcResult result = this.mvc.perform(get("/")).andReturn(); - assertThat(result.getRequest().getSession(false)).isNull(); } + @Test + public void getWhenUserSessionCreationPolicyConfigurationThenOverrides() throws Exception { + this.spring.register(StatelessCreateSessionUserConfig.class).autowire(); + MvcResult result = this.mvc.perform(get("/")).andReturn(); + assertThat(result.getRequest().getSession(false)).isNull(); + } + + @Test + public void getWhenDefaultsThenLoginChallengeCreatesSession() throws Exception { + this.spring.register(DefaultConfig.class, BasicController.class).autowire(); + // @formatter:off + MvcResult result = this.mvc.perform(get("/")) + .andExpect(status().isUnauthorized()) + .andReturn(); + // @formatter:on + assertThat(result.getRequest().getSession(false)).isNotNull(); + } + @EnableWebSecurity static class StatelessCreateSessionSharedObjectConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { super.configure(http); http.setSharedObject(SessionCreationPolicy.class, SessionCreationPolicy.STATELESS); } - } - @Test - public void getWhenUserSessionCreationPolicyConfigurationThenOverrides() - throws Exception { - - this.spring.register(StatelessCreateSessionUserConfig.class).autowire(); - - MvcResult result = this.mvc.perform(get("/")).andReturn(); - - assertThat(result.getRequest().getSession(false)).isNull(); } @EnableWebSecurity static class StatelessCreateSessionUserConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { super.configure(http); + // @formatter:off http .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS); - + // @formatter:on http.setSharedObject(SessionCreationPolicy.class, SessionCreationPolicy.ALWAYS); } - } - @Test - public void getWhenDefaultsThenLoginChallengeCreatesSession() - throws Exception { - - this.spring.register(DefaultConfig.class, BasicController.class).autowire(); - - MvcResult result = - this.mvc.perform(get("/")) - .andExpect(status().isUnauthorized()) - .andReturn(); - - assertThat(result.getRequest().getSession(false)).isNotNull(); } @EnableWebSecurity static class DefaultConfig extends WebSecurityConfigurerAdapter { + } @RestController static class BasicController { + @GetMapping("/") - public String root() { + String root() { return "ok"; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java index b1104f58d3..638391dd6e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTests.java @@ -16,8 +16,13 @@ package org.springframework.security.config.annotation.web.configurers; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpSession; @@ -44,18 +49,15 @@ import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.session.SessionManagementFilter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; @@ -82,15 +84,219 @@ public class SessionManagementConfigurerTests { public void sessionManagementWhenConfiguredThenDoesNotOverrideRequestCache() throws Exception { SessionManagementRequestCacheConfig.REQUEST_CACHE = mock(RequestCache.class); this.spring.register(SessionManagementRequestCacheConfig.class).autowire(); - this.mvc.perform(get("/")); + verify(SessionManagementRequestCacheConfig.REQUEST_CACHE).getMatchingRequest(any(HttpServletRequest.class), + any(HttpServletResponse.class)); + } - verify(SessionManagementRequestCacheConfig.REQUEST_CACHE) - .getMatchingRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); + @Test + public void sessionManagementWhenConfiguredThenDoesNotOverrideSecurityContextRepository() throws Exception { + SessionManagementSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPO = mock(SecurityContextRepository.class); + given(SessionManagementSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPO + .loadContext(any(HttpRequestResponseHolder.class))).willReturn(mock(SecurityContext.class)); + this.spring.register(SessionManagementSecurityContextRepositoryConfig.class).autowire(); + this.mvc.perform(get("/")); + verify(SessionManagementSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPO) + .saveContext(any(SecurityContext.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void sessionManagementWhenInvokedTwiceThenUsesOriginalSessionCreationPolicy() throws Exception { + this.spring.register(InvokeTwiceDoesNotOverride.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + } + + // SEC-2137 + @Test + public void getWhenSessionFixationDisabledAndConcurrencyControlEnabledThenSessionIsNotInvalidated() + throws Exception { + this.spring.register(DisableSessionFixationEnableConcurrencyControlConfig.class).autowire(); + MockHttpSession session = new MockHttpSession(); + String sessionId = session.getId(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/") + .with(httpBasic("user", "password")) + .session(session); + MvcResult mvcResult = this.mvc.perform(request) + .andExpect(status().isNotFound()) + .andReturn(); + // @formatter:on + assertThat(mvcResult.getRequest().getSession().getId()).isEqualTo(sessionId); + } + + @Test + public void authenticateWhenNewSessionFixationProtectionInLambdaThenCreatesNewSession() throws Exception { + this.spring.register(SFPNewSessionInLambdaConfig.class).autowire(); + MockHttpSession givenSession = new MockHttpSession(); + String givenSessionId = givenSession.getId(); + givenSession.setAttribute("name", "value"); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(givenSession) + .with(httpBasic("user", "password")); + MockHttpSession resultingSession = (MockHttpSession) this.mvc.perform(request) + .andExpect(status().isNotFound()) + .andReturn() + .getRequest() + .getSession(false); + // @formatter:on + assertThat(givenSessionId).isNotEqualTo(resultingSession.getId()); + assertThat(resultingSession.getAttribute("name")).isNull(); + } + + @Test + public void loginWhenUserLoggedInAndMaxSessionsIsOneThenLoginPrevented() throws Exception { + this.spring.register(ConcurrencyControlConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder firstRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + this.mvc.perform(firstRequest); + MockHttpServletRequestBuilder secondRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + this.mvc.perform(secondRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?error")); + // @formatter:on + } + + @Test + public void loginWhenUserSessionExpiredAndMaxSessionsIsOneThenLoggedIn() throws Exception { + this.spring.register(ConcurrencyControlConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder firstRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + MvcResult mvcResult = this.mvc.perform(firstRequest) + .andReturn(); + // @formatter:on + HttpSession authenticatedSession = mvcResult.getRequest().getSession(); + this.spring.getContext().publishEvent(new HttpSessionDestroyedEvent(authenticatedSession)); + // @formatter:off + MockHttpServletRequestBuilder secondRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + this.mvc.perform(secondRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/")); + // @formatter:on + } + + @Test + public void loginWhenUserLoggedInAndMaxSessionsOneInLambdaThenLoginPrevented() throws Exception { + this.spring.register(ConcurrencyControlInLambdaConfig.class).autowire(); + // @formatter:off + MockHttpServletRequestBuilder firstRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + // @formatter:on + this.mvc.perform(firstRequest); + // @formatter:off + MockHttpServletRequestBuilder secondRequest = post("/login") + .with(csrf()) + .param("username", "user") + .param("password", "password"); + this.mvc.perform(secondRequest) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?error")); + // @formatter:on + } + + @Test + public void requestWhenSessionCreationPolicyStateLessInLambdaThenNoSessionCreated() throws Exception { + this.spring.register(SessionCreationPolicyStateLessInLambdaConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")).andReturn(); + HttpSession session = mvcResult.getRequest().getSession(false); + assertThat(session).isNull(); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnSessionManagementFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(SessionManagementFilter.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnConcurrentSessionFilter() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ConcurrentSessionFilter.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnConcurrentSessionControlAuthenticationStrategy() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor) + .postProcess(any(ConcurrentSessionControlAuthenticationStrategy.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnCompositeSessionAuthenticationStrategy() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor) + .postProcess(any(CompositeSessionAuthenticationStrategy.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnRegisterSessionAuthenticationStrategy() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor) + .postProcess(any(RegisterSessionAuthenticationStrategy.class)); + } + + @Test + public void configureWhenRegisteringObjectPostProcessorThenInvokedOnChangeSessionIdAuthenticationStrategy() { + ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); + this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor) + .postProcess(any(ChangeSessionIdAuthenticationStrategy.class)); + } + + @Test + public void getWhenAnonymousRequestAndTrustResolverSharedObjectReturnsAnonymousFalseThenSessionIsSaved() + throws Exception { + SharedTrustResolverConfig.TR = mock(AuthenticationTrustResolver.class); + given(SharedTrustResolverConfig.TR.isAnonymous(any())).willReturn(false); + this.spring.register(SharedTrustResolverConfig.class).autowire(); + MvcResult mvcResult = this.mvc.perform(get("/")).andReturn(); + assertThat(mvcResult.getRequest().getSession(false)).isNotNull(); + } + + @Test + public void whenOneSessionRegistryBeanThenUseIt() throws Exception { + SessionRegistryOneBeanConfig.SESSION_REGISTRY = mock(SessionRegistry.class); + this.spring.register(SessionRegistryOneBeanConfig.class).autowire(); + MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext()); + this.mvc.perform(get("/").session(session)); + verify(SessionRegistryOneBeanConfig.SESSION_REGISTRY).getSessionInformation(session.getId()); + } + + @Test + public void whenTwoSessionRegistryBeansThenUseNeither() throws Exception { + SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE = mock(SessionRegistry.class); + SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO = mock(SessionRegistry.class); + this.spring.register(SessionRegistryTwoBeansConfig.class).autowire(); + MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext()); + this.mvc.perform(get("/").session(session)); + verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE); + verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO); } @EnableWebSecurity static class SessionManagementRequestCacheConfig extends WebSecurityConfigurerAdapter { + static RequestCache REQUEST_CACHE; @Override @@ -104,25 +310,12 @@ public class SessionManagementConfigurerTests { .sessionCreationPolicy(SessionCreationPolicy.STATELESS); // @formatter:on } - } - @Test - public void sessionManagementWhenConfiguredThenDoesNotOverrideSecurityContextRepository() throws Exception { - SessionManagementSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPO = mock(SecurityContextRepository.class); - when(SessionManagementSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPO.loadContext(any(HttpRequestResponseHolder.class))) - .thenReturn(mock(SecurityContext.class)); - this.spring.register(SessionManagementSecurityContextRepositoryConfig.class).autowire(); - - this.mvc.perform(get("/")); - - verify(SessionManagementSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPO).saveContext( - any(SecurityContext.class), - any(HttpServletRequest.class), - any(HttpServletResponse.class)); } @EnableWebSecurity static class SessionManagementSecurityContextRepositoryConfig extends WebSecurityConfigurerAdapter { + static SecurityContextRepository SECURITY_CONTEXT_REPO; @Override @@ -136,21 +329,12 @@ public class SessionManagementConfigurerTests { .sessionCreationPolicy(SessionCreationPolicy.STATELESS); // @formatter:on } - } - @Test - public void sessionManagementWhenInvokedTwiceThenUsesOriginalSessionCreationPolicy() throws Exception { - this.spring.register(InvokeTwiceDoesNotOverride.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNull(); } @EnableWebSecurity static class InvokeTwiceDoesNotOverride extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -161,27 +345,12 @@ public class SessionManagementConfigurerTests { .sessionManagement(); // @formatter:on } - } - // SEC-2137 - @Test - public void getWhenSessionFixationDisabledAndConcurrencyControlEnabledThenSessionIsNotInvalidated() - throws Exception { - this.spring.register(DisableSessionFixationEnableConcurrencyControlConfig.class).autowire(); - MockHttpSession session = new MockHttpSession(); - String sessionId = session.getId(); - - MvcResult mvcResult = this.mvc.perform(get("/") - .with(httpBasic("user", "password")) - .session(session)) - .andExpect(status().isNotFound()) - .andReturn(); - - assertThat(mvcResult.getRequest().getSession().getId()).isEqualTo(sessionId); } @EnableWebSecurity static class DisableSessionFixationEnableConcurrencyControlConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -202,36 +371,19 @@ public class SessionManagementConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void authenticateWhenNewSessionFixationProtectionInLambdaThenCreatesNewSession() throws Exception { - this.spring.register(SFPNewSessionInLambdaConfig.class).autowire(); - - MockHttpSession givenSession = new MockHttpSession(); - String givenSessionId = givenSession.getId(); - givenSession.setAttribute("name", "value"); - - MockHttpSession resultingSession = (MockHttpSession) - this.mvc.perform(get("/auth") - .session(givenSession) - .with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()) - .andReturn().getRequest().getSession(false); - - assertThat(givenSessionId).isNotEqualTo(resultingSession.getId()); - assertThat(resultingSession.getAttribute("name")).isNull(); } @EnableWebSecurity static class SFPNewSessionInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .sessionManagement(sessionManagement -> + .sessionManagement((sessionManagement) -> sessionManagement - .sessionFixation(sessionFixation -> + .sessionFixation((sessionFixation) -> sessionFixation.newSession() ) ) @@ -247,47 +399,12 @@ public class SessionManagementConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenUserLoggedInAndMaxSessionsIsOneThenLoginPrevented() throws Exception { - this.spring.register(ConcurrencyControlConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?error")); - } - - @Test - public void loginWhenUserSessionExpiredAndMaxSessionsIsOneThenLoggedIn() throws Exception { - this.spring.register(ConcurrencyControlConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")) - .andReturn(); - HttpSession authenticatedSession = mvcResult.getRequest().getSession(); - this.spring.getContext().publishEvent(new HttpSessionDestroyedEvent(authenticatedSession)); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")); } @EnableWebSecurity static class ConcurrencyControlConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -308,35 +425,20 @@ public class SessionManagementConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void loginWhenUserLoggedInAndMaxSessionsOneInLambdaThenLoginPrevented() throws Exception { - this.spring.register(ConcurrencyControlInLambdaConfig.class).autowire(); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")); - - this.mvc.perform(post("/login") - .with(csrf()) - .param("username", "user") - .param("password", "password")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?error")); } @EnableWebSecurity static class ConcurrencyControlInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(HttpSecurity http) throws Exception { // @formatter:off http .formLogin(withDefaults()) - .sessionManagement(sessionManagement -> + .sessionManagement((sessionManagement) -> sessionManagement - .sessionConcurrency(sessionConcurrency -> + .sessionConcurrency((sessionConcurrency) -> sessionConcurrency .maximumSessions(1) .maxSessionsPreventsLogin(true) @@ -353,89 +455,28 @@ public class SessionManagementConfigurerTests { .withUser(PasswordEncodedUser.user()); // @formatter:on } - } - @Test - public void requestWhenSessionCreationPolicyStateLessInLambdaThenNoSessionCreated() throws Exception { - this.spring.register(SessionCreationPolicyStateLessInLambdaConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andReturn(); - HttpSession session = mvcResult.getRequest().getSession(false); - - assertThat(session).isNull(); } @EnableWebSecurity static class SessionCreationPolicyStateLessInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .sessionManagement(sessionManagement -> + .sessionManagement((sessionManagement) -> sessionManagement .sessionCreationPolicy(SessionCreationPolicy.STATELESS) ); // @formatter:on } - } - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnSessionManagementFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(SessionManagementFilter.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnConcurrentSessionFilter() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ConcurrentSessionFilter.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnConcurrentSessionControlAuthenticationStrategy() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ConcurrentSessionControlAuthenticationStrategy.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnCompositeSessionAuthenticationStrategy() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(CompositeSessionAuthenticationStrategy.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnRegisterSessionAuthenticationStrategy() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(RegisterSessionAuthenticationStrategy.class)); - } - - @Test - public void configureWhenRegisteringObjectPostProcessorThenInvokedOnChangeSessionIdAuthenticationStrategy() { - ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); - this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(ChangeSessionIdAuthenticationStrategy.class)); } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor; @Override @@ -451,30 +492,21 @@ public class SessionManagementConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void getWhenAnonymousRequestAndTrustResolverSharedObjectReturnsAnonymousFalseThenSessionIsSaved() - throws Exception { - SharedTrustResolverConfig.TR = mock(AuthenticationTrustResolver.class); - when(SharedTrustResolverConfig.TR.isAnonymous(any())).thenReturn(false); - this.spring.register(SharedTrustResolverConfig.class).autowire(); - - MvcResult mvcResult = this.mvc.perform(get("/")) - .andReturn(); - - assertThat(mvcResult.getRequest().getSession(false)).isNotNull(); } @EnableWebSecurity static class SharedTrustResolverConfig extends WebSecurityConfigurerAdapter { + static AuthenticationTrustResolver TR; @Override @@ -484,35 +516,12 @@ public class SessionManagementConfigurerTests { .setSharedObject(AuthenticationTrustResolver.class, TR); // @formatter:on } - } - @Test - public void whenOneSessionRegistryBeanThenUseIt() throws Exception { - SessionRegistryOneBeanConfig.SESSION_REGISTRY = mock(SessionRegistry.class); - this.spring.register(SessionRegistryOneBeanConfig.class).autowire(); - - MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext()); - this.mvc.perform(get("/").session(session)); - - verify(SessionRegistryOneBeanConfig.SESSION_REGISTRY) - .getSessionInformation(session.getId()); - } - - @Test - public void whenTwoSessionRegistryBeansThenUseNeither() throws Exception { - SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE = mock(SessionRegistry.class); - SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO = mock(SessionRegistry.class); - this.spring.register(SessionRegistryTwoBeansConfig.class).autowire(); - - MockHttpSession session = new MockHttpSession(this.spring.getContext().getServletContext()); - this.mvc.perform(get("/").session(session)); - - verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_ONE); - verifyNoInteractions(SessionRegistryTwoBeansConfig.SESSION_REGISTRY_TWO); } @EnableWebSecurity static class SessionRegistryOneBeanConfig extends WebSecurityConfigurerAdapter { + private static SessionRegistry SESSION_REGISTRY; @Override @@ -525,13 +534,15 @@ public class SessionManagementConfigurerTests { } @Bean - public SessionRegistry sessionRegistry() { + SessionRegistry sessionRegistry() { return SESSION_REGISTRY; } + } @EnableWebSecurity static class SessionRegistryTwoBeansConfig extends WebSecurityConfigurerAdapter { + private static SessionRegistry SESSION_REGISTRY_ONE; private static SessionRegistry SESSION_REGISTRY_TWO; @@ -546,13 +557,15 @@ public class SessionManagementConfigurerTests { } @Bean - public SessionRegistry sessionRegistryOne() { + SessionRegistry sessionRegistryOne() { return SESSION_REGISTRY_ONE; } @Bean - public SessionRegistry sessionRegistryTwo() { + SessionRegistry sessionRegistryTwo() { return SESSION_REGISTRY_TWO; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTransientAuthenticationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTransientAuthenticationTests.java index 184943bec6..c5b5dd2cc5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTransientAuthenticationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerTransientAuthenticationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import org.junit.Rule; @@ -48,18 +49,14 @@ public class SessionManagementConfigurerTransientAuthenticationTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void postWhenTransientAuthenticationThenNoSessionCreated() - throws Exception { - + public void postWhenTransientAuthenticationThenNoSessionCreated() throws Exception { this.spring.register(WithTransientAuthenticationConfig.class).autowire(); MvcResult result = this.mvc.perform(post("/login")).andReturn(); assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void postWhenTransientAuthenticationThenAlwaysSessionOverrides() - throws Exception { - + public void postWhenTransientAuthenticationThenAlwaysSessionOverrides() throws Exception { this.spring.register(AlwaysCreateSessionConfig.class).autowire(); MvcResult result = this.mvc.perform(post("/login")).andReturn(); assertThat(result.getRequest().getSession(false)).isNotNull(); @@ -67,28 +64,37 @@ public class SessionManagementConfigurerTransientAuthenticationTests { @EnableWebSecurity static class WithTransientAuthenticationConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { super.configure(http); - + // @formatter:off http .csrf().disable(); + // @formatter:on } @Override protected void configure(AuthenticationManagerBuilder auth) { + // @formatter:off auth .authenticationProvider(new TransientAuthenticationProvider()); + // @formatter:on } + } @EnableWebSecurity static class AlwaysCreateSessionConfig extends WithTransientAuthenticationConfig { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.ALWAYS); + // @formatter:on } + } static class TransientAuthenticationProvider implements AuthenticationProvider { @@ -102,10 +108,12 @@ public class SessionManagementConfigurerTransientAuthenticationTests { public boolean supports(Class authentication) { return true; } + } @Transient static class SomeTransientAuthentication extends AbstractAuthenticationToken { + SomeTransientAuthentication() { super(null); } @@ -119,5 +127,7 @@ public class SessionManagementConfigurerTransientAuthenticationTests { public Object getPrincipal() { return null; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java index 9b86829ebc..34dcb0f52d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import javax.servlet.http.HttpServletResponse; @@ -47,10 +48,13 @@ import static org.assertj.core.api.Assertions.assertThat; * */ public class UrlAuthorizationConfigurerTests { + AnnotationConfigWebApplicationContext context; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Autowired @@ -74,34 +78,66 @@ public class UrlAuthorizationConfigurerTests { @Test public void mvcMatcher() throws Exception { loadConfig(MvcMatcherConfig.class, LegacyMvcMatchingConfig.class); - this.request.setRequestURI("/path"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setRequestURI("/path.html"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); setup(); - this.request.setServletPath("/path/"); this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + } - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + @Test + public void mvcMatcherServletPath() throws Exception { + loadConfig(MvcMatcherServletPathConfig.class, LegacyMvcMatchingConfig.class); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path.html"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/spring"); + this.request.setRequestURI("/spring/path/"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + setup(); + this.request.setServletPath("/foo"); + this.request.setRequestURI("/foo/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + setup(); + this.request.setServletPath("/"); + this.request.setRequestURI("/path"); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + @Test + public void anonymousUrlAuthorization() { + loadConfig(AnonymousUrlAuthorizationConfig.class); + } + + public void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.setServletContext(new MockServletContext()); + this.context.refresh(); + this.context.getAutowireCapableBeanFactory().autowireBean(this); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -122,63 +158,21 @@ public class UrlAuthorizationConfigurerTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } + } - } - @Test - public void mvcMatcherServletPath() throws Exception { - loadConfig(MvcMatcherServletPathConfig.class, LegacyMvcMatchingConfig.class); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path.html"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/spring"); - this.request.setRequestURI("/spring/path/"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()) - .isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); - - setup(); - - this.request.setServletPath("/foo"); - this.request.setRequestURI("/foo/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - - setup(); - - this.request.setServletPath("/"); - this.request.setRequestURI("/path"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @EnableWebSecurity @Configuration @EnableWebMvc static class MvcMatcherServletPathConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -199,21 +193,20 @@ public class UrlAuthorizationConfigurerTests { @RestController static class PathController { + @RequestMapping("/path") - public String path() { + String path() { return "path"; } - } - } - @Test - public void anonymousUrlAuthorization() { - loadConfig(AnonymousUrlAuthorizationConfig.class); + } + } @EnableWebSecurity @Configuration static class AnonymousUrlAuthorizationConfig extends WebSecurityConfigurerAdapter { + @Override public void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -222,22 +215,17 @@ public class UrlAuthorizationConfigurerTests { .anyRequest().anonymous(); // @formatter:on } + } @Configuration static class LegacyMvcMatchingConfig implements WebMvcConfigurer { + @Override public void configurePathMatch(PathMatchConfigurer configurer) { configurer.setUseSuffixPatternMatch(true); } + } - public void loadConfig(Class... configs) { - this.context = new AnnotationConfigWebApplicationContext(); - this.context.register(configs); - this.context.setServletContext(new MockServletContext()); - this.context.refresh(); - - this.context.getAutowireCapableBeanFactory().autowireBean(this); - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationsTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationsTests.java index 4d5b6523d8..90620c97fe 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationsTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationsTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers; import java.util.List; + import javax.servlet.Filter; import org.junit.Rule; @@ -41,7 +43,6 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Rob Winch * @author Josh Cummings * @@ -60,86 +61,63 @@ public class UrlAuthorizationsTests { @WithMockUser(authorities = "ROLE_USER") public void hasAnyAuthorityWhenAuthoritySpecifiedThenMatchesAuthority() throws Exception { this.spring.register(RoleConfig.class).autowire(); - + // @formatter:off this.mvc.perform(get("/role-user-authority")) .andExpect(status().isNotFound()); this.mvc.perform(get("/role-user")) .andExpect(status().isNotFound()); this.mvc.perform(get("/role-admin-authority")) .andExpect(status().isForbidden()); + // @formatter:on } @Test @WithMockUser(authorities = "ROLE_ADMIN") public void hasAnyAuthorityWhenAuthoritiesSpecifiedThenMatchesAuthority() throws Exception { this.spring.register(RoleConfig.class).autowire(); - - this.mvc.perform(get("/role-user-admin-authority")) - .andExpect(status().isNotFound()); - this.mvc.perform(get("/role-user-admin")) - .andExpect(status().isNotFound()); - this.mvc.perform(get("/role-user-authority")) - .andExpect(status().isForbidden()); + this.mvc.perform(get("/role-user-admin-authority")).andExpect(status().isNotFound()); + this.mvc.perform(get("/role-user-admin")).andExpect(status().isNotFound()); + this.mvc.perform(get("/role-user-authority")).andExpect(status().isForbidden()); } @Test @WithMockUser(roles = "USER") public void hasAnyRoleWhenRoleSpecifiedThenMatchesRole() throws Exception { this.spring.register(RoleConfig.class).autowire(); - + // @formatter:off this.mvc.perform(get("/role-user")) .andExpect(status().isNotFound()); this.mvc.perform(get("/role-admin")) .andExpect(status().isForbidden()); + // @formatter:on } @Test @WithMockUser(roles = "ADMIN") public void hasAnyRoleWhenRolesSpecifiedThenMatchesRole() throws Exception { this.spring.register(RoleConfig.class).autowire(); - - this.mvc.perform(get("/role-admin-user")) - .andExpect(status().isNotFound()); - this.mvc.perform(get("/role-user")) - .andExpect(status().isForbidden()); + this.mvc.perform(get("/role-admin-user")).andExpect(status().isNotFound()); + this.mvc.perform(get("/role-user")).andExpect(status().isForbidden()); } @Test @WithMockUser(authorities = "USER") public void hasAnyRoleWhenRoleSpecifiedThenDoesNotMatchAuthority() throws Exception { this.spring.register(RoleConfig.class).autowire(); - + // @formatter:off this.mvc.perform(get("/role-user")) .andExpect(status().isForbidden()); this.mvc.perform(get("/role-admin")) .andExpect(status().isForbidden()); - } - - @EnableWebSecurity - static class RoleConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .antMatchers("/role-user-authority").hasAnyAuthority("ROLE_USER") - .antMatchers("/role-admin-authority").hasAnyAuthority("ROLE_ADMIN") - .antMatchers("/role-user-admin-authority").hasAnyAuthority("ROLE_USER", "ROLE_ADMIN") - .antMatchers("/role-user").hasAnyRole("USER") - .antMatchers("/role-admin").hasAnyRole("ADMIN") - .antMatchers("/role-user-admin").hasAnyRole("USER", "ADMIN"); - // @formatter:on - } + // @formatter:on } @Test public void configureWhenNoAccessDecisionManagerThenDefaultsToAffirmativeBased() { this.spring.register(NoSpecificAccessDecisionManagerConfig.class).autowire(); - FilterSecurityInterceptor interceptor = getFilter(FilterSecurityInterceptor.class); assertThat(interceptor).isNotNull(); - assertThat(interceptor).extracting("accessDecisionManager") - .isInstanceOf(AffirmativeBased.class); + assertThat(interceptor).extracting("accessDecisionManager").isInstanceOf(AffirmativeBased.class); } private T getFilter(Class filterType) { @@ -153,18 +131,40 @@ public class UrlAuthorizationsTests { return null; } + @EnableWebSecurity + static class RoleConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .antMatchers("/role-user-authority").hasAnyAuthority("ROLE_USER") + .antMatchers("/role-admin-authority").hasAnyAuthority("ROLE_ADMIN") + .antMatchers("/role-user-admin-authority").hasAnyAuthority("ROLE_USER", "ROLE_ADMIN") + .antMatchers("/role-user").hasAnyRole("USER") + .antMatchers("/role-admin").hasAnyRole("ADMIN") + .antMatchers("/role-user-admin").hasAnyRole("USER", "ADMIN"); + // @formatter:on + } + + } + @EnableWebSecurity static class NoSpecificAccessDecisionManagerConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { ApplicationContext context = getApplicationContext(); - UrlAuthorizationConfigurer.StandardInterceptUrlRegistry registry = - http.apply(new UrlAuthorizationConfigurer(context)).getRegistry(); - + UrlAuthorizationConfigurer.StandardInterceptUrlRegistry registry = http + .apply(new UrlAuthorizationConfigurer(context)).getRegistry(); + // @formatter:off registry - .antMatchers("/a").hasRole("ADMIN") - .anyRequest().hasRole("USER"); + .antMatchers("/a").hasRole("ADMIN") + .anyRequest().hasRole("USER"); + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/X509ConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/X509ConfigurerTests.java index f7c24c0013..e7ef36d356 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/X509ConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/X509ConfigurerTests.java @@ -16,8 +16,14 @@ package org.springframework.security.config.annotation.web.configurers; +import java.io.InputStream; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; @@ -30,11 +36,6 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.web.authentication.preauth.x509.X509AuthenticationFilter; import org.springframework.test.web.servlet.MockMvc; -import java.io.InputStream; -import java.security.cert.Certificate; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; - import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -60,13 +61,52 @@ public class X509ConfigurerTests { @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnX509AuthenticationFilter() { this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(X509AuthenticationFilter.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(X509AuthenticationFilter.class)); + @Test + public void x509WhenInvokedTwiceThenUsesOriginalSubjectPrincipalRegex() throws Exception { + this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); + X509Certificate certificate = loadCert("rodatexampledotcom.cer"); + // @formatter:off + this.mvc.perform(get("/").with(x509(certificate))) + .andExpect(authenticated().withUsername("rod")); + // @formatter:on + } + + @Test + public void x509WhenConfiguredInLambdaThenUsesDefaults() throws Exception { + this.spring.register(DefaultsInLambdaConfig.class).autowire(); + X509Certificate certificate = loadCert("rod.cer"); + // @formatter:off + this.mvc.perform(get("/").with(x509(certificate))) + .andExpect(authenticated().withUsername("rod")); + // @formatter:on + } + + @Test + public void x509WhenSubjectPrincipalRegexInLambdaThenUsesRegexToExtractPrincipal() throws Exception { + this.spring.register(SubjectPrincipalRegexInLambdaConfig.class).autowire(); + X509Certificate certificate = loadCert("rodatexampledotcom.cer"); + // @formatter:off + this.mvc.perform(get("/").with(x509(certificate))) + .andExpect(authenticated().withUsername("rod")); + // @formatter:on + } + + private T loadCert(String location) { + try (InputStream is = new ClassPathResource(location).getInputStream()) { + CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); + return (T) certFactory.generateCertificate(is); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); + } } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor = spy(ReflectingObjectPostProcessor.class); @Override @@ -81,27 +121,21 @@ public class X509ConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void x509WhenInvokedTwiceThenUsesOriginalSubjectPrincipalRegex() throws Exception { - this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); - X509Certificate certificate = loadCert("rodatexampledotcom.cer"); - - this.mvc.perform(get("/") - .with(x509(certificate))) - .andExpect(authenticated().withUsername("rod")); } @EnableWebSecurity static class DuplicateDoesNotOverrideConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -121,20 +155,12 @@ public class X509ConfigurerTests { .withUser("rod").password("password").roles("USER", "ADMIN"); // @formatter:on } - } - @Test - public void x509WhenConfiguredInLambdaThenUsesDefaults() throws Exception { - this.spring.register(DefaultsInLambdaConfig.class).autowire(); - X509Certificate certificate = loadCert("rod.cer"); - - this.mvc.perform(get("/") - .with(x509(certificate))) - .andExpect(authenticated().withUsername("rod")); } @EnableWebSecurity static class DefaultsInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off @@ -151,25 +177,17 @@ public class X509ConfigurerTests { .withUser("rod").password("password").roles("USER", "ADMIN"); // @formatter:on } - } - @Test - public void x509WhenSubjectPrincipalRegexInLambdaThenUsesRegexToExtractPrincipal() throws Exception { - this.spring.register(SubjectPrincipalRegexInLambdaConfig.class).autowire(); - X509Certificate certificate = loadCert("rodatexampledotcom.cer"); - - this.mvc.perform(get("/") - .with(x509(certificate))) - .andExpect(authenticated().withUsername("rod")); } @EnableWebSecurity static class SubjectPrincipalRegexInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .x509(x509 -> + .x509((x509) -> x509 .subjectPrincipalRegex("CN=(.*?)@example.com(?:,|$)") ); @@ -184,14 +202,7 @@ public class X509ConfigurerTests { .withUser("rod").password("password").roles("USER", "ADMIN"); // @formatter:on } + } - private T loadCert(String location) { - try (InputStream is = new ClassPathResource(location).getInputStream()) { - CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); - return (T) certFactory.generateCertificate(is); - } catch (Exception e) { - throw new IllegalArgumentException(e); - } - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index 43e934c6bf..a601bd90b9 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.oauth2.client; +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.mock.web.MockHttpServletRequest; @@ -53,17 +61,16 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; @@ -78,6 +85,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Parikshit Dutta */ public class OAuth2ClientConfigurerTests { + private static ClientRegistrationRepository clientRegistrationRepository; private static OAuth2AuthorizedClientService authorizedClientService; @@ -100,68 +108,65 @@ public class OAuth2ClientConfigurerTests { @Before public void setup() { + // @formatter:off this.registration1 = TestClientRegistrations.clientRegistration() - .registrationId("registration-1") - .clientId("client-1") - .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri("{baseUrl}/client-1") - .scope("user") - .authorizationUri("https://provider.com/oauth2/authorize") - .tokenUri("https://provider.com/oauth2/token") - .userInfoUri("https://provider.com/oauth2/user") - .userNameAttributeName("id") - .clientName("client-1") - .build(); + .registrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri("{baseUrl}/client-1") + .scope("user") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/user") + .userNameAttributeName("id") + .clientName("client-1") + .build(); + // @formatter:on clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1); authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository); - authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService); - authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( - clientRegistrationRepository, "/oauth2/authorization"); - + authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + authorizedClientService); + authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, + "/oauth2/authorization"); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(300) - .build(); + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(300).build(); accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); - when(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class))).thenReturn(accessTokenResponse); + given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class))) + .willReturn(accessTokenResponse); requestCache = mock(RequestCache.class); } @Test public void configureWhenAuthorizationCodeRequestThenRedirectForAuthorization() throws Exception { this.spring.register(OAuth2ClientConfig.class).autowire(); - + // @formatter:off MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1")) - .andExpect(status().is3xxRedirection()) - .andReturn(); - assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?" + - "response_type=code&client_id=client-1&" + - "scope=user&state=.{15,}&" + - "redirect_uri=http://localhost/client-1"); + .andExpect(status().is3xxRedirection()).andReturn(); + assertThat(mvcResult.getResponse().getRedirectedUrl()) + .matches("https://provider.com/oauth2/authorize\\?" + "response_type=code&client_id=client-1&" + + "scope=user&state=.{15,}&" + "redirect_uri=http://localhost/client-1"); + // @formatter:on } @Test public void configureWhenOauth2ClientInLambdaThenRedirectForAuthorization() throws Exception { this.spring.register(OAuth2ClientInLambdaConfig.class).autowire(); - MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1")) - .andExpect(status().is3xxRedirection()) - .andReturn(); - assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?" + - "response_type=code&client_id=client-1&" + - "scope=user&state=.{15,}&" + - "redirect_uri=http://localhost/client-1"); + .andExpect(status().is3xxRedirection()).andReturn(); + assertThat(mvcResult.getResponse().getRedirectedUrl()) + .matches("https://provider.com/oauth2/authorize\\?" + "response_type=code&client_id=client-1&" + + "scope=user&state=.{15,}&" + "redirect_uri=http://localhost/client-1"); } @Test public void configureWhenAuthorizationCodeResponseSuccessThenAuthorizedClientSaved() throws Exception { this.spring.register(OAuth2ClientConfig.class).autowire(); - // Setup the Authorization Request in the session Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) .clientId(this.registration1.getClientId()) @@ -169,105 +174,100 @@ public class OAuth2ClientConfigurerTests { .state("state") .attributes(attributes) .build(); - - AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); + // @formatter:on + AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletResponse response = new MockHttpServletResponse(); authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); - MockHttpSession session = (MockHttpSession) request.getSession(); - String principalName = "user1"; TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); - - this.mockMvc.perform(get("/client-1") - .param(OAuth2ParameterNames.CODE, "code") - .param(OAuth2ParameterNames.STATE, "state") - .with(authentication(authentication)) - .session(session)) - .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrl("http://localhost/client-1")); - - OAuth2AuthorizedClient authorizedClient = authorizedClientRepository.loadAuthorizedClient( - this.registration1.getRegistrationId(), authentication, request); + // @formatter:off + MockHttpServletRequestBuilder clientRequest = get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(authentication(authentication)) + .session(session); + this.mockMvc.perform(clientRequest) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl("http://localhost/client-1")); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = authorizedClientRepository + .loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request); assertThat(authorizedClient).isNotNull(); } @Test - public void configureWhenRequestCacheProvidedAndClientAuthorizationRequiredExceptionThrownThenRequestCacheUsed() throws Exception { + public void configureWhenRequestCacheProvidedAndClientAuthorizationRequiredExceptionThrownThenRequestCacheUsed() + throws Exception { this.spring.register(OAuth2ClientConfig.class).autowire(); - MvcResult mvcResult = this.mockMvc.perform(get("/resource1").with(user("user1"))) - .andExpect(status().is3xxRedirection()) - .andReturn(); - assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?" + - "response_type=code&client_id=client-1&" + - "scope=user&state=.{15,}&" + - "redirect_uri=http://localhost/client-1"); - + .andExpect(status().is3xxRedirection()).andReturn(); + assertThat(mvcResult.getResponse().getRedirectedUrl()) + .matches("https://provider.com/oauth2/authorize\\?" + "response_type=code&client_id=client-1&" + + "scope=user&state=.{15,}&" + "redirect_uri=http://localhost/client-1"); verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void configureWhenRequestCacheProvidedAndClientAuthorizationSucceedsThenRequestCacheUsed() throws Exception { this.spring.register(OAuth2ClientConfig.class).autowire(); - // Setup the Authorization Request in the session Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) - .clientId(this.registration1.getClientId()) - .redirectUri("http://localhost/client-1") + .clientId(this.registration1.getClientId()).redirectUri("http://localhost/client-1") .state("state") .attributes(attributes) .build(); - - AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); + // @formatter:on + AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); MockHttpServletResponse response = new MockHttpServletResponse(); authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); - MockHttpSession session = (MockHttpSession) request.getSession(); - String principalName = "user1"; TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); - - this.mockMvc.perform(get("/client-1") + // @formatter:off + MockHttpServletRequestBuilder clientRequest = get("/client-1") .param(OAuth2ParameterNames.CODE, "code") .param(OAuth2ParameterNames.STATE, "state") .with(authentication(authentication)) - .session(session)) + .session(session); + this.mockMvc.perform(clientRequest) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/client-1")); - + // @formatter:on verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } // gh-5521 @Test - public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception { + public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() + throws Exception { // Override default resolver OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver; authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class); - when(authorizationRequestResolver.resolve(any())).thenAnswer(invocation -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0))); - + given(authorizationRequestResolver.resolve(any())) + .willAnswer((invocation) -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0))); this.spring.register(OAuth2ClientConfig.class).autowire(); - + // @formatter:off this.mockMvc.perform(get("/oauth2/authorization/registration-1")) .andExpect(status().is3xxRedirection()) .andReturn(); - + // @formatter:on verify(authorizationRequestResolver).resolve(any()); } @EnableWebSecurity @EnableWebMvc static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -279,35 +279,41 @@ public class OAuth2ClientConfigurerTests { .authorizationCodeGrant() .authorizationRequestResolver(authorizationRequestResolver) .accessTokenResponseClient(accessTokenResponseClient); + // @formatter:on } @Bean - public ClientRegistrationRepository clientRegistrationRepository() { + ClientRegistrationRepository clientRegistrationRepository() { return clientRegistrationRepository; } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository() { + OAuth2AuthorizedClientRepository authorizedClientRepository() { return authorizedClientRepository; } @RestController - public class ResourceController { + class ResourceController { + @GetMapping("/resource1") - public String resource1(@RegisteredOAuth2AuthorizedClient("registration-1") OAuth2AuthorizedClient authorizedClient) { + String resource1( + @RegisteredOAuth2AuthorizedClient("registration-1") OAuth2AuthorizedClient authorizedClient) { return "resource1"; } + } + } @EnableWebSecurity @EnableWebMvc static class OAuth2ClientInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) @@ -316,13 +322,15 @@ public class OAuth2ClientConfigurerTests { } @Bean - public ClientRegistrationRepository clientRegistrationRepository() { + ClientRegistrationRepository clientRegistrationRepository() { return clientRegistrationRepository; } @Bean - public OAuth2AuthorizedClientRepository authorizedClientRepository() { + OAuth2AuthorizedClientRepository authorizedClientRepository() { return authorizedClientRepository; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index 48274c5e08..5207bfd772 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.configurers.oauth2.client; import java.util.ArrayList; @@ -28,6 +29,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; @@ -69,6 +71,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; @@ -80,6 +83,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; @@ -88,12 +92,10 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -108,13 +110,21 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. */ public class OAuth2LoginConfigurerTests { + // @formatter:off private static final ClientRegistration GOOGLE_CLIENT_REGISTRATION = CommonOAuth2Provider.GOOGLE - .getBuilder("google").clientId("clientId").clientSecret("clientSecret") + .getBuilder("google") + .clientId("clientId") + .clientSecret("clientSecret") .build(); + // @formatter:on + // @formatter:off private static final ClientRegistration GITHUB_CLIENT_REGISTRATION = CommonOAuth2Provider.GITHUB - .getBuilder("github").clientId("clientId").clientSecret("clientSecret") + .getBuilder("github") + .clientId("clientId") + .clientSecret("clientSecret") .build(); + // @formatter:on private ConfigurableApplicationContext context; @@ -134,7 +144,9 @@ public class OAuth2LoginConfigurerTests { MockMvc mvc; private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockFilterChain filterChain; @Before @@ -156,45 +168,35 @@ public class OAuth2LoginConfigurerTests { public void oauth2Login() throws Exception { // setup application context loadConfig(OAuth2LoginConfig.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(1); - assertThat(authentication.getAuthorities()).first() - .isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).first().isInstanceOf(OAuth2UserAuthority.class) + .hasToString("ROLE_USER"); } @Test public void requestWhenOauth2LoginInLambdaThenAuthenticationContainsOauth2UserAuthority() throws Exception { loadConfig(OAuth2LoginInLambdaConfig.class); OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(1); - assertThat(authentication.getAuthorities()).first() - .isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).first().isInstanceOf(OAuth2UserAuthority.class) + .hasToString("ROLE_USER"); } // gh-6009 @@ -202,19 +204,14 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginWhenSuccessThenAuthenticationSuccessEventPublished() throws Exception { // setup application context loadConfig(OAuth2LoginConfig.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions assertThat(OAuth2LoginConfig.EVENTS).isNotEmpty(); assertThat(OAuth2LoginConfig.EVENTS).hasSize(1); @@ -225,23 +222,17 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginCustomWithConfigurer() throws Exception { // setup application context loadConfig(OAuth2LoginConfigCustomWithConfigurer.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(2); assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OAUTH2_USER"); @@ -251,23 +242,17 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginCustomWithBeanRegistration() throws Exception { // setup application context loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(2); assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OAUTH2_USER"); @@ -277,23 +262,17 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginCustomWithUserServiceBeanRegistration() throws Exception { // setup application context loadConfig(OAuth2LoginConfigCustomUserServiceBeanRegistration.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(2); assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OAUTH2_USER"); @@ -304,101 +283,92 @@ public class OAuth2LoginConfigurerTests { public void oauth2LoginConfigLoginProcessingUrl() throws Exception { // setup application context loadConfig(OAuth2LoginConfigLoginProcessingUrl.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); this.request.setServletPath("/login/oauth2/google"); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(1); - assertThat(authentication.getAuthorities()).first() - .isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).first().isInstanceOf(OAuth2UserAuthority.class) + .hasToString("ROLE_USER"); } // gh-5521 @Test public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception { loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class); - OAuth2AuthorizationRequestResolver resolver = this.context.getBean( - OAuth2LoginConfigCustomAuthorizationRequestResolver.class).resolver; + OAuth2AuthorizationRequestResolver resolver = this.context + .getBean(OAuth2LoginConfigCustomAuthorizationRequestResolver.class).resolver; + // @formatter:off OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://accounts.google.com/authorize") .clientId("client-id") .state("adsfa") - .authorizationRequestUri("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1") + .authorizationRequestUri( + "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1") .build(); - when(resolver.resolve(any())).thenReturn(result); - + // @formatter:on + given(resolver.resolve(any())).willReturn(result); String requestUri = "/oauth2/authorization/google"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); + assertThat(this.response.getRedirectedUrl()).isEqualTo( + "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); } @Test public void requestWhenOauth2LoginWithCustomAuthorizationRequestParametersThenParametersInRedirectedUrl() throws Exception { loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolverInLambda.class); - OAuth2AuthorizationRequestResolver resolver = this.context.getBean( - OAuth2LoginConfigCustomAuthorizationRequestResolverInLambda.class).resolver; + OAuth2AuthorizationRequestResolver resolver = this.context + .getBean(OAuth2LoginConfigCustomAuthorizationRequestResolverInLambda.class).resolver; + // @formatter:off OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://accounts.google.com/authorize") .clientId("client-id") .state("adsfa") - .authorizationRequestUri("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1") + .authorizationRequestUri( + "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1") .build(); - when(resolver.resolve(any())).thenReturn(result); - + // @formatter:on + given(resolver.resolve(any())).willReturn(result); String requestUri = "/oauth2/authorization/google"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); + assertThat(this.response.getRedirectedUrl()).isEqualTo( + "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); } // gh-5347 @Test public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception { loadConfig(OAuth2LoginConfig.class); - String requestUri = "/"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - assertThat(this.response.getRedirectedUrl()).matches("http://localhost/oauth2/authorization/google"); } // gh-5347 @Test - public void oauth2LoginWithOneClientConfiguredAndRequestFaviconNotAuthenticatedThenRedirectDefaultLoginPage() throws Exception { + public void oauth2LoginWithOneClientConfiguredAndRequestFaviconNotAuthenticatedThenRedirectDefaultLoginPage() + throws Exception { loadConfig(OAuth2LoginConfig.class); - String requestUri = "/favicon.ico"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); this.request.addHeader(HttpHeaders.ACCEPT, new MediaType("image", "*").toString()); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); } @@ -406,54 +376,43 @@ public class OAuth2LoginConfigurerTests { @Test public void oauth2LoginWithMultipleClientsConfiguredThenRedirectDefaultLoginPage() throws Exception { loadConfig(OAuth2LoginConfigMultipleClients.class); - String requestUri = "/"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - assertThat(this.response.getRedirectedUrl()).matches("http://localhost/login"); } // gh-6812 @Test - public void oauth2LoginWithOneClientConfiguredAndRequestXHRNotAuthenticatedThenDoesNotRedirectForAuthorization() throws Exception { + public void oauth2LoginWithOneClientConfiguredAndRequestXHRNotAuthenticatedThenDoesNotRedirectForAuthorization() + throws Exception { loadConfig(OAuth2LoginConfig.class); - String requestUri = "/"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); this.request.addHeader("X-Requested-With", "XMLHttpRequest"); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - assertThat(this.response.getRedirectedUrl()).doesNotMatch("http://localhost/oauth2/authorization/google"); } @Test public void oauth2LoginWithCustomLoginPageThenRedirectCustomLoginPage() throws Exception { loadConfig(OAuth2LoginConfigCustomLoginPage.class); - String requestUri = "/"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); } @Test public void requestWhenOauth2LoginWithCustomLoginPageInLambdaThenRedirectCustomLoginPage() throws Exception { loadConfig(OAuth2LoginConfigCustomLoginPageInLambda.class); - String requestUri = "/"; this.request = new MockHttpServletRequest("GET", requestUri); this.request.setServletPath(requestUri); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - assertThat(this.response.getRedirectedUrl()).matches("http://localhost/custom-login"); } @@ -461,75 +420,57 @@ public class OAuth2LoginConfigurerTests { public void oidcLogin() throws Exception { // setup application context loadConfig(OAuth2LoginConfig.class, JwtDecoderFactoryConfig.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(1); - assertThat(authentication.getAuthorities()).first() - .isInstanceOf(OidcUserAuthority.class).hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).first().isInstanceOf(OidcUserAuthority.class) + .hasToString("ROLE_USER"); } @Test public void requestWhenOauth2LoginInLambdaAndOidcThenAuthenticationContainsOidcUserAuthority() throws Exception { // setup application context loadConfig(OAuth2LoginInLambdaConfig.class, JwtDecoderFactoryConfig.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(1); - assertThat(authentication.getAuthorities()).first() - .isInstanceOf(OidcUserAuthority.class).hasToString("ROLE_USER"); + assertThat(authentication.getAuthorities()).first().isInstanceOf(OidcUserAuthority.class) + .hasToString("ROLE_USER"); } @Test public void oidcLoginCustomWithConfigurer() throws Exception { // setup application context loadConfig(OAuth2LoginConfigCustomWithConfigurer.class, JwtDecoderFactoryConfig.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(2); assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER"); @@ -539,23 +480,17 @@ public class OAuth2LoginConfigurerTests { public void oidcLoginCustomWithBeanRegistration() throws Exception { // setup application context loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class, JwtDecoderFactoryConfig.class); - // setup authorization request OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, this.request, this.response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); // setup authentication parameters this.request.setParameter("code", "code123"); this.request.setParameter("state", authorizationRequest.getState()); - // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); assertThat(authentication.getAuthorities()).hasSize(2); assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER"); assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER"); @@ -563,25 +498,20 @@ public class OAuth2LoginConfigurerTests { @Test public void oidcLoginCustomWithNoUniqueJwtDecoderFactory() { - assertThatThrownBy(() -> loadConfig(OAuth2LoginConfig.class, NoUniqueJwtDecoderFactoryConfig.class)) - .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) - .hasMessageContaining("No qualifying bean of type " + - "'org.springframework.security.oauth2.jwt.JwtDecoderFactory' " + - "available: expected single matching bean but found 2: jwtDecoderFactory1,jwtDecoderFactory2"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> loadConfig(OAuth2LoginConfig.class, NoUniqueJwtDecoderFactoryConfig.class)) + .withRootCauseInstanceOf(NoUniqueBeanDefinitionException.class) + .withMessageContaining("No qualifying bean of type " + + "'org.springframework.security.oauth2.jwt.JwtDecoderFactory' " + + "available: expected single matching bean but found 2: jwtDecoderFactory1,jwtDecoderFactory2"); } @Test public void logoutWhenUsingOidcLogoutHandlerThenRedirects() throws Exception { this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire(); - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - "registration-id"); - - this.mvc.perform(post("/logout") - .with(authentication(token)) - .with(csrf())) + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, "registration-id"); + this.mvc.perform(post("/logout").with(authentication(token)).with(csrf())) .andExpect(redirectedUrl("https://logout?id_token_hint=id-token")); } @@ -597,30 +527,66 @@ public class OAuth2LoginConfigurerTests { return this.createOAuth2AuthorizationRequest(GOOGLE_CLIENT_REGISTRATION, scopes); } - private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(ClientRegistration registration, String... scopes) { + private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(ClientRegistration registration, + String... scopes) { + // @formatter:off return OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) .clientId(registration.getClientId()) .state("state123") .redirectUri("http://localhost") - .attributes( - Collections.singletonMap( - OAuth2ParameterNames.REGISTRATION_ID, - registration.getRegistrationId())) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, + registration.getRegistrationId())) .scope(scopes) .build(); + // @formatter:on + } + + private static OAuth2AccessTokenResponseClient createOauth2AccessTokenResponseClient() { + return (request) -> { + Map additionalParameters = new HashMap<>(); + if (request.getAuthorizationExchange().getAuthorizationRequest().getScopes().contains("openid")) { + additionalParameters.put(OidcParameterNames.ID_TOKEN, "token123"); + } + return OAuth2AccessTokenResponse.withToken("accessToken123").tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(additionalParameters).build(); + }; + } + + private static OAuth2UserService createOauth2UserService() { + Map userAttributes = Collections.singletonMap("name", "spring"); + return (request) -> new DefaultOAuth2User(Collections.singleton(new OAuth2UserAuthority(userAttributes)), + userAttributes, "name"); + } + + private static OAuth2UserService createOidcUserService() { + OidcIdToken idToken = TestOidcIdTokens.idToken().build(); + return (request) -> new DefaultOidcUser(Collections.singleton(new OidcUserAuthority(idToken)), idToken); + } + + private static GrantedAuthoritiesMapper createGrantedAuthoritiesMapper() { + return (authorities) -> { + boolean isOidc = OidcUserAuthority.class.isInstance(authorities.iterator().next()); + List mappedAuthorities = new ArrayList<>(authorities); + mappedAuthorities.add(new SimpleGrantedAuthority(isOidc ? "ROLE_OIDC_USER" : "ROLE_OAUTH2_USER")); + return mappedAuthorities; + }; } @EnableWebSecurity - static class OAuth2LoginConfig extends CommonWebSecurityConfigurerAdapter implements ApplicationListener { + static class OAuth2LoginConfig extends CommonWebSecurityConfigurerAdapter + implements ApplicationListener { + static List EVENTS = new ArrayList<>(); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login() .clientRegistrationRepository( new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)); + // @formatter:on super.configure(http); } @@ -628,18 +594,20 @@ public class OAuth2LoginConfigurerTests { public void onApplicationEvent(AuthenticationSuccessEvent event) { EVENTS.add(event); } + } @EnableWebSecurity static class OAuth2LoginInLambdaConfig extends CommonLambdaWebSecurityConfigurerAdapter implements ApplicationListener { + static List EVENTS = new ArrayList<>(); @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .oauth2Login(oauth2Login -> + .oauth2Login((oauth2Login) -> oauth2Login .clientRegistrationRepository( new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) @@ -652,28 +620,36 @@ public class OAuth2LoginConfigurerTests { public void onApplicationEvent(AuthenticationSuccessEvent event) { EVENTS.add(event); } + } @EnableWebSecurity static class OAuth2LoginConfigCustomWithConfigurer extends CommonWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login() .clientRegistrationRepository( new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) .userInfoEndpoint() .userAuthoritiesMapper(createGrantedAuthoritiesMapper()); + // @formatter:on super.configure(http); } + } @EnableWebSecurity static class OAuth2LoginConfigCustomWithBeanRegistration extends CommonWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login(); + // @formatter:on super.configure(http); } @@ -686,12 +662,15 @@ public class OAuth2LoginConfigurerTests { GrantedAuthoritiesMapper grantedAuthoritiesMapper() { return createGrantedAuthoritiesMapper(); } + } @EnableWebSecurity static class OAuth2LoginConfigCustomUserServiceBeanRegistration extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -702,6 +681,7 @@ public class OAuth2LoginConfigurerTests { .oauth2Login() .tokenEndpoint() .accessTokenResponseClient(createOauth2AccessTokenResponseClient()); + // @formatter:on } @Bean @@ -733,94 +713,117 @@ public class OAuth2LoginConfigurerTests { OAuth2UserService oidcUserService() { return createOidcUserService(); } + } @EnableWebSecurity static class OAuth2LoginConfigLoginProcessingUrl extends CommonWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login() .clientRegistrationRepository( new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) .loginProcessingUrl("/login/oauth2/*"); + // @formatter:on super.configure(http); } + } @EnableWebSecurity static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonWebSecurityConfigurerAdapter { - private ClientRegistrationRepository clientRegistrationRepository = - new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION); + + private ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login() .clientRegistrationRepository(this.clientRegistrationRepository) .authorizationEndpoint() .authorizationRequestResolver(this.resolver); + // @formatter:on super.configure(http); } + } @EnableWebSecurity - static class OAuth2LoginConfigCustomAuthorizationRequestResolverInLambda extends CommonLambdaWebSecurityConfigurerAdapter { - private ClientRegistrationRepository clientRegistrationRepository = - new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION); + static class OAuth2LoginConfigCustomAuthorizationRequestResolverInLambda + extends CommonLambdaWebSecurityConfigurerAdapter { + + private ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .oauth2Login(oauth2Login -> + .oauth2Login((oauth2Login) -> oauth2Login .clientRegistrationRepository(this.clientRegistrationRepository) - .authorizationEndpoint(authorizationEndpoint -> + .authorizationEndpoint((authorizationEndpoint) -> authorizationEndpoint .authorizationRequestResolver(this.resolver) ) ); + // @formatter:on super.configure(http); } + } @EnableWebSecurity static class OAuth2LoginConfigMultipleClients extends CommonWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login() .clientRegistrationRepository( new InMemoryClientRegistrationRepository( GOOGLE_CLIENT_REGISTRATION, GITHUB_CLIENT_REGISTRATION)); + // @formatter:on super.configure(http); } + } @EnableWebSecurity static class OAuth2LoginConfigCustomLoginPage extends CommonWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .oauth2Login() .clientRegistrationRepository( new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) .loginPage("/custom-login"); + // @formatter:on super.configure(http); } + } @EnableWebSecurity static class OAuth2LoginConfigCustomLoginPageInLambda extends CommonLambdaWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .oauth2Login(oauth2Login -> + .oauth2Login((oauth2Login) -> oauth2Login .clientRegistrationRepository( new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) @@ -829,15 +832,19 @@ public class OAuth2LoginConfigurerTests { // @formatter:on super.configure(http); } + } @EnableWebSecurity static class OAuth2LoginConfigWithOidcLogoutSuccessHandler extends CommonWebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .logout() .logoutSuccessHandler(oidcLogoutSuccessHandler()); + // @formatter:on super.configure(http); } @@ -848,17 +855,18 @@ public class OAuth2LoginConfigurerTests { @Bean ClientRegistrationRepository clientRegistrationRepository() { - Map providerMetadata = - Collections.singletonMap("end_session_endpoint", "https://logout"); - return new InMemoryClientRegistrationRepository( - TestClientRegistrations.clientRegistration() - .providerConfigurationMetadata(providerMetadata).build()); + Map providerMetadata = Collections.singletonMap("end_session_endpoint", "https://logout"); + return new InMemoryClientRegistrationRepository(TestClientRegistrations.clientRegistration() + .providerConfigurationMetadata(providerMetadata).build()); } + } - private static abstract class CommonWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { + private abstract static class CommonWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -873,6 +881,7 @@ public class OAuth2LoginConfigurerTests { .userInfoEndpoint() .userService(createOauth2UserService()) .oidcUserService(createOidcUserService()); + // @formatter:on } @Bean @@ -884,28 +893,30 @@ public class OAuth2LoginConfigurerTests { HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestRepository() { return new HttpSessionOAuth2AuthorizationRequestRepository(); } + } - private static abstract class CommonLambdaWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { + private abstract static class CommonLambdaWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) - .securityContext(securityContext -> + .securityContext((securityContext) -> securityContext .securityContextRepository(securityContextRepository()) ) - .oauth2Login(oauth2Login -> + .oauth2Login((oauth2Login) -> oauth2Login - .tokenEndpoint(tokenEndpoint -> + .tokenEndpoint((tokenEndpoint) -> tokenEndpoint .accessTokenResponseClient(createOauth2AccessTokenResponseClient()) ) - .userInfoEndpoint(userInfoEndpoint -> + .userInfoEndpoint((userInfoEndpoint) -> userInfoEndpoint .userService(createOauth2UserService()) .oidcUserService(createOidcUserService()) @@ -923,6 +934,7 @@ public class OAuth2LoginConfigurerTests { HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestRepository() { return new HttpSessionOAuth2AuthorizationRequestRepository(); } + } @Configuration @@ -930,7 +942,7 @@ public class OAuth2LoginConfigurerTests { @Bean JwtDecoderFactory jwtDecoderFactory() { - return clientRegistration -> getJwtDecoder(); + return (clientRegistration) -> getJwtDecoder(); } private static JwtDecoder getJwtDecoder() { @@ -939,11 +951,12 @@ public class OAuth2LoginConfigurerTests { claims.put(IdTokenClaimNames.ISS, "http://localhost/iss"); claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d")); claims.put(IdTokenClaimNames.AZP, "clientId"); - Jwt jwt = jwt().claims(c -> c.putAll(claims)).build(); + Jwt jwt = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); JwtDecoder jwtDecoder = mock(JwtDecoder.class); - when(jwtDecoder.decode(any())).thenReturn(jwt); + given(jwtDecoder.decode(any())).willReturn(jwt); return jwtDecoder; } + } @Configuration @@ -951,50 +964,14 @@ public class OAuth2LoginConfigurerTests { @Bean JwtDecoderFactory jwtDecoderFactory1() { - return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder(); + return (clientRegistration) -> JwtDecoderFactoryConfig.getJwtDecoder(); } @Bean JwtDecoderFactory jwtDecoderFactory2() { - return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder(); + return (clientRegistration) -> JwtDecoderFactoryConfig.getJwtDecoder(); } } - private static OAuth2AccessTokenResponseClient createOauth2AccessTokenResponseClient() { - return request -> { - Map additionalParameters = new HashMap<>(); - if (request.getAuthorizationExchange().getAuthorizationRequest().getScopes().contains("openid")) { - additionalParameters.put(OidcParameterNames.ID_TOKEN, "token123"); - } - return OAuth2AccessTokenResponse.withToken("accessToken123") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(additionalParameters) - .build(); - }; - } - - private static OAuth2UserService createOauth2UserService() { - Map userAttributes = Collections.singletonMap("name", "spring"); - return request -> new DefaultOAuth2User( - Collections.singleton(new OAuth2UserAuthority(userAttributes)), - userAttributes, "name"); - } - - private static OAuth2UserService createOidcUserService() { - OidcIdToken idToken = idToken().build(); - return request -> new DefaultOidcUser( - Collections.singleton(new OidcUserAuthority(idToken)), idToken); - } - - private static GrantedAuthoritiesMapper createGrantedAuthoritiesMapper() { - return authorities -> { - boolean isOidc = OidcUserAuthority.class - .isInstance(authorities.iterator().next()); - List mappedAuthorities = new ArrayList<>(authorities); - mappedAuthorities.add(new SimpleGrantedAuthority( - isOidc ? "ROLE_OIDC_USER" : "ROLE_OAUTH2_USER")); - return mappedAuthorities; - }; - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java index cf552a6ece..0057529010 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java @@ -31,6 +31,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; + import javax.annotation.PreDestroy; import com.nimbusds.jose.JWSAlgorithm; @@ -93,13 +94,16 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtTimestampValidator; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; @@ -117,34 +121,31 @@ import org.springframework.security.web.access.AccessDeniedHandlerImpl; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.client.RestOperations; import org.springframework.web.context.support.GenericWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.hamcrest.CoreMatchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -153,8 +154,6 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; -import static org.springframework.web.bind.annotation.RequestMethod.GET; -import static org.springframework.web.bind.annotation.RequestMethod.POST; /** * Tests for {@link OAuth2ResourceServerConfigurer} @@ -163,20 +162,29 @@ import static org.springframework.web.bind.annotation.RequestMethod.POST; * @author Evgeniy Cheban */ public class OAuth2ResourceServerConfigurerTests { + private static final String JWT_TOKEN = "token"; + private static final String JWT_SUBJECT = "mock-test-subject"; - private static final Map JWT_CLAIMS = Collections.singletonMap(SUB, JWT_SUBJECT); - private static final Jwt JWT = jwt().build(); + + private static final Map JWT_CLAIMS = Collections.singletonMap(JwtClaimNames.SUB, JWT_SUBJECT); + + private static final Jwt JWT = TestJwts.jwt().build(); + private static final String JWK_SET_URI = "https://mock.org"; - private static final JwtAuthenticationToken JWT_AUTHENTICATION_TOKEN = - new JwtAuthenticationToken(JWT, Collections.emptyList()); + + private static final JwtAuthenticationToken JWT_AUTHENTICATION_TOKEN = new JwtAuthenticationToken(JWT, + Collections.emptyList()); private static final String INTROSPECTION_URI = "https://idp.example.com"; + private static final String CLIENT_ID = "client-id"; + private static final String CLIENT_SECRET = "client-secret"; - private static final BearerTokenAuthentication INTROSPECTION_AUTHENTICATION_TOKEN = - new BearerTokenAuthentication(new DefaultOAuth2AuthenticatedPrincipal(JWT_CLAIMS, Collections.emptyList()), - noScopes(), Collections.emptyList()); + + private static final BearerTokenAuthentication INTROSPECTION_AUTHENTICATION_TOKEN = new BearerTokenAuthentication( + new DefaultOAuth2AuthenticatedPrincipal(JWT_CLAIMS, Collections.emptyList()), + TestOAuth2AccessTokens.noScopes(), Collections.emptyList()); @Autowired(required = false) MockMvc mvc; @@ -188,29 +196,27 @@ public class OAuth2ResourceServerConfigurerTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void getWhenUsingDefaultsWithValidBearerTokenThenAcceptsRequest() - throws Exception { - + public void getWhenUsingDefaultsWithValidBearerTokenThenAcceptsRequest() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("ok")); + // @formatter:on } @Test - public void getWhenUsingDefaultsInLambdaWithValidBearerTokenThenAcceptsRequest() - throws Exception { - + public void getWhenUsingDefaultsInLambdaWithValidBearerTokenThenAcceptsRequest() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultInLambdaConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("ok")); + // @formatter:on } @Test @@ -218,10 +224,11 @@ public class OAuth2ResourceServerConfigurerTests { this.spring.register(WebServerConfig.class, JwkSetUriConfig.class, BasicController.class).autowire(); mockWebServer(jwks("Default")); String token = this.token("ValidNoScopes"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("ok")); + // @formatter:on } @Test @@ -229,151 +236,136 @@ public class OAuth2ResourceServerConfigurerTests { this.spring.register(WebServerConfig.class, JwkSetUriInLambdaConfig.class, BasicController.class).autowire(); mockWebServer(jwks("Default")); String token = this.token("ValidNoScopes"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("ok")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithExpiredBearerTokenThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithExpiredBearerTokenThenInvalidToken() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("Expired"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithBadJwkEndpointThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithBadJwkEndpointThenInvalidToken() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class).autowire(); mockRestOperations("malformed"); String token = this.token("ValidNoScopes"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(header().string("WWW-Authenticate", "Bearer")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithUnavailableJwkEndpointThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithUnavailableJwkEndpointThenInvalidToken() throws Exception { this.spring.register(WebServerConfig.class, JwkSetUriConfig.class).autowire(); this.web.shutdown(); String token = this.token("ValidNoScopes"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(header().string("WWW-Authenticate", "Bearer")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithMalformedBearerTokenThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithMalformedBearerTokenThenInvalidToken() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken("an\"invalid\"token"))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Bearer token is malformed")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithMalformedPayloadThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithMalformedPayloadThenInvalidToken() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("MalformedPayload"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt: Malformed payload")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithUnsignedBearerTokenThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithUnsignedBearerTokenThenInvalidToken() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); String token = this.token("Unsigned"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Unsupported algorithm of none")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithBearerTokenBeforeNotBeforeThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsWithBearerTokenBeforeNotBeforeThenInvalidToken() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class).autowire(); this.mockRestOperations(jwks("Default")); String token = this.token("TooEarly"); - + // @formatter:off this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithBearerTokenInTwoPlacesThenInvalidRequest() - throws Exception { - + public void getWhenUsingDefaultsWithBearerTokenInTwoPlacesThenInvalidRequest() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); - - this.mvc.perform(get("/") - .with(bearerToken("token")) - .with(bearerToken("token").asParam())) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken("token")).with(bearerToken("token").asParam())) .andExpect(status().isBadRequest()) .andExpect(invalidRequestHeader("Found multiple bearer tokens in the request")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithBearerTokenInTwoParametersThenInvalidRequest() - throws Exception { - + public void getWhenUsingDefaultsWithBearerTokenInTwoParametersThenInvalidRequest() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); - MultiValueMap params = new LinkedMultiValueMap<>(); params.add("access_token", "token1"); params.add("access_token", "token2"); - - this.mvc.perform(get("/") - .params(params)) + // @formatter:off + this.mvc.perform(get("/").params(params)) .andExpect(status().isBadRequest()) .andExpect(invalidRequestHeader("Found multiple bearer tokens in the request")); + // @formatter:on } @Test - public void postWhenUsingDefaultsWithBearerTokenAsFormParameterThenIgnoresToken() - throws Exception { - + public void postWhenUsingDefaultsWithBearerTokenAsFormParameterThenIgnoresToken() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); - - this.mvc.perform(post("/") // engage csrf - .with(bearerToken("token").asParam())) + // engage csrf + // @formatter:off + this.mvc.perform(post("/").with(bearerToken("token").asParam())) .andExpect(status().isForbidden()) .andExpect(header().doesNotExist(HttpHeaders.WWW_AUTHENTICATE)); + // @formatter:on } @Test - public void postWhenCsrfDisabledWithBearerTokenAsFormParameterThenIgnoresToken() - throws Exception { - + public void postWhenCsrfDisabledWithBearerTokenAsFormParameterThenIgnoresToken() throws Exception { this.spring.register(CsrfDisabledConfig.class).autowire(); - - this.mvc.perform(post("/") - .with(bearerToken("token").asParam())) + // @formatter:off + this.mvc.perform(post("/").with(bearerToken("token").asParam())) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); + // @formatter:on } // gh-8031 @@ -382,869 +374,702 @@ public class OAuth2ResourceServerConfigurerTests { this.spring.register(RestOperationsConfig.class, AnonymousDisabledConfig.class).autowire(); mockRestOperations(jwks("Default")); String token = token("ValidNoScopes"); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(token))) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithNoBearerTokenThenUnauthorized() - throws Exception { - + public void getWhenUsingDefaultsWithNoBearerTokenThenUnauthorized() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithSufficientlyScopedBearerTokenThenAcceptsRequest() - throws Exception { - + public void getWhenUsingDefaultsWithSufficientlyScopedBearerTokenThenAcceptsRequest() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageReadScope"); - - this.mvc.perform(get("/requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("[SCOPE_message:read]")); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithInsufficientScopeThenInsufficientScopeError() - throws Exception { - + public void getWhenUsingDefaultsWithInsufficientScopeThenInsufficientScopeError() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").with(bearerToken(token))) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); + // @formatter:on } @Test - public void getWhenUsingDefaultsWithInsufficientScpThenInsufficientScopeError() - throws Exception { - + public void getWhenUsingDefaultsWithInsufficientScpThenInsufficientScopeError() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageWriteScp"); - - this.mvc.perform(get("/requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").with(bearerToken(token))) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); + // @formatter:on } @Test - public void getWhenUsingDefaultsAndAuthorizationServerHasNoMatchingKeyThenInvalidToken() - throws Exception { - + public void getWhenUsingDefaultsAndAuthorizationServerHasNoMatchingKeyThenInvalidToken() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class).autowire(); mockRestOperations(jwks("Empty")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } @Test - public void getWhenUsingDefaultsAndAuthorizationServerHasMultipleMatchingKeysThenOk() - throws Exception { - + public void getWhenUsingDefaultsAndAuthorizationServerHasMultipleMatchingKeysThenOk() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("TwoKeys")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); + // @formatter:on } @Test - public void getWhenUsingDefaultsAndKeyMatchesByKidThenOk() - throws Exception { - + public void getWhenUsingDefaultsAndKeyMatchesByKidThenOk() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("TwoKeys")); String token = this.token("Kid"); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); + // @formatter:on } - // -- Method Security - @Test - public void getWhenUsingMethodSecurityWithValidBearerTokenThenAcceptsRequest() - throws Exception { - + public void getWhenUsingMethodSecurityWithValidBearerTokenThenAcceptsRequest() throws Exception { this.spring.register(RestOperationsConfig.class, MethodSecurityConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageReadScope"); - - this.mvc.perform(get("/ms-requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/ms-requires-read-scope").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("[SCOPE_message:read]")); + // @formatter:on } @Test - public void getWhenUsingMethodSecurityWithValidBearerTokenHavingScpAttributeThenAcceptsRequest() - throws Exception { - + public void getWhenUsingMethodSecurityWithValidBearerTokenHavingScpAttributeThenAcceptsRequest() throws Exception { this.spring.register(RestOperationsConfig.class, MethodSecurityConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageReadScp"); - - this.mvc.perform(get("/ms-requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/ms-requires-read-scope").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("[SCOPE_message:read]")); + // @formatter:on } @Test - public void getWhenUsingMethodSecurityWithInsufficientScopeThenInsufficientScopeError() - throws Exception { - + public void getWhenUsingMethodSecurityWithInsufficientScopeThenInsufficientScopeError() throws Exception { this.spring.register(RestOperationsConfig.class, MethodSecurityConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/ms-requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/ms-requires-read-scope").with(bearerToken(token))) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); - + // @formatter:on } @Test - public void getWhenUsingMethodSecurityWithInsufficientScpThenInsufficientScopeError() - throws Exception { - + public void getWhenUsingMethodSecurityWithInsufficientScpThenInsufficientScopeError() throws Exception { this.spring.register(RestOperationsConfig.class, MethodSecurityConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageWriteScp"); - - this.mvc.perform(get("/ms-requires-read-scope") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/ms-requires-read-scope").with(bearerToken(token))) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); + // @formatter:on } @Test - public void getWhenUsingMethodSecurityWithDenyAllThenInsufficientScopeError() - throws Exception { - + public void getWhenUsingMethodSecurityWithDenyAllThenInsufficientScopeError() throws Exception { this.spring.register(RestOperationsConfig.class, MethodSecurityConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageReadScope"); - - this.mvc.perform(get("/ms-deny") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/ms-deny").with(bearerToken(token))) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); + // @formatter:on } - // -- Resource Server should not engage csrf - @Test - public void postWhenUsingDefaultsWithValidBearerTokenAndNoCsrfTokenThenOk() - throws Exception { - + public void postWhenUsingDefaultsWithValidBearerTokenAndNoCsrfTokenThenOk() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(post("/authenticated") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(post("/authenticated").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); + // @formatter:on } @Test - public void postWhenUsingDefaultsWithNoBearerTokenThenCsrfDenies() - throws Exception { - + public void postWhenUsingDefaultsWithNoBearerTokenThenCsrfDenies() throws Exception { this.spring.register(JwkSetUriConfig.class).autowire(); - + // @formatter:off this.mvc.perform(post("/authenticated")) .andExpect(status().isForbidden()) .andExpect(header().doesNotExist(HttpHeaders.WWW_AUTHENTICATE)); + // @formatter:on } @Test - public void postWhenUsingDefaultsWithExpiredBearerTokenAndNoCsrfThenInvalidToken() - throws Exception { - + public void postWhenUsingDefaultsWithExpiredBearerTokenAndNoCsrfThenInvalidToken() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("Expired"); - - this.mvc.perform(post("/authenticated") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(post("/authenticated").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } - // -- Resource Server should not create sessions - @Test - public void requestWhenDefaultConfiguredThenSessionIsNotCreated() - throws Exception { - + public void requestWhenDefaultConfiguredThenSessionIsNotCreated() throws Exception { this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - MvcResult result = this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + MvcResult result = this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void requestWhenIntrospectionConfiguredThenSessionIsNotCreated() - throws Exception { - + public void requestWhenIntrospectionConfiguredThenSessionIsNotCreated() throws Exception { this.spring.register(RestOperationsConfig.class, OpaqueTokenConfig.class, BasicController.class).autowire(); mockRestOperations(json("Active")); - - MvcResult result = this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) + // @formatter:off + MvcResult result = this.mvc.perform(get("/authenticated").with(bearerToken("token"))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void requestWhenUsingDefaultsAndNoBearerTokenThenSessionIsCreated() - throws Exception { - + public void requestWhenUsingDefaultsAndNoBearerTokenThenSessionIsCreated() throws Exception { this.spring.register(JwkSetUriConfig.class, BasicController.class).autowire(); - + // @formatter:off MvcResult result = this.mvc.perform(get("/")) .andExpect(status().isUnauthorized()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNotNull(); } @Test - public void requestWhenSessionManagementConfiguredThenUserConfigurationOverrides() - throws Exception { - - this.spring.register(RestOperationsConfig.class, AlwaysSessionCreationConfig.class, BasicController.class).autowire(); + public void requestWhenSessionManagementConfiguredThenUserConfigurationOverrides() throws Exception { + this.spring.register(RestOperationsConfig.class, AlwaysSessionCreationConfig.class, BasicController.class) + .autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - MvcResult result = this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + MvcResult result = this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNotNull(); } - // -- custom bearer token resolver - @Test public void requestWhenBearerTokenResolverAllowsRequestBodyThenEitherHeaderOrRequestBodyIsAccepted() throws Exception { - - this.spring.register(AllowBearerTokenInRequestBodyConfig.class, JwtDecoderConfig.class, - BasicController.class).autowire(); - + this.spring.register(AllowBearerTokenInRequestBodyConfig.class, JwtDecoderConfig.class, BasicController.class) + .autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(JWT_TOKEN))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); - - this.mvc.perform(post("/authenticated") - .param("access_token", JWT_TOKEN)) + this.mvc.perform(post("/authenticated").param("access_token", JWT_TOKEN)) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); + // @formatter:on } @Test public void requestWhenBearerTokenResolverAllowsQueryParameterThenEitherHeaderOrQueryParameterIsAccepted() throws Exception { - - this.spring.register(AllowBearerTokenAsQueryParameterConfig.class, JwtDecoderConfig.class, - BasicController.class).autowire(); - + this.spring + .register(AllowBearerTokenAsQueryParameterConfig.class, JwtDecoderConfig.class, BasicController.class) + .autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(JWT_TOKEN))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); - - this.mvc.perform(get("/authenticated") - .param("access_token", JWT_TOKEN)) + this.mvc.perform(get("/authenticated").param("access_token", JWT_TOKEN)) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); + // @formatter:on } @Test public void requestWhenBearerTokenResolverAllowsRequestBodyAndRequestContainsTwoTokensThenInvalidRequest() throws Exception { - - this.spring.register(AllowBearerTokenInRequestBodyConfig.class, JwtDecoderConfig.class, - BasicController.class).autowire(); - + this.spring.register(AllowBearerTokenInRequestBodyConfig.class, JwtDecoderConfig.class, BasicController.class) + .autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(post("/authenticated") + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + MockHttpServletRequestBuilder request = post("/authenticated") .param("access_token", JWT_TOKEN) .with(bearerToken(JWT_TOKEN)) - .with(csrf())) + .with(csrf()); + this.mvc.perform(request) .andExpect(status().isBadRequest()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("invalid_request"))); + // @formatter:on } @Test public void requestWhenBearerTokenResolverAllowsQueryParameterAndRequestContainsTwoTokensThenInvalidRequest() throws Exception { - - this.spring.register(AllowBearerTokenAsQueryParameterConfig.class, JwtDecoderConfig.class, - BasicController.class).autowire(); - + this.spring + .register(AllowBearerTokenAsQueryParameterConfig.class, JwtDecoderConfig.class, BasicController.class) + .autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + MockHttpServletRequestBuilder request = get("/authenticated") .with(bearerToken(JWT_TOKEN)) - .param("access_token", JWT_TOKEN)) + .param("access_token", JWT_TOKEN); + this.mvc.perform(request) .andExpect(status().isBadRequest()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("invalid_request"))); + // @formatter:on } @Test public void getBearerTokenResolverWhenDuplicateResolverBeansAndAnotherOnTheDslThenTheDslOneIsUsed() { BearerTokenResolver resolverBean = mock(BearerTokenResolver.class); BearerTokenResolver resolver = mock(BearerTokenResolver.class); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean("resolverOne", BearerTokenResolver.class, () -> resolverBean); context.registerBean("resolverTwo", BearerTokenResolver.class, () -> resolverBean); this.spring.context(context).autowire(); - OAuth2ResourceServerConfigurer oauth2 = new OAuth2ResourceServerConfigurer(context); - oauth2.bearerTokenResolver(resolver); - assertThat(oauth2.getBearerTokenResolver()).isEqualTo(resolver); } @Test public void getBearerTokenResolverWhenDuplicateResolverBeansThenWiringException() { - assertThatCode(() -> this.spring - .register(MultipleBearerTokenResolverBeansConfig.class, JwtDecoderConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring + .register(MultipleBearerTokenResolverBeansConfig.class, JwtDecoderConfig.class).autowire()) + .withRootCauseInstanceOf(NoUniqueBeanDefinitionException.class); } @Test public void getBearerTokenResolverWhenResolverBeanAndAnotherOnTheDslThenTheDslOneIsUsed() { BearerTokenResolver resolver = mock(BearerTokenResolver.class); BearerTokenResolver resolverBean = mock(BearerTokenResolver.class); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean(BearerTokenResolver.class, () -> resolverBean); this.spring.context(context).autowire(); - OAuth2ResourceServerConfigurer oauth2 = new OAuth2ResourceServerConfigurer(context); oauth2.bearerTokenResolver(resolver); - assertThat(oauth2.getBearerTokenResolver()).isEqualTo(resolver); } @Test public void getBearerTokenResolverWhenNoResolverSpecifiedThenTheDefaultIsUsed() { - ApplicationContext context = - this.spring.context(new GenericWebApplicationContext()).getContext(); - + ApplicationContext context = this.spring.context(new GenericWebApplicationContext()).getContext(); OAuth2ResourceServerConfigurer oauth2 = new OAuth2ResourceServerConfigurer(context); - assertThat(oauth2.getBearerTokenResolver()).isInstanceOf(DefaultBearerTokenResolver.class); } - // -- custom jwt decoder - @Test - public void requestWhenCustomJwtDecoderWiredOnDslThenUsed() - throws Exception { - + public void requestWhenCustomJwtDecoderWiredOnDslThenUsed() throws Exception { this.spring.register(CustomJwtDecoderOnDsl.class, BasicController.class).autowire(); - CustomJwtDecoderOnDsl config = this.spring.getContext().getBean(CustomJwtDecoderOnDsl.class); JwtDecoder decoder = config.decoder(); - - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(JWT_TOKEN))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); + // @formatter:on } @Test - public void requestWhenCustomJwtDecoderInLambdaOnDslThenUsed() - throws Exception { - + public void requestWhenCustomJwtDecoderInLambdaOnDslThenUsed() throws Exception { this.spring.register(CustomJwtDecoderInLambdaOnDsl.class, BasicController.class).autowire(); - CustomJwtDecoderInLambdaOnDsl config = this.spring.getContext().getBean(CustomJwtDecoderInLambdaOnDsl.class); JwtDecoder decoder = config.decoder(); - - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(JWT_TOKEN))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); + // @formatter:on } @Test - public void requestWhenCustomJwtDecoderExposedAsBeanThenUsed() - throws Exception { - + public void requestWhenCustomJwtDecoderExposedAsBeanThenUsed() throws Exception { this.spring.register(CustomJwtDecoderAsBean.class, BasicController.class).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(JWT_TOKEN))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()) .andExpect(content().string(JWT_SUBJECT)); + // @formatter:on } @Test public void getJwtDecoderWhenConfiguredWithDecoderAndJwkSetUriThenLastOneWins() { ApplicationContext context = mock(ApplicationContext.class); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); JwtDecoder decoder = mock(JwtDecoder.class); - jwtConfigurer.jwkSetUri(JWK_SET_URI); jwtConfigurer.decoder(decoder); - assertThat(jwtConfigurer.getJwtDecoder()).isEqualTo(decoder); - - jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - + jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); jwtConfigurer.decoder(decoder); jwtConfigurer.jwkSetUri(JWK_SET_URI); - assertThat(jwtConfigurer.getJwtDecoder()).isInstanceOf(NimbusJwtDecoder.class); - } @Test public void getJwtDecoderWhenConflictingJwtDecodersThenTheDslWiredOneTakesPrecedence() { - JwtDecoder decoderBean = mock(JwtDecoder.class); JwtDecoder decoder = mock(JwtDecoder.class); - ApplicationContext context = mock(ApplicationContext.class); - when(context.getBean(JwtDecoder.class)).thenReturn(decoderBean); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); + given(context.getBean(JwtDecoder.class)).willReturn(decoderBean); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); jwtConfigurer.decoder(decoder); - assertThat(jwtConfigurer.getJwtDecoder()).isEqualTo(decoder); } @Test public void getJwtDecoderWhenContextHasBeanAndUserConfiguresJwkSetUriThenJwkSetUriTakesPrecedence() { - JwtDecoder decoder = mock(JwtDecoder.class); ApplicationContext context = mock(ApplicationContext.class); - when(context.getBean(JwtDecoder.class)).thenReturn(decoder); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - + given(context.getBean(JwtDecoder.class)).willReturn(decoder); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); jwtConfigurer.jwkSetUri(JWK_SET_URI); - assertThat(jwtConfigurer.getJwtDecoder()).isNotEqualTo(decoder); assertThat(jwtConfigurer.getJwtDecoder()).isInstanceOf(NimbusJwtDecoder.class); } @Test public void getJwtDecoderWhenTwoJwtDecoderBeansAndAnotherWiredOnDslThenDslWiredOneTakesPrecedence() { - JwtDecoder decoderBean = mock(JwtDecoder.class); JwtDecoder decoder = mock(JwtDecoder.class); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean("decoderOne", JwtDecoder.class, () -> decoderBean); context.registerBean("decoderTwo", JwtDecoder.class, () -> decoderBean); this.spring.context(context).autowire(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); jwtConfigurer.decoder(decoder); - assertThat(jwtConfigurer.getJwtDecoder()).isEqualTo(decoder); } @Test public void getJwtDecoderWhenTwoJwtDecoderBeansThenThrowsException() { - JwtDecoder decoder = mock(JwtDecoder.class); GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean("decoderOne", JwtDecoder.class, () -> decoder); context.registerBean("decoderTwo", JwtDecoder.class, () -> decoder); - this.spring.context(context).autowire(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - - assertThatCode(() -> jwtConfigurer.getJwtDecoder()) - .isInstanceOf(NoUniqueBeanDefinitionException.class); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); + assertThatExceptionOfType(NoUniqueBeanDefinitionException.class) + .isThrownBy(() -> jwtConfigurer.getJwtDecoder()); } - // -- exception handling - @Test - public void requestWhenRealmNameConfiguredThenUsesOnUnauthenticated() - throws Exception { - + public void requestWhenRealmNameConfiguredThenUsesOnUnauthenticated() throws Exception { this.spring.register(RealmNameConfiguredOnEntryPoint.class, JwtDecoderConfig.class).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenThrow(JwtException.class); - - this.mvc.perform(get("/authenticated") - .with(bearerToken("invalid_token"))) + given(decoder.decode(anyString())).willThrow(JwtException.class); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("invalid_token"))) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer realm=\"myRealm\""))); + // @formatter:on } @Test - public void requestWhenRealmNameConfiguredThenUsesOnAccessDenied() - throws Exception { - + public void requestWhenRealmNameConfiguredThenUsesOnAccessDenied() throws Exception { this.spring.register(RealmNameConfiguredOnAccessDeniedHandler.class, JwtDecoderConfig.class).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(bearerToken("insufficiently_scoped"))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("insufficiently_scoped"))) .andExpect(status().isForbidden()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer realm=\"myRealm\""))); + // @formatter:on } @Test public void authenticationEntryPointWhenGivenNullThenThrowsException() { ApplicationContext context = mock(ApplicationContext.class); OAuth2ResourceServerConfigurer configurer = new OAuth2ResourceServerConfigurer(context); - assertThatCode(() -> configurer.authenticationEntryPoint(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> configurer.authenticationEntryPoint(null)); } @Test public void accessDeniedHandlerWhenGivenNullThenThrowsException() { ApplicationContext context = mock(ApplicationContext.class); OAuth2ResourceServerConfigurer configurer = new OAuth2ResourceServerConfigurer(context); - assertThatCode(() -> configurer.accessDeniedHandler(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> configurer.accessDeniedHandler(null)); } - // -- token validator - @Test - public void requestWhenCustomJwtValidatorFailsThenCorrespondingErrorMessage() - throws Exception { - + public void requestWhenCustomJwtValidatorFailsThenCorrespondingErrorMessage() throws Exception { this.spring.register(RestOperationsConfig.class, CustomJwtValidatorConfig.class).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - OAuth2TokenValidator jwtValidator = - this.spring.getContext().getBean(CustomJwtValidatorConfig.class) - .getJwtValidator(); - + OAuth2TokenValidator jwtValidator = this.spring.getContext().getBean(CustomJwtValidatorConfig.class) + .getJwtValidator(); OAuth2Error error = new OAuth2Error("custom-error", "custom-description", "custom-uri"); - - when(jwtValidator.validate(any(Jwt.class))).thenReturn(OAuth2TokenValidatorResult.failure(error)); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + given(jwtValidator.validate(any(Jwt.class))).willReturn(OAuth2TokenValidatorResult.failure(error)); + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("custom-description"))); + // @formatter:on } @Test - public void requestWhenClockSkewSetThenTimestampWindowRelaxedAccordingly() - throws Exception { - - this.spring.register(RestOperationsConfig.class, UnexpiredJwtClockSkewConfig.class, BasicController.class).autowire(); + public void requestWhenClockSkewSetThenTimestampWindowRelaxedAccordingly() throws Exception { + this.spring.register(RestOperationsConfig.class, UnexpiredJwtClockSkewConfig.class, BasicController.class) + .autowire(); mockRestOperations(jwks("Default")); String token = this.token("ExpiresAt4687177990"); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()); + // @formatter:on } @Test - public void requestWhenClockSkewSetButJwtStillTooLateThenReportsExpired() - throws Exception { - - this.spring.register(RestOperationsConfig.class, ExpiredJwtClockSkewConfig.class, BasicController.class).autowire(); + public void requestWhenClockSkewSetButJwtStillTooLateThenReportsExpired() throws Exception { + this.spring.register(RestOperationsConfig.class, ExpiredJwtClockSkewConfig.class, BasicController.class) + .autowire(); mockRestOperations(jwks("Default")); String token = this.token("ExpiresAt4687177990"); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Jwt expired at")); + // @formatter:on } - // -- converter - @Test - public void requestWhenJwtAuthenticationConverterConfiguredOnDslThenIsUsed() - throws Exception { - + public void requestWhenJwtAuthenticationConverterConfiguredOnDslThenIsUsed() throws Exception { this.spring.register(JwtDecoderConfig.class, JwtAuthenticationConverterConfiguredOnDsl.class, BasicController.class).autowire(); - - Converter jwtAuthenticationConverter = - this.spring.getContext().getBean(JwtAuthenticationConverterConfiguredOnDsl.class) - .getJwtAuthenticationConverter(); - when(jwtAuthenticationConverter.convert(JWT)).thenReturn(JWT_AUTHENTICATION_TOKEN); - + Converter jwtAuthenticationConverter = this.spring.getContext() + .getBean(JwtAuthenticationConverterConfiguredOnDsl.class).getJwtAuthenticationConverter(); + given(jwtAuthenticationConverter.convert(JWT)).willReturn(JWT_AUTHENTICATION_TOKEN); JwtDecoder jwtDecoder = this.spring.getContext().getBean(JwtDecoder.class); - when(jwtDecoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/") - .with(bearerToken(JWT_TOKEN))) + given(jwtDecoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()); - + // @formatter:on verify(jwtAuthenticationConverter).convert(JWT); } @Test public void requestWhenJwtAuthenticationConverterCustomizedAuthoritiesThenThoseAuthoritiesArePropagated() throws Exception { - this.spring.register(JwtDecoderConfig.class, CustomAuthorityMappingConfig.class, BasicController.class) .autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(JWT_TOKEN)).thenReturn(JWT); - - this.mvc.perform(get("/requires-read-scope") - .with(bearerToken(JWT_TOKEN))) + given(decoder.decode(JWT_TOKEN)).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/requires-read-scope").with(bearerToken(JWT_TOKEN))) .andExpect(status().isOk()); + // @formatter:on } - // -- single key - @Test - public void requestWhenUsingPublicKeyAndValidTokenThenAuthenticates() - throws Exception { - + public void requestWhenUsingPublicKeyAndValidTokenThenAuthenticates() throws Exception { this.spring.register(SingleKeyConfig.class, BasicController.class).autowire(); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isOk()); + // @formatter:on } @Test - public void requestWhenUsingPublicKeyAndSignatureFailsThenReturnsInvalidToken() - throws Exception { - + public void requestWhenUsingPublicKeyAndSignatureFailsThenReturnsInvalidToken() throws Exception { this.spring.register(SingleKeyConfig.class).autowire(); String token = this.token("WrongSignature"); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(invalidTokenHeader("signature")); + // @formatter:on } @Test - public void requestWhenUsingPublicKeyAlgorithmDoesNotMatchThenReturnsInvalidToken() - throws Exception { - + public void requestWhenUsingPublicKeyAlgorithmDoesNotMatchThenReturnsInvalidToken() throws Exception { this.spring.register(SingleKeyConfig.class).autowire(); String token = this.token("WrongAlgorithm"); - - this.mvc.perform(get("/") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(invalidTokenHeader("algorithm")); + // @formatter:on } // gh-7793 @Test - public void requestWhenUsingCustomAuthenticationEventPublisherThenUses() throws Exception{ + public void requestWhenUsingCustomAuthenticationEventPublisherThenUses() throws Exception { this.spring.register(CustomAuthenticationEventPublisher.class).autowire(); - - when(bean(JwtDecoder.class).decode(anyString())) - .thenThrow(new BadJwtException("problem")); - + given(bean(JwtDecoder.class).decode(anyString())).willThrow(new BadJwtException("problem")); this.mvc.perform(get("/").with(bearerToken("token"))); - verifyBean(AuthenticationEventPublisher.class) - .publishAuthenticationFailure( - any(OAuth2AuthenticationException.class), - any(Authentication.class)); + .publishAuthenticationFailure(any(OAuth2AuthenticationException.class), any(Authentication.class)); } @Test public void getWhenCustomJwtAuthenticationManagerThenUsed() throws Exception { this.spring.register(JwtAuthenticationManagerConfig.class, BasicController.class).autowire(); - - when(bean(AuthenticationProvider.class).authenticate(any(Authentication.class))) - .thenReturn(JWT_AUTHENTICATION_TOKEN); - this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) + given(bean(AuthenticationProvider.class).authenticate(any(Authentication.class))) + .willReturn(JWT_AUTHENTICATION_TOKEN); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("token"))) .andExpect(status().isOk()) .andExpect(content().string("mock-test-subject")); - + // @formatter:on verifyBean(AuthenticationProvider.class).authenticate(any(Authentication.class)); } - // -- opaque - - @Test public void getWhenIntrospectingThenOk() throws Exception { this.spring.register(RestOperationsConfig.class, OpaqueTokenConfig.class, BasicController.class).autowire(); mockRestOperations(json("Active")); - - this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("token"))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); + // @formatter:on } @Test public void getWhenOpaqueTokenInLambdaAndIntrospectingThenOk() throws Exception { - this.spring.register(RestOperationsConfig.class, OpaqueTokenInLambdaConfig.class, BasicController.class).autowire(); + this.spring.register(RestOperationsConfig.class, OpaqueTokenInLambdaConfig.class, BasicController.class) + .autowire(); mockRestOperations(json("Active")); - - this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("token"))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); + // @formatter:on } @Test public void getWhenIntrospectionFailsThenUnauthorized() throws Exception { this.spring.register(RestOperationsConfig.class, OpaqueTokenConfig.class).autowire(); mockRestOperations(json("Inactive")); - - this.mvc.perform(get("/") - .with(bearerToken("token"))) + // @formatter:off + this.mvc.perform(get("/").with(bearerToken("token"))) .andExpect(status().isUnauthorized()) - .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, - containsString("Provided token isn't active"))); + .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("Provided token isn't active"))); + // @formatter:on } @Test public void getWhenIntrospectionLacksScopeThenForbidden() throws Exception { this.spring.register(RestOperationsConfig.class, OpaqueTokenConfig.class).autowire(); mockRestOperations(json("ActiveNoScopes")); - - this.mvc.perform(get("/requires-read-scope") - .with(bearerToken("token"))) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").with(bearerToken("token"))) .andExpect(status().isForbidden()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("scope"))); + // @formatter:on } @Test public void getWhenCustomIntrospectionAuthenticationManagerThenUsed() throws Exception { this.spring.register(OpaqueTokenAuthenticationManagerConfig.class, BasicController.class).autowire(); - - when(bean(AuthenticationProvider.class).authenticate(any(Authentication.class))) - .thenReturn(INTROSPECTION_AUTHENTICATION_TOKEN); - this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) + given(bean(AuthenticationProvider.class).authenticate(any(Authentication.class))) + .willReturn(INTROSPECTION_AUTHENTICATION_TOKEN); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("token"))) .andExpect(status().isOk()) .andExpect(content().string("mock-test-subject")); - + // @formatter:on verifyBean(AuthenticationProvider.class).authenticate(any(Authentication.class)); } @Test public void getWhenCustomIntrospectionAuthenticationManagerInLambdaThenUsed() throws Exception { this.spring.register(OpaqueTokenAuthenticationManagerInLambdaConfig.class, BasicController.class).autowire(); - - when(bean(AuthenticationProvider.class).authenticate(any(Authentication.class))) - .thenReturn(INTROSPECTION_AUTHENTICATION_TOKEN); - this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) + given(bean(AuthenticationProvider.class).authenticate(any(Authentication.class))) + .willReturn(INTROSPECTION_AUTHENTICATION_TOKEN); + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken("token"))) .andExpect(status().isOk()) .andExpect(content().string("mock-test-subject")); - + // @formatter:on verifyBean(AuthenticationProvider.class).authenticate(any(Authentication.class)); } @Test public void configureWhenOnlyIntrospectionUrlThenException() { - assertThatCode(() -> this.spring.register(OpaqueTokenHalfConfiguredConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(OpaqueTokenHalfConfiguredConfig.class).autowire()); } @Test public void getIntrospectionClientWhenConfiguredWithClientAndIntrospectionUriThenLastOneWins() { ApplicationContext context = mock(ApplicationContext.class); - - OAuth2ResourceServerConfigurer.OpaqueTokenConfigurer opaqueTokenConfigurer = - new OAuth2ResourceServerConfigurer(context).opaqueToken(); - + OAuth2ResourceServerConfigurer.OpaqueTokenConfigurer opaqueTokenConfigurer = new OAuth2ResourceServerConfigurer( + context).opaqueToken(); OpaqueTokenIntrospector client = mock(OpaqueTokenIntrospector.class); - opaqueTokenConfigurer.introspectionUri(INTROSPECTION_URI); opaqueTokenConfigurer.introspectionClientCredentials(CLIENT_ID, CLIENT_SECRET); opaqueTokenConfigurer.introspector(client); - assertThat(opaqueTokenConfigurer.getIntrospector()).isEqualTo(client); - - opaqueTokenConfigurer = - new OAuth2ResourceServerConfigurer(context).opaqueToken(); - + opaqueTokenConfigurer = new OAuth2ResourceServerConfigurer(context).opaqueToken(); opaqueTokenConfigurer.introspector(client); opaqueTokenConfigurer.introspectionUri(INTROSPECTION_URI); opaqueTokenConfigurer.introspectionClientCredentials(CLIENT_ID, CLIENT_SECRET); - - assertThat(opaqueTokenConfigurer.getIntrospector()) - .isInstanceOf(NimbusOpaqueTokenIntrospector.class); - + assertThat(opaqueTokenConfigurer.getIntrospector()).isInstanceOf(NimbusOpaqueTokenIntrospector.class); } @Test @@ -1252,141 +1077,104 @@ public class OAuth2ResourceServerConfigurerTests { GenericApplicationContext context = new GenericApplicationContext(); registerMockBean(context, "introspectionClientOne", OpaqueTokenIntrospector.class); registerMockBean(context, "introspectionClientTwo", OpaqueTokenIntrospector.class); - - OAuth2ResourceServerConfigurer.OpaqueTokenConfigurer opaqueToken = - new OAuth2ResourceServerConfigurer(context).opaqueToken(); + OAuth2ResourceServerConfigurer.OpaqueTokenConfigurer opaqueToken = new OAuth2ResourceServerConfigurer(context) + .opaqueToken(); opaqueToken.introspectionUri(INTROSPECTION_URI); opaqueToken.introspectionClientCredentials(CLIENT_ID, CLIENT_SECRET); - assertThat(opaqueToken.getIntrospector()).isNotNull(); } - // -- In combination with other authentication providers - @Test - public void requestWhenBasicAndResourceServerEntryPointsThenMatchedByRequest() - throws Exception { - + public void requestWhenBasicAndResourceServerEntryPointsThenMatchedByRequest() throws Exception { this.spring.register(BasicAndResourceServerConfig.class, JwtDecoderConfig.class).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenThrow(JwtException.class); - - this.mvc.perform(get("/authenticated") - .with(httpBasic("some", "user"))) + given(decoder.decode(anyString())).willThrow(JwtException.class); + // @formatter:off + this.mvc.perform(get("/authenticated").with(httpBasic("some", "user"))) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Basic"))); - this.mvc.perform(get("/authenticated")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Basic"))); - - this.mvc.perform(get("/authenticated") - .with(bearerToken("invalid_token"))) + this.mvc.perform(get("/authenticated").with(bearerToken("invalid_token"))) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer"))); + // @formatter:on } @Test - public void requestWhenFormLoginAndResourceServerEntryPointsThenSessionCreatedByRequest() - throws Exception { - + public void requestWhenFormLoginAndResourceServerEntryPointsThenSessionCreatedByRequest() throws Exception { this.spring.register(FormAndResourceServerConfig.class, JwtDecoderConfig.class).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenThrow(JwtException.class); - - MvcResult result = - this.mvc.perform(get("/authenticated")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn(); - + given(decoder.decode(anyString())).willThrow(JwtException.class); + // @formatter:off + MvcResult result = this.mvc.perform(get("/authenticated")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/login")) + .andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false)).isNotNull(); - - result = - this.mvc.perform(get("/authenticated") - .with(bearerToken("token"))) - .andExpect(status().isUnauthorized()) - .andReturn(); - + // @formatter:off + result = this.mvc.perform(get("/authenticated").with(bearerToken("token"))) + .andExpect(status().isUnauthorized()) + .andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void requestWhenDefaultAndResourceServerAccessDeniedHandlersThenMatchedByRequest() - throws Exception { - - this.spring.register(ExceptionHandlingAndResourceServerWithAccessDeniedHandlerConfig.class, - JwtDecoderConfig.class).autowire(); - + public void requestWhenDefaultAndResourceServerAccessDeniedHandlersThenMatchedByRequest() throws Exception { + this.spring + .register(ExceptionHandlingAndResourceServerWithAccessDeniedHandlerConfig.class, JwtDecoderConfig.class) + .autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(JWT); - - this.mvc.perform(get("/authenticated") - .with(httpBasic("basic-user", "basic-password"))) + given(decoder.decode(anyString())).willReturn(JWT); + // @formatter:off + this.mvc.perform(get("/authenticated").with(httpBasic("basic-user", "basic-password"))) .andExpect(status().isForbidden()) .andExpect(header().doesNotExist(HttpHeaders.WWW_AUTHENTICATE)); - - this.mvc.perform(get("/authenticated") - .with(bearerToken("insufficiently_scoped"))) + this.mvc.perform(get("/authenticated").with(bearerToken("insufficiently_scoped"))) .andExpect(status().isForbidden()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer"))); + // @formatter:on } @Test - public void getWhenAlsoUsingHttpBasicThenCorrectProviderEngages() - throws Exception { - - this.spring.register(RestOperationsConfig.class, BasicAndResourceServerConfig.class, BasicController.class).autowire(); + public void getWhenAlsoUsingHttpBasicThenCorrectProviderEngages() throws Exception { + this.spring.register(RestOperationsConfig.class, BasicAndResourceServerConfig.class, BasicController.class) + .autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(token))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(token))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); - - this.mvc.perform(get("/authenticated") - .with(httpBasic("basic-user", "basic-password"))) + this.mvc.perform(get("/authenticated").with(httpBasic("basic-user", "basic-password"))) .andExpect(status().isOk()) .andExpect(content().string("basic-user")); + // @formatter:on } - // -- authentication manager - @Test public void getAuthenticationManagerWhenConfiguredAuthenticationManagerThenTakesPrecedence() { ApplicationContext context = mock(ApplicationContext.class); HttpSecurityBuilder http = mock(HttpSecurityBuilder.class); - OAuth2ResourceServerConfigurer oauth2ResourceServer = new OAuth2ResourceServerConfigurer(context); AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - oauth2ResourceServer - .jwt() - .authenticationManager(authenticationManager) - .decoder(mock(JwtDecoder.class)); + oauth2ResourceServer.jwt().authenticationManager(authenticationManager).decoder(mock(JwtDecoder.class)); assertThat(oauth2ResourceServer.getAuthenticationManager(http)).isSameAs(authenticationManager); - oauth2ResourceServer = new OAuth2ResourceServerConfigurer(context); - oauth2ResourceServer - .opaqueToken() - .authenticationManager(authenticationManager) + oauth2ResourceServer.opaqueToken().authenticationManager(authenticationManager) .introspector(mock(OpaqueTokenIntrospector.class)); assertThat(oauth2ResourceServer.getAuthenticationManager(http)).isSameAs(authenticationManager); verify(http, never()).authenticationProvider(any(AuthenticationProvider.class)); } - // -- authentication manager resolver - @Test public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Exception { this.spring.register(WebServerConfig.class, MultipleIssuersConfig.class, BasicController.class).autowire(); - MockWebServer server = this.spring.getContext().getBean(MockWebServer.class); - String metadata = "{\n" - + " \"issuer\": \"%s\", \n" - + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" + String metadata = "{\n" + " \"issuer\": \"%s\", \n" + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" + "}"; String jwkSet = jwkSet(); String issuerOne = server.url("/issuerOne").toString(); @@ -1395,86 +1183,71 @@ public class OAuth2ResourceServerConfigurerTests { String jwtOne = jwtFromIssuer(issuerOne); String jwtTwo = jwtFromIssuer(issuerTwo); String jwtThree = jwtFromIssuer(issuerThree); - mockWebServer(String.format(metadata, issuerOne, issuerOne)); mockWebServer(jwkSet); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(jwtOne))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(jwtOne))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); - + // @formatter:on mockWebServer(String.format(metadata, issuerTwo, issuerTwo)); mockWebServer(jwkSet); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(jwtTwo))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(jwtTwo))) .andExpect(status().isOk()) .andExpect(content().string("test-subject")); - + // @formatter:on mockWebServer(String.format(metadata, issuerThree, issuerThree)); mockWebServer(jwkSet); - - this.mvc.perform(get("/authenticated") - .with(bearerToken(jwtThree))) + // @formatter:off + this.mvc.perform(get("/authenticated").with(bearerToken(jwtThree))) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Invalid issuer")); + // @formatter:on } - // -- Incorrect Configuration - @Test public void configuredWhenMissingJwtAuthenticationProviderThenWiringException() { - - assertThatCode(() -> this.spring.register(JwtlessConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("neither was found"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(JwtlessConfig.class).autowire()) + .withMessageContaining("neither was found"); } @Test public void configureWhenMissingJwkSetUriThenWiringException() { - - assertThatCode(() -> this.spring.register(JwtHalfConfiguredConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("No qualifying bean of type"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(JwtHalfConfiguredConfig.class).autowire()) + .withMessageContaining("No qualifying bean of type"); } @Test public void configureWhenUsingBothJwtAndOpaqueThenWiringException() { - assertThatCode(() -> this.spring.register(OpaqueAndJwtConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("Spring Security only supports JWTs or Opaque Tokens"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(OpaqueAndJwtConfig.class).autowire()) + .withMessageContaining("Spring Security only supports JWTs or Opaque Tokens"); } @Test public void configureWhenUsingBothAuthenticationManagerResolverAndOpaqueThenWiringException() { - assertThatCode(() -> this.spring.register(AuthenticationManagerResolverPlusOtherConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("authenticationManagerResolver"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(AuthenticationManagerResolverPlusOtherConfig.class).autowire()) + .withMessageContaining("authenticationManagerResolver"); } @Test public void getJwtAuthenticationConverterWhenNoConverterSpecifiedThenTheDefaultIsUsed() { - ApplicationContext context = - this.spring.context(new GenericWebApplicationContext()).getContext(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - + ApplicationContext context = this.spring.context(new GenericWebApplicationContext()).getContext(); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); assertThat(jwtConfigurer.getJwtAuthenticationConverter()).isInstanceOf(JwtAuthenticationConverter.class); } @Test public void getJwtAuthenticationConverterWhenConverterBeanSpecified() { JwtAuthenticationConverter converterBean = new JwtAuthenticationConverter(); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean(JwtAuthenticationConverter.class, () -> converterBean); this.spring.context(context).autowire(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); assertThat(jwtConfigurer.getJwtAuthenticationConverter()).isEqualTo(converterBean); } @@ -1482,15 +1255,11 @@ public class OAuth2ResourceServerConfigurerTests { public void getJwtAuthenticationConverterWhenConverterBeanAndAnotherOnTheDslThenTheDslOneIsUsed() { JwtAuthenticationConverter converter = new JwtAuthenticationConverter(); JwtAuthenticationConverter converterBean = new JwtAuthenticationConverter(); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean(JwtAuthenticationConverter.class, () -> converterBean); this.spring.context(context).autowire(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); jwtConfigurer.jwtAuthenticationConverter(converter); - assertThat(jwtConfigurer.getJwtAuthenticationConverter()).isEqualTo(converter); } @@ -1498,1011 +1267,73 @@ public class OAuth2ResourceServerConfigurerTests { public void getJwtAuthenticationConverterWhenDuplicateConverterBeansAndAnotherOnTheDslThenTheDslOneIsUsed() { JwtAuthenticationConverter converter = new JwtAuthenticationConverter(); JwtAuthenticationConverter converterBean = new JwtAuthenticationConverter(); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean("converterOne", JwtAuthenticationConverter.class, () -> converterBean); context.registerBean("converterTwo", JwtAuthenticationConverter.class, () -> converterBean); this.spring.context(context).autowire(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); jwtConfigurer.jwtAuthenticationConverter(converter); - assertThat(jwtConfigurer.getJwtAuthenticationConverter()).isEqualTo(converter); } @Test public void getJwtAuthenticationConverterWhenDuplicateConverterBeansThenThrowsException() { JwtAuthenticationConverter converterBean = new JwtAuthenticationConverter(); - GenericWebApplicationContext context = new GenericWebApplicationContext(); context.registerBean("converterOne", JwtAuthenticationConverter.class, () -> converterBean); context.registerBean("converterTwo", JwtAuthenticationConverter.class, () -> converterBean); this.spring.context(context).autowire(); - - OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = - new OAuth2ResourceServerConfigurer(context).jwt(); - - assertThatCode(jwtConfigurer::getJwtAuthenticationConverter) - .isInstanceOf(NoUniqueBeanDefinitionException.class); - } - - // -- support - - @EnableWebSecurity - static class DefaultConfig extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - } - - @EnableWebSecurity - static class DefaultInLambdaConfig extends WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") - .anyRequest().authenticated() - ) - .oauth2ResourceServer(oauth2ResourceServer -> - oauth2ResourceServer - .jwt(withDefaults()) - ); - // @formatter:on - } - } - - @EnableWebSecurity - static class JwkSetUriConfig extends WebSecurityConfigurerAdapter { - @Value("${mockwebserver.url:https://example.org}") - String jwkSetUri; - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt() - .jwkSetUri(this.jwkSetUri); - // @formatter:on - } - } - - @EnableWebSecurity - static class JwkSetUriInLambdaConfig extends WebSecurityConfigurerAdapter { - @Value("${mockwebserver.url:https://example.org}") - String jwkSetUri; - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") - .anyRequest().authenticated() - ) - .oauth2ResourceServer(oauth2ResourceServer -> - oauth2ResourceServer - .jwt(jwt -> - jwt - .jwkSetUri(this.jwkSetUri) - ) - ); - // @formatter:on - } - } - - @EnableWebSecurity - static class CsrfDisabledConfig extends WebSecurityConfigurerAdapter { - @Value("${mockwebserver.url:https://example.org}") - String jwkSetUri; - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") - .anyRequest().authenticated() - .and() - .csrf().disable() - .oauth2ResourceServer() - .jwt() - .jwkSetUri(this.jwkSetUri); - // @formatter:on - } - } - - @EnableWebSecurity - static class AnonymousDisabledConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .anonymous().disable() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - } - - @EnableWebSecurity - @EnableGlobalMethodSecurity(prePostEnabled = true) - static class MethodSecurityConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - } - - @EnableWebSecurity - static class JwtlessConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer(); - // @formatter:on - } - } - - @EnableWebSecurity - static class RealmNameConfiguredOnEntryPoint extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .authenticationEntryPoint(authenticationEntryPoint()) - .jwt(); - // @formatter:on - } - - AuthenticationEntryPoint authenticationEntryPoint() { - BearerTokenAuthenticationEntryPoint entryPoint = - new BearerTokenAuthenticationEntryPoint(); - entryPoint.setRealmName("myRealm"); - return entryPoint; - } - } - - @EnableWebSecurity - static class RealmNameConfiguredOnAccessDeniedHandler extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().denyAll() - .and() - .oauth2ResourceServer() - .accessDeniedHandler(accessDeniedHandler()) - .jwt(); - // @formatter:on - } - - AccessDeniedHandler accessDeniedHandler() { - BearerTokenAccessDeniedHandler accessDeniedHandler = - new BearerTokenAccessDeniedHandler(); - accessDeniedHandler.setRealmName("myRealm"); - return accessDeniedHandler; - } - } - - @EnableWebSecurity - static class ExceptionHandlingAndResourceServerWithAccessDeniedHandlerConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().denyAll() - .and() - .exceptionHandling() - .defaultAccessDeniedHandlerFor(new AccessDeniedHandlerImpl(), request -> false) - .and() - .httpBasic() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - public UserDetailsService userDetailsService() { - return new InMemoryUserDetailsManager( - org.springframework.security.core.userdetails.User.withDefaultPasswordEncoder() - .username("basic-user") - .password("basic-password") - .roles("USER") - .build()); - } - } - - @EnableWebSecurity - static class JwtAuthenticationConverterConfiguredOnDsl extends WebSecurityConfigurerAdapter { - private final Converter jwtAuthenticationConverter = mock(Converter.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt() - .jwtAuthenticationConverter(getJwtAuthenticationConverter()); - - // @formatter:on - } - - Converter getJwtAuthenticationConverter() { - return this.jwtAuthenticationConverter; - } - } - - @EnableWebSecurity - static class CustomAuthorityMappingConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - - http - .authorizeRequests() - .antMatchers("/requires-read-scope").access("hasAuthority('message:read')") - .and() - .oauth2ResourceServer() - .jwt() - .jwtAuthenticationConverter(getJwtAuthenticationConverter()); - - // @formatter:on - } - - Converter getJwtAuthenticationConverter() { - JwtAuthenticationConverter converter = new JwtAuthenticationConverter(); - converter.setJwtGrantedAuthoritiesConverter(jwt -> - Collections.singletonList(new SimpleGrantedAuthority("message:read")) - ); - return converter; - } - } - - @EnableWebSecurity - static class BasicAndResourceServerConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .httpBasic() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - public UserDetailsService userDetailsService() { - return new InMemoryUserDetailsManager( - org.springframework.security.core.userdetails.User.withDefaultPasswordEncoder() - .username("basic-user") - .password("basic-password") - .roles("USER") - .build()); - } - } - - @EnableWebSecurity - static class FormAndResourceServerConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .formLogin() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - } - - @EnableWebSecurity - static class JwtHalfConfiguredConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); // missing key configuration, e.g. jwkSetUri - // @formatter:on - } - } - - @EnableWebSecurity - static class AlwaysSessionCreationConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .sessionManagement() - .sessionCreationPolicy(SessionCreationPolicy.ALWAYS) - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - } - - @EnableWebSecurity - static class AllowBearerTokenInRequestBodyConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .bearerTokenResolver(allowRequestBody()) - .jwt(); - // @formatter:on - } - - private BearerTokenResolver allowRequestBody() { - DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); - resolver.setAllowFormEncodedBodyParameter(true); - return resolver; - } - } - - @EnableWebSecurity - static class AllowBearerTokenAsQueryParameterConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - BearerTokenResolver allowQueryParameter() { - DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); - resolver.setAllowUriQueryParameter(true); - return resolver; - } - } - - @EnableWebSecurity - static class MultipleBearerTokenResolverBeansConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - BearerTokenResolver resolverOne() { - DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); - resolver.setAllowUriQueryParameter(true); - return resolver; - } - - @Bean - BearerTokenResolver resolverTwo() { - DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); - resolver.setAllowFormEncodedBodyParameter(true); - return resolver; - } - } - - @EnableWebSecurity - static class CustomJwtDecoderOnDsl extends WebSecurityConfigurerAdapter { - JwtDecoder decoder = mock(JwtDecoder.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt() - .decoder(decoder()); - // @formatter:on - } - - JwtDecoder decoder() { - return this.decoder; - } - } - - @EnableWebSecurity - static class CustomJwtDecoderInLambdaOnDsl extends WebSecurityConfigurerAdapter { - JwtDecoder decoder = mock(JwtDecoder.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .anyRequest().authenticated() - ) - .oauth2ResourceServer(oauth2ResourceServer -> - oauth2ResourceServer - .jwt(jwt -> - jwt - .decoder(decoder()) - ) - ); - // @formatter:on - } - - JwtDecoder decoder() { - return this.decoder; - } - } - - @EnableWebSecurity - static class CustomJwtDecoderAsBean extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - public JwtDecoder decoder() { - return mock(JwtDecoder.class); - } - } - - @EnableWebSecurity - static class JwtAuthenticationManagerConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt() - .authenticationManager(authenticationProvider()::authenticate); - // @formatter:on - } - - @Bean - public AuthenticationProvider authenticationProvider() { - return mock(AuthenticationProvider.class); - } - } - - @EnableWebSecurity - static class CustomJwtValidatorConfig extends WebSecurityConfigurerAdapter { - @Autowired - NimbusJwtDecoder jwtDecoder; - - private final OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); - - @Override - protected void configure(HttpSecurity http) throws Exception { - this.jwtDecoder.setJwtValidator(this.jwtValidator); - - // @formatter:off - http - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - public OAuth2TokenValidator getJwtValidator() { - return this.jwtValidator; - } - } - - @EnableWebSecurity - static class UnexpiredJwtClockSkewConfig extends WebSecurityConfigurerAdapter { - @Autowired - NimbusJwtDecoder jwtDecoder; - - @Override - protected void configure(HttpSecurity http) throws Exception { - Clock nearlyAnHourFromTokenExpiry = - Clock.fixed(Instant.ofEpochMilli(4687181540000L), ZoneId.systemDefault()); - JwtTimestampValidator jwtValidator = new JwtTimestampValidator(Duration.ofHours(1)); - jwtValidator.setClock(nearlyAnHourFromTokenExpiry); - - this.jwtDecoder.setJwtValidator(jwtValidator); - - // @formatter:off - http - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - } - - @EnableWebSecurity - static class ExpiredJwtClockSkewConfig extends WebSecurityConfigurerAdapter { - @Autowired - NimbusJwtDecoder jwtDecoder; - - @Override - protected void configure(HttpSecurity http) throws Exception { - Clock justOverOneHourAfterExpiry = - Clock.fixed(Instant.ofEpochMilli(4687181595000L), ZoneId.systemDefault()); - JwtTimestampValidator jwtValidator = new JwtTimestampValidator(Duration.ofHours(1)); - jwtValidator.setClock(justOverOneHourAfterExpiry); - - this.jwtDecoder.setJwtValidator(jwtValidator); - - // @formatter:off - http - .oauth2ResourceServer() - .jwt(); - } - } - - @EnableWebSecurity - static class SingleKeyConfig extends WebSecurityConfigurerAdapter { - byte[] spec = Base64.getDecoder().decode( - "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoXJ8OyOv/eRnce4akdan" + - "R4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2" + - "UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48O" + - "cz06PGF3lhbz4t5UEZtdF4rIe7u+977QwHuh7yRPBQ3sII+cVoOUMgaXB9SHcGF2" + - "iZCtPzL/IffDUcfhLQteGebhW8A6eUHgpD5A1PQ+JCw/G7UOzZAjjDjtNM2eqm8j" + - "+Ms/gqnm4MiCZ4E+9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1Hu" + - "QwIDAQAB"); - - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - JwtDecoder decoder() throws Exception { - RSAPublicKey publicKey = (RSAPublicKey) - KeyFactory.getInstance("RSA").generatePublic(new X509EncodedKeySpec(this.spec)); - return withPublicKey(publicKey).build(); - } - } - - @EnableWebSecurity - static class CustomAuthenticationEventPublisher extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .jwt(); - // @formatter:on - } - - @Bean - JwtDecoder jwtDecoder() { - return mock(JwtDecoder.class); - } - - @Bean - AuthenticationEventPublisher authenticationEventPublisher() { - return mock(AuthenticationEventPublisher.class); - } - } - - @EnableWebSecurity - static class OpaqueTokenConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .antMatchers("/requires-read-scope").hasAuthority("SCOPE_message:read") - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .opaqueToken(); - // @formatter:on - } - } - - @EnableWebSecurity - static class OpaqueTokenInLambdaConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .antMatchers("/requires-read-scope").hasAuthority("SCOPE_message:read") - .anyRequest().authenticated() - ) - .oauth2ResourceServer(oauth2ResourceServer -> - oauth2ResourceServer - .opaqueToken(withDefaults()) - ); - // @formatter:on - } - } - - @EnableWebSecurity - static class OpaqueTokenAuthenticationManagerConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .opaqueToken() - .authenticationManager(authenticationProvider()::authenticate); - // @formatter:on - } - - @Bean - public AuthenticationProvider authenticationProvider() { - return mock(AuthenticationProvider.class); - } - } - - @EnableWebSecurity - static class OpaqueTokenAuthenticationManagerInLambdaConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests(authorizeRequests -> - authorizeRequests - .anyRequest().authenticated() - ) - .oauth2ResourceServer(oauth2ResourceServer -> - oauth2ResourceServer - .opaqueToken(opaqueToken -> - opaqueToken - .authenticationManager(authenticationProvider()::authenticate) - ) - ); - // @formatter:on - } - - @Bean - public AuthenticationProvider authenticationProvider() { - return mock(AuthenticationProvider.class); - } - } - - @EnableWebSecurity - static class OpaqueAndJwtConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .oauth2ResourceServer() - .jwt() - .and() - .opaqueToken(); - } - } - - @EnableWebSecurity - static class OpaqueTokenHalfConfiguredConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .opaqueToken() - .introspectionUri("https://idp.example.com"); // missing credentials - // @formatter:on - } - } - - @EnableWebSecurity - static class MultipleIssuersConfig extends WebSecurityConfigurerAdapter { - @Autowired - MockWebServer web; - - @Override - protected void configure(HttpSecurity http) throws Exception { - String issuerOne = this.web.url("/issuerOne").toString(); - String issuerTwo = this.web.url("/issuerTwo").toString(); - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver(issuerOne, issuerTwo); - - // @formatter:off - http - .oauth2ResourceServer() - .authenticationManagerResolver(authenticationManagerResolver); - // @formatter:on - } - } - - @EnableWebSecurity - static class AuthenticationManagerResolverPlusOtherConfig extends WebSecurityConfigurerAdapter { - @Override - protected void configure(HttpSecurity http) throws Exception { - // @formatter:off - http - .authorizeRequests() - .anyRequest().authenticated() - .and() - .oauth2ResourceServer() - .authenticationManagerResolver(mock(AuthenticationManagerResolver.class)) - .opaqueToken(); - } - } - - @Configuration - static class JwtDecoderConfig { - @Bean - public JwtDecoder jwtDecoder() { - return mock(JwtDecoder.class); - } - } - - @RestController - static class BasicController { - @GetMapping("/") - public String get() { - return "ok"; - } - - @PostMapping("/post") - public String post() { - return "post"; - } - - @RequestMapping(value = "/authenticated", method = { GET, POST }) - public String authenticated(Authentication authentication) { - return authentication.getName(); - } - - @GetMapping("/requires-read-scope") - public String requiresReadScope(JwtAuthenticationToken token) { - return token.getAuthorities().stream() - .map(GrantedAuthority::getAuthority) - .collect(Collectors.toList()).toString(); - } - - @GetMapping("/ms-requires-read-scope") - @PreAuthorize("hasAuthority('SCOPE_message:read')") - public String msRequiresReadScope(JwtAuthenticationToken token) { - return requiresReadScope(token); - } - - @GetMapping("/ms-deny") - @PreAuthorize("denyAll") - public String deny() { - return "hmm, that's odd"; - } - } - - @Configuration - static class WebServerConfig implements BeanPostProcessor, EnvironmentAware { - private final MockWebServer server = new MockWebServer(); - - @PreDestroy - public void shutdown() throws IOException { - this.server.shutdown(); - } - - @Override - public void setEnvironment(Environment environment) { - if (environment instanceof ConfigurableEnvironment) { - ((ConfigurableEnvironment) environment) - .getPropertySources().addFirst(new MockWebServerPropertySource()); - } - } - - @Bean - public MockWebServer web() { - return this.server; - } - - private class MockWebServerPropertySource extends PropertySource { - - MockWebServerPropertySource() { - super("mockwebserver"); - } - - @Override - public Object getProperty(String name) { - if ("mockwebserver.url".equals(name)) { - return WebServerConfig.this.server.url("/.well-known/jwks.json").toString(); - } else { - return null; - } - } - } - } - - @Configuration - static class RestOperationsConfig { - RestOperations rest = mock(RestOperations.class); - - @Bean - RestOperations rest() { - return this.rest; - } - - @Bean - NimbusJwtDecoder jwtDecoder() { - return withJwkSetUri("https://example.org/.well-known/jwks.json") - .restOperations(this.rest).build(); - } - - @Bean - NimbusOpaqueTokenIntrospector tokenIntrospectionClient() { - return new NimbusOpaqueTokenIntrospector("https://example.org/introspect", this.rest); - } + OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer = new OAuth2ResourceServerConfigurer(context).jwt(); + assertThatExceptionOfType(NoUniqueBeanDefinitionException.class) + .isThrownBy(jwtConfigurer::getJwtAuthenticationConverter); } private static void registerMockBean(GenericApplicationContext context, String name, Class clazz) { context.registerBean(name, clazz, () -> mock(clazz)); } - private static class BearerTokenRequestPostProcessor implements RequestPostProcessor { - private boolean asRequestParameter; - - private String token; - - BearerTokenRequestPostProcessor(String token) { - this.token = token; - } - - public BearerTokenRequestPostProcessor asParam() { - this.asRequestParameter = true; - return this; - } - - @Override - public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - if (this.asRequestParameter) { - request.setParameter("access_token", this.token); - } else { - request.addHeader("Authorization", "Bearer " + this.token); - } - - return request; - } - } - private static BearerTokenRequestPostProcessor bearerToken(String token) { return new BearerTokenRequestPostProcessor(token); } private static ResultMatcher invalidRequestHeader(String message) { return header().string(HttpHeaders.WWW_AUTHENTICATE, - AllOf.allOf( - new StringStartsWith("Bearer " + - "error=\"invalid_request\", " + - "error_description=\""), + AllOf.allOf(new StringStartsWith("Bearer " + "error=\"invalid_request\", " + "error_description=\""), new StringContains(message), - new StringEndsWith(", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"") - ) - ); + new StringEndsWith(", " + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""))); } private static ResultMatcher invalidTokenHeader(String message) { return header().string(HttpHeaders.WWW_AUTHENTICATE, - AllOf.allOf( - new StringStartsWith("Bearer " + - "error=\"invalid_token\", " + - "error_description=\""), + AllOf.allOf(new StringStartsWith("Bearer " + "error=\"invalid_token\", " + "error_description=\""), new StringContains(message), - new StringEndsWith(", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"") - ) - ); + new StringEndsWith(", " + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""))); } private static ResultMatcher insufficientScopeHeader() { - return header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer " + - "error=\"insufficient_scope\"" + - ", error_description=\"The request requires higher privileges than provided by the access token.\"" + - ", error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""); + return header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer " + "error=\"insufficient_scope\"" + + ", error_description=\"The request requires higher privileges than provided by the access token.\"" + + ", error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""); } private String jwkSet() { - return new JWKSet(new RSAKey.Builder(TestKeys.DEFAULT_PUBLIC_KEY) - .keyID("1").build()).toString(); + return new JWKSet(new RSAKey.Builder(TestKeys.DEFAULT_PUBLIC_KEY).keyID("1").build()).toString(); } private String jwtFromIssuer(String issuer) throws Exception { Map claims = new HashMap<>(); - claims.put(ISS, issuer); - claims.put(SUB, "test-subject"); + claims.put(JwtClaimNames.ISS, issuer); + claims.put(JwtClaimNames.SUB, "test-subject"); claims.put("scope", "message:read"); - JWSObject jws = new JWSObject( - new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(), + JWSObject jws = new JWSObject(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(), new Payload(new JSONObject(claims))); jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); return jws.serialize(); } private void mockWebServer(String response) { - this.web.enqueue(new MockResponse() - .setResponseCode(200) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(response)); + this.web.enqueue(new MockResponse().setResponseCode(200) + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(response)); } private void mockRestOperations(String response) { @@ -2510,8 +1341,7 @@ public class OAuth2ResourceServerConfigurerTests { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); ResponseEntity entity = new ResponseEntity<>(response, headers, HttpStatus.OK); - when(rest.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(entity); + given(rest.exchange(any(RequestEntity.class), eq(String.class))).willReturn(entity); } private T bean(Class beanClass) { @@ -2537,8 +1367,994 @@ public class OAuth2ResourceServerConfigurerTests { private String resource(String suffix) throws IOException { String name = this.getClass().getSimpleName() + "-" + suffix; ClassPathResource resource = new ClassPathResource(name, this.getClass()); - try ( BufferedReader reader = new BufferedReader(new FileReader(resource.getFile())) ) { + try (BufferedReader reader = new BufferedReader(new FileReader(resource.getFile()))) { return reader.lines().collect(Collectors.joining()); } } + + @EnableWebSecurity + static class DefaultConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class DefaultInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") + .anyRequest().authenticated() + ) + .oauth2ResourceServer((oauth2ResourceServer) -> + oauth2ResourceServer + .jwt(withDefaults()) + ); + // @formatter:on + } + + } + + @EnableWebSecurity + static class JwkSetUriConfig extends WebSecurityConfigurerAdapter { + + @Value("${mockwebserver.url:https://example.org}") + String jwkSetUri; + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt() + .jwkSetUri(this.jwkSetUri); + // @formatter:on + } + + } + + @EnableWebSecurity + static class JwkSetUriInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Value("${mockwebserver.url:https://example.org}") + String jwkSetUri; + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") + .anyRequest().authenticated() + ) + .oauth2ResourceServer((oauth2ResourceServer) -> + oauth2ResourceServer + .jwt((jwt) -> + jwt + .jwkSetUri(this.jwkSetUri) + ) + ); + // @formatter:on + } + + } + + @EnableWebSecurity + static class CsrfDisabledConfig extends WebSecurityConfigurerAdapter { + + @Value("${mockwebserver.url:https://example.org}") + String jwkSetUri; + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .antMatchers("/requires-read-scope").access("hasAuthority('SCOPE_message:read')") + .anyRequest().authenticated() + .and() + .csrf().disable() + .oauth2ResourceServer() + .jwt() + .jwkSetUri(this.jwkSetUri); + // @formatter:on + } + + } + + @EnableWebSecurity + static class AnonymousDisabledConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .anonymous().disable() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + } + + @EnableWebSecurity + @EnableGlobalMethodSecurity(prePostEnabled = true) + static class MethodSecurityConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class JwtlessConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class RealmNameConfiguredOnEntryPoint extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .authenticationEntryPoint(authenticationEntryPoint()) + .jwt(); + // @formatter:on + } + + AuthenticationEntryPoint authenticationEntryPoint() { + BearerTokenAuthenticationEntryPoint entryPoint = new BearerTokenAuthenticationEntryPoint(); + entryPoint.setRealmName("myRealm"); + return entryPoint; + } + + } + + @EnableWebSecurity + static class RealmNameConfiguredOnAccessDeniedHandler extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().denyAll() + .and() + .oauth2ResourceServer() + .accessDeniedHandler(accessDeniedHandler()) + .jwt(); + // @formatter:on + } + + AccessDeniedHandler accessDeniedHandler() { + BearerTokenAccessDeniedHandler accessDeniedHandler = new BearerTokenAccessDeniedHandler(); + accessDeniedHandler.setRealmName("myRealm"); + return accessDeniedHandler; + } + + } + + @EnableWebSecurity + static class ExceptionHandlingAndResourceServerWithAccessDeniedHandlerConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().denyAll() + .and() + .exceptionHandling() + .defaultAccessDeniedHandlerFor(new AccessDeniedHandlerImpl(), (request) -> false) + .and() + .httpBasic() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Override + @Bean + public UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager( + // @formatter:off + org.springframework.security.core.userdetails.User.withDefaultPasswordEncoder() + .username("basic-user") + .password("basic-password") + .roles("USER") + .build()); + // @formatter:on + } + + } + + @EnableWebSecurity + static class JwtAuthenticationConverterConfiguredOnDsl extends WebSecurityConfigurerAdapter { + + private final Converter jwtAuthenticationConverter = mock(Converter.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt() + .jwtAuthenticationConverter(getJwtAuthenticationConverter()); + // @formatter:on + } + + Converter getJwtAuthenticationConverter() { + return this.jwtAuthenticationConverter; + } + + } + + @EnableWebSecurity + static class CustomAuthorityMappingConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .antMatchers("/requires-read-scope").access("hasAuthority('message:read')") + .and() + .oauth2ResourceServer() + .jwt() + .jwtAuthenticationConverter(getJwtAuthenticationConverter()); + // @formatter:on + } + + Converter getJwtAuthenticationConverter() { + JwtAuthenticationConverter converter = new JwtAuthenticationConverter(); + converter.setJwtGrantedAuthoritiesConverter( + (jwt) -> Collections.singletonList(new SimpleGrantedAuthority("message:read"))); + return converter; + } + + } + + @EnableWebSecurity + static class BasicAndResourceServerConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .httpBasic() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Override + @Bean + public UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager( + // @formatter:off + org.springframework.security.core.userdetails.User.withDefaultPasswordEncoder() + .username("basic-user") + .password("basic-password") + .roles("USER") + .build()); + // @formatter:on + } + + } + + @EnableWebSecurity + static class FormAndResourceServerConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .formLogin() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class JwtHalfConfiguredConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); // missing key configuration, e.g. jwkSetUri + // @formatter:on + } + + } + + @EnableWebSecurity + static class AlwaysSessionCreationConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .sessionManagement() + .sessionCreationPolicy(SessionCreationPolicy.ALWAYS) + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class AllowBearerTokenInRequestBodyConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .bearerTokenResolver(allowRequestBody()) + .jwt(); + // @formatter:on + } + + private BearerTokenResolver allowRequestBody() { + DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); + resolver.setAllowFormEncodedBodyParameter(true); + return resolver; + } + + } + + @EnableWebSecurity + static class AllowBearerTokenAsQueryParameterConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Bean + BearerTokenResolver allowQueryParameter() { + DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); + resolver.setAllowUriQueryParameter(true); + return resolver; + } + + } + + @EnableWebSecurity + static class MultipleBearerTokenResolverBeansConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Bean + BearerTokenResolver resolverOne() { + DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); + resolver.setAllowUriQueryParameter(true); + return resolver; + } + + @Bean + BearerTokenResolver resolverTwo() { + DefaultBearerTokenResolver resolver = new DefaultBearerTokenResolver(); + resolver.setAllowFormEncodedBodyParameter(true); + return resolver; + } + + } + + @EnableWebSecurity + static class CustomJwtDecoderOnDsl extends WebSecurityConfigurerAdapter { + + JwtDecoder decoder = mock(JwtDecoder.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt() + .decoder(decoder()); + // @formatter:on + } + + JwtDecoder decoder() { + return this.decoder; + } + + } + + @EnableWebSecurity + static class CustomJwtDecoderInLambdaOnDsl extends WebSecurityConfigurerAdapter { + + JwtDecoder decoder = mock(JwtDecoder.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .anyRequest().authenticated() + ) + .oauth2ResourceServer((oauth2ResourceServer) -> + oauth2ResourceServer + .jwt((jwt) -> + jwt + .decoder(decoder()) + ) + ); + // @formatter:on + } + + JwtDecoder decoder() { + return this.decoder; + } + + } + + @EnableWebSecurity + static class CustomJwtDecoderAsBean extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Bean + JwtDecoder decoder() { + return mock(JwtDecoder.class); + } + + } + + @EnableWebSecurity + static class JwtAuthenticationManagerConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt() + .authenticationManager(authenticationProvider()::authenticate); + // @formatter:on + } + + @Bean + AuthenticationProvider authenticationProvider() { + return mock(AuthenticationProvider.class); + } + + } + + @EnableWebSecurity + static class CustomJwtValidatorConfig extends WebSecurityConfigurerAdapter { + + @Autowired + NimbusJwtDecoder jwtDecoder; + + private final OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + this.jwtDecoder.setJwtValidator(this.jwtValidator); + // @formatter:off + http + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + OAuth2TokenValidator getJwtValidator() { + return this.jwtValidator; + } + + } + + @EnableWebSecurity + static class UnexpiredJwtClockSkewConfig extends WebSecurityConfigurerAdapter { + + @Autowired + NimbusJwtDecoder jwtDecoder; + + @Override + protected void configure(HttpSecurity http) throws Exception { + Clock nearlyAnHourFromTokenExpiry = Clock.fixed(Instant.ofEpochMilli(4687181540000L), + ZoneId.systemDefault()); + JwtTimestampValidator jwtValidator = new JwtTimestampValidator(Duration.ofHours(1)); + jwtValidator.setClock(nearlyAnHourFromTokenExpiry); + this.jwtDecoder.setJwtValidator(jwtValidator); + // @formatter:off + http + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class ExpiredJwtClockSkewConfig extends WebSecurityConfigurerAdapter { + + @Autowired + NimbusJwtDecoder jwtDecoder; + + @Override + protected void configure(HttpSecurity http) throws Exception { + Clock justOverOneHourAfterExpiry = Clock.fixed(Instant.ofEpochMilli(4687181595000L), + ZoneId.systemDefault()); + JwtTimestampValidator jwtValidator = new JwtTimestampValidator(Duration.ofHours(1)); + jwtValidator.setClock(justOverOneHourAfterExpiry); + this.jwtDecoder.setJwtValidator(jwtValidator); + // @formatter:off + http + .oauth2ResourceServer() + .jwt(); + } + } + @EnableWebSecurity + static class SingleKeyConfig extends WebSecurityConfigurerAdapter { + byte[] spec = Base64.getDecoder().decode( + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoXJ8OyOv/eRnce4akdan" + + "R4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2" + + "UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48O" + + "cz06PGF3lhbz4t5UEZtdF4rIe7u+977QwHuh7yRPBQ3sII+cVoOUMgaXB9SHcGF2" + + "iZCtPzL/IffDUcfhLQteGebhW8A6eUHgpD5A1PQ+JCw/G7UOzZAjjDjtNM2eqm8j" + + "+Ms/gqnm4MiCZ4E+9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1Hu" + + "QwIDAQAB"); + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Bean + JwtDecoder decoder() throws Exception { + RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA") + .generatePublic(new X509EncodedKeySpec(this.spec)); + return NimbusJwtDecoder.withPublicKey(publicKey).build(); + } + + } + + @EnableWebSecurity + static class CustomAuthenticationEventPublisher extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .jwt(); + // @formatter:on + } + + @Bean + JwtDecoder jwtDecoder() { + return mock(JwtDecoder.class); + } + + @Bean + AuthenticationEventPublisher authenticationEventPublisher() { + return mock(AuthenticationEventPublisher.class); + } + + } + + @EnableWebSecurity + static class OpaqueTokenConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .antMatchers("/requires-read-scope").hasAuthority("SCOPE_message:read") + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .opaqueToken(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class OpaqueTokenInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .antMatchers("/requires-read-scope").hasAuthority("SCOPE_message:read") + .anyRequest().authenticated() + ) + .oauth2ResourceServer((oauth2ResourceServer) -> + oauth2ResourceServer + .opaqueToken(withDefaults()) + ); + // @formatter:on + } + + } + + @EnableWebSecurity + static class OpaqueTokenAuthenticationManagerConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .opaqueToken() + .authenticationManager(authenticationProvider()::authenticate); + // @formatter:on + } + + @Bean + AuthenticationProvider authenticationProvider() { + return mock(AuthenticationProvider.class); + } + + } + + @EnableWebSecurity + static class OpaqueTokenAuthenticationManagerInLambdaConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .anyRequest().authenticated() + ) + .oauth2ResourceServer((oauth2ResourceServer) -> + oauth2ResourceServer + .opaqueToken((opaqueToken) -> + opaqueToken + .authenticationManager(authenticationProvider()::authenticate) + ) + ); + // @formatter:on + } + + @Bean + AuthenticationProvider authenticationProvider() { + return mock(AuthenticationProvider.class); + } + + } + + @EnableWebSecurity + static class OpaqueAndJwtConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2ResourceServer() + .jwt() + .and() + .opaqueToken(); + // @formatter:on + } + + } + + @EnableWebSecurity + static class OpaqueTokenHalfConfiguredConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .opaqueToken() + .introspectionUri("https://idp.example.com"); // missing credentials + // @formatter:on + } + + } + + @EnableWebSecurity + static class MultipleIssuersConfig extends WebSecurityConfigurerAdapter { + + @Autowired + MockWebServer web; + + @Override + protected void configure(HttpSecurity http) throws Exception { + String issuerOne = this.web.url("/issuerOne").toString(); + String issuerTwo = this.web.url("/issuerTwo").toString(); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + issuerOne, issuerTwo); + // @formatter:off + http + .oauth2ResourceServer() + .authenticationManagerResolver(authenticationManagerResolver); + // @formatter:on + } + + } + + @EnableWebSecurity + static class AuthenticationManagerResolverPlusOtherConfig extends WebSecurityConfigurerAdapter { + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests() + .anyRequest().authenticated() + .and() + .oauth2ResourceServer() + .authenticationManagerResolver(mock(AuthenticationManagerResolver.class)) + .opaqueToken(); + // @formatter:on + } + + } + + @Configuration + static class JwtDecoderConfig { + + @Bean + JwtDecoder jwtDecoder() { + return mock(JwtDecoder.class); + } + + } + + @RestController + static class BasicController { + + @GetMapping("/") + String get() { + return "ok"; + } + + @PostMapping("/post") + String post() { + return "post"; + } + + @RequestMapping(value = "/authenticated", method = { RequestMethod.GET, RequestMethod.POST }) + String authenticated(Authentication authentication) { + return authentication.getName(); + } + + @GetMapping("/requires-read-scope") + String requiresReadScope(JwtAuthenticationToken token) { + return token.getAuthorities().stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList()) + .toString(); + } + + @GetMapping("/ms-requires-read-scope") + @PreAuthorize("hasAuthority('SCOPE_message:read')") + String msRequiresReadScope(JwtAuthenticationToken token) { + return requiresReadScope(token); + } + + @GetMapping("/ms-deny") + @PreAuthorize("denyAll") + String deny() { + return "hmm, that's odd"; + } + + } + + @Configuration + static class WebServerConfig implements BeanPostProcessor, EnvironmentAware { + + private final MockWebServer server = new MockWebServer(); + + @PreDestroy + void shutdown() throws IOException { + this.server.shutdown(); + } + + @Override + public void setEnvironment(Environment environment) { + if (environment instanceof ConfigurableEnvironment) { + ((ConfigurableEnvironment) environment).getPropertySources() + .addFirst(new MockWebServerPropertySource()); + } + } + + @Bean + MockWebServer web() { + return this.server; + } + + private class MockWebServerPropertySource extends PropertySource { + + MockWebServerPropertySource() { + super("mockwebserver"); + } + + @Override + public Object getProperty(String name) { + if ("mockwebserver.url".equals(name)) { + return WebServerConfig.this.server.url("/.well-known/jwks.json").toString(); + } + else { + return null; + } + } + + } + + } + + @Configuration + static class RestOperationsConfig { + + RestOperations rest = mock(RestOperations.class); + + @Bean + RestOperations rest() { + return this.rest; + } + + @Bean + NimbusJwtDecoder jwtDecoder() { + return NimbusJwtDecoder.withJwkSetUri("https://example.org/.well-known/jwks.json").restOperations(this.rest) + .build(); + } + + @Bean + NimbusOpaqueTokenIntrospector tokenIntrospectionClient() { + return new NimbusOpaqueTokenIntrospector("https://example.org/introspect", this.rest); + } + + } + + private static class BearerTokenRequestPostProcessor implements RequestPostProcessor { + + private boolean asRequestParameter; + + private String token; + + BearerTokenRequestPostProcessor(String token) { + this.token = token; + } + + BearerTokenRequestPostProcessor asParam() { + this.asRequestParameter = true; + return this; + } + + @Override + public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { + if (this.asRequestParameter) { + request.setParameter("access_token", this.token); + } + else { + request.addHeader("Authorization", "Bearer " + this.token); + } + return request; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurerTests.java index 159c54967e..420e8e1a0b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurerTests.java @@ -16,13 +16,17 @@ package org.springframework.security.config.annotation.web.configurers.openid; +import java.util.List; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.Rule; import org.junit.Test; import org.openid4java.consumer.ConsumerManager; import org.openid4java.discovery.DiscoveryInformation; +import org.openid4java.discovery.yadis.YadisResolver; import org.openid4java.message.AuthRequest; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.ObjectPostProcessor; @@ -36,17 +40,15 @@ import org.springframework.security.openid.OpenIDAuthenticationFilter; import org.springframework.security.openid.OpenIDAuthenticationProvider; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; - -import java.util.List; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.openid4java.discovery.yadis.YadisResolver.YADIS_XRDS_LOCATION; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -70,22 +72,104 @@ public class OpenIDLoginConfigurerTests { public void configureWhenRegisteringObjectPostProcessorThenInvokedOnOpenIDAuthenticationFilter() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); - - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(OpenIDAuthenticationFilter.class)); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(OpenIDAuthenticationFilter.class)); } @Test public void configureWhenRegisteringObjectPostProcessorThenInvokedOnOpenIDAuthenticationProvider() { ObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class); this.spring.register(ObjectPostProcessorConfig.class).autowire(); + verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(OpenIDAuthenticationProvider.class)); + } - verify(ObjectPostProcessorConfig.objectPostProcessor) - .postProcess(any(OpenIDAuthenticationProvider.class)); + @Test + public void openidLoginWhenInvokedTwiceThenUsesOriginalLoginPage() throws Exception { + this.spring.register(InvokeTwiceDoesNotOverrideConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/login/custom")); + // @formatter:on + } + + @Test + public void requestWhenOpenIdLoginPageInLambdaThenRedirectsToLoginPAge() throws Exception { + this.spring.register(OpenIdLoginPageInLambdaConfig.class).autowire(); + // @formatter:off + this.mvc.perform(get("/")) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/login/custom")); + // @formatter:on + } + + @Test + public void requestWhenAttributeExchangeConfiguredThenFetchAttributesMatchAttributeList() throws Exception { + OpenIdAttributesInLambdaConfig.CONSUMER_MANAGER = mock(ConsumerManager.class); + AuthRequest mockAuthRequest = mock(AuthRequest.class); + DiscoveryInformation mockDiscoveryInformation = mock(DiscoveryInformation.class); + given(mockAuthRequest.getDestinationUrl(anyBoolean())).willReturn("mockUrl"); + given(OpenIdAttributesInLambdaConfig.CONSUMER_MANAGER.associate(any())).willReturn(mockDiscoveryInformation); + given(OpenIdAttributesInLambdaConfig.CONSUMER_MANAGER.authenticate(any(DiscoveryInformation.class), any(), + any())).willReturn(mockAuthRequest); + this.spring.register(OpenIdAttributesInLambdaConfig.class).autowire(); + try (MockWebServer server = new MockWebServer()) { + String endpoint = server.url("/").toString(); + server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint)); + server.enqueue(new MockResponse() + .setBody(String.format("%s", endpoint))); + MvcResult mvcResult = this.mvc.perform( + get("/login/openid").param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, endpoint)) + .andExpect(status().isFound()).andReturn(); + Object attributeObject = mvcResult.getRequest().getSession() + .getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"); + assertThat(attributeObject).isInstanceOf(List.class); + List attributeList = (List) attributeObject; + assertThat( + attributeList.stream() + .anyMatch((attribute) -> "nickname".equals(attribute.getName()) + && "https://schema.openid.net/namePerson/friendly".equals(attribute.getType()))) + .isTrue(); + assertThat(attributeList.stream() + .anyMatch((attribute) -> "email".equals(attribute.getName()) + && "https://schema.openid.net/contact/email".equals(attribute.getType()) + && attribute.isRequired() && attribute.getCount() == 2)).isTrue(); + } + } + + @Test + public void requestWhenAttributeNameNotSpecifiedThenAttributeNameDefaulted() throws Exception { + OpenIdAttributesNullNameConfig.CONSUMER_MANAGER = mock(ConsumerManager.class); + AuthRequest mockAuthRequest = mock(AuthRequest.class); + DiscoveryInformation mockDiscoveryInformation = mock(DiscoveryInformation.class); + given(mockAuthRequest.getDestinationUrl(anyBoolean())).willReturn("mockUrl"); + given(OpenIdAttributesNullNameConfig.CONSUMER_MANAGER.associate(any())).willReturn(mockDiscoveryInformation); + given(OpenIdAttributesNullNameConfig.CONSUMER_MANAGER.authenticate(any(DiscoveryInformation.class), any(), + any())).willReturn(mockAuthRequest); + this.spring.register(OpenIdAttributesNullNameConfig.class).autowire(); + try (MockWebServer server = new MockWebServer()) { + String endpoint = server.url("/").toString(); + server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint)); + server.enqueue(new MockResponse() + .setBody(String.format("%s", endpoint))); + // @formatter:off + MockHttpServletRequestBuilder request = get("/login/openid") + .param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, endpoint); + MvcResult mvcResult = this.mvc.perform(request) + .andExpect(status().isFound()) + .andReturn(); + Object attributeObject = mvcResult.getRequest().getSession() + .getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"); + // @formatter:on + assertThat(attributeObject).isInstanceOf(List.class); + List attributeList = (List) attributeObject; + assertThat(attributeList).hasSize(1); + assertThat(attributeList.get(0).getName()).isEqualTo("default-attribute"); + } } @EnableWebSecurity static class ObjectPostProcessorConfig extends WebSecurityConfigurerAdapter { + static ObjectPostProcessor objectPostProcessor; @Override @@ -100,22 +184,16 @@ public class OpenIDLoginConfigurerTests { static ObjectPostProcessor objectPostProcessor() { return objectPostProcessor; } + } static class ReflectingObjectPostProcessor implements ObjectPostProcessor { + @Override public O postProcess(O object) { return object; } - } - @Test - public void openidLoginWhenInvokedTwiceThenUsesOriginalLoginPage() throws Exception { - this.spring.register(InvokeTwiceDoesNotOverrideConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/login/custom")); } @EnableWebSecurity @@ -142,100 +220,54 @@ public class OpenIDLoginConfigurerTests { .openidLogin(); // @formatter:on } - } - @Test - public void requestWhenOpenIdLoginPageInLambdaThenRedirectsToLoginPAge() throws Exception { - this.spring.register(OpenIdLoginPageInLambdaConfig.class).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/login/custom")); } @EnableWebSecurity static class OpenIdLoginPageInLambdaConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) - .openidLogin(openIdLogin -> + .openidLogin((openIdLogin) -> openIdLogin .loginPage("/login/custom") ); // @formatter:on } - } - @Test - public void requestWhenAttributeExchangeConfiguredThenFetchAttributesMatchAttributeList() throws Exception { - OpenIdAttributesInLambdaConfig.CONSUMER_MANAGER = mock(ConsumerManager.class); - AuthRequest mockAuthRequest = mock(AuthRequest.class); - DiscoveryInformation mockDiscoveryInformation = mock(DiscoveryInformation.class); - when(mockAuthRequest.getDestinationUrl(anyBoolean())).thenReturn("mockUrl"); - when(OpenIdAttributesInLambdaConfig.CONSUMER_MANAGER.associate(any())) - .thenReturn(mockDiscoveryInformation); - when(OpenIdAttributesInLambdaConfig.CONSUMER_MANAGER.authenticate(any(DiscoveryInformation.class), any(), any())) - .thenReturn(mockAuthRequest); - this.spring.register(OpenIdAttributesInLambdaConfig.class).autowire(); - - try ( MockWebServer server = new MockWebServer() ) { - String endpoint = server.url("/").toString(); - - server.enqueue(new MockResponse() - .addHeader(YADIS_XRDS_LOCATION, endpoint)); - server.enqueue(new MockResponse() - .setBody(String.format("%s", endpoint))); - - MvcResult mvcResult = this.mvc.perform(get("/login/openid") - .param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, endpoint)) - .andExpect(status().isFound()) - .andReturn(); - - Object attributeObject = mvcResult.getRequest().getSession().getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"); - assertThat(attributeObject).isInstanceOf(List.class); - List attributeList = (List) attributeObject; - assertThat(attributeList.stream().anyMatch(attribute -> - "nickname".equals(attribute.getName()) - && "https://schema.openid.net/namePerson/friendly".equals(attribute.getType()))) - .isTrue(); - assertThat(attributeList.stream().anyMatch(attribute -> - "email".equals(attribute.getName()) - && "https://schema.openid.net/contact/email".equals(attribute.getType()) - && attribute.isRequired() - && attribute.getCount() == 2)) - .isTrue(); - } } @EnableWebSecurity static class OpenIdAttributesInLambdaConfig extends WebSecurityConfigurerAdapter { + static ConsumerManager CONSUMER_MANAGER; @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().permitAll() ) - .openidLogin(openIdLogin -> + .openidLogin((openIdLogin) -> openIdLogin .consumerManager(CONSUMER_MANAGER) - .attributeExchange(attributeExchange -> + .attributeExchange((attributeExchange) -> attributeExchange .identifierPattern(".*") - .attribute(nicknameAttribute -> + .attribute((nicknameAttribute) -> nicknameAttribute .name("nickname") .type("https://schema.openid.net/namePerson/friendly") ) - .attribute(emailAttribute -> + .attribute((emailAttribute) -> emailAttribute .name("email") .type("https://schema.openid.net/contact/email") @@ -246,58 +278,26 @@ public class OpenIDLoginConfigurerTests { ); // @formatter:on } - } - @Test - public void requestWhenAttributeNameNotSpecifiedThenAttributeNameDefaulted() - throws Exception { - OpenIdAttributesNullNameConfig.CONSUMER_MANAGER = mock(ConsumerManager.class); - AuthRequest mockAuthRequest = mock(AuthRequest.class); - DiscoveryInformation mockDiscoveryInformation = mock(DiscoveryInformation.class); - when(mockAuthRequest.getDestinationUrl(anyBoolean())).thenReturn("mockUrl"); - when(OpenIdAttributesNullNameConfig.CONSUMER_MANAGER.associate(any())) - .thenReturn(mockDiscoveryInformation); - when(OpenIdAttributesNullNameConfig.CONSUMER_MANAGER.authenticate(any(DiscoveryInformation.class), any(), any())) - .thenReturn(mockAuthRequest); - this.spring.register(OpenIdAttributesNullNameConfig.class).autowire(); - - try ( MockWebServer server = new MockWebServer() ) { - String endpoint = server.url("/").toString(); - - server.enqueue(new MockResponse() - .addHeader(YADIS_XRDS_LOCATION, endpoint)); - server.enqueue(new MockResponse() - .setBody(String.format("%s", endpoint))); - - MvcResult mvcResult = this.mvc.perform(get("/login/openid") - .param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, endpoint)) - .andExpect(status().isFound()) - .andReturn(); - - Object attributeObject = mvcResult.getRequest().getSession().getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"); - assertThat(attributeObject).isInstanceOf(List.class); - List attributeList = (List) attributeObject; - assertThat(attributeList).hasSize(1); - assertThat(attributeList.get(0).getName()).isEqualTo("default-attribute"); - } } @EnableWebSecurity static class OpenIdAttributesNullNameConfig extends WebSecurityConfigurerAdapter { + static ConsumerManager CONSUMER_MANAGER; @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().permitAll() ) - .openidLogin(openIdLogin -> + .openidLogin((openIdLogin) -> openIdLogin .consumerManager(CONSUMER_MANAGER) - .attributeExchange(attributeExchange -> + .attributeExchange((attributeExchange) -> attributeExchange .identifierPattern(".*") .attribute(withDefaults()) @@ -305,5 +305,7 @@ public class OpenIDLoginConfigurerTests { ); // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index dccebcaf93..bad65f790c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -19,6 +19,7 @@ package org.springframework.security.config.annotation.web.configurers.saml2; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Arrays; import java.util.Base64; @@ -26,6 +27,7 @@ import java.util.Collection; import java.util.Collections; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; + import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -60,14 +62,18 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; +import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects; +import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; @@ -78,22 +84,17 @@ import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; -import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest; -import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -104,13 +105,15 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. */ public class Saml2LoginConfigurerTests { - private static final Converter> - AUTHORITIES_EXTRACTOR = a -> Arrays.asList(new SimpleGrantedAuthority("TEST")); - private static final GrantedAuthoritiesMapper AUTHORITIES_MAPPER = - authorities -> Arrays.asList(new SimpleGrantedAuthority("TEST CONVERTED")); + private static final Converter> AUTHORITIES_EXTRACTOR = ( + a) -> Arrays.asList(new SimpleGrantedAuthority("TEST")); + + private static final GrantedAuthoritiesMapper AUTHORITIES_MAPPER = (authorities) -> Arrays + .asList(new SimpleGrantedAuthority("TEST CONVERTED")); + private static final Duration RESPONSE_TIME_VALIDATION_SKEW = Duration.ZERO; - private static final String SIGNED_RESPONSE = - "PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz48c2FtbDJwOlJlc3BvbnNlIHhtbG5zOnNhbWwycD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOnByb3RvY29sIiBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9ycC5leGFtcGxlLm9yZy9hY3MiIElEPSJfYzE3MzM2YTAtNTM1My00MTQ5LWI3MmMtMDNkOWY5YWYzMDdlIiBJc3N1ZUluc3RhbnQ9IjIwMjAtMDgtMDRUMjI6MDQ6NDUuMDE2WiIgVmVyc2lvbj0iMi4wIj48c2FtbDI6SXNzdWVyIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48ZHM6U2lnbmF0dXJlIHhtbG5zOmRzPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwLzA5L3htbGRzaWcjIj4KPGRzOlNpZ25lZEluZm8+CjxkczpDYW5vbmljYWxpemF0aW9uTWV0aG9kIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMS8xMC94bWwtZXhjLWMxNG4jIi8+CjxkczpTaWduYXR1cmVNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGRzaWctbW9yZSNyc2Etc2hhMjU2Ii8+CjxkczpSZWZlcmVuY2UgVVJJPSIjX2MxNzMzNmEwLTUzNTMtNDE0OS1iNzJjLTAzZDlmOWFmMzA3ZSI+CjxkczpUcmFuc2Zvcm1zPgo8ZHM6VHJhbnNmb3JtIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMC8wOS94bWxkc2lnI2VudmVsb3BlZC1zaWduYXR1cmUiLz4KPGRzOlRyYW5zZm9ybSBBbGdvcml0aG09Imh0dHA6Ly93d3cudzMub3JnLzIwMDEvMTAveG1sLWV4Yy1jMTRuIyIvPgo8L2RzOlRyYW5zZm9ybXM+CjxkczpEaWdlc3RNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGVuYyNzaGEyNTYiLz4KPGRzOkRpZ2VzdFZhbHVlPjYzTmlyenFzaDVVa0h1a3NuRWUrM0hWWU5aYWFsQW1OQXFMc1lGMlRuRDA9PC9kczpEaWdlc3RWYWx1ZT4KPC9kczpSZWZlcmVuY2U+CjwvZHM6U2lnbmVkSW5mbz4KPGRzOlNpZ25hdHVyZVZhbHVlPgpLMVlvWWJVUjBTclY4RTdVMkhxTTIvZUNTOTNoV25mOExnNnozeGZWMUlyalgzSXhWYkNvMVlYcnRBSGRwRVdvYTJKKzVOMmFNbFBHJiMxMzsKN2VpbDBZRC9xdUVRamRYbTNwQTBjZmEvY25pa2RuKzVhbnM0ZWQwanU1amo2dkpvZ2w2Smt4Q25LWUpwTU9HNzhtampmb0phengrWCYjMTM7CkM2NktQVStBYUdxeGVwUEQ1ZlhRdTFKSy9Jb3lBaitaa3k4Z2Jwc3VyZHFCSEJLRWxjdnVOWS92UGY0OGtBeFZBKzdtRGhNNUMvL1AmIzEzOwp0L084Y3NZYXB2UjZjdjZrdk45QXZ1N3FRdm9qVk1McHVxZWNJZDJwTUVYb0NSSnE2Nkd4MStNTUVPeHVpMWZZQlRoMEhhYjRmK3JyJiMxMzsKOEY2V1NFRC8xZllVeHliRkJqZ1Q4d2lEWHFBRU8wSVY4ZWRQeEE9PQo8L2RzOlNpZ25hdHVyZVZhbHVlPgo8L2RzOlNpZ25hdHVyZT48c2FtbDI6QXNzZXJ0aW9uIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIiBJRD0iQWUzZjQ5OGI4LTliMTctNDA3OC05ZDM1LTg2YTA4NDA4NDk5NSIgSXNzdWVJbnN0YW50PSIyMDIwLTA4LTA0VDIyOjA0OjQ1LjA3N1oiIFZlcnNpb249IjIuMCI+PHNhbWwyOklzc3Vlcj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48c2FtbDI6U3ViamVjdD48c2FtbDI6TmFtZUlEPnRlc3RAc2FtbC51c2VyPC9zYW1sMjpOYW1lSUQ+PHNhbWwyOlN1YmplY3RDb25maXJtYXRpb24gTWV0aG9kPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6Y206YmVhcmVyIj48c2FtbDI6U3ViamVjdENvbmZpcm1hdGlvbkRhdGEgTm90QmVmb3JlPSIyMDIwLTA4LTA0VDIxOjU5OjQ1LjA5MFoiIE5vdE9uT3JBZnRlcj0iMjA0MC0wNy0zMFQyMjowNTowNi4wODhaIiBSZWNpcGllbnQ9Imh0dHBzOi8vcnAuZXhhbXBsZS5vcmcvYWNzIi8+PC9zYW1sMjpTdWJqZWN0Q29uZmlybWF0aW9uPjwvc2FtbDI6U3ViamVjdD48c2FtbDI6Q29uZGl0aW9ucyBOb3RCZWZvcmU9IjIwMjAtMDgtMDRUMjE6NTk6NDUuMDgwWiIgTm90T25PckFmdGVyPSIyMDQwLTA3LTMwVDIyOjA1OjA2LjA4N1oiLz48L3NhbWwyOkFzc2VydGlvbj48L3NhbWwycDpSZXNwb25zZT4="; + + private static final String SIGNED_RESPONSE = "PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz48c2FtbDJwOlJlc3BvbnNlIHhtbG5zOnNhbWwycD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOnByb3RvY29sIiBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9ycC5leGFtcGxlLm9yZy9hY3MiIElEPSJfYzE3MzM2YTAtNTM1My00MTQ5LWI3MmMtMDNkOWY5YWYzMDdlIiBJc3N1ZUluc3RhbnQ9IjIwMjAtMDgtMDRUMjI6MDQ6NDUuMDE2WiIgVmVyc2lvbj0iMi4wIj48c2FtbDI6SXNzdWVyIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48ZHM6U2lnbmF0dXJlIHhtbG5zOmRzPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwLzA5L3htbGRzaWcjIj4KPGRzOlNpZ25lZEluZm8+CjxkczpDYW5vbmljYWxpemF0aW9uTWV0aG9kIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMS8xMC94bWwtZXhjLWMxNG4jIi8+CjxkczpTaWduYXR1cmVNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGRzaWctbW9yZSNyc2Etc2hhMjU2Ii8+CjxkczpSZWZlcmVuY2UgVVJJPSIjX2MxNzMzNmEwLTUzNTMtNDE0OS1iNzJjLTAzZDlmOWFmMzA3ZSI+CjxkczpUcmFuc2Zvcm1zPgo8ZHM6VHJhbnNmb3JtIEFsZ29yaXRobT0iaHR0cDovL3d3dy53My5vcmcvMjAwMC8wOS94bWxkc2lnI2VudmVsb3BlZC1zaWduYXR1cmUiLz4KPGRzOlRyYW5zZm9ybSBBbGdvcml0aG09Imh0dHA6Ly93d3cudzMub3JnLzIwMDEvMTAveG1sLWV4Yy1jMTRuIyIvPgo8L2RzOlRyYW5zZm9ybXM+CjxkczpEaWdlc3RNZXRob2QgQWxnb3JpdGhtPSJodHRwOi8vd3d3LnczLm9yZy8yMDAxLzA0L3htbGVuYyNzaGEyNTYiLz4KPGRzOkRpZ2VzdFZhbHVlPjYzTmlyenFzaDVVa0h1a3NuRWUrM0hWWU5aYWFsQW1OQXFMc1lGMlRuRDA9PC9kczpEaWdlc3RWYWx1ZT4KPC9kczpSZWZlcmVuY2U+CjwvZHM6U2lnbmVkSW5mbz4KPGRzOlNpZ25hdHVyZVZhbHVlPgpLMVlvWWJVUjBTclY4RTdVMkhxTTIvZUNTOTNoV25mOExnNnozeGZWMUlyalgzSXhWYkNvMVlYcnRBSGRwRVdvYTJKKzVOMmFNbFBHJiMxMzsKN2VpbDBZRC9xdUVRamRYbTNwQTBjZmEvY25pa2RuKzVhbnM0ZWQwanU1amo2dkpvZ2w2Smt4Q25LWUpwTU9HNzhtampmb0phengrWCYjMTM7CkM2NktQVStBYUdxeGVwUEQ1ZlhRdTFKSy9Jb3lBaitaa3k4Z2Jwc3VyZHFCSEJLRWxjdnVOWS92UGY0OGtBeFZBKzdtRGhNNUMvL1AmIzEzOwp0L084Y3NZYXB2UjZjdjZrdk45QXZ1N3FRdm9qVk1McHVxZWNJZDJwTUVYb0NSSnE2Nkd4MStNTUVPeHVpMWZZQlRoMEhhYjRmK3JyJiMxMzsKOEY2V1NFRC8xZllVeHliRkJqZ1Q4d2lEWHFBRU8wSVY4ZWRQeEE9PQo8L2RzOlNpZ25hdHVyZVZhbHVlPgo8L2RzOlNpZ25hdHVyZT48c2FtbDI6QXNzZXJ0aW9uIHhtbG5zOnNhbWwyPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6YXNzZXJ0aW9uIiBJRD0iQWUzZjQ5OGI4LTliMTctNDA3OC05ZDM1LTg2YTA4NDA4NDk5NSIgSXNzdWVJbnN0YW50PSIyMDIwLTA4LTA0VDIyOjA0OjQ1LjA3N1oiIFZlcnNpb249IjIuMCI+PHNhbWwyOklzc3Vlcj5hcC1lbnRpdHktaWQ8L3NhbWwyOklzc3Vlcj48c2FtbDI6U3ViamVjdD48c2FtbDI6TmFtZUlEPnRlc3RAc2FtbC51c2VyPC9zYW1sMjpOYW1lSUQ+PHNhbWwyOlN1YmplY3RDb25maXJtYXRpb24gTWV0aG9kPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6Y206YmVhcmVyIj48c2FtbDI6U3ViamVjdENvbmZpcm1hdGlvbkRhdGEgTm90QmVmb3JlPSIyMDIwLTA4LTA0VDIxOjU5OjQ1LjA5MFoiIE5vdE9uT3JBZnRlcj0iMjA0MC0wNy0zMFQyMjowNTowNi4wODhaIiBSZWNpcGllbnQ9Imh0dHBzOi8vcnAuZXhhbXBsZS5vcmcvYWNzIi8+PC9zYW1sMjpTdWJqZWN0Q29uZmlybWF0aW9uPjwvc2FtbDI6U3ViamVjdD48c2FtbDI6Q29uZGl0aW9ucyBOb3RCZWZvcmU9IjIwMjAtMDgtMDRUMjE6NTk6NDUuMDgwWiIgTm90T25PckFmdGVyPSIyMDQwLTA3LTMwVDIyOjA1OjA2LjA4N1oiLz48L3NhbWwyOkFzc2VydGlvbj48L3NhbWwycDpSZXNwb25zZT4="; @Autowired private ConfigurableApplicationContext context; @@ -131,7 +134,9 @@ public class Saml2LoginConfigurerTests { MockMvc mvc; private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockFilterChain filterChain; @Before @@ -157,7 +162,8 @@ public class Saml2LoginConfigurerTests { } @Test - public void saml2LoginWhenConfiguringAuthenticationDefaultsUsingCustomizerThenTheProviderIsConfigured() throws Exception { + public void saml2LoginWhenConfiguringAuthenticationDefaultsUsingCustomizerThenTheProviderIsConfigured() + throws Exception { // setup application context this.spring.register(Saml2LoginConfigWithAuthenticationDefaultsWithPostProcessor.class).autowire(); validateSaml2WebSsoAuthenticationFilterConfiguration(); @@ -166,14 +172,11 @@ public class Saml2LoginConfigurerTests { @Test public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() throws Exception { this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire(); - - Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); - Saml2AuthenticationRequestContextResolver resolver = - CustomAuthenticationRequestContextResolver.resolver; - when(resolver.resolve(any(HttpServletRequest.class))) - .thenReturn(context); - this.mvc.perform(get("/saml2/authenticate/registration-id")) - .andExpect(status().isFound()); + Saml2AuthenticationRequestContext context = TestSaml2AuthenticationRequestContexts + .authenticationRequestContext().build(); + Saml2AuthenticationRequestContextResolver resolver = CustomAuthenticationRequestContextResolver.resolver; + given(resolver.resolve(any(HttpServletRequest.class))).willReturn(context); + this.mvc.perform(get("/saml2/authenticate/registration-id")).andExpect(status().isFound()); verify(resolver).resolve(any(HttpServletRequest.class)); } @@ -181,10 +184,8 @@ public class Saml2LoginConfigurerTests { public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception { this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire(); - MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")) - .andReturn(); - UriComponents components = UriComponentsBuilder - .fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); + MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn(); + UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build(); String samlRequest = components.getQueryParams().getFirst("SAMLRequest"); String decoded = URLDecoder.decode(samlRequest, "UTF-8"); String inflated = samlInflate(samlDecode(decoded)); @@ -194,67 +195,91 @@ public class Saml2LoginConfigurerTests { @Test public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception { this.spring.register(CustomAuthenticationConverter.class).autowire(); - RelyingPartyRegistration relyingPartyRegistration = noCredentials() - .assertingPartyDetails(party -> party - .verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())) - ) + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials( + (c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) .build(); String response = new String(samlDecode(SIGNED_RESPONSE)); - when(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class))) - .thenReturn(new Saml2AuthenticationToken(relyingPartyRegistration, response)); - this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()) - .param("SAMLResponse", SIGNED_RESPONSE)) - .andExpect(redirectedUrl("/")); + given(CustomAuthenticationConverter.authenticationConverter.convert(any(HttpServletRequest.class))) + .willReturn(new Saml2AuthenticationToken(relyingPartyRegistration, response)); + // @formatter:off + MockHttpServletRequestBuilder request = post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()) + .param("SAMLResponse", SIGNED_RESPONSE); + // @formatter:on + this.mvc.perform(request).andExpect(redirectedUrl("/")); verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class)); } private void validateSaml2WebSsoAuthenticationFilterConfiguration() { // get the OpenSamlAuthenticationProvider Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain); - AuthenticationManager manager = - (AuthenticationManager) ReflectionTestUtils.getField(filter, "authenticationManager"); + AuthenticationManager manager = (AuthenticationManager) ReflectionTestUtils.getField(filter, + "authenticationManager"); ProviderManager pm = (ProviderManager) manager; - AuthenticationProvider provider = pm.getProviders() - .stream() - .filter(p -> p instanceof OpenSamlAuthenticationProvider) - .findFirst() - .get(); + AuthenticationProvider provider = pm.getProviders().stream() + .filter((p) -> p instanceof OpenSamlAuthenticationProvider).findFirst().get(); Assert.assertSame(AUTHORITIES_EXTRACTOR, ReflectionTestUtils.getField(provider, "authoritiesExtractor")); Assert.assertSame(AUTHORITIES_MAPPER, ReflectionTestUtils.getField(provider, "authoritiesMapper")); - Assert.assertSame(RESPONSE_TIME_VALIDATION_SKEW, ReflectionTestUtils.getField(provider, "responseTimeValidationSkew")); + Assert.assertSame(RESPONSE_TIME_VALIDATION_SKEW, + ReflectionTestUtils.getField(provider, "responseTimeValidationSkew")); } private Saml2WebSsoAuthenticationFilter getSaml2SsoFilter(FilterChainProxy chain) { - return (Saml2WebSsoAuthenticationFilter) chain.getFilters("/login/saml2/sso/test") - .stream() - .filter(f -> f instanceof Saml2WebSsoAuthenticationFilter) - .findFirst() - .get(); + return (Saml2WebSsoAuthenticationFilter) chain.getFilters("/login/saml2/sso/test").stream() + .filter((f) -> f instanceof Saml2WebSsoAuthenticationFilter).findFirst().get(); } private void performSaml2Login(String expected) throws IOException, ServletException { // setup authentication parameters - this.request.setParameter( - "SAMLResponse", - Base64.getEncoder().encodeToString( - "saml2-xml-response-object".getBytes() - ) - ); - - + this.request.setParameter("SAMLResponse", + Base64.getEncoder().encodeToString("saml2-xml-response-object".getBytes())); // perform test this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); - // assertions Authentication authentication = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(this.request, this.response)) - .getAuthentication(); + .loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication(); Assert.assertNotNull("Expected a valid authentication object.", authentication); assertThat(authentication.getAuthorities()).hasSize(1); - assertThat(authentication.getAuthorities()).first() - .isInstanceOf(SimpleGrantedAuthority.class).hasToString(expected); + assertThat(authentication.getAuthorities()).first().isInstanceOf(SimpleGrantedAuthority.class) + .hasToString(expected); } + private static org.apache.commons.codec.binary.Base64 BASE64 = new org.apache.commons.codec.binary.Base64(0, + new byte[] { '\n' }); + + private static byte[] samlDecode(String s) { + return BASE64.decode(s); + } + + private static String samlInflate(byte[] b) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); + iout.write(b); + iout.finish(); + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } + catch (IOException ex) { + throw new Saml2Exception("Unable to inflate string", ex); + } + } + + private static AuthenticationManager getAuthenticationManagerMock(String role) { + return new AuthenticationManager() { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + if (!supports(authentication.getClass())) { + throw new AuthenticationServiceException("not supported"); + } + return new Saml2Authentication(() -> "auth principal", "saml2 response", + Collections.singletonList(new SimpleGrantedAuthority(role))); + } + + public boolean supports(Class authentication) { + return authentication.isAssignableFrom(Saml2AuthenticationToken.class); + } + }; + } @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) @@ -262,12 +287,10 @@ public class Saml2LoginConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { - http.saml2Login() - .authenticationManager( - getAuthenticationManagerMock("ROLE_AUTH_MANAGER") - ); + http.saml2Login().authenticationManager(getAuthenticationManagerMock("ROLE_AUTH_MANAGER")); super.configure(http); } + } @EnableWebSecurity @@ -276,8 +299,7 @@ public class Saml2LoginConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { - ObjectPostProcessor processor - = new ObjectPostProcessor() { + ObjectPostProcessor processor = new ObjectPostProcessor() { @Override public O postProcess(O provider) { provider.setResponseTimeValidationSkew(RESPONSE_TIME_VALIDATION_SKEW); @@ -286,33 +308,35 @@ public class Saml2LoginConfigurerTests { return provider; } }; - - http.saml2Login() - .addObjectPostProcessor(processor) - ; + http.saml2Login().addObjectPostProcessor(processor); super.configure(http); } + } @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter { - private static final Saml2AuthenticationRequestContextResolver resolver = - mock(Saml2AuthenticationRequestContextResolver.class); + + private static final Saml2AuthenticationRequestContextResolver resolver = mock( + Saml2AuthenticationRequestContextResolver.class); @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .authorizeRequests(authz -> authz + .authorizeRequests((authz) -> authz .anyRequest().authenticated() ) .saml2Login(withDefaults()); + // @formatter:on } @Bean Saml2AuthenticationRequestContextResolver resolver() { return resolver; } + } @EnableWebSecurity @@ -321,66 +345,41 @@ public class Saml2LoginConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .authorizeRequests(authz -> authz + .authorizeRequests((authz) -> authz .anyRequest().authenticated() ) - .saml2Login(saml2 -> {}); + .saml2Login((saml2) -> { + }); + // @formatter:on } @Bean Saml2AuthenticationRequestFactory authenticationRequestFactory() { - OpenSamlAuthenticationRequestFactory authenticationRequestFactory = - new OpenSamlAuthenticationRequestFactory(); - authenticationRequestFactory.setAuthenticationRequestContextConverter( - context -> { - AuthnRequest authnRequest = authnRequest(); - authnRequest.setForceAuthn(true); - return authnRequest; - }); + OpenSamlAuthenticationRequestFactory authenticationRequestFactory = new OpenSamlAuthenticationRequestFactory(); + authenticationRequestFactory.setAuthenticationRequestContextConverter((context) -> { + AuthnRequest authnRequest = TestOpenSamlObjects.authnRequest(); + authnRequest.setForceAuthn(true); + return authnRequest; + }); return authenticationRequestFactory; } + } @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter { + static final AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); @Override protected void configure(HttpSecurity http) throws Exception { - http - .authorizeRequests(authz -> authz - .anyRequest().authenticated() - ) - .saml2Login(saml2 -> saml2 - .authenticationConverter(authenticationConverter) - ); + http.authorizeRequests((authz) -> authz.anyRequest().authenticated()) + .saml2Login((saml2) -> saml2.authenticationConverter(authenticationConverter)); } - } - private static AuthenticationManager getAuthenticationManagerMock(String role) { - return new AuthenticationManager() { - - @Override - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { - if (!supports(authentication.getClass())) { - throw new AuthenticationServiceException("not supported"); - } - return new Saml2Authentication( - () -> "auth principal", - "saml2 response", - Collections.singletonList( - new SimpleGrantedAuthority(role) - ) - ); - } - - public boolean supports(Class authentication) { - return authentication.isAssignableFrom(Saml2AuthenticationToken.class); - } - }; } static class Saml2LoginConfigBeans { @@ -393,29 +392,11 @@ public class Saml2LoginConfigurerTests { @Bean RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() { RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); - when(repository.findByRegistrationId(anyString())) - .thenReturn(relyingPartyRegistration().build()); + given(repository.findByRegistrationId(anyString())) + .willReturn(TestRelyingPartyRegistrations.relyingPartyRegistration().build()); return repository; } + } - private static org.apache.commons.codec.binary.Base64 BASE64 = - new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'}); - - private static byte[] samlDecode(String s) { - return BASE64.decode(s); - } - - private static String samlInflate(byte[] b) { - try { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); - iout.write(b); - iout.finish(); - return new String(out.toByteArray(), UTF_8); - } - catch (IOException e) { - throw new Saml2Exception("Unable to inflate string", e); - } - } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestSaml2Credentials.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestSaml2Credentials.java index 4a6922f4df..32d753450f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestSaml2Credentials.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestSaml2Credentials.java @@ -16,101 +16,105 @@ package org.springframework.security.config.annotation.web.configurers.saml2; -import org.springframework.security.converter.RsaKeyConverters; -import org.springframework.security.saml2.credentials.Saml2X509Credential; - import java.io.ByteArrayInputStream; import java.nio.charset.StandardCharsets; import java.security.PrivateKey; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION; +import org.springframework.security.converter.RsaKeyConverters; +import org.springframework.security.saml2.credentials.Saml2X509Credential; +import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType; /** * Preconfigured SAML credentials for SAML integration tests. */ -public class TestSaml2Credentials { +public final class TestSaml2Credentials { + + private TestSaml2Credentials() { + } static Saml2X509Credential verificationCertificate() { - String certificate = "-----BEGIN CERTIFICATE-----\n" + - "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" + - "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" + - "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" + - "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" + - "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" + - "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" + - "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" + - "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" + - "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" + - "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" + - "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" + - "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" + - "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" + - "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" + - "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" + - "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" + - "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" + - "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" + - "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" + - "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" + - "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" + - "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + - "-----END CERTIFICATE-----"; - return new Saml2X509Credential( - x509Certificate(certificate), - VERIFICATION - ); + // @formatter:off + String certificate = "-----BEGIN CERTIFICATE-----\n" + + "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" + + "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" + + "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" + + "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" + + "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" + + "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" + + "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" + + "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" + + "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" + + "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" + + "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" + + "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" + + "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" + + "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" + + "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" + + "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" + + "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" + + "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" + + "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" + + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" + + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" + + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + + "-----END CERTIFICATE-----"; + // @formatter:on + return new Saml2X509Credential(x509Certificate(certificate), Saml2X509CredentialType.VERIFICATION); } static X509Certificate x509Certificate(String source) { try { final CertificateFactory factory = CertificateFactory.getInstance("X.509"); - return (X509Certificate) factory.generateCertificate( - new ByteArrayInputStream(source.getBytes(StandardCharsets.UTF_8)) - ); - } catch (Exception e) { - throw new IllegalArgumentException(e); + return (X509Certificate) factory + .generateCertificate(new ByteArrayInputStream(source.getBytes(StandardCharsets.UTF_8))); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); } } static Saml2X509Credential signingCredential() { - String key = "-----BEGIN PRIVATE KEY-----\n" + - "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + - "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + - "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + - "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + - "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + - "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + - "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + - "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + - "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + - "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + - "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + - "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + - "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + - "INrtuLp4YHbgk1mi\n" + - "-----END PRIVATE KEY-----"; - String certificate = "-----BEGIN CERTIFICATE-----\n" + - "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + - "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + - "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + - "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + - "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + - "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + - "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + - "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + - "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + - "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + - "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + - "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + - "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + - "-----END CERTIFICATE-----"; + // @formatter:off + String key = "-----BEGIN PRIVATE KEY-----\n" + + "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + + "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + + "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + + "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + + "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + + "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + + "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + + "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + + "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + + "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + + "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + + "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + + "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + + "INrtuLp4YHbgk1mi\n" + + "-----END PRIVATE KEY-----"; + // @formatter:on + // @formatter:off + String certificate = "-----BEGIN CERTIFICATE-----\n" + + "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + + "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + + "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + + "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + + "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + + "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + + "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + + "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + + "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + + "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + + "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + + "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + + "-----END CERTIFICATE-----"; + // @formatter:on PrivateKey pk = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(key.getBytes())); X509Certificate cert = x509Certificate(certificate); - return new Saml2X509Credential(pk, cert, SIGNING, DECRYPTION); + return new Saml2X509Credential(pk, cert, Saml2X509CredentialType.SIGNING, Saml2X509CredentialType.DECRYPTION); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistryTests.java index 1893a4f082..68d23d130e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistryTests.java @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.messaging; +import java.util.Collection; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; @@ -29,13 +33,12 @@ import org.springframework.security.messaging.access.intercept.MessageSecurityMe import org.springframework.security.messaging.util.matcher.MessageMatcher; import org.springframework.util.AntPathMatcher; -import java.util.Collection; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; @RunWith(MockitoJUnitRunner.class) public class MessageSecurityMetadataSourceRegistryTests { + @Mock private MessageMatcher matcher; @@ -45,12 +48,13 @@ public class MessageSecurityMetadataSourceRegistryTests { @Before public void setup() { - messages = new MessageSecurityMetadataSourceRegistry(); - message = MessageBuilder - .withPayload("Hi") + this.messages = new MessageSecurityMetadataSourceRegistry(); + // @formatter:off + this.message = MessageBuilder.withPayload("Hi") .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "location") - .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.MESSAGE).build(); + .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.MESSAGE) + .build(); + // @formatter:on } // See @@ -58,248 +62,275 @@ public class MessageSecurityMetadataSourceRegistryTests { // https://jira.spring.io/browse/SPR-11660 @Test public void simpDestMatchersCustom() { - message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "price.stock.1.2").build(); - messages.simpDestPathMatcher(new AntPathMatcher(".")) - .simpDestMatchers("price.stock.*").permitAll(); - + // @formatter:off + this.message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "price.stock.1.2") + .build(); + // @formatter:on + this.messages.simpDestPathMatcher(new AntPathMatcher(".")).simpDestMatchers("price.stock.*").permitAll(); assertThat(getAttribute()).isNull(); - - message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "price.stock.1.2").build(); - messages.simpDestPathMatcher(new AntPathMatcher(".")) - .simpDestMatchers("price.stock.**").permitAll(); - + // @formatter:off + this.message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "price.stock.1.2") + .build(); + // @formatter:on + this.messages.simpDestPathMatcher(new AntPathMatcher(".")).simpDestMatchers("price.stock.**").permitAll(); assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpDestMatchersCustomSetAfterMatchersDoesNotMatter() { - message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "price.stock.1.2").build(); - messages.simpDestMatchers("price.stock.*").permitAll() - .simpDestPathMatcher(new AntPathMatcher(".")); - + // @formatter:off + this.message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "price.stock.1.2") + .build(); + // @formatter:on + this.messages.simpDestMatchers("price.stock.*").permitAll().simpDestPathMatcher(new AntPathMatcher(".")); assertThat(getAttribute()).isNull(); - - message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "price.stock.1.2").build(); - messages.simpDestMatchers("price.stock.**").permitAll() - .simpDestPathMatcher(new AntPathMatcher(".")); - + // @formatter:off + this.message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "price.stock.1.2") + .build(); + // @formatter:on + this.messages.simpDestMatchers("price.stock.**").permitAll().simpDestPathMatcher(new AntPathMatcher(".")); assertThat(getAttribute()).isEqualTo("permitAll"); } @Test(expected = IllegalArgumentException.class) public void pathMatcherNull() { - messages.simpDestPathMatcher(null); + this.messages.simpDestPathMatcher(null); } @Test public void matchersFalse() { - messages.matchers(matcher).permitAll(); - + this.messages.matchers(this.matcher).permitAll(); assertThat(getAttribute()).isNull(); } @Test public void matchersTrue() { - when(matcher.matches(message)).thenReturn(true); - messages.matchers(matcher).permitAll(); - + given(this.matcher.matches(this.message)).willReturn(true); + this.messages.matchers(this.matcher).permitAll(); assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpDestMatchersExact() { - messages.simpDestMatchers("location").permitAll(); - + this.messages.simpDestMatchers("location").permitAll(); assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpDestMatchersMulti() { - messages.simpDestMatchers("admin/**", "api/**").hasRole("ADMIN") + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "api/**").hasRole("ADMIN") .simpDestMatchers("location").permitAll(); - + // @formatter:on assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpDestMatchersRole() { - messages.simpDestMatchers("admin/**", "location/**").hasRole("ADMIN") + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").hasRole("ADMIN") .anyMessage().denyAll(); - + // @formatter:on assertThat(getAttribute()).isEqualTo("hasRole('ROLE_ADMIN')"); } @Test public void simpDestMatchersAnyRole() { - messages.simpDestMatchers("admin/**", "location/**").hasAnyRole("ADMIN", "ROOT") + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").hasAnyRole("ADMIN", "ROOT") .anyMessage().denyAll(); - + // @formatter:on assertThat(getAttribute()).isEqualTo("hasAnyRole('ROLE_ADMIN','ROLE_ROOT')"); } @Test public void simpDestMatchersAuthority() { - messages.simpDestMatchers("admin/**", "location/**").hasAuthority("ROLE_ADMIN") + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").hasAuthority("ROLE_ADMIN") .anyMessage().fullyAuthenticated(); - + // @formatter:on assertThat(getAttribute()).isEqualTo("hasAuthority('ROLE_ADMIN')"); } @Test public void simpDestMatchersAccess() { String expected = "hasRole('ROLE_ADMIN') and fullyAuthenticated"; - messages.simpDestMatchers("admin/**", "location/**").access(expected) - .anyMessage().denyAll(); - + this.messages.simpDestMatchers("admin/**", "location/**").access(expected).anyMessage().denyAll(); assertThat(getAttribute()).isEqualTo(expected); } @Test public void simpDestMatchersAnyAuthority() { - messages.simpDestMatchers("admin/**", "location/**") - .hasAnyAuthority("ROLE_ADMIN", "ROLE_ROOT").anyMessage().denyAll(); - + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").hasAnyAuthority("ROLE_ADMIN", "ROLE_ROOT") + .anyMessage().denyAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("hasAnyAuthority('ROLE_ADMIN','ROLE_ROOT')"); } @Test public void simpDestMatchersRememberMe() { - messages.simpDestMatchers("admin/**", "location/**").rememberMe().anyMessage() - .denyAll(); - + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").rememberMe() + .anyMessage().denyAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("rememberMe"); } @Test public void simpDestMatchersAnonymous() { - messages.simpDestMatchers("admin/**", "location/**").anonymous().anyMessage() - .denyAll(); - + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").anonymous() + .anyMessage().denyAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("anonymous"); } @Test public void simpDestMatchersFullyAuthenticated() { - messages.simpDestMatchers("admin/**", "location/**").fullyAuthenticated() + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").fullyAuthenticated() .anyMessage().denyAll(); - + // @formatter:on assertThat(getAttribute()).isEqualTo("fullyAuthenticated"); } @Test public void simpDestMatchersDenyAll() { - messages.simpDestMatchers("admin/**", "location/**").denyAll().anyMessage() - .permitAll(); - + // @formatter:off + this.messages + .simpDestMatchers("admin/**", "location/**").denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("denyAll"); } @Test public void simpDestMessageMatchersNotMatch() { - messages.simpMessageDestMatchers("admin/**").denyAll().anyMessage().permitAll(); - + // @formatter:off + this.messages. + simpMessageDestMatchers("admin/**").denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpDestMessageMatchersMatch() { - messages.simpMessageDestMatchers("location/**").denyAll().anyMessage() - .permitAll(); - + // @formatter:off + this.messages + .simpMessageDestMatchers("location/**").denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("denyAll"); } @Test public void simpDestSubscribeMatchersNotMatch() { - messages.simpSubscribeDestMatchers("location/**").denyAll().anyMessage() - .permitAll(); - + // @formatter:off + this.messages + .simpSubscribeDestMatchers("location/**").denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpDestSubscribeMatchersMatch() { - message = MessageBuilder - .fromMessage(message) - .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.SUBSCRIBE).build(); - - messages.simpSubscribeDestMatchers("location/**").denyAll().anyMessage() - .permitAll(); - + // @formatter:off + this.message = MessageBuilder.fromMessage(this.message) + .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.SUBSCRIBE) + .build(); + this.messages + .simpSubscribeDestMatchers("location/**").denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("denyAll"); } @Test public void nullDestMatcherNotMatches() { - messages.nullDestMatcher().denyAll().anyMessage().permitAll(); - + // @formatter:off + this.messages + .nullDestMatcher().denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void nullDestMatcherMatch() { - message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.CONNECT).build(); - - messages.nullDestMatcher().denyAll().anyMessage().permitAll(); - + // @formatter:off + this.message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.CONNECT) + .build(); + this.messages + .nullDestMatcher().denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("denyAll"); } @Test public void simpTypeMatchersMatch() { - messages.simpTypeMatchers(SimpMessageType.MESSAGE).denyAll().anyMessage() - .permitAll(); - + // @formatter:off + this.messages + .simpTypeMatchers(SimpMessageType.MESSAGE).denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("denyAll"); } @Test public void simpTypeMatchersMatchMulti() { - messages.simpTypeMatchers(SimpMessageType.CONNECT, SimpMessageType.MESSAGE) - .denyAll().anyMessage().permitAll(); - + // @formatter:off + this.messages + .simpTypeMatchers(SimpMessageType.CONNECT, SimpMessageType.MESSAGE).denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("denyAll"); } @Test public void simpTypeMatchersNotMatch() { - messages.simpTypeMatchers(SimpMessageType.CONNECT).denyAll().anyMessage() - .permitAll(); - + // @formatter:off + this.messages + .simpTypeMatchers(SimpMessageType.CONNECT).denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("permitAll"); } @Test public void simpTypeMatchersNotMatchMulti() { - messages.simpTypeMatchers(SimpMessageType.CONNECT, SimpMessageType.DISCONNECT) - .denyAll().anyMessage().permitAll(); - + // @formatter:off + this.messages + .simpTypeMatchers(SimpMessageType.CONNECT, SimpMessageType.DISCONNECT).denyAll() + .anyMessage().permitAll(); + // @formatter:on assertThat(getAttribute()).isEqualTo("permitAll"); } private String getAttribute() { - MessageSecurityMetadataSource source = messages.createMetadataSource(); - Collection attrs = source.getAttributes(message); + MessageSecurityMetadataSource source = this.messages.createMetadataSource(); + Collection attrs = source.getAttributes(this.message); if (attrs == null) { return null; } assertThat(attrs).hasSize(1); return attrs.iterator().next().toString(); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java index 4be1c2b325..fdee0ecb98 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java @@ -16,9 +16,13 @@ package org.springframework.security.config.annotation.web.reactive; +import java.nio.charset.StandardCharsets; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import reactor.core.publisher.Mono; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; @@ -64,10 +68,7 @@ import org.springframework.web.reactive.config.DelegatingWebFluxConfiguration; import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.result.view.AbstractView; -import reactor.core.publisher.Mono; - -import java.nio.charset.StandardCharsets; -import java.security.Principal; +import org.springframework.web.server.WebFilter; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf; @@ -79,6 +80,7 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock @RunWith(SpringRunner.class) @SecurityTestExecutionListeners public class EnableWebFluxSecurityTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -88,49 +90,48 @@ public class EnableWebFluxSecurityTests { @Test public void defaultRequiresAuthentication() { this.spring.register(Config.class).autowire(); - + // @formatter:off WebTestClient client = WebTestClientBuilder - .bindToWebFilters(this.springSecurityFilterChain) - .build(); - + .bindToWebFilters(this.springSecurityFilterChain) + .build(); client.get() - .uri("/") - .exchange() - .expectStatus().isUnauthorized() - .expectBody().isEmpty(); + .uri("/") + .exchange() + .expectStatus().isUnauthorized() + .expectBody().isEmpty(); + // @formatter:on } // gh-4831 @Test public void defaultMediaAllThenUnAuthorized() { this.spring.register(Config.class).autowire(); - + // @formatter:off WebTestClient client = WebTestClientBuilder - .bindToWebFilters(this.springSecurityFilterChain) - .build(); - + .bindToWebFilters(this.springSecurityFilterChain) + .build(); client.get() - .uri("/") - .accept(MediaType.ALL) - .exchange() - .expectStatus().isUnauthorized() - .expectBody().isEmpty(); + .uri("/") + .accept(MediaType.ALL) + .exchange() + .expectStatus().isUnauthorized() + .expectBody().isEmpty(); + // @formatter:on } @Test public void authenticateWhenBasicThenNoSession() { this.spring.register(Config.class).autowire(); - + // @formatter:off WebTestClient client = WebTestClientBuilder - .bindToWebFilters(this.springSecurityFilterChain) - .build(); - + .bindToWebFilters(this.springSecurityFilterChain) + .build(); FluxExchangeResult result = client.get() - .headers(headers -> headers.setBasicAuth("user", "password")) - .exchange() - .expectStatus() - .isOk() - .returnResult(String.class); + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus().isOk() + .returnResult(String.class); + // @formatter:on result.assertWithDiagnostics(() -> assertThat(result.getResponseCookies().isEmpty())); } @@ -140,229 +141,142 @@ public class EnableWebFluxSecurityTests { Authentication currentPrincipal = new TestingAuthenticationToken("user", "password", "ROLE_USER"); WebSessionServerSecurityContextRepository contextRepository = new WebSessionServerSecurityContextRepository(); SecurityContext context = new SecurityContextImpl(currentPrincipal); - WebTestClient client = WebTestClientBuilder.bindToWebFilters( - (exchange, chain) -> contextRepository.save(exchange, context) + // @formatter:off + WebFilter contextRepositoryWebFilter = (exchange, chain) -> contextRepository.save(exchange, context) .switchIfEmpty(chain.filter(exchange)) - .flatMap(e -> chain.filter(exchange)), - this.springSecurityFilterChain, - (exchange, chain) -> - ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .flatMap( principal -> exchange.getResponse() - .writeWith(Mono.just(toDataBuffer(principal.getName())))) - ).build(); + .flatMap((e) -> chain.filter(exchange)); + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(contextRepositoryWebFilter, this.springSecurityFilterChain, writePrincipalWebFilter()) + .build(); + client.get() + .uri("/") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).consumeWith((result) -> assertThat(result.getResponseBody()).isEqualTo(currentPrincipal.getName())); + // @formatter:on + } - client - .get() - .uri("/") - .exchange() - .expectStatus().isOk() - .expectBody(String.class).consumeWith( result -> assertThat(result.getResponseBody()).isEqualTo(currentPrincipal.getName())); + private WebFilter writePrincipalWebFilter() { + // @formatter:off + return (exchange, chain) -> ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .flatMap((principal) -> exchange.getResponse() + .writeWith(Mono.just(toDataBuffer(principal.getName()))) + ); + // @formatter:on } @Test public void defaultPopulatesReactorContextWhenAuthenticating() { this.spring.register(Config.class).autowire(); - WebTestClient client = WebTestClientBuilder.bindToWebFilters( - this.springSecurityFilterChain, - (exchange, chain) -> - ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .flatMap( principal -> exchange.getResponse() - .writeWith(Mono.just(toDataBuffer(principal.getName())))) - ) - .build(); - - client - .get() - .uri("/") - .headers(headers -> headers.setBasicAuth("user", "password")) - .exchange() - .expectStatus().isOk() - .expectBody(String.class).consumeWith( result -> assertThat(result.getResponseBody()).isEqualTo("user")); + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(this.springSecurityFilterChain, writePrincipalWebFilter()) + .build(); + client.get() + .uri("/") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus().isOk() + .expectBody(String.class).consumeWith((result) -> assertThat(result.getResponseBody()).isEqualTo("user")); + // @formatter:on } @Test public void requestDataValueProcessor() { this.spring.register(Config.class).autowire(); - ConfigurableApplicationContext context = this.spring.getContext(); - CsrfRequestDataValueProcessor rdvp = context.getBean(AbstractView.REQUEST_DATA_VALUE_PROCESSOR_BEAN_NAME, CsrfRequestDataValueProcessor.class); + CsrfRequestDataValueProcessor rdvp = context.getBean(AbstractView.REQUEST_DATA_VALUE_PROCESSOR_BEAN_NAME, + CsrfRequestDataValueProcessor.class); assertThat(rdvp).isNotNull(); } - @EnableWebFluxSecurity - @Import(ReactiveAuthenticationTestConfiguration.class) - static class Config { - } - @Test public void passwordEncoderBeanIsUsed() { this.spring.register(CustomPasswordEncoderConfig.class).autowire(); - WebTestClient client = WebTestClientBuilder.bindToWebFilters( - this.springSecurityFilterChain, - (exchange, chain) -> - ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .flatMap( principal -> exchange.getResponse() - .writeWith(Mono.just(toDataBuffer(principal.getName())))) - ) - .build(); - - client - .get() - .uri("/") - .headers(headers -> headers.setBasicAuth("user", "password")) - .exchange() - .expectStatus().isOk() - .expectBody(String.class).consumeWith( result -> assertThat(result.getResponseBody()).isEqualTo("user")); - } - - @EnableWebFluxSecurity - static class CustomPasswordEncoderConfig { - @Bean - public ReactiveUserDetailsService userDetailsService(PasswordEncoder encoder) { - return new MapReactiveUserDetailsService(User.withUsername("user") - .password(encoder.encode("password")) - .roles("USER") - .build() - ); - } - - @Bean - public static PasswordEncoder passwordEncoder() { - return new BCryptPasswordEncoder(); - } + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(this.springSecurityFilterChain, writePrincipalWebFilter()) + .build(); + client.get().uri("/").headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange().expectStatus().isOk() + .expectBody(String.class) + .consumeWith((result) -> assertThat(result.getResponseBody()).isEqualTo("user")); + // @formatter:on } @Test public void passwordUpdateManagerUsed() { this.spring.register(MapReactiveUserDetailsServiceConfig.class).autowire(); - WebTestClient client = WebTestClientBuilder.bindToWebFilters(this.springSecurityFilterChain).build(); - - client - .get() + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(this.springSecurityFilterChain) + .build(); + client.get() .uri("/") - .headers(h -> h.setBasicAuth("user", "password")) + .headers((h) -> h.setBasicAuth("user", "password")) .exchange() .expectStatus().isOk(); - + // @formatter:on ReactiveUserDetailsService users = this.spring.getContext().getBean(ReactiveUserDetailsService.class); assertThat(users.findByUsername("user").block().getPassword()).startsWith("{bcrypt}"); } - @EnableWebFluxSecurity - static class MapReactiveUserDetailsServiceConfig { - @Bean - public MapReactiveUserDetailsService userDetailsService() { - return new MapReactiveUserDetailsService(User.withUsername("user") - .password("{noop}password") - .roles("USER") - .build() - ); - } - } - @Test public void formLoginWorks() { this.spring.register(Config.class).autowire(); - WebTestClient client = WebTestClientBuilder.bindToWebFilters( - this.springSecurityFilterChain, - (exchange, chain) -> - Mono.subscriberContext() - .flatMap( c -> c.>get(Authentication.class)) - .flatMap( principal -> exchange.getResponse() - .writeWith(Mono.just(toDataBuffer(principal.getName())))) - ) - .build(); - - + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(this.springSecurityFilterChain, writePrincipalWebFilter()) + .build(); + // @formatter:on MultiValueMap data = new LinkedMultiValueMap<>(); data.add("username", "user"); data.add("password", "password"); - client - .mutateWith(csrf()) - .post() - .uri("/login") - .body(BodyInserters.fromFormData(data)) - .exchange() - .expectStatus().is3xxRedirection() - .expectHeader().valueMatches("Location", "/"); + // @formatter:off + client.mutateWith(csrf()) + .post() + .uri("/login") + .body(BodyInserters.fromFormData(data)) + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueMatches("Location", "/"); + // @formatter:on } @Test public void multiWorks() { this.spring.register(MultiSecurityHttpConfig.class).autowire(); - WebTestClient client = WebTestClientBuilder.bindToWebFilters(this.springSecurityFilterChain).build(); - + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(this.springSecurityFilterChain) + .build(); client.get() - .uri("/api/test") - .exchange() - .expectStatus().isUnauthorized() - .expectBody().isEmpty(); - + .uri("/api/test") + .exchange() + .expectStatus().isUnauthorized() + .expectBody().isEmpty(); client.get() - .uri("/test") - .exchange() - .expectStatus().isOk(); - } - - @EnableWebFluxSecurity - @Import(ReactiveAuthenticationTestConfiguration.class) - static class MultiSecurityHttpConfig { - @Order(Ordered.HIGHEST_PRECEDENCE) - @Bean - public SecurityWebFilterChain apiHttpSecurity( - ServerHttpSecurity http) { - http.securityMatcher(new PathPatternParserServerWebExchangeMatcher("/api/**")) - .authorizeExchange().anyExchange().denyAll(); - return http.build(); - } - - @Bean - public SecurityWebFilterChain httpSecurity(ServerHttpSecurity http) { - return http.build(); - } + .uri("/test") + .exchange() + .expectStatus().isOk(); + // @formatter:on } @Test @WithMockUser public void authenticationPrincipalArgumentResolverWhenSpelThenWorks() { this.spring.register(AuthenticationPrincipalConfig.class).autowire(); - - WebTestClient client = WebTestClient.bindToApplicationContext(this.spring.getContext()).build(); - + // @formatter:off + WebTestClient client = WebTestClient + .bindToApplicationContext(this.spring.getContext()) + .build(); client.get() - .uri("/spel") - .exchange() - .expectStatus().isOk() - .expectBody(String.class).isEqualTo("user"); - } - - - @EnableWebFluxSecurity - @EnableWebFlux - @Import(ReactiveAuthenticationTestConfiguration.class) - static class AuthenticationPrincipalConfig { - - @Bean - public PrincipalBean principalBean() { - return new PrincipalBean(); - } - - static class PrincipalBean { - public String username(UserDetails user) { - return user.getUsername(); - } - } - - @RestController - public static class AuthenticationPrincipalResolver { - @GetMapping("/spel") - String username(@AuthenticationPrincipal(expression = "@principalBean.username(#this)") String username) { - return username; - } - } + .uri("/spel") + .exchange() + .expectStatus().isOk() + .expectBody(String.class).isEqualTo("user"); + // @formatter:on } private static DataBuffer toDataBuffer(String body) { @@ -374,85 +288,179 @@ public class EnableWebFluxSecurityTests { @Test public void enableWebFluxSecurityWhenNoConfigurationAnnotationThenBeanProxyingEnabled() { this.spring.register(BeanProxyEnabledByDefaultConfig.class).autowire(); - Child childBean = this.spring.getContext().getBean(Child.class); Parent parentBean = this.spring.getContext().getBean(Parent.class); - assertThat(parentBean.getChild()).isSameAs(childBean); } - @EnableWebFluxSecurity - @Import(ReactiveAuthenticationTestConfiguration.class) - static class BeanProxyEnabledByDefaultConfig { - @Bean - public Child child() { - return new Child(); - } - - @Bean - public Parent parent() { - return new Parent(child()); - } - } - @Test public void enableWebFluxSecurityWhenProxyBeanMethodsFalseThenBeanProxyingDisabled() { this.spring.register(BeanProxyDisabledConfig.class).autowire(); - Child childBean = this.spring.getContext().getBean(Child.class); Parent parentBean = this.spring.getContext().getBean(Parent.class); - assertThat(parentBean.getChild()).isNotSameAs(childBean); } + @Test + // gh-8596 + public void resolveAuthenticationPrincipalArgumentResolverFirstDoesNotCauseBeanCurrentlyInCreationException() { + this.spring.register(EnableWebFluxSecurityConfiguration.class, ReactiveAuthenticationTestConfiguration.class, + DelegatingWebFluxConfiguration.class).autowire(); + } + + @EnableWebFluxSecurity + @Import(ReactiveAuthenticationTestConfiguration.class) + static class Config { + + } + + @EnableWebFluxSecurity + static class CustomPasswordEncoderConfig { + + @Bean + ReactiveUserDetailsService userDetailsService(PasswordEncoder encoder) { + return new MapReactiveUserDetailsService( + User.withUsername("user").password(encoder.encode("password")).roles("USER").build()); + } + + @Bean + static PasswordEncoder passwordEncoder() { + return new BCryptPasswordEncoder(); + } + + } + + @EnableWebFluxSecurity + static class MapReactiveUserDetailsServiceConfig { + + @Bean + MapReactiveUserDetailsService userDetailsService() { + // @formatter:off + return new MapReactiveUserDetailsService(User.withUsername("user") + .password("{noop}password") + .roles("USER") + .build() + // @formatter:on + ); + } + + } + + @EnableWebFluxSecurity + @Import(ReactiveAuthenticationTestConfiguration.class) + static class MultiSecurityHttpConfig { + + @Order(Ordered.HIGHEST_PRECEDENCE) + @Bean + SecurityWebFilterChain apiHttpSecurity(ServerHttpSecurity http) { + http.securityMatcher(new PathPatternParserServerWebExchangeMatcher("/api/**")).authorizeExchange() + .anyExchange().denyAll(); + return http.build(); + } + + @Bean + SecurityWebFilterChain httpSecurity(ServerHttpSecurity http) { + return http.build(); + } + + } + + @EnableWebFluxSecurity + @EnableWebFlux + @Import(ReactiveAuthenticationTestConfiguration.class) + static class AuthenticationPrincipalConfig { + + @Bean + PrincipalBean principalBean() { + return new PrincipalBean(); + } + + static class PrincipalBean { + + public String username(UserDetails user) { + return user.getUsername(); + } + + } + + @RestController + static class AuthenticationPrincipalResolver { + + @GetMapping("/spel") + String username(@AuthenticationPrincipal(expression = "@principalBean.username(#this)") String username) { + return username; + } + + } + + } + + @EnableWebFluxSecurity + @Import(ReactiveAuthenticationTestConfiguration.class) + static class BeanProxyEnabledByDefaultConfig { + + @Bean + Child child() { + return new Child(); + } + + @Bean + Parent parent() { + return new Parent(child()); + } + + } + @Configuration(proxyBeanMethods = false) @EnableWebFluxSecurity @Import(ReactiveAuthenticationTestConfiguration.class) static class BeanProxyDisabledConfig { + @Bean - public Child child() { + Child child() { return new Child(); } @Bean - public Parent parent() { + Parent parent() { return new Parent(child()); } + } static class Parent { + private Child child; Parent(Child child) { this.child = child; } - public Child getChild() { - return child; + Child getChild() { + return this.child; } + } static class Child { + Child() { } - } - @Test - // gh-8596 - public void resolveAuthenticationPrincipalArgumentResolverFirstDoesNotCauseBeanCurrentlyInCreationException() { - this.spring.register(EnableWebFluxSecurityConfiguration.class, - ReactiveAuthenticationTestConfiguration.class, - DelegatingWebFluxConfiguration.class).autowire(); } @EnableWebFluxSecurity @Configuration(proxyBeanMethods = false) static class EnableWebFluxSecurityConfiguration { + /** - * It is necessary to Autowire AuthenticationPrincipalArgumentResolver because it triggers eager loading of - * AuthenticationPrincipalArgumentResolver bean which causes BeanCurrentlyInCreationException + * It is necessary to Autowire AuthenticationPrincipalArgumentResolver because it + * triggers eager loading of AuthenticationPrincipalArgumentResolver bean which + * causes BeanCurrentlyInCreationException */ @Autowired AuthenticationPrincipalArgumentResolver resolver; + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationBuilder.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationBuilder.java index f6703a9977..d8922ad6c1 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationBuilder.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationBuilder.java @@ -26,17 +26,21 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsService; * @author Rob Winch * @since 5.0 */ -public class ServerHttpSecurityConfigurationBuilder { +public final class ServerHttpSecurityConfigurationBuilder { + + private ServerHttpSecurityConfigurationBuilder() { + } + public static ServerHttpSecurity http() { return new ServerHttpSecurityConfiguration().httpSecurity(); } public static ServerHttpSecurity httpWithDefaultAuthentication() { ReactiveUserDetailsService reactiveUserDetailsService = ReactiveAuthenticationTestConfiguration - .userDetailsService(); + .userDetailsService(); ReactiveAuthenticationManager authenticationManager = new UserDetailsRepositoryReactiveAuthenticationManager( - reactiveUserDetailsService); - return http() - .authenticationManager(authenticationManager); + reactiveUserDetailsService); + return http().authenticationManager(authenticationManager); } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTest.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java similarity index 97% rename from config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTest.java rename to config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java index 31c07cc2e0..b791eba328 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTest.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.reactive; import org.junit.Rule; import org.junit.Test; + import org.springframework.context.annotation.Configuration; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration; @@ -30,7 +31,8 @@ import static org.assertj.core.api.Assertions.assertThat; * * @author Eleftheria Stein */ -public class ServerHttpSecurityConfigurationTest { +public class ServerHttpSecurityConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -39,7 +41,6 @@ public class ServerHttpSecurityConfigurationTest { this.spring.register(ServerHttpSecurityConfiguration.class, ReactiveAuthenticationTestConfiguration.class, WebFluxSecurityConfiguration.class).autowire(); ServerHttpSecurity serverHttpSecurity = this.spring.getContext().getBean(ServerHttpSecurity.class); - assertThat(serverHttpSecurity).isNotNull(); } @@ -48,11 +49,12 @@ public class ServerHttpSecurityConfigurationTest { this.spring.register(SubclassConfig.class, ReactiveAuthenticationTestConfiguration.class, WebFluxSecurityConfiguration.class).autowire(); ServerHttpSecurity serverHttpSecurity = this.spring.getContext().getBean(ServerHttpSecurity.class); - assertThat(serverHttpSecurity).isNotNull(); } @Configuration static class SubclassConfig extends ServerHttpSecurityConfiguration { + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfigurationTests.java index 0eeec4140a..43fb7fb1da 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfigurationTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web.reactive; import org.junit.Rule; import org.junit.Test; + import org.springframework.context.annotation.Configuration; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration; @@ -31,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Eleftheria Stein */ public class WebFluxSecurityConfigurationTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -39,7 +41,6 @@ public class WebFluxSecurityConfigurationTests { this.spring.register(ServerHttpSecurityConfiguration.class, ReactiveAuthenticationTestConfiguration.class, WebFluxSecurityConfiguration.class).autowire(); WebFilterChainProxy webFilterChainProxy = this.spring.getContext().getBean(WebFilterChainProxy.class); - assertThat(webFilterChainProxy).isNotNull(); } @@ -48,11 +49,12 @@ public class WebFluxSecurityConfigurationTests { this.spring.register(ServerHttpSecurityConfiguration.class, ReactiveAuthenticationTestConfiguration.class, WebFluxSecurityConfigurationTests.SubclassConfig.class).autowire(); WebFilterChainProxy webFilterChainProxy = this.spring.getContext().getBean(WebFilterChainProxy.class); - assertThat(webFilterChainProxy).isNotNull(); } @Configuration static class SubclassConfig extends WebFluxSecurityConfiguration { + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java index 2343208681..de09d8ec07 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.socket; +import java.util.HashMap; + import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; @@ -25,8 +29,6 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; -import static org.springframework.messaging.simp.SimpMessageType.*; - import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.support.GenericMessage; @@ -43,13 +45,11 @@ import org.springframework.web.socket.config.annotation.AbstractWebSocketMessage import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; - -import java.util.HashMap; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { + AnnotationConfigWebApplicationContext context; TestingAuthenticationToken messageUser; @@ -60,25 +60,22 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); - sessionAttr = "sessionAttr"; - messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); + this.token = new DefaultCsrfToken("header", "param", "token"); + this.sessionAttr = "sessionAttr"; + this.messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); } @After public void cleanup() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @Test public void securityMappings() { loadConfig(WebSocketSecurityConfig.class); - - clientInboundChannel().send( - message("/user/queue/errors", SimpMessageType.SUBSCRIBE)); - + clientInboundChannel().send(message("/user/queue/errors", SimpMessageType.SUBSCRIBE)); try { clientInboundChannel().send(message("/denyAll", SimpMessageType.MESSAGE)); fail("Expected Exception"); @@ -89,15 +86,15 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { } private void loadConfig(Class... configs) { - context = new AnnotationConfigWebApplicationContext(); - context.register(configs); - context.register(WebSocketConfig.class, SyncExecutorConfig.class); - context.setServletConfig(new MockServletConfig()); - context.refresh(); + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.register(WebSocketConfig.class, SyncExecutorConfig.class); + this.context.setServletConfig(new MockServletConfig()); + this.context.refresh(); } private MessageChannel clientInboundChannel() { - return context.getBean("clientInboundChannel", MessageChannel.class); + return this.context.getBean("clientInboundChannel", MessageChannel.class); } private Message message(String destination, SimpMessageType type) { @@ -111,8 +108,8 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { if (destination != null) { headers.setDestination(destination); } - if (messageUser != null) { - headers.setUser(messageUser); + if (this.messageUser != null) { + headers.setUser(this.messageUser); } return new GenericMessage<>("hi", headers.getMessageHeaders()); } @@ -121,14 +118,14 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { static class MyController { @MessageMapping("/authentication") - public void authentication(@AuthenticationPrincipal String un) { + void authentication(@AuthenticationPrincipal String un) { // ... do something ... } + } @Configuration - static class WebSocketSecurityConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { + static class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { @Override protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { @@ -138,18 +135,18 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { // <2> .simpDestMatchers("/app/**").hasRole("USER") // <3> - .simpSubscribeDestMatchers("/user/**", "/topic/friends/*") - .hasRole("USER") // <4> - .simpTypeMatchers(MESSAGE, SUBSCRIBE).denyAll() // <5> + .simpSubscribeDestMatchers("/user/**", "/topic/friends/*").hasRole("USER") // <4> + .simpTypeMatchers(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE).denyAll() // <5> .anyMessage().denyAll(); // <6> - } + } @Configuration @EnableWebSocketMessageBroker static class WebSocketConfig extends AbstractWebSocketMessageBrokerConfigurer { + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/chat").withSockJS(); } @@ -161,16 +158,20 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { } @Bean - public MyController myController() { + MyController myController() { return new MyController(); } + } @Configuration static class SyncExecutorConfig { + @Bean - public static SyncExecutorSubscribableChannelPostProcessor postProcessor() { + static SyncExecutorSubscribableChannelPostProcessor postProcessor() { return new SyncExecutorSubscribableChannelPostProcessor(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index 2b4887acd6..c44063c236 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.annotation.web.socket; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +package org.springframework.security.config.annotation.web.socket; import java.util.HashMap; import java.util.Map; @@ -26,6 +24,7 @@ import javax.servlet.http.HttpServletRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -78,7 +77,11 @@ import org.springframework.web.socket.server.support.HttpSessionHandshakeInterce import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler; import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { + AnnotationConfigWebApplicationContext context; TestingAuthenticationToken messageUser; @@ -89,24 +92,22 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); - sessionAttr = "sessionAttr"; - messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); + this.token = new DefaultCsrfToken("header", "param", "token"); + this.sessionAttr = "sessionAttr"; + this.messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); } @After public void cleanup() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @Test public void simpleRegistryMappings() { loadConfig(SockJsSecurityConfig.class); - clientInboundChannel().send(message("/permitAll")); - try { clientInboundChannel().send(message("/denyAll")); fail("Expected Exception"); @@ -119,8 +120,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Test public void annonymousSupported() { loadConfig(SockJsSecurityConfig.class); - - messageUser = null; + this.messageUser = null; clientInboundChannel().send(message("/permitAll")); } @@ -128,44 +128,36 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Test public void beanResolver() { loadConfig(SockJsSecurityConfig.class); - - messageUser = null; + this.messageUser = null; clientInboundChannel().send(message("/beanResolver")); } @Test public void addsAuthenticationPrincipalResolver() { loadConfig(SockJsSecurityConfig.class); - MessageChannel messageChannel = clientInboundChannel(); Message message = message("/permitAll/authentication"); messageChannel.send(message); - - assertThat(context.getBean(MyController.class).authenticationPrincipal) - .isEqualTo((String) messageUser.getPrincipal()); + assertThat(this.context.getBean(MyController.class).authenticationPrincipal) + .isEqualTo((String) this.messageUser.getPrincipal()); } @Test public void addsAuthenticationPrincipalResolverWhenNoAuthorization() { loadConfig(NoInboundSecurityConfig.class); - MessageChannel messageChannel = clientInboundChannel(); Message message = message("/permitAll/authentication"); messageChannel.send(message); - - assertThat(context.getBean(MyController.class).authenticationPrincipal) - .isEqualTo((String) messageUser.getPrincipal()); + assertThat(this.context.getBean(MyController.class).authenticationPrincipal) + .isEqualTo((String) this.messageUser.getPrincipal()); } @Test public void addsCsrfProtectionWhenNoAuthorization() { loadConfig(NoInboundSecurityConfig.class); - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor - .create(SimpMessageType.CONNECT); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); Message message = message(headers, "/authentication"); MessageChannel messageChannel = clientInboundChannel(); - try { messageChannel.send(message); fail("Expected Exception"); @@ -178,12 +170,9 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Test public void csrfProtectionForConnect() { loadConfig(SockJsSecurityConfig.class); - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor - .create(SimpMessageType.CONNECT); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); Message message = message(headers, "/authentication"); MessageChannel messageChannel = clientInboundChannel(); - try { messageChannel.send(message); fail("Expected Exception"); @@ -196,79 +185,57 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Test public void csrfProtectionDisabledForConnect() { loadConfig(CsrfDisabledSockJsSecurityConfig.class); - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor - .create(SimpMessageType.CONNECT); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); Message message = message(headers, "/permitAll/connect"); MessageChannel messageChannel = clientInboundChannel(); - messageChannel.send(message); } @Test public void csrfProtectionDefinedByBean() { loadConfig(SockJsProxylessSecurityConfig.class); - MessageChannel messageChannel = clientInboundChannel(); - CsrfChannelInterceptor csrfChannelInterceptor = context.getBean(CsrfChannelInterceptor.class); - + CsrfChannelInterceptor csrfChannelInterceptor = this.context.getBean(CsrfChannelInterceptor.class); assertThat(((AbstractMessageChannel) messageChannel).getInterceptors()).contains(csrfChannelInterceptor); } @Test public void messagesConnectUseCsrfTokenHandshakeInterceptor() throws Exception { - loadConfig(SockJsSecurityConfig.class); - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor - .create(SimpMessageType.CONNECT); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); Message message = message(headers, "/authentication"); MockHttpServletRequest request = sockjsHttpRequest("/chat"); HttpRequestHandler handler = handler(request); - handler.handleRequest(request, new MockHttpServletResponse()); - assertHandshake(request); } @Test - public void messagesConnectUseCsrfTokenHandshakeInterceptorMultipleMappings() - throws Exception { + public void messagesConnectUseCsrfTokenHandshakeInterceptorMultipleMappings() throws Exception { loadConfig(SockJsSecurityConfig.class); - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor - .create(SimpMessageType.CONNECT); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); Message message = message(headers, "/authentication"); MockHttpServletRequest request = sockjsHttpRequest("/other"); HttpRequestHandler handler = handler(request); - handler.handleRequest(request, new MockHttpServletResponse()); - assertHandshake(request); } @Test - public void messagesConnectWebSocketUseCsrfTokenHandshakeInterceptor() - throws Exception { + public void messagesConnectWebSocketUseCsrfTokenHandshakeInterceptor() throws Exception { loadConfig(WebSocketSecurityConfig.class); - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor - .create(SimpMessageType.CONNECT); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); Message message = message(headers, "/authentication"); MockHttpServletRequest request = websocketHttpRequest("/websocket"); HttpRequestHandler handler = handler(request); - handler.handleRequest(request, new MockHttpServletResponse()); - assertHandshake(request); } @Test public void msmsRegistryCustomPatternMatcher() { loadConfig(MsmsRegistryCustomPatternMatcherConfig.class); - clientInboundChannel().send(message("/app/a.b")); - try { clientInboundChannel().send(message("/app/a.b.c")); fail("Expected Exception"); @@ -278,48 +245,10 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { } } - @Configuration - @EnableWebSocketMessageBroker - @Import(SyncExecutorConfig.class) - static class MsmsRegistryCustomPatternMatcherConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { - - // @formatter:off - public void registerStompEndpoints(StompEndpointRegistry registry) { - registry - .addEndpoint("/other") - .setHandshakeHandler(testHandshakeHandler()); - } - // @formatter:on - - // @formatter:off - @Override - protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { - messages - .simpDestMatchers("/app/a.*").permitAll() - .anyMessage().denyAll(); - } - // @formatter:on - - @Override - public void configureMessageBroker(MessageBrokerRegistry registry) { - registry.setPathMatcher(new AntPathMatcher(".")); - registry.enableSimpleBroker("/queue/", "/topic/"); - registry.setApplicationDestinationPrefixes("/app"); - } - - @Bean - public TestHandshakeHandler testHandshakeHandler() { - return new TestHandshakeHandler(); - } - } - @Test public void overrideMsmsRegistryCustomPatternMatcher() { loadConfig(OverrideMsmsRegistryCustomPatternMatcherConfig.class); - clientInboundChannel().send(message("/app/a/b")); - try { clientInboundChannel().send(message("/app/a/b/c")); fail("Expected Exception"); @@ -329,50 +258,10 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { } } - @Configuration - @EnableWebSocketMessageBroker - @Import(SyncExecutorConfig.class) - static class OverrideMsmsRegistryCustomPatternMatcherConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { - - // @formatter:off - public void registerStompEndpoints(StompEndpointRegistry registry) { - registry - .addEndpoint("/other") - .setHandshakeHandler(testHandshakeHandler()); - } - // @formatter:on - - - // @formatter:off - @Override - protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { - messages - .simpDestPathMatcher(new AntPathMatcher()) - .simpDestMatchers("/app/a/*").permitAll() - .anyMessage().denyAll(); - } - // @formatter:on - - @Override - public void configureMessageBroker(MessageBrokerRegistry registry) { - registry.setPathMatcher(new AntPathMatcher(".")); - registry.enableSimpleBroker("/queue/", "/topic/"); - registry.setApplicationDestinationPrefixes("/app"); - } - - @Bean - public TestHandshakeHandler testHandshakeHandler() { - return new TestHandshakeHandler(); - } - } - @Test public void defaultPatternMatcher() { loadConfig(DefaultPatternMatcherConfig.class); - clientInboundChannel().send(message("/app/a/b")); - try { clientInboundChannel().send(message("/app/a/b/c")); fail("Expected Exception"); @@ -382,47 +271,10 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { } } - @Configuration - @EnableWebSocketMessageBroker - @Import(SyncExecutorConfig.class) - static class DefaultPatternMatcherConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { - - // @formatter:off - public void registerStompEndpoints(StompEndpointRegistry registry) { - registry - .addEndpoint("/other") - .setHandshakeHandler(testHandshakeHandler()); - } - // @formatter:on - - // @formatter:off - @Override - protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { - messages - .simpDestMatchers("/app/a/*").permitAll() - .anyMessage().denyAll(); - } - // @formatter:on - - @Override - public void configureMessageBroker(MessageBrokerRegistry registry) { - registry.enableSimpleBroker("/queue/", "/topic/"); - registry.setApplicationDestinationPrefixes("/app"); - } - - @Bean - public TestHandshakeHandler testHandshakeHandler() { - return new TestHandshakeHandler(); - } - } - @Test public void customExpression() { loadConfig(CustomExpressionConfig.class); - clientInboundChannel().send(message("/denyRob")); - this.messageUser = new TestingAuthenticationToken("rob", "password", "ROLE_USER"); try { clientInboundChannel().send(message("/denyRob")); @@ -435,24 +287,19 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Test public void channelSecurityInterceptorUsesMetadataSourceBeanWhenProxyingDisabled() { - loadConfig(SockJsProxylessSecurityConfig.class); - - ChannelSecurityInterceptor channelSecurityInterceptor = context.getBean(ChannelSecurityInterceptor.class); - MessageSecurityMetadataSource messageSecurityMetadataSource = - context.getBean(MessageSecurityMetadataSource.class); - + ChannelSecurityInterceptor channelSecurityInterceptor = this.context.getBean(ChannelSecurityInterceptor.class); + MessageSecurityMetadataSource messageSecurityMetadataSource = this.context + .getBean(MessageSecurityMetadataSource.class); assertThat(channelSecurityInterceptor.obtainSecurityMetadataSource()).isSameAs(messageSecurityMetadataSource); } @Test public void securityContextChannelInterceptorDefinedByBean() { loadConfig(SockJsProxylessSecurityConfig.class); - MessageChannel messageChannel = clientInboundChannel(); - SecurityContextChannelInterceptor securityContextChannelInterceptor = - context.getBean(SecurityContextChannelInterceptor.class); - + SecurityContextChannelInterceptor securityContextChannelInterceptor = this.context + .getBean(SecurityContextChannelInterceptor.class); assertThat(((AbstractMessageChannel) messageChannel).getInterceptors()) .contains(securityContextChannelInterceptor); } @@ -460,28 +307,186 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Test public void inboundChannelSecurityDefinedByBean() { loadConfig(SockJsProxylessSecurityConfig.class); - MessageChannel messageChannel = clientInboundChannel(); - ChannelSecurityInterceptor inboundChannelSecurity = context.getBean(ChannelSecurityInterceptor.class); + ChannelSecurityInterceptor inboundChannelSecurity = this.context.getBean(ChannelSecurityInterceptor.class); + assertThat(((AbstractMessageChannel) messageChannel).getInterceptors()).contains(inboundChannelSecurity); + } - assertThat(((AbstractMessageChannel) messageChannel).getInterceptors()) - .contains(inboundChannelSecurity); + private void assertHandshake(HttpServletRequest request) { + TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); + assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThat(handshakeHandler.attributes.get(this.sessionAttr)) + .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); + } + + private HttpRequestHandler handler(HttpServletRequest request) throws Exception { + HandlerMapping handlerMapping = this.context.getBean(HandlerMapping.class); + return (HttpRequestHandler) handlerMapping.getHandler(request).getHandler(); + } + + private MockHttpServletRequest websocketHttpRequest(String mapping) { + MockHttpServletRequest request = sockjsHttpRequest(mapping); + request.setRequestURI(mapping); + return request; + } + + private MockHttpServletRequest sockjsHttpRequest(String mapping) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + request.setMethod("GET"); + request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); + request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); + request.getSession().setAttribute(this.sessionAttr, "sessionValue"); + request.setAttribute(CsrfToken.class.getName(), this.token); + return request; + } + + private Message message(String destination) { + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + return message(headers, destination); + } + + private Message message(SimpMessageHeaderAccessor headers, String destination) { + headers.setSessionId("123"); + headers.setSessionAttributes(new HashMap<>()); + if (destination != null) { + headers.setDestination(destination); + } + if (this.messageUser != null) { + headers.setUser(this.messageUser); + } + return new GenericMessage<>("hi", headers.getMessageHeaders()); + } + + private MessageChannel clientInboundChannel() { + return this.context.getBean("clientInboundChannel", MessageChannel.class); + } + + private void loadConfig(Class... configs) { + this.context = new AnnotationConfigWebApplicationContext(); + this.context.register(configs); + this.context.setServletConfig(new MockServletConfig()); + this.context.refresh(); } @Configuration @EnableWebSocketMessageBroker @Import(SyncExecutorConfig.class) - static class CustomExpressionConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { + static class MsmsRegistryCustomPatternMatcherConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { // @formatter:off + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry .addEndpoint("/other") .setHandshakeHandler(testHandshakeHandler()); } // @formatter:on + // @formatter:off + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestMatchers("/app/a.*").permitAll() + .anyMessage().denyAll(); + } + // @formatter:on + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.setPathMatcher(new AntPathMatcher(".")); + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/app"); + } + @Bean + TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class OverrideMsmsRegistryCustomPatternMatcherConfig + extends AbstractSecurityWebSocketMessageBrokerConfigurer { + + // @formatter:off + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()); + } + // @formatter:on + // @formatter:off + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestPathMatcher(new AntPathMatcher()) + .simpDestMatchers("/app/a/*").permitAll() + .anyMessage().denyAll(); + } + // @formatter:on + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.setPathMatcher(new AntPathMatcher(".")); + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/app"); + } + + @Bean + TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class DefaultPatternMatcherConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + + // @formatter:off + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()); + } + // @formatter:on + // @formatter:off + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestMatchers("/app/a/*").permitAll() + .anyMessage().denyAll(); + } + // @formatter:on + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/app"); + } + + @Bean + TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class CustomExpressionConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + + // @formatter:off + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()); + } + // @formatter:on // @formatter:off @Override protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { @@ -489,13 +494,11 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { .anyMessage().access("denyRob()"); } // @formatter:on - @Bean - public static SecurityExpressionHandler> messageSecurityExpressionHandler() { + static SecurityExpressionHandler> messageSecurityExpressionHandler() { return new DefaultMessageSecurityExpressionHandler() { @Override - protected SecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, + protected SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, Message invocation) { return new MessageSecurityExpressionRoot(authentication, invocation) { public boolean denyRob() { @@ -517,72 +520,14 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { public TestHandshakeHandler testHandshakeHandler() { return new TestHandshakeHandler(); } - } - private void assertHandshake(HttpServletRequest request) { - TestHandshakeHandler handshakeHandler = context - .getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs( - token); - assertThat(handshakeHandler.attributes.get(sessionAttr)).isEqualTo( - request.getSession().getAttribute(sessionAttr)); - } - - private HttpRequestHandler handler(HttpServletRequest request) throws Exception { - HandlerMapping handlerMapping = context.getBean(HandlerMapping.class); - return (HttpRequestHandler) handlerMapping.getHandler(request).getHandler(); - } - - private MockHttpServletRequest websocketHttpRequest(String mapping) { - MockHttpServletRequest request = sockjsHttpRequest(mapping); - request.setRequestURI(mapping); - return request; - } - - private MockHttpServletRequest sockjsHttpRequest(String mapping) { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); - request.setMethod("GET"); - request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, - "/289/tpyx6mde/websocket"); - request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); - request.getSession().setAttribute(sessionAttr, "sessionValue"); - - request.setAttribute(CsrfToken.class.getName(), token); - return request; - } - - private Message message(String destination) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); - return message(headers, destination); - } - - private Message message(SimpMessageHeaderAccessor headers, String destination) { - headers.setSessionId("123"); - headers.setSessionAttributes(new HashMap<>()); - if (destination != null) { - headers.setDestination(destination); - } - if (messageUser != null) { - headers.setUser(messageUser); - } - return new GenericMessage<>("hi", headers.getMessageHeaders()); - } - - private MessageChannel clientInboundChannel() { - return context.getBean("clientInboundChannel", MessageChannel.class); - } - - private void loadConfig(Class... configs) { - context = new AnnotationConfigWebApplicationContext(); - context.register(configs); - context.setServletConfig(new MockServletConfig()); - context.refresh(); } @Controller static class MyController { String authenticationPrincipal; + MyCustomArgument myCustomArgument; @MessageMapping("/authentication") @@ -594,29 +539,36 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { public void myCustom(MyCustomArgument myCustomArgument) { this.myCustomArgument = myCustomArgument; } + } static class MyCustomArgument { + MyCustomArgument(String notDefaultConstr) { } + } static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver { + @Override public boolean supportsParameter(MethodParameter parameter) { return parameter.getParameterType().isAssignableFrom(MyCustomArgument.class); } + @Override public Object resolveArgument(MethodParameter parameter, Message message) { return new MyCustomArgument(""); } + } static class TestHandshakeHandler implements HandshakeHandler { + Map attributes; - public boolean doHandshake(ServerHttpRequest request, - ServerHttpResponse response, WebSocketHandler wsHandler, + @Override + public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { this.attributes = attributes; if (wsHandler instanceof SockJsWebSocketHandler) { @@ -628,20 +580,22 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { } return true; } + } @Configuration @EnableWebSocketMessageBroker @Import(SyncExecutorConfig.class) - static class SockJsSecurityConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { + static class SockJsSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { + // @formatter:off registry.addEndpoint("/other").setHandshakeHandler(testHandshakeHandler()) .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor()); - registry.addEndpoint("/chat").setHandshakeHandler(testHandshakeHandler()) .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor()); + // @formatter:on } // @formatter:off @@ -653,7 +607,6 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { .anyMessage().denyAll(); } // @formatter:on - @Override public void configureMessageBroker(MessageBrokerRegistry registry) { registry.enableSimpleBroker("/queue/", "/topic/"); @@ -676,27 +629,31 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { } static class SecurityCheck { + private boolean check; public boolean check() { - check = !check; - return check; + this.check = !this.check; + return this.check; } + } + } @Configuration @EnableWebSocketMessageBroker @Import(SyncExecutorConfig.class) - static class NoInboundSecurityConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { + static class NoInboundSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { - registry.addEndpoint("/other").withSockJS() - .setInterceptors(new HttpSessionHandshakeInterceptor()); - - registry.addEndpoint("/chat").withSockJS() - .setInterceptors(new HttpSessionHandshakeInterceptor()); + // @formatter:off + registry.addEndpoint("/other") + .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor()); + registry.addEndpoint("/chat") + .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor()); + // @formatter:on } @Override @@ -713,6 +670,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { public MyController myController() { return new MyController(); } + } @Configuration @@ -722,47 +680,54 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { protected boolean sameOriginDisabled() { return true; } + } @Configuration @EnableWebSocketMessageBroker @Import(SyncExecutorConfig.class) - static class WebSocketSecurityConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { + static class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { + // @formatter:off registry.addEndpoint("/websocket") .setHandshakeHandler(testHandshakeHandler()) .addInterceptors(new HttpSessionHandshakeInterceptor()); + // @formatter:on } - // @formatter:off @Override protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + // @formatter:off messages .simpDestMatchers("/permitAll/**").permitAll() .simpDestMatchers("/customExpression/**").access("denyRob") .anyMessage().denyAll(); + // @formatter:on } - // @formatter:on @Bean public TestHandshakeHandler testHandshakeHandler() { return new TestHandshakeHandler(); } + } @Configuration(proxyBeanMethods = false) @EnableWebSocketMessageBroker @Import(SyncExecutorConfig.class) - static class SockJsProxylessSecurityConfig extends - AbstractSecurityWebSocketMessageBrokerConfigurer { + static class SockJsProxylessSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + private ApplicationContext context; + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { + // @formatter:off registry.addEndpoint("/chat") - .setHandshakeHandler(context.getBean(TestHandshakeHandler.class)) + .setHandshakeHandler(this.context.getBean(TestHandshakeHandler.class)) .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor()); + // @formatter:on } @Autowired @@ -777,18 +742,21 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { .anyMessage().denyAll(); } // @formatter:on - @Bean public TestHandshakeHandler testHandshakeHandler() { return new TestHandshakeHandler(); } + } @Configuration static class SyncExecutorConfig { + @Bean public static SyncExecutorSubscribableChannelPostProcessor postProcessor() { return new SyncExecutorSubscribableChannelPostProcessor(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java index 6f3074f506..fa1bf08d26 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.annotation.web.socket; import org.springframework.beans.BeansException; @@ -24,8 +25,8 @@ import org.springframework.messaging.support.ExecutorSubscribableChannel; */ public class SyncExecutorSubscribableChannelPostProcessor implements BeanPostProcessor { - public Object postProcessBeforeInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { if (bean instanceof ExecutorSubscribableChannel) { ExecutorSubscribableChannel original = (ExecutorSubscribableChannel) bean; ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); @@ -35,8 +36,9 @@ public class SyncExecutorSubscribableChannelPostProcessor implements BeanPostPro return bean; } - public Object postProcessAfterInitialization(Object bean, String beanName) - throws BeansException { + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { return bean; } + } diff --git a/config/src/test/java/org/springframework/security/config/authentication/AuthenticationConfigurationGh3935Tests.java b/config/src/test/java/org/springframework/security/config/authentication/AuthenticationConfigurationGh3935Tests.java index 7b13b8f08e..bdc4383017 100644 --- a/config/src/test/java/org/springframework/security/config/authentication/AuthenticationConfigurationGh3935Tests.java +++ b/config/src/test/java/org/springframework/security/config/authentication/AuthenticationConfigurationGh3935Tests.java @@ -38,9 +38,9 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -48,10 +48,13 @@ import static org.mockito.Mockito.when; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration public class AuthenticationConfigurationGh3935Tests { + @Autowired FilterChainProxy springSecurityFilterChain; + @Autowired UserDetailsService uds; + @Autowired BootGlobalAuthenticationConfigurationAdapter adapter; @@ -65,24 +68,22 @@ public class AuthenticationConfigurationGh3935Tests { public void delegateUsesExisitingAuthentication() { String username = "user"; String password = "password"; - when(this.uds.loadUserByUsername(username)).thenReturn(PasswordEncodedUser.user()); - + given(this.uds.loadUserByUsername(username)).willReturn(PasswordEncodedUser.user()); AuthenticationManager authenticationManager = this.adapter.authenticationManager; assertThat(authenticationManager).isNotNull(); - - Authentication auth = authenticationManager.authenticate( - new UsernamePasswordAuthenticationToken(username, password)); - + Authentication auth = authenticationManager + .authenticate(new UsernamePasswordAuthenticationToken(username, password)); verify(this.uds).loadUserByUsername(username); assertThat(auth.getPrincipal()).isEqualTo(PasswordEncodedUser.user()); } @EnableWebSecurity static class WebSecurity extends WebSecurityConfigurerAdapter { + } - static class BootGlobalAuthenticationConfigurationAdapter - extends GlobalAuthenticationConfigurerAdapter { + static class BootGlobalAuthenticationConfigurationAdapter extends GlobalAuthenticationConfigurerAdapter { + private final ApplicationContext context; private AuthenticationManager authenticationManager; @@ -94,23 +95,25 @@ public class AuthenticationConfigurationGh3935Tests { @Override public void init(AuthenticationManagerBuilder auth) throws Exception { - AuthenticationConfiguration configuration = this.context - .getBean(AuthenticationConfiguration.class); + AuthenticationConfiguration configuration = this.context.getBean(AuthenticationConfiguration.class); this.authenticationManager = configuration.getAuthenticationManager(); } + } @Configuration static class AutoConfig { + @Bean - static BootGlobalAuthenticationConfigurationAdapter adapter( - ApplicationContext context) { + static BootGlobalAuthenticationConfigurationAdapter adapter(ApplicationContext context) { return new BootGlobalAuthenticationConfigurationAdapter(context); } @Bean - public UserDetailsService userDetailsService() { + UserDetailsService userDetailsService() { return mock(UserDetailsService.class); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParserTests.java index 6d64ca25eb..a7654277b1 100644 --- a/config/src/test/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/authentication/AuthenticationManagerBeanDefinitionParserTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; +import java.util.ArrayList; +import java.util.List; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; import org.springframework.context.ConfigurableApplicationContext; @@ -29,19 +34,17 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.util.FieldUtils; import org.springframework.test.web.servlet.MockMvc; -import java.util.ArrayList; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Luke Taylor */ public class AuthenticationManagerBeanDefinitionParserTests { + + // @formatter:off private static final String CONTEXT = "" + " " + " " @@ -49,49 +52,43 @@ public class AuthenticationManagerBeanDefinitionParserTests { + " " + " " + ""; + // @formatter:on + @Rule public final SpringTestRule spring = new SpringTestRule(); @Test // SEC-1225 public void providersAreRegisteredAsTopLevelBeans() { - ConfigurableApplicationContext context = this.spring.context(CONTEXT) - .getContext(); + ConfigurableApplicationContext context = this.spring.context(CONTEXT).getContext(); assertThat(context.getBeansOfType(AuthenticationProvider.class)).hasSize(1); } @Test public void eventsArePublishedByDefault() throws Exception { - ConfigurableApplicationContext appContext = this.spring.context(CONTEXT) - .getContext(); + ConfigurableApplicationContext appContext = this.spring.context(CONTEXT).getContext(); AuthListener listener = new AuthListener(); appContext.addApplicationListener(listener); - - ProviderManager pm = (ProviderManager) appContext - .getBeansOfType(ProviderManager.class).values().toArray()[0]; + ProviderManager pm = (ProviderManager) appContext.getBeansOfType(ProviderManager.class).values().toArray()[0]; Object eventPublisher = FieldUtils.getFieldValue(pm, "eventPublisher"); assertThat(eventPublisher).isNotNull(); assertThat(eventPublisher instanceof DefaultAuthenticationEventPublisher).isTrue(); - pm.authenticate(new UsernamePasswordAuthenticationToken("bob", "bobspassword")); assertThat(listener.events).hasSize(1); } @Test public void credentialsAreClearedByDefault() { - ConfigurableApplicationContext appContext = this.spring.context(CONTEXT) - .getContext(); - ProviderManager pm = (ProviderManager) appContext - .getBeansOfType(ProviderManager.class).values().toArray()[0]; + ConfigurableApplicationContext appContext = this.spring.context(CONTEXT).getContext(); + ProviderManager pm = (ProviderManager) appContext.getBeansOfType(ProviderManager.class).values().toArray()[0]; assertThat(pm.isEraseCredentialsAfterAuthentication()).isTrue(); } @Test public void clearCredentialsPropertyIsRespected() { - ConfigurableApplicationContext appContext = this.spring.context("") - .getContext(); - ProviderManager pm = (ProviderManager) appContext - .getBeansOfType(ProviderManager.class).values().toArray()[0]; + ConfigurableApplicationContext appContext = this.spring + .context("").getContext(); + ProviderManager pm = (ProviderManager) appContext.getBeansOfType(ProviderManager.class).values().toArray()[0]; assertThat(pm.isEraseCredentialsAfterAuthentication()).isFalse(); } @@ -100,24 +97,28 @@ public class AuthenticationManagerBeanDefinitionParserTests { @Test public void passwordEncoderBeanUsed() throws Exception { + // @formatter:off this.spring.context("" - + "" - + " " - + "" - + "") - .mockMvcAfterSpringSecurityOk() - .autowire(); - + + "" + + " " + + "" + + "") + .mockMvcAfterSpringSecurityOk() + .autowire(); this.mockMvc.perform(get("/").with(httpBasic("user", "password"))) - .andExpect(status().isOk()); + .andExpect(status().isOk()); + // @formatter:on } - private static class AuthListener implements - ApplicationListener { + private static class AuthListener implements ApplicationListener { + List events = new ArrayList<>(); + @Override public void onApplicationEvent(AbstractAuthenticationEvent event) { this.events.add(event); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParserTests.java index 968f1f9ae1..b1e453d708 100644 --- a/config/src/test/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/authentication/AuthenticationProviderBeanDefinitionParserTests.java @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; +import java.util.List; + +import org.junit.After; +import org.junit.Test; + +import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; +import org.springframework.context.support.AbstractXmlApplicationContext; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.ProviderManager; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -22,13 +30,6 @@ import org.springframework.security.config.BeanIds; import org.springframework.security.config.util.InMemoryXmlApplicationContext; import org.springframework.security.crypto.password.LdapShaPasswordEncoder; import org.springframework.security.crypto.password.MessageDigestPasswordEncoder; -import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; -import org.springframework.context.support.AbstractXmlApplicationContext; - -import org.junit.Test; -import org.junit.After; - -import java.util.List; /** * Tests for {@link AuthenticationProviderBeanDefinitionParser}. @@ -36,53 +37,60 @@ import java.util.List; * @author Luke Taylor */ public class AuthenticationProviderBeanDefinitionParserTests { + private AbstractXmlApplicationContext appContext; - private UsernamePasswordAuthenticationToken bob = new UsernamePasswordAuthenticationToken( - "bob", "bobspassword"); + + private UsernamePasswordAuthenticationToken bob = new UsernamePasswordAuthenticationToken("bob", "bobspassword"); @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); + if (this.appContext != null) { + this.appContext.close(); } } @Test public void worksWithEmbeddedUserService() { + // @formatter:off setContext(" " + " " + " " - + " " + " "); - getProvider().authenticate(bob); + + " " + + " "); + // @formatter:on + getProvider().authenticate(this.bob); } @Test public void externalUserServiceRefWorks() { - appContext = new InMemoryXmlApplicationContext( + // @formatter:off + this.appContext = new InMemoryXmlApplicationContext( " " - + " " - + " " - + " " - + " " - + " "); - getProvider().authenticate(bob); + + " " + + " " + " " + + " " + + " "); + // @formatter:on + getProvider().authenticate(this.bob); } @Test public void providerWithBCryptPasswordEncoderWorks() { + // @formatter:off setContext(" " + " " + " " + " " - + " " + " "); - - getProvider().authenticate(bob); + + " " + // @formatter:on + + " "); + getProvider().authenticate(this.bob); } @Test public void providerWithMd5PasswordEncoderWorks() { - appContext = new InMemoryXmlApplicationContext( - " " + // @formatter:off + this.appContext = new InMemoryXmlApplicationContext(" " + " " + " " + " " @@ -90,18 +98,17 @@ public class AuthenticationProviderBeanDefinitionParserTests { + " " + " " + " " - + " " + + " " + " " + " "); - - getProvider().authenticate(bob); + // @formatter:on + getProvider().authenticate(this.bob); } @Test public void providerWithShaPasswordEncoderWorks() { - appContext = new InMemoryXmlApplicationContext( - " " + // @formatter:off + this.appContext = new InMemoryXmlApplicationContext(" " + " " + " " + " " @@ -109,16 +116,15 @@ public class AuthenticationProviderBeanDefinitionParserTests { + " " + " " + " " - + " "); - - getProvider().authenticate(bob); + + " "); + // @formatter:on + getProvider().authenticate(this.bob); } @Test public void passwordIsBase64EncodedWhenBase64IsEnabled() { - appContext = new InMemoryXmlApplicationContext( - " " + // @formatter:off + this.appContext = new InMemoryXmlApplicationContext(" " + " " + " " + " " @@ -126,38 +132,37 @@ public class AuthenticationProviderBeanDefinitionParserTests { + " " + " " + " " - + " " - + " " - + " " + + " " + + " " + " " + " "); - - getProvider().authenticate(bob); + // @formatter:on + getProvider().authenticate(this.bob); } // SEC-1466 @Test(expected = BeanDefinitionParsingException.class) public void exernalProviderDoesNotSupportChildElements() { - appContext = new InMemoryXmlApplicationContext( - " " - + " " - + " " - + " " - + " " - + " " - + " "); + // @formatter:off + this.appContext = new InMemoryXmlApplicationContext(" " + + " " + + " " + + " " + + " " + + " " + + " "); + // @formatter:on } private AuthenticationProvider getProvider() { - List providers = ((ProviderManager) appContext + List providers = ((ProviderManager) this.appContext .getBean(BeanIds.AUTHENTICATION_MANAGER)).getProviders(); - return providers.get(0); } private void setContext(String context) { - appContext = new InMemoryXmlApplicationContext("" - + context + ""); + this.appContext = new InMemoryXmlApplicationContext( + "" + context + ""); } + } diff --git a/config/src/test/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParserTests.java index 14b833cd24..a8cd7ceb0b 100644 --- a/config/src/test/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/authentication/JdbcUserServiceBeanDefinitionParserTests.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; import org.junit.After; import org.junit.Test; -import org.springframework.security.authentication.CachingUserDetailsService; import org.w3c.dom.Element; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.CachingUserDetailsService; import org.springframework.security.authentication.ProviderManager; -import org.springframework.security.authentication - .UsernamePasswordAuthenticationToken; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.dao.DaoAuthenticationProvider; import org.springframework.security.config.BeanIds; import org.springframework.security.config.util.InMemoryXmlApplicationContext; @@ -41,57 +41,58 @@ import static org.mockito.Mockito.mock; * @author Eddú Meléndez */ public class JdbcUserServiceBeanDefinitionParserTests { + private static String USER_CACHE_XML = ""; + // @formatter:off private static String DATA_SOURCE = " " + " " + " " - + - - " " - + " " + " "; + + " " + + " " + + " "; + // @formatter:on private InMemoryXmlApplicationContext appContext; @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); + if (this.appContext != null) { + this.appContext.close(); } } @Test public void beanNameIsCorrect() { - assertThat(JdbcUserDetailsManager.class.getName()).isEqualTo( - new JdbcUserServiceBeanDefinitionParser() - .getBeanClassName(mock(Element.class))); + assertThat(JdbcUserDetailsManager.class.getName()) + .isEqualTo(new JdbcUserServiceBeanDefinitionParser().getBeanClassName(mock(Element.class))); } @Test public void validUsernameIsFound() { setContext("" + DATA_SOURCE); - JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) appContext - .getBean(BeanIds.USER_DETAILS_SERVICE); + JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) this.appContext.getBean(BeanIds.USER_DETAILS_SERVICE); assertThat(mgr.loadUserByUsername("rod")).isNotNull(); } @Test public void beanIdIsParsedCorrectly() { - setContext("" - + DATA_SOURCE); - assertThat(appContext.getBean("myUserService") instanceof JdbcUserDetailsManager).isTrue(); + setContext("" + DATA_SOURCE); + assertThat(this.appContext.getBean("myUserService") instanceof JdbcUserDetailsManager).isTrue(); } @Test public void usernameAndAuthorityQueriesAreParsedCorrectly() throws Exception { String userQuery = "select username, password, true from users where username = ?"; String authoritiesQuery = "select username, authority from authorities where username = ? and 1 = 1"; + // @formatter:off setContext("" + DATA_SOURCE); - JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) appContext - .getBean("myUserService"); + // @formatter:on + JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) this.appContext.getBean("myUserService"); assertThat(FieldUtils.getFieldValue(mgr, "usersByUsernameQuery")).isEqualTo(userQuery); assertThat(FieldUtils.getFieldValue(mgr, "authoritiesByUsernameQuery")).isEqualTo(authoritiesQuery); assertThat(mgr.loadUserByUsername("rod") != null).isTrue(); @@ -99,11 +100,9 @@ public class JdbcUserServiceBeanDefinitionParserTests { @Test public void groupQueryIsParsedCorrectly() throws Exception { - setContext("" + DATA_SOURCE); - JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) appContext - .getBean("myUserService"); + JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) this.appContext.getBean("myUserService"); assertThat(FieldUtils.getFieldValue(mgr, "groupAuthoritiesByUsernameQuery")).isEqualTo("blah blah"); assertThat((Boolean) FieldUtils.getFieldValue(mgr, "enableGroups")).isTrue(); } @@ -112,54 +111,57 @@ public class JdbcUserServiceBeanDefinitionParserTests { public void cacheRefIsparsedCorrectly() { setContext("" + DATA_SOURCE + USER_CACHE_XML); - CachingUserDetailsService cachingUserService = (CachingUserDetailsService) appContext - .getBean("myUserService" - + AbstractUserDetailsServiceBeanDefinitionParser.CACHING_SUFFIX); - assertThat(appContext.getBean("userCache")).isSameAs(cachingUserService.getUserCache()); + CachingUserDetailsService cachingUserService = (CachingUserDetailsService) this.appContext + .getBean("myUserService" + AbstractUserDetailsServiceBeanDefinitionParser.CACHING_SUFFIX); + assertThat(this.appContext.getBean("userCache")).isSameAs(cachingUserService.getUserCache()); assertThat(cachingUserService.loadUserByUsername("rod")).isNotNull(); assertThat(cachingUserService.loadUserByUsername("rod")).isNotNull(); } @Test public void isSupportedByAuthenticationProviderElement() { - setContext("" + " " + // @formatter:off + setContext("" + + " " + " " - + " " + "" + + " " + + "" + DATA_SOURCE); - AuthenticationManager mgr = (AuthenticationManager) appContext - .getBean(BeanIds.AUTHENTICATION_MANAGER); + // @formatter:on + AuthenticationManager mgr = (AuthenticationManager) this.appContext.getBean(BeanIds.AUTHENTICATION_MANAGER); mgr.authenticate(new UsernamePasswordAuthenticationToken("rod", "koala")); } @Test public void cacheIsInjectedIntoAuthenticationProvider() { + // @formatter:off setContext("" + " " + " " - + " " + "" - + DATA_SOURCE + USER_CACHE_XML); - ProviderManager mgr = (ProviderManager) appContext - .getBean(BeanIds.AUTHENTICATION_MANAGER); - DaoAuthenticationProvider provider = (DaoAuthenticationProvider) mgr - .getProviders().get(0); - assertThat(appContext.getBean("userCache")).isSameAs(provider.getUserCache()); + + " " + + "" + + DATA_SOURCE + + USER_CACHE_XML); + // @formatter:on + ProviderManager mgr = (ProviderManager) this.appContext.getBean(BeanIds.AUTHENTICATION_MANAGER); + DaoAuthenticationProvider provider = (DaoAuthenticationProvider) mgr.getProviders().get(0); + assertThat(this.appContext.getBean("userCache")).isSameAs(provider.getUserCache()); provider.authenticate(new UsernamePasswordAuthenticationToken("rod", "koala")); - assertThat(provider - .getUserCache().getUserFromCache("rod")).isNotNull().withFailMessage("Cache should contain user after authentication"); + assertThat(provider.getUserCache().getUserFromCache("rod")).isNotNull() + .withFailMessage("Cache should contain user after authentication"); } @Test public void rolePrefixIsUsedWhenSet() { setContext("" + DATA_SOURCE); - JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) appContext - .getBean("myUserService"); + JdbcUserDetailsManager mgr = (JdbcUserDetailsManager) this.appContext.getBean("myUserService"); UserDetails rod = mgr.loadUserByUsername("rod"); - assertThat(AuthorityUtils.authorityListToSet(rod.getAuthorities())) - .contains("PREFIX_ROLE_SUPERVISOR"); + assertThat(AuthorityUtils.authorityListToSet(rod.getAuthorities())).contains("PREFIX_ROLE_SUPERVISOR"); } private void setContext(String context) { - appContext = new InMemoryXmlApplicationContext(context); + this.appContext = new InMemoryXmlApplicationContext(context); } + } diff --git a/config/src/test/java/org/springframework/security/config/authentication/PasswordEncoderParserTests.java b/config/src/test/java/org/springframework/security/config/authentication/PasswordEncoderParserTests.java index 7f97bfce57..3a5be4981d 100644 --- a/config/src/test/java/org/springframework/security/config/authentication/PasswordEncoderParserTests.java +++ b/config/src/test/java/org/springframework/security/config/authentication/PasswordEncoderParserTests.java @@ -18,6 +18,7 @@ package org.springframework.security.config.authentication; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.test.SpringTestRule; import org.springframework.test.web.servlet.MockMvc; @@ -31,6 +32,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @since 5.0 */ public class PasswordEncoderParserTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -39,22 +41,24 @@ public class PasswordEncoderParserTests { @Test public void passwordEncoderDefaultsToDelegatingPasswordEncoder() throws Exception { - this.spring.configLocations("classpath:org/springframework/security/config/authentication/PasswordEncoderParserTests-default.xml") - .mockMvcAfterSpringSecurityOk() - .autowire(); - + this.spring.configLocations( + "classpath:org/springframework/security/config/authentication/PasswordEncoderParserTests-default.xml") + .mockMvcAfterSpringSecurityOk().autowire(); + // @formatter:off this.mockMvc.perform(get("/").with(httpBasic("user", "password"))) - .andExpect(status().isOk()); + .andExpect(status().isOk()); + // @formatter:on } @Test public void passwordEncoderDefaultsToPasswordEncoderBean() throws Exception { - this.spring.configLocations("classpath:org/springframework/security/config/authentication/PasswordEncoderParserTests-bean.xml") - .mockMvcAfterSpringSecurityOk() - .autowire(); - + this.spring.configLocations( + "classpath:org/springframework/security/config/authentication/PasswordEncoderParserTests-bean.xml") + .mockMvcAfterSpringSecurityOk().autowire(); + // @formatter:off this.mockMvc.perform(get("/").with(httpBasic("user", "password"))) - .andExpect(status().isOk()); + .andExpect(status().isOk()); + // @formatter:on } } diff --git a/config/src/test/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParserTests.java index b3791d7f1d..239f1145ce 100644 --- a/config/src/test/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/authentication/UserServiceBeanDefinitionParserTests.java @@ -13,29 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.authentication; -import static org.assertj.core.api.Assertions.*; +import org.junit.After; +import org.junit.Test; +import org.springframework.beans.FatalBeanException; +import org.springframework.context.support.AbstractXmlApplicationContext; import org.springframework.security.config.util.InMemoryXmlApplicationContext; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; -import org.springframework.context.support.AbstractXmlApplicationContext; -import org.springframework.beans.FatalBeanException; -import org.junit.Test; -import org.junit.After; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor */ public class UserServiceBeanDefinitionParserTests { + private AbstractXmlApplicationContext appContext; @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); + if (this.appContext != null) { + this.appContext.close(); } } @@ -43,19 +45,19 @@ public class UserServiceBeanDefinitionParserTests { public void userServiceWithValidPropertiesFileWorksSuccessfully() { setContext(""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); userService.loadUserByUsername("bob"); userService.loadUserByUsername("joe"); } @Test public void userServiceWithEmbeddedUsersWorksSuccessfully() { + // @formatter:off setContext("" + " " + ""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); + // @formatter:on + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); userService.loadUserByUsername("joe"); } @@ -64,12 +66,13 @@ public class UserServiceBeanDefinitionParserTests { System.setProperty("principal.name", "joe"); System.setProperty("principal.pass", "joespassword"); System.setProperty("principal.authorities", "ROLE_A,ROLE_B"); + // @formatter:off setContext("" + "" + " " + ""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); + // @formatter:on + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); UserDetails joe = userService.loadUserByUsername("joe"); assertThat(joe.getPassword()).isEqualTo("joespassword"); assertThat(joe.getAuthorities()).hasSize(2); @@ -77,10 +80,12 @@ public class UserServiceBeanDefinitionParserTests { @Test public void embeddedUsersWithNoPasswordIsGivenGeneratedValue() { + // @formatter:off setContext("" - + " " + ""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); + + " " + + ""); + // @formatter:on + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); UserDetails joe = userService.loadUserByUsername("joe"); assertThat(joe.getPassword().length() > 0).isTrue(); Long.parseLong(joe.getPassword()); @@ -88,30 +93,28 @@ public class UserServiceBeanDefinitionParserTests { @Test public void worksWithOpenIDUrlsAsNames() { + // @formatter:off setContext("" + " " + " " + ""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); - assertThat( - userService.loadUserByUsername("https://joe.myopenid.com/").getUsername()) + // @formatter:on + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); + assertThat(userService.loadUserByUsername("https://joe.myopenid.com/").getUsername()) .isEqualTo("https://joe.myopenid.com/"); - assertThat( - userService.loadUserByUsername( - "https://www.google.com/accounts/o8/id?id=MPtOaenBIk5yzW9n7n9") - .getUsername()) - .isEqualTo("https://www.google.com/accounts/o8/id?id=MPtOaenBIk5yzW9n7n9"); + assertThat(userService.loadUserByUsername("https://www.google.com/accounts/o8/id?id=MPtOaenBIk5yzW9n7n9") + .getUsername()).isEqualTo("https://www.google.com/accounts/o8/id?id=MPtOaenBIk5yzW9n7n9"); } @Test public void disabledAndEmbeddedFlagsAreSupported() { + // @formatter:off setContext("" + " " + " " + ""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); + // @formatter:on + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); UserDetails joe = userService.loadUserByUsername("joe"); assertThat(joe.isAccountNonLocked()).isFalse(); // Check case-sensitive lookup SEC-1432 @@ -121,11 +124,12 @@ public class UserServiceBeanDefinitionParserTests { @Test(expected = FatalBeanException.class) public void userWithBothPropertiesAndEmbeddedUsersThrowsException() { + // @formatter:off setContext("" + " " + ""); - UserDetailsService userService = (UserDetailsService) appContext - .getBean("service"); + // @formatter:on + UserDetailsService userService = (UserDetailsService) this.appContext.getBean("service"); userService.loadUserByUsername("Joe"); } @@ -133,7 +137,6 @@ public class UserServiceBeanDefinitionParserTests { public void multipleTopLevelUseWithoutIdThrowsException() { setContext("" + ""); - } @Test(expected = FatalBeanException.class) @@ -142,6 +145,7 @@ public class UserServiceBeanDefinitionParserTests { } private void setContext(String context) { - appContext = new InMemoryXmlApplicationContext(context); + this.appContext = new InMemoryXmlApplicationContext(context); } + } diff --git a/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsJcTests.java b/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsJcTests.java index c26ec69d11..49025a33c3 100644 --- a/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsJcTests.java +++ b/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsJcTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.core; import java.io.IOException; @@ -27,6 +28,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -52,23 +54,26 @@ import static org.assertj.core.api.Assertions.assertThat; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration public class GrantedAuthorityDefaultsJcTests { + @Autowired FilterChainProxy springSecurityFilterChain; + @Autowired MessageService messageService; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Before public void setup() { setup("USER"); - - request = new MockHttpServletRequest("GET", ""); - request.setMethod("GET"); - response = new MockHttpServletResponse(); - chain = new MockFilterChain(); + this.request = new MockHttpServletRequest("GET", ""); + this.request.setMethod("GET"); + this.response = new MockHttpServletResponse(); + this.chain = new MockFilterChain(); } @After @@ -79,57 +84,51 @@ public class GrantedAuthorityDefaultsJcTests { @Test public void doFilter() throws Exception { SecurityContext context = SecurityContextHolder.getContext(); - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, context); - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + context); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void doFilterDenied() throws Exception { setup("DENIED"); - SecurityContext context = SecurityContextHolder.getContext(); - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, context); - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + context); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } @Test public void message() { - messageService.getMessage(); + this.messageService.getMessage(); } @Test public void jsrMessage() { - messageService.getJsrMessage(); + this.messageService.getJsrMessage(); } @Test(expected = AccessDeniedException.class) public void messageDenied() { setup("DENIED"); - - messageService.getMessage(); + this.messageService.getMessage(); } @Test(expected = AccessDeniedException.class) public void jsrMessageDenied() { setup("DENIED"); - - messageService.getJsrMessage(); + this.messageService.getJsrMessage(); } // SEC-2926 @Test public void doFilterIsUserInRole() throws Exception { SecurityContext context = SecurityContextHolder.getContext(); - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, context); - - chain = new MockFilterChain() { - + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + context); + this.chain = new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { @@ -138,12 +137,9 @@ public class GrantedAuthorityDefaultsJcTests { assertThat(httpRequest.isUserInRole("INVALID")).isFalse(); super.doFilter(request, response); } - }; - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(chain.getRequest()).isNotNull(); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.chain.getRequest()).isNotNull(); } private void setup(String role) { @@ -153,31 +149,37 @@ public class GrantedAuthorityDefaultsJcTests { @Configuration @EnableWebSecurity - @EnableGlobalMethodSecurity(prePostEnabled=true, jsr250Enabled=true) + @EnableGlobalMethodSecurity(prePostEnabled = true, jsr250Enabled = true) static class Config extends WebSecurityConfigurerAdapter { @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().access("hasRole('USER')"); + // @formatter:on } @Bean - public MessageService messageService() { + MessageService messageService() { return new HelloWorldMessageService(); } @Bean - public static GrantedAuthorityDefaults grantedAuthorityDefaults() { + static GrantedAuthorityDefaults grantedAuthorityDefaults() { return new GrantedAuthorityDefaults(""); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsXmlTests.java b/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsXmlTests.java index 9cd487988d..d8dad3d308 100644 --- a/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsXmlTests.java +++ b/config/src/test/java/org/springframework/security/config/core/GrantedAuthorityDefaultsXmlTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.core; import java.io.IOException; @@ -27,6 +28,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; @@ -45,23 +47,26 @@ import static org.assertj.core.api.Assertions.assertThat; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration public class GrantedAuthorityDefaultsXmlTests { + @Autowired FilterChainProxy springSecurityFilterChain; + @Autowired MessageService messageService; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; @Before public void setup() { setup("USER"); - - request = new MockHttpServletRequest("GET", ""); - request.setMethod("GET"); - response = new MockHttpServletResponse(); - chain = new MockFilterChain(); + this.request = new MockHttpServletRequest("GET", ""); + this.request.setMethod("GET"); + this.response = new MockHttpServletResponse(); + this.chain = new MockFilterChain(); } @After @@ -72,57 +77,51 @@ public class GrantedAuthorityDefaultsXmlTests { @Test public void doFilter() throws Exception { SecurityContext context = SecurityContextHolder.getContext(); - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, context); - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + context); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void doFilterDenied() throws Exception { setup("DENIED"); - SecurityContext context = SecurityContextHolder.getContext(); - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, context); - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + context); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); } @Test public void message() { - messageService.getMessage(); + this.messageService.getMessage(); } @Test public void jsrMessage() { - messageService.getJsrMessage(); + this.messageService.getJsrMessage(); } @Test(expected = AccessDeniedException.class) public void messageDenied() { setup("DENIED"); - - messageService.getMessage(); + this.messageService.getMessage(); } @Test(expected = AccessDeniedException.class) public void jsrMessageDenied() { setup("DENIED"); - - messageService.getJsrMessage(); + this.messageService.getJsrMessage(); } // SEC-2926 @Test public void doFilterIsUserInRole() throws Exception { SecurityContext context = SecurityContextHolder.getContext(); - request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, context); - - chain = new MockFilterChain() { - + this.request.getSession().setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + context); + this.chain = new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { @@ -131,16 +130,14 @@ public class GrantedAuthorityDefaultsXmlTests { assertThat(httpRequest.isUserInRole("INVALID")).isFalse(); super.doFilter(request, response); } - }; - - springSecurityFilterChain.doFilter(request, response, chain); - - assertThat(chain.getRequest()).isNotNull(); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); + assertThat(this.chain.getRequest()).isNotNull(); } private void setup(String role) { TestingAuthenticationToken user = new TestingAuthenticationToken("user", "password", role); SecurityContextHolder.getContext().setAuthentication(user); } + } diff --git a/config/src/test/java/org/springframework/security/config/core/HelloWorldMessageService.java b/config/src/test/java/org/springframework/security/config/core/HelloWorldMessageService.java index 59a08a96c7..452a20042f 100755 --- a/config/src/test/java/org/springframework/security/config/core/HelloWorldMessageService.java +++ b/config/src/test/java/org/springframework/security/config/core/HelloWorldMessageService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.core; import javax.annotation.security.RolesAllowed; @@ -24,13 +25,16 @@ import org.springframework.security.access.prepost.PreAuthorize; */ public class HelloWorldMessageService implements MessageService { + @Override @PreAuthorize("hasRole('USER')") public String getMessage() { return "Hello World"; } + @Override @RolesAllowed("USER") public String getJsrMessage() { return "Hello JSR"; } + } diff --git a/config/src/test/java/org/springframework/security/config/core/MessageService.java b/config/src/test/java/org/springframework/security/config/core/MessageService.java index 3ec014e486..03eaefab59 100755 --- a/config/src/test/java/org/springframework/security/config/core/MessageService.java +++ b/config/src/test/java/org/springframework/security/config/core/MessageService.java @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.core; /** * @author Rob Winch */ public interface MessageService { + String getMessage(); + String getJsrMessage(); + } diff --git a/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceITests.java b/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceITests.java index 59f75ef5de..77bbef169f 100644 --- a/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceITests.java +++ b/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceITests.java @@ -16,9 +16,9 @@ package org.springframework.security.config.core.userdetails; - import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -26,7 +26,7 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.util.InMemoryResource; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -34,18 +34,24 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; */ @RunWith(SpringRunner.class) public class ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceITests { - @Autowired ReactiveUserDetailsService users; + + @Autowired + ReactiveUserDetailsService users; @Test public void loadUserByUsernameWhenUserFoundThenNotNull() { - assertThat(users.findByUsername("user").block()).isNotNull(); + assertThat(this.users.findByUsername("user").block()).isNotNull(); } @Configuration static class Config { + @Bean - public ReactiveUserDetailsServiceResourceFactoryBean userDetailsService() { - return ReactiveUserDetailsServiceResourceFactoryBean.fromResource(new InMemoryResource("user=password,ROLE_USER")); + ReactiveUserDetailsServiceResourceFactoryBean userDetailsService() { + return ReactiveUserDetailsServiceResourceFactoryBean + .fromResource(new InMemoryResource("user=password,ROLE_USER")); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceLocationITests.java b/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceLocationITests.java index 8ec512ce44..cd720f1b7b 100644 --- a/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceLocationITests.java +++ b/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceLocationITests.java @@ -16,16 +16,16 @@ package org.springframework.security.config.core.userdetails; - import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -33,18 +33,23 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; */ @RunWith(SpringRunner.class) public class ReactiveUserDetailsServiceResourceFactoryBeanPropertiesResourceLocationITests { - @Autowired ReactiveUserDetailsService users; + + @Autowired + ReactiveUserDetailsService users; @Test public void loadUserByUsernameWhenUserFoundThenNotNull() { - assertThat(users.findByUsername("user").block()).isNotNull(); + assertThat(this.users.findByUsername("user").block()).isNotNull(); } @Configuration static class Config { + @Bean - public ReactiveUserDetailsServiceResourceFactoryBean userDetailsService() { + ReactiveUserDetailsServiceResourceFactoryBean userDetailsService() { return ReactiveUserDetailsServiceResourceFactoryBean.fromResourceLocation("classpath:users.properties"); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanStringITests.java b/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanStringITests.java index d6a19645ca..f257138526 100644 --- a/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanStringITests.java +++ b/config/src/test/java/org/springframework/security/config/core/userdetails/ReactiveUserDetailsServiceResourceFactoryBeanStringITests.java @@ -16,16 +16,16 @@ package org.springframework.security.config.core.userdetails; - import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -33,18 +33,23 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; */ @RunWith(SpringRunner.class) public class ReactiveUserDetailsServiceResourceFactoryBeanStringITests { - @Autowired ReactiveUserDetailsService users; + + @Autowired + ReactiveUserDetailsService users; @Test public void findByUsernameWhenUserFoundThenNotNull() { - assertThat(users.findByUsername("user").block()).isNotNull(); + assertThat(this.users.findByUsername("user").block()).isNotNull(); } @Configuration static class Config { + @Bean - public ReactiveUserDetailsServiceResourceFactoryBean userDetailsService() { + ReactiveUserDetailsServiceResourceFactoryBean userDetailsService() { return ReactiveUserDetailsServiceResourceFactoryBean.fromString("user=password,ROLE_USER"); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBeanTest.java b/config/src/test/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBeanTests.java similarity index 64% rename from config/src/test/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBeanTest.java rename to config/src/test/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBeanTests.java index 8b53fc0e34..03ceae7ded 100644 --- a/config/src/test/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBeanTest.java +++ b/config/src/test/java/org/springframework/security/config/core/userdetails/UserDetailsResourceFactoryBeanTests.java @@ -16,26 +16,29 @@ package org.springframework.security.config.core.userdetails; +import java.util.Collection; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.core.io.ResourceLoader; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.util.InMemoryResource; -import java.util.Collection; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * @author Rob Winch * @since 5.0 */ @RunWith(MockitoJUnitRunner.class) -public class UserDetailsResourceFactoryBeanTest { +public class UserDetailsResourceFactoryBeanTests { + @Mock ResourceLoader resourceLoader; @@ -43,60 +46,62 @@ public class UserDetailsResourceFactoryBeanTest { @Test public void setResourceLoaderWhenNullThenThrowsException() { - assertThatThrownBy(() -> factory.setResourceLoader(null) ) - .isInstanceOf(IllegalArgumentException.class) - .hasStackTraceContaining("resourceLoader cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.factory.setResourceLoader(null)) + .withStackTraceContaining("resourceLoader cannot be null"); + // @formatter:on } @Test public void getObjectWhenPropertiesResourceLocationNullThenThrowsIllegalStateException() { - factory.setResourceLoader(resourceLoader); - - assertThatThrownBy(() -> factory.getObject() ) - .isInstanceOf(IllegalArgumentException.class) - .hasStackTraceContaining("resource cannot be null if resourceLocation is null"); + this.factory.setResourceLoader(this.resourceLoader); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.factory.getObject()) + .withStackTraceContaining("resource cannot be null if resourceLocation is null"); + // @formatter:on } @Test public void getObjectWhenPropertiesResourceLocationSingleUserThenThrowsGetsSingleUser() throws Exception { - factory.setResourceLocation("classpath:users.properties"); - - Collection users = factory.getObject(); - + this.factory.setResourceLocation("classpath:users.properties"); + Collection users = this.factory.getObject(); assertLoaded(); } @Test public void getObjectWhenPropertiesResourceSingleUserThenThrowsGetsSingleUser() throws Exception { - factory.setResource(new InMemoryResource("user=password,ROLE_USER")); - + this.factory.setResource(new InMemoryResource("user=password,ROLE_USER")); assertLoaded(); } @Test public void getObjectWhenInvalidUserThenThrowsMeaningfulException() { - factory.setResource(new InMemoryResource("user=invalidFormatHere")); - - assertThatThrownBy(() -> factory.getObject() ) - .isInstanceOf(IllegalStateException.class) - .hasStackTraceContaining("user") - .hasStackTraceContaining("invalidFormatHere"); + this.factory.setResource(new InMemoryResource("user=invalidFormatHere")); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> this.factory.getObject()) + .withStackTraceContaining("user") + .withStackTraceContaining("invalidFormatHere"); + // @formatter:on } @Test public void getObjectWhenStringSingleUserThenGetsSingleUser() throws Exception { this.factory = UserDetailsResourceFactoryBean.fromString("user=password,ROLE_USER"); - assertLoaded(); } private void assertLoaded() throws Exception { - Collection users = factory.getObject(); - + Collection users = this.factory.getObject(); + // @formatter:off UserDetails expectedUser = User.withUsername("user") .password("password") .authorities("ROLE_USER") .build(); + // @formatter:on assertThat(users).containsExactly(expectedUser); } + } diff --git a/config/src/test/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessorTests.java b/config/src/test/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessorTests.java index e4bb97ca8e..84709eea8b 100644 --- a/config/src/test/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessorTests.java +++ b/config/src/test/java/org/springframework/security/config/crypto/RsaKeyConversionServicePostProcessorTests.java @@ -38,46 +38,48 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe import org.springframework.security.config.test.SpringTestRule; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link RsaKeyConversionServicePostProcessor} */ public class RsaKeyConversionServicePostProcessorTests { - private static final String PKCS8_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n" + - "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCMk7CKSTfu3QoV\n" + - "HoPVXxwZO+qweztd36cVWYqGOZinrOR2crWFu50AgR2CsdIH0+cqo7F4Vx7/3O8i\n" + - "RpYYZPe2VoO5sumzJt8P6fS80/TAKjhJDAqgZKRJTgGN8KxCM6p/aJli1ZeDBqiV\n" + - "v7vJJe+ZgJuPGRS+HMNa/wPxEkqqXsglcJcQV1ZEtfKXSHB7jizKpRL38185SyAC\n" + - "pwyjvBu6Cmm1URfhQo88mf239ONh4dZ2HoDfzN1q6Ssu4F4hgutxr9B0DVLDP5u+\n" + - "WFrm3nsJ76zf99uJ+ntMUHJ+bY+gOjSlVWIVBIZeAaEGKCNWRk/knjvjbijpvm3U\n" + - "acGlgdL3AgMBAAECggEACxxxS7zVyu91qI2s5eSKmAQAXMqgup6+2hUluc47nqUv\n" + - "uZz/c/6MPkn2Ryo+65d4IgqmMFjSfm68B/2ER5FTcvoLl1Xo2twrrVpUmcg3BClS\n" + - "IZPuExdhVNnxjYKEWwcyZrehyAoR261fDdcFxLRW588efIUC+rPTTRHzAc7sT+Ln\n" + - "t/uFeYNWJm3LaegOLoOmlMAhJ5puAWSN1F0FxtRf/RVgzbLA9QC975SKHJsfWCSr\n" + - "IZyPsdeaqomKaF65l8nfqlE0Ua2L35gIOGKjUwb7uUE8nI362RWMtYdoi3zDDyoY\n" + - "hSFbgjylCHDM0u6iSh6KfqOHtkYyJ8tUYgVWl787wQKBgQDYO3wL7xuDdD101Lyl\n" + - "AnaDdFB9fxp83FG1cWr+t7LYm9YxGfEUsKHAJXN6TIayDkOOoVwIl+Gz0T3Z06Bm\n" + - "eBGLrB9mrVA7+C7NJwu5gTMlzP6HxUR9zKJIQ/VB1NUGM77LSmvOFbHc9Q0+z8EH\n" + - "X5WO516a3Z7lNtZJcCoPOtu2rwKBgQCmbj41Fh+SSEUApCEKms5ETRpe7LXQlJgx\n" + - "yW7zcJNNuIb1C3vBLPxjiOTMgYKOeMg5rtHTGLT43URHLh9ArjawasjSAr4AM3J4\n" + - "xpoi/sKGDdiKOsuDWIGfzdYL8qyTHSdpZLQsCTMRiRYgAHZFPgNa7SLZRfZicGlr\n" + - "GHN1rJW6OQKBgEjiM/upyrJSWeypUDSmUeAZMpA6aWkwsfHgmtnkfUn5rQa74cDB\n" + - "kKO9e+D7LmOR3z+SL/1NhGwh2SE07dncGr3jdGodfO/ZxZyszozmeaECKcEFwwJM\n" + - "GV8WWPKplGwUwPiwywmZ0mvRxXcoe73KgBS88+xrSwWjqDL0tZiQlEJNAoGATkei\n" + - "GMQMG3jEg9Wu+NbxV6zQT3+U0MNjhl9RQU1c63x0dcNt9OFc4NAdlZcAulRTENaK\n" + - "OHjxffBM0hH+fySx8m53gFfr2BpaqDX5f6ZGBlly1SlsWZ4CchCVsc71nshipi7I\n" + - "k8HL9F5/OpQdDNprJ5RMBNfkWE65Nrcsb1e6oPkCgYAxwgdiSOtNg8PjDVDmAhwT\n" + - "Mxj0Dtwi2fAqQ76RVrrXpNp3uCOIAu4CfruIb5llcJ3uak0ZbnWri32AxSgk80y3\n" + - "EWiRX/WEDu5znejF+5O3pI02atWWcnxifEKGGlxwkcMbQdA67MlrJLFaSnnGpNXo\n" + - "yPfcul058SOqhafIZQMEKQ==\n" + - "-----END PRIVATE KEY-----"; - private static final String X509_PUBLIC_KEY_LOCATION = - "classpath:org/springframework/security/config/annotation/web/configuration/simple.pub"; + // @formatter:off + private static final String PKCS8_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n" + + "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCMk7CKSTfu3QoV\n" + + "HoPVXxwZO+qweztd36cVWYqGOZinrOR2crWFu50AgR2CsdIH0+cqo7F4Vx7/3O8i\n" + + "RpYYZPe2VoO5sumzJt8P6fS80/TAKjhJDAqgZKRJTgGN8KxCM6p/aJli1ZeDBqiV\n" + + "v7vJJe+ZgJuPGRS+HMNa/wPxEkqqXsglcJcQV1ZEtfKXSHB7jizKpRL38185SyAC\n" + + "pwyjvBu6Cmm1URfhQo88mf239ONh4dZ2HoDfzN1q6Ssu4F4hgutxr9B0DVLDP5u+\n" + + "WFrm3nsJ76zf99uJ+ntMUHJ+bY+gOjSlVWIVBIZeAaEGKCNWRk/knjvjbijpvm3U\n" + + "acGlgdL3AgMBAAECggEACxxxS7zVyu91qI2s5eSKmAQAXMqgup6+2hUluc47nqUv\n" + + "uZz/c/6MPkn2Ryo+65d4IgqmMFjSfm68B/2ER5FTcvoLl1Xo2twrrVpUmcg3BClS\n" + + "IZPuExdhVNnxjYKEWwcyZrehyAoR261fDdcFxLRW588efIUC+rPTTRHzAc7sT+Ln\n" + + "t/uFeYNWJm3LaegOLoOmlMAhJ5puAWSN1F0FxtRf/RVgzbLA9QC975SKHJsfWCSr\n" + + "IZyPsdeaqomKaF65l8nfqlE0Ua2L35gIOGKjUwb7uUE8nI362RWMtYdoi3zDDyoY\n" + + "hSFbgjylCHDM0u6iSh6KfqOHtkYyJ8tUYgVWl787wQKBgQDYO3wL7xuDdD101Lyl\n" + + "AnaDdFB9fxp83FG1cWr+t7LYm9YxGfEUsKHAJXN6TIayDkOOoVwIl+Gz0T3Z06Bm\n" + + "eBGLrB9mrVA7+C7NJwu5gTMlzP6HxUR9zKJIQ/VB1NUGM77LSmvOFbHc9Q0+z8EH\n" + + "X5WO516a3Z7lNtZJcCoPOtu2rwKBgQCmbj41Fh+SSEUApCEKms5ETRpe7LXQlJgx\n" + + "yW7zcJNNuIb1C3vBLPxjiOTMgYKOeMg5rtHTGLT43URHLh9ArjawasjSAr4AM3J4\n" + + "xpoi/sKGDdiKOsuDWIGfzdYL8qyTHSdpZLQsCTMRiRYgAHZFPgNa7SLZRfZicGlr\n" + + "GHN1rJW6OQKBgEjiM/upyrJSWeypUDSmUeAZMpA6aWkwsfHgmtnkfUn5rQa74cDB\n" + + "kKO9e+D7LmOR3z+SL/1NhGwh2SE07dncGr3jdGodfO/ZxZyszozmeaECKcEFwwJM\n" + + "GV8WWPKplGwUwPiwywmZ0mvRxXcoe73KgBS88+xrSwWjqDL0tZiQlEJNAoGATkei\n" + + "GMQMG3jEg9Wu+NbxV6zQT3+U0MNjhl9RQU1c63x0dcNt9OFc4NAdlZcAulRTENaK\n" + + "OHjxffBM0hH+fySx8m53gFfr2BpaqDX5f6ZGBlly1SlsWZ4CchCVsc71nshipi7I\n" + + "k8HL9F5/OpQdDNprJ5RMBNfkWE65Nrcsb1e6oPkCgYAxwgdiSOtNg8PjDVDmAhwT\n" + + "Mxj0Dtwi2fAqQ76RVrrXpNp3uCOIAu4CfruIb5llcJ3uak0ZbnWri32AxSgk80y3\n" + + "EWiRX/WEDu5znejF+5O3pI02atWWcnxifEKGGlxwkcMbQdA67MlrJLFaSnnGpNXo\n" + + "yPfcul058SOqhafIZQMEKQ==\n" + + "-----END PRIVATE KEY-----"; + // @formatter:on + + private static final String X509_PUBLIC_KEY_LOCATION = "classpath:org/springframework/security/config/annotation/web/configuration/simple.pub"; + + private final RsaKeyConversionServicePostProcessor postProcessor = new RsaKeyConversionServicePostProcessor(); - private final RsaKeyConversionServicePostProcessor postProcessor = - new RsaKeyConversionServicePostProcessor(); private ConversionService service; @Value("classpath:org/springframework/security/config/annotation/web/configuration/simple.pub") @@ -132,43 +134,51 @@ public class RsaKeyConversionServicePostProcessorTests { @Test public void valueWhenOverridingConversionServiceThenUsed() { - assertThatCode(() -> - this.spring.register(OverrideConversionServiceConfig.class, DefaultConfig.class).autowire()) - .hasRootCauseInstanceOf(IllegalArgumentException.class); + assertThatExceptionOfType(Exception.class).isThrownBy( + () -> this.spring.register(OverrideConversionServiceConfig.class, DefaultConfig.class).autowire()) + .withRootCauseInstanceOf(IllegalArgumentException.class); } @EnableWebSecurity - static class DefaultConfig { } + static class DefaultConfig { + + } @Configuration static class CustomResourceLoaderConfig { + @Bean BeanFactoryPostProcessor conversionServiceCustomizer() { - return beanFactory -> beanFactory.getBean(RsaKeyConversionServicePostProcessor.class) + return (beanFactory) -> beanFactory.getBean(RsaKeyConversionServicePostProcessor.class) .setResourceLoader(new CustomResourceLoader()); } + } @Configuration static class OverrideConversionServiceConfig { + @Bean ConversionService conversionService() { GenericConversionService service = new GenericConversionService(); - service.addConverter(String.class, RSAPublicKey.class, source -> { + service.addConverter(String.class, RSAPublicKey.class, (source) -> { throw new IllegalArgumentException("unsupported"); }); return service; } + } private static class CustomResourceLoader implements ResourceLoader { + private final ResourceLoader delegate = new DefaultResourceLoader(); @Override public Resource getResource(String location) { if (location.startsWith("classpath:")) { return this.delegate.getResource(location); - } else if (location.startsWith("custom:")) { + } + else if (location.startsWith("custom:")) { String[] parts = location.split(":"); return this.delegate.getResource( "classpath:org/springframework/security/config/annotation/web/configuration/" + parts[1]); @@ -180,5 +190,7 @@ public class RsaKeyConversionServicePostProcessorTests { public ClassLoader getClassLoader() { return this.delegate.getClassLoader(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/debug/AuthProviderDependency.java b/config/src/test/java/org/springframework/security/config/debug/AuthProviderDependency.java index 2482e56434..d29f1b476d 100644 --- a/config/src/test/java/org/springframework/security/config/debug/AuthProviderDependency.java +++ b/config/src/test/java/org/springframework/security/config/debug/AuthProviderDependency.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.debug; import org.springframework.stereotype.Component; /** * Fake depenency for {@link TestAuthenticationProvider} + * * @author Rob Winch * */ diff --git a/config/src/test/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests.java b/config/src/test/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests.java index a721906b21..4d885e2116 100644 --- a/config/src/test/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests.java +++ b/config/src/test/java/org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.debug; import org.junit.Rule; import org.junit.Test; + +import org.springframework.security.config.BeanIds; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.debug.DebugFilter; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.config.BeanIds.FILTER_CHAIN_PROXY; -import static org.springframework.security.config.BeanIds.SPRING_SECURITY_FILTER_CHAIN; /** * @author Rob Winch @@ -37,10 +38,12 @@ public class SecurityDebugBeanFactoryPostProcessorTests { @Test public void contextRefreshWhenInDebugModeAndDependencyHasAutowiredConstructorThenDebugModeStillWorks() { // SEC-1885 - this.spring.configLocations("classpath:org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests-context.xml") - .autowire(); - - assertThat(this.spring.getContext().getBean(SPRING_SECURITY_FILTER_CHAIN)).isInstanceOf(DebugFilter.class); - assertThat(this.spring.getContext().getBean(FILTER_CHAIN_PROXY)).isInstanceOf(FilterChainProxy.class); + this.spring.configLocations( + "classpath:org/springframework/security/config/debug/SecurityDebugBeanFactoryPostProcessorTests-context.xml") + .autowire(); + assertThat(this.spring.getContext().getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN)) + .isInstanceOf(DebugFilter.class); + assertThat(this.spring.getContext().getBean(BeanIds.FILTER_CHAIN_PROXY)).isInstanceOf(FilterChainProxy.class); } + } diff --git a/config/src/test/java/org/springframework/security/config/debug/TestAuthenticationProvider.java b/config/src/test/java/org/springframework/security/config/debug/TestAuthenticationProvider.java index 1a7a6f8966..b6aa96f65e 100644 --- a/config/src/test/java/org/springframework/security/config/debug/TestAuthenticationProvider.java +++ b/config/src/test/java/org/springframework/security/config/debug/TestAuthenticationProvider.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.debug; +package org.springframework.security.config.debug; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.authentication.AuthenticationProvider; @@ -23,7 +23,9 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.stereotype.Service; /** - * An {@link AuthenticationProvider} that has an {@link Autowired} constructor which is necessary to recreate SEC-1885. + * An {@link AuthenticationProvider} that has an {@link Autowired} constructor which is + * necessary to recreate SEC-1885. + * * @author Rob Winch * */ @@ -34,11 +36,14 @@ public class TestAuthenticationProvider implements AuthenticationProvider { public TestAuthenticationProvider(AuthProviderDependency authProviderDependency) { } + @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { throw new UnsupportedOperationException(); } + @Override public boolean supports(Class authentication) { throw new UnsupportedOperationException(); } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/Attribute.java b/config/src/test/java/org/springframework/security/config/doc/Attribute.java index 85983a5b55..ff409f7505 100644 --- a/config/src/test/java/org/springframework/security/config/doc/Attribute.java +++ b/config/src/test/java/org/springframework/security/config/doc/Attribute.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.doc; /** - * Represents a Spring Security XSD Attribute. It is created when parsing the current xsd to compare to the documented appendix. + * Represents a Spring Security XSD Attribute. It is created when parsing the current xsd + * to compare to the documented appendix. * * @author Rob Winch * @author Josh Cummings - * * @see SpringSecurityXsdParser * @see XsdDocumentedTests */ public class Attribute { + private String name; private String desc; @@ -63,4 +65,5 @@ public class Attribute { public String getId() { return String.format("%s-%s", this.elmt.getId(), this.name); } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/Element.java b/config/src/test/java/org/springframework/security/config/doc/Element.java index 368dd5d500..365bcfb5b9 100644 --- a/config/src/test/java/org/springframework/security/config/doc/Element.java +++ b/config/src/test/java/org/springframework/security/config/doc/Element.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.doc; import java.util.ArrayList; @@ -21,25 +22,30 @@ import java.util.HashMap; import java.util.Map; /** - * Represents a Spring Security XSD Element. It is created when parsing - * the current xsd to compare to the documented appendix. + * Represents a Spring Security XSD Element. It is created when parsing the current xsd to + * compare to the documented appendix. * * @author Rob Winch * @author Josh Cummings - * * @see SpringSecurityXsdParser * @see XsdDocumentedTests -*/ + */ public class Element { + private String name; + private String desc; + private Collection attrs = new ArrayList<>(); /** - * Contains the elements that extend this element (i.e. any-user-service contains ldap-user-service) + * Contains the elements that extend this element (i.e. any-user-service contains + * ldap-user-service) */ private Collection subGrps = new ArrayList<>(); + private Map childElmts = new HashMap<>(); + private Map parentElmts = new HashMap<>(); public String getId() { @@ -95,75 +101,62 @@ public class Element { } /** - * Gets all the ids related to this Element including attributes, parent elements, and child elements. + * Gets all the ids related to this Element including attributes, parent elements, and + * child elements. * *

    * The expected ids to be found are documented below. *

      - *
    • Elements - any xml element will have the nsa-<element>. For example the http element will have the id - * nsa-http
    • - *
    • Parent Section - Any element with a parent other than beans will have a section named - * nsa-<element>-parents. For example, authentication-provider would have a section id of - * nsa-authentication-provider-parents. The section would then contain a list of links pointing to the - * documentation for each parent element.
    • - *
    • Attributes Section - Any element with attributes will have a section with the id - * nsa-<element>-attributes. For example the http element would require a section with the id - * http-attributes.
    • - *
    • Attribute - Each attribute of an element would have an id of nsa-<element>-<attributeName>. For - * example the attribute create-session for the http attribute would have the id http-create-session.
    • - *
    • Child Section - Any element with a child element will have a section named nsa-<element>-children. - * For example, authentication-provider would have a section id of nsa-authentication-provider-children. The - * section would then contain a list of links pointing to the documentation for each child element.
    • + *
    • Elements - any xml element will have the nsa-<element>. For example the + * http element will have the id nsa-http
    • + *
    • Parent Section - Any element with a parent other than beans will have a section + * named nsa-<element>-parents. For example, authentication-provider would have + * a section id of nsa-authentication-provider-parents. The section would then contain + * a list of links pointing to the documentation for each parent element.
    • + *
    • Attributes Section - Any element with attributes will have a section with the + * id nsa-<element>-attributes. For example the http element would require a + * section with the id http-attributes.
    • + *
    • Attribute - Each attribute of an element would have an id of + * nsa-<element>-<attributeName>. For example the attribute create-session + * for the http attribute would have the id http-create-session.
    • + *
    • Child Section - Any element with a child element will have a section named + * nsa-<element>-children. For example, authentication-provider would have a + * section id of nsa-authentication-provider-children. The section would then contain + * a list of links pointing to the documentation for each child element.
    • *
    * @return */ public Collection getIds() { Collection ids = new ArrayList<>(); ids.add(getId()); - - this.childElmts.values() - .forEach(elmt -> ids.add(elmt.getId())); - - this.attrs.forEach(attr -> ids.add(attr.getId())); - - if ( !this.childElmts.isEmpty() ) { + this.childElmts.values().forEach((elmt) -> ids.add(elmt.getId())); + this.attrs.forEach((attr) -> ids.add(attr.getId())); + if (!this.childElmts.isEmpty()) { ids.add(getId() + "-children"); } - - if ( !this.attrs.isEmpty() ) { + if (!this.attrs.isEmpty()) { ids.add(getId() + "-attributes"); } - - if ( !this.parentElmts.isEmpty() ) { + if (!this.parentElmts.isEmpty()) { ids.add(getId() + "-parents"); } - return ids; } public Map getAllChildElmts() { Map result = new HashMap<>(); - this.childElmts.values() - .forEach(elmt -> - elmt.subGrps.forEach( - subElmt -> result.put(subElmt.name, subElmt))); - + .forEach((elmt) -> elmt.subGrps.forEach((subElmt) -> result.put(subElmt.name, subElmt))); result.putAll(this.childElmts); - return result; } public Map getAllParentElmts() { Map result = new HashMap<>(); - this.parentElmts.values() - .forEach(elmt -> - elmt.subGrps.forEach( - subElmt -> result.put(subElmt.name, subElmt))); - + .forEach((elmt) -> elmt.subGrps.forEach((subElmt) -> result.put(subElmt.name, subElmt))); result.putAll(this.parentElmts); - return result; } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/SpringSecurityXsdParser.java b/config/src/test/java/org/springframework/security/config/doc/SpringSecurityXsdParser.java index 12e1e6b8f3..c58904745d 100644 --- a/config/src/test/java/org/springframework/security/config/doc/SpringSecurityXsdParser.java +++ b/config/src/test/java/org/springframework/security/config/doc/SpringSecurityXsdParser.java @@ -13,13 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.doc; -import org.springframework.util.StringUtils; - -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import java.util.stream.Stream; +import org.springframework.util.StringUtils; + /** * Parses the Spring Security Xsd Document * @@ -27,9 +33,11 @@ import java.util.stream.Stream; * @author Josh Cummings */ public class SpringSecurityXsdParser { + private XmlNode rootElement; private Set attrElmts = new LinkedHashSet<>(); + private Map elementNameToElement = new HashMap<>(); public SpringSecurityXsdParser(XmlNode rootElement) { @@ -38,7 +46,6 @@ public class SpringSecurityXsdParser { /** * Returns a map of the element name to the {@link Element}. - * * @return */ public Map parse() { @@ -48,71 +55,67 @@ public class SpringSecurityXsdParser { /** * Creates a Map of the name to an Element object of all the children of element. - * * @param node * @return */ private Map elements(XmlNode node) { Map elementNameToElement = new HashMap<>(); - - node.children().forEach(child -> { + node.children().forEach((child) -> { if ("element".equals(child.simpleName())) { Element e = elmt(child); elementNameToElement.put(e.getName(), e); - } else { + } + else { elementNameToElement.putAll(elements(child)); } }); - return elementNameToElement; } /** * Any children that are attribute will be returned as an Attribute object. - * * @param element * @return a collection of Attribute objects that are children of element. */ private Collection attrs(XmlNode element) { Collection attrs = new ArrayList<>(); - element.children().forEach(c -> { + element.children().forEach((c) -> { String name = c.simpleName(); if ("attribute".equals(name)) { attrs.add(attr(c)); - } else if ("element".equals(name)) { - } else { + } + else if (!"element".equals(name)) { attrs.addAll(attrs(c)); } }); - return attrs; } /** - * Any children will be searched for an attributeGroup, each of its children will be returned as an Attribute - * + * Any children will be searched for an attributeGroup, each of its children will be + * returned as an Attribute * @param element * @return */ private Collection attrgrps(XmlNode element) { Collection attrgrp = new ArrayList<>(); - - element.children().forEach(c -> { - if ("element".equals(c.simpleName())) { - - } else if ("attributeGroup".equals(c.simpleName())) { - if (c.attribute("name") != null) { - attrgrp.addAll(attrgrp(c)); - } else { - String name = c.attribute("ref").split(":")[1]; - XmlNode attrGrp = findNode(element, name); - attrgrp.addAll(attrgrp(attrGrp)); + element.children().forEach((c) -> { + if (!"element".equals(c.simpleName())) { + if ("attributeGroup".equals(c.simpleName())) { + if (c.attribute("name") != null) { + attrgrp.addAll(attrgrp(c)); + } + else { + String name = c.attribute("ref").split(":")[1]; + XmlNode attrGrp = findNode(element, name); + attrgrp.addAll(attrgrp(attrGrp)); + } + } + else { + attrgrp.addAll(attrgrps(c)); } - } else { - attrgrp.addAll(attrgrps(c)); } }); - return attrgrp; } @@ -121,23 +124,27 @@ public class SpringSecurityXsdParser { while (!"schema".equals(root.simpleName())) { root = root.parent().get(); } - + // @formatter:off return expand(root) - .filter(node -> name.equals(node.attribute("name"))) - .findFirst().orElseThrow(IllegalArgumentException::new); + .filter((node) -> name.equals(node.attribute("name"))) + .findFirst() + .orElseThrow(IllegalArgumentException::new); + // @formatter:on } private Stream expand(XmlNode root) { - return Stream.concat( - Stream.of(root), - root.children().flatMap(this::expand)); + // @formatter:off + return Stream.concat(Stream.of(root), root.children() + .flatMap(this::expand)); + // @formatter:on } /** - * Processes an individual attributeGroup by obtaining all the attributes and then looking for more attributeGroup elements and prcessing them. - * + * Processes an individual attributeGroup by obtaining all the attributes and then + * looking for more attributeGroup elements and prcessing them. * @param e - * @return all the attributes for a specific attributeGroup and any child attributeGroups + * @return all the attributes for a specific attributeGroup and any child + * attributeGroups */ private Collection attrgrp(XmlNode e) { Collection attrs = attrs(e); @@ -147,20 +154,16 @@ public class SpringSecurityXsdParser { /** * Obtains the description for a specific element - * * @param element * @return */ private String desc(XmlNode element) { - return element.child("annotation") - .flatMap(annotation -> annotation.child("documentation")) - .map(documentation -> documentation.text()) - .orElse(null); + return element.child("annotation").flatMap((annotation) -> annotation.child("documentation")) + .map((documentation) -> documentation.text()).orElse(null); } /** * Given an element creates an attribute from it. - * * @param n * @return */ @@ -169,8 +172,8 @@ public class SpringSecurityXsdParser { } /** - * Given an element creates an Element out of it by collecting all its attributes and child elements. - * + * Given an element creates an Element out of it by collecting all its attributes and + * child elements. * @param n * @return */ @@ -178,34 +181,30 @@ public class SpringSecurityXsdParser { String name = n.attribute("ref"); if (StringUtils.isEmpty(name)) { name = n.attribute("name"); - } else { + } + else { name = name.split(":")[1]; n = findNode(n, name); } - if (this.elementNameToElement.containsKey(name)) { return this.elementNameToElement.get(name); } this.attrElmts.add(name); - Element e = new Element(); e.setName(n.attribute("name")); e.setDesc(desc(n)); e.setChildElmts(elements(n)); e.setAttrs(attrs(n)); e.getAttrs().addAll(attrgrps(n)); - e.getAttrs().forEach(attr -> attr.setElmt(e)); - e.getChildElmts().values().forEach(element -> - element.getParentElmts().put(e.getName(), e)); - + e.getAttrs().forEach((attr) -> attr.setElmt(e)); + e.getChildElmts().values().forEach((element) -> element.getParentElmts().put(e.getName(), e)); String subGrpName = n.attribute("substitutionGroup"); if (!StringUtils.isEmpty(subGrpName)) { Element subGrp = elmt(findNode(n, subGrpName.split(":")[1])); subGrp.getSubGrps().add(e); } - this.elementNameToElement.put(name, e); - return e; } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/XmlNode.java b/config/src/test/java/org/springframework/security/config/doc/XmlNode.java index 0de48463f4..5ed6c38c0e 100644 --- a/config/src/test/java/org/springframework/security/config/doc/XmlNode.java +++ b/config/src/test/java/org/springframework/security/config/doc/XmlNode.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.doc; -import org.w3c.dom.Node; -import org.w3c.dom.NodeList; +package org.springframework.security.config.doc; import java.util.Optional; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; + /** * @author Josh Cummings */ public class XmlNode { + private final Node node; public XmlNode(Node node) { @@ -34,7 +36,7 @@ public class XmlNode { public String simpleName() { String[] parts = this.node.getNodeName().split(":"); - return parts[parts.length-1]; + return parts[parts.length - 1]; } public String text() { @@ -43,31 +45,35 @@ public class XmlNode { public Stream children() { NodeList children = this.node.getChildNodes(); - + // @formatter:off return IntStream.range(0, children.getLength()) .mapToObj(children::item) .map(XmlNode::new); + // @formatter:on } public Optional child(String name) { - return this.children() - .filter(child -> name.equals(child.simpleName())) - .findFirst(); + return this.children().filter((child) -> name.equals(child.simpleName())).findFirst(); } public Optional parent() { + // @formatter:off return Optional.ofNullable(this.node.getParentNode()) - .map(parent -> new XmlNode(parent)); + .map((parent) -> new XmlNode(parent)); + // @formatter:on } public String attribute(String name) { + // @formatter:off return Optional.ofNullable(this.node.getAttributes()) - .map(attrs -> attrs.getNamedItem(name)) - .map(attr -> attr.getTextContent()) + .map((attrs) -> attrs.getNamedItem(name)) + .map((attr) -> attr.getTextContent()) .orElse(null); + // @formatter:on } public Node node() { return this.node; } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/XmlParser.java b/config/src/test/java/org/springframework/security/config/doc/XmlParser.java index ffda4846c0..218f6cc5e1 100644 --- a/config/src/test/java/org/springframework/security/config/doc/XmlParser.java +++ b/config/src/test/java/org/springframework/security/config/doc/XmlParser.java @@ -13,20 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.doc; -import org.xml.sax.SAXException; +import java.io.IOException; +import java.io.InputStream; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; -import java.io.IOException; -import java.io.InputStream; + +import org.xml.sax.SAXException; /** * @author Josh Cummings */ public class XmlParser implements AutoCloseable { + private InputStream xml; public XmlParser(InputStream xml) { @@ -37,10 +40,10 @@ public class XmlParser implements AutoCloseable { try { DocumentBuilderFactory dbFactory = DocumentBuilderFactory.newInstance(); DocumentBuilder dBuilder = dbFactory.newDocumentBuilder(); - return new XmlNode(dBuilder.parse(this.xml)); - } catch ( IOException | ParserConfigurationException | SAXException e ) { - throw new IllegalStateException(e); + } + catch (IOException | ParserConfigurationException | SAXException ex) { + throw new IllegalStateException(ex); } } @@ -48,4 +51,5 @@ public class XmlParser implements AutoCloseable { public void close() throws IOException { this.xml.close(); } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/XmlSupport.java b/config/src/test/java/org/springframework/security/config/doc/XmlSupport.java index 3e18d00335..dc6337ec42 100644 --- a/config/src/test/java/org/springframework/security/config/doc/XmlSupport.java +++ b/config/src/test/java/org/springframework/security/config/doc/XmlSupport.java @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.doc; -import org.springframework.core.io.ClassPathResource; +package org.springframework.security.config.doc; import java.io.IOException; import java.util.Map; +import org.springframework.core.io.ClassPathResource; + /** * Support for ensuring preparing the givens in {@link XsdDocumentedTests} * * @author Josh Cummings */ public class XmlSupport { + private XmlParser parser; public XmlNode parse(String location) throws IOException { ClassPathResource resource = new ClassPathResource(location); this.parser = new XmlParser(resource.getInputStream()); - return this.parser.parse(); } @@ -41,8 +42,9 @@ public class XmlSupport { } public void close() throws IOException { - if ( this.parser != null ) { + if (this.parser != null) { this.parser.close(); } } + } diff --git a/config/src/test/java/org/springframework/security/config/doc/XsdDocumentedTests.java b/config/src/test/java/org/springframework/security/config/doc/XsdDocumentedTests.java index cd00c06896..9201f09dde 100644 --- a/config/src/test/java/org/springframework/security/config/doc/XsdDocumentedTests.java +++ b/config/src/test/java/org/springframework/security/config/doc/XsdDocumentedTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.doc; import java.io.IOException; @@ -45,23 +46,25 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class XsdDocumentedTests { - Collection ignoredIds = Arrays.asList( - "nsa-any-user-service", - "nsa-any-user-service-parents", - "nsa-authentication", - "nsa-websocket-security", - "nsa-ldap", - "nsa-method-security", - "nsa-web", - // deprecated and for removal - "nsa-frame-options-strategy", - "nsa-frame-options-ref", - "nsa-frame-options-value", - "nsa-frame-options-from-parameter"); + // @formatter:off + Collection ignoredIds = Arrays.asList("nsa-any-user-service", + "nsa-any-user-service-parents", + "nsa-authentication", + "nsa-websocket-security", + "nsa-ldap", + "nsa-method-security", + "nsa-web", + // deprecated and for removal + "nsa-frame-options-strategy", + "nsa-frame-options-ref", + "nsa-frame-options-value", + "nsa-frame-options-from-parameter"); + // @formatter:on String referenceLocation = "../docs/manual/src/docs/asciidoc/_includes/servlet/appendix/namespace.adoc"; String schema31xDocumentLocation = "org/springframework/security/config/spring-security-3.1.xsd"; + String schemaDocumentLocation = "org/springframework/security/config/spring-security-5.4.xsd"; XmlSupport xml = new XmlSupport(); @@ -72,234 +75,210 @@ public class XsdDocumentedTests { } @Test - public void parseWhenLatestXsdThenAllNamedSecurityFiltersAreDefinedAndOrderedProperly() - throws IOException { + public void parseWhenLatestXsdThenAllNamedSecurityFiltersAreDefinedAndOrderedProperly() throws IOException { XmlNode root = this.xml.parse(this.schemaDocumentLocation); - - List nodes = - root.child("schema") + // @formatter:off + List nodes = root.child("schema") .map(XmlNode::children) .orElse(Stream.empty()) - .filter(node -> - "simpleType".equals(node.simpleName()) && - "named-security-filter".equals(node.attribute("name"))) + .filter((node) -> "simpleType".equals(node.simpleName()) + && "named-security-filter".equals(node.attribute("name"))) .flatMap(XmlNode::children) .flatMap(XmlNode::children) - .map(node -> node.attribute("value")) + .map((node) -> node.attribute("value")) .filter(StringUtils::isNotEmpty) .collect(Collectors.toList()); - + // @formatter:on SecurityFiltersAssertions.assertEquals(nodes); } @Test - public void parseWhen31XsdThenAllNamedSecurityFiltersAreDefinedAndOrderedProperly() - throws IOException { - - List expected = Arrays.asList( - "FIRST", - "CHANNEL_FILTER", - "SECURITY_CONTEXT_FILTER", - "CONCURRENT_SESSION_FILTER", - "LOGOUT_FILTER", - "X509_FILTER", - "PRE_AUTH_FILTER", - "CAS_FILTER", - "FORM_LOGIN_FILTER", - "OPENID_FILTER", - "LOGIN_PAGE_FILTER", - "DIGEST_AUTH_FILTER", - "BASIC_AUTH_FILTER", - "REQUEST_CACHE_FILTER", - "SERVLET_API_SUPPORT_FILTER", - "JAAS_API_SUPPORT_FILTER", - "REMEMBER_ME_FILTER", - "ANONYMOUS_FILTER", - "SESSION_MANAGEMENT_FILTER", - "EXCEPTION_TRANSLATION_FILTER", - "FILTER_SECURITY_INTERCEPTOR", - "SWITCH_USER_FILTER", - "LAST" - ); - + public void parseWhen31XsdThenAllNamedSecurityFiltersAreDefinedAndOrderedProperly() throws IOException { + // @formatter:off + List expected = Arrays.asList("FIRST", + "CHANNEL_FILTER", + "SECURITY_CONTEXT_FILTER", + "CONCURRENT_SESSION_FILTER", + "LOGOUT_FILTER", + "X509_FILTER", + "PRE_AUTH_FILTER", + "CAS_FILTER", + "FORM_LOGIN_FILTER", + "OPENID_FILTER", + "LOGIN_PAGE_FILTER", + "DIGEST_AUTH_FILTER", + "BASIC_AUTH_FILTER", + "REQUEST_CACHE_FILTER", + "SERVLET_API_SUPPORT_FILTER", + "JAAS_API_SUPPORT_FILTER", + "REMEMBER_ME_FILTER", + "ANONYMOUS_FILTER", + "SESSION_MANAGEMENT_FILTER", + "EXCEPTION_TRANSLATION_FILTER", + "FILTER_SECURITY_INTERCEPTOR", + "SWITCH_USER_FILTER", + "LAST"); + // @formatter:on XmlNode root = this.xml.parse(this.schema31xDocumentLocation); - - List nodes = - root.child("schema") + // @formatter:off + List nodes = root.child("schema") .map(XmlNode::children) .orElse(Stream.empty()) - .filter(node -> - "simpleType".equals(node.simpleName()) && - "named-security-filter".equals(node.attribute("name"))) + .filter((node) -> "simpleType".equals(node.simpleName()) + && "named-security-filter".equals(node.attribute("name"))) .flatMap(XmlNode::children) .flatMap(XmlNode::children) - .map(node -> node.attribute("value")) + .map((node) -> node.attribute("value")) .filter(StringUtils::isNotEmpty) .collect(Collectors.toList()); - + // @formatter:on assertThat(nodes).isEqualTo(expected); } /** - * This will check to ensure that the expected number of xsd documents are found to ensure that we are validating - * against the current xsd document. If this test fails, all that is needed is to update the schemaDocument - * and the expected size for this test. + * This will check to ensure that the expected number of xsd documents are found to + * ensure that we are validating against the current xsd document. If this test fails, + * all that is needed is to update the schemaDocument and the expected size for this + * test. * @return */ @Test - public void sizeWhenReadingFilesystemThenIsCorrectNumberOfSchemaFiles() - throws IOException { - + public void sizeWhenReadingFilesystemThenIsCorrectNumberOfSchemaFiles() throws IOException { ClassPathResource resource = new ClassPathResource(this.schemaDocumentLocation); - - String[] schemas = resource.getFile().getParentFile().list((dir, name) -> name.endsWith(".xsd")); - + // @formatter:off + String[] schemas = resource.getFile() + .getParentFile() + .list((dir, name) -> name.endsWith(".xsd")); + // @formatter:on assertThat(schemas.length).isEqualTo(16) - .withFailMessage("the count is equal to 16, if not then schemaDocument needs updating"); + .withFailMessage("the count is equal to 16, if not then schemaDocument needs updating"); } /** - * This uses a naming convention for the ids of the appendix to ensure that the entire appendix is documented. - * The naming convention for the ids is documented in {@link Element#getIds()}. + * This uses a naming convention for the ids of the appendix to ensure that the entire + * appendix is documented. The naming convention for the ids is documented in + * {@link Element#getIds()}. * @return */ @Test - public void countReferencesWhenReviewingDocumentationThenEntireSchemaIsIncluded() - throws IOException { - - Map elementsByElementName = - this.xml.elementsByElementName(this.schemaDocumentLocation); - - List documentIds = - Files.lines(Paths.get(this.referenceLocation)) - .filter(line -> line.matches("\\[\\[(nsa-.*)\\]\\]")) - .map(line -> line.substring(2, line.length() - 2)) + public void countReferencesWhenReviewingDocumentationThenEntireSchemaIsIncluded() throws IOException { + Map elementsByElementName = this.xml.elementsByElementName(this.schemaDocumentLocation); + // @formatter:off + List documentIds = Files.lines(Paths.get(this.referenceLocation)) + .filter((line) -> line.matches("\\[\\[(nsa-.*)\\]\\]")) + .map((line) -> line.substring(2, line.length() - 2)) .collect(Collectors.toList()); - - Set expectedIds = - elementsByElementName.values().stream() - .flatMap(element -> element.getIds().stream()) + Set expectedIds = elementsByElementName.values() + .stream() + .flatMap((element) -> element.getIds().stream()) .collect(Collectors.toSet()); - + // @formatter:on documentIds.removeAll(this.ignoredIds); expectedIds.removeAll(this.ignoredIds); - assertThat(documentIds).containsAll(expectedIds); assertThat(expectedIds).containsAll(documentIds); } /** - * This test ensures that any element that has children or parents contains a section that has links pointing to that - * documentation. + * This test ensures that any element that has children or parents contains a section + * that has links pointing to that documentation. * @return */ @Test - public void countLinksWhenReviewingDocumentationThenParentsAndChildrenAreCorrectlyLinked() - throws IOException { - + public void countLinksWhenReviewingDocumentationThenParentsAndChildrenAreCorrectlyLinked() throws IOException { Map> docAttrNameToChildren = new HashMap<>(); Map> docAttrNameToParents = new HashMap<>(); - String docAttrName = null; Map> currentDocAttrNameToElmt = null; - List lines = Files.readAllLines(Paths.get(this.referenceLocation)); - for ( String line : lines ) { + for (String line : lines) { if (line.matches("^\\[\\[.*\\]\\]$")) { String id = line.substring(2, line.length() - 2); - if (id.endsWith("-children")) { docAttrName = id.substring(0, id.length() - 9); currentDocAttrNameToElmt = docAttrNameToChildren; - } else if (id.endsWith("-parents")) { + } + else if (id.endsWith("-parents")) { docAttrName = id.substring(0, id.length() - 8); currentDocAttrNameToElmt = docAttrNameToParents; - } else if (docAttrName != null && !id.startsWith(docAttrName)) { + } + else if (docAttrName != null && !id.startsWith(docAttrName)) { currentDocAttrNameToElmt = null; docAttrName = null; } } - if (docAttrName != null && currentDocAttrNameToElmt != null) { String expression = "^\\* <<(nsa-.*),.*>>$"; if (line.matches(expression)) { String elmtId = line.replaceAll(expression, "$1"); - currentDocAttrNameToElmt - .computeIfAbsent(docAttrName, key -> new ArrayList<>()) - .add(elmtId); + currentDocAttrNameToElmt.computeIfAbsent(docAttrName, (key) -> new ArrayList<>()).add(elmtId); } } } - Map elementNameToElement = this.xml.elementsByElementName(this.schemaDocumentLocation); - Map> schemaAttrNameToChildren = new HashMap<>(); Map> schemaAttrNameToParents = new HashMap<>(); - - elementNameToElement.entrySet().stream() - .forEach(entry -> { - String key = "nsa-" + entry.getKey(); - if (this.ignoredIds.contains(key) ) { - return; - } - - List parentIds = - entry.getValue().getAllParentElmts().values().stream() - .filter(element -> !this.ignoredIds.contains(element.getId())) - .map(element -> element.getId()) - .sorted() - .collect(Collectors.toList()); - if ( !parentIds.isEmpty() ) { - schemaAttrNameToParents.put(key, parentIds); - } - - List childIds = - entry.getValue().getAllChildElmts().values().stream() - .filter(element -> !this.ignoredIds.contains(element.getId())) - .map(element -> element.getId()) - .sorted() - .collect(Collectors.toList()); - if ( !childIds.isEmpty() ) { - schemaAttrNameToChildren.put(key, childIds); - } - }); - + elementNameToElement.entrySet().stream().forEach((entry) -> { + String key = "nsa-" + entry.getKey(); + if (this.ignoredIds.contains(key)) { + return; + } + // @formatter:off + List parentIds = entry.getValue() + .getAllParentElmts() + .values() + .stream() + .filter((element) -> !this.ignoredIds.contains(element.getId())) + .map((element) -> element.getId()) + .sorted() + .collect(Collectors.toList()); + // @formatter:on + if (!parentIds.isEmpty()) { + schemaAttrNameToParents.put(key, parentIds); + } + // @formatter:off + List childIds = entry.getValue() + .getAllChildElmts() + .values() + .stream() + .filter((element) -> !this.ignoredIds.contains(element.getId())).map((element) -> element.getId()) + .sorted() + .collect(Collectors.toList()); + // @formatter:on + if (!childIds.isEmpty()) { + schemaAttrNameToChildren.put(key, childIds); + } + }); assertThat(docAttrNameToChildren).isEqualTo(schemaAttrNameToChildren); assertThat(docAttrNameToParents).isEqualTo(schemaAttrNameToParents); } - /** * This test checks each xsd element and ensures there is documentation for it. * @return */ @Test - public void countWhenReviewingDocumentationThenAllElementsDocumented() - throws IOException { - - Map elementNameToElement = - this.xml.elementsByElementName(this.schemaDocumentLocation); - - String notDocElmtIds = - elementNameToElement.values().stream() - .filter(element -> - StringUtils.isEmpty(element.getDesc()) && - !this.ignoredIds.contains(element.getId())) - .map(element -> element.getId()) - .sorted() - .collect(Collectors.joining("\n")); - - String notDocAttrIds = - elementNameToElement.values().stream() - .flatMap(element -> element.getAttrs().stream()) - .filter(element -> - StringUtils.isEmpty(element.getDesc()) && - !this.ignoredIds.contains(element.getId())) - .map(element -> element.getId()) - .sorted() - .collect(Collectors.joining("\n")); - + public void countWhenReviewingDocumentationThenAllElementsDocumented() throws IOException { + Map elementNameToElement = this.xml.elementsByElementName(this.schemaDocumentLocation); + // @formatter:off + String notDocElmtIds = elementNameToElement.values() + .stream() + .filter((element) -> StringUtils.isEmpty(element.getDesc()) + && !this.ignoredIds.contains(element.getId())) + .map((element) -> element.getId()) + .sorted() + .collect(Collectors.joining("\n")); + String notDocAttrIds = elementNameToElement.values() + .stream() + .flatMap((element) -> element.getAttrs().stream()) + .filter((element) -> StringUtils.isEmpty(element.getDesc()) + && !this.ignoredIds.contains(element.getId())) + .map((element) -> element.getId()) + .sorted() + .collect(Collectors.joining("\n")); + // @formatter:on assertThat(notDocElmtIds).isEmpty(); assertThat(notDocAttrIds).isEmpty(); } + } diff --git a/config/src/test/java/org/springframework/security/config/http/AccessDeniedConfigTests.java b/config/src/test/java/org/springframework/security/config/http/AccessDeniedConfigTests.java index 2762628079..1344c1b3a5 100644 --- a/config/src/test/java/org/springframework/security/config/http/AccessDeniedConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/AccessDeniedConfigTests.java @@ -13,12 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.eclipse.jetty.http.HttpStatus; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; @@ -31,23 +36,19 @@ import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.web.servlet.MockMvc; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Luke Taylor * @author Josh Cummings */ @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class AccessDeniedConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/AccessDeniedConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/AccessDeniedConfigTests"; @Autowired MockMvc mvc; @@ -58,30 +59,22 @@ public class AccessDeniedConfigTests { @Test public void configureWhenAccessDeniedHandlerIsMissingLeadingSlashThenException() { SpringTestContext context = this.spring.configLocations(this.xml("NoLeadingSlash")); - - assertThatThrownBy(() -> context.autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("errorPage must begin with '/'"); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> context.autowire()) + .withMessageContaining("errorPage must begin with '/'"); } @Test @WithMockUser - public void configureWhenAccessDeniedHandlerRefThenAutowire() - throws Exception { - + public void configureWhenAccessDeniedHandlerRefThenAutowire() throws Exception { this.spring.configLocations(this.xml("AccessDeniedHandler")).autowire(); - - this.mvc.perform(get("/")) - .andExpect(status().is(HttpStatus.GONE_410)); + this.mvc.perform(get("/")).andExpect(status().is(HttpStatus.GONE_410)); } @Test public void configureWhenAccessDeniedHandlerUsesPathAndRefThenException() { SpringTestContext context = this.spring.configLocations(this.xml("UsesPathAndRef")); - - assertThatThrownBy(() -> context.autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("attribute error-page cannot be used together with the 'ref' attribute"); + assertThatExceptionOfType(BeanDefinitionParsingException.class).isThrownBy(() -> context.autowire()) + .withMessageContaining("attribute error-page cannot be used together with the 'ref' attribute"); } private String xml(String configName) { @@ -91,11 +84,11 @@ public class AccessDeniedConfigTests { public static class GoneAccessDeniedHandler implements AccessDeniedHandler { @Override - public void handle(HttpServletRequest request, - HttpServletResponse response, - AccessDeniedException accessDeniedException) { - + public void handle(HttpServletRequest request, HttpServletResponse response, + AccessDeniedException accessDeniedException) { response.setStatus(HttpStatus.GONE_410); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfBeanDefinitionParserTests.java index 145547c52a..d8e9862560 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfBeanDefinitionParserTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.junit.Test; @@ -23,8 +24,8 @@ import org.springframework.context.support.ClassPathXmlApplicationContext; * @author Ankur Pathak */ public class CsrfBeanDefinitionParserTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/CsrfBeanDefinitionParserTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/CsrfBeanDefinitionParserTests"; @Test public void registerDataValueProcessorOnlyIfNotRegistered() { @@ -38,4 +39,5 @@ public class CsrfBeanDefinitionParserTests { private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + } diff --git a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java index fa4b41b400..7f4894f161 100644 --- a/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java @@ -13,12 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.net.URI; +import java.util.List; + +import javax.servlet.Filter; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.eclipse.jetty.http.HttpStatus; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; @@ -39,22 +48,17 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.ResponseBody; -import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.support.RequestDataValueProcessor; -import javax.servlet.Filter; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.net.URI; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; @@ -67,25 +71,16 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; -import static org.springframework.web.bind.annotation.RequestMethod.DELETE; -import static org.springframework.web.bind.annotation.RequestMethod.GET; -import static org.springframework.web.bind.annotation.RequestMethod.HEAD; -import static org.springframework.web.bind.annotation.RequestMethod.OPTIONS; -import static org.springframework.web.bind.annotation.RequestMethod.PATCH; -import static org.springframework.web.bind.annotation.RequestMethod.POST; -import static org.springframework.web.bind.annotation.RequestMethod.PUT; -import static org.springframework.web.bind.annotation.RequestMethod.TRACE; /** - * * @author Rob Winch * @author Josh Cummings */ @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class CsrfConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/CsrfConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/CsrfConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -95,298 +90,249 @@ public class CsrfConfigTests { @Test public void postWhenDefaultConfigurationThenForbiddenSinceCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(post("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void putWhenDefaultConfigurationThenForbiddenSinceCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(put("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void patchWhenDefaultConfigurationThenForbiddenSinceCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(patch("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void deleteWhenDefaultConfigurationThenForbiddenSinceCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(delete("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void invalidWhenDefaultConfigurationThenForbiddenSinceCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(request("INVALID", new URI("/csrf"))) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void getWhenDefaultConfigurationThenCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(get("/csrf")) .andExpect(csrfInBody()); + // @formatter:on } - @Test public void headWhenDefaultConfigurationThenCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(head("/csrf-in-header")) .andExpect(csrfInHeader()); + // @formatter:on } @Test public void traceWhenDefaultConfigurationThenCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - - MockMvc traceEnabled = MockMvcBuilders - .webAppContextSetup((WebApplicationContext) this.spring.getContext()) + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); + // @formatter:off + MockMvc traceEnabled = MockMvcBuilders.webAppContextSetup(this.spring.getContext()) .apply(springSecurity()) - .addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setDispatchTraceRequest(true)) + .addDispatcherServletCustomizer((dispatcherServlet) -> dispatcherServlet.setDispatchTraceRequest(true)) .build(); - traceEnabled.perform(request(HttpMethod.TRACE, "/csrf-in-header")) .andExpect(csrfInHeader()); + // @formatter:on } @Test public void optionsWhenDefaultConfigurationThenCsrfIsEnabled() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); + // @formatter:off this.mvc.perform(options("/csrf-in-header")) .andExpect(csrfInHeader()); + // @formatter:on } @Test public void postWhenCsrfDisabledThenRequestAllowed() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("CsrfDisabled") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("CsrfDisabled")).autowire(); + // @formatter:off this.mvc.perform(post("/ok")) .andExpect(status().isOk()); - + // @formatter:on assertThat(getFilter(this.spring, CsrfFilter.class)).isNull(); } @Test public void postWhenCsrfElementEnabledThenForbidden() throws Exception { - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(post("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void putWhenCsrfElementEnabledThenForbidden() throws Exception { - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(put("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void patchWhenCsrfElementEnabledThenForbidden() throws Exception { - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(patch("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void deleteWhenCsrfElementEnabledThenForbidden() throws Exception { - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(delete("/csrf")) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void invalidWhenCsrfElementEnabledThenForbidden() throws Exception { - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(request("INVALID", new URI("/csrf"))) .andExpect(status().isForbidden()) .andExpect(csrfCreated()); + // @formatter:on } @Test public void getWhenCsrfElementEnabledThenOk() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(get("/csrf")) .andExpect(csrfInBody()); + // @formatter:on } @Test public void headWhenCsrfElementEnabledThenOk() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(head("/csrf-in-header")) .andExpect(csrfInHeader()); + // @formatter:on } @Test public void traceWhenCsrfElementEnabledThenOk() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("CsrfEnabled") - ).autowire(); - - MockMvc traceEnabled = MockMvcBuilders - .webAppContextSetup((WebApplicationContext) this.spring.getContext()) + this.spring.configLocations(this.xml("shared-controllers"), this.xml("CsrfEnabled")).autowire(); + // @formatter:off + MockMvc traceEnabled = MockMvcBuilders.webAppContextSetup(this.spring.getContext()) .apply(springSecurity()) - .addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setDispatchTraceRequest(true)) + .addDispatcherServletCustomizer((dispatcherServlet) -> dispatcherServlet.setDispatchTraceRequest(true)) .build(); - - traceEnabled.perform(request(HttpMethod.TRACE, "/csrf-in-header")) - .andExpect(csrfInHeader()); + // @formatter:on + traceEnabled.perform(request(HttpMethod.TRACE, "/csrf-in-header")).andExpect(csrfInHeader()); } @Test public void optionsWhenCsrfElementEnabledThenOk() throws Exception { - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("shared-controllers"), this.xml("CsrfEnabled")).autowire(); + // @formatter:off this.mvc.perform(options("/csrf-in-header")) .andExpect(csrfInHeader()); + // @formatter:on } @Test public void autowireWhenCsrfElementEnabledThenCreatesCsrfRequestDataValueProcessor() { - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); assertThat(this.spring.getContext().getBean(RequestDataValueProcessor.class)).isNotNull(); } @Test - public void postWhenUsingCsrfAndCustomAccessDeniedHandlerThenTheHandlerIsAppropriatelyEngaged() - throws Exception { - - this.spring.configLocations( - this.xml("WithAccessDeniedHandler"), - this.xml("shared-access-denied-handler") - ).autowire(); - + public void postWhenUsingCsrfAndCustomAccessDeniedHandlerThenTheHandlerIsAppropriatelyEngaged() throws Exception { + this.spring.configLocations(this.xml("WithAccessDeniedHandler"), this.xml("shared-access-denied-handler")) + .autowire(); + // @formatter:off this.mvc.perform(post("/ok")) .andExpect(status().isIAmATeapot()); + // @formatter:on } @Test public void postWhenHasCsrfTokenButSessionExpiresThenRequestIsCancelledAfterSuccessfulAuthentication() - throws Exception { - - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + throws Exception { + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); // simulates a request that has no authentication (e.g. session time-out) - MvcResult result = this.mvc.perform(post("/authenticated") - .with(csrf())) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn(); - + MvcResult result = this.mvc.perform(post("/authenticated").with(csrf())) + .andExpect(redirectedUrl("http://localhost/login")).andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); - - // if the request cache is consulted, then it will redirect back to /some-url, which we don't want - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session(session) - .with(csrf())) + // if the request cache is consulted, then it will redirect back to /some-url, + // which we don't want + // @formatter:off + MockHttpServletRequestBuilder login = post("/login") + .param("username", "user") + .param("password", "password") + .session(session) + .with(csrf()); + this.mvc.perform(login) .andExpect(redirectedUrl("/")); + // @formatter:on } @Test public void getWhenHasCsrfTokenButSessionExpiresThenRequestIsRememeberedAfterSuccessfulAuthentication() throws Exception { - - this.spring.configLocations( - this.xml("CsrfEnabled") - ).autowire(); - + this.spring.configLocations(this.xml("CsrfEnabled")).autowire(); // simulates a request that has no authentication (e.g. session time-out) - MvcResult result = - this.mvc.perform(get("/authenticated")) - .andExpect(redirectedUrl("http://localhost/login")) - .andReturn(); - + MvcResult result = this.mvc.perform(get("/authenticated")).andExpect(redirectedUrl("http://localhost/login")) + .andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); - - // if the request cache is consulted, then it will redirect back to /some-url, which we do want - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session(session) - .with(csrf())) + // if the request cache is consulted, then it will redirect back to /some-url, + // which we do want + // @formatter:off + MockHttpServletRequestBuilder login = post("/login") + .param("username", "user") + .param("password", "password") + .session(session) + .with(csrf()); + this.mvc.perform(login) .andExpect(redirectedUrl("http://localhost/authenticated")); + // @formatter:on } /** @@ -394,125 +340,92 @@ public class CsrfConfigTests { */ @Test public void postWhenUsingCsrfAndCustomSessionManagementAndNoSessionThenStillRedirectsToInvalidSessionUrl() - throws Exception { - - this.spring.configLocations( - this.xml("WithSessionManagement") - ).autowire(); - - MvcResult result = this.mvc.perform(post("/ok").param("_csrf", "abc")) - .andExpect(redirectedUrl("/error/sessionError")) - .andReturn(); - + throws Exception { + this.spring.configLocations(this.xml("WithSessionManagement")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder postToOk = post("/ok") + .param("_csrf", "abc"); + MvcResult result = this.mvc.perform(postToOk) + .andExpect(redirectedUrl("/error/sessionError")) + .andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); - - this.mvc.perform(post("/csrf") - .session(session)) - .andExpect(status().isForbidden()); + this.mvc.perform(post("/csrf").session(session)) + .andExpect(status().isForbidden()); + // @formatter:on } @Test - public void requestWhenUsingCustomRequestMatcherConfiguredThenAppliesAccordingly() - throws Exception { - - SpringTestContext context = - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("WithRequestMatcher"), - this.xml("mock-request-matcher") - ); - + public void requestWhenUsingCustomRequestMatcherConfiguredThenAppliesAccordingly() throws Exception { + SpringTestContext context = this.spring.configLocations(this.xml("shared-controllers"), + this.xml("WithRequestMatcher"), this.xml("mock-request-matcher")); context.autowire(); - RequestMatcher matcher = context.getContext().getBean(RequestMatcher.class); - when(matcher.matches(any(HttpServletRequest.class))).thenReturn(false); - - this.mvc.perform(post("/ok")).andExpect(status().isOk()); - - when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true); - - this.mvc.perform(get("/ok")).andExpect(status().isForbidden()); + given(matcher.matches(any(HttpServletRequest.class))).willReturn(false); + // @formatter:off + this.mvc.perform(post("/ok")) + .andExpect(status().isOk()); + // @formatter:on + given(matcher.matches(any(HttpServletRequest.class))).willReturn(true); + // @formatter:off + this.mvc.perform(get("/ok")) + .andExpect(status().isForbidden()); + // @formatter:on } @Test - public void getWhenDefaultConfigurationThenSessionNotImmediatelyCreated() - throws Exception { - - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - + public void getWhenDefaultConfigurationThenSessionNotImmediatelyCreated() throws Exception { + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); + // @formatter:off MvcResult result = this.mvc.perform(get("/ok")) - .andExpect(status().isOk()) - .andReturn(); - + .andExpect(status().isOk()).andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test @WithMockUser - public void postWhenCsrfMismatchesThenForbidden() - throws Exception { - - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - + public void postWhenCsrfMismatchesThenForbidden() throws Exception { + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); MvcResult result = this.mvc.perform(get("/ok")).andReturn(); - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); - - this.mvc.perform(post("/ok") - .session(session) - .with(csrf().useInvalidToken())) + // @formatter:off + MockHttpServletRequestBuilder postOk = post("/ok") + .session(session) + .with(csrf().useInvalidToken()); + this.mvc.perform(postOk) .andExpect(status().isForbidden()); + // @formatter:on } @Test - public void loginWhenDefaultConfigurationThenCsrfCleared() - throws Exception { - - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - + public void loginWhenDefaultConfigurationThenCsrfCleared() throws Exception { + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); MvcResult result = this.mvc.perform(get("/csrf")).andReturn(); - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session(session) - .with(csrf())) - .andExpect(status().isFound()); - - this.mvc.perform(get("/csrf").session(session)) - .andExpect(csrfChanged(result)); - } - - @Test - public void logoutWhenDefaultConfigurationThenCsrfCleared() - throws Exception { - - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("AutoConfig") - ).autowire(); - - MvcResult result = this.mvc.perform(get("/csrf")).andReturn(); - - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); - - this.mvc.perform(post("/logout").session(session) - .with(csrf())) + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .session(session) + .with(csrf()); + this.mvc.perform(loginRequest) .andExpect(status().isFound()); - this.mvc.perform(get("/csrf").session(session)) .andExpect(csrfChanged(result)); + // @formatter:on + } + + @Test + public void logoutWhenDefaultConfigurationThenCsrfCleared() throws Exception { + this.spring.configLocations(this.xml("shared-controllers"), this.xml("AutoConfig")).autowire(); + MvcResult result = this.mvc.perform(get("/csrf")).andReturn(); + MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); + // @formatter:off + this.mvc.perform(post("/logout").session(session).with(csrf())) + .andExpect(status().isFound()); + this.mvc.perform(get("/csrf").session(session)) + .andExpect(csrfChanged(result)); + // @formatter:on } /** @@ -520,32 +433,28 @@ public class CsrfConfigTests { */ @Test @WithMockUser - public void logoutWhenDefaultConfigurationThenDisabled() - throws Exception { - - this.spring.configLocations( - this.xml("shared-controllers"), - this.xml("CsrfEnabled") - ).autowire(); - + public void logoutWhenDefaultConfigurationThenDisabled() throws Exception { + this.spring.configLocations(this.xml("shared-controllers"), this.xml("CsrfEnabled")).autowire(); + // renders form to log out but does not perform a redirect + // @formatter:off this.mvc.perform(get("/logout")) - .andExpect(status().isOk()); // renders form to log out but does not perform a redirect - + .andExpect(status().isOk()); + // @formatter:on // still logged in - this.mvc.perform(get("/authenticated")).andExpect(status().isOk()); + // @formatter:off + this.mvc.perform(get("/authenticated")) + .andExpect(status().isOk()); + // @formatter:on } private T getFilter(SpringTestContext context, Class type) { FilterChainProxy chain = context.getContext().getBean(FilterChainProxy.class); - List filters = chain.getFilters("/any"); - - for ( Filter filter : filters ) { - if ( type.isAssignableFrom(filter.getClass()) ) { + for (Filter filter : filters) { + if (type.isAssignableFrom(filter.getClass())) { return (T) filter; } } - return null; } @@ -553,46 +462,6 @@ public class CsrfConfigTests { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } - @Controller - public static class RootController { - @RequestMapping(value = "/csrf-in-header", method = { HEAD, TRACE, OPTIONS }) - @ResponseBody - String csrfInHeaderAndBody(CsrfToken token, HttpServletResponse response) { - response.setHeader(token.getHeaderName(), token.getToken()); - return csrfInBody(token); - } - - @RequestMapping(value = "/csrf", method = { POST, PUT, PATCH, DELETE, GET }) - @ResponseBody - String csrfInBody(CsrfToken token) { - return token.getToken(); - } - - @RequestMapping(value = "/ok", method = { POST, GET }) - @ResponseBody - String ok() { - return "ok"; - } - - @GetMapping("/authenticated") - @ResponseBody - String authenticated() { - return "authenticated"; - } - } - - private static class TeapotAccessDeniedHandler implements AccessDeniedHandler { - - @Override - public void handle( - HttpServletRequest request, - HttpServletResponse response, - AccessDeniedException accessDeniedException) { - - response.setStatus(HttpStatus.IM_A_TEAPOT_418); - } - } - ResultMatcher csrfChanged(MvcResult first) { return (second) -> { assertThat(first).isNotNull(); @@ -607,28 +476,75 @@ public class CsrfConfigTests { } ResultMatcher csrfInHeader() { - return new CsrfReturnedResultMatcher(result -> result.getResponse().getHeader("X-CSRF-TOKEN")); + return new CsrfReturnedResultMatcher((result) -> result.getResponse().getHeader("X-CSRF-TOKEN")); } ResultMatcher csrfInBody() { - return new CsrfReturnedResultMatcher(result -> result.getResponse().getContentAsString()); + return new CsrfReturnedResultMatcher((result) -> result.getResponse().getContentAsString()); + } + + @Controller + public static class RootController { + + @RequestMapping(value = "/csrf-in-header", + method = { RequestMethod.HEAD, RequestMethod.TRACE, RequestMethod.OPTIONS }) + @ResponseBody + String csrfInHeaderAndBody(CsrfToken token, HttpServletResponse response) { + response.setHeader(token.getHeaderName(), token.getToken()); + return csrfInBody(token); + } + + @RequestMapping(value = "/csrf", method = { RequestMethod.POST, RequestMethod.PUT, RequestMethod.PATCH, + RequestMethod.DELETE, RequestMethod.GET }) + @ResponseBody + String csrfInBody(CsrfToken token) { + return token.getToken(); + } + + @RequestMapping(value = "/ok", method = { RequestMethod.POST, RequestMethod.GET }) + @ResponseBody + String ok() { + return "ok"; + } + + @GetMapping("/authenticated") + @ResponseBody + String authenticated() { + return "authenticated"; + } + + } + + private static class TeapotAccessDeniedHandler implements AccessDeniedHandler { + + @Override + public void handle(HttpServletRequest request, HttpServletResponse response, + AccessDeniedException accessDeniedException) { + response.setStatus(HttpStatus.IM_A_TEAPOT_418); + } + } @FunctionalInterface interface ExceptionalFunction { + OUT apply(IN in) throws Exception; + } static class CsrfCreatedResultMatcher implements ResultMatcher { + @Override public void match(MvcResult result) { MockHttpServletRequest request = result.getRequest(); CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request); assertThat(token).isNotNull(); } + } static class CsrfReturnedResultMatcher implements ResultMatcher { + ExceptionalFunction token; CsrfReturnedResultMatcher(ExceptionalFunction token) { @@ -642,6 +558,7 @@ public class CsrfConfigTests { assertThat(token).isNotNull(); assertThat(token.getToken()).isEqualTo(this.token.apply(result)); } + } } diff --git a/config/src/test/java/org/springframework/security/config/http/DefaultFilterChainValidatorTests.java b/config/src/test/java/org/springframework/security/config/http/DefaultFilterChainValidatorTests.java index 2e87b07360..61895a4994 100644 --- a/config/src/test/java/org/springframework/security/config/http/DefaultFilterChainValidatorTests.java +++ b/config/src/test/java/org/springframework/security/config/http/DefaultFilterChainValidatorTests.java @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.http; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyObject; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; +package org.springframework.security.config.http; import java.util.Collection; @@ -29,7 +24,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.test.util.ReflectionTestUtils; + import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.core.Authentication; import org.springframework.security.web.AuthenticationEntryPoint; @@ -42,19 +37,30 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept import org.springframework.security.web.authentication.AnonymousAuthenticationFilter; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.util.matcher.AnyRequestMatcher; +import org.springframework.test.util.ReflectionTestUtils; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyObject; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** - * * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class DefaultFilterChainValidatorTests { + private DefaultFilterChainValidator validator; + private FilterChainProxy fcp; + @Mock private Log logger; + @Mock private DefaultFilterInvocationSecurityMetadataSource metadataSource; + @Mock private AccessDecisionManager accessDecisionManager; @@ -63,43 +69,39 @@ public class DefaultFilterChainValidatorTests { @Before public void setUp() { AnonymousAuthenticationFilter aaf = new AnonymousAuthenticationFilter("anonymous"); - fsi = new FilterSecurityInterceptor(); - fsi.setAccessDecisionManager(accessDecisionManager); - fsi.setSecurityMetadataSource(metadataSource); - AuthenticationEntryPoint authenticationEntryPoint = new LoginUrlAuthenticationEntryPoint( - "/login"); - ExceptionTranslationFilter etf = new ExceptionTranslationFilter( - authenticationEntryPoint); - DefaultSecurityFilterChain securityChain = new DefaultSecurityFilterChain( - AnyRequestMatcher.INSTANCE, aaf, etf, fsi); - fcp = new FilterChainProxy(securityChain); - validator = new DefaultFilterChainValidator(); - - ReflectionTestUtils.setField(validator, "logger", logger); + this.fsi = new FilterSecurityInterceptor(); + this.fsi.setAccessDecisionManager(this.accessDecisionManager); + this.fsi.setSecurityMetadataSource(this.metadataSource); + AuthenticationEntryPoint authenticationEntryPoint = new LoginUrlAuthenticationEntryPoint("/login"); + ExceptionTranslationFilter etf = new ExceptionTranslationFilter(authenticationEntryPoint); + DefaultSecurityFilterChain securityChain = new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, aaf, etf, + this.fsi); + this.fcp = new FilterChainProxy(securityChain); + this.validator = new DefaultFilterChainValidator(); + ReflectionTestUtils.setField(this.validator, "logger", this.logger); } // SEC-1878 @SuppressWarnings("unchecked") @Test public void validateCheckLoginPageIsntProtectedThrowsIllegalArgumentException() { - IllegalArgumentException toBeThrown = new IllegalArgumentException( - "failed to eval expression"); - doThrow(toBeThrown).when(accessDecisionManager).decide(any(Authentication.class), - anyObject(), any(Collection.class)); - validator.validate(fcp); - verify(logger) - .info("Unable to check access to the login page to determine if anonymous access is allowed. This might be an error, but can happen under normal circumstances.", - toBeThrown); + IllegalArgumentException toBeThrown = new IllegalArgumentException("failed to eval expression"); + willThrow(toBeThrown).given(this.accessDecisionManager).decide(any(Authentication.class), anyObject(), + any(Collection.class)); + this.validator.validate(this.fcp); + verify(this.logger).info( + "Unable to check access to the login page to determine if anonymous access is allowed. This might be an error, but can happen under normal circumstances.", + toBeThrown); } // SEC-1957 @Test public void validateCustomMetadataSource() { - FilterInvocationSecurityMetadataSource customMetaDataSource = mock(FilterInvocationSecurityMetadataSource.class); - fsi.setSecurityMetadataSource(customMetaDataSource); - - validator.validate(fcp); - + FilterInvocationSecurityMetadataSource customMetaDataSource = mock( + FilterInvocationSecurityMetadataSource.class); + this.fsi.setSecurityMetadataSource(customMetaDataSource); + this.validator.validate(this.fcp); verify(customMetaDataSource).getAttributes(any()); } + } diff --git a/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java index 29164a0f8c..503e0bd803 100644 --- a/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/FilterSecurityMetadataSourceBeanDefinitionParserTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.Collection; + import org.junit.After; import org.junit.Test; + import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; import org.springframework.context.support.AbstractXmlApplicationContext; import org.springframework.mock.web.MockFilterChain; @@ -31,8 +35,6 @@ import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.access.expression.ExpressionBasedFilterInvocationSecurityMetadataSource; import org.springframework.security.web.access.intercept.DefaultFilterInvocationSecurityMetadataSource; -import java.util.Collection; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -58,26 +60,27 @@ public class FilterSecurityMetadataSourceBeanDefinitionParserTests { @Test public void parsingMinimalConfigurationIsSuccessful() { + // @formatter:off setContext("" + " " + ""); + // @formatter:on DefaultFilterInvocationSecurityMetadataSource fids = (DefaultFilterInvocationSecurityMetadataSource) this.appContext .getBean("fids"); - Collection cad = fids - .getAttributes(createFilterInvocation("/anything", "GET")); + Collection cad = fids.getAttributes(createFilterInvocation("/anything", "GET")); assertThat(cad).contains(new SecurityConfig("ROLE_A")); } @Test public void expressionsAreSupported() { + // @formatter:off setContext("" + " " + ""); - + // @formatter:on ExpressionBasedFilterInvocationSecurityMetadataSource fids = (ExpressionBasedFilterInvocationSecurityMetadataSource) this.appContext .getBean("fids"); - ConfigAttribute[] cad = fids - .getAttributes(createFilterInvocation("/anything", "GET")) + ConfigAttribute[] cad = fids.getAttributes(createFilterInvocation("/anything", "GET")) .toArray(new ConfigAttribute[0]); assertThat(cad).hasSize(1); assertThat(cad[0].toString()).isEqualTo("hasRole('ROLE_A')"); @@ -88,20 +91,19 @@ public class FilterSecurityMetadataSourceBeanDefinitionParserTests { public void interceptUrlsSupportPropertyPlaceholders() { System.setProperty("secure.url", "/secure"); System.setProperty("secure.role", "ROLE_A"); - setContext( - "" - + "" - + " " - + ""); + setContext("" + + "" + + " " + + ""); DefaultFilterInvocationSecurityMetadataSource fids = (DefaultFilterInvocationSecurityMetadataSource) this.appContext .getBean("fids"); - Collection cad = fids - .getAttributes(createFilterInvocation("/secure", "GET")); + Collection cad = fids.getAttributes(createFilterInvocation("/secure", "GET")); assertThat(cad).containsExactly(new SecurityConfig("ROLE_A")); } @Test public void parsingWithinFilterSecurityInterceptorIsSuccessful() { + // @formatter:off setContext("" + "" + " " @@ -109,27 +111,27 @@ public class FilterSecurityMetadataSourceBeanDefinitionParserTests { + " " + " " + " " - + " " + " " - + " " + "" + + " " + + " " + + " " + + "" + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on } @Test(expected = BeanDefinitionParsingException.class) public void parsingInterceptUrlServletPathFails() { setContext("" - + " " - + ""); + + " " + + ""); } private FilterInvocation createFilterInvocation(String path, String method) { MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); request.setRequestURI(null); request.setMethod(method); - request.setServletPath(path); - - return new FilterInvocation(request, new MockHttpServletResponse(), - new MockFilterChain()); + return new FilterInvocation(request, new MockHttpServletResponse(), new MockFilterChain()); } + } diff --git a/config/src/test/java/org/springframework/security/config/http/FormLoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/FormLoginBeanDefinitionParserTests.java index 4bef5265fa..27bffb53d1 100644 --- a/config/src/test/java/org/springframework/security/config/http/FormLoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/FormLoginBeanDefinitionParserTests.java @@ -13,18 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.web.WebAttributes; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.core.IsNot.not; -import static org.hamcrest.core.IsNull.nullValue; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; @@ -32,15 +35,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; - /** - * * @author Luke Taylor * @author Josh Cummings */ public class FormLoginBeanDefinitionParserTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/FormLoginBeanDefinitionParserTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/FormLoginBeanDefinitionParserTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -49,208 +50,198 @@ public class FormLoginBeanDefinitionParserTests { MockMvc mvc; @Test - public void getLoginWhenAutoConfigThenShowsDefaultLoginPage() - throws Exception { - + public void getLoginWhenAutoConfigThenShowsDefaultLoginPage() throws Exception { this.spring.configLocations(this.xml("Simple")).autowire(); - - String expectedContent = - "\n" - + "\n" - + " \n" - + " \n" - + " \n" - + " \n" - + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + "
    \n" - + "
    \n" - + ""; - + // @formatter:off + String expectedContent = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + "
    \n" + + "
    \n" + + ""; + // @formatter:on this.mvc.perform(get("/login")).andExpect(content().string(expectedContent)); } @Test - public void getLogoutWhenAutoConfigThenShowsDefaultLogoutPage() - throws Exception { - + public void getLogoutWhenAutoConfigThenShowsDefaultLogoutPage() throws Exception { this.spring.configLocations(this.xml("AutoConfig")).autowire(); - - this.mvc.perform(get("/logout")) - .andExpect(content().string(containsString("action=\"/logout\""))); + this.mvc.perform(get("/logout")).andExpect(content().string(containsString("action=\"/logout\""))); } @Test - public void getLoginWhenConfiguredWithCustomAttributesThenLoginPageReflects() - throws Exception { - + public void getLoginWhenConfiguredWithCustomAttributesThenLoginPageReflects() throws Exception { this.spring.configLocations(this.xml("WithCustomAttributes")).autowire(); - - String expectedContent = - "\n" - + "\n" + " \n" - + " \n" - + " \n" - + " \n" - + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + "
    \n" - + "
    \n" - + ""; - - this.mvc.perform(get("/login")).andExpect(content().string(expectedContent)); - - this.mvc.perform(get("/logout")).andExpect(status().is3xxRedirection()); + // @formatter:off + String expectedContent = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + "
    \n" + + "
    \n" + + ""; + this.mvc.perform(get("/login")) + .andExpect(content().string(expectedContent)); + this.mvc.perform(get("/logout")) + .andExpect(status().is3xxRedirection()); + // @formatter:on } @Test - public void getLoginWhenConfiguredForOpenIdThenLoginPageReflects() - throws Exception { - + public void getLoginWhenConfiguredForOpenIdThenLoginPageReflects() throws Exception { this.spring.configLocations(this.xml("WithOpenId")).autowire(); - - String expectedContent = - "\n" + "\n" + " \n" - + " \n" - + " \n" - + " \n" - + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + "
    \n" - + "
    \n" - + ""; - + // @formatter:off + String expectedContent = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + "
    \n" + + "
    \n" + + ""; + // @formatter:on this.mvc.perform(get("/login")).andExpect(content().string(expectedContent)); } @Test - public void getLoginWhenConfiguredForOpenIdWithCustomAttributesThenLoginPageReflects() - throws Exception { - + public void getLoginWhenConfiguredForOpenIdWithCustomAttributesThenLoginPageReflects() throws Exception { this.spring.configLocations(this.xml("WithOpenIdCustomAttributes")).autowire(); - - String expectedContent = - "\n" + "\n" + " \n" - + " \n" - + " \n" - + " \n" - + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + "
    \n" - + "
    \n" - + ""; - + // @formatter:off + String expectedContent = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + "
    \n" + + "
    \n" + + ""; + // @formatter:on this.mvc.perform(get("/login")).andExpect(content().string(expectedContent)); } @Test - public void failedLoginWhenConfiguredWithCustomAuthenticationFailureThenForwardsAccordingly() - throws Exception { - + public void failedLoginWhenConfiguredWithCustomAuthenticationFailureThenForwardsAccordingly() throws Exception { this.spring.configLocations(this.xml("WithAuthenticationFailureForwardUrl")).autowire(); - - this.mvc.perform(post("/login") - .param("username", "bob") - .param("password", "invalidpassword")) + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "bob") + .param("password", "invalidpassword"); + this.mvc.perform(loginRequest) .andExpect(status().isOk()) .andExpect(forwardedUrl("/failure_forward_url")) .andExpect(request().attribute(WebAttributes.AUTHENTICATION_EXCEPTION, not(nullValue()))); + // @formatter:on } @Test - public void successfulLoginWhenConfiguredWithCustomAuthenticationSuccessThenForwardsAccordingly() - throws Exception { - + public void successfulLoginWhenConfiguredWithCustomAuthenticationSuccessThenForwardsAccordingly() throws Exception { this.spring.configLocations(this.xml("WithAuthenticationSuccessForwardUrl")).autowire(); - - this.mvc.perform(post("/login") + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") - .param("password", "password")) + .param("password", "password"); + this.mvc.perform(loginRequest) .andExpect(status().isOk()) .andExpect(forwardedUrl("/success_forward_url")); + // @formatter:on } private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + } diff --git a/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java b/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java index 1fd9c265e2..ca04337f27 100644 --- a/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java @@ -13,10 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.List; + +import javax.servlet.Filter; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; @@ -30,16 +38,12 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import javax.servlet.Filter; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -48,13 +52,12 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Luke Taylor * @author Josh Cummings */ public class FormLoginConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/FormLoginConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/FormLoginConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -63,94 +66,83 @@ public class FormLoginConfigTests { MockMvc mvc; @Test - public void getProtectedPageWhenFormLoginConfiguredThenRedirectsToDefaultLoginPage() - throws Exception { - + public void getProtectedPageWhenFormLoginConfiguredThenRedirectsToDefaultLoginPage() throws Exception { this.spring.configLocations(this.xml("WithAntRequestMatcher")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test - public void authenticateWhenDefaultTargetUrlConfiguredThenRedirectsAccordingly() - throws Exception { - + public void authenticateWhenDefaultTargetUrlConfiguredThenRedirectsAccordingly() throws Exception { this.spring.configLocations(this.xml("WithDefaultTargetUrl")).autowire(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .with(csrf())) - .andExpect(redirectedUrl("/default")); - } - - @Test - public void authenticateWhenConfiguredWithSpelThenRedirectsAccordingly() - throws Exception { - - this.spring.configLocations(this.xml("UsingSpel")).autowire(); - - this.mvc.perform(post("/login") + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") .param("password", "password") - .with(csrf())) - .andExpect(redirectedUrl(WebConfigUtilsTest.URL + "/default")); + .with(csrf()); + this.mvc.perform(loginRequest) + .andExpect(redirectedUrl("/default")); + // @formatter:on + } - this.mvc.perform(post("/login") + @Test + public void authenticateWhenConfiguredWithSpelThenRedirectsAccordingly() throws Exception { + this.spring.configLocations(this.xml("UsingSpel")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .with(csrf()); + this.mvc.perform(loginRequest) + .andExpect(redirectedUrl(WebConfigUtilsTests.URL + "/default")); + MockHttpServletRequestBuilder invalidPassword = post("/login") .param("username", "user") .param("password", "wrong") - .with(csrf())) - .andExpect(redirectedUrl(WebConfigUtilsTest.URL + "/failure")); - + .with(csrf()); + this.mvc.perform(invalidPassword) + .andExpect(redirectedUrl(WebConfigUtilsTests.URL + "/failure")); this.mvc.perform(get("/")) - .andExpect(redirectedUrl("http://localhost" + WebConfigUtilsTest.URL + "/login")); + .andExpect(redirectedUrl("http://localhost" + WebConfigUtilsTests.URL + "/login")); + // @formatter:on } @Test public void autowireWhenLoginPageIsMisconfiguredThenDetects() { - - assertThatThrownBy(() -> this.spring.configLocations(this.xml("NoLeadingSlashLoginPage")).autowire()) - .isInstanceOf(BeanCreationException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("NoLeadingSlashLoginPage")).autowire()); } @Test public void autowireWhenDefaultTargetUrlIsMisconfiguredThenDetects() { - - assertThatThrownBy(() -> this.spring.configLocations(this.xml("NoLeadingSlashDefaultTargetUrl")).autowire()) - .isInstanceOf(BeanCreationException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("NoLeadingSlashDefaultTargetUrl")).autowire()); } @Test - public void authenticateWhenCustomHandlerBeansConfiguredThenInvokesAccordingly() - throws Exception { - + public void authenticateWhenCustomHandlerBeansConfiguredThenInvokesAccordingly() throws Exception { this.spring.configLocations(this.xml("WithSuccessAndFailureHandlers")).autowire(); - - this.mvc.perform(post("/login") + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") .param("password", "password") - .with(csrf())) + .with(csrf()); + this.mvc.perform(loginRequest) .andExpect(status().isIAmATeapot()); - - this.mvc.perform(post("/login") + MockHttpServletRequestBuilder invalidPassword = post("/login") .param("username", "user") .param("password", "wrong") - .with(csrf())) + .with(csrf()); + this.mvc.perform(invalidPassword) .andExpect(status().isIAmATeapot()); + // @formatter:on } - @Test - public void authenticateWhenCustomUsernameAndPasswordParametersThenSucceeds() - throws Exception { - + public void authenticateWhenCustomUsernameAndPasswordParametersThenSucceeds() throws Exception { this.spring.configLocations(this.xml("WithUsernameAndPasswordParameters")).autowire(); - - this.mvc.perform(post("/login") - .param("xname", "user") - .param("xpass", "password") - .with(csrf())) + this.mvc.perform(post("/login").param("xname", "user").param("xpass", "password").with(csrf())) .andExpect(redirectedUrl("/")); } @@ -159,102 +151,94 @@ public class FormLoginConfigTests { */ @Test public void autowireWhenCustomLoginPageIsSlashLoginThenNoDefaultLoginPageGeneratingFilterIsWired() - throws Exception { - + throws Exception { this.spring.configLocations(this.xml("ForSec2919")).autowire(); - - this.mvc.perform(get("/login")) - .andExpect(content().string("teapot")); - + this.mvc.perform(get("/login")).andExpect(content().string("teapot")); assertThat(getFilter(this.spring.getContext(), DefaultLoginPageGeneratingFilter.class)).isNull(); } @Test - public void authenticateWhenCsrfIsEnabledThenRequiresToken() - throws Exception { - + public void authenticateWhenCsrfIsEnabledThenRequiresToken() throws Exception { this.spring.configLocations(this.xml("WithCsrfEnabled")).autowire(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password")) + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password"); + this.mvc.perform(loginRequest) .andExpect(status().isForbidden()); + // @formatter:on } @Test - public void authenticateWhenCsrfIsDisabledThenDoesNotRequireToken() - throws Exception { - + public void authenticateWhenCsrfIsDisabledThenDoesNotRequireToken() throws Exception { this.spring.configLocations(this.xml("WithCsrfDisabled")).autowire(); - - this.mvc.perform(post("/login") + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") - .param("password", "password")) + .param("password", "password"); + this.mvc.perform(loginRequest) .andExpect(status().isFound()); + // @formatter:on } /** - * SEC-3147: authentication-failure-url should be contained "error" parameter if login-page="/login" + * SEC-3147: authentication-failure-url should be contained "error" parameter if + * login-page="/login" */ @Test public void authenticateWhenLoginPageIsSlashLoginAndAuthenticationFailsThenRedirectContainsErrorParameter() - throws Exception { - + throws Exception { this.spring.configLocations(this.xml("ForSec3147")).autowire(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "wrong") - .with(csrf())) + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "wrong") + .with(csrf()); + this.mvc.perform(loginRequest) .andExpect(redirectedUrl("/login?error")); - } - - @RestController - public static class LoginController { - @GetMapping("/login") - public String ok() { - return "teapot"; - } - } - - public static class TeapotAuthenticationHandler implements - AuthenticationSuccessHandler, - AuthenticationFailureHandler { - - @Override - public void onAuthenticationFailure( - HttpServletRequest request, - HttpServletResponse response, - AuthenticationException exception) { - - response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); - } - - @Override - public void onAuthenticationSuccess( - HttpServletRequest request, - HttpServletResponse response, - Authentication authentication) { - - response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); - } + // @formatter:on } private Filter getFilter(ApplicationContext context, Class filterClass) { FilterChainProxy filterChain = context.getBean(BeanIds.FILTER_CHAIN_PROXY, FilterChainProxy.class); - List filters = filterChain.getFilters("/any"); - - for ( Filter filter : filters ) { - if ( filter.getClass() == filterClass ) { + for (Filter filter : filters) { + if (filter.getClass() == filterClass) { return filter; } } - return null; } private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + + @RestController + public static class LoginController { + + @GetMapping("/login") + public String ok() { + return "teapot"; + } + + } + + public static class TeapotAuthenticationHandler + implements AuthenticationSuccessHandler, AuthenticationFailureHandler { + + @Override + public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, + AuthenticationException exception) { + response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); + } + + @Override + public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) { + response.setStatus(HttpStatus.I_AM_A_TEAPOT.value()); + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java b/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java index f0da65eac7..5174dc331c 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpConfigTests.java @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + import org.apache.http.HttpStatus; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -25,23 +30,18 @@ import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.web.FilterChainProxy; import org.springframework.test.web.servlet.MockMvc; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpServletResponseWrapper; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Rob Winch * @author Josh Cummings */ public class HttpConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/HttpConfigTests"; + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/HttpConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -50,38 +50,32 @@ public class HttpConfigTests { MockMvc mvc; @Test - public void getWhenUsingMinimalConfigurationThenRedirectsToLogin() - throws Exception { - + public void getWhenUsingMinimalConfigurationThenRedirectsToLogin() throws Exception { this.spring.configLocations(this.xml("Minimal")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test - public void getWhenUsingMinimalConfigurationThenPreventsSessionAsUrlParameter() - throws Exception { - + public void getWhenUsingMinimalConfigurationThenPreventsSessionAsUrlParameter() throws Exception { this.spring.configLocations(this.xml("Minimal")).autowire(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChainProxy proxy = this.spring.getContext().getBean(FilterChainProxy.class); - - proxy.doFilter( - request, - new EncodeUrlDenyingHttpServletResponseWrapper(response), - (req, resp) -> {}); - + proxy.doFilter(request, new EncodeUrlDenyingHttpServletResponseWrapper(response), (req, resp) -> { + }); assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/login"); } - private static class EncodeUrlDenyingHttpServletResponseWrapper - extends HttpServletResponseWrapper { + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; + } + + private static class EncodeUrlDenyingHttpServletResponseWrapper extends HttpServletResponseWrapper { EncodeUrlDenyingHttpServletResponseWrapper(HttpServletResponse response) { super(response); @@ -106,9 +100,7 @@ public class HttpConfigTests { public String encodeRedirectUrl(String url) { throw new RuntimeException("Unexpected invocation of encodeURL"); } + } - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/HttpCorsConfigTests.java b/config/src/test/java/org/springframework/security/config/http/HttpCorsConfigTests.java index 2ba0044db8..c1ad5415d1 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpCorsConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpCorsConfigTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.Arrays; + import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpHeaders; @@ -32,24 +36,20 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; -import java.util.Arrays; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.options; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Rob Winch * @author Tim Ysewyn * @author Josh Cummings */ public class HttpCorsConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/HttpCorsConfigTests"; + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/HttpCorsConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -59,78 +59,49 @@ public class HttpCorsConfigTests { @Test public void autowireWhenMissingMvcThenGivesInformativeError() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("RequiresMvc")).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("RequiresMvc")).autowire()) + .withMessageContaining( + "Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext"); } @Test - public void getWhenUsingCorsThenDoesSpringSecurityCorsHandshake() - throws Exception { - + public void getWhenUsingCorsThenDoesSpringSecurityCorsHandshake() throws Exception { this.spring.configLocations(this.xml("WithCors")).autowire(); - + // @formatter:off this.mvc.perform(get("/").with(this.approved())) .andExpect(corsResponseHeaders()) .andExpect((status().isIAmATeapot())); - this.mvc.perform(options("/").with(this.preflight())) .andExpect(corsResponseHeaders()) .andExpect(status().isOk()); + // @formatter:on } @Test - public void getWhenUsingCustomCorsConfigurationSourceThenDoesSpringSecurityCorsHandshake() - throws Exception { - + public void getWhenUsingCustomCorsConfigurationSourceThenDoesSpringSecurityCorsHandshake() throws Exception { this.spring.configLocations(this.xml("WithCorsConfigurationSource")).autowire(); - + // @formatter:off this.mvc.perform(get("/").with(this.approved())) .andExpect(corsResponseHeaders()) .andExpect((status().isIAmATeapot())); - this.mvc.perform(options("/").with(this.preflight())) .andExpect(corsResponseHeaders()) .andExpect(status().isOk()); + // @formatter:on } @Test - public void getWhenUsingCustomCorsFilterThenDoesSPringSecurityCorsHandshake() - throws Exception { - + public void getWhenUsingCustomCorsFilterThenDoesSPringSecurityCorsHandshake() throws Exception { this.spring.configLocations(this.xml("WithCorsFilter")).autowire(); - + // @formatter:off this.mvc.perform(get("/").with(this.approved())) .andExpect(corsResponseHeaders()) .andExpect((status().isIAmATeapot())); - this.mvc.perform(options("/").with(this.preflight())) .andExpect(corsResponseHeaders()) .andExpect(status().isOk()); - } - - @RestController - @CrossOrigin(methods = { - RequestMethod.GET, RequestMethod.POST - }) - static class CorsController { - @RequestMapping("/") - String hello() { - return "Hello"; - } - } - - static class MyCorsConfigurationSource extends UrlBasedCorsConfigurationSource { - MyCorsConfigurationSource() { - CorsConfiguration configuration = new CorsConfiguration(); - configuration.setAllowedOrigins(Arrays.asList("*")); - configuration.setAllowedMethods(Arrays.asList(RequestMethod.GET.name(), RequestMethod.POST.name())); - - super.registerCorsConfiguration( - "/**", - configuration); - } + // @formatter:on } private String xml(String configName) { @@ -148,21 +119,41 @@ public class HttpCorsConfigTests { private RequestPostProcessor cors(boolean preflight) { return (request) -> { request.addHeader(HttpHeaders.ORIGIN, "https://example.com"); - - if ( preflight ) { + if (preflight) { request.setMethod(HttpMethod.OPTIONS.name()); request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.POST.name()); } - return request; }; } private ResultMatcher corsResponseHeaders() { - return result -> { + return (result) -> { header().exists("Access-Control-Allow-Origin").match(result); header().exists("X-Content-Type-Options").match(result); }; } + @RestController + @CrossOrigin(methods = { RequestMethod.GET, RequestMethod.POST }) + static class CorsController { + + @RequestMapping("/") + String hello() { + return "Hello"; + } + + } + + static class MyCorsConfigurationSource extends UrlBasedCorsConfigurationSource { + + MyCorsConfigurationSource() { + CorsConfiguration configuration = new CorsConfiguration(); + configuration.setAllowedOrigins(Arrays.asList("*")); + configuration.setAllowedMethods(Arrays.asList(RequestMethod.GET.name(), RequestMethod.POST.name())); + super.registerCorsConfiguration("/**", configuration); + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/HttpHeadersConfigTests.java b/config/src/test/java/org/springframework/security/config/http/HttpHeadersConfigTests.java index 2df097e589..0d094ed267 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpHeadersConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpHeadersConfigTests.java @@ -13,11 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + import com.google.common.collect.ImmutableMap; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; @@ -28,20 +37,12 @@ import org.springframework.test.web.servlet.ResultMatcher; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Rob Winch * @author Tim Ysewyn * @author Josh Cummings @@ -49,19 +50,18 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. */ public class HttpHeadersConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/HttpHeadersConfigTests"; + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/HttpHeadersConfigTests"; - static final Map defaultHeaders = - ImmutableMap.builder() - .put("X-Content-Type-Options", "nosniff") - .put("X-Frame-Options", "DENY") - .put("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains") - .put("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - .put("Expires", "0") - .put("Pragma", "no-cache") - .put("X-XSS-Protection", "1; mode=block") - .build(); + // @formatter:off + static final Map defaultHeaders = ImmutableMap.builder() + .put("X-Content-Type-Options", "nosniff").put("X-Frame-Options", "DENY") + .put("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains") + .put("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + .put("Expires", "0") + .put("Pragma", "no-cache") + .put("X-XSS-Protection", "1; mode=block") + .build(); + // @formatter:on @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -70,745 +70,655 @@ public class HttpHeadersConfigTests { MockMvc mvc; @Test - public void requestWhenHeadersDisabledThenResponseExcludesAllSecureHeaders() - throws Exception { - + public void requestWhenHeadersDisabledThenResponseExcludesAllSecureHeaders() throws Exception { this.spring.configLocations(this.xml("HeadersDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenHeadersDisabledViaPlaceholderThenResponseExcludesAllSecureHeaders() - throws Exception { - + public void requestWhenHeadersDisabledViaPlaceholderThenResponseExcludesAllSecureHeaders() throws Exception { System.setProperty("security.headers.disabled", "true"); - this.spring.configLocations(this.xml("DisabledWithPlaceholder")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) - .andExpect(status().isOk()) - .andExpect(excludesDefaults()); + .andExpect(status().isOk()) + .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenHeadersEnabledViaPlaceholderThenResponseIncludesAllSecureHeaders() - throws Exception { - + public void requestWhenHeadersEnabledViaPlaceholderThenResponseIncludesAllSecureHeaders() throws Exception { System.setProperty("security.headers.disabled", "false"); - this.spring.configLocations(this.xml("DisabledWithPlaceholder")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) - .andExpect(status().isOk()) - .andExpect(includesDefaults()); + .andExpect(status().isOk()) + .andExpect(includesDefaults()); + // @formatter:on } @Test - public void requestWhenHeadersDisabledRefMissingPlaceholderThenResponseIncludesAllSecureHeaders() - throws Exception { - + public void requestWhenHeadersDisabledRefMissingPlaceholderThenResponseIncludesAllSecureHeaders() throws Exception { System.clearProperty("security.headers.disabled"); - this.spring.configLocations(this.xml("DisabledWithPlaceholder")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) - .andExpect(status().isOk()) - .andExpect(includesDefaults()); + .andExpect(status().isOk()) + .andExpect(includesDefaults()); + // @formatter:on } @Test public void configureWhenHeadersDisabledHavingChildElementThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("HeadersDisabledHavingChildElement")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("Cannot specify with child elements"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("HeadersDisabledHavingChildElement")).autowire()) + .withMessageContaining("Cannot specify with child elements"); } @Test - public void requestWhenHeadersEnabledThenResponseContainsAllSecureHeaders() - throws Exception { - + public void requestWhenHeadersEnabledThenResponseContainsAllSecureHeaders() throws Exception { this.spring.configLocations(this.xml("DefaultConfig")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includesDefaults()); + // @formatter:on } @Test - public void requestWhenHeadersElementUsedThenResponseContainsAllSecureHeaders() - throws Exception { - + public void requestWhenHeadersElementUsedThenResponseContainsAllSecureHeaders() throws Exception { this.spring.configLocations(this.xml("HeadersEnabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includesDefaults()); + // @formatter:on } @Test - public void requestWhenFrameOptionsConfiguredThenIncludesHeader() - throws Exception { - + public void requestWhenFrameOptionsConfiguredThenIncludesHeader() throws Exception { Map headers = new HashMap(defaultHeaders); headers.put("X-Frame-Options", "SAMEORIGIN"); - this.spring.configLocations(this.xml("WithFrameOptions")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(headers)); + // @formatter:on } - // -- defaults disabled - /** * gh-3986 */ @Test - public void requestWhenDefaultsDisabledWithNoOverrideThenExcludesAllSecureHeaders() - throws Exception { - + public void requestWhenDefaultsDisabledWithNoOverrideThenExcludesAllSecureHeaders() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithNoOverride")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenDefaultsDisabledWithPlaceholderTrueThenExcludesAllSecureHeaders() - throws Exception { - + public void requestWhenDefaultsDisabledWithPlaceholderTrueThenExcludesAllSecureHeaders() throws Exception { System.setProperty("security.headers.defaults.disabled", "true"); - this.spring.configLocations(this.xml("DefaultsDisabledWithPlaceholder")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenDefaultsDisabledWithPlaceholderFalseThenIncludeAllSecureHeaders() - throws Exception { - + public void requestWhenDefaultsDisabledWithPlaceholderFalseThenIncludeAllSecureHeaders() throws Exception { System.setProperty("security.headers.defaults.disabled", "false"); - this.spring.configLocations(this.xml("DefaultsDisabledWithPlaceholder")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includesDefaults()); + // @formatter:on } @Test - public void requestWhenDefaultsDisabledWithPlaceholderMissingThenIncludeAllSecureHeaders() - throws Exception { - + public void requestWhenDefaultsDisabledWithPlaceholderMissingThenIncludeAllSecureHeaders() throws Exception { System.clearProperty("security.headers.defaults.disabled"); - this.spring.configLocations(this.xml("DefaultsDisabledWithPlaceholder")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingContentTypeOptionsThenDefaultsToNoSniff() - throws Exception { - + public void requestWhenUsingContentTypeOptionsThenDefaultsToNoSniff() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-Content-Type-Options"); - this.spring.configLocations(this.xml("DefaultsDisabledWithContentTypeOptions")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-Content-Type-Options", "nosniff")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenUsingFrameOptionsThenDefaultsToDeny() - throws Exception { - + public void requestWhenUsingFrameOptionsThenDefaultsToDeny() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-Frame-Options"); - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptions")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-Frame-Options", "DENY")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenUsingFrameOptionsDenyThenRespondsWithDeny() - throws Exception { - + public void requestWhenUsingFrameOptionsDenyThenRespondsWithDeny() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-Frame-Options"); - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptionsDeny")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-Frame-Options", "DENY")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenUsingFrameOptionsSameOriginThenRespondsWithSameOrigin() - throws Exception { - + public void requestWhenUsingFrameOptionsSameOriginThenRespondsWithSameOrigin() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-Frame-Options"); - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptionsSameOrigin")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-Frame-Options", "SAMEORIGIN")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test public void configureWhenUsingFrameOptionsAllowFromNoOriginThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptionsAllowFromNoOrigin")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("Strategy requires a 'value' to be set."); // FIXME better error message? + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring + .configLocations(this.xml("DefaultsDisabledWithFrameOptionsAllowFromNoOrigin")).autowire()) + .withMessageContaining("Strategy requires a 'value' to be set."); + // FIXME better error message? } @Test public void configureWhenUsingFrameOptionsAllowFromBlankOriginThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptionsAllowFromBlankOrigin")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("Strategy requires a 'value' to be set."); // FIXME better error message? + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring + .configLocations(this.xml("DefaultsDisabledWithFrameOptionsAllowFromBlankOrigin")).autowire()) + .withMessageContaining("Strategy requires a 'value' to be set."); + // FIXME better error message? } @Test - public void requestWhenUsingFrameOptionsAllowFromThenRespondsWithAllowFrom() - throws Exception { - + public void requestWhenUsingFrameOptionsAllowFromThenRespondsWithAllowFrom() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-Frame-Options"); - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptionsAllowFrom")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-Frame-Options", "ALLOW-FROM https://example.org")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenUsingFrameOptionsAllowFromWhitelistThenRespondsWithAllowFrom() - throws Exception { - + public void requestWhenUsingFrameOptionsAllowFromWhitelistThenRespondsWithAllowFrom() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-Frame-Options"); - this.spring.configLocations(this.xml("DefaultsDisabledWithFrameOptionsAllowFromWhitelist")).autowire(); - + // @formatter:off this.mvc.perform(get("/").param("from", "https://example.org")) .andExpect(status().isOk()) .andExpect(header().string("X-Frame-Options", "ALLOW-FROM https://example.org")) .andExpect(excludes(excludedHeaders)); - this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-Frame-Options", "DENY")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenUsingCustomHeaderThenRespondsWithThatHeader() - throws Exception { - + public void requestWhenUsingCustomHeaderThenRespondsWithThatHeader() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithCustomHeader")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("a", "b")) .andExpect(header().string("c", "d")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingCustomHeaderWriterThenRespondsWithThatHeader() - throws Exception { - + public void requestWhenUsingCustomHeaderWriterThenRespondsWithThatHeader() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithCustomHeaderWriter")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("abc", "def")) .andExpect(excludesDefaults()); + // @formatter:on } @Test public void configureWhenUsingCustomHeaderNameOnlyThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithOnlyHeaderName")).autowire()) - .isInstanceOf(BeanCreationException.class); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy( + () -> this.spring.configLocations(this.xml("DefaultsDisabledWithOnlyHeaderName")).autowire()); } @Test public void configureWhenUsingCustomHeaderValueOnlyThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithOnlyHeaderValue")).autowire()) - .isInstanceOf(BeanCreationException.class); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy( + () -> this.spring.configLocations(this.xml("DefaultsDisabledWithOnlyHeaderValue")).autowire()); } @Test - public void requestWhenUsingXssProtectionThenDefaultsToModeBlock() - throws Exception { - + public void requestWhenUsingXssProtectionThenDefaultsToModeBlock() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-XSS-Protection"); - this.spring.configLocations(this.xml("DefaultsDisabledWithXssProtection")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-XSS-Protection", "1; mode=block")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenEnablingXssProtectionThenDefaultsToModeBlock() - throws Exception { - + public void requestWhenEnablingXssProtectionThenDefaultsToModeBlock() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-XSS-Protection"); - this.spring.configLocations(this.xml("DefaultsDisabledWithXssProtectionEnabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-XSS-Protection", "1; mode=block")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void requestWhenDisablingXssProtectionThenDefaultsToZero() - throws Exception { - + public void requestWhenDisablingXssProtectionThenDefaultsToZero() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("X-XSS-Protection"); - this.spring.configLocations(this.xml("DefaultsDisabledWithXssProtectionDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("X-XSS-Protection", "0")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test public void configureWhenXssProtectionDisabledAndBlockSetThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithXssProtectionDisabledAndBlockSet")).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("Cannot set block to true with enabled false"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring + .configLocations(this.xml("DefaultsDisabledWithXssProtectionDisabledAndBlockSet")).autowire()) + .withMessageContaining("Cannot set block to true with enabled false"); } @Test - public void requestWhenUsingCacheControlThenRespondsWithCorrespondingHeaders() - throws Exception { - + public void requestWhenUsingCacheControlThenRespondsWithCorrespondingHeaders() throws Exception { Map includedHeaders = ImmutableMap.builder() - .put("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - .put("Expires", "0") - .put("Pragma", "no-cache") - .build(); - + .put("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate").put("Expires", "0") + .put("Pragma", "no-cache").build(); this.spring.configLocations(this.xml("DefaultsDisabledWithCacheControl")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(includes(includedHeaders)); + // @formatter:on } @Test - public void requestWhenUsingHstsThenRespondsWithHstsHeader() - throws Exception { - + public void requestWhenUsingHstsThenRespondsWithHstsHeader() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("Strict-Transport-Security"); - this.spring.configLocations(this.xml("DefaultsDisabledWithHsts")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(header().string("Strict-Transport-Security", "max-age=31536000 ; includeSubDomains")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test - public void insecureRequestWhenUsingHstsThenExcludesHstsHeader() - throws Exception { - + public void insecureRequestWhenUsingHstsThenExcludesHstsHeader() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHsts")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void insecureRequestWhenUsingCustomHstsRequestMatcherThenIncludesHstsHeader() - throws Exception { - + public void insecureRequestWhenUsingCustomHstsRequestMatcherThenIncludesHstsHeader() throws Exception { Set excludedHeaders = new HashSet<>(defaultHeaders.keySet()); excludedHeaders.remove("Strict-Transport-Security"); - this.spring.configLocations(this.xml("DefaultsDisabledWithCustomHstsRequestMatcher")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().string("Strict-Transport-Security", "max-age=1")) .andExpect(excludes(excludedHeaders)); + // @formatter:on } @Test public void configureWhenUsingHpkpWithoutPinsThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithEmptyHpkp")).autowire()) - .isInstanceOf(XmlBeanDefinitionStoreException.class) - .hasMessageContaining("The content of element 'hpkp' is not complete"); + assertThatExceptionOfType(XmlBeanDefinitionStoreException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("DefaultsDisabledWithEmptyHpkp")).autowire()) + .withMessageContaining("The content of element 'hpkp' is not complete"); } @Test public void configureWhenUsingHpkpWithEmptyPinsThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("DefaultsDisabledWithEmptyPins")).autowire()) - .isInstanceOf(XmlBeanDefinitionStoreException.class) - .hasMessageContaining("The content of element 'pins' is not complete"); + assertThatExceptionOfType(XmlBeanDefinitionStoreException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("DefaultsDisabledWithEmptyPins")).autowire()) + .withMessageContaining("The content of element 'pins' is not complete"); } @Test - public void requestWhenUsingHpkpThenIncludesHpkpHeader() - throws Exception { + public void requestWhenUsingHpkpThenIncludesHpkpHeader() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkp")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) - .andExpect(header().string( - "Public-Key-Pins-Report-Only", + .andExpect(header().string("Public-Key-Pins-Report-Only", "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingHpkpDefaultsThenIncludesHpkpHeaderUsingSha256() - throws Exception { + public void requestWhenUsingHpkpDefaultsThenIncludesHpkpHeaderUsingSha256() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkpDefaults")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) - .andExpect(header().string( - "Public-Key-Pins-Report-Only", + .andExpect(header().string("Public-Key-Pins-Report-Only", "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void insecureRequestWhenUsingHpkpThenExcludesHpkpHeader() - throws Exception { + public void insecureRequestWhenUsingHpkpThenExcludesHpkpHeader() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkpDefaults")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(header().doesNotExist("Public-Key-Pins-Report-Only")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingHpkpCustomMaxAgeThenIncludesHpkpHeaderAccordingly() - throws Exception { + public void requestWhenUsingHpkpCustomMaxAgeThenIncludesHpkpHeaderAccordingly() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkpMaxAge")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) - .andExpect(header().string( - "Public-Key-Pins-Report-Only", + .andExpect(header().string("Public-Key-Pins-Report-Only", "max-age=604800 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingHpkpReportThenIncludesHpkpHeaderAccordingly() - throws Exception { + public void requestWhenUsingHpkpReportThenIncludesHpkpHeaderAccordingly() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkpReport")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) - .andExpect(header().string( - "Public-Key-Pins", + .andExpect(header().string("Public-Key-Pins", "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\"")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingHpkpIncludeSubdomainsThenIncludesHpkpHeaderAccordingly() - throws Exception { + public void requestWhenUsingHpkpIncludeSubdomainsThenIncludesHpkpHeaderAccordingly() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkpIncludeSubdomains")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) - .andExpect(header().string( - "Public-Key-Pins-Report-Only", - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; includeSubDomains")) + .andExpect(header().string("Public-Key-Pins-Report-Only", + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; includeSubDomains")) .andExpect(excludesDefaults()); + // @formatter:on } @Test - public void requestWhenUsingHpkpReportUriThenIncludesHpkpHeaderAccordingly() - throws Exception { + public void requestWhenUsingHpkpReportUriThenIncludesHpkpHeaderAccordingly() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithHpkpReportUri")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) - .andExpect(header().string( - "Public-Key-Pins-Report-Only", - "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\"")) + .andExpect(header().string("Public-Key-Pins-Report-Only", + "max-age=5184000 ; pin-sha256=\"d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=\" ; report-uri=\"https://example.net/pkp-report\"")) .andExpect(excludesDefaults()); + // @formatter:on } - // -- single-header disabled - @Test - public void requestWhenCacheControlDisabledThenExcludesHeader() - throws Exception { - + public void requestWhenCacheControlDisabledThenExcludesHeader() throws Exception { Collection cacheControl = Arrays.asList("Cache-Control", "Expires", "Pragma"); Map allButCacheControl = remove(defaultHeaders, cacheControl); - this.spring.configLocations(this.xml("CacheControlDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(allButCacheControl)) .andExpect(excludes(cacheControl)); + // @formatter:on } @Test - public void requestWhenContentTypeOptionsDisabledThenExcludesHeader() - throws Exception { - + public void requestWhenContentTypeOptionsDisabledThenExcludesHeader() throws Exception { Collection contentTypeOptions = Arrays.asList("X-Content-Type-Options"); Map allButContentTypeOptions = remove(defaultHeaders, contentTypeOptions); - this.spring.configLocations(this.xml("ContentTypeOptionsDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(allButContentTypeOptions)) .andExpect(excludes(contentTypeOptions)); + // @formatter:on } @Test - public void requestWhenHstsDisabledThenExcludesHeader() - throws Exception { - + public void requestWhenHstsDisabledThenExcludesHeader() throws Exception { Collection hsts = Arrays.asList("Strict-Transport-Security"); Map allButHsts = remove(defaultHeaders, hsts); - this.spring.configLocations(this.xml("HstsDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(allButHsts)) .andExpect(excludes(hsts)); + // @formatter:on } @Test - public void requestWhenHpkpDisabledThenExcludesHeader() - throws Exception { - + public void requestWhenHpkpDisabledThenExcludesHeader() throws Exception { this.spring.configLocations(this.xml("HpkpDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includesDefaults()); + // @formatter:on } @Test - public void requestWhenFrameOptionsDisabledThenExcludesHeader() - throws Exception { - + public void requestWhenFrameOptionsDisabledThenExcludesHeader() throws Exception { Collection frameOptions = Arrays.asList("X-Frame-Options"); Map allButFrameOptions = remove(defaultHeaders, frameOptions); - this.spring.configLocations(this.xml("FrameOptionsDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(allButFrameOptions)) .andExpect(excludes(frameOptions)); + // @formatter:on } @Test - public void requestWhenXssProtectionDisabledThenExcludesHeader() - throws Exception { - + public void requestWhenXssProtectionDisabledThenExcludesHeader() throws Exception { Collection xssProtection = Arrays.asList("X-XSS-Protection"); Map allButXssProtection = remove(defaultHeaders, xssProtection); - this.spring.configLocations(this.xml("XssProtectionDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(allButXssProtection)) .andExpect(excludes(xssProtection)); + // @formatter:on } - // --- disable error handling --- - @Test public void configureWhenHstsDisabledAndIncludeSubdomainsSpecifiedThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("HstsDisabledSpecifyingIncludeSubdomains")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("include-subdomains"); + assertThatExceptionOfType(BeanDefinitionParsingException.class).isThrownBy( + () -> this.spring.configLocations(this.xml("HstsDisabledSpecifyingIncludeSubdomains")).autowire()) + .withMessageContaining("include-subdomains"); } @Test public void configureWhenHstsDisabledAndMaxAgeSpecifiedThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("HstsDisabledSpecifyingMaxAge")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("max-age"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("HstsDisabledSpecifyingMaxAge")).autowire()) + .withMessageContaining("max-age"); } @Test public void configureWhenHstsDisabledAndRequestMatcherSpecifiedThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("HstsDisabledSpecifyingRequestMatcher")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("request-matcher-ref"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy( + () -> this.spring.configLocations(this.xml("HstsDisabledSpecifyingRequestMatcher")).autowire()) + .withMessageContaining("request-matcher-ref"); } @Test public void configureWhenXssProtectionDisabledAndEnabledThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("XssProtectionDisabledAndEnabled")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("enabled"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("XssProtectionDisabledAndEnabled")).autowire()) + .withMessageContaining("enabled"); } @Test public void configureWhenXssProtectionDisabledAndBlockSpecifiedThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("XssProtectionDisabledSpecifyingBlock")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("block"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy( + () -> this.spring.configLocations(this.xml("XssProtectionDisabledSpecifyingBlock")).autowire()) + .withMessageContaining("block"); } @Test public void configureWhenFrameOptionsDisabledAndPolicySpecifiedThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("FrameOptionsDisabledSpecifyingPolicy")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("policy"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy( + () -> this.spring.configLocations(this.xml("FrameOptionsDisabledSpecifyingPolicy")).autowire()) + .withMessageContaining("policy"); } @Test - public void requestWhenContentSecurityPolicyDirectivesConfiguredThenIncludesDirectives() - throws Exception { - + public void requestWhenContentSecurityPolicyDirectivesConfiguredThenIncludesDirectives() throws Exception { Map includedHeaders = new HashMap<>(defaultHeaders); includedHeaders.put("Content-Security-Policy", "default-src 'self'"); - this.spring.configLocations(this.xml("ContentSecurityPolicyWithPolicyDirectives")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(includedHeaders)); + // @formatter:on } @Test - public void requestWhenHeadersDisabledAndContentSecurityPolicyConfiguredThenExcludesHeader() - throws Exception { - + public void requestWhenHeadersDisabledAndContentSecurityPolicyConfiguredThenExcludesHeader() throws Exception { this.spring.configLocations(this.xml("HeadersDisabledWithContentSecurityPolicy")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(excludesDefaults()) .andExpect(excludes("Content-Security-Policy")); + // @formatter:on } @Test - public void requestWhenDefaultsDisabledAndContentSecurityPolicyConfiguredThenIncludesHeader() - throws Exception { - + public void requestWhenDefaultsDisabledAndContentSecurityPolicyConfiguredThenIncludesHeader() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithContentSecurityPolicy")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(excludesDefaults()) .andExpect(header().string("Content-Security-Policy", "default-src 'self'")); + // @formatter:on } @Test public void configureWhenContentSecurityPolicyConfiguredWithEmptyDirectivesThenAutowireFails() { - assertThatThrownBy(() -> - this.spring.configLocations(this.xml("ContentSecurityPolicyWithEmptyDirectives")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class).isThrownBy( + () -> this.spring.configLocations(this.xml("ContentSecurityPolicyWithEmptyDirectives")).autowire()); } @Test public void requestWhenContentSecurityPolicyConfiguredWithReportOnlyThenIncludesReportOnlyHeader() - throws Exception { - + throws Exception { Map includedHeaders = new HashMap<>(defaultHeaders); - includedHeaders.put("Content-Security-Policy-Report-Only", "default-src https:; report-uri https://example.org/"); - + includedHeaders.put("Content-Security-Policy-Report-Only", + "default-src https:; report-uri https://example.org/"); this.spring.configLocations(this.xml("ContentSecurityPolicyWithReportOnly")).autowire(); - + // @formatter:off this.mvc.perform(get("/").secure(true)) .andExpect(status().isOk()) .andExpect(includes(includedHeaders)); + // @formatter:on } @Test - public void requestWhenReferrerPolicyConfiguredThenResponseDefaultsToNoReferrer() - throws Exception { - + public void requestWhenReferrerPolicyConfiguredThenResponseDefaultsToNoReferrer() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithReferrerPolicy")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(excludesDefaults()) .andExpect(header().string("Referrer-Policy", "no-referrer")); + // @formatter:on } @Test - public void requestWhenReferrerPolicyConfiguredWithSameOriginThenRespondsWithSameOrigin() - throws Exception { - + public void requestWhenReferrerPolicyConfiguredWithSameOriginThenRespondsWithSameOrigin() throws Exception { this.spring.configLocations(this.xml("DefaultsDisabledWithReferrerPolicySameOrigin")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isOk()) .andExpect(excludesDefaults()) .andExpect(header().string("Referrer-Policy", "same-origin")); - } - - @RestController - public static class SimpleController { - @GetMapping("/") - public String ok() { return "ok"; } + // @formatter:on } private static ResultMatcher includesDefaults() { @@ -816,8 +726,8 @@ public class HttpHeadersConfigTests { } private static ResultMatcher includes(Map headers) { - return result -> { - for ( Map.Entry header : headers.entrySet() ) { + return (result) -> { + for (Map.Entry header : headers.entrySet()) { header().string(header.getKey(), header.getValue()).match(result); } }; @@ -828,8 +738,8 @@ public class HttpHeadersConfigTests { } private static ResultMatcher excludes(Collection headers) { - return result -> { - for ( String name : headers ) { + return (result) -> { + for (String name : headers) { header().doesNotExist(name).match(result); } }; @@ -841,15 +751,24 @@ public class HttpHeadersConfigTests { private static Map remove(Map map, Collection keys) { Map copy = new HashMap<>(map); - - for ( K key : keys ) { + for (K key : keys) { copy.remove(key); } - return copy; } private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + + @RestController + public static class SimpleController { + + @GetMapping("/") + public String ok() { + return "ok"; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/HttpInterceptUrlTests.java b/config/src/test/java/org/springframework/security/config/http/HttpInterceptUrlTests.java index e1b9ac6ca3..2f226a45b5 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpInterceptUrlTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpInterceptUrlTests.java @@ -13,15 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.http; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.config.http; import javax.servlet.Filter; import org.junit.After; import org.junit.Test; + import org.springframework.mock.web.MockServletContext; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; @@ -30,34 +29,37 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.ConfigurableWebApplicationContext; import org.springframework.web.context.support.XmlWebApplicationContext; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + public class HttpInterceptUrlTests { + ConfigurableWebApplicationContext context; MockMvc mockMvc; @After public void close() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @Test public void interceptUrlWhenRequestMatcherRefThenWorks() throws Exception { loadConfig("interceptUrlWhenRequestMatcherRefThenWorks.xml"); - - mockMvc.perform(get("/foo")) - .andExpect(status().isUnauthorized()); - - mockMvc.perform(get("/FOO")) - .andExpect(status().isUnauthorized()); - - mockMvc.perform(get("/other")) - .andExpect(status().isOk()); + // @formatter:off + this.mockMvc.perform(get("/foo")) + .andExpect(status().isUnauthorized()); + this.mockMvc.perform(get("/FOO")) + .andExpect(status().isUnauthorized()); + this.mockMvc.perform(get("/other")) + .andExpect(status().isOk()); + // @formatter:on } private void loadConfig(String... configLocations) { - for (int i=0; i this.spring.configLocations(this.xml("AntMatcherServletPath")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("AntMatcherServletPath")).autowire()); } @Test public void configureWhenUsingRegexMatcherAndServletPathThenThrowsException() { - assertThatCode(() -> this.spring.configLocations(this.xml("RegexMatcherServletPath")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("RegexMatcherServletPath")).autowire()); } @Test public void configureWhenUsingCiRegexMatcherAndServletPathThenThrowsException() { - assertThatCode(() -> this.spring.configLocations(this.xml("CiRegexMatcherServletPath")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("CiRegexMatcherServletPath")).autowire()); } @Test public void configureWhenUsingDefaultMatcherAndServletPathThenThrowsException() { - assertThatCode(() -> this.spring.configLocations(this.xml("DefaultMatcherServletPath")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("DefaultMatcherServletPath")).autowire()); } - @RestController - static class PathController { - @RequestMapping("/path") - public String path() { - return "path"; - } - - @RequestMapping("/path/{un}/path") - public String path(@PathVariable("un") String name) { - return name; - } + private static RequestPostProcessor adminCredentials() { + return httpBasic("admin", "password"); } - public static class Id { - public boolean isOne(int i) { - return i == 1; - } + private static RequestPostProcessor userCredentials() { + return httpBasic("user", "password"); } private MockServletContext mockServletContext(String servletPath) { MockServletContext servletContext = spy(new MockServletContext()); final ServletRegistration registration = mock(ServletRegistration.class); - when(registration.getMappings()).thenReturn(Collections.singleton(servletPath)); - Answer> answer = invocation -> - Collections.singletonMap("spring", registration); - when(servletContext.getServletRegistrations()).thenAnswer(answer); + given(registration.getMappings()).willReturn(Collections.singleton(servletPath)); + Answer> answer = (invocation) -> Collections.singletonMap("spring", + registration); + given(servletContext.getServletRegistrations()).willAnswer(answer); return servletContext; } private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + + @RestController + static class PathController { + + @RequestMapping("/path") + String path() { + return "path"; + } + + @RequestMapping("/path/{un}/path") + String path(@PathVariable("un") String name) { + return name; + } + + } + + public static class Id { + + public boolean isOne(int i) { + return i == 1; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java index 4049067b0f..a066387d6e 100644 --- a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java @@ -28,6 +28,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import javax.security.auth.Subject; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.spi.LoginModule; @@ -42,6 +43,7 @@ import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.core.Appender; import org.apache.http.HttpStatus; import org.assertj.core.api.iterable.Extractor; +import org.jetbrains.annotations.NotNull; import org.junit.Rule; import org.junit.Test; import org.mockito.stubbing.Answer; @@ -101,25 +103,27 @@ import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.session.SessionManagementFilter; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.support.XmlWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509; -import static org.springframework.test.util.ReflectionTestUtils.getField; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -127,15 +131,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; - /** - * * @author Luke Taylor * @author Rob Winch */ public class MiscHttpConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/MiscHttpConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/MiscHttpConfigTests"; @Autowired MockMvc mvc; @@ -161,216 +163,166 @@ public class MiscHttpConfigTests { } @Test - public void requestWhenUsingDebugFilterAndPatternIsNotConfigureForSecurityThenRespondsOk() - throws Exception { - + public void requestWhenUsingDebugFilterAndPatternIsNotConfigureForSecurityThenRespondsOk() throws Exception { this.spring.configLocations(xml("NoSecurityForPattern")).autowire(); - + // @formatter:off this.mvc.perform(get("/unprotected")) - .andExpect(status().isNotFound()); - + .andExpect(status().isNotFound()); this.mvc.perform(get("/nomatch")) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void requestWhenHttpPatternUsesRegexMatchingThenMatchesAccordingly() - throws Exception { - + public void requestWhenHttpPatternUsesRegexMatchingThenMatchesAccordingly() throws Exception { this.spring.configLocations(xml("RegexSecurityPattern")).autowire(); - + // @formatter:off this.mvc.perform(get("/protected")) .andExpect(status().isUnauthorized()); - this.mvc.perform(get("/unprotected")) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void requestWhenHttpPatternUsesCiRegexMatchingThenMatchesAccordingly() - throws Exception { - + public void requestWhenHttpPatternUsesCiRegexMatchingThenMatchesAccordingly() throws Exception { this.spring.configLocations(xml("CiRegexSecurityPattern")).autowire(); - + // @formatter:off this.mvc.perform(get("/ProTectEd")) .andExpect(status().isUnauthorized()); - this.mvc.perform(get("/UnProTectEd")) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void requestWhenHttpPatternUsesCustomRequestMatcherThenMatchesAccordingly() - throws Exception { - + public void requestWhenHttpPatternUsesCustomRequestMatcherThenMatchesAccordingly() throws Exception { this.spring.configLocations(xml("CustomRequestMatcher")).autowire(); - + // @formatter:off this.mvc.perform(get("/protected")) .andExpect(status().isUnauthorized()); - this.mvc.perform(get("/unprotected")) .andExpect(status().isNotFound()); + // @formatter:on } /** * SEC-1152 */ @Test - public void requestWhenUsingMinimalConfigurationThenHonorsAnonymousEndpoints() - throws Exception { - + public void requestWhenUsingMinimalConfigurationThenHonorsAnonymousEndpoints() throws Exception { this.spring.configLocations(xml("AnonymousEndpoints")).autowire(); - + // @formatter:off this.mvc.perform(get("/protected")) .andExpect(status().isUnauthorized()); - this.mvc.perform(get("/unprotected")) .andExpect(status().isNotFound()); - + // @formatter:on assertThat(getFilter(AnonymousAuthenticationFilter.class)).isNotNull(); } @Test - public void requestWhenAnonymousIsDisabledThenRejectsAnonymousEndpoints() - throws Exception { - + public void requestWhenAnonymousIsDisabledThenRejectsAnonymousEndpoints() throws Exception { this.spring.configLocations(xml("AnonymousDisabled")).autowire(); - + // @formatter:off this.mvc.perform(get("/protected")) .andExpect(status().isUnauthorized()); - this.mvc.perform(get("/unprotected")) .andExpect(status().isUnauthorized()); - + // @formatter:on assertThat(getFilter(AnonymousAuthenticationFilter.class)).isNull(); } @Test - public void requestWhenAnonymousUsesCustomAttributesThenRespondsWithThoseAttributes() - throws Exception { - + public void requestWhenAnonymousUsesCustomAttributesThenRespondsWithThoseAttributes() throws Exception { this.spring.configLocations(xml("AnonymousCustomAttributes")).autowire(); - - this.mvc.perform(get("/protected") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/protected").with(userCredentials())) .andExpect(status().isForbidden()); - this.mvc.perform(get("/protected")) .andExpect(status().isOk()) .andExpect(content().string("josh")); - this.mvc.perform(get("/customKey")) .andExpect(status().isOk()) .andExpect(content().string(String.valueOf("myCustomKey".hashCode()))); + // @formatter:on } @Test - public void requestWhenAnonymousUsesMultipleGrantedAuthoritiesThenRespondsWithThoseAttributes() - throws Exception { - + public void requestWhenAnonymousUsesMultipleGrantedAuthoritiesThenRespondsWithThoseAttributes() throws Exception { this.spring.configLocations(xml("AnonymousMultipleAuthorities")).autowire(); - - this.mvc.perform(get("/protected") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/protected").with(userCredentials())) .andExpect(status().isForbidden()); - this.mvc.perform(get("/protected")) .andExpect(status().isOk()) .andExpect(content().string("josh")); - this.mvc.perform(get("/customKey")) .andExpect(status().isOk()) .andExpect(content().string(String.valueOf("myCustomKey".hashCode()))); + // @formatter:on } @Test - public void requestWhenInterceptUrlMatchesMethodThenSecuresAccordingly() - throws Exception { - + public void requestWhenInterceptUrlMatchesMethodThenSecuresAccordingly() throws Exception { this.spring.configLocations(xml("InterceptUrlMethod")).autowire(); - - this.mvc.perform(get("/protected") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/protected").with(userCredentials())) .andExpect(status().isOk()); - - this.mvc.perform(post("/protected") - .with(httpBasic("user", "password"))) + this.mvc.perform(post("/protected").with(userCredentials())) .andExpect(status().isForbidden()); - - this.mvc.perform(post("/protected") - .with(httpBasic("poster", "password"))) + this.mvc.perform(post("/protected").with(postCredentials())) .andExpect(status().isOk()); - - this.mvc.perform(delete("/protected") - .with(httpBasic("poster", "password"))) + this.mvc.perform(delete("/protected").with(postCredentials())) .andExpect(status().isForbidden()); - - this.mvc.perform(delete("/protected") - .with(httpBasic("admin", "password"))) + this.mvc.perform(delete("/protected").with(adminCredentials())) .andExpect(status().isOk()); + // @formatter:on } @Test - public void requestWhenInterceptUrlMatchesMethodAndRequiresHttpsThenSecuresAccordingly() - throws Exception { - + public void requestWhenInterceptUrlMatchesMethodAndRequiresHttpsThenSecuresAccordingly() throws Exception { this.spring.configLocations(xml("InterceptUrlMethodRequiresHttps")).autowire(); - + // @formatter:off this.mvc.perform(post("/protected").with(csrf())) .andExpect(status().isOk()); - - this.mvc.perform(get("/protected") - .secure(true) - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/protected").secure(true).with(userCredentials())) .andExpect(status().isForbidden()); - - this.mvc.perform(get("/protected") - .secure(true) - .with(httpBasic("admin", "password"))) + this.mvc.perform(get("/protected").secure(true).with(adminCredentials())) .andExpect(status().isOk()); + // @formatter:on } @Test - public void requestWhenInterceptUrlMatchesAnyPatternAndRequiresHttpsThenSecuresAccordingly() - throws Exception { - + public void requestWhenInterceptUrlMatchesAnyPatternAndRequiresHttpsThenSecuresAccordingly() throws Exception { this.spring.configLocations(xml("InterceptUrlMethodRequiresHttpsAny")).autowire(); - + // @formatter:off this.mvc.perform(post("/protected").with(csrf())) .andExpect(status().isOk()); - - this.mvc.perform(get("/protected") - .secure(true) - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/protected").secure(true).with(userCredentials())) .andExpect(status().isForbidden()); - - this.mvc.perform(get("/protected") - .secure(true) - .with(httpBasic("admin", "password"))) + this.mvc.perform(get("/protected").secure(true).with(adminCredentials())) .andExpect(status().isOk()); + // @formatter:on } @Test public void configureWhenOncePerRequestIsFalseThenFilterSecurityInterceptorExercisedForForwards() { this.spring.configLocations(xml("OncePerRequest")).autowire(); - FilterSecurityInterceptor filterSecurityInterceptor = getFilter(FilterSecurityInterceptor.class); assertThat(filterSecurityInterceptor.isObserveOncePerRequest()).isFalse(); } @Test - public void requestWhenCustomHttpBasicEntryPointRefThenInvokesOnCommence() - throws Exception { - + public void requestWhenCustomHttpBasicEntryPointRefThenInvokesOnCommence() throws Exception { this.spring.configLocations(xml("CustomHttpBasicEntryPointRef")).autowire(); - AuthenticationEntryPoint entryPoint = this.spring.getContext().getBean(AuthenticationEntryPoint.class); - + // @formatter:off this.mvc.perform(get("/protected")) - .andExpect(status().isOk()); - - verify(entryPoint).commence( - any(HttpServletRequest.class), any(HttpServletResponse.class), any(AuthenticationException.class)); + .andExpect(status().isOk()); + // @formatter:on + verify(entryPoint).commence(any(HttpServletRequest.class), any(HttpServletResponse.class), + any(AuthenticationException.class)); } @Test @@ -382,191 +334,156 @@ public class MiscHttpConfigTests { @Test public void getWhenPortsMappedThenRedirectedAccordingly() throws Exception { this.spring.configLocations(xml("PortsMappedInterceptUrlMethodRequiresAny")).autowire(); - + // @formatter:off this.mvc.perform(get("http://localhost:9080/protected")) .andExpect(redirectedUrl("https://localhost:9443/protected")); + // @formatter:on } @Test public void configureWhenCustomFiltersThenAddedToChainInCorrectOrder() { System.setProperty("customFilterRef", "userFilter"); this.spring.configLocations(xml("CustomFilters")).autowire(); - List filters = getFilters("/"); - Class userFilterClass = this.spring.getContext().getBean("userFilter").getClass(); - - assertThat(filters) - .extracting((Extractor>) filter -> filter.getClass()) - .containsSubsequence( - userFilterClass, userFilterClass, - SecurityContextPersistenceFilter.class, LogoutFilter.class, - userFilterClass); + assertThat(filters).extracting((Extractor>) (filter) -> filter.getClass()).containsSubsequence( + userFilterClass, userFilterClass, SecurityContextPersistenceFilter.class, LogoutFilter.class, + userFilterClass); } @Test public void configureWhenTwoFiltersWithSameOrderThenException() { - assertThatCode(() -> this.spring.configLocations(xml("CollidingFilters")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("CollidingFilters")).autowire()); } @Test public void configureWhenUsingX509ThenAddsX509FilterCorrectly() { this.spring.configLocations(xml("X509")).autowire(); - - assertThat(getFilters("/")) - .extracting((Extractor>) filter -> filter.getClass()) - .containsSubsequence( - CsrfFilter.class, X509AuthenticationFilter.class, ExceptionTranslationFilter.class); + assertThat(getFilters("/")).extracting((Extractor>) (filter) -> filter.getClass()) + .containsSubsequence(CsrfFilter.class, X509AuthenticationFilter.class, + ExceptionTranslationFilter.class); } - @Test public void getWhenUsingX509AndPropertyPlaceholderThenSubjectPrincipalRegexIsConfigured() throws Exception { System.setProperty("subject_principal_regex", "OU=(.*?)(?:,|$)"); this.spring.configLocations(xml("X509")).autowire(); - - this.mvc.perform(get("/protected") - .with(x509("classpath:org/springframework/security/config/http/MiscHttpConfigTests-certificate.pem"))) + RequestPostProcessor x509 = x509( + "classpath:org/springframework/security/config/http/MiscHttpConfigTests-certificate.pem"); + // @formatter:off + this.mvc.perform(get("/protected").with(x509)) .andExpect(status().isOk()); + // @formatter:on } @Test public void configureWhenUsingInvalidLogoutSuccessUrlThenThrowsException() { - assertThatCode(() -> this.spring.configLocations(xml("InvalidLogoutSuccessUrl")).autowire()) - .isInstanceOf(BeanCreationException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.configLocations(xml("InvalidLogoutSuccessUrl")).autowire()); } @Test public void logoutWhenSpecifyingCookiesToDeleteThenSetCookieAdded() throws Exception { this.spring.configLocations(xml("DeleteCookies")).autowire(); - - MvcResult result = - this.mvc.perform(post("/logout").with(csrf())).andReturn(); - + MvcResult result = this.mvc.perform(post("/logout").with(csrf())).andReturn(); List values = result.getResponse().getHeaders("Set-Cookie"); assertThat(values.size()).isEqualTo(2); - assertThat(values).extracting(value -> value.split("=")[0]).contains("JSESSIONID", "mycookie"); + assertThat(values).extracting((value) -> value.split("=")[0]).contains("JSESSIONID", "mycookie"); } @Test public void logoutWhenSpecifyingSuccessHandlerRefThenResponseHandledAccordingly() throws Exception { this.spring.configLocations(xml("LogoutSuccessHandlerRef")).autowire(); - + // @formatter:off this.mvc.perform(post("/logout").with(csrf())) .andExpect(redirectedUrl("/logoutSuccessEndpoint")); + // @formatter:on } @Test public void getWhenUnauthenticatedThenUsesConfiguredRequestCache() throws Exception { this.spring.configLocations(xml("RequestCache")).autowire(); - RequestCache requestCache = this.spring.getContext().getBean(RequestCache.class); - this.mvc.perform(get("/")); - verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void getWhenUnauthenticatedThenUsesConfiguredAuthenticationEntryPoint() throws Exception { this.spring.configLocations(xml("EntryPoint")).autowire(); - AuthenticationEntryPoint entryPoint = this.spring.getContext().getBean(AuthenticationEntryPoint.class); - this.mvc.perform(get("/")); - - verify(entryPoint).commence( - any(HttpServletRequest.class), - any(HttpServletResponse.class), + verify(entryPoint).commence(any(HttpServletRequest.class), any(HttpServletResponse.class), any(AuthenticationException.class)); } /** - * See SEC-750. If the http security post processor causes beans to be instantiated too eagerly, they way miss - * additional processing. In this method we have a UserDetailsService which is referenced from the namespace - * and also has a post processor registered which will modify it. + * See SEC-750. If the http security post processor causes beans to be instantiated + * too eagerly, they way miss additional processing. In this method we have a + * UserDetailsService which is referenced from the namespace and also has a post + * processor registered which will modify it. */ @Test public void configureWhenUsingCustomUserDetailsServiceThenBeanPostProcessorsAreStillApplied() { this.spring.configLocations(xml("Sec750")).autowire(); - - BeanNameCollectingPostProcessor postProcessor = - this.spring.getContext().getBean(BeanNameCollectingPostProcessor.class); - - assertThat(postProcessor.getBeforeInitPostProcessedBeans()) - .contains("authenticationProvider", "userService"); - assertThat(postProcessor.getAfterInitPostProcessedBeans()) - .contains("authenticationProvider", "userService"); - + BeanNameCollectingPostProcessor postProcessor = this.spring.getContext() + .getBean(BeanNameCollectingPostProcessor.class); + assertThat(postProcessor.getBeforeInitPostProcessedBeans()).contains("authenticationProvider", "userService"); + assertThat(postProcessor.getAfterInitPostProcessedBeans()).contains("authenticationProvider", "userService"); } /* SEC-934 */ @Test public void getWhenUsingTwoIdenticalInterceptUrlsThenTheSecondTakesPrecedence() throws Exception { this.spring.configLocations(xml("Sec934")).autowire(); - - this.mvc.perform(get("/protected") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/protected").with(userCredentials())) .andExpect(status().isOk()); - - this.mvc.perform(get("/protected") - .with(httpBasic("admin", "password"))) + this.mvc.perform(get("/protected").with(adminCredentials())) .andExpect(status().isForbidden()); + // @formatter:on } @Test public void getWhenAuthenticatingThenConsultsCustomSecurityContextRepository() throws Exception { this.spring.configLocations(xml("SecurityContextRepository")).autowire(); - SecurityContextRepository repository = this.spring.getContext().getBean(SecurityContextRepository.class); SecurityContext context = new SecurityContextImpl(new TestingAuthenticationToken("user", "password")); - when(repository.loadContext(any(HttpRequestResponseHolder.class))).thenReturn(context); - - MvcResult result = - this.mvc.perform(get("/protected") - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()) - .andReturn(); - + given(repository.loadContext(any(HttpRequestResponseHolder.class))).willReturn(context); + // @formatter:off + MvcResult result = this.mvc.perform(get("/protected").with(userCredentials())) + .andExpect(status().isOk()) + .andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false)).isNotNull(); - - verify(repository, atLeastOnce()).saveContext( - any(SecurityContext.class), - any(HttpServletRequest.class), + verify(repository, atLeastOnce()).saveContext(any(SecurityContext.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void getWhenUsingInterceptUrlExpressionsThenAuthorizesAccordingly() throws Exception { this.spring.configLocations(xml("InterceptUrlExpressions")).autowire(); - - this.mvc.perform(get("/protected") - .with(httpBasic("admin", "password"))) + // @formatter:off + this.mvc.perform(get("/protected").with(adminCredentials())) .andExpect(status().isOk()); - - this.mvc.perform(get("/protected") - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/protected").with(userCredentials())) .andExpect(status().isForbidden()); - - this.mvc.perform(get("/unprotected") - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/unprotected").with(userCredentials())) .andExpect(status().isOk()); - + // @formatter:on } @Test public void getWhenUsingCustomExpressionHandlerThenAuthorizesAccordingly() throws Exception { this.spring.configLocations(xml("ExpressionHandler")).autowire(); - PermissionEvaluator permissionEvaluator = this.spring.getContext().getBean(PermissionEvaluator.class); - when(permissionEvaluator.hasPermission(any(Authentication.class), any(Object.class), any(Object.class))) - .thenReturn(false); - - this.mvc.perform(get("/") - .with(httpBasic("user", "password"))) + given(permissionEvaluator.hasPermission(any(Authentication.class), any(Object.class), any(Object.class))) + .willReturn(false); + // @formatter:off + this.mvc.perform(get("/").with(userCredentials())) .andExpect(status().isForbidden()); - + // @formatter:on verify(permissionEvaluator).hasPermission(any(Authentication.class), any(Object.class), any(Object.class)); } @@ -574,44 +491,33 @@ public class MiscHttpConfigTests { public void configureWhenProtectingLoginPageThenWarningLogged() { ByteArrayOutputStream baos = new ByteArrayOutputStream(); redirectLogsTo(baos, DefaultFilterChainValidator.class); - this.spring.configLocations(xml("ProtectedLoginPage")).autowire(); - assertThat(baos.toString()).contains("[WARN]"); } @Test public void configureWhenUsingDisableUrlRewritingThenRedirectIsNotEncodedByResponse() throws IOException, ServletException { - this.spring.configLocations(xml("DisableUrlRewriting")).autowire(); - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChainProxy proxy = this.spring.getContext().getBean(FilterChainProxy.class); - - proxy.doFilter( - request, - new EncodeUrlDenyingHttpServletResponseWrapper(response), - (req, resp) -> {}); - + proxy.doFilter(request, new EncodeUrlDenyingHttpServletResponseWrapper(response), (req, resp) -> { + }); assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/login"); } @Test public void configureWhenUserDetailsServiceInParentContextThenLocatesSuccessfully() { - assertThatCode(() -> this.spring.configLocations(this.xml("MissingUserDetailsService")).autowire()) - .isInstanceOf(BeansException.class); - - try ( XmlWebApplicationContext parent = new XmlWebApplicationContext() ) { - parent.setConfigLocations(this.xml("AutoConfig")); + assertThatExceptionOfType(BeansException.class).isThrownBy( + () -> this.spring.configLocations(MiscHttpConfigTests.xml("MissingUserDetailsService")).autowire()); + try (XmlWebApplicationContext parent = new XmlWebApplicationContext()) { + parent.setConfigLocations(MiscHttpConfigTests.xml("AutoConfig")); parent.refresh(); - - try ( XmlWebApplicationContext child = new XmlWebApplicationContext() ) { + try (XmlWebApplicationContext child = new XmlWebApplicationContext()) { child.setParent(parent); - child.setConfigLocation(this.xml("MissingUserDetailsService")); + child.setConfigLocation(MiscHttpConfigTests.xml("MissingUserDetailsService")); child.refresh(); } } @@ -620,33 +526,33 @@ public class MiscHttpConfigTests { @Test public void loginWhenConfiguredWithNoInternalAuthenticationProvidersThenSuccessfullyAuthenticates() throws Exception { - this.spring.configLocations(xml("NoInternalAuthenticationProviders")).autowire(); - - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password")) + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password"); + this.mvc.perform(loginRequest) .andExpect(redirectedUrl("/")); + // @formatter:on } @Test public void loginWhenUsingDefaultsThenErasesCredentialsAfterAuthentication() throws Exception { this.spring.configLocations(xml("HttpBasic")).autowire(); - - this.mvc.perform(get("/password") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/password").with(userCredentials())) .andExpect(content().string("")); + // @formatter:on } @Test public void loginWhenAuthenticationManagerConfiguredToEraseCredentialsThenErasesCredentialsAfterAuthentication() - throws Exception { - + throws Exception { this.spring.configLocations(xml("AuthenticationManagerEraseCredentials")).autowire(); - - this.mvc.perform(get("/password") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/password").with(userCredentials())) .andExpect(content().string("")); + // @formatter:on } /** @@ -654,104 +560,96 @@ public class MiscHttpConfigTests { */ @Test public void loginWhenAuthenticationManagerRefConfiguredToKeepCredentialsThenKeepsCredentialsAfterAuthentication() - throws Exception { - + throws Exception { this.spring.configLocations(xml("AuthenticationManagerRefKeepCredentials")).autowire(); - - this.mvc.perform(get("/password") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/password").with(userCredentials())) .andExpect(content().string("password")); + // @formatter:on } @Test public void loginWhenAuthenticationManagerRefIsNotAProviderManagerThenKeepsCredentialsAccordingly() - throws Exception { - + throws Exception { this.spring.configLocations(xml("AuthenticationManagerRefNotProviderManager")).autowire(); - - this.mvc.perform(get("/password") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/password").with(userCredentials())) .andExpect(content().string("password")); + // @formatter:on } @Test public void loginWhenJeeFilterThenExtractsRoles() throws Exception { this.spring.configLocations(xml("JeeFilter")).autowire(); - Principal user = mock(Principal.class); - when(user.getName()).thenReturn("joe"); - - this.mvc.perform(get("/roles") + given(user.getName()).willReturn("joe"); + // @formatter:off + MockHttpServletRequestBuilder rolesRequest = get("/roles") .principal(user) - .with(request -> { + .with((request) -> { request.addUserRole("admin"); request.addUserRole("user"); request.addUserRole("unmapped"); return request; - })) + }); + this.mvc.perform(rolesRequest) .andExpect(content().string("ROLE_admin,ROLE_user")); + // @formatter:on } @Test public void loginWhenUsingCustomAuthenticationDetailsSourceRefThenAuthenticationSourcesDetailsAccordingly() - throws Exception { - + throws Exception { this.spring.configLocations(xml("CustomAuthenticationDetailsSourceRef")).autowire(); - Object details = mock(Object.class); AuthenticationDetailsSource source = this.spring.getContext().getBean(AuthenticationDetailsSource.class); - when(source.buildDetails(any(Object.class))).thenReturn(details); - - this.mvc.perform(get("/details") - .with(httpBasic("user", "password"))) + given(source.buildDetails(any(Object.class))).willReturn(details); + RequestPostProcessor x509 = x509( + "classpath:org/springframework/security/config/http/MiscHttpConfigTests-certificate.pem"); + // @formatter:off + this.mvc.perform(get("/details").with(userCredentials())) .andExpect(content().string(details.getClass().getName())); - - this.mvc.perform(get("/details") - .with(x509("classpath:org/springframework/security/config/http/MiscHttpConfigTests-certificate.pem"))) + this.mvc.perform(get("/details").with(x509)) .andExpect(content().string(details.getClass().getName())); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .with(csrf())) - .andReturn().getRequest().getSession(false); - - this.mvc.perform(get("/details") - .session(session)) + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .with(csrf()); + MockHttpSession session = (MockHttpSession) this.mvc.perform(loginRequest) + .andReturn() + .getRequest() + .getSession(false); + this.mvc.perform(get("/details").session(session)) .andExpect(content().string(details.getClass().getName())); - - assertThat(getField(getFilter(OpenIDAuthenticationFilter.class), "authenticationDetailsSource")) - .isEqualTo(source); + // @formatter:on + assertThat(ReflectionTestUtils.getField(getFilter(OpenIDAuthenticationFilter.class), + "authenticationDetailsSource")).isEqualTo(source); } @Test public void loginWhenUsingJaasApiProvisionThenJaasSubjectContainsUsername() throws Exception { this.spring.configLocations(xml("Jaas")).autowire(); - AuthorityGranter granter = this.spring.getContext().getBean(AuthorityGranter.class); - when(granter.grant(any(Principal.class))).thenReturn(new HashSet<>(Arrays.asList("USER"))); - - this.mvc.perform(get("/username") - .with(httpBasic("user", "password"))) + given(granter.grant(any(Principal.class))).willReturn(new HashSet<>(Arrays.asList("USER"))); + // @formatter:off + this.mvc.perform(get("/username").with(userCredentials())) .andExpect(content().string("user")); + // @formatter:on } @Test public void getWhenUsingCustomHttpFirewallThenFirewallIsInvoked() throws Exception { this.spring.configLocations(xml("HttpFirewall")).autowire(); - FirewalledRequest request = new FirewalledRequest(new MockHttpServletRequest()) { @Override - public void reset() { } + public void reset() { + } }; HttpServletResponse response = new MockHttpServletResponse(); - HttpFirewall firewall = this.spring.getContext().getBean(HttpFirewall.class); - when(firewall.getFirewalledRequest(any(HttpServletRequest.class))).thenReturn(request); - when(firewall.getFirewalledResponse(any(HttpServletResponse.class))).thenReturn(response); + given(firewall.getFirewalledRequest(any(HttpServletRequest.class))).willReturn(request); + given(firewall.getFirewalledResponse(any(HttpServletResponse.class))).willReturn(response); this.mvc.perform(get("/unprotected")); - verify(firewall).getFirewalledRequest(any(HttpServletRequest.class)); verify(firewall).getFirewalledResponse(any(HttpServletResponse.class)); } @@ -759,25 +657,22 @@ public class MiscHttpConfigTests { @Test public void getWhenUsingCustomRequestRejectedHandlerThenRequestRejectedHandlerIsInvoked() throws Exception { this.spring.configLocations(xml("RequestRejectedHandler")).autowire(); - HttpServletResponse response = new MockHttpServletResponse(); - RequestRejectedException rejected = new RequestRejectedException("failed"); HttpFirewall firewall = this.spring.getContext().getBean(HttpFirewall.class); RequestRejectedHandler requestRejectedHandler = this.spring.getContext().getBean(RequestRejectedHandler.class); - when(firewall.getFirewalledRequest(any(HttpServletRequest.class))).thenThrow(rejected); + given(firewall.getFirewalledRequest(any(HttpServletRequest.class))).willThrow(rejected); this.mvc.perform(get("/unprotected")); - verify(requestRejectedHandler).handle(any(), any(), any()); } @Test public void getWhenUsingCustomAccessDecisionManagerThenAuthorizesAccordingly() throws Exception { this.spring.configLocations(xml("CustomAccessDecisionManager")).autowire(); - - this.mvc.perform(get("/unprotected") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/unprotected").with(userCredentials())) .andExpect(status().isForbidden()); + // @formatter:on } /** @@ -786,178 +681,37 @@ public class MiscHttpConfigTests { @Test public void authenticateWhenUsingPortMapperThenRedirectsAppropriately() throws Exception { this.spring.configLocations(xml("PortsMappedRequiresHttps")).autowire(); - - MockHttpSession session = (MockHttpSession) - this.mvc.perform(get("https://localhost:9080/protected")) + // @formatter:off + MockHttpSession session = (MockHttpSession) this.mvc.perform(get("https://localhost:9080/protected")) .andExpect(redirectedUrl("https://localhost:9443/login")) - .andReturn().getRequest().getSession(false); - - session = (MockHttpSession) - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session(session) - .with(csrf())) - .andExpect(redirectedUrl("https://localhost:9443/protected")) - .andReturn().getRequest().getSession(false); - - this.mvc.perform(get("http://localhost:9080/protected") - .session(session)) + .andReturn() + .getRequest() + .getSession(false); + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .session(session) + .with(csrf()); + session = (MockHttpSession) this.mvc.perform(loginRequest) + .andExpect(redirectedUrl("https://localhost:9443/protected")) + .andReturn() + .getRequest() + .getSession(false); + this.mvc.perform(get("http://localhost:9080/protected").session(session)) .andExpect(redirectedUrl("https://localhost:9443/protected")); - } - - @RestController - static class BasicController { - @RequestMapping("/unprotected") - public String unprotected() { - return "ok"; - } - - @RequestMapping("/protected") - public String protectedMethod(@AuthenticationPrincipal String name) { - return name; - } - } - - @RestController - static class CustomKeyController { - @GetMapping("/customKey") - public String customKey() { - Authentication authentication = - SecurityContextHolder.getContext().getAuthentication(); - - if ( authentication != null && - authentication instanceof AnonymousAuthenticationToken ) { - return String.valueOf( - ((AnonymousAuthenticationToken) authentication).getKeyHash()); - } - - return null; - } - } - - @RestController - static class AuthenticationController { - @GetMapping("/password") - public String password(@AuthenticationPrincipal Authentication authentication) { - return (String) authentication.getCredentials(); - } - - @GetMapping("/roles") - public String roles(@AuthenticationPrincipal Authentication authentication) { - return authentication.getAuthorities().stream() - .map(GrantedAuthority::getAuthority) - .collect(Collectors.joining(",")); - } - - @GetMapping("/details") - public String details(@AuthenticationPrincipal Authentication authentication) { - return authentication.getDetails().getClass().getName(); - } - } - - @RestController - static class JaasController { - @GetMapping("/username") - public String username() { - Subject subject = Subject.getSubject(AccessController.getContext()); - return subject.getPrincipals().iterator().next().getName(); - } - } - - public static class JaasLoginModule implements LoginModule { - private Subject subject; - - @Override - public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { - this.subject = subject; - } - - @Override - public boolean login() { - return this.subject.getPrincipals().add(() -> "user"); - } - - @Override - public boolean commit() { - return true; - } - - @Override - public boolean abort() { - return true; - } - - @Override - public boolean logout() { - return true; - } - } - - static class MockAccessDecisionManager implements AccessDecisionManager { - - @Override - public void decide(Authentication authentication, Object object, Collection configAttributes) throws AccessDeniedException, InsufficientAuthenticationException { - throw new AccessDeniedException("teapot"); - } - - @Override - public boolean supports(ConfigAttribute attribute) { - return true; - } - - @Override - public boolean supports(Class clazz) { - return true; - } - } - - static class MockAuthenticationManager implements AuthenticationManager { - public Authentication authenticate(Authentication authentication) { - return new TestingAuthenticationToken(authentication.getPrincipal(), - authentication.getCredentials(), - AuthorityUtils.createAuthorityList("ROLE_USER")); - } - } - - static class EncodeUrlDenyingHttpServletResponseWrapper - extends HttpServletResponseWrapper { - - EncodeUrlDenyingHttpServletResponseWrapper(HttpServletResponse response) { - super(response); - } - - @Override - public String encodeURL(String url) { - throw new RuntimeException("Unexpected invocation of encodeURL"); - } - - @Override - public String encodeRedirectURL(String url) { - throw new RuntimeException("Unexpected invocation of encodeURL"); - } - - @Override - public String encodeUrl(String url) { - throw new RuntimeException("Unexpected invocation of encodeURL"); - } - - @Override - public String encodeRedirectUrl(String url) { - throw new RuntimeException("Unexpected invocation of encodeURL"); - } + // @formatter:on } private void redirectLogsTo(OutputStream os, Class clazz) { Logger logger = (Logger) LoggerFactory.getLogger(clazz); Appender appender = mock(Appender.class); - when(appender.isStarted()).thenReturn(true); - doAnswer(writeTo(os)).when(appender).doAppend(any(ILoggingEvent.class)); + given(appender.isStarted()).willReturn(true); + willAnswer(writeTo(os)).given(appender).doAppend(any(ILoggingEvent.class)); logger.addAppender(appender); } private Answer writeTo(OutputStream os) { - return invocation -> { + return (invocation) -> { os.write(invocation.getArgument(0).toString().getBytes()); return null; }; @@ -969,7 +723,6 @@ public class MiscHttpConfigTests { private void assertThatFiltersMatchExpectedAutoConfigList(String url) { Iterator filters = getFilters(url).iterator(); - assertThat(filters.next()).isInstanceOf(SecurityContextPersistenceFilter.class); assertThat(filters.next()).isInstanceOf(WebAsyncManagerIntegrationFilter.class); assertThat(filters.next()).isInstanceOf(HeaderWriterFilter.class); @@ -997,7 +750,174 @@ public class MiscHttpConfigTests { return proxy.getFilters(url); } + @NotNull + private static RequestPostProcessor userCredentials() { + return httpBasic("user", "password"); + } + + @NotNull + private static RequestPostProcessor adminCredentials() { + return httpBasic("admin", "password"); + } + + @NotNull + private static RequestPostProcessor postCredentials() { + return httpBasic("poster", "password"); + } + private static String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + + @RestController + static class BasicController { + + @RequestMapping("/unprotected") + String unprotected() { + return "ok"; + } + + @RequestMapping("/protected") + String protectedMethod(@AuthenticationPrincipal String name) { + return name; + } + + } + + @RestController + static class CustomKeyController { + + @GetMapping("/customKey") + String customKey() { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication != null && authentication instanceof AnonymousAuthenticationToken) { + return String.valueOf(((AnonymousAuthenticationToken) authentication).getKeyHash()); + } + return null; + } + + } + + @RestController + static class AuthenticationController { + + @GetMapping("/password") + String password(@AuthenticationPrincipal Authentication authentication) { + return (String) authentication.getCredentials(); + } + + @GetMapping("/roles") + String roles(@AuthenticationPrincipal Authentication authentication) { + return authentication.getAuthorities().stream().map(GrantedAuthority::getAuthority) + .collect(Collectors.joining(",")); + } + + @GetMapping("/details") + String details(@AuthenticationPrincipal Authentication authentication) { + return authentication.getDetails().getClass().getName(); + } + + } + + @RestController + static class JaasController { + + @GetMapping("/username") + String username() { + Subject subject = Subject.getSubject(AccessController.getContext()); + return subject.getPrincipals().iterator().next().getName(); + } + + } + + public static class JaasLoginModule implements LoginModule { + + private Subject subject; + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, + Map options) { + this.subject = subject; + } + + @Override + public boolean login() { + return this.subject.getPrincipals().add(() -> "user"); + } + + @Override + public boolean commit() { + return true; + } + + @Override + public boolean abort() { + return true; + } + + @Override + public boolean logout() { + return true; + } + + } + + static class MockAccessDecisionManager implements AccessDecisionManager { + + @Override + public void decide(Authentication authentication, Object object, Collection configAttributes) + throws AccessDeniedException, InsufficientAuthenticationException { + throw new AccessDeniedException("teapot"); + } + + @Override + public boolean supports(ConfigAttribute attribute) { + return true; + } + + @Override + public boolean supports(Class clazz) { + return true; + } + + } + + static class MockAuthenticationManager implements AuthenticationManager { + + @Override + public Authentication authenticate(Authentication authentication) { + return new TestingAuthenticationToken(authentication.getPrincipal(), authentication.getCredentials(), + AuthorityUtils.createAuthorityList("ROLE_USER")); + } + + } + + static class EncodeUrlDenyingHttpServletResponseWrapper extends HttpServletResponseWrapper { + + EncodeUrlDenyingHttpServletResponseWrapper(HttpServletResponse response) { + super(response); + } + + @Override + public String encodeURL(String url) { + throw new RuntimeException("Unexpected invocation of encodeURL"); + } + + @Override + public String encodeRedirectURL(String url) { + throw new RuntimeException("Unexpected invocation of encodeURL"); + } + + @Override + public String encodeUrl(String url) { + throw new RuntimeException("Unexpected invocation of encodeURL"); + } + + @Override + public String encodeRedirectUrl(String url) { + throw new RuntimeException("Unexpected invocation of encodeURL"); + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/MultiHttpBlockConfigTests.java b/config/src/test/java/org/springframework/security/config/http/MultiHttpBlockConfigTests.java index 9e33b67e2e..38d42bd9b1 100644 --- a/config/src/test/java/org/springframework/security/config/http/MultiHttpBlockConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/MultiHttpBlockConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.junit.Rule; @@ -23,10 +24,10 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.test.SpringTestRule; import org.springframework.stereotype.Controller; import org.springframework.test.web.servlet.MockMvc; -import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -40,8 +41,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Luke Taylor */ public class MultiHttpBlockConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/MultiHttpBlockConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/MultiHttpBlockConfigTests"; @Autowired MockMvc mvc; @@ -50,35 +51,33 @@ public class MultiHttpBlockConfigTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void requestWhenUsingMutuallyExclusiveHttpElementsThenIsRoutedAccordingly() - throws Exception { - + public void requestWhenUsingMutuallyExclusiveHttpElementsThenIsRoutedAccordingly() throws Exception { this.spring.configLocations(this.xml("DistinctHttpElements")).autowire(); - - this.mvc.perform(MockMvcRequestBuilders.get("/first") - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/first").with(httpBasic("user", "password"))) .andExpect(status().isOk()); - - this.mvc.perform(post("/second/login") + MockHttpServletRequestBuilder formLoginRequest = post("/second/login") .param("username", "user") .param("password", "password") - .with(csrf())) + .with(csrf()); + this.mvc.perform(formLoginRequest) .andExpect(status().isFound()) .andExpect(redirectedUrl("/")); + // @formatter:on } @Test public void configureWhenUsingDuplicateHttpElementsThenThrowsWiringException() { - assertThatCode(() -> this.spring.configLocations(this.xml("IdenticalHttpElements")).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasCauseInstanceOf(IllegalArgumentException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("IdenticalHttpElements")).autowire()) + .withCauseInstanceOf(IllegalArgumentException.class); } @Test public void configureWhenUsingIndenticallyPatternedHttpElementsThenThrowsWiringException() { - assertThatCode(() -> this.spring.configLocations(this.xml("IdenticallyPatternedHttpElements")).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasCauseInstanceOf(IllegalArgumentException.class); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("IdenticallyPatternedHttpElements")).autowire()) + .withCauseInstanceOf(IllegalArgumentException.class); } /** @@ -87,30 +86,34 @@ public class MultiHttpBlockConfigTests { @Test public void requestWhenTargettingAuthenticationManagersToCorrespondingHttpElementsThenAuthenticationProceeds() throws Exception { - this.spring.configLocations(this.xml("Sec1937")).autowire(); - - this.mvc.perform(get("/first") + // @formatter:off + MockHttpServletRequestBuilder basicLoginRequest = get("/first") .with(httpBasic("first", "password")) - .with(csrf())) + .with(csrf()); + this.mvc.perform(basicLoginRequest) .andExpect(status().isOk()); - - this.mvc.perform(post("/second/login") + MockHttpServletRequestBuilder formLoginRequest = post("/second/login") .param("username", "second") .param("password", "password") - .with(csrf())) + .with(csrf()); + this.mvc.perform(formLoginRequest) .andExpect(redirectedUrl("/")); - } - - @Controller - static class BasicController { - @GetMapping("/first") - public String first() { - return "ok"; - } + // @formatter:on } private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + + @Controller + static class BasicController { + + @GetMapping("/first") + String first() { + return "ok"; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/NamespaceHttpBasicTests.java b/config/src/test/java/org/springframework/security/config/http/NamespaceHttpBasicTests.java index 6aecdb0335..03f54cd561 100644 --- a/config/src/test/java/org/springframework/security/config/http/NamespaceHttpBasicTests.java +++ b/config/src/test/java/org/springframework/security/config/http/NamespaceHttpBasicTests.java @@ -39,11 +39,14 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Rob Winch */ public class NamespaceHttpBasicTests { + @Mock Method method; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; ConfigurableApplicationContext context; @@ -69,28 +72,25 @@ public class NamespaceHttpBasicTests { @Test public void httpBasicWithPasswordEncoder() throws Exception { // @formatter:off - loadContext("\n" + - " \n" + - " \n" + - " \n" + - "\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " "); - // @formatter:on - + loadContext("\n" + + " \n" + + " \n" + + "\n" + + "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + "\n" + + ""); + // @formatter:on this.request.addHeader("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:test".getBytes("UTF-8"))); - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @@ -98,23 +98,21 @@ public class NamespaceHttpBasicTests { @Test public void httpBasicUnauthorizedOnDefault() throws Exception { // @formatter:off - loadContext("\n" + - " \n" + - " \n" + - " \n" + - "\n" + - " "); + loadContext("\n" + + " \n" + + " \n" + + "\n" + + "\n" + + ""); // @formatter:on - this.springSecurityFilterChain.doFilter(this.request, this.response, this.chain); - assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); assertThat(this.response.getHeader("WWW-Authenticate")).isEqualTo("Basic realm=\"Realm\""); } private void loadContext(String context) { this.context = new InMemoryXmlApplicationContext(context); - this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", - Filter.class); + this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } + } diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java index c146779731..085ec83e68 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java @@ -13,12 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.HashMap; +import java.util.Map; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestRule; @@ -36,6 +41,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -47,14 +53,10 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -68,6 +70,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class OAuth2ClientBeanDefinitionParserTests { + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests"; @Rule @@ -97,84 +100,77 @@ public class OAuth2ClientBeanDefinitionParserTests { @Test public void requestWhenAuthorizeThenRedirect() throws Exception { this.spring.configLocations(xml("Minimal")).autowire(); - + // @formatter:off MvcResult result = this.mvc.perform(get("/oauth2/authorization/google")) .andExpect(status().is3xxRedirection()) .andReturn(); + // @formatter:on assertThat(result.getResponse().getRedirectedUrl()).matches( - "https://accounts.google.com/o/oauth2/v2/auth\\?" + - "response_type=code&client_id=google-client-id&" + - "scope=scope1%20scope2&state=.{15,}&redirect_uri=http://localhost/callback/google"); + "https://accounts.google.com/o/oauth2/v2/auth\\?" + "response_type=code&client_id=google-client-id&" + + "scope=scope1%20scope2&state=.{15,}&redirect_uri=http://localhost/callback/google"); } @Test public void requestWhenCustomClientRegistrationRepositoryThenCalled() throws Exception { this.spring.configLocations(xml("CustomClientRegistrationRepository")).autowire(); - + // @formatter:off ClientRegistration clientRegistration = CommonOAuth2Provider.GOOGLE.getBuilder("google") .clientId("google-client-id") .clientSecret("google-client-secret") .redirectUri("http://localhost/callback/google") .scope("scope1", "scope2") .build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(clientRegistration); - + // @formatter:on + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); + // @formatter:off MvcResult result = this.mvc.perform(get("/oauth2/authorization/google")) .andExpect(status().is3xxRedirection()) .andReturn(); + // @formatter:on assertThat(result.getResponse().getRedirectedUrl()).matches( - "https://accounts.google.com/o/oauth2/v2/auth\\?" + - "response_type=code&client_id=google-client-id&" + - "scope=scope1%20scope2&state=.{15,}&redirect_uri=http://localhost/callback/google"); - + "https://accounts.google.com/o/oauth2/v2/auth\\?" + "response_type=code&client_id=google-client-id&" + + "scope=scope1%20scope2&state=.{15,}&redirect_uri=http://localhost/callback/google"); verify(this.clientRegistrationRepository).findByRegistrationId(any()); } @Test public void requestWhenCustomAuthorizationRequestResolverThenCalled() throws Exception { this.spring.configLocations(xml("CustomConfiguration")).autowire(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest(clientRegistration); - when(this.authorizationRequestResolver.resolve(any())).thenReturn(authorizationRequest); - + given(this.authorizationRequestResolver.resolve(any())).willReturn(authorizationRequest); + // @formatter:off this.mvc.perform(get("/oauth2/authorization/google")) .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrl( - "https://accounts.google.com/o/oauth2/v2/auth?" + - "response_type=code&client_id=google-client-id&" + - "scope=scope1%20scope2&state=state&redirect_uri=http://localhost/callback/google")); - + .andExpect(redirectedUrl("https://accounts.google.com/o/oauth2/v2/auth?" + + "response_type=code&client_id=google-client-id&" + + "scope=scope1%20scope2&state=state&redirect_uri=http://localhost/callback/google")); + // @formatter:on verify(this.authorizationRequestResolver).resolve(any()); } @Test public void requestWhenAuthorizationResponseMatchThenProcess() throws Exception { this.spring.configLocations(xml("CustomConfiguration")).autowire(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest(clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(authorizationRequest); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) - .thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())).willReturn(authorizationRequest); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); + // @formatter:off this.mvc.perform(get(authorizationRequest.getRedirectUri()).params(params)) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl(authorizationRequest.getRedirectUri())); - - ArgumentCaptor authorizedClientCaptor = - ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository).saveAuthorizedClient( - authorizedClientCaptor.capture(), any(), any(), any()); + // @formatter:on + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), any(), any(), + any()); OAuth2AuthorizedClient authorizedClient = authorizedClientCaptor.getValue(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(clientRegistration); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -184,25 +180,21 @@ public class OAuth2ClientBeanDefinitionParserTests { @Test public void requestWhenCustomAuthorizedClientServiceThenCalled() throws Exception { this.spring.configLocations(xml("CustomAuthorizedClientService")).autowire(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest(clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(authorizationRequest); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) - .thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())).willReturn(authorizationRequest); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); + // @formatter:off this.mvc.perform(get(authorizationRequest.getRedirectUri()).params(params)) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl(authorizationRequest.getRedirectUri())); - + // @formatter:on verify(this.authorizedClientService).saveAuthorizedClient(any(), any()); } @@ -210,42 +202,40 @@ public class OAuth2ClientBeanDefinitionParserTests { @Test public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception { this.spring.configLocations(xml("AuthorizedClientArgumentResolver")).autowire(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "user", + TestOAuth2AccessTokens.noScopes()); + given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).willReturn(authorizedClient); + this.mvc.perform(get("/authorized-client")).andExpect(status().isOk()).andExpect(content().string("resolved")); + } - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, "user", TestOAuth2AccessTokens.noScopes()); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) - .thenReturn(authorizedClient); + private static OAuth2AuthorizationRequest createAuthorizationRequest(ClientRegistration clientRegistration) { + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); + // @formatter:off + return OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .clientId(clientRegistration.getClientId()).redirectUri(clientRegistration.getRedirectUri()) + .scopes(clientRegistration.getScopes()) + .state("state") + .attributes(attributes) + .build(); + // @formatter:on + } - this.mvc.perform(get("/authorized-client")) - .andExpect(status().isOk()) - .andExpect(content().string("resolved")); + private static String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } @RestController static class AuthorizedClientController { @GetMapping("/authorized-client") - String authorizedClient(Model model, @RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) { - return authorizedClient != null ? "resolved" : "not-resolved"; + String authorizedClient(Model model, + @RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) { + return (authorizedClient != null) ? "resolved" : "not-resolved"; } + } - private static OAuth2AuthorizationRequest createAuthorizationRequest(ClientRegistration clientRegistration) { - Map attributes = new HashMap<>(); - attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); - return OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) - .clientId(clientRegistration.getClientId()) - .redirectUri(clientRegistration.getRedirectUri()) - .scopes(clientRegistration.getScopes()) - .state("state") - .attributes(attributes) - .build(); - } - - private static String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java index 3c0c793b25..123f7d9093 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; import org.springframework.http.MediaType; @@ -47,6 +53,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; @@ -68,17 +75,11 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.oidcAccessTokenResponse; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -92,6 +93,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class OAuth2LoginBeanDefinitionParserTests { + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests"; @Rule @@ -142,11 +144,11 @@ public class OAuth2LoginBeanDefinitionParserTests { @Test public void requestLoginWhenMultiClientRegistrationThenReturnLoginPageWithClients() throws Exception { this.spring.configLocations(this.xml("MultiClientRegistration")).autowire(); - + // @formatter:off MvcResult result = this.mvc.perform(get("/login")) .andExpect(status().is2xxSuccessful()) .andReturn(); - + // @formatter:on assertThat(result.getResponse().getContentAsString()) .contains("Google"); assertThat(result.getResponse().getContentAsString()) @@ -157,12 +159,12 @@ public class OAuth2LoginBeanDefinitionParserTests { @Test public void requestWhenSingleClientRegistrationThenAutoRedirect() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/oauth2/authorization/google-login")); - - verify(requestCache).saveRequest(any(), any()); + // @formatter:on + verify(this.requestCache).saveRequest(any(), any()); } // gh-5347 @@ -170,10 +172,11 @@ public class OAuth2LoginBeanDefinitionParserTests { public void requestWhenSingleClientRegistrationAndRequestFaviconNotAuthenticatedThenRedirectDefaultLoginPage() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration")).autowire(); - + // @formatter:off this.mvc.perform(get("/favicon.ico").accept(new MediaType("image", "*"))) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } // gh-6812 @@ -181,25 +184,24 @@ public class OAuth2LoginBeanDefinitionParserTests { public void requestWhenSingleClientRegistrationAndRequestXHRNotAuthenticatedThenDoesNotRedirectForAuthorization() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration")).autowire(); - + // @formatter:off this.mvc.perform(get("/").header("X-Requested-With", "XMLHttpRequest")) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test public void requestWhenAuthorizationRequestNotFoundThenThrowAuthenticationException() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration-WithCustomAuthenticationFailureHandler")) .autowire(); - MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", "state123"); this.mvc.perform(get("/login/oauth2/code/google").params(params)); - ArgumentCaptor exceptionCaptor = ArgumentCaptor .forClass(AuthenticationException.class); - verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), exceptionCaptor.capture()); + verify(this.authenticationFailureHandler).onAuthenticationFailure(any(), any(), exceptionCaptor.capture()); AuthenticationException exception = exceptionCaptor.getValue(); assertThat(exception).isInstanceOf(OAuth2AuthenticationException.class); assertThat(((OAuth2AuthenticationException) exception).getError().getErrorCode()) @@ -209,27 +211,25 @@ public class OAuth2LoginBeanDefinitionParserTests { @Test public void requestWhenAuthorizationResponseValidThenAuthenticate() throws Exception { this.spring.configLocations(this.xml("MultiClientRegistration-WithCustomConfiguration")).autowire(); - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, "github-login"); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); + // @formatter:off this.mvc.perform(get("/login/oauth2/code/github-login").params(params)) .andExpect(status().is2xxSuccessful()); - + // @formatter:on ArgumentCaptor authenticationCaptor = ArgumentCaptor.forClass(Authentication.class); - verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); + verify(this.authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); Authentication authentication = authenticationCaptor.getValue(); assertThat(authentication.getPrincipal()).isInstanceOf(OAuth2User.class); } @@ -238,51 +238,46 @@ public class OAuth2LoginBeanDefinitionParserTests { @Test public void requestWhenAuthorizationResponseValidThenAuthenticationSuccessEventPublished() throws Exception { this.spring.configLocations(this.xml("MultiClientRegistration-WithCustomConfiguration")).autowire(); - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, "github-login"); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); this.mvc.perform(get("/login/oauth2/code/github-login").params(params)); - - verify(authenticationSuccessListener).onApplicationEvent(any(AuthenticationSuccessEvent.class)); + verify(this.authenticationSuccessListener).onApplicationEvent(any(AuthenticationSuccessEvent.class)); } @Test public void requestWhenOidcAuthenticationResponseValidThenJwtDecoderFactoryCalled() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration-WithJwtDecoderFactoryAndDefaultSuccessHandler")) .autowire(); - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, "google-login"); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.oidcRequest() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = oidcAccessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse() + .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); Jwt jwt = TestJwts.user(); - when(this.jwtDecoderFactory.createDecoder(any())).thenReturn(token -> jwt); - + given(this.jwtDecoderFactory.createDecoder(any())).willReturn((token) -> jwt); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); + // @formatter:off this.mvc.perform(get("/login/oauth2/code/google-login").params(params)) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("/")); - + // @formatter:on verify(this.jwtDecoderFactory).createDecoder(any()); verify(this.requestCache).getRequest(any(), any()); } @@ -291,57 +286,47 @@ public class OAuth2LoginBeanDefinitionParserTests { @Test public void requestWhenCustomGrantedAuthoritiesMapperThenCalled() throws Exception { this.spring.configLocations(this.xml("MultiClientRegistration-WithCustomGrantedAuthorities")).autowire(); - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, "github-login"); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - - when(this.userAuthoritiesMapper.mapAuthorities(any())).thenReturn( - (Collection) AuthorityUtils.createAuthorityList("ROLE_OAUTH2_USER")); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); + given(this.userAuthoritiesMapper.mapAuthorities(any())) + .willReturn((Collection) AuthorityUtils.createAuthorityList("ROLE_OAUTH2_USER")); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); - this.mvc.perform(get("/login/oauth2/code/github-login").params(params)) - .andExpect(status().is2xxSuccessful()); - + this.mvc.perform(get("/login/oauth2/code/github-login").params(params)).andExpect(status().is2xxSuccessful()); ArgumentCaptor authenticationCaptor = ArgumentCaptor.forClass(Authentication.class); - verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); + verify(this.authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); Authentication authentication = authenticationCaptor.getValue(); assertThat(authentication.getPrincipal()).isInstanceOf(OAuth2User.class); assertThat(authentication.getAuthorities()).hasSize(1); assertThat(authentication.getAuthorities()).first().isInstanceOf(SimpleGrantedAuthority.class) .hasToString("ROLE_OAUTH2_USER"); - // re-setup for OIDC test attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, "google-login"); - authorizationRequest = TestOAuth2AuthorizationRequests.oidcRequest() - .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - accessTokenResponse = oidcAccessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + authorizationRequest = TestOAuth2AuthorizationRequests.oidcRequest().attributes(attributes).build(); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + accessTokenResponse = TestOAuth2AccessTokenResponses.oidcAccessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); Jwt jwt = TestJwts.user(); - when(this.jwtDecoderFactory.createDecoder(any())).thenReturn(token -> jwt); - - when(this.userAuthoritiesMapper.mapAuthorities(any())) - .thenReturn((Collection) AuthorityUtils.createAuthorityList("ROLE_OIDC_USER")); - + given(this.jwtDecoderFactory.createDecoder(any())).willReturn((token) -> jwt); + given(this.userAuthoritiesMapper.mapAuthorities(any())) + .willReturn((Collection) AuthorityUtils.createAuthorityList("ROLE_OIDC_USER")); + // @formatter:off this.mvc.perform(get("/login/oauth2/code/google-login").params(params)) .andExpect(status().is2xxSuccessful()); - + // @formatter:on authenticationCaptor = ArgumentCaptor.forClass(Authentication.class); - verify(authenticationSuccessHandler, times(2)).onAuthenticationSuccess(any(), any(), + verify(this.authenticationSuccessHandler, times(2)).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); authentication = authenticationCaptor.getValue(); assertThat(authentication.getPrincipal()).isInstanceOf(OidcUser.class); @@ -354,27 +339,25 @@ public class OAuth2LoginBeanDefinitionParserTests { @Test public void requestWhenCustomLoginProcessingUrlThenProcessAuthentication() throws Exception { this.spring.configLocations(this.xml("MultiClientRegistration-WithCustomLoginProcessingUrl")).autowire(); - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, "github-login"); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); + // @formatter:off this.mvc.perform(get("/login/oauth2/github-login").params(params)) .andExpect(status().is2xxSuccessful()); - + // @formatter:on ArgumentCaptor authenticationCaptor = ArgumentCaptor.forClass(Authentication.class); - verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); + verify(this.authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture()); Authentication authentication = authenticationCaptor.getValue(); assertThat(authentication.getPrincipal()).isInstanceOf(OAuth2User.class); } @@ -384,30 +367,32 @@ public class OAuth2LoginBeanDefinitionParserTests { public void requestWhenCustomAuthorizationRequestResolverThenCalled() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration-WithCustomAuthorizationRequestResolver")) .autowire(); - + // @formatter:off this.mvc.perform(get("/oauth2/authorization/google-login")) .andExpect(status().is3xxRedirection()); - - verify(authorizationRequestResolver).resolve(any()); + // @formatter:on + verify(this.authorizationRequestResolver).resolve(any()); } // gh-5347 @Test public void requestWhenMultiClientRegistrationThenRedirectDefaultLoginPage() throws Exception { this.spring.configLocations(this.xml("MultiClientRegistration")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test public void requestWhenCustomLoginPageThenRedirectCustomLoginPage() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration-WithCustomLoginPage")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/custom-login")); + // @formatter:on } // gh-6802 @@ -415,120 +400,107 @@ public class OAuth2LoginBeanDefinitionParserTests { public void requestWhenSingleClientRegistrationAndFormLoginConfiguredThenRedirectDefaultLoginPage() throws Exception { this.spring.configLocations(this.xml("SingleClientRegistration-WithFormLogin")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().is3xxRedirection()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test public void requestWhenCustomClientRegistrationRepositoryThenCalled() throws Exception { this.spring.configLocations(this.xml("WithCustomClientRegistrationRepository")).autowire(); - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(clientRegistration); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); this.mvc.perform(get("/login/oauth2/code/" + clientRegistration.getRegistrationId()).params(params)); - - verify(clientRegistrationRepository).findByRegistrationId(clientRegistration.getRegistrationId()); + verify(this.clientRegistrationRepository).findByRegistrationId(clientRegistration.getRegistrationId()); } @Test public void requestWhenCustomAuthorizedClientRepositoryThenCalled() throws Exception { this.spring.configLocations(this.xml("WithCustomAuthorizedClientRepository")).autowire(); - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(clientRegistration); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); this.mvc.perform(get("/login/oauth2/code/" + clientRegistration.getRegistrationId()).params(params)); - - verify(authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any()); } @Test public void requestWhenCustomAuthorizedClientServiceThenCalled() throws Exception { this.spring.configLocations(this.xml("WithCustomAuthorizedClientService")).autowire(); - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(clientRegistration); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .attributes(attributes).build(); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())).thenReturn(authorizationRequest); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any(), any())) + .willReturn(authorizationRequest); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User oauth2User = TestOAuth2Users.create(); - when(this.oauth2UserService.loadUser(any())).thenReturn(oauth2User); - + given(this.oauth2UserService.loadUser(any())).willReturn(oauth2User); MultiValueMap params = new LinkedMultiValueMap<>(); params.add("code", "code123"); params.add("state", authorizationRequest.getState()); this.mvc.perform(get("/login/oauth2/code/" + clientRegistration.getRegistrationId()).params(params)); - - verify(authorizedClientService).saveAuthorizedClient(any(), any()); + verify(this.authorizedClientService).saveAuthorizedClient(any(), any()); } @WithMockUser @Test public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception { this.spring.configLocations(xml("AuthorizedClientArgumentResolver")).autowire(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google-login"); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, "user", TestOAuth2AccessTokens.noScopes()); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) - .thenReturn(authorizedClient); - + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "user", + TestOAuth2AccessTokens.noScopes()); + given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).willReturn(authorizedClient); + // @formatter:off this.mvc.perform(get("/authorized-client")) .andExpect(status().isOk()) .andExpect(content().string("resolved")); + // @formatter:on + } + + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } @RestController static class AuthorizedClientController { @GetMapping("/authorized-client") - String authorizedClient(Model model, @RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) { - return authorizedClient != null ? "resolved" : "not-resolved"; + String authorizedClient(Model model, + @RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) { + return (authorizedClient != null) ? "resolved" : "not-resolved"; } + } - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java index 15186e0b7e..2cfdfc536c 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.io.BufferedReader; @@ -28,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Properties; import java.util.stream.Collectors; + import javax.servlet.http.HttpServletRequest; import com.nimbusds.jose.JWSAlgorithm; @@ -67,15 +69,19 @@ import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.security.authentication.AuthenticationManagerResolver; +import org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.JwtBeanDefinitionParser; +import org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.OpaqueTokenBeanDefinitionParser; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector; import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector; @@ -85,31 +91,23 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.hamcrest.CoreMatchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.config.http.JwtBeanDefinitionParser.DECODER_REF; -import static org.springframework.security.config.http.JwtBeanDefinitionParser.JWK_SET_URI; -import static org.springframework.security.config.http.OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF; -import static org.springframework.security.config.http.OpaqueTokenBeanDefinitionParser.INTROSPECTION_URI; -import static org.springframework.security.config.http.OpaqueTokenBeanDefinitionParser.INTROSPECTOR_REF; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -118,14 +116,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Josh Cummings */ @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class OAuth2ResourceServerBeanDefinitionParserTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -141,10 +138,10 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test @@ -152,356 +149,297 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { this.spring.configLocations(xml("WebServer"), xml("JwkSetUri")).autowire(); mockWebServer(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void getWhenExpiredBearerTokenThenInvalidToken() - throws Exception { - + public void getWhenExpiredBearerTokenThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("Expired"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } @Test - public void getWhenBadJwkEndpointThenInvalidToken() - throws Exception { - + public void getWhenBadJwkEndpointThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations("malformed"); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(header().string("WWW-Authenticate", "Bearer")); + // @formatter:on } @Test - public void getWhenUnavailableJwkEndpointThenInvalidToken() - throws Exception { - + public void getWhenUnavailableJwkEndpointThenInvalidToken() throws Exception { this.spring.configLocations(xml("WebServer"), xml("JwkSetUri")).autowire(); this.web.shutdown(); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(header().string("WWW-Authenticate", "Bearer")); + // @formatter:on } @Test - public void getWhenMalformedBearerTokenThenInvalidToken() - throws Exception { - + public void getWhenMalformedBearerTokenThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer an\"invalid\"token")) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer an\"invalid\"token")) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Bearer token is malformed")); + // @formatter:on } @Test - public void getWhenMalformedPayloadThenInvalidToken() - throws Exception { - + public void getWhenMalformedPayloadThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("MalformedPayload"); - + // @formatter:off this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt: Malformed payload")); + // @formatter:on } @Test - public void getWhenUnsignedBearerTokenThenInvalidToken() - throws Exception { - + public void getWhenUnsignedBearerTokenThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); String token = this.token("Unsigned"); - + // @formatter:off this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Unsupported algorithm of none")); + // @formatter:on } @Test - public void getWhenBearerTokenBeforeNotBeforeThenInvalidToken() - throws Exception { - + public void getWhenBearerTokenBeforeNotBeforeThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); this.mockRestOperations(jwks("Default")); String token = this.token("TooEarly"); - + // @formatter:off this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } @Test - public void getWhenBearerTokenInTwoPlacesThenInvalidRequest() - throws Exception { - + public void getWhenBearerTokenInTwoPlacesThenInvalidRequest() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer token") - .param("access_token", "token")) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer token").param("access_token", "token")) .andExpect(status().isBadRequest()) .andExpect(invalidRequestHeader("Found multiple bearer tokens in the request")); + // @formatter:on } @Test - public void getWhenBearerTokenInTwoParametersThenInvalidRequest() - throws Exception { - + public void getWhenBearerTokenInTwoParametersThenInvalidRequest() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - MultiValueMap params = new LinkedMultiValueMap<>(); params.add("access_token", "token1"); params.add("access_token", "token2"); - - this.mvc.perform(get("/") - .params(params)) + // @formatter:off + this.mvc.perform(get("/").params(params)) .andExpect(status().isBadRequest()) .andExpect(invalidRequestHeader("Found multiple bearer tokens in the request")); + // @formatter:on } @Test - public void postWhenBearerTokenAsFormParameterThenIgnoresToken() - throws Exception { - + public void postWhenBearerTokenAsFormParameterThenIgnoresToken() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - this.mvc.perform(post("/") // engage csrf - .param("access_token", "token")) - .andExpect(status().isForbidden()) - .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); // different from DSL + .param("access_token", "token")).andExpect(status().isForbidden()) + .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); // different + // from + // DSL } @Test - public void getWhenNoBearerTokenThenUnauthorized() - throws Exception { - + public void getWhenNoBearerTokenThenUnauthorized() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); + // @formatter:on } @Test - public void getWhenSufficientlyScopedBearerTokenThenAcceptsRequest() - throws Exception { - + public void getWhenSufficientlyScopedBearerTokenThenAcceptsRequest() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageReadScope"); - - this.mvc.perform(get("/requires-read-scope") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void getWhenInsufficientScopeThenInsufficientScopeError() - throws Exception { - + public void getWhenInsufficientScopeThenInsufficientScopeError() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/requires-read-scope") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").header("Authorization", "Bearer " + token)) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); + // @formatter:on } @Test - public void getWhenInsufficientScpThenInsufficientScopeError() - throws Exception { - + public void getWhenInsufficientScpThenInsufficientScopeError() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidMessageWriteScp"); - - this.mvc.perform(get("/requires-read-scope") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").header("Authorization", "Bearer " + token)) .andExpect(status().isForbidden()) .andExpect(insufficientScopeHeader()); + // @formatter:on } @Test - public void getWhenAuthorizationServerHasNoMatchingKeyThenInvalidToken() - throws Exception { - + public void getWhenAuthorizationServerHasNoMatchingKeyThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Empty")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } @Test - public void getWhenAuthorizationServerHasMultipleMatchingKeysThenOk() - throws Exception { - + public void getWhenAuthorizationServerHasMultipleMatchingKeysThenOk() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("TwoKeys")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void getWhenKeyMatchesByKidThenOk() - throws Exception { - + public void getWhenKeyMatchesByKidThenOk() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("TwoKeys")); String token = this.token("Kid"); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } - // -- Resource Server should not engage csrf - @Test - public void postWhenValidBearerTokenAndNoCsrfTokenThenOk() - throws Exception { - + public void postWhenValidBearerTokenAndNoCsrfTokenThenOk() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(post("/authenticated") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(post("/authenticated").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void postWhenNoBearerTokenThenCsrfDenies() - throws Exception { - + public void postWhenNoBearerTokenThenCsrfDenies() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - + // @formatter:off this.mvc.perform(post("/authenticated")) .andExpect(status().isForbidden()) - .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); // different from DSL + // different from DSL + .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer")); + // @formatter:on } @Test - public void postWhenExpiredBearerTokenAndNoCsrfThenInvalidToken() - throws Exception { - + public void postWhenExpiredBearerTokenAndNoCsrfThenInvalidToken() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("Expired"); - - this.mvc.perform(post("/authenticated") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(post("/authenticated").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt")); + // @formatter:on } - // -- Resource Server should not create sessions - @Test - public void requestWhenJwtThenSessionIsNotCreated() - throws Exception { - + public void requestWhenJwtThenSessionIsNotCreated() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - MvcResult result = this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + MvcResult result = this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void requestWhenIntrospectionThenSessionIsNotCreated() - throws Exception { - + public void requestWhenIntrospectionThenSessionIsNotCreated() throws Exception { this.spring.configLocations(xml("WebServer"), xml("IntrospectionUri")).autowire(); mockWebServer(json("Active")); - - MvcResult result = this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer token")) + // @formatter:off + MvcResult result = this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void requestWhenNoBearerTokenThenSessionIsCreated() - throws Exception { - + public void requestWhenNoBearerTokenThenSessionIsCreated() throws Exception { this.spring.configLocations(xml("JwkSetUri")).autowire(); - + // @formatter:off MvcResult result = this.mvc.perform(get("/")) .andExpect(status().isUnauthorized()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNotNull(); } @Test - public void requestWhenSessionManagementConfiguredThenUses() - throws Exception { - + public void requestWhenSessionManagementConfiguredThenUses() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("AlwaysSessionCreation")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - MvcResult result = this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + MvcResult result = this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()) .andReturn(); - + // @formatter:on assertThat(result.getRequest().getSession(false)).isNotNull(); } - // -- custom bearer token resolver - @Test public void getWhenCustomBearerTokenResolverThenUses() throws Exception { - this.spring.configLocations(xml("MockBearerTokenResolver"), xml("MockJwtDecoder"), - xml("BearerTokenResolver")).autowire(); - + this.spring.configLocations(xml("MockBearerTokenResolver"), xml("MockJwtDecoder"), xml("BearerTokenResolver")) + .autowire(); JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode("token")).thenReturn(jwt().build()); - + given(decoder.decode("token")).willReturn(TestJwts.jwt().build()); BearerTokenResolver bearerTokenResolver = this.spring.getContext().getBean(BearerTokenResolver.class); - when(bearerTokenResolver.resolve(any(HttpServletRequest.class))) - .thenReturn("token"); - - this.mvc.perform(get("/")) - .andExpect(status().isNotFound()); - + given(bearerTokenResolver.resolve(any(HttpServletRequest.class))).willReturn("token"); + this.mvc.perform(get("/")).andExpect(status().isNotFound()); verify(decoder).decode("token"); verify(bearerTokenResolver).resolve(any(HttpServletRequest.class)); } @@ -509,320 +447,258 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { @Test public void requestWhenBearerTokenResolverAllowsRequestBodyThenEitherHeaderOrRequestBodyIsAccepted() throws Exception { - this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInBody")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenReturn(jwt().build()); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer token")) + given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build()); + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()); - - this.mvc.perform(post("/authenticated") - .param("access_token", "token")) + this.mvc.perform(post("/authenticated").param("access_token", "token")) .andExpect(status().isNotFound()); + // @formatter:on } @Test public void requestWhenBearerTokenResolverAllowsQueryParameterThenEitherHeaderOrQueryParameterIsAccepted() throws Exception { - this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInQuery")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(decoder.decode(anyString())).thenReturn(jwt().build()); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer token")) + given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build()); + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()); - - this.mvc.perform(get("/authenticated") - .param("access_token", "token")) + this.mvc.perform(get("/authenticated").param("access_token", "token")) .andExpect(status().isNotFound()); - + // @formatter:on verify(decoder, times(2)).decode("token"); } @Test public void requestWhenBearerTokenResolverAllowsRequestBodyAndRequestContainsTwoTokensThenInvalidRequest() throws Exception { - this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInBody")).autowire(); - - this.mvc.perform(post("/authenticated") + // @formatter:off + MockHttpServletRequestBuilder request = post("/authenticated") .param("access_token", "token") .header("Authorization", "Bearer token") - .with(csrf())) + .with(csrf()); + this.mvc.perform(request) .andExpect(status().isBadRequest()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("invalid_request"))); + // @formatter:on } @Test public void requestWhenBearerTokenResolverAllowsQueryParameterAndRequestContainsTwoTokensThenInvalidRequest() throws Exception { - this.spring.configLocations(xml("MockJwtDecoder"), xml("AllowBearerTokenInQuery")).autowire(); - - this.mvc.perform(get("/authenticated") + // @formatter:off + MockHttpServletRequestBuilder request = get("/authenticated") .header("Authorization", "Bearer token") - .param("access_token", "token")) + .param("access_token", "token"); + this.mvc.perform(request) .andExpect(status().isBadRequest()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("invalid_request"))); + // @formatter:on } @Test public void getBearerTokenResolverWhenNoResolverSpecifiedThenTheDefaultIsUsed() { - OAuth2ResourceServerBeanDefinitionParser oauth2 = - new OAuth2ResourceServerBeanDefinitionParser - (mock(BeanReference.class), mock(List.class), mock(Map.class), - mock(Map.class), mock(List.class)); - - assertThat(oauth2.getBearerTokenResolver(mock(Element.class))) - .isInstanceOf(RootBeanDefinition.class); + OAuth2ResourceServerBeanDefinitionParser oauth2 = new OAuth2ResourceServerBeanDefinitionParser( + mock(BeanReference.class), mock(List.class), mock(Map.class), mock(Map.class), mock(List.class)); + assertThat(oauth2.getBearerTokenResolver(mock(Element.class))).isInstanceOf(RootBeanDefinition.class); } - // -- custom jwt decoder - @Test - public void requestWhenCustomJwtDecoderThenUsed() - throws Exception { - + public void requestWhenCustomJwtDecoderThenUsed() throws Exception { this.spring.configLocations(xml("MockJwtDecoder"), xml("Jwt")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - - when(decoder.decode(anyString())).thenReturn(jwt().build()); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer token")) + given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build()); + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()); - verify(decoder).decode("token"); } @Test public void configureWhenDecoderAndJwkSetUriThenException() { - assertThatThrownBy(() -> this.spring.configLocations(xml("JwtDecoderAndJwkSetUri")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("JwtDecoderAndJwkSetUri")).autowire()); } - // -- exception handling - @Test - public void requestWhenRealmNameConfiguredThenUsesOnUnauthenticated() - throws Exception { - + public void requestWhenRealmNameConfiguredThenUsesOnUnauthenticated() throws Exception { this.spring.configLocations(xml("MockJwtDecoder"), xml("AuthenticationEntryPoint")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); Mockito.when(decoder.decode(anyString())).thenThrow(JwtException.class); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer invalid_token")) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer invalid_token")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer realm=\"myRealm\""))); + // @formatter:on } @Test - public void requestWhenRealmNameConfiguredThenUsesOnAccessDenied() - throws Exception { - + public void requestWhenRealmNameConfiguredThenUsesOnAccessDenied() throws Exception { this.spring.configLocations(xml("MockJwtDecoder"), xml("AccessDeniedHandler")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(decoder.decode(anyString())).thenReturn(jwt().build()); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer insufficiently_scoped")) + given(decoder.decode(anyString())).willReturn(TestJwts.jwt().build()); + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer insufficiently_scoped")) .andExpect(status().isForbidden()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer realm=\"myRealm\""))); + // @formatter:on } - // -- token validator - @Test - public void requestWhenCustomJwtValidatorFailsThenCorrespondingErrorMessage() - throws Exception { - + public void requestWhenCustomJwtValidatorFailsThenCorrespondingErrorMessage() throws Exception { this.spring.configLocations(xml("MockJwtValidator"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - OAuth2TokenValidator jwtValidator = - this.spring.getContext().getBean(OAuth2TokenValidator.class); - + OAuth2TokenValidator jwtValidator = this.spring.getContext().getBean(OAuth2TokenValidator.class); OAuth2Error error = new OAuth2Error("custom-error", "custom-description", "custom-uri"); - - when(jwtValidator.validate(any(Jwt.class))).thenReturn(OAuth2TokenValidatorResult.failure(error)); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + given(jwtValidator.validate(any(Jwt.class))).willReturn(OAuth2TokenValidatorResult.failure(error)); + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("custom-description"))); + // @formatter:on } @Test - public void requestWhenClockSkewSetThenTimestampWindowRelaxedAccordingly() - throws Exception { - + public void requestWhenClockSkewSetThenTimestampWindowRelaxedAccordingly() throws Exception { this.spring.configLocations(xml("UnexpiredJwtClockSkew"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ExpiresAt4687177990"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void requestWhenClockSkewSetButJwtStillTooLateThenReportsExpired() - throws Exception { - + public void requestWhenClockSkewSetButJwtStillTooLateThenReportsExpired() throws Exception { this.spring.configLocations(xml("ExpiredJwtClockSkew"), xml("Jwt")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ExpiresAt4687177990"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Jwt expired at")); + // @formatter:on } - // -- converter - @Test - public void requestWhenJwtAuthenticationConverterThenUsed() - throws Exception { - - this.spring.configLocations(xml("MockJwtDecoder"), xml("MockJwtAuthenticationConverter"), xml("JwtAuthenticationConverter")).autowire(); - - Converter jwtAuthenticationConverter = - (Converter) this.spring.getContext().getBean("jwtAuthenticationConverter"); - when(jwtAuthenticationConverter.convert(any(Jwt.class))) - .thenReturn(new JwtAuthenticationToken(jwt().build(), Collections.emptyList())); - + public void requestWhenJwtAuthenticationConverterThenUsed() throws Exception { + this.spring.configLocations(xml("MockJwtDecoder"), xml("MockJwtAuthenticationConverter"), + xml("JwtAuthenticationConverter")).autowire(); + Converter jwtAuthenticationConverter = (Converter) this.spring + .getContext().getBean("jwtAuthenticationConverter"); + given(jwtAuthenticationConverter.convert(any(Jwt.class))) + .willReturn(new JwtAuthenticationToken(TestJwts.jwt().build(), Collections.emptyList())); JwtDecoder jwtDecoder = this.spring.getContext().getBean(JwtDecoder.class); - Mockito.when(jwtDecoder.decode(anyString())).thenReturn(jwt().build()); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer token")) + given(jwtDecoder.decode(anyString())).willReturn(TestJwts.jwt().build()); + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()); - + // @formatter:on verify(jwtAuthenticationConverter).convert(any(Jwt.class)); } - // -- single key - @Test - public void requestWhenUsingPublicKeyAndValidTokenThenAuthenticates() - throws Exception { - + public void requestWhenUsingPublicKeyAndValidTokenThenAuthenticates() throws Exception { this.spring.configLocations(xml("SingleKey"), xml("Jwt")).autowire(); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); + // @formatter:on } @Test - public void requestWhenUsingPublicKeyAndSignatureFailsThenReturnsInvalidToken() - throws Exception { - + public void requestWhenUsingPublicKeyAndSignatureFailsThenReturnsInvalidToken() throws Exception { this.spring.configLocations(xml("SingleKey"), xml("Jwt")).autowire(); String token = this.token("WrongSignature"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(invalidTokenHeader("signature")); + // @formatter:on } @Test - public void requestWhenUsingPublicKeyAlgorithmDoesNotMatchThenReturnsInvalidToken() - throws Exception { - + public void requestWhenUsingPublicKeyAlgorithmDoesNotMatchThenReturnsInvalidToken() throws Exception { this.spring.configLocations(xml("SingleKey"), xml("Jwt")).autowire(); String token = this.token("WrongAlgorithm"); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer " + token)) .andExpect(invalidTokenHeader("algorithm")); + // @formatter:on } - // -- opaque - @Test public void getWhenIntrospectingThenOk() throws Exception { this.spring.configLocations(xml("OpaqueTokenRestOperations"), xml("OpaqueToken")).autowire(); mockRestOperations(json("Active")); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer token")) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()); + // @formatter:on } @Test public void getWhenIntrospectionFailsThenUnauthorized() throws Exception { this.spring.configLocations(xml("OpaqueTokenRestOperations"), xml("OpaqueToken")).autowire(); mockRestOperations(json("Inactive")); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer token")) + // @formatter:off + MockHttpServletRequestBuilder request = get("/") + .header("Authorization", "Bearer token"); + this.mvc.perform(request) .andExpect(status().isUnauthorized()) - .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, - containsString("Provided token isn't active"))); + .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("Provided token isn't active"))); + // @formatter:on } @Test public void getWhenIntrospectionLacksScopeThenForbidden() throws Exception { this.spring.configLocations(xml("OpaqueTokenRestOperations"), xml("OpaqueToken")).autowire(); mockRestOperations(json("ActiveNoScopes")); - - this.mvc.perform(get("/requires-read-scope") - .header("Authorization", "Bearer token")) + // @formatter:off + this.mvc.perform(get("/requires-read-scope").header("Authorization", "Bearer token")) .andExpect(status().isForbidden()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("scope"))); + // @formatter:on } @Test public void configureWhenOnlyIntrospectionUrlThenException() { - assertThatCode(() -> this.spring.configLocations(xml("OpaqueTokenHalfConfigured")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("OpaqueTokenHalfConfigured")).autowire()); } @Test public void configureWhenIntrospectorAndIntrospectionUriThenError() { - assertThatCode(() -> this.spring.configLocations(xml("OpaqueTokenAndIntrospectionUri")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("OpaqueTokenAndIntrospectionUri")).autowire()); } - // -- authentication manager resolver - @Test public void getWhenAuthenticationManagerResolverThenUses() throws Exception { this.spring.configLocations(xml("AuthenticationManagerResolver")).autowire(); - - AuthenticationManagerResolver authenticationManagerResolver = - this.spring.getContext().getBean(AuthenticationManagerResolver.class); - when(authenticationManagerResolver.resolve(any(HttpServletRequest.class))) - .thenReturn(authentication -> new JwtAuthenticationToken(jwt().build(), Collections.emptyList())); - - this.mvc.perform(get("/") - .header("Authorization", "Bearer token")) + AuthenticationManagerResolver authenticationManagerResolver = this.spring.getContext() + .getBean(AuthenticationManagerResolver.class); + given(authenticationManagerResolver.resolve(any(HttpServletRequest.class))).willReturn( + (authentication) -> new JwtAuthenticationToken(TestJwts.jwt().build(), Collections.emptyList())); + // @formatter:off + this.mvc.perform(get("/").header("Authorization", "Bearer token")) .andExpect(status().isNotFound()); - + // @formatter:on verify(authenticationManagerResolver).resolve(any(HttpServletRequest.class)); } @Test public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Exception { this.spring.configLocations(xml("WebServer"), xml("MultipleIssuers")).autowire(); - MockWebServer server = this.spring.getContext().getBean(MockWebServer.class); - String metadata = "{\n" - + " \"issuer\": \"%s\", \n" - + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" + String metadata = "{\n" + " \"issuer\": \"%s\", \n" + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" + "}"; String jwkSet = jwkSet(); String issuerOne = server.url("/issuerOne").toString(); @@ -831,144 +707,119 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { String jwtOne = jwtFromIssuer(issuerOne); String jwtTwo = jwtFromIssuer(issuerTwo); String jwtThree = jwtFromIssuer(issuerThree); - mockWebServer(String.format(metadata, issuerOne, issuerOne)); mockWebServer(jwkSet); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer " + jwtOne)) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtOne)) .andExpect(status().isNotFound()); - + // @formatter:on mockWebServer(String.format(metadata, issuerTwo, issuerTwo)); mockWebServer(jwkSet); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer " + jwtTwo)) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtTwo)) .andExpect(status().isNotFound()); - + // @formatter:on mockWebServer(String.format(metadata, issuerThree, issuerThree)); mockWebServer(jwkSet); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer " + jwtThree)) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtThree)) .andExpect(status().isUnauthorized()) .andExpect(invalidTokenHeader("Invalid issuer")); + // @formatter:on } - // -- In combination with other authentication providers - @Test - public void requestWhenBasicAndResourceServerEntryPointsThenBearerTokenPresides() - throws Exception { // different from DSL - + public void requestWhenBasicAndResourceServerEntryPointsThenBearerTokenPresides() throws Exception { + // different from DSL this.spring.configLocations(xml("MockJwtDecoder"), xml("BasicAndResourceServer")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenThrow(JwtException.class); - - this.mvc.perform(get("/authenticated") - .with(httpBasic("some", "user"))) + given(decoder.decode(anyString())).willThrow(JwtException.class); + // @formatter:off + this.mvc.perform(get("/authenticated").with(httpBasic("some", "user"))) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Basic"))); - this.mvc.perform(get("/authenticated")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer"))); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer invalid_token")) + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer invalid_token")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer"))); + // @formatter:on } @Test - public void requestWhenFormLoginAndResourceServerEntryPointsThenSessionCreatedByRequest() - throws Exception { // different from DSL - + public void requestWhenFormLoginAndResourceServerEntryPointsThenSessionCreatedByRequest() throws Exception { + // different from DSL this.spring.configLocations(xml("MockJwtDecoder"), xml("FormAndResourceServer")).autowire(); - JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class); - when(decoder.decode(anyString())).thenThrow(JwtException.class); - - MvcResult result = - this.mvc.perform(get("/authenticated")) - .andExpect(status().isUnauthorized()) - .andReturn(); - + given(decoder.decode(anyString())).willThrow(JwtException.class); + MvcResult result = this.mvc.perform(get("/authenticated")).andExpect(status().isUnauthorized()).andReturn(); assertThat(result.getRequest().getSession(false)).isNotNull(); - - result = - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer token")) - .andExpect(status().isUnauthorized()) - .andReturn(); - + // @formatter:off + result = this.mvc.perform(get("/authenticated").header("Authorization", "Bearer token")) + .andExpect(status().isUnauthorized()) + .andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void getWhenAlsoUsingHttpBasicThenCorrectProviderEngages() - throws Exception { - + public void getWhenAlsoUsingHttpBasicThenCorrectProviderEngages() throws Exception { this.spring.configLocations(xml("JwtRestOperations"), xml("BasicAndResourceServer")).autowire(); mockRestOperations(jwks("Default")); String token = this.token("ValidNoScopes"); - - this.mvc.perform(get("/authenticated") - .header("Authorization", "Bearer " + token)) + // @formatter:off + this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + token)) .andExpect(status().isNotFound()); - - this.mvc.perform(get("/authenticated") - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/authenticated").with(httpBasic("user", "password"))) .andExpect(status().isNotFound()); + // @formatter:on } - // -- Incorrect Configuration - @Test public void configuredWhenMissingJwtAuthenticationProviderThenWiringException() { - assertThatCode(() -> this.spring.configLocations(xml("Jwtless")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("Please select one"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("Jwtless")).autowire()) + .withMessageContaining("Please select one"); } @Test public void configureWhenMissingJwkSetUriThenWiringException() { - assertThatCode(() -> this.spring.configLocations(xml("JwtHalfConfigured")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("Please specify either"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("JwtHalfConfigured")).autowire()) + .withMessageContaining("Please specify either"); } @Test public void configureWhenUsingBothAuthenticationManagerResolverAndJwtThenException() { - assertThatCode(() -> this.spring.configLocations(xml("AuthenticationManagerResolverPlusOtherConfig")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("authentication-manager-resolver-ref"); + assertThatExceptionOfType(BeanDefinitionParsingException.class).isThrownBy( + () -> this.spring.configLocations(xml("AuthenticationManagerResolverPlusOtherConfig")).autowire()) + .withMessageContaining("authentication-manager-resolver-ref"); } @Test public void validateConfigurationWhenMoreThanOneResourceServerModeThenError() { - OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser - (null, null, null, null, null); + OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser(null, null, null, + null, null); Element element = mock(Element.class); - when(element.hasAttribute(AUTHENTICATION_MANAGER_RESOLVER_REF)).thenReturn(true); + given(element.hasAttribute(OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF)) + .willReturn(true); Element child = mock(Element.class); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); - parser.validateConfiguration(element, child, null, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); reset(pc.getReaderContext()); - parser.validateConfiguration(element, null, child, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); } @Test public void validateConfigurationWhenNoResourceServerModeThenError() { - OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser - (null, null, null, null, null); + OAuth2ResourceServerBeanDefinitionParser parser = new OAuth2ResourceServerBeanDefinitionParser(null, null, null, + null, null); Element element = mock(Element.class); - when(element.hasAttribute(AUTHENTICATION_MANAGER_RESOLVER_REF)).thenReturn(false); + given(element.hasAttribute(OAuth2ResourceServerBeanDefinitionParser.AUTHENTICATION_MANAGER_RESOLVER_REF)) + .willReturn(false); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); parser.validateConfiguration(element, null, null, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); @@ -978,8 +829,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void validateConfigurationWhenBothJwtAttributesThenError() { JwtBeanDefinitionParser parser = new JwtBeanDefinitionParser(); Element element = mock(Element.class); - when(element.hasAttribute(JWK_SET_URI)).thenReturn(true); - when(element.hasAttribute(DECODER_REF)).thenReturn(true); + given(element.hasAttribute(JwtBeanDefinitionParser.JWK_SET_URI)).willReturn(true); + given(element.hasAttribute(JwtBeanDefinitionParser.DECODER_REF)).willReturn(true); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); parser.validateConfiguration(element, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); @@ -989,8 +840,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void validateConfigurationWhenNoJwtAttributesThenError() { JwtBeanDefinitionParser parser = new JwtBeanDefinitionParser(); Element element = mock(Element.class); - when(element.hasAttribute(JWK_SET_URI)).thenReturn(false); - when(element.hasAttribute(DECODER_REF)).thenReturn(false); + given(element.hasAttribute(JwtBeanDefinitionParser.JWK_SET_URI)).willReturn(false); + given(element.hasAttribute(JwtBeanDefinitionParser.DECODER_REF)).willReturn(false); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); parser.validateConfiguration(element, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); @@ -1000,8 +851,8 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void validateConfigurationWhenBothOpaqueTokenModesThenError() { OpaqueTokenBeanDefinitionParser parser = new OpaqueTokenBeanDefinitionParser(); Element element = mock(Element.class); - when(element.hasAttribute(INTROSPECTION_URI)).thenReturn(true); - when(element.hasAttribute(INTROSPECTOR_REF)).thenReturn(true); + given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTION_URI)).willReturn(true); + given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTOR_REF)).willReturn(true); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); parser.validateConfiguration(element, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); @@ -1011,16 +862,95 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void validateConfigurationWhenNoOpaqueTokenModeThenError() { OpaqueTokenBeanDefinitionParser parser = new OpaqueTokenBeanDefinitionParser(); Element element = mock(Element.class); - when(element.hasAttribute(INTROSPECTION_URI)).thenReturn(false); - when(element.hasAttribute(INTROSPECTOR_REF)).thenReturn(false); + given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTION_URI)).willReturn(false); + given(element.hasAttribute(OpaqueTokenBeanDefinitionParser.INTROSPECTOR_REF)).willReturn(false); ParserContext pc = new ParserContext(mock(XmlReaderContext.class), mock(BeanDefinitionParserDelegate.class)); parser.validateConfiguration(element, pc); verify(pc.getReaderContext()).error(anyString(), eq(element)); } + private static ResultMatcher invalidRequestHeader(String message) { + return header().string(HttpHeaders.WWW_AUTHENTICATE, + AllOf.allOf(new StringStartsWith("Bearer " + "error=\"invalid_request\", " + "error_description=\""), + new StringContains(message), + new StringEndsWith(", " + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""))); + } + + private static ResultMatcher invalidTokenHeader(String message) { + return header().string(HttpHeaders.WWW_AUTHENTICATE, + AllOf.allOf(new StringStartsWith("Bearer " + "error=\"invalid_token\", " + "error_description=\""), + new StringContains(message), + new StringEndsWith(", " + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""))); + } + + private static ResultMatcher insufficientScopeHeader() { + return header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer " + "error=\"insufficient_scope\"" + + ", error_description=\"The request requires higher privileges than provided by the access token.\"" + + ", error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""); + } + + private String jwkSet() { + return new JWKSet(new RSAKey.Builder(TestKeys.DEFAULT_PUBLIC_KEY).keyID("1").build()).toString(); + } + + private String jwtFromIssuer(String issuer) throws Exception { + Map claims = new HashMap<>(); + claims.put(JwtClaimNames.ISS, issuer); + claims.put(JwtClaimNames.SUB, "test-subject"); + claims.put("scope", "message:read"); + JWSObject jws = new JWSObject(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(), + new Payload(new JSONObject(claims))); + jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); + return jws.serialize(); + } + + private void mockWebServer(String response) { + this.web.enqueue(new MockResponse().setResponseCode(200) + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(response)); + } + + private void mockRestOperations(String response) { + RestOperations rest = this.spring.getContext().getBean(RestOperations.class); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + ResponseEntity entity = new ResponseEntity<>(response, headers, HttpStatus.OK); + given(rest.exchange(any(RequestEntity.class), eq(String.class))).willReturn(entity); + } + + private String json(String name) throws IOException { + return resource(name + ".json"); + } + + private String jwks(String name) throws IOException { + return resource(name + ".jwks"); + } + + private String token(String name) throws IOException { + return resource(name + ".token"); + } + + private String resource(String suffix) throws IOException { + String name = this.getClass().getSimpleName() + "-" + suffix; + ClassPathResource resource = new ClassPathResource(name, this.getClass()); + try (BufferedReader reader = new BufferedReader(new FileReader(resource.getFile()))) { + return reader.lines().collect(Collectors.joining()); + } + } + + private T bean(Class beanClass) { + return this.spring.getContext().getBean(beanClass); + } + + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; + } + static class JwtDecoderFactoryBean implements FactoryBean { + private RestOperations rest; + private RSAPublicKey key; + private OAuth2TokenValidator jwtValidator; @Override @@ -1028,9 +958,9 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { NimbusJwtDecoder decoder; if (this.key != null) { decoder = NimbusJwtDecoder.withPublicKey(this.key).build(); - } else { - decoder = NimbusJwtDecoder.withJwkSetUri("https://idp.example.org") - .restOperations(this.rest).build(); + } + else { + decoder = NimbusJwtDecoder.withJwkSetUri("https://idp.example.org").restOperations(this.rest).build(); } if (this.jwtValidator != null) { decoder.setJwtValidator(this.jwtValidator); @@ -1054,9 +984,11 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void setRest(RestOperations rest) { this.rest = rest; } + } static class OpaqueTokenIntrospectorFactoryBean implements FactoryBean { + private RestOperations rest; @Override @@ -1072,9 +1004,11 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void setRest(RestOperations rest) { this.rest = rest; } + } static class MockWebServerFactoryBean implements FactoryBean, DisposableBean { + private final MockWebServer web = new MockWebServer(); @Override @@ -1091,10 +1025,10 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public Class getObjectType() { return MockWebServer.class; } + } - static class MockWebServerPropertiesFactoryBean - implements FactoryBean, DisposableBean { + static class MockWebServerPropertiesFactoryBean implements FactoryBean, DisposableBean { MockWebServer web; @@ -1121,10 +1055,10 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void destroy() throws Exception { this.web.shutdown(); } + } - static class ClockFactoryBean - implements FactoryBean { + static class ClockFactoryBean implements FactoryBean { Clock clock; @@ -1141,99 +1075,7 @@ public class OAuth2ResourceServerBeanDefinitionParserTests { public void setMillis(long millis) { this.clock = Clock.fixed(Instant.ofEpochMilli(millis), ZoneId.systemDefault()); } + } - private static ResultMatcher invalidRequestHeader(String message) { - return header().string(HttpHeaders.WWW_AUTHENTICATE, - AllOf.allOf( - new StringStartsWith("Bearer " + - "error=\"invalid_request\", " + - "error_description=\""), - new StringContains(message), - new StringEndsWith(", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"") - ) - ); - } - - private static ResultMatcher invalidTokenHeader(String message) { - return header().string(HttpHeaders.WWW_AUTHENTICATE, - AllOf.allOf( - new StringStartsWith("Bearer " + - "error=\"invalid_token\", " + - "error_description=\""), - new StringContains(message), - new StringEndsWith(", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"") - ) - ); - } - - private static ResultMatcher insufficientScopeHeader() { - return header().string(HttpHeaders.WWW_AUTHENTICATE, "Bearer " + - "error=\"insufficient_scope\"" + - ", error_description=\"The request requires higher privileges than provided by the access token.\"" + - ", error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""); - } - - private String jwkSet() { - return new JWKSet(new RSAKey.Builder(TestKeys.DEFAULT_PUBLIC_KEY) - .keyID("1").build()).toString(); - } - - private String jwtFromIssuer(String issuer) throws Exception { - Map claims = new HashMap<>(); - claims.put(ISS, issuer); - claims.put(SUB, "test-subject"); - claims.put("scope", "message:read"); - JWSObject jws = new JWSObject( - new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("1").build(), - new Payload(new JSONObject(claims))); - jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); - return jws.serialize(); - } - - private void mockWebServer(String response) { - this.web.enqueue(new MockResponse() - .setResponseCode(200) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(response)); - } - - private void mockRestOperations(String response) { - RestOperations rest = this.spring.getContext().getBean(RestOperations.class); - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - ResponseEntity entity = new ResponseEntity<>(response, headers, HttpStatus.OK); - Mockito.when(rest.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(entity); - } - - private String json(String name) throws IOException { - return resource(name + ".json"); - } - - private String jwks(String name) throws IOException { - return resource(name + ".jwks"); - } - - private String token(String name) throws IOException { - return resource(name + ".token"); - } - - private String resource(String suffix) throws IOException { - String name = this.getClass().getSimpleName() + "-" + suffix; - ClassPathResource resource = new ClassPathResource(name, this.getClass()); - try ( BufferedReader reader = new BufferedReader(new FileReader(resource.getFile())) ) { - return reader.lines().collect(Collectors.joining()); - } - } - - private T bean(Class beanClass) { - return this.spring.getContext().getBean(beanClass); - } - - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/OpenIDConfigTests.java b/config/src/test/java/org/springframework/security/config/http/OpenIDConfigTests.java index b13678789c..5d2c13657d 100644 --- a/config/src/test/java/org/springframework/security/config/http/OpenIDConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OpenIDConfigTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.HashSet; import java.util.Set; + import javax.servlet.Filter; import javax.servlet.http.HttpServletRequest; @@ -25,6 +27,7 @@ import okhttp3.mockwebserver.MockWebServer; import org.junit.Rule; import org.junit.Test; import org.openid4java.consumer.ConsumerManager; +import org.openid4java.discovery.yadis.YadisResolver; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.parsing.BeanDefinitionParsingException; @@ -37,17 +40,17 @@ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.openid4java.discovery.yadis.YadisResolver.YADIS_XRDS_LOCATION; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -59,8 +62,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Luke Taylor */ public class OpenIDConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/OpenIDConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OpenIDConfigTests"; @Autowired MockMvc mvc; @@ -69,114 +72,92 @@ public class OpenIDConfigTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void requestWhenOpenIDAndFormLoginBothConfiguredThenRedirectsToGeneratedLoginPage() - throws Exception { - + public void requestWhenOpenIDAndFormLoginBothConfiguredThenRedirectsToGeneratedLoginPage() throws Exception { this.spring.configLocations(this.xml("WithFormLogin")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); - + // @formatter:on assertThat(getFilter(DefaultLoginPageGeneratingFilter.class)).isNotNull(); } @Test - public void requestWhenOpenIDAndFormLoginWithFormLoginPageConfiguredThenFormLoginPageWins() - throws Exception { - + public void requestWhenOpenIDAndFormLoginWithFormLoginPageConfiguredThenFormLoginPageWins() throws Exception { this.spring.configLocations(this.xml("WithFormLoginPage")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/form-page")); + // @formatter:on } @Test - public void requestWhenOpenIDAndFormLoginWithOpenIDLoginPageConfiguredThenOpenIDLoginPageWins() - throws Exception { - + public void requestWhenOpenIDAndFormLoginWithOpenIDLoginPageConfiguredThenOpenIDLoginPageWins() throws Exception { this.spring.configLocations(this.xml("WithOpenIDLoginPageAndFormLogin")).autowire(); - + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/openid-page")); + // @formatter:on } @Test public void configureWhenOpenIDAndFormLoginBothConfigureLoginPagesThenWiringException() { - - assertThatCode(() -> this.spring.configLocations(this.xml("WithFormLoginAndOpenIDLoginPages")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(this.xml("WithFormLoginAndOpenIDLoginPages")).autowire()); } @Test - public void requestWhenOpenIDAndRememberMeConfiguredThenRememberMePassedToIdp() - throws Exception { - + public void requestWhenOpenIDAndRememberMeConfiguredThenRememberMePassedToIdp() throws Exception { this.spring.configLocations(this.xml("WithRememberMe")).autowire(); - OpenIDAuthenticationFilter openIDFilter = getFilter(OpenIDAuthenticationFilter.class); - String openIdEndpointUrl = "https://testopenid.com?openid.return_to="; Set returnToUrlParameters = new HashSet<>(); returnToUrlParameters.add(AbstractRememberMeServices.DEFAULT_PARAMETER); openIDFilter.setReturnToUrlParameters(returnToUrlParameters); - OpenIDConsumer consumer = mock(OpenIDConsumer.class); - when(consumer.beginConsumption(any(HttpServletRequest.class), anyString(), anyString(), anyString())) - .then(invocation -> openIdEndpointUrl + invocation.getArgument(2)); + given(consumer.beginConsumption(any(HttpServletRequest.class), anyString(), anyString(), anyString())) + .will((invocation) -> openIdEndpointUrl + invocation.getArgument(2)); openIDFilter.setConsumer(consumer); - String expectedReturnTo = new StringBuilder("http://localhost/login/openid").append("?") - .append(AbstractRememberMeServices.DEFAULT_PARAMETER) - .append("=").append("on").toString(); - + .append(AbstractRememberMeServices.DEFAULT_PARAMETER).append("=").append("on").toString(); + // @formatter:off this.mvc.perform(get("/")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); - this.mvc.perform(get("/login")) .andExpect(status().isOk()) .andExpect(content().string(containsString(AbstractRememberMeServices.DEFAULT_PARAMETER))); - - this.mvc.perform(get("/login/openid") + MockHttpServletRequestBuilder openidLogin = get("/login/openid") .param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, "https://ww1.openid.com") - .param(AbstractRememberMeServices.DEFAULT_PARAMETER, "on")) + .param(AbstractRememberMeServices.DEFAULT_PARAMETER, "on"); + this.mvc.perform(openidLogin) .andExpect(status().isFound()) .andExpect(redirectedUrl(openIdEndpointUrl + expectedReturnTo)); + // @formatter:on } @Test - public void requestWhenAttributeExchangeConfiguredThenFetchAttributesPassedToIdp() - throws Exception { - + public void requestWhenAttributeExchangeConfiguredThenFetchAttributesPassedToIdp() throws Exception { this.spring.configLocations(this.xml("WithOpenIDAttributes")).autowire(); - OpenIDAuthenticationFilter openIDFilter = getFilter(OpenIDAuthenticationFilter.class); OpenID4JavaConsumer consumer = getFieldValue(openIDFilter, "consumer"); ConsumerManager manager = getFieldValue(consumer, "consumerManager"); manager.setMaxAssocAttempts(0); - - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { String endpoint = server.url("/").toString(); - + server.enqueue(new MockResponse().addHeader(YadisResolver.YADIS_XRDS_LOCATION, endpoint)); server.enqueue(new MockResponse() - .addHeader(YADIS_XRDS_LOCATION, endpoint)); - server.enqueue(new MockResponse() - .setBody(String.format( - "%s", - endpoint))); - - this.mvc.perform(get("/login/openid") - .param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, endpoint)) + .setBody(String.format("%s", endpoint))); + this.mvc.perform( + get("/login/openid").param(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, endpoint)) .andExpect(status().isFound()) - .andExpect(result -> result.getResponse().getRedirectedUrl().endsWith( - "openid.ext1.type.nickname=http%3A%2F%2Fschema.openid.net%2FnamePerson%2Ffriendly&" + - "openid.ext1.if_available=nickname&" + - "openid.ext1.type.email=http%3A%2F%2Fschema.openid.net%2Fcontact%2Femail&" + - "openid.ext1.required=email&" + - "openid.ext1.count.email=2")); + .andExpect((result) -> result.getResponse().getRedirectedUrl().endsWith( + "openid.ext1.type.nickname=http%3A%2F%2Fschema.openid.net%2FnamePerson%2Ffriendly&" + + "openid.ext1.if_available=nickname&" + + "openid.ext1.type.email=http%3A%2F%2Fschema.openid.net%2Fcontact%2Femail&" + + "openid.ext1.required=email&" + "openid.ext1.count.email=2")); } } @@ -186,30 +167,18 @@ public class OpenIDConfigTests { @Test public void requestWhenLoginPageConfiguredWithPhraseLoginThenRedirectsOnlyToUserGeneratedLoginPage() throws Exception { - this.spring.configLocations(this.xml("Sec2919")).autowire(); - assertThat(getFilter(DefaultLoginPageGeneratingFilter.class)).isNull(); - + // @formatter:off this.mvc.perform(get("/login")) .andExpect(status().isOk()) .andExpect(content().string("a custom login page")); - } - - @RestController - static class CustomLoginController { - @GetMapping("/login") - public String custom() { - return "a custom login page"; - } + // @formatter:on } private T getFilter(Class clazz) { FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class); - return (T) filterChain.getFilters("/").stream() - .filter(clazz::isInstance) - .findFirst() - .orElse(null); + return (T) filterChain.getFilters("/").stream().filter(clazz::isInstance).findFirst().orElse(null); } private String xml(String configName) { @@ -219,4 +188,15 @@ public class OpenIDConfigTests { private static T getFieldValue(Object bean, String fieldName) throws IllegalAccessException { return (T) FieldUtils.getFieldValue(bean, fieldName); } + + @RestController + static class CustomLoginController { + + @GetMapping("/login") + String custom() { + return "a custom login page"; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/PlaceHolderAndELConfigTests.java b/config/src/test/java/org/springframework/security/config/http/PlaceHolderAndELConfigTests.java index e887545055..10d289d8e4 100644 --- a/config/src/test/java/org/springframework/security/config/http/PlaceHolderAndELConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/PlaceHolderAndELConfigTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; @@ -40,8 +42,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @SecurityTestExecutionListeners public class PlaceHolderAndELConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/PlaceHolderAndELConfigTests"; + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/PlaceHolderAndELConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -50,151 +51,122 @@ public class PlaceHolderAndELConfigTests { MockMvc mvc; @Test - public void getWhenUsingPlaceholderThenUnsecuredPatternCorrectlyConfigured() - throws Exception { - + public void getWhenUsingPlaceholderThenUnsecuredPatternCorrectlyConfigured() throws Exception { System.setProperty("pattern.nofilters", "/unsecured"); - this.spring.configLocations(this.xml("UnsecuredPattern")).autowire(); - + // @formatter:off this.mvc.perform(get("/unsecured")) .andExpect(status().isOk()); + // @formatter:on } /** * SEC-1201 */ @Test - public void loginWhenUsingPlaceholderThenInterceptUrlsAndFormLoginWorks() - throws Exception { - + public void loginWhenUsingPlaceholderThenInterceptUrlsAndFormLoginWorks() throws Exception { System.setProperty("secure.Url", "/secured"); System.setProperty("secure.role", "ROLE_NUNYA"); System.setProperty("login.page", "/loginPage"); System.setProperty("default.target", "/defaultTarget"); System.setProperty("auth.failure", "/authFailure"); - this.spring.configLocations(this.xml("InterceptUrlAndFormLogin")).autowire(); - // login-page setting - + // @formatter:off this.mvc.perform(get("/secured")) .andExpect(redirectedUrl("http://localhost/loginPage")); - // login-processing-url setting // default-target-url setting - - this.mvc.perform(post("/loginPage") - .param("username", "user") - .param("password", "password")) + this.mvc.perform(post("/loginPage").param("username", "user").param("password", "password")) .andExpect(redirectedUrl("/defaultTarget")); - // authentication-failure-url setting - - this.mvc.perform(post("/loginPage") - .param("username", "user") - .param("password", "wrong")) + this.mvc.perform(post("/loginPage").param("username", "user").param("password", "wrong")) .andExpect(redirectedUrl("/authFailure")); + // @formatter:on } /** * SEC-1309 */ @Test - public void loginWhenUsingSpELThenInterceptUrlsAndFormLoginWorks() - throws Exception { - + public void loginWhenUsingSpELThenInterceptUrlsAndFormLoginWorks() throws Exception { System.setProperty("secure.url", "/secured"); System.setProperty("secure.role", "ROLE_NUNYA"); System.setProperty("login.page", "/loginPage"); System.setProperty("default.target", "/defaultTarget"); System.setProperty("auth.failure", "/authFailure"); - - this.spring.configLocations( - this.xml("InterceptUrlAndFormLoginWithSpEL")).autowire(); - + this.spring.configLocations(this.xml("InterceptUrlAndFormLoginWithSpEL")).autowire(); // login-page setting - + // @formatter:off this.mvc.perform(get("/secured")) .andExpect(redirectedUrl("http://localhost/loginPage")); - // login-processing-url setting // default-target-url setting - - this.mvc.perform(post("/loginPage") - .param("username", "user") - .param("password", "password")) + this.mvc.perform(post("/loginPage").param("username", "user").param("password", "password")) .andExpect(redirectedUrl("/defaultTarget")); - // authentication-failure-url setting - - this.mvc.perform(post("/loginPage") - .param("username", "user") - .param("password", "wrong")) + this.mvc.perform(post("/loginPage").param("username", "user").param("password", "wrong")) .andExpect(redirectedUrl("/authFailure")); - + // @formatter:on } @Test @WithMockUser - public void requestWhenUsingPlaceholderOrSpELThenPortMapperWorks() - throws Exception { - + public void requestWhenUsingPlaceholderOrSpELThenPortMapperWorks() throws Exception { System.setProperty("http", "9080"); System.setProperty("https", "9443"); - this.spring.configLocations(this.xml("PortMapping")).autowire(); - + // @formatter:off this.mvc.perform(get("http://localhost:9080/secured")) .andExpect(status().isFound()) .andExpect(redirectedUrl("https://localhost:9443/secured")); - this.mvc.perform(get("https://localhost:9443/unsecured")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost:9080/unsecured")); + // @formatter:on } @Test @WithMockUser - public void requestWhenUsingPlaceholderThenRequiresChannelWorks() - throws Exception { - + public void requestWhenUsingPlaceholderThenRequiresChannelWorks() throws Exception { System.setProperty("secure.url", "/secured"); System.setProperty("required.channel", "https"); - this.spring.configLocations(this.xml("RequiresChannel")).autowire(); - + // @formatter:off this.mvc.perform(get("http://localhost/secured")) .andExpect(status().isFound()) .andExpect(redirectedUrl("https://localhost/secured")); + // @formatter:on } @Test @WithMockUser - public void requestWhenUsingPlaceholderThenAccessDeniedPageWorks() - throws Exception { - + public void requestWhenUsingPlaceholderThenAccessDeniedPageWorks() throws Exception { System.setProperty("accessDenied", "/go-away"); - this.spring.configLocations(this.xml("AccessDeniedPage")).autowire(); - + // @formatter:off this.mvc.perform(get("/secured")) .andExpect(forwardedUrl("/go-away")); + // @formatter:on } @Test @WithMockUser - public void requestWhenUsingSpELThenAccessDeniedPageWorks() - throws Exception { - + public void requestWhenUsingSpELThenAccessDeniedPageWorks() throws Exception { this.spring.configLocations(this.xml("AccessDeniedPageWithSpEL")).autowire(); - + // @formatter:off this.mvc.perform(get("/secured")) .andExpect(forwardedUrl("/go-away")); + // @formatter:on + } + + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } @RestController static class SimpleController { + @GetMapping("/unsecured") String unsecured() { return "unsecured"; @@ -204,9 +176,7 @@ public class PlaceHolderAndELConfigTests { String secured() { return "secured"; } + } - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/RememberMeConfigTests.java b/config/src/test/java/org/springframework/security/config/http/RememberMeConfigTests.java index 0854316af8..2774114e6d 100644 --- a/config/src/test/java/org/springframework/security/config/http/RememberMeConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/RememberMeConfigTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.Collections; + import javax.servlet.http.Cookie; import org.junit.Rule; @@ -29,6 +31,8 @@ import org.springframework.security.TestDataSource; import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetailsService; +import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; +import org.springframework.security.web.authentication.rememberme.JdbcTokenRepositoryImpl; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultActions; @@ -37,14 +41,11 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; -import static org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices.DEFAULT_PARAMETER; -import static org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY; -import static org.springframework.security.web.authentication.rememberme.JdbcTokenRepositoryImpl.CREATE_TABLE_SQL; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie; @@ -52,14 +53,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Luke Taylor * @author Rob Winch * @author Oliver Becker */ public class RememberMeConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/RememberMeConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/RememberMeConfigTests"; @Autowired MockMvc mvc; @@ -68,205 +68,170 @@ public class RememberMeConfigTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void requestWithRememberMeWhenUsingCustomTokenRepositoryThenAutomaticallyReauthenticates() - throws Exception { - - this.spring.configLocations(this.xml("WithTokenRepository")).autowire(); - - MvcResult result = this.rememberAuthentication("user", "password") - .andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) + public void requestWithRememberMeWhenUsingCustomTokenRepositoryThenAutomaticallyReauthenticates() throws Exception { + this.spring.configLocations(xml("WithTokenRepository")).autowire(); + // @formatter:off + MvcResult result = rememberAuthentication("user", "password") + .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) .andReturn(); - + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); - + // @formatter:on JdbcTemplate template = this.spring.getContext().getBean(JdbcTemplate.class); int count = template.queryForObject("select count(*) from persistent_logins", int.class); assertThat(count).isEqualTo(1); } @Test - public void requestWithRememberMeWhenUsingCustomDataSourceThenAutomaticallyReauthenticates() - throws Exception { - - this.spring.configLocations(this.xml("WithDataSource")).autowire(); - + public void requestWithRememberMeWhenUsingCustomDataSourceThenAutomaticallyReauthenticates() throws Exception { + this.spring.configLocations(xml("WithDataSource")).autowire(); TestDataSource dataSource = this.spring.getContext().getBean(TestDataSource.class); JdbcTemplate template = new JdbcTemplate(dataSource); - template.execute(CREATE_TABLE_SQL); - - MvcResult result = this.rememberAuthentication("user", "password") - .andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) + template.execute(JdbcTokenRepositoryImpl.CREATE_TABLE_SQL); + // @formatter:off + MvcResult result = rememberAuthentication("user", "password") + .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) .andReturn(); - + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); - + // @formatter:on int count = template.queryForObject("select count(*) from persistent_logins", int.class); assertThat(count).isEqualTo(1); } @Test - public void requestWithRememberMeWhenUsingAuthenticationSuccessHandlerThenInvokesHandler() - throws Exception { - - this.spring.configLocations(this.xml("WithAuthenticationSuccessHandler")).autowire(); - + public void requestWithRememberMeWhenUsingAuthenticationSuccessHandlerThenInvokesHandler() throws Exception { + this.spring.configLocations(xml("WithAuthenticationSuccessHandler")).autowire(); TestDataSource dataSource = this.spring.getContext().getBean(TestDataSource.class); JdbcTemplate template = new JdbcTemplate(dataSource); - template.execute(CREATE_TABLE_SQL); - - MvcResult result = this.rememberAuthentication("user", "password") - .andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) + template.execute(JdbcTokenRepositoryImpl.CREATE_TABLE_SQL); + // @formatter:off + MvcResult result = rememberAuthentication("user", "password") + .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) .andReturn(); - + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(redirectedUrl("/target")); - + // @formatter:on int count = template.queryForObject("select count(*) from persistent_logins", int.class); assertThat(count).isEqualTo(1); } @Test - public void requestWithRememberMeWhenUsingCustomRememberMeServicesThenAuthenticates() - throws Exception { + public void requestWithRememberMeWhenUsingCustomRememberMeServicesThenAuthenticates() throws Exception { // SEC-1281 - using key with external services - this.spring.configLocations(this.xml("WithServicesRef")).autowire(); - - MvcResult result = this.rememberAuthentication("user", "password") - .andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) - .andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 5000)) + this.spring.configLocations(xml("WithServicesRef")).autowire(); + // @formatter:off + MvcResult result = rememberAuthentication("user", "password") + .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)) + .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 5000)) .andReturn(); - + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); - // SEC-909 - this.mvc.perform(post("/logout") - .cookie(cookie) - .with(csrf())) - .andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0)) + this.mvc.perform(post("/logout").cookie(cookie).with(csrf())) + .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0)) .andReturn(); + // @formatter:on } @Test - public void logoutWhenUsingRememberMeDefaultsThenCookieIsCancelled() - throws Exception { - - this.spring.configLocations(this.xml("DefaultConfig")).autowire(); - - MvcResult result = this.rememberAuthentication("user", "password").andReturn(); - + public void logoutWhenUsingRememberMeDefaultsThenCookieIsCancelled() throws Exception { + this.spring.configLocations(xml("DefaultConfig")).autowire(); + MvcResult result = rememberAuthentication("user", "password").andReturn(); Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(post("/logout") - .cookie(cookie) - .with(csrf())) - .andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0)); + // @formatter:off + this.mvc.perform(post("/logout").cookie(cookie).with(csrf())) + .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 0)); + // @formatter:on } @Test public void requestWithRememberMeWhenTokenValidityIsConfiguredThenCookieReflectsCorrectExpiration() throws Exception { - - this.spring.configLocations(this.xml("TokenValidity")).autowire(); - - MvcResult result = this.rememberAuthentication("user", "password") - .andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 10000)) + this.spring.configLocations(xml("TokenValidity")).autowire(); + // @formatter:off + MvcResult result = rememberAuthentication("user", "password") + .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 10000)) .andReturn(); - + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); + // @formatter:on } @Test - public void requestWithRememberMeWhenTokenValidityIsNegativeThenCookieReflectsCorrectExpiration() - throws Exception { - - this.spring.configLocations(this.xml("NegativeTokenValidity")).autowire(); - - this.rememberAuthentication("user", "password") - .andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, -1)); + public void requestWithRememberMeWhenTokenValidityIsNegativeThenCookieReflectsCorrectExpiration() throws Exception { + this.spring.configLocations(xml("NegativeTokenValidity")).autowire(); + // @formatter:off + rememberAuthentication("user", "password") + .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, -1)); + // @formatter:on } - @Test public void configureWhenUsingDataSourceAndANegativeTokenValidityThenThrowsWiringException() { - assertThatCode(() -> this.spring.configLocations(this.xml("NegativeTokenValidityWithDataSource")).autowire()) - .isInstanceOf(FatalBeanException.class); + assertThatExceptionOfType(FatalBeanException.class) + .isThrownBy(() -> this.spring.configLocations(xml("NegativeTokenValidityWithDataSource")).autowire()); } @Test public void requestWithRememberMeWhenTokenValidityIsResolvedByPropertyPlaceholderThenCookieReflectsCorrectExpiration() throws Exception { - - this.spring.configLocations(this.xml("Sec2165")).autowire(); - - this.rememberAuthentication("user", "password") - .andExpect(cookie().maxAge(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 30)); + this.spring.configLocations(xml("Sec2165")).autowire(); + rememberAuthentication("user", "password") + .andExpect(cookie().maxAge(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, 30)); } @Test - public void requestWithRememberMeWhenUseSecureCookieIsTrueThenCookieIsSecure() - throws Exception { - - this.spring.configLocations(this.xml("SecureCookie")).autowire(); - - this.rememberAuthentication("user", "password") - .andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, true)); + public void requestWithRememberMeWhenUseSecureCookieIsTrueThenCookieIsSecure() throws Exception { + this.spring.configLocations(xml("SecureCookie")).autowire(); + rememberAuthentication("user", "password") + .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, true)); } /** * SEC-1827 */ @Test - public void requestWithRememberMeWhenUseSecureCookieIsFalseThenCookieIsNotSecure() - throws Exception { - - this.spring.configLocations(this.xml("Sec1827")).autowire(); - - this.rememberAuthentication("user", "password") - .andExpect(cookie().secure(SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)); + public void requestWithRememberMeWhenUseSecureCookieIsFalseThenCookieIsNotSecure() throws Exception { + this.spring.configLocations(xml("Sec1827")).autowire(); + rememberAuthentication("user", "password") + .andExpect(cookie().secure(AbstractRememberMeServices.SPRING_SECURITY_REMEMBER_ME_COOKIE_KEY, false)); } @Test public void configureWhenUsingPersistentTokenRepositoryAndANegativeTokenValidityThenThrowsWiringException() { - assertThatCode(() -> this.spring.configLocations(this.xml("NegativeTokenValidityWithPersistentRepository")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class).isThrownBy( + () -> this.spring.configLocations(xml("NegativeTokenValidityWithPersistentRepository")).autowire()); } @Test public void requestWithRememberMeWhenUsingCustomUserDetailsServiceThenInvokesThisUserDetailsService() throws Exception { - this.spring.configLocations(this.xml("WithUserDetailsService")).autowire(); - + this.spring.configLocations(xml("WithUserDetailsService")).autowire(); UserDetailsService userDetailsService = this.spring.getContext().getBean(UserDetailsService.class); - when(userDetailsService.loadUserByUsername("user")).thenAnswer((invocation) -> - new User("user", "{noop}password", Collections.emptyList())); - - MvcResult result = this.rememberAuthentication("user", "password").andReturn(); - + given(userDetailsService.loadUserByUsername("user")) + .willAnswer((invocation) -> new User("user", "{noop}password", Collections.emptyList())); + MvcResult result = rememberAuthentication("user", "password").andReturn(); Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); - + // @formatter:on verify(userDetailsService, atLeastOnce()).loadUserByUsername("user"); } @@ -274,65 +239,57 @@ public class RememberMeConfigTests { * SEC-742 */ @Test - public void requestWithRememberMeWhenExcludingBasicAuthenticationFilterThenStillReauthenticates() - throws Exception { - - this.spring.configLocations(this.xml("Sec742")).autowire(); - - MvcResult result = - this.mvc.perform(login("user", "password") - .param("remember-me", "true") - .with(csrf())) - .andExpect(redirectedUrl("/messageList.html")) - .andReturn(); - + public void requestWithRememberMeWhenExcludingBasicAuthenticationFilterThenStillReauthenticates() throws Exception { + this.spring.configLocations(xml("Sec742")).autowire(); + // @formatter:off + MvcResult result = this.mvc.perform(login("user", "password").param("remember-me", "true").with(csrf())) + .andExpect(redirectedUrl("/messageList.html")) + .andReturn(); + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); + // @formatter:on } /** * SEC-2119 */ @Test - public void requestWithRememberMeWhenUsingCustomRememberMeParameterThenReauthenticates() - throws Exception { - - this.spring.configLocations(this.xml("WithRememberMeParameter")).autowire(); - - MvcResult result = - this.mvc.perform(login("user", "password") - .param("custom-remember-me-parameter", "true") - .with(csrf())) - .andExpect(redirectedUrl("/")) - .andReturn(); - + public void requestWithRememberMeWhenUsingCustomRememberMeParameterThenReauthenticates() throws Exception { + this.spring.configLocations(xml("WithRememberMeParameter")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = login("user", "password") + .param("custom-remember-me-parameter", "true") + .with(csrf()); + MvcResult result = this.mvc.perform(request) + .andExpect(redirectedUrl("/")) + .andReturn(); + // @formatter:on Cookie cookie = rememberMeCookie(result); - - this.mvc.perform(get("/authenticated") - .cookie(cookie)) + // @formatter:off + this.mvc.perform(get("/authenticated").cookie(cookie)) .andExpect(status().isOk()); + // @formatter:on } @Test public void configureWhenUsingRememberMeParameterAndServicesRefThenThrowsWiringException() { - assertThatCode(() -> this.spring.configLocations(this.xml("WithRememberMeParameterAndServicesRef")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("WithRememberMeParameterAndServicesRef")).autowire()); } /** * SEC-2826 */ @Test - public void authenticateWhenUsingCustomRememberMeCookieNameThenIssuesCookieWithThatName() - throws Exception { - - this.spring.configLocations(this.xml("WithRememberMeCookie")).autowire(); - - this.rememberAuthentication("user", "password") + public void authenticateWhenUsingCustomRememberMeCookieNameThenIssuesCookieWithThatName() throws Exception { + this.spring.configLocations(xml("WithRememberMeCookie")).autowire(); + // @formatter:off + rememberAuthentication("user", "password") .andExpect(cookie().exists("custom-remember-me-cookie")); + // @formatter:on } /** @@ -340,32 +297,30 @@ public class RememberMeConfigTests { */ @Test public void configureWhenUsingRememberMeCookieAndServicesRefThenThrowsWiringException() { - assertThatCode(() -> this.spring.configLocations(this.xml("WithRememberMeCookieAndServicesRef")).autowire()) - .isInstanceOf(BeanDefinitionParsingException.class) - .hasMessageContaining("Configuration problem: services-ref can't be used in combination with attributes " + - "token-repository-ref,data-source-ref, user-service-ref, token-validity-seconds, use-secure-cookie, " + - "remember-me-parameter or remember-me-cookie"); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("WithRememberMeCookieAndServicesRef")).autowire()) + .withMessageContaining( + "Configuration problem: services-ref can't be used in combination with attributes " + + "token-repository-ref,data-source-ref, user-service-ref, token-validity-seconds, " + + "use-secure-cookie, remember-me-parameter or remember-me-cookie"); } - @RestController - static class BasicController { - @GetMapping("/authenticated") - String ok() { - return "ok"; - } - } - - private ResultActions rememberAuthentication(String username, String password) - throws Exception { - - return this.mvc.perform(login(username, password) - .param(DEFAULT_PARAMETER, "true") - .with(csrf())) + private ResultActions rememberAuthentication(String username, String password) throws Exception { + // @formatter:off + MockHttpServletRequestBuilder request = login(username, password) + .param(AbstractRememberMeServices.DEFAULT_PARAMETER, "true") + .with(csrf()); + return this.mvc.perform(request) .andExpect(redirectedUrl("/")); + // @formatter:on } private static MockHttpServletRequestBuilder login(String username, String password) { - return post("/login").param("username", username).param("password", password); + // @formatter:off + return post("/login") + .param("username", username) + .param("password", password); + // @formatter:on } private static Cookie rememberMeCookie(MvcResult result) { @@ -375,4 +330,15 @@ public class RememberMeConfigTests { private String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + + @RestController + static class BasicController { + + @GetMapping("/authenticated") + String ok() { + return "ok"; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests.java b/config/src/test/java/org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests.java index 9f57135a89..0399d25cd1 100644 --- a/config/src/test/java/org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests.java @@ -13,12 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.apache.http.HttpHeaders; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.config.test.SpringTestRule; @@ -31,13 +39,8 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; - import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.core.StringContains.containsString; +import static org.hamcrest.CoreMatchers.containsString; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie; @@ -46,7 +49,6 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * * @author Rob Winch * @author Josh Cummings */ @@ -54,8 +56,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @SecurityTestExecutionListeners public class SecurityContextHolderAwareRequestConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests"; + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -64,163 +65,136 @@ public class SecurityContextHolderAwareRequestConfigTests { private MockMvc mvc; @Test - public void servletLoginWhenUsingDefaultConfigurationThenUsesSpringSecurity() - throws Exception { - + public void servletLoginWhenUsingDefaultConfigurationThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("Simple")).autowire(); - + // @formatter:off this.mvc.perform(get("/good-login")) .andExpect(status().isOk()) .andExpect(content().string("user")); + // @formatter:on } @Test - public void servletAuthenticateWhenUsingDefaultConfigurationThenUsesSpringSecurity() - throws Exception { - + public void servletAuthenticateWhenUsingDefaultConfigurationThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("Simple")).autowire(); - + // @formatter:off this.mvc.perform(get("/authenticate")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test - public void servletLogoutWhenUsingDefaultConfigurationThenUsesSpringSecurity() - throws Exception { - + public void servletLogoutWhenUsingDefaultConfigurationThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("Simple")).autowire(); - MvcResult result = this.mvc.perform(get("/good-login")).andReturn(); - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNotNull(); - + // @formatter:off result = this.mvc.perform(get("/do-logout").session(session)) .andExpect(status().isOk()) .andExpect(content().string("")) .andReturn(); - + // @formatter:on session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNull(); } @Test - public void servletAuthenticateWhenUsingHttpBasicThenUsesSpringSecurity() - throws Exception { - + public void servletAuthenticateWhenUsingHttpBasicThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("HttpBasic")).autowire(); - + // @formatter:off this.mvc.perform(get("/authenticate")) .andExpect(status().isUnauthorized()) .andExpect(header().string(HttpHeaders.WWW_AUTHENTICATE, containsString("discworld"))); + // @formatter:on } @Test - public void servletAuthenticateWhenUsingFormLoginThenUsesSpringSecurity() - throws Exception { - + public void servletAuthenticateWhenUsingFormLoginThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("FormLogin")).autowire(); - + // @formatter:off this.mvc.perform(get("/authenticate")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); + // @formatter:on } @Test - public void servletLoginWhenUsingMultipleHttpConfigsThenUsesSpringSecurity() - throws Exception { - + public void servletLoginWhenUsingMultipleHttpConfigsThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("MultiHttp")).autowire(); - + // @formatter:off this.mvc.perform(get("/good-login")) .andExpect(status().isOk()) .andExpect(content().string("user")); - this.mvc.perform(get("/v2/good-login")) .andExpect(status().isOk()) .andExpect(content().string("user2")); + // @formatter:on } @Test - public void servletAuthenticateWhenUsingMultipleHttpConfigsThenUsesSpringSecurity() - throws Exception { - + public void servletAuthenticateWhenUsingMultipleHttpConfigsThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("MultiHttp")).autowire(); - + // @formatter:off this.mvc.perform(get("/authenticate")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login")); - this.mvc.perform(get("/v2/authenticate")) .andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/login2")); - + // @formatter:on } @Test - public void servletLogoutWhenUsingMultipleHttpConfigsThenUsesSpringSecurity() - throws Exception { - + public void servletLogoutWhenUsingMultipleHttpConfigsThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("MultiHttp")).autowire(); - MvcResult result = this.mvc.perform(get("/good-login")).andReturn(); - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNotNull(); - + // @formatter:off result = this.mvc.perform(get("/do-logout").session(session)) .andExpect(status().isOk()) .andExpect(content().string("")) .andReturn(); - + // @formatter:on session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNotNull(); - - result = this.mvc.perform(get("/v2/good-login")).andReturn(); - + // @formatter:off + result = this.mvc.perform(get("/v2/good-login")) + .andReturn(); + // @formatter:on session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNotNull(); - + // @formatter:off result = this.mvc.perform(get("/v2/do-logout").session(session)) .andExpect(status().isOk()) .andExpect(content().string("")) .andReturn(); - + // @formatter:on session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNull(); } @Test - public void servletLogoutWhenUsingCustomLogoutThenUsesSpringSecurity() - throws Exception { - + public void servletLogoutWhenUsingCustomLogoutThenUsesSpringSecurity() throws Exception { this.spring.configLocations(this.xml("Logout")).autowire(); - - this.mvc.perform(get("/authenticate")) - .andExpect(status().isFound()) + this.mvc.perform(get("/authenticate")).andExpect(status().isFound()) .andExpect(redirectedUrl("http://localhost/signin")); - - MvcResult result = this.mvc.perform(get("/good-login")).andReturn(); - + // @formatter:off + MvcResult result = this.mvc.perform(get("/good-login")) + .andReturn(); + // @formatter:on MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNotNull(); - + // @formatter:off result = this.mvc.perform(get("/do-logout").session(session)) .andExpect(status().isOk()) .andExpect(content().string("")) .andExpect(cookie().maxAge("JSESSIONID", 0)) .andReturn(); - + // @formatter:on session = (MockHttpSession) result.getRequest().getSession(false); - assertThat(session).isNotNull(); } @@ -229,45 +203,43 @@ public class SecurityContextHolderAwareRequestConfigTests { */ @Test @WithMockUser - public void servletIsUserInRoleWhenUsingDefaultConfigThenRoleIsSet() - throws Exception { - + public void servletIsUserInRoleWhenUsingDefaultConfigThenRoleIsSet() throws Exception { this.spring.configLocations(this.xml("Simple")).autowire(); + // @formatter:off + this.mvc.perform(get("/role")) + .andExpect(content().string("true")); + // @formatter:on + } - this.mvc.perform(get("/role")).andExpect(content().string("true")); + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } @RestController public static class ServletAuthenticatedController { + @GetMapping("/v2/good-login") public String v2Login(HttpServletRequest request) throws ServletException { - request.login("user2", "password2"); - return this.principal(); } @GetMapping("/good-login") public String login(HttpServletRequest request) throws ServletException { - request.login("user", "password"); - return this.principal(); } @GetMapping("/v2/authenticate") public String v2Authenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { - return this.authenticate(request, response); } @GetMapping("/authenticate") public String authenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { - request.authenticate(response); - return this.principal(); } @@ -279,7 +251,6 @@ public class SecurityContextHolderAwareRequestConfigTests { @GetMapping("/do-logout") public String logout(HttpServletRequest request) throws ServletException { request.logout(); - return this.principal(); } @@ -289,14 +260,12 @@ public class SecurityContextHolderAwareRequestConfigTests { } private String principal() { - if ( SecurityContextHolder.getContext().getAuthentication() != null ) { + if (SecurityContextHolder.getContext().getAuthentication() != null) { return SecurityContextHolder.getContext().getAuthentication().getName(); } return null; } + } - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/SecurityFiltersAssertions.java b/config/src/test/java/org/springframework/security/config/http/SecurityFiltersAssertions.java index 776cb49a92..afa8a80d1a 100644 --- a/config/src/test/java/org/springframework/security/config/http/SecurityFiltersAssertions.java +++ b/config/src/test/java/org/springframework/security/config/http/SecurityFiltersAssertions.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.util.Arrays; @@ -23,18 +24,21 @@ import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; /** - * Assertions for tests that rely on confirming behavior of the package-private SecurityFilters enum + * Assertions for tests that rely on confirming behavior of the package-private + * SecurityFilters enum * * @author Josh Cummings */ -public class SecurityFiltersAssertions { +public final class SecurityFiltersAssertions { + private static Collection ordered = Arrays.asList(SecurityFilters.values()); - public static void assertEquals(List filters) { - List expected = ordered.stream() - .map(SecurityFilters::name) - .collect(Collectors.toList()); + private SecurityFiltersAssertions() { + } + public static void assertEquals(List filters) { + List expected = ordered.stream().map(SecurityFilters::name).collect(Collectors.toList()); assertThat(filters).isEqualTo(expected); } + } diff --git a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java index 7ced0f768a..b9e488cc11 100644 --- a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigServlet31Tests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.lang.reflect.Method; + import javax.servlet.Filter; import org.junit.After; @@ -38,7 +40,7 @@ import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.util.ReflectionUtils; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -48,17 +50,24 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @PrepareForTest({ ReflectionUtils.class, Method.class }) @PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" }) public class SessionManagementConfigServlet31Tests { + + // @formatter:off private static final String XML_AUTHENTICATION_MANAGER = "" - + " " + " " + + " " + + " " + " " - + " " + " " + + " " + + " " + ""; + // @formatter:on @Mock Method method; MockHttpServletRequest request; + MockHttpServletResponse response; + MockFilterChain chain; ConfigurableApplicationContext context; @@ -67,15 +76,15 @@ public class SessionManagementConfigServlet31Tests { @Before public void setup() { - request = new MockHttpServletRequest("GET", ""); - response = new MockHttpServletResponse(); - chain = new MockFilterChain(); + this.request = new MockHttpServletRequest("GET", ""); + this.response = new MockHttpServletResponse(); + this.chain = new MockFilterChain(); } @After public void teardown() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @@ -87,61 +96,54 @@ public class SessionManagementConfigServlet31Tests { request.setMethod("POST"); request.setParameter("username", "user"); request.setParameter("password", "password"); - request.getSession().setAttribute("attribute1", "value1"); - String id = request.getSession().getId(); - - loadContext("\n" + " \n" - + " \n" + " \n" - + " " + XML_AUTHENTICATION_MANAGER); - - springSecurityFilterChain.doFilter(request, response, chain); - - + // @formatter:off + loadContext("\n" + + " \n" + + " \n" + + " \n" + + " " + + XML_AUTHENTICATION_MANAGER); + // @formatter:on + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(request.getSession().getId()).isNotEqualTo(id); assertThat(request.getSession().getAttribute("attribute1")).isEqualTo("value1"); } @Test public void changeSessionId() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); request.getSession(); request.setServletPath("/login"); request.setMethod("POST"); request.setParameter("username", "user"); request.setParameter("password", "password"); - String id = request.getSession().getId(); - + // @formatter:off loadContext("\n" + " \n" + " \n" - + " \n" + " " + + " \n" + + " " + XML_AUTHENTICATION_MANAGER); - - springSecurityFilterChain.doFilter(request, response, chain); - + // @formatter:on + this.springSecurityFilterChain.doFilter(request, this.response, this.chain); assertThat(request.getSession().getId()).isNotEqualTo(id); - } private void loadContext(String context) { this.context = new InMemoryXmlApplicationContext(context); - this.springSecurityFilterChain = this.context.getBean( - "springSecurityFilterChain", Filter.class); + this.springSecurityFilterChain = this.context.getBean("springSecurityFilterChain", Filter.class); } private void login(Authentication auth) { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder( - request, response); + HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(this.request, this.response); repo.loadContext(requestResponseHolder); - SecurityContextImpl securityContextImpl = new SecurityContextImpl(); securityContextImpl.setAuthentication(auth); - repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), - requestResponseHolder.getResponse()); + repo.saveContext(securityContextImpl, requestResponseHolder.getRequest(), requestResponseHolder.getResponse()); } + } diff --git a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTests.java b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTests.java index 141ac51b75..f75a2edc4b 100644 --- a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import java.io.IOException; import java.security.Principal; import java.util.List; + import javax.servlet.Filter; import javax.servlet.ServletContext; import javax.servlet.ServletException; @@ -47,11 +49,13 @@ import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessEventPublishingLogoutHandler; import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; +import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.session.ConcurrentSessionFilter; import org.springframework.security.web.session.SessionManagementFilter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; +import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; @@ -59,7 +63,6 @@ import org.springframework.web.context.WebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; -import static org.springframework.security.web.context.HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; @@ -69,7 +72,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Tests session-related functionality for the <http> namespace element and <session-management> + * Tests session-related functionality for the <http> namespace element and + * <session-management> * * @author Luke Taylor * @author Rob Winch @@ -77,8 +81,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Onur Kagan Ozcan */ public class SessionManagementConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/SessionManagementConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/SessionManagementConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -87,156 +91,130 @@ public class SessionManagementConfigTests { MockMvc mvc; @Test - public void requestWhenCreateSessionAlwaysThenAlwaysCreatesSession() - throws Exception { - this.spring.configLocations(this.xml("CreateSessionAlways")).autowire(); - + public void requestWhenCreateSessionAlwaysThenAlwaysCreatesSession() throws Exception { + this.spring.configLocations(xml("CreateSessionAlways")).autowire(); MockHttpServletRequest request = get("/").buildRequest(this.servletContext()); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_OK); assertThat(request.getSession(false)).isNotNull(); } @Test - public void requestWhenCreateSessionIsSetToNeverThenDoesNotCreateSessionOnLoginChallenge() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionNever")).autowire(); - + public void requestWhenCreateSessionIsSetToNeverThenDoesNotCreateSessionOnLoginChallenge() throws Exception { + this.spring.configLocations(xml("CreateSessionNever")).autowire(); MockHttpServletRequest request = get("/auth").buildRequest(this.servletContext()); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(request.getSession(false)).isNull(); } @Test - public void requestWhenCreateSessionIsSetToNeverThenDoesNotCreateSessionOnLogin() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionNever")).autowire(); - + public void requestWhenCreateSessionIsSetToNeverThenDoesNotCreateSessionOnLogin() throws Exception { + this.spring.configLocations(xml("CreateSessionNever")).autowire(); + // @formatter:off MockHttpServletRequest request = post("/login") .param("username", "user") .param("password", "password") .buildRequest(this.servletContext()); + // @formatter:on request = csrf().postProcessRequest(request); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(request.getSession(false)).isNull(); } @Test - public void requestWhenCreateSessionIsSetToNeverThenUsesExistingSession() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionNever")).autowire(); - + public void requestWhenCreateSessionIsSetToNeverThenUsesExistingSession() throws Exception { + this.spring.configLocations(xml("CreateSessionNever")).autowire(); + // @formatter:off MockHttpServletRequest request = post("/login") .param("username", "user") .param("password", "password") .buildRequest(this.servletContext()); + // @formatter:on request = csrf().postProcessRequest(request); MockHttpSession session = new MockHttpSession(); request.setSession(session); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(request.getSession(false)).isNotNull(); - assertThat(request.getSession(false).getAttribute(SPRING_SECURITY_CONTEXT_KEY)) - .isNotNull(); + assertThat(request.getSession(false) + .getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)).isNotNull(); } @Test - public void requestWhenCreateSessionIsSetToStatelessThenDoesNotCreateSessionOnLoginChallenge() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionStateless")).autowire(); - + public void requestWhenCreateSessionIsSetToStatelessThenDoesNotCreateSessionOnLoginChallenge() throws Exception { + this.spring.configLocations(xml("CreateSessionStateless")).autowire(); + // @formatter:off this.mvc.perform(get("/auth")) .andExpect(status().isFound()) .andExpect(session().exists(false)); + // @formatter:on } @Test - public void requestWhenCreateSessionIsSetToStatelessThenDoesNotCreateSessionOnLogin() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionStateless")).autowire(); - - - this.mvc.perform(post("/login") + public void requestWhenCreateSessionIsSetToStatelessThenDoesNotCreateSessionOnLogin() throws Exception { + this.spring.configLocations(xml("CreateSessionStateless")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") .param("username", "user") .param("password", "password") - .with(csrf())) + .with(csrf()); + this.mvc.perform(loginRequest) .andExpect(status().isFound()) .andExpect(session().exists(false)); + // @formatter:on } @Test - public void requestWhenCreateSessionIsSetToStatelessThenIgnoresExistingSession() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionStateless")).autowire(); - - MvcResult result = - this.mvc.perform(post("/login") - .param("username", "user") - .param("password", "password") - .session(new MockHttpSession()) - .with(csrf())) - .andExpect(status().isFound()) - .andExpect(session()) - .andReturn(); - - assertThat(result.getRequest().getSession(false).getAttribute(SPRING_SECURITY_CONTEXT_KEY)) - .isNull(); + public void requestWhenCreateSessionIsSetToStatelessThenIgnoresExistingSession() throws Exception { + this.spring.configLocations(xml("CreateSessionStateless")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .param("username", "user") + .param("password", "password") + .session(new MockHttpSession()) + .with(csrf()); + MvcResult result = this.mvc.perform(loginRequest) + .andExpect(status().isFound()) + .andExpect(session()).andReturn(); + // @formatter:on + assertThat(result.getRequest().getSession(false) + .getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)).isNull(); } @Test - public void requestWhenCreateSessionIsSetToIfRequiredThenDoesNotCreateSessionOnPublicInvocation() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionIfRequired")).autowire(); - + public void requestWhenCreateSessionIsSetToIfRequiredThenDoesNotCreateSessionOnPublicInvocation() throws Exception { + this.spring.configLocations(xml("CreateSessionIfRequired")).autowire(); ServletContext servletContext = this.mvc.getDispatcherServlet().getServletContext(); MockHttpServletRequest request = get("/").buildRequest(servletContext); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_OK); assertThat(request.getSession(false)).isNull(); } @Test - public void requestWhenCreateSessionIsSetToIfRequiredThenCreatesSessionOnLoginChallenge() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionIfRequired")).autowire(); - + public void requestWhenCreateSessionIsSetToIfRequiredThenCreatesSessionOnLoginChallenge() throws Exception { + this.spring.configLocations(xml("CreateSessionIfRequired")).autowire(); ServletContext servletContext = this.mvc.getDispatcherServlet().getServletContext(); MockHttpServletRequest request = get("/auth").buildRequest(servletContext); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(request.getSession(false)).isNotNull(); } @Test - public void requestWhenCreateSessionIsSetToIfRequiredThenCreatesSessionOnLogin() - throws Exception { - - this.spring.configLocations(this.xml("CreateSessionIfRequired")).autowire(); - + public void requestWhenCreateSessionIsSetToIfRequiredThenCreatesSessionOnLogin() throws Exception { + this.spring.configLocations(xml("CreateSessionIfRequired")).autowire(); ServletContext servletContext = this.mvc.getDispatcherServlet().getServletContext(); + // @formatter:off MockHttpServletRequest request = post("/login") .param("username", "user") .param("password", "password") .buildRequest(servletContext); + // @formatter:on request = csrf().postProcessRequest(request); MockHttpServletResponse response = request(request, this.spring.getContext()); - assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_MOVED_TEMPORARILY); assertThat(request.getSession(false)).isNotNull(); } @@ -245,20 +223,16 @@ public class SessionManagementConfigTests { * SEC-1208 */ @Test - public void requestWhenRejectingUserBasedOnMaxSessionsExceededThenDoesNotCreateSession() - throws Exception { - - this.spring.configLocations(this.xml("Sec1208")).autowire(); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + public void requestWhenRejectingUserBasedOnMaxSessionsExceededThenDoesNotCreateSession() throws Exception { + this.spring.configLocations(xml("Sec1208")).autowire(); + // @formatter:off + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(status().isOk()) .andExpect(session()); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(status().isUnauthorized()) .andExpect(session().exists(false)); + // @formatter:on } /** @@ -267,284 +241,177 @@ public class SessionManagementConfigTests { @Test public void requestWhenSessionFixationProtectionDisabledAndConcurrencyControlEnabledThenSessionNotInvalidated() throws Exception { - - this.spring.configLocations(this.xml("Sec2137")).autowire(); - + this.spring.configLocations(xml("Sec2137")).autowire(); MockHttpSession session = new MockHttpSession(); - this.mvc.perform(get("/auth") - .session(session) - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/auth").session(session).with(httpBasic("user", "password"))) .andExpect(status().isOk()) .andExpect(session().id(session.getId())); + // @formatter:on } @Test public void autowireWhenExportingSessionRegistryBeanThenAvailableForWiring() { - this.spring.configLocations(this.xml("ConcurrencyControlSessionRegistryAlias")).autowire(); - + this.spring.configLocations(xml("ConcurrencyControlSessionRegistryAlias")).autowire(); this.sessionRegistryIsValid(); } @Test - public void requestWhenExpiredUrlIsSetThenInvalidatesSessionAndRedirects() - throws Exception { - - this.spring.configLocations(this.xml("ConcurrencyControlExpiredUrl")).autowire(); - - this.mvc.perform(get("/auth") - .session(this.expiredSession()) - .with(httpBasic("user", "password"))) + public void requestWhenExpiredUrlIsSetThenInvalidatesSessionAndRedirects() throws Exception { + this.spring.configLocations(xml("ConcurrencyControlExpiredUrl")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(expiredSession()) + .with(httpBasic("user", "password")); + this.mvc.perform(request) .andExpect(redirectedUrl("/expired")) .andExpect(session().exists(false)); + // @formatter:on } @Test public void requestWhenConcurrencyControlAndCustomLogoutHandlersAreSetThenAllAreInvokedWhenSessionExpires() throws Exception { - - this.spring.configLocations(this.xml("ConcurrencyControlLogoutAndRememberMeHandlers")).autowire(); - - this.mvc.perform(get("/auth") - .session(this.expiredSession()) - .with(httpBasic("user", "password"))) + this.spring.configLocations(xml("ConcurrencyControlLogoutAndRememberMeHandlers")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(expiredSession()) + .with(httpBasic("user", "password")); + this.mvc.perform(request) .andExpect(status().isOk()) .andExpect(cookie().maxAge("testCookie", 0)) .andExpect(cookie().exists("rememberMeCookie")) .andExpect(session().valid(true)); + // @formatter:on } @Test - public void requestWhenConcurrencyControlAndRememberMeAreSetThenInvokedWhenSessionExpires() - throws Exception { - - this.spring.configLocations(this.xml("ConcurrencyControlRememberMeHandler")).autowire(); - - this.mvc.perform(get("/auth") - .session(this.expiredSession()) - .with(httpBasic("user", "password"))) - .andExpect(status().isOk()) - .andExpect(cookie().exists("rememberMeCookie")) + public void requestWhenConcurrencyControlAndRememberMeAreSetThenInvokedWhenSessionExpires() throws Exception { + this.spring.configLocations(xml("ConcurrencyControlRememberMeHandler")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(expiredSession()) + .with(httpBasic("user", "password")); + this.mvc.perform(request) + .andExpect(status().isOk()).andExpect(cookie().exists("rememberMeCookie")) .andExpect(session().exists(false)); + // @formatter:on } /** * SEC-2057 */ @Test - public void autowireWhenConcurrencyControlIsSetThenLogoutHandlersGetAuthenticationObject() - throws Exception { - - this.spring.configLocations(this.xml("ConcurrencyControlCustomLogoutHandler")).autowire(); - - MvcResult result = - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) - .andExpect(session()) - .andReturn(); - + public void autowireWhenConcurrencyControlIsSetThenLogoutHandlersGetAuthenticationObject() throws Exception { + this.spring.configLocations(xml("ConcurrencyControlCustomLogoutHandler")).autowire(); + MvcResult result = this.mvc.perform(get("/auth").with(httpBasic("user", "password"))).andExpect(session()) + .andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); - SessionRegistry sessionRegistry = this.spring.getContext().getBean(SessionRegistry.class); sessionRegistry.getSessionInformation(session.getId()).expireNow(); - - this.mvc.perform(get("/auth") - .session(session)) + // @formatter:off + this.mvc.perform(get("/auth").session(session)) .andExpect(header().string("X-Username", "user")); + // @formatter:on } @Test - public void requestWhenConcurrencyControlIsSetThenDefaultsToResponseBodyExpirationResponse() - throws Exception { - - this.spring.configLocations(this.xml("ConcurrencyControlSessionRegistryAlias")).autowire(); - - this.mvc.perform(get("/auth") - .session(this.expiredSession()) - .with(httpBasic("user", "password"))) + public void requestWhenConcurrencyControlIsSetThenDefaultsToResponseBodyExpirationResponse() throws Exception { + this.spring.configLocations(xml("ConcurrencyControlSessionRegistryAlias")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder request = get("/auth") + .session(expiredSession()) + .with(httpBasic("user", "password")); + this.mvc.perform(request) .andExpect(content().string("This session has been expired (possibly due to multiple concurrent " + "logins being attempted as the same user).")); + // @formatter:on } @Test - public void requestWhenCustomSessionAuthenticationStrategyThenInvokesOnAuthentication() - throws Exception { - - this.spring.configLocations(this.xml("SessionAuthenticationStrategyRef")).autowire(); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + public void requestWhenCustomSessionAuthenticationStrategyThenInvokesOnAuthentication() throws Exception { + this.spring.configLocations(xml("SessionAuthenticationStrategyRef")).autowire(); + // @formatter:off + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(status().isIAmATeapot()); + // @formatter:on } @Test public void autowireWhenSessionRegistryRefIsSetThenAvailableForWiring() { - this.spring.configLocations(this.xml("ConcurrencyControlSessionRegistryRef")).autowire(); - + this.spring.configLocations(xml("ConcurrencyControlSessionRegistryRef")).autowire(); this.sessionRegistryIsValid(); } @Test - public void requestWhenMaxSessionsIsSetThenErrorsWhenExceeded() - throws Exception { - - this.spring.configLocations(this.xml("ConcurrencyControlMaxSessions")).autowire(); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + public void requestWhenMaxSessionsIsSetThenErrorsWhenExceeded() throws Exception { + this.spring.configLocations(xml("ConcurrencyControlMaxSessions")).autowire(); + // @formatter:off + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(status().isOk()); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(status().isOk()); - - this.mvc.perform(get("/auth") - .with(httpBasic("user", "password"))) + this.mvc.perform(get("/auth").with(httpBasic("user", "password"))) .andExpect(redirectedUrl("/max-exceeded")); + // @formatter:on } @Test public void autowireWhenSessionFixationProtectionIsNoneAndCsrfDisabledThenSessionManagementFilterIsNotWired() { - - this.spring.configLocations(this.xml("NoSessionManagementFilter")).autowire(); - + this.spring.configLocations(xml("NoSessionManagementFilter")).autowire(); assertThat(this.getFilter(SessionManagementFilter.class)).isNull(); } @Test - public void requestWhenSessionFixationProtectionIsNoneThenSessionNotInvalidated() - throws Exception { - - this.spring.configLocations(this.xml("SessionFixationProtectionNone")).autowire(); - + public void requestWhenSessionFixationProtectionIsNoneThenSessionNotInvalidated() throws Exception { + this.spring.configLocations(xml("SessionFixationProtectionNone")).autowire(); MockHttpSession session = new MockHttpSession(); String sessionId = session.getId(); - - this.mvc.perform(get("/auth") - .session(session) - .with(httpBasic("user", "password"))) + // @formatter:off + this.mvc.perform(get("/auth").session(session).with(httpBasic("user", "password"))) .andExpect(session().id(sessionId)); + // @formatter:on } @Test - public void requestWhenSessionFixationProtectionIsMigrateSessionThenSessionIsReplaced() - throws Exception { - - this.spring.configLocations(this.xml("SessionFixationProtectionMigrateSession")).autowire(); - + public void requestWhenSessionFixationProtectionIsMigrateSessionThenSessionIsReplaced() throws Exception { + this.spring.configLocations(xml("SessionFixationProtectionMigrateSession")).autowire(); MockHttpSession session = new MockHttpSession(); String sessionId = session.getId(); - - MvcResult result = - this.mvc.perform(get("/auth") - .session(session) - .with(httpBasic("user", "password"))) - .andExpect(session()) - .andReturn(); - + // @formatter:off + MvcResult result = this.mvc.perform(get("/auth").session(session).with(httpBasic("user", "password"))) + .andExpect(session()) + .andReturn(); + // @formatter:on assertThat(result.getRequest().getSession(false).getId()).isNotEqualTo(sessionId); } @Test public void requestWhenSessionFixationProtectionIsNoneAndInvalidSessionUrlIsSetThenStillRedirectsOnInvalidSession() throws Exception { - - this.spring.configLocations(this.xml("SessionFixationProtectionNoneWithInvalidSessionUrl")).autowire(); - - this.mvc.perform(get("/auth") - .with(request -> { + this.spring.configLocations(xml("SessionFixationProtectionNoneWithInvalidSessionUrl")).autowire(); + // @formatter:off + MockHttpServletRequestBuilder authRequest = get("/auth") + .with((request) -> { request.setRequestedSessionId("1"); request.setRequestedSessionIdValid(false); return request; - })) + }); + this.mvc.perform(authRequest) .andExpect(redirectedUrl("/timeoutUrl")); - } - - /** - * SEC-2680 - */ - @Test - public void checkConcurrencyAndLogoutFilterHasSameSizeAndHasLogoutSuccessEventPublishingLogoutHandler() { - - this.spring.configLocations(this.xml("ConcurrencyControlLogoutAndRememberMeHandlers")).autowire(); - - ConcurrentSessionFilter concurrentSessionFilter = getFilter(ConcurrentSessionFilter.class); - LogoutFilter logoutFilter = getFilter(LogoutFilter.class); - - LogoutHandler csfLogoutHandler = getFieldValue(concurrentSessionFilter, "handlers"); - LogoutHandler lfLogoutHandler = getFieldValue(logoutFilter, "handler"); - - assertThat(csfLogoutHandler).isInstanceOf(CompositeLogoutHandler.class); - assertThat(lfLogoutHandler).isInstanceOf(CompositeLogoutHandler.class); - - List csfLogoutHandlers = getFieldValue(csfLogoutHandler, "logoutHandlers"); - List lfLogoutHandlers = getFieldValue(lfLogoutHandler, "logoutHandlers"); - - assertThat(csfLogoutHandlers).hasSameSizeAs(lfLogoutHandlers); - - assertThat(csfLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); - assertThat(lfLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); - } - - static class TeapotSessionAuthenticationStrategy implements SessionAuthenticationStrategy { - - @Override - public void onAuthentication( - Authentication authentication, - HttpServletRequest request, - HttpServletResponse response) throws SessionAuthenticationException { - - response.setStatus(org.springframework.http.HttpStatus.I_AM_A_TEAPOT.value()); - } - } - - static class CustomRememberMeServices implements RememberMeServices, LogoutHandler { - @Override - public Authentication autoLogin(HttpServletRequest request, HttpServletResponse response) { - return null; - } - - @Override - public void loginFail(HttpServletRequest request, HttpServletResponse response) { - - } - - @Override - public void loginSuccess(HttpServletRequest request, HttpServletResponse response, Authentication successfulAuthentication) { - - } - - @Override - public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { - response.addHeader("X-Username", authentication.getName()); - } - } - - @RestController - static class BasicController { - @GetMapping("/") - public String ok() { - return "ok"; - } - - @GetMapping("/auth") - public String auth(Principal principal) { - return principal.getName(); - } + // @formatter:on } private void sessionRegistryIsValid() { - SessionRegistry sessionRegistry = this.spring.getContext() - .getBean("sessionRegistry", SessionRegistry.class); - + SessionRegistry sessionRegistry = this.spring.getContext().getBean("sessionRegistry", SessionRegistry.class); assertThat(sessionRegistry).isNotNull(); - - assertThat(this.getFilter(ConcurrentSessionFilter.class)) - .returns(sessionRegistry, this::extractSessionRegistry); - assertThat(this.getFilter(UsernamePasswordAuthenticationFilter.class)) - .returns(sessionRegistry, this::extractSessionRegistry); + assertThat(this.getFilter(ConcurrentSessionFilter.class)).returns(sessionRegistry, + this::extractSessionRegistry); + assertThat(this.getFilter(UsernamePasswordAuthenticationFilter.class)).returns(sessionRegistry, + this::extractSessionRegistry); // SEC-1143 - assertThat(this.getFilter(SessionManagementFilter.class)) - .returns(sessionRegistry, this::extractSessionRegistry); + assertThat(this.getFilter(SessionManagementFilter.class)).returns(sessionRegistry, + this::extractSessionRegistry); } private SessionRegistry extractSessionRegistry(ConcurrentSessionFilter filter) { @@ -566,8 +433,9 @@ public class SessionManagementConfigTests { private T getFieldValue(Object target, String fieldName) { try { return (T) FieldUtils.getFieldValue(target, fieldName); - } catch (Exception e) { - throw new RuntimeException(e); + } + catch (Exception ex) { + throw new RuntimeException(ex); } } @@ -575,22 +443,127 @@ public class SessionManagementConfigTests { return new SessionResultMatcher(); } + /** + * SEC-2680 + */ + @Test + public void checkConcurrencyAndLogoutFilterHasSameSizeAndHasLogoutSuccessEventPublishingLogoutHandler() { + this.spring.configLocations(xml("ConcurrencyControlLogoutAndRememberMeHandlers")).autowire(); + ConcurrentSessionFilter concurrentSessionFilter = getFilter(ConcurrentSessionFilter.class); + LogoutFilter logoutFilter = getFilter(LogoutFilter.class); + LogoutHandler csfLogoutHandler = getFieldValue(concurrentSessionFilter, "handlers"); + LogoutHandler lfLogoutHandler = getFieldValue(logoutFilter, "handler"); + assertThat(csfLogoutHandler).isInstanceOf(CompositeLogoutHandler.class); + assertThat(lfLogoutHandler).isInstanceOf(CompositeLogoutHandler.class); + List csfLogoutHandlers = getFieldValue(csfLogoutHandler, "logoutHandlers"); + List lfLogoutHandlers = getFieldValue(lfLogoutHandler, "logoutHandlers"); + assertThat(csfLogoutHandlers).hasSameSizeAs(lfLogoutHandlers); + assertThat(csfLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); + assertThat(lfLogoutHandlers).hasAtLeastOneElementOfType(LogoutSuccessEventPublishingLogoutHandler.class); + } + + private static MockHttpServletResponse request(MockHttpServletRequest request, ApplicationContext context) + throws IOException, ServletException { + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChainProxy proxy = context.getBean(FilterChainProxy.class); + proxy.doFilter(request, new EncodeUrlDenyingHttpServletResponseWrapper(response), (req, resp) -> { + }); + return response; + } + + private MockHttpSession expiredSession() { + MockHttpSession session = new MockHttpSession(); + SessionRegistry sessionRegistry = this.spring.getContext().getBean(SessionRegistry.class); + sessionRegistry.registerNewSession(session.getId(), "user"); + sessionRegistry.getSessionInformation(session.getId()).expireNow(); + return session; + } + + private T getFilter(Class filterClass) { + return (T) getFilters().stream().filter(filterClass::isInstance).findFirst().orElse(null); + } + + private List getFilters() { + FilterChainProxy proxy = this.spring.getContext().getBean(FilterChainProxy.class); + return proxy.getFilters("/"); + } + + private ServletContext servletContext() { + WebApplicationContext context = this.spring.getContext(); + return context.getServletContext(); + } + + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; + } + + static class TeapotSessionAuthenticationStrategy implements SessionAuthenticationStrategy { + + @Override + public void onAuthentication(Authentication authentication, HttpServletRequest request, + HttpServletResponse response) throws SessionAuthenticationException { + response.setStatus(org.springframework.http.HttpStatus.I_AM_A_TEAPOT.value()); + } + + } + + static class CustomRememberMeServices implements RememberMeServices, LogoutHandler { + + @Override + public Authentication autoLogin(HttpServletRequest request, HttpServletResponse response) { + return null; + } + + @Override + public void loginFail(HttpServletRequest request, HttpServletResponse response) { + } + + @Override + public void loginSuccess(HttpServletRequest request, HttpServletResponse response, + Authentication successfulAuthentication) { + } + + @Override + public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { + response.addHeader("X-Username", authentication.getName()); + } + + } + + @RestController + static class BasicController { + + @GetMapping("/") + String ok() { + return "ok"; + } + + @GetMapping("/auth") + String auth(Principal principal) { + return principal.getName(); + } + + } + private static class SessionResultMatcher implements ResultMatcher { + private String id; + private Boolean valid; + private Boolean exists = true; - public ResultMatcher exists(boolean exists) { + ResultMatcher exists(boolean exists) { this.exists = exists; return this; } - public ResultMatcher valid(boolean valid) { + ResultMatcher valid(boolean valid) { this.valid = valid; return this.exists(true); } - public ResultMatcher id(String id) { + ResultMatcher id(String id) { this.id = id; return this.exists(true); } @@ -601,44 +574,24 @@ public class SessionManagementConfigTests { assertThat(result.getRequest().getSession(false)).isNull(); return; } - assertThat(result.getRequest().getSession(false)).isNotNull(); - MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false); - if (this.valid != null) { if (this.valid) { assertThat(session.isInvalid()).isFalse(); - } else { + } + else { assertThat(session.isInvalid()).isTrue(); } } - if (this.id != null) { assertThat(session.getId()).isEqualTo(this.id); } } + } - private static MockHttpServletResponse request( - MockHttpServletRequest request, - ApplicationContext context) - throws IOException, ServletException { - - MockHttpServletResponse response = new MockHttpServletResponse(); - - FilterChainProxy proxy = context.getBean(FilterChainProxy.class); - - proxy.doFilter( - request, - new EncodeUrlDenyingHttpServletResponseWrapper(response), - (req, resp) -> {}); - - return response; - } - - private static class EncodeUrlDenyingHttpServletResponseWrapper - extends HttpServletResponseWrapper { + private static class EncodeUrlDenyingHttpServletResponseWrapper extends HttpServletResponseWrapper { EncodeUrlDenyingHttpServletResponseWrapper(HttpServletResponse response) { super(response); @@ -663,35 +616,7 @@ public class SessionManagementConfigTests { public String encodeRedirectUrl(String url) { throw new RuntimeException("Unexpected invocation of encodeURL"); } + } - private MockHttpSession expiredSession() { - MockHttpSession session = new MockHttpSession(); - SessionRegistry sessionRegistry = this.spring.getContext().getBean(SessionRegistry.class); - sessionRegistry.registerNewSession(session.getId(), "user"); - sessionRegistry.getSessionInformation(session.getId()).expireNow(); - return session; - } - - private T getFilter(Class filterClass) { - return (T) getFilters().stream() - .filter(filterClass::isInstance) - .findFirst() - .orElse(null); - } - - private List getFilters() { - FilterChainProxy proxy = this.spring.getContext().getBean(FilterChainProxy.class); - - return proxy.getFilters("/"); - } - - private ServletContext servletContext() { - WebApplicationContext context = (WebApplicationContext) this.spring.getContext(); - return context.getServletContext(); - } - - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTransientAuthenticationTests.java b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTransientAuthenticationTests.java index e13801ba58..8e2df3e6e9 100644 --- a/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTransientAuthenticationTests.java +++ b/config/src/test/java/org/springframework/security/config/http/SessionManagementConfigTransientAuthenticationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http; import org.junit.Rule; @@ -35,8 +36,8 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder * @author Josh Cummings */ public class SessionManagementConfigTransientAuthenticationTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/http/SessionManagementConfigTransientAuthenticationTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/SessionManagementConfigTransientAuthenticationTests"; @Autowired MockMvc mvc; @@ -45,23 +46,23 @@ public class SessionManagementConfigTransientAuthenticationTests { public final SpringTestRule spring = new SpringTestRule(); @Test - public void postWhenTransientAuthenticationThenNoSessionCreated() - throws Exception { - + public void postWhenTransientAuthenticationThenNoSessionCreated() throws Exception { this.spring.configLocations(this.xml("WithTransientAuthentication")).autowire(); MvcResult result = this.mvc.perform(post("/login")).andReturn(); assertThat(result.getRequest().getSession(false)).isNull(); } @Test - public void postWhenTransientAuthenticationThenAlwaysSessionOverrides() - throws Exception { - + public void postWhenTransientAuthenticationThenAlwaysSessionOverrides() throws Exception { this.spring.configLocations(this.xml("CreateSessionAlwaysWithTransientAuthentication")).autowire(); MvcResult result = this.mvc.perform(post("/login")).andReturn(); assertThat(result.getRequest().getSession(false)).isNotNull(); } + private String xml(String configName) { + return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; + } + static class TransientAuthenticationProvider implements AuthenticationProvider { @Override @@ -73,10 +74,12 @@ public class SessionManagementConfigTransientAuthenticationTests { public boolean supports(Class authentication) { return true; } + } @Transient static class SomeTransientAuthentication extends AbstractAuthenticationToken { + SomeTransientAuthentication() { super(null); } @@ -90,9 +93,7 @@ public class SessionManagementConfigTransientAuthenticationTests { public Object getPrincipal() { return null; } + } - private String xml(String configName) { - return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; - } } diff --git a/config/src/test/java/org/springframework/security/config/http/WebConfigUtilsTest.java b/config/src/test/java/org/springframework/security/config/http/WebConfigUtilsTests.java similarity index 81% rename from config/src/test/java/org/springframework/security/config/http/WebConfigUtilsTest.java rename to config/src/test/java/org/springframework/security/config/http/WebConfigUtilsTests.java index 0eb0023f65..7268e6400c 100644 --- a/config/src/test/java/org/springframework/security/config/http/WebConfigUtilsTest.java +++ b/config/src/test/java/org/springframework/security/config/http/WebConfigUtilsTests.java @@ -13,21 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.http; -import static org.mockito.Mockito.verifyZeroInteractions; +package org.springframework.security.config.http; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.beans.factory.xml.ParserContext; +import static org.mockito.Mockito.verifyZeroInteractions; + @RunWith(PowerMockRunner.class) @PrepareOnlyThisForTest(ParserContext.class) -public class WebConfigUtilsTest { - public final static String URL = "/url"; +public class WebConfigUtilsTests { + + public static final String URL = "/url"; @Mock private ParserContext parserContext; @@ -35,9 +38,9 @@ public class WebConfigUtilsTest { // SEC-1980 @Test public void validateHttpRedirectSpELNoParserWarning() { - WebConfigUtils.validateHttpRedirect( - "#{T(org.springframework.security.config.http.WebConfigUtilsTest).URL}", - parserContext, "fakeSource"); - verifyZeroInteractions(parserContext); + WebConfigUtils.validateHttpRedirect("#{T(org.springframework.security.config.http.WebConfigUtilsTest).URL}", + this.parserContext, "fakeSource"); + verifyZeroInteractions(this.parserContext); } -} \ No newline at end of file + +} diff --git a/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomConfigurer.java b/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomConfigurer.java index d377d0dc35..96f5469ef5 100644 --- a/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomConfigurer.java +++ b/config/src/test/java/org/springframework/security/config/http/customconfigurer/CustomConfigurer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.http.customconfigurer; import org.springframework.beans.factory.annotation.Value; @@ -33,26 +34,25 @@ public class CustomConfigurer extends SecurityConfigurerAdapter clazz) { @@ -120,21 +115,23 @@ public class CustomHttpSecurityConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .apply(customConfigurer()) + .apply(CustomConfigurer.customConfigurer()) .loginPage("/custom"); + // @formatter:on } @Bean - public static PropertyPlaceholderConfigurer propertyPlaceholderConfigurer() { + static PropertyPlaceholderConfigurer propertyPlaceholderConfigurer() { // Typically externalize this as a properties file Properties properties = new Properties(); properties.setProperty("permitAllPattern", "/public/**"); - PropertyPlaceholderConfigurer propertyPlaceholderConfigurer = new PropertyPlaceholderConfigurer(); propertyPlaceholderConfigurer.setProperties(properties); return propertyPlaceholderConfigurer; } + } @EnableWebSecurity @@ -142,23 +139,26 @@ public class CustomHttpSecurityConfigurerTests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .apply(customConfigurer()) + .apply(CustomConfigurer.customConfigurer()) .and() .csrf().disable() .formLogin() .loginPage("/other"); + // @formatter:on } @Bean - public static PropertyPlaceholderConfigurer propertyPlaceholderConfigurer() { + static PropertyPlaceholderConfigurer propertyPlaceholderConfigurer() { // Typically externalize this as a properties file Properties properties = new Properties(); properties.setProperty("permitAllPattern", "/public/**"); - PropertyPlaceholderConfigurer propertyPlaceholderConfigurer = new PropertyPlaceholderConfigurer(); propertyPlaceholderConfigurer.setProperties(properties); return propertyPlaceholderConfigurer; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/method/Contact.java b/config/src/test/java/org/springframework/security/config/method/Contact.java index 78e3420fba..819b97ad35 100644 --- a/config/src/test/java/org/springframework/security/config/method/Contact.java +++ b/config/src/test/java/org/springframework/security/config/method/Contact.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; /** @@ -20,6 +21,7 @@ package org.springframework.security.config.method; * */ public class Contact { + private String name; /** @@ -33,6 +35,7 @@ public class Contact { * @return the name */ public String getName() { - return name; + return this.name; } + } diff --git a/config/src/test/java/org/springframework/security/config/method/ContactPermission.java b/config/src/test/java/org/springframework/security/config/method/ContactPermission.java index 140cf3fd8a..a78b26f15f 100644 --- a/config/src/test/java/org/springframework/security/config/method/ContactPermission.java +++ b/config/src/test/java/org/springframework/security/config/method/ContactPermission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.lang.annotation.Retention; @@ -26,4 +27,6 @@ import org.springframework.security.access.prepost.PreAuthorize; */ @Retention(RetentionPolicy.RUNTIME) @PreAuthorize("#contact.name == authentication.name") -public @interface ContactPermission {} +public @interface ContactPermission { + +} diff --git a/config/src/test/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParserTests.java index 38b75d7743..0165eeb32c 100644 --- a/config/src/test/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/method/GlobalMethodSecurityBeanDefinitionParserTests.java @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.config.method; -import static org.assertj.core.api.Assertions.*; -import static org.springframework.security.config.ConfigTestUtils.AUTH_PROVIDER_XML; +package org.springframework.security.config.method; import java.util.ArrayList; import java.util.List; import org.junit.After; import org.junit.Test; + import org.springframework.aop.Advisor; import org.springframework.aop.framework.Advised; import org.springframework.beans.BeansException; @@ -59,55 +58,58 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.util.FieldUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * @author Ben Alex * @author Luke Taylor */ public class GlobalMethodSecurityBeanDefinitionParserTests { - private final UsernamePasswordAuthenticationToken bob = new UsernamePasswordAuthenticationToken( - "bob", "bobspassword"); + + private final UsernamePasswordAuthenticationToken bob = new UsernamePasswordAuthenticationToken("bob", + "bobspassword"); private AbstractXmlApplicationContext appContext; private BusinessService target; public void loadContext() { + // @formatter:off setContext("" + "" + " " + " " - + "" + ConfigTestUtils.AUTH_PROVIDER_XML); - target = (BusinessService) appContext.getBean("target"); + + "" + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + this.target = (BusinessService) this.appContext.getBean("target"); } @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); - appContext = null; + if (this.appContext != null) { + this.appContext.close(); + this.appContext = null; } SecurityContextHolder.clearContext(); - target = null; + this.target = null; } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void targetShouldPreventProtectedMethodInvocationWithNoContext() { loadContext(); - - target.someUserMethod1(); + this.target.someUserMethod1(); } @Test public void targetShouldAllowProtectedMethodInvocationWithCorrectRole() { loadContext(); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", "password"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password"); SecurityContextHolder.getContext().setAuthentication(token); - - target.someUserMethod1(); - + this.target.someUserMethod1(); // SEC-1213. Check the order - Advisor[] advisors = ((Advised) target).getAdvisors(); + Advisor[] advisors = ((Advised) this.target).getAdvisors(); assertThat(advisors).hasSize(1); assertThat(((MethodSecurityMetadataSourceAdvisor) advisors[0]).getOrder()).isEqualTo(1001); } @@ -115,68 +117,66 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { @Test(expected = AccessDeniedException.class) public void targetShouldPreventProtectedMethodInvocationWithIncorrectRole() { loadContext(); - TestingAuthenticationToken token = new TestingAuthenticationToken("Test", - "Password", "ROLE_SOMEOTHERROLE"); + TestingAuthenticationToken token = new TestingAuthenticationToken("Test", "Password", "ROLE_SOMEOTHERROLE"); token.setAuthenticated(true); - SecurityContextHolder.getContext().setAuthentication(token); - - target.someAdminMethod(); + this.target.someAdminMethod(); } @Test public void doesntInterfereWithBeanPostProcessing() { + // @formatter:off setContext("" + "" + "" + " " + "" + ""); - - PostProcessedMockUserDetailsService service = (PostProcessedMockUserDetailsService) appContext + // @formatter:on + PostProcessedMockUserDetailsService service = (PostProcessedMockUserDetailsService) this.appContext .getBean("myUserService"); - assertThat(service.getPostProcessorWasHere()).isEqualTo("Hello from the post processor!"); } @Test(expected = AccessDeniedException.class) public void worksWithAspectJAutoproxy() { + // @formatter:off setContext("" + " " + "" + "" - + "" + "" + + "" + + "" + " " + ""); - - UserDetailsService service = (UserDetailsService) appContext - .getBean("myUserService"); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", + // @formatter:on + UserDetailsService service = (UserDetailsService) this.appContext.getBean("myUserService"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_SOMEOTHERROLE")); SecurityContextHolder.getContext().setAuthentication(token); - service.loadUserByUsername("notused"); } @Test public void supportsMethodArgumentsInPointcut() { + // @formatter:off setContext("" + "" + " " + " " - + "" + ConfigTestUtils.AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("user", "password")); - target = (BusinessService) appContext.getBean("target"); + + "" + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext() + .setAuthentication(new UsernamePasswordAuthenticationToken("user", "password")); + this.target = (BusinessService) this.appContext.getBean("target"); // someOther(int) should not be matched by someOther(String), but should require // ROLE_USER - target.someOther(0); - + this.target.someOther(0); try { // String version should required admin role - target.someOther("somestring"); + this.target.someOther("somestring"); fail("Expected AccessDeniedException"); } catch (AccessDeniedException expected) { @@ -185,6 +185,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { @Test public void supportsBooleanPointcutExpressions() { + // @formatter:off setContext("" + "" + " " + "" - + AUTH_PROVIDER_XML); - target = (BusinessService) appContext.getBean("target"); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + this.target = (BusinessService) this.appContext.getBean("target"); // String method should not be protected - target.someOther("somestring"); - + this.target.someOther("somestring"); // All others should require ROLE_USER try { - target.someOther(0); + this.target.someOther(0); fail("Expected AuthenticationCredentialsNotFoundException"); } catch (AuthenticationCredentialsNotFoundException expected) { } - - SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("user", "password")); - target.someOther(0); + SecurityContextHolder.getContext() + .setAuthentication(new UsernamePasswordAuthenticationToken("user", "password")); + this.target.someOther(0); } @Test(expected = BeanDefinitionParsingException.class) @@ -218,80 +218,81 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { // SEC-936 @Test(expected = AccessDeniedException.class) public void worksWithoutTargetOrClass() { + // @formatter:off setContext("" + "" + " " + " " - + "" + AUTH_PROVIDER_XML); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", + + "" + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_SOMEOTHERROLE")); SecurityContextHolder.getContext().setAuthentication(token); - target = (BusinessService) appContext.getBean("businessService"); - target.someUserMethod1(); + this.target = (BusinessService) this.appContext.getBean("businessService"); + this.target.someUserMethod1(); } // Expression configuration tests - @SuppressWarnings("unchecked") @Test - public void expressionVoterAndAfterInvocationProviderUseSameExpressionHandlerInstance() - throws Exception { - setContext("" - + AUTH_PROVIDER_XML); - AffirmativeBased adm = (AffirmativeBased) appContext - .getBeansOfType(AffirmativeBased.class).values().toArray()[0]; - List voters = (List) FieldUtils.getFieldValue(adm, "decisionVoters"); - PreInvocationAuthorizationAdviceVoter mev = (PreInvocationAuthorizationAdviceVoter) voters - .get(0); - MethodSecurityMetadataSourceAdvisor msi = (MethodSecurityMetadataSourceAdvisor) appContext - .getBeansOfType(MethodSecurityMetadataSourceAdvisor.class).values() + public void expressionVoterAndAfterInvocationProviderUseSameExpressionHandlerInstance() throws Exception { + setContext("" + ConfigTestUtils.AUTH_PROVIDER_XML); + AffirmativeBased adm = (AffirmativeBased) this.appContext.getBeansOfType(AffirmativeBased.class).values() .toArray()[0]; + List voters = (List) FieldUtils.getFieldValue(adm, "decisionVoters"); + PreInvocationAuthorizationAdviceVoter mev = (PreInvocationAuthorizationAdviceVoter) voters.get(0); + MethodSecurityMetadataSourceAdvisor msi = (MethodSecurityMetadataSourceAdvisor) this.appContext + .getBeansOfType(MethodSecurityMetadataSourceAdvisor.class).values().toArray()[0]; AfterInvocationProviderManager pm = (AfterInvocationProviderManager) ((MethodSecurityInterceptor) msi .getAdvice()).getAfterInvocationManager(); - PostInvocationAdviceProvider aip = (PostInvocationAdviceProvider) pm - .getProviders().get(0); - assertThat(FieldUtils.getFieldValue(mev, "preAdvice.expressionHandler")).isSameAs(FieldUtils - .getFieldValue(aip, "postAdvice.expressionHandler")); + PostInvocationAdviceProvider aip = (PostInvocationAdviceProvider) pm.getProviders().get(0); + assertThat(FieldUtils.getFieldValue(mev, "preAdvice.expressionHandler")) + .isSameAs(FieldUtils.getFieldValue(aip, "postAdvice.expressionHandler")); } @Test(expected = AccessDeniedException.class) public void accessIsDeniedForHasRoleExpression() { + // @formatter:off setContext("" + "" - + AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication(bob); - target = (BusinessService) appContext.getBean("target"); - target.someAdminMethod(); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext().setAuthentication(this.bob); + this.target = (BusinessService) this.appContext.getBean("target"); + this.target.someAdminMethod(); } @Test public void beanNameExpressionPropertyIsSupported() { + // @formatter:off setContext("" + "" + " " + "" + "" - + AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication(bob); - ExpressionProtectedBusinessServiceImpl target = (ExpressionProtectedBusinessServiceImpl) appContext + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext().setAuthentication(this.bob); + ExpressionProtectedBusinessServiceImpl target = (ExpressionProtectedBusinessServiceImpl) this.appContext .getBean("target"); target.methodWithBeanNamePropertyAccessExpression("x"); } @Test public void preAndPostFilterAnnotationsWorkWithLists() { + // @formatter:off setContext("" + "" - + AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication(bob); - target = (BusinessService) appContext.getBean("target"); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext().setAuthentication(this.bob); + this.target = (BusinessService) this.appContext.getBean("target"); List arg = new ArrayList<>(); arg.add("joe"); arg.add("bob"); arg.add("sam"); - List result = target.methodReturningAList(arg); + List result = this.target.methodReturningAList(arg); // Expression is (filterObject == name or filterObject == 'sam'), so "joe" should // be gone after pre-filter // PostFilter should remove sam from the return object @@ -301,13 +302,15 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { @Test public void prePostFilterAnnotationWorksWithArrays() { + // @formatter:off setContext("" + "" - + AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication(bob); - target = (BusinessService) appContext.getBean("target"); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext().setAuthentication(this.bob); + this.target = (BusinessService) this.appContext.getBean("target"); Object[] arg = new String[] { "joe", "bob", "sam" }; - Object[] result = target.methodReturningAnArray(arg); + Object[] result = this.target.methodReturningAnArray(arg); assertThat(result).hasSize(1); assertThat(result[0]).isEqualTo("bob"); } @@ -315,6 +318,7 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { // SEC-1392 @Test public void customPermissionEvaluatorIsSupported() { + // @formatter:off setContext("" + " " + "" @@ -322,18 +326,23 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { + " " + "" + "" - + AUTH_PROVIDER_XML); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on } // SEC-1450 @Test(expected = AuthenticationException.class) @SuppressWarnings("unchecked") public void genericsAreMatchedByProtectPointcut() { - setContext("" - + "" - + " " - + "" + AUTH_PROVIDER_XML); - Foo foo = (Foo) appContext.getBean("target"); + // @formatter:off + setContext( + "" + + "" + + " " + + "" + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + Foo foo = (Foo) this.appContext.getBean("target"); foo.foo(new SecurityConfig("A")); } @@ -341,11 +350,13 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { @Test @SuppressWarnings("unchecked") public void genericsMethodArgumentNamesAreResolved() { + // @formatter:off setContext("" + "" - + AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication(bob); - Foo foo = (Foo) appContext.getBean("target"); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext().setAuthentication(this.bob); + Foo foo = (Foo) this.appContext.getBean("target"); foo.foo(new SecurityConfig("A")); } @@ -356,114 +367,108 @@ public class GlobalMethodSecurityBeanDefinitionParserTests { props.addPropertyValue("key", "blah"); parent.registerSingleton("runAsMgr", RunAsManagerImpl.class, props); parent.refresh(); - - setContext("" - + AUTH_PROVIDER_XML, parent); - RunAsManagerImpl ram = (RunAsManagerImpl) appContext.getBean("runAsMgr"); - MethodSecurityMetadataSourceAdvisor msi = (MethodSecurityMetadataSourceAdvisor) appContext - .getBeansOfType(MethodSecurityMetadataSourceAdvisor.class).values() - .toArray()[0]; + setContext("" + ConfigTestUtils.AUTH_PROVIDER_XML, + parent); + RunAsManagerImpl ram = (RunAsManagerImpl) this.appContext.getBean("runAsMgr"); + MethodSecurityMetadataSourceAdvisor msi = (MethodSecurityMetadataSourceAdvisor) this.appContext + .getBeansOfType(MethodSecurityMetadataSourceAdvisor.class).values().toArray()[0]; assertThat(ram).isSameAs(FieldUtils.getFieldValue(msi.getAdvice(), "runAsManager")); } @Test @SuppressWarnings("unchecked") public void supportsExternalMetadataSource() { - setContext("" + // @formatter:off + setContext("" + "" - + " " + + " " + "" + "" - + AUTH_PROVIDER_XML); + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on // External MDS should take precedence over PreAuthorize - SecurityContextHolder.getContext().setAuthentication(bob); - Foo foo = (Foo) appContext.getBean("target"); + SecurityContextHolder.getContext().setAuthentication(this.bob); + Foo foo = (Foo) this.appContext.getBean("target"); try { foo.foo(new SecurityConfig("A")); fail("Bob can't invoke admin methods"); } catch (AccessDeniedException expected) { } - SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("admin", "password")); + SecurityContextHolder.getContext() + .setAuthentication(new UsernamePasswordAuthenticationToken("admin", "password")); foo.foo(new SecurityConfig("A")); } @Test public void supportsCustomAuthenticationManager() { - setContext("" + // @formatter:off + setContext("" + "" - + " " + + " " + "" + "" + "" - + " " + "" - + AUTH_PROVIDER_XML); - SecurityContextHolder.getContext().setAuthentication(bob); - Foo foo = (Foo) appContext.getBean("target"); + + " " + + "" + + ConfigTestUtils.AUTH_PROVIDER_XML); + // @formatter:on + SecurityContextHolder.getContext().setAuthentication(this.bob); + Foo foo = (Foo) this.appContext.getBean("target"); try { foo.foo(new SecurityConfig("A")); fail("Bob can't invoke admin methods"); } catch (AccessDeniedException expected) { } - SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("admin", "password")); + SecurityContextHolder.getContext() + .setAuthentication(new UsernamePasswordAuthenticationToken("admin", "password")); foo.foo(new SecurityConfig("A")); } - static class CustomAuthManager implements AuthenticationManager, - ApplicationContextAware { + private void setContext(String context) { + this.appContext = new InMemoryXmlApplicationContext(context); + } + + private void setContext(String context, ApplicationContext parent) { + this.appContext = new InMemoryXmlApplicationContext(context, parent); + } + + static class CustomAuthManager implements AuthenticationManager, ApplicationContextAware { + private String beanName; + private AuthenticationManager authenticationManager; CustomAuthManager(String beanName) { this.beanName = beanName; } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { - return authenticationManager.authenticate(authentication); + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + return this.authenticationManager.authenticate(authentication); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.context.ApplicationContextAware#setApplicationContext(org - * .springframework.context.ApplicationContext) - */ - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - this.authenticationManager = applicationContext.getBean(beanName, - AuthenticationManager.class); + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.authenticationManager = applicationContext.getBean(this.beanName, AuthenticationManager.class); } - } - private void setContext(String context) { - appContext = new InMemoryXmlApplicationContext(context); - } - - private void setContext(String context, ApplicationContext parent) { - appContext = new InMemoryXmlApplicationContext(context, parent); } interface Foo { + void foo(T action); + } public static class ConcreteFoo implements Foo { + + @Override @PreAuthorize("#action.attribute == 'A'") public void foo(SecurityConfig action) { } + } } diff --git a/config/src/test/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecoratorTests.java b/config/src/test/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecoratorTests.java index 087d0e1bdc..0316daf0fe 100644 --- a/config/src/test/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecoratorTests.java +++ b/config/src/test/java/org/springframework/security/config/method/InterceptMethodsBeanDefinitionDecoratorTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; -import static org.assertj.core.api.Assertions.*; - -import org.junit.*; +import org.junit.After; +import org.junit.BeforeClass; +import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -37,19 +39,23 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Luke Taylor */ @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(locations = "classpath:org/springframework/security/config/method-security.xml") -public class InterceptMethodsBeanDefinitionDecoratorTests implements - ApplicationContextAware { +public class InterceptMethodsBeanDefinitionDecoratorTests implements ApplicationContextAware { + @Autowired @Qualifier("target") private TestBusinessBean target; + @Autowired @Qualifier("transactionalTarget") private TestBusinessBean transactionalTarget; + private ApplicationContext appContext; @BeforeClass @@ -65,50 +71,46 @@ public class InterceptMethodsBeanDefinitionDecoratorTests implements @Test public void targetDoesntLoseApplicationListenerInterface() { - assertThat(appContext.getBeansOfType(ApplicationListener.class)).hasSize(1); - assertThat(appContext.getBeanNamesForType(ApplicationListener.class)).hasSize(1); - appContext.publishEvent(new AuthenticationSuccessEvent( - new TestingAuthenticationToken("user", ""))); - - assertThat(target).isInstanceOf(ApplicationListener.class); + assertThat(this.appContext.getBeansOfType(ApplicationListener.class)).hasSize(1); + assertThat(this.appContext.getBeanNamesForType(ApplicationListener.class)).hasSize(1); + this.appContext.publishEvent(new AuthenticationSuccessEvent(new TestingAuthenticationToken("user", ""))); + assertThat(this.target).isInstanceOf(ApplicationListener.class); } @Test public void targetShouldAllowUnprotectedMethodInvocationWithNoContext() { - target.unprotected(); + this.target.unprotected(); } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void targetShouldPreventProtectedMethodInvocationWithNoContext() { - target.doSomething(); + this.target.doSomething(); } @Test public void targetShouldAllowProtectedMethodInvocationWithCorrectRole() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ROLE_USER")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContextHolder.getContext().setAuthentication(token); - - target.doSomething(); + this.target.doSomething(); } @Test(expected = AccessDeniedException.class) public void targetShouldPreventProtectedMethodInvocationWithIncorrectRole() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_SOMEOTHERROLE")); SecurityContextHolder.getContext().setAuthentication(token); - - target.doSomething(); + this.target.doSomething(); } @Test(expected = AuthenticationException.class) public void transactionalMethodsShouldBeSecured() { - transactionalTarget.doSomething(); + this.transactionalTarget.doSomething(); } - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.appContext = applicationContext; } + } diff --git a/config/src/test/java/org/springframework/security/config/method/Jsr250AnnotationDrivenBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/method/Jsr250AnnotationDrivenBeanDefinitionParserTests.java index c84bd208cb..571868f184 100644 --- a/config/src/test/java/org/springframework/security/config/method/Jsr250AnnotationDrivenBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/method/Jsr250AnnotationDrivenBeanDefinitionParserTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.annotation.BusinessService; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; @@ -31,66 +33,65 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Luke Taylor */ public class Jsr250AnnotationDrivenBeanDefinitionParserTests { + private InMemoryXmlApplicationContext appContext; private BusinessService target; @Before public void loadContext() { - appContext = new InMemoryXmlApplicationContext( + // @formatter:off + this.appContext = new InMemoryXmlApplicationContext( "" + "" + ConfigTestUtils.AUTH_PROVIDER_XML); - target = (BusinessService) appContext.getBean("target"); + // @formatter:on + this.target = (BusinessService) this.appContext.getBean("target"); } @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); + if (this.appContext != null) { + this.appContext.close(); } SecurityContextHolder.clearContext(); } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void targetShouldPreventProtectedMethodInvocationWithNoContext() { - target.someUserMethod1(); + this.target.someUserMethod1(); } @Test public void permitAllShouldBeDefaultAttribute() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ROLE_USER")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContextHolder.getContext().setAuthentication(token); - - target.someOther(0); + this.target.someOther(0); } @Test public void targetShouldAllowProtectedMethodInvocationWithCorrectRole() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ROLE_USER")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContextHolder.getContext().setAuthentication(token); - - target.someUserMethod1(); + this.target.someUserMethod1(); } @Test(expected = AccessDeniedException.class) public void targetShouldPreventProtectedMethodInvocationWithIncorrectRole() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_SOMEOTHERROLE")); SecurityContextHolder.getContext().setAuthentication(token); - - target.someAdminMethod(); + this.target.someAdminMethod(); } @Test public void hasAnyRoleAddsDefaultPrefix() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ROLE_USER")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContextHolder.getContext().setAuthentication(token); - - target.rolesAllowedUser(); + this.target.rolesAllowedUser(); } + } diff --git a/config/src/test/java/org/springframework/security/config/method/PreAuthorizeAdminRole.java b/config/src/test/java/org/springframework/security/config/method/PreAuthorizeAdminRole.java index 46c2d4cefb..0401ff3ebf 100644 --- a/config/src/test/java/org/springframework/security/config/method/PreAuthorizeAdminRole.java +++ b/config/src/test/java/org/springframework/security/config/method/PreAuthorizeAdminRole.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.lang.annotation.Retention; diff --git a/config/src/test/java/org/springframework/security/config/method/PreAuthorizeServiceImpl.java b/config/src/test/java/org/springframework/security/config/method/PreAuthorizeServiceImpl.java index 48e9e24ba2..bd3721a5db 100644 --- a/config/src/test/java/org/springframework/security/config/method/PreAuthorizeServiceImpl.java +++ b/config/src/test/java/org/springframework/security/config/method/PreAuthorizeServiceImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; /** @@ -22,9 +23,11 @@ package org.springframework.security.config.method; public class PreAuthorizeServiceImpl { @PreAuthorizeAdminRole - public void preAuthorizeAdminRole() {} + public void preAuthorizeAdminRole() { + } @ContactPermission - public void contactPermission(Contact contact) {} + public void contactPermission(Contact contact) { + } -} \ No newline at end of file +} diff --git a/config/src/test/java/org/springframework/security/config/method/PreAuthorizeTests.java b/config/src/test/java/org/springframework/security/config/method/PreAuthorizeTests.java index 676c31314c..91c10958e3 100644 --- a/config/src/test/java/org/springframework/security/config/method/PreAuthorizeTests.java +++ b/config/src/test/java/org/springframework/security/config/method/PreAuthorizeTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -26,13 +28,13 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; /** - * * @author Rob Winch * */ @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration public class PreAuthorizeTests { + @Autowired PreAuthorizeServiceImpl service; @@ -43,25 +45,30 @@ public class PreAuthorizeTests { @Test(expected = AccessDeniedException.class) public void preAuthorizeAdminRoleDenied() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_USER")); - service.preAuthorizeAdminRole(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_USER")); + this.service.preAuthorizeAdminRole(); } @Test public void preAuthorizeAdminRoleGranted() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); - service.preAuthorizeAdminRole(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); + this.service.preAuthorizeAdminRole(); } @Test public void preAuthorizeContactPermissionGranted() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); - service.contactPermission(new Contact("user")); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); + this.service.contactPermission(new Contact("user")); } @Test(expected = AccessDeniedException.class) public void preAuthorizeContactPermissionDenied() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); - service.contactPermission(new Contact("admin")); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); + this.service.contactPermission(new Contact("admin")); } -} \ No newline at end of file + +} diff --git a/config/src/test/java/org/springframework/security/config/method/Sec2196Tests.java b/config/src/test/java/org/springframework/security/config/method/Sec2196Tests.java index 6e1c38aade..7814a7f7d6 100644 --- a/config/src/test/java/org/springframework/security/config/method/Sec2196Tests.java +++ b/config/src/test/java/org/springframework/security/config/method/Sec2196Tests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import org.junit.After; import org.junit.Test; + import org.springframework.context.ConfigurableApplicationContext; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.prepost.PreAuthorize; @@ -36,10 +38,9 @@ public class Sec2196Tests { public void genericMethodsProtected() { loadContext("" + ""); - - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("test", "pass", "ROLE_USER")); - Service service = context.getBean(Service.class); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("test", "pass", "ROLE_USER")); + Service service = this.context.getBean(Service.class); service.save(new User()); } @@ -47,10 +48,9 @@ public class Sec2196Tests { public void genericMethodsAllowed() { loadContext("" + ""); - - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("test", "pass", "saveUsers")); - Service service = context.getBean(Service.class); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("test", "pass", "saveUsers")); + Service service = this.context.getBean(Service.class); service.save(new User()); } @@ -60,20 +60,24 @@ public class Sec2196Tests { @After public void closeAppContext() { - if (context != null) { - context.close(); - context = null; + if (this.context != null) { + this.context.close(); + this.context = null; } SecurityContextHolder.clearContext(); } public static class Service { + @PreAuthorize("hasAuthority('saveUsers')") public T save(T dto) { return dto; } + } static class User { + } + } diff --git a/config/src/test/java/org/springframework/security/config/method/SecuredAdminRole.java b/config/src/test/java/org/springframework/security/config/method/SecuredAdminRole.java index 8140450c86..629dcaa7b2 100644 --- a/config/src/test/java/org/springframework/security/config/method/SecuredAdminRole.java +++ b/config/src/test/java/org/springframework/security/config/method/SecuredAdminRole.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.lang.annotation.Retention; @@ -26,4 +27,6 @@ import org.springframework.security.access.annotation.Secured; */ @Retention(RetentionPolicy.RUNTIME) @Secured("ROLE_ADMIN") -public @interface SecuredAdminRole { } \ No newline at end of file +public @interface SecuredAdminRole { + +} diff --git a/config/src/test/java/org/springframework/security/config/method/SecuredAnnotationDrivenBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/method/SecuredAnnotationDrivenBeanDefinitionParserTests.java index 0f9af88507..c7f1cc54a2 100644 --- a/config/src/test/java/org/springframework/security/config/method/SecuredAnnotationDrivenBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/method/SecuredAnnotationDrivenBeanDefinitionParserTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.io.ByteArrayInputStream; @@ -24,6 +25,7 @@ import java.io.ObjectOutputStream; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.annotation.BusinessService; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; @@ -38,6 +40,7 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Ben Alex */ public class SecuredAnnotationDrivenBeanDefinitionParserTests { + private InMemoryXmlApplicationContext appContext; private BusinessService target; @@ -45,78 +48,71 @@ public class SecuredAnnotationDrivenBeanDefinitionParserTests { @Before public void loadContext() { SecurityContextHolder.clearContext(); - appContext = new InMemoryXmlApplicationContext( + this.appContext = new InMemoryXmlApplicationContext( "" + "" + ConfigTestUtils.AUTH_PROVIDER_XML); - target = (BusinessService) appContext.getBean("target"); + this.target = (BusinessService) this.appContext.getBean("target"); } @After public void closeAppContext() { - if (appContext != null) { - appContext.close(); + if (this.appContext != null) { + this.appContext.close(); } SecurityContextHolder.clearContext(); } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void targetShouldPreventProtectedMethodInvocationWithNoContext() { - target.someUserMethod1(); + this.target.someUserMethod1(); } @Test public void targetShouldAllowProtectedMethodInvocationWithCorrectRole() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ROLE_USER")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContextHolder.getContext().setAuthentication(token); - - target.someUserMethod1(); + this.target.someUserMethod1(); } @Test(expected = AccessDeniedException.class) public void targetShouldPreventProtectedMethodInvocationWithIncorrectRole() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ROLE_SOMEOTHER")); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ROLE_SOMEOTHER")); SecurityContextHolder.getContext().setAuthentication(token); - - target.someAdminMethod(); + this.target.someAdminMethod(); } // SEC-1387 @Test(expected = AuthenticationCredentialsNotFoundException.class) public void targetIsSerializableBeforeUse() throws Exception { - BusinessService chompedTarget = (BusinessService) serializeAndDeserialize(target); + BusinessService chompedTarget = (BusinessService) serializeAndDeserialize(this.target); chompedTarget.someAdminMethod(); } @Test(expected = AccessDeniedException.class) public void targetIsSerializableAfterUse() throws Exception { try { - target.someAdminMethod(); + this.target.someAdminMethod(); } catch (AuthenticationCredentialsNotFoundException expected) { } - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("u", "p", "ROLE_A")); - - BusinessService chompedTarget = (BusinessService) serializeAndDeserialize(target); + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("u", "p", "ROLE_A")); + BusinessService chompedTarget = (BusinessService) serializeAndDeserialize(this.target); chompedTarget.someAdminMethod(); } - private Object serializeAndDeserialize(Object o) throws IOException, - ClassNotFoundException { + private Object serializeAndDeserialize(Object o) throws IOException, ClassNotFoundException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(o); oos.flush(); baos.flush(); byte[] bytes = baos.toByteArray(); - ByteArrayInputStream is = new ByteArrayInputStream(bytes); ObjectInputStream ois = new ObjectInputStream(is); Object o2 = ois.readObject(); - return o2; } diff --git a/config/src/test/java/org/springframework/security/config/method/SecuredServiceImpl.java b/config/src/test/java/org/springframework/security/config/method/SecuredServiceImpl.java index 5dd48e224b..582f51dd0d 100644 --- a/config/src/test/java/org/springframework/security/config/method/SecuredServiceImpl.java +++ b/config/src/test/java/org/springframework/security/config/method/SecuredServiceImpl.java @@ -13,14 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; /** -* -* @author Rob Winch -* -*/ + * @author Rob Winch + * + */ public class SecuredServiceImpl { + @SecuredAdminRole - public void securedAdminRole() {} + public void securedAdminRole() { + } + } diff --git a/config/src/test/java/org/springframework/security/config/method/SecuredTests.java b/config/src/test/java/org/springframework/security/config/method/SecuredTests.java index 01788afc0b..607b164c67 100644 --- a/config/src/test/java/org/springframework/security/config/method/SecuredTests.java +++ b/config/src/test/java/org/springframework/security/config/method/SecuredTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -26,13 +28,13 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; /** - * * @author Rob Winch * */ @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration public class SecuredTests { + @Autowired SecuredServiceImpl service; @@ -43,13 +45,16 @@ public class SecuredTests { @Test(expected = AccessDeniedException.class) public void securedAdminRoleDenied() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_USER")); - service.securedAdminRole(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_USER")); + this.service.securedAdminRole(); } @Test public void securedAdminRoleGranted() { - SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); - service.securedAdminRole(); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_ADMIN")); + this.service.securedAdminRole(); } -} \ No newline at end of file + +} diff --git a/config/src/test/java/org/springframework/security/config/method/TestPermissionEvaluator.java b/config/src/test/java/org/springframework/security/config/method/TestPermissionEvaluator.java index 087250407b..0d6bbe37d8 100644 --- a/config/src/test/java/org/springframework/security/config/method/TestPermissionEvaluator.java +++ b/config/src/test/java/org/springframework/security/config/method/TestPermissionEvaluator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method; import java.io.Serializable; @@ -22,13 +23,14 @@ import org.springframework.security.core.Authentication; public class TestPermissionEvaluator implements PermissionEvaluator { - public boolean hasPermission(Authentication authentication, - Object targetDomainObject, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { return false; } - public boolean hasPermission(Authentication authentication, Serializable targetId, - String targetType, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, + Object permission) { return false; } diff --git a/config/src/test/java/org/springframework/security/config/method/configuration/Gh4020GlobalMethodSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/method/configuration/Gh4020GlobalMethodSecurityConfigurationTests.java index 4dcb611c12..ac1bc7b814 100644 --- a/config/src/test/java/org/springframework/security/config/method/configuration/Gh4020GlobalMethodSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/method/configuration/Gh4020GlobalMethodSecurityConfigurationTests.java @@ -39,6 +39,7 @@ import static org.mockito.Mockito.mock; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration public class Gh4020GlobalMethodSecurityConfigurationTests { + @Autowired DenyAllService denyAll; @@ -51,6 +52,7 @@ public class Gh4020GlobalMethodSecurityConfigurationTests { @Configuration @EnableGlobalMethodSecurity(prePostEnabled = true) static class SecurityConfig { + @Bean PermissionEvaluator permissionEvaluator() { return mock(PermissionEvaluator.class); @@ -68,19 +70,25 @@ public class Gh4020GlobalMethodSecurityConfigurationTests { @Autowired DenyAllService denyAll; + } @Configuration static class ServiceConfig { + @Bean DenyAllService denyAllService() { return new DenyAllService(); } + } @PreAuthorize("denyAll") static class DenyAllService { + void denyAll() { } + } + } diff --git a/config/src/test/java/org/springframework/security/config/method/sec2136/JpaPermissionEvaluator.java b/config/src/test/java/org/springframework/security/config/method/sec2136/JpaPermissionEvaluator.java index d664cd6956..d186330f0e 100644 --- a/config/src/test/java/org/springframework/security/config/method/sec2136/JpaPermissionEvaluator.java +++ b/config/src/test/java/org/springframework/security/config/method/sec2136/JpaPermissionEvaluator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method.sec2136; import java.io.Serializable; @@ -24,11 +25,11 @@ import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.core.Authentication; /** - * * @author Rob Winch * */ public class JpaPermissionEvaluator implements PermissionEvaluator { + @Autowired private EntityManager entityManager; @@ -36,13 +37,15 @@ public class JpaPermissionEvaluator implements PermissionEvaluator { System.out.println("initializing " + this); } - public boolean hasPermission(Authentication authentication, - Object targetDomainObject, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { return true; } - public boolean hasPermission(Authentication authentication, Serializable targetId, - String targetType, Object permission) { + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, + Object permission) { return true; } + } diff --git a/config/src/test/java/org/springframework/security/config/method/sec2136/Sec2136Tests.java b/config/src/test/java/org/springframework/security/config/method/sec2136/Sec2136Tests.java index e921086911..4951e56206 100644 --- a/config/src/test/java/org/springframework/security/config/method/sec2136/Sec2136Tests.java +++ b/config/src/test/java/org/springframework/security/config/method/sec2136/Sec2136Tests.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method.sec2136; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; /** - * * @author Rob Winch * @since 3.2 */ @@ -31,6 +32,6 @@ public class Sec2136Tests { @Test public void configurationLoads() { - } + } diff --git a/config/src/test/java/org/springframework/security/config/method/sec2499/Sec2499Tests.java b/config/src/test/java/org/springframework/security/config/method/sec2499/Sec2499Tests.java index 2fb5200bbb..745c60eaa5 100644 --- a/config/src/test/java/org/springframework/security/config/method/sec2499/Sec2499Tests.java +++ b/config/src/test/java/org/springframework/security/config/method/sec2499/Sec2499Tests.java @@ -13,39 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.method.sec2499; import org.junit.After; import org.junit.Test; + import org.springframework.context.support.GenericXmlApplicationContext; /** - * * @author Rob Winch * */ public class Sec2499Tests { + private GenericXmlApplicationContext parent; private GenericXmlApplicationContext child; @After public void cleanup() { - if (parent != null) { - parent.close(); + if (this.parent != null) { + this.parent.close(); } - if (child != null) { - child.close(); + if (this.child != null) { + this.child.close(); } } @Test public void methodExpressionHandlerInParentContextLoads() { - parent = new GenericXmlApplicationContext( - "org/springframework/security/config/method/sec2499/parent.xml"); - child = new GenericXmlApplicationContext(); - child.load("org/springframework/security/config/method/sec2499/child.xml"); - child.setParent(parent); - child.refresh(); + this.parent = new GenericXmlApplicationContext("org/springframework/security/config/method/sec2499/parent.xml"); + this.child = new GenericXmlApplicationContext(); + this.child.load("org/springframework/security/config/method/sec2499/child.xml"); + this.child.setParent(this.parent); + this.child.refresh(); } -} \ No newline at end of file + +} diff --git a/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java index 308fa95b62..74cfdb5672 100644 --- a/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.oauth2.client; import okhttp3.mockwebserver.MockResponse; @@ -20,6 +21,7 @@ import okhttp3.mockwebserver.MockWebServer; import org.junit.After; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -41,27 +43,30 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Ruby Hartono */ public class ClientRegistrationsBeanDefinitionParserTests { + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests"; - private static final String ISSUER_URI_XML_CONFIG = "\n" + - "\n" + - "\t\n" + - "\t\t\n" + - "\t\t\n" + - "\t\n" + - "\n" + - "\n"; + // @formatter:off + private static final String ISSUER_URI_XML_CONFIG = "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "\n" + + "\n"; + // @formatter:on - private static final String OIDC_DISCOVERY_RESPONSE = - "{\n" + // @formatter:off + private static final String OIDC_DISCOVERY_RESPONSE = "{\n" + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" + " \"claims_supported\": [\n" + " \"aud\", \n" @@ -114,6 +119,7 @@ public class ClientRegistrationsBeanDefinitionParserTests { + " ], \n" + " \"userinfo_endpoint\": \"https://example.com/oauth2/v3/userinfo\"\n" + "}"; + // @formatter:on @Autowired private ClientRegistrationRepository clientRegistrationRepository; @@ -135,16 +141,12 @@ public class ClientRegistrationsBeanDefinitionParserTests { this.server = new MockWebServer(); this.server.start(); String serverUrl = this.server.url("/").toString(); - String discoveryResponse = OIDC_DISCOVERY_RESPONSE.replace("${issuer-uri}", serverUrl); this.server.enqueue(jsonResponse(discoveryResponse)); - String contextConfig = ISSUER_URI_XML_CONFIG.replace("${issuer-uri}", serverUrl); this.spring.context(contextConfig).autowire(); - - assertThat(clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class); - - ClientRegistration googleRegistration = clientRegistrationRepository.findByRegistrationId("google-login"); + assertThat(this.clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class); + ClientRegistration googleRegistration = this.clientRegistrationRepository.findByRegistrationId("google-login"); assertThat(googleRegistration).isNotNull(); assertThat(googleRegistration.getRegistrationId()).isEqualTo("google-login"); assertThat(googleRegistration.getClientId()).isEqualTo("google-client-id"); @@ -154,7 +156,6 @@ public class ClientRegistrationsBeanDefinitionParserTests { assertThat(googleRegistration.getRedirectUri()).isEqualTo("{baseUrl}/{action}/oauth2/code/{registrationId}"); assertThat(googleRegistration.getScopes()).isNull(); assertThat(googleRegistration.getClientName()).isEqualTo(serverUrl); - ProviderDetails googleProviderDetails = googleRegistration.getProviderDetails(); assertThat(googleProviderDetails).isNotNull(); assertThat(googleProviderDetails.getAuthorizationUri()).isEqualTo("https://example.com/o/oauth2/v2/auth"); @@ -170,11 +171,10 @@ public class ClientRegistrationsBeanDefinitionParserTests { @Test public void parseWhenMultipleClientsConfiguredThenAvailableInRepository() { - this.spring.configLocations(this.xml("MultiClientRegistration")).autowire(); - - assertThat(clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class); - - ClientRegistration googleRegistration = clientRegistrationRepository.findByRegistrationId("google-login"); + this.spring.configLocations(ClientRegistrationsBeanDefinitionParserTests.xml("MultiClientRegistration")) + .autowire(); + assertThat(this.clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class); + ClientRegistration googleRegistration = this.clientRegistrationRepository.findByRegistrationId("google-login"); assertThat(googleRegistration).isNotNull(); assertThat(googleRegistration.getRegistrationId()).isEqualTo("google-login"); assertThat(googleRegistration.getClientId()).isEqualTo("google-client-id"); @@ -182,9 +182,9 @@ public class ClientRegistrationsBeanDefinitionParserTests { assertThat(googleRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); assertThat(googleRegistration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(googleRegistration.getRedirectUri()).isEqualTo("{baseUrl}/login/oauth2/code/{registrationId}"); - assertThat(googleRegistration.getScopes()).isEqualTo(StringUtils.commaDelimitedListToSet("openid,profile,email")); + assertThat(googleRegistration.getScopes()) + .isEqualTo(StringUtils.commaDelimitedListToSet("openid,profile,email")); assertThat(googleRegistration.getClientName()).isEqualTo("Google"); - ProviderDetails googleProviderDetails = googleRegistration.getProviderDetails(); assertThat(googleProviderDetails).isNotNull(); assertThat(googleProviderDetails.getAuthorizationUri()) @@ -197,8 +197,7 @@ public class ClientRegistrationsBeanDefinitionParserTests { assertThat(googleProviderDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo("sub"); assertThat(googleProviderDetails.getJwkSetUri()).isEqualTo("https://www.googleapis.com/oauth2/v3/certs"); assertThat(googleProviderDetails.getIssuerUri()).isEqualTo("https://accounts.google.com"); - - ClientRegistration githubRegistration = clientRegistrationRepository.findByRegistrationId("github-login"); + ClientRegistration githubRegistration = this.clientRegistrationRepository.findByRegistrationId("github-login"); assertThat(githubRegistration).isNotNull(); assertThat(githubRegistration.getRegistrationId()).isEqualTo("github-login"); assertThat(githubRegistration.getClientId()).isEqualTo("github-client-id"); @@ -206,9 +205,9 @@ public class ClientRegistrationsBeanDefinitionParserTests { assertThat(githubRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); assertThat(githubRegistration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(githubRegistration.getRedirectUri()).isEqualTo("{baseUrl}/login/oauth2/code/{registrationId}"); - assertThat(googleRegistration.getScopes()).isEqualTo(StringUtils.commaDelimitedListToSet("openid,profile,email")); + assertThat(googleRegistration.getScopes()) + .isEqualTo(StringUtils.commaDelimitedListToSet("openid,profile,email")); assertThat(githubRegistration.getClientName()).isEqualTo("Github"); - ProviderDetails githubProviderDetails = githubRegistration.getProviderDetails(); assertThat(githubProviderDetails).isNotNull(); assertThat(githubProviderDetails.getAuthorizationUri()).isEqualTo("https://github.com/login/oauth/authorize"); @@ -220,12 +219,11 @@ public class ClientRegistrationsBeanDefinitionParserTests { } private static MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } private static String xml(String configName) { return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml"; } + } diff --git a/config/src/test/java/org/springframework/security/config/oauth2/client/CommonOAuth2ProviderTests.java b/config/src/test/java/org/springframework/security/config/oauth2/client/CommonOAuth2ProviderTests.java index ad18946aca..65d4ffd53a 100644 --- a/config/src/test/java/org/springframework/security/config/oauth2/client/CommonOAuth2ProviderTests.java +++ b/config/src/test/java/org/springframework/security/config/oauth2/client/CommonOAuth2ProviderTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.oauth2.client; import org.junit.Test; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -37,22 +39,15 @@ public class CommonOAuth2ProviderTests { public void getBuilderWhenGoogleShouldHaveGoogleSettings() { ClientRegistration registration = build(CommonOAuth2Provider.GOOGLE); ProviderDetails providerDetails = registration.getProviderDetails(); - assertThat(providerDetails.getAuthorizationUri()) - .isEqualTo("https://accounts.google.com/o/oauth2/v2/auth"); - assertThat(providerDetails.getTokenUri()) - .isEqualTo("https://www.googleapis.com/oauth2/v4/token"); + assertThat(providerDetails.getAuthorizationUri()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth"); + assertThat(providerDetails.getTokenUri()).isEqualTo("https://www.googleapis.com/oauth2/v4/token"); assertThat(providerDetails.getUserInfoEndpoint().getUri()) - .isEqualTo("https://www.googleapis.com/oauth2/v3/userinfo"); - assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()) - .isEqualTo(IdTokenClaimNames.SUB); - assertThat(providerDetails.getJwkSetUri()) - .isEqualTo("https://www.googleapis.com/oauth2/v3/certs"); - assertThat(providerDetails.getIssuerUri()) - .isEqualTo("https://accounts.google.com"); - assertThat(registration.getClientAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC); - assertThat(registration.getAuthorizationGrantType()) - .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + .isEqualTo("https://www.googleapis.com/oauth2/v3/userinfo"); + assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo(IdTokenClaimNames.SUB); + assertThat(providerDetails.getJwkSetUri()).isEqualTo("https://www.googleapis.com/oauth2/v3/certs"); + assertThat(providerDetails.getIssuerUri()).isEqualTo("https://accounts.google.com"); + assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); + assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(registration.getRedirectUri()).isEqualTo(DEFAULT_REDIRECT_URL); assertThat(registration.getScopes()).containsOnly("openid", "profile", "email"); assertThat(registration.getClientName()).isEqualTo("Google"); @@ -63,19 +58,13 @@ public class CommonOAuth2ProviderTests { public void getBuilderWhenGitHubShouldHaveGitHubSettings() { ClientRegistration registration = build(CommonOAuth2Provider.GITHUB); ProviderDetails providerDetails = registration.getProviderDetails(); - assertThat(providerDetails.getAuthorizationUri()) - .isEqualTo("https://github.com/login/oauth/authorize"); - assertThat(providerDetails.getTokenUri()) - .isEqualTo("https://github.com/login/oauth/access_token"); - assertThat(providerDetails.getUserInfoEndpoint().getUri()) - .isEqualTo("https://api.github.com/user"); - assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()) - .isEqualTo("id"); + assertThat(providerDetails.getAuthorizationUri()).isEqualTo("https://github.com/login/oauth/authorize"); + assertThat(providerDetails.getTokenUri()).isEqualTo("https://github.com/login/oauth/access_token"); + assertThat(providerDetails.getUserInfoEndpoint().getUri()).isEqualTo("https://api.github.com/user"); + assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo("id"); assertThat(providerDetails.getJwkSetUri()).isNull(); - assertThat(registration.getClientAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC); - assertThat(registration.getAuthorizationGrantType()) - .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); + assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(registration.getRedirectUri()).isEqualTo(DEFAULT_REDIRECT_URL); assertThat(registration.getScopes()).containsOnly("read:user"); assertThat(registration.getClientName()).isEqualTo("GitHub"); @@ -86,19 +75,14 @@ public class CommonOAuth2ProviderTests { public void getBuilderWhenFacebookShouldHaveFacebookSettings() { ClientRegistration registration = build(CommonOAuth2Provider.FACEBOOK); ProviderDetails providerDetails = registration.getProviderDetails(); - assertThat(providerDetails.getAuthorizationUri()) - .isEqualTo("https://www.facebook.com/v2.8/dialog/oauth"); - assertThat(providerDetails.getTokenUri()) - .isEqualTo("https://graph.facebook.com/v2.8/oauth/access_token"); + assertThat(providerDetails.getAuthorizationUri()).isEqualTo("https://www.facebook.com/v2.8/dialog/oauth"); + assertThat(providerDetails.getTokenUri()).isEqualTo("https://graph.facebook.com/v2.8/oauth/access_token"); assertThat(providerDetails.getUserInfoEndpoint().getUri()) - .isEqualTo("https://graph.facebook.com/me?fields=id,name,email"); - assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()) - .isEqualTo("id"); + .isEqualTo("https://graph.facebook.com/me?fields=id,name,email"); + assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo("id"); assertThat(providerDetails.getJwkSetUri()).isNull(); - assertThat(registration.getClientAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.POST); - assertThat(registration.getAuthorizationGrantType()) - .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.POST); + assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(registration.getRedirectUri()).isEqualTo(DEFAULT_REDIRECT_URL); assertThat(registration.getScopes()).containsOnly("public_profile", "email"); assertThat(registration.getClientName()).isEqualTo("Facebook"); @@ -108,22 +92,16 @@ public class CommonOAuth2ProviderTests { @Test public void getBuilderWhenOktaShouldHaveOktaSettings() { ClientRegistration registration = builder(CommonOAuth2Provider.OKTA) - .authorizationUri("https://example.com/auth") - .tokenUri("https://example.com/token") - .userInfoUri("https://example.com/info") - .jwkSetUri("https://example.com/jwkset").build(); + .authorizationUri("https://example.com/auth").tokenUri("https://example.com/token") + .userInfoUri("https://example.com/info").jwkSetUri("https://example.com/jwkset").build(); ProviderDetails providerDetails = registration.getProviderDetails(); - assertThat(providerDetails.getAuthorizationUri()) - .isEqualTo("https://example.com/auth"); + assertThat(providerDetails.getAuthorizationUri()).isEqualTo("https://example.com/auth"); assertThat(providerDetails.getTokenUri()).isEqualTo("https://example.com/token"); assertThat(providerDetails.getUserInfoEndpoint().getUri()).isEqualTo("https://example.com/info"); - assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()) - .isEqualTo(IdTokenClaimNames.SUB); + assertThat(providerDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo(IdTokenClaimNames.SUB); assertThat(providerDetails.getJwkSetUri()).isEqualTo("https://example.com/jwkset"); - assertThat(registration.getClientAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC); - assertThat(registration.getAuthorizationGrantType()) - .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); + assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(registration.getRedirectUri()).isEqualTo(DEFAULT_REDIRECT_URL); assertThat(registration.getScopes()).containsOnly("openid", "profile", "email"); assertThat(registration.getClientName()).isEqualTo("Okta"); @@ -135,9 +113,7 @@ public class CommonOAuth2ProviderTests { } private ClientRegistration.Builder builder(CommonOAuth2Provider provider) { - return provider.getBuilder("123") - .clientId("abcd") - .clientSecret("secret"); + return provider.getBuilder("123").clientId("abcd").clientSecret("secret"); } } diff --git a/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceITests.java b/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceITests.java index 5548c5df46..d1c937dd65 100644 --- a/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceITests.java +++ b/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceITests.java @@ -16,9 +16,9 @@ package org.springframework.security.config.provisioning; - import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -26,7 +26,7 @@ import org.springframework.security.provisioning.UserDetailsManager; import org.springframework.security.util.InMemoryResource; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -34,19 +34,23 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; */ @RunWith(SpringRunner.class) public class UserDetailsManagerResourceFactoryBeanPropertiesResourceITests { + @Autowired UserDetailsManager users; @Test public void loadUserByUsernameWhenUserFoundThenNotNull() { - assertThat(users.loadUserByUsername("user")).isNotNull(); + assertThat(this.users.loadUserByUsername("user")).isNotNull(); } @Configuration static class Config { + @Bean - public UserDetailsManagerResourceFactoryBean userDetailsService() { + UserDetailsManagerResourceFactoryBean userDetailsService() { return UserDetailsManagerResourceFactoryBean.fromResource(new InMemoryResource("user=password,ROLE_USER")); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceLocationITests.java b/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceLocationITests.java index 3c524961ce..5138dd984d 100644 --- a/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceLocationITests.java +++ b/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanPropertiesResourceLocationITests.java @@ -16,16 +16,16 @@ package org.springframework.security.config.provisioning; - import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.provisioning.UserDetailsManager; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -33,19 +33,23 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; */ @RunWith(SpringRunner.class) public class UserDetailsManagerResourceFactoryBeanPropertiesResourceLocationITests { + @Autowired UserDetailsManager users; @Test public void loadUserByUsernameWhenUserFoundThenNotNull() { - assertThat(users.loadUserByUsername("user")).isNotNull(); + assertThat(this.users.loadUserByUsername("user")).isNotNull(); } @Configuration static class Config { + @Bean - public UserDetailsManagerResourceFactoryBean userDetailsService() { + UserDetailsManagerResourceFactoryBean userDetailsService() { return UserDetailsManagerResourceFactoryBean.fromResourceLocation("classpath:users.properties"); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanStringITests.java b/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanStringITests.java index 1a3cab08f1..af4fccbd23 100644 --- a/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanStringITests.java +++ b/config/src/test/java/org/springframework/security/config/provisioning/UserDetailsManagerResourceFactoryBeanStringITests.java @@ -16,16 +16,16 @@ package org.springframework.security.config.provisioning; - import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.provisioning.UserDetailsManager; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -33,6 +33,7 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; */ @RunWith(SpringRunner.class) public class UserDetailsManagerResourceFactoryBeanStringITests { + @Autowired UserDetailsManager users; @@ -43,9 +44,12 @@ public class UserDetailsManagerResourceFactoryBeanStringITests { @Configuration static class Config { + @Bean - public UserDetailsManagerResourceFactoryBean userDetailsService() { + UserDetailsManagerResourceFactoryBean userDetailsService() { return UserDetailsManagerResourceFactoryBean.fromString("user=password,ROLE_USER"); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java b/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java index 09d580acc7..fc4182deb0 100644 --- a/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java +++ b/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java @@ -16,9 +16,19 @@ package org.springframework.security.config.test; +import java.io.Closeable; +import java.util.ArrayList; +import java.util.List; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; import org.springframework.mock.web.MockServletConfig; import org.springframework.mock.web.MockServletContext; +import org.springframework.security.config.BeanIds; import org.springframework.security.config.util.InMemoryXmlWebApplicationContext; import org.springframework.test.context.web.GenericXmlWebContextLoader; import org.springframework.test.web.servlet.MockMvc; @@ -32,15 +42,6 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import org.springframework.web.context.support.XmlWebApplicationContext; import org.springframework.web.filter.OncePerRequestFilter; -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.Closeable; -import java.util.ArrayList; -import java.util.List; - -import static org.springframework.security.config.BeanIds.SPRING_SECURITY_FILTER_CHAIN; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; /** @@ -48,6 +49,7 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv * @since 5.0 */ public class SpringTestContext implements Closeable { + private Object test; private ConfigurableWebApplicationContext context; @@ -62,7 +64,9 @@ public class SpringTestContext implements Closeable { public void close() { try { this.context.close(); - } catch(Exception e) {} + } + catch (Exception ex) { + } } public SpringTestContext context(ConfigurableWebApplicationContext context) { @@ -79,8 +83,7 @@ public class SpringTestContext implements Closeable { public SpringTestContext testConfigLocations(String... configLocations) { GenericXmlWebContextLoader loader = new GenericXmlWebContextLoader(); - String[] locations = loader.processLocations(this.test.getClass(), - configLocations); + String[] locations = loader.processLocations(this.test.getClass(), configLocations); return configLocations(locations); } @@ -100,8 +103,8 @@ public class SpringTestContext implements Closeable { public SpringTestContext mockMvcAfterSpringSecurityOk() { return addFilter(new OncePerRequestFilter() { @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) { + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) { response.setStatus(HttpServletResponse.SC_OK); } }); @@ -125,25 +128,29 @@ public class SpringTestContext implements Closeable { this.context.setServletContext(new MockServletContext()); this.context.setServletConfig(new MockServletConfig()); this.context.refresh(); - - if (this.context.containsBean(SPRING_SECURITY_FILTER_CHAIN)) { - MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.context) - .apply(springSecurity()) - .apply(new AddFilter()).build(); - this.context.getBeanFactory() - .registerResolvableDependency(MockMvc.class, mockMvc); + if (this.context.containsBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN)) { + // @formatter:off + MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.context). + apply(springSecurity()) + .apply(new AddFilter()) + .build(); + // @formatter:on + this.context.getBeanFactory().registerResolvableDependency(MockMvc.class, mockMvc); } - AutowiredAnnotationBeanPostProcessor bpp = new AutowiredAnnotationBeanPostProcessor(); bpp.setBeanFactory(this.context.getBeanFactory()); bpp.processInjection(this.test); } private class AddFilter implements MockMvcConfigurer { - public RequestPostProcessor beforeMockMvcCreated( - ConfigurableMockMvcBuilder builder, WebApplicationContext context) { + + @Override + public RequestPostProcessor beforeMockMvcCreated(ConfigurableMockMvcBuilder builder, + WebApplicationContext context) { builder.addFilters(SpringTestContext.this.filters.toArray(new Filter[0])); return null; } + } + } diff --git a/config/src/test/java/org/springframework/security/config/test/SpringTestRule.java b/config/src/test/java/org/springframework/security/config/test/SpringTestRule.java index 192830f5b5..df0834c895 100644 --- a/config/src/test/java/org/springframework/security/config/test/SpringTestRule.java +++ b/config/src/test/java/org/springframework/security/config/test/SpringTestRule.java @@ -19,6 +19,7 @@ package org.springframework.security.config.test; import org.junit.rules.MethodRule; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.Statement; + import org.springframework.security.test.context.TestSecurityContextHolder; /** @@ -26,19 +27,22 @@ import org.springframework.security.test.context.TestSecurityContextHolder; * @since 5.0 */ public class SpringTestRule extends SpringTestContext implements MethodRule { + @Override - public Statement apply(Statement base, FrameworkMethod method, - Object target) { + public Statement apply(Statement base, FrameworkMethod method, Object target) { return new Statement() { + @Override public void evaluate() throws Throwable { setTest(target); try { base.evaluate(); - } finally { + } + finally { TestSecurityContextHolder.clearContext(); close(); } } }; } + } diff --git a/config/src/test/java/org/springframework/security/config/users/AuthenticationTestConfiguration.java b/config/src/test/java/org/springframework/security/config/users/AuthenticationTestConfiguration.java index d00036b61f..c8960e2dc6 100644 --- a/config/src/test/java/org/springframework/security/config/users/AuthenticationTestConfiguration.java +++ b/config/src/test/java/org/springframework/security/config/users/AuthenticationTestConfiguration.java @@ -28,8 +28,10 @@ import org.springframework.security.provisioning.InMemoryUserDetailsManager; */ @Configuration public class AuthenticationTestConfiguration { + @Bean public static UserDetailsService userDetailsService() { return new InMemoryUserDetailsManager(PasswordEncodedUser.user(), PasswordEncodedUser.admin()); } + } diff --git a/config/src/test/java/org/springframework/security/config/users/ReactiveAuthenticationTestConfiguration.java b/config/src/test/java/org/springframework/security/config/users/ReactiveAuthenticationTestConfiguration.java index 27862b5294..c32addfc41 100644 --- a/config/src/test/java/org/springframework/security/config/users/ReactiveAuthenticationTestConfiguration.java +++ b/config/src/test/java/org/springframework/security/config/users/ReactiveAuthenticationTestConfiguration.java @@ -28,8 +28,10 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsService; */ @Configuration public class ReactiveAuthenticationTestConfiguration { + @Bean public static ReactiveUserDetailsService userDetailsService() { return new MapReactiveUserDetailsService(PasswordEncodedUser.user(), PasswordEncodedUser.admin()); } + } diff --git a/config/src/test/java/org/springframework/security/config/util/InMemoryXmlApplicationContext.java b/config/src/test/java/org/springframework/security/config/util/InMemoryXmlApplicationContext.java index ac7b485f75..865462456d 100644 --- a/config/src/test/java/org/springframework/security/config/util/InMemoryXmlApplicationContext.java +++ b/config/src/test/java/org/springframework/security/config/util/InMemoryXmlApplicationContext.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.util; import org.springframework.beans.factory.support.DefaultListableBeanFactory; @@ -26,6 +27,7 @@ import org.springframework.security.util.InMemoryResource; * @author Eddú Meléndez */ public class InMemoryXmlApplicationContext extends AbstractXmlApplicationContext { + static final String BEANS_OPENING = "\n" + xml + BEANS_CLOSE; - inMemoryXml = new InMemoryResource(fullXml); + this.inMemoryXml = new InMemoryResource(fullXml); setAllowBeanDefinitionOverriding(true); setParent(parent); refresh(); @@ -72,7 +72,9 @@ public class InMemoryXmlApplicationContext extends AbstractXmlApplicationContext }; } + @Override protected Resource[] getConfigResources() { - return new Resource[] { inMemoryXml }; + return new Resource[] { this.inMemoryXml }; } + } diff --git a/config/src/test/java/org/springframework/security/config/util/InMemoryXmlWebApplicationContext.java b/config/src/test/java/org/springframework/security/config/util/InMemoryXmlWebApplicationContext.java index 96c8e758f1..72992ecef0 100644 --- a/config/src/test/java/org/springframework/security/config/util/InMemoryXmlWebApplicationContext.java +++ b/config/src/test/java/org/springframework/security/config/util/InMemoryXmlWebApplicationContext.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.util; import org.springframework.beans.BeansException; @@ -23,28 +24,25 @@ import org.springframework.core.io.Resource; import org.springframework.security.util.InMemoryResource; import org.springframework.web.context.support.AbstractRefreshableWebApplicationContext; -import static org.springframework.security.config.util.InMemoryXmlApplicationContext.BEANS_CLOSE; -import static org.springframework.security.config.util.InMemoryXmlApplicationContext.BEANS_OPENING; -import static org.springframework.security.config.util.InMemoryXmlApplicationContext.SPRING_SECURITY_VERSION; - /** * @author Joe Grandja */ public class InMemoryXmlWebApplicationContext extends AbstractRefreshableWebApplicationContext { + private Resource inMemoryXml; public InMemoryXmlWebApplicationContext(String xml) { - this(xml, SPRING_SECURITY_VERSION, null); + this(xml, InMemoryXmlApplicationContext.SPRING_SECURITY_VERSION, null); } public InMemoryXmlWebApplicationContext(String xml, ApplicationContext parent) { - this(xml, SPRING_SECURITY_VERSION, parent); + this(xml, InMemoryXmlApplicationContext.SPRING_SECURITY_VERSION, parent); } - public InMemoryXmlWebApplicationContext(String xml, String secVersion, - ApplicationContext parent) { - String fullXml = BEANS_OPENING + secVersion + ".xsd'>\n" + xml + BEANS_CLOSE; - inMemoryXml = new InMemoryResource(fullXml); + public InMemoryXmlWebApplicationContext(String xml, String secVersion, ApplicationContext parent) { + String fullXml = InMemoryXmlApplicationContext.BEANS_OPENING + secVersion + ".xsd'>\n" + xml + + InMemoryXmlApplicationContext.BEANS_CLOSE; + this.inMemoryXml = new InMemoryResource(fullXml); setAllowBeanDefinitionOverriding(true); setParent(parent); } @@ -52,7 +50,7 @@ public class InMemoryXmlWebApplicationContext extends AbstractRefreshableWebAppl @Override protected void loadBeanDefinitions(DefaultListableBeanFactory beanFactory) throws BeansException { XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(beanFactory); - reader.loadBeanDefinitions(new Resource[] { inMemoryXml }); + reader.loadBeanDefinitions(new Resource[] { this.inMemoryXml }); } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeSpecTests.java index d19e5e953f..99fe4a2c1a 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeSpecTests.java @@ -17,6 +17,7 @@ package org.springframework.security.config.web.server; import org.junit.Test; + import org.springframework.http.HttpMethod; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; @@ -27,140 +28,112 @@ import org.springframework.test.web.reactive.server.WebTestClient; * @since 5.0 */ public class AuthorizeExchangeSpecTests { + ServerHttpSecurity http = ServerHttpSecurityConfigurationBuilder.httpWithDefaultAuthentication(); @Test public void antMatchersWhenMethodAndPatternsThenDiscriminatesByMethod() { - this.http - .csrf().disable() - .authorizeExchange() - .pathMatchers(HttpMethod.POST, "/a", "/b").denyAll() - .anyExchange().permitAll(); - + this.http.csrf().disable().authorizeExchange().pathMatchers(HttpMethod.POST, "/a", "/b").denyAll().anyExchange() + .permitAll(); WebTestClient client = buildClient(); - + // @formatter:off client.get() - .uri("/a") - .exchange() - .expectStatus().isOk(); - + .uri("/a") + .exchange() + .expectStatus().isOk(); client.get() - .uri("/b") - .exchange() - .expectStatus().isOk(); - + .uri("/b") + .exchange() + .expectStatus().isOk(); client.post() - .uri("/a") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/a") + .exchange() + .expectStatus().isUnauthorized(); client.post() - .uri("/b") - .exchange() - .expectStatus().isUnauthorized(); + .uri("/b") + .exchange() + .expectStatus().isUnauthorized(); + // @formatter:on } - @Test public void antMatchersWhenPatternsThenAnyMethod() { - this.http - .csrf().disable() - .authorizeExchange() - .pathMatchers("/a", "/b").denyAll() - .anyExchange().permitAll(); - + this.http.csrf().disable().authorizeExchange().pathMatchers("/a", "/b").denyAll().anyExchange().permitAll(); WebTestClient client = buildClient(); - + // @formatter:off client.get() - .uri("/a") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/a") + .exchange() + .expectStatus().isUnauthorized(); client.get() - .uri("/b") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/b") + .exchange() + .expectStatus().isUnauthorized(); client.post() - .uri("/a") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/a") + .exchange() + .expectStatus().isUnauthorized(); client.post() - .uri("/b") - .exchange() - .expectStatus().isUnauthorized(); + .uri("/b") + .exchange() + .expectStatus().isUnauthorized(); + // @formatter:on } @Test public void antMatchersWhenPatternsInLambdaThenAnyMethod() { - this.http - .csrf(ServerHttpSecurity.CsrfSpec::disable) - .authorizeExchange(exchanges -> - exchanges - .pathMatchers("/a", "/b").denyAll() - .anyExchange().permitAll() - ); - + this.http.csrf(ServerHttpSecurity.CsrfSpec::disable).authorizeExchange( + (exchanges) -> exchanges.pathMatchers("/a", "/b").denyAll().anyExchange().permitAll()); WebTestClient client = buildClient(); - + // @formatter:off client.get() - .uri("/a") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/a") + .exchange() + .expectStatus().isUnauthorized(); client.get() - .uri("/b") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/b") + .exchange() + .expectStatus().isUnauthorized(); client.post() - .uri("/a") - .exchange() - .expectStatus().isUnauthorized(); - + .uri("/a") + .exchange() + .expectStatus().isUnauthorized(); client.post() - .uri("/b") - .exchange() - .expectStatus().isUnauthorized(); + .uri("/b") + .exchange() + .expectStatus().isUnauthorized(); + // @formatter:on } @Test(expected = IllegalStateException.class) public void antMatchersWhenNoAccessAndAnotherMatcherThenThrowsException() { - this.http - .authorizeExchange() - .pathMatchers("/incomplete"); - this.http - .authorizeExchange() - .pathMatchers("/throws-exception"); + this.http.authorizeExchange().pathMatchers("/incomplete"); + this.http.authorizeExchange().pathMatchers("/throws-exception"); } @Test(expected = IllegalStateException.class) public void anyExchangeWhenFollowedByMatcherThenThrowsException() { - this.http - .authorizeExchange().anyExchange().denyAll() - .pathMatchers("/never-reached"); + // @formatter:off + this.http.authorizeExchange() + .anyExchange().denyAll() + .pathMatchers("/never-reached"); + // @formatter:on } @Test(expected = IllegalStateException.class) public void buildWhenMatcherDefinedWithNoAccessThenThrowsException() { - this.http - .authorizeExchange() - .pathMatchers("/incomplete"); + this.http.authorizeExchange().pathMatchers("/incomplete"); this.http.build(); } @Test(expected = IllegalStateException.class) public void buildWhenMatcherDefinedWithNoAccessInLambdaThenThrowsException() { - this.http - .authorizeExchange(exchanges -> - exchanges - .pathMatchers("/incomplete") - ); + this.http.authorizeExchange((exchanges) -> exchanges.pathMatchers("/incomplete")); this.http.build(); } private WebTestClient buildClient() { return WebTestClientBuilder.bindToWebFilters(this.http.build()).build(); } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java index 9cac1da459..71ee7b605b 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java @@ -16,11 +16,18 @@ package org.springframework.security.config.web.server; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.context.ApplicationContext; import org.springframework.core.ResolvableType; import org.springframework.http.HttpHeaders; @@ -30,15 +37,9 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.reactive.CorsConfigurationSource; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -46,8 +47,10 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class CorsSpecTests { + @Mock private CorsConfigurationSource source; + @Mock private ApplicationContext context; @@ -59,11 +62,10 @@ public class CorsSpecTests { @Before public void setup() { - this.http = new TestingServerHttpSecurity() - .applicationContext(this.context); + this.http = new TestingServerHttpSecurity().applicationContext(this.context); CorsConfiguration value = new CorsConfiguration(); value.setAllowedOrigins(Arrays.asList("*")); - when(this.source.getCorsConfiguration(any())).thenReturn(value); + given(this.source.getCorsConfiguration(any())).willReturn(value); } @Test @@ -76,7 +78,7 @@ public class CorsSpecTests { @Test public void corsWhenEnabledInLambdaThenAccessControlAllowOriginAndSecurityHeaders() { - this.http.cors(cors -> cors.configurationSource(this.source)); + this.http.cors((cors) -> cors.configurationSource(this.source)); this.expectedHeaders.set("Access-Control-Allow-Origin", "*"); this.expectedHeaders.set("X-Frame-Options", "DENY"); assertHeaders(); @@ -84,8 +86,9 @@ public class CorsSpecTests { @Test public void corsWhenCorsConfigurationSourceBeanThenAccessControlAllowOriginAndSecurityHeaders() { - when(this.context.getBeanNamesForType(any(ResolvableType.class))).thenReturn(new String[] {"source"}, new String[0]); - when(this.context.getBean("source")).thenReturn(this.source); + given(this.context.getBeanNamesForType(any(ResolvableType.class))).willReturn(new String[] { "source" }, + new String[0]); + given(this.context.getBean("source")).willReturn(this.source); this.expectedHeaders.set("Access-Control-Allow-Origin", "*"); this.expectedHeaders.set("X-Frame-Options", "DENY"); assertHeaders(); @@ -93,24 +96,23 @@ public class CorsSpecTests { @Test public void corsWhenNoConfigurationSourceThenNoCorsHeaders() { - when(this.context.getBeanNamesForType(any(ResolvableType.class))).thenReturn(new String[0]); + given(this.context.getBeanNamesForType(any(ResolvableType.class))).willReturn(new String[0]); this.headerNamesNotPresent.add("Access-Control-Allow-Origin"); assertHeaders(); } private void assertHeaders() { WebTestClient client = buildClient(); + // @formatter:off FluxExchangeResult response = client.get() - .uri("https://example.com/") - .headers(h -> h.setOrigin("https://origin.example.com")) - .exchange() - .returnResult(String.class); - + .uri("https://example.com/") + .headers((h) -> h.setOrigin("https://origin.example.com")) + .exchange() + .returnResult(String.class); + // @formatter:on Map> responseHeaders = response.getResponseHeaders(); - if (!this.expectedHeaders.isEmpty()) { - assertThat(responseHeaders).describedAs(response.toString()) - .containsAllEntriesOf(this.expectedHeaders); + assertThat(responseHeaders).describedAs(response.toString()).containsAllEntriesOf(this.expectedHeaders); } if (!this.headerNamesNotPresent.isEmpty()) { assertThat(responseHeaders.keySet()).doesNotContainAnyElementsOf(this.headerNamesNotPresent); @@ -118,8 +120,10 @@ public class CorsSpecTests { } private WebTestClient buildClient() { - return WebTestClientBuilder - .bindToWebFilters(this.http.build()) + // @formatter:off + return WebTestClientBuilder.bindToWebFilters(this.http.build()) .build(); + // @formatter:on } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/ExceptionHandlingSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/ExceptionHandlingSpecTests.java index beae57c576..c278be38f4 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ExceptionHandlingSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ExceptionHandlingSpecTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.web.server; import org.junit.Test; + import org.springframework.http.HttpStatus; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; @@ -33,198 +35,183 @@ import static org.springframework.security.config.Customizer.withDefaults; * @since 5.0.5 */ public class ExceptionHandlingSpecTests { + private ServerHttpSecurity http = ServerHttpSecurityConfigurationBuilder.httpWithDefaultAuthentication(); @Test public void defaultAuthenticationEntryPoint() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .csrf().disable() - .authorizeExchange() - .anyExchange().authenticated() - .and() - .exceptionHandling() - .and() - .build(); - + .csrf().disable() + .authorizeExchange() + .anyExchange().authenticated() + .and() + .exceptionHandling().and() + .build(); WebTestClient client = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - client - .get() - .uri("/test") - .exchange() - .expectStatus().isUnauthorized() - .expectHeader().valueMatches("WWW-Authenticate", "Basic.*"); + .bindToWebFilters(securityWebFilter) + .build(); + client.get() + .uri("/test") + .exchange() + .expectStatus().isUnauthorized() + .expectHeader().valueMatches("WWW-Authenticate", "Basic.*"); + // @formatter:on } @Test public void requestWhenExceptionHandlingWithDefaultsInLambdaThenDefaultAuthenticationEntryPointUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange(exchanges -> - exchanges - .anyExchange().authenticated() - ) - .exceptionHandling(withDefaults()) - .build(); - + .authorizeExchange((exchanges) -> exchanges + .anyExchange().authenticated() + ) + .exceptionHandling(withDefaults()) + .build(); WebTestClient client = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - client - .get() - .uri("/test") - .exchange() - .expectStatus().isUnauthorized() - .expectHeader().valueMatches("WWW-Authenticate", "Basic.*"); + .bindToWebFilters(securityWebFilter) + .build(); + client.get() + .uri("/test") + .exchange() + .expectStatus().isUnauthorized() + .expectHeader().valueMatches("WWW-Authenticate", "Basic.*"); + // @formatter:on } @Test public void customAuthenticationEntryPoint() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .csrf().disable() - .authorizeExchange() - .anyExchange().authenticated() - .and() - .exceptionHandling() - .authenticationEntryPoint(redirectServerAuthenticationEntryPoint("/auth")) - .and() - .build(); - + .csrf().disable() + .authorizeExchange() + .anyExchange().authenticated() + .and() + .exceptionHandling() + .authenticationEntryPoint(redirectServerAuthenticationEntryPoint("/auth")) + .and() + .build(); WebTestClient client = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - client - .get() - .uri("/test") - .exchange() - .expectStatus().isFound() - .expectHeader().valueMatches("Location", ".*"); + .bindToWebFilters(securityWebFilter) + .build(); + client.get() + .uri("/test") + .exchange() + .expectStatus().isFound() + .expectHeader().valueMatches("Location", ".*"); + // @formatter:on } @Test public void requestWhenCustomAuthenticationEntryPointInLambdaThenCustomAuthenticationEntryPointUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange(exchanges -> - exchanges + .authorizeExchange((exchanges) -> exchanges .anyExchange().authenticated() ) - .exceptionHandling(exceptionHandling -> - exceptionHandling + .exceptionHandling((exceptionHandling) -> exceptionHandling .authenticationEntryPoint(redirectServerAuthenticationEntryPoint("/auth")) ) .build(); - WebTestClient client = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - client - .get() - .uri("/test") - .exchange() - .expectStatus().isFound() - .expectHeader().valueMatches("Location", ".*"); + .bindToWebFilters(securityWebFilter) + .build(); + client.get() + .uri("/test") + .exchange() + .expectStatus().isFound() + .expectHeader().valueMatches("Location", ".*"); + // @formatter:on } @Test public void defaultAccessDeniedHandler() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .csrf().disable() - .httpBasic().and() - .authorizeExchange() - .anyExchange().hasRole("ADMIN") - .and() - .exceptionHandling() - .and() - .build(); - + .csrf().disable() + .httpBasic().and() + .authorizeExchange() + .anyExchange().hasRole("ADMIN") + .and() + .exceptionHandling().and() + .build(); WebTestClient client = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - client - .get() - .uri("/admin") - .headers(headers -> headers.setBasicAuth("user", "password")) - .exchange() - .expectStatus().isForbidden(); + .bindToWebFilters(securityWebFilter) + .build(); + client.get() + .uri("/admin") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus().isForbidden(); + // @formatter:on } @Test public void requestWhenExceptionHandlingWithDefaultsInLambdaThenDefaultAccessDeniedHandlerUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http .httpBasic(withDefaults()) - .authorizeExchange(exchanges -> - exchanges + .authorizeExchange((exchanges) -> exchanges .anyExchange().hasRole("ADMIN") ) .exceptionHandling(withDefaults()) .build(); - WebTestClient client = WebTestClientBuilder .bindToWebFilters(securityWebFilter) .build(); - - client - .get() + client.get() .uri("/admin") - .headers(headers -> headers.setBasicAuth("user", "password")) + .headers((headers) -> headers.setBasicAuth("user", "password")) .exchange() .expectStatus().isForbidden(); + // @formatter:on } @Test public void customAccessDeniedHandler() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .csrf().disable() - .httpBasic().and() - .authorizeExchange() - .anyExchange().hasRole("ADMIN") - .and() - .exceptionHandling() - .accessDeniedHandler(httpStatusServerAccessDeniedHandler(HttpStatus.BAD_REQUEST)) - .and() - .build(); - + .csrf().disable() + .httpBasic().and() + .authorizeExchange() + .anyExchange().hasRole("ADMIN") + .and() + .exceptionHandling() + .accessDeniedHandler(httpStatusServerAccessDeniedHandler(HttpStatus.BAD_REQUEST)) + .and() + .build(); WebTestClient client = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - client - .get() - .uri("/admin") - .headers(headers -> headers.setBasicAuth("user", "password")) - .exchange() - .expectStatus().isBadRequest(); + .bindToWebFilters(securityWebFilter) + .build(); + client.get() + .uri("/admin") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus().isBadRequest(); + // @formatter:on } @Test public void requestWhenCustomAccessDeniedHandlerInLambdaThenCustomAccessDeniedHandlerUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http .httpBasic(withDefaults()) - .authorizeExchange(exchanges -> - exchanges - .anyExchange().hasRole("ADMIN") + .authorizeExchange((exchanges) -> exchanges + .anyExchange().hasRole("ADMIN") ) - .exceptionHandling(exceptionHandling -> - exceptionHandling + .exceptionHandling((exceptionHandling) -> exceptionHandling .accessDeniedHandler(httpStatusServerAccessDeniedHandler(HttpStatus.BAD_REQUEST)) ) .build(); - WebTestClient client = WebTestClientBuilder .bindToWebFilters(securityWebFilter) .build(); - - client - .get() + client.get() .uri("/admin") - .headers(headers -> headers.setBasicAuth("user", "password")) + .headers((headers) -> headers.setBasicAuth("user", "password")) .exchange() .expectStatus().isBadRequest(); + // @formatter:on } private ServerAuthenticationEntryPoint redirectServerAuthenticationEntryPoint(String location) { @@ -234,4 +221,5 @@ public class ExceptionHandlingSpecTests { private ServerAccessDeniedHandler httpStatusServerAccessDeniedHandler(HttpStatus httpStatus) { return new HttpStatusServerAccessDeniedHandler(httpStatus); } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java index c0f1ff8938..ad09dbf50d 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java @@ -23,6 +23,8 @@ import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebElement; import org.openqa.selenium.support.FindBy; import org.openqa.selenium.support.PageFactory; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; @@ -43,11 +45,9 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; @@ -62,242 +62,205 @@ import static org.springframework.security.config.Customizer.withDefaults; * @since 5.0 */ public class FormLoginTests { + private ServerHttpSecurity http = ServerHttpSecurityConfigurationBuilder.httpWithDefaultAuthentication(); @Test public void defaultLoginPage() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin().and() - .build(); - + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin() + .and() + .build(); WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - + .bindToWebFilters(securityWebFilter) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class).assertAt(); + // @formatter:off loginPage = loginPage.loginForm() - .username("user") - .password("invalid") - .submit(DefaultLoginPage.class) - .assertError(); - + .username("user") + .password("invalid") + .submit(DefaultLoginPage.class) + .assertError(); HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on homePage.assertAt(); - - loginPage = DefaultLogoutPage.to(driver) - .assertAt() - .logout(); - - loginPage - .assertAt() - .assertLogout(); + loginPage = DefaultLogoutPage.to(driver).assertAt().logout(); + loginPage.assertAt().assertLogout(); } @Test public void formLoginWhenDefaultsInLambdaThenCreatesDefaultLoginPage() { SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange(exchanges -> - exchanges - .anyExchange().authenticated() - ) - .formLogin(withDefaults()) - .build(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class) - .assertAt(); - - loginPage = loginPage.loginForm() - .username("user") - .password("invalid") - .submit(DefaultLoginPage.class) - .assertError(); - + .authorizeExchange((exchanges) -> exchanges.anyExchange().authenticated()).formLogin(withDefaults()) + .build(); + WebTestClient webTestClient = WebTestClientBuilder.bindToWebFilters(securityWebFilter).build(); + WebDriver driver = WebTestClientHtmlUnitDriverBuilder.webTestClientSetup(webTestClient).build(); + DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class).assertAt(); + // @formatter:off + loginPage = loginPage + .loginForm() + .username("user") + .password("invalid") + .submit(DefaultLoginPage.class) + .assertError(); HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on homePage.assertAt(); - - loginPage = DefaultLogoutPage.to(driver) - .assertAt() - .logout(); - - loginPage - .assertAt() - .assertLogout(); + loginPage = DefaultLogoutPage.to(driver).assertAt().logout(); + loginPage.assertAt().assertLogout(); } @Test public void customLoginPage() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .pathMatchers("/login").permitAll() - .anyExchange().authenticated() - .and() - .formLogin() - .loginPage("/login") - .and() - .build(); - + .authorizeExchange() + .pathMatchers("/login").permitAll() + .anyExchange().authenticated() + .and() + .formLogin() + .loginPage("/login") + .and() + .build(); WebTestClient webTestClient = WebTestClient - .bindToController(new CustomLoginPageController(), new WebTestClientBuilder.Http200RestController()) - .webFilter(new WebFilterChainProxy(securityWebFilter)) - .build(); - + .bindToController(new CustomLoginPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - CustomLoginPage loginPage = HomePage.to(driver, CustomLoginPage.class) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + CustomLoginPage loginPage = HomePage.to(driver, CustomLoginPage.class).assertAt(); + // @formatter:off HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on homePage.assertAt(); } @Test public void formLoginWhenCustomLoginPageInLambdaThenUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange(exchanges -> - exchanges - .pathMatchers("/login").permitAll() - .anyExchange().authenticated() - ) - .formLogin(formLogin -> - formLogin - .loginPage("/login") - ) - .build(); - + .authorizeExchange((exchanges) -> exchanges + .pathMatchers("/login").permitAll() + .anyExchange().authenticated() + ) + .formLogin((formLogin) -> formLogin + .loginPage("/login") + ) + .build(); WebTestClient webTestClient = WebTestClient - .bindToController(new CustomLoginPageController(), new WebTestClientBuilder.Http200RestController()) - .webFilter(new WebFilterChainProxy(securityWebFilter)) - .build(); - + .bindToController(new CustomLoginPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - CustomLoginPage loginPage = HomePage.to(driver, CustomLoginPage.class) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + CustomLoginPage loginPage = HomePage.to(driver, CustomLoginPage.class).assertAt(); + // @formatter:off HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on homePage.assertAt(); } @Test public void formLoginWhenCustomAuthenticationFailureHandlerThenUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .pathMatchers("/login", "/failure").permitAll() - .anyExchange().authenticated() - .and() - .formLogin() - .authenticationFailureHandler(new RedirectServerAuthenticationFailureHandler("/failure")) - .and() - .build(); - + .authorizeExchange() + .pathMatchers("/login", "/failure").permitAll() + .anyExchange().authenticated() + .and() + .formLogin() + .authenticationFailureHandler(new RedirectServerAuthenticationFailureHandler("/failure")) + .and() + .build(); WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(securityWebFilter) .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder .webTestClientSetup(webTestClient) .build(); - - DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class) - .assertAt(); - + // @formatter:on + DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class).assertAt(); + // @formatter:off loginPage.loginForm() .username("invalid") .password("invalid") .submit(HomePage.class); - + // @formatter:on assertThat(driver.getCurrentUrl()).endsWith("/failure"); } @Test public void formLoginWhenCustomRequiresAuthenticationMatcherThenUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .pathMatchers("/login", "/sign-in").permitAll() - .anyExchange().authenticated() - .and() - .formLogin() - .requiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/sign-in")) - .and() - .build(); - + .authorizeExchange() + .pathMatchers("/login", "/sign-in").permitAll() + .anyExchange().authenticated() + .and() + .formLogin() + .requiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/sign-in")) + .and() + .build(); WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(securityWebFilter) .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder .webTestClientSetup(webTestClient) .build(); - + // @formatter:on driver.get("http://localhost/sign-in"); - assertThat(driver.getCurrentUrl()).endsWith("/login?error"); } @Test public void authenticationSuccess() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin() - .authenticationSuccessHandler(new RedirectServerAuthenticationSuccessHandler("/custom")) - .and() - .build(); - + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin() + .authenticationSuccessHandler(new RedirectServerAuthenticationSuccessHandler("/custom")) + .and() + .build(); WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - + .bindToWebFilters(securityWebFilter) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = DefaultLoginPage.to(driver) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + DefaultLoginPage loginPage = DefaultLoginPage.to(driver).assertAt(); + // @formatter:off HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on assertThat(driver.getCurrentUrl()).endsWith("/custom"); } @@ -305,35 +268,32 @@ public class FormLoginTests { public void customAuthenticationManager() { ReactiveAuthenticationManager defaultAuthenticationManager = mock(ReactiveAuthenticationManager.class); ReactiveAuthenticationManager customAuthenticationManager = mock(ReactiveAuthenticationManager.class); - - given(defaultAuthenticationManager.authenticate(any())).willThrow(new RuntimeException("should not interact with default auth manager")); - given(customAuthenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("user", "password", "ROLE_USER", "ROLE_ADMIN"))); - + given(defaultAuthenticationManager.authenticate(any())) + .willThrow(new RuntimeException("should not interact with default auth manager")); + given(customAuthenticationManager.authenticate(any())) + .willReturn(Mono.just(new TestingAuthenticationToken("user", "password", "ROLE_USER", "ROLE_ADMIN"))); + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authenticationManager(defaultAuthenticationManager) - .formLogin() - .authenticationManager(customAuthenticationManager) - .and() - .build(); - + .authenticationManager(defaultAuthenticationManager) + .formLogin() + .authenticationManager(customAuthenticationManager) + .and() + .build(); WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - + .bindToWebFilters(securityWebFilter) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = DefaultLoginPage.to(driver) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + DefaultLoginPage loginPage = DefaultLoginPage.to(driver).assertAt(); + // @formatter:off HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on homePage.assertAt(); - verifyZeroInteractions(defaultAuthenticationManager); } @@ -341,14 +301,12 @@ public class FormLoginTests { public void formLoginSecurityContextRepository() { ServerSecurityContextRepository defaultSecContextRepository = mock(ServerSecurityContextRepository.class); ServerSecurityContextRepository formLoginSecContextRepository = mock(ServerSecurityContextRepository.class); - TestingAuthenticationToken token = new TestingAuthenticationToken("rob", "rob", "ROLE_USER"); - given(defaultSecContextRepository.save(any(), any())).willReturn(Mono.empty()); given(defaultSecContextRepository.load(any())).willReturn(authentication(token)); given(formLoginSecContextRepository.save(any(), any())).willReturn(Mono.empty()); given(formLoginSecContextRepository.load(any())).willReturn(authentication(token)); - + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http .authorizeExchange() .anyExchange().authenticated() @@ -358,29 +316,31 @@ public class FormLoginTests { .securityContextRepository(formLoginSecContextRepository) .and() .build(); - WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(securityWebFilter) .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder .webTestClientSetup(webTestClient) .build(); - - DefaultLoginPage loginPage = DefaultLoginPage.to(driver) - .assertAt(); - + // @formatter:on + DefaultLoginPage loginPage = DefaultLoginPage.to(driver).assertAt(); + // @formatter:off HomePage homePage = loginPage.loginForm() .username("user") .password("password") .submit(HomePage.class); - + // @formatter:on homePage.assertAt(); - verify(defaultSecContextRepository, atLeastOnce()).load(any()); verify(formLoginSecContextRepository).save(any(), any()); } + Mono authentication(Authentication authentication) { + SecurityContext context = new SecurityContextImpl(); + context.setAuthentication(authentication); + return Mono.just(context); + } + public static class CustomLoginPage { private WebDriver driver; @@ -402,9 +362,13 @@ public class FormLoginTests { } public static class LoginForm { + private WebDriver driver; + private WebElement username; + private WebElement password; + @FindBy(css = "button[type=submit]") private WebElement submit; @@ -426,12 +390,15 @@ public class FormLoginTests { this.submit.click(); return PageFactory.initElements(this.driver, page); } + } + } public static class DefaultLoginPage { private WebDriver driver; + @FindBy(css = "div[role=alert]") private WebElement alert; @@ -463,8 +430,7 @@ public class FormLoginTests { } public DefaultLoginPage assertLoginFormNotPresent() { - assertThatThrownBy(() -> loginForm().username("")) - .isInstanceOf(NoSuchElementException.class); + assertThatExceptionOfType(NoSuchElementException.class).isThrownBy(() -> loginForm().username("")); return this; } @@ -485,9 +451,13 @@ public class FormLoginTests { } public static class LoginForm { + private WebDriver driver; + private WebElement username; + private WebElement password; + @FindBy(css = "button[type=submit]") private WebElement submit; @@ -509,28 +479,32 @@ public class FormLoginTests { this.submit.click(); return PageFactory.initElements(this.driver, page); } + } public class OAuth2Login { + public WebElement findClientRegistrationByName(String clientName) { return DefaultLoginPage.this.driver.findElement(By.linkText(clientName)); } public OAuth2Login assertClientRegistrationByName(String clientName) { - assertThatCode(() -> findClientRegistrationByName(clientName)) - .doesNotThrowAnyException(); + findClientRegistrationByName(clientName); return this; } public DefaultLoginPage and() { return DefaultLoginPage.this; } + } + } public static class DefaultLogoutPage { private WebDriver driver; + @FindBy(css = "button[type=submit]") private WebElement submit; @@ -554,7 +528,9 @@ public class FormLoginTests { } } + public static class HomePage { + private WebDriver driver; @FindBy(tagName = "body") @@ -572,48 +548,47 @@ public class FormLoginTests { driver.get("http://localhost/"); return PageFactory.initElements(driver, page); } + } @Controller public static class CustomLoginPageController { + @ResponseBody @GetMapping("/login") public Mono login(ServerWebExchange exchange) { Mono token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty()); - return token.map(t -> - "\n" - + "\n" - + " \n" - + " \n" - + " \n" - + " \n" - + " \n" - + " Custom Log In Page\n" - + " \n" - + " \n" - + "
    \n" - + "
    \n" - + "

    Please sign in

    \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + "

    \n" - + " \n" - + " \n" - + "

    \n" - + " \n" - + " \n" - + "
    \n" - + "
    \n" - + " \n" - + ""); + // @formatter:off + return token.map((t) -> "\n" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Custom Log In Page\n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + "

    Please sign in

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + "

    \n" + + " \n" + + " \n" + + "

    \n" + + " \n" + + " \n" + + "
    \n" + + "
    \n" + + " \n" + + ""); + // @formatter:on } + } - Mono authentication(Authentication authentication) { - SecurityContext context = new SecurityContextImpl(); - context.setAuthentication(authentication); - return Mono.just(context); - } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/HeaderSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/HeaderSpecTests.java index ea16551349..6e51e7c6ec 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/HeaderSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/HeaderSpecTests.java @@ -39,7 +39,7 @@ import org.springframework.security.web.server.header.XXssProtectionServerHttpHe import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.config.Customizer.withDefaults; /** @@ -51,8 +51,10 @@ import static org.springframework.security.config.Customizer.withDefaults; * @since 5.0 */ public class HeaderSpecTests { - private final static String CUSTOM_HEADER = "CUSTOM-HEADER"; - private final static String CUSTOM_VALUE = "CUSTOM-VALUE"; + + private static final String CUSTOM_HEADER = "CUSTOM-HEADER"; + + private static final String CUSTOM_VALUE = "CUSTOM-VALUE"; private ServerHttpSecurity http = ServerHttpSecurity.http(); @@ -62,54 +64,45 @@ public class HeaderSpecTests { @Before public void setup() { - this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, "max-age=31536000 ; includeSubDomains"); + this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, + "max-age=31536000 ; includeSubDomains"); this.expectedHeaders.add(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate"); this.expectedHeaders.add(HttpHeaders.PRAGMA, "no-cache"); this.expectedHeaders.add(HttpHeaders.EXPIRES, "0"); - this.expectedHeaders - .add(ContentTypeOptionsServerHttpHeadersWriter.X_CONTENT_OPTIONS, "nosniff"); + this.expectedHeaders.add(ContentTypeOptionsServerHttpHeadersWriter.X_CONTENT_OPTIONS, "nosniff"); this.expectedHeaders.add(XFrameOptionsServerHttpHeadersWriter.X_FRAME_OPTIONS, "DENY"); - this.expectedHeaders - .add(XXssProtectionServerHttpHeadersWriter.X_XSS_PROTECTION, "1 ; mode=block"); + this.expectedHeaders.add(XXssProtectionServerHttpHeadersWriter.X_XSS_PROTECTION, "1 ; mode=block"); } @Test public void headersWhenDisableThenNoSecurityHeaders() { new HashSet<>(this.expectedHeaders.keySet()).forEach(this::expectHeaderNamesNotPresent); - this.http.headers().disable(); - assertHeaders(); } @Test public void headersWhenDisableInLambdaThenNoSecurityHeaders() { new HashSet<>(this.expectedHeaders.keySet()).forEach(this::expectHeaderNamesNotPresent); - - this.http.headers(headers -> headers.disable()); - + this.http.headers((headers) -> headers.disable()); assertHeaders(); } @Test public void headersWhenDisableAndInvokedExplicitlyThenDefautsUsed() { - this.http.headers().disable() - .headers(); - + this.http.headers().disable().headers(); assertHeaders(); } @Test public void headersWhenDefaultsThenAllDefaultsWritten() { this.http.headers(); - assertHeaders(); } @Test public void headersWhenDefaultsInLambdaThenAllDefaultsWritten() { this.http.headers(withDefaults()); - assertHeaders(); } @@ -117,19 +110,17 @@ public class HeaderSpecTests { public void headersWhenCacheDisableThenCacheNotWritten() { expectHeaderNamesNotPresent(HttpHeaders.CACHE_CONTROL, HttpHeaders.PRAGMA, HttpHeaders.EXPIRES); this.http.headers().cache().disable(); - assertHeaders(); } - @Test public void headersWhenCacheDisableInLambdaThenCacheNotWritten() { expectHeaderNamesNotPresent(HttpHeaders.CACHE_CONTROL, HttpHeaders.PRAGMA, HttpHeaders.EXPIRES); - this.http - .headers(headers -> - headers.cache(cache -> cache.disable()) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .cache((cache) -> cache.disable()) + ); + // @formatter:on assertHeaders(); } @@ -137,18 +128,18 @@ public class HeaderSpecTests { public void headersWhenContentOptionsDisableThenContentTypeOptionsNotWritten() { expectHeaderNamesNotPresent(ContentTypeOptionsServerHttpHeadersWriter.X_CONTENT_OPTIONS); this.http.headers().contentTypeOptions().disable(); - assertHeaders(); } @Test public void headersWhenContentOptionsDisableInLambdaThenContentTypeOptionsNotWritten() { expectHeaderNamesNotPresent(ContentTypeOptionsServerHttpHeadersWriter.X_CONTENT_OPTIONS); + // @formatter:off this.http - .headers(headers -> - headers.contentTypeOptions(contentTypeOptions -> contentTypeOptions.disable()) - ); - + .headers((headers) -> headers + .contentTypeOptions((contentTypeOptions) -> contentTypeOptions.disable() + )); + // @formatter:on assertHeaders(); } @@ -156,148 +147,159 @@ public class HeaderSpecTests { public void headersWhenHstsDisableThenHstsNotWritten() { expectHeaderNamesNotPresent(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY); this.http.headers().hsts().disable(); - assertHeaders(); } @Test public void headersWhenHstsDisableInLambdaThenHstsNotWritten() { expectHeaderNamesNotPresent(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY); - this.http - .headers(headers -> - headers.hsts(hsts -> hsts.disable()) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .hsts((hsts) -> hsts.disable()) + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenHstsCustomThenCustomHstsWritten() { this.expectedHeaders.remove(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY); - this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, "max-age=60"); - this.http.headers().hsts() + this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, + "max-age=60"); + // @formatter:off + this.http.headers() + .hsts() .maxAge(Duration.ofSeconds(60)) .includeSubdomains(false); - + // @formatter:on assertHeaders(); } @Test public void headersWhenHstsCustomInLambdaThenCustomHstsWritten() { this.expectedHeaders.remove(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY); - this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, "max-age=60"); - this.http - .headers(headers -> - headers - .hsts(hsts -> - hsts + this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, + "max-age=60"); + // @formatter:off + this.http.headers( + (headers) -> headers + .hsts((hsts) -> hsts .maxAge(Duration.ofSeconds(60)) .includeSubdomains(false) ) - ); - + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenHstsCustomWithPreloadThenCustomHstsWritten() { this.expectedHeaders.remove(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY); - this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, "max-age=60 ; includeSubDomains ; preload"); - this.http.headers().hsts() - .maxAge(Duration.ofSeconds(60)) - .preload(true); - + this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, + "max-age=60 ; includeSubDomains ; preload"); + // @formatter:off + this.http.headers() + .hsts() + .maxAge(Duration.ofSeconds(60)) + .preload(true); + // @formatter:on assertHeaders(); } @Test public void headersWhenHstsCustomWithPreloadInLambdaThenCustomHstsWritten() { this.expectedHeaders.remove(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY); - this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, "max-age=60 ; includeSubDomains ; preload"); - this.http - .headers(headers -> - headers - .hsts(hsts -> - hsts - .maxAge(Duration.ofSeconds(60)) - .preload(true) - ) - ); - + this.expectedHeaders.add(StrictTransportSecurityServerHttpHeadersWriter.STRICT_TRANSPORT_SECURITY, + "max-age=60 ; includeSubDomains ; preload"); + // @formatter:off + this.http.headers((headers) -> headers + .hsts((hsts) -> hsts + .maxAge(Duration.ofSeconds(60)) + .preload(true) + ) + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenFrameOptionsDisableThenFrameOptionsNotWritten() { expectHeaderNamesNotPresent(XFrameOptionsServerHttpHeadersWriter.X_FRAME_OPTIONS); - this.http.headers().frameOptions().disable(); - + // @formatter:off + this.http.headers() + .frameOptions().disable(); + // @formatter:on assertHeaders(); } @Test public void headersWhenFrameOptionsDisableInLambdaThenFrameOptionsNotWritten() { expectHeaderNamesNotPresent(XFrameOptionsServerHttpHeadersWriter.X_FRAME_OPTIONS); - this.http - .headers(headers -> - headers.frameOptions(frameOptions -> frameOptions.disable()) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .frameOptions((frameOptions) -> frameOptions + .disable() + ) + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenFrameOptionsModeThenFrameOptionsCustomMode() { this.expectedHeaders.set(XFrameOptionsServerHttpHeadersWriter.X_FRAME_OPTIONS, "SAMEORIGIN"); + // @formatter:off this.http.headers() .frameOptions() .mode(XFrameOptionsServerHttpHeadersWriter.Mode.SAMEORIGIN); - + // @formatter:on assertHeaders(); } @Test public void headersWhenFrameOptionsModeInLambdaThenFrameOptionsCustomMode() { this.expectedHeaders.set(XFrameOptionsServerHttpHeadersWriter.X_FRAME_OPTIONS, "SAMEORIGIN"); - this.http - .headers(headers -> - headers - .frameOptions(frameOptions -> - frameOptions - .mode(XFrameOptionsServerHttpHeadersWriter.Mode.SAMEORIGIN) - ) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .frameOptions((frameOptions) -> frameOptions + .mode(XFrameOptionsServerHttpHeadersWriter.Mode.SAMEORIGIN) + ) + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenXssProtectionDisableThenXssProtectionNotWritten() { expectHeaderNamesNotPresent("X-Xss-Protection"); - this.http.headers().xssProtection().disable(); - + // @formatter:off + this.http.headers() + .xssProtection().disable(); + // @formatter:on assertHeaders(); } @Test public void headersWhenXssProtectionDisableInLambdaThenXssProtectionNotWritten() { expectHeaderNamesNotPresent("X-Xss-Protection"); - this.http - .headers(headers -> - headers.xssProtection(xssProtection -> xssProtection.disable()) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .xssProtection((xssProtection) -> xssProtection + .disable() + ) + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenFeaturePolicyEnabledThenFeaturePolicyWritten() { String policyDirectives = "Feature-Policy"; - this.expectedHeaders.add(FeaturePolicyServerHttpHeadersWriter.FEATURE_POLICY, - policyDirectives); - - this.http.headers().featurePolicy(policyDirectives); - + this.expectedHeaders.add(FeaturePolicyServerHttpHeadersWriter.FEATURE_POLICY, policyDirectives); + // @formatter:off + this.http.headers() + .featurePolicy(policyDirectives); + // @formatter:on assertHeaders(); } @@ -306,9 +308,10 @@ public class HeaderSpecTests { String policyDirectives = "default-src 'self'"; this.expectedHeaders.add(ContentSecurityPolicyServerHttpHeadersWriter.CONTENT_SECURITY_POLICY, policyDirectives); - - this.http.headers().contentSecurityPolicy(policyDirectives); - + // @formatter:off + this.http.headers() + .contentSecurityPolicy(policyDirectives); + // @formatter:on assertHeaders(); } @@ -317,12 +320,11 @@ public class HeaderSpecTests { String expectedPolicyDirectives = "default-src 'self'"; this.expectedHeaders.add(ContentSecurityPolicyServerHttpHeadersWriter.CONTENT_SECURITY_POLICY, expectedPolicyDirectives); - - this.http - .headers(headers -> - headers.contentSecurityPolicy(withDefaults()) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .contentSecurityPolicy(withDefaults()) + ); + // @formatter:on assertHeaders(); } @@ -331,16 +333,13 @@ public class HeaderSpecTests { String policyDirectives = "default-src 'self' *.trusted.com"; this.expectedHeaders.add(ContentSecurityPolicyServerHttpHeadersWriter.CONTENT_SECURITY_POLICY, policyDirectives); - - this.http - .headers(headers -> - headers - .contentSecurityPolicy(contentSecurityPolicy -> - contentSecurityPolicy - .policyDirectives(policyDirectives) - ) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .contentSecurityPolicy((csp) -> csp + .policyDirectives(policyDirectives) + ) + ); + // @formatter:on assertHeaders(); } @@ -348,8 +347,10 @@ public class HeaderSpecTests { public void headersWhenReferrerPolicyEnabledThenFeaturePolicyWritten() { this.expectedHeaders.add(ReferrerPolicyServerHttpHeadersWriter.REFERRER_POLICY, ReferrerPolicy.NO_REFERRER.getPolicy()); - this.http.headers().referrerPolicy(); - + // @formatter:off + this.http.headers() + .referrerPolicy(); + // @formatter:on assertHeaders(); } @@ -357,12 +358,12 @@ public class HeaderSpecTests { public void headersWhenReferrerPolicyEnabledInLambdaThenReferrerPolicyWritten() { this.expectedHeaders.add(ReferrerPolicyServerHttpHeadersWriter.REFERRER_POLICY, ReferrerPolicy.NO_REFERRER.getPolicy()); - this.http - .headers(headers -> - headers - .referrerPolicy(withDefaults()) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .referrerPolicy(withDefaults() + ) + ); + // @formatter:on assertHeaders(); } @@ -370,8 +371,10 @@ public class HeaderSpecTests { public void headersWhenReferrerPolicyCustomEnabledThenFeaturePolicyCustomWritten() { this.expectedHeaders.add(ReferrerPolicyServerHttpHeadersWriter.REFERRER_POLICY, ReferrerPolicy.NO_REFERRER_WHEN_DOWNGRADE.getPolicy()); - this.http.headers().referrerPolicy(ReferrerPolicy.NO_REFERRER_WHEN_DOWNGRADE); - + // @formatter:off + this.http.headers() + .referrerPolicy(ReferrerPolicy.NO_REFERRER_WHEN_DOWNGRADE); + // @formatter:on assertHeaders(); } @@ -379,29 +382,27 @@ public class HeaderSpecTests { public void headersWhenReferrerPolicyCustomEnabledInLambdaThenCustomReferrerPolicyWritten() { this.expectedHeaders.add(ReferrerPolicyServerHttpHeadersWriter.REFERRER_POLICY, ReferrerPolicy.NO_REFERRER_WHEN_DOWNGRADE.getPolicy()); - this.http - .headers(headers -> - headers - .referrerPolicy(referrerPolicy -> - referrerPolicy - .policy(ReferrerPolicy.NO_REFERRER_WHEN_DOWNGRADE) - ) - ); - + // @formatter:off + this.http.headers((headers) -> headers + .referrerPolicy((referrerPolicy) -> referrerPolicy + .policy(ReferrerPolicy.NO_REFERRER_WHEN_DOWNGRADE) + ) + ); + // @formatter:on assertHeaders(); } @Test public void headersWhenCustomHeadersWriter() { this.expectedHeaders.add(CUSTOM_HEADER, CUSTOM_VALUE); - this.http.headers(headers -> headers.writer(exchange -> { - return Mono.just(exchange) - .doOnNext(it -> { - it.getResponse().getHeaders().add(CUSTOM_HEADER, CUSTOM_VALUE); - }).then(); - - })); - + // @formatter:off + this.http.headers((headers) -> headers + .writer((exchange) -> Mono.just(exchange) + .doOnNext((it) -> it.getResponse().getHeaders().add(CUSTOM_HEADER, CUSTOM_VALUE)) + .then() + ) + ); + // @formatter:on assertHeaders(); } @@ -414,16 +415,11 @@ public class HeaderSpecTests { private void assertHeaders() { WebTestClient client = buildClient(); - FluxExchangeResult response = client.get() - .uri("https://example.com/") - .exchange() - .returnResult(String.class); - + FluxExchangeResult response = client.get().uri("https://example.com/").exchange() + .returnResult(String.class); Map> responseHeaders = response.getResponseHeaders(); - if (!this.expectedHeaders.isEmpty()) { - assertThat(responseHeaders).describedAs(response.toString()) - .containsAllEntriesOf(this.expectedHeaders); + assertThat(responseHeaders).describedAs(response.toString()).containsAllEntriesOf(this.expectedHeaders); } if (!this.headerNamesNotPresent.isEmpty()) { assertThat(responseHeaders.keySet()).doesNotContainAnyElementsOf(this.headerNamesNotPresent); @@ -433,4 +429,5 @@ public class HeaderSpecTests { private WebTestClient buildClient() { return WebTestClientBuilder.bindToWebFilters(this.http.build()).build(); } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/HttpsRedirectSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/HttpsRedirectSpecTests.java index fa0f89940d..044d0820d0 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/HttpsRedirectSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/HttpsRedirectSpecTests.java @@ -31,8 +31,8 @@ import org.springframework.security.web.server.util.matcher.PathPatternParserSer import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.reactive.config.EnableWebFlux; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; /** @@ -41,6 +41,7 @@ import static org.springframework.security.config.Customizer.withDefaults; * @author Josh Cummings */ public class HttpsRedirectSpecTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -48,133 +49,142 @@ public class HttpsRedirectSpecTests { @Autowired public void setApplicationContext(ApplicationContext context) { - this.client = WebTestClient.bindToApplicationContext(context).build(); + // @formatter:off + this.client = WebTestClient + .bindToApplicationContext(context) + .build(); + // @formatter:on } @Test public void getWhenSecureThenDoesNotRedirect() { this.spring.register(RedirectToHttpConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("https://localhost") .exchange() .expectStatus().isNotFound(); + // @formatter:on } @Test public void getWhenInsecureThenRespondsWithRedirectToSecure() { this.spring.register(RedirectToHttpConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("http://localhost") .exchange() .expectStatus().isFound() .expectHeader().valueEquals(HttpHeaders.LOCATION, "https://localhost"); + // @formatter:on } @Test public void getWhenInsecureAndRedirectConfiguredInLambdaThenRespondsWithRedirectToSecure() { this.spring.register(RedirectToHttpsInLambdaConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("http://localhost") .exchange() .expectStatus().isFound() .expectHeader().valueEquals(HttpHeaders.LOCATION, "https://localhost"); + // @formatter:on } @Test public void getWhenInsecureAndPathRequiresTransportSecurityThenRedirects() { this.spring.register(SometimesRedirectToHttpsConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("http://localhost:8080") .exchange() .expectStatus().isNotFound(); - this.client.get() .uri("http://localhost:8080/secure") .exchange() .expectStatus().isFound() .expectHeader().valueEquals(HttpHeaders.LOCATION, "https://localhost:8443/secure"); + // @formatter:on } @Test public void getWhenInsecureAndPathRequiresTransportSecurityInLambdaThenRedirects() { this.spring.register(SometimesRedirectToHttpsInLambdaConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("http://localhost:8080") .exchange() .expectStatus().isNotFound(); - this.client.get() .uri("http://localhost:8080/secure") .exchange() .expectStatus().isFound() .expectHeader().valueEquals(HttpHeaders.LOCATION, "https://localhost:8443/secure"); + // @formatter:on } @Test public void getWhenInsecureAndUsingCustomPortMapperThenRespondsWithRedirectToSecurePort() { this.spring.register(RedirectToHttpsViaCustomPortsConfig.class).autowire(); - PortMapper portMapper = this.spring.getContext().getBean(PortMapper.class); - when(portMapper.lookupHttpsPort(4080)).thenReturn(4443); - + given(portMapper.lookupHttpsPort(4080)).willReturn(4443); + // @formatter:off this.client.get() .uri("http://localhost:4080") .exchange() .expectStatus().isFound() .expectHeader().valueEquals(HttpHeaders.LOCATION, "https://localhost:4443"); + // @formatter:on } @Test public void getWhenInsecureAndUsingCustomPortMapperInLambdaThenRespondsWithRedirectToSecurePort() { this.spring.register(RedirectToHttpsViaCustomPortsInLambdaConfig.class).autowire(); - PortMapper portMapper = this.spring.getContext().getBean(PortMapper.class); - when(portMapper.lookupHttpsPort(4080)).thenReturn(4443); - + given(portMapper.lookupHttpsPort(4080)).willReturn(4443); + // @formatter:off this.client.get() .uri("http://localhost:4080") .exchange() .expectStatus().isFound() .expectHeader().valueEquals(HttpHeaders.LOCATION, "https://localhost:4443"); + // @formatter:on } @EnableWebFlux @EnableWebFluxSecurity static class RedirectToHttpConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off http .redirectToHttps(); // @formatter:on - return http.build(); } - } + } @EnableWebFlux @EnableWebFluxSecurity static class RedirectToHttpsInLambdaConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off http .redirectToHttps(withDefaults()); // @formatter:on - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class SometimesRedirectToHttpsConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -182,32 +192,33 @@ public class HttpsRedirectSpecTests { .redirectToHttps() .httpsRedirectWhen(new PathPatternParserServerWebExchangeMatcher("/secure")); // @formatter:on - return http.build(); } - } + } @EnableWebFlux @EnableWebFluxSecurity static class SometimesRedirectToHttpsInLambdaConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off http - .redirectToHttps(redirectToHttps -> + .redirectToHttps((redirectToHttps) -> redirectToHttps .httpsRedirectWhen(new PathPatternParserServerWebExchangeMatcher("/secure")) ); // @formatter:on - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class RedirectToHttpsViaCustomPortsConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -215,35 +226,37 @@ public class HttpsRedirectSpecTests { .redirectToHttps() .portMapper(portMapper()); // @formatter:on - return http.build(); } @Bean - public PortMapper portMapper() { + PortMapper portMapper() { return mock(PortMapper.class); } + } @EnableWebFlux @EnableWebFluxSecurity static class RedirectToHttpsViaCustomPortsInLambdaConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off http - .redirectToHttps(redirectToHttps -> + .redirectToHttps((redirectToHttps) -> redirectToHttps .portMapper(portMapper()) ); // @formatter:on - return http.build(); } @Bean - public PortMapper portMapper() { + PortMapper portMapper() { return mock(PortMapper.class); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java index 723251e4cf..1149e97c34 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/LogoutSpecTests.java @@ -18,13 +18,14 @@ package org.springframework.security.config.web.server; import org.junit.Test; import org.openqa.selenium.WebDriver; + import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; +import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.test.web.reactive.server.WebTestClient; -import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import static org.springframework.security.config.Customizer.withDefaults; @@ -38,209 +39,173 @@ public class LogoutSpecTests { @Test public void defaultLogout() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin().and() - .build(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) - .assertAt(); - - loginPage = loginPage.loginForm() - .username("user") - .password("invalid") - .submit(FormLoginTests.DefaultLoginPage.class) - .assertError(); - - FormLoginTests.HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(FormLoginTests.HomePage.class); - - homePage.assertAt(); - - loginPage = FormLoginTests.DefaultLogoutPage.to(driver) - .assertAt() - .logout(); - - loginPage - .assertAt() - .assertLogout(); - } - - @Test - public void customLogout() { - SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin().and() - .logout() - .requiresLogout(ServerWebExchangeMatchers.pathMatchers("/custom-logout")) - .and() - .build(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) - .assertAt(); - - loginPage = loginPage.loginForm() - .username("user") - .password("invalid") - .submit(FormLoginTests.DefaultLoginPage.class) - .assertError(); - - FormLoginTests.HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(FormLoginTests.HomePage.class); - - homePage.assertAt(); - - driver.get("http://localhost/custom-logout"); - - FormLoginTests.DefaultLoginPage.create(driver) - .assertAt() - .assertLogout(); - } - - @Test - public void logoutWhenCustomLogoutInLambdaThenCustomLogoutUsed() { - SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange(authorizeExchange -> - authorizeExchange - .anyExchange().authenticated() - ) - .formLogin(withDefaults()) - .logout(logout -> - logout - .requiresLogout(ServerWebExchangeMatchers.pathMatchers("/custom-logout")) - ) - .build(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) - .assertAt(); - - loginPage = loginPage.loginForm() - .username("user") - .password("invalid") - .submit(FormLoginTests.DefaultLoginPage.class) - .assertError(); - - FormLoginTests.HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(FormLoginTests.HomePage.class); - - homePage.assertAt(); - - driver.get("http://localhost/custom-logout"); - - FormLoginTests.DefaultLoginPage.create(driver) - .assertAt() - .assertLogout(); - } - - @Test - public void logoutWhenDisabledThenPostToLogoutDoesNothing() { - SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin().and() - .logout().disable() - .build(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(securityWebFilter) - .build(); - - WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) - .assertAt(); - - FormLoginTests.HomePage homePage = loginPage.loginForm() - .username("user") - .password("password") - .submit(FormLoginTests.HomePage.class); - - homePage.assertAt(); - - FormLoginTests.DefaultLogoutPage.to(driver) - .assertAt() - .logout(); - - homePage - .assertAt(); - } - - - @Test - public void logoutWhenCustomSecurityContextRepositoryThenLogsOut() { - WebSessionServerSecurityContextRepository repository = new WebSessionServerSecurityContextRepository(); - repository.setSpringSecurityContextAttrName("CUSTOM_CONTEXT_ATTR"); - SecurityWebFilterChain securityWebFilter = this.http - .securityContextRepository(repository) .authorizeExchange() .anyExchange().authenticated() .and() .formLogin() .and() - .logout() - .and() .build(); - WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(securityWebFilter) .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder .webTestClientSetup(webTestClient) .build(); - - FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) - .assertAt(); - + // @formatter:on + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage + .to(driver, FormLoginTests.DefaultLoginPage.class).assertAt(); + // @formatter:off + loginPage = loginPage.loginForm() + .username("user") + .password("invalid") + .submit(FormLoginTests.DefaultLoginPage.class) + .assertError(); FormLoginTests.HomePage homePage = loginPage.loginForm() .username("user") .password("password") .submit(FormLoginTests.HomePage.class); - + // @formatter:on homePage.assertAt(); - - FormLoginTests.DefaultLogoutPage.to(driver) - .assertAt() - .logout(); - - FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) - .assertAt(); + loginPage = FormLoginTests.DefaultLogoutPage.to(driver).assertAt().logout(); + loginPage.assertAt().assertLogout(); } + + @Test + public void customLogout() { + // @formatter:off + SecurityWebFilterChain securityWebFilter = this.http + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .logout() + .requiresLogout(ServerWebExchangeMatchers.pathMatchers("/custom-logout")) + .and() + .build(); + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(securityWebFilter) + .build(); + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage + .to(driver, FormLoginTests.DefaultLoginPage.class).assertAt(); + // @formatter:off + loginPage = loginPage.loginForm() + .username("user") + .password("invalid") + .submit(FormLoginTests.DefaultLoginPage.class) + .assertError(); + FormLoginTests.HomePage homePage = loginPage.loginForm() + .username("user") + .password("password") + .submit(FormLoginTests.HomePage.class); + homePage.assertAt(); + // @formatter:on + driver.get("http://localhost/custom-logout"); + FormLoginTests.DefaultLoginPage.create(driver).assertAt().assertLogout(); + } + + @Test + public void logoutWhenCustomLogoutInLambdaThenCustomLogoutUsed() { + // @formatter:off + SecurityWebFilterChain securityWebFilter = this.http + .authorizeExchange((exchange) -> exchange + .anyExchange().authenticated() + ) + .formLogin(withDefaults()) + .logout((logout) -> logout + .requiresLogout(ServerWebExchangeMatchers.pathMatchers("/custom-logout")) + ) + .build(); + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(securityWebFilter) + .build(); + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage + .to(driver, FormLoginTests.DefaultLoginPage.class).assertAt(); + // @formatter:off + loginPage = loginPage.loginForm() + .username("user") + .password("invalid") + .submit(FormLoginTests.DefaultLoginPage.class) + .assertError(); + FormLoginTests.HomePage homePage = loginPage.loginForm() + .username("user").password("password") + .submit(FormLoginTests.HomePage.class); + // @formatter:on + homePage.assertAt(); + driver.get("http://localhost/custom-logout"); + FormLoginTests.DefaultLoginPage.create(driver).assertAt().assertLogout(); + } + + @Test + public void logoutWhenDisabledThenPostToLogoutDoesNothing() { + // @formatter:off + SecurityWebFilterChain securityWebFilter = this.http + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .logout().disable() + .build(); + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(securityWebFilter) + .build(); + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage + .to(driver, FormLoginTests.DefaultLoginPage.class).assertAt(); + // @formatter:off + FormLoginTests.HomePage homePage = loginPage.loginForm() + .username("user") + .password("password") + .submit(FormLoginTests.HomePage.class); + // @formatter:on + homePage.assertAt(); + FormLoginTests.DefaultLogoutPage.to(driver).assertAt().logout(); + homePage.assertAt(); + } + + @Test + public void logoutWhenCustomSecurityContextRepositoryThenLogsOut() { + WebSessionServerSecurityContextRepository repository = new WebSessionServerSecurityContextRepository(); + repository.setSpringSecurityContextAttrName("CUSTOM_CONTEXT_ATTR"); + // @formatter:off + SecurityWebFilterChain securityWebFilter = this.http + .securityContextRepository(repository) + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .logout().and() + .build(); + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(securityWebFilter) + .build(); + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage + .to(driver, FormLoginTests.DefaultLoginPage.class).assertAt(); + // @formatter:off + FormLoginTests.HomePage homePage = loginPage.loginForm() + .username("user") + .password("password") + .submit(FormLoginTests.HomePage.class); + // @formatter:on + homePage.assertAt(); + FormLoginTests.DefaultLogoutPage.to(driver).assertAt().logout(); + FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class).assertAt(); + } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java index 0ea9c446da..171bcab955 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java @@ -21,6 +21,8 @@ import java.net.URI; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import reactor.core.publisher.Mono; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -56,12 +58,11 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.config.EnableWebFlux; -import reactor.core.publisher.Mono; -import static org.mockito.Mockito.any; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -71,6 +72,7 @@ import static org.mockito.Mockito.when; @RunWith(SpringRunner.class) @SecurityTestExecutionListeners public class OAuth2ClientSpecTests { + @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -80,7 +82,11 @@ public class OAuth2ClientSpecTests { @Autowired public void setApplicationContext(ApplicationContext context) { - this.client = WebTestClient.bindToApplicationContext(context).build(); + // @formatter:off + this.client = WebTestClient + .bindToApplicationContext(context) + .build(); + // @formatter:on } @Test @@ -89,13 +95,17 @@ public class OAuth2ClientSpecTests { this.spring.register(Config.class, AuthorizedClientController.class).autowire(); ReactiveClientRegistrationRepository repository = this.spring.getContext() .getBean(ReactiveClientRegistrationRepository.class); - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext().getBean(ServerOAuth2AuthorizedClientRepository.class); - when(repository.findByRegistrationId(any())).thenReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); - when(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); - - this.client.get().uri("/") - .exchange() - .expectStatus().is3xxRedirection(); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(ServerOAuth2AuthorizedClientRepository.class); + given(repository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + given(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + // @formatter:off + this.client.get() + .uri("/") + .exchange() + .expectStatus().is3xxRedirection(); + // @formatter:on } @Test @@ -103,22 +113,108 @@ public class OAuth2ClientSpecTests { this.spring.register(Config.class, AuthorizedClientController.class).autowire(); ReactiveClientRegistrationRepository repository = this.spring.getContext() .getBean(ReactiveClientRegistrationRepository.class); - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext().getBean(ServerOAuth2AuthorizedClientRepository.class); - when(repository.findByRegistrationId(any())).thenReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); - when(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); - - this.client.get().uri("/") + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(ServerOAuth2AuthorizedClientRepository.class); + given(repository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + given(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + // @formatter:off + this.client.get() + .uri("/") .exchange() .expectStatus().is3xxRedirection(); + // @formatter:on + } + + @Test + public void oauth2ClientWhenCustomObjectsThenUsed() { + this.spring.register(ClientRegistrationConfig.class, OAuth2ClientCustomConfig.class, + AuthorizedClientController.class).autowire(); + OAuth2ClientCustomConfig config = this.spring.getContext().getBean(OAuth2ClientCustomConfig.class); + ServerAuthenticationConverter converter = config.authenticationConverter; + ReactiveAuthenticationManager manager = config.manager; + ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; + ServerRequestCache requestCache = config.requestCache; + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/oauth2/code/registration-id").build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/oauth2/code/registration-id").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); + OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( + this.registration, authorizationExchange, accessToken); + given(authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(authorizationRequest)); + given(converter.convert(any())).willReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); + given(manager.authenticate(any())).willReturn(Mono.just(result)); + given(requestCache.getRedirectUri(any())).willReturn(Mono.just(URI.create("/saved-request"))); + // @formatter:off + this.client.get() + .uri((uriBuilder) -> uriBuilder + .path("/authorize/oauth2/code/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build() + ) + .exchange() + .expectStatus().is3xxRedirection(); + // @formatter:on + verify(converter).convert(any()); + verify(manager).authenticate(any()); + verify(requestCache).getRedirectUri(any()); + } + + @Test + public void oauth2ClientWhenCustomObjectsInLambdaThenUsed() { + this.spring.register(ClientRegistrationConfig.class, OAuth2ClientInLambdaCustomConfig.class, + AuthorizedClientController.class).autowire(); + OAuth2ClientInLambdaCustomConfig config = this.spring.getContext() + .getBean(OAuth2ClientInLambdaCustomConfig.class); + ServerAuthenticationConverter converter = config.authenticationConverter; + ReactiveAuthenticationManager manager = config.manager; + ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; + ServerRequestCache requestCache = config.requestCache; + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/oauth2/code/registration-id").build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/oauth2/code/registration-id").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); + OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( + this.registration, authorizationExchange, accessToken); + given(authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(authorizationRequest)); + given(converter.convert(any())).willReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); + given(manager.authenticate(any())).willReturn(Mono.just(result)); + given(requestCache.getRedirectUri(any())).willReturn(Mono.just(URI.create("/saved-request"))); + // @formatter:off + this.client.get() + .uri((uriBuilder) -> uriBuilder + .path("/authorize/oauth2/code/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build() + ) + .exchange() + .expectStatus().is3xxRedirection(); + // @formatter:on + verify(converter).convert(any()); + verify(manager).authenticate(any()); + verify(requestCache).getRedirectUri(any()); } @EnableWebFlux @EnableWebFluxSecurity static class Config { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off http .oauth2Client(); + // @formatter:on return http.build(); } @@ -131,157 +227,86 @@ public class OAuth2ClientSpecTests { ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { return mock(ServerOAuth2AuthorizedClientRepository.class); } + } @RestController static class AuthorizedClientController { + @GetMapping("/") String home(@RegisteredOAuth2AuthorizedClient("github") OAuth2AuthorizedClient authorizedClient) { return "home"; } - } - @Test - public void oauth2ClientWhenCustomObjectsThenUsed() { - this.spring.register(ClientRegistrationConfig.class, OAuth2ClientCustomConfig.class, AuthorizedClientController.class).autowire(); - - OAuth2ClientCustomConfig config = this.spring.getContext().getBean(OAuth2ClientCustomConfig.class); - - ServerAuthenticationConverter converter = config.authenticationConverter; - ReactiveAuthenticationManager manager = config.manager; - ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; - ServerRequestCache requestCache = config.requestCache; - - OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() - .redirectUri("/authorize/oauth2/code/registration-id") - .build(); - OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() - .redirectUri("/authorize/oauth2/code/registration-id") - .build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); - OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - - OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( - this.registration, authorizationExchange, accessToken); - - when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); - when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); - when(manager.authenticate(any())).thenReturn(Mono.just(result)); - when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request"))); - - this.client.get() - .uri(uriBuilder -> - uriBuilder.path("/authorize/oauth2/code/registration-id") - .queryParam(OAuth2ParameterNames.CODE, "code") - .queryParam(OAuth2ParameterNames.STATE, "state") - .build()) - .exchange() - .expectStatus().is3xxRedirection(); - - verify(converter).convert(any()); - verify(manager).authenticate(any()); - verify(requestCache).getRedirectUri(any()); } @EnableWebFlux @EnableWebFluxSecurity static class ClientRegistrationConfig { - private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() - .build(); + + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); @Bean InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { return new InMemoryReactiveClientRegistrationRepository(this.clientRegistration); } + } @Configuration static class OAuth2ClientCustomConfig { + ReactiveAuthenticationManager manager = mock(ReactiveAuthenticationManager.class); ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); - ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + ServerAuthorizationRequestRepository authorizationRequestRepository = mock( + ServerAuthorizationRequestRepository.class); ServerRequestCache requestCache = mock(ServerRequestCache.class); @Bean - public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + // @formatter:off http .oauth2Client() .authenticationConverter(this.authenticationConverter) .authenticationManager(this.manager) .authorizationRequestRepository(this.authorizationRequestRepository) .and() - .requestCache(c -> c.requestCache(this.requestCache)); + .requestCache((c) -> c.requestCache(this.requestCache)); + // @formatter:on return http.build(); } - } - @Test - public void oauth2ClientWhenCustomObjectsInLambdaThenUsed() { - this.spring.register(ClientRegistrationConfig.class, OAuth2ClientInLambdaCustomConfig.class, AuthorizedClientController.class).autowire(); - - OAuth2ClientInLambdaCustomConfig config = this.spring.getContext().getBean(OAuth2ClientInLambdaCustomConfig.class); - - ServerAuthenticationConverter converter = config.authenticationConverter; - ReactiveAuthenticationManager manager = config.manager; - ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; - ServerRequestCache requestCache = config.requestCache; - - OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() - .redirectUri("/authorize/oauth2/code/registration-id") - .build(); - OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() - .redirectUri("/authorize/oauth2/code/registration-id") - .build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); - OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - - OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( - this.registration, authorizationExchange, accessToken); - - when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); - when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); - when(manager.authenticate(any())).thenReturn(Mono.just(result)); - when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request"))); - - this.client.get() - .uri(uriBuilder -> - uriBuilder.path("/authorize/oauth2/code/registration-id") - .queryParam(OAuth2ParameterNames.CODE, "code") - .queryParam(OAuth2ParameterNames.STATE, "state") - .build()) - .exchange() - .expectStatus().is3xxRedirection(); - - verify(converter).convert(any()); - verify(manager).authenticate(any()); - verify(requestCache).getRedirectUri(any()); } @Configuration static class OAuth2ClientInLambdaCustomConfig { + ReactiveAuthenticationManager manager = mock(ReactiveAuthenticationManager.class); ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); - ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + ServerAuthorizationRequestRepository authorizationRequestRepository = mock( + ServerAuthorizationRequestRepository.class); ServerRequestCache requestCache = mock(ServerRequestCache.class); @Bean - public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + // @formatter:off http - .oauth2Client(oauth2Client -> + .oauth2Client((oauth2Client) -> oauth2Client .authenticationConverter(this.authenticationConverter) .authenticationManager(this.manager) .authorizationRequestRepository(this.authorizationRequestRepository)) - .requestCache(c -> c.requestCache(this.requestCache)); + .requestCache((c) -> c.requestCache(this.requestCache)); + // @formatter:on return http.build(); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index b1e5662a3e..211d353203 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -16,10 +16,16 @@ package org.springframework.security.config.web.server; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Rule; import org.junit.Test; import org.mockito.stubbing.Answer; import org.openqa.selenium.WebDriver; +import reactor.core.publisher.Mono; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -74,6 +80,7 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtValidationException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; @@ -95,19 +102,13 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebHandler; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * @author Rob Winch @@ -124,69 +125,54 @@ public class OAuth2LoginTests { @Autowired private WebFilterChainProxy springSecurity; - private static ClientRegistration github = CommonOAuth2Provider.GITHUB - .getBuilder("github") - .clientId("client") - .clientSecret("secret") - .build(); + private static ClientRegistration github = CommonOAuth2Provider.GITHUB.getBuilder("github").clientId("client") + .clientSecret("secret").build(); - private static ClientRegistration google = CommonOAuth2Provider.GOOGLE - .getBuilder("google") - .clientId("client") - .clientSecret("secret") - .build(); + private static ClientRegistration google = CommonOAuth2Provider.GOOGLE.getBuilder("google").clientId("client") + .clientSecret("secret").build(); @Autowired public void setApplicationContext(ApplicationContext context) { if (context.getBeanNamesForType(WebHandler.class).length > 0) { - this.client = WebTestClient.bindToApplicationContext(context) + // @formatter:off + this.client = WebTestClient + .bindToApplicationContext(context) .build(); + // @formatter:on } } @Test public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() { this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire(); - + // @formatter:off WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(this.springSecurity) .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder .webTestClientSetup(webTestClient) .build(); - - FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage - .to(driver, FormLoginTests.DefaultLoginPage.class) + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage.to(driver, FormLoginTests.DefaultLoginPage.class) .assertAt() .assertLoginFormNotPresent() .oauth2Login() - .assertClientRegistrationByName(this.github.getClientName()) + .assertClientRegistrationByName(OAuth2LoginTests.github.getClientName()) .and(); - } - - @EnableWebFluxSecurity - static class OAuth2LoginWithMultipleClientRegistrations { - @Bean - InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { - return new InMemoryReactiveClientRegistrationRepository(github, google); - } + // @formatter:on } @Test public void defaultLoginPageWithSingleClientRegistrationThenRedirect() { this.spring.register(OAuth2LoginWithSingleClientRegistrations.class).autowire(); - + // @formatter:off WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(new GitHubWebFilter(), this.springSecurity) .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder .webTestClientSetup(webTestClient) .build(); - + // @formatter:on driver.get("http://localhost/"); - assertThat(driver.getCurrentUrl()).startsWith("https://github.com/login/oauth/authorize"); } @@ -194,99 +180,46 @@ public class OAuth2LoginTests { @Test public void defaultLoginPageWithSingleClientRegistrationAndXhrRequestThenDoesNotRedirectForAuthorization() { this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, WebFluxConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("/") .header("X-Requested-With", "XMLHttpRequest") .exchange() .expectStatus().is3xxRedirection() .expectHeader().valueEquals(HttpHeaders.LOCATION, "/login"); - } - - @EnableWebFlux - static class WebFluxConfig { } - - @EnableWebFluxSecurity - static class OAuth2LoginWithSingleClientRegistrations { - @Bean - InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { - return new InMemoryReactiveClientRegistrationRepository(github); - } + // @formatter:on } @Test public void oauth2AuthorizeWhenCustomObjectsThenUsed() { - this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, - OAuth2AuthorizeWithMockObjectsConfig.class, + this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, OAuth2AuthorizeWithMockObjectsConfig.class, AuthorizedClientController.class).autowire(); - - OAuth2AuthorizeWithMockObjectsConfig config = this.spring.getContext().getBean(OAuth2AuthorizeWithMockObjectsConfig.class); - + OAuth2AuthorizeWithMockObjectsConfig config = this.spring.getContext() + .getBean(OAuth2AuthorizeWithMockObjectsConfig.class); ServerOAuth2AuthorizedClientRepository authorizedClientRepository = config.authorizedClientRepository; ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; ServerRequestCache requestCache = config.requestCache; - - when(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); - when(authorizationRequestRepository.saveAuthorizationRequest(any(), any())).thenReturn(Mono.empty()); - when(requestCache.removeMatchingRequest(any())).thenReturn(Mono.empty()); - when(requestCache.saveRequest(any())).thenReturn(Mono.empty()); - + given(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + given(authorizationRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); + given(requestCache.removeMatchingRequest(any())).willReturn(Mono.empty()); + given(requestCache.saveRequest(any())).willReturn(Mono.empty()); + // @formatter:off this.client.get() .uri("/") .exchange() .expectStatus().is3xxRedirection(); - + // @formatter:on verify(authorizedClientRepository).loadAuthorizedClient(any(), any(), any()); verify(authorizationRequestRepository).saveAuthorizationRequest(any(), any()); verify(requestCache).saveRequest(any()); } - @EnableWebFlux - static class OAuth2AuthorizeWithMockObjectsConfig { - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = - mock(ServerOAuth2AuthorizedClientRepository.class); - - ServerAuthorizationRequestRepository authorizationRequestRepository = - mock(ServerAuthorizationRequestRepository.class); - - ServerRequestCache requestCache = mock(ServerRequestCache.class); - - @Bean - SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { - http - .requestCache() - .requestCache(this.requestCache) - .and() - .oauth2Login() - .authorizationRequestRepository(this.authorizationRequestRepository); - return http.build(); - } - - @Bean - ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { - return this.authorizedClientRepository; - } - } - - @RestController - static class AuthorizedClientController { - @GetMapping("/") - String home(@RegisteredOAuth2AuthorizedClient("github") OAuth2AuthorizedClient authorizedClient) { - return "home"; - } - } - @Test public void oauth2LoginWhenCustomObjectsThenUsed() { this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, OAuth2LoginMockAuthenticationManagerConfig.class).autowire(); - String redirectLocation = "/custom-redirect-location"; - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(this.springSecurity) - .build(); - + WebTestClient webTestClient = WebTestClientBuilder.bindToWebFilters(this.springSecurity).build(); OAuth2LoginMockAuthenticationManagerConfig config = this.spring.getContext() .getBean(OAuth2LoginMockAuthenticationManagerConfig.class); ServerAuthenticationConverter converter = config.authenticationConverter; @@ -294,31 +227,28 @@ public class OAuth2LoginTests { ServerWebExchangeMatcher matcher = config.matcher; ServerOAuth2AuthorizationRequestResolver resolver = config.resolver; ServerAuthenticationSuccessHandler successHandler = config.successHandler; - OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); OAuth2User user = TestOAuth2Users.create(); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - - OAuth2LoginAuthenticationToken result = new OAuth2LoginAuthenticationToken(github, exchange, user, user.getAuthorities(), accessToken); - - when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); - when(manager.authenticate(any())).thenReturn(Mono.just(result)); - when(matcher.matches(any())).thenReturn(ServerWebExchangeMatcher.MatchResult.match()); - when(resolver.resolve(any())).thenReturn(Mono.empty()); - when(successHandler.onAuthenticationSuccess(any(), any())).thenAnswer((Answer>) invocation -> { + OAuth2LoginAuthenticationToken result = new OAuth2LoginAuthenticationToken(github, exchange, user, + user.getAuthorities(), accessToken); + given(converter.convert(any())).willReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); + given(manager.authenticate(any())).willReturn(Mono.just(result)); + given(matcher.matches(any())).willReturn(ServerWebExchangeMatcher.MatchResult.match()); + given(resolver.resolve(any())).willReturn(Mono.empty()); + given(successHandler.onAuthenticationSuccess(any(), any())).willAnswer((Answer>) (invocation) -> { WebFilterExchange webFilterExchange = invocation.getArgument(0); Authentication authentication = invocation.getArgument(1); - return new RedirectServerAuthenticationSuccessHandler(redirectLocation) .onAuthenticationSuccess(webFilterExchange, authentication); }); - + // @formatter:off webTestClient.get() - .uri("/login/oauth2/code/github") - .exchange() - .expectStatus().is3xxRedirection() - .expectHeader().valueEquals("Location", redirectLocation); - + .uri("/login/oauth2/code/github") + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals("Location", redirectLocation); + // @formatter:on verify(converter).convert(any()); verify(manager).authenticate(any()); verify(matcher).matches(any()); @@ -330,14 +260,13 @@ public class OAuth2LoginTests { public void oauth2LoginFailsWhenCustomObjectsThenUsed() { this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, OAuth2LoginMockAuthenticationManagerConfig.class).autowire(); - String redirectLocation = "/custom-redirect-location"; String failureRedirectLocation = "/failure-redirect-location"; - + // @formatter:off WebTestClient webTestClient = WebTestClientBuilder .bindToWebFilters(this.springSecurity) .build(); - + // @formatter:on OAuth2LoginMockAuthenticationManagerConfig config = this.spring.getContext() .getBean(OAuth2LoginMockAuthenticationManagerConfig.class); ServerAuthenticationConverter converter = config.authenticationConverter; @@ -346,32 +275,30 @@ public class OAuth2LoginTests { ServerOAuth2AuthorizationRequestResolver resolver = config.resolver; ServerAuthenticationSuccessHandler successHandler = config.successHandler; ServerAuthenticationFailureHandler failureHandler = config.failureHandler; - - when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); - when(manager.authenticate(any())).thenReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("error"), "message"))); - when(matcher.matches(any())).thenReturn(ServerWebExchangeMatcher.MatchResult.match()); - when(resolver.resolve(any())).thenReturn(Mono.empty()); - when(successHandler.onAuthenticationSuccess(any(), any())).thenAnswer((Answer>) invocation -> { + given(converter.convert(any())).willReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); + given(manager.authenticate(any())) + .willReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("error"), "message"))); + given(matcher.matches(any())).willReturn(ServerWebExchangeMatcher.MatchResult.match()); + given(resolver.resolve(any())).willReturn(Mono.empty()); + given(successHandler.onAuthenticationSuccess(any(), any())).willAnswer((Answer>) (invocation) -> { WebFilterExchange webFilterExchange = invocation.getArgument(0); Authentication authentication = invocation.getArgument(1); - return new RedirectServerAuthenticationSuccessHandler(redirectLocation) .onAuthenticationSuccess(webFilterExchange, authentication); }); - when(failureHandler.onAuthenticationFailure(any(), any())).thenAnswer((Answer>) invocation -> { + given(failureHandler.onAuthenticationFailure(any(), any())).willAnswer((Answer>) (invocation) -> { WebFilterExchange webFilterExchange = invocation.getArgument(0); AuthenticationException authenticationException = invocation.getArgument(1); - return new RedirectServerAuthenticationFailureHandler(failureRedirectLocation) .onAuthenticationFailure(webFilterExchange, authenticationException); }); - + // @formatter:off webTestClient.get() .uri("/login/oauth2/code/github") .exchange() .expectStatus().is3xxRedirection() .expectHeader().valueEquals("Location", failureRedirectLocation); - + // @formatter:on verify(converter).convert(any()); verify(manager).authenticate(any()); verify(matcher).matches(any()); @@ -379,8 +306,287 @@ public class OAuth2LoginTests { verify(failureHandler).onAuthenticationFailure(any(), any()); } + @Test + public void oauth2LoginWhenCustomObjectsInLambdaThenUsed() { + this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, + OAuth2LoginMockAuthenticationManagerInLambdaConfig.class).autowire(); + String redirectLocation = "/custom-redirect-location"; + WebTestClient webTestClient = WebTestClientBuilder.bindToWebFilters(this.springSecurity).build(); + OAuth2LoginMockAuthenticationManagerInLambdaConfig config = this.spring.getContext() + .getBean(OAuth2LoginMockAuthenticationManagerInLambdaConfig.class); + ServerAuthenticationConverter converter = config.authenticationConverter; + ReactiveAuthenticationManager manager = config.manager; + ServerWebExchangeMatcher matcher = config.matcher; + ServerOAuth2AuthorizationRequestResolver resolver = config.resolver; + ServerAuthenticationSuccessHandler successHandler = config.successHandler; + OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2User user = TestOAuth2Users.create(); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); + OAuth2LoginAuthenticationToken result = new OAuth2LoginAuthenticationToken(github, exchange, user, + user.getAuthorities(), accessToken); + given(converter.convert(any())).willReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); + given(manager.authenticate(any())).willReturn(Mono.just(result)); + given(matcher.matches(any())).willReturn(ServerWebExchangeMatcher.MatchResult.match()); + given(resolver.resolve(any())).willReturn(Mono.empty()); + given(successHandler.onAuthenticationSuccess(any(), any())).willAnswer((Answer>) (invocation) -> { + WebFilterExchange webFilterExchange = invocation.getArgument(0); + Authentication authentication = invocation.getArgument(1); + return new RedirectServerAuthenticationSuccessHandler(redirectLocation) + .onAuthenticationSuccess(webFilterExchange, authentication); + }); + // @formatter:off + webTestClient.get() + .uri("/login/oauth2/code/github") + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals("Location", redirectLocation); + // @formatter:on + verify(converter).convert(any()); + verify(manager).authenticate(any()); + verify(matcher).matches(any()); + verify(resolver).resolve(any()); + verify(successHandler).onAuthenticationSuccess(any(), any()); + } + + @Test + public void oauth2LoginWhenCustomBeansThenUsed() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, OAuth2LoginWithCustomBeansConfig.class) + .autowire(); + // @formatter:off + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + // @formatter:on + OAuth2LoginWithCustomBeansConfig config = this.spring.getContext() + .getBean(OAuth2LoginWithCustomBeansConfig.class); + OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); + OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken token = new OAuth2AuthorizationCodeAuthenticationToken(google, + exchange, accessToken); + ServerAuthenticationConverter converter = config.authenticationConverter; + given(converter.convert(any())).willReturn(Mono.just(token)); + ServerSecurityContextRepository securityContextRepository = config.securityContextRepository; + given(securityContextRepository.save(any(), any())).willReturn(Mono.empty()); + given(securityContextRepository.load(any())).willReturn(authentication(token)); + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()) + .additionalParameters(additionalParameters) + .build(); + // @formatter:on + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + given(tokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + OidcUser user = TestOidcUsers.create(); + ReactiveOAuth2UserService userService = config.userService; + given(userService.loadUser(any())).willReturn(Mono.just(user)); + // @formatter:off + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus().is3xxRedirection(); + // @formatter:on + verify(config.jwtDecoderFactory).createDecoder(any()); + verify(tokenResponseClient).getTokenResponse(any()); + verify(securityContextRepository).save(any(), any()); + } + + // gh-5562 + @Test + public void oauth2LoginWhenAccessTokenRequestFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, OAuth2LoginWithCustomBeansConfig.class) + .autowire(); + // @formatter:off + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests + .request() + .scope("openid") + .build(); + // @formatter:on + OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken( + google, exchange, accessToken); + OAuth2LoginWithCustomBeansConfig config = this.spring.getContext() + .getBean(OAuth2LoginWithCustomBeansConfig.class); + ServerAuthenticationConverter converter = config.authenticationConverter; + given(converter.convert(any())).willReturn(Mono.just(authenticationToken)); + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + OAuth2Error oauth2Error = new OAuth2Error("invalid_request", "Invalid request", null); + given(tokenResponseClient.getTokenResponse(any())).willThrow(new OAuth2AuthenticationException(oauth2Error)); + // @formatter:off + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals("Location", "/login?error"); + // @formatter:on + } + + // gh-6484 + @Test + public void oauth2LoginWhenIdTokenValidationFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, OAuth2LoginWithCustomBeansConfig.class) + .autowire(); + WebTestClient webTestClient = WebTestClientBuilder.bindToWebFilters(this.springSecurity).build(); + OAuth2LoginWithCustomBeansConfig config = this.spring.getContext() + .getBean(OAuth2LoginWithCustomBeansConfig.class); + // @formatter:off + OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests + .request() + .scope("openid") + .build(); + OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses + .success() + .build(); + // @formatter:on + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken( + google, exchange, accessToken); + ServerAuthenticationConverter converter = config.authenticationConverter; + given(converter.convert(any())).willReturn(Mono.just(authenticationToken)); + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()) + .additionalParameters(additionalParameters) + .build(); + // @formatter:on + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + given(tokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + ReactiveJwtDecoderFactory jwtDecoderFactory = config.jwtDecoderFactory; + OAuth2Error oauth2Error = new OAuth2Error("invalid_id_token", "Invalid ID Token", null); + given(jwtDecoderFactory.createDecoder(any())).willReturn((token) -> Mono + .error(new JwtValidationException("ID Token validation failed", Collections.singleton(oauth2Error)))); + // @formatter:off + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals("Location", "/login?error"); + // @formatter:on + } + + @Test + public void logoutWhenUsingOidcLogoutHandlerThenRedirects() { + this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire(); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, getBean(ClientRegistration.class).getRegistrationId()); + ServerSecurityContextRepository repository = getBean(ServerSecurityContextRepository.class); + given(repository.load(any())).willReturn(authentication(token)); + // @formatter:off + this.client.post() + .uri("/logout") + .exchange() + .expectHeader().valueEquals("Location", "https://logout?id_token_hint=id-token"); + // @formatter:on + } + + // gh-8609 + @Test + public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire(); + WebTestClient webTestClient = WebTestClientBuilder.bindToWebFilters(this.springSecurity).build(); + // @formatter:off + webTestClient.get() + .uri("/login/oauth2/code/google") + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals("Location", "/login?error"); + // @formatter:on + } + + Mono authentication(Authentication authentication) { + SecurityContext context = new SecurityContextImpl(); + context.setAuthentication(authentication); + return Mono.just(context); + } + + T getBean(Class beanClass) { + return this.spring.getContext().getBean(beanClass); + } + + @EnableWebFluxSecurity + static class OAuth2LoginWithMultipleClientRegistrations { + + @Bean + InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { + return new InMemoryReactiveClientRegistrationRepository(github, google); + } + + } + + @EnableWebFlux + static class WebFluxConfig { + + } + + @EnableWebFluxSecurity + static class OAuth2LoginWithSingleClientRegistrations { + + @Bean + InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { + return new InMemoryReactiveClientRegistrationRepository(github); + } + + } + + @EnableWebFlux + static class OAuth2AuthorizeWithMockObjectsConfig { + + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = mock( + ServerOAuth2AuthorizedClientRepository.class); + + ServerAuthorizationRequestRepository authorizationRequestRepository = mock( + ServerAuthorizationRequestRepository.class); + + ServerRequestCache requestCache = mock(ServerRequestCache.class); + + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off + http + .requestCache() + .requestCache(this.requestCache) + .and() + .oauth2Login() + .authorizationRequestRepository(this.authorizationRequestRepository); + // @formatter:on + return http.build(); + } + + @Bean + ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { + return this.authorizedClientRepository; + } + + } + + @RestController + static class AuthorizedClientController { + + @GetMapping("/") + String home(@RegisteredOAuth2AuthorizedClient("github") OAuth2AuthorizedClient authorizedClient) { + return "home"; + } + + } + @Configuration static class OAuth2LoginMockAuthenticationManagerConfig { + ReactiveAuthenticationManager manager = mock(ReactiveAuthenticationManager.class); ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); @@ -394,74 +600,28 @@ public class OAuth2LoginTests { ServerAuthenticationFailureHandler failureHandler = mock(ServerAuthenticationFailureHandler.class); @Bean - public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + // @formatter:off http .authorizeExchange() .anyExchange().authenticated() .and() .oauth2Login() - .authenticationConverter(authenticationConverter) - .authenticationManager(manager) - .authenticationMatcher(matcher) - .authorizationRequestResolver(resolver) - .authenticationSuccessHandler(successHandler) - .authenticationFailureHandler(failureHandler); + .authenticationConverter(this.authenticationConverter) + .authenticationManager(this.manager) + .authenticationMatcher(this.matcher) + .authorizationRequestResolver(this.resolver) + .authenticationSuccessHandler(this.successHandler) + .authenticationFailureHandler(this.failureHandler); + // @formatter:on return http.build(); } - } - @Test - public void oauth2LoginWhenCustomObjectsInLambdaThenUsed() { - this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, - OAuth2LoginMockAuthenticationManagerInLambdaConfig.class).autowire(); - - String redirectLocation = "/custom-redirect-location"; - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(this.springSecurity) - .build(); - - OAuth2LoginMockAuthenticationManagerInLambdaConfig config = this.spring.getContext() - .getBean(OAuth2LoginMockAuthenticationManagerInLambdaConfig.class); - ServerAuthenticationConverter converter = config.authenticationConverter; - ReactiveAuthenticationManager manager = config.manager; - ServerWebExchangeMatcher matcher = config.matcher; - ServerOAuth2AuthorizationRequestResolver resolver = config.resolver; - ServerAuthenticationSuccessHandler successHandler = config.successHandler; - - OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); - OAuth2User user = TestOAuth2Users.create(); - OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - - OAuth2LoginAuthenticationToken result = new OAuth2LoginAuthenticationToken(github, exchange, user, user.getAuthorities(), accessToken); - - when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); - when(manager.authenticate(any())).thenReturn(Mono.just(result)); - when(matcher.matches(any())).thenReturn(ServerWebExchangeMatcher.MatchResult.match()); - when(resolver.resolve(any())).thenReturn(Mono.empty()); - when(successHandler.onAuthenticationSuccess(any(), any())).thenAnswer((Answer>) invocation -> { - WebFilterExchange webFilterExchange = invocation.getArgument(0); - Authentication authentication = invocation.getArgument(1); - - return new RedirectServerAuthenticationSuccessHandler(redirectLocation) - .onAuthenticationSuccess(webFilterExchange, authentication); - }); - - webTestClient.get() - .uri("/login/oauth2/code/github") - .exchange() - .expectStatus().is3xxRedirection() - .expectHeader().valueEquals("Location", redirectLocation); - - verify(converter).convert(any()); - verify(manager).authenticate(any()); - verify(matcher).matches(any()); - verify(resolver).resolve(any()); - verify(successHandler).onAuthenticationSuccess(any(), any()); } @Configuration static class OAuth2LoginMockAuthenticationManagerInLambdaConfig { + ReactiveAuthenticationManager manager = mock(ReactiveAuthenticationManager.class); ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); @@ -473,150 +633,25 @@ public class OAuth2LoginTests { ServerAuthenticationSuccessHandler successHandler = mock(ServerAuthenticationSuccessHandler.class); @Bean - public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + // @formatter:off http - .authorizeExchange(exchanges -> + .authorizeExchange((exchanges) -> exchanges .anyExchange().authenticated() ) - .oauth2Login(oauth2Login -> + .oauth2Login((oauth2Login) -> oauth2Login - .authenticationConverter(authenticationConverter) - .authenticationManager(manager) - .authenticationMatcher(matcher) - .authorizationRequestResolver(resolver) - .authenticationSuccessHandler(successHandler) + .authenticationConverter(this.authenticationConverter) + .authenticationManager(this.manager) + .authenticationMatcher(this.matcher) + .authorizationRequestResolver(this.resolver) + .authenticationSuccessHandler(this.successHandler) ); + // @formatter:on return http.build(); } - } - @Test - public void oauth2LoginWhenCustomBeansThenUsed() { - this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, - OAuth2LoginWithCustomBeansConfig.class).autowire(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(this.springSecurity) - .build(); - - OAuth2LoginWithCustomBeansConfig config = this.spring.getContext() - .getBean(OAuth2LoginWithCustomBeansConfig.class); - - OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); - OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); - OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); - OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); - OAuth2AuthorizationCodeAuthenticationToken token = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken); - - ServerAuthenticationConverter converter = config.authenticationConverter; - when(converter.convert(any())).thenReturn(Mono.just(token)); - - ServerSecurityContextRepository securityContextRepository = config.securityContextRepository; - when(securityContextRepository.save(any(), any())).thenReturn(Mono.empty()); - when(securityContextRepository.load(any())).thenReturn(authentication(token)); - - Map additionalParameters = new HashMap<>(); - additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) - .tokenType(accessToken.getTokenType()) - .scopes(accessToken.getScopes()) - .additionalParameters(additionalParameters) - .build(); - ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; - when(tokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OidcUser user = TestOidcUsers.create(); - ReactiveOAuth2UserService userService = config.userService; - when(userService.loadUser(any())).thenReturn(Mono.just(user)); - - webTestClient.get() - .uri("/login/oauth2/code/google") - .exchange() - .expectStatus().is3xxRedirection(); - - verify(config.jwtDecoderFactory).createDecoder(any()); - verify(tokenResponseClient).getTokenResponse(any()); - verify(securityContextRepository).save(any(), any()); - } - - // gh-5562 - @Test - public void oauth2LoginWhenAccessTokenRequestFailsThenDefaultRedirectToLogin() { - this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, - OAuth2LoginWithCustomBeansConfig.class).autowire(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(this.springSecurity) - .build(); - - OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); - OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); - OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); - OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); - OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken); - - OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class); - - ServerAuthenticationConverter converter = config.authenticationConverter; - when(converter.convert(any())).thenReturn(Mono.just(authenticationToken)); - - ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; - OAuth2Error oauth2Error = new OAuth2Error("invalid_request", "Invalid request", null); - when(tokenResponseClient.getTokenResponse(any())).thenThrow(new OAuth2AuthenticationException(oauth2Error)); - - webTestClient.get() - .uri("/login/oauth2/code/google") - .exchange() - .expectStatus() - .is3xxRedirection() - .expectHeader() - .valueEquals("Location", "/login?error"); - } - - // gh-6484 - @Test - public void oauth2LoginWhenIdTokenValidationFailsThenDefaultRedirectToLogin() { - this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, - OAuth2LoginWithCustomBeansConfig.class).autowire(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(this.springSecurity) - .build(); - - OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class); - - OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); - OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); - OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); - OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); - OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken); - - ServerAuthenticationConverter converter = config.authenticationConverter; - when(converter.convert(any())).thenReturn(Mono.just(authenticationToken)); - - Map additionalParameters = new HashMap<>(); - additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) - .tokenType(accessToken.getTokenType()) - .scopes(accessToken.getScopes()) - .additionalParameters(additionalParameters) - .build(); - ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; - when(tokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - ReactiveJwtDecoderFactory jwtDecoderFactory = config.jwtDecoderFactory; - OAuth2Error oauth2Error = new OAuth2Error("invalid_id_token", "Invalid ID Token", null); - when(jwtDecoderFactory.createDecoder(any())).thenReturn(token -> - Mono.error(new JwtValidationException("ID Token validation failed", Collections.singleton(oauth2Error)))); - - webTestClient.get() - .uri("/login/oauth2/code/google") - .exchange() - .expectStatus() - .is3xxRedirection() - .expectHeader() - .valueEquals("Location", "/login?error"); } @Configuration @@ -624,8 +659,8 @@ public class OAuth2LoginTests { ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); - ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = - mock(ReactiveOAuth2AccessTokenResponseClient.class); + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = mock( + ReactiveOAuth2AccessTokenResponseClient.class); ReactiveOAuth2UserService userService = mock(ReactiveOAuth2UserService.class); @@ -634,35 +669,35 @@ public class OAuth2LoginTests { ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class); @Bean - public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { // @formatter:off http .authorizeExchange() .anyExchange().authenticated() .and() .oauth2Login() - .authenticationConverter(authenticationConverter) + .authenticationConverter(this.authenticationConverter) .authenticationManager(authenticationManager()) - .securityContextRepository(securityContextRepository); + .securityContextRepository(this.securityContextRepository); return http.build(); // @formatter:on } private ReactiveAuthenticationManager authenticationManager() { - OidcAuthorizationCodeReactiveAuthenticationManager oidc = - new OidcAuthorizationCodeReactiveAuthenticationManager(tokenResponseClient, userService); + OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager( + this.tokenResponseClient, this.userService); oidc.setJwtDecoderFactory(jwtDecoderFactory()); return oidc; } @Bean - public ReactiveJwtDecoderFactory jwtDecoderFactory() { - return jwtDecoderFactory; + ReactiveJwtDecoderFactory jwtDecoderFactory() { + return this.jwtDecoderFactory; } @Bean - public ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient() { - return tokenResponseClient; + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient() { + return this.tokenResponseClient; } private static class JwtDecoderFactory implements ReactiveJwtDecoderFactory { @@ -673,49 +708,34 @@ public class OAuth2LoginTests { } private ReactiveJwtDecoder getJwtDecoder() { - return token -> { + return (token) -> { Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.SUB, "subject"); claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer"); claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client")); claims.put(IdTokenClaimNames.AZP, "client"); - Jwt jwt = jwt().claims(c -> c.putAll(claims)).build(); + Jwt jwt = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); return Mono.just(jwt); }; } + } - } - @Test - public void logoutWhenUsingOidcLogoutHandlerThenRedirects() { - this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire(); - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - getBean(ClientRegistration.class).getRegistrationId()); - - ServerSecurityContextRepository repository = getBean(ServerSecurityContextRepository.class); - when(repository.load(any())).thenReturn(authentication(token)); - - this.client.post().uri("/logout") - .exchange() - .expectHeader().valueEquals("Location", "https://logout?id_token_hint=id-token"); } @EnableWebFlux @EnableWebFluxSecurity static class OAuth2LoginConfigWithOidcLogoutSuccessHandler { - private final ServerSecurityContextRepository repository = - mock(ServerSecurityContextRepository.class); - private final ClientRegistration withLogout = - TestClientRegistrations.clientRegistration() - .providerConfigurationMetadata(Collections.singletonMap( - "end_session_endpoint", "https://logout")).build(); + + private final ServerSecurityContextRepository repository = mock(ServerSecurityContextRepository.class); + + private final ClientRegistration withLogout = TestClientRegistrations.clientRegistration() + .providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://logout")) + .build(); @Bean - public SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { - + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off http .csrf().disable() .logout() @@ -726,7 +746,7 @@ public class OAuth2LoginTests { new InMemoryReactiveClientRegistrationRepository(this.withLogout))) .and() .securityContextRepository(this.repository); - + // @formatter:on return http.build(); } @@ -739,24 +759,7 @@ public class OAuth2LoginTests { ClientRegistration clientRegistration() { return this.withLogout; } - } - // gh-8609 - @Test - public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() { - this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire(); - - WebTestClient webTestClient = WebTestClientBuilder - .bindToWebFilters(this.springSecurity) - .build(); - - webTestClient.get() - .uri("/login/oauth2/code/google") - .exchange() - .expectStatus() - .is3xxRedirection() - .expectHeader() - .valueEquals("Location", "/login?error"); } static class GitHubWebFilter implements WebFilter { @@ -768,16 +771,7 @@ public class OAuth2LoginTests { } return chain.filter(exchange); } - } - Mono authentication(Authentication authentication) { - SecurityContext context = new SecurityContextImpl(); - context.setAuthentication(authentication); - return Mono.just(context); - } - - T getBean(Class beanClass) { - return this.spring.getContext().getBean(beanClass); } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index 943fe62d71..09c8b4ae45 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -27,6 +27,7 @@ import java.util.Base64; import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; + import javax.annotation.PreDestroy; import okhttp3.mockwebserver.Dispatcher; @@ -60,6 +61,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter; @@ -78,20 +80,21 @@ import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.hamcrest.CoreMatchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** - * Tests for {@link org.springframework.security.config.web.server.ServerHttpSecurity.OAuth2ResourceServerSpec} + * Tests for + * {@link org.springframework.security.config.web.server.ServerHttpSecurity.OAuth2ResourceServerSpec} */ @RunWith(SpringRunner.class) public class OAuth2ResourceServerSpecTests { + private String expired = "eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE1MzUwMzc4OTd9.jqZDDjfc2eysX44lHXEIr9XFd2S8vjIZHCccZU-dRWMRJNsQ1QN5VNnJGklqJBXJR4qgla6cmVqPOLkUHDb0sL0nxM5XuzQaG5ZzKP81RV88shFyAiT0fD-6nl1k-Fai-Fu-VkzSpNXgeONoTxDaYhdB-yxmgrgsApgmbOTE_9AcMk-FQDXQ-pL9kynccFGV0lZx4CA7cyknKN7KBxUilfIycvXODwgKCjj_1WddLTCNGYogJJSg__7NoxzqbyWd3udbHVjqYq7GsMMrGB4_2kBD4CkghOSNcRHbT_DIXowxfAVT7PAg7Q0E5ruZsr2zPZacEUDhJ6-wbvlA0FAOUg"; private String messageReadToken = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJtb2NrLXN1YmplY3QiLCJzY29wZSI6Im1lc3NhZ2U6cmVhZCIsImV4cCI6NDY4ODY0MTQxM30.cRl1bv_dDYcAN5U4NlIVKj8uu4mLMwjABF93P4dShiq-GQ-owzaqTSlB4YarNFgV3PKQvT9wxN1jBpGribvISljakoC0E8wDV-saDi8WxN-qvImYsn1zLzYFiZXCfRIxCmonJpydeiAPRxMTPtwnYDS9Ib0T_iA80TBGd-INhyxUUfrwRW5sqKRbjUciRJhpp7fW2ZYXmi9iPt3HDjRQA4IloJZ7f4-spt5Q9wl5HcQTv1t4XrX4eqhVbE5cCoIkFQnKPOc-jhVM44_eazLU6Xk-CCXP8C_UT5pX0luRS2cJrVFfHp2IR_AWxC-shItg6LNEmNFD4Zc-JLZcr0Q86Q"; @@ -100,35 +103,40 @@ public class OAuth2ResourceServerSpecTests { private String unsignedToken = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9."; - private String jwkSet = "{\n" + - " \"keys\":[\n" + - " {\n" + - " \"kty\":\"RSA\",\n" + - " \"e\":\"AQAB\",\n" + - " \"use\":\"sig\",\n" + - " \"kid\":\"one\",\n" + - " \"n\":\"0IUjrPZDz-3z0UE4ppcKU36v7hnh8FJjhu3lbJYj0qj9eZiwEJxi9HHUfSK1DhUQG7mJBbYTK1tPYCgre5EkfKh-64VhYUa-vz17zYCmuB8fFj4XHE3MLkWIG-AUn8hNbPzYYmiBTjfGnMKxLHjsbdTiF4mtn-85w366916R6midnAuiPD4HjZaZ1PAsuY60gr8bhMEDtJ8unz81hoQrozpBZJ6r8aR1PrsWb1OqPMloK9kAIutJNvWYKacp8WYAp2WWy72PxQ7Fb0eIA1br3A5dnp-Cln6JROJcZUIRJ-QvS6QONWeS2407uQmS-i-lybsqaH0ldYC7NBEBA5inPQ\"\n" + - " }\n" + - " ]\n" + - "}\n"; + // @formatter:off + private String jwkSet = "{\n" + + " \"keys\":[\n" + + " {\n" + + " \"kty\":\"RSA\",\n" + + " \"e\":\"AQAB\",\n" + + " \"use\":\"sig\",\n" + + " \"kid\":\"one\",\n" + + " \"n\":\"0IUjrPZDz-3z0UE4ppcKU36v7hnh8FJjhu3lbJYj0qj9eZiwEJxi9HHUfSK1DhUQG7mJBbYTK1tPYCgre5EkfKh-64VhYUa-vz17zYCmuB8fFj4XHE3MLkWIG-AUn8hNbPzYYmiBTjfGnMKxLHjsbdTiF4mtn-85w366916R6midnAuiPD4HjZaZ1PAsuY60gr8bhMEDtJ8unz81hoQrozpBZJ6r8aR1PrsWb1OqPMloK9kAIutJNvWYKacp8WYAp2WWy72PxQ7Fb0eIA1br3A5dnp-Cln6JROJcZUIRJ-QvS6QONWeS2407uQmS-i-lybsqaH0ldYC7NBEBA5inPQ\"\n" + + " }\n" + + " ]\n" + + "}\n"; + // @formatter:on - private Jwt jwt = jwt().build(); + private Jwt jwt = TestJwts.jwt().build(); private String clientId = "client"; - private String clientSecret = "secret"; - private String active = "{\n" + - " \"active\": true,\n" + - " \"client_id\": \"l238j323ds-23ij4\",\n" + - " \"username\": \"jdoe\",\n" + - " \"scope\": \"read write dolphin\",\n" + - " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + - " \"aud\": \"https://protected.example.net/resource\",\n" + - " \"iss\": \"https://server.example.com/\",\n" + - " \"exp\": 1419356238,\n" + - " \"iat\": 1419350238,\n" + - " \"extension_field\": \"twenty-seven\"\n" + - " }"; + private String clientSecret = "secret"; + + // @formatter:off + private String active = "{\n" + + " \"active\": true,\n" + + " \"client_id\": \"l238j323ds-23ij4\",\n" + + " \"username\": \"jdoe\",\n" + + " \"scope\": \"read write dolphin\",\n" + + " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + + " \"aud\": \"https://protected.example.net/resource\",\n" + + " \"iss\": \"https://server.example.com/\",\n" + + " \"exp\": 1419356238,\n" + + " \"iat\": 1419350238,\n" + + " \"extension_field\": \"twenty-seven\"\n" + + " }"; + // @formatter:on @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -143,241 +151,283 @@ public class OAuth2ResourceServerSpecTests { @Test public void getWhenValidThenReturnsOk() { this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken)) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenExpiredThenReturnsInvalidToken() { this.spring.register(PublicKeyConfig.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.expired)) + .headers((headers) -> headers + .setBearerAuth(this.expired)) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\"")); + // @formatter:on } @Test public void getWhenUnsignedThenReturnsInvalidToken() { this.spring.register(PublicKeyConfig.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.unsignedToken)) + .headers((headers) -> headers + .setBearerAuth(this.unsignedToken)) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\"")); + // @formatter:on } @Test public void getWhenEmptyBearerTokenThenReturnsInvalidToken() { this.spring.register(PublicKeyConfig.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.add("Authorization", "Bearer ")) + .headers((headers) -> headers + .add("Authorization", "Bearer ") + ) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\"")); + // @formatter:on } @Test public void getWhenValidTokenAndPublicKeyInLambdaThenReturnsOk() { this.spring.register(PublicKeyInLambdaConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenExpiredTokenAndPublicKeyInLambdaThenReturnsInvalidToken() { this.spring.register(PublicKeyInLambdaConfig.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.expired)) + .headers((headers) -> headers + .setBearerAuth(this.expired) + ) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\"")); + // @formatter:on } @Test public void getWhenValidUsingPlaceholderThenReturnsOk() { this.spring.register(PlaceholderConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenCustomDecoderThenAuthenticatesAccordingly() { this.spring.register(CustomDecoderConfig.class, RootController.class).autowire(); - ReactiveJwtDecoder jwtDecoder = this.spring.getContext().getBean(ReactiveJwtDecoder.class); - when(jwtDecoder.decode(anyString())).thenReturn(Mono.just(this.jwt)); - + given(jwtDecoder.decode(anyString())).willReturn(Mono.just(this.jwt)); + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth("token")) + .headers((headers) -> headers + .setBearerAuth("token") + ) .exchange() .expectStatus().isOk(); - + // @formatter:on verify(jwtDecoder).decode(anyString()); } @Test public void getWhenUsingJwkSetUriThenConsultsAccordingly() { this.spring.register(JwkSetUriConfig.class, RootController.class).autowire(); - MockWebServer mockWebServer = this.spring.getContext().getBean(MockWebServer.class); mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet)); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadTokenWithKid)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadTokenWithKid) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenUsingJwkSetUriInLambdaThenConsultsAccordingly() { this.spring.register(JwkSetUriInLambdaConfig.class, RootController.class).autowire(); - MockWebServer mockWebServer = this.spring.getContext().getBean(MockWebServer.class); mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet)); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadTokenWithKid)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadTokenWithKid) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenUsingCustomAuthenticationManagerThenUsesItAccordingly() { this.spring.register(CustomAuthenticationManagerConfig.class).autowire(); - - ReactiveAuthenticationManager authenticationManager = this.spring.getContext().getBean( - ReactiveAuthenticationManager.class); - when(authenticationManager.authenticate(any(Authentication.class))) - .thenReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); - + ReactiveAuthenticationManager authenticationManager = this.spring.getContext() + .getBean(ReactiveAuthenticationManager.class); + given(authenticationManager.authenticate(any(Authentication.class))) + .willReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"mock-failure\"")); + // @formatter:on } @Test public void getWhenUsingCustomAuthenticationManagerInLambdaThenUsesItAccordingly() { this.spring.register(CustomAuthenticationManagerInLambdaConfig.class).autowire(); - - ReactiveAuthenticationManager authenticationManager = this.spring.getContext().getBean( - ReactiveAuthenticationManager.class); - when(authenticationManager.authenticate(any(Authentication.class))) - .thenReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); - + ReactiveAuthenticationManager authenticationManager = this.spring.getContext() + .getBean(ReactiveAuthenticationManager.class); + given(authenticationManager.authenticate(any(Authentication.class))) + .willReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"mock-failure\"")); + // @formatter:on } @Test public void getWhenUsingCustomAuthenticationManagerResolverThenUsesItAccordingly() { this.spring.register(CustomAuthenticationManagerResolverConfig.class).autowire(); - - ReactiveAuthenticationManagerResolver authenticationManagerResolver = - this.spring.getContext().getBean(ReactiveAuthenticationManagerResolver.class); - - ReactiveAuthenticationManager authenticationManager = - this.spring.getContext().getBean(ReactiveAuthenticationManager.class); - - when(authenticationManagerResolver.resolve(any(ServerWebExchange.class))) - .thenReturn(Mono.just(authenticationManager)); - when(authenticationManager.authenticate(any(Authentication.class))) - .thenReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); - + ReactiveAuthenticationManagerResolver authenticationManagerResolver = this.spring + .getContext().getBean(ReactiveAuthenticationManagerResolver.class); + ReactiveAuthenticationManager authenticationManager = this.spring.getContext() + .getBean(ReactiveAuthenticationManager.class); + given(authenticationManagerResolver.resolve(any(ServerWebExchange.class))) + .willReturn(Mono.just(authenticationManager)); + given(authenticationManager.authenticate(any(Authentication.class))) + .willReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isUnauthorized() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"mock-failure\"")); + // @formatter:on } @Test public void postWhenSignedThenReturnsOk() { this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.post() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenTokenHasInsufficientScopeThenReturnsInsufficientScope() { this.spring.register(DenyAllConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isForbidden() .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"insufficient_scope\"")); + // @formatter:on } @Test public void postWhenMissingTokenThenReturnsForbidden() { this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.post() .exchange() .expectStatus().isForbidden(); + // @formatter:on } @Test public void getWhenCustomBearerTokenServerAuthenticationConverterThenResponds() { this.spring.register(CustomBearerTokenServerAuthenticationConverter.class, RootController.class).autowire(); - + // @formatter:off this.client.get() .cookie("TOKEN", this.messageReadToken) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenSignedAndCustomConverterThenConverts() { this.spring.register(CustomJwtAuthenticationConverterConfig.class, RootController.class).autowire(); - + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void getWhenCustomBearerTokenEntryPointThenResponds() { this.spring.register(CustomErrorHandlingConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("/authenticated") .exchange() .expectStatus().isEqualTo(HttpStatus.I_AM_A_TEAPOT); + // @formatter:on } @Test public void getWhenCustomBearerTokenDeniedHandlerThenResponds() { this.spring.register(CustomErrorHandlingConfig.class).autowire(); - + // @formatter:off this.client.get() .uri("/unobtainable") - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isEqualTo(HttpStatus.BANDWIDTH_LIMIT_EXCEEDED); + // @formatter:on } @Test @@ -385,14 +435,11 @@ public class OAuth2ResourceServerSpecTests { GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); ServerHttpSecurity http = new ServerHttpSecurity(); http.setApplicationContext(context); - ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class); ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class); context.registerBean(ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); - ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); jwt.jwtDecoder(dslWiredJwtDecoder); - assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder); } @@ -401,15 +448,12 @@ public class OAuth2ResourceServerSpecTests { GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); ServerHttpSecurity http = new ServerHttpSecurity(); http.setApplicationContext(context); - ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class); ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class); context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); - ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); jwt.jwtDecoder(dslWiredJwtDecoder); - assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder); } @@ -418,15 +462,11 @@ public class OAuth2ResourceServerSpecTests { GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); ServerHttpSecurity http = new ServerHttpSecurity(); http.setApplicationContext(context); - ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class); context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); - ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); - - assertThatCode(() -> jwt.getJwtDecoder()) - .isInstanceOf(NoUniqueBeanDefinitionException.class); + assertThatExceptionOfType(NoUniqueBeanDefinitionException.class).isThrownBy(() -> jwt.getJwtDecoder()); } @Test @@ -434,47 +474,102 @@ public class OAuth2ResourceServerSpecTests { GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); ServerHttpSecurity http = new ServerHttpSecurity(); http.setApplicationContext(context); - ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); - - assertThatCode(() -> jwt.getJwtDecoder()) - .isInstanceOf(NoSuchBeanDefinitionException.class); + assertThatExceptionOfType(NoSuchBeanDefinitionException.class).isThrownBy(() -> jwt.getJwtDecoder()); } @Test public void introspectWhenValidThenReturnsOk() { this.spring.register(IntrospectionConfig.class, RootController.class).autowire(); this.spring.getContext().getBean(MockWebServer.class) - .setDispatcher(requiresAuth(clientId, clientSecret, active)); - + .setDispatcher(requiresAuth(this.clientId, this.clientSecret, this.active)); + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void introspectWhenValidAndIntrospectionInLambdaThenReturnsOk() { this.spring.register(IntrospectionInLambdaConfig.class, RootController.class).autowire(); this.spring.getContext().getBean(MockWebServer.class) - .setDispatcher(requiresAuth(clientId, clientSecret, active)); - + .setDispatcher(requiresAuth(this.clientId, this.clientSecret, this.active)); + // @formatter:off this.client.get() - .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .headers((headers) -> headers + .setBearerAuth(this.messageReadToken) + ) .exchange() .expectStatus().isOk(); + // @formatter:on } @Test public void configureWhenUsingBothAuthenticationManagerResolverAndOpaqueThenWiringException() { - assertThatCode(() -> this.spring.register(AuthenticationManagerResolverPlusOtherConfig.class).autowire()) - .isInstanceOf(BeanCreationException.class) - .hasMessageContaining("authenticationManagerResolver"); + assertThatExceptionOfType(BeanCreationException.class) + .isThrownBy(() -> this.spring.register(AuthenticationManagerResolverPlusOtherConfig.class).autowire()) + .withMessageContaining("authenticationManagerResolver"); + } + + private static Dispatcher requiresAuth(String username, String password, String response) { + return new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + String authorization = request.getHeader(org.springframework.http.HttpHeaders.AUTHORIZATION); + // @formatter:off + return Optional.ofNullable(authorization) + .filter((a) -> isAuthorized(authorization, username, password)) + .map((a) -> ok(response)) + .orElse(unauthorized()); + // @formatter:on + } + }; + } + + private static boolean isAuthorized(String authorization, String username, String password) { + String[] values = new String(Base64.getDecoder().decode(authorization.substring(6))).split(":"); + return username.equals(values[0]) && password.equals(values[1]); + } + + private static MockResponse ok(String response) { + return new MockResponse().setBody(response).setHeader(org.springframework.http.HttpHeaders.CONTENT_TYPE, + MediaType.APPLICATION_JSON_VALUE); + } + + private static MockResponse unauthorized() { + return new MockResponse().setResponseCode(401); + } + + private static RSAPublicKey publicKey() { + String modulus = "26323220897278656456354815752829448539647589990395639665273015355787577386000316054335559633864476469390247312823732994485311378484154955583861993455004584140858982659817218753831620205191028763754231454775026027780771426040997832758235764611119743390612035457533732596799927628476322029280486807310749948064176545712270582940917249337311592011920620009965129181413510845780806191965771671528886508636605814099711121026468495328702234901200169245493126030184941412539949521815665744267183140084667383643755535107759061065656273783542590997725982989978433493861515415520051342321336460543070448417126615154138673620797"; + String exponent = "65537"; + RSAPublicKeySpec spec = new RSAPublicKeySpec(new BigInteger(modulus), new BigInteger(exponent)); + RSAPublicKey rsaPublicKey = null; + try { + KeyFactory factory = KeyFactory.getInstance("RSA"); + rsaPublicKey = (RSAPublicKey) factory.generatePublic(spec); + } + catch (NoSuchAlgorithmException | InvalidKeySpecException ex) { + ex.printStackTrace(); + } + return rsaPublicKey; + } + + private GenericWebApplicationContext autowireWebServerGenericWebApplicationContext() { + GenericWebApplicationContext context = new GenericWebApplicationContext(); + context.registerBean("webHandler", DispatcherHandler.class); + this.spring.context(context).autowire(); + return (GenericWebApplicationContext) this.spring.getContext(); } @EnableWebFlux @EnableWebFluxSecurity static class PublicKeyConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -486,39 +581,40 @@ public class OAuth2ResourceServerSpecTests { .jwt() .publicKey(publicKey()); // @formatter:on - - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class PublicKeyInLambdaConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off http - .authorizeExchange(exchanges -> + .authorizeExchange((exchanges) -> exchanges .anyExchange().hasAuthority("SCOPE_message:read") ) - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer - .jwt(jwt -> + .jwt((jwt) -> jwt .publicKey(publicKey()) ) ); // @formatter:on - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class PlaceholderConfig { + @Value("${classpath:org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests-simple.pub}") RSAPublicKey key; @@ -533,28 +629,26 @@ public class OAuth2ResourceServerSpecTests { .jwt() .publicKey(this.key); // @formatter:on - - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class JwkSetUriConfig { + private MockWebServer mockWebServer = new MockWebServer(); @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { String jwkSetUri = mockWebServer().url("/.well-known/jwks.json").toString(); - // @formatter:off http .oauth2ResourceServer() .jwt() .jwkSetUri(jwkSetUri); // @formatter:on - return http.build(); } @@ -567,28 +661,28 @@ public class OAuth2ResourceServerSpecTests { void shutdown() throws IOException { this.mockWebServer.shutdown(); } + } @EnableWebFlux @EnableWebFluxSecurity static class JwkSetUriInLambdaConfig { + private MockWebServer mockWebServer = new MockWebServer(); @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { String jwkSetUri = mockWebServer().url("/.well-known/jwks.json").toString(); - // @formatter:off http - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer - .jwt(jwt -> + .jwt((jwt) -> jwt .jwkSetUri(jwkSetUri) ) ); // @formatter:on - return http.build(); } @@ -601,11 +695,13 @@ public class OAuth2ResourceServerSpecTests { void shutdown() throws IOException { this.mockWebServer.shutdown(); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomDecoderConfig { + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); @Bean @@ -615,7 +711,6 @@ public class OAuth2ResourceServerSpecTests { .oauth2ResourceServer() .jwt(); // @formatter:on - return http.build(); } @@ -623,11 +718,13 @@ public class OAuth2ResourceServerSpecTests { ReactiveJwtDecoder jwtDecoder() { return this.jwtDecoder; } + } @EnableWebFlux @EnableWebFluxSecurity static class DenyAllConfig { + @Bean SecurityWebFilterChain authorization(ServerHttpSecurity http) { // @formatter:off @@ -639,14 +736,15 @@ public class OAuth2ResourceServerSpecTests { .jwt() .publicKey(publicKey()); // @formatter:on - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomAuthenticationManagerConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -655,7 +753,6 @@ public class OAuth2ResourceServerSpecTests { .jwt() .authenticationManager(authenticationManager()); // @formatter:on - return http.build(); } @@ -663,24 +760,25 @@ public class OAuth2ResourceServerSpecTests { ReactiveAuthenticationManager authenticationManager() { return mock(ReactiveAuthenticationManager.class); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomAuthenticationManagerInLambdaConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off http - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer - .jwt(jwt -> + .jwt((jwt) -> jwt .authenticationManager(authenticationManager()) ) ); // @formatter:on - return http.build(); } @@ -688,11 +786,13 @@ public class OAuth2ResourceServerSpecTests { ReactiveAuthenticationManager authenticationManager() { return mock(ReactiveAuthenticationManager.class); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomAuthenticationManagerResolverConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -703,7 +803,6 @@ public class OAuth2ResourceServerSpecTests { .oauth2ResourceServer() .authenticationManagerResolver(authenticationManagerResolver()); // @formatter:on - return http.build(); } @@ -716,11 +815,13 @@ public class OAuth2ResourceServerSpecTests { ReactiveAuthenticationManager authenticationManager() { return mock(ReactiveAuthenticationManager.class); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomBearerTokenServerAuthenticationConverter { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -733,20 +834,21 @@ public class OAuth2ResourceServerSpecTests { .jwt() .publicKey(publicKey()); // @formatter:on - return http.build(); } @Bean ServerAuthenticationConverter bearerTokenAuthenticationConverter() { - return exchange -> Mono.justOrEmpty(exchange.getRequest().getCookies().getFirst("TOKEN").getValue()) + return (exchange) -> Mono.justOrEmpty(exchange.getRequest().getCookies().getFirst("TOKEN").getValue()) .map(BearerTokenAuthenticationToken::new); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomJwtAuthenticationConverterConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -759,26 +861,25 @@ public class OAuth2ResourceServerSpecTests { .jwtAuthenticationConverter(jwtAuthenticationConverter()) .publicKey(publicKey()); // @formatter:on - return http.build(); } @Bean Converter> jwtAuthenticationConverter() { - JwtAuthenticationConverter converter = new JwtAuthenticationConverter(); - converter.setJwtGrantedAuthoritiesConverter(jwt -> { - String[] claims = ((String) jwt.getClaims().get("scope")).split(" "); - return Stream.of(claims).map(SimpleGrantedAuthority::new).collect(Collectors.toList()); - }); - + converter.setJwtGrantedAuthoritiesConverter((jwt) -> { + String[] claims = ((String) jwt.getClaims().get("scope")).split(" "); + return Stream.of(claims).map(SimpleGrantedAuthority::new).collect(Collectors.toList()); + }); return new ReactiveJwtAuthenticationConverterAdapter(converter); } + } @EnableWebFlux @EnableWebFluxSecurity static class CustomErrorHandlingConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -793,20 +894,20 @@ public class OAuth2ResourceServerSpecTests { .jwt() .publicKey(publicKey()); // @formatter:on - return http.build(); } + } @EnableWebFlux @EnableWebFluxSecurity static class IntrospectionConfig { + private MockWebServer mockWebServer = new MockWebServer(); @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { String introspectionUri = mockWebServer().url("/introspect").toString(); - // @formatter:off http .oauth2ResourceServer() @@ -814,7 +915,6 @@ public class OAuth2ResourceServerSpecTests { .introspectionUri(introspectionUri) .introspectionClientCredentials("client", "secret"); // @formatter:on - return http.build(); } @@ -827,29 +927,29 @@ public class OAuth2ResourceServerSpecTests { void shutdown() throws IOException { this.mockWebServer.shutdown(); } + } @EnableWebFlux @EnableWebFluxSecurity static class IntrospectionInLambdaConfig { + private MockWebServer mockWebServer = new MockWebServer(); @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { String introspectionUri = mockWebServer().url("/introspect").toString(); - // @formatter:off http - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer - .opaqueToken(opaqueToken -> + .opaqueToken((opaqueToken) -> opaqueToken .introspectionUri(introspectionUri) .introspectionClientCredentials("client", "secret") ) ); // @formatter:on - return http.build(); } @@ -862,11 +962,13 @@ public class OAuth2ResourceServerSpecTests { void shutdown() throws IOException { this.mockWebServer.shutdown(); } + } @EnableWebFlux @EnableWebFluxSecurity static class AuthenticationManagerResolverPlusOtherConfig { + @Bean SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { // @formatter:off @@ -877,13 +979,15 @@ public class OAuth2ResourceServerSpecTests { .oauth2ResourceServer() .authenticationManagerResolver(mock(ReactiveAuthenticationManagerResolver.class)) .opaqueToken(); - + // @formatter:on return http.build(); } + } @RestController static class RootController { + @GetMapping Mono get() { return Mono.just("ok"); @@ -893,54 +997,7 @@ public class OAuth2ResourceServerSpecTests { Mono post() { return Mono.just("ok"); } + } - private static Dispatcher requiresAuth(String username, String password, String response) { - return new Dispatcher() { - @Override - public MockResponse dispatch(RecordedRequest request) { - String authorization = request.getHeader(org.springframework.http.HttpHeaders.AUTHORIZATION); - return Optional.ofNullable(authorization) - .filter(a -> isAuthorized(authorization, username, password)) - .map(a -> ok(response)) - .orElse(unauthorized()); - } - }; - } - - private static boolean isAuthorized(String authorization, String username, String password) { - String[] values = new String(Base64.getDecoder().decode(authorization.substring(6))).split(":"); - return username.equals(values[0]) && password.equals(values[1]); - } - - private static MockResponse ok(String response) { - return new MockResponse().setBody(response) - .setHeader(org.springframework.http.HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); - } - - private static MockResponse unauthorized() { - return new MockResponse().setResponseCode(401); - } - - private static RSAPublicKey publicKey() { - String modulus = "26323220897278656456354815752829448539647589990395639665273015355787577386000316054335559633864476469390247312823732994485311378484154955583861993455004584140858982659817218753831620205191028763754231454775026027780771426040997832758235764611119743390612035457533732596799927628476322029280486807310749948064176545712270582940917249337311592011920620009965129181413510845780806191965771671528886508636605814099711121026468495328702234901200169245493126030184941412539949521815665744267183140084667383643755535107759061065656273783542590997725982989978433493861515415520051342321336460543070448417126615154138673620797"; - String exponent = "65537"; - - RSAPublicKeySpec spec = new RSAPublicKeySpec(new BigInteger(modulus), new BigInteger(exponent)); - RSAPublicKey rsaPublicKey = null; - try { - KeyFactory factory = KeyFactory.getInstance("RSA"); - rsaPublicKey = (RSAPublicKey) factory.generatePublic(spec); - } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { - e.printStackTrace(); - } - return rsaPublicKey; - } - - private GenericWebApplicationContext autowireWebServerGenericWebApplicationContext() { - GenericWebApplicationContext context = new GenericWebApplicationContext(); - context.registerBean("webHandler", DispatcherHandler.class); - this.spring.context(context).autowire(); - return (GenericWebApplicationContext) this.spring.getContext(); - } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java b/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java index 9f0d9a09ec..b8e39e02ab 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java @@ -19,6 +19,7 @@ package org.springframework.security.config.web.server; import org.junit.Test; import org.openqa.selenium.WebDriver; import org.openqa.selenium.support.PageFactory; + import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.config.web.server.FormLoginTests.DefaultLoginPage; import org.springframework.security.config.web.server.FormLoginTests.HomePage; @@ -41,104 +42,98 @@ import static org.springframework.security.config.Customizer.withDefaults; * @since 5.0 */ public class RequestCacheTests { + private ServerHttpSecurity http = ServerHttpSecurityConfigurationBuilder.httpWithDefaultAuthentication(); @Test public void defaultFormLoginRequestCache() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin().and() - .build(); - + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .build(); WebTestClient webTestClient = WebTestClient - .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) - .webFilter(new WebFilterChainProxy(securityWebFilter)) - .build(); - + .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class).assertAt(); + // @formatter:off SecuredPage securedPage = loginPage.loginForm() - .username("user") - .password("password") - .submit(SecuredPage.class); - + .username("user") + .password("password") + .submit(SecuredPage.class); + // @formatter:on securedPage.assertAt(); } @Test public void requestCacheNoOp() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange() - .anyExchange().authenticated() - .and() - .formLogin().and() - .requestCache() - .requestCache(NoOpServerRequestCache.getInstance()) - .and() - .build(); - + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .requestCache() + .requestCache(NoOpServerRequestCache.getInstance()) + .and() + .build(); WebTestClient webTestClient = WebTestClient - .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) - .webFilter(new WebFilterChainProxy(securityWebFilter)) - .build(); - + .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class).assertAt(); + // @formatter:off HomePage securedPage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on securedPage.assertAt(); } @Test public void requestWhenCustomRequestCacheInLambdaThenCustomCacheUsed() { + // @formatter:off SecurityWebFilterChain securityWebFilter = this.http - .authorizeExchange(authorizeExchange -> - authorizeExchange - .anyExchange().authenticated() - ) - .formLogin(withDefaults()) - .requestCache(requestCache -> - requestCache - .requestCache(NoOpServerRequestCache.getInstance()) - ) - .build(); - + .authorizeExchange((exchange) -> exchange + .anyExchange().authenticated() + ) + .formLogin(withDefaults()) + .requestCache((requestCache) -> requestCache + .requestCache(NoOpServerRequestCache.getInstance()) + ) + .build(); WebTestClient webTestClient = WebTestClient - .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) - .webFilter(new WebFilterChainProxy(securityWebFilter)) - .build(); - + .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient) - .build(); - - DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class) - .assertAt(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class).assertAt(); + // @formatter:off HomePage securedPage = loginPage.loginForm() - .username("user") - .password("password") - .submit(HomePage.class); - + .username("user") + .password("password") + .submit(HomePage.class); + // @formatter:on securedPage.assertAt(); } public static class SecuredPage { + private WebDriver driver; public SecuredPage(WebDriver driver) { @@ -153,23 +148,28 @@ public class RequestCacheTests { driver.get("http://localhost/secured"); return PageFactory.initElements(driver, page); } + } @Controller public static class SecuredPageController { + @ResponseBody @GetMapping("/secured") public String login(ServerWebExchange exchange) { - return - "\n" - + "\n" - + " \n" - + " Secured\n" - + " \n" - + " \n" - + "

    Secured

    \n" - + " \n" - + ""; + // @formatter:off + return "\n" + + "\n" + + " \n" + + " Secured\n" + + " \n" + + " \n" + + "

    Secured

    \n" + + " \n" + + ""; + // @formatter:on } + } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 052b6629a4..de1f4ec014 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -16,17 +16,6 @@ package org.springframework.security.config.web.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.BDDMockito.given; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.config.Customizer.withDefaults; -import static org.springframework.test.util.ReflectionTestUtils.getField; - import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -40,36 +29,39 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; - -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; -import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; -import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; -import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; -import org.springframework.security.web.server.savedrequest.ServerRequestCache; -import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import reactor.core.publisher.Mono; import reactor.test.publisher.TestPublisher; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; +import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; +import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; +import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; +import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; +import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; +import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; @@ -77,10 +69,16 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; -import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.web.server.WebFilterChain; -import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; -import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.springframework.security.config.Customizer.withDefaults; /** * @author Rob Winch @@ -89,10 +87,13 @@ import org.springframework.security.web.server.authentication.HttpBasicServerAut */ @RunWith(MockitoJUnitRunner.class) public class ServerHttpSecurityTests { + @Mock private ServerSecurityContextRepository contextRepository; + @Mock private ReactiveAuthenticationManager authenticationManager; + @Mock private ServerCsrfTokenRepository csrfTokenRepository; @@ -100,24 +101,22 @@ public class ServerHttpSecurityTests { @Before public void setup() { - this.http = ServerHttpSecurityConfigurationBuilder.http() - .authenticationManager(this.authenticationManager); + this.http = ServerHttpSecurityConfigurationBuilder.http().authenticationManager(this.authenticationManager); } @Test public void defaults() { TestPublisher securityContext = TestPublisher.create(); - when(this.contextRepository.load(any())).thenReturn(securityContext.mono()); + given(this.contextRepository.load(any())).willReturn(securityContext.mono()); this.http.securityContextRepository(this.contextRepository); - WebTestClient client = buildClient(); - + // @formatter:off FluxExchangeResult result = client.get() - .uri("/") - .exchange() - .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") - .returnResult(String.class); - + .uri("/") + .exchange() + .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") + .returnResult(String.class); + // @formatter:on assertThat(result.getResponseCookies()).isEmpty(); // there is no need to try and load the SecurityContext by default securityContext.assertWasNotSubscribed(); @@ -125,186 +124,192 @@ public class ServerHttpSecurityTests { @Test public void basic() { - given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); - + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); this.http.httpBasic(); this.http.authenticationManager(this.authenticationManager); ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange(); authorize.anyExchange().authenticated(); - WebTestClient client = buildClient(); - + // @formatter:off EntityExchangeResult result = client.get() - .uri("/") - .headers(headers -> headers.setBasicAuth("rob", "rob")) - .exchange() - .expectStatus().isOk() - .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") - .expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok")) - .returnResult(); - + .uri("/") + .headers((headers) -> headers + .setBasicAuth("rob", "rob") + ) + .exchange() + .expectStatus().isOk() + .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") + .expectBody(String.class).consumeWith((b) -> assertThat(b.getResponseBody()).isEqualTo("ok")) + .returnResult(); + // @formatter:on assertThat(result.getResponseCookies().getFirst("SESSION")).isNull(); } @Test public void basicWithGlobalWebSessionServerSecurityContextRepository() { - given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); - + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); this.http.securityContextRepository(new WebSessionServerSecurityContextRepository()); this.http.httpBasic(); this.http.authenticationManager(this.authenticationManager); ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange(); authorize.anyExchange().authenticated(); - WebTestClient client = buildClient(); - + // @formatter:off EntityExchangeResult result = client.get() .uri("/") - .headers(headers -> headers.setBasicAuth("rob", "rob")) + .headers((headers) -> headers + .setBasicAuth("rob", "rob") + ) .exchange() .expectStatus().isOk() .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") - .expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok")) + .expectBody(String.class).consumeWith((b) -> assertThat(b.getResponseBody()).isEqualTo("ok")) .returnResult(); - + // @formatter:on assertThat(result.getResponseCookies().getFirst("SESSION")).isNotNull(); } @Test public void basicWhenNoCredentialsThenUnauthorized() { this.http.authorizeExchange().anyExchange().authenticated(); - WebTestClient client = buildClient(); - client - .get() - .uri("/") - .exchange() - .expectStatus().isUnauthorized() - .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") - .expectBody().isEmpty(); + // @formatter:off + client.get().uri("/") + .exchange() + .expectStatus().isUnauthorized() + .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") + .expectBody().isEmpty(); + // @formatter:on } @Test public void buildWhenServerWebExchangeFromContextThenFound() { SecurityWebFilterChain filter = this.http.build(); - - WebTestClient client = WebTestClient.bindToController(new SubscriberContextController()) + // @formatter:off + WebTestClient client = WebTestClient + .bindToController(new SubscriberContextController()) .webFilter(new WebFilterChainProxy(filter)) .build(); - - client.get().uri("/foo/bar") + client.get() + .uri("/foo/bar") .exchange() .expectBody(String.class).isEqualTo("/foo/bar"); + // @formatter:on } @Test public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() { SecurityWebFilterChain securityWebFilterChain = this.http.csrf().disable().build(); - - assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) - .isNotPresent(); - + assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)).isNotPresent(); Optional logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) - .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); - - assertThat(logoutHandler) - .get() - .isExactlyInstanceOf(SecurityContextServerLogoutHandler.class); + .map((logoutWebFilter) -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, + LogoutWebFilter.class, "logoutHandler")); + assertThat(logoutHandler).get().isExactlyInstanceOf(SecurityContextServerLogoutHandler.class); } @Test public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() { - SecurityWebFilterChain securityWebFilterChain = this.http.csrf().csrfTokenRepository(this.csrfTokenRepository).and().build(); - - assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) - .get() - .extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository")) + SecurityWebFilterChain securityWebFilterChain = this.http.csrf().csrfTokenRepository(this.csrfTokenRepository) + .and().build(); + assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)).get() + .extracting((csrfWebFilter) -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository")) .isEqualTo(this.csrfTokenRepository); - Optional logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) - .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); - - assertThat(logoutHandler) - .get() - .isExactlyInstanceOf(DelegatingServerLogoutHandler.class) - .extracting(delegatingLogoutHandler -> - ((List) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() - .map(ServerLogoutHandler::getClass) - .collect(Collectors.toList())) + .map((logoutWebFilter) -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, + LogoutWebFilter.class, "logoutHandler")); + assertThat(logoutHandler).get().isExactlyInstanceOf(DelegatingServerLogoutHandler.class) + .extracting((delegatingLogoutHandler) -> ((List) ReflectionTestUtils + .getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() + .map(ServerLogoutHandler::getClass).collect(Collectors.toList())) .isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); } @Test @SuppressWarnings("unchecked") - public void addFilterAfterIsApplied(){ - SecurityWebFilterChain securityWebFilterChain = this.http.addFilterAfter(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build(); - List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block(); - - assertThat(filters).isNotNull() - .isNotEmpty() - .containsSequence(SecurityContextServerWebExchangeWebFilter.class, TestWebFilter.class); - + public void addFilterAfterIsApplied() { + SecurityWebFilterChain securityWebFilterChain = this.http + .addFilterAfter(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + .build(); + // @formatter:off + List filters = securityWebFilterChain.getWebFilters() + .map(WebFilter::getClass) + .collectList() + .block(); + // @formatter:on + assertThat(filters).isNotNull().isNotEmpty().containsSequence(SecurityContextServerWebExchangeWebFilter.class, + TestWebFilter.class); } @Test @SuppressWarnings("unchecked") - public void addFilterBeforeIsApplied(){ - SecurityWebFilterChain securityWebFilterChain = this.http.addFilterBefore(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build(); - List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block(); - - assertThat(filters).isNotNull() - .isNotEmpty() - .containsSequence(TestWebFilter.class, SecurityContextServerWebExchangeWebFilter.class); - + public void addFilterBeforeIsApplied() { + SecurityWebFilterChain securityWebFilterChain = this.http + .addFilterBefore(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE) + .build(); + // @formatter:off + List filters = securityWebFilterChain.getWebFilters() + .map(WebFilter::getClass) + .collectList() + .block(); + // @formatter:on + assertThat(filters).isNotNull().isNotEmpty().containsSequence(TestWebFilter.class, + SecurityContextServerWebExchangeWebFilter.class); } @Test - public void anonymous(){ - SecurityWebFilterChain securityFilterChain = this.http.anonymous().and().build(); - WebTestClient client = WebTestClientBuilder.bindToControllerAndWebFilters(AnonymousAuthenticationWebFilterTests.HttpMeController.class, - securityFilterChain).build(); - + public void anonymous() { + // @formatter:off + SecurityWebFilterChain securityFilterChain = this.http + .anonymous().and() + .build(); + WebTestClient client = WebTestClientBuilder + .bindToControllerAndWebFilters(AnonymousAuthenticationWebFilterTests.HttpMeController.class, securityFilterChain) + .build(); client.get() .uri("/me") .exchange() .expectStatus().isOk() .expectBody(String.class).isEqualTo("anonymousUser"); - + // @formatter:on } @Test public void getWhenAnonymousConfiguredThenAuthenticationIsAnonymous() { SecurityWebFilterChain securityFilterChain = this.http.anonymous(withDefaults()).build(); - WebTestClient client = WebTestClientBuilder.bindToControllerAndWebFilters(AnonymousAuthenticationWebFilterTests.HttpMeController.class, - securityFilterChain).build(); - + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToControllerAndWebFilters(AnonymousAuthenticationWebFilterTests.HttpMeController.class, securityFilterChain) + .build(); client.get() .uri("/me") .exchange() .expectStatus().isOk() .expectBody(String.class).isEqualTo("anonymousUser"); + // @formatter:on } @Test public void basicWithAnonymous() { - given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); - + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); this.http.httpBasic().and().anonymous(); this.http.authenticationManager(this.authenticationManager); ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange(); authorize.anyExchange().hasAuthority("ROLE_ADMIN"); - WebTestClient client = buildClient(); - + // @formatter:off EntityExchangeResult result = client.get() .uri("/") - .headers(headers -> headers.setBasicAuth("rob", "rob")) - .exchange() + .headers((headers) -> headers + .setBasicAuth("rob", "rob") + ).exchange() .expectStatus().isOk() .expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+") - .expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok")) + .expectBody(String.class).consumeWith((b) -> assertThat(b.getResponseBody()).isEqualTo("ok")) .returnResult(); - + // @formatter:on assertThat(result.getResponseCookies().getFirst("SESSION")).isNull(); } @@ -317,17 +322,16 @@ public class ServerHttpSecurityTests { this.http.authenticationManager(this.authenticationManager); ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange(); authorize.anyExchange().authenticated(); - WebTestClient client = buildClient(); - + // @formatter:off EntityExchangeResult result = client.get() .uri("/") .exchange() .expectStatus().isUnauthorized() - .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, value -> assertThat(value).contains("myrealm")) + .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, (value) -> assertThat(value).contains("myrealm")) .expectBody(String.class) .returnResult(); - + // @formatter:on assertThat(result.getResponseCookies().getFirst("SESSION")).isNull(); } @@ -336,42 +340,48 @@ public class ServerHttpSecurityTests { this.http.securityContextRepository(new WebSessionServerSecurityContextRepository()); HttpBasicServerAuthenticationEntryPoint authenticationEntryPoint = new HttpBasicServerAuthenticationEntryPoint(); authenticationEntryPoint.setRealm("myrealm"); - this.http.httpBasic(httpBasic -> - httpBasic.authenticationEntryPoint(authenticationEntryPoint) - ); + this.http.httpBasic((httpBasic) -> httpBasic.authenticationEntryPoint(authenticationEntryPoint)); this.http.authenticationManager(this.authenticationManager); ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange(); authorize.anyExchange().authenticated(); - WebTestClient client = buildClient(); - + // @formatter:off EntityExchangeResult result = client.get() .uri("/") .exchange() .expectStatus().isUnauthorized() - .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, value -> assertThat(value).contains("myrealm")) + .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, (value) -> assertThat(value).contains("myrealm")) .expectBody(String.class) .returnResult(); - + // @formatter:on assertThat(result.getResponseCookies().getFirst("SESSION")).isNull(); } @Test public void basicWithCustomAuthenticationManager() { ReactiveAuthenticationManager customAuthenticationManager = mock(ReactiveAuthenticationManager.class); - given(customAuthenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); - - SecurityWebFilterChain securityFilterChain = this.http.httpBasic().authenticationManager(customAuthenticationManager).and().build(); + given(customAuthenticationManager.authenticate(any())) + .willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); + // @formatter:off + SecurityWebFilterChain securityFilterChain = this.http + .httpBasic() + .authenticationManager(customAuthenticationManager) + .and() + .build(); + // @formatter:on WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(securityFilterChain); - WebTestClient client = WebTestClientBuilder.bindToWebFilters(springSecurityFilterChain).build(); - + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(springSecurityFilterChain) + .build(); client.get() - .uri("/") - .headers(headers -> headers.setBasicAuth("rob", "rob")) + .uri("/").headers((headers) -> headers + .setBasicAuth("rob", "rob") + ) .exchange() .expectStatus().isOk() - .expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok")); - + .expectBody(String.class).consumeWith((b) -> assertThat(b.getResponseBody()).isEqualTo("ok")); + // @formatter:on verifyZeroInteractions(this.authenticationManager); } @@ -380,22 +390,27 @@ public class ServerHttpSecurityTests { ReactiveAuthenticationManager customAuthenticationManager = mock(ReactiveAuthenticationManager.class); given(customAuthenticationManager.authenticate(any())) .willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN"))); - + // @formatter:off SecurityWebFilterChain securityFilterChain = this.http - .httpBasic(httpBasic -> - httpBasic.authenticationManager(customAuthenticationManager) + .httpBasic((httpBasic) -> httpBasic + .authenticationManager(customAuthenticationManager) ) .build(); + // @formatter:on WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(securityFilterChain); - WebTestClient client = WebTestClientBuilder.bindToWebFilters(springSecurityFilterChain).build(); - + // @formatter:off + WebTestClient client = WebTestClientBuilder + .bindToWebFilters(springSecurityFilterChain) + .build(); client.get() .uri("/") - .headers(headers -> headers.setBasicAuth("rob", "rob")) + .headers((headers) -> headers + .setBasicAuth("rob", "rob") + ) .exchange() .expectStatus().isOk() - .expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok")); - + .expectBody(String.class).consumeWith((b) -> assertThat(b.getResponseBody()).isEqualTo("ok")); + // @formatter:on verifyZeroInteractions(this.authenticationManager); verify(customAuthenticationManager).authenticate(any(Authentication.class)); } @@ -405,15 +420,9 @@ public class ServerHttpSecurityTests { public void addsX509FilterWhenX509AuthenticationIsConfigured() { X509PrincipalExtractor mockExtractor = mock(X509PrincipalExtractor.class); ReactiveAuthenticationManager mockAuthenticationManager = mock(ReactiveAuthenticationManager.class); - - this.http.x509() - .principalExtractor(mockExtractor) - .authenticationManager(mockAuthenticationManager) - .and(); - + this.http.x509().principalExtractor(mockExtractor).authenticationManager(mockAuthenticationManager).and(); SecurityWebFilterChain securityWebFilterChain = this.http.build(); WebFilter x509WebFilter = securityWebFilterChain.getWebFilters().filter(this::isX509Filter).blockFirst(); - assertThat(x509WebFilter).isNotNull(); } @@ -421,157 +430,125 @@ public class ServerHttpSecurityTests { public void x509WhenCustomizedThenAddsX509Filter() { X509PrincipalExtractor mockExtractor = mock(X509PrincipalExtractor.class); ReactiveAuthenticationManager mockAuthenticationManager = mock(ReactiveAuthenticationManager.class); - - this.http.x509(x509 -> - x509 - .principalExtractor(mockExtractor) - .authenticationManager(mockAuthenticationManager) - ); - + this.http.x509( + (x509) -> x509.principalExtractor(mockExtractor).authenticationManager(mockAuthenticationManager)); SecurityWebFilterChain securityWebFilterChain = this.http.build(); WebFilter x509WebFilter = securityWebFilterChain.getWebFilters().filter(this::isX509Filter).blockFirst(); - assertThat(x509WebFilter).isNotNull(); } @Test public void addsX509FilterWhenX509AuthenticationIsConfiguredWithDefaults() { this.http.x509(); - SecurityWebFilterChain securityWebFilterChain = this.http.build(); WebFilter x509WebFilter = securityWebFilterChain.getWebFilters().filter(this::isX509Filter).blockFirst(); - assertThat(x509WebFilter).isNotNull(); } @Test public void x509WhenDefaultsThenAddsX509Filter() { this.http.x509(withDefaults()); - SecurityWebFilterChain securityWebFilterChain = this.http.build(); WebFilter x509WebFilter = securityWebFilterChain.getWebFilters().filter(this::isX509Filter).blockFirst(); - assertThat(x509WebFilter).isNotNull(); } @Test public void postWhenCsrfDisabledThenPermitted() { - SecurityWebFilterChain securityFilterChain = this.http.csrf(csrf -> csrf.disable()).build(); + SecurityWebFilterChain securityFilterChain = this.http.csrf((csrf) -> csrf.disable()).build(); WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(securityFilterChain); WebTestClient client = WebTestClientBuilder.bindToWebFilters(springSecurityFilterChain).build(); - - client.post() - .uri("/") - .exchange() - .expectStatus().isOk(); + client.post().uri("/").exchange().expectStatus().isOk(); } @Test public void postWhenCustomCsrfTokenRepositoryThenUsed() { ServerCsrfTokenRepository customServerCsrfTokenRepository = mock(ServerCsrfTokenRepository.class); - when(customServerCsrfTokenRepository.loadToken(any(ServerWebExchange.class))).thenReturn(Mono.empty()); + given(customServerCsrfTokenRepository.loadToken(any(ServerWebExchange.class))).willReturn(Mono.empty()); SecurityWebFilterChain securityFilterChain = this.http - .csrf(csrf -> csrf.csrfTokenRepository(customServerCsrfTokenRepository)) - .build(); + .csrf((csrf) -> csrf.csrfTokenRepository(customServerCsrfTokenRepository)).build(); WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(securityFilterChain); WebTestClient client = WebTestClientBuilder.bindToWebFilters(springSecurityFilterChain).build(); - - client.post() - .uri("/") - .exchange() - .expectStatus().isForbidden(); - + client.post().uri("/").exchange().expectStatus().isForbidden(); verify(customServerCsrfTokenRepository).loadToken(any()); } @Test public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() { ServerRequestCache requestCache = spy(new WebSessionServerRequestCache()); - ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); - - SecurityWebFilterChain securityFilterChain = this.http - .oauth2Login() - .clientRegistrationRepository(clientRegistrationRepository) - .and() - .authorizeExchange().anyExchange().authenticated() - .and() - .requestCache(c -> c.requestCache(requestCache)) - .build(); - + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + SecurityWebFilterChain securityFilterChain = this.http.oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository).and().authorizeExchange().anyExchange() + .authenticated().and().requestCache((c) -> c.requestCache(requestCache)).build(); WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); client.get().uri("/test").exchange(); ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); verify(requestCache).saveRequest(captor.capture()); assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test"); - - - OAuth2LoginAuthenticationWebFilter authenticationWebFilter = - getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get(); - Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler"); - assertThat(getField(handler, "requestCache")).isSameAs(requestCache); + OAuth2LoginAuthenticationWebFilter authenticationWebFilter = getWebFilter(securityFilterChain, + OAuth2LoginAuthenticationWebFilter.class).get(); + Object handler = ReflectionTestUtils.getField(authenticationWebFilter, "authenticationSuccessHandler"); + assertThat(ReflectionTestUtils.getField(handler, "requestCache")).isSameAs(requestCache); } @Test public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { - ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); - ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); - + ServerAuthorizationRequestRepository authorizationRequestRepository = mock( + ServerAuthorizationRequestRepository.class); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().build(); - - when(authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); - - SecurityWebFilterChain securityFilterChain = this.http - .oauth2Login() + given(authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(authorizationRequest)); + SecurityWebFilterChain securityFilterChain = this.http.oauth2Login() .clientRegistrationRepository(clientRegistrationRepository) - .authorizationRequestRepository(authorizationRequestRepository) - .and() - .build(); - + .authorizationRequestRepository(authorizationRequestRepository).and().build(); WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); client.get().uri("/login/oauth2/code/registration-id").exchange(); - verify(authorizationRequestRepository).removeAuthorizationRequest(any()); } private boolean isX509Filter(WebFilter filter) { try { - Object converter = getField(filter, "authenticationConverter"); + Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class); - } catch (IllegalArgumentException e) { + } + catch (IllegalArgumentException ex) { // field doesn't exist return false; } } private Optional getWebFilter(SecurityWebFilterChain filterChain, Class filterClass) { - return (Optional) filterChain.getWebFilters() - .filter(Objects::nonNull) - .filter(filter -> filter.getClass().isAssignableFrom(filterClass)) - .singleOrEmpty() - .blockOptional(); + return (Optional) filterChain.getWebFilters().filter(Objects::nonNull) + .filter((filter) -> filter.getClass().isAssignableFrom(filterClass)).singleOrEmpty().blockOptional(); } private WebTestClient buildClient() { - WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy( - this.http.build()); + WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(this.http.build()); return WebTestClientBuilder.bindToWebFilters(springSecurityFilterChain).build(); } @RestController private static class SubscriberContextController { + @GetMapping("/**") Mono pathWithinApplicationFromContext() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)) - .map(e -> e.getRequest().getPath().pathWithinApplication().value()); + return Mono.subscriberContext().filter((c) -> c.hasKey(ServerWebExchange.class)) + .map((c) -> c.get(ServerWebExchange.class)) + .map((e) -> e.getRequest().getPath().pathWithinApplication().value()); } + } private static class TestWebFilter implements WebFilter { + @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return chain.filter(exchange); } + } + } diff --git a/config/src/test/java/org/springframework/security/config/web/server/TestingServerHttpSecurity.java b/config/src/test/java/org/springframework/security/config/web/server/TestingServerHttpSecurity.java index 50d2a3020d..60a997a9c3 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/TestingServerHttpSecurity.java +++ b/config/src/test/java/org/springframework/security/config/web/server/TestingServerHttpSecurity.java @@ -24,9 +24,10 @@ import org.springframework.context.ApplicationContext; * @since 5.1 */ public class TestingServerHttpSecurity extends ServerHttpSecurity { - public TestingServerHttpSecurity applicationContext(ApplicationContext applicationContext) - throws BeansException { + + public TestingServerHttpSecurity applicationContext(ApplicationContext applicationContext) throws BeansException { super.setApplicationContext(applicationContext); return this; } + } diff --git a/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTest.java b/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTests.java similarity index 82% rename from config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTest.java rename to config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTests.java index e05ee31365..bda928963a 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTest.java +++ b/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTests.java @@ -13,22 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.websocket; import org.junit.Test; + import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.beans.factory.support.SimpleBeanDefinitionRegistry; -public class MessageSecurityPostProcessorTest { +public class MessageSecurityPostProcessorTests { - private WebSocketMessageBrokerSecurityBeanDefinitionParser.MessageSecurityPostProcessor postProcessor = - new WebSocketMessageBrokerSecurityBeanDefinitionParser.MessageSecurityPostProcessor("id", false); + private WebSocketMessageBrokerSecurityBeanDefinitionParser.MessageSecurityPostProcessor postProcessor = new WebSocketMessageBrokerSecurityBeanDefinitionParser.MessageSecurityPostProcessor( + "id", false); @Test public void handlesBeansWithoutClass() { BeanDefinitionRegistry registry = new SimpleBeanDefinitionRegistry(); registry.registerBeanDefinition("beanWithoutClass", new GenericBeanDefinition()); - postProcessor.postProcessBeanDefinitionRegistry(registry); + this.postProcessor.postProcessBeanDefinitionRegistry(registry); } + } diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index a693eaf96a..3e72669f57 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.config.websocket; +import java.util.HashMap; +import java.util.Map; + import org.assertj.core.api.ThrowableAssert; -import org.assertj.core.api.ThrowableAssert.ThrowingCallable; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.config.BeanDefinition; @@ -62,10 +66,8 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeHandler; -import java.util.HashMap; -import java.util.Map; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** @@ -75,8 +77,8 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners public class WebSocketMessageBrokerConfigTests { - private static final String CONFIG_LOCATION_PREFIX = - "classpath:org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests"; + + private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests"; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -98,237 +100,178 @@ public class WebSocketMessageBrokerConfigTests { @Test public void sendWhenNoIdSpecifiedThenIntegratesWithClientInboundChannel() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - this.clientInboundChannel.send(message("/permitAll")); - - assertThatThrownBy(() -> this.clientInboundChannel.send(message("/denyAll"))) - .hasCauseInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(() -> this.clientInboundChannel.send(message("/denyAll"))) + .withCauseInstanceOf(AccessDeniedException.class); } @Test public void sendWhenAnonymousMessageWithConnectMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken()); - - assertThatCode(() -> this.clientInboundChannel.send(message("/permitAll", headers))) - .doesNotThrowAnyException(); + this.clientInboundChannel.send(message("/permitAll", headers)); } @Test public void sendWhenAnonymousMessageWithConnectAckMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.CONNECT_ACK); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithDisconnectMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.DISCONNECT); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithDisconnectAckMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.DISCONNECT_ACK); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithHeartbeatMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.HEARTBEAT); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithMessageMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.MESSAGE); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithOtherMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.OTHER); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithSubscribeMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.SUBSCRIBE); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenAnonymousMessageWithUnsubscribeMessageTypeThenPermitted() { this.spring.configLocations(xml("NoIdConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.UNSUBSCRIBE); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenConnectWithoutCsrfTokenThenDenied() { this.spring.configLocations(xml("SyncConfig")).autowire(); - Message message = message("/message", SimpMessageType.CONNECT); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(InvalidCsrfTokenException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(InvalidCsrfTokenException.class); } @Test public void sendWhenConnectWithSameOriginDisabledThenCsrfTokenNotRequired() { this.spring.configLocations(xml("SyncSameOriginDisabledConfig")).autowire(); - Message message = message("/message", SimpMessageType.CONNECT); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenInterceptWiredForMessageTypeThenDeniesOnTypeMismatch() { this.spring.configLocations(xml("MessageInterceptTypeConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.MESSAGE); - - assertThatCode(send(message)).doesNotThrowAnyException(); - + send(message); message = message("/permitAll", SimpMessageType.UNSUBSCRIBE); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); - + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); message = message("/anyOther", SimpMessageType.MESSAGE); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); } @Test public void sendWhenInterceptWiredForSubscribeTypeThenDeniesOnTypeMismatch() { this.spring.configLocations(xml("SubscribeInterceptTypeConfig")).autowire(); - Message message = message("/permitAll", SimpMessageType.SUBSCRIBE); - - assertThatCode(send(message)).doesNotThrowAnyException(); - + send(message); message = message("/permitAll", SimpMessageType.UNSUBSCRIBE); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); - + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); message = message("/anyOther", SimpMessageType.SUBSCRIBE); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); } - // -- invalid intercept types -- // - @Test public void configureWhenUsingConnectMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("ConnectInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("ConnectInterceptTypeConfig")).autowire()); } @Test public void configureWhenUsingConnectAckMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("ConnectAckInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("ConnectAckInterceptTypeConfig")).autowire()); } @Test public void configureWhenUsingDisconnectMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("DisconnectInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("DisconnectInterceptTypeConfig")).autowire()); } @Test public void configureWhenUsingDisconnectAckMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("DisconnectAckInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("DisconnectAckInterceptTypeConfig")).autowire()); } @Test public void configureWhenUsingHeartbeatMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("HeartbeatInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("HeartbeatInterceptTypeConfig")).autowire()); } @Test public void configureWhenUsingOtherMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("OtherInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("OtherInterceptTypeConfig")).autowire()); } @Test public void configureWhenUsingUnsubscribeMessageTypeThenAutowireFails() { - ThrowingCallable bad = () -> - this.spring.configLocations(xml("UnsubscribeInterceptTypeConfig")).autowire(); - - assertThatThrownBy(bad).isInstanceOf(BeanDefinitionParsingException.class); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> this.spring.configLocations(xml("UnsubscribeInterceptTypeConfig")).autowire()); } @Test public void sendWhenNoIdMessageThenAuthenticationPrincipalResolved() { this.spring.configLocations(xml("SyncConfig")).autowire(); - this.clientInboundChannel.send(message("/message")); - assertThat(this.messageController.username).isEqualTo("anonymous"); } @Test public void requestWhenConnectMessageThenUsesCsrfTokenHandshakeInterceptor() throws Exception { this.spring.configLocations(xml("SyncConfig")).autowire(); - - WebApplicationContext context = (WebApplicationContext) this.spring.getContext(); + WebApplicationContext context = this.spring.getContext(); MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); - String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - - MvcResult result = mvc.perform(get("/app") - .requestAttr(csrfAttributeName, this.token) - .sessionAttr(customAttributeName, "attributeValue")) - .andReturn(); - + MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token) + .sessionAttr(customAttributeName, "attributeValue")).andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - - assertThat(handshakeToken).isEqualTo(this.token) - .withFailMessage("CsrfToken is populated"); - + assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -336,25 +279,16 @@ public class WebSocketMessageBrokerConfigTests { @Test public void requestWhenConnectMessageAndUsingSockJsThenUsesCsrfTokenHandshakeInterceptor() throws Exception { this.spring.configLocations(xml("SyncSockJsConfig")).autowire(); - - WebApplicationContext context = (WebApplicationContext) this.spring.getContext(); + WebApplicationContext context = this.spring.getContext(); MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); - String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - - MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket") - .requestAttr(csrfAttributeName, this.token) - .sessionAttr(customAttributeName, "attributeValue")) - .andReturn(); - + MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token) + .sessionAttr(customAttributeName, "attributeValue")).andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - - assertThat(handshakeToken).isEqualTo(this.token) - .withFailMessage("CsrfToken is populated"); - + assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -362,61 +296,51 @@ public class WebSocketMessageBrokerConfigTests { @Test public void sendWhenNoIdSpecifiedThenCustomArgumentResolversAreNotOverridden() { this.spring.configLocations(xml("SyncCustomArgumentResolverConfig")).autowire(); - this.clientInboundChannel.send(message("/message-with-argument")); - assertThat(this.messageWithArgumentController.messageArgument).isNotNull(); } @Test public void sendWhenUsingCustomPathMatcherThenSecurityAppliesIt() { this.spring.configLocations(xml("CustomPathMatcherConfig")).autowire(); - Message message = message("/denyAll.a"); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); - + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); message = message("/denyAll.a.b"); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test public void sendWhenIdSpecifiedThenSecurityDoesNotIntegrateWithClientInboundChannel() { this.spring.configLocations(xml("IdConfig")).autowire(); - Message message = message("/denyAll"); - - assertThatCode(send(message)).doesNotThrowAnyException(); + send(message); } @Test @WithMockUser public void sendWhenIdSpecifiedAndExplicitlyIntegratedWhenBrokerUsesClientInboundChannel() { this.spring.configLocations(xml("IdIntegratedConfig")).autowire(); - Message message = message("/denyAll"); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); } @Test public void sendWhenNoIdSpecifiedThenSecurityDoesntOverrideCustomInterceptors() { this.spring.configLocations(xml("CustomInterceptorConfig")).autowire(); - Message message = message("/throwAll"); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(UnsupportedOperationException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(UnsupportedOperationException.class); } @Test @WithMockUser(username = "nile") public void sendWhenCustomExpressionHandlerThenAuthorizesAccordingly() { this.spring.configLocations(xml("CustomExpressionHandlerConfig")).autowire(); - Message message = message("/denyNile"); - - assertThatThrownBy(send(message)).hasCauseInstanceOf(AccessDeniedException.class); + assertThatExceptionOfType(Exception.class).isThrownBy(send(message)) + .withCauseInstanceOf(AccessDeniedException.class); } private String xml(String configName) { @@ -440,40 +364,42 @@ public class WebSocketMessageBrokerConfigTests { headers.setSessionId("123"); headers.setSessionAttributes(new HashMap<>()); headers.setDestination(destination); - if (SecurityContextHolder.getContext().getAuthentication() != null) { headers.setUser(SecurityContextHolder.getContext().getAuthentication()); } - headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token); - return new GenericMessage<>("hi", headers.getMessageHeaders()); } @Controller static class MessageController { + String username; @MessageMapping("/message") - public void authentication(@AuthenticationPrincipal String username) { + void authentication(@AuthenticationPrincipal String username) { this.username = username; } + } @Controller static class MessageWithArgumentController { + MessageArgument messageArgument; @MessageMapping("/message-with-argument") - public void myCustom(MessageArgument messageArgument) { + void myCustom(MessageArgument messageArgument) { this.messageArgument = messageArgument; } + } - static class MessageArgument { + MessageArgument(String notDefaultConstructor) { } + } static class MessageArgumentResolver implements HandlerMethodArgumentResolver { @@ -487,22 +413,21 @@ public class WebSocketMessageBrokerConfigTests { public Object resolveArgument(MethodParameter parameter, Message message) { return new MessageArgument(""); } + } static class TestHandshakeHandler implements HandshakeHandler { + Map attributes; @Override - public boolean doHandshake( - ServerHttpRequest request, - org.springframework.http.server.ServerHttpResponse response, - WebSocketHandler wsHandler, + public boolean doHandshake(ServerHttpRequest request, + org.springframework.http.server.ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { - this.attributes = attributes; - return true; } + } static class InboundExecutorPostProcessor implements BeanDefinitionRegistryPostProcessor { @@ -510,14 +435,14 @@ public class WebSocketMessageBrokerConfigTests { @Override public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { BeanDefinition inbound = registry.getBeanDefinition("clientInboundChannel"); - inbound.getConstructorArgumentValues() - .addIndexedArgumentValue(0, new RootBeanDefinition(SyncTaskExecutor.class)); + inbound.getConstructorArgumentValues().addIndexedArgumentValue(0, + new RootBeanDefinition(SyncTaskExecutor.class)); } @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { - } + } static class ExceptingInterceptor extends ChannelInterceptorAdapter { @@ -526,16 +451,14 @@ public class WebSocketMessageBrokerConfigTests { public Message preSend(Message message, MessageChannel channel) { throw new UnsupportedOperationException("no"); } + } - static class DenyNileMessageSecurityExpressionHandler - extends DefaultMessageSecurityExpressionHandler { + static class DenyNileMessageSecurityExpressionHandler extends DefaultMessageSecurityExpressionHandler { @Override - protected SecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, + protected SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, Message invocation) { - return new MessageSecurityExpressionRoot(authentication, invocation) { public boolean denyNile() { Authentication auth = getAuthentication(); @@ -543,5 +466,7 @@ public class WebSocketMessageBrokerConfigTests { } }; } + } + } diff --git a/config/src/test/java/org/springframework/security/htmlunit/server/HtmlUnitWebTestClient.java b/config/src/test/java/org/springframework/security/htmlunit/server/HtmlUnitWebTestClient.java index 4543deaae4..5efc812af0 100644 --- a/config/src/test/java/org/springframework/security/htmlunit/server/HtmlUnitWebTestClient.java +++ b/config/src/test/java/org/springframework/security/htmlunit/server/HtmlUnitWebTestClient.java @@ -27,8 +27,9 @@ import java.util.StringTokenizer; import com.gargoylesoftware.htmlunit.FormEncodingType; import com.gargoylesoftware.htmlunit.WebClient; import com.gargoylesoftware.htmlunit.WebRequest; - import com.gargoylesoftware.htmlunit.util.NameValuePair; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; @@ -43,7 +44,6 @@ import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; -import reactor.core.publisher.Mono; final class HtmlUnitWebTestClient { @@ -55,21 +55,23 @@ final class HtmlUnitWebTestClient { Assert.notNull(webClient, "WebClient must not be null"); Assert.notNull(webTestClient, "WebTestClient must not be null"); this.webClient = webClient; + // @formatter:off this.webTestClient = webTestClient.mutate() - .filter(new FollowRedirects()) - .filter(new CookieManager()) - .build(); + .filter(new FollowRedirects()) + .filter(new CookieManager()) + .build(); + // @formatter:on } - public FluxExchangeResult getResponse(WebRequest webRequest) { + FluxExchangeResult getResponse(WebRequest webRequest) { + // @formatter:off WebTestClient.RequestBodySpec request = this.webTestClient .method(httpMethod(webRequest)) .uri(uri(webRequest)); + // @formatter:on contentType(request, webRequest); cookies(request, webRequest); headers(request, webRequest); - - return content(request, webRequest).exchange().returnResult(String.class); } @@ -87,7 +89,7 @@ final class HtmlUnitWebTestClient { private MultiValueMap formData(List params) { MultiValueMap result = new LinkedMultiValueMap<>(params.size()); - params.forEach( pair -> result.add(pair.getName(), pair.getValue())); + params.forEach((pair) -> result.add(pair.getName(), pair.getValue())); return result; } @@ -99,7 +101,7 @@ final class HtmlUnitWebTestClient { contentType = encodingType.getName(); } } - MediaType mediaType = contentType == null ? MediaType.ALL : MediaType.parseMediaType(contentType); + MediaType mediaType = (contentType != null) ? MediaType.parseMediaType(contentType) : MediaType.ALL; request.contentType(mediaType); } @@ -109,14 +111,12 @@ final class HtmlUnitWebTestClient { StringTokenizer tokens = new StringTokenizer(cookieHeaderValue, "=;"); while (tokens.hasMoreTokens()) { String cookieName = tokens.nextToken().trim(); - Assert.isTrue(tokens.hasMoreTokens(), - () -> "Expected value for cookie name '" + cookieName + - "': full cookie header was [" + cookieHeaderValue + "]"); + Assert.isTrue(tokens.hasMoreTokens(), () -> "Expected value for cookie name '" + cookieName + + "': full cookie header was [" + cookieHeaderValue + "]"); String cookieValue = tokens.nextToken().trim(); request.cookie(cookieName, cookieValue); } } - Set managedCookies = this.webClient.getCookies(webRequest.getUrl()); for (com.gargoylesoftware.htmlunit.util.Cookie cookie : managedCookies) { request.cookie(cookie.getName(), cookie.getValue()); @@ -129,7 +129,7 @@ final class HtmlUnitWebTestClient { } private void headers(WebTestClient.RequestBodySpec request, WebRequest webRequest) { - webRequest.getAdditionalHeaders().forEach( (name, value) -> request.header(name, value)); + webRequest.getAdditionalHeaders().forEach((name, value) -> request.header(name, value)); } private HttpMethod httpMethod(WebRequest webRequest) { @@ -143,66 +143,66 @@ final class HtmlUnitWebTestClient { } static class FollowRedirects implements ExchangeFilterFunction { + @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return next.exchange(request) - .flatMap( response -> redirectIfNecessary(request, next, response)); + return next.exchange(request).flatMap((response) -> redirectIfNecessary(request, next, response)); } - private Mono redirectIfNecessary(ClientRequest request, ExchangeFunction next, ClientResponse response) { + private Mono redirectIfNecessary(ClientRequest request, ExchangeFunction next, + ClientResponse response) { URI location = response.headers().asHttpHeaders().getLocation(); String host = request.url().getHost(); String scheme = request.url().getScheme(); if (location != null) { String redirectUrl = location.toASCIIString(); if (location.getHost() == null) { - redirectUrl = scheme+ "://" + host + location.toASCIIString(); + redirectUrl = scheme + "://" + host + location.toASCIIString(); } + // @formatter:off ClientRequest redirect = ClientRequest.method(HttpMethod.GET, URI.create(redirectUrl)) - .headers(headers -> headers.addAll(request.headers())) - .cookies(cookies -> cookies.addAll(request.cookies())) - .attributes(attributes -> attributes.putAll(request.attributes())) - .build(); - - return next.exchange(redirect).flatMap( r -> redirectIfNecessary(request, next, r)); + .headers((headers) -> headers.addAll(request.headers())) + .cookies((cookies) -> cookies.addAll(request.cookies())) + .attributes((attributes) -> attributes.putAll(request.attributes())) + .build(); + // @formatter:on + return next.exchange(redirect).flatMap((r) -> redirectIfNecessary(request, next, r)); } - return Mono.just(response); } + } static class CookieManager implements ExchangeFilterFunction { + private Map cookies = new HashMap<>(); @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return next.exchange(withClientCookies(request)) - .doOnSuccess( response -> { - response.cookies().values().forEach( cookies -> { - cookies.forEach( cookie -> { - if (cookie.getMaxAge().isZero()) { - this.cookies.remove(cookie.getName()); - } else { - this.cookies.put(cookie.getName(), cookie); - } - }); + return next.exchange(withClientCookies(request)).doOnSuccess((response) -> { + response.cookies().values().forEach((cookies) -> { + cookies.forEach((cookie) -> { + if (cookie.getMaxAge().isZero()) { + this.cookies.remove(cookie.getName()); + } + else { + this.cookies.put(cookie.getName(), cookie); + } }); }); + }); } private ClientRequest withClientCookies(ClientRequest request) { - return ClientRequest.from(request) - .cookies( c -> { - c.addAll(clientCookies()); - }).build(); + return ClientRequest.from(request).cookies((c) -> c.addAll(clientCookies())).build(); } private MultiValueMap clientCookies() { MultiValueMap result = new LinkedMultiValueMap<>(this.cookies.size()); - this.cookies.values().forEach( cookie -> - result.add(cookie.getName(), cookie.getValue()) - ); + this.cookies.values().forEach((cookie) -> result.add(cookie.getName(), cookie.getValue())); return result; } + } + } diff --git a/config/src/test/java/org/springframework/security/htmlunit/server/MockWebResponseBuilder.java b/config/src/test/java/org/springframework/security/htmlunit/server/MockWebResponseBuilder.java index 15022b4596..ef1253c08d 100644 --- a/config/src/test/java/org/springframework/security/htmlunit/server/MockWebResponseBuilder.java +++ b/config/src/test/java/org/springframework/security/htmlunit/server/MockWebResponseBuilder.java @@ -35,13 +35,13 @@ import org.springframework.util.Assert; * @since 5.0 */ final class MockWebResponseBuilder { + private final long startTime; private final WebRequest webRequest; private final FluxExchangeResult exchangeResult; - MockWebResponseBuilder(long startTime, WebRequest webRequest, FluxExchangeResult exchangeResult) { Assert.notNull(webRequest, "WebRequest must not be null"); Assert.notNull(exchangeResult, "FluxExchangeResult must not be null"); @@ -50,8 +50,7 @@ final class MockWebResponseBuilder { this.exchangeResult = exchangeResult; } - - public WebResponse build() throws IOException { + WebResponse build() throws IOException { WebResponseData webResponseData = webResponseData(); long endTime = System.currentTimeMillis(); return new WebResponse(webResponseData, this.webRequest, endTime - this.startTime); @@ -60,17 +59,15 @@ final class MockWebResponseBuilder { private WebResponseData webResponseData() { List responseHeaders = responseHeaders(); HttpStatus status = this.exchangeResult.getStatus(); - return new WebResponseData(this.exchangeResult.getResponseBodyContent(), status.value(), status.getReasonPhrase(), responseHeaders); + return new WebResponseData(this.exchangeResult.getResponseBodyContent(), status.value(), + status.getReasonPhrase(), responseHeaders); } private List responseHeaders() { HttpHeaders responseHeaders = this.exchangeResult.getResponseHeaders(); List result = new ArrayList<>(responseHeaders.size()); - responseHeaders.forEach( (headerName, headerValues) -> - headerValues.forEach( headerValue -> - result.add(new NameValuePair(headerName, headerValue)) - ) - ); + responseHeaders.forEach((headerName, headerValues) -> headerValues + .forEach((headerValue) -> result.add(new NameValuePair(headerName, headerValue)))); return result; } diff --git a/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilder.java b/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilder.java index ae345738c6..db19b75056 100644 --- a/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilder.java +++ b/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilder.java @@ -19,6 +19,7 @@ package org.springframework.security.htmlunit.server; import com.gargoylesoftware.htmlunit.WebClient; import com.gargoylesoftware.htmlunit.WebConnection; import org.openqa.selenium.WebDriver; + import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.servlet.htmlunit.DelegatingWebConnection; import org.springframework.test.web.servlet.htmlunit.DelegatingWebConnection.DelegateWebConnection; @@ -29,7 +30,8 @@ import org.springframework.test.web.servlet.htmlunit.webdriver.WebConnectionHtml * @author Rob Winch * @since 5.0 */ -public class WebTestClientHtmlUnitDriverBuilder { +public final class WebTestClientHtmlUnitDriverBuilder { + private final WebTestClient webTestClient; private WebTestClientHtmlUnitDriverBuilder(WebTestClient webTestClient) { @@ -40,7 +42,8 @@ public class WebTestClientHtmlUnitDriverBuilder { WebConnectionHtmlUnitDriver driver = new WebConnectionHtmlUnitDriver(); WebClient webClient = driver.getWebClient(); WebTestClientWebConnection webClientConnection = new WebTestClientWebConnection(this.webTestClient, webClient); - WebConnection connection = new DelegatingWebConnection(driver.getWebConnection(), new DelegateWebConnection(new HostRequestMatcher("localhost"), webClientConnection)); + WebConnection connection = new DelegatingWebConnection(driver.getWebConnection(), + new DelegateWebConnection(new HostRequestMatcher("localhost"), webClientConnection)); driver.setWebConnection(connection); return driver; } @@ -48,4 +51,5 @@ public class WebTestClientHtmlUnitDriverBuilder { public static WebTestClientHtmlUnitDriverBuilder webTestClientSetup(WebTestClient webTestClient) { return new WebTestClientHtmlUnitDriverBuilder(webTestClient); } + } diff --git a/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilderTests.java b/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilderTests.java index 023b1fc7dc..e3f8bb4a13 100644 --- a/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilderTests.java +++ b/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientHtmlUnitDriverBuilderTests.java @@ -16,8 +16,13 @@ package org.springframework.security.htmlunit.server; +import java.net.URI; +import java.time.Duration; + import org.junit.Test; import org.openqa.selenium.WebDriver; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; @@ -28,10 +33,6 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.CookieValue; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.ResponseBody; -import reactor.core.publisher.Mono; - -import java.net.URI; -import java.time.Duration; import static org.assertj.core.api.Assertions.assertThat; @@ -43,26 +44,39 @@ public class WebTestClientHtmlUnitDriverBuilderTests { @Test public void helloWorld() { - WebTestClient webTestClient = WebTestClient - .bindToController(new HelloWorldController()) - .build(); + WebTestClient webTestClient = WebTestClient.bindToController(new HelloWorldController()).build(); + // @formatter:off WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient).build(); - + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on driver.get("http://localhost/"); - assertThat(driver.getPageSource()).contains("Hello World"); } - /** - * @author Rob Winch - * @since 5.0 - */ + @Test + public void cookies() { + // @formatter:off + WebTestClient webTestClient = WebTestClient + .bindToController(new CookieController()) + .build(); + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + // @formatter:on + driver.get("http://localhost/cookie"); + assertThat(driver.getPageSource()).contains("theCookie"); + driver.get("http://localhost/cookie/delete"); + assertThat(driver.getPageSource()).contains("null"); + } + @Controller class HelloWorldController { + @ResponseBody @GetMapping(produces = MediaType.TEXT_HTML_VALUE) - public String index() { + String index() { + // @formatter:off return "\n" + "\n" + "Hello World\n" @@ -71,43 +85,33 @@ public class WebTestClientHtmlUnitDriverBuilderTests { + "

    Hello World

    \n" + "\n" + ""; + // @formatter:on } - } - @Test - public void cookies() { - WebTestClient webTestClient = WebTestClient - .bindToController(new CookieController()) - .build(); - WebDriver driver = WebTestClientHtmlUnitDriverBuilder - .webTestClientSetup(webTestClient).build(); - - driver.get("http://localhost/cookie"); - - assertThat(driver.getPageSource()).contains("theCookie"); - - driver.get("http://localhost/cookie/delete"); - - assertThat(driver.getPageSource()).contains("null"); } @Controller @ResponseBody class CookieController { + @GetMapping(path = "/", produces = MediaType.TEXT_HTML_VALUE) - public String view(@CookieValue(required = false) String cookieName) { + String view(@CookieValue(required = false) String cookieName) { + // @formatter:off return "\n" + "\n" + "Hello World\n" + "\n" + "\n" - + "

    " + TextEscapeUtils.escapeEntities(cookieName) + "

    \n" + + "

    " + + TextEscapeUtils.escapeEntities(cookieName) + + "

    \n" + "\n" + ""; + // @formatter:on } @GetMapping("/cookie") - public Mono setCookie(ServerHttpResponse response) { + Mono setCookie(ServerHttpResponse response) { response.addCookie(ResponseCookie.from("cookieName", "theCookie").build()); return redirect(response); } @@ -119,10 +123,11 @@ public class WebTestClientHtmlUnitDriverBuilderTests { } @GetMapping("/cookie/delete") - public Mono deleteCookie(ServerHttpResponse response) { - response.addCookie( - ResponseCookie.from("cookieName", "").maxAge(Duration.ofSeconds(0)).build()); + Mono deleteCookie(ServerHttpResponse response) { + response.addCookie(ResponseCookie.from("cookieName", "").maxAge(Duration.ofSeconds(0)).build()); return redirect(response); } + } + } diff --git a/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientWebConnection.java b/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientWebConnection.java index c5d6828c03..a2fed90d58 100644 --- a/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientWebConnection.java +++ b/config/src/test/java/org/springframework/security/htmlunit/server/WebTestClientWebConnection.java @@ -37,11 +37,11 @@ public class WebTestClientWebConnection implements WebConnection { private final WebTestClient webTestClient; private final String contextPath; + private final HtmlUnitWebTestClient requestBuilder; private WebClient webClient; - public WebTestClientWebConnection(WebTestClient webTestClient, WebClient webClient) { this(webTestClient, webClient, ""); } @@ -50,7 +50,6 @@ public class WebTestClientWebConnection implements WebConnection { Assert.notNull(webTestClient, "MockMvc must not be null"); Assert.notNull(webClient, "WebClient must not be null"); validateContextPath(contextPath); - this.webClient = webClient; this.webTestClient = webTestClient; this.contextPath = contextPath; @@ -59,10 +58,11 @@ public class WebTestClientWebConnection implements WebConnection { /** * Validate the supplied {@code contextPath}. - *

    If the value is not {@code null}, it must conform to - * {@link javax.servlet.http.HttpServletRequest#getContextPath()} which - * states that it can be an empty string and otherwise must start with - * a "/" character and not end with a "/" character. + *

    + * If the value is not {@code null}, it must conform to + * {@link javax.servlet.http.HttpServletRequest#getContextPath()} which states that it + * can be an empty string and otherwise must start with a "/" character and not end + * with a "/" character. * @param contextPath the path to validate */ static void validateContextPath(@Nullable String contextPath) { @@ -73,7 +73,6 @@ public class WebTestClientWebConnection implements WebConnection { Assert.isTrue(!contextPath.endsWith("/"), () -> "contextPath '" + contextPath + "' must not end with '/'."); } - public void setWebClient(WebClient webClient) { Assert.notNull(webClient, "WebClient must not be null"); this.webClient = webClient; @@ -82,12 +81,13 @@ public class WebTestClientWebConnection implements WebConnection { @Override public WebResponse getResponse(WebRequest webRequest) throws IOException { long startTime = System.currentTimeMillis(); - FluxExchangeResult exchangeResult = this.requestBuilder.getResponse(webRequest); webRequest.setUrl(exchangeResult.getUrl().toURL()); return new MockWebResponseBuilder(startTime, webRequest, exchangeResult).build(); } @Override - public void close() {} + public void close() { + } + } diff --git a/config/src/test/java/org/springframework/security/intercept/method/aopalliance/MethodSecurityInterceptorWithAopConfigTests.java b/config/src/test/java/org/springframework/security/intercept/method/aopalliance/MethodSecurityInterceptorWithAopConfigTests.java index e77dfc900f..309386fb82 100644 --- a/config/src/test/java/org/springframework/security/intercept/method/aopalliance/MethodSecurityInterceptorWithAopConfigTests.java +++ b/config/src/test/java/org/springframework/security/intercept/method/aopalliance/MethodSecurityInterceptorWithAopConfigTests.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.intercept.method.aopalliance; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.intercept.method.aopalliance; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.context.support.AbstractXmlApplicationContext; import org.springframework.security.ITargetObject; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.config.util.InMemoryXmlApplicationContext; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.fail; + /** * Tests for SEC-428 (and SEC-1204). * @@ -33,19 +35,27 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Ben Alex */ public class MethodSecurityInterceptorWithAopConfigTests { + + // @formatter:off static final String AUTH_PROVIDER_XML = "" + " " + " " + " " + " " - + " " + " " + + " " + + " " + ""; + // @formatter:on + // @formatter:off static final String ACCESS_MANAGER_XML = "" + " " + " " - + " " + ""; + + " " + + ""; + // @formatter:on + // @formatter:off static final String TARGET_BEAN_AND_INTERCEPTOR = "" + "" + " " @@ -56,6 +66,7 @@ public class MethodSecurityInterceptorWithAopConfigTests { + " " + " " + ""; + // @formatter:on private AbstractXmlApplicationContext appContext; @@ -67,22 +78,24 @@ public class MethodSecurityInterceptorWithAopConfigTests { @After public void closeAppContext() { SecurityContextHolder.clearContext(); - if (appContext != null) { - appContext.close(); - appContext = null; + if (this.appContext != null) { + this.appContext.close(); + this.appContext = null; } } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void securityInterceptorIsAppliedWhenUsedWithAopConfig() { + // @formatter:off setContext("" + " " + " " - + "" + TARGET_BEAN_AND_INTERCEPTOR + AUTH_PROVIDER_XML + + "" + + TARGET_BEAN_AND_INTERCEPTOR + + AUTH_PROVIDER_XML + ACCESS_MANAGER_XML); - - ITargetObject target = (ITargetObject) appContext.getBean("target"); - + // @formatter:on + ITargetObject target = (ITargetObject) this.appContext.getBean("target"); // Check both against interface and class try { target.makeLowerCase("TEST"); @@ -90,12 +103,12 @@ public class MethodSecurityInterceptorWithAopConfigTests { } catch (AuthenticationCredentialsNotFoundException expected) { } - target.makeUpperCase("test"); } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void securityInterceptorIsAppliedWhenUsedWithBeanNameAutoProxyCreator() { + // @formatter:off setContext("" + " " + " " @@ -109,22 +122,23 @@ public class MethodSecurityInterceptorWithAopConfigTests { + " " + " " + "" - + TARGET_BEAN_AND_INTERCEPTOR + AUTH_PROVIDER_XML + ACCESS_MANAGER_XML); - - ITargetObject target = (ITargetObject) appContext.getBean("target"); + + TARGET_BEAN_AND_INTERCEPTOR + + AUTH_PROVIDER_XML + + ACCESS_MANAGER_XML); + // @formatter:on + ITargetObject target = (ITargetObject) this.appContext.getBean("target"); try { target.makeLowerCase("TEST"); fail("AuthenticationCredentialsNotFoundException expected"); } catch (AuthenticationCredentialsNotFoundException expected) { } - target.makeUpperCase("test"); - } private void setContext(String context) { - appContext = new InMemoryXmlApplicationContext(context); + this.appContext = new InMemoryXmlApplicationContext(context); } + } diff --git a/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-UsingSpel.xml b/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-UsingSpel.xml index 7cc3784f25..8146964536 100644 --- a/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-UsingSpel.xml +++ b/config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-UsingSpel.xml @@ -27,9 +27,9 @@ + default-target-url="#{T(org.springframework.security.config.http.WebConfigUtilsTests).URL}/default" + authentication-failure-url="#{T(org.springframework.security.config.http.WebConfigUtilsTests).URL}/failure" + login-page="#{T(org.springframework.security.config.http.WebConfigUtilsTests).URL}/login"/> diff --git a/core/src/main/java/org/springframework/security/access/AccessDecisionManager.java b/core/src/main/java/org/springframework/security/access/AccessDecisionManager.java index 298c754b1b..eabb82e25d 100644 --- a/core/src/main/java/org/springframework/security/access/AccessDecisionManager.java +++ b/core/src/main/java/org/springframework/security/access/AccessDecisionManager.java @@ -27,25 +27,20 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface AccessDecisionManager { - // ~ Methods - // ======================================================================================================== /** * Resolves an access control decision for the passed parameters. - * * @param authentication the caller invoking the method (not null) * @param object the secured object being called * @param configAttributes the configuration attributes associated with the secured * object being invoked - * * @throws AccessDeniedException if access is denied as the authentication does not * hold a required authority or ACL privilege * @throws InsufficientAuthenticationException if access is denied as the * authentication does not provide a sufficient level of trust */ - void decide(Authentication authentication, Object object, - Collection configAttributes) throws AccessDeniedException, - InsufficientAuthenticationException; + void decide(Authentication authentication, Object object, Collection configAttributes) + throws AccessDeniedException, InsufficientAuthenticationException; /** * Indicates whether this AccessDecisionManager is able to process @@ -56,10 +51,8 @@ public interface AccessDecisionManager { * AccessDecisionManager and/or RunAsManager and/or * AfterInvocationManager. *

    - * * @param attribute a configuration attribute that has been configured against the * AbstractSecurityInterceptor - * * @return true if this AccessDecisionManager can support the passed * configuration attribute */ @@ -68,10 +61,9 @@ public interface AccessDecisionManager { /** * Indicates whether the AccessDecisionManager implementation is able to * provide access control decisions for the indicated secured object type. - * * @param clazz the class that is being queried - * * @return true if the implementation can process the indicated class */ boolean supports(Class clazz); + } diff --git a/core/src/main/java/org/springframework/security/access/AccessDecisionVoter.java b/core/src/main/java/org/springframework/security/access/AccessDecisionVoter.java index bf4af5ba6a..e1d98fc8f3 100644 --- a/core/src/main/java/org/springframework/security/access/AccessDecisionVoter.java +++ b/core/src/main/java/org/springframework/security/access/AccessDecisionVoter.java @@ -30,15 +30,12 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface AccessDecisionVoter { - // ~ Static fields/initializers - // ===================================================================================== int ACCESS_GRANTED = 1; - int ACCESS_ABSTAIN = 0; - int ACCESS_DENIED = -1; - // ~ Methods - // ======================================================================================================== + int ACCESS_ABSTAIN = 0; + + int ACCESS_DENIED = -1; /** * Indicates whether this {@code AccessDecisionVoter} is able to vote on the passed @@ -47,10 +44,8 @@ public interface AccessDecisionVoter { * This allows the {@code AbstractSecurityInterceptor} to check every configuration * attribute can be consumed by the configured {@code AccessDecisionManager} and/or * {@code RunAsManager} and/or {@code AfterInvocationManager}. - * * @param attribute a configuration attribute that has been configured against the * {@code AbstractSecurityInterceptor} - * * @return true if this {@code AccessDecisionVoter} can support the passed * configuration attribute */ @@ -59,9 +54,7 @@ public interface AccessDecisionVoter { /** * Indicates whether the {@code AccessDecisionVoter} implementation is able to provide * access control votes for the indicated secured object type. - * * @param clazz the class that is being queried - * * @return true if the implementation can process the indicated class */ boolean supports(Class clazz); @@ -87,14 +80,12 @@ public interface AccessDecisionVoter { * parameter to maximise flexibility in making access control decisions, implementing * classes should not modify it or cause the represented invocation to take place (for * example, by calling {@code MethodInvocation.proceed()}). - * * @param authentication the caller making the invocation * @param object the secured object being invoked * @param attributes the configuration attributes associated with the secured object - * * @return either {@link #ACCESS_GRANTED}, {@link #ACCESS_ABSTAIN} or * {@link #ACCESS_DENIED} */ - int vote(Authentication authentication, S object, - Collection attributes); + int vote(Authentication authentication, S object, Collection attributes); + } diff --git a/core/src/main/java/org/springframework/security/access/AccessDeniedException.java b/core/src/main/java/org/springframework/security/access/AccessDeniedException.java index 8429d92a16..3bf6ceac5a 100644 --- a/core/src/main/java/org/springframework/security/access/AccessDeniedException.java +++ b/core/src/main/java/org/springframework/security/access/AccessDeniedException.java @@ -23,12 +23,9 @@ package org.springframework.security.access; * @author Ben Alex */ public class AccessDeniedException extends RuntimeException { - // ~ Constructors - // =================================================================================================== /** * Constructs an AccessDeniedException with the specified message. - * * @param msg the detail message */ public AccessDeniedException(String msg) { @@ -38,11 +35,11 @@ public class AccessDeniedException extends RuntimeException { /** * Constructs an AccessDeniedException with the specified message and * root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public AccessDeniedException(String msg, Throwable t) { - super(msg, t); + public AccessDeniedException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/access/AfterInvocationProvider.java b/core/src/main/java/org/springframework/security/access/AfterInvocationProvider.java index bca38afa57..92691877f0 100644 --- a/core/src/main/java/org/springframework/security/access/AfterInvocationProvider.java +++ b/core/src/main/java/org/springframework/security/access/AfterInvocationProvider.java @@ -28,12 +28,9 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface AfterInvocationProvider { - // ~ Methods - // ======================================================================================================== - Object decide(Authentication authentication, Object object, - Collection attributes, Object returnedObject) - throws AccessDeniedException; + Object decide(Authentication authentication, Object object, Collection attributes, + Object returnedObject) throws AccessDeniedException; /** * Indicates whether this AfterInvocationProvider is able to participate @@ -44,10 +41,8 @@ public interface AfterInvocationProvider { * AccessDecisionManager and/or RunAsManager and/or * AccessDecisionManager. *

    - * * @param attribute a configuration attribute that has been configured against the * AbstractSecurityInterceptor - * * @return true if this AfterInvocationProvider can support the passed * configuration attribute */ @@ -56,10 +51,9 @@ public interface AfterInvocationProvider { /** * Indicates whether the AfterInvocationProvider is able to provide * "after invocation" processing for the indicated secured object type. - * * @param clazz the class of secure object that is being queried - * * @return true if the implementation can process the indicated class */ boolean supports(Class clazz); + } diff --git a/core/src/main/java/org/springframework/security/access/AuthorizationServiceException.java b/core/src/main/java/org/springframework/security/access/AuthorizationServiceException.java index 920849a1e0..6952be563a 100644 --- a/core/src/main/java/org/springframework/security/access/AuthorizationServiceException.java +++ b/core/src/main/java/org/springframework/security/access/AuthorizationServiceException.java @@ -25,13 +25,10 @@ package org.springframework.security.access; * @author Ben Alex */ public class AuthorizationServiceException extends AccessDeniedException { - // ~ Constructors - // =================================================================================================== /** * Constructs an AuthorizationServiceException with the specified * message. - * * @param msg the detail message */ public AuthorizationServiceException(String msg) { @@ -41,11 +38,11 @@ public class AuthorizationServiceException extends AccessDeniedException { /** * Constructs an AuthorizationServiceException with the specified message * and root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public AuthorizationServiceException(String msg, Throwable t) { - super(msg, t); + public AuthorizationServiceException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/access/ConfigAttribute.java b/core/src/main/java/org/springframework/security/access/ConfigAttribute.java index 91212036d6..b1708a6e3c 100644 --- a/core/src/main/java/org/springframework/security/access/ConfigAttribute.java +++ b/core/src/main/java/org/springframework/security/access/ConfigAttribute.java @@ -37,8 +37,6 @@ import org.springframework.security.access.intercept.RunAsManager; * @author Ben Alex */ public interface ConfigAttribute extends Serializable { - // ~ Methods - // ======================================================================================================== /** * If the ConfigAttribute can be represented as a String and @@ -52,10 +50,10 @@ public interface ConfigAttribute extends Serializable { * null will require any relying classes to specifically support the * ConfigAttribute implementation, so returning null should * be avoided unless actually required. - * * @return a representation of the configuration attribute (or null if * the configuration attribute cannot be expressed as a String with * sufficient precision). */ String getAttribute(); + } diff --git a/core/src/main/java/org/springframework/security/access/PermissionCacheOptimizer.java b/core/src/main/java/org/springframework/security/access/PermissionCacheOptimizer.java index d27ebfc141..bfb0e1ba94 100644 --- a/core/src/main/java/org/springframework/security/access/PermissionCacheOptimizer.java +++ b/core/src/main/java/org/springframework/security/access/PermissionCacheOptimizer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access; import java.util.Collection; @@ -27,14 +28,15 @@ import org.springframework.security.core.Authentication; * @since 3.1 */ public interface PermissionCacheOptimizer extends AopInfrastructureBean { + /** * Optimises the permission cache for anticipated operation on the supplied collection * of objects. Usually this will entail batch loading of permissions for the objects * in the collection. - * * @param a the user for whom permissions should be obtained. * @param objects the (non-null) collection of domain objects for which permissions * should be retrieved. */ void cachePermissionsFor(Authentication a, Collection objects); + } diff --git a/core/src/main/java/org/springframework/security/access/PermissionEvaluator.java b/core/src/main/java/org/springframework/security/access/PermissionEvaluator.java index 58ea8f145a..0371efe29e 100644 --- a/core/src/main/java/org/springframework/security/access/PermissionEvaluator.java +++ b/core/src/main/java/org/springframework/security/access/PermissionEvaluator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access; import java.io.Serializable; @@ -24,13 +25,12 @@ import org.springframework.security.core.Authentication; * Strategy used in expression evaluation to determine whether a user has a permission or * permissions for a given domain object. * - * * @author Luke Taylor * @since 3.0 */ public interface PermissionEvaluator extends AopInfrastructureBean { + /** - * * @param authentication represents the user in question. Should not be null. * @param targetDomainObject the domain object for which permissions should be * checked. May be null in which case implementations should return false, as the null @@ -39,13 +39,11 @@ public interface PermissionEvaluator extends AopInfrastructureBean { * expression system. Not null. * @return true if the permission is granted, false otherwise */ - boolean hasPermission(Authentication authentication, Object targetDomainObject, - Object permission); + boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission); /** * Alternative method for evaluating a permission where only the identifier of the * target object is available, rather than the target instance itself. - * * @param authentication represents the user in question. Should not be null. * @param targetId the identifier for the object instance (usually a Long) * @param targetType a String representing the target's type (usually a Java @@ -54,6 +52,6 @@ public interface PermissionEvaluator extends AopInfrastructureBean { * expression system. Not null. * @return true if the permission is granted, false otherwise */ - boolean hasPermission(Authentication authentication, Serializable targetId, - String targetType, Object permission); + boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, Object permission); + } diff --git a/core/src/main/java/org/springframework/security/access/SecurityConfig.java b/core/src/main/java/org/springframework/security/access/SecurityConfig.java index 67823dfc15..4be11c2dfc 100644 --- a/core/src/main/java/org/springframework/security/access/SecurityConfig.java +++ b/core/src/main/java/org/springframework/security/access/SecurityConfig.java @@ -28,30 +28,20 @@ import org.springframework.util.StringUtils; * @author Ben Alex */ public class SecurityConfig implements ConfigAttribute { - // ~ Instance fields - // ================================================================================================ private final String attrib; - // ~ Constructors - // =================================================================================================== - public SecurityConfig(String config) { Assert.hasText(config, "You must provide a configuration attribute"); this.attrib = config; } - // ~ Methods - // ======================================================================================================== - @Override public boolean equals(Object obj) { if (obj instanceof ConfigAttribute) { ConfigAttribute attr = (ConfigAttribute) obj; - return this.attrib.equals(attr.getAttribute()); } - return false; } @@ -76,13 +66,11 @@ public class SecurityConfig implements ConfigAttribute { public static List createList(String... attributeNames) { Assert.notNull(attributeNames, "You must supply an array of attribute names"); - List attributes = new ArrayList<>( - attributeNames.length); - + List attributes = new ArrayList<>(attributeNames.length); for (String attribute : attributeNames) { attributes.add(new SecurityConfig(attribute.trim())); } - return attributes; } + } diff --git a/core/src/main/java/org/springframework/security/access/SecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/SecurityMetadataSource.java index 381d2c5875..6e504db9b6 100644 --- a/core/src/main/java/org/springframework/security/access/SecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/SecurityMetadataSource.java @@ -28,22 +28,16 @@ import org.springframework.security.access.intercept.AbstractSecurityInterceptor * @author Ben Alex */ public interface SecurityMetadataSource extends AopInfrastructureBean { - // ~ Methods - // ======================================================================================================== /** * Accesses the {@code ConfigAttribute}s that apply to a given secure object. - * * @param object the object being secured - * * @return the attributes that apply to the passed in secured object. Should return an * empty collection if there are no applicable attributes. - * * @throws IllegalArgumentException if the passed object is not of a type supported by * the SecurityMetadataSource implementation */ - Collection getAttributes(Object object) - throws IllegalArgumentException; + Collection getAttributes(Object object) throws IllegalArgumentException; /** * If available, returns all of the {@code ConfigAttribute}s defined by the @@ -51,7 +45,6 @@ public interface SecurityMetadataSource extends AopInfrastructureBean { *

    * This is used by the {@link AbstractSecurityInterceptor} to perform startup time * validation of each {@code ConfigAttribute} configured against it. - * * @return the {@code ConfigAttribute}s or {@code null} if unsupported */ Collection getAllConfigAttributes(); @@ -59,10 +52,9 @@ public interface SecurityMetadataSource extends AopInfrastructureBean { /** * Indicates whether the {@code SecurityMetadataSource} implementation is able to * provide {@code ConfigAttribute}s for the indicated secure object type. - * * @param clazz the class that is being queried - * * @return true if the implementation can process the indicated class */ boolean supports(Class clazz); + } diff --git a/core/src/main/java/org/springframework/security/access/annotation/AnnotationMetadataExtractor.java b/core/src/main/java/org/springframework/security/access/annotation/AnnotationMetadataExtractor.java index 8f4544b99b..274856f9dd 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/AnnotationMetadataExtractor.java +++ b/core/src/main/java/org/springframework/security/access/annotation/AnnotationMetadataExtractor.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; -import org.springframework.security.access.ConfigAttribute; - import java.lang.annotation.Annotation; -import java.util.*; +import java.util.Collection; + +import org.springframework.security.access.ConfigAttribute; /** * Strategy to process a custom security annotation to extract the relevant @@ -31,4 +32,5 @@ import java.util.*; public interface AnnotationMetadataExtractor { Collection extractAttributes(A securityAnnotation); + } diff --git a/core/src/main/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSource.java index 6772ccdff0..2cc9700280 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSource.java @@ -36,8 +36,7 @@ import org.springframework.security.access.method.AbstractFallbackMethodSecurity * @author Ben Alex * @since 2.0 */ -public class Jsr250MethodSecurityMetadataSource extends - AbstractFallbackMethodSecurityMetadataSource { +public class Jsr250MethodSecurityMetadataSource extends AbstractFallbackMethodSecurityMetadataSource { private String defaultRolePrefix = "ROLE_"; @@ -51,22 +50,23 @@ public class Jsr250MethodSecurityMetadataSource extends *

    * If null or empty, then no default role prefix is used. *

    - * * @param defaultRolePrefix the default prefix to add to roles. Default "ROLE_". */ public void setDefaultRolePrefix(String defaultRolePrefix) { this.defaultRolePrefix = defaultRolePrefix; } + @Override protected Collection findAttributes(Class clazz) { return processAnnotations(clazz.getAnnotations()); } - protected Collection findAttributes(Method method, - Class targetClass) { + @Override + protected Collection findAttributes(Method method, Class targetClass) { return processAnnotations(AnnotationUtils.getAnnotations(method)); } + @Override public Collection getAllConfigAttributes() { return null; } @@ -76,18 +76,17 @@ public class Jsr250MethodSecurityMetadataSource extends return null; } List attributes = new ArrayList<>(); - - for (Annotation a : annotations) { - if (a instanceof DenyAll) { + for (Annotation annotation : annotations) { + if (annotation instanceof DenyAll) { attributes.add(Jsr250SecurityConfig.DENY_ALL_ATTRIBUTE); return attributes; } - if (a instanceof PermitAll) { + if (annotation instanceof PermitAll) { attributes.add(Jsr250SecurityConfig.PERMIT_ALL_ATTRIBUTE); return attributes; } - if (a instanceof RolesAllowed) { - RolesAllowed ra = (RolesAllowed) a; + if (annotation instanceof RolesAllowed) { + RolesAllowed ra = (RolesAllowed) annotation; for (String allowed : ra.value()) { String defaultedAllowed = getRoleWithDefaultPrefix(allowed); @@ -103,12 +102,13 @@ public class Jsr250MethodSecurityMetadataSource extends if (role == null) { return role; } - if (defaultRolePrefix == null || defaultRolePrefix.length() == 0) { + if (this.defaultRolePrefix == null || this.defaultRolePrefix.length() == 0) { return role; } - if (role.startsWith(defaultRolePrefix)) { + if (role.startsWith(this.defaultRolePrefix)) { return role; } - return defaultRolePrefix + role; + return this.defaultRolePrefix + role; } + } diff --git a/core/src/main/java/org/springframework/security/access/annotation/Jsr250SecurityConfig.java b/core/src/main/java/org/springframework/security/access/annotation/Jsr250SecurityConfig.java index 25bab456ee..133498535b 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/Jsr250SecurityConfig.java +++ b/core/src/main/java/org/springframework/security/access/annotation/Jsr250SecurityConfig.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; -import org.springframework.security.access.SecurityConfig; - -import javax.annotation.security.PermitAll; import javax.annotation.security.DenyAll; +import javax.annotation.security.PermitAll; + +import org.springframework.security.access.SecurityConfig; /** * Security config applicable as a JSR 250 annotation attribute. @@ -27,13 +28,13 @@ import javax.annotation.security.DenyAll; * @since 2.0 */ public class Jsr250SecurityConfig extends SecurityConfig { - public static final Jsr250SecurityConfig PERMIT_ALL_ATTRIBUTE = new Jsr250SecurityConfig( - PermitAll.class.getName()); - public static final Jsr250SecurityConfig DENY_ALL_ATTRIBUTE = new Jsr250SecurityConfig( - DenyAll.class.getName()); + + public static final Jsr250SecurityConfig PERMIT_ALL_ATTRIBUTE = new Jsr250SecurityConfig(PermitAll.class.getName()); + + public static final Jsr250SecurityConfig DENY_ALL_ATTRIBUTE = new Jsr250SecurityConfig(DenyAll.class.getName()); public Jsr250SecurityConfig(String role) { super(role); } -} \ No newline at end of file +} diff --git a/core/src/main/java/org/springframework/security/access/annotation/Jsr250Voter.java b/core/src/main/java/org/springframework/security/access/annotation/Jsr250Voter.java index 9fb6f194e7..d94ae1d3da 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/Jsr250Voter.java +++ b/core/src/main/java/org/springframework/security/access/annotation/Jsr250Voter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; import java.util.Collection; @@ -33,20 +34,20 @@ public class Jsr250Voter implements AccessDecisionVoter { /** * The specified config attribute is supported if its an instance of a * {@link Jsr250SecurityConfig}. - * * @param configAttribute The config attribute. * @return whether the config attribute is supported. */ + @Override public boolean supports(ConfigAttribute configAttribute) { return configAttribute instanceof Jsr250SecurityConfig; } /** * All classes are supported. - * * @param clazz the class. * @return true */ + @Override public boolean supports(Class clazz) { return true; } @@ -56,25 +57,21 @@ public class Jsr250Voter implements AccessDecisionVoter { *

    * If no JSR-250 attributes are found, it will abstain, otherwise it will grant or * deny access based on the attributes that are found. - * * @param authentication The authentication object. * @param object The access object. * @param definition The configuration definition. * @return The vote. */ - public int vote(Authentication authentication, Object object, - Collection definition) { + @Override + public int vote(Authentication authentication, Object object, Collection definition) { boolean jsr250AttributeFound = false; - for (ConfigAttribute attribute : definition) { if (Jsr250SecurityConfig.PERMIT_ALL_ATTRIBUTE.equals(attribute)) { return ACCESS_GRANTED; } - if (Jsr250SecurityConfig.DENY_ALL_ATTRIBUTE.equals(attribute)) { return ACCESS_DENIED; } - if (supports(attribute)) { jsr250AttributeFound = true; // Attempt to find a matching granted authority @@ -85,7 +82,7 @@ public class Jsr250Voter implements AccessDecisionVoter { } } } - return jsr250AttributeFound ? ACCESS_DENIED : ACCESS_ABSTAIN; } + } diff --git a/core/src/main/java/org/springframework/security/access/annotation/Secured.java b/core/src/main/java/org/springframework/security/access/annotation/Secured.java index 28d509ed16..e8640e0af5 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/Secured.java +++ b/core/src/main/java/org/springframework/security/access/annotation/Secured.java @@ -50,10 +50,12 @@ import java.lang.annotation.Target; @Inherited @Documented public @interface Secured { + /** - * Returns the list of security configuration attributes (e.g. ROLE_USER, ROLE_ADMIN). - * + * Returns the list of security configuration attributes (e.g. ROLE_USER, + * ROLE_ADMIN). * @return String[] The secure method attributes */ String[] value(); + } diff --git a/core/src/main/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSource.java index e74300bb0a..152b52ec43 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSource.java @@ -18,7 +18,9 @@ package org.springframework.security.access.annotation; import java.lang.annotation.Annotation; import java.lang.reflect.Method; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import org.springframework.core.GenericTypeResolver; import org.springframework.core.annotation.AnnotationUtils; @@ -38,59 +40,56 @@ import org.springframework.util.Assert; * @author Luke Taylor */ @SuppressWarnings({ "unchecked" }) -public class SecuredAnnotationSecurityMetadataSource extends - AbstractFallbackMethodSecurityMetadataSource { +public class SecuredAnnotationSecurityMetadataSource extends AbstractFallbackMethodSecurityMetadataSource { + private AnnotationMetadataExtractor annotationExtractor; + private Class annotationType; public SecuredAnnotationSecurityMetadataSource() { this(new SecuredAnnotationMetadataExtractor()); } - public SecuredAnnotationSecurityMetadataSource( - AnnotationMetadataExtractor annotationMetadataExtractor) { + public SecuredAnnotationSecurityMetadataSource(AnnotationMetadataExtractor annotationMetadataExtractor) { Assert.notNull(annotationMetadataExtractor, "annotationMetadataExtractor cannot be null"); - annotationExtractor = annotationMetadataExtractor; - annotationType = (Class) GenericTypeResolver - .resolveTypeArgument(annotationExtractor.getClass(), - AnnotationMetadataExtractor.class); - Assert.notNull(annotationType, () -> annotationExtractor.getClass().getName() + this.annotationExtractor = annotationMetadataExtractor; + this.annotationType = (Class) GenericTypeResolver + .resolveTypeArgument(this.annotationExtractor.getClass(), AnnotationMetadataExtractor.class); + Assert.notNull(this.annotationType, () -> this.annotationExtractor.getClass().getName() + " must supply a generic parameter for AnnotationMetadataExtractor"); } + @Override protected Collection findAttributes(Class clazz) { - return processAnnotation(AnnotationUtils.findAnnotation(clazz, annotationType)); + return processAnnotation(AnnotationUtils.findAnnotation(clazz, this.annotationType)); } - protected Collection findAttributes(Method method, - Class targetClass) { - return processAnnotation(AnnotationUtils.findAnnotation(method, annotationType)); + @Override + protected Collection findAttributes(Method method, Class targetClass) { + return processAnnotation(AnnotationUtils.findAnnotation(method, this.annotationType)); } + @Override public Collection getAllConfigAttributes() { return null; } - private Collection processAnnotation(Annotation a) { - if (a == null) { - return null; + private Collection processAnnotation(Annotation annotation) { + return (annotation != null) ? this.annotationExtractor.extractAttributes(annotation) : null; + } + + static class SecuredAnnotationMetadataExtractor implements AnnotationMetadataExtractor { + + @Override + public Collection extractAttributes(Secured secured) { + String[] attributeTokens = secured.value(); + List attributes = new ArrayList<>(attributeTokens.length); + for (String token : attributeTokens) { + attributes.add(new SecurityConfig(token)); + } + return attributes; } - return annotationExtractor.extractAttributes(a); - } -} - -class SecuredAnnotationMetadataExtractor implements AnnotationMetadataExtractor { - - public Collection extractAttributes(Secured secured) { - String[] attributeTokens = secured.value(); - List attributes = new ArrayList<>( - attributeTokens.length); - - for (String token : attributeTokens) { - attributes.add(new SecurityConfig(token)); - } - - return attributes; } + } diff --git a/core/src/main/java/org/springframework/security/access/annotation/package-info.java b/core/src/main/java/org/springframework/security/access/annotation/package-info.java index 1d35e90ac5..269a726520 100644 --- a/core/src/main/java/org/springframework/security/access/annotation/package-info.java +++ b/core/src/main/java/org/springframework/security/access/annotation/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Support for JSR-250 and Spring Security {@code @Secured} annotations. */ diff --git a/core/src/main/java/org/springframework/security/access/event/AbstractAuthorizationEvent.java b/core/src/main/java/org/springframework/security/access/event/AbstractAuthorizationEvent.java index 0e902f7ba5..a285007615 100644 --- a/core/src/main/java/org/springframework/security/access/event/AbstractAuthorizationEvent.java +++ b/core/src/main/java/org/springframework/security/access/event/AbstractAuthorizationEvent.java @@ -24,15 +24,13 @@ import org.springframework.context.ApplicationEvent; * @author Ben Alex */ public abstract class AbstractAuthorizationEvent extends ApplicationEvent { - // ~ Constructors - // =================================================================================================== /** * Construct the event, passing in the secure object being intercepted. - * * @param secureObject the secure object */ public AbstractAuthorizationEvent(Object secureObject) { super(secureObject); } + } diff --git a/core/src/main/java/org/springframework/security/access/event/AuthenticationCredentialsNotFoundEvent.java b/core/src/main/java/org/springframework/security/access/event/AuthenticationCredentialsNotFoundEvent.java index a853f02a94..2474a81059 100644 --- a/core/src/main/java/org/springframework/security/access/event/AuthenticationCredentialsNotFoundEvent.java +++ b/core/src/main/java/org/springframework/security/access/event/AuthenticationCredentialsNotFoundEvent.java @@ -20,6 +20,7 @@ import java.util.Collection; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; +import org.springframework.util.Assert; /** * Indicates a secure object invocation failed because the Authentication @@ -28,46 +29,34 @@ import org.springframework.security.authentication.AuthenticationCredentialsNotF * @author Ben Alex */ public class AuthenticationCredentialsNotFoundEvent extends AbstractAuthorizationEvent { - // ~ Instance fields - // ================================================================================================ - private AuthenticationCredentialsNotFoundException credentialsNotFoundException; - private Collection configAttribs; + private final AuthenticationCredentialsNotFoundException credentialsNotFoundException; - // ~ Constructors - // =================================================================================================== + private final Collection configAttribs; /** * Construct the event. - * * @param secureObject the secure object * @param attributes that apply to the secure object * @param credentialsNotFoundException exception returned to the caller (contains * reason) * */ - public AuthenticationCredentialsNotFoundEvent(Object secureObject, - Collection attributes, + public AuthenticationCredentialsNotFoundEvent(Object secureObject, Collection attributes, AuthenticationCredentialsNotFoundException credentialsNotFoundException) { super(secureObject); - - if ((attributes == null) || (credentialsNotFoundException == null)) { - throw new IllegalArgumentException( - "All parameters are required and cannot be null"); - } - + Assert.isTrue(attributes != null && credentialsNotFoundException != null, + "All parameters are required and cannot be null"); this.configAttribs = attributes; this.credentialsNotFoundException = credentialsNotFoundException; } - // ~ Methods - // ======================================================================================================== - public Collection getConfigAttributes() { - return configAttribs; + return this.configAttribs; } public AuthenticationCredentialsNotFoundException getCredentialsNotFoundException() { - return credentialsNotFoundException; + return this.credentialsNotFoundException; } + } diff --git a/core/src/main/java/org/springframework/security/access/event/AuthorizationFailureEvent.java b/core/src/main/java/org/springframework/security/access/event/AuthorizationFailureEvent.java index 902df7a86d..de6a301c6a 100644 --- a/core/src/main/java/org/springframework/security/access/event/AuthorizationFailureEvent.java +++ b/core/src/main/java/org/springframework/security/access/event/AuthorizationFailureEvent.java @@ -21,6 +21,7 @@ import java.util.Collection; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; /** * Indicates a secure object invocation failed because the principal could not be @@ -35,55 +36,42 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public class AuthorizationFailureEvent extends AbstractAuthorizationEvent { - // ~ Instance fields - // ================================================================================================ - private AccessDeniedException accessDeniedException; - private Authentication authentication; - private Collection configAttributes; + private final AccessDeniedException accessDeniedException; - // ~ Constructors - // =================================================================================================== + private final Authentication authentication; + + private final Collection configAttributes; /** * Construct the event. - * * @param secureObject the secure object * @param attributes that apply to the secure object * @param authentication that was found in the SecurityContextHolder * @param accessDeniedException that was returned by the * AccessDecisionManager - * * @throws IllegalArgumentException if any null arguments are presented. */ - public AuthorizationFailureEvent(Object secureObject, - Collection attributes, Authentication authentication, - AccessDeniedException accessDeniedException) { + public AuthorizationFailureEvent(Object secureObject, Collection attributes, + Authentication authentication, AccessDeniedException accessDeniedException) { super(secureObject); - - if ((attributes == null) || (authentication == null) - || (accessDeniedException == null)) { - throw new IllegalArgumentException( - "All parameters are required and cannot be null"); - } - + Assert.isTrue(attributes != null && authentication != null && accessDeniedException != null, + "All parameters are required and cannot be null"); this.configAttributes = attributes; this.authentication = authentication; this.accessDeniedException = accessDeniedException; } - // ~ Methods - // ======================================================================================================== - public AccessDeniedException getAccessDeniedException() { - return accessDeniedException; + return this.accessDeniedException; } public Authentication getAuthentication() { - return authentication; + return this.authentication; } public Collection getConfigAttributes() { - return configAttributes; + return this.configAttributes; } + } diff --git a/core/src/main/java/org/springframework/security/access/event/AuthorizedEvent.java b/core/src/main/java/org/springframework/security/access/event/AuthorizedEvent.java index 43a4fc3c4d..f3b05d655b 100644 --- a/core/src/main/java/org/springframework/security/access/event/AuthorizedEvent.java +++ b/core/src/main/java/org/springframework/security/access/event/AuthorizedEvent.java @@ -20,6 +20,7 @@ import java.util.Collection; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; /** * Event indicating a secure object was invoked successfully. @@ -30,44 +31,31 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public class AuthorizedEvent extends AbstractAuthorizationEvent { - // ~ Instance fields - // ================================================================================================ - private Authentication authentication; - private Collection configAttributes; + private final Authentication authentication; - // ~ Constructors - // =================================================================================================== + private final Collection configAttributes; /** * Construct the event. - * * @param secureObject the secure object * @param attributes that apply to the secure object * @param authentication that successfully called the secure object * */ - public AuthorizedEvent(Object secureObject, Collection attributes, - Authentication authentication) { + public AuthorizedEvent(Object secureObject, Collection attributes, Authentication authentication) { super(secureObject); - - if ((attributes == null) || (authentication == null)) { - throw new IllegalArgumentException( - "All parameters are required and cannot be null"); - } - + Assert.isTrue(attributes != null && authentication != null, "All parameters are required and cannot be null"); this.configAttributes = attributes; this.authentication = authentication; } - // ~ Methods - // ======================================================================================================== - public Authentication getAuthentication() { - return authentication; + return this.authentication; } public Collection getConfigAttributes() { - return configAttributes; + return this.configAttributes; } + } diff --git a/core/src/main/java/org/springframework/security/access/event/LoggerListener.java b/core/src/main/java/org/springframework/security/access/event/LoggerListener.java index ed05cf8248..02247ceb15 100644 --- a/core/src/main/java/org/springframework/security/access/event/LoggerListener.java +++ b/core/src/main/java/org/springframework/security/access/event/LoggerListener.java @@ -18,7 +18,9 @@ package org.springframework.security.access.event; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.context.ApplicationListener; +import org.springframework.core.log.LogMessage; /** * Outputs interceptor-related application events to Commons Logging. @@ -30,58 +32,47 @@ import org.springframework.context.ApplicationListener; * @author Ben Alex */ public class LoggerListener implements ApplicationListener { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(LoggerListener.class); - // ~ Methods - // ======================================================================================================== - + @Override public void onApplicationEvent(AbstractAuthorizationEvent event) { if (event instanceof AuthenticationCredentialsNotFoundEvent) { - AuthenticationCredentialsNotFoundEvent authEvent = (AuthenticationCredentialsNotFoundEvent) event; - - if (logger.isWarnEnabled()) { - logger.warn("Security interception failed due to: " - + authEvent.getCredentialsNotFoundException() - + "; secure object: " + authEvent.getSource() - + "; configuration attributes: " - + authEvent.getConfigAttributes()); - } + onAuthenticationCredentialsNotFoundEvent((AuthenticationCredentialsNotFoundEvent) event); } - if (event instanceof AuthorizationFailureEvent) { - AuthorizationFailureEvent authEvent = (AuthorizationFailureEvent) event; - - if (logger.isWarnEnabled()) { - logger.warn("Security authorization failed due to: " - + authEvent.getAccessDeniedException() - + "; authenticated principal: " + authEvent.getAuthentication() - + "; secure object: " + authEvent.getSource() - + "; configuration attributes: " - + authEvent.getConfigAttributes()); - } + onAuthorizationFailureEvent((AuthorizationFailureEvent) event); } - if (event instanceof AuthorizedEvent) { - AuthorizedEvent authEvent = (AuthorizedEvent) event; - - if (logger.isInfoEnabled()) { - logger.info("Security authorized for authenticated principal: " - + authEvent.getAuthentication() + "; secure object: " - + authEvent.getSource() + "; configuration attributes: " - + authEvent.getConfigAttributes()); - } + onAuthorizedEvent((AuthorizedEvent) event); } - if (event instanceof PublicInvocationEvent) { - PublicInvocationEvent authEvent = (PublicInvocationEvent) event; - - if (logger.isInfoEnabled()) { - logger.info("Security interception not required for public secure object: " - + authEvent.getSource()); - } + onPublicInvocationEvent((PublicInvocationEvent) event); } } + + private void onAuthenticationCredentialsNotFoundEvent(AuthenticationCredentialsNotFoundEvent authEvent) { + logger.warn(LogMessage.format( + "Security interception failed due to: %s; secure object: %s; configuration attributes: %s", + authEvent.getCredentialsNotFoundException(), authEvent.getSource(), authEvent.getConfigAttributes())); + } + + private void onPublicInvocationEvent(PublicInvocationEvent event) { + logger.info(LogMessage.format("Security interception not required for public secure object: %s", + event.getSource())); + } + + private void onAuthorizedEvent(AuthorizedEvent authEvent) { + logger.info(LogMessage.format( + "Security authorized for authenticated principal: %s; secure object: %s; configuration attributes: %s", + authEvent.getAuthentication(), authEvent.getSource(), authEvent.getConfigAttributes())); + } + + private void onAuthorizationFailureEvent(AuthorizationFailureEvent authEvent) { + logger.warn(LogMessage.format( + "Security authorization failed due to: %s; authenticated principal: %s; secure object: %s; configuration attributes: %s", + authEvent.getAccessDeniedException(), authEvent.getAuthentication(), authEvent.getSource(), + authEvent.getConfigAttributes())); + } + } diff --git a/core/src/main/java/org/springframework/security/access/event/PublicInvocationEvent.java b/core/src/main/java/org/springframework/security/access/event/PublicInvocationEvent.java index 5c7c82a50d..2ce5db3a85 100644 --- a/core/src/main/java/org/springframework/security/access/event/PublicInvocationEvent.java +++ b/core/src/main/java/org/springframework/security/access/event/PublicInvocationEvent.java @@ -30,15 +30,13 @@ package org.springframework.security.access.event; * @author Ben Alex */ public class PublicInvocationEvent extends AbstractAuthorizationEvent { - // ~ Constructors - // =================================================================================================== /** * Construct the event, passing in the public secure object. - * * @param secureObject the public secure object */ public PublicInvocationEvent(Object secureObject) { super(secureObject); } + } diff --git a/core/src/main/java/org/springframework/security/access/event/package-info.java b/core/src/main/java/org/springframework/security/access/event/package-info.java index dbd745ed68..41488584ec 100644 --- a/core/src/main/java/org/springframework/security/access/event/package-info.java +++ b/core/src/main/java/org/springframework/security/access/event/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Authorization event and listener classes. */ package org.springframework.security.access.event; - diff --git a/core/src/main/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandler.java b/core/src/main/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandler.java index 67d4a80dae..6a70491d7c 100644 --- a/core/src/main/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandler.java +++ b/core/src/main/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; import org.springframework.context.ApplicationContext; @@ -36,15 +37,20 @@ import org.springframework.util.Assert; * @author Luke Taylor * @since 3.1 */ -public abstract class AbstractSecurityExpressionHandler implements - SecurityExpressionHandler, ApplicationContextAware { +public abstract class AbstractSecurityExpressionHandler + implements SecurityExpressionHandler, ApplicationContextAware { + private ExpressionParser expressionParser = new SpelExpressionParser(); - private BeanResolver br; + + private BeanResolver beanResolver; + private RoleHierarchy roleHierarchy; + private PermissionEvaluator permissionEvaluator = new DenyAllPermissionEvaluator(); + @Override public final ExpressionParser getExpressionParser() { - return expressionParser; + return this.expressionParser; } public final void setExpressionParser(ExpressionParser expressionParser) { @@ -55,21 +61,17 @@ public abstract class AbstractSecurityExpressionHandler implements /** * Invokes the internal template methods to create {@code StandardEvaluationContext} * and {@code SecurityExpressionRoot} objects. - * * @param authentication the current authentication object * @param invocation the invocation (filter, method, channel) * @return the context object for use in evaluating the expression, populated with a * suitable root object. */ - public final EvaluationContext createEvaluationContext(Authentication authentication, - T invocation) { - SecurityExpressionOperations root = createSecurityExpressionRoot(authentication, - invocation); - StandardEvaluationContext ctx = createEvaluationContextInternal(authentication, - invocation); - ctx.setBeanResolver(br); + @Override + public final EvaluationContext createEvaluationContext(Authentication authentication, T invocation) { + SecurityExpressionOperations root = createSecurityExpressionRoot(authentication, invocation); + StandardEvaluationContext ctx = createEvaluationContextInternal(authentication, invocation); + ctx.setBeanResolver(this.beanResolver); ctx.setRootObject(root); - return ctx; } @@ -79,30 +81,27 @@ public abstract class AbstractSecurityExpressionHandler implements * The returned object will have a {@code SecurityExpressionRootPropertyAccessor} * added, allowing beans in the {@code ApplicationContext} to be accessed via * expression properties. - * * @param authentication the current authentication object * @param invocation the invocation (filter, method, channel) * @return A {@code StandardEvaluationContext} or potentially a custom subclass if * overridden. */ - protected StandardEvaluationContext createEvaluationContextInternal( - Authentication authentication, T invocation) { + protected StandardEvaluationContext createEvaluationContextInternal(Authentication authentication, T invocation) { return new StandardEvaluationContext(); } /** * Implement in order to create a root object of the correct type for the supported * invocation type. - * * @param authentication the current authentication object * @param invocation the invocation (filter, method, channel) - * @return the object wh + * @return the object */ - protected abstract SecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, T invocation); + protected abstract SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, + T invocation); protected RoleHierarchy getRoleHierarchy() { - return roleHierarchy; + return this.roleHierarchy; } public void setRoleHierarchy(RoleHierarchy roleHierarchy) { @@ -110,14 +109,16 @@ public abstract class AbstractSecurityExpressionHandler implements } protected PermissionEvaluator getPermissionEvaluator() { - return permissionEvaluator; + return this.permissionEvaluator; } public void setPermissionEvaluator(PermissionEvaluator permissionEvaluator) { this.permissionEvaluator = permissionEvaluator; } + @Override public void setApplicationContext(ApplicationContext applicationContext) { - br = new BeanFactoryResolver(applicationContext); + this.beanResolver = new BeanFactoryResolver(applicationContext); } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/DenyAllPermissionEvaluator.java b/core/src/main/java/org/springframework/security/access/expression/DenyAllPermissionEvaluator.java index 4d9708ea14..2aae6b7a24 100644 --- a/core/src/main/java/org/springframework/security/access/expression/DenyAllPermissionEvaluator.java +++ b/core/src/main/java/org/springframework/security/access/expression/DenyAllPermissionEvaluator.java @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; import java.io.Serializable; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.core.Authentication; @@ -36,20 +39,21 @@ public class DenyAllPermissionEvaluator implements PermissionEvaluator { /** * @return false always */ - public boolean hasPermission(Authentication authentication, Object target, - Object permission) { - logger.warn("Denying user " + authentication.getName() + " permission '" - + permission + "' on object " + target); + @Override + public boolean hasPermission(Authentication authentication, Object target, Object permission) { + this.logger.warn(LogMessage.format("Denying user %s permission '%s' on object %s", authentication.getName(), + permission, target)); return false; } /** * @return false always */ - public boolean hasPermission(Authentication authentication, Serializable targetId, - String targetType, Object permission) { - logger.warn("Denying user " + authentication.getName() + " permission '" - + permission + "' on object with Id '" + targetId); + @Override + public boolean hasPermission(Authentication authentication, Serializable targetId, String targetType, + Object permission) { + this.logger.warn(LogMessage.format("Denying user %s permission '%s' on object with Id %s", + authentication.getName(), permission, targetId)); return false; } diff --git a/core/src/main/java/org/springframework/security/access/expression/ExpressionUtils.java b/core/src/main/java/org/springframework/security/access/expression/ExpressionUtils.java index 70cf8ab5a5..5296a3eacb 100644 --- a/core/src/main/java/org/springframework/security/access/expression/ExpressionUtils.java +++ b/core/src/main/java/org/springframework/security/access/expression/ExpressionUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; import org.springframework.expression.EvaluationContext; @@ -21,13 +22,17 @@ import org.springframework.expression.Expression; public final class ExpressionUtils { + private ExpressionUtils() { + } + public static boolean evaluateAsBoolean(Expression expr, EvaluationContext ctx) { try { return expr.getValue(ctx, Boolean.class); } - catch (EvaluationException e) { - throw new IllegalArgumentException("Failed to evaluate expression '" - + expr.getExpressionString() + "'", e); + catch (EvaluationException ex) { + throw new IllegalArgumentException("Failed to evaluate expression '" + expr.getExpressionString() + "'", + ex); } } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionHandler.java b/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionHandler.java index 73f5ac1636..39e171c4dd 100644 --- a/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionHandler.java +++ b/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; import org.springframework.aop.framework.AopInfrastructureBean; @@ -28,6 +29,7 @@ import org.springframework.security.core.Authentication; * @since 3.1 */ public interface SecurityExpressionHandler extends AopInfrastructureBean { + /** * @return an expression parser for the expressions used by the implementation. */ @@ -38,4 +40,5 @@ public interface SecurityExpressionHandler extends AopInfrastructureBean { * invocation type. */ EvaluationContext createEvaluationContext(Authentication authentication, T invocation); + } diff --git a/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionOperations.java b/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionOperations.java index 01ace0e04a..b06d56df11 100644 --- a/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionOperations.java +++ b/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionOperations.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; import org.springframework.security.core.Authentication; @@ -59,7 +60,6 @@ public interface SecurityExpressionOperations { * implementation may convert it to use "ROLE_USER" instead. The way in which the role * is converted may depend on the implementation settings. *

    - * * @param role the authority to test (i.e. "USER") * @return true if the authority is found, else false */ @@ -71,12 +71,11 @@ public interface SecurityExpressionOperations { * within {@link Authentication#getAuthorities()}. *

    *

    - * This is a similar to hasAnyAuthority except that this method implies - * that the String passed in is a role. For example, if "USER" is passed in the - * implementation may convert it to use "ROLE_USER" instead. The way in which the role - * is converted may depend on the implementation settings. + * This is a similar to hasAnyAuthority except that this method implies that the + * String passed in is a role. For example, if "USER" is passed in the implementation + * may convert it to use "ROLE_USER" instead. The way in which the role is converted + * may depend on the implementation settings. *

    - * * @param roles the authorities to test (i.e. "USER", "ADMIN") * @return true if any of the authorities is found, else false */ diff --git a/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionRoot.java b/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionRoot.java index 97b201d085..155b317638 100644 --- a/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionRoot.java +++ b/core/src/main/java/org/springframework/security/access/expression/SecurityExpressionRoot.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; import java.io.Serializable; @@ -33,22 +34,37 @@ import org.springframework.security.core.authority.AuthorityUtils; * @since 3.0 */ public abstract class SecurityExpressionRoot implements SecurityExpressionOperations { + protected final Authentication authentication; + private AuthenticationTrustResolver trustResolver; + private RoleHierarchy roleHierarchy; + private Set roles; + private String defaultRolePrefix = "ROLE_"; - /** Allows "permitAll" expression */ + /** + * Allows "permitAll" expression + */ public final boolean permitAll = true; - /** Allows "denyAll" expression */ + /** + * Allows "denyAll" expression + */ public final boolean denyAll = false; + private PermissionEvaluator permissionEvaluator; + public final String read = "read"; + public final String write = "write"; + public final String create = "create"; + public final String delete = "delete"; + public final String admin = "administration"; /** @@ -62,62 +78,71 @@ public abstract class SecurityExpressionRoot implements SecurityExpressionOperat this.authentication = authentication; } + @Override public final boolean hasAuthority(String authority) { return hasAnyAuthority(authority); } + @Override public final boolean hasAnyAuthority(String... authorities) { return hasAnyAuthorityName(null, authorities); } + @Override public final boolean hasRole(String role) { return hasAnyRole(role); } + @Override public final boolean hasAnyRole(String... roles) { - return hasAnyAuthorityName(defaultRolePrefix, roles); + return hasAnyAuthorityName(this.defaultRolePrefix, roles); } private boolean hasAnyAuthorityName(String prefix, String... roles) { Set roleSet = getAuthoritySet(); - for (String role : roles) { String defaultedRole = getRoleWithDefaultPrefix(prefix, role); if (roleSet.contains(defaultedRole)) { return true; } } - return false; } + @Override public final Authentication getAuthentication() { - return authentication; + return this.authentication; } + @Override public final boolean permitAll() { return true; } + @Override public final boolean denyAll() { return false; } + @Override public final boolean isAnonymous() { - return trustResolver.isAnonymous(authentication); + return this.trustResolver.isAnonymous(this.authentication); } + @Override public final boolean isAuthenticated() { return !isAnonymous(); } + @Override public final boolean isRememberMe() { - return trustResolver.isRememberMe(authentication); + return this.trustResolver.isRememberMe(this.authentication); } + @Override public final boolean isFullyAuthenticated() { - return !trustResolver.isAnonymous(authentication) - && !trustResolver.isRememberMe(authentication); + return !this.trustResolver.isAnonymous(this.authentication) + && !this.trustResolver.isRememberMe(this.authentication); } /** @@ -126,7 +151,7 @@ public abstract class SecurityExpressionRoot implements SecurityExpressionOperat * @return */ public Object getPrincipal() { - return authentication.getPrincipal(); + return this.authentication.getPrincipal(); } public void setTrustResolver(AuthenticationTrustResolver trustResolver) { @@ -148,7 +173,6 @@ public abstract class SecurityExpressionRoot implements SecurityExpressionOperat *

    * If null or empty, then no default role prefix is used. *

    - * * @param defaultRolePrefix the default prefix to add to roles. Default "ROLE_". */ public void setDefaultRolePrefix(String defaultRolePrefix) { @@ -156,28 +180,25 @@ public abstract class SecurityExpressionRoot implements SecurityExpressionOperat } private Set getAuthoritySet() { - if (roles == null) { - Collection userAuthorities = authentication - .getAuthorities(); - - if (roleHierarchy != null) { - userAuthorities = roleHierarchy - .getReachableGrantedAuthorities(userAuthorities); + if (this.roles == null) { + Collection userAuthorities = this.authentication.getAuthorities(); + if (this.roleHierarchy != null) { + userAuthorities = this.roleHierarchy.getReachableGrantedAuthorities(userAuthorities); } - - roles = AuthorityUtils.authorityListToSet(userAuthorities); + this.roles = AuthorityUtils.authorityListToSet(userAuthorities); } - - return roles; + return this.roles; } + @Override public boolean hasPermission(Object target, Object permission) { - return permissionEvaluator.hasPermission(authentication, target, permission); + return this.permissionEvaluator.hasPermission(this.authentication, target, permission); } + @Override public boolean hasPermission(Object targetId, String targetType, Object permission) { - return permissionEvaluator.hasPermission(authentication, (Serializable) targetId, - targetType, permission); + return this.permissionEvaluator.hasPermission(this.authentication, (Serializable) targetId, targetType, + permission); } public void setPermissionEvaluator(PermissionEvaluator permissionEvaluator) { @@ -187,7 +208,6 @@ public abstract class SecurityExpressionRoot implements SecurityExpressionOperat /** * Prefixes role with defaultRolePrefix if defaultRolePrefix is non-null and if role * does not already start with defaultRolePrefix. - * * @param defaultRolePrefix * @param role * @return @@ -204,4 +224,5 @@ public abstract class SecurityExpressionRoot implements SecurityExpressionOperat } return defaultRolePrefix + role; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/AbstractExpressionBasedMethodConfigAttribute.java b/core/src/main/java/org/springframework/security/access/expression/method/AbstractExpressionBasedMethodConfigAttribute.java index cafad53385..722505f22a 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/AbstractExpressionBasedMethodConfigAttribute.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/AbstractExpressionBasedMethodConfigAttribute.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.springframework.expression.Expression; @@ -33,41 +34,42 @@ import org.springframework.util.Assert; * @since 3.0 */ abstract class AbstractExpressionBasedMethodConfigAttribute implements ConfigAttribute { + private final Expression filterExpression; + private final Expression authorizeExpression; /** * Parses the supplied expressions as Spring-EL. */ - AbstractExpressionBasedMethodConfigAttribute(String filterExpression, - String authorizeExpression) throws ParseException { + AbstractExpressionBasedMethodConfigAttribute(String filterExpression, String authorizeExpression) + throws ParseException { Assert.isTrue(filterExpression != null || authorizeExpression != null, "Filter and authorization Expressions cannot both be null"); SpelExpressionParser parser = new SpelExpressionParser(); - this.filterExpression = filterExpression == null ? null : parser - .parseExpression(filterExpression); - this.authorizeExpression = authorizeExpression == null ? null : parser - .parseExpression(authorizeExpression); + this.filterExpression = (filterExpression != null) ? parser.parseExpression(filterExpression) : null; + this.authorizeExpression = (authorizeExpression != null) ? parser.parseExpression(authorizeExpression) : null; } - AbstractExpressionBasedMethodConfigAttribute(Expression filterExpression, - Expression authorizeExpression) throws ParseException { + AbstractExpressionBasedMethodConfigAttribute(Expression filterExpression, Expression authorizeExpression) + throws ParseException { Assert.isTrue(filterExpression != null || authorizeExpression != null, "Filter and authorization Expressions cannot both be null"); - this.filterExpression = filterExpression == null ? null : filterExpression; - this.authorizeExpression = authorizeExpression == null ? null - : authorizeExpression; + this.filterExpression = (filterExpression != null) ? filterExpression : null; + this.authorizeExpression = (authorizeExpression != null) ? authorizeExpression : null; } Expression getFilterExpression() { - return filterExpression; + return this.filterExpression; } Expression getAuthorizeExpression() { - return authorizeExpression; + return this.authorizeExpression; } + @Override public String getAttribute() { return null; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java b/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java index 26bdbf154a..6254ee087d 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import java.lang.reflect.Array; @@ -22,12 +23,14 @@ import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.*; +import java.util.stream.Stream; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.core.log.LogMessage; import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.expression.spel.support.StandardEvaluationContext; @@ -49,15 +52,17 @@ import org.springframework.util.Assert; * @author Luke Taylor * @since 3.0 */ -public class DefaultMethodSecurityExpressionHandler extends - AbstractSecurityExpressionHandler implements - MethodSecurityExpressionHandler { +public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpressionHandler + implements MethodSecurityExpressionHandler { protected final Log logger = LogFactory.getLog(getClass()); private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); + private ParameterNameDiscoverer parameterNameDiscoverer = new DefaultSecurityParameterNameDiscoverer(); + private PermissionCacheOptimizer permissionCacheOptimizer = null; + private String defaultRolePrefix = "ROLE_"; public DefaultMethodSecurityExpressionHandler() { @@ -67,159 +72,123 @@ public class DefaultMethodSecurityExpressionHandler extends * Uses a {@link MethodSecurityEvaluationContext} as the EvaluationContext * implementation. */ - public StandardEvaluationContext createEvaluationContextInternal(Authentication auth, - MethodInvocation mi) { + @Override + public StandardEvaluationContext createEvaluationContextInternal(Authentication auth, MethodInvocation mi) { return new MethodSecurityEvaluationContext(auth, mi, getParameterNameDiscoverer()); } /** * Creates the root object for expression evaluation. */ - protected MethodSecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, MethodInvocation invocation) { - MethodSecurityExpressionRoot root = new MethodSecurityExpressionRoot( - authentication); + @Override + protected MethodSecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, + MethodInvocation invocation) { + MethodSecurityExpressionRoot root = new MethodSecurityExpressionRoot(authentication); root.setThis(invocation.getThis()); root.setPermissionEvaluator(getPermissionEvaluator()); root.setTrustResolver(getTrustResolver()); root.setRoleHierarchy(getRoleHierarchy()); root.setDefaultRolePrefix(getDefaultRolePrefix()); - return root; } /** - * Filters the {@code filterTarget} object (which must be either a collection, array, map - * or stream), by evaluating the supplied expression. + * Filters the {@code filterTarget} object (which must be either a collection, array, + * map or stream), by evaluating the supplied expression. *

    - * If a {@code Collection} or {@code Map} is used, the original instance will be modified to contain - * the elements for which the permission expression evaluates to {@code true}. For an - * array, a new array instance will be returned. + * If a {@code Collection} or {@code Map} is used, the original instance will be + * modified to contain the elements for which the permission expression evaluates to + * {@code true}. For an array, a new array instance will be returned. */ - @SuppressWarnings("unchecked") - public Object filter(Object filterTarget, Expression filterExpression, - EvaluationContext ctx) { - MethodSecurityExpressionOperations rootObject = (MethodSecurityExpressionOperations) ctx - .getRootObject().getValue(); - final boolean debug = logger.isDebugEnabled(); - List retainList; - - if (debug) { - logger.debug("Filtering with expression: " - + filterExpression.getExpressionString()); - } - + @Override + public Object filter(Object filterTarget, Expression filterExpression, EvaluationContext ctx) { + MethodSecurityExpressionOperations rootObject = (MethodSecurityExpressionOperations) ctx.getRootObject() + .getValue(); + this.logger.debug(LogMessage.format("Filtering with expression: %s", filterExpression.getExpressionString())); if (filterTarget instanceof Collection) { - Collection collection = (Collection) filterTarget; - retainList = new ArrayList(collection.size()); - - if (debug) { - logger.debug("Filtering collection with " + collection.size() - + " elements"); - } - - if (permissionCacheOptimizer != null) { - permissionCacheOptimizer.cachePermissionsFor( - rootObject.getAuthentication(), collection); - } - - for (Object filterObject : (Collection) filterTarget) { - rootObject.setFilterObject(filterObject); - - if (ExpressionUtils.evaluateAsBoolean(filterExpression, ctx)) { - retainList.add(filterObject); - } - } - - if (debug) { - logger.debug("Retaining elements: " + retainList); - } - - collection.clear(); - collection.addAll(retainList); - - return filterTarget; + return filterCollection((Collection) filterTarget, filterExpression, ctx, rootObject); } - if (filterTarget.getClass().isArray()) { - Object[] array = (Object[]) filterTarget; - retainList = new ArrayList(array.length); - - if (debug) { - logger.debug("Filtering array with " + array.length + " elements"); - } - - if (permissionCacheOptimizer != null) { - permissionCacheOptimizer.cachePermissionsFor( - rootObject.getAuthentication(), Arrays.asList(array)); - } - - for (Object o : array) { - rootObject.setFilterObject(o); - - if (ExpressionUtils.evaluateAsBoolean(filterExpression, ctx)) { - retainList.add(o); - } - } - - if (debug) { - logger.debug("Retaining elements: " + retainList); - } - - Object[] filtered = (Object[]) Array.newInstance(filterTarget.getClass() - .getComponentType(), retainList.size()); - for (int i = 0; i < retainList.size(); i++) { - filtered[i] = retainList.get(i); - } - - return filtered; + return filterArray((Object[]) filterTarget, filterExpression, ctx, rootObject); } - if (filterTarget instanceof Map) { - final Map map = (Map) filterTarget; - final Map retainMap = new LinkedHashMap(map.size()); - - if (debug) { - logger.debug("Filtering map with " + map.size() + " elements"); - } - - for (Map.Entry filterObject : map.entrySet()) { - rootObject.setFilterObject(filterObject); - - if (ExpressionUtils.evaluateAsBoolean(filterExpression, ctx)) { - retainMap.put(filterObject.getKey(), filterObject.getValue()); - } - } - - if (debug) { - logger.debug("Retaining elements: " + retainMap); - } - - map.clear(); - map.putAll(retainMap); - - return filterTarget; + return filterMap((Map) filterTarget, filterExpression, ctx, rootObject); } - if (filterTarget instanceof Stream) { - final Stream original = (Stream) filterTarget; - - return original.filter(filterObject -> { - rootObject.setFilterObject(filterObject); - return ExpressionUtils.evaluateAsBoolean(filterExpression, ctx); - }) - .onClose(original::close); + return filterStream((Stream) filterTarget, filterExpression, ctx, rootObject); } - throw new IllegalArgumentException( - "Filter target must be a collection, array, map or stream type, but was " - + filterTarget); + "Filter target must be a collection, array, map or stream type, but was " + filterTarget); + } + + private Object filterCollection(Collection filterTarget, Expression filterExpression, EvaluationContext ctx, + MethodSecurityExpressionOperations rootObject) { + this.logger.debug(LogMessage.format("Filtering collection with %s elements", filterTarget.size())); + List retain = new ArrayList<>(filterTarget.size()); + if (this.permissionCacheOptimizer != null) { + this.permissionCacheOptimizer.cachePermissionsFor(rootObject.getAuthentication(), filterTarget); + } + for (T filterObject : filterTarget) { + rootObject.setFilterObject(filterObject); + if (ExpressionUtils.evaluateAsBoolean(filterExpression, ctx)) { + retain.add(filterObject); + } + } + this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); + filterTarget.clear(); + filterTarget.addAll(retain); + return filterTarget; + } + + private Object filterArray(Object[] filterTarget, Expression filterExpression, EvaluationContext ctx, + MethodSecurityExpressionOperations rootObject) { + List retain = new ArrayList<>(filterTarget.length); + this.logger.debug(LogMessage.format("Filtering array with %s elements", filterTarget.length)); + if (this.permissionCacheOptimizer != null) { + this.permissionCacheOptimizer.cachePermissionsFor(rootObject.getAuthentication(), + Arrays.asList(filterTarget)); + } + for (Object filterObject : filterTarget) { + rootObject.setFilterObject(filterObject); + if (ExpressionUtils.evaluateAsBoolean(filterExpression, ctx)) { + retain.add(filterObject); + } + } + this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); + Object[] filtered = (Object[]) Array.newInstance(filterTarget.getClass().getComponentType(), retain.size()); + for (int i = 0; i < retain.size(); i++) { + filtered[i] = retain.get(i); + } + return filtered; + } + + private Object filterMap(final Map filterTarget, Expression filterExpression, EvaluationContext ctx, + MethodSecurityExpressionOperations rootObject) { + Map retain = new LinkedHashMap<>(filterTarget.size()); + this.logger.debug(LogMessage.format("Filtering map with %s elements", filterTarget.size())); + for (Map.Entry filterObject : filterTarget.entrySet()) { + rootObject.setFilterObject(filterObject); + if (ExpressionUtils.evaluateAsBoolean(filterExpression, ctx)) { + retain.put(filterObject.getKey(), filterObject.getValue()); + } + } + this.logger.debug(LogMessage.format("Retaining elements: %s", retain)); + filterTarget.clear(); + filterTarget.putAll(retain); + return filterTarget; + } + + private Object filterStream(final Stream filterTarget, Expression filterExpression, EvaluationContext ctx, + MethodSecurityExpressionOperations rootObject) { + return filterTarget.filter((filterObject) -> { + rootObject.setFilterObject(filterObject); + return ExpressionUtils.evaluateAsBoolean(filterExpression, ctx); + }).onClose(filterTarget::close); } /** * Sets the {@link AuthenticationTrustResolver} to be used. The default is * {@link AuthenticationTrustResolverImpl}. - * * @param trustResolver the {@link AuthenticationTrustResolver} to use. Cannot be * null. */ @@ -230,9 +199,9 @@ public class DefaultMethodSecurityExpressionHandler extends /** * @return The current {@link AuthenticationTrustResolver} - */ + */ protected AuthenticationTrustResolver getTrustResolver() { - return trustResolver; + return this.trustResolver; } /** @@ -246,33 +215,33 @@ public class DefaultMethodSecurityExpressionHandler extends /** * @return The current {@link ParameterNameDiscoverer} - */ + */ protected ParameterNameDiscoverer getParameterNameDiscoverer() { - return parameterNameDiscoverer; + return this.parameterNameDiscoverer; } - public void setPermissionCacheOptimizer( - PermissionCacheOptimizer permissionCacheOptimizer) { + public void setPermissionCacheOptimizer(PermissionCacheOptimizer permissionCacheOptimizer) { this.permissionCacheOptimizer = permissionCacheOptimizer; } + @Override public void setReturnObject(Object returnObject, EvaluationContext ctx) { - ((MethodSecurityExpressionOperations) ctx.getRootObject().getValue()) - .setReturnObject(returnObject); + ((MethodSecurityExpressionOperations) ctx.getRootObject().getValue()).setReturnObject(returnObject); } /** *

    - * Sets the default prefix to be added to {@link org.springframework.security.access.expression.SecurityExpressionRoot#hasAnyRole(String...)} or - * {@link org.springframework.security.access.expression.SecurityExpressionRoot#hasRole(String)}. For example, if hasRole("ADMIN") or hasRole("ROLE_ADMIN") - * is passed in, then the role ROLE_ADMIN will be used when the defaultRolePrefix is - * "ROLE_" (default). + * Sets the default prefix to be added to + * {@link org.springframework.security.access.expression.SecurityExpressionRoot#hasAnyRole(String...)} + * or + * {@link org.springframework.security.access.expression.SecurityExpressionRoot#hasRole(String)}. + * For example, if hasRole("ADMIN") or hasRole("ROLE_ADMIN") is passed in, then the + * role ROLE_ADMIN will be used when the defaultRolePrefix is "ROLE_" (default). *

    * *

    * If null or empty, then no default role prefix is used. *

    - * * @param defaultRolePrefix the default prefix to add to roles. Default "ROLE_". */ public void setDefaultRolePrefix(String defaultRolePrefix) { @@ -281,8 +250,9 @@ public class DefaultMethodSecurityExpressionHandler extends /** * @return The default role prefix - */ + */ protected String getDefaultRolePrefix() { - return defaultRolePrefix; + return this.defaultRolePrefix; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedAnnotationAttributeFactory.java b/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedAnnotationAttributeFactory.java index 2d96262dfa..d6fb761966 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedAnnotationAttributeFactory.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedAnnotationAttributeFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.springframework.expression.Expression; @@ -30,53 +31,50 @@ import org.springframework.security.access.prepost.PrePostInvocationAttributeFac * @author Rob Winch * @since 3.0 */ -public class ExpressionBasedAnnotationAttributeFactory implements - PrePostInvocationAttributeFactory { +public class ExpressionBasedAnnotationAttributeFactory implements PrePostInvocationAttributeFactory { + private final Object parserLock = new Object(); + private ExpressionParser parser; + private MethodSecurityExpressionHandler handler; - public ExpressionBasedAnnotationAttributeFactory( - MethodSecurityExpressionHandler handler) { + public ExpressionBasedAnnotationAttributeFactory(MethodSecurityExpressionHandler handler) { this.handler = handler; } - public PreInvocationAttribute createPreInvocationAttribute(String preFilterAttribute, - String filterObject, String preAuthorizeAttribute) { + @Override + public PreInvocationAttribute createPreInvocationAttribute(String preFilterAttribute, String filterObject, + String preAuthorizeAttribute) { try { // TODO: Optimization of permitAll ExpressionParser parser = getParser(); - Expression preAuthorizeExpression = preAuthorizeAttribute == null ? parser - .parseExpression("permitAll") : parser - .parseExpression(preAuthorizeAttribute); - Expression preFilterExpression = preFilterAttribute == null ? null : parser - .parseExpression(preFilterAttribute); - return new PreInvocationExpressionAttribute(preFilterExpression, - filterObject, preAuthorizeExpression); + Expression preAuthorizeExpression = (preAuthorizeAttribute != null) + ? parser.parseExpression(preAuthorizeAttribute) : parser.parseExpression("permitAll"); + Expression preFilterExpression = (preFilterAttribute != null) ? parser.parseExpression(preFilterAttribute) + : null; + return new PreInvocationExpressionAttribute(preFilterExpression, filterObject, preAuthorizeExpression); } - catch (ParseException e) { - throw new IllegalArgumentException("Failed to parse expression '" - + e.getExpressionString() + "'", e); + catch (ParseException ex) { + throw new IllegalArgumentException("Failed to parse expression '" + ex.getExpressionString() + "'", ex); } } - public PostInvocationAttribute createPostInvocationAttribute( - String postFilterAttribute, String postAuthorizeAttribute) { + @Override + public PostInvocationAttribute createPostInvocationAttribute(String postFilterAttribute, + String postAuthorizeAttribute) { try { ExpressionParser parser = getParser(); - Expression postAuthorizeExpression = postAuthorizeAttribute == null ? null - : parser.parseExpression(postAuthorizeAttribute); - Expression postFilterExpression = postFilterAttribute == null ? null : parser - .parseExpression(postFilterAttribute); - + Expression postAuthorizeExpression = (postAuthorizeAttribute != null) + ? parser.parseExpression(postAuthorizeAttribute) : null; + Expression postFilterExpression = (postFilterAttribute != null) + ? parser.parseExpression(postFilterAttribute) : null; if (postFilterExpression != null || postAuthorizeExpression != null) { - return new PostInvocationExpressionAttribute(postFilterExpression, - postAuthorizeExpression); + return new PostInvocationExpressionAttribute(postFilterExpression, postAuthorizeExpression); } } - catch (ParseException e) { - throw new IllegalArgumentException("Failed to parse expression '" - + e.getExpressionString() + "'", e); + catch (ParseException ex) { + throw new IllegalArgumentException("Failed to parse expression '" + ex.getExpressionString() + "'", ex); } return null; @@ -84,17 +82,17 @@ public class ExpressionBasedAnnotationAttributeFactory implements /** * Delay the lookup of the {@link ExpressionParser} to prevent SEC-2136 - * * @return */ private ExpressionParser getParser() { if (this.parser != null) { return this.parser; } - synchronized (parserLock) { - this.parser = handler.getExpressionParser(); + synchronized (this.parserLock) { + this.parser = this.handler.getExpressionParser(); this.handler = null; } return this.parser; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPostInvocationAdvice.java b/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPostInvocationAdvice.java index 8dd739804f..664f97102e 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPostInvocationAdvice.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPostInvocationAdvice.java @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.security.access.AccessDeniedException; @@ -27,56 +30,41 @@ import org.springframework.security.access.prepost.PostInvocationAuthorizationAd import org.springframework.security.core.Authentication; /** - * * @author Luke Taylor * @since 3.0 */ -public class ExpressionBasedPostInvocationAdvice implements - PostInvocationAuthorizationAdvice { +public class ExpressionBasedPostInvocationAdvice implements PostInvocationAuthorizationAdvice { + protected final Log logger = LogFactory.getLog(getClass()); private final MethodSecurityExpressionHandler expressionHandler; - public ExpressionBasedPostInvocationAdvice( - MethodSecurityExpressionHandler expressionHandler) { + public ExpressionBasedPostInvocationAdvice(MethodSecurityExpressionHandler expressionHandler) { this.expressionHandler = expressionHandler; } - public Object after(Authentication authentication, MethodInvocation mi, - PostInvocationAttribute postAttr, Object returnedObject) - throws AccessDeniedException { + @Override + public Object after(Authentication authentication, MethodInvocation mi, PostInvocationAttribute postAttr, + Object returnedObject) throws AccessDeniedException { PostInvocationExpressionAttribute pia = (PostInvocationExpressionAttribute) postAttr; - EvaluationContext ctx = expressionHandler.createEvaluationContext(authentication, - mi); + EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, mi); Expression postFilter = pia.getFilterExpression(); Expression postAuthorize = pia.getAuthorizeExpression(); - if (postFilter != null) { - if (logger.isDebugEnabled()) { - logger.debug("Applying PostFilter expression " + postFilter); - } - + this.logger.debug(LogMessage.format("Applying PostFilter expression %s", postFilter)); if (returnedObject != null) { - returnedObject = expressionHandler - .filter(returnedObject, postFilter, ctx); + returnedObject = this.expressionHandler.filter(returnedObject, postFilter, ctx); } else { - if (logger.isDebugEnabled()) { - logger.debug("Return object is null, filtering will be skipped"); - } + this.logger.debug("Return object is null, filtering will be skipped"); } } - - expressionHandler.setReturnObject(returnedObject, ctx); - - if (postAuthorize != null - && !ExpressionUtils.evaluateAsBoolean(postAuthorize, ctx)) { - if (logger.isDebugEnabled()) { - logger.debug("PostAuthorize expression rejected access"); - } + this.expressionHandler.setReturnObject(returnedObject, ctx); + if (postAuthorize != null && !ExpressionUtils.evaluateAsBoolean(postAuthorize, ctx)) { + this.logger.debug("PostAuthorize expression rejected access"); throw new AccessDeniedException("Access is denied"); } - return returnedObject; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdvice.java b/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdvice.java index c4f11aa6f8..19c7659407 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdvice.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdvice.java @@ -13,20 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/** - * - */ + package org.springframework.security.access.expression.method; import java.util.Collection; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.security.access.expression.ExpressionUtils; import org.springframework.security.access.prepost.PreInvocationAttribute; import org.springframework.security.access.prepost.PreInvocationAuthorizationAdvice; import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; /** * Method pre-invocation handling based on expressions. @@ -34,68 +34,49 @@ import org.springframework.security.core.Authentication; * @author Luke Taylor * @since 3.0 */ -public class ExpressionBasedPreInvocationAdvice implements - PreInvocationAuthorizationAdvice { +public class ExpressionBasedPreInvocationAdvice implements PreInvocationAuthorizationAdvice { + private MethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler(); - public boolean before(Authentication authentication, MethodInvocation mi, - PreInvocationAttribute attr) { + @Override + public boolean before(Authentication authentication, MethodInvocation mi, PreInvocationAttribute attr) { PreInvocationExpressionAttribute preAttr = (PreInvocationExpressionAttribute) attr; - EvaluationContext ctx = expressionHandler.createEvaluationContext(authentication, - mi); + EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, mi); Expression preFilter = preAttr.getFilterExpression(); Expression preAuthorize = preAttr.getAuthorizeExpression(); - if (preFilter != null) { Object filterTarget = findFilterTarget(preAttr.getFilterTarget(), ctx, mi); - - expressionHandler.filter(filterTarget, preFilter, ctx); + this.expressionHandler.filter(filterTarget, preFilter, ctx); } - - if (preAuthorize == null) { - return true; - } - - return ExpressionUtils.evaluateAsBoolean(preAuthorize, ctx); + return (preAuthorize != null) ? ExpressionUtils.evaluateAsBoolean(preAuthorize, ctx) : true; } - private Object findFilterTarget(String filterTargetName, EvaluationContext ctx, - MethodInvocation mi) { + private Object findFilterTarget(String filterTargetName, EvaluationContext ctx, MethodInvocation invocation) { Object filterTarget = null; - if (filterTargetName.length() > 0) { filterTarget = ctx.lookupVariable(filterTargetName); - if (filterTarget == null) { - throw new IllegalArgumentException( - "Filter target was null, or no argument with name " - + filterTargetName + " found in method"); - } + Assert.notNull(filterTarget, + () -> "Filter target was null, or no argument with name " + filterTargetName + " found in method"); } - else if (mi.getArguments().length == 1) { - Object arg = mi.getArguments()[0]; + else if (invocation.getArguments().length == 1) { + Object arg = invocation.getArguments()[0]; if (arg.getClass().isArray() || arg instanceof Collection) { filterTarget = arg; } - if (filterTarget == null) { - throw new IllegalArgumentException( - "A PreFilter expression was set but the method argument type" - + arg.getClass() + " is not filterable"); - } - } else if (mi.getArguments().length > 1) { + Assert.notNull(filterTarget, () -> "A PreFilter expression was set but the method argument type" + + arg.getClass() + " is not filterable"); + } + else if (invocation.getArguments().length > 1) { throw new IllegalArgumentException( "Unable to determine the method argument for filtering. Specify the filter target."); } - - if (filterTarget.getClass().isArray()) { - throw new IllegalArgumentException( - "Pre-filtering on array types is not supported. " - + "Using a Collection will solve this problem"); - } - + Assert.isTrue(!filterTarget.getClass().isArray(), + "Pre-filtering on array types is not supported. Using a Collection will solve this problem"); return filterTarget; } public void setExpressionHandler(MethodSecurityExpressionHandler expressionHandler) { this.expressionHandler = expressionHandler; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContext.java b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContext.java index 5ce94731e7..5ba0a376b6 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContext.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContext.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import java.lang.reflect.Method; import org.aopalliance.intercept.MethodInvocation; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; + import org.springframework.aop.framework.AopProxyUtils; import org.springframework.aop.support.AopUtils; import org.springframework.context.expression.MethodBasedEvaluationContext; @@ -37,8 +37,6 @@ import org.springframework.security.core.parameters.DefaultSecurityParameterName * @since 3.0 */ class MethodSecurityEvaluationContext extends MethodBasedEvaluationContext { - private static final Log logger = LogFactory - .getLog(MethodSecurityEvaluationContext.class); /** * Intended for testing. Don't use in practice as it creates a new parameter resolver @@ -57,4 +55,5 @@ class MethodSecurityEvaluationContext extends MethodBasedEvaluationContext { private static Method getSpecificMethod(MethodInvocation mi) { return AopUtils.getMostSpecificMethod(mi.getMethod(), AopProxyUtils.ultimateTargetClass(mi.getThis())); } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionHandler.java b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionHandler.java index efebcca9b4..50e0bfa76a 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionHandler.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionHandler.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.security.access.expression.SecurityExpressionHandler; @@ -27,11 +29,10 @@ import org.springframework.security.access.expression.SecurityExpressionHandler; * @author Luke Taylor * @since 3.0 */ -public interface MethodSecurityExpressionHandler extends - SecurityExpressionHandler { +public interface MethodSecurityExpressionHandler extends SecurityExpressionHandler { + /** * Filters a target collection or array. Only applies to method invocations. - * * @param filterTarget the array or collection to be filtered. * @param filterExpression the expression which should be used as the filter * condition. If it returns false on evaluation, the object will be removed from the @@ -45,7 +46,6 @@ public interface MethodSecurityExpressionHandler extends /** * Used to inform the expression system of the return object for the given evaluation * context. Only applies to method invocations. - * * @param returnObject the return object value * @param ctx the context within which the object should be set (as created through a * call to diff --git a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionOperations.java b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionOperations.java index 24863d6bed..d3553eba33 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionOperations.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionOperations.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.springframework.security.access.expression.SecurityExpressionOperations; @@ -25,6 +26,7 @@ import org.springframework.security.access.expression.SecurityExpressionOperatio * @since 3.1.1 */ public interface MethodSecurityExpressionOperations extends SecurityExpressionOperations { + void setFilterObject(Object filterObject); Object getFilterObject(); @@ -34,4 +36,5 @@ public interface MethodSecurityExpressionOperations extends SecurityExpressionOp Object getReturnObject(); Object getThis(); + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRoot.java b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRoot.java index 0e8c47c7a5..a4d3f0a015 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRoot.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRoot.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.springframework.security.access.expression.SecurityExpressionRoot; @@ -24,44 +25,51 @@ import org.springframework.security.core.Authentication; * @author Luke Taylor * @since 3.0 */ -class MethodSecurityExpressionRoot extends SecurityExpressionRoot implements - MethodSecurityExpressionOperations { +class MethodSecurityExpressionRoot extends SecurityExpressionRoot implements MethodSecurityExpressionOperations { + private Object filterObject; + private Object returnObject; + private Object target; MethodSecurityExpressionRoot(Authentication a) { super(a); } + @Override public void setFilterObject(Object filterObject) { this.filterObject = filterObject; } + @Override public Object getFilterObject() { - return filterObject; + return this.filterObject; } + @Override public void setReturnObject(Object returnObject) { this.returnObject = returnObject; } + @Override public Object getReturnObject() { - return returnObject; + return this.returnObject; } /** * Sets the "this" property for use in expressions. Typically this will be the "this" * property of the {@code JoinPoint} representing the method invocation which is being * protected. - * * @param target the target object on which the method in is being invoked. */ void setThis(Object target) { this.target = target; } + @Override public Object getThis() { - return target; + return this.target; } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/PostInvocationExpressionAttribute.java b/core/src/main/java/org/springframework/security/access/expression/method/PostInvocationExpressionAttribute.java index d308dcb13f..08d9c6f2b8 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/PostInvocationExpressionAttribute.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/PostInvocationExpressionAttribute.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.springframework.expression.Expression; @@ -20,20 +21,18 @@ import org.springframework.expression.ParseException; import org.springframework.security.access.prepost.PostInvocationAttribute; /** - * * @author Luke Taylor * @since 3.0 */ -class PostInvocationExpressionAttribute extends - AbstractExpressionBasedMethodConfigAttribute implements PostInvocationAttribute { +class PostInvocationExpressionAttribute extends AbstractExpressionBasedMethodConfigAttribute + implements PostInvocationAttribute { - PostInvocationExpressionAttribute(String filterExpression, String authorizeExpression) - throws ParseException { + PostInvocationExpressionAttribute(String filterExpression, String authorizeExpression) throws ParseException { super(filterExpression, authorizeExpression); } - PostInvocationExpressionAttribute(Expression filterExpression, - Expression authorizeExpression) throws ParseException { + PostInvocationExpressionAttribute(Expression filterExpression, Expression authorizeExpression) + throws ParseException { super(filterExpression, authorizeExpression); } @@ -42,11 +41,9 @@ class PostInvocationExpressionAttribute extends StringBuilder sb = new StringBuilder(); Expression authorize = getAuthorizeExpression(); Expression filter = getFilterExpression(); - sb.append("[authorize: '").append( - authorize == null ? "null" : authorize.getExpressionString()); - sb.append("', filter: '") - .append(filter == null ? "null" : filter.getExpressionString()) - .append("']"); + sb.append("[authorize: '").append((authorize != null) ? authorize.getExpressionString() : "null"); + sb.append("', filter: '").append((filter != null) ? filter.getExpressionString() : "null").append("']"); return sb.toString(); } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/PreInvocationExpressionAttribute.java b/core/src/main/java/org/springframework/security/access/expression/method/PreInvocationExpressionAttribute.java index be871fa633..f8c3bae329 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/PreInvocationExpressionAttribute.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/PreInvocationExpressionAttribute.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import org.springframework.expression.Expression; @@ -20,37 +21,33 @@ import org.springframework.expression.ParseException; import org.springframework.security.access.prepost.PreInvocationAttribute; /** - * * @author Luke Taylor * @since 3.0 */ -class PreInvocationExpressionAttribute extends - AbstractExpressionBasedMethodConfigAttribute implements PreInvocationAttribute { +class PreInvocationExpressionAttribute extends AbstractExpressionBasedMethodConfigAttribute + implements PreInvocationAttribute { private final String filterTarget; - PreInvocationExpressionAttribute(String filterExpression, String filterTarget, - String authorizeExpression) throws ParseException { + PreInvocationExpressionAttribute(String filterExpression, String filterTarget, String authorizeExpression) + throws ParseException { super(filterExpression, authorizeExpression); - this.filterTarget = filterTarget; } - PreInvocationExpressionAttribute(Expression filterExpression, String filterTarget, - Expression authorizeExpression) throws ParseException { + PreInvocationExpressionAttribute(Expression filterExpression, String filterTarget, Expression authorizeExpression) + throws ParseException { super(filterExpression, authorizeExpression); - this.filterTarget = filterTarget; } /** * The parameter name of the target argument (must be a Collection) to which filtering * will be applied. - * * @return the method parameter name */ String getFilterTarget() { - return filterTarget; + return this.filterTarget; } @Override @@ -58,11 +55,10 @@ class PreInvocationExpressionAttribute extends StringBuilder sb = new StringBuilder(); Expression authorize = getAuthorizeExpression(); Expression filter = getFilterExpression(); - sb.append("[authorize: '").append( - authorize == null ? "null" : authorize.getExpressionString()); - sb.append("', filter: '").append( - filter == null ? "null" : filter.getExpressionString()); - sb.append("', filterTarget: '").append(filterTarget).append("']"); + sb.append("[authorize: '").append((authorize != null) ? authorize.getExpressionString() : "null"); + sb.append("', filter: '").append((filter != null) ? filter.getExpressionString() : "null"); + sb.append("', filterTarget: '").append(this.filterTarget).append("']"); return sb.toString(); } + } diff --git a/core/src/main/java/org/springframework/security/access/expression/method/package-info.java b/core/src/main/java/org/springframework/security/access/expression/method/package-info.java index 5c0e30b7b9..b559639982 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/package-info.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/package-info.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Implementation of expression-based method security. * * @since 3.0 */ package org.springframework.security.access.expression.method; - diff --git a/core/src/main/java/org/springframework/security/access/expression/package-info.java b/core/src/main/java/org/springframework/security/access/expression/package-info.java index c96ea42988..b8873c1ddb 100644 --- a/core/src/main/java/org/springframework/security/access/expression/package-info.java +++ b/core/src/main/java/org/springframework/security/access/expression/package-info.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Expression handling code to support the use of Spring-EL based expressions in {@code @PreAuthorize}, {@code @PreFilter}, - * {@code @PostAuthorize} and {@code @PostFilter} annotations. Mainly for internal framework use and liable to change. + * Expression handling code to support the use of Spring-EL based expressions in + * {@code @PreAuthorize}, {@code @PreFilter}, {@code @PostAuthorize} and + * {@code @PostFilter} annotations. Mainly for internal framework use and liable to + * change. * * @since 3.0 */ package org.springframework.security.access.expression; - diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/CycleInRoleHierarchyException.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/CycleInRoleHierarchyException.java index fe007cd5d7..75ef16fab7 100755 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/CycleInRoleHierarchyException.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/CycleInRoleHierarchyException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; /** diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/NullRoleHierarchy.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/NullRoleHierarchy.java index d29d307f4f..1e988862d0 100644 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/NullRoleHierarchy.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/NullRoleHierarchy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; import java.util.Collection; @@ -20,12 +21,12 @@ import java.util.Collection; import org.springframework.security.core.GrantedAuthority; /** - * * @author Luke Taylor * @since 3.0 */ public final class NullRoleHierarchy implements RoleHierarchy { + @Override public Collection getReachableGrantedAuthorities( Collection authorities) { return authorities; diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchy.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchy.java index 5ab09dda7b..c8c0a975d7 100755 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchy.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; import java.util.Collection; @@ -36,7 +37,6 @@ public interface RoleHierarchy { * Role hierarchy: ROLE_A > ROLE_B > ROLE_C.
    * Directly assigned authority: ROLE_A.
    * Reachable authorities: ROLE_A, ROLE_B, ROLE_C. - * * @param authorities - List of the directly assigned authorities. * @return List of all reachable authorities given the assigned authorities. */ diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapper.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapper.java index 9619d47500..a6d01008e6 100644 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapper.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapper.java @@ -13,25 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; +import java.util.Collection; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; -import java.util.*; - /** * @author Luke Taylor */ public class RoleHierarchyAuthoritiesMapper implements GrantedAuthoritiesMapper { + private final RoleHierarchy roleHierarchy; public RoleHierarchyAuthoritiesMapper(RoleHierarchy roleHierarchy) { this.roleHierarchy = roleHierarchy; } - public Collection mapAuthorities( - Collection authorities) { - return roleHierarchy.getReachableGrantedAuthorities(authorities); + @Override + public Collection mapAuthorities(Collection authorities) { + return this.roleHierarchy.getReachableGrantedAuthorities(authorities); } + } diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImpl.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImpl.java index 0a51abe2de..e0988445fe 100755 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImpl.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImpl.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; @@ -34,7 +36,8 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; * This class defines a role hierarchy for use with various access checking components. * *

    - * Here is an example configuration of a role hierarchy (hint: read the ">" sign as "includes"): + * Here is an example configuration of a role hierarchy (hint: read the ">" sign as + * "includes"): * *

      *     <property name="hierarchy">
    @@ -49,25 +52,26 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
      * 

    * Explanation of the above: *

      - *
    • In effect every user with ROLE_A also has ROLE_B, ROLE_AUTHENTICATED and ROLE_UNAUTHENTICATED;
    • + *
    • In effect every user with ROLE_A also has ROLE_B, ROLE_AUTHENTICATED and + * ROLE_UNAUTHENTICATED;
    • *
    • every user with ROLE_B also has ROLE_AUTHENTICATED and ROLE_UNAUTHENTICATED;
    • *
    • every user with ROLE_AUTHENTICATED also has ROLE_UNAUTHENTICATED.
    • *
    * *

    - * Hierarchical Roles will dramatically shorten your access rules (and also make the access rules - * much more elegant). + * Hierarchical Roles will dramatically shorten your access rules (and also make the + * access rules much more elegant). * *

    - * Consider this access rule for Spring Security's RoleVoter (background: every user that is - * authenticated should be able to log out): + * Consider this access rule for Spring Security's RoleVoter (background: every user that + * is authenticated should be able to log out): *

    /logout.html=ROLE_A,ROLE_B,ROLE_AUTHENTICATED
    * * With hierarchical roles this can now be shortened to: *
    /logout.html=ROLE_AUTHENTICATED
    * - * In addition to shorter rules this will also make your access rules more readable and your - * intentions clearer. + * In addition to shorter rules this will also make your access rules more readable and + * your intentions clearer. * * @author Michael Mayr */ @@ -76,20 +80,21 @@ public class RoleHierarchyImpl implements RoleHierarchy { private static final Log logger = LogFactory.getLog(RoleHierarchyImpl.class); /** - * Raw hierarchy configuration where each line represents single or multiple level role chain. + * Raw hierarchy configuration where each line represents single or multiple level + * role chain. */ private String roleHierarchyStringRepresentation = null; /** - * {@code rolesReachableInOneStepMap} is a Map that under the key of a specific role name - * contains a set of all roles reachable from this role in 1 step - * (i.e. parsed {@link #roleHierarchyStringRepresentation} grouped by the higher role) + * {@code rolesReachableInOneStepMap} is a Map that under the key of a specific role + * name contains a set of all roles reachable from this role in 1 step (i.e. parsed + * {@link #roleHierarchyStringRepresentation} grouped by the higher role) */ private Map> rolesReachableInOneStepMap = null; /** - * {@code rolesReachableInOneOrMoreStepsMap} is a Map that under the key of a specific role - * name contains a set of all roles reachable from this role in 1 or more steps + * {@code rolesReachableInOneOrMoreStepsMap} is a Map that under the key of a specific + * role name contains a set of all roles reachable from this role in 1 or more steps * (i.e. fully resolved hierarchy from {@link #rolesReachableInOneStepMap}) */ private Map> rolesReachableInOneOrMoreStepsMap = null; @@ -100,17 +105,12 @@ public class RoleHierarchyImpl implements RoleHierarchy { * is done for performance reasons (reachable roles can then be calculated in O(1) * time). During pre-calculation, cycles in role hierarchy are detected and will cause * a CycleInRoleHierarchyException to be thrown. - * * @param roleHierarchyStringRepresentation - String definition of the role hierarchy. */ public void setHierarchy(String roleHierarchyStringRepresentation) { this.roleHierarchyStringRepresentation = roleHierarchyStringRepresentation; - - if (logger.isDebugEnabled()) { - logger.debug("setHierarchy() - The following role hierarchy was set: " - + roleHierarchyStringRepresentation); - } - + logger.debug(LogMessage.format("setHierarchy() - The following role hierarchy was set: %s", + roleHierarchyStringRepresentation)); buildRolesReachableInOneStepMap(); buildRolesReachableInOneOrMoreStepsMap(); } @@ -121,10 +121,8 @@ public class RoleHierarchyImpl implements RoleHierarchy { if (authorities == null || authorities.isEmpty()) { return AuthorityUtils.NO_AUTHORITIES; } - Set reachableRoles = new HashSet<>(); Set processedNames = new HashSet<>(); - for (GrantedAuthority authority : authorities) { // Do not process authorities without string representation if (authority.getAuthority() == null) { @@ -148,17 +146,10 @@ public class RoleHierarchyImpl implements RoleHierarchy { } } } - - if (logger.isDebugEnabled()) { - logger.debug("getReachableGrantedAuthorities() - From the roles " - + authorities + " one can reach " + reachableRoles - + " in zero or more steps."); - } - - List reachableRoleList = new ArrayList<>(reachableRoles.size()); - reachableRoleList.addAll(reachableRoles); - - return reachableRoleList; + logger.debug(LogMessage.format( + "getReachableGrantedAuthorities() - From the roles %s one can reach %s in zero or more steps.", + authorities, reachableRoles)); + return new ArrayList<>(reachableRoles); } /** @@ -170,24 +161,21 @@ public class RoleHierarchyImpl implements RoleHierarchy { for (String line : this.roleHierarchyStringRepresentation.split("\n")) { // Split on > and trim excessive whitespace String[] roles = line.trim().split("\\s+>\\s+"); - for (int i = 1; i < roles.length; i++) { String higherRole = roles[i - 1]; GrantedAuthority lowerRole = new SimpleGrantedAuthority(roles[i]); - Set rolesReachableInOneStepSet; if (!this.rolesReachableInOneStepMap.containsKey(higherRole)) { rolesReachableInOneStepSet = new HashSet<>(); this.rolesReachableInOneStepMap.put(higherRole, rolesReachableInOneStepSet); - } else { + } + else { rolesReachableInOneStepSet = this.rolesReachableInOneStepMap.get(higherRole); } rolesReachableInOneStepSet.add(lowerRole); - - if (logger.isDebugEnabled()) { - logger.debug("buildRolesReachableInOneStepMap() - From role " + higherRole - + " one can reach role " + lowerRole + " in one step."); - } + logger.debug(LogMessage.format( + "buildRolesReachableInOneStepMap() - From role %s one can reach role %s in one step.", + higherRole, lowerRole)); } } } @@ -204,23 +192,23 @@ public class RoleHierarchyImpl implements RoleHierarchy { for (String roleName : this.rolesReachableInOneStepMap.keySet()) { Set rolesToVisitSet = new HashSet<>(this.rolesReachableInOneStepMap.get(roleName)); Set visitedRolesSet = new HashSet<>(); - while (!rolesToVisitSet.isEmpty()) { // take a role from the rolesToVisit set GrantedAuthority lowerRole = rolesToVisitSet.iterator().next(); rolesToVisitSet.remove(lowerRole); - if (!visitedRolesSet.add(lowerRole) || - !this.rolesReachableInOneStepMap.containsKey(lowerRole.getAuthority())) { + if (!visitedRolesSet.add(lowerRole) + || !this.rolesReachableInOneStepMap.containsKey(lowerRole.getAuthority())) { continue; // Already visited role or role with missing hierarchy - } else if (roleName.equals(lowerRole.getAuthority())) { + } + else if (roleName.equals(lowerRole.getAuthority())) { throw new CycleInRoleHierarchyException(); } rolesToVisitSet.addAll(this.rolesReachableInOneStepMap.get(lowerRole.getAuthority())); } this.rolesReachableInOneOrMoreStepsMap.put(roleName, visitedRolesSet); - - logger.debug("buildRolesReachableInOneOrMoreStepsMap() - From role " + roleName - + " one can reach " + visitedRolesSet + " in one or more steps."); + logger.debug(LogMessage.format( + "buildRolesReachableInOneOrMoreStepsMap() - From role %s one can reach %s in one or more steps.", + roleName, visitedRolesSet)); } } diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtils.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtils.java index 6b7e4afcbe..2143632d14 100644 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtils.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtils.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.hierarchicalroles; -import org.springframework.util.Assert; +package org.springframework.security.access.hierarchicalroles; import java.io.PrintWriter; import java.io.StringWriter; import java.util.List; import java.util.Map; +import org.springframework.util.Assert; + /** * Utility methods for {@link RoleHierarchy}. * @@ -35,35 +36,26 @@ public final class RoleHierarchyUtils { /** * Converts the supplied {@link Map} of role name to implied role name(s) to a string - * representation understood by {@link RoleHierarchyImpl#setHierarchy(String)}. - * The map key is the role name and the map value is a {@link List} of implied role name(s). - * + * representation understood by {@link RoleHierarchyImpl#setHierarchy(String)}. The + * map key is the role name and the map value is a {@link List} of implied role + * name(s). * @param roleHierarchyMap the mapping(s) of role name to implied role name(s) * @return a string representation of a role hierarchy - * @throws IllegalArgumentException if roleHierarchyMap is null or empty or if a role name is null or - * empty or if an implied role name(s) is null or empty - * + * @throws IllegalArgumentException if roleHierarchyMap is null or empty or if a role + * name is null or empty or if an implied role name(s) is null or empty */ public static String roleHierarchyFromMap(Map> roleHierarchyMap) { Assert.notEmpty(roleHierarchyMap, "roleHierarchyMap cannot be empty"); - - StringWriter roleHierarchyBuffer = new StringWriter(); - PrintWriter roleHierarchyWriter = new PrintWriter(roleHierarchyBuffer); - - for (Map.Entry> roleHierarchyEntry : roleHierarchyMap.entrySet()) { - String role = roleHierarchyEntry.getKey(); - List impliedRoles = roleHierarchyEntry.getValue(); - + StringWriter result = new StringWriter(); + PrintWriter writer = new PrintWriter(result); + roleHierarchyMap.forEach((role, impliedRoles) -> { Assert.hasLength(role, "role name must be supplied"); Assert.notEmpty(impliedRoles, "implied role name(s) cannot be empty"); - for (String impliedRole : impliedRoles) { - String roleMapping = role + " > " + impliedRole; - roleHierarchyWriter.println(roleMapping); + writer.println(role + " > " + impliedRole); } - } - - return roleHierarchyBuffer.toString(); + }); + return result.toString(); } } diff --git a/core/src/main/java/org/springframework/security/access/hierarchicalroles/package-info.java b/core/src/main/java/org/springframework/security/access/hierarchicalroles/package-info.java index bbcb668705..e161c07e18 100644 --- a/core/src/main/java/org/springframework/security/access/hierarchicalroles/package-info.java +++ b/core/src/main/java/org/springframework/security/access/hierarchicalroles/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Role hierarchy implementation. */ package org.springframework.security.access.hierarchicalroles; - diff --git a/core/src/main/java/org/springframework/security/access/intercept/AbstractSecurityInterceptor.java b/core/src/main/java/org/springframework/security/access/intercept/AbstractSecurityInterceptor.java index 2f3d9e7b76..d92a5ed500 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/AbstractSecurityInterceptor.java +++ b/core/src/main/java/org/springframework/security/access/intercept/AbstractSecurityInterceptor.java @@ -22,6 +22,7 @@ import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; @@ -29,6 +30,7 @@ import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; @@ -46,6 +48,7 @@ import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Abstract class that implements security interception for secure objects. @@ -54,7 +57,8 @@ import org.springframework.util.Assert; * configuration of the security interceptor. It will also implement the proper handling * of secure object invocations, namely: *
      - *
    1. Obtain the {@link Authentication} object from the {@link SecurityContextHolder}.
    2. + *
    3. Obtain the {@link Authentication} object from the + * {@link SecurityContextHolder}.
    4. *
    5. Determine if the request relates to a secured or public invocation by looking up * the secure object request against the {@link SecurityMetadataSource}.
    6. *
    7. For an invocation that is secured (there is a list of ConfigAttributes @@ -100,178 +104,130 @@ import org.springframework.util.Assert; * @author Ben Alex * @author Rob Winch */ -public abstract class AbstractSecurityInterceptor implements InitializingBean, - ApplicationEventPublisherAware, MessageSourceAware { - // ~ Static fields/initializers - // ===================================================================================== +public abstract class AbstractSecurityInterceptor + implements InitializingBean, ApplicationEventPublisherAware, MessageSourceAware { protected final Log logger = LogFactory.getLog(getClass()); - // ~ Instance fields - // ================================================================================================ - protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private ApplicationEventPublisher eventPublisher; + private AccessDecisionManager accessDecisionManager; + private AfterInvocationManager afterInvocationManager; + private AuthenticationManager authenticationManager = new NoOpAuthenticationManager(); + private RunAsManager runAsManager = new NullRunAsManager(); private boolean alwaysReauthenticate = false; + private boolean rejectPublicInvocations = false; + private boolean validateConfigAttributes = true; + private boolean publishAuthorizationSuccess = false; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(getSecureObjectClass(), - "Subclass must provide a non-null response to getSecureObjectClass()"); + Assert.notNull(getSecureObjectClass(), "Subclass must provide a non-null response to getSecureObjectClass()"); Assert.notNull(this.messages, "A message source must be set"); Assert.notNull(this.authenticationManager, "An AuthenticationManager is required"); Assert.notNull(this.accessDecisionManager, "An AccessDecisionManager is required"); Assert.notNull(this.runAsManager, "A RunAsManager is required"); - Assert.notNull(this.obtainSecurityMetadataSource(), - "An SecurityMetadataSource is required"); - Assert.isTrue(this.obtainSecurityMetadataSource() - .supports(getSecureObjectClass()), - () -> "SecurityMetadataSource does not support secure object class: " - + getSecureObjectClass()); + Assert.notNull(this.obtainSecurityMetadataSource(), "An SecurityMetadataSource is required"); + Assert.isTrue(this.obtainSecurityMetadataSource().supports(getSecureObjectClass()), + () -> "SecurityMetadataSource does not support secure object class: " + getSecureObjectClass()); Assert.isTrue(this.runAsManager.supports(getSecureObjectClass()), - () -> "RunAsManager does not support secure object class: " - + getSecureObjectClass()); + () -> "RunAsManager does not support secure object class: " + getSecureObjectClass()); Assert.isTrue(this.accessDecisionManager.supports(getSecureObjectClass()), - () -> "AccessDecisionManager does not support secure object class: " - + getSecureObjectClass()); - + () -> "AccessDecisionManager does not support secure object class: " + getSecureObjectClass()); if (this.afterInvocationManager != null) { Assert.isTrue(this.afterInvocationManager.supports(getSecureObjectClass()), - () -> "AfterInvocationManager does not support secure object class: " - + getSecureObjectClass()); + () -> "AfterInvocationManager does not support secure object class: " + getSecureObjectClass()); } - if (this.validateConfigAttributes) { - Collection attributeDefs = this - .obtainSecurityMetadataSource().getAllConfigAttributes(); - + Collection attributeDefs = this.obtainSecurityMetadataSource().getAllConfigAttributes(); if (attributeDefs == null) { - logger.warn("Could not validate configuration attributes as the SecurityMetadataSource did not return " - + "any attributes from getAllConfigAttributes()"); + this.logger.warn("Could not validate configuration attributes as the " + + "SecurityMetadataSource did not return any attributes from getAllConfigAttributes()"); return; } - - Set unsupportedAttrs = new HashSet<>(); - - for (ConfigAttribute attr : attributeDefs) { - if (!this.runAsManager.supports(attr) - && !this.accessDecisionManager.supports(attr) - && ((this.afterInvocationManager == null) || !this.afterInvocationManager - .supports(attr))) { - unsupportedAttrs.add(attr); - } - } - - if (unsupportedAttrs.size() != 0) { - throw new IllegalArgumentException( - "Unsupported configuration attributes: " + unsupportedAttrs); - } - - logger.debug("Validated configuration attributes"); + validateAttributeDefs(attributeDefs); } } + private void validateAttributeDefs(Collection attributeDefs) { + Set unsupportedAttrs = new HashSet<>(); + for (ConfigAttribute attr : attributeDefs) { + if (!this.runAsManager.supports(attr) && !this.accessDecisionManager.supports(attr) + && ((this.afterInvocationManager == null) || !this.afterInvocationManager.supports(attr))) { + unsupportedAttrs.add(attr); + } + } + if (unsupportedAttrs.size() != 0) { + throw new IllegalArgumentException("Unsupported configuration attributes: " + unsupportedAttrs); + } + this.logger.debug("Validated configuration attributes"); + } + protected InterceptorStatusToken beforeInvocation(Object object) { Assert.notNull(object, "Object was null"); - final boolean debug = logger.isDebugEnabled(); - if (!getSecureObjectClass().isAssignableFrom(object.getClass())) { - throw new IllegalArgumentException( - "Security invocation attempted for object " - + object.getClass().getName() - + " but AbstractSecurityInterceptor only configured to support secure objects of type: " - + getSecureObjectClass()); + throw new IllegalArgumentException("Security invocation attempted for object " + object.getClass().getName() + + " but AbstractSecurityInterceptor only configured to support secure objects of type: " + + getSecureObjectClass()); } - - Collection attributes = this.obtainSecurityMetadataSource() - .getAttributes(object); - - if (attributes == null || attributes.isEmpty()) { - if (rejectPublicInvocations) { - throw new IllegalArgumentException( - "Secure object invocation " - + object - + " was denied as public invocations are not allowed via this interceptor. " - + "This indicates a configuration error because the " - + "rejectPublicInvocations property is set to 'true'"); - } - - if (debug) { - logger.debug("Public object - authentication not attempted"); - } - + Collection attributes = this.obtainSecurityMetadataSource().getAttributes(object); + if (CollectionUtils.isEmpty(attributes)) { + Assert.isTrue(!this.rejectPublicInvocations, + () -> "Secure object invocation " + object + + " was denied as public invocations are not allowed via this interceptor. " + + "This indicates a configuration error because the " + + "rejectPublicInvocations property is set to 'true'"); + this.logger.debug("Public object - authentication not attempted"); publishEvent(new PublicInvocationEvent(object)); - return null; // no further work post-invocation } - - if (debug) { - logger.debug("Secure object: " + object + "; Attributes: " + attributes); - } - + this.logger.debug(LogMessage.format("Secure object: %s; Attributes: %s", object, attributes)); if (SecurityContextHolder.getContext().getAuthentication() == null) { - credentialsNotFound(messages.getMessage( - "AbstractSecurityInterceptor.authenticationNotFound", - "An Authentication object was not found in the SecurityContext"), - object, attributes); + credentialsNotFound(this.messages.getMessage("AbstractSecurityInterceptor.authenticationNotFound", + "An Authentication object was not found in the SecurityContext"), object, attributes); } - Authentication authenticated = authenticateIfRequired(); - // Attempt authorization - try { - this.accessDecisionManager.decide(authenticated, object, attributes); - } - catch (AccessDeniedException accessDeniedException) { - publishEvent(new AuthorizationFailureEvent(object, attributes, authenticated, - accessDeniedException)); - - throw accessDeniedException; - } - - if (debug) { - logger.debug("Authorization successful"); - } - - if (publishAuthorizationSuccess) { + attemptAuthorization(object, attributes, authenticated); + this.logger.debug("Authorization successful"); + if (this.publishAuthorizationSuccess) { publishEvent(new AuthorizedEvent(object, attributes, authenticated)); } // Attempt to run as a different user - Authentication runAs = this.runAsManager.buildRunAs(authenticated, object, - attributes); - - if (runAs == null) { - if (debug) { - logger.debug("RunAsManager did not change Authentication object"); - } - - // no further work post-invocation - return new InterceptorStatusToken(SecurityContextHolder.getContext(), false, - attributes, object); - } - else { - if (debug) { - logger.debug("Switching to RunAs Authentication: " + runAs); - } - + Authentication runAs = this.runAsManager.buildRunAs(authenticated, object, attributes); + if (runAs != null) { + this.logger.debug(LogMessage.format("Switching to RunAs Authentication: %s", runAs)); SecurityContext origCtx = SecurityContextHolder.getContext(); SecurityContextHolder.setContext(SecurityContextHolder.createEmptyContext()); SecurityContextHolder.getContext().setAuthentication(runAs); - // need to revert to token.Authenticated post-invocation return new InterceptorStatusToken(origCtx, true, attributes, object); } + this.logger.debug("RunAsManager did not change Authentication object"); + // no further work post-invocation + return new InterceptorStatusToken(SecurityContextHolder.getContext(), false, attributes, object); + + } + + private void attemptAuthorization(Object object, Collection attributes, + Authentication authenticated) { + try { + this.accessDecisionManager.decide(authenticated, object, attributes); + } + catch (AccessDeniedException ex) { + publishEvent(new AuthorizationFailureEvent(object, attributes, authenticated, ex)); + throw ex; + } } /** @@ -279,16 +235,12 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, * object invocation has been completed. This method should be invoked after the * secure object invocation and before afterInvocation regardless of the secure object * invocation returning successfully (i.e. it should be done in a finally block). - * * @param token as returned by the {@link #beforeInvocation(Object)} method */ protected void finallyInvocation(InterceptorStatusToken token) { if (token != null && token.isContextHolderRefreshRequired()) { - if (logger.isDebugEnabled()) { - logger.debug("Reverting to original Authentication: " - + token.getSecurityContext().getAuthentication()); - } - + this.logger.debug(LogMessage.of( + () -> "Reverting to original Authentication: " + token.getSecurityContext().getAuthentication())); SecurityContextHolder.setContext(token.getSecurityContext()); } } @@ -296,7 +248,6 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, /** * Completes the work of the AbstractSecurityInterceptor after the secure * object invocation has been completed. - * * @param token as returned by the {@link #beforeInvocation(Object)} method * @param returnedObject any object returned from the secure object invocation (may be * null) @@ -308,27 +259,19 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, // public object return returnedObject; } - finallyInvocation(token); // continue to clean in this method for passivity - - if (afterInvocationManager != null) { + if (this.afterInvocationManager != null) { // Attempt after invocation handling try { - returnedObject = afterInvocationManager.decide(token.getSecurityContext() - .getAuthentication(), token.getSecureObject(), token - .getAttributes(), returnedObject); + returnedObject = this.afterInvocationManager.decide(token.getSecurityContext().getAuthentication(), + token.getSecureObject(), token.getAttributes(), returnedObject); } - catch (AccessDeniedException accessDeniedException) { - AuthorizationFailureEvent event = new AuthorizationFailureEvent( - token.getSecureObject(), token.getAttributes(), token - .getSecurityContext().getAuthentication(), - accessDeniedException); - publishEvent(event); - - throw accessDeniedException; + catch (AccessDeniedException ex) { + publishEvent(new AuthorizationFailureEvent(token.getSecureObject(), token.getAttributes(), + token.getSecurityContext().getAuthentication(), ex)); + throw ex; } } - return returnedObject; } @@ -336,31 +279,18 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, * Checks the current authentication token and passes it to the AuthenticationManager * if {@link org.springframework.security.core.Authentication#isAuthenticated()} * returns false or the property alwaysReauthenticate has been set to true. - * * @return an authenticated Authentication object. */ private Authentication authenticateIfRequired() { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); - - if (authentication.isAuthenticated() && !alwaysReauthenticate) { - if (logger.isDebugEnabled()) { - logger.debug("Previously Authenticated: " + authentication); - } - + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication.isAuthenticated() && !this.alwaysReauthenticate) { + this.logger.debug(LogMessage.format("Previously Authenticated: %s", authentication)); return authentication; } - - authentication = authenticationManager.authenticate(authentication); - - // We don't authenticated.setAuthentication(true), because each provider should do - // that - if (logger.isDebugEnabled()) { - logger.debug("Successfully Authenticated: " + authentication); - } - + authentication = this.authenticationManager.authenticate(authentication); + // Don't authenticated.setAuthentication(true) because each provider does that + this.logger.debug(LogMessage.format("Successfully Authenticated: %s", authentication)); SecurityContextHolder.getContext().setAuthentication(authentication); - return authentication; } @@ -369,29 +299,24 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, * publishes an event to the application context. *

      * Always throws an exception. - * * @param reason to be provided in the exception detail * @param secureObject that was being called * @param configAttribs that were defined for the secureObject */ - private void credentialsNotFound(String reason, Object secureObject, - Collection configAttribs) { - AuthenticationCredentialsNotFoundException exception = new AuthenticationCredentialsNotFoundException( - reason); - - AuthenticationCredentialsNotFoundEvent event = new AuthenticationCredentialsNotFoundEvent( - secureObject, configAttribs, exception); + private void credentialsNotFound(String reason, Object secureObject, Collection configAttribs) { + AuthenticationCredentialsNotFoundException exception = new AuthenticationCredentialsNotFoundException(reason); + AuthenticationCredentialsNotFoundEvent event = new AuthenticationCredentialsNotFoundEvent(secureObject, + configAttribs, exception); publishEvent(event); - throw exception; } public AccessDecisionManager getAccessDecisionManager() { - return accessDecisionManager; + return this.accessDecisionManager; } public AfterInvocationManager getAfterInvocationManager() { - return afterInvocationManager; + return this.afterInvocationManager; } public AuthenticationManager getAuthenticationManager() { @@ -399,28 +324,27 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, } public RunAsManager getRunAsManager() { - return runAsManager; + return this.runAsManager; } /** * Indicates the type of secure objects the subclass will be presenting to the * abstract parent for processing. This is used to ensure collaborators wired to the * {@code AbstractSecurityInterceptor} all support the indicated secure object class. - * * @return the type of secure object the subclass provides services for */ public abstract Class getSecureObjectClass(); public boolean isAlwaysReauthenticate() { - return alwaysReauthenticate; + return this.alwaysReauthenticate; } public boolean isRejectPublicInvocations() { - return rejectPublicInvocations; + return this.rejectPublicInvocations; } public boolean isValidateConfigAttributes() { - return validateConfigAttributes; + return this.validateConfigAttributes; } public abstract SecurityMetadataSource obtainSecurityMetadataSource(); @@ -439,7 +363,6 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, * meaning by default the Authentication.isAuthenticated() property is * trusted and re-authentication will not occur if the principal has already been * authenticated. - * * @param alwaysReauthenticate true to force * AbstractSecurityInterceptor to disregard the value of * Authentication.isAuthenticated() and always re-authenticate the @@ -449,8 +372,8 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, this.alwaysReauthenticate = alwaysReauthenticate; } - public void setApplicationEventPublisher( - ApplicationEventPublisher applicationEventPublisher) { + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { this.eventPublisher = applicationEventPublisher; } @@ -458,6 +381,7 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, this.authenticationManager = newManager; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } @@ -465,7 +389,6 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, /** * Only {@code AuthorizationFailureEvent} will be published. If you set this property * to {@code true}, {@code AuthorizedEvent}s will also be published. - * * @param publishAuthorizationSuccess default value is {@code false} */ public void setPublishAuthorizationSuccess(boolean publishAuthorizationSuccess) { @@ -481,7 +404,6 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, * IllegalArgumentException will be thrown by the * AbstractSecurityInterceptor if you set this property to true and * an attempt is made to invoke a secure object that has no configuration attributes. - * * @param rejectPublicInvocations set to true to reject invocations of * secure objects that have no configuration attributes (by default it is * false which treats undeclared secure objects as "public" or @@ -507,10 +429,11 @@ public abstract class AbstractSecurityInterceptor implements InitializingBean, private static class NoOpAuthenticationManager implements AuthenticationManager { - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { - throw new AuthenticationServiceException("Cannot authenticate " - + authentication); + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + throw new AuthenticationServiceException("Cannot authenticate " + authentication); } + } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationManager.java b/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationManager.java index eae4f139fe..679b495d54 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationManager.java +++ b/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationManager.java @@ -43,31 +43,25 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface AfterInvocationManager { - // ~ Methods - // ======================================================================================================== /** * Given the details of a secure object invocation including its returned * Object, make an access control decision or optionally modify the * returned Object. - * * @param authentication the caller that invoked the method * @param object the secured object that was called * @param attributes the configuration attributes associated with the secured object * that was invoked * @param returnedObject the Object that was returned from the secure * object invocation - * * @return the Object that will ultimately be returned to the caller (if * an implementation does not wish to modify the object to be returned to the caller, * the implementation should simply return the same object it was passed by the * returnedObject method argument) - * * @throws AccessDeniedException if access is denied */ - Object decide(Authentication authentication, Object object, - Collection attributes, Object returnedObject) - throws AccessDeniedException; + Object decide(Authentication authentication, Object object, Collection attributes, + Object returnedObject) throws AccessDeniedException; /** * Indicates whether this AfterInvocationManager is able to process @@ -78,10 +72,8 @@ public interface AfterInvocationManager { * AccessDecisionManager and/or RunAsManager and/or * AfterInvocationManager. *

      - * * @param attribute a configuration attribute that has been configured against the * AbstractSecurityInterceptor - * * @return true if this AfterInvocationManager can support the passed * configuration attribute */ @@ -90,10 +82,9 @@ public interface AfterInvocationManager { /** * Indicates whether the AfterInvocationManager implementation is able to * provide access control decisions for the indicated secured object type. - * * @param clazz the class that is being queried - * * @return true if the implementation can process the indicated class */ boolean supports(Class clazz); + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationProviderManager.java b/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationProviderManager.java index bc8bea7fb3..44d42d850a 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationProviderManager.java +++ b/core/src/main/java/org/springframework/security/access/intercept/AfterInvocationProviderManager.java @@ -22,12 +22,15 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AfterInvocationProvider; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Provider-based implementation of {@link AfterInvocationManager}. @@ -45,43 +48,24 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public class AfterInvocationProviderManager implements AfterInvocationManager, - InitializingBean { - // ~ Static fields/initializers - // ===================================================================================== +public class AfterInvocationProviderManager implements AfterInvocationManager, InitializingBean { - protected static final Log logger = LogFactory - .getLog(AfterInvocationProviderManager.class); - - // ~ Instance fields - // ================================================================================================ + protected static final Log logger = LogFactory.getLog(AfterInvocationProviderManager.class); private List providers; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { checkIfValidList(this.providers); } - private void checkIfValidList(List listToCheck) { - if ((listToCheck == null) || (listToCheck.size() == 0)) { - throw new IllegalArgumentException( - "A list of AfterInvocationProviders is required"); - } - } - - public Object decide(Authentication authentication, Object object, - Collection config, Object returnedObject) - throws AccessDeniedException { - + @Override + public Object decide(Authentication authentication, Object object, Collection config, + Object returnedObject) throws AccessDeniedException { Object result = returnedObject; - - for (AfterInvocationProvider provider : providers) { + for (AfterInvocationProvider provider : this.providers) { result = provider.decide(authentication, object, config, result); } - return result; } @@ -91,27 +75,26 @@ public class AfterInvocationProviderManager implements AfterInvocationManager, public void setProviders(List newList) { checkIfValidList(newList); - providers = new ArrayList<>(newList.size()); - + this.providers = new ArrayList<>(newList.size()); for (Object currentObject : newList) { - Assert.isInstanceOf(AfterInvocationProvider.class, currentObject, - () -> "AfterInvocationProvider " + currentObject.getClass().getName() - + " must implement AfterInvocationProvider"); - providers.add((AfterInvocationProvider) currentObject); + Assert.isInstanceOf(AfterInvocationProvider.class, currentObject, () -> "AfterInvocationProvider " + + currentObject.getClass().getName() + " must implement AfterInvocationProvider"); + this.providers.add((AfterInvocationProvider) currentObject); } } - public boolean supports(ConfigAttribute attribute) { - for (AfterInvocationProvider provider : providers) { - if (logger.isDebugEnabled()) { - logger.debug("Evaluating " + attribute + " against " + provider); - } + private void checkIfValidList(List listToCheck) { + Assert.isTrue(!CollectionUtils.isEmpty(listToCheck), "A list of AfterInvocationProviders is required"); + } + @Override + public boolean supports(ConfigAttribute attribute) { + for (AfterInvocationProvider provider : this.providers) { + logger.debug(LogMessage.format("Evaluating %s against %s", attribute, provider)); if (provider.supports(attribute)) { return true; } } - return false; } @@ -121,20 +104,19 @@ public class AfterInvocationProviderManager implements AfterInvocationManager, *

      * If one or more providers cannot support the presented class, false is * returned. - * * @param clazz the secure object class being queries - * * @return if the AfterInvocationProviderManager can support the secure * object class, which requires every one of its AfterInvocationProviders * to support the secure object class */ + @Override public boolean supports(Class clazz) { - for (AfterInvocationProvider provider : providers) { + for (AfterInvocationProvider provider : this.providers) { if (!provider.supports(clazz)) { return false; } } - return true; } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/InterceptorStatusToken.java b/core/src/main/java/org/springframework/security/access/intercept/InterceptorStatusToken.java index fbe6f7acc6..de97aebf60 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/InterceptorStatusToken.java +++ b/core/src/main/java/org/springframework/security/access/intercept/InterceptorStatusToken.java @@ -31,42 +31,37 @@ import org.springframework.security.core.context.SecurityContext; * @author Ben Alex */ public class InterceptorStatusToken { - // ~ Instance fields - // ================================================================================================ private SecurityContext securityContext; + private Collection attr; + private Object secureObject; + private boolean contextHolderRefreshRequired; - // ~ Constructors - // =================================================================================================== - - public InterceptorStatusToken(SecurityContext securityContext, - boolean contextHolderRefreshRequired, Collection attributes, - Object secureObject) { + public InterceptorStatusToken(SecurityContext securityContext, boolean contextHolderRefreshRequired, + Collection attributes, Object secureObject) { this.securityContext = securityContext; this.contextHolderRefreshRequired = contextHolderRefreshRequired; this.attr = attributes; this.secureObject = secureObject; } - // ~ Methods - // ======================================================================================================== - public Collection getAttributes() { - return attr; + return this.attr; } public SecurityContext getSecurityContext() { - return securityContext; + return this.securityContext; } public Object getSecureObject() { - return secureObject; + return this.secureObject; } public boolean isContextHolderRefreshRequired() { - return contextHolderRefreshRequired; + return this.contextHolderRefreshRequired; } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/MethodInvocationPrivilegeEvaluator.java b/core/src/main/java/org/springframework/security/access/intercept/MethodInvocationPrivilegeEvaluator.java index e8551ca0ae..95b98e2622 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/MethodInvocationPrivilegeEvaluator.java +++ b/core/src/main/java/org/springframework/security/access/intercept/MethodInvocationPrivilegeEvaluator.java @@ -21,7 +21,9 @@ import java.util.Collection; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; @@ -43,67 +45,44 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class MethodInvocationPrivilegeEvaluator implements InitializingBean { - // ~ Static fields/initializers - // ===================================================================================== - protected static final Log logger = LogFactory - .getLog(MethodInvocationPrivilegeEvaluator.class); - - // ~ Instance fields - // ================================================================================================ + protected static final Log logger = LogFactory.getLog(MethodInvocationPrivilegeEvaluator.class); private AbstractSecurityInterceptor securityInterceptor; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(securityInterceptor, "SecurityInterceptor required"); + Assert.notNull(this.securityInterceptor, "SecurityInterceptor required"); } - public boolean isAllowed(MethodInvocation mi, Authentication authentication) { - Assert.notNull(mi, "MethodInvocation required"); - Assert.notNull(mi.getMethod(), - "MethodInvocation must provide a non-null getMethod()"); - - Collection attrs = securityInterceptor - .obtainSecurityMetadataSource().getAttributes(mi); - + public boolean isAllowed(MethodInvocation invocation, Authentication authentication) { + Assert.notNull(invocation, "MethodInvocation required"); + Assert.notNull(invocation.getMethod(), "MethodInvocation must provide a non-null getMethod()"); + Collection attrs = this.securityInterceptor.obtainSecurityMetadataSource() + .getAttributes(invocation); if (attrs == null) { - if (securityInterceptor.isRejectPublicInvocations()) { - return false; - } - - return true; + return !this.securityInterceptor.isRejectPublicInvocations(); } - if (authentication == null || authentication.getAuthorities().isEmpty()) { return false; } - try { - securityInterceptor.getAccessDecisionManager().decide(authentication, mi, - attrs); + this.securityInterceptor.getAccessDecisionManager().decide(authentication, invocation, attrs); + return true; } catch (AccessDeniedException unauthorized) { - if (logger.isDebugEnabled()) { - logger.debug(mi.toString() + " denied for " + authentication.toString(), - unauthorized); - } - + logger.debug(LogMessage.format("%s denied for %s", invocation, authentication), unauthorized); return false; } - - return true; } public void setSecurityInterceptor(AbstractSecurityInterceptor securityInterceptor) { Assert.notNull(securityInterceptor, "AbstractSecurityInterceptor cannot be null"); - Assert.isTrue( - MethodInvocation.class.equals(securityInterceptor.getSecureObjectClass()), + Assert.isTrue(MethodInvocation.class.equals(securityInterceptor.getSecureObjectClass()), "AbstractSecurityInterceptor does not support MethodInvocations"); Assert.notNull(securityInterceptor.getAccessDecisionManager(), "AbstractSecurityInterceptor must provide a non-null AccessDecisionManager"); this.securityInterceptor = securityInterceptor; } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/NullRunAsManager.java b/core/src/main/java/org/springframework/security/access/intercept/NullRunAsManager.java index 2aca1ad642..1cd0bf85d1 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/NullRunAsManager.java +++ b/core/src/main/java/org/springframework/security/access/intercept/NullRunAsManager.java @@ -30,19 +30,20 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ final class NullRunAsManager implements RunAsManager { - // ~ Methods - // ======================================================================================================== - public Authentication buildRunAs(Authentication authentication, Object object, - Collection config) { + @Override + public Authentication buildRunAs(Authentication authentication, Object object, Collection config) { return null; } + @Override public boolean supports(ConfigAttribute attribute) { return false; } + @Override public boolean supports(Class clazz) { return true; } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProvider.java b/core/src/main/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProvider.java index 4a1096a1eb..eb0fa743aa 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProvider.java @@ -16,18 +16,15 @@ package org.springframework.security.access.intercept; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.MessageSource; +import org.springframework.context.MessageSourceAware; +import org.springframework.context.support.MessageSourceAccessor; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.SpringSecurityMessageSource; - -import org.springframework.beans.factory.InitializingBean; - -import org.springframework.context.MessageSource; -import org.springframework.context.MessageSourceAware; -import org.springframework.context.support.MessageSourceAccessor; - import org.springframework.util.Assert; /** @@ -43,49 +40,43 @@ import org.springframework.util.Assert; * If the key does not match, a BadCredentialsException is thrown. *

      */ -public class RunAsImplAuthenticationProvider implements InitializingBean, - AuthenticationProvider, MessageSourceAware { - // ~ Instance fields - // ================================================================================================ +public class RunAsImplAuthenticationProvider implements InitializingBean, AuthenticationProvider, MessageSourceAware { protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private String key; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(key, - "A Key is required and should match that configured for the RunAsManagerImpl"); + Assert.notNull(this.key, "A Key is required and should match that configured for the RunAsManagerImpl"); } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { RunAsUserToken token = (RunAsUserToken) authentication; - - if (token.getKeyHash() == key.hashCode()) { - return authentication; - } - else { - throw new BadCredentialsException(messages.getMessage( - "RunAsImplAuthenticationProvider.incorrectKey", + if (token.getKeyHash() != this.key.hashCode()) { + throw new BadCredentialsException(this.messages.getMessage("RunAsImplAuthenticationProvider.incorrectKey", "The presented RunAsUserToken does not contain the expected key")); } + return authentication; } public String getKey() { - return key; + return this.key; } public void setKey(String key) { this.key = key; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } + @Override public boolean supports(Class authentication) { return RunAsUserToken.class.isAssignableFrom(authentication); } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/RunAsManager.java b/core/src/main/java/org/springframework/security/access/intercept/RunAsManager.java index 8fa6bee9fe..e121b9bb5b 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/RunAsManager.java +++ b/core/src/main/java/org/springframework/security/access/intercept/RunAsManager.java @@ -59,24 +59,19 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface RunAsManager { - // ~ Methods - // ======================================================================================================== /** * Returns a replacement Authentication object for the current secure * object invocation, or null if replacement not required. - * * @param authentication the caller invoking the secure object * @param object the secured object being called * @param attributes the configuration attributes associated with the secure object * being invoked - * * @return a replacement object to be used for duration of the secure object * invocation, or null if the Authentication should be left * as is */ - Authentication buildRunAs(Authentication authentication, Object object, - Collection attributes); + Authentication buildRunAs(Authentication authentication, Object object, Collection attributes); /** * Indicates whether this RunAsManager is able to process the passed @@ -87,10 +82,8 @@ public interface RunAsManager { * AccessDecisionManager and/or RunAsManager and/or * AfterInvocationManager. *

      - * * @param attribute a configuration attribute that has been configured against the * AbstractSecurityInterceptor - * * @return true if this RunAsManager can support the passed * configuration attribute */ @@ -99,10 +92,9 @@ public interface RunAsManager { /** * Indicates whether the RunAsManager implementation is able to provide * run-as replacement for the indicated secure object type. - * * @param clazz the class that is being queried - * * @return true if the implementation can process the indicated class */ boolean supports(Class clazz); + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/RunAsManagerImpl.java b/core/src/main/java/org/springframework/security/access/intercept/RunAsManagerImpl.java index 5efcd78d40..352ff66f9f 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/RunAsManagerImpl.java +++ b/core/src/main/java/org/springframework/security/access/intercept/RunAsManagerImpl.java @@ -54,25 +54,21 @@ import org.springframework.util.Assert; * @author colin sampaleanu */ public class RunAsManagerImpl implements RunAsManager, InitializingBean { - // ~ Instance fields - // ================================================================================================ private String key; + private String rolePrefix = "ROLE_"; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull( - key, + Assert.notNull(this.key, "A Key is required and should match that configured for the RunAsImplAuthenticationProvider"); } + @Override public Authentication buildRunAs(Authentication authentication, Object object, Collection attributes) { List newAuthorities = new ArrayList<>(); - for (ConfigAttribute attribute : attributes) { if (this.supports(attribute)) { GrantedAuthority extraAuthority = new SimpleGrantedAuthority( @@ -80,25 +76,21 @@ public class RunAsManagerImpl implements RunAsManager, InitializingBean { newAuthorities.add(extraAuthority); } } - if (newAuthorities.size() == 0) { return null; } - // Add existing authorities newAuthorities.addAll(authentication.getAuthorities()); - - return new RunAsUserToken(this.key, authentication.getPrincipal(), - authentication.getCredentials(), newAuthorities, - authentication.getClass()); + return new RunAsUserToken(this.key, authentication.getPrincipal(), authentication.getCredentials(), + newAuthorities, authentication.getClass()); } public String getKey() { - return key; + return this.key; } public String getRolePrefix() { - return rolePrefix; + return this.rolePrefix; } public void setKey(String key) { @@ -108,27 +100,26 @@ public class RunAsManagerImpl implements RunAsManager, InitializingBean { /** * Allows the default role prefix of ROLE_ to be overridden. May be set * to an empty value, although this is usually not desirable. - * * @param rolePrefix the new prefix */ public void setRolePrefix(String rolePrefix) { this.rolePrefix = rolePrefix; } + @Override public boolean supports(ConfigAttribute attribute) { - return attribute.getAttribute() != null - && attribute.getAttribute().startsWith("RUN_AS_"); + return attribute.getAttribute() != null && attribute.getAttribute().startsWith("RUN_AS_"); } /** * This implementation supports any type of class, because it does not query the * presented secure object. - * * @param clazz the secure object - * * @return always true */ + @Override public boolean supports(Class clazz) { return true; } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/RunAsUserToken.java b/core/src/main/java/org/springframework/security/access/intercept/RunAsUserToken.java index e050f4433f..1019947440 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/RunAsUserToken.java +++ b/core/src/main/java/org/springframework/security/access/intercept/RunAsUserToken.java @@ -33,16 +33,13 @@ public class RunAsUserToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private final Class originalAuthentication; - private final Object credentials; - private final Object principal; - private final int keyHash; - // ~ Constructors - // =================================================================================================== + private final Object credentials; + + private final Object principal; + + private final int keyHash; public RunAsUserToken(String key, Object principal, Object credentials, Collection authorities, @@ -55,9 +52,6 @@ public class RunAsUserToken extends AbstractAuthenticationToken { setAuthenticated(true); } - // ~ Methods - // ======================================================================================================== - @Override public Object getCredentials() { return this.credentials; @@ -79,10 +73,9 @@ public class RunAsUserToken extends AbstractAuthenticationToken { @Override public String toString() { StringBuilder sb = new StringBuilder(super.toString()); - String className = this.originalAuthentication == null ? null - : this.originalAuthentication.getName(); + String className = (this.originalAuthentication != null) ? this.originalAuthentication.getName() : null; sb.append("; Original Class: ").append(className); - return sb.toString(); } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptor.java b/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptor.java index bb7f8d236a..fadf5baf42 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptor.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptor.java @@ -16,14 +16,14 @@ package org.springframework.security.access.intercept.aopalliance; +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; + import org.springframework.security.access.SecurityMetadataSource; import org.springframework.security.access.intercept.AbstractSecurityInterceptor; import org.springframework.security.access.intercept.InterceptorStatusToken; import org.springframework.security.access.method.MethodSecurityMetadataSource; -import org.aopalliance.intercept.MethodInterceptor; -import org.aopalliance.intercept.MethodInvocation; - /** * Provides security interception of AOP Alliance based method invocations. *

      @@ -37,33 +37,25 @@ import org.aopalliance.intercept.MethodInvocation; * @author Ben Alex * @author Rob Winch */ -public class MethodSecurityInterceptor extends AbstractSecurityInterceptor implements - MethodInterceptor { - // ~ Instance fields - // ================================================================================================ +public class MethodSecurityInterceptor extends AbstractSecurityInterceptor implements MethodInterceptor { private MethodSecurityMetadataSource securityMetadataSource; - // ~ Methods - // ======================================================================================================== - + @Override public Class getSecureObjectClass() { return MethodInvocation.class; } /** * This method should be used to enforce security on a MethodInvocation. - * * @param mi The method being invoked which requires a security decision - * * @return The returned value from the method invocation (possibly modified by the * {@code AfterInvocationManager}). - * * @throws Throwable if any error occurs */ + @Override public Object invoke(MethodInvocation mi) throws Throwable { InterceptorStatusToken token = super.beforeInvocation(mi); - Object result; try { result = mi.proceed(); @@ -78,6 +70,7 @@ public class MethodSecurityInterceptor extends AbstractSecurityInterceptor imple return this.securityMetadataSource; } + @Override public SecurityMetadataSource obtainSecurityMetadataSource() { return this.securityMetadataSource; } @@ -85,4 +78,5 @@ public class MethodSecurityInterceptor extends AbstractSecurityInterceptor imple public void setSecurityMetadataSource(MethodSecurityMetadataSource newSource) { this.securityMetadataSource = newSource; } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisor.java b/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisor.java index d6ba04d102..1d8902d82c 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisor.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisor.java @@ -20,10 +20,10 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.Serializable; import java.lang.reflect.Method; -import java.util.*; import org.aopalliance.aop.Advice; import org.aopalliance.intercept.MethodInterceptor; + import org.springframework.aop.Pointcut; import org.springframework.aop.support.AbstractPointcutAdvisor; import org.springframework.aop.support.StaticMethodMatcherPointcut; @@ -32,14 +32,15 @@ import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.security.access.method.MethodSecurityMetadataSource; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Advisor driven by a {@link MethodSecurityMetadataSource}, used to exclude a * {@link MethodInterceptor} from public (non-secure) methods. *

      * Because the AOP framework caches advice calculations, this is normally faster than just - * letting the MethodInterceptor run and find out itself that it has - * no work to do. + * letting the MethodInterceptor run and find out itself that it has no work + * to do. *

      * This class also allows the use of Spring's {@code DefaultAdvisorAutoProxyCreator}, * which makes configuration easier than setup a ProxyFactoryBean for each @@ -51,21 +52,21 @@ import org.springframework.util.Assert; * @author Ben Alex * @author Luke Taylor */ -public class MethodSecurityMetadataSourceAdvisor extends AbstractPointcutAdvisor - implements BeanFactoryAware { - // ~ Instance fields - // ================================================================================================ +public class MethodSecurityMetadataSourceAdvisor extends AbstractPointcutAdvisor implements BeanFactoryAware { private transient MethodSecurityMetadataSource attributeSource; - private transient MethodInterceptor interceptor; - private final Pointcut pointcut = new MethodSecurityMetadataSourcePointcut(); - private BeanFactory beanFactory; - private final String adviceBeanName; - private final String metadataSourceBeanName; - private transient volatile Object adviceMonitor = new Object(); - // ~ Constructors - // =================================================================================================== + private transient MethodInterceptor interceptor; + + private final Pointcut pointcut = new MethodSecurityMetadataSourcePointcut(); + + private BeanFactory beanFactory; + + private final String adviceBeanName; + + private final String metadataSourceBeanName; + + private transient volatile Object adviceMonitor = new Object(); /** * Alternative constructor for situations where we want the advisor decoupled from the @@ -73,67 +74,59 @@ public class MethodSecurityMetadataSourceAdvisor extends AbstractPointcutAdvisor * instantiation of the interceptor (and hence the AuthenticationManager). See * SEC-773, for example. The metadataSourceBeanName is used rather than a direct * reference to support serialization via a bean factory lookup. - * * @param adviceBeanName name of the MethodSecurityInterceptor bean * @param attributeSource the SecurityMetadataSource (should be the same as the one * used on the interceptor) * @param attributeSourceBeanName the bean name of the attributeSource (required for * serialization) */ - public MethodSecurityMetadataSourceAdvisor(String adviceBeanName, - MethodSecurityMetadataSource attributeSource, String attributeSourceBeanName) { + public MethodSecurityMetadataSourceAdvisor(String adviceBeanName, MethodSecurityMetadataSource attributeSource, + String attributeSourceBeanName) { Assert.notNull(adviceBeanName, "The adviceBeanName cannot be null"); Assert.notNull(attributeSource, "The attributeSource cannot be null"); - Assert.notNull(attributeSourceBeanName, - "The attributeSourceBeanName cannot be null"); - + Assert.notNull(attributeSourceBeanName, "The attributeSourceBeanName cannot be null"); this.adviceBeanName = adviceBeanName; this.attributeSource = attributeSource; this.metadataSourceBeanName = attributeSourceBeanName; } - // ~ Methods - // ======================================================================================================== - + @Override public Pointcut getPointcut() { - return pointcut; + return this.pointcut; } + @Override public Advice getAdvice() { synchronized (this.adviceMonitor) { - if (interceptor == null) { - Assert.notNull(adviceBeanName, - "'adviceBeanName' must be set for use with bean factory lookup."); - Assert.state(beanFactory != null, - "BeanFactory must be set to resolve 'adviceBeanName'"); - interceptor = beanFactory.getBean(this.adviceBeanName, - MethodInterceptor.class); + if (this.interceptor == null) { + Assert.notNull(this.adviceBeanName, "'adviceBeanName' must be set for use with bean factory lookup."); + Assert.state(this.beanFactory != null, "BeanFactory must be set to resolve 'adviceBeanName'"); + this.interceptor = this.beanFactory.getBean(this.adviceBeanName, MethodInterceptor.class); } - return interceptor; + return this.interceptor; } } + @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.beanFactory = beanFactory; } - private void readObject(ObjectInputStream ois) throws IOException, - ClassNotFoundException { + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { ois.defaultReadObject(); - adviceMonitor = new Object(); - attributeSource = beanFactory.getBean(metadataSourceBeanName, + this.adviceMonitor = new Object(); + this.attributeSource = this.beanFactory.getBean(this.metadataSourceBeanName, MethodSecurityMetadataSource.class); } - // ~ Inner Classes - // ================================================================================================== + class MethodSecurityMetadataSourcePointcut extends StaticMethodMatcherPointcut implements Serializable { - class MethodSecurityMetadataSourcePointcut extends StaticMethodMatcherPointcut - implements Serializable { - @SuppressWarnings("unchecked") - public boolean matches(Method m, Class targetClass) { - Collection attributes = attributeSource.getAttributes(m, targetClass); - return attributes != null && !attributes.isEmpty(); + @Override + public boolean matches(Method m, Class targetClass) { + MethodSecurityMetadataSource source = MethodSecurityMetadataSourceAdvisor.this.attributeSource; + return !CollectionUtils.isEmpty(source.getAttributes(m, targetClass)); } + } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/aopalliance/package-info.java b/core/src/main/java/org/springframework/security/access/intercept/aopalliance/package-info.java index cdf2da36eb..ad8a9098ec 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aopalliance/package-info.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aopalliance/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Enforces security for AOP Alliance MethodInvocations, such as via Spring AOP. + * Enforces security for AOP Alliance MethodInvocations, such as via Spring + * AOP. */ package org.springframework.security.access.intercept.aopalliance; - diff --git a/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJCallback.java b/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJCallback.java index 7b76686229..f82855ac33 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJCallback.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJCallback.java @@ -24,8 +24,7 @@ package org.springframework.security.access.intercept.aspectj; * @author Ben Alex */ public interface AspectJCallback { - // ~ Methods - // ======================================================================================================== Object proceedWithObject(); + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptor.java b/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptor.java index c3c8fc25b1..78e450c4d1 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptor.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptor.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.intercept.aspectj; import org.aspectj.lang.JoinPoint; + import org.springframework.security.access.intercept.InterceptorStatusToken; import org.springframework.security.access.intercept.aopalliance.MethodSecurityInterceptor; @@ -36,7 +38,6 @@ public final class AspectJMethodSecurityInterceptor extends MethodSecurityInterc /** * Method that is suitable for user with @Aspect notation. - * * @param jp The AspectJ joint point being invoked which requires a security decision * @return The returned value from the method invocation * @throws Throwable if the invocation throws one @@ -47,17 +48,13 @@ public final class AspectJMethodSecurityInterceptor extends MethodSecurityInterc /** * Method that is suitable for user with traditional AspectJ-code aspects. - * * @param jp The AspectJ joint point being invoked which requires a security decision * @param advisorProceed the advice-defined anonymous class that implements * {@code AspectJCallback} containing a simple {@code return proceed();} statement - * * @return The returned value from the method invocation */ public Object invoke(JoinPoint jp, AspectJCallback advisorProceed) { - InterceptorStatusToken token = super - .beforeInvocation(new MethodInvocationAdapter(jp)); - + InterceptorStatusToken token = super.beforeInvocation(new MethodInvocationAdapter(jp)); Object result; try { result = advisorProceed.proceedWithObject(); @@ -65,7 +62,7 @@ public final class AspectJMethodSecurityInterceptor extends MethodSecurityInterc finally { super.finallyInvocation(token); } - return super.afterInvocation(token, result); } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/aspectj/MethodInvocationAdapter.java b/core/src/main/java/org/springframework/security/access/intercept/aspectj/MethodInvocationAdapter.java index d7dc2ea554..7030f99ee4 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aspectj/MethodInvocationAdapter.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aspectj/MethodInvocationAdapter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.intercept.aspectj; import java.lang.reflect.AccessibleObject; @@ -23,6 +24,8 @@ import org.aspectj.lang.JoinPoint; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.reflect.CodeSignature; +import org.springframework.util.Assert; + /** * Decorates a JoinPoint to allow it to be used with method-security infrastructure * classes which support {@code MethodInvocation} instances. @@ -31,41 +34,36 @@ import org.aspectj.lang.reflect.CodeSignature; * @since 3.0.3 */ public final class MethodInvocationAdapter implements MethodInvocation { + private final ProceedingJoinPoint jp; + private final Method method; + private final Object target; MethodInvocationAdapter(JoinPoint jp) { this.jp = (ProceedingJoinPoint) jp; if (jp.getTarget() != null) { - target = jp.getTarget(); + this.target = jp.getTarget(); } else { // SEC-1295: target may be null if an ITD is in use - target = jp.getSignature().getDeclaringType(); + this.target = jp.getSignature().getDeclaringType(); } String targetMethodName = jp.getStaticPart().getSignature().getName(); - Class[] types = ((CodeSignature) jp.getStaticPart().getSignature()) - .getParameterTypes(); + Class[] types = ((CodeSignature) jp.getStaticPart().getSignature()).getParameterTypes(); Class declaringType = jp.getStaticPart().getSignature().getDeclaringType(); - - method = findMethod(targetMethodName, declaringType, types); - - if (method == null) { - throw new IllegalArgumentException( - "Could not obtain target method from JoinPoint: '" + jp + "'"); - } + this.method = findMethod(targetMethodName, declaringType, types); + Assert.notNull(this.method, () -> "Could not obtain target method from JoinPoint: '" + jp + "'"); } private Method findMethod(String name, Class declaringType, Class[] params) { Method method = null; - try { method = declaringType.getMethod(name, params); } catch (NoSuchMethodException ignored) { } - if (method == null) { try { method = declaringType.getDeclaredMethod(name, params); @@ -73,27 +71,32 @@ public final class MethodInvocationAdapter implements MethodInvocation { catch (NoSuchMethodException ignored) { } } - return method; } + @Override public Method getMethod() { - return method; + return this.method; } + @Override public Object[] getArguments() { - return jp.getArgs(); + return this.jp.getArgs(); } + @Override public AccessibleObject getStaticPart() { - return method; + return this.method; } + @Override public Object getThis() { - return target; + return this.target; } + @Override public Object proceed() throws Throwable { - return jp.proceed(); + return this.jp.proceed(); } + } diff --git a/core/src/main/java/org/springframework/security/access/intercept/aspectj/package-info.java b/core/src/main/java/org/springframework/security/access/intercept/aspectj/package-info.java index 37139ad9b8..cf16a6f843 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/aspectj/package-info.java +++ b/core/src/main/java/org/springframework/security/access/intercept/aspectj/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Enforces security for AspectJ JointPoints, delegating secure object callbacks to the calling aspect. + * Enforces security for AspectJ JointPoints, delegating secure object + * callbacks to the calling aspect. */ package org.springframework.security.access.intercept.aspectj; - diff --git a/core/src/main/java/org/springframework/security/access/intercept/package-info.java b/core/src/main/java/org/springframework/security/access/intercept/package-info.java index e2c7501807..68b24a7798 100644 --- a/core/src/main/java/org/springframework/security/access/intercept/package-info.java +++ b/core/src/main/java/org/springframework/security/access/intercept/package-info.java @@ -13,25 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Abstract level security interception classes which are responsible for enforcing the * configured security constraints for a secure object. *

      - * A secure object is a term frequently used throughout the security - * system. It does not refer to a business object that is being - * secured, but instead refers to some infrastructure object that can have - * security facilities provided for it by Spring Security. - * For example, one secure object would be MethodInvocation, - * whilst another would be HTTP - * {@code org.springframework.security.web.FilterInvocation}. Note these are - * infrastructure objects and their design allows them to represent a large - * variety of actual resources that might need to be secured, such as business - * objects or HTTP request URLs. - *

      Each secure object typically has its own interceptor package. - * Each package usually includes a concrete security interceptor (which subclasses - * {@link org.springframework.security.access.intercept.AbstractSecurityInterceptor}) and an - * appropriate {@link org.springframework.security.access.SecurityMetadataSource} - * for the type of resources the secure object represents. + * A secure object is a term frequently used throughout the security system. It + * does not refer to a business object that is being secured, but instead refers to + * some infrastructure object that can have security facilities provided for it by Spring + * Security. For example, one secure object would be MethodInvocation, whilst + * another would be HTTP {@code org.springframework.security.web.FilterInvocation}. Note + * these are infrastructure objects and their design allows them to represent a large + * variety of actual resources that might need to be secured, such as business objects or + * HTTP request URLs. + *

      + * Each secure object typically has its own interceptor package. Each package usually + * includes a concrete security interceptor (which subclasses + * {@link org.springframework.security.access.intercept.AbstractSecurityInterceptor}) and + * an appropriate {@link org.springframework.security.access.SecurityMetadataSource} for + * the type of resources the secure object represents. */ package org.springframework.security.access.intercept; - diff --git a/core/src/main/java/org/springframework/security/access/method/AbstractFallbackMethodSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/method/AbstractFallbackMethodSecurityMetadataSource.java index ef65794b08..3b3be3effb 100644 --- a/core/src/main/java/org/springframework/security/access/method/AbstractFallbackMethodSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/method/AbstractFallbackMethodSecurityMetadataSource.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.method; import java.lang.reflect.Method; -import java.util.*; +import java.util.Collection; +import java.util.Collections; import org.springframework.aop.support.AopUtils; import org.springframework.security.access.ConfigAttribute; @@ -42,9 +44,9 @@ import org.springframework.security.access.ConfigAttribute; * @author Luke taylor * @since 2.0 */ -public abstract class AbstractFallbackMethodSecurityMetadataSource extends - AbstractMethodSecurityMetadataSource { +public abstract class AbstractFallbackMethodSecurityMetadataSource extends AbstractMethodSecurityMetadataSource { + @Override public Collection getAttributes(Method method, Class targetClass) { // The method may be on an interface, but we need attributes from the target // class. @@ -55,13 +57,11 @@ public abstract class AbstractFallbackMethodSecurityMetadataSource extends if (attr != null) { return attr; } - // Second try is the config attribute on the target class. attr = findAttributes(specificMethod.getDeclaringClass()); if (attr != null) { return attr; } - if (specificMethod != method || targetClass == null) { // Fallback is to look at the original method. attr = findAttributes(method, method.getDeclaringClass()); @@ -83,13 +83,11 @@ public abstract class AbstractFallbackMethodSecurityMetadataSource extends * may wish to provide advanced capabilities related to method metadata being * "registered" against a method even if the target class does not declare the method * (i.e. the subclass may only inherit the method). - * * @param method the method for the current invocation (never null) * @param targetClass the target class for the invocation (may be null) * @return the security metadata (or null if no metadata applies) */ - protected abstract Collection findAttributes(Method method, - Class targetClass); + protected abstract Collection findAttributes(Method method, Class targetClass); /** * Obtains the security metadata registered against the specified class. @@ -99,7 +97,6 @@ public abstract class AbstractFallbackMethodSecurityMetadataSource extends * should NOT aggregate metadata for each method registered against a class, as the * abstract superclass will separate invoke {@link #findAttributes(Method, Class)} for * individual methods as appropriate. - * * @param clazz the target class for the invocation (never null) * @return the security metadata (or null if no metadata applies) */ diff --git a/core/src/main/java/org/springframework/security/access/method/AbstractMethodSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/method/AbstractMethodSecurityMetadataSource.java index dcb4d987e3..3ed17acc70 100644 --- a/core/src/main/java/org/springframework/security/access/method/AbstractMethodSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/method/AbstractMethodSecurityMetadataSource.java @@ -21,6 +21,7 @@ import java.util.Collection; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.aop.framework.AopProxyUtils; import org.springframework.security.access.ConfigAttribute; @@ -31,22 +32,18 @@ import org.springframework.security.access.ConfigAttribute; * @author Ben Alex * @author Luke Taylor */ -public abstract class AbstractMethodSecurityMetadataSource implements - MethodSecurityMetadataSource { +public abstract class AbstractMethodSecurityMetadataSource implements MethodSecurityMetadataSource { protected final Log logger = LogFactory.getLog(getClass()); - // ~ Methods - // ======================================================================================================== - + @Override public final Collection getAttributes(Object object) { if (object instanceof MethodInvocation) { MethodInvocation mi = (MethodInvocation) object; Object target = mi.getThis(); Class targetClass = null; - if (target != null) { - targetClass = target instanceof Class ? (Class) target + targetClass = (target instanceof Class) ? (Class) target : AopProxyUtils.ultimateTargetClass(target); } Collection attrs = getAttributes(mi.getMethod(), targetClass); @@ -58,11 +55,12 @@ public abstract class AbstractMethodSecurityMetadataSource implements } return attrs; } - throw new IllegalArgumentException("Object must be a non-null MethodInvocation"); } + @Override public final boolean supports(Class clazz) { return (MethodInvocation.class.isAssignableFrom(clazz)); } + } diff --git a/core/src/main/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSource.java index 9d85983e33..2591d099fc 100644 --- a/core/src/main/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSource.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.method; import java.lang.reflect.Method; @@ -24,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.ConfigAttribute; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -36,59 +38,43 @@ import org.springframework.util.ObjectUtils; * @author Ben Alex * @author Luke Taylor */ -public final class DelegatingMethodSecurityMetadataSource extends - AbstractMethodSecurityMetadataSource { - private final static List NULL_CONFIG_ATTRIBUTE = Collections - .emptyList(); +public final class DelegatingMethodSecurityMetadataSource extends AbstractMethodSecurityMetadataSource { + + private static final List NULL_CONFIG_ATTRIBUTE = Collections.emptyList(); private final List methodSecurityMetadataSources; + private final Map> attributeCache = new HashMap<>(); - // ~ Constructor - // ==================================================================================================== - - public DelegatingMethodSecurityMetadataSource( - List methodSecurityMetadataSources) { - Assert.notNull(methodSecurityMetadataSources, - "MethodSecurityMetadataSources cannot be null"); + public DelegatingMethodSecurityMetadataSource(List methodSecurityMetadataSources) { + Assert.notNull(methodSecurityMetadataSources, "MethodSecurityMetadataSources cannot be null"); this.methodSecurityMetadataSources = methodSecurityMetadataSources; } - // ~ Methods - // ======================================================================================================== - + @Override public Collection getAttributes(Method method, Class targetClass) { DefaultCacheKey cacheKey = new DefaultCacheKey(method, targetClass); - synchronized (attributeCache) { - Collection cached = attributeCache.get(cacheKey); + synchronized (this.attributeCache) { + Collection cached = this.attributeCache.get(cacheKey); // Check for canonical value indicating there is no config attribute, - if (cached != null) { return cached; } - // No cached value, so query the sources to find a result Collection attributes = null; - for (MethodSecurityMetadataSource s : methodSecurityMetadataSources) { + for (MethodSecurityMetadataSource s : this.methodSecurityMetadataSources) { attributes = s.getAttributes(method, targetClass); if (attributes != null && !attributes.isEmpty()) { break; } } - // Put it in the cache. if (attributes == null || attributes.isEmpty()) { this.attributeCache.put(cacheKey, NULL_CONFIG_ATTRIBUTE); return NULL_CONFIG_ATTRIBUTE; } - - if (logger.isDebugEnabled()) { - logger.debug("Caching method [" + cacheKey + "] with attributes " - + attributes); - } - + this.logger.debug(LogMessage.format("Caching method [%s] with attributes %s", cacheKey, attributes)); this.attributeCache.put(cacheKey, attributes); - return attributes; } } @@ -96,7 +82,7 @@ public final class DelegatingMethodSecurityMetadataSource extends @Override public Collection getAllConfigAttributes() { Set set = new HashSet<>(); - for (MethodSecurityMetadataSource s : methodSecurityMetadataSources) { + for (MethodSecurityMetadataSource s : this.methodSecurityMetadataSources) { Collection attrs = s.getAllConfigAttributes(); if (attrs != null) { set.addAll(attrs); @@ -106,14 +92,13 @@ public final class DelegatingMethodSecurityMetadataSource extends } public List getMethodSecurityMetadataSources() { - return methodSecurityMetadataSources; + return this.methodSecurityMetadataSources; } - // ~ Inner Classes - // ================================================================================================== - private static class DefaultCacheKey { + private final Method method; + private final Class targetClass; DefaultCacheKey(Method method, Class targetClass) { @@ -124,20 +109,21 @@ public final class DelegatingMethodSecurityMetadataSource extends @Override public boolean equals(Object other) { DefaultCacheKey otherKey = (DefaultCacheKey) other; - return (this.method.equals(otherKey.method) && ObjectUtils.nullSafeEquals( - this.targetClass, otherKey.targetClass)); + return (this.method.equals(otherKey.method) + && ObjectUtils.nullSafeEquals(this.targetClass, otherKey.targetClass)); } @Override public int hashCode() { - return this.method.hashCode() * 21 - + (this.targetClass != null ? this.targetClass.hashCode() : 0); + return this.method.hashCode() * 21 + ((this.targetClass != null) ? this.targetClass.hashCode() : 0); } @Override public String toString() { - return "CacheKey[" + (targetClass == null ? "-" : targetClass.getName()) - + "; " + method + "]"; + String targetClassName = (this.targetClass != null) ? this.targetClass.getName() : "-"; + return "CacheKey[" + targetClassName + "; " + this.method + "]"; } + } + } diff --git a/core/src/main/java/org/springframework/security/access/method/MapBasedMethodSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/method/MapBasedMethodSecurityMetadataSource.java index df79e47b1e..4fb00cc78e 100644 --- a/core/src/main/java/org/springframework/security/access/method/MapBasedMethodSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/method/MapBasedMethodSecurityMetadataSource.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Set; import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.ConfigAttribute; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -42,22 +43,21 @@ import org.springframework.util.ClassUtils; * @author Ben Alex * @since 2.0 */ -public class MapBasedMethodSecurityMetadataSource extends - AbstractFallbackMethodSecurityMetadataSource implements BeanClassLoaderAware { +public class MapBasedMethodSecurityMetadataSource extends AbstractFallbackMethodSecurityMetadataSource + implements BeanClassLoaderAware { - // ~ Instance fields - // ================================================================================================ private ClassLoader beanClassLoader = ClassUtils.getDefaultClassLoader(); - /** Map from RegisteredMethod to ConfigAttribute list */ + /** + * Map from RegisteredMethod to ConfigAttribute list + */ protected final Map> methodMap = new HashMap<>(); - /** Map from RegisteredMethod to name pattern used for registration */ + /** + * Map from RegisteredMethod to name pattern used for registration + */ private final Map nameMap = new HashMap<>(); - // ~ Methods - // ======================================================================================================== - public MapBasedMethodSecurityMetadataSource() { } @@ -65,8 +65,7 @@ public class MapBasedMethodSecurityMetadataSource extends * Creates the MapBasedMethodSecurityMetadataSource from a * @param methodMap map of method names to ConfigAttributes. */ - public MapBasedMethodSecurityMetadataSource( - Map> methodMap) { + public MapBasedMethodSecurityMetadataSource(Map> methodMap) { for (Map.Entry> entry : methodMap.entrySet()) { addSecureMethod(entry.getKey(), entry.getValue()); } @@ -85,20 +84,17 @@ public class MapBasedMethodSecurityMetadataSource extends * applicable. */ @Override - protected Collection findAttributes(Method method, - Class targetClass) { + protected Collection findAttributes(Method method, Class targetClass) { if (targetClass == null) { return null; } - return findAttributesSpecifiedAgainst(method, targetClass); } - private List findAttributesSpecifiedAgainst(Method method, - Class clazz) { + private List findAttributesSpecifiedAgainst(Method method, Class clazz) { RegisteredMethod registeredMethod = new RegisteredMethod(method, clazz); - if (methodMap.containsKey(registeredMethod)) { - return methodMap.get(registeredMethod); + if (this.methodMap.containsKey(registeredMethod)) { + return this.methodMap.get(registeredMethod); } // Search superclass if (clazz.getSuperclass() != null) { @@ -110,82 +106,61 @@ public class MapBasedMethodSecurityMetadataSource extends /** * Add configuration attributes for a secure method. Method names can end or start * with * for matching multiple methods. - * * @param name type and method name, separated by a dot * @param attr the security attributes associated with the method */ private void addSecureMethod(String name, List attr) { int lastDotIndex = name.lastIndexOf("."); - - if (lastDotIndex == -1) { - throw new IllegalArgumentException("'" + name - + "' is not a valid method name: format is FQN.methodName"); - } - + Assert.isTrue(lastDotIndex != -1, () -> "'" + name + "' is not a valid method name: format is FQN.methodName"); String methodName = name.substring(lastDotIndex + 1); Assert.hasText(methodName, () -> "Method not found for '" + name + "'"); - String typeName = name.substring(0, lastDotIndex); Class type = ClassUtils.resolveClassName(typeName, this.beanClassLoader); - addSecureMethod(type, methodName, attr); } /** * Add configuration attributes for a secure method. Mapped method names can end or * start with * for matching multiple methods. - * * @param javaType target interface or class the security configuration attribute * applies to * @param mappedName mapped method name, which the javaType has declared or inherited * @param attr required authorities associated with the method */ - public void addSecureMethod(Class javaType, String mappedName, - List attr) { + public void addSecureMethod(Class javaType, String mappedName, List attr) { String name = javaType.getName() + '.' + mappedName; - - if (logger.isDebugEnabled()) { - logger.debug("Request to add secure method [" + name + "] with attributes [" - + attr + "]"); - } - + this.logger.debug(LogMessage.format("Request to add secure method [%s] with attributes [%s]", name, attr)); Method[] methods = javaType.getMethods(); List matchingMethods = new ArrayList<>(); - - for (Method m : methods) { - if (m.getName().equals(mappedName) || isMatch(m.getName(), mappedName)) { - matchingMethods.add(m); + for (Method method : methods) { + if (method.getName().equals(mappedName) || isMatch(method.getName(), mappedName)) { + matchingMethods.add(method); } } + Assert.notEmpty(matchingMethods, () -> "Couldn't find method '" + mappedName + "' on '" + javaType + "'"); + registerAllMatchingMethods(javaType, attr, name, matchingMethods); + } - if (matchingMethods.isEmpty()) { - throw new IllegalArgumentException("Couldn't find method '" + mappedName - + "' on '" + javaType + "'"); - } - - // register all matching methods + private void registerAllMatchingMethods(Class javaType, List attr, String name, + List matchingMethods) { for (Method method : matchingMethods) { RegisteredMethod registeredMethod = new RegisteredMethod(method, javaType); String regMethodName = this.nameMap.get(registeredMethod); - - if ((regMethodName == null) - || (!regMethodName.equals(name) && (regMethodName.length() <= name - .length()))) { + if ((regMethodName == null) || (!regMethodName.equals(name) && (regMethodName.length() <= name.length()))) { // no already registered method name, or more specific - // method name specification now -> (re-)register method + // method name specification (now) -> (re-)register method if (regMethodName != null) { - logger.debug("Replacing attributes for secure method [" + method - + "]: current name [" + name + "] is more specific than [" - + regMethodName + "]"); + this.logger.debug(LogMessage.format( + "Replacing attributes for secure method [%s]: current name [%s] is more specific than [%s]", + method, name, regMethodName)); } - this.nameMap.put(registeredMethod, name); addSecureMethod(registeredMethod, attr); } else { - logger.debug("Keeping attributes for secure method [" + method - + "]: current name [" + name + "] is not more specific than [" - + regMethodName + "]"); + this.logger.debug(LogMessage.format( + "Keeping attributes for secure method [%s]: current name [%s] is not more specific than [%s]", + method, name, regMethodName)); } } } @@ -199,66 +174,49 @@ public class MapBasedMethodSecurityMetadataSource extends *

      * This method should only be called during initialization of the {@code BeanFactory}. */ - public void addSecureMethod(Class javaType, Method method, - List attr) { + public void addSecureMethod(Class javaType, Method method, List attr) { RegisteredMethod key = new RegisteredMethod(method, javaType); - - if (methodMap.containsKey(key)) { - logger.debug("Method [" + method - + "] is already registered with attributes [" + methodMap.get(key) - + "]"); + if (this.methodMap.containsKey(key)) { + this.logger.debug(LogMessage.format("Method [%s] is already registered with attributes [%s]", method, + this.methodMap.get(key))); return; } - - methodMap.put(key, attr); + this.methodMap.put(key, attr); } /** * Add configuration attributes for a secure method. - * * @param method the method to be secured * @param attr required authorities associated with the method */ private void addSecureMethod(RegisteredMethod method, List attr) { Assert.notNull(method, "RegisteredMethod required"); Assert.notNull(attr, "Configuration attribute required"); - if (logger.isInfoEnabled()) { - logger.info("Adding secure method [" + method + "] with attributes [" + attr - + "]"); - } + this.logger.info(LogMessage.format("Adding secure method [%s] with attributes [%s]", method, attr)); this.methodMap.put(method, attr); } /** * Obtains the configuration attributes explicitly defined against this bean. - * * @return the attributes explicitly defined against this bean */ @Override public Collection getAllConfigAttributes() { Set allAttributes = new HashSet<>(); - - for (List attributeList : methodMap.values()) { - allAttributes.addAll(attributeList); - } - + this.methodMap.values().forEach(allAttributes::addAll); return allAttributes; } /** * Return if the given method name matches the mapped name. The default implementation * checks for "xxx" and "xxx" matches. - * * @param methodName the method name of the class * @param mappedName the name in the descriptor - * * @return if the names match */ private boolean isMatch(String methodName, String mappedName) { - return (mappedName.endsWith("*") && methodName.startsWith(mappedName.substring(0, - mappedName.length() - 1))) - || (mappedName.startsWith("*") && methodName.endsWith(mappedName - .substring(1, mappedName.length()))); + return (mappedName.endsWith("*") && methodName.startsWith(mappedName.substring(0, mappedName.length() - 1))) + || (mappedName.startsWith("*") && methodName.endsWith(mappedName.substring(1, mappedName.length()))); } @Override @@ -271,7 +229,7 @@ public class MapBasedMethodSecurityMetadataSource extends * @return map size (for unit tests and diagnostics) */ public int getMethodMapSize() { - return methodMap.size(); + return this.methodMap.size(); } /** @@ -284,7 +242,9 @@ public class MapBasedMethodSecurityMetadataSource extends * we're invoking against and the Method will provide details of the declared class. */ private static class RegisteredMethod { + private final Method method; + private final Class registeredJavaType; RegisteredMethod(Method method, Class registeredJavaType) { @@ -301,22 +261,21 @@ public class MapBasedMethodSecurityMetadataSource extends } if (obj != null && obj instanceof RegisteredMethod) { RegisteredMethod rhs = (RegisteredMethod) obj; - return method.equals(rhs.method) - && registeredJavaType.equals(rhs.registeredJavaType); + return this.method.equals(rhs.method) && this.registeredJavaType.equals(rhs.registeredJavaType); } return false; } @Override public int hashCode() { - return method.hashCode() * registeredJavaType.hashCode(); + return this.method.hashCode() * this.registeredJavaType.hashCode(); } @Override public String toString() { - return "RegisteredMethod[" + registeredJavaType.getName() + "; " + method - + "]"; + return "RegisteredMethod[" + this.registeredJavaType.getName() + "; " + this.method + "]"; } + } } diff --git a/core/src/main/java/org/springframework/security/access/method/MethodSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/method/MethodSecurityMetadataSource.java index adfc739ed7..5cfb3ff9ca 100644 --- a/core/src/main/java/org/springframework/security/access/method/MethodSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/method/MethodSecurityMetadataSource.java @@ -29,5 +29,7 @@ import org.springframework.security.access.SecurityMetadataSource; * @author Ben Alex */ public interface MethodSecurityMetadataSource extends SecurityMetadataSource { + Collection getAttributes(Method method, Class targetClass); + } diff --git a/core/src/main/java/org/springframework/security/access/method/P.java b/core/src/main/java/org/springframework/security/access/method/P.java index 0d8139e455..53fc1a71c2 100644 --- a/core/src/main/java/org/springframework/security/access/method/P.java +++ b/core/src/main/java/org/springframework/security/access/method/P.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.method; import java.lang.annotation.Documented; @@ -45,4 +46,5 @@ public @interface P { * @return */ String value(); -} \ No newline at end of file + +} diff --git a/core/src/main/java/org/springframework/security/access/method/package-info.java b/core/src/main/java/org/springframework/security/access/method/package-info.java index b90eb4948e..d7cf84ccdc 100644 --- a/core/src/main/java/org/springframework/security/access/method/package-info.java +++ b/core/src/main/java/org/springframework/security/access/method/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Provides {@code SecurityMetadataSource} implementations for securing Java method invocations via different - * AOP libraries. + * Provides {@code SecurityMetadataSource} implementations for securing Java method + * invocations via different AOP libraries. */ package org.springframework.security.access.method; - diff --git a/core/src/main/java/org/springframework/security/access/package-info.java b/core/src/main/java/org/springframework/security/access/package-info.java index 99908b931a..2055530ce8 100644 --- a/core/src/main/java/org/springframework/security/access/package-info.java +++ b/core/src/main/java/org/springframework/security/access/package-info.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Core access-control related code, including security metadata related classes, interception code, access control - * annotations, EL support and voter-based implementations of the central - * {@link org.springframework.security.access.AccessDecisionManager AccessDecisionManager} interface. + * Core access-control related code, including security metadata related classes, + * interception code, access control annotations, EL support and voter-based + * implementations of the central + * {@link org.springframework.security.access.AccessDecisionManager AccessDecisionManager} + * interface. */ package org.springframework.security.access; - diff --git a/core/src/main/java/org/springframework/security/access/prepost/PostAuthorize.java b/core/src/main/java/org/springframework/security/access/prepost/PostAuthorize.java index 18e5abe644..18a60ef88f 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PostAuthorize.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PostAuthorize.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.lang.annotation.Documented; @@ -34,9 +35,11 @@ import java.lang.annotation.Target; @Inherited @Documented public @interface PostAuthorize { + /** * @return the Spring-EL expression to be evaluated after invoking the protected * method */ String value(); + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PostFilter.java b/core/src/main/java/org/springframework/security/access/prepost/PostFilter.java index 73efa4ad3b..dc5d35b320 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PostFilter.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PostFilter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.lang.annotation.Documented; @@ -34,9 +35,11 @@ import java.lang.annotation.Target; @Inherited @Documented public @interface PostFilter { + /** * @return the Spring-EL expression to be evaluated after invoking the protected * method */ String value(); + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAdviceProvider.java b/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAdviceProvider.java index 6c171fb7e1..b63a39e8c7 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAdviceProvider.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAdviceProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.util.Collection; @@ -20,6 +21,7 @@ import java.util.Collection; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AfterInvocationProvider; import org.springframework.security.access.ConfigAttribute; @@ -35,6 +37,7 @@ import org.springframework.security.core.Authentication; * @since 3.0 */ public class PostInvocationAdviceProvider implements AfterInvocationProvider { + protected final Log logger = LogFactory.getLog(getClass()); private final PostInvocationAuthorizationAdvice postAdvice; @@ -43,36 +46,34 @@ public class PostInvocationAdviceProvider implements AfterInvocationProvider { this.postAdvice = postAdvice; } - public Object decide(Authentication authentication, Object object, - Collection config, Object returnedObject) - throws AccessDeniedException { - - PostInvocationAttribute pia = findPostInvocationAttribute(config); - - if (pia == null) { + @Override + public Object decide(Authentication authentication, Object object, Collection config, + Object returnedObject) throws AccessDeniedException { + PostInvocationAttribute postInvocationAttribute = findPostInvocationAttribute(config); + if (postInvocationAttribute == null) { return returnedObject; } - - return postAdvice.after(authentication, (MethodInvocation) object, pia, + return this.postAdvice.after(authentication, (MethodInvocation) object, postInvocationAttribute, returnedObject); } - private PostInvocationAttribute findPostInvocationAttribute( - Collection config) { + private PostInvocationAttribute findPostInvocationAttribute(Collection config) { for (ConfigAttribute attribute : config) { if (attribute instanceof PostInvocationAttribute) { return (PostInvocationAttribute) attribute; } } - return null; } + @Override public boolean supports(ConfigAttribute attribute) { return attribute instanceof PostInvocationAttribute; } + @Override public boolean supports(Class clazz) { return clazz.isAssignableFrom(MethodInvocation.class); } + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAttribute.java b/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAttribute.java index 9608345f3e..c5f6d280fa 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAttribute.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAttribute.java @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import org.springframework.security.access.ConfigAttribute; /** - * Marker interface for attributes which are created from combined @PostFilter and @PostAuthorize - * annotations. + * Marker interface for attributes which are created from combined @PostFilter + * and @PostAuthorize annotations. *

      * Consumed by a {@link PostInvocationAuthorizationAdvice}. * diff --git a/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAuthorizationAdvice.java b/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAuthorizationAdvice.java index 76166982ab..31501ce000 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAuthorizationAdvice.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PostInvocationAuthorizationAdvice.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.core.Authentication; @@ -28,7 +30,7 @@ import org.springframework.security.core.Authentication; */ public interface PostInvocationAuthorizationAdvice extends AopInfrastructureBean { - Object after(Authentication authentication, MethodInvocation mi, - PostInvocationAttribute pia, Object returnedObject) + Object after(Authentication authentication, MethodInvocation mi, PostInvocationAttribute pia, Object returnedObject) throws AccessDeniedException; + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PreAuthorize.java b/core/src/main/java/org/springframework/security/access/prepost/PreAuthorize.java index 65cc07a60f..ba71103053 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PreAuthorize.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PreAuthorize.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.lang.annotation.Documented; @@ -34,9 +35,11 @@ import java.lang.annotation.Target; @Inherited @Documented public @interface PreAuthorize { + /** * @return the Spring-EL expression to be evaluated before invoking the protected * method */ String value(); + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PreFilter.java b/core/src/main/java/org/springframework/security/access/prepost/PreFilter.java index 0432936576..7736714395 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PreFilter.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PreFilter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.lang.annotation.Documented; @@ -46,6 +47,7 @@ import java.lang.annotation.Target; @Inherited @Documented public @interface PreFilter { + /** * @return the Spring-EL expression to be evaluated before invoking the protected * method @@ -58,4 +60,5 @@ public @interface PreFilter { * attribute can be omitted. */ String filterTarget() default ""; + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAttribute.java b/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAttribute.java index a0fecd9e49..1fa478c32e 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAttribute.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAttribute.java @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import org.springframework.security.access.ConfigAttribute; /** - * Marker interface for attributes which are created from combined @PreFilter and @PreAuthorize - * annotations. + * Marker interface for attributes which are created from combined @PreFilter + * and @PreAuthorize annotations. *

      * Consumed by a {@link PreInvocationAuthorizationAdvice}. * diff --git a/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdvice.java b/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdvice.java index 72bd17009b..b0647568e7 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdvice.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdvice.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.security.core.Authentication; @@ -30,14 +32,13 @@ public interface PreInvocationAuthorizationAdvice extends AopInfrastructureBean /** * The "before" advice which should be executed to perform any filtering necessary and * to decide whether the method call is authorised. - * * @param authentication the information on the principal on whose account the * decision should be made * @param mi the method invocation being attempted - * @param preInvocationAttribute the attribute built from the @PreFilter and @PostFilter - * annotations. + * @param preInvocationAttribute the attribute built from the @PreFilter + * and @PostFilter annotations. * @return true if authorised, false otherwise */ - boolean before(Authentication authentication, MethodInvocation mi, - PreInvocationAttribute preInvocationAttribute); + boolean before(Authentication authentication, MethodInvocation mi, PreInvocationAttribute preInvocationAttribute); + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoter.java b/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoter.java index 4432fbc6cf..2b9979290a 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoter.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.util.Collection; @@ -20,6 +21,7 @@ import java.util.Collection; import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; @@ -37,8 +39,8 @@ import org.springframework.security.core.Authentication; * @author Luke Taylor * @since 3.0 */ -public class PreInvocationAuthorizationAdviceVoter implements - AccessDecisionVoter { +public class PreInvocationAuthorizationAdviceVoter implements AccessDecisionVoter { + protected final Log logger = LogFactory.getLog(getClass()); private final PreInvocationAuthorizationAdvice preAdvice; @@ -47,41 +49,35 @@ public class PreInvocationAuthorizationAdviceVoter implements this.preAdvice = pre; } + @Override public boolean supports(ConfigAttribute attribute) { return attribute instanceof PreInvocationAttribute; } + @Override public boolean supports(Class clazz) { return MethodInvocation.class.isAssignableFrom(clazz); } - public int vote(Authentication authentication, MethodInvocation method, - Collection attributes) { - + @Override + public int vote(Authentication authentication, MethodInvocation method, Collection attributes) { // Find prefilter and preauth (or combined) attributes - // if both null, abstain - // else call advice with them - + // if both null, abstain else call advice with them PreInvocationAttribute preAttr = findPreInvocationAttribute(attributes); - if (preAttr == null) { // No expression based metadata, so abstain return ACCESS_ABSTAIN; } - - boolean allowed = preAdvice.before(authentication, method, preAttr); - - return allowed ? ACCESS_GRANTED : ACCESS_DENIED; + return this.preAdvice.before(authentication, method, preAttr) ? ACCESS_GRANTED : ACCESS_DENIED; } - private PreInvocationAttribute findPreInvocationAttribute( - Collection config) { + private PreInvocationAttribute findPreInvocationAttribute(Collection config) { for (ConfigAttribute attribute : config) { if (attribute instanceof PreInvocationAttribute) { return (PreInvocationAttribute) attribute; } } - return null; } + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PrePostAdviceReactiveMethodInterceptor.java b/core/src/main/java/org/springframework/security/access/prepost/PrePostAdviceReactiveMethodInterceptor.java index 98ef8d43e1..4ed5e039fc 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PrePostAdviceReactiveMethodInterceptor.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PrePostAdviceReactiveMethodInterceptor.java @@ -16,9 +16,16 @@ package org.springframework.security.access.prepost; +import java.lang.reflect.Method; +import java.util.Collection; + import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; import org.reactivestreams.Publisher; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.method.MethodSecurityMetadataSource; @@ -28,23 +35,18 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.util.Assert; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.lang.reflect.Method; -import java.util.Collection; /** - * A {@link MethodInterceptor} that supports {@link PreAuthorize} and {@link PostAuthorize} for methods that return - * {@link Mono} or {@link Flux} + * A {@link MethodInterceptor} that supports {@link PreAuthorize} and + * {@link PostAuthorize} for methods that return {@link Mono} or {@link Flux} * * @author Rob Winch * @since 5.0 */ public class PrePostAdviceReactiveMethodInterceptor implements MethodInterceptor { + private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", - AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); private final MethodSecurityMetadataSource attributeSource; @@ -58,11 +60,12 @@ public class PrePostAdviceReactiveMethodInterceptor implements MethodInterceptor * @param preInvocationAdvice the {@link PreInvocationAuthorizationAdvice} to use * @param postInvocationAdvice the {@link PostInvocationAuthorizationAdvice} to use */ - public PrePostAdviceReactiveMethodInterceptor(MethodSecurityMetadataSource attributeSource, PreInvocationAuthorizationAdvice preInvocationAdvice, PostInvocationAuthorizationAdvice postInvocationAdvice) { + public PrePostAdviceReactiveMethodInterceptor(MethodSecurityMetadataSource attributeSource, + PreInvocationAuthorizationAdvice preInvocationAdvice, + PostInvocationAuthorizationAdvice postInvocationAdvice) { Assert.notNull(attributeSource, "attributeSource cannot be null"); Assert.notNull(preInvocationAdvice, "preInvocationAdvice cannot be null"); Assert.notNull(postInvocationAdvice, "postInvocationAdvice cannot be null"); - this.attributeSource = attributeSource; this.preInvocationAdvice = preInvocationAdvice; this.postAdvice = postInvocationAdvice; @@ -72,70 +75,59 @@ public class PrePostAdviceReactiveMethodInterceptor implements MethodInterceptor public Object invoke(final MethodInvocation invocation) { Method method = invocation.getMethod(); Class returnType = method.getReturnType(); - if (!Publisher.class.isAssignableFrom(returnType)) { - throw new IllegalStateException("The returnType " + returnType + " on " + method + " must return an instance of org.reactivestreams.Publisher (i.e. Mono / Flux) in order to support Reactor Context"); - } + Assert.state(Publisher.class.isAssignableFrom(returnType), + () -> "The returnType " + returnType + " on " + method + + " must return an instance of org.reactivestreams.Publisher " + + "(i.e. Mono / Flux) in order to support Reactor Context"); Class targetClass = invocation.getThis().getClass(); - Collection attributes = this.attributeSource - .getAttributes(method, targetClass); - + Collection attributes = this.attributeSource.getAttributes(method, targetClass); PreInvocationAttribute preAttr = findPreInvocationAttribute(attributes); + // @formatter:off Mono toInvoke = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(this.anonymous) - .filter( auth -> this.preInvocationAdvice.before(auth, invocation, preAttr)) - .switchIfEmpty(Mono.defer(() -> Mono.error(new AccessDeniedException("Denied")))); - - + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(this.anonymous) + .filter((auth) -> this.preInvocationAdvice.before(auth, invocation, preAttr)) + .switchIfEmpty(Mono.defer(() -> Mono.error(new AccessDeniedException("Denied")))); + // @formatter:on PostInvocationAttribute attr = findPostInvocationAttribute(attributes); - if (Mono.class.isAssignableFrom(returnType)) { - return toInvoke - .flatMap( auth -> this.>proceed(invocation) - .map( r -> attr == null ? r : this.postAdvice.after(auth, invocation, attr, r)) - ); + return toInvoke.flatMap((auth) -> PrePostAdviceReactiveMethodInterceptor.>proceed(invocation) + .map((r) -> (attr != null) ? this.postAdvice.after(auth, invocation, attr, r) : r)); } - if (Flux.class.isAssignableFrom(returnType)) { - return toInvoke - .flatMapMany( auth -> this.>proceed(invocation) - .map( r -> attr == null ? r : this.postAdvice.after(auth, invocation, attr, r)) - ); + return toInvoke.flatMapMany((auth) -> PrePostAdviceReactiveMethodInterceptor.>proceed(invocation) + .map((r) -> (attr != null) ? this.postAdvice.after(auth, invocation, attr, r) : r)); } - - return toInvoke - .flatMapMany( auth -> Flux.from(this.>proceed(invocation)) - .map( r -> attr == null ? r : this.postAdvice.after(auth, invocation, attr, r)) - ); + return toInvoke.flatMapMany( + (auth) -> Flux.from(PrePostAdviceReactiveMethodInterceptor.>proceed(invocation)) + .map((r) -> (attr != null) ? this.postAdvice.after(auth, invocation, attr, r) : r)); } private static > T proceed(final MethodInvocation invocation) { try { return (T) invocation.proceed(); - } catch(Throwable throwable) { + } + catch (Throwable throwable) { throw Exceptions.propagate(throwable); } } - private static PostInvocationAttribute findPostInvocationAttribute( - Collection config) { + private static PostInvocationAttribute findPostInvocationAttribute(Collection config) { for (ConfigAttribute attribute : config) { if (attribute instanceof PostInvocationAttribute) { return (PostInvocationAttribute) attribute; } } - return null; } - private static PreInvocationAttribute findPreInvocationAttribute( - Collection config) { + private static PreInvocationAttribute findPreInvocationAttribute(Collection config) { for (ConfigAttribute attribute : config) { if (attribute instanceof PreInvocationAttribute) { return (PreInvocationAttribute) attribute; } } - return null; } + } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PrePostAnnotationSecurityMetadataSource.java b/core/src/main/java/org/springframework/security/access/prepost/PrePostAnnotationSecurityMetadataSource.java index c6797debf1..8f85e8ad78 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PrePostAnnotationSecurityMetadataSource.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PrePostAnnotationSecurityMetadataSource.java @@ -13,23 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import java.lang.annotation.Annotation; import java.lang.reflect.Method; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.method.AbstractMethodSecurityMetadataSource; import org.springframework.util.ClassUtils; /** - * MethodSecurityMetadataSource which extracts metadata from the @PreFilter and @PreAuthorize - * annotations placed on a method. This class is merely responsible for locating the - * relevant annotations (if any). It delegates the actual ConfigAttribute - * creation to its {@link PrePostInvocationAttributeFactory}, thus decoupling itself from - * the mechanism which will enforce the annotations' behaviour. + * MethodSecurityMetadataSource which extracts metadata from the @PreFilter + * and @PreAuthorize annotations placed on a method. This class is merely responsible for + * locating the relevant annotations (if any). It delegates the actual + * ConfigAttribute creation to its {@link PrePostInvocationAttributeFactory}, + * thus decoupling itself from the mechanism which will enforce the annotations' + * behaviour. *

      * Annotations may be specified on classes or methods, and method-specific annotations * will take precedence. If you use any annotation and do not specify a pre-authorization @@ -40,71 +45,56 @@ import org.springframework.util.ClassUtils; * combine annotations defined in multiple locations for a single method - they may be * defined on the method itself, or at interface or class level. * - * @see PreInvocationAuthorizationAdviceVoter - * * @author Luke Taylor * @since 3.0 + * @see PreInvocationAuthorizationAdviceVoter */ -public class PrePostAnnotationSecurityMetadataSource extends - AbstractMethodSecurityMetadataSource { +public class PrePostAnnotationSecurityMetadataSource extends AbstractMethodSecurityMetadataSource { private final PrePostInvocationAttributeFactory attributeFactory; - public PrePostAnnotationSecurityMetadataSource( - PrePostInvocationAttributeFactory attributeFactory) { + public PrePostAnnotationSecurityMetadataSource(PrePostInvocationAttributeFactory attributeFactory) { this.attributeFactory = attributeFactory; } + @Override public Collection getAttributes(Method method, Class targetClass) { if (method.getDeclaringClass() == Object.class) { return Collections.emptyList(); } - - logger.trace("Looking for Pre/Post annotations for method '" + method.getName() - + "' on target class '" + targetClass + "'"); + this.logger.trace(LogMessage.format("Looking for Pre/Post annotations for method '%s' on target class '%s'", + method.getName(), targetClass)); PreFilter preFilter = findAnnotation(method, targetClass, PreFilter.class); - PreAuthorize preAuthorize = findAnnotation(method, targetClass, - PreAuthorize.class); + PreAuthorize preAuthorize = findAnnotation(method, targetClass, PreAuthorize.class); PostFilter postFilter = findAnnotation(method, targetClass, PostFilter.class); // TODO: Can we check for void methods and throw an exception here? - PostAuthorize postAuthorize = findAnnotation(method, targetClass, - PostAuthorize.class); - - if (preFilter == null && preAuthorize == null && postFilter == null - && postAuthorize == null) { + PostAuthorize postAuthorize = findAnnotation(method, targetClass, PostAuthorize.class); + if (preFilter == null && preAuthorize == null && postFilter == null && postAuthorize == null) { // There is no meta-data so return - logger.trace("No expression annotations found"); + this.logger.trace("No expression annotations found"); return Collections.emptyList(); } - - String preFilterAttribute = preFilter == null ? null : preFilter.value(); - String filterObject = preFilter == null ? null : preFilter.filterTarget(); - String preAuthorizeAttribute = preAuthorize == null ? null : preAuthorize.value(); - String postFilterAttribute = postFilter == null ? null : postFilter.value(); - String postAuthorizeAttribute = postAuthorize == null ? null : postAuthorize - .value(); - + String preFilterAttribute = (preFilter != null) ? preFilter.value() : null; + String filterObject = (preFilter != null) ? preFilter.filterTarget() : null; + String preAuthorizeAttribute = (preAuthorize != null) ? preAuthorize.value() : null; + String postFilterAttribute = (postFilter != null) ? postFilter.value() : null; + String postAuthorizeAttribute = (postAuthorize != null) ? postAuthorize.value() : null; ArrayList attrs = new ArrayList<>(2); - - PreInvocationAttribute pre = attributeFactory.createPreInvocationAttribute( - preFilterAttribute, filterObject, preAuthorizeAttribute); - + PreInvocationAttribute pre = this.attributeFactory.createPreInvocationAttribute(preFilterAttribute, + filterObject, preAuthorizeAttribute); if (pre != null) { attrs.add(pre); } - - PostInvocationAttribute post = attributeFactory.createPostInvocationAttribute( - postFilterAttribute, postAuthorizeAttribute); - + PostInvocationAttribute post = this.attributeFactory.createPostInvocationAttribute(postFilterAttribute, + postAuthorizeAttribute); if (post != null) { attrs.add(post); } - attrs.trimToSize(); - return attrs; } + @Override public Collection getAllConfigAttributes() { return null; } @@ -115,40 +105,32 @@ public class PrePostAnnotationSecurityMetadataSource extends * for the logic of this method. The ordering here is slightly different in that we * consider method-specific annotations on an interface before class-level ones. */ - private A findAnnotation(Method method, Class targetClass, - Class annotationClass) { + private A findAnnotation(Method method, Class targetClass, Class annotationClass) { // The method may be on an interface, but we need attributes from the target // class. // If the target class is null, the method will be unchanged. Method specificMethod = ClassUtils.getMostSpecificMethod(method, targetClass); A annotation = AnnotationUtils.findAnnotation(specificMethod, annotationClass); - if (annotation != null) { - logger.debug(annotation + " found on specific method: " + specificMethod); + this.logger.debug(LogMessage.format("%s found on specific method: %s", annotation, specificMethod)); return annotation; } - // Check the original (e.g. interface) method if (specificMethod != method) { annotation = AnnotationUtils.findAnnotation(method, annotationClass); - if (annotation != null) { - logger.debug(annotation + " found on: " + method); + this.logger.debug(LogMessage.format("%s found on: %s", annotation, method)); return annotation; } } - // Check the class-level (note declaringClass, not targetClass, which may not // actually implement the method) - annotation = AnnotationUtils.findAnnotation(specificMethod.getDeclaringClass(), - annotationClass); - + annotation = AnnotationUtils.findAnnotation(specificMethod.getDeclaringClass(), annotationClass); if (annotation != null) { - logger.debug(annotation + " found on: " - + specificMethod.getDeclaringClass().getName()); + this.logger.debug( + LogMessage.format("%s found on: %s", annotation, specificMethod.getDeclaringClass().getName())); return annotation; } - return null; } diff --git a/core/src/main/java/org/springframework/security/access/prepost/PrePostInvocationAttributeFactory.java b/core/src/main/java/org/springframework/security/access/prepost/PrePostInvocationAttributeFactory.java index bd1716555a..1feaba1395 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/PrePostInvocationAttributeFactory.java +++ b/core/src/main/java/org/springframework/security/access/prepost/PrePostInvocationAttributeFactory.java @@ -13,20 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.prepost; import org.springframework.aop.framework.AopInfrastructureBean; /** - * * @author Luke Taylor * @since 3.0 */ public interface PrePostInvocationAttributeFactory extends AopInfrastructureBean { - PreInvocationAttribute createPreInvocationAttribute(String preFilterAttribute, - String filterObject, String preAuthorizeAttribute); + PreInvocationAttribute createPreInvocationAttribute(String preFilterAttribute, String filterObject, + String preAuthorizeAttribute); + + PostInvocationAttribute createPostInvocationAttribute(String postFilterAttribute, String postAuthorizeAttribute); - PostInvocationAttribute createPostInvocationAttribute(String postFilterAttribute, - String postAuthorizeAttribute); } diff --git a/core/src/main/java/org/springframework/security/access/prepost/package-info.java b/core/src/main/java/org/springframework/security/access/prepost/package-info.java index 8db4a0005b..6fc3a3a8f9 100644 --- a/core/src/main/java/org/springframework/security/access/prepost/package-info.java +++ b/core/src/main/java/org/springframework/security/access/prepost/package-info.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Contains the infrastructure classes for handling the {@code @PreAuthorize}, {@code @PreFilter}, {@code @PostAuthorize} - * and {@code @PostFilter} annotations. + * Contains the infrastructure classes for handling the {@code @PreAuthorize}, + * {@code @PreFilter}, {@code @PostAuthorize} and {@code @PostFilter} annotations. *

      - * Other than the annotations themselves, the classes should be regarded as for internal framework use and - * are liable to change without notice. + * Other than the annotations themselves, the classes should be regarded as for internal + * framework use and are liable to change without notice. */ package org.springframework.security.access.prepost; - diff --git a/core/src/main/java/org/springframework/security/access/vote/AbstractAccessDecisionManager.java b/core/src/main/java/org/springframework/security/access/vote/AbstractAccessDecisionManager.java index 3cec181713..50d835df37 100644 --- a/core/src/main/java/org/springframework/security/access/vote/AbstractAccessDecisionManager.java +++ b/core/src/main/java/org/springframework/security/access/vote/AbstractAccessDecisionManager.java @@ -20,15 +20,16 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.MessageSource; +import org.springframework.context.MessageSourceAware; +import org.springframework.context.support.MessageSourceAccessor; import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.SpringSecurityMessageSource; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.context.MessageSource; -import org.springframework.context.MessageSourceAware; -import org.springframework.context.support.MessageSourceAccessor; import org.springframework.util.Assert; /** @@ -39,10 +40,9 @@ import org.springframework.util.Assert; * and the access control behaviour if all voters abstain from voting (defaults to deny * access). */ -public abstract class AbstractAccessDecisionManager implements AccessDecisionManager, - InitializingBean, MessageSourceAware { - // ~ Instance fields - // ================================================================================================ +public abstract class AbstractAccessDecisionManager + implements AccessDecisionManager, InitializingBean, MessageSourceAware { + protected final Log logger = LogFactory.getLog(getClass()); private List> decisionVoters; @@ -51,15 +51,12 @@ public abstract class AbstractAccessDecisionManager implements AccessDecisionMan private boolean allowIfAllAbstainDecisions = false; - protected AbstractAccessDecisionManager( - List> decisionVoters) { + protected AbstractAccessDecisionManager(List> decisionVoters) { Assert.notEmpty(decisionVoters, "A list of AccessDecisionVoters is required"); this.decisionVoters = decisionVoters; } - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { Assert.notEmpty(this.decisionVoters, "A list of AccessDecisionVoters is required"); Assert.notNull(this.messages, "A message source must be set"); @@ -67,8 +64,8 @@ public abstract class AbstractAccessDecisionManager implements AccessDecisionMan protected final void checkAllowIfAllAbstainDecisions() { if (!this.isAllowIfAllAbstainDecisions()) { - throw new AccessDeniedException(messages.getMessage( - "AbstractAccessDecisionManager.accessDenied", "Access is denied")); + throw new AccessDeniedException( + this.messages.getMessage("AbstractAccessDecisionManager.accessDenied", "Access is denied")); } } @@ -77,24 +74,25 @@ public abstract class AbstractAccessDecisionManager implements AccessDecisionMan } public boolean isAllowIfAllAbstainDecisions() { - return allowIfAllAbstainDecisions; + return this.allowIfAllAbstainDecisions; } public void setAllowIfAllAbstainDecisions(boolean allowIfAllAbstainDecisions) { this.allowIfAllAbstainDecisions = allowIfAllAbstainDecisions; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } + @Override public boolean supports(ConfigAttribute attribute) { - for (AccessDecisionVoter voter : this.decisionVoters) { + for (AccessDecisionVoter voter : this.decisionVoters) { if (voter.supports(attribute)) { return true; } } - return false; } @@ -104,17 +102,17 @@ public abstract class AbstractAccessDecisionManager implements AccessDecisionMan *

      * If one or more voters cannot support the presented class, false is * returned. - * * @param clazz the type of secured object being presented * @return true if this type is supported */ + @Override public boolean supports(Class clazz) { - for (AccessDecisionVoter voter : this.decisionVoters) { + for (AccessDecisionVoter voter : this.decisionVoters) { if (!voter.supports(clazz)) { return false; } } - return true; } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/AbstractAclVoter.java b/core/src/main/java/org/springframework/security/access/vote/AbstractAclVoter.java index 67fbe73bef..5e1fd32180 100644 --- a/core/src/main/java/org/springframework/security/access/vote/AbstractAclVoter.java +++ b/core/src/main/java/org/springframework/security/access/vote/AbstractAclVoter.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.vote; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AuthorizationServiceException; import org.springframework.util.Assert; @@ -27,51 +29,40 @@ import org.springframework.util.Assert; * @author Ben Alex */ public abstract class AbstractAclVoter implements AccessDecisionVoter { - // ~ Instance fields - // ================================================================================================ private Class processDomainObjectClass; - // ~ Methods - // ======================================================================================================== - protected Object getDomainObjectInstance(MethodInvocation invocation) { - Object[] args; - Class[] params; - - params = invocation.getMethod().getParameterTypes(); - args = invocation.getArguments(); - + Object[] args = invocation.getArguments(); + Class[] params = invocation.getMethod().getParameterTypes(); for (int i = 0; i < params.length; i++) { - if (processDomainObjectClass.isAssignableFrom(params[i])) { + if (this.processDomainObjectClass.isAssignableFrom(params[i])) { return args[i]; } } - throw new AuthorizationServiceException("MethodInvocation: " + invocation - + " did not provide any argument of type: " + processDomainObjectClass); + + " did not provide any argument of type: " + this.processDomainObjectClass); } public Class getProcessDomainObjectClass() { - return processDomainObjectClass; + return this.processDomainObjectClass; } public void setProcessDomainObjectClass(Class processDomainObjectClass) { - Assert.notNull(processDomainObjectClass, - "processDomainObjectClass cannot be set to null"); + Assert.notNull(processDomainObjectClass, "processDomainObjectClass cannot be set to null"); this.processDomainObjectClass = processDomainObjectClass; } /** * This implementation supports only MethodSecurityInterceptor, because * it queries the presented MethodInvocation. - * * @param clazz the secure object - * * @return true if the secure object is MethodInvocation, * false otherwise */ + @Override public boolean supports(Class clazz) { return (MethodInvocation.class.isAssignableFrom(clazz)); } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/AffirmativeBased.java b/core/src/main/java/org/springframework/security/access/vote/AffirmativeBased.java index 30f27da022..718ed9f15f 100644 --- a/core/src/main/java/org/springframework/security/access/vote/AffirmativeBased.java +++ b/core/src/main/java/org/springframework/security/access/vote/AffirmativeBased.java @@ -16,8 +16,10 @@ package org.springframework.security.access.vote; -import java.util.*; +import java.util.Collection; +import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; @@ -34,9 +36,6 @@ public class AffirmativeBased extends AbstractAccessDecisionManager { super(decisionVoters); } - // ~ Methods - // ======================================================================================================== - /** * This concrete implementation simply polls all configured * {@link AccessDecisionVoter}s and grants access if any @@ -47,45 +46,36 @@ public class AffirmativeBased extends AbstractAccessDecisionManager { * be based on the {@link #isAllowIfAllAbstainDecisions()} property (defaults to * false). *

      - * * @param authentication the caller invoking the method * @param object the secured object * @param configAttributes the configuration attributes associated with the method * being invoked - * * @throws AccessDeniedException if access is denied */ - public void decide(Authentication authentication, Object object, - Collection configAttributes) throws AccessDeniedException { + @Override + @SuppressWarnings({ "rawtypes", "unchecked" }) + public void decide(Authentication authentication, Object object, Collection configAttributes) + throws AccessDeniedException { int deny = 0; - for (AccessDecisionVoter voter : getDecisionVoters()) { int result = voter.vote(authentication, object, configAttributes); - - if (logger.isDebugEnabled()) { - logger.debug("Voter: " + voter + ", returned: " + result); - } - + this.logger.debug(LogMessage.format("Voter: %s, returned: %s", voter, result)); switch (result) { case AccessDecisionVoter.ACCESS_GRANTED: return; - case AccessDecisionVoter.ACCESS_DENIED: deny++; - break; - default: break; } } - if (deny > 0) { - throw new AccessDeniedException(messages.getMessage( - "AbstractAccessDecisionManager.accessDenied", "Access is denied")); + throw new AccessDeniedException( + this.messages.getMessage("AbstractAccessDecisionManager.accessDenied", "Access is denied")); } - // To get this far, every AccessDecisionVoter abstained checkAllowIfAllAbstainDecisions(); } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/AuthenticatedVoter.java b/core/src/main/java/org/springframework/security/access/vote/AuthenticatedVoter.java index beb357b1eb..eec33f2d53 100644 --- a/core/src/main/java/org/springframework/security/access/vote/AuthenticatedVoter.java +++ b/core/src/main/java/org/springframework/security/access/vote/AuthenticatedVoter.java @@ -47,87 +47,70 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class AuthenticatedVoter implements AccessDecisionVoter { - // ~ Static fields/initializers - // ===================================================================================== public static final String IS_AUTHENTICATED_FULLY = "IS_AUTHENTICATED_FULLY"; + public static final String IS_AUTHENTICATED_REMEMBERED = "IS_AUTHENTICATED_REMEMBERED"; + public static final String IS_AUTHENTICATED_ANONYMOUSLY = "IS_AUTHENTICATED_ANONYMOUSLY"; - // ~ Instance fields - // ================================================================================================ private AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); - // ~ Methods - // ======================================================================================================== - private boolean isFullyAuthenticated(Authentication authentication) { - return (!authenticationTrustResolver.isAnonymous(authentication) && !authenticationTrustResolver - .isRememberMe(authentication)); + return (!this.authenticationTrustResolver.isAnonymous(authentication) + && !this.authenticationTrustResolver.isRememberMe(authentication)); } - public void setAuthenticationTrustResolver( - AuthenticationTrustResolver authenticationTrustResolver) { - Assert.notNull(authenticationTrustResolver, - "AuthenticationTrustResolver cannot be set to null"); + public void setAuthenticationTrustResolver(AuthenticationTrustResolver authenticationTrustResolver) { + Assert.notNull(authenticationTrustResolver, "AuthenticationTrustResolver cannot be set to null"); this.authenticationTrustResolver = authenticationTrustResolver; } + @Override public boolean supports(ConfigAttribute attribute) { - if ((attribute.getAttribute() != null) - && (IS_AUTHENTICATED_FULLY.equals(attribute.getAttribute()) - || IS_AUTHENTICATED_REMEMBERED.equals(attribute.getAttribute()) || IS_AUTHENTICATED_ANONYMOUSLY - .equals(attribute.getAttribute()))) { - return true; - } - else { - return false; - } + return (attribute.getAttribute() != null) && (IS_AUTHENTICATED_FULLY.equals(attribute.getAttribute()) + || IS_AUTHENTICATED_REMEMBERED.equals(attribute.getAttribute()) + || IS_AUTHENTICATED_ANONYMOUSLY.equals(attribute.getAttribute())); } /** * This implementation supports any type of class, because it does not query the * presented secure object. - * * @param clazz the secure object type - * * @return always {@code true} */ + @Override public boolean supports(Class clazz) { return true; } - public int vote(Authentication authentication, Object object, - Collection attributes) { + @Override + public int vote(Authentication authentication, Object object, Collection attributes) { int result = ACCESS_ABSTAIN; - for (ConfigAttribute attribute : attributes) { if (this.supports(attribute)) { result = ACCESS_DENIED; - if (IS_AUTHENTICATED_FULLY.equals(attribute.getAttribute())) { if (isFullyAuthenticated(authentication)) { return ACCESS_GRANTED; } } - if (IS_AUTHENTICATED_REMEMBERED.equals(attribute.getAttribute())) { - if (authenticationTrustResolver.isRememberMe(authentication) + if (this.authenticationTrustResolver.isRememberMe(authentication) || isFullyAuthenticated(authentication)) { return ACCESS_GRANTED; } } - if (IS_AUTHENTICATED_ANONYMOUSLY.equals(attribute.getAttribute())) { - if (authenticationTrustResolver.isAnonymous(authentication) + if (this.authenticationTrustResolver.isAnonymous(authentication) || isFullyAuthenticated(authentication) - || authenticationTrustResolver.isRememberMe(authentication)) { + || this.authenticationTrustResolver.isRememberMe(authentication)) { return ACCESS_GRANTED; } } } } - return result; } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/ConsensusBased.java b/core/src/main/java/org/springframework/security/access/vote/ConsensusBased.java index 2b4be013a3..3a98a61a49 100644 --- a/core/src/main/java/org/springframework/security/access/vote/ConsensusBased.java +++ b/core/src/main/java/org/springframework/security/access/vote/ConsensusBased.java @@ -16,8 +16,10 @@ package org.springframework.security.access.vote; -import java.util.*; +import java.util.Collection; +import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; @@ -33,8 +35,6 @@ import org.springframework.security.core.Authentication; * {@link UnanimousBased}. */ public class ConsensusBased extends AbstractAccessDecisionManager { - // ~ Instance fields - // ================================================================================================ private boolean allowIfEqualGrantedDeniedDecisions = true; @@ -42,9 +42,6 @@ public class ConsensusBased extends AbstractAccessDecisionManager { super(decisionVoters); } - // ~ Methods - // ======================================================================================================== - /** * This concrete implementation simply polls all configured * {@link AccessDecisionVoter}s and upon completion determines the consensus of @@ -56,71 +53,56 @@ public class ConsensusBased extends AbstractAccessDecisionManager { * If every AccessDecisionVoter abstained from voting, the decision will * be based on the {@link #isAllowIfAllAbstainDecisions()} property (defaults to * false). - * * @param authentication the caller invoking the method * @param object the secured object * @param configAttributes the configuration attributes associated with the method * being invoked - * * @throws AccessDeniedException if access is denied */ - public void decide(Authentication authentication, Object object, - Collection configAttributes) throws AccessDeniedException { + @Override + @SuppressWarnings({ "rawtypes", "unchecked" }) + public void decide(Authentication authentication, Object object, Collection configAttributes) + throws AccessDeniedException { int grant = 0; int deny = 0; - for (AccessDecisionVoter voter : getDecisionVoters()) { int result = voter.vote(authentication, object, configAttributes); - - if (logger.isDebugEnabled()) { - logger.debug("Voter: " + voter + ", returned: " + result); - } - + this.logger.debug(LogMessage.format("Voter: %s, returned: %s", voter, result)); switch (result) { case AccessDecisionVoter.ACCESS_GRANTED: grant++; - break; - case AccessDecisionVoter.ACCESS_DENIED: deny++; - break; - default: break; } } - if (grant > deny) { return; } - if (deny > grant) { - throw new AccessDeniedException(messages.getMessage( - "AbstractAccessDecisionManager.accessDenied", "Access is denied")); + throw new AccessDeniedException( + this.messages.getMessage("AbstractAccessDecisionManager.accessDenied", "Access is denied")); } - if ((grant == deny) && (grant != 0)) { if (this.allowIfEqualGrantedDeniedDecisions) { return; } - else { - throw new AccessDeniedException(messages.getMessage( - "AbstractAccessDecisionManager.accessDenied", "Access is denied")); - } + throw new AccessDeniedException( + this.messages.getMessage("AbstractAccessDecisionManager.accessDenied", "Access is denied")); } - // To get this far, every AccessDecisionVoter abstained checkAllowIfAllAbstainDecisions(); } public boolean isAllowIfEqualGrantedDeniedDecisions() { - return allowIfEqualGrantedDeniedDecisions; + return this.allowIfEqualGrantedDeniedDecisions; } - public void setAllowIfEqualGrantedDeniedDecisions( - boolean allowIfEqualGrantedDeniedDecisions) { + public void setAllowIfEqualGrantedDeniedDecisions(boolean allowIfEqualGrantedDeniedDecisions) { this.allowIfEqualGrantedDeniedDecisions = allowIfEqualGrantedDeniedDecisions; } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/RoleHierarchyVoter.java b/core/src/main/java/org/springframework/security/access/vote/RoleHierarchyVoter.java index 32e1598e4d..6dc1cbcb6d 100644 --- a/core/src/main/java/org/springframework/security/access/vote/RoleHierarchyVoter.java +++ b/core/src/main/java/org/springframework/security/access/vote/RoleHierarchyVoter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.vote; import java.util.Collection; @@ -30,6 +31,7 @@ import org.springframework.util.Assert; * @since 2.0.4 */ public class RoleHierarchyVoter extends RoleVoter { + private RoleHierarchy roleHierarchy = null; public RoleHierarchyVoter(RoleHierarchy roleHierarchy) { @@ -41,9 +43,8 @@ public class RoleHierarchyVoter extends RoleVoter { * Calls the RoleHierarchy to obtain the complete set of user authorities. */ @Override - Collection extractAuthorities( - Authentication authentication) { - return roleHierarchy.getReachableGrantedAuthorities(authentication - .getAuthorities()); + Collection extractAuthorities(Authentication authentication) { + return this.roleHierarchy.getReachableGrantedAuthorities(authentication.getAuthorities()); } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/RoleVoter.java b/core/src/main/java/org/springframework/security/access/vote/RoleVoter.java index e4d1bd6319..546507df12 100644 --- a/core/src/main/java/org/springframework/security/access/vote/RoleVoter.java +++ b/core/src/main/java/org/springframework/security/access/vote/RoleVoter.java @@ -50,62 +50,48 @@ import org.springframework.security.core.GrantedAuthority; * @author colin sampaleanu */ public class RoleVoter implements AccessDecisionVoter { - // ~ Instance fields - // ================================================================================================ private String rolePrefix = "ROLE_"; - // ~ Methods - // ======================================================================================================== - public String getRolePrefix() { - return rolePrefix; + return this.rolePrefix; } /** * Allows the default role prefix of ROLE_ to be overridden. May be set * to an empty value, although this is usually not desirable. - * * @param rolePrefix the new prefix */ public void setRolePrefix(String rolePrefix) { this.rolePrefix = rolePrefix; } + @Override public boolean supports(ConfigAttribute attribute) { - if ((attribute.getAttribute() != null) - && attribute.getAttribute().startsWith(getRolePrefix())) { - return true; - } - else { - return false; - } + return (attribute.getAttribute() != null) && attribute.getAttribute().startsWith(getRolePrefix()); } /** * This implementation supports any type of class, because it does not query the * presented secure object. - * * @param clazz the secure object - * * @return always true */ + @Override public boolean supports(Class clazz) { return true; } - public int vote(Authentication authentication, Object object, - Collection attributes) { + @Override + public int vote(Authentication authentication, Object object, Collection attributes) { if (authentication == null) { return ACCESS_DENIED; } int result = ACCESS_ABSTAIN; Collection authorities = extractAuthorities(authentication); - for (ConfigAttribute attribute : attributes) { if (this.supports(attribute)) { result = ACCESS_DENIED; - // Attempt to find a matching granted authority for (GrantedAuthority authority : authorities) { if (attribute.getAttribute().equals(authority.getAuthority())) { @@ -114,12 +100,11 @@ public class RoleVoter implements AccessDecisionVoter { } } } - return result; } - Collection extractAuthorities( - Authentication authentication) { + Collection extractAuthorities(Authentication authentication) { return authentication.getAuthorities(); } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/UnanimousBased.java b/core/src/main/java/org/springframework/security/access/vote/UnanimousBased.java index af7f8eae9d..edadbb9f49 100644 --- a/core/src/main/java/org/springframework/security/access/vote/UnanimousBased.java +++ b/core/src/main/java/org/springframework/security/access/vote/UnanimousBased.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; @@ -36,9 +37,6 @@ public class UnanimousBased extends AbstractAccessDecisionManager { super(decisionVoters); } - // ~ Methods - // ======================================================================================================== - /** * This concrete implementation polls all configured {@link AccessDecisionVoter}s for * each {@link ConfigAttribute} and grants access if only grant (or abstain) @@ -52,55 +50,42 @@ public class UnanimousBased extends AbstractAccessDecisionManager { * If every AccessDecisionVoter abstained from voting, the decision will * be based on the {@link #isAllowIfAllAbstainDecisions()} property (defaults to * false). - * * @param authentication the caller invoking the method * @param object the secured object * @param attributes the configuration attributes associated with the method being * invoked - * * @throws AccessDeniedException if access is denied */ - public void decide(Authentication authentication, Object object, - Collection attributes) throws AccessDeniedException { - + @Override + @SuppressWarnings({ "rawtypes", "unchecked" }) + public void decide(Authentication authentication, Object object, Collection attributes) + throws AccessDeniedException { int grant = 0; - List singleAttributeList = new ArrayList<>(1); singleAttributeList.add(null); - for (ConfigAttribute attribute : attributes) { singleAttributeList.set(0, attribute); - for (AccessDecisionVoter voter : getDecisionVoters()) { int result = voter.vote(authentication, object, singleAttributeList); - - if (logger.isDebugEnabled()) { - logger.debug("Voter: " + voter + ", returned: " + result); - } - + this.logger.debug(LogMessage.format("Voter: %s, returned: %s", voter, result)); switch (result) { case AccessDecisionVoter.ACCESS_GRANTED: grant++; - break; - case AccessDecisionVoter.ACCESS_DENIED: - throw new AccessDeniedException(messages.getMessage( - "AbstractAccessDecisionManager.accessDenied", - "Access is denied")); - + throw new AccessDeniedException( + this.messages.getMessage("AbstractAccessDecisionManager.accessDenied", "Access is denied")); default: break; } } } - // To get this far, there were no deny votes if (grant > 0) { return; } - // To get this far, every AccessDecisionVoter abstained checkAllowIfAllAbstainDecisions(); } + } diff --git a/core/src/main/java/org/springframework/security/access/vote/package-info.java b/core/src/main/java/org/springframework/security/access/vote/package-info.java index 6c8cf27fea..2ec2ecedb2 100644 --- a/core/src/main/java/org/springframework/security/access/vote/package-info.java +++ b/core/src/main/java/org/springframework/security/access/vote/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Implements a vote-based approach to authorization decisions. */ package org.springframework.security.access.vote; - diff --git a/core/src/main/java/org/springframework/security/authentication/AbstractAuthenticationToken.java b/core/src/main/java/org/springframework/security/authentication/AbstractAuthenticationToken.java index c3b017b7c9..422ce529d3 100644 --- a/core/src/main/java/org/springframework/security/authentication/AbstractAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/authentication/AbstractAuthenticationToken.java @@ -21,12 +21,13 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.CredentialsContainer; import org.springframework.security.core.AuthenticatedPrincipal; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.CredentialsContainer; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.util.Assert; /** * Base class for Authentication objects. @@ -36,49 +37,36 @@ import org.springframework.security.core.userdetails.UserDetails; * @author Ben Alex * @author Luke Taylor */ -public abstract class AbstractAuthenticationToken implements Authentication, - CredentialsContainer { - // ~ Instance fields - // ================================================================================================ +public abstract class AbstractAuthenticationToken implements Authentication, CredentialsContainer { private final Collection authorities; - private Object details; - private boolean authenticated = false; - // ~ Constructors - // =================================================================================================== + private Object details; + + private boolean authenticated = false; /** * Creates a token with the supplied array of authorities. - * * @param authorities the collection of GrantedAuthoritys for the principal - * represented by this authentication object. + * represented by this authentication object. */ public AbstractAuthenticationToken(Collection authorities) { if (authorities == null) { this.authorities = AuthorityUtils.NO_AUTHORITIES; return; } - for (GrantedAuthority a : authorities) { - if (a == null) { - throw new IllegalArgumentException( - "Authorities collection cannot contain any null elements"); - } + Assert.notNull(a, "Authorities collection cannot contain any null elements"); } - ArrayList temp = new ArrayList<>( - authorities.size()); - temp.addAll(authorities); - this.authorities = Collections.unmodifiableList(temp); + this.authorities = Collections.unmodifiableList(new ArrayList<>(authorities)); } - // ~ Methods - // ======================================================================================================== - + @Override public Collection getAuthorities() { - return authorities; + return this.authorities; } + @Override public String getName() { if (this.getPrincipal() instanceof UserDetails) { return ((UserDetails) this.getPrincipal()).getUsername(); @@ -89,20 +77,22 @@ public abstract class AbstractAuthenticationToken implements Authentication, if (this.getPrincipal() instanceof Principal) { return ((Principal) this.getPrincipal()).getName(); } - return (this.getPrincipal() == null) ? "" : this.getPrincipal().toString(); } + @Override public boolean isAuthenticated() { - return authenticated; + return this.authenticated; } + @Override public void setAuthenticated(boolean authenticated) { this.authenticated = authenticated; } + @Override public Object getDetails() { - return details; + return this.details; } public void setDetails(Object details) { @@ -114,10 +104,11 @@ public abstract class AbstractAuthenticationToken implements Authentication, * invoking the {@code eraseCredentials} method on any which implement * {@link CredentialsContainer}. */ + @Override public void eraseCredentials() { eraseSecret(getCredentials()); eraseSecret(getPrincipal()); - eraseSecret(details); + eraseSecret(this.details); } private void eraseSecret(Object secret) { @@ -131,70 +122,52 @@ public abstract class AbstractAuthenticationToken implements Authentication, if (!(obj instanceof AbstractAuthenticationToken)) { return false; } - AbstractAuthenticationToken test = (AbstractAuthenticationToken) obj; - - if (!authorities.equals(test.authorities)) { + if (!this.authorities.equals(test.authorities)) { return false; } - if ((this.details == null) && (test.getDetails() != null)) { return false; } - if ((this.details != null) && (test.getDetails() == null)) { return false; } - if ((this.details != null) && (!this.details.equals(test.getDetails()))) { return false; } - if ((this.getCredentials() == null) && (test.getCredentials() != null)) { return false; } - - if ((this.getCredentials() != null) - && !this.getCredentials().equals(test.getCredentials())) { + if ((this.getCredentials() != null) && !this.getCredentials().equals(test.getCredentials())) { return false; } - if (this.getPrincipal() == null && test.getPrincipal() != null) { return false; } - - if (this.getPrincipal() != null - && !this.getPrincipal().equals(test.getPrincipal())) { + if (this.getPrincipal() != null && !this.getPrincipal().equals(test.getPrincipal())) { return false; } - return this.isAuthenticated() == test.isAuthenticated(); } @Override public int hashCode() { int code = 31; - - for (GrantedAuthority authority : authorities) { + for (GrantedAuthority authority : this.authorities) { code ^= authority.hashCode(); } - if (this.getPrincipal() != null) { code ^= this.getPrincipal().hashCode(); } - if (this.getCredentials() != null) { code ^= this.getCredentials().hashCode(); } - if (this.getDetails() != null) { code ^= this.getDetails().hashCode(); } - if (this.isAuthenticated()) { code ^= -37; } - return code; } @@ -206,22 +179,20 @@ public abstract class AbstractAuthenticationToken implements Authentication, sb.append("Credentials: [PROTECTED]; "); sb.append("Authenticated: ").append(this.isAuthenticated()).append("; "); sb.append("Details: ").append(this.getDetails()).append("; "); - - if (!authorities.isEmpty()) { + if (!this.authorities.isEmpty()) { sb.append("Granted Authorities: "); - int i = 0; - for (GrantedAuthority authority : authorities) { + for (GrantedAuthority authority : this.authorities) { if (i++ > 0) { sb.append(", "); } - sb.append(authority); } - } else { + } + else { sb.append("Not granted any authorities"); } - return sb.toString(); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AbstractUserDetailsReactiveAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/AbstractUserDetailsReactiveAuthenticationManager.java index cbc6e3b71b..4864ea12ee 100644 --- a/core/src/main/java/org/springframework/security/authentication/AbstractUserDetailsReactiveAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/AbstractUserDetailsReactiveAuthenticationManager.java @@ -33,8 +33,8 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.util.Assert; /** - * A base {@link ReactiveAuthenticationManager} that allows subclasses to override and work with - * {@link UserDetails} objects. + * A base {@link ReactiveAuthenticationManager} that allows subclasses to override and + * work with {@link UserDetails} objects. * *

      * Upon successful validation, a UsernamePasswordAuthenticationToken will be @@ -57,67 +57,70 @@ public abstract class AbstractUserDetailsReactiveAuthenticationManager implement private Scheduler scheduler = Schedulers.boundedElastic(); - private UserDetailsChecker preAuthenticationChecks = user -> { - if (!user.isAccountNonLocked()) { - logger.debug("User account is locked"); + private UserDetailsChecker preAuthenticationChecks = this::defaultPreAuthenticationChecks; - throw new LockedException(this.messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.locked", + private UserDetailsChecker postAuthenticationChecks = this::defaultPostAuthenticationChecks; + + private void defaultPreAuthenticationChecks(UserDetails user) { + if (!user.isAccountNonLocked()) { + this.logger.debug("User account is locked"); + throw new LockedException(this.messages.getMessage("AbstractUserDetailsAuthenticationProvider.locked", "User account is locked")); } - if (!user.isEnabled()) { - logger.debug("User account is disabled"); - - throw new DisabledException(this.messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.disabled", - "User is disabled")); + this.logger.debug("User account is disabled"); + throw new DisabledException( + this.messages.getMessage("AbstractUserDetailsAuthenticationProvider.disabled", "User is disabled")); } - if (!user.isAccountNonExpired()) { - logger.debug("User account is expired"); - - throw new AccountExpiredException(this.messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.expired", - "User account has expired")); + this.logger.debug("User account is expired"); + throw new AccountExpiredException(this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.expired", "User account has expired")); } - }; + } - private UserDetailsChecker postAuthenticationChecks = user -> { + private void defaultPostAuthenticationChecks(UserDetails user) { if (!user.isCredentialsNonExpired()) { - logger.debug("User account credentials have expired"); - + this.logger.debug("User account credentials have expired"); throw new CredentialsExpiredException(this.messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.credentialsExpired", - "User credentials have expired")); + "AbstractUserDetailsAuthenticationProvider.credentialsExpired", "User credentials have expired")); } - }; + } @Override public Mono authenticate(Authentication authentication) { - final String username = authentication.getName(); - final String presentedPassword = (String) authentication.getCredentials(); + String username = authentication.getName(); + String presentedPassword = (String) authentication.getCredentials(); + // @formatter:off return retrieveUser(username) .doOnNext(this.preAuthenticationChecks::check) .publishOn(this.scheduler) - .filter(u -> this.passwordEncoder.matches(presentedPassword, u.getPassword())) + .filter((userDetails) -> this.passwordEncoder.matches(presentedPassword, userDetails.getPassword())) .switchIfEmpty(Mono.defer(() -> Mono.error(new BadCredentialsException("Invalid Credentials")))) - .flatMap(u -> { - boolean upgradeEncoding = this.userDetailsPasswordService != null - && this.passwordEncoder.upgradeEncoding(u.getPassword()); - if (upgradeEncoding) { - String newPassword = this.passwordEncoder.encode(presentedPassword); - return this.userDetailsPasswordService.updatePassword(u, newPassword); - } - return Mono.just(u); - }) + .flatMap((userDetails) -> upgradeEncodingIfNecessary(userDetails, presentedPassword)) .doOnNext(this.postAuthenticationChecks::check) - .map(u -> new UsernamePasswordAuthenticationToken(u, u.getPassword(), u.getAuthorities()) ); + .map(this::createUsernamePasswordAuthenticationToken); + // @formatter:on + } + + private Mono upgradeEncodingIfNecessary(UserDetails userDetails, String presentedPassword) { + boolean upgradeEncoding = this.userDetailsPasswordService != null + && this.passwordEncoder.upgradeEncoding(userDetails.getPassword()); + if (upgradeEncoding) { + String newPassword = this.passwordEncoder.encode(presentedPassword); + return this.userDetailsPasswordService.updatePassword(userDetails, newPassword); + } + return Mono.just(userDetails); + } + + private UsernamePasswordAuthenticationToken createUsernamePasswordAuthenticationToken(UserDetails userDetails) { + return new UsernamePasswordAuthenticationToken(userDetails, userDetails.getPassword(), + userDetails.getAuthorities()); } /** - * The {@link PasswordEncoder} that is used for validating the password. The default is - * {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()} + * The {@link PasswordEncoder} that is used for validating the password. The default + * is {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()} * @param passwordEncoder the {@link PasswordEncoder} to use. Cannot be null */ public void setPasswordEncoder(PasswordEncoder passwordEncoder) { @@ -126,13 +129,14 @@ public abstract class AbstractUserDetailsReactiveAuthenticationManager implement } /** - * Sets the {@link Scheduler} used by the {@link UserDetailsRepositoryReactiveAuthenticationManager}. - * The default is {@code Schedulers.newParallel(String)} because modern password encoding is - * a CPU intensive task that is non blocking. This means validation is bounded by the - * number of CPUs. Some applications may want to customize the {@link Scheduler}. For - * example, if users are stuck using the insecure {@link org.springframework.security.crypto.password.NoOpPasswordEncoder} - * they might want to leverage {@code Schedulers.immediate()}. - * + * Sets the {@link Scheduler} used by the + * {@link UserDetailsRepositoryReactiveAuthenticationManager}. The default is + * {@code Schedulers.newParallel(String)} because modern password encoding is a CPU + * intensive task that is non blocking. This means validation is bounded by the number + * of CPUs. Some applications may want to customize the {@link Scheduler}. For + * example, if users are stuck using the insecure + * {@link org.springframework.security.crypto.password.NoOpPasswordEncoder} they might + * want to leverage {@code Schedulers.immediate()}. * @param scheduler the {@link Scheduler} to use. Cannot be null. * @since 5.0.6 */ @@ -145,15 +149,13 @@ public abstract class AbstractUserDetailsReactiveAuthenticationManager implement * Sets the service to use for upgrading passwords on successful authentication. * @param userDetailsPasswordService the service to use */ - public void setUserDetailsPasswordService( - ReactiveUserDetailsPasswordService userDetailsPasswordService) { + public void setUserDetailsPasswordService(ReactiveUserDetailsPasswordService userDetailsPasswordService) { this.userDetailsPasswordService = userDetailsPasswordService; } /** * Sets the strategy which will be used to validate the loaded UserDetails * object after authentication occurs. - * * @param postAuthenticationChecks The {@link UserDetailsChecker} * @since 5.2 */ @@ -163,9 +165,8 @@ public abstract class AbstractUserDetailsReactiveAuthenticationManager implement } /** - * Allows subclasses to retrieve the UserDetails - * from an implementation-specific location. - * + * Allows subclasses to retrieve the UserDetails from an + * implementation-specific location. * @param username The username to retrieve * @return the user information. If authentication fails, a Mono error is returned. */ diff --git a/core/src/main/java/org/springframework/security/authentication/AccountExpiredException.java b/core/src/main/java/org/springframework/security/authentication/AccountExpiredException.java index 2636cf0931..e8ef659882 100644 --- a/core/src/main/java/org/springframework/security/authentication/AccountExpiredException.java +++ b/core/src/main/java/org/springframework/security/authentication/AccountExpiredException.java @@ -23,12 +23,9 @@ package org.springframework.security.authentication; * @author Ben Alex */ public class AccountExpiredException extends AccountStatusException { - // ~ Constructors - // =================================================================================================== /** * Constructs a AccountExpiredException with the specified message. - * * @param msg the detail message */ public AccountExpiredException(String msg) { @@ -38,11 +35,11 @@ public class AccountExpiredException extends AccountStatusException { /** * Constructs a AccountExpiredException with the specified message and * root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public AccountExpiredException(String msg, Throwable t) { - super(msg, t); + public AccountExpiredException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AccountStatusException.java b/core/src/main/java/org/springframework/security/authentication/AccountStatusException.java index bed91273a9..f465f995d6 100644 --- a/core/src/main/java/org/springframework/security/authentication/AccountStatusException.java +++ b/core/src/main/java/org/springframework/security/authentication/AccountStatusException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; import org.springframework.security.core.AuthenticationException; @@ -24,11 +25,13 @@ import org.springframework.security.core.AuthenticationException; * @author Luke Taylor */ public abstract class AccountStatusException extends AuthenticationException { + public AccountStatusException(String msg) { super(msg); } - public AccountStatusException(String msg, Throwable t) { - super(msg, t); + public AccountStatusException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AccountStatusUserDetailsChecker.java b/core/src/main/java/org/springframework/security/authentication/AccountStatusUserDetailsChecker.java index 89512454bb..9770a62111 100644 --- a/core/src/main/java/org/springframework/security/authentication/AccountStatusUserDetailsChecker.java +++ b/core/src/main/java/org/springframework/security/authentication/AccountStatusUserDetailsChecker.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; +import org.springframework.context.support.MessageSourceAccessor; import org.springframework.security.core.SpringSecurityMessageSource; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsChecker; -import org.springframework.context.support.MessageSourceAccessor; import org.springframework.util.Assert; /** @@ -28,30 +29,25 @@ import org.springframework.util.Assert; */ public class AccountStatusUserDetailsChecker implements UserDetailsChecker, MessageSourceAware { - protected MessageSourceAccessor messages = SpringSecurityMessageSource - .getAccessor(); + protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + @Override public void check(UserDetails user) { if (!user.isAccountNonLocked()) { - throw new LockedException(messages.getMessage( - "AccountStatusUserDetailsChecker.locked", "User account is locked")); + throw new LockedException( + this.messages.getMessage("AccountStatusUserDetailsChecker.locked", "User account is locked")); } - if (!user.isEnabled()) { - throw new DisabledException(messages.getMessage( - "AccountStatusUserDetailsChecker.disabled", "User is disabled")); + throw new DisabledException( + this.messages.getMessage("AccountStatusUserDetailsChecker.disabled", "User is disabled")); } - if (!user.isAccountNonExpired()) { throw new AccountExpiredException( - messages.getMessage("AccountStatusUserDetailsChecker.expired", - "User account has expired")); + this.messages.getMessage("AccountStatusUserDetailsChecker.expired", "User account has expired")); } - if (!user.isCredentialsNonExpired()) { - throw new CredentialsExpiredException(messages.getMessage( - "AccountStatusUserDetailsChecker.credentialsExpired", - "User credentials have expired")); + throw new CredentialsExpiredException(this.messages + .getMessage("AccountStatusUserDetailsChecker.credentialsExpired", "User credentials have expired")); } } @@ -63,4 +59,5 @@ public class AccountStatusUserDetailsChecker implements UserDetailsChecker, Mess Assert.notNull(messageSource, "messageSource cannot be null"); this.messages = new MessageSourceAccessor(messageSource); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationProvider.java index ba0dfeff2a..dbff14cb4c 100644 --- a/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationProvider.java @@ -33,13 +33,10 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public class AnonymousAuthenticationProvider implements AuthenticationProvider, - MessageSourceAware { - - // ~ Instance fields - // ================================================================================================ +public class AnonymousAuthenticationProvider implements AuthenticationProvider, MessageSourceAware { protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private String key; public AnonymousAuthenticationProvider(String key) { @@ -47,35 +44,31 @@ public class AnonymousAuthenticationProvider implements AuthenticationProvider, this.key = key; } - // ~ Methods - // ======================================================================================================== - - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { if (!supports(authentication.getClass())) { return null; } - - if (this.key.hashCode() != ((AnonymousAuthenticationToken) authentication) - .getKeyHash()) { - throw new BadCredentialsException( - messages.getMessage("AnonymousAuthenticationProvider.incorrectKey", - "The presented AnonymousAuthenticationToken does not contain the expected key")); + if (this.key.hashCode() != ((AnonymousAuthenticationToken) authentication).getKeyHash()) { + throw new BadCredentialsException(this.messages.getMessage("AnonymousAuthenticationProvider.incorrectKey", + "The presented AnonymousAuthenticationToken does not contain the expected key")); } - return authentication; } public String getKey() { - return key; + return this.key; } + @Override public void setMessageSource(MessageSource messageSource) { Assert.notNull(messageSource, "messageSource cannot be null"); this.messages = new MessageSourceAccessor(messageSource); } + @Override public boolean supports(Class authentication) { return (AnonymousAuthenticationToken.class.isAssignableFrom(authentication)); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationToken.java b/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationToken.java index 5ec688d361..2e92d105f7 100644 --- a/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/authentication/AnonymousAuthenticationToken.java @@ -27,56 +27,43 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public class AnonymousAuthenticationToken extends AbstractAuthenticationToken implements - Serializable { - // ~ Instance fields - // ================================================================================================ +public class AnonymousAuthenticationToken extends AbstractAuthenticationToken implements Serializable { private static final long serialVersionUID = 1L; - private final Object principal; - private final int keyHash; - // ~ Constructors - // =================================================================================================== + private final Object principal; + + private final int keyHash; /** * Constructor. - * - * @param key to identify if this object made by an authorised client - * @param principal the principal (typically a UserDetails) + * @param key to identify if this object made by an authorised client + * @param principal the principal (typically a UserDetails) * @param authorities the authorities granted to the principal * @throws IllegalArgumentException if a null was passed */ public AnonymousAuthenticationToken(String key, Object principal, - Collection authorities) { + Collection authorities) { this(extractKeyHash(key), principal, authorities); } /** * Constructor helps in Jackson Deserialization - * - * @param keyHash hashCode of provided Key, constructed by above constructor - * @param principal the principal (typically a UserDetails) + * @param keyHash hashCode of provided Key, constructed by above constructor + * @param principal the principal (typically a UserDetails) * @param authorities the authorities granted to the principal * @since 4.2 */ private AnonymousAuthenticationToken(Integer keyHash, Object principal, - Collection authorities) { + Collection authorities) { super(authorities); - - if (principal == null || "".equals(principal)) { - throw new IllegalArgumentException("principal cannot be null or empty"); - } + Assert.isTrue(principal != null && !"".equals(principal), "principal cannot be null or empty"); Assert.notEmpty(authorities, "authorities cannot be null or empty"); - this.keyHash = keyHash; this.principal = principal; setAuthenticated(true); } - // ~ Methods - // ======================================================================================================== - private static Integer extractKeyHash(String key) { Assert.hasLength(key, "key cannot be empty or null"); return key.hashCode(); @@ -87,17 +74,10 @@ public class AnonymousAuthenticationToken extends AbstractAuthenticationToken im if (!super.equals(obj)) { return false; } - if (obj instanceof AnonymousAuthenticationToken) { AnonymousAuthenticationToken test = (AnonymousAuthenticationToken) obj; - - if (this.getKeyHash() != test.getKeyHash()) { - return false; - } - - return true; + return (this.getKeyHash() == test.getKeyHash()); } - return false; } @@ -110,7 +90,6 @@ public class AnonymousAuthenticationToken extends AbstractAuthenticationToken im /** * Always returns an empty String - * * @return an empty String */ @Override @@ -126,4 +105,5 @@ public class AnonymousAuthenticationToken extends AbstractAuthenticationToken im public Object getPrincipal() { return this.principal; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationCredentialsNotFoundException.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationCredentialsNotFoundException.java index 8d096c450a..91b5d616d8 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationCredentialsNotFoundException.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationCredentialsNotFoundException.java @@ -27,13 +27,10 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class AuthenticationCredentialsNotFoundException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== /** * Constructs an AuthenticationCredentialsNotFoundException with the * specified message. - * * @param msg the detail message */ public AuthenticationCredentialsNotFoundException(String msg) { @@ -43,11 +40,11 @@ public class AuthenticationCredentialsNotFoundException extends AuthenticationEx /** * Constructs an AuthenticationCredentialsNotFoundException with the * specified message and root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public AuthenticationCredentialsNotFoundException(String msg, Throwable t) { - super(msg, t); + public AuthenticationCredentialsNotFoundException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationDetailsSource.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationDetailsSource.java index 27f8550a88..780aed8eca 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationDetailsSource.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationDetailsSource.java @@ -23,17 +23,14 @@ package org.springframework.security.authentication; * @author Ben Alex */ public interface AuthenticationDetailsSource { - // ~ Methods - // ======================================================================================================== /** * Called by a class when it wishes a new authentication details instance to be * created. - * * @param context the request object, which may be used by the authentication details * object - * * @return a fully-configured authentication details instance */ T buildDetails(C context); + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationEventPublisher.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationEventPublisher.java index b467ebcfe2..9158a0559a 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationEventPublisher.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationEventPublisher.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; /** - * * @author Luke Taylor * @since 3.0 */ @@ -27,6 +27,6 @@ public interface AuthenticationEventPublisher { void publishAuthenticationSuccess(Authentication authentication); - void publishAuthenticationFailure(AuthenticationException exception, - Authentication authentication); + void publishAuthenticationFailure(AuthenticationException exception, Authentication authentication); + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationManager.java index 6288a6f179..b98c835149 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationManager.java @@ -25,8 +25,6 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public interface AuthenticationManager { - // ~ Methods - // ======================================================================================================== /** * Attempts to authenticate the passed {@link Authentication} object, returning a @@ -48,13 +46,10 @@ public interface AuthenticationManager { * above (i.e. if an account is disabled or locked, the authentication request is * immediately rejected and the credentials testing process is not performed). This * prevents credentials being tested against disabled or locked accounts. - * * @param authentication the authentication request object - * * @return a fully authenticated object including credentials - * * @throws AuthenticationException if authentication fails */ - Authentication authenticate(Authentication authentication) - throws AuthenticationException; + Authentication authenticate(Authentication authentication) throws AuthenticationException; + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationManagerResolver.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationManagerResolver.java index ad1ea989c8..8a0b8dc979 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationManagerResolver.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationManagerResolver.java @@ -17,7 +17,8 @@ package org.springframework.security.authentication; /** - * An interface for resolving an {@link AuthenticationManager} based on the provided context + * An interface for resolving an {@link AuthenticationManager} based on the provided + * context * * @author Josh Cummings * @since 5.2 @@ -30,4 +31,5 @@ public interface AuthenticationManagerResolver { * @return the {@link AuthenticationManager} to use */ AuthenticationManager resolve(C context); + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationProvider.java index 2c4e2763e1..86e4c6e27e 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationProvider.java @@ -26,26 +26,20 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public interface AuthenticationProvider { - // ~ Methods - // ======================================================================================================== /** * Performs authentication with the same contract as * {@link org.springframework.security.authentication.AuthenticationManager#authenticate(Authentication)} * . - * * @param authentication the authentication request object. - * * @return a fully authenticated object including credentials. May return * null if the AuthenticationProvider is unable to support * authentication of the passed Authentication object. In such a case, * the next AuthenticationProvider that supports the presented * Authentication class will be tried. - * * @throws AuthenticationException if authentication fails. */ - Authentication authenticate(Authentication authentication) - throws AuthenticationException; + Authentication authenticate(Authentication authentication) throws AuthenticationException; /** * Returns true if this AuthenticationProvider supports the @@ -62,11 +56,10 @@ public interface AuthenticationProvider { * Selection of an AuthenticationProvider capable of performing * authentication is conducted at runtime the ProviderManager. *

      - * * @param authentication - * * @return true if the implementation can more closely evaluate the * Authentication class presented */ boolean supports(Class authentication); + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationServiceException.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationServiceException.java index f6aef41cc3..69d7233bdf 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationServiceException.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationServiceException.java @@ -28,13 +28,10 @@ import org.springframework.security.core.AuthenticationException; * @see InternalAuthenticationServiceException */ public class AuthenticationServiceException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== /** * Constructs an AuthenticationServiceException with the specified * message. - * * @param msg the detail message */ public AuthenticationServiceException(String msg) { @@ -44,11 +41,11 @@ public class AuthenticationServiceException extends AuthenticationException { /** * Constructs an AuthenticationServiceException with the specified * message and root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public AuthenticationServiceException(String msg, Throwable t) { - super(msg, t); + public AuthenticationServiceException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolver.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolver.java index 4192b13765..b0de70d2ce 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolver.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolver.java @@ -24,8 +24,6 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public interface AuthenticationTrustResolver { - // ~ Methods - // ======================================================================================================== /** * Indicates whether the passed Authentication token represents an @@ -34,10 +32,8 @@ public interface AuthenticationTrustResolver { * rejection (i.e. as would be the case if the principal was non-anonymous/fully * authenticated) or direct the principal to attempt actual authentication (i.e. as * would be the case if the Authentication was merely anonymous). - * * @param authentication to test (may be null in which case the method * will always return false) - * * @return true the passed authentication token represented an anonymous * principal, false otherwise */ @@ -50,12 +46,11 @@ public interface AuthenticationTrustResolver { * The method is provided to assist with custom AccessDecisionVoters and * the like that you might develop. Of course, you don't need to use this method * either and can develop your own "trust level" hierarchy instead. - * * @param authentication to test (may be null in which case the method * will always return false) - * * @return true the passed authentication token represented a principal * authenticated using a remember-me token, false otherwise */ boolean isRememberMe(Authentication authentication); + } diff --git a/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolverImpl.java b/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolverImpl.java index 53a9aef58b..a34645cde0 100644 --- a/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolverImpl.java +++ b/core/src/main/java/org/springframework/security/authentication/AuthenticationTrustResolverImpl.java @@ -30,37 +30,33 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public class AuthenticationTrustResolverImpl implements AuthenticationTrustResolver { - // ~ Instance fields - // ================================================================================================ private Class anonymousClass = AnonymousAuthenticationToken.class; + private Class rememberMeClass = RememberMeAuthenticationToken.class; - // ~ Methods - // ======================================================================================================== - Class getAnonymousClass() { - return anonymousClass; + return this.anonymousClass; } Class getRememberMeClass() { - return rememberMeClass; + return this.rememberMeClass; } + @Override public boolean isAnonymous(Authentication authentication) { - if ((anonymousClass == null) || (authentication == null)) { + if ((this.anonymousClass == null) || (authentication == null)) { return false; } - - return anonymousClass.isAssignableFrom(authentication.getClass()); + return this.anonymousClass.isAssignableFrom(authentication.getClass()); } + @Override public boolean isRememberMe(Authentication authentication) { - if ((rememberMeClass == null) || (authentication == null)) { + if ((this.rememberMeClass == null) || (authentication == null)) { return false; } - - return rememberMeClass.isAssignableFrom(authentication.getClass()); + return this.rememberMeClass.isAssignableFrom(authentication.getClass()); } public void setAnonymousClass(Class anonymousClass) { @@ -70,4 +66,5 @@ public class AuthenticationTrustResolverImpl implements AuthenticationTrustResol public void setRememberMeClass(Class rememberMeClass) { this.rememberMeClass = rememberMeClass; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/BadCredentialsException.java b/core/src/main/java/org/springframework/security/authentication/BadCredentialsException.java index 7b050d9078..e202ef7b5a 100644 --- a/core/src/main/java/org/springframework/security/authentication/BadCredentialsException.java +++ b/core/src/main/java/org/springframework/security/authentication/BadCredentialsException.java @@ -25,12 +25,9 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class BadCredentialsException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== /** * Constructs a BadCredentialsException with the specified message. - * * @param msg the detail message */ public BadCredentialsException(String msg) { @@ -40,11 +37,11 @@ public class BadCredentialsException extends AuthenticationException { /** * Constructs a BadCredentialsException with the specified message and * root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public BadCredentialsException(String msg, Throwable t) { - super(msg, t); + public BadCredentialsException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/CachingUserDetailsService.java b/core/src/main/java/org/springframework/security/authentication/CachingUserDetailsService.java index ed5dc9241c..39dc1f6487 100644 --- a/core/src/main/java/org/springframework/security/authentication/CachingUserDetailsService.java +++ b/core/src/main/java/org/springframework/security/authentication/CachingUserDetailsService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; import org.springframework.security.core.userdetails.UserCache; @@ -22,12 +23,13 @@ import org.springframework.security.core.userdetails.cache.NullUserCache; import org.springframework.util.Assert; /** - * * @author Luke Taylor * @since 2.0 */ public class CachingUserDetailsService implements UserDetailsService { + private UserCache userCache = new NullUserCache(); + private final UserDetailsService delegate; public CachingUserDetailsService(UserDetailsService delegate) { @@ -35,26 +37,23 @@ public class CachingUserDetailsService implements UserDetailsService { } public UserCache getUserCache() { - return userCache; + return this.userCache; } public void setUserCache(UserCache userCache) { this.userCache = userCache; } + @Override public UserDetails loadUserByUsername(String username) { - UserDetails user = userCache.getUserFromCache(username); - + UserDetails user = this.userCache.getUserFromCache(username); if (user == null) { - user = delegate.loadUserByUsername(username); + user = this.delegate.loadUserByUsername(username); } - - Assert.notNull(user, () -> "UserDetailsService " + delegate - + " returned null for username " + username + ". " - + "This is an interface contract violation"); - - userCache.putUserInCache(user); - + Assert.notNull(user, () -> "UserDetailsService " + this.delegate + " returned null for username " + username + + ". " + "This is an interface contract violation"); + this.userCache.putUserInCache(user); return user; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/CredentialsExpiredException.java b/core/src/main/java/org/springframework/security/authentication/CredentialsExpiredException.java index 494ca04348..8e532169ae 100644 --- a/core/src/main/java/org/springframework/security/authentication/CredentialsExpiredException.java +++ b/core/src/main/java/org/springframework/security/authentication/CredentialsExpiredException.java @@ -23,12 +23,9 @@ package org.springframework.security.authentication; * @author Ben Alex */ public class CredentialsExpiredException extends AccountStatusException { - // ~ Constructors - // =================================================================================================== /** * Constructs a CredentialsExpiredException with the specified message. - * * @param msg the detail message */ public CredentialsExpiredException(String msg) { @@ -38,11 +35,11 @@ public class CredentialsExpiredException extends AccountStatusException { /** * Constructs a CredentialsExpiredException with the specified message * and root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public CredentialsExpiredException(String msg, Throwable t) { - super(msg, t); + public CredentialsExpiredException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisher.java b/core/src/main/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisher.java index 965a6ee2a9..a5e9ff8619 100644 --- a/core/src/main/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisher.java +++ b/core/src/main/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisher.java @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; -import java.util.*; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.security.authentication.event.AbstractAuthenticationEvent; @@ -58,58 +62,48 @@ import org.springframework.util.Assert; * @author Luke Taylor * @since 3.0 */ -public class DefaultAuthenticationEventPublisher implements AuthenticationEventPublisher, - ApplicationEventPublisherAware { +public class DefaultAuthenticationEventPublisher + implements AuthenticationEventPublisher, ApplicationEventPublisherAware { + private final Log logger = LogFactory.getLog(getClass()); private ApplicationEventPublisher applicationEventPublisher; + private final HashMap> exceptionMappings = new HashMap<>(); + private Constructor defaultAuthenticationFailureEventConstructor; public DefaultAuthenticationEventPublisher() { this(null); } - public DefaultAuthenticationEventPublisher( - ApplicationEventPublisher applicationEventPublisher) { + public DefaultAuthenticationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { this.applicationEventPublisher = applicationEventPublisher; - - addMapping(BadCredentialsException.class.getName(), - AuthenticationFailureBadCredentialsEvent.class); - addMapping(UsernameNotFoundException.class.getName(), - AuthenticationFailureBadCredentialsEvent.class); - addMapping(AccountExpiredException.class.getName(), - AuthenticationFailureExpiredEvent.class); - addMapping(ProviderNotFoundException.class.getName(), - AuthenticationFailureProviderNotFoundEvent.class); - addMapping(DisabledException.class.getName(), - AuthenticationFailureDisabledEvent.class); - addMapping(LockedException.class.getName(), - AuthenticationFailureLockedEvent.class); - addMapping(AuthenticationServiceException.class.getName(), - AuthenticationFailureServiceExceptionEvent.class); - addMapping(CredentialsExpiredException.class.getName(), - AuthenticationFailureCredentialsExpiredEvent.class); - addMapping( - "org.springframework.security.authentication.cas.ProxyUntrustedException", + addMapping(BadCredentialsException.class.getName(), AuthenticationFailureBadCredentialsEvent.class); + addMapping(UsernameNotFoundException.class.getName(), AuthenticationFailureBadCredentialsEvent.class); + addMapping(AccountExpiredException.class.getName(), AuthenticationFailureExpiredEvent.class); + addMapping(ProviderNotFoundException.class.getName(), AuthenticationFailureProviderNotFoundEvent.class); + addMapping(DisabledException.class.getName(), AuthenticationFailureDisabledEvent.class); + addMapping(LockedException.class.getName(), AuthenticationFailureLockedEvent.class); + addMapping(AuthenticationServiceException.class.getName(), AuthenticationFailureServiceExceptionEvent.class); + addMapping(CredentialsExpiredException.class.getName(), AuthenticationFailureCredentialsExpiredEvent.class); + addMapping("org.springframework.security.authentication.cas.ProxyUntrustedException", AuthenticationFailureProxyUntrustedEvent.class); - addMapping( - "org.springframework.security.oauth2.server.resource.InvalidBearerTokenException", + addMapping("org.springframework.security.oauth2.server.resource.InvalidBearerTokenException", AuthenticationFailureBadCredentialsEvent.class); } + @Override public void publishAuthenticationSuccess(Authentication authentication) { - if (applicationEventPublisher != null) { - applicationEventPublisher.publishEvent(new AuthenticationSuccessEvent( - authentication)); + if (this.applicationEventPublisher != null) { + this.applicationEventPublisher.publishEvent(new AuthenticationSuccessEvent(authentication)); } } - public void publishAuthenticationFailure(AuthenticationException exception, - Authentication authentication) { + @Override + public void publishAuthenticationFailure(AuthenticationException exception, Authentication authentication) { Constructor constructor = getEventConstructor(exception); AbstractAuthenticationEvent event = null; - if (constructor != null) { try { event = constructor.newInstance(authentication, exception); @@ -117,57 +111,50 @@ public class DefaultAuthenticationEventPublisher implements AuthenticationEventP catch (IllegalAccessException | InvocationTargetException | InstantiationException ignored) { } } - if (event != null) { - if (applicationEventPublisher != null) { - applicationEventPublisher.publishEvent(event); + if (this.applicationEventPublisher != null) { + this.applicationEventPublisher.publishEvent(event); } } else { - if (logger.isDebugEnabled()) { - logger.debug("No event was found for the exception " - + exception.getClass().getName()); + if (this.logger.isDebugEnabled()) { + this.logger.debug("No event was found for the exception " + exception.getClass().getName()); } } } private Constructor getEventConstructor(AuthenticationException exception) { - Constructor eventConstructor = - this.exceptionMappings.get(exception.getClass().getName()); - return (eventConstructor == null ? this.defaultAuthenticationFailureEventConstructor : eventConstructor); + Constructor eventConstructor = this.exceptionMappings + .get(exception.getClass().getName()); + return (eventConstructor != null) ? eventConstructor : this.defaultAuthenticationFailureEventConstructor; } - public void setApplicationEventPublisher( - ApplicationEventPublisher applicationEventPublisher) { + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { this.applicationEventPublisher = applicationEventPublisher; } /** * Sets additional exception to event mappings. These are automatically merged with * the default exception to event mappings that ProviderManager defines. - * * @param additionalExceptionMappings where keys are the fully-qualified string name * of the exception class and the values are the fully-qualified string name of the * event class to fire. - * * @deprecated use {@link #setAdditionalExceptionMappings(Map)} */ @Deprecated @SuppressWarnings({ "unchecked" }) public void setAdditionalExceptionMappings(Properties additionalExceptionMappings) { - Assert.notNull(additionalExceptionMappings, - "The exceptionMappings object must not be null"); + Assert.notNull(additionalExceptionMappings, "The exceptionMappings object must not be null"); for (Object exceptionClass : additionalExceptionMappings.keySet()) { String eventClass = (String) additionalExceptionMappings.get(exceptionClass); try { Class clazz = getClass().getClassLoader().loadClass(eventClass); Assert.isAssignable(AbstractAuthenticationFailureEvent.class, clazz); - addMapping((String) exceptionClass, - (Class) clazz); + addMapping((String) exceptionClass, (Class) clazz); } - catch (ClassNotFoundException e) { - throw new RuntimeException("Failed to load authentication event class " - + eventClass); + catch (ClassNotFoundException ex) { + throw new RuntimeException("Failed to load authentication event class " + eventClass); } } } @@ -175,29 +162,27 @@ public class DefaultAuthenticationEventPublisher implements AuthenticationEventP /** * Sets additional exception to event mappings. These are automatically merged with * the default exception to event mappings that ProviderManager defines. - * * @param mappings where keys are exception classes and values are event classes. * @since 5.3 */ - public void setAdditionalExceptionMappings(Map, - Class> mappings){ + public void setAdditionalExceptionMappings( + Map, Class> mappings) { Assert.notEmpty(mappings, "The mappings Map must not be empty nor null"); - for (Map.Entry, Class> entry - : mappings.entrySet()) { - Class exceptionClass = entry.getKey(); - Class eventClass = entry.getValue(); - Assert.notNull(exceptionClass, "exceptionClass cannot be null"); - Assert.notNull(eventClass, "eventClass cannot be null"); - addMapping(exceptionClass.getName(), (Class) eventClass); + for (Map.Entry, Class> entry : mappings + .entrySet()) { + Class exceptionClass = entry.getKey(); + Class eventClass = entry.getValue(); + Assert.notNull(exceptionClass, "exceptionClass cannot be null"); + Assert.notNull(eventClass, "eventClass cannot be null"); + addMapping(exceptionClass.getName(), (Class) eventClass); } } /** * Sets a default authentication failure event as a fallback event for any unmapped * exceptions not mapped in the exception mappings. - * - * @param defaultAuthenticationFailureEventClass is the authentication failure event class - * to be fired for unmapped exceptions. + * @param defaultAuthenticationFailureEventClass is the authentication failure event + * class to be fired for unmapped exceptions. */ public void setDefaultAuthenticationFailureEvent( Class defaultAuthenticationFailureEventClass) { @@ -206,22 +191,23 @@ public class DefaultAuthenticationEventPublisher implements AuthenticationEventP try { this.defaultAuthenticationFailureEventConstructor = defaultAuthenticationFailureEventClass .getConstructor(Authentication.class, AuthenticationException.class); - } catch (NoSuchMethodException e) { + } + catch (NoSuchMethodException ex) { throw new RuntimeException("Default Authentication Failure event class " + defaultAuthenticationFailureEventClass.getName() + " has no suitable constructor"); } } - private void addMapping(String exceptionClass, - Class eventClass) { + private void addMapping(String exceptionClass, Class eventClass) { try { Constructor constructor = eventClass .getConstructor(Authentication.class, AuthenticationException.class); - exceptionMappings.put(exceptionClass, constructor); + this.exceptionMappings.put(exceptionClass, constructor); } - catch (NoSuchMethodException e) { - throw new RuntimeException("Authentication event class " - + eventClass.getName() + " has no suitable constructor"); + catch (NoSuchMethodException ex) { + throw new RuntimeException( + "Authentication event class " + eventClass.getName() + " has no suitable constructor"); } } + } diff --git a/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java index 0a2a43d532..7a06a0695b 100644 --- a/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManager.java @@ -19,37 +19,40 @@ package org.springframework.security.authentication; import java.util.Arrays; import java.util.List; -import org.springframework.security.core.Authentication; -import org.springframework.util.Assert; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; + /** - * A {@link ReactiveAuthenticationManager} that delegates to other {@link ReactiveAuthenticationManager} instances using - * the result from the first non empty result. + * A {@link ReactiveAuthenticationManager} that delegates to other + * {@link ReactiveAuthenticationManager} instances using the result from the first non + * empty result. * * @author Rob Winch * @since 5.1 */ -public class DelegatingReactiveAuthenticationManager - implements ReactiveAuthenticationManager { +public class DelegatingReactiveAuthenticationManager implements ReactiveAuthenticationManager { + private final List delegates; - public DelegatingReactiveAuthenticationManager( - ReactiveAuthenticationManager... entryPoints) { + public DelegatingReactiveAuthenticationManager(ReactiveAuthenticationManager... entryPoints) { this(Arrays.asList(entryPoints)); } - public DelegatingReactiveAuthenticationManager( - List entryPoints) { + public DelegatingReactiveAuthenticationManager(List entryPoints) { Assert.notEmpty(entryPoints, "entryPoints cannot be null"); this.delegates = entryPoints; } + @Override public Mono authenticate(Authentication authentication) { + // @formatter:off return Flux.fromIterable(this.delegates) - .concatMap(m -> m.authenticate(authentication)) + .concatMap((m) -> m.authenticate(authentication)) .next(); + // @formatter:on } + } diff --git a/core/src/main/java/org/springframework/security/authentication/DisabledException.java b/core/src/main/java/org/springframework/security/authentication/DisabledException.java index bc48774b89..31a75ce0cc 100644 --- a/core/src/main/java/org/springframework/security/authentication/DisabledException.java +++ b/core/src/main/java/org/springframework/security/authentication/DisabledException.java @@ -23,12 +23,9 @@ package org.springframework.security.authentication; * @author Ben Alex */ public class DisabledException extends AccountStatusException { - // ~ Constructors - // =================================================================================================== /** * Constructs a DisabledException with the specified message. - * * @param msg the detail message */ public DisabledException(String msg) { @@ -38,11 +35,11 @@ public class DisabledException extends AccountStatusException { /** * Constructs a DisabledException with the specified message and root * cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public DisabledException(String msg, Throwable t) { - super(msg, t); + public DisabledException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/InsufficientAuthenticationException.java b/core/src/main/java/org/springframework/security/authentication/InsufficientAuthenticationException.java index e60723f546..0e072b527a 100644 --- a/core/src/main/java/org/springframework/security/authentication/InsufficientAuthenticationException.java +++ b/core/src/main/java/org/springframework/security/authentication/InsufficientAuthenticationException.java @@ -32,13 +32,10 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class InsufficientAuthenticationException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== /** * Constructs an InsufficientAuthenticationException with the specified * message. - * * @param msg the detail message */ public InsufficientAuthenticationException(String msg) { @@ -48,11 +45,11 @@ public class InsufficientAuthenticationException extends AuthenticationException /** * Constructs an InsufficientAuthenticationException with the specified * message and root cause. - * * @param msg the detail message - * @param t root cause + * @param cause root cause */ - public InsufficientAuthenticationException(String msg, Throwable t) { - super(msg, t); + public InsufficientAuthenticationException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/InternalAuthenticationServiceException.java b/core/src/main/java/org/springframework/security/authentication/InternalAuthenticationServiceException.java index 5c08ed4e71..140bbb8f2e 100644 --- a/core/src/main/java/org/springframework/security/authentication/InternalAuthenticationServiceException.java +++ b/core/src/main/java/org/springframework/security/authentication/InternalAuthenticationServiceException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; /** @@ -34,8 +35,7 @@ package org.springframework.security.authentication; * @author Rob Winch * */ -public class InternalAuthenticationServiceException extends - AuthenticationServiceException { +public class InternalAuthenticationServiceException extends AuthenticationServiceException { public InternalAuthenticationServiceException(String message, Throwable cause) { super(message, cause); @@ -44,4 +44,5 @@ public class InternalAuthenticationServiceException extends public InternalAuthenticationServiceException(String message) { super(message); } -} \ No newline at end of file + +} diff --git a/core/src/main/java/org/springframework/security/authentication/LockedException.java b/core/src/main/java/org/springframework/security/authentication/LockedException.java index ff27ec8cc7..9b2272b08f 100644 --- a/core/src/main/java/org/springframework/security/authentication/LockedException.java +++ b/core/src/main/java/org/springframework/security/authentication/LockedException.java @@ -23,12 +23,9 @@ package org.springframework.security.authentication; * @author Ben Alex */ public class LockedException extends AccountStatusException { - // ~ Constructors - // =================================================================================================== /** * Constructs a LockedException with the specified message. - * * @param msg the detail message. */ public LockedException(String msg) { @@ -38,11 +35,11 @@ public class LockedException extends AccountStatusException { /** * Constructs a LockedException with the specified message and root * cause. - * * @param msg the detail message. - * @param t root cause + * @param cause root cause */ - public LockedException(String msg, Throwable t) { - super(msg, t); + public LockedException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/ProviderManager.java b/core/src/main/java/org/springframework/security/authentication/ProviderManager.java index 959309c95e..622c0284b1 100644 --- a/core/src/main/java/org/springframework/security/authentication/ProviderManager.java +++ b/core/src/main/java/org/springframework/security/authentication/ProviderManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; import java.util.Arrays; @@ -21,10 +22,12 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.factory.InitializingBean; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.CredentialsContainer; @@ -80,31 +83,26 @@ import org.springframework.util.CollectionUtils; * {@code AuthenticationManager} if one has been set. So in this situation, the parent * should not generally be configured to publish events or there will be duplicates. * - * * @author Ben Alex * @author Luke Taylor - * * @see DefaultAuthenticationEventPublisher */ -public class ProviderManager implements AuthenticationManager, MessageSourceAware, - InitializingBean { - // ~ Static fields/initializers - // ===================================================================================== +public class ProviderManager implements AuthenticationManager, MessageSourceAware, InitializingBean { private static final Log logger = LogFactory.getLog(ProviderManager.class); - // ~ Instance fields - // ================================================================================================ - private AuthenticationEventPublisher eventPublisher = new NullEventPublisher(); + private List providers = Collections.emptyList(); + protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private AuthenticationManager parent; + private boolean eraseCredentialsAfterAuthentication = true; /** * Construct a {@link ProviderManager} using the given {@link AuthenticationProvider}s - * * @param providers the {@link AuthenticationProvider}s to use */ public ProviderManager(AuthenticationProvider... providers) { @@ -113,7 +111,6 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar /** * Construct a {@link ProviderManager} using the given {@link AuthenticationProvider}s - * * @param providers the {@link AuthenticationProvider}s to use */ public ProviderManager(List providers) { @@ -122,34 +119,26 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar /** * Construct a {@link ProviderManager} using the provided parameters - * * @param providers the {@link AuthenticationProvider}s to use * @param parent a parent {@link AuthenticationManager} to fall back to */ - public ProviderManager(List providers, - AuthenticationManager parent) { + public ProviderManager(List providers, AuthenticationManager parent) { Assert.notNull(providers, "providers list cannot be null"); this.providers = providers; this.parent = parent; checkState(); } - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { checkState(); } private void checkState() { - if (parent == null && providers.isEmpty()) { - throw new IllegalArgumentException( - "A parent AuthenticationManager or a list " - + "of AuthenticationProviders is required"); - } else if (CollectionUtils.contains(providers.iterator(), null)) { - throw new IllegalArgumentException( - "providers list cannot contain null values"); - } + Assert.isTrue(this.parent != null || !this.providers.isEmpty(), + "A parent AuthenticationManager or a list of AuthenticationProviders is required"); + Assert.isTrue(!CollectionUtils.contains(this.providers.iterator(), null), + "providers list cannot contain null values"); } /** @@ -161,139 +150,122 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar * attempted with that AuthenticationProvider. *

      * If more than one AuthenticationProvider supports the passed - * Authentication object, the first one able to successfully - * authenticate the Authentication object determines the - * result, overriding any possible AuthenticationException - * thrown by earlier supporting AuthenticationProviders. - * On successful authentication, no subsequent AuthenticationProviders - * will be tried. - * If authentication was not successful by any supporting - * AuthenticationProvider the last thrown - * AuthenticationException will be rethrown. - * + * Authentication object, the first one able to successfully authenticate + * the Authentication object determines the result, + * overriding any possible AuthenticationException thrown by earlier + * supporting AuthenticationProviders. On successful authentication, no + * subsequent AuthenticationProviders will be tried. If authentication + * was not successful by any supporting AuthenticationProvider the last + * thrown AuthenticationException will be rethrown. * @param authentication the authentication request object. - * * @return a fully authenticated object including credentials. - * * @throws AuthenticationException if authentication fails. */ - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { Class toTest = authentication.getClass(); AuthenticationException lastException = null; AuthenticationException parentException = null; Authentication result = null; Authentication parentResult = null; - boolean debug = logger.isDebugEnabled(); - for (AuthenticationProvider provider : getProviders()) { if (!provider.supports(toTest)) { continue; } - - if (debug) { - logger.debug("Authentication attempt using " - + provider.getClass().getName()); - } - + logger.debug(LogMessage.format("Authentication attempt using %s", provider.getClass().getName())); try { result = provider.authenticate(authentication); - if (result != null) { copyDetails(authentication, result); break; } } - catch (AccountStatusException | InternalAuthenticationServiceException e) { - prepareException(e, authentication); + catch (AccountStatusException | InternalAuthenticationServiceException ex) { + prepareException(ex, authentication); // SEC-546: Avoid polling additional providers if auth failure is due to // invalid account status - throw e; - } catch (AuthenticationException e) { - lastException = e; + throw ex; + } + catch (AuthenticationException ex) { + lastException = ex; } } - - if (result == null && parent != null) { + if (result == null && this.parent != null) { // Allow the parent to try. try { - result = parentResult = parent.authenticate(authentication); + parentResult = this.parent.authenticate(authentication); + result = parentResult; } - catch (ProviderNotFoundException e) { + catch (ProviderNotFoundException ex) { // ignore as we will throw below if no other exception occurred prior to // calling parent and the parent // may throw ProviderNotFound even though a provider in the child already // handled the request } - catch (AuthenticationException e) { - lastException = parentException = e; + catch (AuthenticationException ex) { + parentException = ex; + lastException = ex; } } - if (result != null) { - if (eraseCredentialsAfterAuthentication - && (result instanceof CredentialsContainer)) { + if (this.eraseCredentialsAfterAuthentication && (result instanceof CredentialsContainer)) { // Authentication is complete. Remove credentials and other secret data // from authentication ((CredentialsContainer) result).eraseCredentials(); } - - // If the parent AuthenticationManager was attempted and successful then it will publish an AuthenticationSuccessEvent - // This check prevents a duplicate AuthenticationSuccessEvent if the parent AuthenticationManager already published it + // If the parent AuthenticationManager was attempted and successful then it + // will publish an AuthenticationSuccessEvent + // This check prevents a duplicate AuthenticationSuccessEvent if the parent + // AuthenticationManager already published it if (parentResult == null) { - eventPublisher.publishAuthenticationSuccess(result); + this.eventPublisher.publishAuthenticationSuccess(result); } return result; } // Parent was null, or didn't authenticate (or throw an exception). - if (lastException == null) { - lastException = new ProviderNotFoundException(messages.getMessage( - "ProviderManager.providerNotFound", - new Object[] { toTest.getName() }, - "No AuthenticationProvider found for {0}")); + lastException = new ProviderNotFoundException(this.messages.getMessage("ProviderManager.providerNotFound", + new Object[] { toTest.getName() }, "No AuthenticationProvider found for {0}")); } - - // If the parent AuthenticationManager was attempted and failed then it will publish an AbstractAuthenticationFailureEvent - // This check prevents a duplicate AbstractAuthenticationFailureEvent if the parent AuthenticationManager already published it + // If the parent AuthenticationManager was attempted and failed then it will + // publish an AbstractAuthenticationFailureEvent + // This check prevents a duplicate AbstractAuthenticationFailureEvent if the + // parent AuthenticationManager already published it if (parentException == null) { prepareException(lastException, authentication); } - throw lastException; } @SuppressWarnings("deprecation") private void prepareException(AuthenticationException ex, Authentication auth) { - eventPublisher.publishAuthenticationFailure(ex, auth); + this.eventPublisher.publishAuthenticationFailure(ex, auth); } /** * Copies the authentication details from a source Authentication object to a * destination one, provided the latter does not already have one set. - * * @param source source authentication * @param dest the destination authentication object */ private void copyDetails(Authentication source, Authentication dest) { if ((dest instanceof AbstractAuthenticationToken) && (dest.getDetails() == null)) { AbstractAuthenticationToken token = (AbstractAuthenticationToken) dest; - token.setDetails(source.getDetails()); } } public List getProviders() { - return providers; + return this.providers; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } - public void setAuthenticationEventPublisher( - AuthenticationEventPublisher eventPublisher) { + public void setAuthenticationEventPublisher(AuthenticationEventPublisher eventPublisher) { Assert.notNull(eventPublisher, "AuthenticationEventPublisher cannot be null"); this.eventPublisher = eventPublisher; } @@ -303,7 +275,6 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar * {@code CredentialsContainer} interface will have its * {@link CredentialsContainer#eraseCredentials() eraseCredentials} method called * before it is returned from the {@code authenticate()} method. - * * @param eraseSecretData set to {@literal false} to retain the credentials data in * memory. Defaults to {@literal true}. */ @@ -312,15 +283,19 @@ public class ProviderManager implements AuthenticationManager, MessageSourceAwar } public boolean isEraseCredentialsAfterAuthentication() { - return eraseCredentialsAfterAuthentication; + return this.eraseCredentialsAfterAuthentication; } private static final class NullEventPublisher implements AuthenticationEventPublisher { - public void publishAuthenticationFailure(AuthenticationException exception, - Authentication authentication) { + + @Override + public void publishAuthenticationFailure(AuthenticationException exception, Authentication authentication) { } + @Override public void publishAuthenticationSuccess(Authentication authentication) { } + } + } diff --git a/core/src/main/java/org/springframework/security/authentication/ProviderNotFoundException.java b/core/src/main/java/org/springframework/security/authentication/ProviderNotFoundException.java index 62169c32f7..629a28e8c8 100644 --- a/core/src/main/java/org/springframework/security/authentication/ProviderNotFoundException.java +++ b/core/src/main/java/org/springframework/security/authentication/ProviderNotFoundException.java @@ -26,15 +26,13 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class ProviderNotFoundException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== /** * Constructs a ProviderNotFoundException with the specified message. - * * @param msg the detail message */ public ProviderNotFoundException(String msg) { super(msg); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManager.java index 72d554d48e..92f945a890 100644 --- a/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManager.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; -import org.springframework.security.core.Authentication; - import reactor.core.publisher.Mono; +import org.springframework.security.core.Authentication; + /** * Determines if the provided {@link Authentication} can be authenticated. * @@ -30,11 +31,11 @@ public interface ReactiveAuthenticationManager { /** * Attempts to authenticate the provided {@link Authentication} - * * @param authentication the {@link Authentication} to test * @return if authentication is successful an {@link Authentication} is returned. If * authentication cannot be determined, an empty Mono is returned. If authentication * fails, a Mono error is returned. */ Mono authenticate(Authentication authentication); + } diff --git a/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapter.java b/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapter.java index 087a169b94..9c163c4ef3 100644 --- a/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapter.java +++ b/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapter.java @@ -13,18 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; -import org.springframework.security.core.Authentication; -import org.springframework.util.Assert; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +import org.springframework.security.core.Authentication; +import org.springframework.util.Assert; + /** - * Adapts an AuthenticationManager to the reactive APIs. This is somewhat necessary because many of the ways that - * credentials are stored (i.e. JDBC, LDAP, etc) do not have reactive implementations. What's more is it is generally - * considered best practice to store passwords in a hash that is intentionally slow which would block ever request + * Adapts an AuthenticationManager to the reactive APIs. This is somewhat necessary + * because many of the ways that credentials are stored (i.e. JDBC, LDAP, etc) do not have + * reactive implementations. What's more is it is generally considered best practice to + * store passwords in a hash that is intentionally slow which would block ever request * from coming in unless it was put on another thread. * * @author Rob Winch @@ -32,6 +35,7 @@ import reactor.core.scheduler.Schedulers; * @since 5.0 */ public class ReactiveAuthenticationManagerAdapter implements ReactiveAuthenticationManager { + private final AuthenticationManager authenticationManager; private Scheduler scheduler = Schedulers.boundedElastic(); @@ -43,16 +47,21 @@ public class ReactiveAuthenticationManagerAdapter implements ReactiveAuthenticat @Override public Mono authenticate(Authentication token) { + // @formatter:off return Mono.just(token) - .publishOn(this.scheduler) - .flatMap( t -> { - try { - return Mono.just(authenticationManager.authenticate(t)); - } catch(Throwable error) { - return Mono.error(error); - } - }) - .filter( a -> a.isAuthenticated()); + .publishOn(this.scheduler) + .flatMap(this::doAuthenticate) + .filter(Authentication::isAuthenticated); + // @formatter:on + } + + private Mono doAuthenticate(Authentication authentication) { + try { + return Mono.just(this.authenticationManager.authenticate(authentication)); + } + catch (Throwable ex) { + return Mono.error(ex); + } } /** diff --git a/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerResolver.java b/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerResolver.java index 25c4fc8241..4697c70311 100644 --- a/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerResolver.java +++ b/core/src/main/java/org/springframework/security/authentication/ReactiveAuthenticationManagerResolver.java @@ -16,17 +16,18 @@ package org.springframework.security.authentication; -import org.springframework.security.authentication.ReactiveAuthenticationManager; - import reactor.core.publisher.Mono; /** - * An interface for resolving a {@link ReactiveAuthenticationManager} based on the provided context + * An interface for resolving a {@link ReactiveAuthenticationManager} based on the + * provided context * * @author Rafiullah Hamedy * @since 5.2 */ @FunctionalInterface public interface ReactiveAuthenticationManagerResolver { + Mono resolve(C context); -} \ No newline at end of file + +} diff --git a/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationProvider.java index 0b44ba0fb1..ed9df33a28 100644 --- a/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationProvider.java @@ -32,12 +32,10 @@ import org.springframework.util.Assert; * To be successfully validated, the {@link RememberMeAuthenticationToken#getKeyHash()} * must match this class' {@link #getKey()}. */ -public class RememberMeAuthenticationProvider implements AuthenticationProvider, - InitializingBean, MessageSourceAware { - // ~ Instance fields - // ================================================================================================ +public class RememberMeAuthenticationProvider implements AuthenticationProvider, InitializingBean, MessageSourceAware { protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private String key; public RememberMeAuthenticationProvider(String key) { @@ -45,38 +43,35 @@ public class RememberMeAuthenticationProvider implements AuthenticationProvider, this.key = key; } - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { Assert.notNull(this.messages, "A message source must be set"); } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { if (!supports(authentication.getClass())) { return null; } - - if (this.key.hashCode() != ((RememberMeAuthenticationToken) authentication) - .getKeyHash()) { - throw new BadCredentialsException( - messages.getMessage("RememberMeAuthenticationProvider.incorrectKey", - "The presented RememberMeAuthenticationToken does not contain the expected key")); + if (this.key.hashCode() != ((RememberMeAuthenticationToken) authentication).getKeyHash()) { + throw new BadCredentialsException(this.messages.getMessage("RememberMeAuthenticationProvider.incorrectKey", + "The presented RememberMeAuthenticationToken does not contain the expected key")); } - return authentication; } public String getKey() { - return key; + return this.key; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } + @Override public boolean supports(Class authentication) { return (RememberMeAuthenticationToken.class.isAssignableFrom(authentication)); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationToken.java b/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationToken.java index 1e75db9a50..8a62c9ca2a 100644 --- a/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/authentication/RememberMeAuthenticationToken.java @@ -34,33 +34,23 @@ public class RememberMeAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private final Object principal; - private final int keyHash; - // ~ Constructors - // =================================================================================================== + private final int keyHash; /** * Constructor. - * - * @param key to identify if this object made by an authorised client - * @param principal the principal (typically a UserDetails) + * @param key to identify if this object made by an authorised client + * @param principal the principal (typically a UserDetails) * @param authorities the authorities granted to the principal * @throws IllegalArgumentException if a null was passed */ public RememberMeAuthenticationToken(String key, Object principal, - Collection authorities) { + Collection authorities) { super(authorities); - - if ((key == null) || ("".equals(key)) || (principal == null) - || "".equals(principal)) { - throw new IllegalArgumentException( - "Cannot pass null or empty values to constructor"); + if ((key == null) || ("".equals(key)) || (principal == null) || "".equals(principal)) { + throw new IllegalArgumentException("Cannot pass null or empty values to constructor"); } - this.keyHash = key.hashCode(); this.principal = principal; setAuthenticated(true); @@ -68,26 +58,21 @@ public class RememberMeAuthenticationToken extends AbstractAuthenticationToken { /** * Private Constructor to help in Jackson deserialization. - * - * @param keyHash hashCode of above given key. - * @param principal the principal (typically a UserDetails) + * @param keyHash hashCode of above given key. + * @param principal the principal (typically a UserDetails) * @param authorities the authorities granted to the principal * @since 4.2 */ - private RememberMeAuthenticationToken(Integer keyHash, Object principal, Collection authorities) { + private RememberMeAuthenticationToken(Integer keyHash, Object principal, + Collection authorities) { super(authorities); - this.keyHash = keyHash; this.principal = principal; setAuthenticated(true); } - // ~ Methods - // ======================================================================================================== - /** * Always returns an empty String - * * @return an empty String */ @Override @@ -109,17 +94,13 @@ public class RememberMeAuthenticationToken extends AbstractAuthenticationToken { if (!super.equals(obj)) { return false; } - if (obj instanceof RememberMeAuthenticationToken) { - RememberMeAuthenticationToken test = (RememberMeAuthenticationToken) obj; - - if (this.getKeyHash() != test.getKeyHash()) { + RememberMeAuthenticationToken other = (RememberMeAuthenticationToken) obj; + if (this.getKeyHash() != other.getKeyHash()) { return false; } - return true; } - return false; } @@ -129,4 +110,5 @@ public class RememberMeAuthenticationToken extends AbstractAuthenticationToken { result = 31 * result + this.keyHash; return result; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationProvider.java index ff42fe79c6..ef22c5d1c2 100644 --- a/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationProvider.java @@ -33,15 +33,15 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class TestingAuthenticationProvider implements AuthenticationProvider { - // ~ Methods - // ======================================================================================================== - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { return authentication; } + @Override public boolean supports(Class authentication) { return TestingAuthenticationToken.class.isAssignableFrom(authentication); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationToken.java b/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationToken.java index 9aeb4446b7..8162b45965 100644 --- a/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/authentication/TestingAuthenticationToken.java @@ -16,11 +16,11 @@ package org.springframework.security.authentication; +import java.util.List; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; -import java.util.List; - /** * An {@link org.springframework.security.core.Authentication} implementation that is * designed for use whilst unit testing. @@ -30,15 +30,12 @@ import java.util.List; * @author Ben Alex */ public class TestingAuthenticationToken extends AbstractAuthenticationToken { - // ~ Instance fields - // ================================================================================================ private static final long serialVersionUID = 1L; - private final Object credentials; - private final Object principal; - // ~ Constructors - // =================================================================================================== + private final Object credentials; + + private final Object principal; public TestingAuthenticationToken(Object principal, Object credentials) { super(null); @@ -46,27 +43,25 @@ public class TestingAuthenticationToken extends AbstractAuthenticationToken { this.credentials = credentials; } - public TestingAuthenticationToken(Object principal, Object credentials, - String... authorities) { + public TestingAuthenticationToken(Object principal, Object credentials, String... authorities) { this(principal, credentials, AuthorityUtils.createAuthorityList(authorities)); } - public TestingAuthenticationToken(Object principal, Object credentials, - List authorities) { + public TestingAuthenticationToken(Object principal, Object credentials, List authorities) { super(authorities); this.principal = principal; this.credentials = credentials; setAuthenticated(true); } - // ~ Methods - // ======================================================================================================== - + @Override public Object getCredentials() { return this.credentials; } + @Override public Object getPrincipal() { return this.principal; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java index 5d273beb4b..a27ed5f68d 100644 --- a/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManager.java @@ -23,14 +23,15 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.util.Assert; /** - * A {@link ReactiveAuthenticationManager} that uses a {@link ReactiveUserDetailsService} to validate the provided - * username and password. + * A {@link ReactiveAuthenticationManager} that uses a {@link ReactiveUserDetailsService} + * to validate the provided username and password. * * @author Rob Winch * @author Eddú Meléndez * @since 5.0 */ -public class UserDetailsRepositoryReactiveAuthenticationManager extends AbstractUserDetailsReactiveAuthenticationManager { +public class UserDetailsRepositoryReactiveAuthenticationManager + extends AbstractUserDetailsReactiveAuthenticationManager { private ReactiveUserDetailsService userDetailsService; diff --git a/core/src/main/java/org/springframework/security/authentication/UsernamePasswordAuthenticationToken.java b/core/src/main/java/org/springframework/security/authentication/UsernamePasswordAuthenticationToken.java index 152d9961e6..55963150a6 100644 --- a/core/src/main/java/org/springframework/security/authentication/UsernamePasswordAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/authentication/UsernamePasswordAuthenticationToken.java @@ -20,6 +20,7 @@ import java.util.Collection; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.util.Assert; /** * An {@link org.springframework.security.core.Authentication} implementation that is @@ -36,14 +37,9 @@ public class UsernamePasswordAuthenticationToken extends AbstractAuthenticationT private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private final Object principal; - private Object credentials; - // ~ Constructors - // =================================================================================================== + private Object credentials; /** * This constructor can be safely used by any code that wishes to create a @@ -63,7 +59,6 @@ public class UsernamePasswordAuthenticationToken extends AbstractAuthenticationT * AuthenticationProvider implementations that are satisfied with * producing a trusted (i.e. {@link #isAuthenticated()} = true) * authentication token. - * * @param principal * @param credentials * @param authorities @@ -76,29 +71,27 @@ public class UsernamePasswordAuthenticationToken extends AbstractAuthenticationT super.setAuthenticated(true); // must use super, as we override } - // ~ Methods - // ======================================================================================================== - + @Override public Object getCredentials() { return this.credentials; } + @Override public Object getPrincipal() { return this.principal; } + @Override public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { - if (isAuthenticated) { - throw new IllegalArgumentException( - "Cannot set this token to trusted - use constructor which takes a GrantedAuthority list instead"); - } - + Assert.isTrue(!isAuthenticated, + "Cannot set this token to trusted - use constructor which takes a GrantedAuthority list instead"); super.setAuthenticated(false); } @Override public void eraseCredentials() { super.eraseCredentials(); - credentials = null; + this.credentials = null; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java index c752c81197..953e99455c 100644 --- a/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java @@ -18,6 +18,11 @@ package org.springframework.security.authentication.dao; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.context.MessageSource; +import org.springframework.context.MessageSourceAware; +import org.springframework.context.support.MessageSourceAccessor; import org.springframework.security.authentication.AccountExpiredException; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.BadCredentialsException; @@ -36,13 +41,6 @@ import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.core.userdetails.cache.NullUserCache; - -import org.springframework.beans.factory.InitializingBean; - -import org.springframework.context.MessageSource; -import org.springframework.context.MessageSourceAware; -import org.springframework.context.support.MessageSourceAccessor; - import org.springframework.util.Assert; /** @@ -76,24 +74,24 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public abstract class AbstractUserDetailsAuthenticationProvider implements - AuthenticationProvider, InitializingBean, MessageSourceAware { +public abstract class AbstractUserDetailsAuthenticationProvider + implements AuthenticationProvider, InitializingBean, MessageSourceAware { protected final Log logger = LogFactory.getLog(getClass()); - // ~ Instance fields - // ================================================================================================ - protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); - private UserCache userCache = new NullUserCache(); - private boolean forcePrincipalAsString = false; - protected boolean hideUserNotFoundExceptions = true; - private UserDetailsChecker preAuthenticationChecks = new DefaultPreAuthenticationChecks(); - private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks(); - private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); - // ~ Methods - // ======================================================================================================== + private UserCache userCache = new NullUserCache(); + + private boolean forcePrincipalAsString = false; + + protected boolean hideUserNotFoundExceptions = true; + + private UserDetailsChecker preAuthenticationChecks = new DefaultPreAuthenticationChecks(); + + private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks(); + + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); /** * Allows subclasses to perform any additional checks of a returned (or cached) @@ -103,100 +101,77 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements * properties of UserDetails and/or * UsernamePasswordAuthenticationToken, these should also appear in this * method. - * * @param userDetails as retrieved from the * {@link #retrieveUser(String, UsernamePasswordAuthenticationToken)} or * UserCache * @param authentication the current request that needs to be authenticated - * * @throws AuthenticationException AuthenticationException if the credentials could * not be validated (generally a BadCredentialsException, an * AuthenticationServiceException) */ protected abstract void additionalAuthenticationChecks(UserDetails userDetails, - UsernamePasswordAuthenticationToken authentication) - throws AuthenticationException; + UsernamePasswordAuthenticationToken authentication) throws AuthenticationException; + @Override public final void afterPropertiesSet() throws Exception { Assert.notNull(this.userCache, "A user cache must be set"); Assert.notNull(this.messages, "A message source must be set"); doAfterPropertiesSet(); } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { Assert.isInstanceOf(UsernamePasswordAuthenticationToken.class, authentication, - () -> messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.onlySupports", + () -> this.messages.getMessage("AbstractUserDetailsAuthenticationProvider.onlySupports", "Only UsernamePasswordAuthenticationToken is supported")); - - // Determine username - String username = (authentication.getPrincipal() == null) ? "NONE_PROVIDED" - : authentication.getName(); - + String username = determineUsername(authentication); boolean cacheWasUsed = true; UserDetails user = this.userCache.getUserFromCache(username); - if (user == null) { cacheWasUsed = false; - try { - user = retrieveUser(username, - (UsernamePasswordAuthenticationToken) authentication); + user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication); } - catch (UsernameNotFoundException notFound) { - logger.debug("User '" + username + "' not found"); - - if (hideUserNotFoundExceptions) { - throw new BadCredentialsException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.badCredentials", - "Bad credentials")); - } - else { - throw notFound; + catch (UsernameNotFoundException ex) { + this.logger.debug("User '" + username + "' not found"); + if (!this.hideUserNotFoundExceptions) { + throw ex; } + throw new BadCredentialsException(this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.badCredentials", "Bad credentials")); } - - Assert.notNull(user, - "retrieveUser returned null - a violation of the interface contract"); + Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract"); } - try { - preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, - (UsernamePasswordAuthenticationToken) authentication); + this.preAuthenticationChecks.check(user); + additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); } - catch (AuthenticationException exception) { - if (cacheWasUsed) { - // There was a problem, so try again after checking - // we're using latest data (i.e. not from the cache) - cacheWasUsed = false; - user = retrieveUser(username, - (UsernamePasswordAuthenticationToken) authentication); - preAuthenticationChecks.check(user); - additionalAuthenticationChecks(user, - (UsernamePasswordAuthenticationToken) authentication); - } - else { - throw exception; + catch (AuthenticationException ex) { + if (!cacheWasUsed) { + throw ex; } + // There was a problem, so try again after checking + // we're using latest data (i.e. not from the cache) + cacheWasUsed = false; + user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication); + this.preAuthenticationChecks.check(user); + additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication); } - - postAuthenticationChecks.check(user); - + this.postAuthenticationChecks.check(user); if (!cacheWasUsed) { this.userCache.putUserInCache(user); } - Object principalToReturn = user; - - if (forcePrincipalAsString) { + if (this.forcePrincipalAsString) { principalToReturn = user.getUsername(); } - return createSuccessAuthentication(principalToReturn, authentication, user); } + private String determineUsername(Authentication authentication) { + return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName(); + } + /** * Creates a successful {@link Authentication} object. *

      @@ -206,25 +181,21 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements * Subclasses will usually store the original credentials the user supplied (not * salted or encoded passwords) in the returned Authentication object. *

      - * * @param principal that should be the principal in the returned object (defined by * the {@link #isForcePrincipalAsString()} method) * @param authentication that was presented to the provider for validation * @param user that was loaded by the implementation - * * @return the successful authentication token */ - protected Authentication createSuccessAuthentication(Object principal, - Authentication authentication, UserDetails user) { + protected Authentication createSuccessAuthentication(Object principal, Authentication authentication, + UserDetails user) { // Ensure we return the original credentials the user supplied, // so subsequent attempts are successful even with encoded passwords. // Also ensure we return the original getDetails(), so that future // authentication events after cache expiry contain the details - UsernamePasswordAuthenticationToken result = new UsernamePasswordAuthenticationToken( - principal, authentication.getCredentials(), - authoritiesMapper.mapAuthorities(user.getAuthorities())); + UsernamePasswordAuthenticationToken result = new UsernamePasswordAuthenticationToken(principal, + authentication.getCredentials(), this.authoritiesMapper.mapAuthorities(user.getAuthorities())); result.setDetails(authentication.getDetails()); - return result; } @@ -232,15 +203,15 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements } public UserCache getUserCache() { - return userCache; + return this.userCache; } public boolean isForcePrincipalAsString() { - return forcePrincipalAsString; + return this.forcePrincipalAsString; } public boolean isHideUserNotFoundExceptions() { - return hideUserNotFoundExceptions; + return this.hideUserNotFoundExceptions; } /** @@ -271,21 +242,17 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements * so that code related to credentials validation need not be duplicated across two * methods. *

      - * * @param username The username to retrieve * @param authentication The authentication request, which subclasses may * need to perform a binding-based retrieval of the UserDetails - * * @return the user information (never null - instead an exception should * the thrown) - * * @throws AuthenticationException if the credentials could not be validated * (generally a BadCredentialsException, an * AuthenticationServiceException or * UsernameNotFoundException) */ - protected abstract UserDetails retrieveUser(String username, - UsernamePasswordAuthenticationToken authentication) + protected abstract UserDetails retrieveUser(String username, UsernamePasswordAuthenticationToken authentication) throws AuthenticationException; public void setForcePrincipalAsString(boolean forcePrincipalAsString) { @@ -299,7 +266,6 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements * UsernameNotFoundExceptions to be thrown instead for the former. Note * this is considered less secure than throwing BadCredentialsException * for both exceptions. - * * @param hideUserNotFoundExceptions set to false if you wish * UsernameNotFoundExceptions to be thrown instead of the non-specific * BadCredentialsException (defaults to true) @@ -308,6 +274,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements this.hideUserNotFoundExceptions = hideUserNotFoundExceptions; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } @@ -316,19 +283,18 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements this.userCache = userCache; } + @Override public boolean supports(Class authentication) { - return (UsernamePasswordAuthenticationToken.class - .isAssignableFrom(authentication)); + return (UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication)); } protected UserDetailsChecker getPreAuthenticationChecks() { - return preAuthenticationChecks; + return this.preAuthenticationChecks; } /** * Sets the policy will be used to verify the status of the loaded * UserDetails before validation of the credentials takes place. - * * @param preAuthenticationChecks strategy to be invoked prior to authentication. */ public void setPreAuthenticationChecks(UserDetailsChecker preAuthenticationChecks) { @@ -336,7 +302,7 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements } protected UserDetailsChecker getPostAuthenticationChecks() { - return postAuthenticationChecks; + return this.postAuthenticationChecks; } public void setPostAuthenticationChecks(UserDetailsChecker postAuthenticationChecks) { @@ -348,42 +314,40 @@ public abstract class AbstractUserDetailsAuthenticationProvider implements } private class DefaultPreAuthenticationChecks implements UserDetailsChecker { + + @Override public void check(UserDetails user) { if (!user.isAccountNonLocked()) { - logger.debug("User account is locked"); - - throw new LockedException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.locked", - "User account is locked")); + AbstractUserDetailsAuthenticationProvider.this.logger.debug("User account is locked"); + throw new LockedException(AbstractUserDetailsAuthenticationProvider.this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.locked", "User account is locked")); } - if (!user.isEnabled()) { - logger.debug("User account is disabled"); - - throw new DisabledException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.disabled", - "User is disabled")); + AbstractUserDetailsAuthenticationProvider.this.logger.debug("User account is disabled"); + throw new DisabledException(AbstractUserDetailsAuthenticationProvider.this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.disabled", "User is disabled")); } - if (!user.isAccountNonExpired()) { - logger.debug("User account is expired"); - - throw new AccountExpiredException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.expired", - "User account has expired")); + AbstractUserDetailsAuthenticationProvider.this.logger.debug("User account is expired"); + throw new AccountExpiredException(AbstractUserDetailsAuthenticationProvider.this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.expired", "User account has expired")); } } + } private class DefaultPostAuthenticationChecks implements UserDetailsChecker { + + @Override public void check(UserDetails user) { if (!user.isCredentialsNonExpired()) { - logger.debug("User account credentials have expired"); - - throw new CredentialsExpiredException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.credentialsExpired", - "User credentials have expired")); + AbstractUserDetailsAuthenticationProvider.this.logger.debug("User account credentials have expired"); + throw new CredentialsExpiredException(AbstractUserDetailsAuthenticationProvider.this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.credentialsExpired", + "User credentials have expired")); } } + } + } diff --git a/core/src/main/java/org/springframework/security/authentication/dao/DaoAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/dao/DaoAuthenticationProvider.java index 2771f9790a..245101592f 100644 --- a/core/src/main/java/org/springframework/security/authentication/dao/DaoAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/dao/DaoAuthenticationProvider.java @@ -23,11 +23,11 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.crypto.factory.PasswordEncoderFactories; import org.springframework.security.crypto.password.PasswordEncoder; -import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.util.Assert; /** @@ -38,25 +38,18 @@ import org.springframework.util.Assert; * @author Rob Winch */ public class DaoAuthenticationProvider extends AbstractUserDetailsAuthenticationProvider { - // ~ Static fields/initializers - // ===================================================================================== /** - * The plaintext password used to perform - * PasswordEncoder#matches(CharSequence, String)} on when the user is - * not found to avoid SEC-2056. + * The plaintext password used to perform PasswordEncoder#matches(CharSequence, + * String)} on when the user is not found to avoid SEC-2056. */ private static final String USER_NOT_FOUND_PASSWORD = "userNotFoundPassword"; - // ~ Instance fields - // ================================================================================================ - private PasswordEncoder passwordEncoder; /** - * The password used to perform - * {@link PasswordEncoder#matches(CharSequence, String)} on when the user is - * not found to avoid SEC-2056. This is necessary, because some + * The password used to perform {@link PasswordEncoder#matches(CharSequence, String)} + * on when the user is not found to avoid SEC-2056. This is necessary, because some * {@link PasswordEncoder} implementations will short circuit if the password is not * in a valid format. */ @@ -70,38 +63,30 @@ public class DaoAuthenticationProvider extends AbstractUserDetailsAuthentication setPasswordEncoder(PasswordEncoderFactories.createDelegatingPasswordEncoder()); } - // ~ Methods - // ======================================================================================================== - + @Override @SuppressWarnings("deprecation") protected void additionalAuthenticationChecks(UserDetails userDetails, - UsernamePasswordAuthenticationToken authentication) - throws AuthenticationException { + UsernamePasswordAuthenticationToken authentication) throws AuthenticationException { if (authentication.getCredentials() == null) { - logger.debug("Authentication failed: no credentials provided"); - - throw new BadCredentialsException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.badCredentials", - "Bad credentials")); + this.logger.debug("Authentication failed: no credentials provided"); + throw new BadCredentialsException(this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.badCredentials", "Bad credentials")); } - String presentedPassword = authentication.getCredentials().toString(); - - if (!passwordEncoder.matches(presentedPassword, userDetails.getPassword())) { - logger.debug("Authentication failed: password does not match stored value"); - - throw new BadCredentialsException(messages.getMessage( - "AbstractUserDetailsAuthenticationProvider.badCredentials", - "Bad credentials")); + if (!this.passwordEncoder.matches(presentedPassword, userDetails.getPassword())) { + this.logger.debug("Authentication failed: password does not match stored value"); + throw new BadCredentialsException(this.messages + .getMessage("AbstractUserDetailsAuthenticationProvider.badCredentials", "Bad credentials")); } } + @Override protected void doAfterPropertiesSet() { Assert.notNull(this.userDetailsService, "A UserDetailsService must be set"); } - protected final UserDetails retrieveUser(String username, - UsernamePasswordAuthenticationToken authentication) + @Override + protected final UserDetails retrieveUser(String username, UsernamePasswordAuthenticationToken authentication) throws AuthenticationException { prepareTimingAttackProtection(); try { @@ -125,8 +110,8 @@ public class DaoAuthenticationProvider extends AbstractUserDetailsAuthentication } @Override - protected Authentication createSuccessAuthentication(Object principal, - Authentication authentication, UserDetails user) { + protected Authentication createSuccessAuthentication(Object principal, Authentication authentication, + UserDetails user) { boolean upgradeEncoding = this.userDetailsPasswordService != null && this.passwordEncoder.upgradeEncoding(user.getPassword()); if (upgradeEncoding) { @@ -152,8 +137,8 @@ public class DaoAuthenticationProvider extends AbstractUserDetailsAuthentication /** * Sets the PasswordEncoder instance to be used to encode and validate passwords. If - * not set, the password will be compared using {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()} - * + * not set, the password will be compared using + * {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()} * @param passwordEncoder must be an instance of one of the {@code PasswordEncoder} * types. */ @@ -164,7 +149,7 @@ public class DaoAuthenticationProvider extends AbstractUserDetailsAuthentication } protected PasswordEncoder getPasswordEncoder() { - return passwordEncoder; + return this.passwordEncoder; } public void setUserDetailsService(UserDetailsService userDetailsService) { @@ -172,11 +157,11 @@ public class DaoAuthenticationProvider extends AbstractUserDetailsAuthentication } protected UserDetailsService getUserDetailsService() { - return userDetailsService; + return this.userDetailsService; } - public void setUserDetailsPasswordService( - UserDetailsPasswordService userDetailsPasswordService) { + public void setUserDetailsPasswordService(UserDetailsPasswordService userDetailsPasswordService) { this.userDetailsPasswordService = userDetailsPasswordService; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/dao/package-info.java b/core/src/main/java/org/springframework/security/authentication/dao/package-info.java index a171b83750..b5668f899e 100644 --- a/core/src/main/java/org/springframework/security/authentication/dao/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/dao/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * An {@code AuthenticationProvider} which relies upon a data access object. */ package org.springframework.security.authentication.dao; - diff --git a/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationEvent.java index 3102a7f688..20c4ed54ba 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationEvent.java @@ -16,9 +16,8 @@ package org.springframework.security.authentication.event; -import org.springframework.security.core.Authentication; - import org.springframework.context.ApplicationEvent; +import org.springframework.security.core.Authentication; /** * Represents an application authentication event. @@ -30,23 +29,18 @@ import org.springframework.context.ApplicationEvent; * @author Ben Alex */ public abstract class AbstractAuthenticationEvent extends ApplicationEvent { - // ~ Constructors - // =================================================================================================== public AbstractAuthenticationEvent(Authentication authentication) { super(authentication); } - // ~ Methods - // ======================================================================================================== - /** * Getters for the Authentication request that caused the event. Also * available from super.getSource(). - * * @return the authentication request */ public Authentication getAuthentication() { return (Authentication) super.getSource(); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationFailureEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationFailureEvent.java index 74dec87a9e..0f06d0b94c 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationFailureEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AbstractAuthenticationFailureEvent.java @@ -18,7 +18,6 @@ package org.springframework.security.authentication.event; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; - import org.springframework.util.Assert; /** @@ -26,27 +25,18 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public abstract class AbstractAuthenticationFailureEvent extends - AbstractAuthenticationEvent { - // ~ Instance fields - // ================================================================================================ +public abstract class AbstractAuthenticationFailureEvent extends AbstractAuthenticationEvent { private final AuthenticationException exception; - // ~ Constructors - // =================================================================================================== - - public AbstractAuthenticationFailureEvent(Authentication authentication, - AuthenticationException exception) { + public AbstractAuthenticationFailureEvent(Authentication authentication, AuthenticationException exception) { super(authentication); Assert.notNull(exception, "AuthenticationException is required"); this.exception = exception; } - // ~ Methods - // ======================================================================================================== - public AuthenticationException getException() { - return exception; + return this.exception; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureBadCredentialsEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureBadCredentialsEvent.java index c1edd6f79a..796690b0e6 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureBadCredentialsEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureBadCredentialsEvent.java @@ -25,13 +25,10 @@ import org.springframework.security.core.AuthenticationException; * * @author Ben Alex */ -public class AuthenticationFailureBadCredentialsEvent extends - AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== +public class AuthenticationFailureBadCredentialsEvent extends AbstractAuthenticationFailureEvent { - public AuthenticationFailureBadCredentialsEvent(Authentication authentication, - AuthenticationException exception) { + public AuthenticationFailureBadCredentialsEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureCredentialsExpiredEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureCredentialsExpiredEvent.java index 2414f05ae9..57f218a239 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureCredentialsExpiredEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureCredentialsExpiredEvent.java @@ -25,13 +25,11 @@ import org.springframework.security.core.AuthenticationException; * * @author Ben Alex */ -public class AuthenticationFailureCredentialsExpiredEvent extends - AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== +public class AuthenticationFailureCredentialsExpiredEvent extends AbstractAuthenticationFailureEvent { public AuthenticationFailureCredentialsExpiredEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureDisabledEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureDisabledEvent.java index d9a8dd22c1..3a4604354f 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureDisabledEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureDisabledEvent.java @@ -25,13 +25,10 @@ import org.springframework.security.core.AuthenticationException; * * @author Ben Alex */ -public class AuthenticationFailureDisabledEvent extends - AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== +public class AuthenticationFailureDisabledEvent extends AbstractAuthenticationFailureEvent { - public AuthenticationFailureDisabledEvent(Authentication authentication, - AuthenticationException exception) { + public AuthenticationFailureDisabledEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureExpiredEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureExpiredEvent.java index e95bcf3eae..086e16cb37 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureExpiredEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureExpiredEvent.java @@ -26,11 +26,9 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class AuthenticationFailureExpiredEvent extends AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== - public AuthenticationFailureExpiredEvent(Authentication authentication, - AuthenticationException exception) { + public AuthenticationFailureExpiredEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureLockedEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureLockedEvent.java index 0e0e88b6a7..544964cdec 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureLockedEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureLockedEvent.java @@ -26,11 +26,9 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class AuthenticationFailureLockedEvent extends AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== - public AuthenticationFailureLockedEvent(Authentication authentication, - AuthenticationException exception) { + public AuthenticationFailureLockedEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProviderNotFoundEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProviderNotFoundEvent.java index e3e9f2d818..1a1cf7c87e 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProviderNotFoundEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProviderNotFoundEvent.java @@ -25,13 +25,11 @@ import org.springframework.security.core.AuthenticationException; * * @author Ben Alex */ -public class AuthenticationFailureProviderNotFoundEvent extends - AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== +public class AuthenticationFailureProviderNotFoundEvent extends AbstractAuthenticationFailureEvent { public AuthenticationFailureProviderNotFoundEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProxyUntrustedEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProxyUntrustedEvent.java index 667076bc06..772774d3f1 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProxyUntrustedEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureProxyUntrustedEvent.java @@ -25,13 +25,10 @@ import org.springframework.security.core.AuthenticationException; * * @author Ben Alex */ -public class AuthenticationFailureProxyUntrustedEvent extends - AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== +public class AuthenticationFailureProxyUntrustedEvent extends AbstractAuthenticationFailureEvent { - public AuthenticationFailureProxyUntrustedEvent(Authentication authentication, - AuthenticationException exception) { + public AuthenticationFailureProxyUntrustedEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureServiceExceptionEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureServiceExceptionEvent.java index 52869e5357..167d5fae3b 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureServiceExceptionEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationFailureServiceExceptionEvent.java @@ -25,13 +25,11 @@ import org.springframework.security.core.AuthenticationException; * * @author Ben Alex */ -public class AuthenticationFailureServiceExceptionEvent extends - AbstractAuthenticationFailureEvent { - // ~ Constructors - // =================================================================================================== +public class AuthenticationFailureServiceExceptionEvent extends AbstractAuthenticationFailureEvent { public AuthenticationFailureServiceExceptionEvent(Authentication authentication, AuthenticationException exception) { super(authentication, exception); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationSuccessEvent.java b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationSuccessEvent.java index 67a7b18f11..5b3b9bcd24 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/AuthenticationSuccessEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/AuthenticationSuccessEvent.java @@ -24,10 +24,9 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public class AuthenticationSuccessEvent extends AbstractAuthenticationEvent { - // ~ Constructors - // =================================================================================================== public AuthenticationSuccessEvent(Authentication authentication) { super(authentication); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/InteractiveAuthenticationSuccessEvent.java b/core/src/main/java/org/springframework/security/authentication/event/InteractiveAuthenticationSuccessEvent.java index db60e208f7..c93d2a9165 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/InteractiveAuthenticationSuccessEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/event/InteractiveAuthenticationSuccessEvent.java @@ -17,7 +17,6 @@ package org.springframework.security.authentication.event; import org.springframework.security.core.Authentication; - import org.springframework.util.Assert; /** @@ -34,31 +33,22 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class InteractiveAuthenticationSuccessEvent extends AbstractAuthenticationEvent { - // ~ Instance fields - // ================================================================================================ private final Class generatedBy; - // ~ Constructors - // =================================================================================================== - - public InteractiveAuthenticationSuccessEvent(Authentication authentication, - Class generatedBy) { + public InteractiveAuthenticationSuccessEvent(Authentication authentication, Class generatedBy) { super(authentication); Assert.notNull(generatedBy, "generatedBy cannot be null"); this.generatedBy = generatedBy; } - // ~ Methods - // ======================================================================================================== - /** * Getter for the Class that generated this event. This can be useful for * generating additional logging information. - * * @return the class */ public Class getGeneratedBy() { - return generatedBy; + return this.generatedBy; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/LoggerListener.java b/core/src/main/java/org/springframework/security/authentication/event/LoggerListener.java index 0d23755dea..92fda0f52d 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/LoggerListener.java +++ b/core/src/main/java/org/springframework/security/authentication/event/LoggerListener.java @@ -18,7 +18,9 @@ package org.springframework.security.authentication.event; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.context.ApplicationListener; +import org.springframework.core.log.LogMessage; import org.springframework.util.ClassUtils; /** @@ -29,8 +31,6 @@ import org.springframework.util.ClassUtils; * @author Ben Alex */ public class LoggerListener implements ApplicationListener { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(LoggerListener.class); @@ -40,40 +40,35 @@ public class LoggerListener implements ApplicationListener getLogMessage(event))); + } - if (logger.isWarnEnabled()) { - final StringBuilder builder = new StringBuilder(); - builder.append("Authentication event "); - builder.append(ClassUtils.getShortName(event.getClass())); - builder.append(": "); - builder.append(event.getAuthentication().getName()); - builder.append("; details: "); - builder.append(event.getAuthentication().getDetails()); - - if (event instanceof AbstractAuthenticationFailureEvent) { - builder.append("; exception: "); - builder.append(((AbstractAuthenticationFailureEvent) event) - .getException().getMessage()); - } - - logger.warn(builder.toString()); + private String getLogMessage(AbstractAuthenticationEvent event) { + StringBuilder builder = new StringBuilder(); + builder.append("Authentication event "); + builder.append(ClassUtils.getShortName(event.getClass())); + builder.append(": "); + builder.append(event.getAuthentication().getName()); + builder.append("; details: "); + builder.append(event.getAuthentication().getDetails()); + if (event instanceof AbstractAuthenticationFailureEvent) { + builder.append("; exception: "); + builder.append(((AbstractAuthenticationFailureEvent) event).getException().getMessage()); } + return builder.toString(); } public boolean isLogInteractiveAuthenticationSuccessEvents() { - return logInteractiveAuthenticationSuccessEvents; + return this.logInteractiveAuthenticationSuccessEvents; } - public void setLogInteractiveAuthenticationSuccessEvents( - boolean logInteractiveAuthenticationSuccessEvents) { + public void setLogInteractiveAuthenticationSuccessEvents(boolean logInteractiveAuthenticationSuccessEvents) { this.logInteractiveAuthenticationSuccessEvents = logInteractiveAuthenticationSuccessEvents; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/event/package-info.java b/core/src/main/java/org/springframework/security/authentication/event/package-info.java index 006bfe1fd9..b55823abf4 100644 --- a/core/src/main/java/org/springframework/security/authentication/event/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/event/package-info.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Authentication success and failure events which can be published to the Spring application context. + * Authentication success and failure events which can be published to the Spring + * application context. * - * The ProviderManager automatically publishes events to the application context. These events are - * received by all registered Spring ApplicationListeners. + * The ProviderManager automatically publishes events to the application + * context. These events are received by all registered Spring + * ApplicationListeners. */ - package org.springframework.security.authentication.event; - diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/AbstractJaasAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/jaas/AbstractJaasAuthenticationProvider.java index 54552a1e0f..d9327ff437 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/AbstractJaasAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/AbstractJaasAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication.jaas; import java.io.IOException; @@ -35,6 +36,7 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.ApplicationListener; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.jaas.event.JaasAuthenticationFailedEvent; @@ -45,6 +47,7 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.session.SessionDestroyedEvent; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; /** @@ -64,8 +67,8 @@ import org.springframework.util.ObjectUtils; * *

      * When using JAAS login modules as the authentication source, sometimes the - * LoginContext will require CallbackHandlers. The + * "https://java.sun.com/j2se/1.5.0/docs/api/javax/security/auth/login/LoginContext.html" + * > LoginContext will require CallbackHandlers. The * AbstractJaasAuthenticationProvider uses an internal CallbackHandler to wrap the {@link JaasAuthenticationCallbackHandler}s configured @@ -113,21 +116,20 @@ import org.springframework.util.ObjectUtils; * @author Ray Krueger * @author Rob Winch */ -public abstract class AbstractJaasAuthenticationProvider - implements AuthenticationProvider, ApplicationEventPublisherAware, - InitializingBean, ApplicationListener { - // ~ Instance fields - // ================================================================================================ +public abstract class AbstractJaasAuthenticationProvider implements AuthenticationProvider, + ApplicationEventPublisherAware, InitializingBean, ApplicationListener { private ApplicationEventPublisher applicationEventPublisher; - private AuthorityGranter[] authorityGranters; - private JaasAuthenticationCallbackHandler[] callbackHandlers; - protected final Log log = LogFactory.getLog(getClass()); - private LoginExceptionResolver loginExceptionResolver = new DefaultLoginExceptionResolver(); - private String loginContextName = "SPRINGSECURITY"; - // ~ Methods - // ======================================================================================================== + private AuthorityGranter[] authorityGranters; + + private JaasAuthenticationCallbackHandler[] callbackHandlers; + + protected final Log log = LogFactory.getLog(getClass()); + + private LoginExceptionResolver loginExceptionResolver = new DefaultLoginExceptionResolver(); + + private String loginContextName = "SPRINGSECURITY"; /** * Validates the required properties are set. In addition, if @@ -135,149 +137,127 @@ public abstract class AbstractJaasAuthenticationProvider * called with valid handlers, initializes to use {@link JaasNameCallbackHandler} and * {@link JaasPasswordCallbackHandler}. */ + @Override public void afterPropertiesSet() throws Exception { - Assert.hasLength(this.loginContextName, - "loginContextName cannot be null or empty"); - Assert.notEmpty(this.authorityGranters, - "authorityGranters cannot be null or empty"); + Assert.hasLength(this.loginContextName, "loginContextName cannot be null or empty"); + Assert.notEmpty(this.authorityGranters, "authorityGranters cannot be null or empty"); if (ObjectUtils.isEmpty(this.callbackHandlers)) { - setCallbackHandlers(new JaasAuthenticationCallbackHandler[] { - new JaasNameCallbackHandler(), new JaasPasswordCallbackHandler() }); + setCallbackHandlers(new JaasAuthenticationCallbackHandler[] { new JaasNameCallbackHandler(), + new JaasPasswordCallbackHandler() }); } - Assert.notNull(this.loginExceptionResolver, - "loginExceptionResolver cannot be null"); + Assert.notNull(this.loginExceptionResolver, "loginExceptionResolver cannot be null"); } /** * Attempts to login the user given the Authentication objects principal and * credential - * * @param auth The Authentication object to be authenticated. - * * @return The authenticated Authentication object, with it's grantedAuthorities set. - * * @throws AuthenticationException This implementation does not handle 'locked' or * 'disabled' accounts. This method only throws a AuthenticationServiceException, with * the message of the LoginException that will be thrown, should the * loginContext.login() method fail. */ - public Authentication authenticate(Authentication auth) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication auth) throws AuthenticationException { if (!(auth instanceof UsernamePasswordAuthenticationToken)) { return null; } - UsernamePasswordAuthenticationToken request = (UsernamePasswordAuthenticationToken) auth; Set authorities; - try { // Create the LoginContext object, and pass our InternallCallbackHandler - LoginContext loginContext = createLoginContext( - new InternalCallbackHandler(auth)); - + LoginContext loginContext = createLoginContext(new InternalCallbackHandler(auth)); // Attempt to login the user, the LoginContext will call our // InternalCallbackHandler at this point. loginContext.login(); - - // Create a set to hold the authorities, and add any that have already been - // applied. - authorities = new HashSet<>(); - // Get the subject principals and pass them to each of the AuthorityGranters Set principals = loginContext.getSubject().getPrincipals(); - - for (Principal principal : principals) { - for (AuthorityGranter granter : this.authorityGranters) { - Set roles = granter.grant(principal); - - // If the granter doesn't wish to grant any authorities, it should - // return null. - if ((roles != null) && !roles.isEmpty()) { - for (String role : roles) { - authorities.add(new JaasGrantedAuthority(role, principal)); - } - } - } - } - + // Create a set to hold the authorities, and add any that have already been + // applied. + authorities = getAuthorities(principals); // Convert the authorities set back to an array and apply it to the token. - JaasAuthenticationToken result = new JaasAuthenticationToken( - request.getPrincipal(), request.getCredentials(), - new ArrayList<>(authorities), loginContext); - + JaasAuthenticationToken result = new JaasAuthenticationToken(request.getPrincipal(), + request.getCredentials(), new ArrayList<>(authorities), loginContext); // Publish the success event publishSuccessEvent(result); - // we're done, return the token. return result; } - catch (LoginException loginException) { - AuthenticationException ase = this.loginExceptionResolver - .resolveException(loginException); - - publishFailureEvent(request, ase); - throw ase; + catch (LoginException ex) { + AuthenticationException resolvedException = this.loginExceptionResolver.resolveException(ex); + publishFailureEvent(request, resolvedException); + throw resolvedException; } } + private Set getAuthorities(Set principals) { + Set authorities; + authorities = new HashSet<>(); + for (Principal principal : principals) { + for (AuthorityGranter granter : this.authorityGranters) { + Set roles = granter.grant(principal); + // If the granter doesn't wish to grant any authorities, + // it should return null. + if (!CollectionUtils.isEmpty(roles)) { + for (String role : roles) { + authorities.add(new JaasGrantedAuthority(role, principal)); + } + } + } + } + return authorities; + } + /** * Creates the LoginContext to be used for authentication. - * * @param handler The CallbackHandler that should be used for the LoginContext (never * null). * @return the LoginContext to use for authentication. * @throws LoginException */ - protected abstract LoginContext createLoginContext(CallbackHandler handler) - throws LoginException; + protected abstract LoginContext createLoginContext(CallbackHandler handler) throws LoginException; /** * Handles the logout by getting the security contexts for the destroyed session and * invoking {@code LoginContext.logout()} for any which contain a * {@code JaasAuthenticationToken}. - * - * * @param event the session event which contains the current session */ protected void handleLogout(SessionDestroyedEvent event) { List contexts = event.getSecurityContexts(); - if (contexts.isEmpty()) { this.log.debug("The destroyed session has no SecurityContexts"); - return; } - for (SecurityContext context : contexts) { Authentication auth = context.getAuthentication(); - if ((auth != null) && (auth instanceof JaasAuthenticationToken)) { JaasAuthenticationToken token = (JaasAuthenticationToken) auth; - try { LoginContext loginContext = token.getLoginContext(); - boolean debug = this.log.isDebugEnabled(); - if (loginContext != null) { - if (debug) { - this.log.debug("Logging principal: [" + token.getPrincipal() - + "] out of LoginContext"); - } - loginContext.logout(); - } - else if (debug) { - this.log.debug("Cannot logout principal: [" + token.getPrincipal() - + "] from LoginContext. " - + "The LoginContext is unavailable"); - } + logout(token, loginContext); } - catch (LoginException e) { - this.log.warn("Error error logging out of LoginContext", e); + catch (LoginException ex) { + this.log.warn("Error error logging out of LoginContext", ex); } } } } + private void logout(JaasAuthenticationToken token, LoginContext loginContext) throws LoginException { + if (loginContext != null) { + this.log.debug( + LogMessage.of(() -> "Logging principal: [" + token.getPrincipal() + "] out of LoginContext")); + loginContext.logout(); + return; + } + this.log.debug(LogMessage.of(() -> "Cannot logout principal: [" + token.getPrincipal() + + "] from LoginContext. The LoginContext is unavailable")); + } + + @Override public void onApplicationEvent(SessionDestroyedEvent event) { handleLogout(event); } @@ -285,28 +265,23 @@ public abstract class AbstractJaasAuthenticationProvider /** * Publishes the {@link JaasAuthenticationFailedEvent}. Can be overridden by * subclasses for different functionality - * * @param token The authentication token being processed * @param ase The excetion that caused the authentication failure */ - protected void publishFailureEvent(UsernamePasswordAuthenticationToken token, - AuthenticationException ase) { + protected void publishFailureEvent(UsernamePasswordAuthenticationToken token, AuthenticationException ase) { if (this.applicationEventPublisher != null) { - this.applicationEventPublisher - .publishEvent(new JaasAuthenticationFailedEvent(token, ase)); + this.applicationEventPublisher.publishEvent(new JaasAuthenticationFailedEvent(token, ase)); } } /** * Publishes the {@link JaasAuthenticationSuccessEvent}. Can be overridden by * subclasses for different functionality. - * * @param token The token being processed */ protected void publishSuccessEvent(UsernamePasswordAuthenticationToken token) { if (this.applicationEventPublisher != null) { - this.applicationEventPublisher - .publishEvent(new JaasAuthenticationSuccessEvent(token)); + this.applicationEventPublisher.publishEvent(new JaasAuthenticationSuccessEvent(token)); } } @@ -314,7 +289,6 @@ public abstract class AbstractJaasAuthenticationProvider * Returns the AuthorityGrannter array that was passed to the * {@link #setAuthorityGranters(AuthorityGranter[])} method, or null if it none were * ever set. - * * @return The AuthorityGranter array, or null * * @see #setAuthorityGranters(AuthorityGranter[]) @@ -326,7 +300,6 @@ public abstract class AbstractJaasAuthenticationProvider /** * Set the AuthorityGranters that should be consulted for role names to be granted to * the Authentication. - * * @param authorityGranters AuthorityGranter array * * @see JaasAuthenticationProvider @@ -338,7 +311,6 @@ public abstract class AbstractJaasAuthenticationProvider /** * Returns the current JaasAuthenticationCallbackHandler array, or null if none are * set. - * * @return the JAASAuthenticationCallbackHandlers. * * @see #setCallbackHandlers(JaasAuthenticationCallbackHandler[]) @@ -350,11 +322,9 @@ public abstract class AbstractJaasAuthenticationProvider /** * Set the JAASAuthentcationCallbackHandler array to handle callback objects generated * by the LoginContext.login method. - * * @param callbackHandlers Array of JAASAuthenticationCallbackHandlers */ - public void setCallbackHandlers( - JaasAuthenticationCallbackHandler[] callbackHandlers) { + public void setCallbackHandlers(JaasAuthenticationCallbackHandler[] callbackHandlers) { this.callbackHandlers = callbackHandlers; } @@ -365,7 +335,6 @@ public abstract class AbstractJaasAuthenticationProvider /** * Set the loginContextName, this name is used as the index to the configuration * specified in the loginConfig property. - * * @param loginContextName */ public void setLoginContextName(String loginContextName) { @@ -380,12 +349,13 @@ public abstract class AbstractJaasAuthenticationProvider this.loginExceptionResolver = loginExceptionResolver; } + @Override public boolean supports(Class aClass) { return UsernamePasswordAuthenticationToken.class.isAssignableFrom(aClass); } - public void setApplicationEventPublisher( - ApplicationEventPublisher applicationEventPublisher) { + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { this.applicationEventPublisher = applicationEventPublisher; } @@ -393,26 +363,26 @@ public abstract class AbstractJaasAuthenticationProvider return this.applicationEventPublisher; } - // ~ Inner Classes - // ================================================================================================== - /** * Wrapper class for JAASAuthenticationCallbackHandlers */ private class InternalCallbackHandler implements CallbackHandler { + private final Authentication authentication; InternalCallbackHandler(Authentication authentication) { this.authentication = authentication; } - public void handle(Callback[] callbacks) - throws IOException, UnsupportedCallbackException { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { for (JaasAuthenticationCallbackHandler handler : AbstractJaasAuthenticationProvider.this.callbackHandlers) { for (Callback callback : callbacks) { handler.handle(callback, this.authentication); } } } + } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/AuthorityGranter.java b/core/src/main/java/org/springframework/security/authentication/jaas/AuthorityGranter.java index f0eb9f1440..cc04dc8de5 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/AuthorityGranter.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/AuthorityGranter.java @@ -17,7 +17,6 @@ package org.springframework.security.authentication.jaas; import java.security.Principal; - import java.util.Set; /** @@ -30,8 +29,6 @@ import java.util.Set; * @author Ray Krueger */ public interface AuthorityGranter { - // ~ Methods - // ======================================================================================================== /** * The grant method is called for each principal returned from the LoginContext @@ -41,12 +38,11 @@ public interface AuthorityGranter { *

      * The set may contain any object as all objects in the returned set will be passed to * the JaasGrantedAuthority constructor using toString(). - * * @param principal One of the principals from the * LoginContext.getSubect().getPrincipals() method. - * * @return the role names to grant, or null, meaning no roles should be granted to the * principal. */ Set grant(Principal principal); + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProvider.java index 379a70c08b..924f570831 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication.jaas; import javax.security.auth.callback.CallbackHandler; @@ -84,16 +85,10 @@ import org.springframework.util.Assert; * @see AbstractJaasAuthenticationProvider * @see InMemoryConfiguration */ -public class DefaultJaasAuthenticationProvider - extends AbstractJaasAuthenticationProvider { - // ~ Instance fields - // ================================================================================================ +public class DefaultJaasAuthenticationProvider extends AbstractJaasAuthenticationProvider { private Configuration configuration; - // ~ Methods - // ======================================================================================================== - @Override public void afterPropertiesSet() throws Exception { super.afterPropertiesSet(); @@ -105,8 +100,7 @@ public class DefaultJaasAuthenticationProvider * {@link #setConfiguration(Configuration)}. */ @Override - protected LoginContext createLoginContext(CallbackHandler handler) - throws LoginException { + protected LoginContext createLoginContext(CallbackHandler handler) throws LoginException { return new LoginContext(getLoginContextName(), null, handler, getConfiguration()); } @@ -116,11 +110,11 @@ public class DefaultJaasAuthenticationProvider /** * Sets the Configuration to use for Authentication. - * * @param configuration the Configuration that is used when * {@link #createLoginContext(CallbackHandler)} is called. */ public void setConfiguration(Configuration configuration) { this.configuration = configuration; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/DefaultLoginExceptionResolver.java b/core/src/main/java/org/springframework/security/authentication/jaas/DefaultLoginExceptionResolver.java index 3f43cb325a..5b5a261a44 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/DefaultLoginExceptionResolver.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/DefaultLoginExceptionResolver.java @@ -16,11 +16,11 @@ package org.springframework.security.authentication.jaas; +import javax.security.auth.login.LoginException; + import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.AuthenticationException; -import javax.security.auth.login.LoginException; - /** * This LoginExceptionResolver simply wraps the LoginException with an * AuthenticationServiceException. @@ -28,10 +28,10 @@ import javax.security.auth.login.LoginException; * @author Ray Krueger */ public class DefaultLoginExceptionResolver implements LoginExceptionResolver { - // ~ Methods - // ======================================================================================================== - public AuthenticationException resolveException(LoginException e) { - return new AuthenticationServiceException(e.getMessage(), e); + @Override + public AuthenticationException resolveException(LoginException ex) { + return new AuthenticationServiceException(ex.getMessage(), ex); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationCallbackHandler.java b/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationCallbackHandler.java index 070ba6599f..3dd70455e6 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationCallbackHandler.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationCallbackHandler.java @@ -17,6 +17,7 @@ package org.springframework.security.authentication.jaas; import java.io.IOException; + import javax.security.auth.callback.Callback; import javax.security.auth.callback.UnsupportedCallbackException; @@ -26,8 +27,8 @@ import org.springframework.security.core.Authentication; * The JaasAuthenticationCallbackHandler is similar to the * javax.security.auth.callback.CallbackHandler interface in that it defines a handle * method. The JaasAuthenticationCallbackHandler is only asked to handle one Callback - * instance at time rather than an array of all Callbacks, as the javax... - * CallbackHandler defines. + * instance at time rather than an array of all Callbacks, as the javax... CallbackHandler + * defines. * *

      * Before a JaasAuthenticationCallbackHandler is asked to 'handle' any callbacks, it is @@ -36,18 +37,15 @@ import org.springframework.security.core.Authentication; *

      * * @author Ray Krueger - * * @see JaasNameCallbackHandler * @see JaasPasswordCallbackHandler - * @see Callback - * @see + * @see Callback + * @see * CallbackHandler */ public interface JaasAuthenticationCallbackHandler { - // ~ Methods - // ======================================================================================================== /** * Handle the Callback. The handle method will be called for every callback instance sent * from the LoginContext. Meaning that The handle method may be called multiple times * for a given JaasAuthenticationCallbackHandler. - * * @param callback * @param auth The Authentication object currently being authenticated. * */ - void handle(Callback callback, Authentication auth) throws IOException, - UnsupportedCallbackException; + void handle(Callback callback, Authentication auth) throws IOException, UnsupportedCallbackException; + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationProvider.java index 34a3f231ee..f9e38143c4 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationProvider.java @@ -30,6 +30,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.io.Resource; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.jaas.event.JaasAuthenticationFailedEvent; @@ -84,8 +85,8 @@ import org.springframework.util.Assert; * *

      * When using JAAS login modules as the authentication source, sometimes the - * LoginContext will require CallbackHandlers. The JaasAuthenticationProvider + * "https://java.sun.com/j2se/1.5.0/docs/api/javax/security/auth/login/LoginContext.html" + * > LoginContext will require CallbackHandlers. The JaasAuthenticationProvider * uses an internal CallbackHandler to wrap the {@link JaasAuthenticationCallbackHandler}s configured @@ -138,31 +139,21 @@ import org.springframework.util.Assert; * @author Rob Winch */ public class JaasAuthenticationProvider extends AbstractJaasAuthenticationProvider { - // ~ Static fields/initializers - // ===================================================================================== // exists for passivity protected static final Log log = LogFactory.getLog(JaasAuthenticationProvider.class); - // ~ Instance fields - // ================================================================================================ - private Resource loginConfig; - private boolean refreshConfigurationOnStartup = true; - // ~ Methods - // ======================================================================================================== + private boolean refreshConfigurationOnStartup = true; @Override public void afterPropertiesSet() throws Exception { // the superclass is not called because it does additional checks that are // non-passive - Assert.hasLength(getLoginContextName(), - () -> "loginContextName must be set on " + getClass()); - Assert.notNull(this.loginConfig, - () -> "loginConfig must be set on " + getClass()); + Assert.hasLength(getLoginContextName(), () -> "loginContextName must be set on " + getClass()); + Assert.notNull(this.loginConfig, () -> "loginConfig must be set on " + getClass()); configureJaas(this.loginConfig); - Assert.notNull(Configuration.getConfiguration(), "As per https://java.sun.com/j2se/1.5.0/docs/api/javax/security/auth/login/Configuration.html " + "\"If a Configuration object was set via the Configuration.setConfiguration method, then that object is " @@ -171,21 +162,17 @@ public class JaasAuthenticationProvider extends AbstractJaasAuthenticationProvid } @Override - protected LoginContext createLoginContext(CallbackHandler handler) - throws LoginException { + protected LoginContext createLoginContext(CallbackHandler handler) throws LoginException { return new LoginContext(getLoginContextName(), handler); } /** * Hook method for configuring Jaas. - * * @param loginConfig URL to Jaas login configuration - * * @throws IOException if there is a problem reading the config resource. */ protected void configureJaas(Resource loginConfig) throws IOException { configureJaasUsingLoop(); - if (this.refreshConfigurationOnStartup) { // Overcome issue in SEC-760 Configuration.getConfiguration().refresh(); @@ -201,42 +188,33 @@ public class JaasAuthenticationProvider extends AbstractJaasAuthenticationProvid private void configureJaasUsingLoop() throws IOException { String loginConfigUrl = convertLoginConfigToUrl(); boolean alreadySet = false; - int n = 1; final String prefix = "login.config.url."; String existing; - while ((existing = Security.getProperty(prefix + n)) != null) { alreadySet = existing.equals(loginConfigUrl); - if (alreadySet) { break; } - n++; } - if (!alreadySet) { String key = prefix + n; - log.debug("Setting security property [" + key + "] to: " + loginConfigUrl); + log.debug(LogMessage.format("Setting security property [%s] to: %s", key, loginConfigUrl)); Security.setProperty(key, loginConfigUrl); } } private String convertLoginConfigToUrl() throws IOException { String loginConfigPath; - try { - loginConfigPath = this.loginConfig.getFile().getAbsolutePath() - .replace(File.separatorChar, '/'); - + loginConfigPath = this.loginConfig.getFile().getAbsolutePath().replace(File.separatorChar, '/'); if (!loginConfigPath.startsWith("/")) { loginConfigPath = "/" + loginConfigPath; } - return new URL("file", "", loginConfigPath).toString(); } - catch (IOException e) { + catch (IOException ex) { // SEC-1700: May be inside a jar return this.loginConfig.getURL().toString(); } @@ -245,16 +223,13 @@ public class JaasAuthenticationProvider extends AbstractJaasAuthenticationProvid /** * Publishes the {@link JaasAuthenticationFailedEvent}. Can be overridden by * subclasses for different functionality - * * @param token The authentication token being processed * @param ase The excetion that caused the authentication failure */ @Override - protected void publishFailureEvent(UsernamePasswordAuthenticationToken token, - AuthenticationException ase) { + protected void publishFailureEvent(UsernamePasswordAuthenticationToken token, AuthenticationException ase) { // exists for passivity (the superclass does a null check before publishing) - getApplicationEventPublisher() - .publishEvent(new JaasAuthenticationFailedEvent(token, ase)); + getApplicationEventPublisher().publishEvent(new JaasAuthenticationFailedEvent(token, ase)); } public Resource getLoginConfig() { @@ -263,7 +238,6 @@ public class JaasAuthenticationProvider extends AbstractJaasAuthenticationProvid /** * Set the JAAS login configuration file. - * * @param loginConfig * * @see SEC-1320 - * * @param refresh set to {@code false} to disable reloading of the configuration. May * be useful in some environments. + * @see SEC-1320 */ public void setRefreshConfigurationOnStartup(boolean refresh) { this.refreshConfigurationOnStartup = refresh; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationToken.java b/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationToken.java index ced47d245f..410f7e5f06 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationToken.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/JaasAuthenticationToken.java @@ -18,12 +18,12 @@ package org.springframework.security.authentication.jaas; import java.util.List; +import javax.security.auth.login.LoginContext; + import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; -import javax.security.auth.login.LoginContext; - /** * UsernamePasswordAuthenticationToken extension to carry the Jaas LoginContext that the * user was logged into @@ -34,30 +34,21 @@ public class JaasAuthenticationToken extends UsernamePasswordAuthenticationToken private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private final transient LoginContext loginContext; - // ~ Constructors - // =================================================================================================== - - public JaasAuthenticationToken(Object principal, Object credentials, - LoginContext loginContext) { + public JaasAuthenticationToken(Object principal, Object credentials, LoginContext loginContext) { super(principal, credentials); this.loginContext = loginContext; } - public JaasAuthenticationToken(Object principal, Object credentials, - List authorities, LoginContext loginContext) { + public JaasAuthenticationToken(Object principal, Object credentials, List authorities, + LoginContext loginContext) { super(principal, credentials, authorities); this.loginContext = loginContext; } - // ~ Methods - // ======================================================================================================== - public LoginContext getLoginContext() { - return loginContext; + return this.loginContext; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/JaasGrantedAuthority.java b/core/src/main/java/org/springframework/security/authentication/jaas/JaasGrantedAuthority.java index d6fc0b9a4f..d0d76b89d1 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/JaasGrantedAuthority.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/JaasGrantedAuthority.java @@ -16,18 +16,17 @@ package org.springframework.security.authentication.jaas; +import java.security.Principal; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.security.Principal; - /** * {@code GrantedAuthority} which, in addition to the assigned role, holds the principal * that an {@link AuthorityGranter} used as a reason to grant this authority. * * @author Ray Krueger - * * @see AuthorityGranter */ public final class JaasGrantedAuthority implements GrantedAuthority { @@ -35,6 +34,7 @@ public final class JaasGrantedAuthority implements GrantedAuthority { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private final String role; + private final Principal principal; public JaasGrantedAuthority(String role, Principal principal) { @@ -44,16 +44,25 @@ public final class JaasGrantedAuthority implements GrantedAuthority { this.principal = principal; } - // ~ Methods - // ======================================================================================================== - public Principal getPrincipal() { - return principal; + return this.principal; } @Override public String getAuthority() { - return role; + return this.role; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj instanceof JaasGrantedAuthority) { + JaasGrantedAuthority jga = (JaasGrantedAuthority) obj; + return this.role.equals(jga.role) && this.principal.equals(jga.principal); + } + return false; } @Override @@ -63,22 +72,9 @@ public final class JaasGrantedAuthority implements GrantedAuthority { return result; } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (obj instanceof JaasGrantedAuthority) { - JaasGrantedAuthority jga = (JaasGrantedAuthority) obj; - return this.role.equals(jga.role) && this.principal.equals(jga.principal); - } - - return false; - } - @Override public String toString() { - return "Jaas Authority [" + role + "," + principal + "]"; + return "Jaas Authority [" + this.role + "," + this.principal + "]"; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/JaasNameCallbackHandler.java b/core/src/main/java/org/springframework/security/authentication/jaas/JaasNameCallbackHandler.java index 2a53197e64..1f7881283f 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/JaasNameCallbackHandler.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/JaasNameCallbackHandler.java @@ -16,52 +16,46 @@ package org.springframework.security.authentication.jaas; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.userdetails.UserDetails; - import javax.security.auth.callback.Callback; import javax.security.auth.callback.NameCallback; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UserDetails; + /** * The most basic Callbacks to be handled when using a LoginContext from JAAS, are the * NameCallback and PasswordCallback. Spring Security provides the JaasNameCallbackHandler * specifically tailored to handling the NameCallback.
      * * @author Ray Krueger - * - * @see Callback - * @see NameCallback + * @see Callback + * @see NameCallback */ public class JaasNameCallbackHandler implements JaasAuthenticationCallbackHandler { - // ~ Methods - // ======================================================================================================== /** * If the callback passed to the 'handle' method is an instance of NameCallback, the * JaasNameCallbackHandler will call, * callback.setName(authentication.getPrincipal().toString()). - * * @param callback * @param authentication * */ + @Override public void handle(Callback callback, Authentication authentication) { if (callback instanceof NameCallback) { - NameCallback ncb = (NameCallback) callback; - String username; - - Object principal = authentication.getPrincipal(); - - if (principal instanceof UserDetails) { - username = ((UserDetails) principal).getUsername(); - } - else { - username = principal.toString(); - } - - ncb.setName(username); + ((NameCallback) callback).setName(getUserName(authentication)); } } + + private String getUserName(Authentication authentication) { + Object principal = authentication.getPrincipal(); + if (principal instanceof UserDetails) { + return ((UserDetails) principal).getUsername(); + } + return principal.toString(); + } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/JaasPasswordCallbackHandler.java b/core/src/main/java/org/springframework/security/authentication/jaas/JaasPasswordCallbackHandler.java index c42f392a57..8b43440052 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/JaasPasswordCallbackHandler.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/JaasPasswordCallbackHandler.java @@ -16,41 +16,39 @@ package org.springframework.security.authentication.jaas; -import org.springframework.security.core.Authentication; - import javax.security.auth.callback.Callback; import javax.security.auth.callback.PasswordCallback; +import org.springframework.security.core.Authentication; + /** * The most basic Callbacks to be handled when using a LoginContext from JAAS, are the * NameCallback and PasswordCallback. Spring Security provides the - * JaasPasswordCallbackHandler specifically tailored to handling the PasswordCallback.
      + * JaasPasswordCallbackHandler specifically tailored to handling the PasswordCallback. + *
      * * @author Ray Krueger - * - * @see Callback - * @see + * @see Callback + * @see * PasswordCallback */ public class JaasPasswordCallbackHandler implements JaasAuthenticationCallbackHandler { - // ~ Methods - // ======================================================================================================== /** * If the callback passed to the 'handle' method is an instance of PasswordCallback, * the JaasPasswordCallbackHandler will call, * callback.setPassword(authentication.getCredentials().toString()). - * * @param callback * @param auth * */ + @Override public void handle(Callback callback, Authentication auth) { if (callback instanceof PasswordCallback) { - PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(auth.getCredentials().toString().toCharArray()); + ((PasswordCallback) callback).setPassword(auth.getCredentials().toString().toCharArray()); } } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/LoginExceptionResolver.java b/core/src/main/java/org/springframework/security/authentication/jaas/LoginExceptionResolver.java index 5fc36f4b3c..cdaaed8dfa 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/LoginExceptionResolver.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/LoginExceptionResolver.java @@ -16,10 +16,10 @@ package org.springframework.security.authentication.jaas; -import org.springframework.security.core.AuthenticationException; - import javax.security.auth.login.LoginException; +import org.springframework.security.core.AuthenticationException; + /** * The JaasAuthenticationProvider takes an instance of LoginExceptionResolver to resolve * LoginModule specific exceptions to Spring Security AuthenticationExceptions. @@ -31,16 +31,13 @@ import javax.security.auth.login.LoginException; * @author Ray Krueger */ public interface LoginExceptionResolver { - // ~ Methods - // ======================================================================================================== /** * Translates a Jaas LoginException to an SpringSecurityException. - * - * @param e The LoginException thrown by the configured LoginModule. - * + * @param ex The LoginException thrown by the configured LoginModule. * @return The AuthenticationException that the JaasAuthenticationProvider should * throw. */ - AuthenticationException resolveException(LoginException e); + AuthenticationException resolveException(LoginException ex); + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/SecurityContextLoginModule.java b/core/src/main/java/org/springframework/security/authentication/jaas/SecurityContextLoginModule.java index 38756da657..4cc8c8d402 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/SecurityContextLoginModule.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/SecurityContextLoginModule.java @@ -16,12 +16,6 @@ package org.springframework.security.authentication.jaas; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextHolder; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - import java.util.Map; import javax.security.auth.Subject; @@ -29,6 +23,12 @@ import javax.security.auth.callback.CallbackHandler; import javax.security.auth.login.LoginException; import javax.security.auth.spi.LoginModule; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; + /** * An implementation of {@link LoginModule} that uses a Spring Security * {@link org.springframework.security.core.context.SecurityContext SecurityContext} to @@ -52,65 +52,53 @@ import javax.security.auth.spi.LoginModule; * @author Ray Krueger */ public class SecurityContextLoginModule implements LoginModule { - // ~ Static fields/initializers - // ===================================================================================== private static final Log log = LogFactory.getLog(SecurityContextLoginModule.class); - // ~ Instance fields - // ================================================================================================ - private Authentication authen; - private Subject subject; - private boolean ignoreMissingAuthentication = false; - // ~ Methods - // ======================================================================================================== + private Subject subject; + + private boolean ignoreMissingAuthentication = false; /** * Abort the authentication process by forgetting the Spring Security * Authentication. - * * @return true if this method succeeded, or false if this LoginModule * should be ignored. - * * @exception LoginException if the abort fails */ + @Override public boolean abort() { - if (authen == null) { + if (this.authen == null) { return false; } - - authen = null; - + this.authen = null; return true; } /** * Authenticate the Subject (phase two) by adding the Spring Security * Authentication to the Subject's principals. - * * @return true if this method succeeded, or false if this LoginModule * should be ignored. - * * @exception LoginException if the commit fails */ + @Override public boolean commit() { - if (authen == null) { + if (this.authen == null) { return false; } - - subject.getPrincipals().add(authen); - + this.subject.getPrincipals().add(this.authen); return true; } Authentication getAuthentication() { - return authen; + return this.authen; } Subject getSubject() { - return subject; + return this.subject; } /** @@ -118,67 +106,55 @@ public class SecurityContextLoginModule implements LoginModule { * code establishing the LoginContext likely won't provide one that * understands Spring Security. Also ignores the sharedState and * options parameters, since none are recognized. - * * @param subject the Subject to be authenticated. * @param callbackHandler is ignored * @param sharedState is ignored * @param options are ignored */ + @Override @SuppressWarnings("unchecked") - public void initialize(Subject subject, CallbackHandler callbackHandler, - Map sharedState, Map options) { + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { this.subject = subject; - if (options != null) { - ignoreMissingAuthentication = "true".equals(options - .get("ignoreMissingAuthentication")); + this.ignoreMissingAuthentication = "true".equals(options.get("ignoreMissingAuthentication")); } } /** * Authenticate the Subject (phase one) by extracting the Spring Security * Authentication from the current SecurityContext. - * * @return true if the authentication succeeded, or false if this * LoginModule should be ignored. - * * @throws LoginException if the authentication fails */ + @Override public boolean login() throws LoginException { - authen = SecurityContextHolder.getContext().getAuthentication(); - - if (authen == null) { - String msg = "Login cannot complete, authentication not found in security context"; - - if (ignoreMissingAuthentication) { - log.warn(msg); - - return false; - } - else { - throw new LoginException(msg); - } + this.authen = SecurityContextHolder.getContext().getAuthentication(); + if (this.authen != null) { + return true; } - - return true; + String msg = "Login cannot complete, authentication not found in security context"; + if (!this.ignoreMissingAuthentication) { + throw new LoginException(msg); + } + log.warn(msg); + return false; } /** * Log out the Subject. - * * @return true if this method succeeded, or false if this LoginModule * should be ignored. - * * @exception LoginException if the logout fails */ + @Override public boolean logout() { - if (authen == null) { + if (this.authen == null) { return false; } - - subject.getPrincipals().remove(authen); - authen = null; - + this.subject.getPrincipals().remove(this.authen); + this.authen = null; return true; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationEvent.java b/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationEvent.java index 2c3332d59b..c5aa50f97d 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationEvent.java @@ -16,9 +16,8 @@ package org.springframework.security.authentication.jaas.event; -import org.springframework.security.core.Authentication; - import org.springframework.context.ApplicationEvent; +import org.springframework.security.core.Authentication; /** * Parent class for events fired by the @@ -28,27 +27,21 @@ import org.springframework.context.ApplicationEvent; * @author Ray Krueger */ public abstract class JaasAuthenticationEvent extends ApplicationEvent { - // ~ Constructors - // =================================================================================================== /** * The Authentication object is stored as the ApplicationEvent 'source'. - * * @param auth */ public JaasAuthenticationEvent(Authentication auth) { super(auth); } - // ~ Methods - // ======================================================================================================== - /** * Pre-casted method that returns the 'source' of the event. - * * @return the Authentication */ public Authentication getAuthentication() { - return (Authentication) source; + return (Authentication) this.source; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationFailedEvent.java b/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationFailedEvent.java index e3955c909f..4b70d77950 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationFailedEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationFailedEvent.java @@ -25,23 +25,16 @@ import org.springframework.security.core.Authentication; * @author Ray Krueger */ public class JaasAuthenticationFailedEvent extends JaasAuthenticationEvent { - // ~ Instance fields - // ================================================================================================ private final Exception exception; - // ~ Constructors - // =================================================================================================== - public JaasAuthenticationFailedEvent(Authentication auth, Exception exception) { super(auth); this.exception = exception; } - // ~ Methods - // ======================================================================================================== - public Exception getException() { - return exception; + return this.exception; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationSuccessEvent.java b/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationSuccessEvent.java index f57d4ab35d..0afa2b882b 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationSuccessEvent.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/event/JaasAuthenticationSuccessEvent.java @@ -27,10 +27,9 @@ import org.springframework.security.core.Authentication; * @author Ray Krueger */ public class JaasAuthenticationSuccessEvent extends JaasAuthenticationEvent { - // ~ Constructors - // =================================================================================================== public JaasAuthenticationSuccessEvent(Authentication auth) { super(auth); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/event/package-info.java b/core/src/main/java/org/springframework/security/authentication/jaas/event/package-info.java index 266efce2c0..802ac1c9d1 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/event/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/event/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * JAAS authentication events which can be published to the Spring application context by the JAAS authentication - * provider. + * JAAS authentication events which can be published to the Spring application context by + * the JAAS authentication provider. */ package org.springframework.security.authentication.jaas.event; - diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/memory/InMemoryConfiguration.java b/core/src/main/java/org/springframework/security/authentication/jaas/memory/InMemoryConfiguration.java index 181b118b98..8aa767aaa7 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/memory/InMemoryConfiguration.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/memory/InMemoryConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication.jaas.memory; import java.util.Collections; @@ -35,36 +36,28 @@ import org.springframework.util.Assert; * @author Rob Winch */ public class InMemoryConfiguration extends Configuration { - // ~ Instance fields - // ================================================================================================ private final AppConfigurationEntry[] defaultConfiguration; - private final Map mappedConfigurations; - // ~ Constructors - // =================================================================================================== + private final Map mappedConfigurations; /** * Creates a new instance with only a defaultConfiguration. Any configuration name * will result in defaultConfiguration being returned. - * * @param defaultConfiguration The result for any calls to * {@link #getAppConfigurationEntry(String)}. Can be null. */ public InMemoryConfiguration(AppConfigurationEntry[] defaultConfiguration) { - this(Collections.emptyMap(), - defaultConfiguration); + this(Collections.emptyMap(), defaultConfiguration); } /** * Creates a new instance with a mapping of login context name to an array of * {@link AppConfigurationEntry}s. - * * @param mappedConfigurations each key represents a login context name and each value * is an Array of {@link AppConfigurationEntry}s that should be used. */ - public InMemoryConfiguration( - Map mappedConfigurations) { + public InMemoryConfiguration(Map mappedConfigurations) { this(mappedConfigurations, null); } @@ -72,27 +65,22 @@ public class InMemoryConfiguration extends Configuration { * Creates a new instance with a mapping of login context name to an array of * {@link AppConfigurationEntry}s along with a default configuration that will be used * if no mapping is found for the given login context name. - * * @param mappedConfigurations each key represents a login context name and each value * is an Array of {@link AppConfigurationEntry}s that should be used. * @param defaultConfiguration The result for any calls to * {@link #getAppConfigurationEntry(String)}. Can be null. */ - public InMemoryConfiguration( - Map mappedConfigurations, + public InMemoryConfiguration(Map mappedConfigurations, AppConfigurationEntry[] defaultConfiguration) { Assert.notNull(mappedConfigurations, "mappedConfigurations cannot be null."); this.mappedConfigurations = mappedConfigurations; this.defaultConfiguration = defaultConfiguration; } - // ~ Methods - // ======================================================================================================== - @Override public AppConfigurationEntry[] getAppConfigurationEntry(String name) { AppConfigurationEntry[] mappedResult = this.mappedConfigurations.get(name); - return mappedResult == null ? this.defaultConfiguration : mappedResult; + return (mappedResult != null) ? mappedResult : this.defaultConfiguration; } /** @@ -101,4 +89,5 @@ public class InMemoryConfiguration extends Configuration { @Override public void refresh() { } -} \ No newline at end of file + +} diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/memory/package-info.java b/core/src/main/java/org/springframework/security/authentication/jaas/memory/package-info.java index 0657c03910..4e0476288b 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/memory/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/memory/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * An in memory JAAS implementation. */ package org.springframework.security.authentication.jaas.memory; - diff --git a/core/src/main/java/org/springframework/security/authentication/jaas/package-info.java b/core/src/main/java/org/springframework/security/authentication/jaas/package-info.java index 3204025524..a5bb8e318d 100644 --- a/core/src/main/java/org/springframework/security/authentication/jaas/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/jaas/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * An authentication provider for JAAS. */ package org.springframework.security.authentication.jaas; - diff --git a/core/src/main/java/org/springframework/security/authentication/package-info.java b/core/src/main/java/org/springframework/security/authentication/package-info.java index e56ab09664..b133c48be4 100644 --- a/core/src/main/java/org/springframework/security/authentication/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/package-info.java @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Core classes and interfaces related to user authentication, which are used throughout Spring Security. + * Core classes and interfaces related to user authentication, which are used throughout + * Spring Security. *

      - * Of key importance is the {@link org.springframework.security.authentication.AuthenticationManager AuthenticationManager} - * and its default implementation {@link org.springframework.security.authentication.ProviderManager - * ProviderManager}, which maintains a list {@link org.springframework.security.authentication.AuthenticationProvider + * Of key importance is the + * {@link org.springframework.security.authentication.AuthenticationManager + * AuthenticationManager} and its default implementation + * {@link org.springframework.security.authentication.ProviderManager ProviderManager}, + * which maintains a list + * {@link org.springframework.security.authentication.AuthenticationProvider * AuthenticationProvider}s to which it delegates authentication requests. */ package org.springframework.security.authentication; - diff --git a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationException.java b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationException.java index 8cb14cbdd3..fdfece6ba7 100644 --- a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationException.java +++ b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationException.java @@ -33,16 +33,13 @@ public class RemoteAuthenticationException extends NestedRuntimeException { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Constructors - // =================================================================================================== - /** * Constructs a RemoteAuthenticationException with the specified message * and no root cause. - * * @param msg the detail message */ public RemoteAuthenticationException(String msg) { super(msg); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManager.java b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManager.java index 4108468dfa..f2bec46c89 100644 --- a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManager.java +++ b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManager.java @@ -26,8 +26,6 @@ import org.springframework.security.core.GrantedAuthority; * @author Ben Alex */ public interface RemoteAuthenticationManager { - // ~ Methods - // ======================================================================================================== /** * Attempts to authenticate the remote client using the presented username and @@ -39,15 +37,13 @@ public interface RemoteAuthenticationManager { * required for remote clients to enable/disable relevant user interface commands etc. * There is nothing preventing users from implementing their own equivalent package * that works with more complex object types. - * * @param username the username the remote client wishes to authenticate with. * @param password the password the remote client wishes to authenticate with. - * * @return all of the granted authorities the specified username and password have * access to. - * * @throws RemoteAuthenticationException if the authentication failed. */ - Collection attemptAuthentication(String username, - String password) throws RemoteAuthenticationException; + Collection attemptAuthentication(String username, String password) + throws RemoteAuthenticationException; + } diff --git a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImpl.java b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImpl.java index 35e1a6d556..2f3063cdd3 100644 --- a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImpl.java +++ b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImpl.java @@ -33,38 +33,33 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public class RemoteAuthenticationManagerImpl implements RemoteAuthenticationManager, - InitializingBean { - // ~ Instance fields - // ================================================================================================ +public class RemoteAuthenticationManagerImpl implements RemoteAuthenticationManager, InitializingBean { private AuthenticationManager authenticationManager; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { Assert.notNull(this.authenticationManager, "authenticationManager is required"); } - public Collection attemptAuthentication(String username, - String password) throws RemoteAuthenticationException { - UsernamePasswordAuthenticationToken request = new UsernamePasswordAuthenticationToken( - username, password); - + @Override + public Collection attemptAuthentication(String username, String password) + throws RemoteAuthenticationException { + UsernamePasswordAuthenticationToken request = new UsernamePasswordAuthenticationToken(username, password); try { - return authenticationManager.authenticate(request).getAuthorities(); + return this.authenticationManager.authenticate(request).getAuthorities(); } - catch (AuthenticationException authEx) { - throw new RemoteAuthenticationException(authEx.getMessage()); + catch (AuthenticationException ex) { + throw new RemoteAuthenticationException(ex.getMessage()); } } protected AuthenticationManager getAuthenticationManager() { - return authenticationManager; + return this.authenticationManager; } public void setAuthenticationManager(AuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; } + } diff --git a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProvider.java b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProvider.java index 7e6b8c6b15..3ed938a248 100644 --- a/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProvider.java +++ b/core/src/main/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProvider.java @@ -50,43 +50,36 @@ import org.springframework.util.Assert; * * @author Ben Alex */ -public class RemoteAuthenticationProvider implements AuthenticationProvider, - InitializingBean { - // ~ Instance fields - // ================================================================================================ +public class RemoteAuthenticationProvider implements AuthenticationProvider, InitializingBean { private RemoteAuthenticationManager remoteAuthenticationManager; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(this.remoteAuthenticationManager, - "remoteAuthenticationManager is mandatory"); + Assert.notNull(this.remoteAuthenticationManager, "remoteAuthenticationManager is mandatory"); } - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { String username = authentication.getPrincipal().toString(); Object credentials = authentication.getCredentials(); - String password = credentials == null ? null : credentials.toString(); - Collection authorities = remoteAuthenticationManager + String password = (credentials != null) ? credentials.toString() : null; + Collection authorities = this.remoteAuthenticationManager .attemptAuthentication(username, password); - return new UsernamePasswordAuthenticationToken(username, password, authorities); } public RemoteAuthenticationManager getRemoteAuthenticationManager() { - return remoteAuthenticationManager; + return this.remoteAuthenticationManager; } - public void setRemoteAuthenticationManager( - RemoteAuthenticationManager remoteAuthenticationManager) { + public void setRemoteAuthenticationManager(RemoteAuthenticationManager remoteAuthenticationManager) { this.remoteAuthenticationManager = remoteAuthenticationManager; } + @Override public boolean supports(Class authentication) { - return (UsernamePasswordAuthenticationToken.class - .isAssignableFrom(authentication)); + return (UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication)); } + } diff --git a/core/src/main/java/org/springframework/security/authentication/rcp/package-info.java b/core/src/main/java/org/springframework/security/authentication/rcp/package-info.java index 443fb72df3..b4010d186b 100644 --- a/core/src/main/java/org/springframework/security/authentication/rcp/package-info.java +++ b/core/src/main/java/org/springframework/security/authentication/rcp/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Allows remote clients to authenticate and obtain a populated Authentication object. + * Allows remote clients to authenticate and obtain a populated + * Authentication object. */ package org.springframework.security.authentication.rcp; - diff --git a/core/src/main/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManager.java b/core/src/main/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManager.java index 8647f685fd..2bba4744f6 100644 --- a/core/src/main/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManager.java +++ b/core/src/main/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManager.java @@ -16,39 +16,45 @@ package org.springframework.security.authorization; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.core.Authentication; -import reactor.core.publisher.Mono; /** * A {@link ReactiveAuthorizationManager} that determines if the current user is * authenticated. * - * @author Rob Winch - * @since 5.0 * @param The type of object authorization is being performed against. This does not - * matter since the authorization decision does not use the object. + * @author Rob Winch + * @since 5.0 matter since the authorization decision does not use the object. */ public class AuthenticatedReactiveAuthorizationManager implements ReactiveAuthorizationManager { private AuthenticationTrustResolver authTrustResolver = new AuthenticationTrustResolverImpl(); + AuthenticatedReactiveAuthorizationManager() { + } + @Override public Mono check(Mono authentication, T object) { - return authentication - .filter(this::isNotAnonymous) - .map(a -> new AuthorizationDecision(a.isAuthenticated())) - .defaultIfEmpty(new AuthorizationDecision(false)); + return authentication.filter(this::isNotAnonymous).map(this::getAuthorizationDecision) + .defaultIfEmpty(new AuthorizationDecision(false)); + } + + private AuthorizationDecision getAuthorizationDecision(Authentication authentication) { + return new AuthorizationDecision(authentication.isAuthenticated()); } /** - * Verify (via {@link AuthenticationTrustResolver}) that the given authentication is not anonymous. + * Verify (via {@link AuthenticationTrustResolver}) that the given authentication is + * not anonymous. * @param authentication to be checked * @return true if not anonymous, otherwise false. */ private boolean isNotAnonymous(Authentication authentication) { - return !authTrustResolver.isAnonymous(authentication); + return !this.authTrustResolver.isAnonymous(authentication); } /** @@ -60,5 +66,4 @@ public class AuthenticatedReactiveAuthorizationManager implements ReactiveAut return new AuthenticatedReactiveAuthorizationManager<>(); } - private AuthenticatedReactiveAuthorizationManager() {} } diff --git a/core/src/main/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManager.java b/core/src/main/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManager.java index 5ab05c960c..9a8940ff0a 100644 --- a/core/src/main/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManager.java +++ b/core/src/main/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManager.java @@ -16,43 +16,46 @@ package org.springframework.security.authorization; -import org.springframework.security.core.Authentication; -import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - import java.util.Arrays; import java.util.List; +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.util.Assert; + /** * A {@link ReactiveAuthorizationManager} that determines if the current user is * authorized by evaluating if the {@link Authentication} contains a specified authority. * + * @param the type of object being authorized * @author Rob Winch * @since 5.0 - * @param the type of object being authorized */ public class AuthorityReactiveAuthorizationManager implements ReactiveAuthorizationManager { + private final List authorities; - private AuthorityReactiveAuthorizationManager(String... authorities) { + AuthorityReactiveAuthorizationManager(String... authorities) { this.authorities = Arrays.asList(authorities); } @Override public Mono check(Mono authentication, T object) { - return authentication - .filter(a -> a.isAuthenticated()) - .flatMapIterable( a -> a.getAuthorities()) - .map(g -> g.getAuthority()) - .any(a -> this.authorities.contains(a)) - .map( hasAuthority -> new AuthorizationDecision(hasAuthority)) - .defaultIfEmpty(new AuthorizationDecision(false)); + // @formatter:off + return authentication.filter((a) -> a.isAuthenticated()) + .flatMapIterable(Authentication::getAuthorities) + .map(GrantedAuthority::getAuthority) + .any(this.authorities::contains) + .map(AuthorizationDecision::new) + .defaultIfEmpty(new AuthorizationDecision(false)); + // @formatter:on } /** * Creates an instance of {@link AuthorityReactiveAuthorizationManager} with the * provided authority. - * * @param authority the authority to check for * @param the type of object being authorized * @return the new instance @@ -76,14 +79,12 @@ public class AuthorityReactiveAuthorizationManager implements ReactiveAuthori for (String authority : authorities) { Assert.notNull(authority, "authority cannot be null"); } - return new AuthorityReactiveAuthorizationManager<>(authorities); } /** * Creates an instance of {@link AuthorityReactiveAuthorizationManager} with the * provided authority. - * * @param role the authority to check for prefixed with "ROLE_" * @param the type of object being authorized * @return the new instance @@ -107,15 +108,15 @@ public class AuthorityReactiveAuthorizationManager implements ReactiveAuthori for (String role : roles) { Assert.notNull(role, "role cannot be null"); } - return hasAnyAuthority(toNamedRolesArray(roles)); } private static String[] toNamedRolesArray(String... roles) { String[] result = new String[roles.length]; - for (int i=0; i < roles.length; i++) { + for (int i = 0; i < roles.length; i++) { result[i] = "ROLE_" + roles[i]; } return result; } + } diff --git a/core/src/main/java/org/springframework/security/authorization/AuthorizationDecision.java b/core/src/main/java/org/springframework/security/authorization/AuthorizationDecision.java index 08fdf7a92d..026d5afdbc 100644 --- a/core/src/main/java/org/springframework/security/authorization/AuthorizationDecision.java +++ b/core/src/main/java/org/springframework/security/authorization/AuthorizationDecision.java @@ -21,6 +21,7 @@ package org.springframework.security.authorization; * @since 5.0 */ public class AuthorizationDecision { + private final boolean granted; public AuthorizationDecision(boolean granted) { @@ -28,6 +29,7 @@ public class AuthorizationDecision { } public boolean isGranted() { - return granted; + return this.granted; } + } diff --git a/core/src/main/java/org/springframework/security/authorization/ReactiveAuthorizationManager.java b/core/src/main/java/org/springframework/security/authorization/ReactiveAuthorizationManager.java index 6686658cbc..667cecf5af 100644 --- a/core/src/main/java/org/springframework/security/authorization/ReactiveAuthorizationManager.java +++ b/core/src/main/java/org/springframework/security/authorization/ReactiveAuthorizationManager.java @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authorization; +import reactor.core.publisher.Mono; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.core.Authentication; -import reactor.core.publisher.Mono; - /** - * A reactive authorization manager which can determine if an {@link Authentication} - * has access to a specific object. + * A reactive authorization manager which can determine if an {@link Authentication} has + * access to a specific object. * + * @param the type of object that the authorization check is being done one. * @author Rob Winch * @since 5.0 - * @param the type of object that the authorization check is being done one. */ public interface ReactiveAuthorizationManager { + /** * Determines if access is granted for a specific authentication and object. - * * @param authentication the Authentication to check * @param object the object to check * @return an decision or empty Mono if no decision could be made. @@ -40,17 +41,18 @@ public interface ReactiveAuthorizationManager { /** * Determines if access should be granted for a specific authentication and object - * - * @param authentication the Authentication to check * @param object the object to check * @return an empty Mono if authorization is granted or a Mono error if access is * denied */ default Mono verify(Mono authentication, T object) { + // @formatter:off return check(authentication, object) - .filter( d -> d.isGranted()) - .switchIfEmpty(Mono.defer(() -> Mono.error(new AccessDeniedException("Access Denied")))) - .flatMap( d -> Mono.empty() ); + .filter(AuthorizationDecision::isGranted) + .switchIfEmpty(Mono.defer(() -> Mono.error(new AccessDeniedException("Access Denied")))) + .flatMap((decision) -> Mono.empty()); + // @formatter:on } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java b/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java index 7cc9a342d8..e3e6e0267a 100644 --- a/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java +++ b/core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import java.util.concurrent.Callable; @@ -34,7 +35,6 @@ abstract class AbstractDelegatingSecurityContextSupport { /** * Creates a new {@link AbstractDelegatingSecurityContextSupport} that uses the * specified {@link SecurityContext}. - * * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} and each * {@link DelegatingSecurityContextCallable} or null to default to the current @@ -45,10 +45,11 @@ abstract class AbstractDelegatingSecurityContextSupport { } protected final Runnable wrap(Runnable delegate) { - return DelegatingSecurityContextRunnable.create(delegate, securityContext); + return DelegatingSecurityContextRunnable.create(delegate, this.securityContext); } protected final Callable wrap(Callable delegate) { - return DelegatingSecurityContextCallable.create(delegate, securityContext); + return DelegatingSecurityContextCallable.create(delegate, this.securityContext); } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java index d65272ec5d..50d1b89af0 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import java.util.concurrent.Callable; @@ -23,13 +24,13 @@ import org.springframework.util.Assert; /** *

      - * Wraps a delegate {@link Callable} with logic for setting up a - * {@link SecurityContext} before invoking the delegate {@link Callable} and - * then removing the {@link SecurityContext} after the delegate has completed. + * Wraps a delegate {@link Callable} with logic for setting up a {@link SecurityContext} + * before invoking the delegate {@link Callable} and then removing the + * {@link SecurityContext} after the delegate has completed. *

      *

      - * If there is a {@link SecurityContext} that already exists, it will be - * restored after the {@link #call()} method is invoked. + * If there is a {@link SecurityContext} that already exists, it will be restored after + * the {@link #call()} method is invoked. *

      * * @author Rob Winch @@ -39,16 +40,14 @@ public final class DelegatingSecurityContextCallable implements Callable { private final Callable delegate; - /** - * The {@link SecurityContext} that the delegate {@link Callable} will be - * ran as. + * The {@link SecurityContext} that the delegate {@link Callable} will be ran as. */ private final SecurityContext delegateSecurityContext; /** - * The {@link SecurityContext} that was on the {@link SecurityContextHolder} - * prior to being set to the delegateSecurityContext. + * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to + * being set to the delegateSecurityContext. */ private SecurityContext originalSecurityContext; @@ -60,8 +59,7 @@ public final class DelegatingSecurityContextCallable implements Callable { * @param securityContext the {@link SecurityContext} to establish for the delegate * {@link Callable}. Cannot be null. */ - public DelegatingSecurityContextCallable(Callable delegate, - SecurityContext securityContext) { + public DelegatingSecurityContextCallable(Callable delegate, SecurityContext securityContext) { Assert.notNull(delegate, "delegate cannot be null"); Assert.notNull(securityContext, "securityContext cannot be null"); this.delegate = delegate; @@ -81,17 +79,17 @@ public final class DelegatingSecurityContextCallable implements Callable { @Override public V call() throws Exception { this.originalSecurityContext = SecurityContextHolder.getContext(); - try { - SecurityContextHolder.setContext(delegateSecurityContext); - return delegate.call(); + SecurityContextHolder.setContext(this.delegateSecurityContext); + return this.delegate.call(); } finally { SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); - if (emptyContext.equals(originalSecurityContext)) { + if (emptyContext.equals(this.originalSecurityContext)) { SecurityContextHolder.clearContext(); - } else { - SecurityContextHolder.setContext(originalSecurityContext); + } + else { + SecurityContextHolder.setContext(this.originalSecurityContext); } this.originalSecurityContext = null; } @@ -99,7 +97,7 @@ public final class DelegatingSecurityContextCallable implements Callable { @Override public String toString() { - return delegate.toString(); + return this.delegate.toString(); } /** @@ -107,17 +105,15 @@ public final class DelegatingSecurityContextCallable implements Callable { * {@link Callable} and {@link SecurityContext}, but if the securityContext is null * will defaults to the current {@link SecurityContext} on the * {@link SecurityContextHolder} - * * @param delegate the delegate {@link DelegatingSecurityContextCallable} to run with * the specified {@link SecurityContext}. Cannot be null. * @param securityContext the {@link SecurityContext} to establish for the delegate * {@link Callable}. If null, defaults to {@link SecurityContextHolder#getContext()} * @return */ - public static Callable create(Callable delegate, - SecurityContext securityContext) { - return securityContext == null ? new DelegatingSecurityContextCallable<>( - delegate) : new DelegatingSecurityContextCallable<>(delegate, - securityContext); + public static Callable create(Callable delegate, SecurityContext securityContext) { + return (securityContext != null) ? new DelegatingSecurityContextCallable<>(delegate, securityContext) + : new DelegatingSecurityContextCallable<>(delegate); } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java index ba9c0a6217..c1af6a7546 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import java.util.concurrent.Executor; @@ -28,21 +29,19 @@ import org.springframework.util.Assert; * @author Rob Winch * @since 3.2 */ -public class DelegatingSecurityContextExecutor extends - AbstractDelegatingSecurityContextSupport implements Executor { +public class DelegatingSecurityContextExecutor extends AbstractDelegatingSecurityContextSupport implements Executor { + private final Executor delegate; /** * Creates a new {@link DelegatingSecurityContextExecutor} that uses the specified * {@link SecurityContext}. - * * @param delegateExecutor the {@link Executor} to delegate to. Cannot be null. * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} or null to default to the current * {@link SecurityContext} */ - public DelegatingSecurityContextExecutor(Executor delegateExecutor, - SecurityContext securityContext) { + public DelegatingSecurityContextExecutor(Executor delegateExecutor, SecurityContext securityContext) { super(securityContext); Assert.notNull(delegateExecutor, "delegateExecutor cannot be null"); this.delegate = delegateExecutor; @@ -52,19 +51,19 @@ public class DelegatingSecurityContextExecutor extends * Creates a new {@link DelegatingSecurityContextExecutor} that uses the current * {@link SecurityContext} from the {@link SecurityContextHolder} at the time the task * is submitted. - * * @param delegate the {@link Executor} to delegate to. Cannot be null. */ public DelegatingSecurityContextExecutor(Executor delegate) { this(delegate, null); } + @Override public final void execute(Runnable task) { - task = wrap(task); - delegate.execute(task); + this.delegate.execute(wrap(task)); } protected final Executor getDelegateExecutor() { - return delegate; + return this.delegate; } -} \ No newline at end of file + +} diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutorService.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutorService.java index 9f00bb4fab..289f9bec83 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutorService.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutorService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import java.util.ArrayList; @@ -36,90 +37,94 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Rob Winch * @since 3.2 */ -public class DelegatingSecurityContextExecutorService extends - DelegatingSecurityContextExecutor implements ExecutorService { +public class DelegatingSecurityContextExecutorService extends DelegatingSecurityContextExecutor + implements ExecutorService { + /** * Creates a new {@link DelegatingSecurityContextExecutorService} that uses the * specified {@link SecurityContext}. - * * @param delegateExecutorService the {@link ExecutorService} to delegate to. Cannot * be null. * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} and each * {@link DelegatingSecurityContextCallable}. */ - public DelegatingSecurityContextExecutorService( - ExecutorService delegateExecutorService, SecurityContext securityContext) { + public DelegatingSecurityContextExecutorService(ExecutorService delegateExecutorService, + SecurityContext securityContext) { super(delegateExecutorService, securityContext); } /** * Creates a new {@link DelegatingSecurityContextExecutorService} that uses the * current {@link SecurityContext} from the {@link SecurityContextHolder}. - * - * @param delegate the {@link ExecutorService} to delegate to. Cannot be - * null. + * @param delegate the {@link ExecutorService} to delegate to. Cannot be null. */ public DelegatingSecurityContextExecutorService(ExecutorService delegate) { this(delegate, null); } + @Override public final void shutdown() { getDelegate().shutdown(); } + @Override public final List shutdownNow() { return getDelegate().shutdownNow(); } + @Override public final boolean isShutdown() { return getDelegate().isShutdown(); } + @Override public final boolean isTerminated() { return getDelegate().isTerminated(); } - public final boolean awaitTermination(long timeout, TimeUnit unit) - throws InterruptedException { + @Override + public final boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { return getDelegate().awaitTermination(timeout, unit); } + @Override public final Future submit(Callable task) { - task = wrap(task); - return getDelegate().submit(task); + return getDelegate().submit(wrap(task)); } + @Override public final Future submit(Runnable task, T result) { - task = wrap(task); - return getDelegate().submit(task, result); + return getDelegate().submit(wrap(task), result); } + @Override public final Future submit(Runnable task) { - task = wrap(task); - return getDelegate().submit(task); + return getDelegate().submit(wrap(task)); } + @Override @SuppressWarnings({ "rawtypes", "unchecked" }) public final List invokeAll(Collection tasks) throws InterruptedException { tasks = createTasks(tasks); return getDelegate().invokeAll(tasks); } + @Override @SuppressWarnings({ "rawtypes", "unchecked" }) - public final List invokeAll(Collection tasks, long timeout, TimeUnit unit) - throws InterruptedException { + public final List invokeAll(Collection tasks, long timeout, TimeUnit unit) throws InterruptedException { tasks = createTasks(tasks); return getDelegate().invokeAll(tasks, timeout, unit); } + @Override @SuppressWarnings({ "rawtypes", "unchecked" }) - public final Object invokeAny(Collection tasks) throws InterruptedException, - ExecutionException { + public final Object invokeAny(Collection tasks) throws InterruptedException, ExecutionException { tasks = createTasks(tasks); return getDelegate().invokeAny(tasks); } + @Override @SuppressWarnings({ "rawtypes", "unchecked" }) public final Object invokeAny(Collection tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { @@ -141,4 +146,5 @@ public class DelegatingSecurityContextExecutorService extends private ExecutorService getDelegate() { return (ExecutorService) getDelegateExecutor(); } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java index addba8cfd9..24b0746641 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.springframework.security.core.context.SecurityContext; @@ -26,8 +27,8 @@ import org.springframework.util.Assert; * {@link SecurityContext} after the delegate has completed. *

      *

      - * If there is a {@link SecurityContext} that already exists, it will be - * restored after the {@link #run()} method is invoked. + * If there is a {@link SecurityContext} that already exists, it will be restored after + * the {@link #run()} method is invoked. *

      * * @author Rob Winch @@ -38,14 +39,13 @@ public final class DelegatingSecurityContextRunnable implements Runnable { private final Runnable delegate; /** - * The {@link SecurityContext} that the delegate {@link Runnable} will be - * ran as. + * The {@link SecurityContext} that the delegate {@link Runnable} will be ran as. */ private final SecurityContext delegateSecurityContext; /** - * The {@link SecurityContext} that was on the {@link SecurityContextHolder} - * prior to being set to the delegateSecurityContext. + * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to + * being set to the delegateSecurityContext. */ private SecurityContext originalSecurityContext; @@ -57,8 +57,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable { * @param securityContext the {@link SecurityContext} to establish for the delegate * {@link Runnable}. Cannot be null. */ - public DelegatingSecurityContextRunnable(Runnable delegate, - SecurityContext securityContext) { + public DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext) { Assert.notNull(delegate, "delegate cannot be null"); Assert.notNull(securityContext, "securityContext cannot be null"); this.delegate = delegate; @@ -78,17 +77,17 @@ public final class DelegatingSecurityContextRunnable implements Runnable { @Override public void run() { this.originalSecurityContext = SecurityContextHolder.getContext(); - try { - SecurityContextHolder.setContext(delegateSecurityContext); - delegate.run(); + SecurityContextHolder.setContext(this.delegateSecurityContext); + this.delegate.run(); } finally { SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); - if (emptyContext.equals(originalSecurityContext)) { + if (emptyContext.equals(this.originalSecurityContext)) { SecurityContextHolder.clearContext(); - } else { - SecurityContextHolder.setContext(originalSecurityContext); + } + else { + SecurityContextHolder.setContext(this.originalSecurityContext); } this.originalSecurityContext = null; } @@ -96,12 +95,11 @@ public final class DelegatingSecurityContextRunnable implements Runnable { @Override public String toString() { - return delegate.toString(); + return this.delegate.toString(); } /** * Factory method for creating a {@link DelegatingSecurityContextRunnable}. - * * @param delegate the original {@link Runnable} that will be delegated to after * establishing a {@link SecurityContext} on the {@link SecurityContextHolder}. Cannot * have null. @@ -112,7 +110,8 @@ public final class DelegatingSecurityContextRunnable implements Runnable { */ public static Runnable create(Runnable delegate, SecurityContext securityContext) { Assert.notNull(delegate, "delegate cannot be null"); - return securityContext == null ? new DelegatingSecurityContextRunnable(delegate) - : new DelegatingSecurityContextRunnable(delegate, securityContext); + return (securityContext != null) ? new DelegatingSecurityContextRunnable(delegate, securityContext) + : new DelegatingSecurityContextRunnable(delegate); } + } diff --git a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextScheduledExecutorService.java b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextScheduledExecutorService.java index f58fc68989..ee8ff98489 100644 --- a/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextScheduledExecutorService.java +++ b/core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextScheduledExecutorService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import java.util.concurrent.Callable; @@ -31,20 +32,19 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Rob Winch * @since 3.2 */ -public final class DelegatingSecurityContextScheduledExecutorService extends - DelegatingSecurityContextExecutorService implements ScheduledExecutorService { +public final class DelegatingSecurityContextScheduledExecutorService extends DelegatingSecurityContextExecutorService + implements ScheduledExecutorService { + /** * Creates a new {@link DelegatingSecurityContextScheduledExecutorService} that uses * the specified {@link SecurityContext}. - * * @param delegateScheduledExecutorService the {@link ScheduledExecutorService} to * delegate to. Cannot be null. * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} and each * {@link DelegatingSecurityContextCallable}. */ - public DelegatingSecurityContextScheduledExecutorService( - ScheduledExecutorService delegateScheduledExecutorService, + public DelegatingSecurityContextScheduledExecutorService(ScheduledExecutorService delegateScheduledExecutorService, SecurityContext securityContext) { super(delegateScheduledExecutorService, securityContext); } @@ -52,39 +52,35 @@ public final class DelegatingSecurityContextScheduledExecutorService extends /** * Creates a new {@link DelegatingSecurityContextScheduledExecutorService} that uses * the current {@link SecurityContext} from the {@link SecurityContextHolder}. - * * @param delegate the {@link ScheduledExecutorService} to delegate to. Cannot be * null. */ - public DelegatingSecurityContextScheduledExecutorService( - ScheduledExecutorService delegate) { + public DelegatingSecurityContextScheduledExecutorService(ScheduledExecutorService delegate) { this(delegate, null); } + @Override public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { - command = wrap(command); - return getDelegate().schedule(command, delay, unit); + return getDelegate().schedule(wrap(command), delay, unit); } - public ScheduledFuture schedule(Callable callable, long delay, - TimeUnit unit) { - callable = wrap(callable); - return getDelegate().schedule(callable, delay, unit); + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return getDelegate().schedule(wrap(callable), delay, unit); } - public ScheduledFuture scheduleAtFixedRate(Runnable command, - long initialDelay, long period, TimeUnit unit) { - command = wrap(command); - return getDelegate().scheduleAtFixedRate(command, initialDelay, period, unit); + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + return getDelegate().scheduleAtFixedRate(wrap(command), initialDelay, period, unit); } - public ScheduledFuture scheduleWithFixedDelay(Runnable command, - long initialDelay, long delay, TimeUnit unit) { - command = wrap(command); - return getDelegate().scheduleWithFixedDelay(command, initialDelay, delay, unit); + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { + return getDelegate().scheduleWithFixedDelay(wrap(command), initialDelay, delay, unit); } private ScheduledExecutorService getDelegate() { return (ScheduledExecutorService) getDelegateExecutor(); } + } diff --git a/core/src/main/java/org/springframework/security/context/DelegatingApplicationListener.java b/core/src/main/java/org/springframework/security/context/DelegatingApplicationListener.java index b1341a91df..38ced18d42 100644 --- a/core/src/main/java/org/springframework/security/context/DelegatingApplicationListener.java +++ b/core/src/main/java/org/springframework/security/context/DelegatingApplicationListener.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.context; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationListener; import org.springframework.context.event.SmartApplicationListener; import org.springframework.util.Assert; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; - /** * Used for delegating to a number of SmartApplicationListener instances. This is useful * when needing to register an SmartApplicationListener with the ApplicationContext @@ -30,15 +31,16 @@ import java.util.concurrent.CopyOnWriteArrayList; * * @author Rob Winch */ -public final class DelegatingApplicationListener implements - ApplicationListener { +public final class DelegatingApplicationListener implements ApplicationListener { + private List listeners = new CopyOnWriteArrayList<>(); + @Override public void onApplicationEvent(ApplicationEvent event) { if (event == null) { return; } - for (SmartApplicationListener listener : listeners) { + for (SmartApplicationListener listener : this.listeners) { Object source = event.getSource(); if (source != null && listener.supportsEventType(event.getClass()) && listener.supportsSourceType(source.getClass())) { @@ -49,13 +51,12 @@ public final class DelegatingApplicationListener implements /** * Adds a new SmartApplicationListener to use. - * * @param smartApplicationListener the SmartApplicationListener to use. Cannot be * null. */ public void addListener(SmartApplicationListener smartApplicationListener) { - Assert.notNull(smartApplicationListener, - "smartApplicationListener cannot be null"); - listeners.add(smartApplicationListener); + Assert.notNull(smartApplicationListener, "smartApplicationListener cannot be null"); + this.listeners.add(smartApplicationListener); } + } diff --git a/core/src/main/java/org/springframework/security/converter/RsaKeyConverters.java b/core/src/main/java/org/springframework/security/converter/RsaKeyConverters.java index c0aeb2a53d..c330d4121e 100644 --- a/core/src/main/java/org/springframework/security/converter/RsaKeyConverters.java +++ b/core/src/main/java/org/springframework/security/converter/RsaKeyConverters.java @@ -16,8 +16,8 @@ package org.springframework.security.converter; -import java.io.InputStream; import java.io.BufferedReader; +import java.io.InputStream; import java.io.InputStreamReader; import java.security.KeyFactory; import java.security.NoSuchAlgorithmException; @@ -25,8 +25,8 @@ import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.X509EncodedKeySpec; -import java.util.List; import java.util.Base64; +import java.util.List; import java.util.stream.Collectors; import org.springframework.core.convert.converter.Converter; @@ -38,34 +38,42 @@ import org.springframework.util.Assert; * @author Josh Cummings * @since 5.2 */ -public class RsaKeyConverters { +public final class RsaKeyConverters { + private static final String DASHES = "-----"; + private static final String PKCS8_PEM_HEADER = DASHES + "BEGIN PRIVATE KEY" + DASHES; + private static final String PKCS8_PEM_FOOTER = DASHES + "END PRIVATE KEY" + DASHES; + private static final String X509_PEM_HEADER = DASHES + "BEGIN PUBLIC KEY" + DASHES; + private static final String X509_PEM_FOOTER = DASHES + "END PUBLIC KEY" + DASHES; + private RsaKeyConverters() { + } + /** * Construct a {@link Converter} for converting a PEM-encoded PKCS#8 RSA Private Key * into a {@link RSAPrivateKey}. * - * Note that keys are often formatted in PKCS#1 and this can easily be identified by the header. - * If the key file begins with "-----BEGIN RSA PRIVATE KEY-----", then it is PKCS#1. If it is - * PKCS#8 formatted, then it begins with "-----BEGIN PRIVATE KEY-----". + * Note that keys are often formatted in PKCS#1 and this can easily be identified by + * the header. If the key file begins with "-----BEGIN RSA PRIVATE KEY-----", then it + * is PKCS#1. If it is PKCS#8 formatted, then it begins with "-----BEGIN PRIVATE + * KEY-----". * - * This converter does not close the {@link InputStream} in order to avoid making non-portable - * assumptions about the streams' origin and further use. - * - * @return A {@link Converter} that can read a PEM-encoded PKCS#8 RSA Private Key and return a - * {@link RSAPrivateKey}. + * This converter does not close the {@link InputStream} in order to avoid making + * non-portable assumptions about the streams' origin and further use. + * @return A {@link Converter} that can read a PEM-encoded PKCS#8 RSA Private Key and + * return a {@link RSAPrivateKey}. */ public static Converter pkcs8() { KeyFactory keyFactory = rsaFactory(); - return source -> { + return (source) -> { List lines = readAllLines(source); Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(PKCS8_PEM_HEADER), - "Key is not in PEM-encoded PKCS#8 format, " + - "please check that the header begins with -----" + PKCS8_PEM_HEADER + "-----"); + "Key is not in PEM-encoded PKCS#8 format, please check that the header begins with -----" + + PKCS8_PEM_HEADER + "-----"); StringBuilder base64Encoded = new StringBuilder(); for (String line : lines) { if (RsaKeyConverters.isNotPkcs8Wrapper(line)) { @@ -73,12 +81,11 @@ public class RsaKeyConverters { } } byte[] pkcs8 = Base64.getDecoder().decode(base64Encoded.toString()); - try { - return (RSAPrivateKey) keyFactory.generatePrivate( - new PKCS8EncodedKeySpec(pkcs8)); - } catch (Exception e) { - throw new IllegalArgumentException(e); + return (RSAPrivateKey) keyFactory.generatePrivate(new PKCS8EncodedKeySpec(pkcs8)); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); } }; } @@ -87,19 +94,18 @@ public class RsaKeyConverters { * Construct a {@link Converter} for converting a PEM-encoded X.509 RSA Public Key * into a {@link RSAPublicKey}. * - * This converter does not close the {@link InputStream} in order to avoid making non-portable - * assumptions about the streams' origin and further use. - * - * @return A {@link Converter} that can read a PEM-encoded X.509 RSA Public Key and return a - * {@link RSAPublicKey}. + * This converter does not close the {@link InputStream} in order to avoid making + * non-portable assumptions about the streams' origin and further use. + * @return A {@link Converter} that can read a PEM-encoded X.509 RSA Public Key and + * return a {@link RSAPublicKey}. */ public static Converter x509() { KeyFactory keyFactory = rsaFactory(); - return source -> { + return (source) -> { List lines = readAllLines(source); Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(X509_PEM_HEADER), - "Key is not in PEM-encoded X.509 format, " + - "please check that the header begins with -----" + X509_PEM_HEADER + "-----"); + "Key is not in PEM-encoded X.509 format, please check that the header begins with -----" + + X509_PEM_HEADER + "-----"); StringBuilder base64Encoded = new StringBuilder(); for (String line : lines) { if (RsaKeyConverters.isNotX509Wrapper(line)) { @@ -107,12 +113,11 @@ public class RsaKeyConverters { } } byte[] x509 = Base64.getDecoder().decode(base64Encoded.toString()); - try { - return (RSAPublicKey) keyFactory.generatePublic( - new X509EncodedKeySpec(x509)); - } catch (Exception e) { - throw new IllegalArgumentException(e); + return (RSAPublicKey) keyFactory.generatePublic(new X509EncodedKeySpec(x509)); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); } }; } @@ -125,8 +130,9 @@ public class RsaKeyConverters { private static KeyFactory rsaFactory() { try { return KeyFactory.getInstance("RSA"); - } catch (NoSuchAlgorithmException e) { - throw new IllegalStateException(e); + } + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException(ex); } } @@ -137,4 +143,5 @@ public class RsaKeyConverters { private static boolean isNotX509Wrapper(String line) { return !X509_PEM_HEADER.equals(line) && !X509_PEM_FOOTER.equals(line); } + } diff --git a/core/src/main/java/org/springframework/security/core/AuthenticatedPrincipal.java b/core/src/main/java/org/springframework/security/core/AuthenticatedPrincipal.java index 51739334bf..88eb49837d 100644 --- a/core/src/main/java/org/springframework/security/core/AuthenticatedPrincipal.java +++ b/core/src/main/java/org/springframework/security/core/AuthenticatedPrincipal.java @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core; import org.springframework.security.authentication.AuthenticationManager; /** * Representation of an authenticated Principal once an - * {@link Authentication} request has been successfully authenticated - * by the {@link AuthenticationManager#authenticate(Authentication)} method. + * {@link Authentication} request has been successfully authenticated by the + * {@link AuthenticationManager#authenticate(Authentication)} method. * * Implementors typically provide their own representation of a Principal, - * which usually contains information describing the Principal entity, - * such as, first/middle/last name, address, email, phone, id, etc. + * which usually contains information describing the Principal entity, such + * as, first/middle/last name, address, email, phone, id, etc. * - * This interface allows implementors to expose specific attributes - * of their custom representation of Principal in a generic way. + * This interface allows implementors to expose specific attributes of their custom + * representation of Principal in a generic way. * * @author Joe Grandja * @since 5.0 @@ -37,8 +38,8 @@ import org.springframework.security.authentication.AuthenticationManager; public interface AuthenticatedPrincipal { /** - * Returns the name of the authenticated Principal. Never null. - * + * Returns the name of the authenticated Principal. Never + * null. * @return the name of the authenticated Principal */ String getName(); diff --git a/core/src/main/java/org/springframework/security/core/Authentication.java b/core/src/main/java/org/springframework/security/core/Authentication.java index 0a09c26a20..12e04fdeb2 100644 --- a/core/src/main/java/org/springframework/security/core/Authentication.java +++ b/core/src/main/java/org/springframework/security/core/Authentication.java @@ -49,8 +49,6 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Ben Alex */ public interface Authentication extends Principal, Serializable { - // ~ Methods - // ======================================================================================================== /** * Set by an AuthenticationManager to indicate the authorities that the @@ -61,7 +59,6 @@ public interface Authentication extends Principal, Serializable { * do not affect the state of the Authentication object, or use an unmodifiable * instance. *

      - * * @return the authorities granted to the principal, or an empty collection if the * token has not been authenticated. Never null. */ @@ -71,7 +68,6 @@ public interface Authentication extends Principal, Serializable { * The credentials that prove the principal is correct. This is usually a password, * but could be anything relevant to the AuthenticationManager. Callers * are expected to populate the credentials. - * * @return the credentials that prove the identity of the Principal */ Object getCredentials(); @@ -79,7 +75,6 @@ public interface Authentication extends Principal, Serializable { /** * Stores additional details about the authentication request. These might be an IP * address, certificate serial number etc. - * * @return additional details about the authentication request, or null * if not used */ @@ -94,7 +89,6 @@ public interface Authentication extends Principal, Serializable { * Authentication containing richer information as the principal for use by * the application. Many of the authentication providers will create a * {@code UserDetails} object as the principal. - * * @return the Principal being authenticated or the authenticated * principal after authentication. */ @@ -114,7 +108,6 @@ public interface Authentication extends Principal, Serializable { * about returning true from this method unless they are either * immutable, or have some way of ensuring the properties have not been changed since * original creation. - * * @return true if the token has been authenticated and the * AbstractSecurityInterceptor does not need to present the token to the * AuthenticationManager again for re-authentication. @@ -130,14 +123,13 @@ public interface Authentication extends Principal, Serializable { * an invocation with a true parameter (which would indicate the * authentication token is trusted - a potential security risk) the implementation * should throw an {@link IllegalArgumentException}. - * * @param isAuthenticated true if the token should be trusted (which may * result in an exception) or false if the token should not be trusted - * * @throws IllegalArgumentException if an attempt to make the authentication token * trusted (by passing true as the argument) is rejected due to the * implementation being immutable or implementing its own alternative approach to * {@link #isAuthenticated()} */ void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException; + } diff --git a/core/src/main/java/org/springframework/security/core/AuthenticationException.java b/core/src/main/java/org/springframework/security/core/AuthenticationException.java index 0d85b76ef2..e634738b69 100644 --- a/core/src/main/java/org/springframework/security/core/AuthenticationException.java +++ b/core/src/main/java/org/springframework/security/core/AuthenticationException.java @@ -24,24 +24,19 @@ package org.springframework.security.core; */ public abstract class AuthenticationException extends RuntimeException { - // ~ Constructors - // =================================================================================================== - /** * Constructs an {@code AuthenticationException} with the specified message and root * cause. - * * @param msg the detail message - * @param t the root cause + * @param cause the root cause */ - public AuthenticationException(String msg, Throwable t) { - super(msg, t); + public AuthenticationException(String msg, Throwable cause) { + super(msg, cause); } /** * Constructs an {@code AuthenticationException} with the specified message and no * root cause. - * * @param msg the detail message */ public AuthenticationException(String msg) { diff --git a/core/src/main/java/org/springframework/security/core/ComparableVersion.java b/core/src/main/java/org/springframework/security/core/ComparableVersion.java index e5b6f247c2..bce13b7ac1 100644 --- a/core/src/main/java/org/springframework/security/core/ComparableVersion.java +++ b/core/src/main/java/org/springframework/security/core/ComparableVersion.java @@ -13,26 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core; -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ +package org.springframework.security.core; import java.math.BigInteger; import java.util.ArrayList; @@ -78,13 +60,14 @@ import java.util.Stack; * *

      * - * @see "Versioning" on - * Maven Wiki - * @author Kenney Westerhof - * @author Hervé Boutemy + * @author Kenney Westerhof + * @author Hervé Boutemy + * @see "Versioning" on Maven + * Wiki */ class ComparableVersion implements Comparable { + private String value; private String canonical; @@ -92,8 +75,11 @@ class ComparableVersion implements Comparable { private ListItem items; private interface Item { + int INTEGER_ITEM = 0; + int STRING_ITEM = 1; + int LIST_ITEM = 2; int compareTo(Item item); @@ -101,12 +87,14 @@ class ComparableVersion implements Comparable { int getType(); boolean isNull(); + } /** * Represents a numeric item in the version item list. */ private static class IntegerItem implements Item { + private static final BigInteger BigInteger_ZERO = new BigInteger("0"); private final BigInteger value; @@ -128,18 +116,18 @@ class ComparableVersion implements Comparable { @Override public boolean isNull() { - return BigInteger_ZERO.equals(value); + return BigInteger_ZERO.equals(this.value); } @Override public int compareTo(Item item) { if (item == null) { - return BigInteger_ZERO.equals(value) ? 0 : 1; // 1.0 == 1, 1.1 > 1 + return BigInteger_ZERO.equals(this.value) ? 0 : 1; // 1.0 == 1, 1.1 > 1 } switch (item.getType()) { case INTEGER_ITEM: - return value.compareTo(((IntegerItem) item).value); + return this.value.compareTo(((IntegerItem) item).value); case STRING_ITEM: return 1; // 1.1 > 1-sp @@ -154,16 +142,17 @@ class ComparableVersion implements Comparable { @Override public String toString() { - return value.toString(); + return this.value.toString(); } + } /** * Represents a string in the version item list, usually a qualifier. */ private static class StringItem implements Item { - private static final String[] QUALIFIERS = { "alpha", "beta", "milestone", "rc", - "snapshot", "", "sp" }; + + private static final String[] QUALIFIERS = { "alpha", "beta", "milestone", "rc", "snapshot", "", "sp" }; private static final List _QUALIFIERS = Arrays.asList(QUALIFIERS); @@ -179,8 +168,7 @@ class ComparableVersion implements Comparable { * determine if a given qualifier makes the version older than one without a * qualifier, or more recent. */ - private static final String RELEASE_VERSION_INDEX = String.valueOf(_QUALIFIERS - .indexOf("")); + private static final String RELEASE_VERSION_INDEX = String.valueOf(_QUALIFIERS.indexOf("")); private String value; @@ -209,7 +197,7 @@ class ComparableVersion implements Comparable { @Override public boolean isNull() { - return (comparableQualifier(value).compareTo(RELEASE_VERSION_INDEX) == 0); + return (comparableQualifier(this.value).compareTo(RELEASE_VERSION_INDEX) == 0); } /** @@ -222,7 +210,6 @@ class ComparableVersion implements Comparable { * if/then/else to check for -1 or QUALIFIERS.size and then resort to lexical * ordering. Most comparisons are decided by the first character, so this is still * fast. If more characters are needed then it requires a lexical sort anyway. - * * @param qualifier * @return an equivalent value that can be used with lexical comparison */ @@ -236,15 +223,14 @@ class ComparableVersion implements Comparable { public int compareTo(Item item) { if (item == null) { // 1-rc < 1, 1-ga > 1 - return comparableQualifier(value).compareTo(RELEASE_VERSION_INDEX); + return comparableQualifier(this.value).compareTo(RELEASE_VERSION_INDEX); } switch (item.getType()) { case INTEGER_ITEM: return -1; // 1.any < 1.1 ? case STRING_ITEM: - return comparableQualifier(value).compareTo( - comparableQualifier(((StringItem) item).value)); + return comparableQualifier(this.value).compareTo(comparableQualifier(((StringItem) item).value)); case LIST_ITEM: return -1; // 1.any < 1-1 @@ -256,8 +242,9 @@ class ComparableVersion implements Comparable { @Override public String toString() { - return value; + return this.value; } + } /** @@ -265,6 +252,7 @@ class ComparableVersion implements Comparable { * and for sub-lists (which start with '-(number)' in the version specification). */ private static class ListItem extends ArrayList implements Item { + @Override public int getType() { return LIST_ITEM; @@ -276,8 +264,7 @@ class ComparableVersion implements Comparable { } void normalize() { - for (ListIterator iterator = listIterator(size()); iterator - .hasPrevious();) { + for (ListIterator iterator = listIterator(size()); iterator.hasPrevious();) { Item item = iterator.previous(); if (item.isNull()) { iterator.remove(); // remove null trailing items: 0, "", empty list @@ -339,6 +326,7 @@ class ComparableVersion implements Comparable { buffer.append(')'); return buffer.toString(); } + } ComparableVersion(String version) { @@ -348,11 +336,11 @@ class ComparableVersion implements Comparable { public final void parseVersion(String version) { this.value = version; - items = new ListItem(); + this.items = new ListItem(); version = version.toLowerCase(Locale.ENGLISH); - ListItem list = items; + ListItem list = this.items; Stack stack = new Stack<>(); stack.push(list); @@ -385,8 +373,7 @@ class ComparableVersion implements Comparable { if (isDigit) { list.normalize(); // 1.0-* = 1-* - if ((i + 1 < version.length()) - && Character.isDigit(version.charAt(i + 1))) { + if ((i + 1 < version.length()) && Character.isDigit(version.charAt(i + 1))) { // new ListItem only if previous were digits and new char is a // digit, // ie need to differentiate only 1.1 from 1-1 @@ -423,7 +410,7 @@ class ComparableVersion implements Comparable { list.normalize(); } - canonical = items.toString(); + this.canonical = this.items.toString(); } private static Item parseItem(boolean isDigit, String buf) { @@ -432,22 +419,22 @@ class ComparableVersion implements Comparable { @Override public int compareTo(ComparableVersion o) { - return items.compareTo(o.items); + return this.items.compareTo(o.items); } @Override public String toString() { - return value; + return this.value; } @Override public boolean equals(Object o) { - return (o instanceof ComparableVersion) - && canonical.equals(((ComparableVersion) o).canonical); + return (o instanceof ComparableVersion) && this.canonical.equals(((ComparableVersion) o).canonical); } @Override public int hashCode() { - return canonical.hashCode(); + return this.canonical.hashCode(); } + } diff --git a/core/src/main/java/org/springframework/security/core/CredentialsContainer.java b/core/src/main/java/org/springframework/security/core/CredentialsContainer.java index 5dc1b520cf..0a5d020870 100644 --- a/core/src/main/java/org/springframework/security/core/CredentialsContainer.java +++ b/core/src/main/java/org/springframework/security/core/CredentialsContainer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core; /** @@ -29,5 +30,7 @@ package org.springframework.security.core; * @since 3.0.3 */ public interface CredentialsContainer { + void eraseCredentials(); + } diff --git a/core/src/main/java/org/springframework/security/core/GrantedAuthority.java b/core/src/main/java/org/springframework/security/core/GrantedAuthority.java index 7e29789fbe..463bf5432f 100644 --- a/core/src/main/java/org/springframework/security/core/GrantedAuthority.java +++ b/core/src/main/java/org/springframework/security/core/GrantedAuthority.java @@ -30,8 +30,6 @@ import org.springframework.security.access.AccessDecisionManager; * @author Ben Alex */ public interface GrantedAuthority extends Serializable { - // ~ Methods - // ======================================================================================================== /** * If the GrantedAuthority can be represented as a String @@ -44,10 +42,10 @@ public interface GrantedAuthority extends Serializable { * null will require an AccessDecisionManager (or delegate) * to specifically support the GrantedAuthority implementation, so * returning null should be avoided unless actually required. - * * @return a representation of the granted authority (or null if the * granted authority cannot be expressed as a String with sufficient * precision). */ String getAuthority(); + } diff --git a/core/src/main/java/org/springframework/security/core/SpringSecurityCoreVersion.java b/core/src/main/java/org/springframework/security/core/SpringSecurityCoreVersion.java index 89b47d111f..e67e0cafd3 100644 --- a/core/src/main/java/org/springframework/security/core/SpringSecurityCoreVersion.java +++ b/core/src/main/java/org/springframework/security/core/SpringSecurityCoreVersion.java @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core; +import java.io.IOException; +import java.util.Properties; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.SpringVersion; -import java.io.IOException; -import java.util.Properties; - /** * Internal class used for checking version compatibility in a deployed application. * * @author Luke Taylor * @author Rob Winch */ -public class SpringSecurityCoreVersion { - private static final String DISABLE_CHECKS = SpringSecurityCoreVersion.class.getName() - .concat(".DISABLE_CHECKS"); +public final class SpringSecurityCoreVersion { + + private static final String DISABLE_CHECKS = SpringSecurityCoreVersion.class.getName().concat(".DISABLE_CHECKS"); private static final Log logger = LogFactory.getLog(SpringSecurityCoreVersion.class); @@ -49,21 +50,15 @@ public class SpringSecurityCoreVersion { performVersionChecks(); } - public static String getVersion() { - Package pkg = SpringSecurityCoreVersion.class.getPackage(); - return (pkg != null ? pkg.getImplementationVersion() : null); + private SpringSecurityCoreVersion() { } - /** - * Performs version checks - */ private static void performVersionChecks() { performVersionChecks(MIN_SPRING_VERSION); } /** * Perform version checks with specific min Spring Version - * * @param minSpringVersion */ private static void performVersionChecks(String minSpringVersion) { @@ -73,29 +68,29 @@ public class SpringSecurityCoreVersion { // Check Spring Compatibility String springVersion = SpringVersion.getVersion(); String version = getVersion(); - if (disableChecks(springVersion, version)) { return; } - logger.info("You are running with Spring Security Core " + version); - if (new ComparableVersion(springVersion) - .compareTo(new ComparableVersion(minSpringVersion)) < 0) { + if (new ComparableVersion(springVersion).compareTo(new ComparableVersion(minSpringVersion)) < 0) { logger.warn("**** You are advised to use Spring " + minSpringVersion + " or later with this version. You are running: " + springVersion); } } + public static String getVersion() { + Package pkg = SpringSecurityCoreVersion.class.getPackage(); + return (pkg != null) ? pkg.getImplementationVersion() : null; + } + /** * Disable if springVersion and springSecurityVersion are the same to allow working * with Uber Jars. - * * @param springVersion * @param springSecurityVersion * @return */ - private static boolean disableChecks(String springVersion, - String springSecurityVersion) { + private static boolean disableChecks(String springVersion, String springSecurityVersion) { if (springVersion == null || springVersion.equals(springSecurityVersion)) { return true; } @@ -109,10 +104,13 @@ public class SpringSecurityCoreVersion { private static String getSpringVersion() { Properties properties = new Properties(); try { - properties.load(SpringSecurityCoreVersion.class.getClassLoader().getResourceAsStream("META-INF/spring-security.versions")); - } catch (IOException | NullPointerException e) { + properties.load(SpringSecurityCoreVersion.class.getClassLoader() + .getResourceAsStream("META-INF/spring-security.versions")); + } + catch (IOException | NullPointerException ex) { return null; } return properties.getProperty("org.springframework:spring-core"); } + } diff --git a/core/src/main/java/org/springframework/security/core/SpringSecurityMessageSource.java b/core/src/main/java/org/springframework/security/core/SpringSecurityMessageSource.java index 74959b6eee..3166e36224 100644 --- a/core/src/main/java/org/springframework/security/core/SpringSecurityMessageSource.java +++ b/core/src/main/java/org/springframework/security/core/SpringSecurityMessageSource.java @@ -32,17 +32,13 @@ import org.springframework.context.support.ResourceBundleMessageSource; * @author Ben Alex */ public class SpringSecurityMessageSource extends ResourceBundleMessageSource { - // ~ Constructors - // =================================================================================================== public SpringSecurityMessageSource() { setBasename("org.springframework.security.messages"); } - // ~ Methods - // ======================================================================================================== - public static MessageSourceAccessor getAccessor() { return new MessageSourceAccessor(new SpringSecurityMessageSource()); } + } diff --git a/core/src/main/java/org/springframework/security/core/Transient.java b/core/src/main/java/org/springframework/security/core/Transient.java index b5bb566e7f..a3d8d749a3 100644 --- a/core/src/main/java/org/springframework/security/core/Transient.java +++ b/core/src/main/java/org/springframework/security/core/Transient.java @@ -24,15 +24,16 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** - * A marker for {@link Authentication}s that should never be stored across requests, for example - * a bearer token authentication + * A marker for {@link Authentication}s that should never be stored across requests, for + * example a bearer token authentication * * @author Josh Cummings * @since 5.1 */ -@Target({ElementType.TYPE, ElementType.ANNOTATION_TYPE}) +@Target({ ElementType.TYPE, ElementType.ANNOTATION_TYPE }) @Retention(RetentionPolicy.RUNTIME) @Inherited @Documented public @interface Transient { + } diff --git a/core/src/main/java/org/springframework/security/core/annotation/AuthenticationPrincipal.java b/core/src/main/java/org/springframework/security/core/annotation/AuthenticationPrincipal.java index f75f2bfd28..1d544402f8 100644 --- a/core/src/main/java/org/springframework/security/core/annotation/AuthenticationPrincipal.java +++ b/core/src/main/java/org/springframework/security/core/annotation/AuthenticationPrincipal.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.annotation; import java.lang.annotation.Documented; @@ -42,7 +43,6 @@ public @interface AuthenticationPrincipal { /** * True if a {@link ClassCastException} should be thrown when the current * {@link Authentication#getPrincipal()} is the incorrect type. Default is false. - * * @return */ boolean errorOnInvalidType() default false; @@ -71,8 +71,8 @@ public @interface AuthenticationPrincipal { *
       	 * @AuthenticationPrincipal(expression = "customUser")
       	 * 
      - * * @return the expression to use. */ String expression() default ""; + } diff --git a/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java b/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java index c32adb405c..11301ad7c1 100644 --- a/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java +++ b/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.annotation; import java.lang.annotation.Documented; @@ -22,8 +23,8 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** - * Annotation that is used to resolve the {@link org.springframework.security.core.context.SecurityContext} as a method - * argument. + * Annotation that is used to resolve the + * {@link org.springframework.security.core.context.SecurityContext} as a method argument. * * @author Dan Zheng * @since 5.2 @@ -44,26 +45,27 @@ import java.lang.annotation.Target; @Retention(RetentionPolicy.RUNTIME) @Documented public @interface CurrentSecurityContext { + /** * True if a {@link ClassCastException} should be thrown when the current - * {@link org.springframework.security.core.context.SecurityContext} is the incorrect type. Default is false. - * + * {@link org.springframework.security.core.context.SecurityContext} is the incorrect + * type. Default is false. * @return whether or not to error on an invalid type */ boolean errorOnInvalidType() default false; /** - * If specified, will use the provided SpEL expression to resolve the security context. This - * is convenient if applications need to transform the result. + * If specified, will use the provided SpEL expression to resolve the security + * context. This is convenient if applications need to transform the result. * - * For example, if an application needs to extract its custom {@code Authentication} implementation, - * then it could specify the appropriate SpEL like so: + * For example, if an application needs to extract its custom {@code Authentication} + * implementation, then it could specify the appropriate SpEL like so: * *
       	 * @CurrentSecurityContext(expression = "authentication") CustomAuthentication authentication
       	 * 
      - * * @return the expression to use */ String expression() default ""; + } diff --git a/core/src/main/java/org/springframework/security/core/authority/AuthorityUtils.java b/core/src/main/java/org/springframework/security/core/authority/AuthorityUtils.java index df67a2aba1..babadbca1d 100644 --- a/core/src/main/java/org/springframework/security/core/authority/AuthorityUtils.java +++ b/core/src/main/java/org/springframework/security/core/authority/AuthorityUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority; import java.util.ArrayList; @@ -23,8 +24,8 @@ import java.util.List; import java.util.Set; import org.springframework.security.core.GrantedAuthority; -import org.springframework.util.StringUtils; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * Utility method for manipulating GrantedAuthority collections etc. @@ -33,20 +34,21 @@ import org.springframework.util.Assert; * * @author Luke Taylor */ -public abstract class AuthorityUtils { +public final class AuthorityUtils { + public static final List NO_AUTHORITIES = Collections.emptyList(); + private AuthorityUtils() { + } + /** * Creates a array of GrantedAuthority objects from a comma-separated string * representation (e.g. "ROLE_A, ROLE_B, ROLE_C"). - * * @param authorityString the comma-separated string * @return the authorities created by tokenizing the string */ - public static List commaSeparatedStringToAuthorityList( - String authorityString) { - return createAuthorityList(StringUtils - .tokenizeToStringArray(authorityString, ",")); + public static List commaSeparatedStringToAuthorityList(String authorityString) { + return createAuthorityList(StringUtils.tokenizeToStringArray(authorityString, ",")); } /** @@ -54,31 +56,26 @@ public abstract class AuthorityUtils { * @return a Set of the Strings obtained from each call to * GrantedAuthority.getAuthority() */ - public static Set authorityListToSet( - Collection userAuthorities) { + public static Set authorityListToSet(Collection userAuthorities) { Assert.notNull(userAuthorities, "userAuthorities cannot be null"); Set set = new HashSet<>(userAuthorities.size()); - for (GrantedAuthority authority : userAuthorities) { set.add(authority.getAuthority()); } - return set; } /** * Converts authorities into a List of GrantedAuthority objects. - * * @param authorities the authorities to convert * @return a List of GrantedAuthority objects */ public static List createAuthorityList(String... authorities) { List grantedAuthorities = new ArrayList<>(authorities.length); - for (String authority : authorities) { grantedAuthorities.add(new SimpleGrantedAuthority(authority)); } - return grantedAuthorities; } + } diff --git a/core/src/main/java/org/springframework/security/core/authority/GrantedAuthoritiesContainer.java b/core/src/main/java/org/springframework/security/core/authority/GrantedAuthoritiesContainer.java index 5564316755..5511964133 100644 --- a/core/src/main/java/org/springframework/security/core/authority/GrantedAuthoritiesContainer.java +++ b/core/src/main/java/org/springframework/security/core/authority/GrantedAuthoritiesContainer.java @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority; import java.io.Serializable; -import java.util.*; +import java.util.Collection; import org.springframework.security.core.GrantedAuthority; @@ -31,5 +32,7 @@ import org.springframework.security.core.GrantedAuthority; * @since 2.0 */ public interface GrantedAuthoritiesContainer extends Serializable { + Collection getGrantedAuthorities(); + } diff --git a/core/src/main/java/org/springframework/security/core/authority/SimpleGrantedAuthority.java b/core/src/main/java/org/springframework/security/core/authority/SimpleGrantedAuthority.java index 21e2803e2e..71719d4801 100644 --- a/core/src/main/java/org/springframework/security/core/authority/SimpleGrantedAuthority.java +++ b/core/src/main/java/org/springframework/security/core/authority/SimpleGrantedAuthority.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority; import org.springframework.security.core.GrantedAuthority; @@ -41,7 +42,7 @@ public final class SimpleGrantedAuthority implements GrantedAuthority { @Override public String getAuthority() { - return role; + return this.role; } @Override @@ -49,11 +50,9 @@ public final class SimpleGrantedAuthority implements GrantedAuthority { if (this == obj) { return true; } - if (obj instanceof SimpleGrantedAuthority) { - return role.equals(((SimpleGrantedAuthority) obj).role); + return this.role.equals(((SimpleGrantedAuthority) obj).role); } - return false; } @@ -66,4 +65,5 @@ public final class SimpleGrantedAuthority implements GrantedAuthority { public String toString() { return this.role; } + } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/Attributes2GrantedAuthoritiesMapper.java b/core/src/main/java/org/springframework/security/core/authority/mapping/Attributes2GrantedAuthoritiesMapper.java index a375df225d..10cd297225 100755 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/Attributes2GrantedAuthoritiesMapper.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/Attributes2GrantedAuthoritiesMapper.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; import java.util.Collection; @@ -27,16 +28,16 @@ import org.springframework.security.core.GrantedAuthority; * @since 2.0 */ public interface Attributes2GrantedAuthoritiesMapper { + /** * Implementations of this method should map the given collection of attributes to a * collection of Spring Security GrantedAuthorities. There are no restrictions for the * mapping process; a single attribute can be mapped to multiple Spring Security * GrantedAuthorities, all attributes can be mapped to a single Spring Security * {@code GrantedAuthority}, some attributes may not be mapped, etc. - * * @param attributes the attributes to be mapped * @return the collection of authorities created from the attributes */ - Collection getGrantedAuthorities( - Collection attributes); + Collection getGrantedAuthorities(Collection attributes); + } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/GrantedAuthoritiesMapper.java b/core/src/main/java/org/springframework/security/core/authority/mapping/GrantedAuthoritiesMapper.java index 99aa7e340e..0393ace9fa 100644 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/GrantedAuthoritiesMapper.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/GrantedAuthoritiesMapper.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; -import org.springframework.security.core.GrantedAuthority; +import java.util.Collection; -import java.util.*; +import org.springframework.security.core.GrantedAuthority; /** * Mapping interface which can be injected into the authentication layer to convert the @@ -27,6 +28,7 @@ import java.util.*; * @author Luke Taylor */ public interface GrantedAuthoritiesMapper { - Collection mapAuthorities( - Collection authorities); + + Collection mapAuthorities(Collection authorities); + } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapper.java b/core/src/main/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapper.java index dd3eb7e0cf..263973f5d2 100755 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapper.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapper.java @@ -13,9 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.StringTokenizer; import org.springframework.beans.factory.InitializingBean; import org.springframework.security.core.GrantedAuthority; @@ -31,78 +39,71 @@ import org.springframework.util.StringUtils; * * @author Ruud Senden */ -public class MapBasedAttributes2GrantedAuthoritiesMapper implements - Attributes2GrantedAuthoritiesMapper, MappableAttributesRetriever, - InitializingBean { +public class MapBasedAttributes2GrantedAuthoritiesMapper + implements Attributes2GrantedAuthoritiesMapper, MappableAttributesRetriever, InitializingBean { + private Map> attributes2grantedAuthoritiesMap = null; + private String stringSeparator = ","; + private Set mappableAttributes = null; + @Override public void afterPropertiesSet() { - Assert.notNull(attributes2grantedAuthoritiesMap, - "attributes2grantedAuthoritiesMap must be set"); + Assert.notNull(this.attributes2grantedAuthoritiesMap, "attributes2grantedAuthoritiesMap must be set"); } /** * Map the given array of attributes to Spring Security GrantedAuthorities. */ + @Override public List getGrantedAuthorities(Collection attributes) { - ArrayList gaList = new ArrayList<>(); + ArrayList result = new ArrayList<>(); for (String attribute : attributes) { - Collection c = attributes2grantedAuthoritiesMap - .get(attribute); - if (c != null) { - gaList.addAll(c); + Collection granted = this.attributes2grantedAuthoritiesMap.get(attribute); + if (granted != null) { + result.addAll(granted); } } - gaList.trimToSize(); - - return gaList; + result.trimToSize(); + return result; } /** * @return Returns the attributes2grantedAuthoritiesMap. */ public Map> getAttributes2grantedAuthoritiesMap() { - return attributes2grantedAuthoritiesMap; + return this.attributes2grantedAuthoritiesMap; } /** * @param attributes2grantedAuthoritiesMap The attributes2grantedAuthoritiesMap to * set. */ - public void setAttributes2grantedAuthoritiesMap( - final Map attributes2grantedAuthoritiesMap) { + public void setAttributes2grantedAuthoritiesMap(final Map attributes2grantedAuthoritiesMap) { Assert.notEmpty(attributes2grantedAuthoritiesMap, "A non-empty attributes2grantedAuthoritiesMap must be supplied"); this.attributes2grantedAuthoritiesMap = preProcessMap(attributes2grantedAuthoritiesMap); - - mappableAttributes = Collections - .unmodifiableSet(this.attributes2grantedAuthoritiesMap.keySet()); + this.mappableAttributes = Collections.unmodifiableSet(this.attributes2grantedAuthoritiesMap.keySet()); } /** * Preprocess the given map to convert all the values to GrantedAuthority collections - * * @param orgMap The map to process * @return the processed Map */ private Map> preProcessMap(Map orgMap) { - Map> result = new HashMap<>( - orgMap.size()); - + Map> result = new HashMap<>(orgMap.size()); for (Map.Entry entry : orgMap.entrySet()) { Assert.isInstanceOf(String.class, entry.getKey(), "attributes2grantedAuthoritiesMap contains non-String objects as keys"); - result.put((String) entry.getKey(), - getGrantedAuthorityCollection(entry.getValue())); + result.put((String) entry.getKey(), getGrantedAuthorityCollection(entry.getValue())); } return result; } /** * Convert the given value to a collection of Granted Authorities - * * @param value The value to convert to a GrantedAuthority Collection * @return Collection containing the GrantedAuthority Collection */ @@ -115,12 +116,10 @@ public class MapBasedAttributes2GrantedAuthoritiesMapper implements /** * Convert the given value to a collection of Granted Authorities, adding the result * to the given result collection. - * * @param value The value to convert to a GrantedAuthority Collection * @return Collection containing the GrantedAuthority Collection */ - private void addGrantedAuthorityCollection(Collection result, - Object value) { + private void addGrantedAuthorityCollection(Collection result, Object value) { if (value == null) { return; } @@ -137,32 +136,28 @@ public class MapBasedAttributes2GrantedAuthoritiesMapper implements result.add((GrantedAuthority) value); } else { - throw new IllegalArgumentException("Invalid object type: " - + value.getClass().getName()); + throw new IllegalArgumentException("Invalid object type: " + value.getClass().getName()); } } - private void addGrantedAuthorityCollection(Collection result, - Collection value) { + private void addGrantedAuthorityCollection(Collection result, Collection value) { for (Object elt : value) { addGrantedAuthorityCollection(result, elt); } } - private void addGrantedAuthorityCollection(Collection result, - Object[] value) { + private void addGrantedAuthorityCollection(Collection result, Object[] value) { for (Object aValue : value) { addGrantedAuthorityCollection(result, aValue); } } - private void addGrantedAuthorityCollection(Collection result, - String value) { - StringTokenizer st = new StringTokenizer(value, stringSeparator, false); - while (st.hasMoreTokens()) { - String nextToken = st.nextToken(); - if (StringUtils.hasText(nextToken)) { - result.add(new SimpleGrantedAuthority(nextToken)); + private void addGrantedAuthorityCollection(Collection result, String value) { + StringTokenizer tokenizer = new StringTokenizer(value, this.stringSeparator, false); + while (tokenizer.hasMoreTokens()) { + String token = tokenizer.nextToken(); + if (StringUtils.hasText(token)) { + result.add(new SimpleGrantedAuthority(token)); } } } @@ -171,15 +166,16 @@ public class MapBasedAttributes2GrantedAuthoritiesMapper implements * * @see org.springframework.security.core.authority.mapping.MappableAttributesRetriever#getMappableAttributes() */ + @Override public Set getMappableAttributes() { - return mappableAttributes; + return this.mappableAttributes; } /** * @return Returns the stringSeparator. */ public String getStringSeparator() { - return stringSeparator; + return this.stringSeparator; } /** diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/MappableAttributesRetriever.java b/core/src/main/java/org/springframework/security/core/authority/mapping/MappableAttributesRetriever.java index 67fd1d0374..dfb5d4a2a0 100755 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/MappableAttributesRetriever.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/MappableAttributesRetriever.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; import java.util.Set; @@ -26,11 +27,12 @@ import java.util.Set; * @since 2.0 */ public interface MappableAttributesRetriever { + /** * Implementations of this method should return a set of all string attributes which * can be mapped to GrantedAuthoritys. - * * @return set of all mappable roles */ Set getMappableAttributes(); + } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/NullAuthoritiesMapper.java b/core/src/main/java/org/springframework/security/core/authority/mapping/NullAuthoritiesMapper.java index dd22a1e649..66ae12909e 100644 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/NullAuthoritiesMapper.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/NullAuthoritiesMapper.java @@ -13,18 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; -import org.springframework.security.core.GrantedAuthority; +import java.util.Collection; -import java.util.*; +import org.springframework.security.core.GrantedAuthority; /** * @author Luke Taylor */ public class NullAuthoritiesMapper implements GrantedAuthoritiesMapper { - public Collection mapAuthorities( - Collection authorities) { + + @Override + public Collection mapAuthorities(Collection authorities) { return authorities; } + } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAttributes2GrantedAuthoritiesMapper.java b/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAttributes2GrantedAuthoritiesMapper.java index b3c0a92109..c07137bb8f 100755 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAttributes2GrantedAuthoritiesMapper.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAttributes2GrantedAuthoritiesMapper.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core.authority.mapping; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; +package org.springframework.security.core.authority.mapping; import java.util.ArrayList; import java.util.Collection; @@ -24,6 +22,8 @@ import java.util.List; import java.util.Locale; import org.springframework.beans.factory.InitializingBean; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.util.Assert; /** @@ -38,8 +38,9 @@ import org.springframework.util.Assert; * @author Ruud Senden * @since 2.0 */ -public class SimpleAttributes2GrantedAuthoritiesMapper implements - Attributes2GrantedAuthoritiesMapper, InitializingBean { +public class SimpleAttributes2GrantedAuthoritiesMapper + implements Attributes2GrantedAuthoritiesMapper, InitializingBean { + private String attributePrefix = "ROLE_"; private boolean convertAttributeToUpperCase = false; @@ -51,9 +52,9 @@ public class SimpleAttributes2GrantedAuthoritiesMapper implements /** * Check whether all properties have been set to correct values. */ + @Override public void afterPropertiesSet() { - Assert.isTrue( - !(isConvertAttributeToUpperCase() && isConvertAttributeToLowerCase()), + Assert.isTrue(!(isConvertAttributeToUpperCase() && isConvertAttributeToLowerCase()), "Either convertAttributeToUpperCase or convertAttributeToLowerCase can be set to true, but not both"); } @@ -61,6 +62,7 @@ public class SimpleAttributes2GrantedAuthoritiesMapper implements * Map the given list of string attributes one-to-one to Spring Security * GrantedAuthorities. */ + @Override public List getGrantedAuthorities(Collection attributes) { List result = new ArrayList<>(attributes.size()); for (String attribute : attributes) { @@ -72,7 +74,6 @@ public class SimpleAttributes2GrantedAuthoritiesMapper implements /** * Map the given role one-on-one to a Spring Security GrantedAuthority, optionally * doing case conversion and/or adding a prefix. - * * @param attribute The attribute for which to get a GrantedAuthority * @return GrantedAuthority representing the given role. */ @@ -92,35 +93,35 @@ public class SimpleAttributes2GrantedAuthoritiesMapper implements } private boolean isConvertAttributeToLowerCase() { - return convertAttributeToLowerCase; + return this.convertAttributeToLowerCase; } public void setConvertAttributeToLowerCase(boolean b) { - convertAttributeToLowerCase = b; + this.convertAttributeToLowerCase = b; } private boolean isConvertAttributeToUpperCase() { - return convertAttributeToUpperCase; + return this.convertAttributeToUpperCase; } public void setConvertAttributeToUpperCase(boolean b) { - convertAttributeToUpperCase = b; + this.convertAttributeToUpperCase = b; } private String getAttributePrefix() { - return attributePrefix == null ? "" : attributePrefix; + return (this.attributePrefix != null) ? this.attributePrefix : ""; } public void setAttributePrefix(String string) { - attributePrefix = string; + this.attributePrefix = string; } private boolean isAddPrefixIfAlreadyExisting() { - return addPrefixIfAlreadyExisting; + return this.addPrefixIfAlreadyExisting; } public void setAddPrefixIfAlreadyExisting(boolean b) { - addPrefixIfAlreadyExisting = b; + this.addPrefixIfAlreadyExisting = b; } } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAuthorityMapper.java b/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAuthorityMapper.java index b381d8fcb9..2bb66a73b6 100644 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAuthorityMapper.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleAuthorityMapper.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + import org.springframework.beans.factory.InitializingBean; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.util.Assert; -import java.util.*; - /** * Simple one-to-one {@code GrantedAuthoritiesMapper} which allows for case conversion of * the authority name and the addition of a string prefix (which defaults to {@code ROLE_} @@ -30,15 +33,19 @@ import java.util.*; * @author Luke Taylor * @since 3.1 */ -public final class SimpleAuthorityMapper implements GrantedAuthoritiesMapper, - InitializingBean { +public final class SimpleAuthorityMapper implements GrantedAuthoritiesMapper, InitializingBean { + private GrantedAuthority defaultAuthority; + private String prefix = "ROLE_"; + private boolean convertToUpperCase = false; + private boolean convertToLowerCase = false; + @Override public void afterPropertiesSet() { - Assert.isTrue(!(convertToUpperCase && convertToLowerCase), + Assert.isTrue(!(this.convertToUpperCase && this.convertToLowerCase), "Either convertToUpperCase or convertToLowerCase can be set to true, but not both"); } @@ -47,45 +54,37 @@ public final class SimpleAuthorityMapper implements GrantedAuthoritiesMapper, * prefix settings. The mapping will be one-to-one unless duplicates are produced * during the conversion. If a default authority has been set, this will also be * assigned to each mapping. - * * @param authorities the original authorities - * * @return the converted set of authorities */ - public Set mapAuthorities( - Collection authorities) { - HashSet mapped = new HashSet<>( - authorities.size()); + @Override + public Set mapAuthorities(Collection authorities) { + HashSet mapped = new HashSet<>(authorities.size()); for (GrantedAuthority authority : authorities) { mapped.add(mapAuthority(authority.getAuthority())); } - - if (defaultAuthority != null) { - mapped.add(defaultAuthority); + if (this.defaultAuthority != null) { + mapped.add(this.defaultAuthority); } - return mapped; } private GrantedAuthority mapAuthority(String name) { - if (convertToUpperCase) { + if (this.convertToUpperCase) { name = name.toUpperCase(); } - else if (convertToLowerCase) { + else if (this.convertToLowerCase) { name = name.toLowerCase(); } - - if (prefix.length() > 0 && !name.startsWith(prefix)) { - name = prefix + name; + if (this.prefix.length() > 0 && !name.startsWith(this.prefix)) { + name = this.prefix + name; } - return new SimpleGrantedAuthority(name); } /** * Sets the prefix which should be added to the authority name (if it doesn't already * exist) - * * @param prefix the prefix, typically to satisfy the behaviour of an * {@code AccessDecisionVoter}. */ @@ -96,7 +95,6 @@ public final class SimpleAuthorityMapper implements GrantedAuthoritiesMapper, /** * Whether to convert the authority value to upper case in the mapping. - * * @param convertToUpperCase defaults to {@code false} */ public void setConvertToUpperCase(boolean convertToUpperCase) { @@ -105,7 +103,6 @@ public final class SimpleAuthorityMapper implements GrantedAuthoritiesMapper, /** * Whether to convert the authority value to lower case in the mapping. - * * @param convertToLowerCase defaults to {@code false} */ public void setConvertToLowerCase(boolean convertToLowerCase) { @@ -114,11 +111,11 @@ public final class SimpleAuthorityMapper implements GrantedAuthoritiesMapper, /** * Sets a default authority to be assigned to all users - * * @param authority the name of the authority to be assigned to all users. */ public void setDefaultAuthority(String authority) { Assert.hasText(authority, "The authority name cannot be set to an empty value"); this.defaultAuthority = new SimpleGrantedAuthority(authority); } + } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleMappableAttributesRetriever.java b/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleMappableAttributesRetriever.java index 0ce561150e..057cc22d9e 100755 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleMappableAttributesRetriever.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/SimpleMappableAttributesRetriever.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; import java.util.Collections; @@ -27,15 +28,10 @@ import java.util.Set; * @since 2.0 */ public class SimpleMappableAttributesRetriever implements MappableAttributesRetriever { + private Set mappableAttributes = null; - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.core.authority.mapping.MappableAttributesRetriever - * #getMappableAttributes() - */ + @Override public Set getMappableAttributes() { return this.mappableAttributes; } diff --git a/core/src/main/java/org/springframework/security/core/authority/mapping/package-info.java b/core/src/main/java/org/springframework/security/core/authority/mapping/package-info.java index f52d4bf459..48f121691b 100644 --- a/core/src/main/java/org/springframework/security/core/authority/mapping/package-info.java +++ b/core/src/main/java/org/springframework/security/core/authority/mapping/package-info.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Strategies for mapping a list of attributes (such as roles or LDAP groups) to a list of * {@code GrantedAuthority}s. *

      - * Provides a layer of indirection between a security data repository and the logical authorities - * required within an application. + * Provides a layer of indirection between a security data repository and the logical + * authorities required within an application. */ package org.springframework.security.core.authority.mapping; - diff --git a/core/src/main/java/org/springframework/security/core/authority/package-info.java b/core/src/main/java/org/springframework/security/core/authority/package-info.java index 3d3822993d..b47cd17e4b 100644 --- a/core/src/main/java/org/springframework/security/core/authority/package-info.java +++ b/core/src/main/java/org/springframework/security/core/authority/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * The default implementation of the {@code GrantedAuthority} interface. */ package org.springframework.security.core.authority; - diff --git a/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java index 18ae5a2828..d8367c4ebd 100644 --- a/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java @@ -28,32 +28,31 @@ import org.springframework.util.Assert; * @author Ben Alex */ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolderStrategy { - // ~ Static fields/initializers - // ===================================================================================== private static SecurityContext contextHolder; - // ~ Methods - // ======================================================================================================== - + @Override public void clearContext() { contextHolder = null; } + @Override public SecurityContext getContext() { if (contextHolder == null) { contextHolder = new SecurityContextImpl(); } - return contextHolder; } + @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); contextHolder = context; } + @Override public SecurityContext createEmptyContext() { return new SecurityContextImpl(); } + } diff --git a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java index 5e9d5e395d..cb415500ca 100644 --- a/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java @@ -23,40 +23,36 @@ import org.springframework.util.Assert; * {@link org.springframework.security.core.context.SecurityContextHolderStrategy}. * * @author Ben Alex - * * @see java.lang.ThreadLocal */ -final class InheritableThreadLocalSecurityContextHolderStrategy implements - SecurityContextHolderStrategy { - // ~ Static fields/initializers - // ===================================================================================== +final class InheritableThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy { private static final ThreadLocal contextHolder = new InheritableThreadLocal<>(); - // ~ Methods - // ======================================================================================================== - + @Override public void clearContext() { contextHolder.remove(); } + @Override public SecurityContext getContext() { SecurityContext ctx = contextHolder.get(); - if (ctx == null) { ctx = createEmptyContext(); contextHolder.set(ctx); } - return ctx; } + @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); contextHolder.set(context); } + @Override public SecurityContext createEmptyContext() { return new SecurityContextImpl(); } + } diff --git a/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolder.java b/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolder.java index e6ed84ae2c..d94a5112b2 100644 --- a/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolder.java +++ b/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolder.java @@ -16,12 +16,12 @@ package org.springframework.security.core.context; +import java.util.function.Function; -import org.springframework.security.core.Authentication; import reactor.core.publisher.Mono; import reactor.util.context.Context; -import java.util.function.Function; +import org.springframework.security.core.Authentication; /** * Allows getting and setting the Spring {@link SecurityContext} into a {@link Context}. @@ -29,17 +29,31 @@ import java.util.function.Function; * @author Rob Winch * @since 5.0 */ -public class ReactiveSecurityContextHolder { +public final class ReactiveSecurityContextHolder { + private static final Class SECURITY_CONTEXT_KEY = SecurityContext.class; + private ReactiveSecurityContextHolder() { + } + /** * Gets the {@code Mono} from Reactor {@link Context} * @return the {@code Mono} */ public static Mono getContext() { + // @formatter:off return Mono.subscriberContext() - .filter( c -> c.hasKey(SECURITY_CONTEXT_KEY)) - .flatMap( c-> c.>get(SECURITY_CONTEXT_KEY)); + .filter(ReactiveSecurityContextHolder::hasSecurityContext) + .flatMap(ReactiveSecurityContextHolder::getSecurityContext); + // @formatter:on + } + + private static boolean hasSecurityContext(Context context) { + return context.hasKey(SECURITY_CONTEXT_KEY); + } + + private static Mono getSecurityContext(Context context) { + return context.>get(SECURITY_CONTEXT_KEY); } /** @@ -48,7 +62,7 @@ public class ReactiveSecurityContextHolder { * from clearing the context. */ public static Function clearContext() { - return context -> context.delete(SECURITY_CONTEXT_KEY); + return (context) -> context.delete(SECURITY_CONTEXT_KEY); } /** @@ -70,4 +84,5 @@ public class ReactiveSecurityContextHolder { public static Context withAuthentication(Authentication authentication) { return withSecurityContext(Mono.just(new SecurityContextImpl(authentication))); } + } diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContext.java b/core/src/main/java/org/springframework/security/core/context/SecurityContext.java index 1fd92790cc..bfc40672fd 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContext.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContext.java @@ -16,10 +16,10 @@ package org.springframework.security.core.context; -import org.springframework.security.core.Authentication; - import java.io.Serializable; +import org.springframework.security.core.Authentication; + /** * Interface defining the minimum security information associated with the current thread * of execution. @@ -31,12 +31,9 @@ import java.io.Serializable; * @author Ben Alex */ public interface SecurityContext extends Serializable { - // ~ Methods - // ======================================================================================================== /** * Obtains the currently authenticated principal, or an authentication request token. - * * @return the Authentication or null if no authentication * information is available */ @@ -45,9 +42,9 @@ public interface SecurityContext extends Serializable { /** * Changes the currently authenticated principal, or removes the authentication * information. - * * @param authentication the new Authentication token, or * null if no further authentication information should be stored */ void setAuthentication(Authentication authentication); + } diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java index 61c4231c35..810edef7c9 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java @@ -16,11 +16,11 @@ package org.springframework.security.core.context; +import java.lang.reflect.Constructor; + import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; -import java.lang.reflect.Constructor; - /** * Associates a given {@link SecurityContext} with the current execution thread. *

      @@ -48,57 +48,30 @@ import java.lang.reflect.Constructor; * */ public class SecurityContextHolder { - // ~ Static fields/initializers - // ===================================================================================== public static final String MODE_THREADLOCAL = "MODE_THREADLOCAL"; + public static final String MODE_INHERITABLETHREADLOCAL = "MODE_INHERITABLETHREADLOCAL"; + public static final String MODE_GLOBAL = "MODE_GLOBAL"; + public static final String SYSTEM_PROPERTY = "spring.security.strategy"; + private static String strategyName = System.getProperty(SYSTEM_PROPERTY); + private static SecurityContextHolderStrategy strategy; + private static int initializeCount = 0; static { initialize(); } - // ~ Methods - // ======================================================================================================== - - /** - * Explicitly clears the context value from the current thread. - */ - public static void clearContext() { - strategy.clearContext(); - } - - /** - * Obtain the current SecurityContext. - * - * @return the security context (never null) - */ - public static SecurityContext getContext() { - return strategy.getContext(); - } - - /** - * Primarily for troubleshooting purposes, this method shows how many times the class - * has re-initialized its SecurityContextHolderStrategy. - * - * @return the count (should be one unless you've called - * {@link #setStrategyName(String)} to switch to an alternate strategy. - */ - public static int getInitializeCount() { - return initializeCount; - } - private static void initialize() { if (!StringUtils.hasText(strategyName)) { // Set default strategyName = MODE_THREADLOCAL; } - if (strategyName.equals(MODE_THREADLOCAL)) { strategy = new ThreadLocalSecurityContextHolderStrategy(); } @@ -119,13 +92,36 @@ public class SecurityContextHolder { ReflectionUtils.handleReflectionException(ex); } } - initializeCount++; } + /** + * Explicitly clears the context value from the current thread. + */ + public static void clearContext() { + strategy.clearContext(); + } + + /** + * Obtain the current SecurityContext. + * @return the security context (never null) + */ + public static SecurityContext getContext() { + return strategy.getContext(); + } + + /** + * Primarily for troubleshooting purposes, this method shows how many times the class + * has re-initialized its SecurityContextHolderStrategy. + * @return the count (should be one unless you've called + * {@link #setStrategyName(String)} to switch to an alternate strategy. + */ + public static int getInitializeCount() { + return initializeCount; + } + /** * Associates a new SecurityContext with the current thread of execution. - * * @param context the new SecurityContext (may not be null) */ public static void setContext(SecurityContext context) { @@ -136,7 +132,6 @@ public class SecurityContextHolder { * Changes the preferred strategy. Do NOT call this method more than once for * a given JVM, as it will re-initialize the strategy and adversely affect any * existing threads using the old strategy. - * * @param strategyName the fully qualified class name of the strategy that should be * used. */ @@ -147,7 +142,6 @@ public class SecurityContextHolder { /** * Allows retrieval of the context strategy. See SEC-1188. - * * @return the configured strategy for storing the security context. */ public static SecurityContextHolderStrategy getContextHolderStrategy() { @@ -163,7 +157,7 @@ public class SecurityContextHolder { @Override public String toString() { - return "SecurityContextHolder[strategy='" + strategyName + "'; initializeCount=" - + initializeCount + "]"; + return "SecurityContextHolder[strategy='" + strategyName + "'; initializeCount=" + initializeCount + "]"; } + } diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java index aa35315f4c..4954db70aa 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java @@ -25,8 +25,6 @@ package org.springframework.security.core.context; * @author Ben Alex */ public interface SecurityContextHolderStrategy { - // ~ Methods - // ======================================================================================================== /** * Clears the current context. @@ -35,7 +33,6 @@ public interface SecurityContextHolderStrategy { /** * Obtains the current context. - * * @return a context (never null - create a default implementation if * necessary) */ @@ -43,7 +40,6 @@ public interface SecurityContextHolderStrategy { /** * Sets the current context. - * * @param context to the new argument (should never be null, although * implementations must check if null has been passed and throw an * IllegalArgumentException in such cases) @@ -54,8 +50,8 @@ public interface SecurityContextHolderStrategy { * Creates a new, empty context implementation, for use by * SecurityContextRepository implementations, when creating a new context for * the first time. - * * @return the empty context. */ SecurityContext createEmptyContext(); + } diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextImpl.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextImpl.java index bfd409c1c4..19c18abb24 100644 --- a/core/src/main/java/org/springframework/security/core/context/SecurityContextImpl.java +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextImpl.java @@ -18,6 +18,7 @@ package org.springframework.security.core.context; import org.springframework.security.core.Authentication; import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.util.ObjectUtils; /** * Base implementation of {@link SecurityContext}. @@ -30,51 +31,38 @@ public class SecurityContextImpl implements SecurityContext { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private Authentication authentication; - public SecurityContextImpl() {} + public SecurityContextImpl() { + } public SecurityContextImpl(Authentication authentication) { this.authentication = authentication; } - // ~ Methods - // ======================================================================================================== - @Override public boolean equals(Object obj) { if (obj instanceof SecurityContextImpl) { - SecurityContextImpl test = (SecurityContextImpl) obj; - - if ((this.getAuthentication() == null) && (test.getAuthentication() == null)) { + SecurityContextImpl other = (SecurityContextImpl) obj; + if ((this.getAuthentication() == null) && (other.getAuthentication() == null)) { return true; } - - if ((this.getAuthentication() != null) && (test.getAuthentication() != null) - && this.getAuthentication().equals(test.getAuthentication())) { + if ((this.getAuthentication() != null) && (other.getAuthentication() != null) + && this.getAuthentication().equals(other.getAuthentication())) { return true; } } - return false; } @Override public Authentication getAuthentication() { - return authentication; + return this.authentication; } @Override public int hashCode() { - if (this.authentication == null) { - return -1; - } - else { - return this.authentication.hashCode(); - } + return ObjectUtils.nullSafeHashCode(this.authentication); } @Override @@ -86,14 +74,13 @@ public class SecurityContextImpl implements SecurityContext { public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); - if (this.authentication == null) { sb.append(": Null authentication"); } else { sb.append(": Authentication: ").append(this.authentication); } - return sb.toString(); } + } diff --git a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java index c13a909c59..801f5c8207 100644 --- a/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java +++ b/core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java @@ -23,41 +23,37 @@ import org.springframework.util.Assert; * {@link SecurityContextHolderStrategy}. * * @author Ben Alex - * * @see java.lang.ThreadLocal * @see org.springframework.security.core.context.web.SecurityContextPersistenceFilter */ -final class ThreadLocalSecurityContextHolderStrategy implements - SecurityContextHolderStrategy { - // ~ Static fields/initializers - // ===================================================================================== +final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy { private static final ThreadLocal contextHolder = new ThreadLocal<>(); - // ~ Methods - // ======================================================================================================== - + @Override public void clearContext() { contextHolder.remove(); } + @Override public SecurityContext getContext() { SecurityContext ctx = contextHolder.get(); - if (ctx == null) { ctx = createEmptyContext(); contextHolder.set(ctx); } - return ctx; } + @Override public void setContext(SecurityContext context) { Assert.notNull(context, "Only non-null SecurityContext instances are permitted"); contextHolder.set(context); } + @Override public SecurityContext createEmptyContext() { return new SecurityContextImpl(); } + } diff --git a/core/src/main/java/org/springframework/security/core/context/package-info.java b/core/src/main/java/org/springframework/security/core/context/package-info.java index 6660feef43..61009683ca 100644 --- a/core/src/main/java/org/springframework/security/core/context/package-info.java +++ b/core/src/main/java/org/springframework/security/core/context/package-info.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Classes related to the establishment of a security context for the duration of a request (such as - * an HTTP or RMI invocation). + * Classes related to the establishment of a security context for the duration of a + * request (such as an HTTP or RMI invocation). *

      - * A security context is usually associated with the current execution thread for the duration of the request, - * making the authentication information it contains available throughout all the layers of an application. + * A security context is usually associated with the current execution thread for the + * duration of the request, making the authentication information it contains available + * throughout all the layers of an application. *

      - * The {@link org.springframework.security.core.context.SecurityContext SecurityContext} can be accessed at any point - * by calling the {@link org.springframework.security.core.context.SecurityContextHolder SecurityContextHolder}. + * The {@link org.springframework.security.core.context.SecurityContext SecurityContext} + * can be accessed at any point by calling the + * {@link org.springframework.security.core.context.SecurityContextHolder + * SecurityContextHolder}. */ package org.springframework.security.core.context; - diff --git a/core/src/main/java/org/springframework/security/core/package-info.java b/core/src/main/java/org/springframework/security/core/package-info.java index f27dcaf004..b8dd2eaa5a 100644 --- a/core/src/main/java/org/springframework/security/core/package-info.java +++ b/core/src/main/java/org/springframework/security/core/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Core classes and interfaces related to user authentication and authorization, as well as the maintenance of - * a security context. + * Core classes and interfaces related to user authentication and authorization, as well + * as the maintenance of a security context. */ package org.springframework.security.core; - diff --git a/core/src/main/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscoverer.java b/core/src/main/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscoverer.java index daac6e6d95..18d8099705 100644 --- a/core/src/main/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscoverer.java +++ b/core/src/main/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscoverer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.parameters; import java.lang.annotation.Annotation; @@ -81,13 +82,17 @@ import org.springframework.util.ReflectionUtils; * {@link PrioritizedParameterNameDiscoverer} are an all or nothing operation. *

      * - * @see DefaultSecurityParameterNameDiscoverer - * * @author Rob Winch * @since 3.2 + * @see DefaultSecurityParameterNameDiscoverer */ public class AnnotationParameterNameDiscoverer implements ParameterNameDiscoverer { + private static final ParameterNameFactory> CONSTRUCTOR_METHODPARAM_FACTORY = ( + constructor) -> constructor.getParameterAnnotations(); + + private static final ParameterNameFactory METHOD_METHODPARAM_FACTORY = Method::getParameterAnnotations; + private final Set annotationClassesToUse; public AnnotationParameterNameDiscoverer(String... annotationClassToUse) { @@ -95,29 +100,21 @@ public class AnnotationParameterNameDiscoverer implements ParameterNameDiscovere } public AnnotationParameterNameDiscoverer(Set annotationClassesToUse) { - Assert.notEmpty(annotationClassesToUse, - "annotationClassesToUse cannot be null or empty"); + Assert.notEmpty(annotationClassesToUse, "annotationClassesToUse cannot be null or empty"); this.annotationClassesToUse = annotationClassesToUse; } - /* - * (non-Javadoc) - * - * @see org.springframework.core.ParameterNameDiscoverer#getParameterNames(java - * .lang.reflect.Method) - */ + @Override public String[] getParameterNames(Method method) { Method originalMethod = BridgeMethodResolver.findBridgedMethod(method); - String[] paramNames = lookupParameterNames(METHOD_METHODPARAM_FACTORY, - originalMethod); + String[] paramNames = lookupParameterNames(METHOD_METHODPARAM_FACTORY, originalMethod); if (paramNames != null) { return paramNames; } Class declaringClass = method.getDeclaringClass(); Class[] interfaces = declaringClass.getInterfaces(); for (Class intrfc : interfaces) { - Method intrfcMethod = ReflectionUtils.findMethod(intrfc, method.getName(), - method.getParameterTypes()); + Method intrfcMethod = ReflectionUtils.findMethod(intrfc, method.getName(), method.getParameterTypes()); if (intrfcMethod != null) { return lookupParameterNames(METHOD_METHODPARAM_FACTORY, intrfcMethod); } @@ -125,26 +122,20 @@ public class AnnotationParameterNameDiscoverer implements ParameterNameDiscovere return paramNames; } - /* - * (non-Javadoc) - * - * @see org.springframework.core.ParameterNameDiscoverer#getParameterNames(java - * .lang.reflect.Constructor) - */ + @Override public String[] getParameterNames(Constructor constructor) { return lookupParameterNames(CONSTRUCTOR_METHODPARAM_FACTORY, constructor); } /** * Gets the parameter names or null if not found. - * * @param parameterNameFactory the {@link ParameterNameFactory} to use * @param t the {@link AccessibleObject} to find the parameter names on (i.e. Method * or Constructor) * @return the parameter names or null */ - private String[] lookupParameterNames( - ParameterNameFactory parameterNameFactory, T t) { + private String[] lookupParameterNames(ParameterNameFactory parameterNameFactory, + T t) { Annotation[][] parameterAnnotations = parameterNameFactory.findParameterAnnotations(t); int parameterCount = parameterAnnotations.length; String[] paramNames = new String[parameterCount]; @@ -164,32 +155,26 @@ public class AnnotationParameterNameDiscoverer implements ParameterNameDiscovere * Finds the parameter name from the provided {@link Annotation}s or null if it could * not find it. The search is done by looking at the value property of the * {@link #annotationClassesToUse}. - * * @param parameterAnnotations the {@link Annotation}'s to search. * @return */ private String findParameterName(Annotation[] parameterAnnotations) { for (Annotation paramAnnotation : parameterAnnotations) { - if (annotationClassesToUse.contains(paramAnnotation.annotationType() - .getName())) { + if (this.annotationClassesToUse.contains(paramAnnotation.annotationType().getName())) { return (String) AnnotationUtils.getValue(paramAnnotation, "value"); } } return null; } - private static final ParameterNameFactory> CONSTRUCTOR_METHODPARAM_FACTORY = constructor -> constructor.getParameterAnnotations(); - - private static final ParameterNameFactory METHOD_METHODPARAM_FACTORY = method -> method.getParameterAnnotations(); - /** * Strategy interface for looking up the parameter names. * + * @param the type to inspect (i.e. {@link Method} or {@link Constructor}) * @author Rob Winch * @since 3.2 - * - * @param the type to inspect (i.e. {@link Method} or {@link Constructor}) */ + @FunctionalInterface private interface ParameterNameFactory { /** @@ -199,5 +184,7 @@ public class AnnotationParameterNameDiscoverer implements ParameterNameDiscovere * @return */ Annotation[][] findParameterAnnotations(T t); + } + } diff --git a/core/src/main/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscoverer.java b/core/src/main/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscoverer.java index c7c10fe90c..7a7b339917 100644 --- a/core/src/main/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscoverer.java +++ b/core/src/main/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscoverer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.parameters; import java.util.Collections; @@ -20,8 +21,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.LocalVariableTableParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer; @@ -44,19 +43,15 @@ import org.springframework.util.ClassUtils; * {@link LocalVariableTableParameterNameDiscoverer} is added directly. * * - * @see AnnotationParameterNameDiscoverer - * * @author Rob Winch * @since 3.2 + * @see AnnotationParameterNameDiscoverer */ -public class DefaultSecurityParameterNameDiscoverer extends - PrioritizedParameterNameDiscoverer { - - private final Log logger = LogFactory.getLog(getClass()); +public class DefaultSecurityParameterNameDiscoverer extends PrioritizedParameterNameDiscoverer { private static final String DATA_PARAM_CLASSNAME = "org.springframework.data.repository.query.Param"; - private static final boolean DATA_PARAM_PRESENT = ClassUtils.isPresent( - DATA_PARAM_CLASSNAME, + + private static final boolean DATA_PARAM_PRESENT = ClassUtils.isPresent(DATA_PARAM_CLASSNAME, DefaultSecurityParameterNameDiscoverer.class.getClassLoader()); /** @@ -64,7 +59,7 @@ public class DefaultSecurityParameterNameDiscoverer extends * instances. */ public DefaultSecurityParameterNameDiscoverer() { - this(Collections. emptyList()); + this(Collections.emptyList()); } /** @@ -74,21 +69,19 @@ public class DefaultSecurityParameterNameDiscoverer extends * defaults. Cannot be null. */ @SuppressWarnings("unchecked") - public DefaultSecurityParameterNameDiscoverer( - List parameterNameDiscovers) { + public DefaultSecurityParameterNameDiscoverer(List parameterNameDiscovers) { Assert.notNull(parameterNameDiscovers, "parameterNameDiscovers cannot be null"); for (ParameterNameDiscoverer discover : parameterNameDiscovers) { addDiscoverer(discover); } - Set annotationClassesToUse = new HashSet<>(2); annotationClassesToUse.add("org.springframework.security.access.method.P"); annotationClassesToUse.add(P.class.getName()); if (DATA_PARAM_PRESENT) { annotationClassesToUse.add(DATA_PARAM_CLASSNAME); } - addDiscoverer(new AnnotationParameterNameDiscoverer(annotationClassesToUse)); addDiscoverer(new DefaultParameterNameDiscoverer()); } + } diff --git a/core/src/main/java/org/springframework/security/core/parameters/P.java b/core/src/main/java/org/springframework/security/core/parameters/P.java index f828e3795c..b6ece3bb30 100644 --- a/core/src/main/java/org/springframework/security/core/parameters/P.java +++ b/core/src/main/java/org/springframework/security/core/parameters/P.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.parameters; import java.lang.annotation.Documented; @@ -41,4 +42,5 @@ public @interface P { * @return */ String value(); + } diff --git a/core/src/main/java/org/springframework/security/core/session/AbstractSessionEvent.java b/core/src/main/java/org/springframework/security/core/session/AbstractSessionEvent.java index 5a43b2d6de..4c8c20da5c 100644 --- a/core/src/main/java/org/springframework/security/core/session/AbstractSessionEvent.java +++ b/core/src/main/java/org/springframework/security/core/session/AbstractSessionEvent.java @@ -29,4 +29,5 @@ public class AbstractSessionEvent extends ApplicationEvent { public AbstractSessionEvent(Object source) { super(source); } + } diff --git a/core/src/main/java/org/springframework/security/core/session/SessionCreationEvent.java b/core/src/main/java/org/springframework/security/core/session/SessionCreationEvent.java index 97471c85e1..3fab046e58 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionCreationEvent.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionCreationEvent.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.session; /** @@ -27,4 +28,5 @@ public abstract class SessionCreationEvent extends AbstractSessionEvent { public SessionCreationEvent(Object source) { super(source); } + } diff --git a/core/src/main/java/org/springframework/security/core/session/SessionDestroyedEvent.java b/core/src/main/java/org/springframework/security/core/session/SessionDestroyedEvent.java index 785c40c0e6..07e5d7d6f7 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionDestroyedEvent.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionDestroyedEvent.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.session; -import org.springframework.security.core.context.SecurityContext; +import java.util.List; -import java.util.*; +import org.springframework.security.core.context.SecurityContext; /** * Generic "session termination" event which indicates that a session (potentially @@ -35,7 +36,6 @@ public abstract class SessionDestroyedEvent extends AbstractSessionEvent { /** * Provides the {@code SecurityContext} instances which were associated with the * destroyed session. Usually there will be only one security context per session. - * * @return the {@code SecurityContext} instances which were stored in the current * session (an empty list if there are none). */ @@ -45,4 +45,5 @@ public abstract class SessionDestroyedEvent extends AbstractSessionEvent { * @return the identifier associated with the destroyed session. */ public abstract String getId(); + } diff --git a/core/src/main/java/org/springframework/security/core/session/SessionIdChangedEvent.java b/core/src/main/java/org/springframework/security/core/session/SessionIdChangedEvent.java index 640cdf7618..27229533e6 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionIdChangedEvent.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionIdChangedEvent.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.session; /** - * Generic "session ID changed" event which indicates that a session - * identifier (potentially represented by a security context) has changed. + * Generic "session ID changed" event which indicates that a session identifier + * (potentially represented by a security context) has changed. * * @since 5.4 */ @@ -29,16 +30,14 @@ public abstract class SessionIdChangedEvent extends AbstractSessionEvent { /** * Returns the old session ID. - * - * @return the identifier that was previously associated with - * the session. + * @return the identifier that was previously associated with the session. */ public abstract String getOldSessionId(); /** * Returns the new session ID. - * * @return the new identifier that is associated with the session. */ public abstract String getNewSessionId(); + } diff --git a/core/src/main/java/org/springframework/security/core/session/SessionInformation.java b/core/src/main/java/org/springframework/security/core/session/SessionInformation.java index fd7ae3ff73..54b05bbbb0 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionInformation.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionInformation.java @@ -16,12 +16,12 @@ package org.springframework.security.core.session; +import java.io.Serializable; +import java.util.Date; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.util.Date; -import java.io.Serializable; - /** * Represents a record of a session within the Spring Security framework. *

      @@ -41,16 +41,13 @@ public class SessionInformation implements Serializable { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private Date lastRequest; - private final Object principal; - private final String sessionId; - private boolean expired = false; - // ~ Constructors - // =================================================================================================== + private final Object principal; + + private final String sessionId; + + private boolean expired = false; public SessionInformation(Object principal, String sessionId, Date lastRequest) { Assert.notNull(principal, "Principal required"); @@ -61,27 +58,24 @@ public class SessionInformation implements Serializable { this.lastRequest = lastRequest; } - // ~ Methods - // ======================================================================================================== - public void expireNow() { this.expired = true; } public Date getLastRequest() { - return lastRequest; + return this.lastRequest; } public Object getPrincipal() { - return principal; + return this.principal; } public String getSessionId() { - return sessionId; + return this.sessionId; } public boolean isExpired() { - return expired; + return this.expired; } /** @@ -90,4 +84,5 @@ public class SessionInformation implements Serializable { public void refreshLastRequest() { this.lastRequest = new Date(); } + } diff --git a/core/src/main/java/org/springframework/security/core/session/SessionRegistry.java b/core/src/main/java/org/springframework/security/core/session/SessionRegistry.java index 68f4ab054a..107dfaf54c 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionRegistry.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionRegistry.java @@ -24,12 +24,9 @@ import java.util.List; * @author Ben Alex */ public interface SessionRegistry { - // ~ Methods - // ======================================================================================================== /** * Obtains all the known principals in the SessionRegistry. - * * @return each of the unique principals, which can then be presented to * {@link #getAllSessions(Object, boolean)}. */ @@ -39,22 +36,17 @@ public interface SessionRegistry { * Obtains all the known sessions for the specified principal. Sessions that have been * destroyed are not returned. Sessions that have expired may be returned, depending * on the passed argument. - * * @param principal to locate sessions for (should never be null) * @param includeExpiredSessions if true, the returned sessions will also * include those that have expired for the principal - * * @return the matching sessions for this principal (should not return null). */ - List getAllSessions(Object principal, - boolean includeExpiredSessions); + List getAllSessions(Object principal, boolean includeExpiredSessions); /** * Obtains the session information for the specified sessionId. Even * expired sessions are returned (although destroyed sessions are never returned). - * * @param sessionId to lookup (should never be null) - * * @return the session information, or null if not found */ SessionInformation getSessionInformation(String sessionId); @@ -63,7 +55,6 @@ public interface SessionRegistry { * Updates the given sessionId so its last request time is equal to the * present date and time. Silently returns if the given sessionId cannot * be found or the session is marked to expire. - * * @param sessionId for which to update the date and time of the last request (should * never be null) */ @@ -72,7 +63,6 @@ public interface SessionRegistry { /** * Registers a new session for the specified principal. The newly registered session * will not be marked for expiration. - * * @param sessionId to associate with the principal (should never be null * ) * @param principal to associate with the session (should never be null) @@ -83,8 +73,8 @@ public interface SessionRegistry { * Deletes all the session information being maintained for the specified * sessionId. If the sessionId is not found, the method * gracefully returns. - * * @param sessionId to delete information for (should never be null) */ void removeSessionInformation(String sessionId); + } diff --git a/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java b/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java index fa00255902..73623f1bc4 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java @@ -16,16 +16,23 @@ package org.springframework.security.core.session; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.context.ApplicationListener; -import org.springframework.util.Assert; - -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArraySet; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.context.ApplicationListener; +import org.springframework.core.log.LogMessage; +import org.springframework.util.Assert; + /** * Default implementation of * {@link org.springframework.security.core.session.SessionRegistry SessionRegistry} which @@ -33,165 +40,132 @@ import java.util.concurrent.CopyOnWriteArraySet; * SessionDestroyedEvent}s published in the Spring application context. *

      * For this class to function correctly in a web application, it is important that you - * register an HttpSessionEventPublisher + * register an HttpSessionEventPublisher * in the web.xml file so that this class is notified of sessions that expire. * * @author Ben Alex * @author Luke Taylor */ -public class SessionRegistryImpl implements SessionRegistry, - ApplicationListener { - - // ~ Instance fields - // ================================================================================================ +public class SessionRegistryImpl implements SessionRegistry, ApplicationListener { protected final Log logger = LogFactory.getLog(SessionRegistryImpl.class); - /** */ + // private final ConcurrentMap> principals; - /** */ - private final Map sessionIds; - // ~ Methods - // ======================================================================================================== + // + private final Map sessionIds; public SessionRegistryImpl() { this.principals = new ConcurrentHashMap<>(); this.sessionIds = new ConcurrentHashMap<>(); } - public SessionRegistryImpl(ConcurrentMap> principals, Map sessionIds) { - this.principals=principals; - this.sessionIds=sessionIds; + public SessionRegistryImpl(ConcurrentMap> principals, + Map sessionIds) { + this.principals = principals; + this.sessionIds = sessionIds; } + @Override public List getAllPrincipals() { - return new ArrayList<>(principals.keySet()); + return new ArrayList<>(this.principals.keySet()); } - public List getAllSessions(Object principal, - boolean includeExpiredSessions) { - final Set sessionsUsedByPrincipal = principals.get(principal); - + @Override + public List getAllSessions(Object principal, boolean includeExpiredSessions) { + Set sessionsUsedByPrincipal = this.principals.get(principal); if (sessionsUsedByPrincipal == null) { return Collections.emptyList(); } - - List list = new ArrayList<>( - sessionsUsedByPrincipal.size()); - + List list = new ArrayList<>(sessionsUsedByPrincipal.size()); for (String sessionId : sessionsUsedByPrincipal) { SessionInformation sessionInformation = getSessionInformation(sessionId); - if (sessionInformation == null) { continue; } - if (includeExpiredSessions || !sessionInformation.isExpired()) { list.add(sessionInformation); } } - return list; } + @Override public SessionInformation getSessionInformation(String sessionId) { Assert.hasText(sessionId, "SessionId required as per interface contract"); - - return sessionIds.get(sessionId); + return this.sessionIds.get(sessionId); } + @Override public void onApplicationEvent(AbstractSessionEvent event) { if (event instanceof SessionDestroyedEvent) { SessionDestroyedEvent sessionDestroyedEvent = (SessionDestroyedEvent) event; String sessionId = sessionDestroyedEvent.getId(); removeSessionInformation(sessionId); - } else if (event instanceof SessionIdChangedEvent) { + } + else if (event instanceof SessionIdChangedEvent) { SessionIdChangedEvent sessionIdChangedEvent = (SessionIdChangedEvent) event; String oldSessionId = sessionIdChangedEvent.getOldSessionId(); - Object principal = sessionIds.get(oldSessionId).getPrincipal(); + Object principal = this.sessionIds.get(oldSessionId).getPrincipal(); removeSessionInformation(oldSessionId); registerNewSession(sessionIdChangedEvent.getNewSessionId(), principal); } } + @Override public void refreshLastRequest(String sessionId) { Assert.hasText(sessionId, "SessionId required as per interface contract"); - SessionInformation info = getSessionInformation(sessionId); - if (info != null) { info.refreshLastRequest(); } } + @Override public void registerNewSession(String sessionId, Object principal) { Assert.hasText(sessionId, "SessionId required as per interface contract"); Assert.notNull(principal, "Principal required as per interface contract"); - if (getSessionInformation(sessionId) != null) { removeSessionInformation(sessionId); } - - if (logger.isDebugEnabled()) { - logger.debug("Registering session " + sessionId + ", for principal " - + principal); + if (this.logger.isDebugEnabled()) { + this.logger.debug(LogMessage.format("Registering session %s, for principal %s", sessionId, principal)); } - - sessionIds.put(sessionId, - new SessionInformation(principal, sessionId, new Date())); - - principals.compute(principal, (key, sessionsUsedByPrincipal) -> { + this.sessionIds.put(sessionId, new SessionInformation(principal, sessionId, new Date())); + this.principals.compute(principal, (key, sessionsUsedByPrincipal) -> { if (sessionsUsedByPrincipal == null) { sessionsUsedByPrincipal = new CopyOnWriteArraySet<>(); } sessionsUsedByPrincipal.add(sessionId); - - if (logger.isTraceEnabled()) { - logger.trace("Sessions used by '" + principal + "' : " - + sessionsUsedByPrincipal); - } + this.logger.trace(LogMessage.format("Sessions used by '%s' : %s", principal, sessionsUsedByPrincipal)); return sessionsUsedByPrincipal; }); } + @Override public void removeSessionInformation(String sessionId) { Assert.hasText(sessionId, "SessionId required as per interface contract"); - SessionInformation info = getSessionInformation(sessionId); - if (info == null) { return; } - - if (logger.isTraceEnabled()) { - logger.debug("Removing session " + sessionId - + " from set of registered sessions"); + if (this.logger.isTraceEnabled()) { + this.logger.debug("Removing session " + sessionId + " from set of registered sessions"); } - - sessionIds.remove(sessionId); - - principals.computeIfPresent(info.getPrincipal(), (key, sessionsUsedByPrincipal) -> { - if (logger.isDebugEnabled()) { - logger.debug("Removing session " + sessionId - + " from principal's set of registered sessions"); - } - + this.sessionIds.remove(sessionId); + this.principals.computeIfPresent(info.getPrincipal(), (key, sessionsUsedByPrincipal) -> { + this.logger.debug( + LogMessage.format("Removing session %s from principal's set of registered sessions", sessionId)); sessionsUsedByPrincipal.remove(sessionId); - if (sessionsUsedByPrincipal.isEmpty()) { // No need to keep object in principals Map anymore - if (logger.isDebugEnabled()) { - logger.debug("Removing principal " + info.getPrincipal() - + " from registry"); - } + this.logger.debug(LogMessage.format("Removing principal %s from registry", info.getPrincipal())); sessionsUsedByPrincipal = null; } - - if (logger.isTraceEnabled()) { - logger.trace("Sessions used by '" + info.getPrincipal() + "' : " - + sessionsUsedByPrincipal); - } + this.logger.trace( + LogMessage.format("Sessions used by '%s' : %s", info.getPrincipal(), sessionsUsedByPrincipal)); return sessionsUsedByPrincipal; }); } diff --git a/core/src/main/java/org/springframework/security/core/session/package-info.java b/core/src/main/java/org/springframework/security/core/session/package-info.java index 7b40d87320..bdda791db2 100644 --- a/core/src/main/java/org/springframework/security/core/session/package-info.java +++ b/core/src/main/java/org/springframework/security/core/session/package-info.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Session abstraction which is provided by the {@code org.springframework.security.core.session.SessionInformation - * SessionInformation} class. The {@link org.springframework.security.core.session.SessionRegistry SessionRegistry} - * is a core part of the web-based concurrent session control, but the code is not dependent on any of the servlet APIs. + * Session abstraction which is provided by the + * {@code org.springframework.security.core.session.SessionInformation + * SessionInformation} class. The + * {@link org.springframework.security.core.session.SessionRegistry SessionRegistry} is a + * core part of the web-based concurrent session control, but the code is not dependent on + * any of the servlet APIs. */ package org.springframework.security.core.session; - diff --git a/core/src/main/java/org/springframework/security/core/token/DefaultToken.java b/core/src/main/java/org/springframework/security/core/token/DefaultToken.java index f5db72c9d0..aaa89ffb6d 100644 --- a/core/src/main/java/org/springframework/security/core/token/DefaultToken.java +++ b/core/src/main/java/org/springframework/security/core/token/DefaultToken.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; import java.util.Date; @@ -26,8 +27,11 @@ import org.springframework.util.Assert; * @since 2.0.1 */ public class DefaultToken implements Token { + private final String key; + private final long keyCreationTime; + private final String extendedInformation; public DefaultToken(String key, long keyCreationTime, String extendedInformation) { @@ -40,25 +44,24 @@ public class DefaultToken implements Token { @Override public String getKey() { - return key; + return this.key; } @Override public long getKeyCreationTime() { - return keyCreationTime; + return this.keyCreationTime; } @Override public String getExtendedInformation() { - return extendedInformation; + return this.extendedInformation; } @Override public boolean equals(Object obj) { if (obj != null && obj instanceof DefaultToken) { DefaultToken rhs = (DefaultToken) obj; - return this.key.equals(rhs.key) - && this.keyCreationTime == rhs.keyCreationTime + return this.key.equals(rhs.key) && this.keyCreationTime == rhs.keyCreationTime && this.extendedInformation.equals(rhs.extendedInformation); } return false; @@ -67,16 +70,16 @@ public class DefaultToken implements Token { @Override public int hashCode() { int code = 979; - code = code * key.hashCode(); - code = code * new Long(keyCreationTime).hashCode(); - code = code * extendedInformation.hashCode(); + code = code * this.key.hashCode(); + code = code * new Long(this.keyCreationTime).hashCode(); + code = code * this.extendedInformation.hashCode(); return code; } @Override public String toString() { - return "DefaultToken[key=" + key + "; creation=" + new Date(keyCreationTime) - + "; extended=" + extendedInformation + "]"; + return "DefaultToken[key=" + this.key + "; creation=" + new Date(this.keyCreationTime) + "; extended=" + + this.extendedInformation + "]"; } } diff --git a/core/src/main/java/org/springframework/security/core/token/KeyBasedPersistenceTokenService.java b/core/src/main/java/org/springframework/security/core/token/KeyBasedPersistenceTokenService.java index b38b078665..cf3f0f0206 100644 --- a/core/src/main/java/org/springframework/security/core/token/KeyBasedPersistenceTokenService.java +++ b/core/src/main/java/org/springframework/security/core/token/KeyBasedPersistenceTokenService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; import java.security.SecureRandom; @@ -72,48 +73,49 @@ import org.springframework.util.StringUtils; * */ public class KeyBasedPersistenceTokenService implements TokenService, InitializingBean { + private int pseudoRandomNumberBytes = 32; + private String serverSecret; + private Integer serverInteger; + private SecureRandom secureRandom; + @Override public Token allocateToken(String extendedInformation) { - Assert.notNull(extendedInformation, - "Must provided non-null extendedInformation (but it can be empty)"); + Assert.notNull(extendedInformation, "Must provided non-null extendedInformation (but it can be empty)"); long creationTime = new Date().getTime(); String serverSecret = computeServerSecretApplicableAt(creationTime); String pseudoRandomNumber = generatePseudoRandomNumber(); - String content = creationTime + ":" + pseudoRandomNumber + ":" - + extendedInformation; - - // Compute key - String sha512Hex = Sha512DigestUtils.shaHex(content + ":" + serverSecret); - String keyPayload = content + ":" + sha512Hex; - String key = Utf8.decode(Base64.getEncoder().encode(Utf8.encode(keyPayload))); - + String content = creationTime + ":" + pseudoRandomNumber + ":" + extendedInformation; + String key = computeKey(serverSecret, content); return new DefaultToken(key, creationTime, extendedInformation); } + private String computeKey(String serverSecret, String content) { + String sha512Hex = Sha512DigestUtils.shaHex(content + ":" + serverSecret); + String keyPayload = content + ":" + sha512Hex; + return Utf8.decode(Base64.getEncoder().encode(Utf8.encode(keyPayload))); + } + + @Override public Token verifyToken(String key) { if (key == null || "".equals(key)) { return null; } - String[] tokens = StringUtils.delimitedListToStringArray( - Utf8.decode(Base64.getDecoder().decode(Utf8.encode(key))), ":"); - Assert.isTrue(tokens.length >= 4, () -> "Expected 4 or more tokens but found " - + tokens.length); - + String[] tokens = StringUtils + .delimitedListToStringArray(Utf8.decode(Base64.getDecoder().decode(Utf8.encode(key))), ":"); + Assert.isTrue(tokens.length >= 4, () -> "Expected 4 or more tokens but found " + tokens.length); long creationTime; try { creationTime = Long.decode(tokens[0]); } - catch (NumberFormatException nfe) { + catch (NumberFormatException ex) { throw new IllegalArgumentException("Expected number but found " + tokens[0]); } - String serverSecret = computeServerSecretApplicableAt(creationTime); String pseudoRandomNumber = tokens[1]; - // Permit extendedInfo to itself contain ":" characters StringBuilder extendedInfo = new StringBuilder(); for (int i = 2; i < tokens.length - 1; i++) { @@ -122,15 +124,11 @@ public class KeyBasedPersistenceTokenService implements TokenService, Initializi } extendedInfo.append(tokens[i]); } - String sha1Hex = tokens[tokens.length - 1]; - // Verification - String content = creationTime + ":" + pseudoRandomNumber + ":" - + extendedInfo.toString(); + String content = creationTime + ":" + pseudoRandomNumber + ":" + extendedInfo.toString(); String expectedSha512Hex = Sha512DigestUtils.shaHex(content + ":" + serverSecret); Assert.isTrue(expectedSha512Hex.equals(sha1Hex), "Key verification failure"); - return new DefaultToken(key, creationTime, extendedInfo.toString()); } @@ -138,13 +136,13 @@ public class KeyBasedPersistenceTokenService implements TokenService, Initializi * @return a pseduo random number (hex encoded) */ private String generatePseudoRandomNumber() { - byte[] randomBytes = new byte[pseudoRandomNumberBytes]; - secureRandom.nextBytes(randomBytes); + byte[] randomBytes = new byte[this.pseudoRandomNumberBytes]; + this.secureRandom.nextBytes(randomBytes); return new String(Hex.encode(randomBytes)); } private String computeServerSecretApplicableAt(long time) { - return serverSecret + ":" + new Long(time % serverInteger).intValue(); + return this.serverSecret + ":" + new Long(time % this.serverInteger).intValue(); } /** @@ -164,8 +162,7 @@ public class KeyBasedPersistenceTokenService implements TokenService, Initializi * defaults to 256) */ public void setPseudoRandomNumberBytes(int pseudoRandomNumberBytes) { - Assert.isTrue(pseudoRandomNumberBytes >= 0, - "Must have a positive pseudo random number bit size"); + Assert.isTrue(pseudoRandomNumberBytes >= 0, "Must have a positive pseudo random number bit size"); this.pseudoRandomNumberBytes = pseudoRandomNumberBytes; } @@ -173,9 +170,11 @@ public class KeyBasedPersistenceTokenService implements TokenService, Initializi this.serverInteger = serverInteger; } + @Override public void afterPropertiesSet() { - Assert.hasText(serverSecret, "Server secret required"); - Assert.notNull(serverInteger, "Server integer required"); - Assert.notNull(secureRandom, "SecureRandom instance required"); + Assert.hasText(this.serverSecret, "Server secret required"); + Assert.notNull(this.serverInteger, "Server integer required"); + Assert.notNull(this.secureRandom, "SecureRandom instance required"); } + } diff --git a/core/src/main/java/org/springframework/security/core/token/SecureRandomFactoryBean.java b/core/src/main/java/org/springframework/security/core/token/SecureRandomFactoryBean.java index 2d640de72a..2e73c5dcf7 100644 --- a/core/src/main/java/org/springframework/security/core/token/SecureRandomFactoryBean.java +++ b/core/src/main/java/org/springframework/security/core/token/SecureRandomFactoryBean.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; import java.io.InputStream; @@ -32,28 +33,29 @@ import org.springframework.util.FileCopyUtils; public class SecureRandomFactoryBean implements FactoryBean { private String algorithm = "SHA1PRNG"; + private Resource seed; + @Override public SecureRandom getObject() throws Exception { - SecureRandom rnd = SecureRandom.getInstance(algorithm); - + SecureRandom random = SecureRandom.getInstance(this.algorithm); // Request the next bytes, thus eagerly incurring the expense of default // seeding and to prevent the see from replacing the entire state - rnd.nextBytes(new byte[1]); - - if (seed != null) { + random.nextBytes(new byte[1]); + if (this.seed != null) { // Seed specified, so use it - byte[] seedBytes = FileCopyUtils.copyToByteArray(seed.getInputStream()); - rnd.setSeed(seedBytes); + byte[] seedBytes = FileCopyUtils.copyToByteArray(this.seed.getInputStream()); + random.setSeed(seedBytes); } - - return rnd; + return random; } + @Override public Class getObjectType() { return SecureRandom.class; } + @Override public boolean isSingleton() { return false; } @@ -61,7 +63,6 @@ public class SecureRandomFactoryBean implements FactoryBean { /** * Allows the Pseudo Random Number Generator (PRNG) algorithm to be nominated. * Defaults to "SHA1PRNG". - * * @param algorithm to use (mandatory) */ public void setAlgorithm(String algorithm) { @@ -76,10 +77,10 @@ public class SecureRandomFactoryBean implements FactoryBean { * {@link SecureRandom#setSeed(byte[])} method. Note that this will simply supplement, * rather than replace, the existing seed. As such, it is always safe to set a seed * using this method (it never reduces randomness). - * * @param seed to use, or null if no additional seeding is needed */ public void setSeed(Resource seed) { this.seed = seed; } + } diff --git a/core/src/main/java/org/springframework/security/core/token/Sha512DigestUtils.java b/core/src/main/java/org/springframework/security/core/token/Sha512DigestUtils.java index d1c354955b..3aaa7fc0ae 100644 --- a/core/src/main/java/org/springframework/security/core/token/Sha512DigestUtils.java +++ b/core/src/main/java/org/springframework/security/core/token/Sha512DigestUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; import java.security.MessageDigest; @@ -35,7 +36,6 @@ public abstract class Sha512DigestUtils { /** * Returns an SHA digest. - * * @return An SHA digest instance. * @throws RuntimeException when a {@link java.security.NoSuchAlgorithmException} is * caught. @@ -44,14 +44,13 @@ public abstract class Sha512DigestUtils { try { return MessageDigest.getInstance("SHA-512"); } - catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e.getMessage()); + catch (NoSuchAlgorithmException ex) { + throw new RuntimeException(ex.getMessage()); } } /** * Calculates the SHA digest and returns the value as a byte[]. - * * @param data Data to digest * @return SHA digest */ @@ -61,7 +60,6 @@ public abstract class Sha512DigestUtils { /** * Calculates the SHA digest and returns the value as a byte[]. - * * @param data Data to digest * @return SHA digest */ @@ -71,7 +69,6 @@ public abstract class Sha512DigestUtils { /** * Calculates the SHA digest and returns the value as a hex string. - * * @param data Data to digest * @return SHA digest as a hex string */ @@ -81,7 +78,6 @@ public abstract class Sha512DigestUtils { /** * Calculates the SHA digest and returns the value as a hex string. - * * @param data Data to digest * @return SHA digest as a hex string */ diff --git a/core/src/main/java/org/springframework/security/core/token/Token.java b/core/src/main/java/org/springframework/security/core/token/Token.java index 54f8a0ffb9..8ab94c0d13 100644 --- a/core/src/main/java/org/springframework/security/core/token/Token.java +++ b/core/src/main/java/org/springframework/security/core/token/Token.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; /** @@ -36,7 +37,6 @@ public interface Token { * Obtains the randomised, secure key assigned to this token. Presentation of this * token to {@link TokenService} will always return a Token that is equal * to the original Token issued for that key. - * * @return a key with appropriate randomness and security. */ String getKey(); @@ -45,7 +45,6 @@ public interface Token { * The time the token key was initially created is available from this method. Note * that a given token must never have this creation time changed. If necessary, a new * token can be requested from the {@link TokenService} to replace the original token. - * * @return the time this token key was created, in the same format as specified by * {@link java.util.Date#getTime()}. */ @@ -54,8 +53,8 @@ public interface Token { /** * Obtains the extended information associated within the token, which was presented * when the token was first created. - * * @return the user-specified extended information, if any */ String getExtendedInformation(); + } diff --git a/core/src/main/java/org/springframework/security/core/token/TokenService.java b/core/src/main/java/org/springframework/security/core/token/TokenService.java index b54af4d317..4c31ebc629 100644 --- a/core/src/main/java/org/springframework/security/core/token/TokenService.java +++ b/core/src/main/java/org/springframework/security/core/token/TokenService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; /** @@ -41,9 +42,9 @@ package org.springframework.security.core.token; * */ public interface TokenService { + /** * Forces the allocation of a new {@link Token}. - * * @param extendedInformation the extended information desired in the token (cannot be * null, but can be empty) * @return a new token that has not been issued previously, and is guaranteed to be @@ -55,11 +56,11 @@ public interface TokenService { /** * Permits verification the {@link Token#getKey()} was issued by this * TokenService and reconstructs the corresponding Token. - * * @param key as obtained from {@link Token#getKey()} and created by this * implementation * @return the token, or null if the token was not issued by this * TokenService */ Token verifyToken(String key); + } diff --git a/core/src/main/java/org/springframework/security/core/token/package-info.java b/core/src/main/java/org/springframework/security/core/token/package-info.java index b0cefaf120..87b72c1e23 100644 --- a/core/src/main/java/org/springframework/security/core/token/package-info.java +++ b/core/src/main/java/org/springframework/security/core/token/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * A service for building secure random tokens. */ package org.springframework.security.core.token; - diff --git a/core/src/main/java/org/springframework/security/core/userdetails/AuthenticationUserDetailsService.java b/core/src/main/java/org/springframework/security/core/userdetails/AuthenticationUserDetailsService.java index 7284ec5d93..7533051852 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/AuthenticationUserDetailsService.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/AuthenticationUserDetailsService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails; import org.springframework.security.core.Authentication; @@ -27,11 +28,11 @@ import org.springframework.security.core.Authentication; public interface AuthenticationUserDetailsService { /** - * * @param token The pre-authenticated authentication token * @return UserDetails for the given authentication token, never null. * @throws UsernameNotFoundException if no user details can be found for the given * authentication token */ UserDetails loadUserDetails(T token) throws UsernameNotFoundException; + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsService.java b/core/src/main/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsService.java index f0d8c12e56..10712776a7 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsService.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsService.java @@ -21,9 +21,10 @@ import java.util.Collection; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.util.Assert; import reactor.core.publisher.Mono; +import org.springframework.util.Assert; + /** * A {@link Map} based implementation of {@link ReactiveUserDetailsService} * @@ -31,6 +32,7 @@ import reactor.core.publisher.Mono; * @since 5.0 */ public class MapReactiveUserDetailsService implements ReactiveUserDetailsService, ReactiveUserDetailsPasswordService { + private final Map users; /** @@ -64,25 +66,32 @@ public class MapReactiveUserDetailsService implements ReactiveUserDetailsService @Override public Mono findByUsername(String username) { String key = getKey(username); - UserDetails result = users.get(key); - return result == null ? Mono.empty() : Mono.just(User.withUserDetails(result).build()); + UserDetails result = this.users.get(key); + return (result != null) ? Mono.just(User.withUserDetails(result).build()) : Mono.empty(); } @Override public Mono updatePassword(UserDetails user, String newPassword) { + // @formatter:off return Mono.just(user) - .map(u -> - User.withUserDetails(u) - .password(newPassword) - .build() - ) - .doOnNext(u -> { + .map((userDetails) -> withNewPassword(userDetails, newPassword)) + .doOnNext((userDetails) -> { String key = getKey(user.getUsername()); - this.users.put(key, u); + this.users.put(key, userDetails); }); + // @formatter:on + } + + private UserDetails withNewPassword(UserDetails userDetails, String newPassword) { + // @formatter:off + return User.withUserDetails(userDetails) + .password(newPassword) + .build(); + // @formatter:on } private String getKey(String username) { return username.toLowerCase(); } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsPasswordService.java b/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsPasswordService.java index 867c59758a..57e55bb2ce 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsPasswordService.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsPasswordService.java @@ -20,6 +20,7 @@ import reactor.core.publisher.Mono; /** * An API for changing a {@link UserDetails} password. + * * @author Rob Winch * @since 5.1 */ @@ -28,10 +29,10 @@ public interface ReactiveUserDetailsPasswordService { /** * Modify the specified user's password. This should change the user's password in the * persistent user repository (datbase, LDAP etc). - * * @param user the user to modify the password for * @param newPassword the password to change to * @return the updated UserDetails with the new password */ Mono updatePassword(UserDetails user, String newPassword); + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsService.java b/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsService.java index c66fbec0a6..63c9b6546d 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsService.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/ReactiveUserDetailsService.java @@ -32,4 +32,5 @@ public interface ReactiveUserDetailsService { * @return the {@link UserDetails}. Cannot be null */ Mono findByUsername(String username); + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/User.java b/core/src/main/java/org/springframework/security/core/userdetails/User.java index 05419a012e..e716f64bd7 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/User.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/User.java @@ -30,8 +30,9 @@ import java.util.function.Function; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.security.core.GrantedAuthority; + import org.springframework.security.core.CredentialsContainer; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; @@ -65,31 +66,30 @@ public class User implements UserDetails, CredentialsContainer { private static final Log logger = LogFactory.getLog(User.class); - // ~ Instance fields - // ================================================================================================ private String password; - private final String username; - private final Set authorities; - private final boolean accountNonExpired; - private final boolean accountNonLocked; - private final boolean credentialsNonExpired; - private final boolean enabled; - // ~ Constructors - // =================================================================================================== + private final String username; + + private final Set authorities; + + private final boolean accountNonExpired; + + private final boolean accountNonLocked; + + private final boolean credentialsNonExpired; + + private final boolean enabled; /** * Calls the more complex constructor with all boolean arguments set to {@code true}. */ - public User(String username, String password, - Collection authorities) { + public User(String username, String password, Collection authorities) { this(username, password, true, true, true, true, authorities); } /** * Construct the User with the details required by * {@link org.springframework.security.authentication.dao.DaoAuthenticationProvider}. - * * @param username the username presented to the * DaoAuthenticationProvider * @param password the password that should be presented to the @@ -101,19 +101,14 @@ public class User implements UserDetails, CredentialsContainer { * @param accountNonLocked set to true if the account is not locked * @param authorities the authorities that should be granted to the caller if they * presented the correct username and password and the user is enabled. Not null. - * * @throws IllegalArgumentException if a null value was passed either as * a parameter or as an element in the GrantedAuthority collection */ - public User(String username, String password, boolean enabled, - boolean accountNonExpired, boolean credentialsNonExpired, - boolean accountNonLocked, Collection authorities) { - - if (((username == null) || "".equals(username)) || (password == null)) { - throw new IllegalArgumentException( - "Cannot pass null or empty values to constructor"); - } - + public User(String username, String password, boolean enabled, boolean accountNonExpired, + boolean credentialsNonExpired, boolean accountNonLocked, + Collection authorities) { + Assert.isTrue(username != null && !"".equals(username) && password != null, + "Cannot pass null or empty values to constructor"); this.username = username; this.password = password; this.enabled = enabled; @@ -123,79 +118,58 @@ public class User implements UserDetails, CredentialsContainer { this.authorities = Collections.unmodifiableSet(sortAuthorities(authorities)); } - // ~ Methods - // ======================================================================================================== - + @Override public Collection getAuthorities() { - return authorities; + return this.authorities; } + @Override public String getPassword() { - return password; + return this.password; } + @Override public String getUsername() { - return username; + return this.username; } + @Override public boolean isEnabled() { - return enabled; + return this.enabled; } + @Override public boolean isAccountNonExpired() { - return accountNonExpired; + return this.accountNonExpired; } + @Override public boolean isAccountNonLocked() { - return accountNonLocked; + return this.accountNonLocked; } + @Override public boolean isCredentialsNonExpired() { - return credentialsNonExpired; + return this.credentialsNonExpired; } + @Override public void eraseCredentials() { - password = null; + this.password = null; } - private static SortedSet sortAuthorities( - Collection authorities) { + private static SortedSet sortAuthorities(Collection authorities) { Assert.notNull(authorities, "Cannot pass a null GrantedAuthority collection"); // Ensure array iteration order is predictable (as per // UserDetails.getAuthorities() contract and SEC-717) - SortedSet sortedAuthorities = new TreeSet<>( - new AuthorityComparator()); - + SortedSet sortedAuthorities = new TreeSet<>(new AuthorityComparator()); for (GrantedAuthority grantedAuthority : authorities) { - Assert.notNull(grantedAuthority, - "GrantedAuthority list cannot contain any null elements"); + Assert.notNull(grantedAuthority, "GrantedAuthority list cannot contain any null elements"); sortedAuthorities.add(grantedAuthority); } - return sortedAuthorities; } - private static class AuthorityComparator implements Comparator, - Serializable { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - - public int compare(GrantedAuthority g1, GrantedAuthority g2) { - // Neither should ever be null as each entry is checked before adding it to - // the set. - // If the authority is null, it is a custom authority and should precede - // others. - if (g2.getAuthority() == null) { - return -1; - } - - if (g1.getAuthority() == null) { - return 1; - } - - return g1.getAuthority().compareTo(g2.getAuthority()); - } - } - /** * Returns {@code true} if the supplied object is a {@code User} instance with the * same {@code username} value. @@ -204,9 +178,9 @@ public class User implements UserDetails, CredentialsContainer { * the same principal. */ @Override - public boolean equals(Object rhs) { - if (rhs instanceof User) { - return username.equals(((User) rhs).username); + public boolean equals(Object obj) { + if (obj instanceof User) { + return this.username.equals(((User) obj).username); } return false; } @@ -216,7 +190,7 @@ public class User implements UserDetails, CredentialsContainer { */ @Override public int hashCode() { - return username.hashCode(); + return this.username.hashCode(); } @Override @@ -227,33 +201,27 @@ public class User implements UserDetails, CredentialsContainer { sb.append("Password: [PROTECTED]; "); sb.append("Enabled: ").append(this.enabled).append("; "); sb.append("AccountNonExpired: ").append(this.accountNonExpired).append("; "); - sb.append("credentialsNonExpired: ").append(this.credentialsNonExpired) - .append("; "); + sb.append("credentialsNonExpired: ").append(this.credentialsNonExpired).append("; "); sb.append("AccountNonLocked: ").append(this.accountNonLocked).append("; "); - - if (!authorities.isEmpty()) { + if (!this.authorities.isEmpty()) { sb.append("Granted Authorities: "); - boolean first = true; - for (GrantedAuthority auth : authorities) { + for (GrantedAuthority auth : this.authorities) { if (!first) { sb.append(","); } first = false; - sb.append(auth); } } else { sb.append("Not granted any authorities"); } - return sb.toString(); } /** * Creates a UserBuilder with a specified user name - * * @param username the username to use * @return the UserBuilder */ @@ -263,7 +231,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Creates a UserBuilder - * * @return the UserBuilder */ public static UserBuilder builder() { @@ -272,8 +239,8 @@ public class User implements UserDetails, CredentialsContainer { /** *

      - * WARNING: This method is considered unsafe for production and is only intended - * for sample applications. + * WARNING: This method is considered unsafe for production and is only + * intended for sample applications. *

      *

      * Creates a user and automatically encodes the provided password using @@ -289,8 +256,7 @@ public class User implements UserDetails, CredentialsContainer { * .build(); * // outputs {bcrypt}$2a$10$dXJ3SW6G7P50lGmMkkmwe.20cQQubK3.HZWzG3YB1tlRy.fqvM/BG * System.out.println(user.getPassword()); - * - * + * * * This is not safe for production (it is intended for getting started experience) * because the password "password" is compiled into the source code and then is @@ -300,8 +266,8 @@ public class User implements UserDetails, CredentialsContainer { * securely hashed. This means if the UserDetails password is accidentally exposed, * the password is securely stored. * - * In a production setting, it is recommended to hash the password ahead of time. - * For example: + * In a production setting, it is recommended to hash the password ahead of time. For + * example: * *

       	 * 
      @@ -309,8 +275,7 @@ public class User implements UserDetails, CredentialsContainer {
       	 * // outputs {bcrypt}$2a$10$dXJ3SW6G7P50lGmMkkmwe.20cQQubK3.HZWzG3YB1tlRy.fqvM/BG
       	 * // remember the password that is printed out and use in the next step
       	 * System.out.println(encoder.encode("password"));
      -	 * 
      -	 * 
      + * * *
       	 * 
      @@ -318,47 +283,76 @@ public class User implements UserDetails, CredentialsContainer {
       	 *     .password("{bcrypt}$2a$10$dXJ3SW6G7P50lGmMkkmwe.20cQQubK3.HZWzG3YB1tlRy.fqvM/BG")
       	 *     .roles("USER")
       	 *     .build();
      -	 * 
      -	 * 
      - * + * * @return a UserBuilder that automatically encodes the password with the default * PasswordEncoder * @deprecated Using this method is not considered safe for production, but is * acceptable for demos and getting started. For production purposes, ensure the * password is encoded externally. See the method Javadoc for additional details. - * There are no plans to remove this support. It is deprecated to indicate - * that this is considered insecure for production purposes. + * There are no plans to remove this support. It is deprecated to indicate that this + * is considered insecure for production purposes. */ @Deprecated public static UserBuilder withDefaultPasswordEncoder() { - logger.warn("User.withDefaultPasswordEncoder() is considered unsafe for production and is only intended for sample applications."); + logger.warn("User.withDefaultPasswordEncoder() is considered unsafe for production " + + "and is only intended for sample applications."); PasswordEncoder encoder = PasswordEncoderFactories.createDelegatingPasswordEncoder(); return builder().passwordEncoder(encoder::encode); } public static UserBuilder withUserDetails(UserDetails userDetails) { + // @formatter:off return withUsername(userDetails.getUsername()) - .password(userDetails.getPassword()) - .accountExpired(!userDetails.isAccountNonExpired()) - .accountLocked(!userDetails.isAccountNonLocked()) - .authorities(userDetails.getAuthorities()) - .credentialsExpired(!userDetails.isCredentialsNonExpired()) - .disabled(!userDetails.isEnabled()); + .password(userDetails.getPassword()) + .accountExpired(!userDetails.isAccountNonExpired()) + .accountLocked(!userDetails.isAccountNonLocked()) + .authorities(userDetails.getAuthorities()) + .credentialsExpired(!userDetails.isCredentialsNonExpired()) + .disabled(!userDetails.isEnabled()); + // @formatter:on + } + + private static class AuthorityComparator implements Comparator, Serializable { + + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + + @Override + public int compare(GrantedAuthority g1, GrantedAuthority g2) { + // Neither should ever be null as each entry is checked before adding it to + // the set. If the authority is null, it is a custom authority and should + // precede others. + if (g2.getAuthority() == null) { + return -1; + } + if (g1.getAuthority() == null) { + return 1; + } + return g1.getAuthority().compareTo(g2.getAuthority()); + } + } /** * Builds the user to be added. At minimum the username, password, and authorities * should provided. The remaining attributes have reasonable defaults. */ - public static class UserBuilder { + public static final class UserBuilder { + private String username; + private String password; + private List authorities; + private boolean accountExpired; + private boolean accountLocked; + private boolean credentialsExpired; + private boolean disabled; - private Function passwordEncoder = password -> password; + + private Function passwordEncoder = (password) -> password; /** * Creates a new instance @@ -368,7 +362,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Populates the username. This attribute is required. - * * @param username the username. Cannot be null. * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -381,7 +374,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Populates the password. This attribute is required. - * * @param password the password. Cannot be null. * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -393,9 +385,8 @@ public class User implements UserDetails, CredentialsContainer { } /** - * Encodes the current password (if non-null) and any future passwords supplied - * to {@link #password(String)}. - * + * Encodes the current password (if non-null) and any future passwords supplied to + * {@link #password(String)}. * @param encoder the encoder to use * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -425,18 +416,16 @@ public class User implements UserDetails, CredentialsContainer { * This attribute is required, but can also be populated with * {@link #authorities(String...)}. *

      - * * @param roles the roles for this user (i.e. USER, ADMIN, etc). Cannot be null, * contain null values or start with "ROLE_" * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) */ public UserBuilder roles(String... roles) { - List authorities = new ArrayList<>( - roles.length); + List authorities = new ArrayList<>(roles.length); for (String role : roles) { - Assert.isTrue(!role.startsWith("ROLE_"), () -> role - + " cannot start with ROLE_ (it is automatically added)"); + Assert.isTrue(!role.startsWith("ROLE_"), + () -> role + " cannot start with ROLE_ (it is automatically added)"); authorities.add(new SimpleGrantedAuthority("ROLE_" + role)); } return authorities(authorities); @@ -444,7 +433,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Populates the authorities. This attribute is required. - * * @param authorities the authorities for this user. Cannot be null, or contain * null values * @return the {@link UserBuilder} for method chaining (i.e. to populate @@ -457,7 +445,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Populates the authorities. This attribute is required. - * * @param authorities the authorities for this user. Cannot be null, or contain * null values * @return the {@link UserBuilder} for method chaining (i.e. to populate @@ -471,7 +458,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Populates the authorities. This attribute is required. - * * @param authorities the authorities for this user (i.e. ROLE_USER, ROLE_ADMIN, * etc). Cannot be null, or contain null values * @return the {@link UserBuilder} for method chaining (i.e. to populate @@ -484,7 +470,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Defines if the account is expired or not. Default is false. - * * @param accountExpired true if the account is expired, false otherwise * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -496,7 +481,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Defines if the account is locked or not. Default is false. - * * @param accountLocked true if the account is locked, false otherwise * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -508,7 +492,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Defines if the credentials are expired or not. Default is false. - * * @param credentialsExpired true if the credentials are expired, false otherwise * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -520,7 +503,6 @@ public class User implements UserDetails, CredentialsContainer { /** * Defines if the account is disabled or not. Default is false. - * * @param disabled true if the account is disabled, false otherwise * @return the {@link UserBuilder} for method chaining (i.e. to populate * additional attributes for this user) @@ -531,9 +513,11 @@ public class User implements UserDetails, CredentialsContainer { } public UserDetails build() { - String encodedPassword = this.passwordEncoder.apply(password); - return new User(username, encodedPassword, !disabled, !accountExpired, - !credentialsExpired, !accountLocked, authorities); + String encodedPassword = this.passwordEncoder.apply(this.password); + return new User(this.username, encodedPassword, !this.disabled, !this.accountExpired, + !this.credentialsExpired, !this.accountLocked, this.authorities); } + } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UserCache.java b/core/src/main/java/org/springframework/security/core/userdetails/UserCache.java index 2f250e6ff6..d352a396b1 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/UserCache.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UserCache.java @@ -32,19 +32,14 @@ package org.springframework.security.core.userdetails; * configure a cache to store the UserDetails information rather than loading it * each time. * - * @see org.springframework.security.authentication.dao.AbstractUserDetailsAuthenticationProvider - * * @author Ben Alex + * @see org.springframework.security.authentication.dao.AbstractUserDetailsAuthenticationProvider */ public interface UserCache { - // ~ Methods - // ======================================================================================================== /** * Obtains a {@link UserDetails} from the cache. - * * @param username the {@link User#getUsername()} used to place the user in the cache - * * @return the populated UserDetails or null if the user * could not be found or if the cache entry has expired */ @@ -53,7 +48,6 @@ public interface UserCache { /** * Places a {@link UserDetails} in the cache. The username is the key * used to subsequently retrieve the UserDetails. - * * @param user the fully populated UserDetails to place in the cache */ void putUserInCache(UserDetails user); @@ -66,8 +60,8 @@ public interface UserCache { * Some cache implementations may not support eviction from the cache, in which case * they should provide appropriate behaviour to alter the user in either its * documentation, via an exception, or through a log message. - * * @param username to be evicted from the cache */ void removeUserFromCache(String username); + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UserDetails.java b/core/src/main/java/org/springframework/security/core/userdetails/UserDetails.java index d1f83da9ef..664725631e 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/UserDetails.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UserDetails.java @@ -16,12 +16,12 @@ package org.springframework.security.core.userdetails; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; - import java.io.Serializable; import java.util.Collection; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; + /** * Provides core user information. * @@ -36,32 +36,27 @@ import java.util.Collection; * {@link org.springframework.security.core.userdetails.User} for a reference * implementation (which you might like to extend or use in your code). * + * @author Ben Alex * @see UserDetailsService * @see UserCache - * - * @author Ben Alex */ public interface UserDetails extends Serializable { - // ~ Methods - // ======================================================================================================== /** * Returns the authorities granted to the user. Cannot return null. - * * @return the authorities, sorted by natural key (never null) */ Collection getAuthorities(); /** * Returns the password used to authenticate the user. - * * @return the password */ String getPassword(); /** - * Returns the username used to authenticate the user. Cannot return null. - * + * Returns the username used to authenticate the user. Cannot return + * null. * @return the username (never null) */ String getUsername(); @@ -69,7 +64,6 @@ public interface UserDetails extends Serializable { /** * Indicates whether the user's account has expired. An expired account cannot be * authenticated. - * * @return true if the user's account is valid (ie non-expired), * false if no longer valid (ie expired) */ @@ -78,7 +72,6 @@ public interface UserDetails extends Serializable { /** * Indicates whether the user is locked or unlocked. A locked user cannot be * authenticated. - * * @return true if the user is not locked, false otherwise */ boolean isAccountNonLocked(); @@ -86,7 +79,6 @@ public interface UserDetails extends Serializable { /** * Indicates whether the user's credentials (password) has expired. Expired * credentials prevent authentication. - * * @return true if the user's credentials are valid (ie non-expired), * false if no longer valid (ie expired) */ @@ -95,8 +87,8 @@ public interface UserDetails extends Serializable { /** * Indicates whether the user is enabled or disabled. A disabled user cannot be * authenticated. - * * @return true if the user is enabled, false otherwise */ boolean isEnabled(); + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapper.java b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapper.java index daea7fa063..f116f69104 100755 --- a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapper.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapper.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails; import org.springframework.beans.factory.InitializingBean; @@ -28,8 +29,9 @@ import org.springframework.util.Assert; * @author Scott Battaglia * @since 2.0 */ -public class UserDetailsByNameServiceWrapper implements - AuthenticationUserDetailsService, InitializingBean { +public class UserDetailsByNameServiceWrapper + implements AuthenticationUserDetailsService, InitializingBean { + private UserDetailsService userDetailsService = null; /** @@ -44,7 +46,6 @@ public class UserDetailsByNameServiceWrapper implement * Constructs a new wrapper using the supplied * {@link org.springframework.security.core.userdetails.UserDetailsService} as the * service to delegate to. - * * @param userDetailsService the UserDetailsService to delegate to. */ public UserDetailsByNameServiceWrapper(final UserDetailsService userDetailsService) { @@ -57,6 +58,7 @@ public class UserDetailsByNameServiceWrapper implement * * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet() */ + @Override public void afterPropertiesSet() { Assert.notNull(this.userDetailsService, "UserDetailsService must be set"); } @@ -64,16 +66,17 @@ public class UserDetailsByNameServiceWrapper implement /** * Get the UserDetails object from the wrapped UserDetailsService implementation */ + @Override public UserDetails loadUserDetails(T authentication) throws UsernameNotFoundException { return this.userDetailsService.loadUserByUsername(authentication.getName()); } /** * Set the wrapped UserDetailsService implementation - * * @param aUserDetailsService The wrapped UserDetailsService to set */ public void setUserDetailsService(UserDetailsService aUserDetailsService) { this.userDetailsService = aUserDetailsService; } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsChecker.java b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsChecker.java index 706022b283..a136b13ddc 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsChecker.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsChecker.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails; /** @@ -28,14 +29,15 @@ package org.springframework.security.core.userdetails; * * @author Luke Taylor * @since 2.0 - * * @see org.springframework.security.authentication.AccountStatusUserDetailsChecker * @see org.springframework.security.authentication.AccountStatusException */ public interface UserDetailsChecker { + /** * Examines the User * @param toCheck the UserDetails instance whose status should be checked. */ void check(UserDetails toCheck); + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsPasswordService.java b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsPasswordService.java index fd56992793..ed85e088aa 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsPasswordService.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsPasswordService.java @@ -18,6 +18,7 @@ package org.springframework.security.core.userdetails; /** * An API for changing a {@link UserDetails} password. + * * @author Rob Winch * @since 5.1 */ @@ -26,11 +27,11 @@ public interface UserDetailsPasswordService { /** * Modify the specified user's password. This should change the user's password in the * persistent user repository (database, LDAP etc). - * * @param user the user to modify the password for - * @param newPassword the password to change to, - * encoded by the configured {@code PasswordEncoder} + * @param newPassword the password to change to, encoded by the configured + * {@code PasswordEncoder} * @return the updated UserDetails with the new password */ UserDetails updatePassword(UserDetails user, String newPassword); + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsService.java b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsService.java index 82f5631e35..22ac216297 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsService.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UserDetailsService.java @@ -27,14 +27,11 @@ package org.springframework.security.core.userdetails; * The interface requires only one read-only method, which simplifies support for new * data-access strategies. * + * @author Ben Alex * @see org.springframework.security.authentication.dao.DaoAuthenticationProvider * @see UserDetails - * - * @author Ben Alex */ public interface UserDetailsService { - // ~ Methods - // ======================================================================================================== /** * Locates the user based on the username. In the actual implementation, the search @@ -42,13 +39,11 @@ public interface UserDetailsService { * implementation instance is configured. In this case, the UserDetails * object that comes back may have a username that is of a different case than what * was actually requested.. - * * @param username the username identifying the user whose data is required. - * * @return a fully populated user record (never null) - * * @throws UsernameNotFoundException if the user could not be found or the user has no * GrantedAuthority */ UserDetails loadUserByUsername(String username) throws UsernameNotFoundException; + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/UsernameNotFoundException.java b/core/src/main/java/org/springframework/security/core/userdetails/UsernameNotFoundException.java index af4e5f539a..22c3c1d8e5 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/UsernameNotFoundException.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/UsernameNotFoundException.java @@ -25,12 +25,9 @@ import org.springframework.security.core.AuthenticationException; * @author Ben Alex */ public class UsernameNotFoundException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== /** * Constructs a UsernameNotFoundException with the specified message. - * * @param msg the detail message. */ public UsernameNotFoundException(String msg) { @@ -40,11 +37,11 @@ public class UsernameNotFoundException extends AuthenticationException { /** * Constructs a {@code UsernameNotFoundException} with the specified message and root * cause. - * * @param msg the detail message. - * @param t root cause + * @param cause root cause */ - public UsernameNotFoundException(String msg, Throwable t) { - super(msg, t); + public UsernameNotFoundException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCache.java b/core/src/main/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCache.java index dacd1c69cf..201199ffeb 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCache.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCache.java @@ -18,80 +18,62 @@ package org.springframework.security.core.userdetails.cache; import net.sf.ehcache.Ehcache; import net.sf.ehcache.Element; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.beans.factory.InitializingBean; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.userdetails.UserCache; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.util.Assert; /** - * Caches User objects using a Spring IoC defined EHCACHE. + * Caches User objects using a Spring IoC defined + * EHCACHE. * * @author Ben Alex */ public class EhCacheBasedUserCache implements UserCache, InitializingBean { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(EhCacheBasedUserCache.class); - // ~ Instance fields - // ================================================================================================ - private Ehcache cache; - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.notNull(cache, "cache mandatory"); + Assert.notNull(this.cache, "cache mandatory"); } public Ehcache getCache() { - return cache; + return this.cache; } + @Override public UserDetails getUserFromCache(String username) { - Element element = cache.get(username); - - if (logger.isDebugEnabled()) { - logger.debug("Cache hit: " + (element != null) + "; username: " + username); - } - - if (element == null) { - return null; - } - else { - return (UserDetails) element.getValue(); - } + Element element = this.cache.get(username); + logger.debug(LogMessage.of(() -> "Cache hit: " + (element != null) + "; username: " + username)); + return (element != null) ? (UserDetails) element.getValue() : null; } + @Override public void putUserInCache(UserDetails user) { Element element = new Element(user.getUsername(), user); - - if (logger.isDebugEnabled()) { - logger.debug("Cache put: " + element.getKey()); - } - - cache.put(element); + logger.debug(LogMessage.of(() -> "Cache put: " + element.getKey())); + this.cache.put(element); } public void removeUserFromCache(UserDetails user) { - if (logger.isDebugEnabled()) { - logger.debug("Cache remove: " + user.getUsername()); - } - + logger.debug(LogMessage.of(() -> "Cache remove: " + user.getUsername())); this.removeUserFromCache(user.getUsername()); } + @Override public void removeUserFromCache(String username) { - cache.remove(username); + this.cache.remove(username); } public void setCache(Ehcache cache) { this.cache = cache; } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/cache/NullUserCache.java b/core/src/main/java/org/springframework/security/core/userdetails/cache/NullUserCache.java index 79f40369e5..d831685bd1 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/cache/NullUserCache.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/cache/NullUserCache.java @@ -25,16 +25,18 @@ import org.springframework.security.core.userdetails.UserDetails; * @author Ben Alex */ public class NullUserCache implements UserCache { - // ~ Methods - // ======================================================================================================== + @Override public UserDetails getUserFromCache(String username) { return null; } + @Override public void putUserInCache(UserDetails user) { } + @Override public void removeUserFromCache(String username) { } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCache.java b/core/src/main/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCache.java index 8150bd9023..f0cae216e5 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCache.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCache.java @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails.cache; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.cache.Cache; +import org.springframework.core.log.LogMessage; import org.springframework.security.core.userdetails.UserCache; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.util.Assert; @@ -30,58 +33,36 @@ import org.springframework.util.Assert; */ public class SpringCacheBasedUserCache implements UserCache { - // ~ Static fields/initializers - // ===================================================================================== - private static final Log logger = LogFactory.getLog(SpringCacheBasedUserCache.class); - // ~ Instance fields - // ================================================================================================ - private final Cache cache; - // ~ Constructors - // =================================================================================================== - public SpringCacheBasedUserCache(Cache cache) { Assert.notNull(cache, "cache mandatory"); this.cache = cache; } - // ~ Methods - // ======================================================================================================== - + @Override public UserDetails getUserFromCache(String username) { - Cache.ValueWrapper element = username != null ? cache.get(username) : null; - - if (logger.isDebugEnabled()) { - logger.debug("Cache hit: " + (element != null) + "; username: " + username); - } - - if (element == null) { - return null; - } - else { - return (UserDetails) element.get(); - } + Cache.ValueWrapper element = (username != null) ? this.cache.get(username) : null; + logger.debug(LogMessage.of(() -> "Cache hit: " + (element != null) + "; username: " + username)); + return (element != null) ? (UserDetails) element.get() : null; } + @Override public void putUserInCache(UserDetails user) { - if (logger.isDebugEnabled()) { - logger.debug("Cache put: " + user.getUsername()); - } - cache.put(user.getUsername(), user); + logger.debug(LogMessage.of(() -> "Cache put: " + user.getUsername())); + this.cache.put(user.getUsername(), user); } public void removeUserFromCache(UserDetails user) { - if (logger.isDebugEnabled()) { - logger.debug("Cache remove: " + user.getUsername()); - } - + logger.debug(LogMessage.of(() -> "Cache remove: " + user.getUsername())); this.removeUserFromCache(user.getUsername()); } + @Override public void removeUserFromCache(String username) { - cache.evict(username); + this.cache.evict(username); } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/cache/package-info.java b/core/src/main/java/org/springframework/security/core/userdetails/cache/package-info.java index 44ba285f8e..3ce5d7e93b 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/cache/package-info.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/cache/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Implementations of {@link org.springframework.security.core.userdetails.UserCache UserCache}. + * Implementations of {@link org.springframework.security.core.userdetails.UserCache + * UserCache}. */ package org.springframework.security.core.userdetails.cache; - diff --git a/core/src/main/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImpl.java b/core/src/main/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImpl.java index ab271a6799..a76d809bf7 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImpl.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImpl.java @@ -25,6 +25,7 @@ import org.springframework.context.ApplicationContextException; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.support.JdbcDaoSupport; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityMessageSource; @@ -107,35 +108,41 @@ import org.springframework.util.Assert; * @author colin sampaleanu * @author Luke Taylor */ -public class JdbcDaoImpl extends JdbcDaoSupport - implements UserDetailsService, MessageSourceAware { - // ~ Static fields/initializers - // ===================================================================================== +public class JdbcDaoImpl extends JdbcDaoSupport implements UserDetailsService, MessageSourceAware { + // @formatter:off public static final String DEF_USERS_BY_USERNAME_QUERY = "select username,password,enabled " - + "from users " + "where username = ?"; + + "from users " + + "where username = ?"; + // @formatter:on + + // @formatter:off public static final String DEF_AUTHORITIES_BY_USERNAME_QUERY = "select username,authority " - + "from authorities " + "where username = ?"; + + "from authorities " + + "where username = ?"; + // @formatter:on + + // @formatter:off public static final String DEF_GROUP_AUTHORITIES_BY_USERNAME_QUERY = "select g.id, g.group_name, ga.authority " + "from groups g, group_members gm, group_authorities ga " - + "where gm.username = ? " + "and g.id = ga.group_id " - + "and g.id = gm.group_id"; - - // ~ Instance fields - // ================================================================================================ + + "where gm.username = ? " + "and g.id = ga.group_id " + "and g.id = gm.group_id"; + // @formatter:on protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); private String authoritiesByUsernameQuery; - private String groupAuthoritiesByUsernameQuery; - private String usersByUsernameQuery; - private String rolePrefix = ""; - private boolean usernameBasedPrimaryKey = true; - private boolean enableAuthorities = true; - private boolean enableGroups; - // ~ Constructors - // =================================================================================================== + private String groupAuthoritiesByUsernameQuery; + + private String usersByUsernameQuery; + + private String rolePrefix = ""; + + private boolean usernameBasedPrimaryKey = true; + + private boolean enableAuthorities = true; + + private boolean enableGroups; public JdbcDaoImpl() { this.usersByUsernameQuery = DEF_USERS_BY_USERNAME_QUERY; @@ -143,9 +150,6 @@ public class JdbcDaoImpl extends JdbcDaoSupport this.groupAuthoritiesByUsernameQuery = DEF_GROUP_AUTHORITIES_BY_USERNAME_QUERY; } - // ~ Methods - // ======================================================================================================== - /** * @return the messages */ @@ -156,13 +160,11 @@ public class JdbcDaoImpl extends JdbcDaoSupport /** * Allows subclasses to add their own granted authorities to the list to be returned * in the UserDetails. - * * @param username the username, for use by finder methods * @param authorities the current granted authorities, as populated from the * authoritiesByUsername mapping */ - protected void addCustomAuthorities(String username, - List authorities) { + protected void addCustomAuthorities(String username, List authorities) { } public String getUsersByUsernameQuery() { @@ -176,43 +178,28 @@ public class JdbcDaoImpl extends JdbcDaoSupport } @Override - public UserDetails loadUserByUsername(String username) - throws UsernameNotFoundException { + public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException { List users = loadUsersByUsername(username); - if (users.size() == 0) { this.logger.debug("Query returned no results for user '" + username + "'"); - - throw new UsernameNotFoundException( - this.messages.getMessage("JdbcDaoImpl.notFound", - new Object[] { username }, "Username {0} not found")); + throw new UsernameNotFoundException(this.messages.getMessage("JdbcDaoImpl.notFound", + new Object[] { username }, "Username {0} not found")); } - UserDetails user = users.get(0); // contains no GrantedAuthority[] - Set dbAuthsSet = new HashSet<>(); - if (this.enableAuthorities) { dbAuthsSet.addAll(loadUserAuthorities(user.getUsername())); } - if (this.enableGroups) { dbAuthsSet.addAll(loadGroupAuthorities(user.getUsername())); } - List dbAuths = new ArrayList<>(dbAuthsSet); - addCustomAuthorities(user.getUsername(), dbAuths); - if (dbAuths.size() == 0) { - this.logger.debug("User '" + username - + "' has no authorities and will be treated as 'not found'"); - - throw new UsernameNotFoundException(this.messages.getMessage( - "JdbcDaoImpl.noAuthority", new Object[] { username }, - "User {0} has no GrantedAuthority")); + this.logger.debug("User '" + username + "' has no authorities and will be treated as 'not found'"); + throw new UsernameNotFoundException(this.messages.getMessage("JdbcDaoImpl.noAuthority", + new Object[] { username }, "User {0} has no GrantedAuthority")); } - return createUserDetails(username, user, dbAuths); } @@ -221,41 +208,37 @@ public class JdbcDaoImpl extends JdbcDaoSupport * objects. There should normally only be one matching user. */ protected List loadUsersByUsername(String username) { - return getJdbcTemplate().query(this.usersByUsernameQuery, - new String[] { username }, (rs, rowNum) -> { - String username1 = rs.getString(1); - String password = rs.getString(2); - boolean enabled = rs.getBoolean(3); - return new User(username1, password, enabled, true, true, true, - AuthorityUtils.NO_AUTHORITIES); - }); + // @formatter:off + RowMapper mapper = (rs, rowNum) -> { + String username1 = rs.getString(1); + String password = rs.getString(2); + boolean enabled = rs.getBoolean(3); + return new User(username1, password, enabled, true, true, true, AuthorityUtils.NO_AUTHORITIES); + }; + // @formatter:on + return getJdbcTemplate().query(this.usersByUsernameQuery, mapper, username); } /** * Loads authorities by executing the SQL from authoritiesByUsernameQuery. - * * @return a list of GrantedAuthority objects for the user */ protected List loadUserAuthorities(String username) { - return getJdbcTemplate().query(this.authoritiesByUsernameQuery, - new String[] { username }, (rs, rowNum) -> { - String roleName = JdbcDaoImpl.this.rolePrefix + rs.getString(2); - - return new SimpleGrantedAuthority(roleName); - }); + return getJdbcTemplate().query(this.authoritiesByUsernameQuery, new String[] { username }, (rs, rowNum) -> { + String roleName = JdbcDaoImpl.this.rolePrefix + rs.getString(2); + return new SimpleGrantedAuthority(roleName); + }); } /** * Loads authorities by executing the SQL from * groupAuthoritiesByUsernameQuery. - * * @return a list of GrantedAuthority objects for the user */ protected List loadGroupAuthorities(String username) { - return getJdbcTemplate().query(this.groupAuthoritiesByUsernameQuery, - new String[] { username }, (rs, rowNum) -> { + return getJdbcTemplate().query(this.groupAuthoritiesByUsernameQuery, new String[] { username }, + (rs, rowNum) -> { String roleName = getRolePrefix() + rs.getString(3); - return new SimpleGrantedAuthority(roleName); }); } @@ -263,33 +246,29 @@ public class JdbcDaoImpl extends JdbcDaoSupport /** * Can be overridden to customize the creation of the final UserDetailsObject which is * returned by the loadUserByUsername method. - * * @param username the name originally passed to loadUserByUsername * @param userFromUserQuery the object returned from the execution of the * @param combinedAuthorities the combined array of authorities from all the authority * loading queries. * @return the final UserDetails which should be used in the system. */ - protected UserDetails createUserDetails(String username, - UserDetails userFromUserQuery, List combinedAuthorities) { + protected UserDetails createUserDetails(String username, UserDetails userFromUserQuery, + List combinedAuthorities) { String returnUsername = userFromUserQuery.getUsername(); - if (!this.usernameBasedPrimaryKey) { returnUsername = username; } - - return new User(returnUsername, userFromUserQuery.getPassword(), - userFromUserQuery.isEnabled(), userFromUserQuery.isAccountNonExpired(), - userFromUserQuery.isCredentialsNonExpired(), userFromUserQuery.isAccountNonLocked(), combinedAuthorities); + return new User(returnUsername, userFromUserQuery.getPassword(), userFromUserQuery.isEnabled(), + userFromUserQuery.isAccountNonExpired(), userFromUserQuery.isCredentialsNonExpired(), + userFromUserQuery.isAccountNonLocked(), combinedAuthorities); } /** * Allows the default query string used to retrieve authorities based on username to * be overridden, if default table or column names need to be changed. The default * query is {@link #DEF_AUTHORITIES_BY_USERNAME_QUERY}; when modifying this query, - * ensure that all returned columns are mapped back to the same column positions as in the - * default query. - * + * ensure that all returned columns are mapped back to the same column positions as in + * the default query. * @param queryString The SQL query string to set */ public void setAuthoritiesByUsernameQuery(String queryString) { @@ -306,7 +285,6 @@ public class JdbcDaoImpl extends JdbcDaoSupport * default query is {@link #DEF_GROUP_AUTHORITIES_BY_USERNAME_QUERY}; when modifying * this query, ensure that all returned columns are mapped back to the same column * positions as in the default query. - * * @param queryString The SQL query string to set */ public void setGroupAuthoritiesByUsernameQuery(String queryString) { @@ -319,7 +297,6 @@ public class JdbcDaoImpl extends JdbcDaoSupport * example be used to add the ROLE_ prefix expected to exist in role names * (by default) by some other Spring Security classes, in the case that the prefix is * not already present in the db. - * * @param rolePrefix the new prefix */ public void setRolePrefix(String rolePrefix) { @@ -338,7 +315,6 @@ public class JdbcDaoImpl extends JdbcDaoSupport * UserDetails. If false, the class will use the * {@link #loadUserByUsername(String)} derived username in the returned * UserDetails. - * * @param usernameBasedPrimaryKey true if the mapping queries return the * username String, or false if the mapping returns a * database primary key. @@ -355,14 +331,13 @@ public class JdbcDaoImpl extends JdbcDaoSupport * Allows the default query string used to retrieve users based on username to be * overridden, if default table or column names need to be changed. The default query * is {@link #DEF_USERS_BY_USERNAME_QUERY}; when modifying this query, ensure that all - * returned columns are mapped back to the same column positions as in the default query. - * If the 'enabled' column does not exist in the source database, a permanent true - * value for this column may be returned by using a query similar to + * returned columns are mapped back to the same column positions as in the default + * query. If the 'enabled' column does not exist in the source database, a permanent + * true value for this column may be returned by using a query similar to * *
       	 * "select username,password,'true' as enabled from users where username = ?"
       	 * 
      - * * @param usersByUsernameQueryString The query string to set */ public void setUsersByUsernameQuery(String usersByUsernameQueryString) { @@ -397,4 +372,5 @@ public class JdbcDaoImpl extends JdbcDaoSupport Assert.notNull(messageSource, "messageSource cannot be null"); this.messages = new MessageSourceAccessor(messageSource); } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/jdbc/package-info.java b/core/src/main/java/org/springframework/security/core/userdetails/jdbc/package-info.java index 75238a665c..3c9399128f 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/jdbc/package-info.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/jdbc/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Exposes a JDBC-based authentication repository, implementing * {@code org.springframework.security.core.userdetails.UserDetailsService UserDetailsService}. */ package org.springframework.security.core.userdetails.jdbc; - diff --git a/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttribute.java b/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttribute.java index 0499931865..8e3643a223 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttribute.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttribute.java @@ -24,21 +24,18 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; /** - * Used by {@link org.springframework.security.provisioning.InMemoryUserDetailsManager} to temporarily store the attributes associated with a - * user. + * Used by {@link org.springframework.security.provisioning.InMemoryUserDetailsManager} to + * temporarily store the attributes associated with a user. * * @author Ben Alex */ public class UserAttribute { - // ~ Instance fields - // ================================================================================================ private List authorities = new Vector<>(); - private String password; - private boolean enabled = true; - // ~ Methods - // ======================================================================================================== + private String password; + + private boolean enabled = true; public void addAuthority(GrantedAuthority newAuthority) { this.authorities.add(newAuthority); @@ -50,7 +47,6 @@ public class UserAttribute { /** * Set all authorities for this user. - * * @param authorities {@link List} <{@link GrantedAuthority}> * @since 1.1 */ @@ -61,7 +57,6 @@ public class UserAttribute { /** * Set all authorities for this user from String values. It will create the necessary * {@link GrantedAuthority} objects. - * * @param authoritiesAsStrings {@link List} <{@link String}> * @since 1.1 */ @@ -73,20 +68,15 @@ public class UserAttribute { } public String getPassword() { - return password; + return this.password; } public boolean isEnabled() { - return enabled; + return this.enabled; } public boolean isValid() { - if ((this.password != null) && (authorities.size() > 0)) { - return true; - } - else { - return false; - } + return (this.password != null) && (this.authorities.size() > 0); } public void setEnabled(boolean enabled) { @@ -96,4 +86,5 @@ public class UserAttribute { public void setPassword(String password) { this.password = password; } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttributeEditor.java b/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttributeEditor.java index ab4248a767..09f10c18d0 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttributeEditor.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/memory/UserAttributeEditor.java @@ -29,45 +29,35 @@ import org.springframework.util.StringUtils; * @author Ben Alex */ public class UserAttributeEditor extends PropertyEditorSupport { - // ~ Methods - // ======================================================================================================== + @Override public void setAsText(String s) throws IllegalArgumentException { - if (StringUtils.hasText(s)) { - String[] tokens = StringUtils.commaDelimitedListToStringArray(s); - UserAttribute userAttrib = new UserAttribute(); - - List authoritiesAsStrings = new ArrayList<>(); - - for (int i = 0; i < tokens.length; i++) { - String currentToken = tokens[i].trim(); - - if (i == 0) { - userAttrib.setPassword(currentToken); - } - else { - if (currentToken.toLowerCase().equals("enabled")) { - userAttrib.setEnabled(true); - } - else if (currentToken.toLowerCase().equals("disabled")) { - userAttrib.setEnabled(false); - } - else { - authoritiesAsStrings.add(currentToken); - } - } - } - userAttrib.setAuthoritiesAsString(authoritiesAsStrings); - - if (userAttrib.isValid()) { - setValue(userAttrib); + if (!StringUtils.hasText(s)) { + setValue(null); + return; + } + String[] tokens = StringUtils.commaDelimitedListToStringArray(s); + UserAttribute userAttrib = new UserAttribute(); + List authoritiesAsStrings = new ArrayList<>(); + for (int i = 0; i < tokens.length; i++) { + String currentToken = tokens[i].trim(); + if (i == 0) { + userAttrib.setPassword(currentToken); } else { - setValue(null); + if (currentToken.toLowerCase().equals("enabled")) { + userAttrib.setEnabled(true); + } + else if (currentToken.toLowerCase().equals("disabled")) { + userAttrib.setEnabled(false); + } + else { + authoritiesAsStrings.add(currentToken); + } } } - else { - setValue(null); - } + userAttrib.setAuthoritiesAsString(authoritiesAsStrings); + setValue(userAttrib.isValid() ? userAttrib : null); } + } diff --git a/core/src/main/java/org/springframework/security/core/userdetails/memory/package-info.java b/core/src/main/java/org/springframework/security/core/userdetails/memory/package-info.java index cc7550c568..67b7f1ccbc 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/memory/package-info.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/memory/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Exposes an in-memory authentication repository. */ package org.springframework.security.core.userdetails.memory; - diff --git a/core/src/main/java/org/springframework/security/core/userdetails/package-info.java b/core/src/main/java/org/springframework/security/core/userdetails/package-info.java index 1f68a69eec..e18dc5218b 100644 --- a/core/src/main/java/org/springframework/security/core/userdetails/package-info.java +++ b/core/src/main/java/org/springframework/security/core/userdetails/package-info.java @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * The standard interfaces for implementing user data DAOs. *

      - * Can be the traditional {@link org.springframework.security.core.userdetails.UserDetailsService UserDetailsService} - * which uses a unique username to identify the user or, for more complex requirements, the - * {@link org.springframework.security.core.userdetails.AuthenticationUserDetailsService AuthenticationUserDetailsService}. + * Can be the traditional + * {@link org.springframework.security.core.userdetails.UserDetailsService + * UserDetailsService} which uses a unique username to identify the user or, for more + * complex requirements, the + * {@link org.springframework.security.core.userdetails.AuthenticationUserDetailsService + * AuthenticationUserDetailsService}. */ package org.springframework.security.core.userdetails; - diff --git a/core/src/main/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixin.java b/core/src/main/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixin.java index 02ad8851f4..3d8aa80232 100644 --- a/core/src/main/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixin.java +++ b/core/src/main/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixin.java @@ -16,17 +16,23 @@ package org.springframework.security.jackson2; -import com.fasterxml.jackson.annotation.*; -import org.springframework.security.core.GrantedAuthority; - import java.util.Collection; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import org.springframework.security.core.GrantedAuthority; + /** * This is a Jackson mixin class helps in serialize/deserialize - * {@link org.springframework.security.authentication.AnonymousAuthenticationToken} class. To use this class you need to register it - * with {@link com.fasterxml.jackson.databind.ObjectMapper} and {@link SimpleGrantedAuthorityMixin} because - * AnonymousAuthenticationToken contains SimpleGrantedAuthority. - *

      + * {@link org.springframework.security.authentication.AnonymousAuthenticationToken} class.
      + * To use this class you need to register it with
      + * {@link com.fasterxml.jackson.databind.ObjectMapper} and
      + * {@link SimpleGrantedAuthorityMixin} because AnonymousAuthenticationToken contains
      + * SimpleGrantedAuthority. 
        *     ObjectMapper mapper = new ObjectMapper();
        *     mapper.registerModule(new CoreJackson2Module());
        * 
      @@ -45,15 +51,17 @@ import java.util.Collection; class AnonymousAuthenticationTokenMixin { /** - * Constructor used by Jackson to create object of {@link org.springframework.security.authentication.AnonymousAuthenticationToken}. - * + * Constructor used by Jackson to create object of + * {@link org.springframework.security.authentication.AnonymousAuthenticationToken}. * @param keyHash hashCode of key provided at the time of token creation by using - * {@link org.springframework.security.authentication.AnonymousAuthenticationToken#AnonymousAuthenticationToken(String, Object, Collection)} + * {@link org.springframework.security.authentication.AnonymousAuthenticationToken#AnonymousAuthenticationToken(String, Object, Collection)} * @param principal the principal (typically a UserDetails) * @param authorities the authorities granted to the principal */ @JsonCreator - AnonymousAuthenticationTokenMixin(@JsonProperty("keyHash") Integer keyHash, @JsonProperty("principal") Object principal, - @JsonProperty("authorities") Collection authorities) { + AnonymousAuthenticationTokenMixin(@JsonProperty("keyHash") Integer keyHash, + @JsonProperty("principal") Object principal, + @JsonProperty("authorities") Collection authorities) { } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/BadCredentialsExceptionMixin.java b/core/src/main/java/org/springframework/security/jackson2/BadCredentialsExceptionMixin.java index 1c47950d7c..08287ae8dd 100644 --- a/core/src/main/java/org/springframework/security/jackson2/BadCredentialsExceptionMixin.java +++ b/core/src/main/java/org/springframework/security/jackson2/BadCredentialsExceptionMixin.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.jackson2; import com.fasterxml.jackson.annotation.JsonCreator; @@ -22,31 +23,33 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo; /** * This mixin class helps in serialize/deserialize - * {@link org.springframework.security.authentication.BadCredentialsException} class. To use this class you need to - * register it with {@link com.fasterxml.jackson.databind.ObjectMapper}. + * {@link org.springframework.security.authentication.BadCredentialsException} class. To + * use this class you need to register it with + * {@link com.fasterxml.jackson.databind.ObjectMapper}. * *
        *     ObjectMapper mapper = new ObjectMapper();
        *     mapper.registerModule(new CoreJackson2Module());
        * 
      * - * Note: This class will save TypeInfo (full class name) into a property called @class - * The cause and stackTrace are ignored in the serialization. + * Note: This class will save TypeInfo (full class name) into a property + * called @class The cause and stackTrace are ignored in the serialization. * * @author Yannick Lombardi * @see CoreJackson2Module * @since 5.0 */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) -@JsonIgnoreProperties(ignoreUnknown = true, value = {"cause", "stackTrace"}) +@JsonIgnoreProperties(ignoreUnknown = true, value = { "cause", "stackTrace" }) class BadCredentialsExceptionMixin { /** * Constructor used by Jackson to create * {@link org.springframework.security.authentication.BadCredentialsException} object. - * * @param message the detail message */ @JsonCreator - BadCredentialsExceptionMixin(@JsonProperty("message") String message) {} + BadCredentialsExceptionMixin(@JsonProperty("message") String message) { + } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/CoreJackson2Module.java b/core/src/main/java/org/springframework/security/jackson2/CoreJackson2Module.java index 8e50b2fff3..4621479ac4 100644 --- a/core/src/main/java/org/springframework/security/jackson2/CoreJackson2Module.java +++ b/core/src/main/java/org/springframework/security/jackson2/CoreJackson2Module.java @@ -16,8 +16,11 @@ package org.springframework.security.jackson2; +import java.util.Collections; + import com.fasterxml.jackson.core.Version; import com.fasterxml.jackson.databind.module.SimpleModule; + import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.RememberMeAuthenticationToken; @@ -25,20 +28,20 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.User; -import java.util.Collections; - /** - * Jackson module for spring-security-core. This module register {@link AnonymousAuthenticationTokenMixin}, - * {@link RememberMeAuthenticationTokenMixin}, {@link SimpleGrantedAuthorityMixin}, {@link UnmodifiableSetMixin}, - * {@link UserMixin} and {@link UsernamePasswordAuthenticationTokenMixin}. If no default typing enabled by default then - * it'll enable it because typing info is needed to properly serialize/deserialize objects. In order to use this module just - * add this module into your ObjectMapper configuration. + * Jackson module for spring-security-core. This module register + * {@link AnonymousAuthenticationTokenMixin}, {@link RememberMeAuthenticationTokenMixin}, + * {@link SimpleGrantedAuthorityMixin}, {@link UnmodifiableSetMixin}, {@link UserMixin} + * and {@link UsernamePasswordAuthenticationTokenMixin}. If no default typing enabled by + * default then it'll enable it because typing info is needed to properly + * serialize/deserialize objects. In order to use this module just add this module into + * your ObjectMapper configuration. * *
        *     ObjectMapper mapper = new ObjectMapper();
        *     mapper.registerModule(new CoreJackson2Module());
      - * 
      - * Note: use {@link SecurityJackson2Modules#getModules(ClassLoader)} to get list of all security modules. + *
      Note: use {@link SecurityJackson2Modules#getModules(ClassLoader)} to get list + * of all security modules. * * @author Jitendra Singh. * @see SecurityJackson2Modules @@ -57,10 +60,14 @@ public class CoreJackson2Module extends SimpleModule { context.setMixInAnnotations(AnonymousAuthenticationToken.class, AnonymousAuthenticationTokenMixin.class); context.setMixInAnnotations(RememberMeAuthenticationToken.class, RememberMeAuthenticationTokenMixin.class); context.setMixInAnnotations(SimpleGrantedAuthority.class, SimpleGrantedAuthorityMixin.class); - context.setMixInAnnotations(Collections.unmodifiableSet(Collections.emptySet()).getClass(), UnmodifiableSetMixin.class); - context.setMixInAnnotations(Collections.unmodifiableList(Collections.emptyList()).getClass(), UnmodifiableListMixin.class); + context.setMixInAnnotations(Collections.unmodifiableSet(Collections.emptySet()).getClass(), + UnmodifiableSetMixin.class); + context.setMixInAnnotations(Collections.unmodifiableList(Collections.emptyList()).getClass(), + UnmodifiableListMixin.class); context.setMixInAnnotations(User.class, UserMixin.class); - context.setMixInAnnotations(UsernamePasswordAuthenticationToken.class, UsernamePasswordAuthenticationTokenMixin.class); + context.setMixInAnnotations(UsernamePasswordAuthenticationToken.class, + UsernamePasswordAuthenticationTokenMixin.class); context.setMixInAnnotations(BadCredentialsException.class, BadCredentialsExceptionMixin.class); } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixin.java b/core/src/main/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixin.java index 3274a2c779..24d82cc0a2 100644 --- a/core/src/main/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixin.java +++ b/core/src/main/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixin.java @@ -16,20 +16,26 @@ package org.springframework.security.jackson2; -import com.fasterxml.jackson.annotation.*; -import org.springframework.security.core.GrantedAuthority; - import java.util.Collection; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import org.springframework.security.core.GrantedAuthority; + /** * This mixin class helps in serialize/deserialize - * {@link org.springframework.security.authentication.RememberMeAuthenticationToken} class. To use this class you need to register it - * with {@link com.fasterxml.jackson.databind.ObjectMapper} and 2 more mixin classes. + * {@link org.springframework.security.authentication.RememberMeAuthenticationToken} + * class. To use this class you need to register it with + * {@link com.fasterxml.jackson.databind.ObjectMapper} and 2 more mixin classes. * *
        - *
      1. {@link SimpleGrantedAuthorityMixin}
      2. - *
      3. {@link UserMixin}
      4. - *
      5. {@link UnmodifiableSetMixin}
      6. + *
      7. {@link SimpleGrantedAuthorityMixin}
      8. + *
      9. {@link UserMixin}
      10. + *
      11. {@link UnmodifiableSetMixin}
      12. *
      * *
      @@ -37,7 +43,8 @@ import java.util.Collection;
        *     mapper.registerModule(new CoreJackson2Module());
        * 
      * - * Note: This class will save TypeInfo (full class name) into a property called @class + * Note: This class will save TypeInfo (full class name) into a property + * called @class * * @author Jitendra Singh * @see CoreJackson2Module @@ -52,15 +59,16 @@ class RememberMeAuthenticationTokenMixin { /** * Constructor used by Jackson to create - * {@link org.springframework.security.authentication.RememberMeAuthenticationToken} object. - * + * {@link org.springframework.security.authentication.RememberMeAuthenticationToken} + * object. * @param keyHash hashCode of above given key. * @param principal the principal (typically a UserDetails) * @param authorities the authorities granted to the principal */ @JsonCreator RememberMeAuthenticationTokenMixin(@JsonProperty("keyHash") Integer keyHash, - @JsonProperty("principal") Object principal, - @JsonProperty("authorities") Collection authorities) { + @JsonProperty("principal") Object principal, + @JsonProperty("authorities") Collection authorities) { } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/SecurityJackson2Modules.java b/core/src/main/java/org/springframework/security/jackson2/SecurityJackson2Modules.java index 011ee95b1a..febd2b755c 100644 --- a/core/src/main/java/org/springframework/security/jackson2/SecurityJackson2Modules.java +++ b/core/src/main/java/org/springframework/security/jackson2/SecurityJackson2Modules.java @@ -16,6 +16,15 @@ package org.springframework.security.jackson2; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + import com.fasterxml.jackson.annotation.JacksonAnnotation; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.DatabindContext; @@ -31,17 +40,10 @@ import com.fasterxml.jackson.databind.jsontype.TypeIdResolver; import com.fasterxml.jackson.databind.jsontype.TypeResolverBuilder; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.util.ClassUtils; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.log.LogMessage; +import org.springframework.util.ClassUtils; /** * This utility class will find all the SecurityModules in classpath. @@ -50,8 +52,7 @@ import java.util.Set; *
        *     ObjectMapper mapper = new ObjectMapper();
        *     mapper.registerModules(SecurityJackson2Modules.getModules());
      - * 
      - * Above code is equivalent to + * Above code is equivalent to *

      *

        *     ObjectMapper mapper = new ObjectMapper();
      @@ -70,18 +71,18 @@ import java.util.Set;
       public final class SecurityJackson2Modules {
       
       	private static final Log logger = LogFactory.getLog(SecurityJackson2Modules.class);
      +
       	private static final List securityJackson2ModuleClasses = Arrays.asList(
       			"org.springframework.security.jackson2.CoreJackson2Module",
       			"org.springframework.security.cas.jackson2.CasJackson2Module",
       			"org.springframework.security.web.jackson2.WebJackson2Module",
      -			"org.springframework.security.web.server.jackson2.WebServerJackson2Module"
      -	);
      -	private static final String webServletJackson2ModuleClass =
      -			"org.springframework.security.web.jackson2.WebServletJackson2Module";
      -	private static final String oauth2ClientJackson2ModuleClass =
      -			"org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module";
      -	private static final String javaTimeJackson2ModuleClass =
      -			"com.fasterxml.jackson.datatype.jsr310.JavaTimeModule";
      +			"org.springframework.security.web.server.jackson2.WebServerJackson2Module");
      +
      +	private static final String webServletJackson2ModuleClass = "org.springframework.security.web.jackson2.WebServletJackson2Module";
      +
      +	private static final String oauth2ClientJackson2ModuleClass = "org.springframework.security.oauth2.client.jackson2.OAuth2ClientJackson2Module";
      +
      +	private static final String javaTimeJackson2ModuleClass = "com.fasterxml.jackson.datatype.jsr310.JavaTimeModule";
       
       	private SecurityJackson2Modules() {
       	}
      @@ -97,21 +98,17 @@ public final class SecurityJackson2Modules {
       
       	@SuppressWarnings("unchecked")
       	private static Module loadAndGetInstance(String className, ClassLoader loader) {
      -		Module instance = null;
       		try {
       			Class securityModule = (Class) ClassUtils.forName(className, loader);
       			if (securityModule != null) {
      -				if (logger.isDebugEnabled()) {
      -					logger.debug("Loaded module " + className + ", now registering");
      -				}
      -				instance = securityModule.newInstance();
      -			}
      -		} catch (Exception e) {
      -			if (logger.isDebugEnabled()) {
      -				logger.debug("Cannot load module " + className, e);
      +				logger.debug(LogMessage.format("Loaded module %s, now registering", className));
      +				return securityModule.newInstance();
       			}
       		}
      -		return instance;
      +		catch (Exception ex) {
      +			logger.debug(LogMessage.format("Cannot load module %s", className), ex);
      +		}
      +		return null;
       	}
       
       	/**
      @@ -136,8 +133,8 @@ public final class SecurityJackson2Modules {
       	}
       
       	/**
      -	 * @param loader    the ClassLoader to use
      -	 * @param modules   list of the modules to add
      +	 * @param loader the ClassLoader to use
      +	 * @param modules list of the modules to add
       	 * @param className name of the class to instantiate
       	 */
       	private static void addToModulesList(ClassLoader loader, List modules, String className) {
      @@ -152,60 +149,63 @@ public final class SecurityJackson2Modules {
       	 * @return a TypeResolverBuilder that restricts allowed types.
       	 */
       	private static TypeResolverBuilder createAllowlistedDefaultTyping() {
      -		TypeResolverBuilder  result = new AllowlistTypeResolverBuilder(ObjectMapper.DefaultTyping.NON_FINAL);
      +		TypeResolverBuilder result = new AllowlistTypeResolverBuilder(
      +				ObjectMapper.DefaultTyping.NON_FINAL);
       		result = result.init(JsonTypeInfo.Id.CLASS, null);
       		result = result.inclusion(JsonTypeInfo.As.PROPERTY);
       		return result;
       	}
       
       	/**
      -	 * An implementation of {@link ObjectMapper.DefaultTypeResolverBuilder}
      -	 * that inserts an {@code allow all} {@link PolymorphicTypeValidator}
      -	 * and overrides the {@code TypeIdResolver}
      +	 * An implementation of {@link ObjectMapper.DefaultTypeResolverBuilder} that inserts
      +	 * an {@code allow all} {@link PolymorphicTypeValidator} and overrides the
      +	 * {@code TypeIdResolver}
      +	 *
       	 * @author Rob Winch
       	 */
       	static class AllowlistTypeResolverBuilder extends ObjectMapper.DefaultTypeResolverBuilder {
       
       		AllowlistTypeResolverBuilder(ObjectMapper.DefaultTyping defaultTyping) {
      -			super(
      -					defaultTyping,
      -					//we do explicit validation in the TypeIdResolver
      -					BasicPolymorphicTypeValidator.builder()
      -							.allowIfSubType(Object.class)
      -							.build()
      -			);
      +			super(defaultTyping,
      +					// we do explicit validation in the TypeIdResolver
      +					BasicPolymorphicTypeValidator.builder().allowIfSubType(Object.class).build());
       		}
       
       		@Override
      -		protected TypeIdResolver idResolver(MapperConfig config,
      -				JavaType baseType,
      -				PolymorphicTypeValidator subtypeValidator,
      -				Collection subtypes, boolean forSer, boolean forDeser) {
      +		protected TypeIdResolver idResolver(MapperConfig config, JavaType baseType,
      +				PolymorphicTypeValidator subtypeValidator, Collection subtypes, boolean forSer,
      +				boolean forDeser) {
       			TypeIdResolver result = super.idResolver(config, baseType, subtypeValidator, subtypes, forSer, forDeser);
       			return new AllowlistTypeIdResolver(result);
       		}
      +
       	}
       
       	/**
      -	 * A {@link TypeIdResolver} that delegates to an existing implementation and throws an IllegalStateException if the
      -	 * class being looked up is not in the allowlist, does not provide an explicit mixin, and is not annotated with Jackson
      -	 * mappings. See https://github.com/spring-projects/spring-security/issues/4370
      +	 * A {@link TypeIdResolver} that delegates to an existing implementation and throws an
      +	 * IllegalStateException if the class being looked up is not in the allowlist, does
      +	 * not provide an explicit mixin, and is not annotated with Jackson mappings. See
      +	 * https://github.com/spring-projects/spring-security/issues/4370
       	 */
       	static class AllowlistTypeIdResolver implements TypeIdResolver {
      -		private static final Set ALLOWLIST_CLASS_NAMES = Collections.unmodifiableSet(new HashSet(Arrays.asList(
      -			"java.util.ArrayList",
      -			"java.util.Collections$EmptyList",
      -			"java.util.Collections$EmptyMap",
      -			"java.util.Collections$UnmodifiableRandomAccessList",
      -			"java.util.Collections$SingletonList",
      -			"java.util.Date",
      -			"java.time.Instant",
      -			"java.net.URL",
      -			"java.util.TreeMap",
      -			"java.util.HashMap",
      -			"java.util.LinkedHashMap",
      -			"org.springframework.security.core.context.SecurityContextImpl"
      -		)));
      +
      +		private static final Set ALLOWLIST_CLASS_NAMES;
      +		static {
      +			Set names = new HashSet<>();
      +			names.add("java.util.ArrayList");
      +			names.add("java.util.Collections$EmptyList");
      +			names.add("java.util.Collections$EmptyMap");
      +			names.add("java.util.Collections$UnmodifiableRandomAccessList");
      +			names.add("java.util.Collections$SingletonList");
      +			names.add("java.util.Date");
      +			names.add("java.time.Instant");
      +			names.add("java.net.URL");
      +			names.add("java.util.TreeMap");
      +			names.add("java.util.HashMap");
      +			names.add("java.util.LinkedHashMap");
      +			names.add("org.springframework.security.core.context.SecurityContextImpl");
      +			ALLOWLIST_CLASS_NAMES = Collections.unmodifiableSet(names);
      +		}
       
       		private final TypeIdResolver delegate;
       
      @@ -215,28 +215,28 @@ public final class SecurityJackson2Modules {
       
       		@Override
       		public void init(JavaType baseType) {
      -			delegate.init(baseType);
      +			this.delegate.init(baseType);
       		}
       
       		@Override
       		public String idFromValue(Object value) {
      -			return delegate.idFromValue(value);
      +			return this.delegate.idFromValue(value);
       		}
       
       		@Override
       		public String idFromValueAndType(Object value, Class suggestedType) {
      -			return delegate.idFromValueAndType(value, suggestedType);
      +			return this.delegate.idFromValueAndType(value, suggestedType);
       		}
       
       		@Override
       		public String idFromBaseType() {
      -			return delegate.idFromBaseType();
      +			return this.delegate.idFromBaseType();
       		}
       
       		@Override
       		public JavaType typeFromId(DatabindContext context, String id) throws IOException {
       			DeserializationConfig config = (DeserializationConfig) context.getConfig();
      -			JavaType result = delegate.typeFromId(context, id);
      +			JavaType result = this.delegate.typeFromId(context, id);
       			String className = result.getRawClass().getName();
       			if (isInAllowlist(className)) {
       				return result;
      @@ -245,14 +245,16 @@ public final class SecurityJackson2Modules {
       			if (isExplicitMixin) {
       				return result;
       			}
      -			JacksonAnnotation jacksonAnnotation = AnnotationUtils.findAnnotation(result.getRawClass(), JacksonAnnotation.class);
      +			JacksonAnnotation jacksonAnnotation = AnnotationUtils.findAnnotation(result.getRawClass(),
      +					JacksonAnnotation.class);
       			if (jacksonAnnotation != null) {
       				return result;
       			}
      -			throw new IllegalArgumentException("The class with " + id + " and name of " + className + " is not in the allowlist. " +
      -				"If you believe this class is safe to deserialize, please provide an explicit mapping using Jackson annotations or by providing a Mixin. " +
      -				"If the serialization is only done by a trusted source, you can also enable default typing. " +
      -				"See https://github.com/spring-projects/spring-security/issues/4370 for details");
      +			throw new IllegalArgumentException("The class with " + id + " and name of " + className
      +					+ " is not in the allowlist. "
      +					+ "If you believe this class is safe to deserialize, please provide an explicit mapping using Jackson annotations or by providing a Mixin. "
      +					+ "If the serialization is only done by a trusted source, you can also enable default typing. "
      +					+ "See https://github.com/spring-projects/spring-security/issues/4370 for details");
       		}
       
       		private boolean isInAllowlist(String id) {
      @@ -261,13 +263,14 @@ public final class SecurityJackson2Modules {
       
       		@Override
       		public String getDescForKnownTypeIds() {
      -			return delegate.getDescForKnownTypeIds();
      +			return this.delegate.getDescForKnownTypeIds();
       		}
       
       		@Override
       		public JsonTypeInfo.Id getMechanism() {
      -			return delegate.getMechanism();
      +			return this.delegate.getMechanism();
       		}
       
       	}
      +
       }
      diff --git a/core/src/main/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixin.java b/core/src/main/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixin.java
      index 4fa8558dc3..4c3beadcc5 100644
      --- a/core/src/main/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixin.java
      +++ b/core/src/main/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixin.java
      @@ -16,7 +16,11 @@
       
       package org.springframework.security.jackson2;
       
      -import com.fasterxml.jackson.annotation.*;
      +import com.fasterxml.jackson.annotation.JsonAutoDetect;
      +import com.fasterxml.jackson.annotation.JsonCreator;
      +import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
      +import com.fasterxml.jackson.annotation.JsonProperty;
      +import com.fasterxml.jackson.annotation.JsonTypeInfo;
       
       /**
        * Jackson Mixin class helps in serialize/deserialize
      @@ -26,14 +30,15 @@ import com.fasterxml.jackson.annotation.*;
        *     ObjectMapper mapper = new ObjectMapper();
        *     mapper.registerModule(new CoreJackson2Module());
        * 
      + * * @author Jitendra Singh * @see CoreJackson2Module * @see SecurityJackson2Modules * @since 4.2 */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) -@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.NONE, getterVisibility = JsonAutoDetect.Visibility.PUBLIC_ONLY, - isGetterVisibility = JsonAutoDetect.Visibility.NONE) +@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.NONE, + getterVisibility = JsonAutoDetect.Visibility.PUBLIC_ONLY, isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonIgnoreProperties(ignoreUnknown = true) public abstract class SimpleGrantedAuthorityMixin { @@ -44,4 +49,5 @@ public abstract class SimpleGrantedAuthorityMixin { @JsonCreator public SimpleGrantedAuthorityMixin(@JsonProperty("authority") String role) { } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListDeserializer.java b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListDeserializer.java index af4075f014..fd86dadccc 100644 --- a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListDeserializer.java +++ b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListDeserializer.java @@ -16,6 +16,11 @@ package org.springframework.security.jackson2; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationContext; @@ -24,17 +29,12 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - /** * Custom deserializer for {@link UnmodifiableListDeserializer}. * * @author Rob Winch - * @see UnmodifiableListMixin * @since 5.0.2 + * @see UnmodifiableListMixin */ class UnmodifiableListDeserializer extends JsonDeserializer { @@ -49,10 +49,12 @@ class UnmodifiableListDeserializer extends JsonDeserializer { for (JsonNode elementNode : arrayNode) { result.add(mapper.readValue(elementNode.traverse(mapper), Object.class)); } - } else { + } + else { result.add(mapper.readValue(node.traverse(mapper), Object.class)); } } return Collections.unmodifiableList(result); } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListMixin.java b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListMixin.java index b5fb86e76c..9483b79719 100644 --- a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListMixin.java +++ b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableListMixin.java @@ -16,12 +16,12 @@ package org.springframework.security.jackson2; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import java.util.Set; - /** * This mixin class used to deserialize java.util.Collections$UnmodifiableRandomAccessList * and used with various AuthenticationToken implementation's mixin classes. @@ -46,5 +46,7 @@ class UnmodifiableListMixin { * @param s the Set */ @JsonCreator - UnmodifiableListMixin(Set s) {} + UnmodifiableListMixin(Set s) { + } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetDeserializer.java b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetDeserializer.java index c25ca727f6..c26d6921b5 100644 --- a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetDeserializer.java +++ b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetDeserializer.java @@ -16,6 +16,11 @@ package org.springframework.security.jackson2; +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationContext; @@ -24,17 +29,12 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; -import java.io.IOException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - /** * Custom deserializer for {@link UnmodifiableSetMixin}. * * @author Jitendra Singh - * @see UnmodifiableSetMixin * @since 4.2 + * @see UnmodifiableSetMixin */ class UnmodifiableSetDeserializer extends JsonDeserializer { @@ -49,10 +49,12 @@ class UnmodifiableSetDeserializer extends JsonDeserializer { for (JsonNode elementNode : arrayNode) { resultSet.add(mapper.readValue(elementNode.traverse(mapper), Object.class)); } - } else { + } + else { resultSet.add(mapper.readValue(node.traverse(mapper), Object.class)); } } return Collections.unmodifiableSet(resultSet); } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetMixin.java b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetMixin.java index 09181a88e6..2dba600d68 100644 --- a/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetMixin.java +++ b/core/src/main/java/org/springframework/security/jackson2/UnmodifiableSetMixin.java @@ -16,15 +16,15 @@ package org.springframework.security.jackson2; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import java.util.Set; - /** - * This mixin class used to deserialize java.util.Collections$UnmodifiableSet and used with various AuthenticationToken - * implementation's mixin classes. + * This mixin class used to deserialize java.util.Collections$UnmodifiableSet and used + * with various AuthenticationToken implementation's mixin classes. * *
        *     ObjectMapper mapper = new ObjectMapper();
      @@ -46,5 +46,7 @@ class UnmodifiableSetMixin {
       	 * @param s the Set
       	 */
       	@JsonCreator
      -	UnmodifiableSetMixin(Set s) {}
      +	UnmodifiableSetMixin(Set s) {
      +	}
      +
       }
      diff --git a/core/src/main/java/org/springframework/security/jackson2/UserDeserializer.java b/core/src/main/java/org/springframework/security/jackson2/UserDeserializer.java
      index 96d3ffe748..8363acc332 100644
      --- a/core/src/main/java/org/springframework/security/jackson2/UserDeserializer.java
      +++ b/core/src/main/java/org/springframework/security/jackson2/UserDeserializer.java
      @@ -16,6 +16,9 @@
       
       package org.springframework.security.jackson2;
       
      +import java.io.IOException;
      +import java.util.Set;
      +
       import com.fasterxml.jackson.core.JsonParser;
       import com.fasterxml.jackson.core.JsonProcessingException;
       import com.fasterxml.jackson.core.type.TypeReference;
      @@ -24,28 +27,29 @@ import com.fasterxml.jackson.databind.JsonDeserializer;
       import com.fasterxml.jackson.databind.JsonNode;
       import com.fasterxml.jackson.databind.ObjectMapper;
       import com.fasterxml.jackson.databind.node.MissingNode;
      +
       import org.springframework.security.core.GrantedAuthority;
       import org.springframework.security.core.authority.SimpleGrantedAuthority;
       import org.springframework.security.core.userdetails.User;
       
      -import java.io.IOException;
      -import java.util.Set;
      -
       /**
      - * Custom Deserializer for {@link User} class. This is already registered with {@link UserMixin}.
      - * You can also use it directly with your mixin class.
      + * Custom Deserializer for {@link User} class. This is already registered with
      + * {@link UserMixin}. You can also use it directly with your mixin class.
        *
        * @author Jitendra Singh
      - * @see UserMixin
        * @since 4.2
      + * @see UserMixin
        */
       class UserDeserializer extends JsonDeserializer {
       
      +	private static final TypeReference> SIMPLE_GRANTED_AUTHORITY_SET = new TypeReference>() {
      +	};
      +
       	/**
      -	 * This method will create {@link User} object. It will ensure successful object creation even if password key is null in
      -	 * serialized json, because credentials may be removed from the {@link User} by invoking {@link User#eraseCredentials()}.
      -	 * In that case there won't be any password key in serialized json.
      -	 *
      +	 * This method will create {@link User} object. It will ensure successful object
      +	 * creation even if password key is null in serialized json, because credentials may
      +	 * be removed from the {@link User} by invoking {@link User#eraseCredentials()}. In
      +	 * that case there won't be any password key in serialized json.
       	 * @param jp the JsonParser
       	 * @param ctxt the DeserializationContext
       	 * @return the user
      @@ -56,20 +60,18 @@ class UserDeserializer extends JsonDeserializer {
       	public User deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException, JsonProcessingException {
       		ObjectMapper mapper = (ObjectMapper) jp.getCodec();
       		JsonNode jsonNode = mapper.readTree(jp);
      -		Set authorities =
      -				mapper.convertValue(
      -						jsonNode.get("authorities"),
      -						new TypeReference>() {}
      -				);
      -		JsonNode password = readJsonNode(jsonNode, "password");
      -		User result =  new User(
      -				readJsonNode(jsonNode, "username").asText(), password.asText(""),
      -				readJsonNode(jsonNode, "enabled").asBoolean(), readJsonNode(jsonNode, "accountNonExpired").asBoolean(),
      -				readJsonNode(jsonNode, "credentialsNonExpired").asBoolean(),
      -				readJsonNode(jsonNode, "accountNonLocked").asBoolean(), authorities
      -		);
      -
      -		if (password.asText(null) == null) {
      +		Set authorities = mapper.convertValue(jsonNode.get("authorities"),
      +				SIMPLE_GRANTED_AUTHORITY_SET);
      +		JsonNode passwordNode = readJsonNode(jsonNode, "password");
      +		String username = readJsonNode(jsonNode, "username").asText();
      +		String password = passwordNode.asText("");
      +		boolean enabled = readJsonNode(jsonNode, "enabled").asBoolean();
      +		boolean accountNonExpired = readJsonNode(jsonNode, "accountNonExpired").asBoolean();
      +		boolean credentialsNonExpired = readJsonNode(jsonNode, "credentialsNonExpired").asBoolean();
      +		boolean accountNonLocked = readJsonNode(jsonNode, "accountNonLocked").asBoolean();
      +		User result = new User(username, password, enabled, accountNonExpired, credentialsNonExpired, accountNonLocked,
      +				authorities);
      +		if (passwordNode.asText(null) == null) {
       			result.eraseCredentials();
       		}
       		return result;
      @@ -78,4 +80,5 @@ class UserDeserializer extends JsonDeserializer {
       	private JsonNode readJsonNode(JsonNode jsonNode, String field) {
       		return jsonNode.has(field) ? jsonNode.get(field) : MissingNode.getInstance();
       	}
      +
       }
      diff --git a/core/src/main/java/org/springframework/security/jackson2/UserMixin.java b/core/src/main/java/org/springframework/security/jackson2/UserMixin.java
      index 36fca4e808..0cb7b3c332 100644
      --- a/core/src/main/java/org/springframework/security/jackson2/UserMixin.java
      +++ b/core/src/main/java/org/springframework/security/jackson2/UserMixin.java
      @@ -22,12 +22,14 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
       import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
       
       /**
      - * This mixin class helps in serialize/deserialize {@link org.springframework.security.core.userdetails.User}.
      - * This class also register a custom deserializer {@link UserDeserializer} to deserialize User object successfully.
      - * In order to use this mixin you need to register two more mixin classes in your ObjectMapper configuration.
      + * This mixin class helps in serialize/deserialize
      + * {@link org.springframework.security.core.userdetails.User}. This class also register a
      + * custom deserializer {@link UserDeserializer} to deserialize User object successfully.
      + * In order to use this mixin you need to register two more mixin classes in your
      + * ObjectMapper configuration.
        * 
        - *
      1. {@link SimpleGrantedAuthorityMixin}
      2. - *
      3. {@link UnmodifiableSetMixin}
      4. + *
      5. {@link SimpleGrantedAuthorityMixin}
      6. + *
      7. {@link UnmodifiableSetMixin}
      8. *
      *
        *     ObjectMapper mapper = new ObjectMapper();
      @@ -46,4 +48,5 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
       		isGetterVisibility = JsonAutoDetect.Visibility.NONE)
       @JsonIgnoreProperties(ignoreUnknown = true)
       abstract class UserMixin {
      +
       }
      diff --git a/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenDeserializer.java b/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenDeserializer.java
      index 64a5b4a7b0..c5d815ad79 100644
      --- a/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenDeserializer.java
      +++ b/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenDeserializer.java
      @@ -19,11 +19,13 @@ package org.springframework.security.jackson2;
       import java.io.IOException;
       import java.util.List;
       
      +import com.fasterxml.jackson.core.JsonParseException;
       import com.fasterxml.jackson.core.JsonParser;
       import com.fasterxml.jackson.core.JsonProcessingException;
       import com.fasterxml.jackson.core.type.TypeReference;
       import com.fasterxml.jackson.databind.DeserializationContext;
       import com.fasterxml.jackson.databind.JsonDeserializer;
      +import com.fasterxml.jackson.databind.JsonMappingException;
       import com.fasterxml.jackson.databind.JsonNode;
       import com.fasterxml.jackson.databind.ObjectMapper;
       import com.fasterxml.jackson.databind.node.MissingNode;
      @@ -32,23 +34,31 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio
       import org.springframework.security.core.GrantedAuthority;
       
       /**
      - * Custom deserializer for {@link UsernamePasswordAuthenticationToken}. At the time of deserialization
      - * it will invoke suitable constructor depending on the value of authenticated property.
      - * It will ensure that the token's state must not change.
      + * Custom deserializer for {@link UsernamePasswordAuthenticationToken}. At the time of
      + * deserialization it will invoke suitable constructor depending on the value of
      + * authenticated property. It will ensure that the token's state must not change.
        * 

      - * This deserializer is already registered with {@link UsernamePasswordAuthenticationTokenMixin} but - * you can also registered it with your own mixin class. + * This deserializer is already registered with + * {@link UsernamePasswordAuthenticationTokenMixin} but you can also registered it with + * your own mixin class. * * @author Jitendra Singh * @author Greg Turnquist * @author Onur Kagan Ozcan - * @see UsernamePasswordAuthenticationTokenMixin * @since 4.2 + * @see UsernamePasswordAuthenticationTokenMixin */ class UsernamePasswordAuthenticationTokenDeserializer extends JsonDeserializer { + private static final TypeReference> GRANTED_AUTHORITY_LIST = new TypeReference>() { + }; + + private static final TypeReference OBJECT = new TypeReference() { + }; + /** - * This method construct {@link UsernamePasswordAuthenticationToken} object from serialized json. + * This method construct {@link UsernamePasswordAuthenticationToken} object from + * serialized json. * @param jp the JsonParser * @param ctxt the DeserializationContext * @return the user @@ -56,44 +66,48 @@ class UsernamePasswordAuthenticationTokenDeserializer extends JsonDeserializer authorities = mapper.readValue( - readJsonNode(jsonNode, "authorities").traverse(mapper), new TypeReference>() { - }); - if (authenticated) { - token = new UsernamePasswordAuthenticationToken(principal, credentials, authorities); - } else { - token = new UsernamePasswordAuthenticationToken(principal, credentials); - } + Object credentials = getCredentials(credentialsNode); + List authorities = mapper.readValue(readJsonNode(jsonNode, "authorities").traverse(mapper), + GRANTED_AUTHORITY_LIST); + UsernamePasswordAuthenticationToken token = (!authenticated) + ? new UsernamePasswordAuthenticationToken(principal, credentials) + : new UsernamePasswordAuthenticationToken(principal, credentials, authorities); JsonNode detailsNode = readJsonNode(jsonNode, "details"); if (detailsNode.isNull() || detailsNode.isMissingNode()) { token.setDetails(null); - } else { - Object details = mapper.readValue(detailsNode.toString(), new TypeReference() {}); + } + else { + Object details = mapper.readValue(detailsNode.toString(), OBJECT); token.setDetails(details); } return token; } + private Object getCredentials(JsonNode credentialsNode) { + if (credentialsNode.isNull() || credentialsNode.isMissingNode()) { + return null; + } + return credentialsNode.asText(); + } + + private Object getPrincipal(ObjectMapper mapper, JsonNode principalNode) + throws IOException, JsonParseException, JsonMappingException { + if (principalNode.isObject()) { + return mapper.readValue(principalNode.traverse(mapper), Object.class); + } + return principalNode.asText(); + } + private JsonNode readJsonNode(JsonNode jsonNode, String field) { return jsonNode.has(field) ? jsonNode.get(field) : MissingNode.getInstance(); } + } diff --git a/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixin.java b/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixin.java index 4c0b58e194..e2bc37e3f1 100644 --- a/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixin.java +++ b/core/src/main/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixin.java @@ -16,25 +16,28 @@ package org.springframework.security.jackson2; -import com.fasterxml.jackson.annotation.*; +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; /** * This mixin class is used to serialize / deserialize - * {@link org.springframework.security.authentication.UsernamePasswordAuthenticationToken}. This class register - * a custom deserializer {@link UsernamePasswordAuthenticationTokenDeserializer}. + * {@link org.springframework.security.authentication.UsernamePasswordAuthenticationToken}. + * This class register a custom deserializer + * {@link UsernamePasswordAuthenticationTokenDeserializer}. * * In order to use this mixin you'll need to add 3 more mixin classes. *
        - *
      1. {@link UnmodifiableSetMixin}
      2. - *
      3. {@link SimpleGrantedAuthorityMixin}
      4. - *
      5. {@link UserMixin}
      6. + *
      7. {@link UnmodifiableSetMixin}
      8. + *
      9. {@link SimpleGrantedAuthorityMixin}
      10. + *
      11. {@link UserMixin}
      12. *
      * *
        *     ObjectMapper mapper = new ObjectMapper();
        *     mapper.registerModule(new CoreJackson2Module());
        * 
      + * * @author Jitendra Singh * @see CoreJackson2Module * @see SecurityJackson2Modules @@ -45,4 +48,5 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonDeserialize(using = UsernamePasswordAuthenticationTokenDeserializer.class) abstract class UsernamePasswordAuthenticationTokenMixin { + } diff --git a/core/src/main/java/org/springframework/security/jackson2/package-info.java b/core/src/main/java/org/springframework/security/jackson2/package-info.java index 9e818414a2..a922b96576 100644 --- a/core/src/main/java/org/springframework/security/jackson2/package-info.java +++ b/core/src/main/java/org/springframework/security/jackson2/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Mix-in classes to add Jackson serialization support. * @@ -20,7 +21,3 @@ * @since 4.2 */ package org.springframework.security.jackson2; - -/** - * Package contains Jackson mixin classes. - */ \ No newline at end of file diff --git a/core/src/main/java/org/springframework/security/provisioning/GroupManager.java b/core/src/main/java/org/springframework/security/provisioning/GroupManager.java index 4dfcb6031b..763c2ca65e 100644 --- a/core/src/main/java/org/springframework/security/provisioning/GroupManager.java +++ b/core/src/main/java/org/springframework/security/provisioning/GroupManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.provisioning; import java.util.List; @@ -40,7 +41,6 @@ public interface GroupManager { /** * Locates the users who are members of a group - * * @param groupName the group whose members are required * @return the usernames of the group members */ @@ -48,7 +48,6 @@ public interface GroupManager { /** * Creates a new group with the specified list of authorities. - * * @param groupName the name for the new group * @param authorities the authorities which are to be allocated to this group. */ @@ -56,7 +55,6 @@ public interface GroupManager { /** * Removes a group, including all members and authorities. - * * @param groupName the group to remove. */ void deleteGroup(String groupName); @@ -68,7 +66,6 @@ public interface GroupManager { /** * Makes a user a member of a particular group. - * * @param username the user to be given membership. * @param group the name of the group to which the user will be added. */ @@ -76,7 +73,6 @@ public interface GroupManager { /** * Deletes a user's membership of a group. - * * @param username the user * @param groupName the group to remove them from */ @@ -96,4 +92,5 @@ public interface GroupManager { * Deletes an authority from those assigned to a group */ void removeGroupAuthority(String groupName, GrantedAuthority authority); + } diff --git a/core/src/main/java/org/springframework/security/provisioning/InMemoryUserDetailsManager.java b/core/src/main/java/org/springframework/security/provisioning/InMemoryUserDetailsManager.java index f70e1b0af6..346c481cf8 100644 --- a/core/src/main/java/org/springframework/security/provisioning/InMemoryUserDetailsManager.java +++ b/core/src/main/java/org/springframework/security/provisioning/InMemoryUserDetailsManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.provisioning; import java.util.Collection; @@ -23,6 +24,8 @@ import java.util.Properties; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -46,8 +49,8 @@ import org.springframework.util.Assert; * @author Luke Taylor * @since 3.1 */ -public class InMemoryUserDetailsManager implements UserDetailsManager, - UserDetailsPasswordService { +public class InMemoryUserDetailsManager implements UserDetailsManager, UserDetailsPasswordService { + protected final Log logger = LogFactory.getLog(getClass()); private final Map users = new HashMap<>(); @@ -72,71 +75,61 @@ public class InMemoryUserDetailsManager implements UserDetailsManager, public InMemoryUserDetailsManager(Properties users) { Enumeration names = users.propertyNames(); UserAttributeEditor editor = new UserAttributeEditor(); - while (names.hasMoreElements()) { String name = (String) names.nextElement(); editor.setAsText(users.getProperty(name)); UserAttribute attr = (UserAttribute) editor.getValue(); - UserDetails user = new User(name, attr.getPassword(), attr.isEnabled(), true, - true, true, attr.getAuthorities()); - createUser(user); + createUser(createUserDetails(name, attr)); } } + private User createUserDetails(String name, UserAttribute attr) { + return new User(name, attr.getPassword(), attr.isEnabled(), true, true, true, attr.getAuthorities()); + } + + @Override public void createUser(UserDetails user) { Assert.isTrue(!userExists(user.getUsername()), "user should not exist"); - - users.put(user.getUsername().toLowerCase(), new MutableUser(user)); + this.users.put(user.getUsername().toLowerCase(), new MutableUser(user)); } + @Override public void deleteUser(String username) { - users.remove(username.toLowerCase()); + this.users.remove(username.toLowerCase()); } + @Override public void updateUser(UserDetails user) { Assert.isTrue(userExists(user.getUsername()), "user should exist"); - - users.put(user.getUsername().toLowerCase(), new MutableUser(user)); + this.users.put(user.getUsername().toLowerCase(), new MutableUser(user)); } + @Override public boolean userExists(String username) { - return users.containsKey(username.toLowerCase()); + return this.users.containsKey(username.toLowerCase()); } + @Override public void changePassword(String oldPassword, String newPassword) { - Authentication currentUser = SecurityContextHolder.getContext() - .getAuthentication(); - + Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); if (currentUser == null) { // This would indicate bad coding somewhere throw new AccessDeniedException( - "Can't change password as no Authentication object found in context " - + "for current user."); + "Can't change password as no Authentication object found in context " + "for current user."); } - String username = currentUser.getName(); - - logger.debug("Changing password for user '" + username + "'"); - + this.logger.debug(LogMessage.format("Changing password for user '%s'", username)); // If an authentication manager has been set, re-authenticate the user with the // supplied password. - if (authenticationManager != null) { - logger.debug("Reauthenticating user '" + username - + "' for password change request."); - - authenticationManager.authenticate(new UsernamePasswordAuthenticationToken( - username, oldPassword)); + if (this.authenticationManager != null) { + this.logger.debug(LogMessage.format("Reauthenticating user '%s' for password change request.", username)); + this.authenticationManager.authenticate(new UsernamePasswordAuthenticationToken(username, oldPassword)); } else { - logger.debug("No authentication manager set. Password won't be re-checked."); + this.logger.debug("No authentication manager set. Password won't be re-checked."); } - - MutableUserDetails user = users.get(username); - - if (user == null) { - throw new IllegalStateException("Current user doesn't exist in database."); - } - + MutableUserDetails user = this.users.get(username); + Assert.state(user != null, "Current user doesn't exist in database."); user.setPassword(newPassword); } @@ -148,20 +141,18 @@ public class InMemoryUserDetailsManager implements UserDetailsManager, return mutableUser; } - public UserDetails loadUserByUsername(String username) - throws UsernameNotFoundException { - UserDetails user = users.get(username.toLowerCase()); - + @Override + public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException { + UserDetails user = this.users.get(username.toLowerCase()); if (user == null) { throw new UsernameNotFoundException(username); } - - return new User(user.getUsername(), user.getPassword(), user.isEnabled(), - user.isAccountNonExpired(), user.isCredentialsNonExpired(), - user.isAccountNonLocked(), user.getAuthorities()); + return new User(user.getUsername(), user.getPassword(), user.isEnabled(), user.isAccountNonExpired(), + user.isCredentialsNonExpired(), user.isAccountNonLocked(), user.getAuthorities()); } public void setAuthenticationManager(AuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; } + } diff --git a/core/src/main/java/org/springframework/security/provisioning/JdbcUserDetailsManager.java b/core/src/main/java/org/springframework/security/provisioning/JdbcUserDetailsManager.java index 5d298568c0..75a96b478a 100644 --- a/core/src/main/java/org/springframework/security/provisioning/JdbcUserDetailsManager.java +++ b/core/src/main/java/org/springframework/security/provisioning/JdbcUserDetailsManager.java @@ -13,8 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.provisioning; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Collection; +import java.util.List; + +import javax.sql.DataSource; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.context.ApplicationContextException; +import org.springframework.core.log.LogMessage; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -29,18 +44,8 @@ import org.springframework.security.core.userdetails.UserCache; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.cache.NullUserCache; import org.springframework.security.core.userdetails.jdbc.JdbcDaoImpl; -import org.springframework.context.ApplicationContextException; -import org.springframework.dao.IncorrectResultSizeDataAccessException; -import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.util.Assert; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import javax.sql.DataSource; -import java.util.Collection; -import java.util.List; - /** * Jdbc user management service, based on the same table structure as its parent class, * JdbcDaoImpl. @@ -56,64 +61,90 @@ import java.util.List; * @author Luke Taylor * @since 2.0 */ -public class JdbcUserDetailsManager extends JdbcDaoImpl implements UserDetailsManager, - GroupManager { - // ~ Static fields/initializers - // ===================================================================================== +public class JdbcUserDetailsManager extends JdbcDaoImpl implements UserDetailsManager, GroupManager { - // UserDetailsManager SQL public static final String DEF_CREATE_USER_SQL = "insert into users (username, password, enabled) values (?,?,?)"; + public static final String DEF_DELETE_USER_SQL = "delete from users where username = ?"; + public static final String DEF_UPDATE_USER_SQL = "update users set password = ?, enabled = ? where username = ?"; + public static final String DEF_INSERT_AUTHORITY_SQL = "insert into authorities (username, authority) values (?,?)"; + public static final String DEF_DELETE_USER_AUTHORITIES_SQL = "delete from authorities where username = ?"; + public static final String DEF_USER_EXISTS_SQL = "select username from users where username = ?"; + public static final String DEF_CHANGE_PASSWORD_SQL = "update users set password = ? where username = ?"; - // GroupManager SQL public static final String DEF_FIND_GROUPS_SQL = "select group_name from groups"; + public static final String DEF_FIND_USERS_IN_GROUP_SQL = "select username from group_members gm, groups g " + "where gm.group_id = g.id and g.group_name = ?"; - public static final String DEF_INSERT_GROUP_SQL = "insert into groups (group_name) values (?)"; - public static final String DEF_FIND_GROUP_ID_SQL = "select id from groups where group_name = ?"; - public static final String DEF_INSERT_GROUP_AUTHORITY_SQL = "insert into group_authorities (group_id, authority) values (?,?)"; - public static final String DEF_DELETE_GROUP_SQL = "delete from groups where id = ?"; - public static final String DEF_DELETE_GROUP_AUTHORITIES_SQL = "delete from group_authorities where group_id = ?"; - public static final String DEF_DELETE_GROUP_MEMBERS_SQL = "delete from group_members where group_id = ?"; - public static final String DEF_RENAME_GROUP_SQL = "update groups set group_name = ? where group_name = ?"; - public static final String DEF_INSERT_GROUP_MEMBER_SQL = "insert into group_members (group_id, username) values (?,?)"; - public static final String DEF_DELETE_GROUP_MEMBER_SQL = "delete from group_members where group_id = ? and username = ?"; - public static final String DEF_GROUP_AUTHORITIES_QUERY_SQL = "select g.id, g.group_name, ga.authority " - + "from groups g, group_authorities ga " - + "where g.group_name = ? " - + "and g.id = ga.group_id "; - public static final String DEF_DELETE_GROUP_AUTHORITY_SQL = "delete from group_authorities where group_id = ? and authority = ?"; - // ~ Instance fields - // ================================================================================================ + public static final String DEF_INSERT_GROUP_SQL = "insert into groups (group_name) values (?)"; + + public static final String DEF_FIND_GROUP_ID_SQL = "select id from groups where group_name = ?"; + + public static final String DEF_INSERT_GROUP_AUTHORITY_SQL = "insert into group_authorities (group_id, authority) values (?,?)"; + + public static final String DEF_DELETE_GROUP_SQL = "delete from groups where id = ?"; + + public static final String DEF_DELETE_GROUP_AUTHORITIES_SQL = "delete from group_authorities where group_id = ?"; + + public static final String DEF_DELETE_GROUP_MEMBERS_SQL = "delete from group_members where group_id = ?"; + + public static final String DEF_RENAME_GROUP_SQL = "update groups set group_name = ? where group_name = ?"; + + public static final String DEF_INSERT_GROUP_MEMBER_SQL = "insert into group_members (group_id, username) values (?,?)"; + + public static final String DEF_DELETE_GROUP_MEMBER_SQL = "delete from group_members where group_id = ? and username = ?"; + + public static final String DEF_GROUP_AUTHORITIES_QUERY_SQL = "select g.id, g.group_name, ga.authority " + + "from groups g, group_authorities ga " + "where g.group_name = ? " + "and g.id = ga.group_id "; + + public static final String DEF_DELETE_GROUP_AUTHORITY_SQL = "delete from group_authorities where group_id = ? and authority = ?"; protected final Log logger = LogFactory.getLog(getClass()); private String createUserSql = DEF_CREATE_USER_SQL; + private String deleteUserSql = DEF_DELETE_USER_SQL; + private String updateUserSql = DEF_UPDATE_USER_SQL; + private String createAuthoritySql = DEF_INSERT_AUTHORITY_SQL; + private String deleteUserAuthoritiesSql = DEF_DELETE_USER_AUTHORITIES_SQL; + private String userExistsSql = DEF_USER_EXISTS_SQL; + private String changePasswordSql = DEF_CHANGE_PASSWORD_SQL; private String findAllGroupsSql = DEF_FIND_GROUPS_SQL; + private String findUsersInGroupSql = DEF_FIND_USERS_IN_GROUP_SQL; + private String insertGroupSql = DEF_INSERT_GROUP_SQL; + private String findGroupIdSql = DEF_FIND_GROUP_ID_SQL; + private String insertGroupAuthoritySql = DEF_INSERT_GROUP_AUTHORITY_SQL; + private String deleteGroupSql = DEF_DELETE_GROUP_SQL; + private String deleteGroupAuthoritiesSql = DEF_DELETE_GROUP_AUTHORITIES_SQL; + private String deleteGroupMembersSql = DEF_DELETE_GROUP_MEMBERS_SQL; + private String renameGroupSql = DEF_RENAME_GROUP_SQL; + private String insertGroupMemberSql = DEF_INSERT_GROUP_MEMBER_SQL; + private String deleteGroupMemberSql = DEF_DELETE_GROUP_MEMBER_SQL; + private String groupAuthoritiesSql = DEF_GROUP_AUTHORITIES_QUERY_SQL; + private String deleteGroupAuthoritySql = DEF_DELETE_GROUP_AUTHORITY_SQL; private AuthenticationManager authenticationManager; @@ -127,301 +158,260 @@ public class JdbcUserDetailsManager extends JdbcDaoImpl implements UserDetailsMa setDataSource(dataSource); } - // ~ Methods - // ======================================================================================================== - + @Override protected void initDao() throws ApplicationContextException { - if (authenticationManager == null) { - logger.info("No authentication manager set. Reauthentication of users when changing passwords will " - + "not be performed."); + if (this.authenticationManager == null) { + this.logger.info( + "No authentication manager set. Reauthentication of users when changing passwords will not be performed."); } - super.initDao(); } - // ~ UserDetailsManager implementation - // ============================================================================== - /** * Executes the SQL usersByUsernameQuery and returns a list of UserDetails * objects. There should normally only be one matching user. */ + @Override protected List loadUsersByUsername(String username) { - return getJdbcTemplate().query(getUsersByUsernameQuery(), new String[]{username}, - (rs, rowNum) -> { - - String userName = rs.getString(1); - String password = rs.getString(2); - boolean enabled = rs.getBoolean(3); - - boolean accLocked = false; - boolean accExpired = false; - boolean credsExpired = false; - - if (rs.getMetaData().getColumnCount() > 3) { - //NOTE: acc_locked, acc_expired and creds_expired are also to be loaded - accLocked = rs.getBoolean(4); - accExpired = rs.getBoolean(5); - credsExpired = rs.getBoolean(6); - } - return new User(userName, password, enabled, !accExpired, !credsExpired, !accLocked, - AuthorityUtils.NO_AUTHORITIES); - }); + return getJdbcTemplate().query(getUsersByUsernameQuery(), this::mapToUser, username); } + private UserDetails mapToUser(ResultSet rs, int rowNum) throws SQLException { + String userName = rs.getString(1); + String password = rs.getString(2); + boolean enabled = rs.getBoolean(3); + boolean accLocked = false; + boolean accExpired = false; + boolean credsExpired = false; + if (rs.getMetaData().getColumnCount() > 3) { + // NOTE: acc_locked, acc_expired and creds_expired are also to be loaded + accLocked = rs.getBoolean(4); + accExpired = rs.getBoolean(5); + credsExpired = rs.getBoolean(6); + } + return new User(userName, password, enabled, !accExpired, !credsExpired, !accLocked, + AuthorityUtils.NO_AUTHORITIES); + } + + @Override public void createUser(final UserDetails user) { validateUserDetails(user); - - getJdbcTemplate().update(createUserSql, ps -> { + getJdbcTemplate().update(this.createUserSql, (ps) -> { ps.setString(1, user.getUsername()); ps.setString(2, user.getPassword()); ps.setBoolean(3, user.isEnabled()); - int paramCount = ps.getParameterMetaData().getParameterCount(); if (paramCount > 3) { - //NOTE: acc_locked, acc_expired and creds_expired are also to be inserted + // NOTE: acc_locked, acc_expired and creds_expired are also to be inserted ps.setBoolean(4, !user.isAccountNonLocked()); ps.setBoolean(5, !user.isAccountNonExpired()); ps.setBoolean(6, !user.isCredentialsNonExpired()); } }); - if (getEnableAuthorities()) { insertUserAuthorities(user); } } + @Override public void updateUser(final UserDetails user) { validateUserDetails(user); - - getJdbcTemplate().update(updateUserSql, ps -> { + getJdbcTemplate().update(this.updateUserSql, (ps) -> { ps.setString(1, user.getPassword()); ps.setBoolean(2, user.isEnabled()); - int paramCount = ps.getParameterMetaData().getParameterCount(); if (paramCount == 3) { ps.setString(3, user.getUsername()); - } else { - //NOTE: acc_locked, acc_expired and creds_expired are also updated + } + else { + // NOTE: acc_locked, acc_expired and creds_expired are also updated ps.setBoolean(3, !user.isAccountNonLocked()); ps.setBoolean(4, !user.isAccountNonExpired()); ps.setBoolean(5, !user.isCredentialsNonExpired()); - ps.setString(6, user.getUsername()); } - }); - if (getEnableAuthorities()) { deleteUserAuthorities(user.getUsername()); insertUserAuthorities(user); } - - userCache.removeUserFromCache(user.getUsername()); + this.userCache.removeUserFromCache(user.getUsername()); } private void insertUserAuthorities(UserDetails user) { for (GrantedAuthority auth : user.getAuthorities()) { - getJdbcTemplate().update(createAuthoritySql, user.getUsername(), - auth.getAuthority()); + getJdbcTemplate().update(this.createAuthoritySql, user.getUsername(), auth.getAuthority()); } } + @Override public void deleteUser(String username) { if (getEnableAuthorities()) { deleteUserAuthorities(username); } - getJdbcTemplate().update(deleteUserSql, username); - userCache.removeUserFromCache(username); + getJdbcTemplate().update(this.deleteUserSql, username); + this.userCache.removeUserFromCache(username); } private void deleteUserAuthorities(String username) { - getJdbcTemplate().update(deleteUserAuthoritiesSql, username); + getJdbcTemplate().update(this.deleteUserAuthoritiesSql, username); } - public void changePassword(String oldPassword, String newPassword) - throws AuthenticationException { - Authentication currentUser = SecurityContextHolder.getContext() - .getAuthentication(); - + @Override + public void changePassword(String oldPassword, String newPassword) throws AuthenticationException { + Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); if (currentUser == null) { // This would indicate bad coding somewhere throw new AccessDeniedException( - "Can't change password as no Authentication object found in context " - + "for current user."); + "Can't change password as no Authentication object found in context " + "for current user."); } - String username = currentUser.getName(); - // If an authentication manager has been set, re-authenticate the user with the // supplied password. - if (authenticationManager != null) { - logger.debug("Reauthenticating user '" + username - + "' for password change request."); - - authenticationManager.authenticate(new UsernamePasswordAuthenticationToken( - username, oldPassword)); + if (this.authenticationManager != null) { + this.logger.debug(LogMessage.format("Reauthenticating user '%s' for password change request.", username)); + this.authenticationManager.authenticate(new UsernamePasswordAuthenticationToken(username, oldPassword)); } else { - logger.debug("No authentication manager set. Password won't be re-checked."); + this.logger.debug("No authentication manager set. Password won't be re-checked."); } - - logger.debug("Changing password for user '" + username + "'"); - - getJdbcTemplate().update(changePasswordSql, newPassword, username); - - SecurityContextHolder.getContext().setAuthentication( - createNewAuthentication(currentUser, newPassword)); - - userCache.removeUserFromCache(username); + this.logger.debug("Changing password for user '" + username + "'"); + getJdbcTemplate().update(this.changePasswordSql, newPassword, username); + SecurityContextHolder.getContext().setAuthentication(createNewAuthentication(currentUser, newPassword)); + this.userCache.removeUserFromCache(username); } - protected Authentication createNewAuthentication(Authentication currentAuth, - String newPassword) { + protected Authentication createNewAuthentication(Authentication currentAuth, String newPassword) { UserDetails user = loadUserByUsername(currentAuth.getName()); - - UsernamePasswordAuthenticationToken newAuthentication = new UsernamePasswordAuthenticationToken( - user, null, user.getAuthorities()); + UsernamePasswordAuthenticationToken newAuthentication = new UsernamePasswordAuthenticationToken(user, null, + user.getAuthorities()); newAuthentication.setDetails(currentAuth.getDetails()); - return newAuthentication; } + @Override public boolean userExists(String username) { - List users = getJdbcTemplate().queryForList(userExistsSql, - new String[] { username }, String.class); - + List users = getJdbcTemplate().queryForList(this.userExistsSql, new String[] { username }, + String.class); if (users.size() > 1) { - throw new IncorrectResultSizeDataAccessException( - "More than one user found with name '" + username + "'", 1); + throw new IncorrectResultSizeDataAccessException("More than one user found with name '" + username + "'", + 1); } - return users.size() == 1; } - // ~ GroupManager implementation - // ==================================================================================== - + @Override public List findAllGroups() { - return getJdbcTemplate().queryForList(findAllGroupsSql, String.class); + return getJdbcTemplate().queryForList(this.findAllGroupsSql, String.class); } + @Override public List findUsersInGroup(String groupName) { Assert.hasText(groupName, "groupName should have text"); - return getJdbcTemplate().queryForList(findUsersInGroupSql, - new String[] { groupName }, String.class); + return getJdbcTemplate().queryForList(this.findUsersInGroupSql, new String[] { groupName }, String.class); } - public void createGroup(final String groupName, - final List authorities) { + @Override + public void createGroup(final String groupName, final List authorities) { Assert.hasText(groupName, "groupName should have text"); Assert.notNull(authorities, "authorities cannot be null"); - - logger.debug("Creating new group '" + groupName + "' with authorities " + this.logger.debug("Creating new group '" + groupName + "' with authorities " + AuthorityUtils.authorityListToSet(authorities)); - - getJdbcTemplate().update(insertGroupSql, groupName); - - final int groupId = findGroupId(groupName); - + getJdbcTemplate().update(this.insertGroupSql, groupName); + int groupId = findGroupId(groupName); for (GrantedAuthority a : authorities) { - final String authority = a.getAuthority(); - getJdbcTemplate().update(insertGroupAuthoritySql, - ps -> { - ps.setInt(1, groupId); - ps.setString(2, authority); - }); + String authority = a.getAuthority(); + getJdbcTemplate().update(this.insertGroupAuthoritySql, (ps) -> { + ps.setInt(1, groupId); + ps.setString(2, authority); + }); } } + @Override public void deleteGroup(String groupName) { - logger.debug("Deleting group '" + groupName + "'"); + this.logger.debug("Deleting group '" + groupName + "'"); Assert.hasText(groupName, "groupName should have text"); - - final int id = findGroupId(groupName); - PreparedStatementSetter groupIdPSS = ps -> ps.setInt(1, id); - getJdbcTemplate().update(deleteGroupMembersSql, groupIdPSS); - getJdbcTemplate().update(deleteGroupAuthoritiesSql, groupIdPSS); - getJdbcTemplate().update(deleteGroupSql, groupIdPSS); + int id = findGroupId(groupName); + PreparedStatementSetter groupIdPSS = (ps) -> ps.setInt(1, id); + getJdbcTemplate().update(this.deleteGroupMembersSql, groupIdPSS); + getJdbcTemplate().update(this.deleteGroupAuthoritiesSql, groupIdPSS); + getJdbcTemplate().update(this.deleteGroupSql, groupIdPSS); } + @Override public void renameGroup(String oldName, String newName) { - logger.debug("Changing group name from '" + oldName + "' to '" + newName + "'"); + this.logger.debug("Changing group name from '" + oldName + "' to '" + newName + "'"); Assert.hasText(oldName, "oldName should have text"); Assert.hasText(newName, "newName should have text"); - - getJdbcTemplate().update(renameGroupSql, newName, oldName); + getJdbcTemplate().update(this.renameGroupSql, newName, oldName); } + @Override public void addUserToGroup(final String username, final String groupName) { - logger.debug("Adding user '" + username + "' to group '" + groupName + "'"); + this.logger.debug("Adding user '" + username + "' to group '" + groupName + "'"); Assert.hasText(username, "username should have text"); Assert.hasText(groupName, "groupName should have text"); - - final int id = findGroupId(groupName); - getJdbcTemplate().update(insertGroupMemberSql, ps -> { + int id = findGroupId(groupName); + getJdbcTemplate().update(this.insertGroupMemberSql, (ps) -> { ps.setInt(1, id); ps.setString(2, username); }); - - userCache.removeUserFromCache(username); + this.userCache.removeUserFromCache(username); } + @Override public void removeUserFromGroup(final String username, final String groupName) { - logger.debug("Removing user '" + username + "' to group '" + groupName + "'"); + this.logger.debug("Removing user '" + username + "' to group '" + groupName + "'"); Assert.hasText(username, "username should have text"); Assert.hasText(groupName, "groupName should have text"); - - final int id = findGroupId(groupName); - - getJdbcTemplate().update(deleteGroupMemberSql, ps -> { + int id = findGroupId(groupName); + getJdbcTemplate().update(this.deleteGroupMemberSql, (ps) -> { ps.setInt(1, id); ps.setString(2, username); }); - - userCache.removeUserFromCache(username); + this.userCache.removeUserFromCache(username); } + @Override public List findGroupAuthorities(String groupName) { - logger.debug("Loading authorities for group '" + groupName + "'"); + this.logger.debug("Loading authorities for group '" + groupName + "'"); Assert.hasText(groupName, "groupName should have text"); - - return getJdbcTemplate().query(groupAuthoritiesSql, new String[] { groupName }, - (rs, rowNum) -> { - String roleName = getRolePrefix() + rs.getString(3); - - return new SimpleGrantedAuthority(roleName); - }); + return getJdbcTemplate().query(this.groupAuthoritiesSql, new String[] { groupName }, + this::mapToGrantedAuthority); } + private GrantedAuthority mapToGrantedAuthority(ResultSet rs, int rowNum) throws SQLException { + String roleName = getRolePrefix() + rs.getString(3); + return new SimpleGrantedAuthority(roleName); + } + + @Override public void removeGroupAuthority(String groupName, final GrantedAuthority authority) { - logger.debug("Removing authority '" + authority + "' from group '" + groupName - + "'"); + this.logger.debug("Removing authority '" + authority + "' from group '" + groupName + "'"); Assert.hasText(groupName, "groupName should have text"); Assert.notNull(authority, "authority cannot be null"); - - final int id = findGroupId(groupName); - - getJdbcTemplate().update(deleteGroupAuthoritySql, ps -> { + int id = findGroupId(groupName); + getJdbcTemplate().update(this.deleteGroupAuthoritySql, (ps) -> { ps.setInt(1, id); ps.setString(2, authority.getAuthority()); }); } + @Override public void addGroupAuthority(final String groupName, final GrantedAuthority authority) { - logger.debug("Adding authority '" + authority + "' to group '" + groupName + "'"); + this.logger.debug("Adding authority '" + authority + "' to group '" + groupName + "'"); Assert.hasText(groupName, "groupName should have text"); Assert.notNull(authority, "authority cannot be null"); - - final int id = findGroupId(groupName); - getJdbcTemplate().update(insertGroupAuthoritySql, ps -> { + int id = findGroupId(groupName); + getJdbcTemplate().update(this.insertGroupAuthoritySql, (ps) -> { ps.setInt(1, id); ps.setString(2, authority.getAuthority()); }); } private int findGroupId(String group) { - return getJdbcTemplate().queryForObject(findGroupIdSql, Integer.class, group); + return getJdbcTemplate().queryForObject(this.findGroupIdSql, Integer.class, group); } public void setAuthenticationManager(AuthenticationManager authenticationManager) { @@ -532,7 +522,6 @@ public class JdbcUserDetailsManager extends JdbcDaoImpl implements UserDetailsMa * Optionally sets the UserCache if one is in use in the application. This allows the * user to be removed from the cache after updates have taken place to avoid stale * data. - * * @param userCache the cache used by the AuthenticationManager. */ public void setUserCache(UserCache userCache) { @@ -547,11 +536,10 @@ public class JdbcUserDetailsManager extends JdbcDaoImpl implements UserDetailsMa private void validateAuthorities(Collection authorities) { Assert.notNull(authorities, "Authorities list must not be null"); - for (GrantedAuthority authority : authorities) { Assert.notNull(authority, "Authorities list contains a null entry"); - Assert.hasText(authority.getAuthority(), - "getAuthority() method must return a non-empty string"); + Assert.hasText(authority.getAuthority(), "getAuthority() method must return a non-empty string"); } } + } diff --git a/core/src/main/java/org/springframework/security/provisioning/MutableUser.java b/core/src/main/java/org/springframework/security/provisioning/MutableUser.java index 04d62b42fc..40b28f611b 100644 --- a/core/src/main/java/org/springframework/security/provisioning/MutableUser.java +++ b/core/src/main/java/org/springframework/security/provisioning/MutableUser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.provisioning; import java.util.Collection; @@ -22,7 +23,6 @@ import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.core.userdetails.UserDetails; /** - * * @author Luke Taylor * @since 3.1 */ @@ -31,6 +31,7 @@ class MutableUser implements MutableUserDetails { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private String password; + private final UserDetails delegate; MutableUser(UserDetails user) { @@ -38,35 +39,44 @@ class MutableUser implements MutableUserDetails { this.password = user.getPassword(); } + @Override public String getPassword() { - return password; + return this.password; } + @Override public void setPassword(String password) { this.password = password; } + @Override public Collection getAuthorities() { - return delegate.getAuthorities(); + return this.delegate.getAuthorities(); } + @Override public String getUsername() { - return delegate.getUsername(); + return this.delegate.getUsername(); } + @Override public boolean isAccountNonExpired() { - return delegate.isAccountNonExpired(); + return this.delegate.isAccountNonExpired(); } + @Override public boolean isAccountNonLocked() { - return delegate.isAccountNonLocked(); + return this.delegate.isAccountNonLocked(); } + @Override public boolean isCredentialsNonExpired() { - return delegate.isCredentialsNonExpired(); + return this.delegate.isCredentialsNonExpired(); } + @Override public boolean isEnabled() { - return delegate.isEnabled(); + return this.delegate.isEnabled(); } + } diff --git a/core/src/main/java/org/springframework/security/provisioning/MutableUserDetails.java b/core/src/main/java/org/springframework/security/provisioning/MutableUserDetails.java index 10bc1f2d3a..2a911d668b 100644 --- a/core/src/main/java/org/springframework/security/provisioning/MutableUserDetails.java +++ b/core/src/main/java/org/springframework/security/provisioning/MutableUserDetails.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.provisioning; import org.springframework.security.core.userdetails.UserDetails; /** - * * @author Luke Taylor * @since 3.1 */ diff --git a/core/src/main/java/org/springframework/security/provisioning/UserDetailsManager.java b/core/src/main/java/org/springframework/security/provisioning/UserDetailsManager.java index 730bb6641e..631966c14d 100644 --- a/core/src/main/java/org/springframework/security/provisioning/UserDetailsManager.java +++ b/core/src/main/java/org/springframework/security/provisioning/UserDetailsManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.provisioning; import org.springframework.security.core.userdetails.UserDetails; @@ -45,7 +46,6 @@ public interface UserDetailsManager extends UserDetailsService { /** * Modify the current user's password. This should change the user's password in the * persistent user repository (datbase, LDAP etc). - * * @param oldPassword current password (for re-authentication if required) * @param newPassword the password to change to */ diff --git a/core/src/main/java/org/springframework/security/provisioning/package-info.java b/core/src/main/java/org/springframework/security/provisioning/package-info.java index 40f1195f63..f661aaa89d 100644 --- a/core/src/main/java/org/springframework/security/provisioning/package-info.java +++ b/core/src/main/java/org/springframework/security/provisioning/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Contains simple user and authority group account provisioning interfaces together with a a - * JDBC-based implementation. + * Contains simple user and authority group account provisioning interfaces together with + * a a JDBC-based implementation. */ package org.springframework.security.provisioning; - diff --git a/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextSchedulingTaskExecutor.java b/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextSchedulingTaskExecutor.java index 6fa2a09a18..6c517a472e 100644 --- a/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextSchedulingTaskExecutor.java +++ b/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextSchedulingTaskExecutor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.scheduling; import java.util.concurrent.Callable; @@ -32,21 +33,19 @@ import org.springframework.security.task.DelegatingSecurityContextAsyncTaskExecu * @author Rob Winch * @since 3.2 */ -public class DelegatingSecurityContextSchedulingTaskExecutor extends - DelegatingSecurityContextAsyncTaskExecutor implements SchedulingTaskExecutor { +public class DelegatingSecurityContextSchedulingTaskExecutor extends DelegatingSecurityContextAsyncTaskExecutor + implements SchedulingTaskExecutor { /** * Creates a new {@link DelegatingSecurityContextSchedulingTaskExecutor} that uses the * specified {@link SecurityContext}. - * * @param delegateSchedulingTaskExecutor the {@link SchedulingTaskExecutor} to * delegate to. Cannot be null. * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} and * {@link DelegatingSecurityContextCallable} */ - public DelegatingSecurityContextSchedulingTaskExecutor( - SchedulingTaskExecutor delegateSchedulingTaskExecutor, + public DelegatingSecurityContextSchedulingTaskExecutor(SchedulingTaskExecutor delegateSchedulingTaskExecutor, SecurityContext securityContext) { super(delegateSchedulingTaskExecutor, securityContext); } @@ -54,15 +53,14 @@ public class DelegatingSecurityContextSchedulingTaskExecutor extends /** * Creates a new {@link DelegatingSecurityContextSchedulingTaskExecutor} that uses the * current {@link SecurityContext}. - * * @param delegateAsyncTaskExecutor the {@link AsyncTaskExecutor} to delegate to. * Cannot be null. */ - public DelegatingSecurityContextSchedulingTaskExecutor( - SchedulingTaskExecutor delegateAsyncTaskExecutor) { + public DelegatingSecurityContextSchedulingTaskExecutor(SchedulingTaskExecutor delegateAsyncTaskExecutor) { this(delegateAsyncTaskExecutor, null); } + @Override public boolean prefersShortLivedTasks() { return getDelegate().prefersShortLivedTasks(); } @@ -70,4 +68,5 @@ public class DelegatingSecurityContextSchedulingTaskExecutor extends private SchedulingTaskExecutor getDelegate() { return (SchedulingTaskExecutor) getDelegateExecutor(); } + } diff --git a/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java b/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java index 9087732952..0925b7fa1f 100644 --- a/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java +++ b/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.scheduling; +import java.util.Date; +import java.util.concurrent.ScheduledFuture; + import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; import org.springframework.util.Assert; -import java.util.Date; -import java.util.concurrent.ScheduledFuture; - /** - * An implementation of {@link TaskScheduler} invoking it whenever the trigger - * indicates a next execution time. + * An implementation of {@link TaskScheduler} invoking it whenever the trigger indicates a + * next execution time. * * @author Richard Valdivieso * @since 5.1 @@ -35,7 +36,6 @@ public class DelegatingSecurityContextTaskScheduler implements TaskScheduler { /** * Creates a new {@link DelegatingSecurityContextTaskScheduler} - * * @param taskScheduler the {@link TaskScheduler} */ public DelegatingSecurityContextTaskScheduler(TaskScheduler taskScheduler) { @@ -45,31 +45,32 @@ public class DelegatingSecurityContextTaskScheduler implements TaskScheduler { @Override public ScheduledFuture schedule(Runnable task, Trigger trigger) { - return taskScheduler.schedule(task, trigger); + return this.taskScheduler.schedule(task, trigger); } @Override public ScheduledFuture schedule(Runnable task, Date startTime) { - return taskScheduler.schedule(task, startTime); + return this.taskScheduler.schedule(task, startTime); } @Override public ScheduledFuture scheduleAtFixedRate(Runnable task, Date startTime, long period) { - return taskScheduler.scheduleAtFixedRate(task, startTime, period); + return this.taskScheduler.scheduleAtFixedRate(task, startTime, period); } @Override public ScheduledFuture scheduleAtFixedRate(Runnable task, long period) { - return taskScheduler.scheduleAtFixedRate(task, period); + return this.taskScheduler.scheduleAtFixedRate(task, period); } @Override public ScheduledFuture scheduleWithFixedDelay(Runnable task, Date startTime, long delay) { - return taskScheduler.scheduleWithFixedDelay(task, startTime, delay); + return this.taskScheduler.scheduleWithFixedDelay(task, startTime, delay); } @Override public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { - return taskScheduler.scheduleWithFixedDelay(task, delay); + return this.taskScheduler.scheduleWithFixedDelay(task, delay); } + } diff --git a/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextAsyncTaskExecutor.java b/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextAsyncTaskExecutor.java index f622747b44..3bbba07a0e 100644 --- a/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextAsyncTaskExecutor.java +++ b/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextAsyncTaskExecutor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.task; import java.util.concurrent.Callable; @@ -31,52 +32,50 @@ import org.springframework.security.core.context.SecurityContext; * @author Rob Winch * @since 3.2 */ -public class DelegatingSecurityContextAsyncTaskExecutor extends - DelegatingSecurityContextTaskExecutor implements AsyncTaskExecutor { +public class DelegatingSecurityContextAsyncTaskExecutor extends DelegatingSecurityContextTaskExecutor + implements AsyncTaskExecutor { /** * Creates a new {@link DelegatingSecurityContextAsyncTaskExecutor} that uses the * specified {@link SecurityContext}. - * * @param delegateAsyncTaskExecutor the {@link AsyncTaskExecutor} to delegate to. * Cannot be null. * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} and * {@link DelegatingSecurityContextCallable} */ - public DelegatingSecurityContextAsyncTaskExecutor( - AsyncTaskExecutor delegateAsyncTaskExecutor, SecurityContext securityContext) { + public DelegatingSecurityContextAsyncTaskExecutor(AsyncTaskExecutor delegateAsyncTaskExecutor, + SecurityContext securityContext) { super(delegateAsyncTaskExecutor, securityContext); } /** * Creates a new {@link DelegatingSecurityContextAsyncTaskExecutor} that uses the * current {@link SecurityContext}. - * * @param delegateAsyncTaskExecutor the {@link AsyncTaskExecutor} to delegate to. * Cannot be null. */ - public DelegatingSecurityContextAsyncTaskExecutor( - AsyncTaskExecutor delegateAsyncTaskExecutor) { + public DelegatingSecurityContextAsyncTaskExecutor(AsyncTaskExecutor delegateAsyncTaskExecutor) { this(delegateAsyncTaskExecutor, null); } + @Override public final void execute(Runnable task, long startTimeout) { - task = wrap(task); - getDelegate().execute(task, startTimeout); + getDelegate().execute(wrap(task), startTimeout); } + @Override public final Future submit(Runnable task) { - task = wrap(task); - return getDelegate().submit(task); + return getDelegate().submit(wrap(task)); } + @Override public final Future submit(Callable task) { - task = wrap(task); - return getDelegate().submit(task); + return getDelegate().submit(wrap(task)); } private AsyncTaskExecutor getDelegate() { return (AsyncTaskExecutor) getDelegateExecutor(); } + } diff --git a/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextTaskExecutor.java b/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextTaskExecutor.java index 776b89df3c..c051e014b4 100644 --- a/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextTaskExecutor.java +++ b/core/src/main/java/org/springframework/security/task/DelegatingSecurityContextTaskExecutor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.task; import org.springframework.core.task.TaskExecutor; @@ -28,30 +29,27 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Rob Winch * @since 3.2 */ -public class DelegatingSecurityContextTaskExecutor extends - DelegatingSecurityContextExecutor implements TaskExecutor { +public class DelegatingSecurityContextTaskExecutor extends DelegatingSecurityContextExecutor implements TaskExecutor { + /** * Creates a new {@link DelegatingSecurityContextTaskExecutor} that uses the specified * {@link SecurityContext}. - * * @param delegateTaskExecutor the {@link TaskExecutor} to delegate to. Cannot be * null. * @param securityContext the {@link SecurityContext} to use for each * {@link DelegatingSecurityContextRunnable} */ - public DelegatingSecurityContextTaskExecutor(TaskExecutor delegateTaskExecutor, - SecurityContext securityContext) { + public DelegatingSecurityContextTaskExecutor(TaskExecutor delegateTaskExecutor, SecurityContext securityContext) { super(delegateTaskExecutor, securityContext); } /** * Creates a new {@link DelegatingSecurityContextTaskExecutor} that uses the current * {@link SecurityContext} from the {@link SecurityContextHolder}. - * - * @param delegate the {@link TaskExecutor} to delegate to. Cannot be - * null. + * @param delegate the {@link TaskExecutor} to delegate to. Cannot be null. */ public DelegatingSecurityContextTaskExecutor(TaskExecutor delegate) { this(delegate, null); } -} \ No newline at end of file + +} diff --git a/core/src/main/java/org/springframework/security/util/FieldUtils.java b/core/src/main/java/org/springframework/security/util/FieldUtils.java index 50cd8484d7..261d264b0d 100644 --- a/core/src/main/java/org/springframework/security/util/FieldUtils.java +++ b/core/src/main/java/org/springframework/security/util/FieldUtils.java @@ -16,12 +16,12 @@ package org.springframework.security.util; +import java.lang.reflect.Field; + import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; -import java.lang.reflect.Field; - /** * Offers static methods for directly manipulating fields. * @@ -29,34 +29,28 @@ import java.lang.reflect.Field; */ public final class FieldUtils { - // ~ Methods - // ======================================================================================================== + private FieldUtils() { + } + /** * Attempts to locate the specified field on the class. - * * @param clazz the class definition containing the field * @param fieldName the name of the field to locate - * * @return the Field (never null) - * * @throws IllegalStateException if field could not be found */ - public static Field getField(Class clazz, String fieldName) - throws IllegalStateException { + public static Field getField(Class clazz, String fieldName) throws IllegalStateException { Assert.notNull(clazz, "Class required"); Assert.hasText(fieldName, "Field name required"); - try { return clazz.getDeclaredField(fieldName); } - catch (NoSuchFieldException nsf) { + catch (NoSuchFieldException ex) { // Try superclass if (clazz.getSuperclass() != null) { return getField(clazz.getSuperclass(), fieldName); } - - throw new IllegalStateException("Could not locate field '" + fieldName - + "' on class " + clazz); + throw new IllegalStateException("Could not locate field '" + fieldName + "' on class " + clazz); } } @@ -66,14 +60,12 @@ public final class FieldUtils { * @param fieldName the field name, with "." separating nested properties * @return the value of the nested field */ - public static Object getFieldValue(Object bean, String fieldName) - throws IllegalAccessException { + public static Object getFieldValue(Object bean, String fieldName) throws IllegalAccessException { Assert.notNull(bean, "Bean cannot be null"); Assert.hasText(fieldName, "Field name required"); String[] nestedFields = StringUtils.tokenizeToStringArray(fieldName, "."); Class componentClass = bean.getClass(); Object value = bean; - for (String nestedField : nestedFields) { Field field = getField(componentClass, nestedField); field.setAccessible(true); @@ -82,30 +74,24 @@ public final class FieldUtils { componentClass = value.getClass(); } } - return value; } public static Object getProtectedFieldValue(String protectedField, Object object) { Field field = FieldUtils.getField(object.getClass(), protectedField); - try { field.setAccessible(true); - return field.get(object); } catch (Exception ex) { ReflectionUtils.handleReflectionException(ex); - return null; // unreachable - previous line throws exception } } - public static void setProtectedFieldValue(String protectedField, Object object, - Object newValue) { + public static void setProtectedFieldValue(String protectedField, Object object, Object newValue) { Field field = FieldUtils.getField(object.getClass(), protectedField); - try { field.setAccessible(true); field.set(object, newValue); @@ -114,4 +100,5 @@ public final class FieldUtils { ReflectionUtils.handleReflectionException(ex); } } + } diff --git a/core/src/main/java/org/springframework/security/util/InMemoryResource.java b/core/src/main/java/org/springframework/security/util/InMemoryResource.java index 65fe2f7bc6..30b3158b6a 100644 --- a/core/src/main/java/org/springframework/security/util/InMemoryResource.java +++ b/core/src/main/java/org/springframework/security/util/InMemoryResource.java @@ -16,13 +16,13 @@ package org.springframework.security.util; -import org.springframework.core.io.AbstractResource; -import org.springframework.util.Assert; - import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.Arrays; +import org.springframework.core.io.AbstractResource; +import org.springframework.util.Assert; + /** * An in memory implementation of Spring's {@link org.springframework.core.io.Resource} * interface. @@ -33,14 +33,10 @@ import java.util.Arrays; * @author Luke Taylor */ public class InMemoryResource extends AbstractResource { - // ~ Instance fields - // ================================================================================================ private final byte[] source; - private final String description; - // ~ Constructors - // =================================================================================================== + private final String description; public InMemoryResource(String source) { this(source.getBytes()); @@ -56,22 +52,14 @@ public class InMemoryResource extends AbstractResource { this.description = description; } - // ~ Methods - // ======================================================================================================== - @Override public String getDescription() { - return description; + return this.description; } @Override public InputStream getInputStream() { - return new ByteArrayInputStream(source); - } - - @Override - public int hashCode() { - return 1; + return new ByteArrayInputStream(this.source); } @Override @@ -79,7 +67,12 @@ public class InMemoryResource extends AbstractResource { if (!(res instanceof InMemoryResource)) { return false; } - - return Arrays.equals(source, ((InMemoryResource) res).source); + return Arrays.equals(this.source, ((InMemoryResource) res).source); } + + @Override + public int hashCode() { + return 1; + } + } diff --git a/core/src/main/java/org/springframework/security/util/MethodInvocationUtils.java b/core/src/main/java/org/springframework/security/util/MethodInvocationUtils.java index fa4c7cd8cb..26ade0d092 100644 --- a/core/src/main/java/org/springframework/security/util/MethodInvocationUtils.java +++ b/core/src/main/java/org/springframework/security/util/MethodInvocationUtils.java @@ -19,6 +19,7 @@ package org.springframework.security.util; import java.lang.reflect.Method; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.aop.framework.Advised; import org.springframework.aop.support.AopUtils; import org.springframework.util.Assert; @@ -34,37 +35,30 @@ import org.springframework.util.Assert; */ public final class MethodInvocationUtils { - // ~ Methods - // ======================================================================================================== + private MethodInvocationUtils() { + } /** * Generates a MethodInvocation for specified methodName on * the passed object, using the args to locate the method. - * * @param object the object that will be used to find the relevant Method * @param methodName the name of the method to find * @param args arguments that are required as part of the method signature (can be * empty) - * * @return a MethodInvocation, or null if there was a * problem */ - public static MethodInvocation create(Object object, String methodName, - Object... args) { + public static MethodInvocation create(Object object, String methodName, Object... args) { Assert.notNull(object, "Object required"); - Class[] classArgs = null; - if (args != null) { classArgs = new Class[args.length]; - for (int i = 0; i < args.length; i++) { classArgs[i] = args[i].getClass(); } } - - // Determine the type that declares the requested method, taking into account - // proxies + // Determine the type that declares the requested method, + // taking into account proxies Class target = AopUtils.getTargetClass(object); if (object instanceof Advised) { Advised a = (Advised) object; @@ -77,13 +71,12 @@ public final class MethodInvocationUtils { target = possibleInterface; break; } - catch (Exception ignored) { + catch (Exception ex) { // try the next one } } } } - return createFromClass(object, target, methodName, classArgs, args); } @@ -95,37 +88,29 @@ public final class MethodInvocationUtils { * through the declared methods on the class, until one is found matching the supplied * name. If more than one method name matches, an IllegalArgumentException * will be raised. - * * @param clazz the class of object that will be used to find the relevant * Method * @param methodName the name of the method to find - * * @return a MethodInvocation, or null if there was a * problem */ public static MethodInvocation createFromClass(Class clazz, String methodName) { - MethodInvocation mi = createFromClass(null, clazz, methodName, null, null); - - if (mi == null) { - for (Method m : clazz.getDeclaredMethods()) { - if (m.getName().equals(methodName)) { - if (mi != null) { - throw new IllegalArgumentException("The class " + clazz - + " has more than one method named" + " '" + methodName - + "'"); - } - mi = new SimpleMethodInvocation(null, m); + MethodInvocation invocation = createFromClass(null, clazz, methodName, null, null); + if (invocation == null) { + for (Method method : clazz.getDeclaredMethods()) { + if (method.getName().equals(methodName)) { + Assert.isTrue(invocation == null, + () -> "The class " + clazz + " has more than one method named" + " '" + methodName + "'"); + invocation = new SimpleMethodInvocation(null, method); } } } - - return mi; + return invocation; } /** * Generates a MethodInvocation for specified methodName on * the passed class, using the args to locate the method. - * * @param targetObject the object being invoked * @param clazz the class of object that will be used to find the relevant * Method @@ -136,20 +121,17 @@ public final class MethodInvocationUtils { * @return a MethodInvocation, or null if there was a * problem */ - public static MethodInvocation createFromClass(Object targetObject, Class clazz, - String methodName, Class[] classArgs, Object[] args) { + public static MethodInvocation createFromClass(Object targetObject, Class clazz, String methodName, + Class[] classArgs, Object[] args) { Assert.notNull(clazz, "Class required"); Assert.hasText(methodName, "MethodName required"); - - Method method; - try { - method = clazz.getMethod(methodName, classArgs); + Method method = clazz.getMethod(methodName, classArgs); + return new SimpleMethodInvocation(targetObject, method, args); } - catch (NoSuchMethodException e) { + catch (NoSuchMethodException ex) { return null; } - - return new SimpleMethodInvocation(targetObject, method, args); } + } diff --git a/core/src/main/java/org/springframework/security/util/SimpleMethodInvocation.java b/core/src/main/java/org/springframework/security/util/SimpleMethodInvocation.java index a930b22f2f..37977d3881 100644 --- a/core/src/main/java/org/springframework/security/util/SimpleMethodInvocation.java +++ b/core/src/main/java/org/springframework/security/util/SimpleMethodInvocation.java @@ -16,56 +16,56 @@ package org.springframework.security.util; -import org.aopalliance.intercept.MethodInvocation; - import java.lang.reflect.AccessibleObject; import java.lang.reflect.Method; +import org.aopalliance.intercept.MethodInvocation; + /** * Represents the AOP Alliance MethodInvocation. * * @author Ben Alex */ public class SimpleMethodInvocation implements MethodInvocation { - // ~ Instance fields - // ================================================================================================ private Method method; - private Object[] arguments; - private Object targetObject; - // ~ Constructors - // =================================================================================================== + private Object[] arguments; + + private Object targetObject; public SimpleMethodInvocation(Object targetObject, Method method, Object... arguments) { this.targetObject = targetObject; this.method = method; - this.arguments = arguments == null ? new Object[0] : arguments; + this.arguments = (arguments != null) ? arguments : new Object[0]; } public SimpleMethodInvocation() { } - // ~ Methods - // ======================================================================================================== - + @Override public Object[] getArguments() { - return arguments; + return this.arguments; } + @Override public Method getMethod() { - return method; + return this.method; } + @Override public AccessibleObject getStaticPart() { throw new UnsupportedOperationException("mock method not implemented"); } + @Override public Object getThis() { - return targetObject; + return this.targetObject; } + @Override public Object proceed() { throw new UnsupportedOperationException("mock method not implemented"); } + } diff --git a/core/src/main/java/org/springframework/security/util/package-info.java b/core/src/main/java/org/springframework/security/util/package-info.java index fba13f4bed..e8cbd0733d 100644 --- a/core/src/main/java/org/springframework/security/util/package-info.java +++ b/core/src/main/java/org/springframework/security/util/package-info.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * General utility classes used throughout the Spring Security framework. Intended for internal use. + * General utility classes used throughout the Spring Security framework. Intended for + * internal use. *

      - * This package should be standalone - it should not have dependencies on other parts of the framework, - * just on external libraries and the JDK. + * This package should be standalone - it should not have dependencies on other parts of + * the framework, just on external libraries and the JDK. */ package org.springframework.security.util; - diff --git a/core/src/test/java/org/springframework/security/ITargetObject.java b/core/src/test/java/org/springframework/security/ITargetObject.java index f0460b6233..b091350bb9 100644 --- a/core/src/test/java/org/springframework/security/ITargetObject.java +++ b/core/src/test/java/org/springframework/security/ITargetObject.java @@ -22,8 +22,6 @@ package org.springframework.security; * @author Ben Alex */ public interface ITargetObject { - // ~ Methods - // ======================================================================================================== Integer computeHashCode(String input); @@ -34,4 +32,5 @@ public interface ITargetObject { String makeUpperCase(String input); String publicMakeLowerCase(String input); + } diff --git a/core/src/test/java/org/springframework/security/OtherTargetObject.java b/core/src/test/java/org/springframework/security/OtherTargetObject.java index ae0d240d87..d57d1f41f4 100644 --- a/core/src/test/java/org/springframework/security/OtherTargetObject.java +++ b/core/src/test/java/org/springframework/security/OtherTargetObject.java @@ -34,18 +34,20 @@ package org.springframework.security; * @author Ben Alex */ public class OtherTargetObject extends TargetObject implements ITargetObject { - // ~ Methods - // ======================================================================================================== + @Override public String makeLowerCase(String input) { return super.makeLowerCase(input); } + @Override public String makeUpperCase(String input) { return super.makeUpperCase(input); } + @Override public String publicMakeLowerCase(String input) { return super.publicMakeLowerCase(input); } + } diff --git a/core/src/test/java/org/springframework/security/PopulatedDatabase.java b/core/src/test/java/org/springframework/security/PopulatedDatabase.java index af67a37170..2ff999577a 100644 --- a/core/src/test/java/org/springframework/security/PopulatedDatabase.java +++ b/core/src/test/java/org/springframework/security/PopulatedDatabase.java @@ -16,48 +16,42 @@ package org.springframework.security; -import org.springframework.jdbc.core.JdbcTemplate; - import javax.sql.DataSource; +import org.springframework.jdbc.core.JdbcTemplate; + /** * Singleton which provides a populated database connection for all JDBC-related unit * tests. * * @author Ben Alex */ -public class PopulatedDatabase { - // ~ Static fields/initializers - // ===================================================================================== +public final class PopulatedDatabase { private static TestDataSource dataSource = null; - // ~ Constructors - // =================================================================================================== - private PopulatedDatabase() { } - // ~ Methods - // ======================================================================================================== - public static DataSource getDataSource() { if (dataSource == null) { setupDataSource(); } - return dataSource; } private static void setupDataSource() { dataSource = new TestDataSource("springsecuritytest"); JdbcTemplate template = new JdbcTemplate(dataSource); - - template.execute("CREATE TABLE USERS(USERNAME VARCHAR_IGNORECASE(50) NOT NULL PRIMARY KEY,PASSWORD VARCHAR_IGNORECASE(500) NOT NULL,ENABLED BOOLEAN NOT NULL)"); - template.execute("CREATE TABLE AUTHORITIES(USERNAME VARCHAR_IGNORECASE(50) NOT NULL,AUTHORITY VARCHAR_IGNORECASE(50) NOT NULL,CONSTRAINT FK_AUTHORITIES_USERS FOREIGN KEY(USERNAME) REFERENCES USERS(USERNAME))"); + template.execute( + "CREATE TABLE USERS(USERNAME VARCHAR_IGNORECASE(50) NOT NULL PRIMARY KEY,PASSWORD VARCHAR_IGNORECASE(500) NOT NULL,ENABLED BOOLEAN NOT NULL)"); + template.execute( + "CREATE TABLE AUTHORITIES(USERNAME VARCHAR_IGNORECASE(50) NOT NULL,AUTHORITY VARCHAR_IGNORECASE(50) NOT NULL,CONSTRAINT FK_AUTHORITIES_USERS FOREIGN KEY(USERNAME) REFERENCES USERS(USERNAME))"); template.execute("CREATE UNIQUE INDEX IX_AUTH_USERNAME ON AUTHORITIES(USERNAME,AUTHORITY)"); - template.execute("CREATE TABLE ACL_OBJECT_IDENTITY(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) NOT NULL PRIMARY KEY,OBJECT_IDENTITY VARCHAR_IGNORECASE(250) NOT NULL,PARENT_OBJECT BIGINT,ACL_CLASS VARCHAR_IGNORECASE(250) NOT NULL,CONSTRAINT UNIQUE_OBJECT_IDENTITY UNIQUE(OBJECT_IDENTITY),CONSTRAINT SYS_FK_3 FOREIGN KEY(PARENT_OBJECT) REFERENCES ACL_OBJECT_IDENTITY(ID))"); - template.execute("CREATE TABLE ACL_PERMISSION(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) NOT NULL PRIMARY KEY,ACL_OBJECT_IDENTITY BIGINT NOT NULL,RECIPIENT VARCHAR_IGNORECASE(100) NOT NULL,MASK INTEGER NOT NULL,CONSTRAINT UNIQUE_RECIPIENT UNIQUE(ACL_OBJECT_IDENTITY,RECIPIENT),CONSTRAINT SYS_FK_7 FOREIGN KEY(ACL_OBJECT_IDENTITY) REFERENCES ACL_OBJECT_IDENTITY(ID))"); + template.execute( + "CREATE TABLE ACL_OBJECT_IDENTITY(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) NOT NULL PRIMARY KEY,OBJECT_IDENTITY VARCHAR_IGNORECASE(250) NOT NULL,PARENT_OBJECT BIGINT,ACL_CLASS VARCHAR_IGNORECASE(250) NOT NULL,CONSTRAINT UNIQUE_OBJECT_IDENTITY UNIQUE(OBJECT_IDENTITY),CONSTRAINT SYS_FK_3 FOREIGN KEY(PARENT_OBJECT) REFERENCES ACL_OBJECT_IDENTITY(ID))"); + template.execute( + "CREATE TABLE ACL_PERMISSION(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) NOT NULL PRIMARY KEY,ACL_OBJECT_IDENTITY BIGINT NOT NULL,RECIPIENT VARCHAR_IGNORECASE(100) NOT NULL,MASK INTEGER NOT NULL,CONSTRAINT UNIQUE_RECIPIENT UNIQUE(ACL_OBJECT_IDENTITY,RECIPIENT),CONSTRAINT SYS_FK_7 FOREIGN KEY(ACL_OBJECT_IDENTITY) REFERENCES ACL_OBJECT_IDENTITY(ID))"); template.execute("SET IGNORECASE TRUE"); template.execute("INSERT INTO USERS VALUES('dianne','emu',TRUE)"); template.execute("INSERT INTO USERS VALUES('rod','koala',TRUE)"); @@ -69,44 +63,49 @@ public class PopulatedDatabase { template.execute("INSERT INTO AUTHORITIES VALUES('dianne','ROLE_TELLER')"); template.execute("INSERT INTO AUTHORITIES VALUES('scott','ROLE_TELLER')"); template.execute("INSERT INTO AUTHORITIES VALUES('peter','ROLE_TELLER')"); - template.execute("INSERT INTO acl_object_identity VALUES (1, 'org.springframework.security.acl.DomainObject:1', null, 'org.springframework.security.acl.basic.SimpleAclEntry');"); - template.execute("INSERT INTO acl_object_identity VALUES (2, 'org.springframework.security.acl.DomainObject:2', 1, 'org.springframework.security.acl.basic.SimpleAclEntry');"); - template.execute("INSERT INTO acl_object_identity VALUES (3, 'org.springframework.security.acl.DomainObject:3', 1, 'org.springframework.security.acl.basic.SimpleAclEntry');"); - template.execute("INSERT INTO acl_object_identity VALUES (4, 'org.springframework.security.acl.DomainObject:4', 1, 'org.springframework.security.acl.basic.SimpleAclEntry');"); - template.execute("INSERT INTO acl_object_identity VALUES (5, 'org.springframework.security.acl.DomainObject:5', 3, 'org.springframework.security.acl.basic.SimpleAclEntry');"); - template.execute("INSERT INTO acl_object_identity VALUES (6, 'org.springframework.security.acl.DomainObject:6', 3, 'org.springframework.security.acl.basic.SimpleAclEntry');"); - + template.execute( + "INSERT INTO acl_object_identity VALUES (1, 'org.springframework.security.acl.DomainObject:1', null, 'org.springframework.security.acl.basic.SimpleAclEntry');"); + template.execute( + "INSERT INTO acl_object_identity VALUES (2, 'org.springframework.security.acl.DomainObject:2', 1, 'org.springframework.security.acl.basic.SimpleAclEntry');"); + template.execute( + "INSERT INTO acl_object_identity VALUES (3, 'org.springframework.security.acl.DomainObject:3', 1, 'org.springframework.security.acl.basic.SimpleAclEntry');"); + template.execute( + "INSERT INTO acl_object_identity VALUES (4, 'org.springframework.security.acl.DomainObject:4', 1, 'org.springframework.security.acl.basic.SimpleAclEntry');"); + template.execute( + "INSERT INTO acl_object_identity VALUES (5, 'org.springframework.security.acl.DomainObject:5', 3, 'org.springframework.security.acl.basic.SimpleAclEntry');"); + template.execute( + "INSERT INTO acl_object_identity VALUES (6, 'org.springframework.security.acl.DomainObject:6', 3, 'org.springframework.security.acl.basic.SimpleAclEntry');"); // ----- BEGIN deviation from normal sample data load script ----- - template.execute("INSERT INTO acl_object_identity VALUES (7, 'org.springframework.security.acl.DomainObject:7', 3, 'some.invalid.acl.entry.class');"); - + template.execute( + "INSERT INTO acl_object_identity VALUES (7, 'org.springframework.security.acl.DomainObject:7', 3, 'some.invalid.acl.entry.class');"); // ----- FINISH deviation from normal sample data load script ----- template.execute("INSERT INTO acl_permission VALUES (null, 1, 'ROLE_SUPERVISOR', 1);"); template.execute("INSERT INTO acl_permission VALUES (null, 2, 'ROLE_SUPERVISOR', 0);"); template.execute("INSERT INTO acl_permission VALUES (null, 2, 'rod', 2);"); template.execute("INSERT INTO acl_permission VALUES (null, 3, 'scott', 14);"); template.execute("INSERT INTO acl_permission VALUES (null, 6, 'scott', 1);"); - createGroupTables(template); insertGroupData(template); } public static void createGroupTables(JdbcTemplate template) { // Group tables and data - template.execute("CREATE TABLE GROUPS(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) PRIMARY KEY, GROUP_NAME VARCHAR_IGNORECASE(50) NOT NULL)"); - template.execute("CREATE TABLE GROUP_AUTHORITIES(GROUP_ID BIGINT NOT NULL, AUTHORITY VARCHAR(50) NOT NULL, CONSTRAINT FK_GROUP_AUTHORITIES_GROUP FOREIGN KEY(GROUP_ID) REFERENCES GROUPS(ID))"); - template.execute("CREATE TABLE GROUP_MEMBERS(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) PRIMARY KEY, USERNAME VARCHAR(50) NOT NULL, GROUP_ID BIGINT NOT NULL, CONSTRAINT FK_GROUP_MEMBERS_GROUP FOREIGN KEY(GROUP_ID) REFERENCES GROUPS(ID))"); + template.execute( + "CREATE TABLE GROUPS(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) PRIMARY KEY, GROUP_NAME VARCHAR_IGNORECASE(50) NOT NULL)"); + template.execute( + "CREATE TABLE GROUP_AUTHORITIES(GROUP_ID BIGINT NOT NULL, AUTHORITY VARCHAR(50) NOT NULL, CONSTRAINT FK_GROUP_AUTHORITIES_GROUP FOREIGN KEY(GROUP_ID) REFERENCES GROUPS(ID))"); + template.execute( + "CREATE TABLE GROUP_MEMBERS(ID BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 0) PRIMARY KEY, USERNAME VARCHAR(50) NOT NULL, GROUP_ID BIGINT NOT NULL, CONSTRAINT FK_GROUP_MEMBERS_GROUP FOREIGN KEY(GROUP_ID) REFERENCES GROUPS(ID))"); } public static void insertGroupData(JdbcTemplate template) { template.execute("INSERT INTO USERS VALUES('jerry','password',TRUE)"); template.execute("INSERT INTO USERS VALUES('tom','password',TRUE)"); - template.execute("INSERT INTO GROUPS VALUES (0, 'GROUP_0')"); template.execute("INSERT INTO GROUPS VALUES (1, 'GROUP_1')"); template.execute("INSERT INTO GROUPS VALUES (2, 'GROUP_2')"); // Group 3 isn't used template.execute("INSERT INTO GROUPS VALUES (3, 'GROUP_3')"); - template.execute("INSERT INTO GROUP_AUTHORITIES VALUES (0, 'ROLE_A')"); template.execute("INSERT INTO GROUP_AUTHORITIES VALUES (1, 'ROLE_B')"); template.execute("INSERT INTO GROUP_AUTHORITIES VALUES (1, 'ROLE_C')"); @@ -115,11 +114,11 @@ public class PopulatedDatabase { template.execute("INSERT INTO GROUP_AUTHORITIES VALUES (2, 'ROLE_C')"); template.execute("INSERT INTO GROUP_AUTHORITIES VALUES (3, 'ROLE_D')"); template.execute("INSERT INTO GROUP_AUTHORITIES VALUES (3, 'ROLE_E')"); - template.execute("INSERT INTO GROUP_MEMBERS VALUES (0, 'jerry', 0)"); template.execute("INSERT INTO GROUP_MEMBERS VALUES (1, 'jerry', 1)"); // tom has groups with overlapping roles template.execute("INSERT INTO GROUP_MEMBERS VALUES (2, 'tom', 1)"); template.execute("INSERT INTO GROUP_MEMBERS VALUES (3, 'tom', 2)"); } + } diff --git a/core/src/test/java/org/springframework/security/TargetObject.java b/core/src/test/java/org/springframework/security/TargetObject.java index 612d401b59..b936d69043 100644 --- a/core/src/test/java/org/springframework/security/TargetObject.java +++ b/core/src/test/java/org/springframework/security/TargetObject.java @@ -25,62 +25,57 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Ben Alex */ public class TargetObject implements ITargetObject { - // ~ Methods - // ======================================================================================================== + @Override public Integer computeHashCode(String input) { return input.hashCode(); } + @Override public int countLength(String input) { return input.length(); } /** * Returns the lowercase string, followed by security environment information. - * * @param input the message to make lowercase - * * @return the lowercase message, a space, the Authentication class that * was on the SecurityContext at the time of method invocation, and a * boolean indicating if the Authentication object is authenticated or * not */ + @Override public String makeLowerCase(String input) { Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - if (auth == null) { return input.toLowerCase() + " Authentication empty"; } else { - return input.toLowerCase() + " " + auth.getClass().getName() + " " - + auth.isAuthenticated(); + return input.toLowerCase() + " " + auth.getClass().getName() + " " + auth.isAuthenticated(); } } /** * Returns the uppercase string, followed by security environment information. - * * @param input the message to make uppercase - * * @return the uppercase message, a space, the Authentication class that * was on the SecurityContext at the time of method invocation, and a * boolean indicating if the Authentication object is authenticated or * not */ + @Override public String makeUpperCase(String input) { Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - - return input.toUpperCase() + " " + auth.getClass().getName() + " " - + auth.isAuthenticated(); + return input.toUpperCase() + " " + auth.getClass().getName() + " " + auth.isAuthenticated(); } /** * Delegates through to the {@link #makeLowerCase(String)} method. - * * @param input the message to be made lower-case */ + @Override public String publicMakeLowerCase(String input) { return this.makeLowerCase(input); } + } diff --git a/core/src/test/java/org/springframework/security/TestDataSource.java b/core/src/test/java/org/springframework/security/TestDataSource.java index 53e79e12f0..9a81dec420 100644 --- a/core/src/test/java/org/springframework/security/TestDataSource.java +++ b/core/src/test/java/org/springframework/security/TestDataSource.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security; -import org.springframework.jdbc.datasource.DriverManagerDataSource; -import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.beans.factory.DisposableBean; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DriverManagerDataSource; /** * A Datasource bean which starts an in-memory HSQL database with the supplied name and @@ -26,19 +27,22 @@ import org.springframework.beans.factory.DisposableBean; * @author Luke Taylor */ public class TestDataSource extends DriverManagerDataSource implements DisposableBean { + String name; public TestDataSource(String databaseName) { - name = databaseName; - System.out.println("Creating database: " + name); + this.name = databaseName; + System.out.println("Creating database: " + this.name); setDriverClassName("org.hsqldb.jdbcDriver"); setUrl("jdbc:hsqldb:mem:" + databaseName); setUsername("sa"); setPassword(""); } + @Override public void destroy() { - System.out.println("Shutting down database: " + name); + System.out.println("Shutting down database: " + this.name); new JdbcTemplate(this).execute("SHUTDOWN"); } + } diff --git a/core/src/test/java/org/springframework/security/access/AuthenticationCredentialsNotFoundEventTests.java b/core/src/test/java/org/springframework/security/access/AuthenticationCredentialsNotFoundEventTests.java index 29e8ba822c..e694b06466 100644 --- a/core/src/test/java/org/springframework/security/access/AuthenticationCredentialsNotFoundEventTests.java +++ b/core/src/test/java/org/springframework/security/access/AuthenticationCredentialsNotFoundEventTests.java @@ -17,7 +17,7 @@ package org.springframework.security.access; import org.junit.Test; -import org.springframework.security.access.SecurityConfig; + import org.springframework.security.access.event.AuthenticationCredentialsNotFoundEvent; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.util.SimpleMethodInvocation; @@ -31,8 +31,7 @@ public class AuthenticationCredentialsNotFoundEventTests { @Test(expected = IllegalArgumentException.class) public void testRejectsNulls() { - new AuthenticationCredentialsNotFoundEvent(null, - SecurityConfig.createList("TEST"), + new AuthenticationCredentialsNotFoundEvent(null, SecurityConfig.createList("TEST"), new AuthenticationCredentialsNotFoundException("test")); } @@ -44,7 +43,8 @@ public class AuthenticationCredentialsNotFoundEventTests { @Test(expected = IllegalArgumentException.class) public void testRejectsNulls3() { - new AuthenticationCredentialsNotFoundEvent(new SimpleMethodInvocation(), - SecurityConfig.createList("TEST"), null); + new AuthenticationCredentialsNotFoundEvent(new SimpleMethodInvocation(), SecurityConfig.createList("TEST"), + null); } + } diff --git a/core/src/test/java/org/springframework/security/access/AuthorizationFailureEventTests.java b/core/src/test/java/org/springframework/security/access/AuthorizationFailureEventTests.java index 31a198fddf..b6f5766cfe 100644 --- a/core/src/test/java/org/springframework/security/access/AuthorizationFailureEventTests.java +++ b/core/src/test/java/org/springframework/security/access/AuthorizationFailureEventTests.java @@ -16,14 +16,15 @@ package org.springframework.security.access; -import static org.assertj.core.api.Assertions.assertThat; +import java.util.List; import org.junit.Test; + import org.springframework.security.access.event.AuthorizationFailureEvent; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.util.SimpleMethodInvocation; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; /** * Tests {@link AuthorizationFailureEvent}. @@ -31,39 +32,40 @@ import java.util.*; * @author Ben Alex */ public class AuthorizationFailureEventTests { - private final UsernamePasswordAuthenticationToken foo = new UsernamePasswordAuthenticationToken( - "foo", "bar"); + + private final UsernamePasswordAuthenticationToken foo = new UsernamePasswordAuthenticationToken("foo", "bar"); + private List attributes = SecurityConfig.createList("TEST"); - private AccessDeniedException exception = new AuthorizationServiceException("error", - new Throwable()); + + private AccessDeniedException exception = new AuthorizationServiceException("error", new Throwable()); @Test(expected = IllegalArgumentException.class) public void rejectsNullSecureObject() { - new AuthorizationFailureEvent(null, attributes, foo, exception); + new AuthorizationFailureEvent(null, this.attributes, this.foo, this.exception); } @Test(expected = IllegalArgumentException.class) public void rejectsNullAttributesList() { - new AuthorizationFailureEvent(new SimpleMethodInvocation(), null, foo, exception); + new AuthorizationFailureEvent(new SimpleMethodInvocation(), null, this.foo, this.exception); } @Test(expected = IllegalArgumentException.class) public void rejectsNullAuthentication() { - new AuthorizationFailureEvent(new SimpleMethodInvocation(), attributes, null, - exception); + new AuthorizationFailureEvent(new SimpleMethodInvocation(), this.attributes, null, this.exception); } @Test(expected = IllegalArgumentException.class) public void rejectsNullException() { - new AuthorizationFailureEvent(new SimpleMethodInvocation(), attributes, foo, null); + new AuthorizationFailureEvent(new SimpleMethodInvocation(), this.attributes, this.foo, null); } @Test public void gettersReturnCtorSuppliedData() { - AuthorizationFailureEvent event = new AuthorizationFailureEvent(new Object(), - attributes, foo, exception); - assertThat(event.getConfigAttributes()).isSameAs(attributes); - assertThat(event.getAccessDeniedException()).isSameAs(exception); - assertThat(event.getAuthentication()).isSameAs(foo); + AuthorizationFailureEvent event = new AuthorizationFailureEvent(new Object(), this.attributes, this.foo, + this.exception); + assertThat(event.getConfigAttributes()).isSameAs(this.attributes); + assertThat(event.getAccessDeniedException()).isSameAs(this.exception); + assertThat(event.getAuthentication()).isSameAs(this.foo); } + } diff --git a/core/src/test/java/org/springframework/security/access/AuthorizedEventTests.java b/core/src/test/java/org/springframework/security/access/AuthorizedEventTests.java index dff9e761ae..9b6fee171e 100644 --- a/core/src/test/java/org/springframework/security/access/AuthorizedEventTests.java +++ b/core/src/test/java/org/springframework/security/access/AuthorizedEventTests.java @@ -17,7 +17,7 @@ package org.springframework.security.access; import org.junit.Test; -import org.springframework.security.access.SecurityConfig; + import org.springframework.security.access.event.AuthorizedEvent; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.util.SimpleMethodInvocation; @@ -37,14 +37,12 @@ public class AuthorizedEventTests { @Test(expected = IllegalArgumentException.class) public void testRejectsNulls2() { - - new AuthorizedEvent(new SimpleMethodInvocation(), null, - new UsernamePasswordAuthenticationToken("foo", "bar")); + new AuthorizedEvent(new SimpleMethodInvocation(), null, new UsernamePasswordAuthenticationToken("foo", "bar")); } @Test(expected = IllegalArgumentException.class) public void testRejectsNulls3() { - new AuthorizedEvent(new SimpleMethodInvocation(), - SecurityConfig.createList("TEST"), null); + new AuthorizedEvent(new SimpleMethodInvocation(), SecurityConfig.createList("TEST"), null); } + } diff --git a/core/src/test/java/org/springframework/security/access/SecurityConfigTests.java b/core/src/test/java/org/springframework/security/access/SecurityConfigTests.java index 3dc60932fd..7cc22aff20 100644 --- a/core/src/test/java/org/springframework/security/access/SecurityConfigTests.java +++ b/core/src/test/java/org/springframework/security/access/SecurityConfigTests.java @@ -16,13 +16,10 @@ package org.springframework.security.access; +import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import org.junit.Test; -import org.springframework.security.access.ConfigAttribute; -import org.springframework.security.access.SecurityConfig; - /** * Tests {@link SecurityConfig}. * @@ -30,9 +27,6 @@ import org.springframework.security.access.SecurityConfig; */ public class SecurityConfigTests { - // ~ Methods - // ======================================================================================================== - @Test public void testHashCode() { SecurityConfig config = new SecurityConfig("TEST"); @@ -59,23 +53,17 @@ public class SecurityConfigTests { SecurityConfig security1 = new SecurityConfig("TEST"); SecurityConfig security2 = new SecurityConfig("TEST"); assertThat(security2).isEqualTo(security1); - // SEC-311: Must observe symmetry requirement of Object.equals(Object) contract String securityString1 = "TEST"; assertThat(securityString1).isNotSameAs(security1); - String securityString2 = "NOT_EQUAL"; assertThat(!security1.equals(securityString2)).isTrue(); - SecurityConfig security3 = new SecurityConfig("NOT_EQUAL"); assertThat(!security1.equals(security3)).isTrue(); - MockConfigAttribute mock1 = new MockConfigAttribute("TEST"); assertThat(security1).isEqualTo(mock1); - MockConfigAttribute mock2 = new MockConfigAttribute("NOT_EQUAL"); assertThat(security1).isNotEqualTo(mock2); - Integer int1 = 987; assertThat(security1).isNotEqualTo(int1); } @@ -86,18 +74,19 @@ public class SecurityConfigTests { assertThat(config.toString()).isEqualTo("TEST"); } - // ~ Inner Classes - // ================================================================================================== - private class MockConfigAttribute implements ConfigAttribute { + private String attribute; MockConfigAttribute(String configuration) { this.attribute = configuration; } + @Override public String getAttribute() { return this.attribute; } + } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/BusinessService.java b/core/src/test/java/org/springframework/security/access/annotation/BusinessService.java index 2494083b2e..16c63b7e28 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/BusinessService.java +++ b/core/src/test/java/org/springframework/security/access/annotation/BusinessService.java @@ -29,8 +29,6 @@ import org.springframework.security.access.prepost.PreAuthorize; @Secured({ "ROLE_USER" }) @PermitAll public interface BusinessService extends Serializable { - // ~ Methods - // ======================================================================================================== @Secured({ "ROLE_ADMIN" }) @RolesAllowed({ "ROLE_ADMIN" }) diff --git a/core/src/test/java/org/springframework/security/access/annotation/BusinessServiceImpl.java b/core/src/test/java/org/springframework/security/access/annotation/BusinessServiceImpl.java index 5d20ea9a4a..0e732bf480 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/BusinessServiceImpl.java +++ b/core/src/test/java/org/springframework/security/access/annotation/BusinessServiceImpl.java @@ -13,29 +13,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; import java.util.ArrayList; import java.util.List; /** - * * @author Joe Scalise */ public class BusinessServiceImpl implements BusinessService { + @Override @Secured({ "ROLE_USER" }) public void someUserMethod1() { } + @Override @Secured({ "ROLE_USER" }) public void someUserMethod2() { } + @Override @Secured({ "ROLE_USER", "ROLE_ADMIN" }) public void someUserAndAdminMethod() { } + @Override @Secured({ "ROLE_ADMIN" }) public void someAdminMethod() { } @@ -44,27 +48,33 @@ public class BusinessServiceImpl implements BusinessService { return entity; } + @Override public int someOther(String s) { return 0; } + @Override public int someOther(int input) { return input; } + @Override public List methodReturningAList(List someList) { return someList; } + @Override public List methodReturningAList(String userName, String arg2) { return new ArrayList<>(); } + @Override public Object[] methodReturningAnArray(Object[] someArray) { return null; } + @Override public void rolesAllowedUser() { - } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/Entity.java b/core/src/test/java/org/springframework/security/access/annotation/Entity.java index 7d68adb48f..b6bc9fd7a9 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/Entity.java +++ b/core/src/test/java/org/springframework/security/access/annotation/Entity.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; /** @@ -22,6 +23,8 @@ package org.springframework.security.access.annotation; * */ public class Entity { + public Entity(String someParameter) { } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/ExpressionProtectedBusinessServiceImpl.java b/core/src/test/java/org/springframework/security/access/annotation/ExpressionProtectedBusinessServiceImpl.java index 9652dab648..9d1b066d01 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/ExpressionProtectedBusinessServiceImpl.java +++ b/core/src/test/java/org/springframework/security/access/annotation/ExpressionProtectedBusinessServiceImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; import java.util.ArrayList; @@ -24,36 +25,45 @@ import org.springframework.security.access.prepost.PreFilter; public class ExpressionProtectedBusinessServiceImpl implements BusinessService { + @Override public void someAdminMethod() { } + @Override public int someOther(String s) { return 0; } + @Override public int someOther(int input) { return 0; } + @Override public void someUserAndAdminMethod() { } + @Override public void someUserMethod1() { } + @Override public void someUserMethod2() { } + @Override @PreFilter(filterTarget = "someList", value = "filterObject == authentication.name or filterObject == 'sam'") @PostFilter("filterObject == 'bob'") public List methodReturningAList(List someList) { return someList; } + @Override public List methodReturningAList(String userName, String arg2) { return new ArrayList<>(); } + @Override @PostFilter("filterObject == 'bob'") public Object[] methodReturningAnArray(Object[] someArray) { return someArray; @@ -61,10 +71,10 @@ public class ExpressionProtectedBusinessServiceImpl implements BusinessService { @PreAuthorize("#x == 'x' and @number.intValue() == 1294 ") public void methodWithBeanNamePropertyAccessExpression(String x) { - } + @Override public void rolesAllowedUser() { - } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/Jsr250BusinessServiceImpl.java b/core/src/test/java/org/springframework/security/access/annotation/Jsr250BusinessServiceImpl.java index c09c0bab3c..09aa5ae48c 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/Jsr250BusinessServiceImpl.java +++ b/core/src/test/java/org/springframework/security/access/annotation/Jsr250BusinessServiceImpl.java @@ -13,59 +13,69 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation; import java.util.ArrayList; import java.util.List; -import javax.annotation.security.RolesAllowed; import javax.annotation.security.PermitAll; +import javax.annotation.security.RolesAllowed; /** - * * @author Luke Taylor */ @PermitAll public class Jsr250BusinessServiceImpl implements BusinessService { + @Override @RolesAllowed("ROLE_USER") public void someUserMethod1() { } + @Override @RolesAllowed("ROLE_USER") public void someUserMethod2() { } + @Override @RolesAllowed({ "ROLE_USER", "ROLE_ADMIN" }) public void someUserAndAdminMethod() { } + @Override @RolesAllowed("ROLE_ADMIN") public void someAdminMethod() { } + @Override public int someOther(String input) { return 0; } + @Override public int someOther(int input) { return input; } + @Override public List methodReturningAList(List someList) { return someList; } + @Override public List methodReturningAList(String userName, String arg2) { return new ArrayList<>(); } + @Override public Object[] methodReturningAnArray(Object[] someArray) { return null; } + @Override @RolesAllowed({ "USER" }) public void rolesAllowedUser() { - } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSourceTests.java b/core/src/test/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSourceTests.java index 2e2597160b..642674a887 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSourceTests.java +++ b/core/src/test/java/org/springframework/security/access/annotation/Jsr250MethodSecurityMetadataSourceTests.java @@ -49,8 +49,7 @@ public class Jsr250MethodSecurityMetadataSourceTests { } private ConfigAttribute[] findAttributes(String methodName) throws Exception { - return this.mds.findAttributes(this.a.getClass().getMethod(methodName), null) - .toArray(new ConfigAttribute[0]); + return this.mds.findAttributes(this.a.getClass().getMethod(methodName), null).toArray(new ConfigAttribute[0]); } @Test @@ -64,8 +63,7 @@ public class Jsr250MethodSecurityMetadataSourceTests { public void permitAllMethodHasPermitAllAttribute() throws Exception { ConfigAttribute[] accessAttributes = findAttributes("permitAllMethod"); assertThat(accessAttributes).hasSize(1); - assertThat(accessAttributes[0].toString()) - .isEqualTo("javax.annotation.security.PermitAll"); + assertThat(accessAttributes[0].toString()).isEqualTo("javax.annotation.security.PermitAll"); } @Test @@ -77,15 +75,15 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void classRoleIsAppliedToNoRoleMethod() throws Exception { - Collection accessAttributes = this.mds.findAttributes( - this.userAllowed.getClass().getMethod("noRoleMethod"), null); + Collection accessAttributes = this.mds + .findAttributes(this.userAllowed.getClass().getMethod("noRoleMethod"), null); assertThat(accessAttributes).isNull(); } @Test public void methodRoleOverridesClassRole() throws Exception { - Collection accessAttributes = this.mds.findAttributes( - this.userAllowed.getClass().getMethod("adminMethod"), null); + Collection accessAttributes = this.mds + .findAttributes(this.userAllowed.getClass().getMethod("adminMethod"), null); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes.toArray()[0].toString()).isEqualTo("ROLE_ADMIN"); } @@ -93,7 +91,6 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void customDefaultRolePrefix() throws Exception { this.mds.setDefaultRolePrefix("CUSTOMPREFIX_"); - ConfigAttribute[] accessAttributes = findAttributes("adminMethod"); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes[0].toString()).isEqualTo("CUSTOMPREFIX_ADMIN"); @@ -102,7 +99,6 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void emptyDefaultRolePrefix() throws Exception { this.mds.setDefaultRolePrefix(""); - ConfigAttribute[] accessAttributes = findAttributes("adminMethod"); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes[0].toString()).isEqualTo("ADMIN"); @@ -111,7 +107,6 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void nullDefaultRolePrefix() throws Exception { this.mds.setDefaultRolePrefix(null); - ConfigAttribute[] accessAttributes = findAttributes("adminMethod"); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes[0].toString()).isEqualTo("ADMIN"); @@ -125,32 +120,24 @@ public class Jsr250MethodSecurityMetadataSourceTests { } // JSR-250 Spec Tests - /** * Class-level annotations only affect the class they annotate and their members, that * is, its methods and fields. They never affect a member declared by a superclass, * even if it is not hidden or overridden by the class in question. - * * @throws Exception */ @Test - public void classLevelAnnotationsOnlyAffectTheClassTheyAnnotateAndTheirMembers() - throws Exception { + public void classLevelAnnotationsOnlyAffectTheClassTheyAnnotateAndTheirMembers() throws Exception { Child target = new Child(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "notOverriden"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "notOverriden"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).isNull(); } @Test - public void classLevelAnnotationsOnlyAffectTheClassTheyAnnotateAndTheirMembersOverriden() - throws Exception { + public void classLevelAnnotationsOnlyAffectTheClassTheyAnnotateAndTheirMembersOverriden() throws Exception { Child target = new Child(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "overriden"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "overriden"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes.toArray()[0].toString()).isEqualTo("ROLE_DERIVED"); @@ -159,21 +146,16 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void classLevelAnnotationsImpactMemberLevel() throws Exception { Child target = new Child(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "defaults"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "defaults"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes.toArray()[0].toString()).isEqualTo("ROLE_DERIVED"); } @Test - public void classLevelAnnotationsIgnoredByExplicitMemberAnnotation() - throws Exception { + public void classLevelAnnotationsIgnoredByExplicitMemberAnnotation() throws Exception { Child target = new Child(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "explicitMethod"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "explicitMethod"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes.toArray()[0].toString()).isEqualTo("ROLE_EXPLICIT"); @@ -182,15 +164,12 @@ public class Jsr250MethodSecurityMetadataSourceTests { /** * The interfaces implemented by a class never contribute annotations to the class * itself or any of its members. - * * @throws Exception */ @Test public void interfacesNeverContributeAnnotationsMethodLevel() throws Exception { Parent target = new Parent(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "interfaceMethod"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "interfaceMethod"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).isEmpty(); } @@ -198,9 +177,7 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void interfacesNeverContributeAnnotationsClassLevel() throws Exception { Parent target = new Parent(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "notOverriden"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "notOverriden"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).isEmpty(); } @@ -208,17 +185,12 @@ public class Jsr250MethodSecurityMetadataSourceTests { @Test public void annotationsOnOverriddenMemberIgnored() throws Exception { Child target = new Child(); - MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), - "overridenIgnored"); - + MockMethodInvocation mi = new MockMethodInvocation(target, target.getClass(), "overridenIgnored"); Collection accessAttributes = this.mds.getAttributes(mi); assertThat(accessAttributes).hasSize(1); assertThat(accessAttributes.toArray()[0].toString()).isEqualTo("ROLE_DERIVED"); } - // ~ Inner Classes - // ====================================================================================================== - public static class A { public void noRoleMethod() { @@ -235,6 +207,7 @@ public class Jsr250MethodSecurityMetadataSourceTests { @PermitAll public void permitAllMethod() { } + } @RolesAllowed("USER") @@ -246,19 +219,21 @@ public class Jsr250MethodSecurityMetadataSourceTests { @RolesAllowed("ADMIN") public void adminMethod() { } + } // JSR-250 Spec - @RolesAllowed("IPARENT") interface IParent { @RolesAllowed("INTERFACEMETHOD") void interfaceMethod(); + } static class Parent implements IParent { + @Override public void interfaceMethod() { } @@ -271,6 +246,7 @@ public class Jsr250MethodSecurityMetadataSourceTests { @RolesAllowed("OVERRIDENIGNORED") public void overridenIgnored() { } + } @RolesAllowed("DERIVED") @@ -290,5 +266,7 @@ public class Jsr250MethodSecurityMetadataSourceTests { @RolesAllowed("EXPLICIT") public void explicitMethod() { } + } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/Jsr250VoterTests.java b/core/src/test/java/org/springframework/security/access/annotation/Jsr250VoterTests.java index 6157b00309..412d2fe93f 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/Jsr250VoterTests.java +++ b/core/src/test/java/org/springframework/security/access/annotation/Jsr250VoterTests.java @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.annotation; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.access.annotation; import java.util.ArrayList; import java.util.List; import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.authentication.TestingAuthenticationToken; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Luke Taylor */ public class Jsr250VoterTests { @@ -37,24 +38,19 @@ public class Jsr250VoterTests { public void supportsMultipleRolesCorrectly() { List attrs = new ArrayList<>(); Jsr250Voter voter = new Jsr250Voter(); - attrs.add(new Jsr250SecurityConfig("A")); attrs.add(new Jsr250SecurityConfig("B")); attrs.add(new Jsr250SecurityConfig("C")); - - assertThat(voter.vote( - new TestingAuthenticationToken("user", "pwd", "A"), new Object(), attrs)).isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); - assertThat(voter.vote( - new TestingAuthenticationToken("user", "pwd", "B"), new Object(), attrs)).isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); - assertThat(voter.vote( - new TestingAuthenticationToken("user", "pwd", "C"), new Object(), attrs)).isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); - - assertThat(voter.vote( - new TestingAuthenticationToken("user", "pwd", "NONE"), new Object(), - attrs)).isEqualTo(AccessDecisionVoter.ACCESS_DENIED); - - assertThat(voter.vote( - new TestingAuthenticationToken("user", "pwd", "A"), new Object(), + assertThat(voter.vote(new TestingAuthenticationToken("user", "pwd", "A"), new Object(), attrs)) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + assertThat(voter.vote(new TestingAuthenticationToken("user", "pwd", "B"), new Object(), attrs)) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + assertThat(voter.vote(new TestingAuthenticationToken("user", "pwd", "C"), new Object(), attrs)) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + assertThat(voter.vote(new TestingAuthenticationToken("user", "pwd", "NONE"), new Object(), attrs)) + .isEqualTo(AccessDecisionVoter.ACCESS_DENIED); + assertThat(voter.vote(new TestingAuthenticationToken("user", "pwd", "A"), new Object(), SecurityConfig.createList("A", "B", "C"))).isEqualTo(AccessDecisionVoter.ACCESS_ABSTAIN); } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSourceTests.java b/core/src/test/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSourceTests.java index ad60044daa..a607b56874 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSourceTests.java +++ b/core/src/test/java/org/springframework/security/access/annotation/SecuredAnnotationSecurityMetadataSourceTests.java @@ -16,9 +16,6 @@ package org.springframework.security.access.annotation; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.lang.annotation.ElementType; import java.lang.annotation.Inherited; import java.lang.annotation.Retention; @@ -31,12 +28,16 @@ import java.util.EnumSet; import java.util.List; import org.junit.Test; + import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.annotation.sec2150.MethodInvocationFactory; import org.springframework.security.access.intercept.method.MockMethodInvocation; import org.springframework.security.core.GrantedAuthority; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests for * {@link org.springframework.security.access.annotation.SecuredAnnotationSecurityMetadataSource} @@ -47,106 +48,72 @@ import org.springframework.security.core.GrantedAuthority; * @author Luke Taylor */ public class SecuredAnnotationSecurityMetadataSourceTests { - // ~ Instance fields - // ================================================================================================ private SecuredAnnotationSecurityMetadataSource mds = new SecuredAnnotationSecurityMetadataSource(); - // ~ Methods - // ======================================================================================================== - @Test public void genericsSuperclassDeclarationsAreIncludedWhenSubclassesOverride() { Method method = null; - try { - method = DepartmentServiceImpl.class.getMethod("someUserMethod3", - new Class[] { Department.class }); + method = DepartmentServiceImpl.class.getMethod("someUserMethod3", new Class[] { Department.class }); } catch (NoSuchMethodException unexpected) { fail("Should be a superMethod called 'someUserMethod3' on class!"); } - - Collection attrs = mds.findAttributes(method, - DepartmentServiceImpl.class); - + Collection attrs = this.mds.findAttributes(method, DepartmentServiceImpl.class); assertThat(attrs).isNotNull(); - // expect 1 attribute assertThat(attrs.size() == 1).as("Did not find 1 attribute").isTrue(); - // should have 1 SecurityConfig for (ConfigAttribute sc : attrs) { - assertThat(sc.getAttribute()).as("Found an incorrect role").isEqualTo( - "ROLE_ADMIN"); + assertThat(sc.getAttribute()).as("Found an incorrect role").isEqualTo("ROLE_ADMIN"); } - Method superMethod = null; - try { - superMethod = DepartmentServiceImpl.class.getMethod("someUserMethod3", - new Class[] { Entity.class }); + superMethod = DepartmentServiceImpl.class.getMethod("someUserMethod3", new Class[] { Entity.class }); } catch (NoSuchMethodException unexpected) { fail("Should be a superMethod called 'someUserMethod3' on class!"); } - - Collection superAttrs = this.mds.findAttributes(superMethod, - DepartmentServiceImpl.class); - + Collection superAttrs = this.mds.findAttributes(superMethod, DepartmentServiceImpl.class); assertThat(superAttrs).isNotNull(); - // This part of the test relates to SEC-274 // expect 1 attribute assertThat(superAttrs).as("Did not find 1 attribute").hasSize(1); // should have 1 SecurityConfig for (ConfigAttribute sc : superAttrs) { - assertThat(sc.getAttribute()).as("Found an incorrect role").isEqualTo( - "ROLE_ADMIN"); + assertThat(sc.getAttribute()).as("Found an incorrect role").isEqualTo("ROLE_ADMIN"); } } @Test public void classLevelAttributesAreFound() { - Collection attrs = this.mds.findAttributes( - BusinessService.class); - + Collection attrs = this.mds.findAttributes(BusinessService.class); assertThat(attrs).isNotNull(); - // expect 1 annotation assertThat(attrs).hasSize(1); - // should have 1 SecurityConfig SecurityConfig sc = (SecurityConfig) attrs.toArray()[0]; - assertThat(sc.getAttribute()).isEqualTo("ROLE_USER"); } @Test public void methodLevelAttributesAreFound() { Method method = null; - try { - method = BusinessService.class.getMethod("someUserAndAdminMethod", - new Class[] {}); + method = BusinessService.class.getMethod("someUserAndAdminMethod", new Class[] {}); } catch (NoSuchMethodException unexpected) { fail("Should be a method called 'someUserAndAdminMethod' on class!"); } - - Collection attrs = this.mds.findAttributes(method, - BusinessService.class); - + Collection attrs = this.mds.findAttributes(method, BusinessService.class); // expect 2 attributes assertThat(attrs).hasSize(2); - boolean user = false; boolean admin = false; - // should have 2 SecurityConfigs for (ConfigAttribute sc : attrs) { assertThat(sc).isInstanceOf(SecurityConfig.class); - if (sc.getAttribute().equals("ROLE_USER")) { user = true; } @@ -154,7 +121,6 @@ public class SecuredAnnotationSecurityMetadataSourceTests { admin = true; } } - // expect to have ROLE_USER and ROLE_ADMIN assertThat(user).isEqualTo(admin).isTrue(); } @@ -164,20 +130,15 @@ public class SecuredAnnotationSecurityMetadataSourceTests { public void customAnnotationAttributesAreFound() { SecuredAnnotationSecurityMetadataSource mds = new SecuredAnnotationSecurityMetadataSource( new CustomSecurityAnnotationMetadataExtractor()); - Collection attrs = mds.findAttributes( - CustomAnnotatedService.class); + Collection attrs = mds.findAttributes(CustomAnnotatedService.class); assertThat(attrs).containsOnly(SecurityEnum.ADMIN); } @Test public void annotatedAnnotationAtClassLevelIsDetected() throws Exception { - MockMethodInvocation annotatedAtClassLevel = new MockMethodInvocation( - new AnnotatedAnnotationAtClassLevel(), ReturnVoid.class, "doSomething", - List.class); - - ConfigAttribute[] attrs = mds.getAttributes(annotatedAtClassLevel).toArray( - new ConfigAttribute[0]); - + MockMethodInvocation annotatedAtClassLevel = new MockMethodInvocation(new AnnotatedAnnotationAtClassLevel(), + ReturnVoid.class, "doSomething", List.class); + ConfigAttribute[] attrs = this.mds.getAttributes(annotatedAtClassLevel).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs).extracting("attribute").containsOnly("CUSTOM"); } @@ -185,24 +146,17 @@ public class SecuredAnnotationSecurityMetadataSourceTests { @Test public void annotatedAnnotationAtInterfaceLevelIsDetected() throws Exception { MockMethodInvocation annotatedAtInterfaceLevel = new MockMethodInvocation( - new AnnotatedAnnotationAtInterfaceLevel(), ReturnVoid2.class, - "doSomething", List.class); - - ConfigAttribute[] attrs = mds.getAttributes(annotatedAtInterfaceLevel).toArray( - new ConfigAttribute[0]); - + new AnnotatedAnnotationAtInterfaceLevel(), ReturnVoid2.class, "doSomething", List.class); + ConfigAttribute[] attrs = this.mds.getAttributes(annotatedAtInterfaceLevel).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs).extracting("attribute").containsOnly("CUSTOM"); } @Test public void annotatedAnnotationAtMethodLevelIsDetected() throws Exception { - MockMethodInvocation annotatedAtMethodLevel = new MockMethodInvocation( - new AnnotatedAnnotationAtMethodLevel(), ReturnVoid.class, "doSomething", - List.class); - ConfigAttribute[] attrs = mds.getAttributes(annotatedAtMethodLevel).toArray( - new ConfigAttribute[0]); - + MockMethodInvocation annotatedAtMethodLevel = new MockMethodInvocation(new AnnotatedAnnotationAtMethodLevel(), + ReturnVoid.class, "doSomething", List.class); + ConfigAttribute[] attrs = this.mds.getAttributes(annotatedAtMethodLevel).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs).extracting("attribute").containsOnly("CUSTOM"); } @@ -210,7 +164,7 @@ public class SecuredAnnotationSecurityMetadataSourceTests { @Test public void proxyFactoryInterfaceAttributesFound() throws Exception { MockMethodInvocation mi = MethodInvocationFactory.createSec2150MethodInvocation(); - Collection attributes = mds.getAttributes(mi); + Collection attributes = this.mds.getAttributes(mi); assertThat(attributes).hasSize(1); assertThat(attributes).extracting("attribute").containsOnly("ROLE_PERSON"); } @@ -221,61 +175,69 @@ public class SecuredAnnotationSecurityMetadataSourceTests { Department(String name) { super(name); } + } interface DepartmentService extends BusinessService { @Secured({ "ROLE_USER" }) Department someUserMethod3(Department dept); + } @SuppressWarnings("serial") - class DepartmentServiceImpl extends BusinessServiceImpl - implements DepartmentService { + class DepartmentServiceImpl extends BusinessServiceImpl implements DepartmentService { + @Override @Secured({ "ROLE_ADMIN" }) public Department someUserMethod3(final Department dept) { return super.someUserMethod3(dept); } + } // SEC-1491 Related classes. PoC for custom annotation with enum value. - @CustomSecurityAnnotation(SecurityEnum.ADMIN) interface CustomAnnotatedService { + } class CustomAnnotatedServiceImpl implements CustomAnnotatedService { + } enum SecurityEnum implements ConfigAttribute, GrantedAuthority { + ADMIN, USER; + @Override public String getAttribute() { return toString(); } + @Override public String getAuthority() { return toString(); } + } @Target({ ElementType.METHOD, ElementType.TYPE }) @Retention(RetentionPolicy.RUNTIME) @interface CustomSecurityAnnotation { - SecurityEnum[]value(); + SecurityEnum[] value(); + } - class CustomSecurityAnnotationMetadataExtractor - implements AnnotationMetadataExtractor { + class CustomSecurityAnnotationMetadataExtractor implements AnnotationMetadataExtractor { - public Collection extractAttributes( - CustomSecurityAnnotation securityAnnotation) { + @Override + public Collection extractAttributes(CustomSecurityAnnotation securityAnnotation) { SecurityEnum[] values = securityAnnotation.value(); - return EnumSet.copyOf(Arrays.asList(values)); } + } @Target({ ElementType.METHOD, ElementType.TYPE }) @@ -283,36 +245,46 @@ public class SecuredAnnotationSecurityMetadataSourceTests { @Inherited @Secured("CUSTOM") public @interface AnnotatedAnnotation { + } public interface ReturnVoid { void doSomething(List param); + } @AnnotatedAnnotation public interface ReturnVoid2 { void doSomething(List param); + } @AnnotatedAnnotation public static class AnnotatedAnnotationAtClassLevel implements ReturnVoid { + @Override public void doSomething(List param) { } + } public static class AnnotatedAnnotationAtInterfaceLevel implements ReturnVoid2 { + @Override public void doSomething(List param) { } + } public static class AnnotatedAnnotationAtMethodLevel implements ReturnVoid { + @Override @AnnotatedAnnotation public void doSomething(List param) { } + } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/sec2150/CrudRepository.java b/core/src/test/java/org/springframework/security/access/annotation/sec2150/CrudRepository.java index 67e38337bf..f0cefba8a9 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/sec2150/CrudRepository.java +++ b/core/src/test/java/org/springframework/security/access/annotation/sec2150/CrudRepository.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation.sec2150; public interface CrudRepository { Iterable findAll(); -} \ No newline at end of file + +} diff --git a/core/src/test/java/org/springframework/security/access/annotation/sec2150/MethodInvocationFactory.java b/core/src/test/java/org/springframework/security/access/annotation/sec2150/MethodInvocationFactory.java index fdc35340af..eb0c440713 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/sec2150/MethodInvocationFactory.java +++ b/core/src/test/java/org/springframework/security/access/annotation/sec2150/MethodInvocationFactory.java @@ -13,25 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation.sec2150; import org.springframework.aop.framework.ProxyFactory; import org.springframework.security.access.intercept.method.MockMethodInvocation; -public class MethodInvocationFactory { +public final class MethodInvocationFactory { + + private MethodInvocationFactory() { + } /** * In order to reproduce the bug for SEC-2150, we must have a proxy object that * implements TargetSourceAware and implements our annotated interface. - * - * @return + * @return the mock method invocation * @throws NoSuchMethodException */ - public static MockMethodInvocation createSec2150MethodInvocation() - throws NoSuchMethodException { + public static MockMethodInvocation createSec2150MethodInvocation() throws NoSuchMethodException { ProxyFactory factory = new ProxyFactory(new Class[] { PersonRepository.class }); factory.setTargetClass(CrudRepository.class); PersonRepository repository = (PersonRepository) factory.getProxy(); return new MockMethodInvocation(repository, PersonRepository.class, "findAll"); } + } diff --git a/core/src/test/java/org/springframework/security/access/annotation/sec2150/PersonRepository.java b/core/src/test/java/org/springframework/security/access/annotation/sec2150/PersonRepository.java index 59dcae3a1b..ed70ebaa59 100644 --- a/core/src/test/java/org/springframework/security/access/annotation/sec2150/PersonRepository.java +++ b/core/src/test/java/org/springframework/security/access/annotation/sec2150/PersonRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.annotation.sec2150; import org.springframework.security.access.annotation.Secured; @@ -28,4 +29,5 @@ import org.springframework.security.access.prepost.PreAuthorize; @Secured("ROLE_PERSON") @PreAuthorize("hasRole('ROLE_PERSON')") public interface PersonRepository extends CrudRepository { -} \ No newline at end of file + +} diff --git a/core/src/test/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandlerTests.java b/core/src/test/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandlerTests.java index 7b371e123c..f8c6b653a2 100644 --- a/core/src/test/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandlerTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/AbstractSecurityExpressionHandlerTests.java @@ -13,14 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; -import static org.assertj.core.api.Assertions.assertThat; - -import static org.mockito.Mockito.mock; - import org.junit.Before; import org.junit.Test; + import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -28,18 +26,22 @@ import org.springframework.expression.Expression; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + /** * @author Luke Taylor */ public class AbstractSecurityExpressionHandlerTests { + private AbstractSecurityExpressionHandler handler; @Before public void setUp() { - handler = new AbstractSecurityExpressionHandler() { + this.handler = new AbstractSecurityExpressionHandler() { @Override - protected SecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, Object o) { + protected SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, + Object o) { return new SecurityExpressionRoot(authentication) { }; } @@ -48,38 +50,38 @@ public class AbstractSecurityExpressionHandlerTests { @Test public void beanNamesAreCorrectlyResolved() { - handler.setApplicationContext(new AnnotationConfigApplicationContext( - TestConfiguration.class)); - - Expression expression = handler.getExpressionParser().parseExpression( - "@number10.compareTo(@number20) < 0"); - assertThat(expression.getValue(handler.createEvaluationContext( - mock(Authentication.class), new Object()))).isEqualTo(true); + this.handler.setApplicationContext(new AnnotationConfigApplicationContext(TestConfiguration.class)); + Expression expression = this.handler.getExpressionParser() + .parseExpression("@number10.compareTo(@number20) < 0"); + assertThat(expression.getValue(this.handler.createEvaluationContext(mock(Authentication.class), new Object()))) + .isEqualTo(true); } @Test(expected = IllegalArgumentException.class) public void setExpressionParserNull() { - handler.setExpressionParser(null); + this.handler.setExpressionParser(null); } @Test public void setExpressionParser() { SpelExpressionParser parser = new SpelExpressionParser(); - handler.setExpressionParser(parser); - assertThat(parser == handler.getExpressionParser()).isTrue(); - } -} - -@Configuration -class TestConfiguration { - - @Bean - Integer number10() { - return 10; - } - - @Bean - Integer number20() { - return 20; + this.handler.setExpressionParser(parser); + assertThat(parser == this.handler.getExpressionParser()).isTrue(); } + + @Configuration + static class TestConfiguration { + + @Bean + Integer number10() { + return 10; + } + + @Bean + Integer number20() { + return 20; + } + + } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/SecurityExpressionRootTests.java b/core/src/test/java/org/springframework/security/access/expression/SecurityExpressionRootTests.java index 3efa1cefc1..9cb6564f61 100644 --- a/core/src/test/java/org/springframework/security/access/expression/SecurityExpressionRootTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/SecurityExpressionRootTests.java @@ -13,126 +13,125 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.assertj.core.api.Assertions.*; - - import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** - * * @author Luke Taylor * @since 3.0 */ public class SecurityExpressionRootTests { - final static Authentication JOE = new TestingAuthenticationToken("joe", "pass", - "ROLE_A", "ROLE_B"); + + static final Authentication JOE = new TestingAuthenticationToken("joe", "pass", "ROLE_A", "ROLE_B"); SecurityExpressionRoot root; @Before public void setup() { - root = new SecurityExpressionRoot(JOE) { + this.root = new SecurityExpressionRoot(JOE) { }; } @Test public void denyAllIsFalsePermitAllTrue() { - assertThat(root.denyAll()).isFalse(); - assertThat(root.denyAll).isFalse(); - assertThat(root.permitAll()).isTrue(); - assertThat(root.permitAll).isTrue(); + assertThat(this.root.denyAll()).isFalse(); + assertThat(this.root.denyAll).isFalse(); + assertThat(this.root.permitAll()).isTrue(); + assertThat(this.root.permitAll).isTrue(); } @Test public void rememberMeIsCorrectlyDetected() { AuthenticationTrustResolver atr = mock(AuthenticationTrustResolver.class); - root.setTrustResolver(atr); - when(atr.isRememberMe(JOE)).thenReturn(true); - assertThat(root.isRememberMe()).isTrue(); - assertThat(root.isFullyAuthenticated()).isFalse(); + this.root.setTrustResolver(atr); + given(atr.isRememberMe(JOE)).willReturn(true); + assertThat(this.root.isRememberMe()).isTrue(); + assertThat(this.root.isFullyAuthenticated()).isFalse(); } @Test public void roleHierarchySupportIsCorrectlyUsedInEvaluatingRoles() { - root.setRoleHierarchy(authorities -> AuthorityUtils.createAuthorityList("ROLE_C")); - - assertThat(root.hasRole("C")).isTrue(); - assertThat(root.hasAuthority("ROLE_C")).isTrue(); - assertThat(root.hasRole("A")).isFalse(); - assertThat(root.hasRole("B")).isFalse(); - assertThat(root.hasAnyRole("C", "A", "B")).isTrue(); - assertThat(root.hasAnyAuthority("ROLE_C", "ROLE_A", "ROLE_B")).isTrue(); - assertThat(root.hasAnyRole("A", "B")).isFalse(); + this.root.setRoleHierarchy((authorities) -> AuthorityUtils.createAuthorityList("ROLE_C")); + assertThat(this.root.hasRole("C")).isTrue(); + assertThat(this.root.hasAuthority("ROLE_C")).isTrue(); + assertThat(this.root.hasRole("A")).isFalse(); + assertThat(this.root.hasRole("B")).isFalse(); + assertThat(this.root.hasAnyRole("C", "A", "B")).isTrue(); + assertThat(this.root.hasAnyAuthority("ROLE_C", "ROLE_A", "ROLE_B")).isTrue(); + assertThat(this.root.hasAnyRole("A", "B")).isFalse(); } @Test public void hasRoleAddsDefaultPrefix() { - assertThat(root.hasRole("A")).isTrue(); - assertThat(root.hasRole("NO")).isFalse(); + assertThat(this.root.hasRole("A")).isTrue(); + assertThat(this.root.hasRole("NO")).isFalse(); } @Test public void hasRoleEmptyPrefixDoesNotAddsDefaultPrefix() { - root.setDefaultRolePrefix(""); - assertThat(root.hasRole("A")).isFalse(); - assertThat(root.hasRole("ROLE_A")).isTrue(); + this.root.setDefaultRolePrefix(""); + assertThat(this.root.hasRole("A")).isFalse(); + assertThat(this.root.hasRole("ROLE_A")).isTrue(); } @Test public void hasRoleNullPrefixDoesNotAddsDefaultPrefix() { - root.setDefaultRolePrefix(null); - assertThat(root.hasRole("A")).isFalse(); - assertThat(root.hasRole("ROLE_A")).isTrue(); + this.root.setDefaultRolePrefix(null); + assertThat(this.root.hasRole("A")).isFalse(); + assertThat(this.root.hasRole("ROLE_A")).isTrue(); } @Test public void hasRoleDoesNotAddDefaultPrefixForAlreadyPrefixedRoles() { SecurityExpressionRoot root = new SecurityExpressionRoot(JOE) { }; - assertThat(root.hasRole("ROLE_A")).isTrue(); assertThat(root.hasRole("ROLE_NO")).isFalse(); } @Test public void hasAnyRoleAddsDefaultPrefix() { - assertThat(root.hasAnyRole("NO", "A")).isTrue(); - assertThat(root.hasAnyRole("NO", "NOT")).isFalse(); + assertThat(this.root.hasAnyRole("NO", "A")).isTrue(); + assertThat(this.root.hasAnyRole("NO", "NOT")).isFalse(); } @Test public void hasAnyRoleDoesNotAddDefaultPrefixForAlreadyPrefixedRoles() { - assertThat(root.hasAnyRole("ROLE_NO", "ROLE_A")).isTrue(); - assertThat(root.hasAnyRole("ROLE_NO", "ROLE_NOT")).isFalse(); + assertThat(this.root.hasAnyRole("ROLE_NO", "ROLE_A")).isTrue(); + assertThat(this.root.hasAnyRole("ROLE_NO", "ROLE_NOT")).isFalse(); } @Test public void hasAnyRoleEmptyPrefixDoesNotAddsDefaultPrefix() { - root.setDefaultRolePrefix(""); - assertThat(root.hasRole("A")).isFalse(); - assertThat(root.hasRole("ROLE_A")).isTrue(); + this.root.setDefaultRolePrefix(""); + assertThat(this.root.hasRole("A")).isFalse(); + assertThat(this.root.hasRole("ROLE_A")).isTrue(); } @Test public void hasAnyRoleNullPrefixDoesNotAddsDefaultPrefix() { - root.setDefaultRolePrefix(null); - assertThat(root.hasAnyRole("A")).isFalse(); - assertThat(root.hasAnyRole("ROLE_A")).isTrue(); + this.root.setDefaultRolePrefix(null); + assertThat(this.root.hasAnyRole("A")).isFalse(); + assertThat(this.root.hasAnyRole("ROLE_A")).isTrue(); } @Test public void hasAuthorityDoesNotAddDefaultPrefix() { - assertThat(root.hasAuthority("A")).isFalse(); - assertThat(root.hasAnyAuthority("NO", "A")).isFalse(); - assertThat(root.hasAnyAuthority("ROLE_A", "NOT")).isTrue(); + assertThat(this.root.hasAuthority("A")).isFalse(); + assertThat(this.root.hasAnyAuthority("NO", "A")).isFalse(); + assertThat(this.root.hasAnyAuthority("ROLE_A", "NOT")).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java b/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java index 85dd94acd4..0cc3343ca5 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import java.util.HashMap; @@ -37,24 +38,30 @@ import org.springframework.security.core.context.SecurityContextHolder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @RunWith(MockitoJUnitRunner.class) public class DefaultMethodSecurityExpressionHandlerTests { + private DefaultMethodSecurityExpressionHandler handler; @Mock private Authentication authentication; + @Mock private MethodInvocation methodInvocation; + @Mock private AuthenticationTrustResolver trustResolver; @Before public void setup() { - handler = new DefaultMethodSecurityExpressionHandler(); - when(methodInvocation.getThis()).thenReturn(new Foo()); - when(methodInvocation.getMethod()).thenReturn(Foo.class.getMethods()[0]); + this.handler = new DefaultMethodSecurityExpressionHandler(); + given(this.methodInvocation.getThis()).willReturn(new Foo()); + given(this.methodInvocation.getMethod()).willReturn(Foo.class.getMethods()[0]); } @After @@ -64,20 +71,16 @@ public class DefaultMethodSecurityExpressionHandlerTests { @Test(expected = IllegalArgumentException.class) public void setTrustResolverNull() { - handler.setTrustResolver(null); + this.handler.setTrustResolver(null); } @Test public void createEvaluationContextCustomTrustResolver() { - handler.setTrustResolver(trustResolver); - - Expression expression = handler.getExpressionParser() - .parseExpression("anonymous"); - EvaluationContext context = handler.createEvaluationContext(authentication, - methodInvocation); + this.handler.setTrustResolver(this.trustResolver); + Expression expression = this.handler.getExpressionParser().parseExpression("anonymous"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); expression.getValue(context, Boolean.class); - - verify(trustResolver).isAnonymous(authentication); + verify(this.trustResolver).isAnonymous(this.authentication); } @Test @@ -87,14 +90,9 @@ public class DefaultMethodSecurityExpressionHandlerTests { map.put("key1", "value1"); map.put("key2", "value2"); map.put("key3", "value3"); - - Expression expression = handler.getExpressionParser().parseExpression("filterObject.key eq 'key2'"); - - EvaluationContext context = handler.createEvaluationContext(authentication, - methodInvocation); - - Object filtered = handler.filter(map, expression, context); - + Expression expression = this.handler.getExpressionParser().parseExpression("filterObject.key eq 'key2'"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(map, expression, context); assertThat(filtered == map); Map result = ((Map) filtered); assertThat(result.size() == 1); @@ -109,14 +107,9 @@ public class DefaultMethodSecurityExpressionHandlerTests { map.put("key1", "value1"); map.put("key2", "value2"); map.put("key3", "value3"); - - Expression expression = handler.getExpressionParser().parseExpression("filterObject.value eq 'value3'"); - - EvaluationContext context = handler.createEvaluationContext(authentication, - methodInvocation); - - Object filtered = handler.filter(map, expression, context); - + Expression expression = this.handler.getExpressionParser().parseExpression("filterObject.value eq 'value3'"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(map, expression, context); assertThat(filtered == map); Map result = ((Map) filtered); assertThat(result.size() == 1); @@ -131,14 +124,10 @@ public class DefaultMethodSecurityExpressionHandlerTests { map.put("key1", "value1"); map.put("key2", "value2"); map.put("key3", "value3"); - - Expression expression = handler.getExpressionParser().parseExpression("(filterObject.key eq 'key1') or (filterObject.value eq 'value2')"); - - EvaluationContext context = handler.createEvaluationContext(authentication, - methodInvocation); - - Object filtered = handler.filter(map, expression, context); - + Expression expression = this.handler.getExpressionParser() + .parseExpression("(filterObject.key eq 'key1') or (filterObject.value eq 'value2')"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(map, expression, context); assertThat(filtered == map); Map result = ((Map) filtered); assertThat(result.size() == 2); @@ -150,14 +139,9 @@ public class DefaultMethodSecurityExpressionHandlerTests { @SuppressWarnings("unchecked") public void filterWhenUsingStreamThenFiltersStream() { final Stream stream = Stream.of("1", "2", "3"); - - Expression expression = handler.getExpressionParser().parseExpression("filterObject ne '2'"); - - EvaluationContext context = handler.createEvaluationContext(authentication, - methodInvocation); - - Object filtered = handler.filter(stream, expression, context); - + Expression expression = this.handler.getExpressionParser().parseExpression("filterObject ne '2'"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(stream, expression, context); assertThat(filtered).isInstanceOf(Stream.class); List list = ((Stream) filtered).collect(Collectors.toList()); assertThat(list).containsExactly("1", "3"); @@ -167,18 +151,17 @@ public class DefaultMethodSecurityExpressionHandlerTests { public void filterStreamWhenClosedThenUpstreamGetsClosed() { final Stream upstream = mock(Stream.class); doReturn(Stream.empty()).when(upstream).filter(any()); - - Expression expression = handler.getExpressionParser().parseExpression("true"); - - EvaluationContext context = handler.createEvaluationContext(authentication, - methodInvocation); - - ((Stream) handler.filter(upstream, expression, context)).close(); + Expression expression = this.handler.getExpressionParser().parseExpression("true"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + ((Stream) this.handler.filter(upstream, expression, context)).close(); verify(upstream).close(); } - private static class Foo { - public void bar(){ + static class Foo { + + void bar() { } + } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdviceTests.java b/core/src/test/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdviceTests.java index 27095641d8..8f942bb330 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdviceTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/ExpressionBasedPreInvocationAdviceTests.java @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.security.access.intercept.method.MockMethodInvocation; import org.springframework.security.access.prepost.PreInvocationAttribute; import org.springframework.security.core.Authentication; -import java.util.ArrayList; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -45,100 +47,69 @@ public class ExpressionBasedPreInvocationAdviceTests { @Before public void setUp() { - expressionBasedPreInvocationAdvice = new ExpressionBasedPreInvocationAdvice(); + this.expressionBasedPreInvocationAdvice = new ExpressionBasedPreInvocationAdvice(); } @Test(expected = IllegalArgumentException.class) public void findFilterTargetNameProvidedButNotMatch() throws Exception { - //given - PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", - "filterTargetDoesNotMatch", - null); + PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "filterTargetDoesNotMatch", + null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingCollection", - new Class[]{List.class}, - new Object[]{new ArrayList<>()}); - //when - then - expressionBasedPreInvocationAdvice.before(authentication, methodInvocation, attribute); + "doSomethingCollection", new Class[] { List.class }, new Object[] { new ArrayList<>() }); + this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, attribute); } @Test(expected = IllegalArgumentException.class) public void findFilterTargetNameProvidedArrayUnsupported() throws Exception { - //given - PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", - "param", null); + PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "param", null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingArray", - new Class[]{String[].class}, - new Object[]{new String[0]}); - //when - then - expressionBasedPreInvocationAdvice.before(authentication, methodInvocation, attribute); + "doSomethingArray", new Class[] { String[].class }, new Object[] { new String[0] }); + this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, attribute); } @Test public void findFilterTargetNameProvided() throws Exception { - //given PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "param", null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingCollection", - new Class[]{List.class}, - new Object[]{new ArrayList<>()}); - - //when - boolean result = expressionBasedPreInvocationAdvice - .before(authentication, methodInvocation, attribute); - //then + "doSomethingCollection", new Class[] { List.class }, new Object[] { new ArrayList<>() }); + boolean result = this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, + attribute); assertThat(result).isTrue(); } @Test(expected = IllegalArgumentException.class) public void findFilterTargetNameNotProvidedArrayUnsupported() throws Exception { - //given PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "", null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingArray", - new Class[]{String[].class}, - new Object[]{new String[0]}); - //when - then - expressionBasedPreInvocationAdvice.before(authentication, methodInvocation, attribute); + "doSomethingArray", new Class[] { String[].class }, new Object[] { new String[0] }); + this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, attribute); } @Test public void findFilterTargetNameNotProvided() throws Exception { - //given PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "", null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingCollection", - new Class[]{List.class}, - new Object[]{new ArrayList<>()}); - //when - boolean result = expressionBasedPreInvocationAdvice.before(authentication, methodInvocation, attribute); - //then + "doSomethingCollection", new Class[] { List.class }, new Object[] { new ArrayList<>() }); + boolean result = this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, + attribute); assertThat(result).isTrue(); } @Test(expected = IllegalArgumentException.class) public void findFilterTargetNameNotProvidedTypeNotSupported() throws Exception { - //given PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "", null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingString", - new Class[]{String.class}, - new Object[]{"param"}); - //when - then - expressionBasedPreInvocationAdvice.before(authentication, methodInvocation, attribute); + "doSomethingString", new Class[] { String.class }, new Object[] { "param" }); + this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, attribute); } @Test(expected = IllegalArgumentException.class) public void findFilterTargetNameNotProvidedMethodAcceptMoreThenOneArgument() throws Exception { - //given PreInvocationAttribute attribute = new PreInvocationExpressionAttribute("true", "", null); MockMethodInvocation methodInvocation = new MockMethodInvocation(new TestClass(), TestClass.class, - "doSomethingTwoArgs", - new Class[]{String.class, List.class}, - new Object[]{"param", new ArrayList<>()}); - //when - then - expressionBasedPreInvocationAdvice.before(authentication, methodInvocation, attribute); + "doSomethingTwoArgs", new Class[] { String.class, List.class }, + new Object[] { "param", new ArrayList<>() }); + this.expressionBasedPreInvocationAdvice.before(this.authentication, methodInvocation, attribute); } private class TestClass { @@ -158,5 +129,7 @@ public class ExpressionBasedPreInvocationAdviceTests { public Boolean doSomethingTwoArgs(String param, List list) { return Boolean.TRUE; } + } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/MethodExpressionVoterTests.java b/core/src/test/java/org/springframework/security/access/expression/method/MethodExpressionVoterTests.java index 0a738f2ec2..d409c4054d 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/MethodExpressionVoterTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/MethodExpressionVoterTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.expression.method; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.access.expression.method; import java.lang.reflect.Method; import java.util.ArrayList; @@ -25,57 +24,55 @@ import java.util.List; import org.aopalliance.intercept.MethodInvocation; import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; -import org.springframework.security.access.expression.method.PreInvocationExpressionAttribute; import org.springframework.security.access.prepost.PreInvocationAuthorizationAdviceVoter; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.util.SimpleMethodInvocation; +import static org.assertj.core.api.Assertions.assertThat; + @SuppressWarnings("unchecked") public class MethodExpressionVoterTests { - private TestingAuthenticationToken joe = new TestingAuthenticationToken("joe", - "joespass", "ROLE_blah"); + + private TestingAuthenticationToken joe = new TestingAuthenticationToken("joe", "joespass", "ROLE_blah"); + private PreInvocationAuthorizationAdviceVoter am = new PreInvocationAuthorizationAdviceVoter( new ExpressionBasedPreInvocationAdvice()); @Test public void hasRoleExpressionAllowsUserWithRole() throws Exception { - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingAnArray()); - assertThat(am.vote(joe, mi, - createAttributes(new PreInvocationExpressionAttribute(null, null, - "hasRole('blah')")))).isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingAnArray()); + assertThat(this.am.vote(this.joe, mi, + createAttributes(new PreInvocationExpressionAttribute(null, null, "hasRole('blah')")))) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); } @Test public void hasRoleExpressionDeniesUserWithoutRole() throws Exception { List cad = new ArrayList<>(1); cad.add(new PreInvocationExpressionAttribute(null, null, "hasRole('joedoesnt')")); - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingAnArray()); - assertThat(am.vote(joe, mi, cad)).isEqualTo(AccessDecisionVoter.ACCESS_DENIED); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingAnArray()); + assertThat(this.am.vote(this.joe, mi, cad)).isEqualTo(AccessDecisionVoter.ACCESS_DENIED); } @Test public void matchingArgAgainstAuthenticationNameIsSuccessful() throws Exception { - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingAString(), "joe"); - assertThat(am.vote(joe, mi, + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingAString(), "joe"); + assertThat(this.am.vote(this.joe, mi, createAttributes(new PreInvocationExpressionAttribute(null, null, "(#argument == principal) and (principal == 'joe')")))) - .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); } @Test public void accessIsGrantedIfNoPreAuthorizeAttributeIsUsed() throws Exception { Collection arg = createCollectionArg("joe", "bob", "sam"); - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingACollection(), arg); - assertThat(am.vote(joe, mi, - createAttributes(new PreInvocationExpressionAttribute( - "(filterObject == 'jim')", "collection", null)))) - .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingACollection(), arg); + assertThat(this.am.vote(this.joe, mi, + createAttributes(new PreInvocationExpressionAttribute("(filterObject == 'jim')", "collection", null)))) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); // All objects should have been removed, because the expression is always false assertThat(arg).isEmpty(); } @@ -83,45 +80,40 @@ public class MethodExpressionVoterTests { @Test public void collectionPreFilteringIsSuccessful() throws Exception { List arg = createCollectionArg("joe", "bob", "sam"); - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingACollection(), arg); - am.vote(joe, mi, createAttributes(new PreInvocationExpressionAttribute( - "(filterObject == 'joe' or filterObject == 'sam')", "collection", - "permitAll"))); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingACollection(), arg); + this.am.vote(this.joe, mi, createAttributes(new PreInvocationExpressionAttribute( + "(filterObject == 'joe' or filterObject == 'sam')", "collection", "permitAll"))); assertThat(arg).containsExactly("joe", "sam"); } @Test(expected = IllegalArgumentException.class) public void arraysCannotBePrefiltered() throws Exception { - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingAnArray(), createArrayArg("sam", "joe")); - am.vote(joe, mi, createAttributes(new PreInvocationExpressionAttribute( - "(filterObject == 'jim')", "someArray", null))); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingAnArray(), + createArrayArg("sam", "joe")); + this.am.vote(this.joe, mi, + createAttributes(new PreInvocationExpressionAttribute("(filterObject == 'jim')", "someArray", null))); } @Test(expected = IllegalArgumentException.class) public void incorrectFilterTargetNameIsRejected() throws Exception { - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingACollection(), createCollectionArg("joe", "bob")); - am.vote(joe, mi, createAttributes(new PreInvocationExpressionAttribute( - "(filterObject == 'joe')", "collcetion", null))); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingACollection(), + createCollectionArg("joe", "bob")); + this.am.vote(this.joe, mi, + createAttributes(new PreInvocationExpressionAttribute("(filterObject == 'joe')", "collcetion", null))); } @Test(expected = IllegalArgumentException.class) public void nullNamedFilterTargetIsRejected() throws Exception { - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingACollection(), new Object[] { null }); - am.vote(joe, mi, createAttributes(new PreInvocationExpressionAttribute( - "(filterObject == 'joe')", "collection", null))); + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingACollection(), + new Object[] { null }); + this.am.vote(this.joe, mi, + createAttributes(new PreInvocationExpressionAttribute("(filterObject == 'joe')", "collection", null))); } @Test public void ruleDefinedInAClassMethodIsApplied() throws Exception { - MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), - methodTakingAString(), "joe"); - assertThat( - - am.vote(joe, mi, + MethodInvocation mi = new SimpleMethodInvocation(new TargetImpl(), methodTakingAString(), "joe"); + assertThat(this.am.vote(this.joe, mi, createAttributes(new PreInvocationExpressionAttribute(null, null, "T(org.springframework.security.access.expression.method.SecurityRules).isJoe(#argument)")))) .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); @@ -155,26 +147,31 @@ public class MethodExpressionVoterTests { return Target.class.getMethod("methodTakingACollection", Collection.class); } - // ~ Inner Classes - // ================================================================================================== - private interface Target { + void methodTakingAnArray(Object[] args); void methodTakingAString(String argument); Collection methodTakingACollection(Collection collection); + } private static class TargetImpl implements Target { + + @Override public void methodTakingAnArray(Object[] args) { } + @Override public void methodTakingAString(String argument) { }; + @Override public Collection methodTakingACollection(Collection collection) { return collection; } + } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContextTests.java b/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContextTests.java index 18d7e9f787..a573263d25 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContextTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityEvaluationContextTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; import java.lang.reflect.Method; @@ -36,10 +37,13 @@ import static org.mockito.Mockito.doReturn; */ @RunWith(MockitoJUnitRunner.class) public class MethodSecurityEvaluationContextTests { + @Mock private ParameterNameDiscoverer paramNameDiscoverer; + @Mock private Authentication authentication; + @Mock private MethodInvocation methodInvocation; @@ -47,16 +51,16 @@ public class MethodSecurityEvaluationContextTests { public void lookupVariableWhenParameterNameNullThenNotSet() { Class type = String.class; Method method = ReflectionUtils.findMethod(String.class, "contains", CharSequence.class); - doReturn(new String[] {null}).when(paramNameDiscoverer).getParameterNames(method); - doReturn(new Object[]{null}).when(methodInvocation).getArguments(); - doReturn(type).when(methodInvocation).getThis(); - doReturn(method).when(methodInvocation).getMethod(); - NotNullVariableMethodSecurityEvaluationContext context= new NotNullVariableMethodSecurityEvaluationContext(authentication, methodInvocation, paramNameDiscoverer); + doReturn(new String[] { null }).when(this.paramNameDiscoverer).getParameterNames(method); + doReturn(new Object[] { null }).when(this.methodInvocation).getArguments(); + doReturn(type).when(this.methodInvocation).getThis(); + doReturn(method).when(this.methodInvocation).getMethod(); + NotNullVariableMethodSecurityEvaluationContext context = new NotNullVariableMethodSecurityEvaluationContext( + this.authentication, this.methodInvocation, this.paramNameDiscoverer); context.lookupVariable("testVariable"); } - private static class NotNullVariableMethodSecurityEvaluationContext - extends MethodSecurityEvaluationContext { + private static class NotNullVariableMethodSecurityEvaluationContext extends MethodSecurityEvaluationContext { NotNullVariableMethodSecurityEvaluationContext(Authentication auth, MethodInvocation mi, ParameterNameDiscoverer parameterNameDiscoverer) { @@ -65,12 +69,14 @@ public class MethodSecurityEvaluationContextTests { @Override public void setVariable(String name, @Nullable Object value) { - if ( name == null ) { + if (name == null) { throw new IllegalArgumentException("name should not be null"); } else { super.setVariable(name, value); } } + } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRootTests.java b/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRootTests.java index 2899956888..e6c8910fd8 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRootTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/MethodSecurityExpressionRootTests.java @@ -13,13 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.expression.method; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.access.expression.method; import org.junit.Before; import org.junit.Test; + import org.springframework.expression.Expression; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.expression.spel.support.StandardEvaluationContext; @@ -28,89 +27,94 @@ import org.springframework.security.access.expression.ExpressionUtils; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests for {@link MethodSecurityExpressionRoot} * * @author Luke Taylor */ public class MethodSecurityExpressionRootTests { + SpelExpressionParser parser = new SpelExpressionParser(); + MethodSecurityExpressionRoot root; + StandardEvaluationContext ctx; + private AuthenticationTrustResolver trustResolver; + private Authentication user; @Before public void createContext() { - user = mock(Authentication.class); - root = new MethodSecurityExpressionRoot(user); - ctx = new StandardEvaluationContext(); - ctx.setRootObject(root); - trustResolver = mock(AuthenticationTrustResolver.class); - root.setTrustResolver(trustResolver); + this.user = mock(Authentication.class); + this.root = new MethodSecurityExpressionRoot(this.user); + this.ctx = new StandardEvaluationContext(); + this.ctx.setRootObject(this.root); + this.trustResolver = mock(AuthenticationTrustResolver.class); + this.root.setTrustResolver(this.trustResolver); } @Test public void canCallMethodsOnVariables() { - ctx.setVariable("var", "somestring"); - Expression e = parser.parseExpression("#var.length() == 10"); - - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isTrue(); + this.ctx.setVariable("var", "somestring"); + Expression e = this.parser.parseExpression("#var.length() == 10"); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isTrue(); } @Test public void isAnonymousReturnsTrueIfTrustResolverReportsAnonymous() { - when(trustResolver.isAnonymous(user)).thenReturn(true); - assertThat(root.isAnonymous()).isTrue(); + given(this.trustResolver.isAnonymous(this.user)).willReturn(true); + assertThat(this.root.isAnonymous()).isTrue(); } @Test public void isAnonymousReturnsFalseIfTrustResolverReportsNonAnonymous() { - when(trustResolver.isAnonymous(user)).thenReturn(false); - assertThat(root.isAnonymous()).isFalse(); + given(this.trustResolver.isAnonymous(this.user)).willReturn(false); + assertThat(this.root.isAnonymous()).isFalse(); } @Test public void hasPermissionOnDomainObjectReturnsFalseIfPermissionEvaluatorDoes() { final Object dummyDomainObject = new Object(); final PermissionEvaluator pe = mock(PermissionEvaluator.class); - ctx.setVariable("domainObject", dummyDomainObject); - root.setPermissionEvaluator(pe); - when(pe.hasPermission(user, dummyDomainObject, "ignored")).thenReturn(false); - - assertThat(root.hasPermission(dummyDomainObject, "ignored")).isFalse(); - + this.ctx.setVariable("domainObject", dummyDomainObject); + this.root.setPermissionEvaluator(pe); + given(pe.hasPermission(this.user, dummyDomainObject, "ignored")).willReturn(false); + assertThat(this.root.hasPermission(dummyDomainObject, "ignored")).isFalse(); } @Test public void hasPermissionOnDomainObjectReturnsTrueIfPermissionEvaluatorDoes() { final Object dummyDomainObject = new Object(); final PermissionEvaluator pe = mock(PermissionEvaluator.class); - ctx.setVariable("domainObject", dummyDomainObject); - root.setPermissionEvaluator(pe); - when(pe.hasPermission(user, dummyDomainObject, "ignored")).thenReturn(true); - - assertThat(root.hasPermission(dummyDomainObject, "ignored")).isTrue(); + this.ctx.setVariable("domainObject", dummyDomainObject); + this.root.setPermissionEvaluator(pe); + given(pe.hasPermission(this.user, dummyDomainObject, "ignored")).willReturn(true); + assertThat(this.root.hasPermission(dummyDomainObject, "ignored")).isTrue(); } @Test public void hasPermissionOnDomainObjectWorksWithIntegerExpressions() { final Object dummyDomainObject = new Object(); - ctx.setVariable("domainObject", dummyDomainObject); + this.ctx.setVariable("domainObject", dummyDomainObject); final PermissionEvaluator pe = mock(PermissionEvaluator.class); - root.setPermissionEvaluator(pe); - when(pe.hasPermission(eq(user), eq(dummyDomainObject), any(Integer.class))) - .thenReturn(true).thenReturn(true).thenReturn(false); - - Expression e = parser.parseExpression("hasPermission(#domainObject, 0xA)"); + this.root.setPermissionEvaluator(pe); + given(pe.hasPermission(eq(this.user), eq(dummyDomainObject), any(Integer.class))).willReturn(true, true, false); + Expression e = this.parser.parseExpression("hasPermission(#domainObject, 0xA)"); // evaluator returns true - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isTrue(); - e = parser.parseExpression("hasPermission(#domainObject, 10)"); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isTrue(); + e = this.parser.parseExpression("hasPermission(#domainObject, 10)"); // evaluator returns true - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isTrue(); - e = parser.parseExpression("hasPermission(#domainObject, 0xFF)"); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isTrue(); + e = this.parser.parseExpression("hasPermission(#domainObject, 0xFF)"); // evaluator returns false, make sure return value matches - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isFalse(); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isFalse(); } @Test @@ -120,19 +124,18 @@ public class MethodSecurityExpressionRootTests { return "x"; } }; - root.setThis(targetObject); + this.root.setThis(targetObject); Integer i = 2; PermissionEvaluator pe = mock(PermissionEvaluator.class); - root.setPermissionEvaluator(pe); - when(pe.hasPermission(user, targetObject, i)).thenReturn(true).thenReturn(false); - when(pe.hasPermission(user, "x", i)).thenReturn(true); - - Expression e = parser.parseExpression("hasPermission(this, 2)"); - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isTrue(); - e = parser.parseExpression("hasPermission(this, 2)"); - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isFalse(); - - e = parser.parseExpression("hasPermission(this.x, 2)"); - assertThat(ExpressionUtils.evaluateAsBoolean(e, ctx)).isTrue(); + this.root.setPermissionEvaluator(pe); + given(pe.hasPermission(this.user, targetObject, i)).willReturn(true, false); + given(pe.hasPermission(this.user, "x", i)).willReturn(true); + Expression e = this.parser.parseExpression("hasPermission(this, 2)"); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isTrue(); + e = this.parser.parseExpression("hasPermission(this, 2)"); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isFalse(); + e = this.parser.parseExpression("hasPermission(this.x, 2)"); + assertThat(ExpressionUtils.evaluateAsBoolean(e, this.ctx)).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/PrePostAnnotationSecurityMetadataSourceTests.java b/core/src/test/java/org/springframework/security/access/expression/method/PrePostAnnotationSecurityMetadataSourceTests.java index 5c0cc78c49..c7e19fbf89 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/PrePostAnnotationSecurityMetadataSourceTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/PrePostAnnotationSecurityMetadataSourceTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.expression.method; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.access.expression.method; import java.lang.annotation.ElementType; import java.lang.annotation.Inherited; @@ -27,6 +26,7 @@ import java.util.List; import org.junit.Before; import org.junit.Test; + import org.springframework.expression.Expression; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.annotation.sec2150.MethodInvocationFactory; @@ -38,56 +38,56 @@ import org.springframework.security.access.prepost.PreFilter; import org.springframework.security.access.prepost.PrePostAnnotationSecurityMetadataSource; import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Luke Taylor * @since 3.0 */ public class PrePostAnnotationSecurityMetadataSourceTests { + private PrePostAnnotationSecurityMetadataSource mds = new PrePostAnnotationSecurityMetadataSource( - new ExpressionBasedAnnotationAttributeFactory( - new DefaultMethodSecurityExpressionHandler())); + new ExpressionBasedAnnotationAttributeFactory(new DefaultMethodSecurityExpressionHandler())); private MockMethodInvocation voidImpl1; + private MockMethodInvocation voidImpl2; + private MockMethodInvocation voidImpl3; + private MockMethodInvocation listImpl1; + private MockMethodInvocation notherListImpl1; + private MockMethodInvocation notherListImpl2; + private MockMethodInvocation annotatedAtClassLevel; + private MockMethodInvocation annotatedAtInterfaceLevel; + private MockMethodInvocation annotatedAtMethodLevel; @Before public void setUpData() throws Exception { - voidImpl1 = new MockMethodInvocation(new ReturnVoidImpl1(), ReturnVoid.class, + this.voidImpl1 = new MockMethodInvocation(new ReturnVoidImpl1(), ReturnVoid.class, "doSomething", List.class); + this.voidImpl2 = new MockMethodInvocation(new ReturnVoidImpl2(), ReturnVoid.class, "doSomething", List.class); + this.voidImpl3 = new MockMethodInvocation(new ReturnVoidImpl3(), ReturnVoid.class, "doSomething", List.class); + this.listImpl1 = new MockMethodInvocation(new ReturnAListImpl1(), ReturnAList.class, "doSomething", List.class); + this.notherListImpl1 = new MockMethodInvocation(new ReturnAnotherListImpl1(), ReturnAnotherList.class, "doSomething", List.class); - voidImpl2 = new MockMethodInvocation(new ReturnVoidImpl2(), ReturnVoid.class, + this.notherListImpl2 = new MockMethodInvocation(new ReturnAnotherListImpl2(), ReturnAnotherList.class, "doSomething", List.class); - voidImpl3 = new MockMethodInvocation(new ReturnVoidImpl3(), ReturnVoid.class, + this.annotatedAtClassLevel = new MockMethodInvocation(new CustomAnnotationAtClassLevel(), ReturnVoid.class, "doSomething", List.class); - listImpl1 = new MockMethodInvocation(new ReturnAListImpl1(), ReturnAList.class, + this.annotatedAtInterfaceLevel = new MockMethodInvocation(new CustomAnnotationAtInterfaceLevel(), + ReturnVoid2.class, "doSomething", List.class); + this.annotatedAtMethodLevel = new MockMethodInvocation(new CustomAnnotationAtMethodLevel(), ReturnVoid.class, "doSomething", List.class); - notherListImpl1 = new MockMethodInvocation(new ReturnAnotherListImpl1(), - ReturnAnotherList.class, "doSomething", List.class); - notherListImpl2 = new MockMethodInvocation(new ReturnAnotherListImpl2(), - ReturnAnotherList.class, "doSomething", List.class); - annotatedAtClassLevel = new MockMethodInvocation( - new CustomAnnotationAtClassLevel(), ReturnVoid.class, "doSomething", - List.class); - annotatedAtInterfaceLevel = new MockMethodInvocation( - new CustomAnnotationAtInterfaceLevel(), ReturnVoid2.class, "doSomething", - List.class); - annotatedAtMethodLevel = new MockMethodInvocation( - new CustomAnnotationAtMethodLevel(), ReturnVoid.class, "doSomething", - List.class); } @Test public void classLevelPreAnnotationIsPickedUpWhenNoMethodLevelExists() { - ConfigAttribute[] attrs = mds.getAttributes(voidImpl1).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.voidImpl1).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs[0] instanceof PreInvocationExpressionAttribute).isTrue(); PreInvocationExpressionAttribute pre = (PreInvocationExpressionAttribute) attrs[0]; @@ -98,9 +98,7 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Test public void mixedClassAndMethodPreAnnotationsAreBothIncluded() { - ConfigAttribute[] attrs = mds.getAttributes(voidImpl2).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.voidImpl2).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs[0] instanceof PreInvocationExpressionAttribute).isTrue(); PreInvocationExpressionAttribute pre = (PreInvocationExpressionAttribute) attrs[0]; @@ -111,9 +109,7 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Test public void methodWithPreFilterOnlyIsAllowed() { - ConfigAttribute[] attrs = mds.getAttributes(voidImpl3).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.voidImpl3).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs[0] instanceof PreInvocationExpressionAttribute).isTrue(); PreInvocationExpressionAttribute pre = (PreInvocationExpressionAttribute) attrs[0]; @@ -124,9 +120,7 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Test public void methodWithPostFilterOnlyIsAllowed() { - ConfigAttribute[] attrs = mds.getAttributes(listImpl1).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.listImpl1).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(2); assertThat(attrs[0] instanceof PreInvocationExpressionAttribute).isTrue(); assertThat(attrs[1] instanceof PostInvocationExpressionAttribute).isTrue(); @@ -139,9 +133,7 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Test public void interfaceAttributesAreIncluded() { - ConfigAttribute[] attrs = mds.getAttributes(notherListImpl1).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.notherListImpl1).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs[0] instanceof PreInvocationExpressionAttribute).isTrue(); PreInvocationExpressionAttribute pre = (PreInvocationExpressionAttribute) attrs[0]; @@ -153,9 +145,7 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Test public void classAttributesTakesPrecedeceOverInterfaceAttributes() { - ConfigAttribute[] attrs = mds.getAttributes(notherListImpl2).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.notherListImpl2).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); assertThat(attrs[0] instanceof PreInvocationExpressionAttribute).isTrue(); PreInvocationExpressionAttribute pre = (PreInvocationExpressionAttribute) attrs[0]; @@ -167,83 +157,95 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Test public void customAnnotationAtClassLevelIsDetected() { - ConfigAttribute[] attrs = mds.getAttributes(annotatedAtClassLevel).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.annotatedAtClassLevel).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); } @Test public void customAnnotationAtInterfaceLevelIsDetected() { - ConfigAttribute[] attrs = mds.getAttributes(annotatedAtInterfaceLevel).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.annotatedAtInterfaceLevel) + .toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); } @Test public void customAnnotationAtMethodLevelIsDetected() { - ConfigAttribute[] attrs = mds.getAttributes(annotatedAtMethodLevel).toArray( - new ConfigAttribute[0]); - + ConfigAttribute[] attrs = this.mds.getAttributes(this.annotatedAtMethodLevel).toArray(new ConfigAttribute[0]); assertThat(attrs).hasSize(1); } @Test public void proxyFactoryInterfaceAttributesFound() throws Exception { MockMethodInvocation mi = MethodInvocationFactory.createSec2150MethodInvocation(); - Collection attributes = mds.getAttributes(mi); + Collection attributes = this.mds.getAttributes(mi); assertThat(attributes).hasSize(1); - Expression expression = (Expression) ReflectionTestUtils.getField(attributes - .iterator().next(), "authorizeExpression"); + Expression expression = (Expression) ReflectionTestUtils.getField(attributes.iterator().next(), + "authorizeExpression"); assertThat(expression.getExpressionString()).isEqualTo("hasRole('ROLE_PERSON')"); } - // ~ Inner Classes - // ================================================================================================== - public interface ReturnVoid { + void doSomething(List param); + } public interface ReturnAList { + List doSomething(List param); + } @PreAuthorize("interfaceAuthzExpression") public interface ReturnAnotherList { + @PreAuthorize("interfaceMethodAuthzExpression") @PreFilter(filterTarget = "param", value = "interfacePreFilterExpression") List doSomething(List param); + } @PreAuthorize("someExpression") public static class ReturnVoidImpl1 implements ReturnVoid { + + @Override public void doSomething(List param) { } + } @PreAuthorize("someExpression") public static class ReturnVoidImpl2 implements ReturnVoid { + + @Override @PreFilter(filterTarget = "param", value = "somePreFilterExpression") public void doSomething(List param) { } + } public static class ReturnVoidImpl3 implements ReturnVoid { + + @Override @PreFilter(filterTarget = "param", value = "somePreFilterExpression") public void doSomething(List param) { } + } public static class ReturnAListImpl1 implements ReturnAList { + + @Override @PostFilter("somePostFilterExpression") public List doSomething(List param) { return param; } + } public static class ReturnAListImpl2 implements ReturnAList { + + @Override @PreAuthorize("someExpression") @PreFilter(filterTarget = "param", value = "somePreFilterExpression") @PostFilter("somePostFilterExpression") @@ -251,19 +253,26 @@ public class PrePostAnnotationSecurityMetadataSourceTests { public List doSomething(List param) { return param; } + } public static class ReturnAnotherListImpl1 implements ReturnAnotherList { + + @Override public List doSomething(List param) { return param; } + } public static class ReturnAnotherListImpl2 implements ReturnAnotherList { + + @Override @PreFilter(filterTarget = "param", value = "classMethodPreFilterExpression") public List doSomething(List param) { return param; } + } @Target({ ElementType.METHOD, ElementType.TYPE }) @@ -271,27 +280,40 @@ public class PrePostAnnotationSecurityMetadataSourceTests { @Inherited @PreAuthorize("customAnnotationExpression") public @interface CustomAnnotation { + } @CustomAnnotation public interface ReturnVoid2 { + void doSomething(List param); + } @CustomAnnotation public static class CustomAnnotationAtClassLevel implements ReturnVoid { + + @Override public void doSomething(List param) { } + } public static class CustomAnnotationAtInterfaceLevel implements ReturnVoid2 { + + @Override public void doSomething(List param) { } + } public static class CustomAnnotationAtMethodLevel implements ReturnVoid { + + @Override @CustomAnnotation public void doSomething(List param) { } + } + } diff --git a/core/src/test/java/org/springframework/security/access/expression/method/SecurityRules.java b/core/src/test/java/org/springframework/security/access/expression/method/SecurityRules.java index 2abe9a1118..af4715cf69 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/SecurityRules.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/SecurityRules.java @@ -13,9 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.expression.method; -public class SecurityRules { +public final class SecurityRules { + + private SecurityRules() { + } + public static boolean disallow() { return false; } @@ -27,4 +32,5 @@ public class SecurityRules { public static boolean isJoe(String s) { return "joe".equals(s); } + } diff --git a/core/src/test/java/org/springframework/security/access/hierarchicalroles/HierarchicalRolesTestHelper.java b/core/src/test/java/org/springframework/security/access/hierarchicalroles/HierarchicalRolesTestHelper.java index 901abbdf9d..b8df1e837e 100755 --- a/core/src/test/java/org/springframework/security/access/hierarchicalroles/HierarchicalRolesTestHelper.java +++ b/core/src/test/java/org/springframework/security/access/hierarchicalroles/HierarchicalRolesTestHelper.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; import java.util.ArrayList; import java.util.Collection; import java.util.List; -import org.springframework.security.core.GrantedAuthority; import org.apache.commons.collections.CollectionUtils; +import org.springframework.security.core.GrantedAuthority; + /** * Test helper class for the hierarchical roles tests. * @@ -29,13 +31,11 @@ import org.apache.commons.collections.CollectionUtils; */ public abstract class HierarchicalRolesTestHelper { - public static boolean containTheSameGrantedAuthorities( - Collection authorities1, + public static boolean containTheSameGrantedAuthorities(Collection authorities1, Collection authorities2) { if (authorities1 == null && authorities2 == null) { return true; } - if (authorities1 == null || authorities2 == null) { return false; } @@ -43,26 +43,21 @@ public abstract class HierarchicalRolesTestHelper { } public static boolean containTheSameGrantedAuthoritiesCompareByAuthorityString( - Collection authorities1, - Collection authorities2) { + Collection authorities1, Collection authorities2) { if (authorities1 == null && authorities2 == null) { return true; } - if (authorities1 == null || authorities2 == null) { return false; } - return CollectionUtils.isEqualCollection( - toCollectionOfAuthorityStrings(authorities1), + return CollectionUtils.isEqualCollection(toCollectionOfAuthorityStrings(authorities1), toCollectionOfAuthorityStrings(authorities2)); } - public static List toCollectionOfAuthorityStrings( - Collection authorities) { + public static List toCollectionOfAuthorityStrings(Collection authorities) { if (authorities == null) { return null; } - List result = new ArrayList<>(authorities.size()); for (GrantedAuthority authority : authorities) { result.add(authority.getAuthority()); @@ -72,12 +67,10 @@ public abstract class HierarchicalRolesTestHelper { public static List createAuthorityList(final String... roles) { List authorities = new ArrayList<>(roles.length); - for (final String role : roles) { // Use non SimpleGrantedAuthority (SEC-863) authorities.add((GrantedAuthority) () -> role); } - return authorities; } diff --git a/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapperTests.java b/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapperTests.java index 94f6ec2106..58beb183f3 100644 --- a/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapperTests.java +++ b/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyAuthoritiesMapperTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; -import static org.assertj.core.api.Assertions.assertThat; +import java.util.Collection; + +import org.junit.Test; -import org.junit.*; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor @@ -33,17 +35,12 @@ public class RoleHierarchyAuthoritiesMapperTests { RoleHierarchyImpl rh = new RoleHierarchyImpl(); rh.setHierarchy("ROLE_A > ROLE_B\nROLE_B > ROLE_C"); RoleHierarchyAuthoritiesMapper mapper = new RoleHierarchyAuthoritiesMapper(rh); - Collection authorities = mapper .mapAuthorities(AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_D")); - assertThat(authorities).hasSize(4); - mapper = new RoleHierarchyAuthoritiesMapper(new NullRoleHierarchy()); - - authorities = mapper.mapAuthorities(AuthorityUtils.createAuthorityList("ROLE_A", - "ROLE_D")); - + authorities = mapper.mapAuthorities(AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_D")); assertThat(authorities).hasSize(2); } + } diff --git a/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImplTests.java b/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImplTests.java index d25983b3ad..0bd68d1955 100644 --- a/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImplTests.java +++ b/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyImplTests.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.hierarchicalroles; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +package org.springframework.security.access.hierarchicalroles; import java.util.ArrayList; import java.util.List; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests for {@link RoleHierarchyImpl}. * @@ -36,155 +38,110 @@ public class RoleHierarchyImplTests { public void testRoleHierarchyWithNullOrEmptyAuthorities() { List authorities0 = null; List authorities1 = new ArrayList<>(); - RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B"); - - assertThat(roleHierarchyImpl.getReachableGrantedAuthorities( - authorities0)).isNotNull(); - assertThat( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities0)).isEmpty(); - - assertThat(roleHierarchyImpl.getReachableGrantedAuthorities( - authorities1)).isNotNull(); - assertThat( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1)).isEmpty(); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(authorities0)).isNotNull(); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(authorities0)).isEmpty(); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(authorities1)).isNotNull(); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(authorities1)).isEmpty(); } @Test public void testSimpleRoleHierarchy() { - - List authorities0 = AuthorityUtils.createAuthorityList( - "ROLE_0"); - List authorities1 = AuthorityUtils.createAuthorityList( - "ROLE_A"); - List authorities2 = AuthorityUtils.createAuthorityList("ROLE_A", - "ROLE_B"); - + List authorities0 = AuthorityUtils.createAuthorityList("ROLE_0"); + List authorities1 = AuthorityUtils.createAuthorityList("ROLE_A"); + List authorities2 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B"); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities0), - authorities0)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities0), authorities0)).isTrue(); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), - authorities2)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), authorities2)).isTrue(); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities2), - authorities2)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities2), authorities2)).isTrue(); } @Test public void testTransitiveRoleHierarchies() { - List authorities1 = AuthorityUtils.createAuthorityList( - "ROLE_A"); - List authorities2 = AuthorityUtils.createAuthorityList("ROLE_A", - "ROLE_B", "ROLE_C"); - List authorities3 = AuthorityUtils.createAuthorityList("ROLE_A", - "ROLE_B", "ROLE_C", "ROLE_D"); - + List authorities1 = AuthorityUtils.createAuthorityList("ROLE_A"); + List authorities2 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B", "ROLE_C"); + List authorities3 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B", "ROLE_C", + "ROLE_D"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\nROLE_B > ROLE_C"); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), - authorities2)).isTrue(); - - roleHierarchyImpl.setHierarchy( - "ROLE_A > ROLE_B\nROLE_B > ROLE_C\nROLE_C > ROLE_D"); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), authorities2)).isTrue(); + roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\nROLE_B > ROLE_C\nROLE_C > ROLE_D"); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), - authorities3)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), authorities3)).isTrue(); } @Test public void testComplexRoleHierarchy() { - List authoritiesInput1 = AuthorityUtils.createAuthorityList( - "ROLE_A"); - List authoritiesOutput1 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_B", "ROLE_C", "ROLE_D"); - List authoritiesInput2 = AuthorityUtils.createAuthorityList( - "ROLE_B"); - List authoritiesOutput2 = AuthorityUtils.createAuthorityList( - "ROLE_B", "ROLE_D"); - List authoritiesInput3 = AuthorityUtils.createAuthorityList( - "ROLE_C"); - List authoritiesOutput3 = AuthorityUtils.createAuthorityList( - "ROLE_C", "ROLE_D"); - List authoritiesInput4 = AuthorityUtils.createAuthorityList( + List authoritiesInput1 = AuthorityUtils.createAuthorityList("ROLE_A"); + List authoritiesOutput1 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B", "ROLE_C", "ROLE_D"); - List authoritiesOutput4 = AuthorityUtils.createAuthorityList( - "ROLE_D"); - + List authoritiesInput2 = AuthorityUtils.createAuthorityList("ROLE_B"); + List authoritiesOutput2 = AuthorityUtils.createAuthorityList("ROLE_B", "ROLE_D"); + List authoritiesInput3 = AuthorityUtils.createAuthorityList("ROLE_C"); + List authoritiesOutput3 = AuthorityUtils.createAuthorityList("ROLE_C", "ROLE_D"); + List authoritiesInput4 = AuthorityUtils.createAuthorityList("ROLE_D"); + List authoritiesOutput4 = AuthorityUtils.createAuthorityList("ROLE_D"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - roleHierarchyImpl.setHierarchy( - "ROLE_A > ROLE_B\nROLE_A > ROLE_C\nROLE_C > ROLE_D\nROLE_B > ROLE_D"); - + roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\nROLE_A > ROLE_C\nROLE_C > ROLE_D\nROLE_B > ROLE_D"); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput1), - authoritiesOutput1)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput1), authoritiesOutput1)).isTrue(); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput2), - authoritiesOutput2)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput2), authoritiesOutput2)).isTrue(); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput3), - authoritiesOutput3)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput3), authoritiesOutput3)).isTrue(); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput4), - authoritiesOutput4)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authoritiesInput4), authoritiesOutput4)).isTrue(); } @Test public void testCyclesInRoleHierarchy() { RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - try { roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_A"); fail("Cycle in role hierarchy was not detected!"); } - catch (CycleInRoleHierarchyException e) { + catch (CycleInRoleHierarchyException ex) { } - try { roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\nROLE_B > ROLE_A"); fail("Cycle in role hierarchy was not detected!"); } - catch (CycleInRoleHierarchyException e) { + catch (CycleInRoleHierarchyException ex) { } - try { - roleHierarchyImpl.setHierarchy( - "ROLE_A > ROLE_B\nROLE_B > ROLE_C\nROLE_C > ROLE_A"); + roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\nROLE_B > ROLE_C\nROLE_C > ROLE_A"); fail("Cycle in role hierarchy was not detected!"); } - catch (CycleInRoleHierarchyException e) { + catch (CycleInRoleHierarchyException ex) { } - try { roleHierarchyImpl.setHierarchy( "ROLE_A > ROLE_B\nROLE_B > ROLE_C\nROLE_C > ROLE_E\nROLE_E > ROLE_D\nROLE_D > ROLE_B"); fail("Cycle in role hierarchy was not detected!"); } - catch (CycleInRoleHierarchyException e) { + catch (CycleInRoleHierarchyException ex) { } - try { roleHierarchyImpl.setHierarchy("ROLE_C > ROLE_B\nROLE_B > ROLE_A\nROLE_A > ROLE_B"); fail("Cycle in role hierarchy was not detected!"); - } catch (CycleInRoleHierarchyException e) { + } + catch (CycleInRoleHierarchyException ex) { } } @Test public void testNoCyclesInRoleHierarchy() { RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - try { - roleHierarchyImpl.setHierarchy( - "ROLE_A > ROLE_B\nROLE_A > ROLE_C\nROLE_C > ROLE_D\nROLE_B > ROLE_D"); + roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\nROLE_A > ROLE_C\nROLE_C > ROLE_D\nROLE_B > ROLE_D"); } - catch (CycleInRoleHierarchyException e) { + catch (CycleInRoleHierarchyException ex) { fail("A cycle in role hierarchy was incorrectly detected!"); } } @@ -192,94 +149,70 @@ public class RoleHierarchyImplTests { // SEC-863 @Test public void testSimpleRoleHierarchyWithCustomGrantedAuthorityImplementation() { - - List authorities0 = HierarchicalRolesTestHelper.createAuthorityList( - "ROLE_0"); - List authorities1 = HierarchicalRolesTestHelper.createAuthorityList( - "ROLE_A"); - List authorities2 = HierarchicalRolesTestHelper.createAuthorityList( - "ROLE_A", "ROLE_B"); - + List authorities0 = HierarchicalRolesTestHelper.createAuthorityList("ROLE_0"); + List authorities1 = HierarchicalRolesTestHelper.createAuthorityList("ROLE_A"); + List authorities2 = HierarchicalRolesTestHelper.createAuthorityList("ROLE_A", "ROLE_B"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B"); - - assertThat( - HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities0), - authorities0)).isTrue(); - assertThat( - HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), - authorities2)).isTrue(); - assertThat( - HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities2), - authorities2)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString( + roleHierarchyImpl.getReachableGrantedAuthorities(authorities0), authorities0)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString( + roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), authorities2)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString( + roleHierarchyImpl.getReachableGrantedAuthorities(authorities2), authorities2)).isTrue(); } @Test public void testWhitespaceRoleHierarchies() { - List authorities1 = AuthorityUtils.createAuthorityList( - "ROLE A"); - List authorities2 = AuthorityUtils.createAuthorityList("ROLE A", - "ROLE B", "ROLE>C"); - List authorities3 = AuthorityUtils.createAuthorityList("ROLE A", - "ROLE B", "ROLE>C", "ROLE D"); - + List authorities1 = AuthorityUtils.createAuthorityList("ROLE A"); + List authorities2 = AuthorityUtils.createAuthorityList("ROLE A", "ROLE B", "ROLE>C"); + List authorities3 = AuthorityUtils.createAuthorityList("ROLE A", "ROLE B", "ROLE>C", + "ROLE D"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - roleHierarchyImpl.setHierarchy("ROLE A > ROLE B\nROLE B > ROLE>C"); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), - authorities2)).isTrue(); - - roleHierarchyImpl.setHierarchy( - "ROLE A > ROLE B\nROLE B > ROLE>C\nROLE>C > ROLE D"); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), authorities2)).isTrue(); + roleHierarchyImpl.setHierarchy("ROLE A > ROLE B\nROLE B > ROLE>C\nROLE>C > ROLE D"); assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), - authorities3)).isTrue(); + roleHierarchyImpl.getReachableGrantedAuthorities(authorities1), authorities3)).isTrue(); } // gh-6954 @Test public void testJavadoc() { - List flatAuthorities = AuthorityUtils.createAuthorityList( - "ROLE_A"); - List allAuthorities = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_B", "ROLE_AUTHENTICATED", "ROLE_UNAUTHENTICATED"); + List flatAuthorities = AuthorityUtils.createAuthorityList("ROLE_A"); + List allAuthorities = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B", + "ROLE_AUTHENTICATED", "ROLE_UNAUTHENTICATED"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B\n" - + "ROLE_B > ROLE_AUTHENTICATED\n" - + "ROLE_AUTHENTICATED > ROLE_UNAUTHENTICATED"); - - assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(flatAuthorities)).containsExactlyInAnyOrderElementsOf(allAuthorities); + roleHierarchyImpl.setHierarchy( + "ROLE_A > ROLE_B\n" + "ROLE_B > ROLE_AUTHENTICATED\n" + "ROLE_AUTHENTICATED > ROLE_UNAUTHENTICATED"); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(flatAuthorities)) + .containsExactlyInAnyOrderElementsOf(allAuthorities); } // gh-6954 @Test public void testInterfaceJavadoc() { - List flatAuthorities = AuthorityUtils.createAuthorityList( - "ROLE_HIGHEST"); - List allAuthorities = AuthorityUtils.createAuthorityList( - "ROLE_HIGHEST", "ROLE_HIGHER", "ROLE_LOW", "ROLE_LOWER"); + List flatAuthorities = AuthorityUtils.createAuthorityList("ROLE_HIGHEST"); + List allAuthorities = AuthorityUtils.createAuthorityList("ROLE_HIGHEST", "ROLE_HIGHER", + "ROLE_LOW", "ROLE_LOWER"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); - roleHierarchyImpl.setHierarchy("ROLE_HIGHEST > ROLE_HIGHER\n" - + "ROLE_HIGHER > ROLE_LOW\n" - + "ROLE_LOW > ROLE_LOWER"); - - assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(flatAuthorities)).containsExactlyInAnyOrderElementsOf(allAuthorities); + roleHierarchyImpl + .setHierarchy("ROLE_HIGHEST > ROLE_HIGHER\n" + "ROLE_HIGHER > ROLE_LOW\n" + "ROLE_LOW > ROLE_LOWER"); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(flatAuthorities)) + .containsExactlyInAnyOrderElementsOf(allAuthorities); } // gh-6954 @Test public void singleLineLargeHierarchy() { - List flatAuthorities = AuthorityUtils.createAuthorityList( - "ROLE_HIGHEST"); - List allAuthorities = AuthorityUtils.createAuthorityList( - "ROLE_HIGHEST", "ROLE_HIGHER", "ROLE_LOW", "ROLE_LOWER"); + List flatAuthorities = AuthorityUtils.createAuthorityList("ROLE_HIGHEST"); + List allAuthorities = AuthorityUtils.createAuthorityList("ROLE_HIGHEST", "ROLE_HIGHER", + "ROLE_LOW", "ROLE_LOWER"); RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); roleHierarchyImpl.setHierarchy("ROLE_HIGHEST > ROLE_HIGHER > ROLE_LOW > ROLE_LOWER"); - - assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(flatAuthorities)).containsExactlyInAnyOrderElementsOf(allAuthorities); + assertThat(roleHierarchyImpl.getReachableGrantedAuthorities(flatAuthorities)) + .containsExactlyInAnyOrderElementsOf(allAuthorities); } + } diff --git a/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtilsTests.java b/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtilsTests.java index 2fb503fa95..ae08fd1249 100644 --- a/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtilsTests.java +++ b/core/src/test/java/org/springframework/security/access/hierarchicalroles/RoleHierarchyUtilsTests.java @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.hierarchicalroles; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + import org.junit.Test; -import java.util.*; - -import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; /** @@ -28,6 +33,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class RoleHierarchyUtilsTests { + private static final String EOL = System.lineSeparator(); @Test @@ -38,14 +44,11 @@ public class RoleHierarchyUtilsTests { "ROLE_B > ROLE_D" + EOL + "ROLE_C > ROLE_D" + EOL; // @formatter:on - Map> roleHierarchyMap = new TreeMap<>(); - roleHierarchyMap.put("ROLE_A", asList("ROLE_B", "ROLE_C")); - roleHierarchyMap.put("ROLE_B", asList("ROLE_D")); - roleHierarchyMap.put("ROLE_C", asList("ROLE_D")); - + roleHierarchyMap.put("ROLE_A", Arrays.asList("ROLE_B", "ROLE_C")); + roleHierarchyMap.put("ROLE_B", Arrays.asList("ROLE_D")); + roleHierarchyMap.put("ROLE_C", Arrays.asList("ROLE_D")); String roleHierarchy = RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); - assertThat(roleHierarchy).isEqualTo(expectedRoleHierarchy); } @@ -62,16 +65,14 @@ public class RoleHierarchyUtilsTests { @Test(expected = IllegalArgumentException.class) public void roleHierarchyFromMapWhenRoleNullThenThrowsIllegalArgumentException() { Map> roleHierarchyMap = new HashMap<>(); - roleHierarchyMap.put(null, asList("ROLE_B", "ROLE_C")); - + roleHierarchyMap.put(null, Arrays.asList("ROLE_B", "ROLE_C")); RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); } @Test(expected = IllegalArgumentException.class) public void roleHierarchyFromMapWhenRoleEmptyThenThrowsIllegalArgumentException() { Map> roleHierarchyMap = new HashMap<>(); - roleHierarchyMap.put("", asList("ROLE_B", "ROLE_C")); - + roleHierarchyMap.put("", Arrays.asList("ROLE_B", "ROLE_C")); RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); } @@ -79,7 +80,6 @@ public class RoleHierarchyUtilsTests { public void roleHierarchyFromMapWhenImpliedRolesNullThenThrowsIllegalArgumentException() { Map> roleHierarchyMap = new HashMap<>(); roleHierarchyMap.put("ROLE_A", null); - RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); } @@ -87,7 +87,7 @@ public class RoleHierarchyUtilsTests { public void roleHierarchyFromMapWhenImpliedRolesEmptyThenThrowsIllegalArgumentException() { Map> roleHierarchyMap = new HashMap<>(); roleHierarchyMap.put("ROLE_A", Collections.emptyList()); - RoleHierarchyUtils.roleHierarchyFromMap(roleHierarchyMap); } + } diff --git a/core/src/test/java/org/springframework/security/access/hierarchicalroles/TestHelperTests.java b/core/src/test/java/org/springframework/security/access/hierarchicalroles/TestHelperTests.java index 21bf89df09..111b94b8f5 100644 --- a/core/src/test/java/org/springframework/security/access/hierarchicalroles/TestHelperTests.java +++ b/core/src/test/java/org/springframework/security/access/hierarchicalroles/TestHelperTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.hierarchicalroles; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.access.hierarchicalroles; import java.util.ArrayList; import java.util.Collection; @@ -23,9 +22,12 @@ import java.util.List; import org.apache.commons.collections.CollectionUtils; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests for {@link HierarchicalRolesTestHelper}. * @@ -35,159 +37,103 @@ public class TestHelperTests { @Test public void testContainTheSameGrantedAuthorities() { - List authorities1 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_B"); - List authorities2 = AuthorityUtils.createAuthorityList( - "ROLE_B", "ROLE_A"); - List authorities3 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_C"); - List authorities4 = AuthorityUtils - .createAuthorityList("ROLE_A"); - List authorities5 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_A"); - - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, - null)).isTrue(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities1)).isTrue(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities2)).isTrue(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities2, authorities1)).isTrue(); - - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, - authorities1)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, null)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities3)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities3, authorities1)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities4)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities4, authorities1)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities4, authorities5)).isFalse(); + List authorities1 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B"); + List authorities2 = AuthorityUtils.createAuthorityList("ROLE_B", "ROLE_A"); + List authorities3 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_C"); + List authorities4 = AuthorityUtils.createAuthorityList("ROLE_A"); + List authorities5 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_A"); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, null)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities1)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities2)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities2, authorities1)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, authorities1)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, null)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities3)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities3, authorities1)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities4)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities4, authorities1)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities4, authorities5)).isFalse(); } // SEC-863 @Test public void testToListOfAuthorityStrings() { - Collection authorities1 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_B"); - Collection authorities2 = AuthorityUtils.createAuthorityList( - "ROLE_B", "ROLE_A"); - Collection authorities3 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_C"); - Collection authorities4 = AuthorityUtils - .createAuthorityList("ROLE_A"); - Collection authorities5 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_A"); - + Collection authorities1 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B"); + Collection authorities2 = AuthorityUtils.createAuthorityList("ROLE_B", "ROLE_A"); + Collection authorities3 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_C"); + Collection authorities4 = AuthorityUtils.createAuthorityList("ROLE_A"); + Collection authorities5 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_A"); List authoritiesStrings1 = new ArrayList<>(); authoritiesStrings1.add("ROLE_A"); authoritiesStrings1.add("ROLE_B"); - List authoritiesStrings2 = new ArrayList<>(); authoritiesStrings2.add("ROLE_B"); authoritiesStrings2.add("ROLE_A"); - List authoritiesStrings3 = new ArrayList<>(); authoritiesStrings3.add("ROLE_A"); authoritiesStrings3.add("ROLE_C"); - List authoritiesStrings4 = new ArrayList<>(); authoritiesStrings4.add("ROLE_A"); - List authoritiesStrings5 = new ArrayList<>(); authoritiesStrings5.add("ROLE_A"); authoritiesStrings5.add("ROLE_A"); - assertThat(CollectionUtils.isEqualCollection( - HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities1), - authoritiesStrings1)).isTrue(); - + HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities1), authoritiesStrings1)) + .isTrue(); assertThat(CollectionUtils.isEqualCollection( - HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities2), - authoritiesStrings2)).isTrue(); - + HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities2), authoritiesStrings2)) + .isTrue(); assertThat(CollectionUtils.isEqualCollection( - HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities3), - authoritiesStrings3)).isTrue(); - + HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities3), authoritiesStrings3)) + .isTrue(); assertThat(CollectionUtils.isEqualCollection( - HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities4), - authoritiesStrings4)).isTrue(); - + HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities4), authoritiesStrings4)) + .isTrue(); assertThat(CollectionUtils.isEqualCollection( - HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities5), - authoritiesStrings5)).isTrue(); + HierarchicalRolesTestHelper.toCollectionOfAuthorityStrings(authorities5), authoritiesStrings5)) + .isTrue(); } // SEC-863 @Test public void testContainTheSameGrantedAuthoritiesCompareByAuthorityString() { - List authorities1 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_B"); - List authorities2 = AuthorityUtils.createAuthorityList( - "ROLE_B", "ROLE_A"); - List authorities3 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_C"); - List authorities4 = AuthorityUtils - .createAuthorityList("ROLE_A"); - List authorities5 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_A"); - - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, - null)).isTrue(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities1)).isTrue(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities2)).isTrue(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities2, authorities1)).isTrue(); - - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, - authorities1)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, null)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities3)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities3, authorities1)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities1, authorities4)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities4, authorities1)).isFalse(); - assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities( - authorities4, authorities5)).isFalse(); + List authorities1 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B"); + List authorities2 = AuthorityUtils.createAuthorityList("ROLE_B", "ROLE_A"); + List authorities3 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_C"); + List authorities4 = AuthorityUtils.createAuthorityList("ROLE_A"); + List authorities5 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_A"); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, null)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities1)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities2)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities2, authorities1)).isTrue(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(null, authorities1)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, null)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities3)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities3, authorities1)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities1, authorities4)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities4, authorities1)).isFalse(); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthorities(authorities4, authorities5)).isFalse(); } // SEC-863 @Test public void testContainTheSameGrantedAuthoritiesCompareByAuthorityStringWithAuthorityLists() { - List authorities1 = HierarchicalRolesTestHelper - .createAuthorityList("ROLE_A", "ROLE_B"); - List authorities2 = AuthorityUtils.createAuthorityList( - "ROLE_A", "ROLE_B"); - assertThat(HierarchicalRolesTestHelper - .containTheSameGrantedAuthoritiesCompareByAuthorityString(authorities1, - authorities2)).isTrue(); + List authorities1 = HierarchicalRolesTestHelper.createAuthorityList("ROLE_A", "ROLE_B"); + List authorities2 = AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B"); + assertThat(HierarchicalRolesTestHelper.containTheSameGrantedAuthoritiesCompareByAuthorityString(authorities1, + authorities2)).isTrue(); } // SEC-863 @Test public void testCreateAuthorityList() { - List authorities1 = HierarchicalRolesTestHelper - .createAuthorityList("ROLE_A"); + List authorities1 = HierarchicalRolesTestHelper.createAuthorityList("ROLE_A"); assertThat(authorities1).hasSize(1); assertThat(authorities1.get(0).getAuthority()).isEqualTo("ROLE_A"); - - List authorities2 = HierarchicalRolesTestHelper - .createAuthorityList("ROLE_A", "ROLE_C"); + List authorities2 = HierarchicalRolesTestHelper.createAuthorityList("ROLE_A", "ROLE_C"); assertThat(authorities2).hasSize(2); assertThat(authorities2.get(0).getAuthority()).isEqualTo("ROLE_A"); assertThat(authorities2.get(1).getAuthority()).isEqualTo("ROLE_C"); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/AbstractSecurityInterceptorTests.java b/core/src/test/java/org/springframework/security/access/intercept/AbstractSecurityInterceptorTests.java index e9bd20de69..6a4047cbae 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/AbstractSecurityInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/AbstractSecurityInterceptorTests.java @@ -16,14 +16,15 @@ package org.springframework.security.access.intercept; -import static org.mockito.Mockito.mock; - import org.junit.Test; + import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.access.SecurityMetadataSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.util.SimpleMethodInvocation; +import static org.mockito.Mockito.mock; + /** * Tests some {@link AbstractSecurityInterceptor} methods. Most of the testing for this * class is found in the {@code MethodSecurityInterceptorTests} class. @@ -31,13 +32,10 @@ import org.springframework.security.util.SimpleMethodInvocation; * @author Ben Alex */ public class AbstractSecurityInterceptorTests { - // ~ Methods - // ======================================================================================================== @Test(expected = IllegalArgumentException.class) public void detectsIfInvocationPassedIncompatibleSecureObject() { MockSecurityInterceptorWhichOnlySupportsStrings si = new MockSecurityInterceptorWhichOnlySupportsStrings(); - si.setRunAsManager(mock(RunAsManager.class)); si.setAuthenticationManager(mock(AuthenticationManager.class)); si.setAfterInvocationManager(mock(AfterInvocationManager.class)); @@ -57,41 +55,44 @@ public class AbstractSecurityInterceptorTests { si.afterPropertiesSet(); } - // ~ Inner Classes - // ================================================================================================== - private class MockSecurityInterceptorReturnsNull extends AbstractSecurityInterceptor { + private SecurityMetadataSource securityMetadataSource; + @Override public Class getSecureObjectClass() { return null; } + @Override public SecurityMetadataSource obtainSecurityMetadataSource() { - return securityMetadataSource; + return this.securityMetadataSource; } - public void setSecurityMetadataSource( - SecurityMetadataSource securityMetadataSource) { + void setSecurityMetadataSource(SecurityMetadataSource securityMetadataSource) { this.securityMetadataSource = securityMetadataSource; } + } - private class MockSecurityInterceptorWhichOnlySupportsStrings extends - AbstractSecurityInterceptor { + private class MockSecurityInterceptorWhichOnlySupportsStrings extends AbstractSecurityInterceptor { + private SecurityMetadataSource securityMetadataSource; + @Override public Class getSecureObjectClass() { return String.class; } + @Override public SecurityMetadataSource obtainSecurityMetadataSource() { - return securityMetadataSource; + return this.securityMetadataSource; } - public void setSecurityMetadataSource( - SecurityMetadataSource securityMetadataSource) { + void setSecurityMetadataSource(SecurityMetadataSource securityMetadataSource) { this.securityMetadataSource = securityMetadataSource; } + } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/AfterInvocationProviderManagerTests.java b/core/src/test/java/org/springframework/security/access/intercept/AfterInvocationProviderManagerTests.java index 90ce27038c..f6fc8ec922 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/AfterInvocationProviderManagerTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/AfterInvocationProviderManagerTests.java @@ -16,15 +16,13 @@ package org.springframework.security.access.intercept; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.Collection; import java.util.List; import java.util.Vector; import org.aopalliance.intercept.MethodInvocation; import org.junit.Test; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AfterInvocationProvider; import org.springframework.security.access.ConfigAttribute; @@ -32,6 +30,9 @@ import org.springframework.security.access.SecurityConfig; import org.springframework.security.core.Authentication; import org.springframework.security.util.SimpleMethodInvocation; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AfterInvocationProviderManager}. * @@ -40,54 +41,37 @@ import org.springframework.security.util.SimpleMethodInvocation; @SuppressWarnings("unchecked") public class AfterInvocationProviderManagerTests { - // ~ Methods - // ======================================================================================================== @Test public void testCorrectOperation() throws Exception { AfterInvocationProviderManager manager = new AfterInvocationProviderManager(); List list = new Vector(); - list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP1"))); - list.add(new MockAfterInvocationProvider("swap2", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP2"))); - list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP3"))); + list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP1"))); + list.add(new MockAfterInvocationProvider("swap2", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP2"))); + list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP3"))); manager.setProviders(list); assertThat(manager.getProviders()).isEqualTo(list); manager.afterPropertiesSet(); - - List attr1 = SecurityConfig.createList( - new String[] { "GIVE_ME_SWAP1" }); - List attr2 = SecurityConfig.createList( - new String[] { "GIVE_ME_SWAP2" }); - List attr3 = SecurityConfig.createList( - new String[] { "GIVE_ME_SWAP3" }); - List attr2and3 = SecurityConfig.createList( - new String[] { "GIVE_ME_SWAP2", "GIVE_ME_SWAP3" }); - List attr4 = SecurityConfig.createList( - new String[] { "NEVER_CAUSES_SWAP" }); - - assertThat(manager.decide(null, new SimpleMethodInvocation(), attr1, - "content-before-swapping")).isEqualTo("swap1"); - - assertThat(manager.decide(null, new SimpleMethodInvocation(), attr2, - "content-before-swapping")).isEqualTo("swap2"); - - assertThat(manager.decide(null, new SimpleMethodInvocation(), attr3, - "content-before-swapping")).isEqualTo("swap3"); - - assertThat(manager.decide(null, new SimpleMethodInvocation(), attr4, - "content-before-swapping")).isEqualTo("content-before-swapping"); - - assertThat(manager.decide(null, new SimpleMethodInvocation(), attr2and3, - "content-before-swapping")).isEqualTo("swap3"); + List attr1 = SecurityConfig.createList(new String[] { "GIVE_ME_SWAP1" }); + List attr2 = SecurityConfig.createList(new String[] { "GIVE_ME_SWAP2" }); + List attr3 = SecurityConfig.createList(new String[] { "GIVE_ME_SWAP3" }); + List attr2and3 = SecurityConfig.createList(new String[] { "GIVE_ME_SWAP2", "GIVE_ME_SWAP3" }); + List attr4 = SecurityConfig.createList(new String[] { "NEVER_CAUSES_SWAP" }); + assertThat(manager.decide(null, new SimpleMethodInvocation(), attr1, "content-before-swapping")) + .isEqualTo("swap1"); + assertThat(manager.decide(null, new SimpleMethodInvocation(), attr2, "content-before-swapping")) + .isEqualTo("swap2"); + assertThat(manager.decide(null, new SimpleMethodInvocation(), attr3, "content-before-swapping")) + .isEqualTo("swap3"); + assertThat(manager.decide(null, new SimpleMethodInvocation(), attr4, "content-before-swapping")) + .isEqualTo("content-before-swapping"); + assertThat(manager.decide(null, new SimpleMethodInvocation(), attr2and3, "content-before-swapping")) + .isEqualTo("swap3"); } @Test public void testRejectsEmptyProvidersList() { AfterInvocationProviderManager manager = new AfterInvocationProviderManager(); List list = new Vector(); - try { manager.setProviders(list); fail("Should have thrown IllegalArgumentException"); @@ -101,12 +85,9 @@ public class AfterInvocationProviderManagerTests { public void testRejectsNonAfterInvocationProviders() { AfterInvocationProviderManager manager = new AfterInvocationProviderManager(); List list = new Vector(); - list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP1"))); + list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP1"))); list.add(45); - list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP3"))); - + list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP3"))); try { manager.setProviders(list); fail("Should have thrown IllegalArgumentException"); @@ -119,7 +100,6 @@ public class AfterInvocationProviderManagerTests { @Test public void testRejectsNullProvidersList() throws Exception { AfterInvocationProviderManager manager = new AfterInvocationProviderManager(); - try { manager.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); @@ -133,15 +113,11 @@ public class AfterInvocationProviderManagerTests { public void testSupportsConfigAttributeIteration() throws Exception { AfterInvocationProviderManager manager = new AfterInvocationProviderManager(); List list = new Vector(); - list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP1"))); - list.add(new MockAfterInvocationProvider("swap2", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP2"))); - list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP3"))); + list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP1"))); + list.add(new MockAfterInvocationProvider("swap2", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP2"))); + list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP3"))); manager.setProviders(list); manager.afterPropertiesSet(); - assertThat(manager.supports(new SecurityConfig("UNKNOWN_ATTRIB"))).isFalse(); assertThat(manager.supports(new SecurityConfig("GIVE_ME_SWAP2"))).isTrue(); } @@ -150,22 +126,15 @@ public class AfterInvocationProviderManagerTests { public void testSupportsSecureObjectIteration() throws Exception { AfterInvocationProviderManager manager = new AfterInvocationProviderManager(); List list = new Vector(); - list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP1"))); - list.add(new MockAfterInvocationProvider("swap2", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP2"))); - list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, - new SecurityConfig("GIVE_ME_SWAP3"))); + list.add(new MockAfterInvocationProvider("swap1", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP1"))); + list.add(new MockAfterInvocationProvider("swap2", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP2"))); + list.add(new MockAfterInvocationProvider("swap3", MethodInvocation.class, new SecurityConfig("GIVE_ME_SWAP3"))); manager.setProviders(list); manager.afterPropertiesSet(); - // assertFalse(manager.supports(FilterInvocation.class)); assertThat(manager.supports(MethodInvocation.class)).isTrue(); } - // ~ Inner Classes - // ================================================================================================== - /** * Always returns the constructor-defined forceReturnObject, provided the * same configuration attribute was provided. Also stores the secure object it @@ -185,22 +154,25 @@ public class AfterInvocationProviderManagerTests { this.configAttribute = configAttribute; } - public Object decide(Authentication authentication, Object object, - Collection config, Object returnedObject) - throws AccessDeniedException { - if (config.contains(configAttribute)) { - return forceReturnObject; + @Override + public Object decide(Authentication authentication, Object object, Collection config, + Object returnedObject) throws AccessDeniedException { + if (config.contains(this.configAttribute)) { + return this.forceReturnObject; } - return returnedObject; } + @Override public boolean supports(Class clazz) { - return secureObject.isAssignableFrom(clazz); + return this.secureObject.isAssignableFrom(clazz); } + @Override public boolean supports(ConfigAttribute attribute) { - return attribute.equals(configAttribute); + return attribute.equals(this.configAttribute); } + } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/InterceptorStatusTokenTests.java b/core/src/test/java/org/springframework/security/access/intercept/InterceptorStatusTokenTests.java index f9f163aeb8..eb6947816a 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/InterceptorStatusTokenTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/InterceptorStatusTokenTests.java @@ -16,18 +16,19 @@ package org.springframework.security.access.intercept; -import static org.assertj.core.api.Assertions.*; - import java.util.List; import org.aopalliance.intercept.MethodInvocation; import org.junit.Test; + import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.util.SimpleMethodInvocation; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link InterceptorStatusToken}. * @@ -41,10 +42,10 @@ public class InterceptorStatusTokenTests { MethodInvocation mi = new SimpleMethodInvocation(); SecurityContext ctx = SecurityContextHolder.createEmptyContext(); InterceptorStatusToken token = new InterceptorStatusToken(ctx, true, attr, mi); - assertThat(token.isContextHolderRefreshRequired()).isTrue(); assertThat(token.getAttributes()).isEqualTo(attr); assertThat(token.getSecureObject()).isEqualTo(mi); assertThat(token.getSecurityContext()).isSameAs(ctx); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/NullRunAsManagerTests.java b/core/src/test/java/org/springframework/security/access/intercept/NullRunAsManagerTests.java index 3200cb954d..da84efa8ba 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/NullRunAsManagerTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/NullRunAsManagerTests.java @@ -16,19 +16,18 @@ package org.springframework.security.access.intercept; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.access.SecurityConfig; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link NullRunAsManager}. * * @author Ben Alex */ public class NullRunAsManagerTests { - // ~ Methods - // ======================================================================================================== @Test public void testAlwaysReturnsNull() { @@ -47,4 +46,5 @@ public class NullRunAsManagerTests { NullRunAsManager runAs = new NullRunAsManager(); assertThat(runAs.supports(new SecurityConfig("X"))).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProviderTests.java index ad2fda9bd9..620806f5ff 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/RunAsImplAuthenticationProviderTests.java @@ -16,15 +16,17 @@ package org.springframework.security.access.intercept; -import static org.assertj.core.api.Assertions.*; +import org.junit.Assert; +import org.junit.Test; -import org.junit.*; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link RunAsImplAuthenticationProvider}. */ @@ -33,27 +35,20 @@ public class RunAsImplAuthenticationProviderTests { @Test(expected = BadCredentialsException.class) public void testAuthenticationFailDueToWrongKey() { RunAsUserToken token = new RunAsUserToken("wrong_key", "Test", "Password", - AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), - UsernamePasswordAuthenticationToken.class); + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), UsernamePasswordAuthenticationToken.class); RunAsImplAuthenticationProvider provider = new RunAsImplAuthenticationProvider(); provider.setKey("hello_world"); - provider.authenticate(token); } @Test public void testAuthenticationSuccess() { RunAsUserToken token = new RunAsUserToken("my_password", "Test", "Password", - AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), - UsernamePasswordAuthenticationToken.class); + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), UsernamePasswordAuthenticationToken.class); RunAsImplAuthenticationProvider provider = new RunAsImplAuthenticationProvider(); provider.setKey("my_password"); - Authentication result = provider.authenticate(token); - - Assert.assertTrue("Should have returned RunAsUserToken", - result instanceof RunAsUserToken); - + Assert.assertTrue("Should have returned RunAsUserToken", result instanceof RunAsUserToken); RunAsUserToken resultCast = (RunAsUserToken) result; assertThat(resultCast.getKeyHash()).isEqualTo("my_password".hashCode()); } @@ -61,7 +56,6 @@ public class RunAsImplAuthenticationProviderTests { @Test(expected = IllegalArgumentException.class) public void testStartupFailsIfNoKey() throws Exception { RunAsImplAuthenticationProvider provider = new RunAsImplAuthenticationProvider(); - provider.afterPropertiesSet(); } @@ -79,4 +73,5 @@ public class RunAsImplAuthenticationProviderTests { assertThat(provider.supports(RunAsUserToken.class)).isTrue(); assertThat(!provider.supports(TestingAuthenticationToken.class)).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/RunAsManagerImplTests.java b/core/src/test/java/org/springframework/security/access/intercept/RunAsManagerImplTests.java index be93663948..31503300c3 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/RunAsManagerImplTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/RunAsManagerImplTests.java @@ -16,17 +16,18 @@ package org.springframework.security.access.intercept; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.Set; import org.junit.Test; + import org.springframework.security.access.SecurityConfig; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link RunAsManagerImpl}. * @@ -42,13 +43,10 @@ public class RunAsManagerImplTests { @Test public void testDoesNotReturnAdditionalAuthoritiesIfCalledWithoutARunAsSetting() { - UsernamePasswordAuthenticationToken inputToken = new UsernamePasswordAuthenticationToken( - "Test", "Password", + UsernamePasswordAuthenticationToken inputToken = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); - RunAsManagerImpl runAs = new RunAsManagerImpl(); runAs.setKey("my_password"); - Authentication resultingToken = runAs.buildRunAs(inputToken, new Object(), SecurityConfig.createList("SOMETHING_WE_IGNORE")); assertThat(resultingToken).isNull(); @@ -56,56 +54,41 @@ public class RunAsManagerImplTests { @Test public void testRespectsRolePrefix() { - UsernamePasswordAuthenticationToken inputToken = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.createAuthorityList("ONE", "TWO")); - + UsernamePasswordAuthenticationToken inputToken = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.createAuthorityList("ONE", "TWO")); RunAsManagerImpl runAs = new RunAsManagerImpl(); runAs.setKey("my_password"); runAs.setRolePrefix("FOOBAR_"); - Authentication result = runAs.buildRunAs(inputToken, new Object(), SecurityConfig.createList("RUN_AS_SOMETHING")); - - assertThat(result instanceof RunAsUserToken).withFailMessage( - "Should have returned a RunAsUserToken").isTrue(); + assertThat(result instanceof RunAsUserToken).withFailMessage("Should have returned a RunAsUserToken").isTrue(); assertThat(result.getPrincipal()).isEqualTo(inputToken.getPrincipal()); assertThat(result.getCredentials()).isEqualTo(inputToken.getCredentials()); - Set authorities = AuthorityUtils.authorityListToSet( - result.getAuthorities()); - + Set authorities = AuthorityUtils.authorityListToSet(result.getAuthorities()); assertThat(authorities.contains("FOOBAR_RUN_AS_SOMETHING")).isTrue(); assertThat(authorities.contains("ONE")).isTrue(); assertThat(authorities.contains("TWO")).isTrue(); - RunAsUserToken resultCast = (RunAsUserToken) result; assertThat(resultCast.getKeyHash()).isEqualTo("my_password".hashCode()); } @Test public void testReturnsAdditionalGrantedAuthorities() { - UsernamePasswordAuthenticationToken inputToken = new UsernamePasswordAuthenticationToken( - "Test", "Password", + UsernamePasswordAuthenticationToken inputToken = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); - RunAsManagerImpl runAs = new RunAsManagerImpl(); runAs.setKey("my_password"); - Authentication result = runAs.buildRunAs(inputToken, new Object(), SecurityConfig.createList("RUN_AS_SOMETHING")); - if (!(result instanceof RunAsUserToken)) { fail("Should have returned a RunAsUserToken"); } - assertThat(result.getPrincipal()).isEqualTo(inputToken.getPrincipal()); assertThat(result.getCredentials()).isEqualTo(inputToken.getCredentials()); - - Set authorities = AuthorityUtils.authorityListToSet( - result.getAuthorities()); + Set authorities = AuthorityUtils.authorityListToSet(result.getAuthorities()); assertThat(authorities.contains("ROLE_RUN_AS_SOMETHING")).isTrue(); assertThat(authorities.contains("ROLE_ONE")).isTrue(); assertThat(authorities.contains("ROLE_TWO")).isTrue(); - RunAsUserToken resultCast = (RunAsUserToken) result; assertThat(resultCast.getKeyHash()).isEqualTo("my_password".hashCode()); } @@ -113,13 +96,11 @@ public class RunAsManagerImplTests { @Test public void testStartupDetectsMissingKey() throws Exception { RunAsManagerImpl runAs = new RunAsManagerImpl(); - try { runAs.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @@ -138,4 +119,5 @@ public class RunAsManagerImplTests { assertThat(!runAs.supports(new SecurityConfig("ROLE_WHICH_IS_IGNORED"))).isTrue(); assertThat(!runAs.supports(new SecurityConfig("role_LOWER_CASE_FAILS"))).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/RunAsUserTokenTests.java b/core/src/test/java/org/springframework/security/access/intercept/RunAsUserTokenTests.java index a378d586ce..b8b151b27a 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/RunAsUserTokenTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/RunAsUserTokenTests.java @@ -16,13 +16,14 @@ package org.springframework.security.access.intercept; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import org.junit.Test; + import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link RunAsUserToken}. * @@ -33,8 +34,7 @@ public class RunAsUserTokenTests { @Test public void testAuthenticationSetting() { RunAsUserToken token = new RunAsUserToken("my_password", "Test", "Password", - AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), - UsernamePasswordAuthenticationToken.class); + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), UsernamePasswordAuthenticationToken.class); assertThat(token.isAuthenticated()).isTrue(); token.setAuthenticated(false); assertThat(!token.isAuthenticated()).isTrue(); @@ -43,19 +43,16 @@ public class RunAsUserTokenTests { @Test public void testGetters() { RunAsUserToken token = new RunAsUserToken("my_password", "Test", "Password", - AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), - UsernamePasswordAuthenticationToken.class); + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), UsernamePasswordAuthenticationToken.class); assertThat("Test").isEqualTo(token.getPrincipal()); assertThat("Password").isEqualTo(token.getCredentials()); assertThat("my_password".hashCode()).isEqualTo(token.getKeyHash()); - assertThat(UsernamePasswordAuthenticationToken.class).isEqualTo( - token.getOriginalAuthentication()); + assertThat(UsernamePasswordAuthenticationToken.class).isEqualTo(token.getOriginalAuthentication()); } @Test public void testNoArgConstructorDoesntExist() { Class clazz = RunAsUserToken.class; - try { clazz.getDeclaredConstructor((Class[]) null); fail("Should have thrown NoSuchMethodException"); @@ -68,10 +65,10 @@ public class RunAsUserTokenTests { @Test public void testToString() { RunAsUserToken token = new RunAsUserToken("my_password", "Test", "Password", - AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), - UsernamePasswordAuthenticationToken.class); - assertThat(token.toString().lastIndexOf("Original Class: " - + UsernamePasswordAuthenticationToken.class.getName().toString()) != -1).isTrue(); + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), UsernamePasswordAuthenticationToken.class); + assertThat(token.toString() + .lastIndexOf("Original Class: " + UsernamePasswordAuthenticationToken.class.getName().toString()) != -1) + .isTrue(); } // SEC-1792 @@ -81,4 +78,5 @@ public class RunAsUserTokenTests { AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"), null); assertThat(token.toString().lastIndexOf("Original Class: null") != -1).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptorTests.java b/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptorTests.java index b4e1f61c64..aa8ff61359 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityInterceptorTests.java @@ -16,11 +16,13 @@ package org.springframework.security.access.intercept.aopalliance; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import java.util.List; import org.aopalliance.intercept.MethodInvocation; -import org.junit.*; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + import org.springframework.aop.framework.ProxyFactory; import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.ITargetObject; @@ -44,7 +46,17 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests {@link MethodSecurityInterceptor}. @@ -54,31 +66,36 @@ import java.util.*; */ @SuppressWarnings("unchecked") public class MethodSecurityInterceptorTests { - private TestingAuthenticationToken token; - private MethodSecurityInterceptor interceptor; - private ITargetObject realTarget; - private ITargetObject advisedTarget; - private AccessDecisionManager adm; - private MethodSecurityMetadataSource mds; - private AuthenticationManager authman; - private ApplicationEventPublisher eventPublisher; - // ~ Methods - // ======================================================================================================== + private TestingAuthenticationToken token; + + private MethodSecurityInterceptor interceptor; + + private ITargetObject realTarget; + + private ITargetObject advisedTarget; + + private AccessDecisionManager adm; + + private MethodSecurityMetadataSource mds; + + private AuthenticationManager authman; + + private ApplicationEventPublisher eventPublisher; @Before public final void setUp() { SecurityContextHolder.clearContext(); - token = new TestingAuthenticationToken("Test", "Password"); - interceptor = new MethodSecurityInterceptor(); - adm = mock(AccessDecisionManager.class); - authman = mock(AuthenticationManager.class); - mds = mock(MethodSecurityMetadataSource.class); - eventPublisher = mock(ApplicationEventPublisher.class); - interceptor.setAccessDecisionManager(adm); - interceptor.setAuthenticationManager(authman); - interceptor.setSecurityMetadataSource(mds); - interceptor.setApplicationEventPublisher(eventPublisher); + this.token = new TestingAuthenticationToken("Test", "Password"); + this.interceptor = new MethodSecurityInterceptor(); + this.adm = mock(AccessDecisionManager.class); + this.authman = mock(AuthenticationManager.class); + this.mds = mock(MethodSecurityMetadataSource.class); + this.eventPublisher = mock(ApplicationEventPublisher.class); + this.interceptor.setAccessDecisionManager(this.adm); + this.interceptor.setAuthenticationManager(this.authman); + this.interceptor.setSecurityMetadataSource(this.mds); + this.interceptor.setApplicationEventPublisher(this.eventPublisher); createTarget(false); } @@ -88,261 +105,239 @@ public class MethodSecurityInterceptorTests { } private void createTarget(boolean useMock) { - realTarget = useMock ? mock(ITargetObject.class) : new TargetObject(); - ProxyFactory pf = new ProxyFactory(realTarget); - pf.addAdvice(interceptor); - advisedTarget = (ITargetObject) pf.getProxy(); + this.realTarget = useMock ? mock(ITargetObject.class) : new TargetObject(); + ProxyFactory pf = new ProxyFactory(this.realTarget); + pf.addAdvice(this.interceptor); + this.advisedTarget = (ITargetObject) pf.getProxy(); } @Test public void gettersReturnExpectedData() { RunAsManager runAs = mock(RunAsManager.class); AfterInvocationManager aim = mock(AfterInvocationManager.class); - interceptor.setRunAsManager(runAs); - interceptor.setAfterInvocationManager(aim); - assertThat(interceptor.getAccessDecisionManager()).isEqualTo(adm); - assertThat(interceptor.getRunAsManager()).isEqualTo(runAs); - assertThat(interceptor.getAuthenticationManager()).isEqualTo(authman); - assertThat(interceptor.getSecurityMetadataSource()).isEqualTo(mds); - assertThat(interceptor.getAfterInvocationManager()).isEqualTo(aim); + this.interceptor.setRunAsManager(runAs); + this.interceptor.setAfterInvocationManager(aim); + assertThat(this.interceptor.getAccessDecisionManager()).isEqualTo(this.adm); + assertThat(this.interceptor.getRunAsManager()).isEqualTo(runAs); + assertThat(this.interceptor.getAuthenticationManager()).isEqualTo(this.authman); + assertThat(this.interceptor.getSecurityMetadataSource()).isEqualTo(this.mds); + assertThat(this.interceptor.getAfterInvocationManager()).isEqualTo(aim); } @Test(expected = IllegalArgumentException.class) public void missingAccessDecisionManagerIsDetected() throws Exception { - interceptor.setAccessDecisionManager(null); - interceptor.afterPropertiesSet(); + this.interceptor.setAccessDecisionManager(null); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) public void missingAuthenticationManagerIsDetected() throws Exception { - interceptor.setAuthenticationManager(null); - interceptor.afterPropertiesSet(); + this.interceptor.setAuthenticationManager(null); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) public void missingMethodSecurityMetadataSourceIsRejected() throws Exception { - interceptor.setSecurityMetadataSource(null); - interceptor.afterPropertiesSet(); + this.interceptor.setSecurityMetadataSource(null); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) public void missingRunAsManagerIsRejected() throws Exception { - interceptor.setRunAsManager(null); - interceptor.afterPropertiesSet(); + this.interceptor.setRunAsManager(null); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) - public void initializationRejectsSecurityMetadataSourceThatDoesNotSupportMethodInvocation() - throws Throwable { - when(mds.supports(MethodInvocation.class)).thenReturn(false); - interceptor.afterPropertiesSet(); + public void initializationRejectsSecurityMetadataSourceThatDoesNotSupportMethodInvocation() throws Throwable { + given(this.mds.supports(MethodInvocation.class)).willReturn(false); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) - public void initializationRejectsAccessDecisionManagerThatDoesNotSupportMethodInvocation() - throws Exception { - when(mds.supports(MethodInvocation.class)).thenReturn(true); - when(adm.supports(MethodInvocation.class)).thenReturn(false); - interceptor.afterPropertiesSet(); + public void initializationRejectsAccessDecisionManagerThatDoesNotSupportMethodInvocation() throws Exception { + given(this.mds.supports(MethodInvocation.class)).willReturn(true); + given(this.adm.supports(MethodInvocation.class)).willReturn(false); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) - public void intitalizationRejectsRunAsManagerThatDoesNotSupportMethodInvocation() - throws Exception { + public void intitalizationRejectsRunAsManagerThatDoesNotSupportMethodInvocation() throws Exception { final RunAsManager ram = mock(RunAsManager.class); - when(ram.supports(MethodInvocation.class)).thenReturn(false); - interceptor.setRunAsManager(ram); - interceptor.afterPropertiesSet(); + given(ram.supports(MethodInvocation.class)).willReturn(false); + this.interceptor.setRunAsManager(ram); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) - public void intitalizationRejectsAfterInvocationManagerThatDoesNotSupportMethodInvocation() - throws Exception { + public void intitalizationRejectsAfterInvocationManagerThatDoesNotSupportMethodInvocation() throws Exception { final AfterInvocationManager aim = mock(AfterInvocationManager.class); - when(aim.supports(MethodInvocation.class)).thenReturn(false); - interceptor.setAfterInvocationManager(aim); - interceptor.afterPropertiesSet(); + given(aim.supports(MethodInvocation.class)).willReturn(false); + this.interceptor.setAfterInvocationManager(aim); + this.interceptor.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) - public void initializationFailsIfAccessDecisionManagerRejectsConfigAttributes() - throws Exception { - when(adm.supports(any(ConfigAttribute.class))).thenReturn(false); - interceptor.afterPropertiesSet(); + public void initializationFailsIfAccessDecisionManagerRejectsConfigAttributes() throws Exception { + given(this.adm.supports(any(ConfigAttribute.class))).willReturn(false); + this.interceptor.afterPropertiesSet(); } @Test - public void validationNotAttemptedIfIsValidateConfigAttributesSetToFalse() - throws Exception { - when(adm.supports(MethodInvocation.class)).thenReturn(true); - when(mds.supports(MethodInvocation.class)).thenReturn(true); - interceptor.setValidateConfigAttributes(false); - interceptor.afterPropertiesSet(); - verify(mds, never()).getAllConfigAttributes(); - verify(adm, never()).supports(any(ConfigAttribute.class)); + public void validationNotAttemptedIfIsValidateConfigAttributesSetToFalse() throws Exception { + given(this.adm.supports(MethodInvocation.class)).willReturn(true); + given(this.mds.supports(MethodInvocation.class)).willReturn(true); + this.interceptor.setValidateConfigAttributes(false); + this.interceptor.afterPropertiesSet(); + verify(this.mds, never()).getAllConfigAttributes(); + verify(this.adm, never()).supports(any(ConfigAttribute.class)); } @Test - public void validationNotAttemptedIfMethodSecurityMetadataSourceReturnsNullForAttributes() - throws Exception { - when(adm.supports(MethodInvocation.class)).thenReturn(true); - when(mds.supports(MethodInvocation.class)).thenReturn(true); - when(mds.getAllConfigAttributes()).thenReturn(null); - - interceptor.setValidateConfigAttributes(true); - interceptor.afterPropertiesSet(); - verify(adm, never()).supports(any(ConfigAttribute.class)); + public void validationNotAttemptedIfMethodSecurityMetadataSourceReturnsNullForAttributes() throws Exception { + given(this.adm.supports(MethodInvocation.class)).willReturn(true); + given(this.mds.supports(MethodInvocation.class)).willReturn(true); + given(this.mds.getAllConfigAttributes()).willReturn(null); + this.interceptor.setValidateConfigAttributes(true); + this.interceptor.afterPropertiesSet(); + verify(this.adm, never()).supports(any(ConfigAttribute.class)); } @Test public void callingAPublicMethodFacadeWillNotRepeatSecurityChecksWhenPassedToTheSecuredMethodItFronts() { mdsReturnsNull(); - String result = advisedTarget.publicMakeLowerCase("HELLO"); + String result = this.advisedTarget.publicMakeLowerCase("HELLO"); assertThat(result).isEqualTo("hello Authentication empty"); } @Test public void callingAPublicMethodWhenPresentingAnAuthenticationObjectDoesntChangeItsAuthenticatedProperty() { mdsReturnsNull(); - SecurityContextHolder.getContext().setAuthentication(token); - assertThat(advisedTarget.publicMakeLowerCase("HELLO")).isEqualTo("hello org.springframework.security.authentication.TestingAuthenticationToken false"); - assertThat(!token.isAuthenticated()).isTrue(); + SecurityContextHolder.getContext().setAuthentication(this.token); + assertThat(this.advisedTarget.publicMakeLowerCase("HELLO")) + .isEqualTo("hello org.springframework.security.authentication.TestingAuthenticationToken false"); + assertThat(!this.token.isAuthenticated()).isTrue(); } @Test(expected = AuthenticationException.class) public void callIsntMadeWhenAuthenticationManagerRejectsAuthentication() { - final TestingAuthenticationToken token = new TestingAuthenticationToken("Test", - "Password"); + final TestingAuthenticationToken token = new TestingAuthenticationToken("Test", "Password"); SecurityContextHolder.getContext().setAuthentication(token); - mdsReturnsUserRole(); - when(authman.authenticate(token)).thenThrow( - new BadCredentialsException("rejected")); - - advisedTarget.makeLowerCase("HELLO"); + given(this.authman.authenticate(token)).willThrow(new BadCredentialsException("rejected")); + this.advisedTarget.makeLowerCase("HELLO"); } @Test public void callSucceedsIfAccessDecisionManagerGrantsAccess() { - token.setAuthenticated(true); - interceptor.setPublishAuthorizationSuccess(true); - SecurityContextHolder.getContext().setAuthentication(token); + this.token.setAuthenticated(true); + this.interceptor.setPublishAuthorizationSuccess(true); + SecurityContextHolder.getContext().setAuthentication(this.token); mdsReturnsUserRole(); - - String result = advisedTarget.makeLowerCase("HELLO"); - + String result = this.advisedTarget.makeLowerCase("HELLO"); // Note we check the isAuthenticated remained true in following line - assertThat(result).isEqualTo("hello org.springframework.security.authentication.TestingAuthenticationToken true"); - verify(eventPublisher).publishEvent(any(AuthorizedEvent.class)); + assertThat(result) + .isEqualTo("hello org.springframework.security.authentication.TestingAuthenticationToken true"); + verify(this.eventPublisher).publishEvent(any(AuthorizedEvent.class)); } @Test public void callIsntMadeWhenAccessDecisionManagerRejectsAccess() { - SecurityContextHolder.getContext().setAuthentication(token); + SecurityContextHolder.getContext().setAuthentication(this.token); // Use mocked target to make sure invocation doesn't happen (not in expectations // so test would fail) createTarget(true); mdsReturnsUserRole(); - when(authman.authenticate(token)).thenReturn(token); - doThrow(new AccessDeniedException("rejected")).when(adm).decide( - any(Authentication.class), any(MethodInvocation.class), any(List.class)); - + given(this.authman.authenticate(this.token)).willReturn(this.token); + willThrow(new AccessDeniedException("rejected")).given(this.adm).decide(any(Authentication.class), + any(MethodInvocation.class), any(List.class)); try { - advisedTarget.makeUpperCase("HELLO"); + this.advisedTarget.makeUpperCase("HELLO"); fail("Expected Exception"); } catch (AccessDeniedException expected) { } - verify(eventPublisher).publishEvent(any(AuthorizationFailureEvent.class)); + verify(this.eventPublisher).publishEvent(any(AuthorizationFailureEvent.class)); } @Test(expected = IllegalArgumentException.class) public void rejectsNullSecuredObjects() throws Throwable { - interceptor.invoke(null); + this.interceptor.invoke(null); } @Test public void runAsReplacementIsCorrectlySet() { SecurityContext ctx = SecurityContextHolder.getContext(); - ctx.setAuthentication(token); - token.setAuthenticated(true); + ctx.setAuthentication(this.token); + this.token.setAuthenticated(true); final RunAsManager runAs = mock(RunAsManager.class); - final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", - token.getAuthorities(), TestingAuthenticationToken.class); - interceptor.setRunAsManager(runAs); + final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", this.token.getAuthorities(), + TestingAuthenticationToken.class); + this.interceptor.setRunAsManager(runAs); mdsReturnsUserRole(); - when(runAs.buildRunAs(eq(token), any(MethodInvocation.class), any(List.class))) - .thenReturn(runAsToken); - - String result = advisedTarget.makeUpperCase("hello"); + given(runAs.buildRunAs(eq(this.token), any(MethodInvocation.class), any(List.class))).willReturn(runAsToken); + String result = this.advisedTarget.makeUpperCase("hello"); assertThat(result).isEqualTo("HELLO org.springframework.security.access.intercept.RunAsUserToken true"); // Check we've changed back assertThat(SecurityContextHolder.getContext()).isSameAs(ctx); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(token); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.token); } // SEC-1967 @Test public void runAsReplacementCleansAfterException() { createTarget(true); - when(realTarget.makeUpperCase(anyString())).thenThrow(new RuntimeException()); + given(this.realTarget.makeUpperCase(anyString())).willThrow(new RuntimeException()); SecurityContext ctx = SecurityContextHolder.getContext(); - ctx.setAuthentication(token); - token.setAuthenticated(true); + ctx.setAuthentication(this.token); + this.token.setAuthenticated(true); final RunAsManager runAs = mock(RunAsManager.class); - final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", - token.getAuthorities(), TestingAuthenticationToken.class); - interceptor.setRunAsManager(runAs); + final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", this.token.getAuthorities(), + TestingAuthenticationToken.class); + this.interceptor.setRunAsManager(runAs); mdsReturnsUserRole(); - when(runAs.buildRunAs(eq(token), any(MethodInvocation.class), any(List.class))) - .thenReturn(runAsToken); - + given(runAs.buildRunAs(eq(this.token), any(MethodInvocation.class), any(List.class))).willReturn(runAsToken); try { - advisedTarget.makeUpperCase("hello"); + this.advisedTarget.makeUpperCase("hello"); fail("Expected Exception"); } catch (RuntimeException success) { } - // Check we've changed back assertThat(SecurityContextHolder.getContext()).isSameAs(ctx); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(token); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.token); } @Test(expected = AuthenticationCredentialsNotFoundException.class) public void emptySecurityContextIsRejected() { mdsReturnsUserRole(); - advisedTarget.makeUpperCase("hello"); + this.advisedTarget.makeUpperCase("hello"); } @Test public void afterInvocationManagerIsNotInvokedIfExceptionIsRaised() throws Throwable { MethodInvocation mi = mock(MethodInvocation.class); - token.setAuthenticated(true); - SecurityContextHolder.getContext().setAuthentication(token); + this.token.setAuthenticated(true); + SecurityContextHolder.getContext().setAuthentication(this.token); mdsReturnsUserRole(); - AfterInvocationManager aim = mock(AfterInvocationManager.class); - interceptor.setAfterInvocationManager(aim); - - when(mi.proceed()).thenThrow(new Throwable()); - + this.interceptor.setAfterInvocationManager(aim); + given(mi.proceed()).willThrow(new Throwable()); try { - interceptor.invoke(mi); + this.interceptor.invoke(mi); fail("Expected exception"); } catch (Throwable expected) { } - verifyZeroInteractions(aim); } void mdsReturnsNull() { - when(mds.getAttributes(any(MethodInvocation.class))).thenReturn(null); + given(this.mds.getAttributes(any(MethodInvocation.class))).willReturn(null); } void mdsReturnsUserRole() { - when(mds.getAttributes(any(MethodInvocation.class))).thenReturn( - SecurityConfig.createList("ROLE_USER")); + given(this.mds.getAttributes(any(MethodInvocation.class))).willReturn(SecurityConfig.createList("ROLE_USER")); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisorTests.java b/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisorTests.java index 6ddf453c9b..297705c6e5 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisorTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/aopalliance/MethodSecurityMetadataSourceAdvisorTests.java @@ -16,17 +16,18 @@ package org.springframework.security.access.intercept.aopalliance; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - import java.lang.reflect.Method; import org.junit.Test; + import org.springframework.security.TargetObject; import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.method.MethodSecurityMetadataSource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests {@link MethodSecurityMetadataSourceAdvisor}. * @@ -34,32 +35,24 @@ import org.springframework.security.access.method.MethodSecurityMetadataSource; */ public class MethodSecurityMetadataSourceAdvisorTests { - // ~ Methods - // ======================================================================================================== @Test public void testAdvisorReturnsFalseWhenMethodInvocationNotDefined() throws Exception { Class clazz = TargetObject.class; Method method = clazz.getMethod("makeLowerCase", new Class[] { String.class }); - MethodSecurityMetadataSource mds = mock(MethodSecurityMetadataSource.class); - when(mds.getAttributes(method, clazz)).thenReturn(null); - MethodSecurityMetadataSourceAdvisor advisor = new MethodSecurityMetadataSourceAdvisor( - "", mds, ""); - assertThat(advisor.getPointcut().getMethodMatcher().matches(method, - clazz)).isFalse(); + given(mds.getAttributes(method, clazz)).willReturn(null); + MethodSecurityMetadataSourceAdvisor advisor = new MethodSecurityMetadataSourceAdvisor("", mds, ""); + assertThat(advisor.getPointcut().getMethodMatcher().matches(method, clazz)).isFalse(); } @Test public void testAdvisorReturnsTrueWhenMethodInvocationIsDefined() throws Exception { Class clazz = TargetObject.class; Method method = clazz.getMethod("countLength", new Class[] { String.class }); - MethodSecurityMetadataSource mds = mock(MethodSecurityMetadataSource.class); - when(mds.getAttributes(method, clazz)).thenReturn( - SecurityConfig.createList("ROLE_A")); - MethodSecurityMetadataSourceAdvisor advisor = new MethodSecurityMetadataSourceAdvisor( - "", mds, ""); - assertThat( - advisor.getPointcut().getMethodMatcher().matches(method, clazz)).isTrue(); + given(mds.getAttributes(method, clazz)).willReturn(SecurityConfig.createList("ROLE_A")); + MethodSecurityMetadataSourceAdvisor advisor = new MethodSecurityMetadataSourceAdvisor("", mds, ""); + assertThat(advisor.getPointcut().getMethodMatcher().matches(method, clazz)).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptorTests.java b/core/src/test/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptorTests.java index e98c5764a6..08ef48c501 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/aspectj/AspectJMethodSecurityInterceptorTests.java @@ -16,8 +16,8 @@ package org.springframework.security.access.intercept.aspectj; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import java.lang.reflect.Method; +import java.util.List; import org.aopalliance.intercept.MethodInvocation; import org.aspectj.lang.JoinPoint; @@ -29,6 +29,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; + import org.springframework.security.TargetObject; import org.springframework.security.access.AccessDecisionManager; import org.springframework.security.access.AccessDeniedException; @@ -43,8 +44,16 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.ClassUtils; -import java.lang.reflect.Method; -import java.util.List; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests {@link AspectJMethodSecurityInterceptor}. @@ -54,42 +63,49 @@ import java.util.List; * @author Rob Winch */ public class AspectJMethodSecurityInterceptorTests { - private TestingAuthenticationToken token; - private AspectJMethodSecurityInterceptor interceptor; - private @Mock AccessDecisionManager adm; - private @Mock MethodSecurityMetadataSource mds; - private @Mock AuthenticationManager authman; - private @Mock AspectJCallback aspectJCallback; - private ProceedingJoinPoint joinPoint; - // ~ Methods - // ======================================================================================================== + private TestingAuthenticationToken token; + + private AspectJMethodSecurityInterceptor interceptor; + + @Mock + private AccessDecisionManager adm; + + @Mock + private MethodSecurityMetadataSource mds; + + @Mock + private AuthenticationManager authman; + + @Mock + private AspectJCallback aspectJCallback; + + private ProceedingJoinPoint joinPoint; @Before public final void setUp() { MockitoAnnotations.initMocks(this); SecurityContextHolder.clearContext(); - token = new TestingAuthenticationToken("Test", "Password"); - interceptor = new AspectJMethodSecurityInterceptor(); - interceptor.setAccessDecisionManager(adm); - interceptor.setAuthenticationManager(authman); - interceptor.setSecurityMetadataSource(mds); + this.token = new TestingAuthenticationToken("Test", "Password"); + this.interceptor = new AspectJMethodSecurityInterceptor(); + this.interceptor.setAccessDecisionManager(this.adm); + this.interceptor.setAuthenticationManager(this.authman); + this.interceptor.setSecurityMetadataSource(this.mds); // Set up joinpoint information for the countLength method on TargetObject - joinPoint = mock(ProceedingJoinPoint.class); // new MockJoinPoint(new - // TargetObject(), method); + this.joinPoint = mock(ProceedingJoinPoint.class); // new MockJoinPoint(new + // TargetObject(), method); Signature sig = mock(Signature.class); - when(sig.getDeclaringType()).thenReturn(TargetObject.class); + given(sig.getDeclaringType()).willReturn(TargetObject.class); JoinPoint.StaticPart staticPart = mock(JoinPoint.StaticPart.class); - when(joinPoint.getSignature()).thenReturn(sig); - when(joinPoint.getStaticPart()).thenReturn(staticPart); + given(this.joinPoint.getSignature()).willReturn(sig); + given(this.joinPoint.getStaticPart()).willReturn(staticPart); CodeSignature codeSig = mock(CodeSignature.class); - when(codeSig.getName()).thenReturn("countLength"); - when(codeSig.getDeclaringType()).thenReturn(TargetObject.class); - when(codeSig.getParameterTypes()).thenReturn(new Class[] { String.class }); - when(staticPart.getSignature()).thenReturn(codeSig); - when(mds.getAttributes(any())).thenReturn( - SecurityConfig.createList("ROLE_USER")); - when(authman.authenticate(token)).thenReturn(token); + given(codeSig.getName()).willReturn("countLength"); + given(codeSig.getDeclaringType()).willReturn(TargetObject.class); + given(codeSig.getParameterTypes()).willReturn(new Class[] { String.class }); + given(staticPart.getSignature()).willReturn(codeSig); + given(this.mds.getAttributes(any())).willReturn(SecurityConfig.createList("ROLE_USER")); + given(this.authman.authenticate(this.token)).willReturn(this.token); } @After @@ -99,39 +115,34 @@ public class AspectJMethodSecurityInterceptorTests { @Test public void callbackIsInvokedWhenPermissionGranted() throws Throwable { - SecurityContextHolder.getContext().setAuthentication(token); - interceptor.invoke(joinPoint, aspectJCallback); - verify(aspectJCallback).proceedWithObject(); - + SecurityContextHolder.getContext().setAuthentication(this.token); + this.interceptor.invoke(this.joinPoint, this.aspectJCallback); + verify(this.aspectJCallback).proceedWithObject(); // Just try the other method too - interceptor.invoke(joinPoint); + this.interceptor.invoke(this.joinPoint); } @SuppressWarnings("unchecked") @Test public void callbackIsNotInvokedWhenPermissionDenied() { - doThrow(new AccessDeniedException("denied")).when(adm).decide( - any(), any(), any()); - - SecurityContextHolder.getContext().setAuthentication(token); + willThrow(new AccessDeniedException("denied")).given(this.adm).decide(any(), any(), any()); + SecurityContextHolder.getContext().setAuthentication(this.token); try { - interceptor.invoke(joinPoint, aspectJCallback); + this.interceptor.invoke(this.joinPoint, this.aspectJCallback); fail("Expected AccessDeniedException"); } catch (AccessDeniedException expected) { } - verify(aspectJCallback, never()).proceedWithObject(); + verify(this.aspectJCallback, never()).proceedWithObject(); } @Test public void adapterHoldsCorrectData() { TargetObject to = new TargetObject(); - Method m = ClassUtils.getMethodIfAvailable(TargetObject.class, "countLength", - new Class[] { String.class }); - - when(joinPoint.getTarget()).thenReturn(to); - when(joinPoint.getArgs()).thenReturn(new Object[] { "Hi" }); - MethodInvocationAdapter mia = new MethodInvocationAdapter(joinPoint); + Method m = ClassUtils.getMethodIfAvailable(TargetObject.class, "countLength", new Class[] { String.class }); + given(this.joinPoint.getTarget()).willReturn(to); + given(this.joinPoint.getArgs()).willReturn(new Object[] { "Hi" }); + MethodInvocationAdapter mia = new MethodInvocationAdapter(this.joinPoint); assertThat(mia.getArguments()[0]).isEqualTo("Hi"); assertThat(mia.getStaticPart()).isEqualTo(m); assertThat(mia.getMethod()).isEqualTo(m); @@ -140,21 +151,17 @@ public class AspectJMethodSecurityInterceptorTests { @Test public void afterInvocationManagerIsNotInvokedIfExceptionIsRaised() { - token.setAuthenticated(true); - SecurityContextHolder.getContext().setAuthentication(token); - + this.token.setAuthenticated(true); + SecurityContextHolder.getContext().setAuthentication(this.token); AfterInvocationManager aim = mock(AfterInvocationManager.class); - interceptor.setAfterInvocationManager(aim); - - when(aspectJCallback.proceedWithObject()).thenThrow(new RuntimeException()); - + this.interceptor.setAfterInvocationManager(aim); + given(this.aspectJCallback.proceedWithObject()).willThrow(new RuntimeException()); try { - interceptor.invoke(joinPoint, aspectJCallback); + this.interceptor.invoke(this.joinPoint, this.aspectJCallback); fail("Expected exception"); } catch (RuntimeException expected) { } - verifyZeroInteractions(aim); } @@ -163,26 +170,23 @@ public class AspectJMethodSecurityInterceptorTests { @SuppressWarnings("unchecked") public void invokeWithAspectJCallbackRunAsReplacementCleansAfterException() { SecurityContext ctx = SecurityContextHolder.getContext(); - ctx.setAuthentication(token); - token.setAuthenticated(true); + ctx.setAuthentication(this.token); + this.token.setAuthenticated(true); final RunAsManager runAs = mock(RunAsManager.class); - final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", - token.getAuthorities(), TestingAuthenticationToken.class); - interceptor.setRunAsManager(runAs); - when(runAs.buildRunAs(eq(token), any(MethodInvocation.class), any(List.class))) - .thenReturn(runAsToken); - when(aspectJCallback.proceedWithObject()).thenThrow(new RuntimeException()); - + final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", this.token.getAuthorities(), + TestingAuthenticationToken.class); + this.interceptor.setRunAsManager(runAs); + given(runAs.buildRunAs(eq(this.token), any(MethodInvocation.class), any(List.class))).willReturn(runAsToken); + given(this.aspectJCallback.proceedWithObject()).willThrow(new RuntimeException()); try { - interceptor.invoke(joinPoint, aspectJCallback); + this.interceptor.invoke(this.joinPoint, this.aspectJCallback); fail("Expected Exception"); } catch (RuntimeException success) { } - // Check we've changed back assertThat(SecurityContextHolder.getContext()).isSameAs(ctx); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(token); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.token); } // SEC-1967 @@ -190,25 +194,23 @@ public class AspectJMethodSecurityInterceptorTests { @SuppressWarnings("unchecked") public void invokeRunAsReplacementCleansAfterException() throws Throwable { SecurityContext ctx = SecurityContextHolder.getContext(); - ctx.setAuthentication(token); - token.setAuthenticated(true); + ctx.setAuthentication(this.token); + this.token.setAuthenticated(true); final RunAsManager runAs = mock(RunAsManager.class); - final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", - token.getAuthorities(), TestingAuthenticationToken.class); - interceptor.setRunAsManager(runAs); - when(runAs.buildRunAs(eq(token), any(MethodInvocation.class), any(List.class))) - .thenReturn(runAsToken); - when(joinPoint.proceed()).thenThrow(new RuntimeException()); - + final RunAsUserToken runAsToken = new RunAsUserToken("key", "someone", "creds", this.token.getAuthorities(), + TestingAuthenticationToken.class); + this.interceptor.setRunAsManager(runAs); + given(runAs.buildRunAs(eq(this.token), any(MethodInvocation.class), any(List.class))).willReturn(runAsToken); + given(this.joinPoint.proceed()).willThrow(new RuntimeException()); try { - interceptor.invoke(joinPoint); + this.interceptor.invoke(this.joinPoint); fail("Expected Exception"); } catch (RuntimeException success) { } - // Check we've changed back assertThat(SecurityContextHolder.getContext()).isSameAs(ctx); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(token); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.token); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/method/MapBasedMethodSecurityMetadataSourceTests.java b/core/src/test/java/org/springframework/security/access/intercept/method/MapBasedMethodSecurityMetadataSourceTests.java index 89ae4f9822..ae3c44b91e 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/method/MapBasedMethodSecurityMetadataSourceTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/method/MapBasedMethodSecurityMetadataSourceTests.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.intercept.method; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.access.intercept.method; import java.lang.reflect.Method; import java.util.List; import org.junit.Before; import org.junit.Test; + import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.method.MapBasedMethodSecurityMetadataSource; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests for {@link MapBasedMethodSecurityMetadataSource}. * @@ -33,41 +35,48 @@ import org.springframework.security.access.method.MapBasedMethodSecurityMetadata * @since 2.0.4 */ public class MapBasedMethodSecurityMetadataSourceTests { + private final List ROLE_A = SecurityConfig.createList("ROLE_A"); + private final List ROLE_B = SecurityConfig.createList("ROLE_B"); + private MapBasedMethodSecurityMetadataSource mds; + private Method someMethodString; + private Method someMethodInteger; @Before public void initialize() throws Exception { - mds = new MapBasedMethodSecurityMetadataSource(); - someMethodString = MockService.class.getMethod("someMethod", String.class); - someMethodInteger = MockService.class.getMethod("someMethod", Integer.class); + this.mds = new MapBasedMethodSecurityMetadataSource(); + this.someMethodString = MockService.class.getMethod("someMethod", String.class); + this.someMethodInteger = MockService.class.getMethod("someMethod", Integer.class); } @Test public void wildcardedMatchIsOverwrittenByMoreSpecificMatch() { - mds.addSecureMethod(MockService.class, "some*", ROLE_A); - mds.addSecureMethod(MockService.class, "someMethod*", ROLE_B); - assertThat(mds.getAttributes(someMethodInteger, MockService.class)).isEqualTo(ROLE_B); + this.mds.addSecureMethod(MockService.class, "some*", this.ROLE_A); + this.mds.addSecureMethod(MockService.class, "someMethod*", this.ROLE_B); + assertThat(this.mds.getAttributes(this.someMethodInteger, MockService.class)).isEqualTo(this.ROLE_B); } @Test public void methodsWithDifferentArgumentsAreMatchedCorrectly() { - mds.addSecureMethod(MockService.class, someMethodInteger, ROLE_A); - mds.addSecureMethod(MockService.class, someMethodString, ROLE_B); - - assertThat(mds.getAttributes(someMethodInteger, MockService.class)).isEqualTo(ROLE_A); - assertThat(mds.getAttributes(someMethodString, MockService.class)).isEqualTo(ROLE_B); + this.mds.addSecureMethod(MockService.class, this.someMethodInteger, this.ROLE_A); + this.mds.addSecureMethod(MockService.class, this.someMethodString, this.ROLE_B); + assertThat(this.mds.getAttributes(this.someMethodInteger, MockService.class)).isEqualTo(this.ROLE_A); + assertThat(this.mds.getAttributes(this.someMethodString, MockService.class)).isEqualTo(this.ROLE_B); } @SuppressWarnings("unused") private class MockService { + public void someMethod(String s) { } public void someMethod(Integer i) { } + } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/method/MethodInvocationPrivilegeEvaluatorTests.java b/core/src/test/java/org/springframework/security/access/intercept/method/MethodInvocationPrivilegeEvaluatorTests.java index 80b6b3ac4d..a9e89fa1e5 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/method/MethodInvocationPrivilegeEvaluatorTests.java +++ b/core/src/test/java/org/springframework/security/access/intercept/method/MethodInvocationPrivilegeEvaluatorTests.java @@ -16,14 +16,12 @@ package org.springframework.security.access.intercept.method; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import java.util.*; +import java.util.List; import org.aopalliance.intercept.MethodInvocation; import org.junit.Before; import org.junit.Test; + import org.springframework.security.ITargetObject; import org.springframework.security.OtherTargetObject; import org.springframework.security.TargetObject; @@ -39,6 +37,11 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.util.MethodInvocationUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; + /** * Tests * {@link org.springframework.security.access.intercept.MethodInvocationPrivilegeEvaluator} @@ -47,79 +50,71 @@ import org.springframework.security.util.MethodInvocationUtils; * @author Ben Alex */ public class MethodInvocationPrivilegeEvaluatorTests { - private TestingAuthenticationToken token; - private MethodSecurityInterceptor interceptor; - private AccessDecisionManager adm; - private MethodSecurityMetadataSource mds; - private final List role = SecurityConfig.createList("ROLE_IGNORED"); - // ~ Methods - // ======================================================================================================== + private TestingAuthenticationToken token; + + private MethodSecurityInterceptor interceptor; + + private AccessDecisionManager adm; + + private MethodSecurityMetadataSource mds; + + private final List role = SecurityConfig.createList("ROLE_IGNORED"); @Before public final void setUp() { SecurityContextHolder.clearContext(); - interceptor = new MethodSecurityInterceptor(); - token = new TestingAuthenticationToken("Test", "Password", "ROLE_SOMETHING"); - adm = mock(AccessDecisionManager.class); + this.interceptor = new MethodSecurityInterceptor(); + this.token = new TestingAuthenticationToken("Test", "Password", "ROLE_SOMETHING"); + this.adm = mock(AccessDecisionManager.class); AuthenticationManager authman = mock(AuthenticationManager.class); - mds = mock(MethodSecurityMetadataSource.class); - interceptor.setAccessDecisionManager(adm); - interceptor.setAuthenticationManager(authman); - interceptor.setSecurityMetadataSource(mds); + this.mds = mock(MethodSecurityMetadataSource.class); + this.interceptor.setAccessDecisionManager(this.adm); + this.interceptor.setAuthenticationManager(authman); + this.interceptor.setSecurityMetadataSource(this.mds); } @Test public void allowsAccessUsingCreate() throws Exception { Object object = new TargetObject(); - final MethodInvocation mi = MethodInvocationUtils.create(object, "makeLowerCase", - "foobar"); - + final MethodInvocation mi = MethodInvocationUtils.create(object, "makeLowerCase", "foobar"); MethodInvocationPrivilegeEvaluator mipe = new MethodInvocationPrivilegeEvaluator(); - when(mds.getAttributes(mi)).thenReturn(role); - - mipe.setSecurityInterceptor(interceptor); + given(this.mds.getAttributes(mi)).willReturn(this.role); + mipe.setSecurityInterceptor(this.interceptor); mipe.afterPropertiesSet(); - - assertThat(mipe.isAllowed(mi, token)).isTrue(); + assertThat(mipe.isAllowed(mi, this.token)).isTrue(); } @Test public void allowsAccessUsingCreateFromClass() { - final MethodInvocation mi = MethodInvocationUtils.createFromClass( - new OtherTargetObject(), ITargetObject.class, "makeLowerCase", - new Class[] { String.class }, new Object[] { "Hello world" }); + final MethodInvocation mi = MethodInvocationUtils.createFromClass(new OtherTargetObject(), ITargetObject.class, + "makeLowerCase", new Class[] { String.class }, new Object[] { "Hello world" }); MethodInvocationPrivilegeEvaluator mipe = new MethodInvocationPrivilegeEvaluator(); - mipe.setSecurityInterceptor(interceptor); - when(mds.getAttributes(mi)).thenReturn(role); - - assertThat(mipe.isAllowed(mi, token)).isTrue(); + mipe.setSecurityInterceptor(this.interceptor); + given(this.mds.getAttributes(mi)).willReturn(this.role); + assertThat(mipe.isAllowed(mi, this.token)).isTrue(); } @Test public void declinesAccessUsingCreate() { Object object = new TargetObject(); - final MethodInvocation mi = MethodInvocationUtils.create(object, "makeLowerCase", - "foobar"); + final MethodInvocation mi = MethodInvocationUtils.create(object, "makeLowerCase", "foobar"); MethodInvocationPrivilegeEvaluator mipe = new MethodInvocationPrivilegeEvaluator(); - mipe.setSecurityInterceptor(interceptor); - when(mds.getAttributes(mi)).thenReturn(role); - doThrow(new AccessDeniedException("rejected")).when(adm).decide(token, mi, role); - - assertThat(mipe.isAllowed(mi, token)).isFalse(); + mipe.setSecurityInterceptor(this.interceptor); + given(this.mds.getAttributes(mi)).willReturn(this.role); + willThrow(new AccessDeniedException("rejected")).given(this.adm).decide(this.token, mi, this.role); + assertThat(mipe.isAllowed(mi, this.token)).isFalse(); } @Test public void declinesAccessUsingCreateFromClass() { - final MethodInvocation mi = MethodInvocationUtils.createFromClass( - new OtherTargetObject(), ITargetObject.class, "makeLowerCase", - new Class[] { String.class }, new Object[] { "helloWorld" }); - + final MethodInvocation mi = MethodInvocationUtils.createFromClass(new OtherTargetObject(), ITargetObject.class, + "makeLowerCase", new Class[] { String.class }, new Object[] { "helloWorld" }); MethodInvocationPrivilegeEvaluator mipe = new MethodInvocationPrivilegeEvaluator(); - mipe.setSecurityInterceptor(interceptor); - when(mds.getAttributes(mi)).thenReturn(role); - doThrow(new AccessDeniedException("rejected")).when(adm).decide(token, mi, role); - - assertThat(mipe.isAllowed(mi, token)).isFalse(); + mipe.setSecurityInterceptor(this.interceptor); + given(this.mds.getAttributes(mi)).willReturn(this.role); + willThrow(new AccessDeniedException("rejected")).given(this.adm).decide(this.token, mi, this.role); + assertThat(mipe.isAllowed(mi, this.token)).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/access/intercept/method/MockMethodInvocation.java b/core/src/test/java/org/springframework/security/access/intercept/method/MockMethodInvocation.java index 8c295000ff..f0e8c4b2de 100644 --- a/core/src/test/java/org/springframework/security/access/intercept/method/MockMethodInvocation.java +++ b/core/src/test/java/org/springframework/security/access/intercept/method/MockMethodInvocation.java @@ -13,48 +13,58 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.intercept.method; -import org.aopalliance.intercept.MethodInvocation; +package org.springframework.security.access.intercept.method; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Method; +import org.aopalliance.intercept.MethodInvocation; + @SuppressWarnings("unchecked") public class MockMethodInvocation implements MethodInvocation { + private Method method; + private Object targetObject; + private Object[] arguments = new Object[0]; public MockMethodInvocation(Object targetObject, Class clazz, String methodName, Class[] parameterTypes, - Object[] arguments) throws NoSuchMethodException { + Object[] arguments) throws NoSuchMethodException { this(targetObject, clazz, methodName, parameterTypes); this.arguments = arguments; } - public MockMethodInvocation(Object targetObject, Class clazz, String methodName, - Class... parameterTypes) throws NoSuchMethodException { + public MockMethodInvocation(Object targetObject, Class clazz, String methodName, Class... parameterTypes) + throws NoSuchMethodException { this.method = clazz.getMethod(methodName, parameterTypes); this.targetObject = targetObject; } + @Override public Object[] getArguments() { - return arguments; + return this.arguments; } + @Override public Method getMethod() { - return method; + return this.method; } + @Override public AccessibleObject getStaticPart() { return null; } + @Override public Object getThis() { - return targetObject; + return this.targetObject; } + @Override public Object proceed() { return null; } + } diff --git a/core/src/test/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSourceTests.java b/core/src/test/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSourceTests.java index b16f71b8bd..cdab7177ca 100644 --- a/core/src/test/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSourceTests.java +++ b/core/src/test/java/org/springframework/security/access/method/DelegatingMethodSecurityMetadataSourceTests.java @@ -13,42 +13,48 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.method; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.aopalliance.intercept.MethodInvocation; import org.junit.Test; import org.mockito.ArgumentMatchers; + import org.springframework.security.access.ConfigAttribute; import org.springframework.security.util.SimpleMethodInvocation; -import java.lang.reflect.Method; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; /** * @author Luke Taylor */ @SuppressWarnings({ "unchecked" }) public class DelegatingMethodSecurityMetadataSourceTests { + DelegatingMethodSecurityMetadataSource mds; @Test public void returnsEmptyListIfDelegateReturnsNull() throws Exception { List sources = new ArrayList(); MethodSecurityMetadataSource delegate = mock(MethodSecurityMetadataSource.class); - when(delegate.getAttributes(ArgumentMatchers. any(), ArgumentMatchers.any(Class.class))) - .thenReturn(null); + given(delegate.getAttributes(ArgumentMatchers.any(), ArgumentMatchers.any(Class.class))) + .willReturn(null); sources.add(delegate); - mds = new DelegatingMethodSecurityMetadataSource(sources); - assertThat(mds.getMethodSecurityMetadataSources()).isSameAs(sources); - assertThat(mds.getAllConfigAttributes().isEmpty()).isTrue(); - MethodInvocation mi = new SimpleMethodInvocation(null, - String.class.getMethod("toString")); - assertThat(mds.getAttributes(mi)).isEqualTo(Collections.emptyList()); + this.mds = new DelegatingMethodSecurityMetadataSource(sources); + assertThat(this.mds.getMethodSecurityMetadataSources()).isSameAs(sources); + assertThat(this.mds.getAllConfigAttributes().isEmpty()).isTrue(); + MethodInvocation mi = new SimpleMethodInvocation(null, String.class.getMethod("toString")); + assertThat(this.mds.getAttributes(mi)).isEqualTo(Collections.emptyList()); // Exercise the cached case - assertThat(mds.getAttributes(mi)).isEqualTo(Collections.emptyList()); + assertThat(this.mds.getAttributes(mi)).isEqualTo(Collections.emptyList()); } @Test @@ -58,17 +64,17 @@ public class DelegatingMethodSecurityMetadataSourceTests { ConfigAttribute ca = mock(ConfigAttribute.class); List attributes = Arrays.asList(ca); Method toString = String.class.getMethod("toString"); - when(delegate.getAttributes(toString, String.class)).thenReturn(attributes); + given(delegate.getAttributes(toString, String.class)).willReturn(attributes); sources.add(delegate); - mds = new DelegatingMethodSecurityMetadataSource(sources); - assertThat(mds.getMethodSecurityMetadataSources()).isSameAs(sources); - assertThat(mds.getAllConfigAttributes().isEmpty()).isTrue(); + this.mds = new DelegatingMethodSecurityMetadataSource(sources); + assertThat(this.mds.getMethodSecurityMetadataSources()).isSameAs(sources); + assertThat(this.mds.getAllConfigAttributes().isEmpty()).isTrue(); MethodInvocation mi = new SimpleMethodInvocation("", toString); - assertThat(mds.getAttributes(mi)).isSameAs(attributes); + assertThat(this.mds.getAttributes(mi)).isSameAs(attributes); // Exercise the cached case - assertThat(mds.getAttributes(mi)).isSameAs(attributes); - assertThat(mds.getAttributes( - new SimpleMethodInvocation(null, String.class.getMethod("length")))).isEmpty(); + assertThat(this.mds.getAttributes(mi)).isSameAs(attributes); + assertThat(this.mds.getAttributes(new SimpleMethodInvocation(null, String.class.getMethod("length")))) + .isEmpty(); } } diff --git a/core/src/test/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoterTests.java b/core/src/test/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoterTests.java index bb6f951680..417faedaab 100644 --- a/core/src/test/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoterTests.java +++ b/core/src/test/java/org/springframework/security/access/prepost/PreInvocationAuthorizationAdviceVoterTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.access.prepost; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.access.prepost; import org.aopalliance.intercept.MethodInvocation; import org.junit.Before; @@ -23,33 +22,39 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.aop.ProxyMethodInvocation; import org.springframework.security.access.intercept.aspectj.MethodInvocationAdapter; +import static org.assertj.core.api.Assertions.assertThat; + @RunWith(MockitoJUnitRunner.class) public class PreInvocationAuthorizationAdviceVoterTests { + @Mock private PreInvocationAuthorizationAdvice authorizationAdvice; + private PreInvocationAuthorizationAdviceVoter voter; @Before public void setUp() { - voter = new PreInvocationAuthorizationAdviceVoter(authorizationAdvice); + this.voter = new PreInvocationAuthorizationAdviceVoter(this.authorizationAdvice); } @Test public void supportsMethodInvocation() { - assertThat(voter.supports(MethodInvocation.class)).isTrue(); + assertThat(this.voter.supports(MethodInvocation.class)).isTrue(); } // SEC-2031 @Test public void supportsProxyMethodInvocation() { - assertThat(voter.supports(ProxyMethodInvocation.class)).isTrue(); + assertThat(this.voter.supports(ProxyMethodInvocation.class)).isTrue(); } @Test public void supportsMethodInvocationAdapter() { - assertThat(voter.supports(MethodInvocationAdapter.class)).isTrue(); + assertThat(this.voter.supports(MethodInvocationAdapter.class)).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/AbstractAccessDecisionManagerTests.java b/core/src/test/java/org/springframework/security/access/vote/AbstractAccessDecisionManagerTests.java index e3fb32d6bb..b0cfe45d60 100644 --- a/core/src/test/java/org/springframework/security/access/vote/AbstractAccessDecisionManagerTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/AbstractAccessDecisionManagerTests.java @@ -16,19 +16,20 @@ package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.Collection; import java.util.List; import java.util.Vector; import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AbstractAccessDecisionManager}. * @@ -37,8 +38,6 @@ import org.springframework.security.core.Authentication; @SuppressWarnings("unchecked") public class AbstractAccessDecisionManagerTests { - // ~ Methods - // ======================================================================================================== @Test public void testAllowIfAccessDecisionManagerDefaults() { List list = new Vector(); @@ -55,9 +54,7 @@ public class AbstractAccessDecisionManagerTests { List list = new Vector(); list.add(new DenyVoter()); list.add(new MockStringOnlyVoter()); - MockDecisionManagerImpl mock = new MockDecisionManagerImpl(list); - assertThat(mock.supports(String.class)).isTrue(); assertThat(!mock.supports(Integer.class)).isTrue(); } @@ -69,12 +66,9 @@ public class AbstractAccessDecisionManagerTests { DenyAgainVoter denyVoter = new DenyAgainVoter(); list.add(voter); list.add(denyVoter); - MockDecisionManagerImpl mock = new MockDecisionManagerImpl(list); - ConfigAttribute attr = new SecurityConfig("DENY_AGAIN_FOR_SURE"); assertThat(mock.supports(attr)).isTrue(); - ConfigAttribute badAttr = new SecurityConfig("WE_DONT_SUPPORT_THIS"); assertThat(!mock.supports(badAttr)).isTrue(); } @@ -93,13 +87,11 @@ public class AbstractAccessDecisionManagerTests { @Test public void testRejectsEmptyList() { List list = new Vector(); - try { new MockDecisionManagerImpl(list); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @@ -110,7 +102,6 @@ public class AbstractAccessDecisionManagerTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @@ -127,38 +118,38 @@ public class AbstractAccessDecisionManagerTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } - // ~ Inner Classes - // ================================================================================================== - private class MockDecisionManagerImpl extends AbstractAccessDecisionManager { - protected MockDecisionManagerImpl( - List> decisionVoters) { + protected MockDecisionManagerImpl(List> decisionVoters) { super(decisionVoters); } - public void decide(Authentication authentication, Object object, - Collection configAttributes) { + @Override + public void decide(Authentication authentication, Object object, Collection configAttributes) { } + } private class MockStringOnlyVoter implements AccessDecisionVoter { + @Override public boolean supports(Class clazz) { return String.class.isAssignableFrom(clazz); } + @Override public boolean supports(ConfigAttribute attribute) { throw new UnsupportedOperationException("mock method not implemented"); } - public int vote(Authentication authentication, Object object, - Collection attributes) { + @Override + public int vote(Authentication authentication, Object object, Collection attributes) { throw new UnsupportedOperationException("mock method not implemented"); } + } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/AbstractAclVoterTests.java b/core/src/test/java/org/springframework/security/access/vote/AbstractAclVoterTests.java index 4f66e67244..7cb15549ac 100644 --- a/core/src/test/java/org/springframework/security/access/vote/AbstractAclVoterTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/AbstractAclVoterTests.java @@ -13,28 +13,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.*; - -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; import org.aopalliance.intercept.MethodInvocation; import org.junit.Test; + import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; import org.springframework.security.util.MethodInvocationUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Luke Taylor */ public class AbstractAclVoterTests { + private AbstractAclVoter voter = new AbstractAclVoter() { + @Override public boolean supports(ConfigAttribute attribute) { return false; } + @Override public int vote(Authentication authentication, MethodInvocation object, Collection attributes) { return 0; @@ -43,28 +48,28 @@ public class AbstractAclVoterTests { @Test public void supportsMethodInvocations() { - assertThat(voter.supports(MethodInvocation.class)).isTrue(); - assertThat(voter.supports(String.class)).isFalse(); + assertThat(this.voter.supports(MethodInvocation.class)).isTrue(); + assertThat(this.voter.supports(String.class)).isFalse(); } @Test public void expectedDomainObjectArgumentIsReturnedFromMethodInvocation() { - voter.setProcessDomainObjectClass(String.class); - MethodInvocation mi = MethodInvocationUtils.create(new TestClass(), - "methodTakingAString", "The Argument"); - assertThat(voter.getDomainObjectInstance(mi)).isEqualTo("The Argument"); + this.voter.setProcessDomainObjectClass(String.class); + MethodInvocation mi = MethodInvocationUtils.create(new TestClass(), "methodTakingAString", "The Argument"); + assertThat(this.voter.getDomainObjectInstance(mi)).isEqualTo("The Argument"); } @Test public void correctArgumentIsSelectedFromMultipleArgs() { - voter.setProcessDomainObjectClass(String.class); - MethodInvocation mi = MethodInvocationUtils.create(new TestClass(), - "methodTakingAListAndAString", new ArrayList<>(), "The Argument"); - assertThat(voter.getDomainObjectInstance(mi)).isEqualTo("The Argument"); + this.voter.setProcessDomainObjectClass(String.class); + MethodInvocation mi = MethodInvocationUtils.create(new TestClass(), "methodTakingAListAndAString", + new ArrayList<>(), "The Argument"); + assertThat(this.voter.getDomainObjectInstance(mi)).isEqualTo("The Argument"); } @SuppressWarnings("unused") private static class TestClass { + public void methodTakingAString(String arg) { } @@ -73,6 +78,7 @@ public class AbstractAclVoterTests { public void methodTakingAListAndAString(ArrayList arg1, String arg2) { } + } } diff --git a/core/src/test/java/org/springframework/security/access/vote/AffirmativeBasedTests.java b/core/src/test/java/org/springframework/security/access/vote/AffirmativeBasedTests.java index bf237fb51b..d11135de93 100644 --- a/core/src/test/java/org/springframework/security/access/vote/AffirmativeBasedTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/AffirmativeBasedTests.java @@ -16,99 +16,101 @@ package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.assertThat; - -import static org.mockito.Mockito.*; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.junit.Before; import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests {@link AffirmativeBased}. * * @author Ben Alex */ public class AffirmativeBasedTests { + private final List attrs = new ArrayList<>(); - private final Authentication user = new TestingAuthenticationToken("somebody", - "password", "ROLE_1", "ROLE_2"); + + private final Authentication user = new TestingAuthenticationToken("somebody", "password", "ROLE_1", "ROLE_2"); + private AffirmativeBased mgr; + private AccessDecisionVoter grant; + private AccessDecisionVoter abstain; + private AccessDecisionVoter deny; @Before @SuppressWarnings("unchecked") public void setup() { - - grant = mock(AccessDecisionVoter.class); - abstain = mock(AccessDecisionVoter.class); - deny = mock(AccessDecisionVoter.class); - - when(grant.vote(any(Authentication.class), any(Object.class), any(List.class))) - .thenReturn(AccessDecisionVoter.ACCESS_GRANTED); - when(abstain.vote(any(Authentication.class), any(Object.class), any(List.class))) - .thenReturn(AccessDecisionVoter.ACCESS_ABSTAIN); - when(deny.vote(any(Authentication.class), any(Object.class), any(List.class))) - .thenReturn(AccessDecisionVoter.ACCESS_DENIED); + this.grant = mock(AccessDecisionVoter.class); + this.abstain = mock(AccessDecisionVoter.class); + this.deny = mock(AccessDecisionVoter.class); + given(this.grant.vote(any(Authentication.class), any(Object.class), any(List.class))) + .willReturn(AccessDecisionVoter.ACCESS_GRANTED); + given(this.abstain.vote(any(Authentication.class), any(Object.class), any(List.class))) + .willReturn(AccessDecisionVoter.ACCESS_ABSTAIN); + given(this.deny.vote(any(Authentication.class), any(Object.class), any(List.class))) + .willReturn(AccessDecisionVoter.ACCESS_DENIED); } @Test - public void oneAffirmativeVoteOneDenyVoteOneAbstainVoteGrantsAccess() - throws Exception { - - mgr = new AffirmativeBased(Arrays.> asList( - grant, deny, abstain)); - mgr.afterPropertiesSet(); - mgr.decide(user, new Object(), attrs); + public void oneAffirmativeVoteOneDenyVoteOneAbstainVoteGrantsAccess() throws Exception { + this.mgr = new AffirmativeBased( + Arrays.>asList(this.grant, this.deny, this.abstain)); + this.mgr.afterPropertiesSet(); + this.mgr.decide(this.user, new Object(), this.attrs); } @Test public void oneDenyVoteOneAbstainVoteOneAffirmativeVoteGrantsAccess() { - mgr = new AffirmativeBased(Arrays.> asList( - deny, abstain, grant)); - mgr.decide(user, new Object(), attrs); + this.mgr = new AffirmativeBased( + Arrays.>asList(this.deny, this.abstain, this.grant)); + this.mgr.decide(this.user, new Object(), this.attrs); } @Test public void oneAffirmativeVoteTwoAbstainVotesGrantsAccess() { - mgr = new AffirmativeBased(Arrays.> asList( - grant, abstain, abstain)); - mgr.decide(user, new Object(), attrs); + this.mgr = new AffirmativeBased( + Arrays.>asList(this.grant, this.abstain, this.abstain)); + this.mgr.decide(this.user, new Object(), this.attrs); } @Test(expected = AccessDeniedException.class) public void oneDenyVoteTwoAbstainVotesDeniesAccess() { - mgr = new AffirmativeBased(Arrays.> asList( - deny, abstain, abstain)); - mgr.decide(user, new Object(), attrs); + this.mgr = new AffirmativeBased( + Arrays.>asList(this.deny, this.abstain, this.abstain)); + this.mgr.decide(this.user, new Object(), this.attrs); } @Test(expected = AccessDeniedException.class) public void onlyAbstainVotesDeniesAccessWithDefault() { - mgr = new AffirmativeBased(Arrays.> asList( - abstain, abstain, abstain)); - assertThat(!mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check default - - mgr.decide(user, new Object(), attrs); + this.mgr = new AffirmativeBased( + Arrays.>asList(this.abstain, this.abstain, this.abstain)); + assertThat(!this.mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check default + this.mgr.decide(this.user, new Object(), this.attrs); } @Test public void testThreeAbstainVotesGrantsAccessIfAllowIfAllAbstainDecisionsIsSet() { - mgr = new AffirmativeBased(Arrays.> asList( - abstain, abstain, abstain)); - mgr.setAllowIfAllAbstainDecisions(true); - assertThat(mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check changed - - mgr.decide(user, new Object(), attrs); + this.mgr = new AffirmativeBased( + Arrays.>asList(this.abstain, this.abstain, this.abstain)); + this.mgr.setAllowIfAllAbstainDecisions(true); + assertThat(this.mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check changed + this.mgr.decide(this.user, new Object(), this.attrs); } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/AuthenticatedVoterTests.java b/core/src/test/java/org/springframework/security/access/vote/AuthenticatedVoterTests.java index 8a8d3f4043..595bd55fc9 100644 --- a/core/src/test/java/org/springframework/security/access/vote/AuthenticatedVoterTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/AuthenticatedVoterTests.java @@ -16,12 +16,10 @@ package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.List; import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; @@ -31,6 +29,9 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AuthenticatedVoter}. * @@ -39,8 +40,7 @@ import org.springframework.security.core.authority.AuthorityUtils; public class AuthenticatedVoterTests { private Authentication createAnonymous() { - return new AnonymousAuthenticationToken("ignored", "ignored", - AuthorityUtils.createAuthorityList("ignored")); + return new AnonymousAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("ignored")); } private Authentication createFullyAuthenticated() { @@ -49,59 +49,44 @@ public class AuthenticatedVoterTests { } private Authentication createRememberMe() { - return new RememberMeAuthenticationToken("ignored", "ignored", - AuthorityUtils.createAuthorityList("ignored")); + return new RememberMeAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("ignored")); } @Test public void testAnonymousWorks() { AuthenticatedVoter voter = new AuthenticatedVoter(); - List def = SecurityConfig.createList( - AuthenticatedVoter.IS_AUTHENTICATED_ANONYMOUSLY); - assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo( - voter.vote(createAnonymous(), null, def)); - assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo( - voter.vote(createRememberMe(), null, def)); - assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo( - voter.vote(createFullyAuthenticated(), null, def)); + List def = SecurityConfig.createList(AuthenticatedVoter.IS_AUTHENTICATED_ANONYMOUSLY); + assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo(voter.vote(createAnonymous(), null, def)); + assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo(voter.vote(createRememberMe(), null, def)); + assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo(voter.vote(createFullyAuthenticated(), null, def)); } @Test public void testFullyWorks() { AuthenticatedVoter voter = new AuthenticatedVoter(); - List def = SecurityConfig.createList( - AuthenticatedVoter.IS_AUTHENTICATED_FULLY); - assertThat(AccessDecisionVoter.ACCESS_DENIED).isEqualTo( - voter.vote(createAnonymous(), null, def)); - assertThat(AccessDecisionVoter.ACCESS_DENIED).isEqualTo( - voter.vote(createRememberMe(), null, def)); - assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo( - voter.vote(createFullyAuthenticated(), null, def)); + List def = SecurityConfig.createList(AuthenticatedVoter.IS_AUTHENTICATED_FULLY); + assertThat(AccessDecisionVoter.ACCESS_DENIED).isEqualTo(voter.vote(createAnonymous(), null, def)); + assertThat(AccessDecisionVoter.ACCESS_DENIED).isEqualTo(voter.vote(createRememberMe(), null, def)); + assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo(voter.vote(createFullyAuthenticated(), null, def)); } @Test public void testRememberMeWorks() { AuthenticatedVoter voter = new AuthenticatedVoter(); - List def = SecurityConfig.createList( - AuthenticatedVoter.IS_AUTHENTICATED_REMEMBERED); - assertThat(AccessDecisionVoter.ACCESS_DENIED).isEqualTo( - voter.vote(createAnonymous(), null, def)); - assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo( - voter.vote(createRememberMe(), null, def)); - assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo( - voter.vote(createFullyAuthenticated(), null, def)); + List def = SecurityConfig.createList(AuthenticatedVoter.IS_AUTHENTICATED_REMEMBERED); + assertThat(AccessDecisionVoter.ACCESS_DENIED).isEqualTo(voter.vote(createAnonymous(), null, def)); + assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo(voter.vote(createRememberMe(), null, def)); + assertThat(AccessDecisionVoter.ACCESS_GRANTED).isEqualTo(voter.vote(createFullyAuthenticated(), null, def)); } @Test public void testSetterRejectsNull() { AuthenticatedVoter voter = new AuthenticatedVoter(); - try { voter.setAuthenticationTrustResolver(null); fail("Expected IAE"); } catch (IllegalArgumentException expected) { - } } @@ -109,12 +94,10 @@ public class AuthenticatedVoterTests { public void testSupports() { AuthenticatedVoter voter = new AuthenticatedVoter(); assertThat(voter.supports(String.class)).isTrue(); - assertThat(voter.supports(new SecurityConfig( - AuthenticatedVoter.IS_AUTHENTICATED_ANONYMOUSLY))).isTrue(); - assertThat(voter.supports( - new SecurityConfig(AuthenticatedVoter.IS_AUTHENTICATED_FULLY))).isTrue(); - assertThat(voter.supports(new SecurityConfig( - AuthenticatedVoter.IS_AUTHENTICATED_REMEMBERED))).isTrue(); + assertThat(voter.supports(new SecurityConfig(AuthenticatedVoter.IS_AUTHENTICATED_ANONYMOUSLY))).isTrue(); + assertThat(voter.supports(new SecurityConfig(AuthenticatedVoter.IS_AUTHENTICATED_FULLY))).isTrue(); + assertThat(voter.supports(new SecurityConfig(AuthenticatedVoter.IS_AUTHENTICATED_REMEMBERED))).isTrue(); assertThat(voter.supports(new SecurityConfig("FOO"))).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/ConsensusBasedTests.java b/core/src/test/java/org/springframework/security/access/vote/ConsensusBasedTests.java index e8fca78fc2..647387d2aa 100644 --- a/core/src/test/java/org/springframework/security/access/vote/ConsensusBasedTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/ConsensusBasedTests.java @@ -16,16 +16,19 @@ package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.*; +import java.util.List; +import java.util.Vector; + +import org.junit.Test; -import org.junit.*; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.authentication.TestingAuthenticationToken; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** * Tests {@link ConsensusBased}. @@ -40,10 +43,7 @@ public class ConsensusBasedTests { ConsensusBased mgr = makeDecisionManager(); mgr.setAllowIfEqualGrantedDeniedDecisions(false); assertThat(!mgr.isAllowIfEqualGrantedDeniedDecisions()).isTrue(); // check changed - - List config = SecurityConfig.createList("ROLE_1", - "DENY_FOR_SURE"); - + List config = SecurityConfig.createList("ROLE_1", "DENY_FOR_SURE"); mgr.decide(auth, new Object(), config); } @@ -51,30 +51,22 @@ public class ConsensusBasedTests { public void testOneAffirmativeVoteOneDenyVoteOneAbstainVoteGrantsAccessWithDefault() { TestingAuthenticationToken auth = makeTestToken(); ConsensusBased mgr = makeDecisionManager(); - assertThat(mgr.isAllowIfEqualGrantedDeniedDecisions()).isTrue(); // check default - - List config = SecurityConfig.createList("ROLE_1", - "DENY_FOR_SURE"); - + List config = SecurityConfig.createList("ROLE_1", "DENY_FOR_SURE"); mgr.decide(auth, new Object(), config); - } @Test public void testOneAffirmativeVoteTwoAbstainVotesGrantsAccess() { TestingAuthenticationToken auth = makeTestToken(); ConsensusBased mgr = makeDecisionManager(); - mgr.decide(auth, new Object(), SecurityConfig.createList("ROLE_2")); - } @Test(expected = AccessDeniedException.class) public void testOneDenyVoteTwoAbstainVotesDeniesAccess() { TestingAuthenticationToken auth = makeTestToken(); ConsensusBased mgr = makeDecisionManager(); - mgr.decide(auth, new Object(), SecurityConfig.createList("ROLE_WE_DO_NOT_HAVE")); fail("Should have thrown AccessDeniedException"); } @@ -83,9 +75,7 @@ public class ConsensusBasedTests { public void testThreeAbstainVotesDeniesAccessWithDefault() { TestingAuthenticationToken auth = makeTestToken(); ConsensusBased mgr = makeDecisionManager(); - assertThat(!mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check default - mgr.decide(auth, new Object(), SecurityConfig.createList("IGNORED_BY_ALL")); } @@ -95,7 +85,6 @@ public class ConsensusBasedTests { ConsensusBased mgr = makeDecisionManager(); mgr.setAllowIfAllAbstainDecisions(true); assertThat(mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check changed - mgr.decide(auth, new Object(), SecurityConfig.createList("IGNORED_BY_ALL")); } @@ -103,7 +92,6 @@ public class ConsensusBasedTests { public void testTwoAffirmativeVotesTwoAbstainVotesGrantsAccess() { TestingAuthenticationToken auth = makeTestToken(); ConsensusBased mgr = makeDecisionManager(); - mgr.decide(auth, new Object(), SecurityConfig.createList("ROLE_1", "ROLE_2")); } @@ -115,11 +103,11 @@ public class ConsensusBasedTests { voters.add(roleVoter); voters.add(denyForSureVoter); voters.add(denyAgainForSureVoter); - return new ConsensusBased(voters); } private TestingAuthenticationToken makeTestToken() { return new TestingAuthenticationToken("somebody", "password", "ROLE_1", "ROLE_2"); } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/DenyAgainVoter.java b/core/src/test/java/org/springframework/security/access/vote/DenyAgainVoter.java index e48e0548f1..4d59a0173b 100644 --- a/core/src/test/java/org/springframework/security/access/vote/DenyAgainVoter.java +++ b/core/src/test/java/org/springframework/security/access/vote/DenyAgainVoter.java @@ -16,13 +16,13 @@ package org.springframework.security.access.vote; +import java.util.Collection; +import java.util.Iterator; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; -import java.util.Collection; -import java.util.Iterator; - /** * Implementation of an {@link AccessDecisionVoter} for unit testing. *

      @@ -34,34 +34,26 @@ import java.util.Iterator; * @author Ben Alex */ public class DenyAgainVoter implements AccessDecisionVoter { - // ~ Methods - // ======================================================================================================== + @Override public boolean supports(ConfigAttribute attribute) { - if ("DENY_AGAIN_FOR_SURE".equals(attribute.getAttribute())) { - return true; - } - else { - return false; - } + return "DENY_AGAIN_FOR_SURE".equals(attribute.getAttribute()); } + @Override public boolean supports(Class clazz) { return true; } - public int vote(Authentication authentication, Object object, - Collection attributes) { + @Override + public int vote(Authentication authentication, Object object, Collection attributes) { Iterator iter = attributes.iterator(); - while (iter.hasNext()) { ConfigAttribute attribute = iter.next(); - if (this.supports(attribute)) { return ACCESS_DENIED; } } - return ACCESS_ABSTAIN; } diff --git a/core/src/test/java/org/springframework/security/access/vote/DenyVoter.java b/core/src/test/java/org/springframework/security/access/vote/DenyVoter.java index 2db4297778..b20964b020 100644 --- a/core/src/test/java/org/springframework/security/access/vote/DenyVoter.java +++ b/core/src/test/java/org/springframework/security/access/vote/DenyVoter.java @@ -16,13 +16,13 @@ package org.springframework.security.access.vote; +import java.util.Collection; +import java.util.Iterator; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; -import java.util.Collection; -import java.util.Iterator; - /** * Implementation of an {@link AccessDecisionVoter} for unit testing. *

      @@ -36,34 +36,27 @@ import java.util.Iterator; * @author Ben Alex */ public class DenyVoter implements AccessDecisionVoter { - // ~ Methods - // ======================================================================================================== + @Override public boolean supports(ConfigAttribute attribute) { - if ("DENY_FOR_SURE".equals(attribute.getAttribute())) { - return true; - } - else { - return false; - } + return "DENY_FOR_SURE".equals(attribute.getAttribute()); } + @Override public boolean supports(Class clazz) { return true; } - public int vote(Authentication authentication, Object object, - Collection attributes) { + @Override + public int vote(Authentication authentication, Object object, Collection attributes) { Iterator iter = attributes.iterator(); - while (iter.hasNext()) { ConfigAttribute attribute = iter.next(); - if (this.supports(attribute)) { return ACCESS_DENIED; } } - return ACCESS_ABSTAIN; } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/RoleHierarchyVoterTests.java b/core/src/test/java/org/springframework/security/access/vote/RoleHierarchyVoterTests.java index 681c61ff12..806ec7416b 100644 --- a/core/src/test/java/org/springframework/security/access/vote/RoleHierarchyVoterTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/RoleHierarchyVoterTests.java @@ -13,28 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + +import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl; -import org.springframework.security.access.vote.RoleHierarchyVoter; import org.springframework.security.authentication.TestingAuthenticationToken; +import static org.assertj.core.api.Assertions.assertThat; + public class RoleHierarchyVoterTests { @Test public void hierarchicalRoleIsIncludedInDecision() { RoleHierarchyImpl roleHierarchyImpl = new RoleHierarchyImpl(); roleHierarchyImpl.setHierarchy("ROLE_A > ROLE_B"); - // User has role A, role B is required - TestingAuthenticationToken auth = new TestingAuthenticationToken("user", - "password", "ROLE_A"); + TestingAuthenticationToken auth = new TestingAuthenticationToken("user", "password", "ROLE_A"); RoleHierarchyVoter voter = new RoleHierarchyVoter(roleHierarchyImpl); - - assertThat(voter.vote(auth, new Object(), SecurityConfig.createList("ROLE_B"))).isEqualTo(RoleHierarchyVoter.ACCESS_GRANTED); + assertThat(voter.vote(auth, new Object(), SecurityConfig.createList("ROLE_B"))) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/RoleVoterTests.java b/core/src/test/java/org/springframework/security/access/vote/RoleVoterTests.java index 741a7bcc7a..ca6215e486 100644 --- a/core/src/test/java/org/springframework/security/access/vote/RoleVoterTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/RoleVoterTests.java @@ -13,28 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.SecurityConfig; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Luke Taylor */ public class RoleVoterTests { + @Test public void oneMatchingAttributeGrantsAccess() { RoleVoter voter = new RoleVoter(); voter.setRolePrefix(""); Authentication userAB = new TestingAuthenticationToken("user", "pass", "A", "B"); // Vote on attribute list that has two attributes A and C (i.e. only one matching) - assertThat(voter.vote(userAB, this, SecurityConfig.createList("A", "C"))).isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + assertThat(voter.vote(userAB, this, SecurityConfig.createList("A", "C"))) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); } // SEC-3128 @@ -43,6 +46,8 @@ public class RoleVoterTests { RoleVoter voter = new RoleVoter(); voter.setRolePrefix(""); Authentication notAuthenitcated = null; - assertThat(voter.vote(notAuthenitcated, this, SecurityConfig.createList("A"))).isEqualTo(AccessDecisionVoter.ACCESS_DENIED); + assertThat(voter.vote(notAuthenitcated, this, SecurityConfig.createList("A"))) + .isEqualTo(AccessDecisionVoter.ACCESS_DENIED); } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/SomeDomainObject.java b/core/src/test/java/org/springframework/security/access/vote/SomeDomainObject.java index ab2b8f0c67..93d47219e0 100644 --- a/core/src/test/java/org/springframework/security/access/vote/SomeDomainObject.java +++ b/core/src/test/java/org/springframework/security/access/vote/SomeDomainObject.java @@ -22,22 +22,15 @@ package org.springframework.security.access.vote; * @author Ben Alex */ public class SomeDomainObject { - // ~ Instance fields - // ================================================================================================ private String identity; - // ~ Constructors - // =================================================================================================== - public SomeDomainObject(String identity) { this.identity = identity; } - // ~ Methods - // ======================================================================================================== - public String getParent() { - return "parentOf" + identity; + return "parentOf" + this.identity; } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/SomeDomainObjectManager.java b/core/src/test/java/org/springframework/security/access/vote/SomeDomainObjectManager.java index f7431ff6a4..abde87e604 100644 --- a/core/src/test/java/org/springframework/security/access/vote/SomeDomainObjectManager.java +++ b/core/src/test/java/org/springframework/security/access/vote/SomeDomainObjectManager.java @@ -23,9 +23,8 @@ package org.springframework.security.access.vote; * @author Ben Alex */ public class SomeDomainObjectManager { - // ~ Methods - // ======================================================================================================== public void someServiceMethod(SomeDomainObject someDomainObject) { } + } diff --git a/core/src/test/java/org/springframework/security/access/vote/UnanimousBasedTests.java b/core/src/test/java/org/springframework/security/access/vote/UnanimousBasedTests.java index a00a611850..943d31da0a 100644 --- a/core/src/test/java/org/springframework/security/access/vote/UnanimousBasedTests.java +++ b/core/src/test/java/org/springframework/security/access/vote/UnanimousBasedTests.java @@ -16,19 +16,20 @@ package org.springframework.security.access.vote; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.List; import java.util.Vector; import org.junit.Test; + import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.authentication.TestingAuthenticationToken; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link UnanimousBased}. * @@ -36,9 +37,6 @@ import org.springframework.security.authentication.TestingAuthenticationToken; */ public class UnanimousBasedTests { - // ~ Methods - // ======================================================================================================== - private UnanimousBased makeDecisionManager() { RoleVoter roleVoter = new RoleVoter(); DenyVoter denyForSureVoter = new DenyVoter(); @@ -53,7 +51,6 @@ public class UnanimousBasedTests { private UnanimousBased makeDecisionManagerWithFooBarPrefix() { RoleVoter roleVoter = new RoleVoter(); roleVoter.setRolePrefix("FOOBAR_"); - DenyVoter denyForSureVoter = new DenyVoter(); DenyAgainVoter denyAgainForSureVoter = new DenyAgainVoter(); List> voters = new Vector<>(); @@ -68,18 +65,14 @@ public class UnanimousBasedTests { } private TestingAuthenticationToken makeTestTokenWithFooBarPrefix() { - return new TestingAuthenticationToken("somebody", "password", "FOOBAR_1", - "FOOBAR_2"); + return new TestingAuthenticationToken("somebody", "password", "FOOBAR_1", "FOOBAR_2"); } @Test public void testOneAffirmativeVoteOneDenyVoteOneAbstainVoteDeniesAccess() { TestingAuthenticationToken auth = makeTestToken(); UnanimousBased mgr = makeDecisionManager(); - - List config = SecurityConfig.createList( - new String[] { "ROLE_1", "DENY_FOR_SURE" }); - + List config = SecurityConfig.createList(new String[] { "ROLE_1", "DENY_FOR_SURE" }); try { mgr.decide(auth, new Object(), config); fail("Should have thrown AccessDeniedException"); @@ -92,9 +85,7 @@ public class UnanimousBasedTests { public void testOneAffirmativeVoteTwoAbstainVotesGrantsAccess() { TestingAuthenticationToken auth = makeTestToken(); UnanimousBased mgr = makeDecisionManager(); - List config = SecurityConfig.createList("ROLE_2"); - mgr.decide(auth, new Object(), config); } @@ -102,9 +93,7 @@ public class UnanimousBasedTests { public void testOneDenyVoteTwoAbstainVotesDeniesAccess() { TestingAuthenticationToken auth = makeTestToken(); UnanimousBased mgr = makeDecisionManager(); - List config = SecurityConfig.createList("ROLE_WE_DO_NOT_HAVE"); - try { mgr.decide(auth, new Object(), config); fail("Should have thrown AccessDeniedException"); @@ -117,10 +106,7 @@ public class UnanimousBasedTests { public void testRoleVoterPrefixObserved() { TestingAuthenticationToken auth = makeTestTokenWithFooBarPrefix(); UnanimousBased mgr = makeDecisionManagerWithFooBarPrefix(); - - List config = SecurityConfig.createList( - new String[] { "FOOBAR_1", "FOOBAR_2" }); - + List config = SecurityConfig.createList(new String[] { "FOOBAR_1", "FOOBAR_2" }); mgr.decide(auth, new Object(), config); } @@ -128,11 +114,8 @@ public class UnanimousBasedTests { public void testThreeAbstainVotesDeniesAccessWithDefault() { TestingAuthenticationToken auth = makeTestToken(); UnanimousBased mgr = makeDecisionManager(); - assertThat(!mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check default - List config = SecurityConfig.createList("IGNORED_BY_ALL"); - try { mgr.decide(auth, new Object(), config); fail("Should have thrown AccessDeniedException"); @@ -147,9 +130,7 @@ public class UnanimousBasedTests { UnanimousBased mgr = makeDecisionManager(); mgr.setAllowIfAllAbstainDecisions(true); assertThat(mgr.isAllowIfAllAbstainDecisions()).isTrue(); // check changed - List config = SecurityConfig.createList("IGNORED_BY_ALL"); - mgr.decide(auth, new Object(), config); } @@ -157,10 +138,8 @@ public class UnanimousBasedTests { public void testTwoAffirmativeVotesTwoAbstainVotesGrantsAccess() { TestingAuthenticationToken auth = makeTestToken(); UnanimousBased mgr = makeDecisionManager(); - - List config = SecurityConfig.createList( - new String[] { "ROLE_1", "ROLE_2" }); - + List config = SecurityConfig.createList(new String[] { "ROLE_1", "ROLE_2" }); mgr.decide(auth, new Object(), config); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/AbstractAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/authentication/AbstractAuthenticationTokenTests.java index 7f5638b31e..90c0f82f4a 100644 --- a/core/src/test/java/org/springframework/security/authentication/AbstractAuthenticationTokenTests.java +++ b/core/src/test/java/org/springframework/security/authentication/AbstractAuthenticationTokenTests.java @@ -16,16 +16,21 @@ package org.springframework.security.authentication; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; -import org.junit.*; import org.springframework.security.core.AuthenticatedPrincipal; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Tests {@link AbstractAuthenticationToken}. @@ -33,34 +38,25 @@ import java.util.*; * @author Ben Alex */ public class AbstractAuthenticationTokenTests { - // ~ Instance fields - // ================================================================================================ private List authorities = null; - // ~ Methods - // ======================================================================================================== - @Before public final void setUp() { - authorities = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); + this.authorities = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); } @Test(expected = UnsupportedOperationException.class) public void testAuthoritiesAreImmutable() { - MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", - authorities); - List gotAuthorities = (List) token - .getAuthorities(); - assertThat(gotAuthorities).isNotSameAs(authorities); - + MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", this.authorities); + List gotAuthorities = (List) token.getAuthorities(); + assertThat(gotAuthorities).isNotSameAs(this.authorities); gotAuthorities.set(0, new SimpleGrantedAuthority("ROLE_SUPER_USER")); } @Test public void testGetters() { - MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", - authorities); + MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", this.authorities); assertThat(token.getPrincipal()).isEqualTo("Test"); assertThat(token.getCredentials()).isEqualTo("Password"); assertThat(token.getName()).isEqualTo("Test"); @@ -68,56 +64,39 @@ public class AbstractAuthenticationTokenTests { @Test public void testHashCode() { - MockAuthenticationImpl token1 = new MockAuthenticationImpl("Test", "Password", - authorities); - MockAuthenticationImpl token2 = new MockAuthenticationImpl("Test", "Password", - authorities); - MockAuthenticationImpl token3 = new MockAuthenticationImpl(null, null, - AuthorityUtils.NO_AUTHORITIES); + MockAuthenticationImpl token1 = new MockAuthenticationImpl("Test", "Password", this.authorities); + MockAuthenticationImpl token2 = new MockAuthenticationImpl("Test", "Password", this.authorities); + MockAuthenticationImpl token3 = new MockAuthenticationImpl(null, null, AuthorityUtils.NO_AUTHORITIES); assertThat(token2.hashCode()).isEqualTo(token1.hashCode()); assertThat(token1.hashCode() != token3.hashCode()).isTrue(); - token2.setAuthenticated(true); - assertThat(token1.hashCode() != token2.hashCode()).isTrue(); } @Test public void testObjectsEquals() { - MockAuthenticationImpl token1 = new MockAuthenticationImpl("Test", "Password", - authorities); - MockAuthenticationImpl token2 = new MockAuthenticationImpl("Test", "Password", - authorities); + MockAuthenticationImpl token1 = new MockAuthenticationImpl("Test", "Password", this.authorities); + MockAuthenticationImpl token2 = new MockAuthenticationImpl("Test", "Password", this.authorities); assertThat(token2).isEqualTo(token1); - - MockAuthenticationImpl token3 = new MockAuthenticationImpl("Test", - "Password_Changed", authorities); + MockAuthenticationImpl token3 = new MockAuthenticationImpl("Test", "Password_Changed", this.authorities); assertThat(!token1.equals(token3)).isTrue(); - - MockAuthenticationImpl token4 = new MockAuthenticationImpl("Test_Changed", - "Password", authorities); + MockAuthenticationImpl token4 = new MockAuthenticationImpl("Test_Changed", "Password", this.authorities); assertThat(!token1.equals(token4)).isTrue(); - MockAuthenticationImpl token5 = new MockAuthenticationImpl("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO_CHANGED")); assertThat(!token1.equals(token5)).isTrue(); - MockAuthenticationImpl token6 = new MockAuthenticationImpl("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_ONE")); assertThat(!token1.equals(token6)).isTrue(); - - MockAuthenticationImpl token7 = new MockAuthenticationImpl("Test", "Password", - null); + MockAuthenticationImpl token7 = new MockAuthenticationImpl("Test", "Password", null); assertThat(!token1.equals(token7)).isTrue(); assertThat(!token7.equals(token1)).isTrue(); - assertThat(!token1.equals(100)).isTrue(); } @Test public void testSetAuthenticated() { - MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", - authorities); + MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", this.authorities); assertThat(!token.isAuthenticated()).isTrue(); token.setAuthenticated(true); assertThat(token.isAuthenticated()).isTrue(); @@ -125,35 +104,30 @@ public class AbstractAuthenticationTokenTests { @Test public void testToStringWithAuthorities() { - MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", - authorities); + MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", this.authorities); assertThat(token.toString().lastIndexOf("ROLE_TWO") != -1).isTrue(); } @Test public void testToStringWithNullAuthorities() { - MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", - null); + MockAuthenticationImpl token = new MockAuthenticationImpl("Test", "Password", null); assertThat(token.toString().lastIndexOf("Not granted any authorities") != -1).isTrue(); } @Test public void testGetNameWhenPrincipalIsAuthenticatedPrincipal() { String principalName = "test"; - AuthenticatedPrincipal principal = mock(AuthenticatedPrincipal.class); - when(principal.getName()).thenReturn(principalName); - - MockAuthenticationImpl token = new MockAuthenticationImpl(principal, "Password", authorities); + given(principal.getName()).willReturn(principalName); + MockAuthenticationImpl token = new MockAuthenticationImpl(principal, "Password", this.authorities); assertThat(token.getName()).isEqualTo(principalName); verify(principal, times(1)).getName(); } - // ~ Inner Classes - // ================================================================================================== - private class MockAuthenticationImpl extends AbstractAuthenticationToken { + private Object credentials; + private Object principal; MockAuthenticationImpl(Object principal, Object credentials, List authorities) { @@ -162,12 +136,16 @@ public class AbstractAuthenticationTokenTests { this.credentials = credentials; } + @Override public Object getCredentials() { return this.credentials; } + @Override public Object getPrincipal() { return this.principal; } + } + } diff --git a/core/src/test/java/org/springframework/security/authentication/AuthenticationTrustResolverImplTests.java b/core/src/test/java/org/springframework/security/authentication/AuthenticationTrustResolverImplTests.java index fbc05e04ac..fa476f0738 100644 --- a/core/src/test/java/org/springframework/security/authentication/AuthenticationTrustResolverImplTests.java +++ b/core/src/test/java/org/springframework/security/authentication/AuthenticationTrustResolverImplTests.java @@ -16,11 +16,12 @@ package org.springframework.security.authentication; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests * {@link org.springframework.security.authentication.AuthenticationTrustResolverImpl}. @@ -29,40 +30,37 @@ import org.springframework.security.core.authority.AuthorityUtils; */ public class AuthenticationTrustResolverImplTests { - // ~ Methods - // ======================================================================================================== @Test public void testCorrectOperationIsAnonymous() { AuthenticationTrustResolverImpl trustResolver = new AuthenticationTrustResolverImpl(); - assertThat(trustResolver.isAnonymous(new AnonymousAuthenticationToken("ignored", - "ignored", AuthorityUtils.createAuthorityList("ignored")))).isTrue(); - assertThat(trustResolver.isAnonymous(new TestingAuthenticationToken("ignored", - "ignored", AuthorityUtils.createAuthorityList("ignored")))).isFalse(); + assertThat(trustResolver.isAnonymous( + new AnonymousAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("ignored")))) + .isTrue(); + assertThat(trustResolver.isAnonymous( + new TestingAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("ignored")))) + .isFalse(); } @Test public void testCorrectOperationIsRememberMe() { AuthenticationTrustResolverImpl trustResolver = new AuthenticationTrustResolverImpl(); - assertThat(trustResolver.isRememberMe(new RememberMeAuthenticationToken("ignored", - "ignored", AuthorityUtils.createAuthorityList("ignored")))).isTrue(); - assertThat(trustResolver.isAnonymous(new TestingAuthenticationToken("ignored", - "ignored", AuthorityUtils.createAuthorityList("ignored")))).isFalse(); + assertThat(trustResolver.isRememberMe( + new RememberMeAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("ignored")))) + .isTrue(); + assertThat(trustResolver.isAnonymous( + new TestingAuthenticationToken("ignored", "ignored", AuthorityUtils.createAuthorityList("ignored")))) + .isFalse(); } @Test public void testGettersSetters() { AuthenticationTrustResolverImpl trustResolver = new AuthenticationTrustResolverImpl(); - - assertThat(AnonymousAuthenticationToken.class).isEqualTo( - trustResolver.getAnonymousClass()); + assertThat(AnonymousAuthenticationToken.class).isEqualTo(trustResolver.getAnonymousClass()); trustResolver.setAnonymousClass(TestingAuthenticationToken.class); - assertThat(trustResolver.getAnonymousClass()).isEqualTo( - TestingAuthenticationToken.class); - - assertThat(RememberMeAuthenticationToken.class).isEqualTo( - trustResolver.getRememberMeClass()); + assertThat(trustResolver.getAnonymousClass()).isEqualTo(TestingAuthenticationToken.class); + assertThat(RememberMeAuthenticationToken.class).isEqualTo(trustResolver.getRememberMeClass()); trustResolver.setRememberMeClass(TestingAuthenticationToken.class); - assertThat(trustResolver.getRememberMeClass()).isEqualTo( - TestingAuthenticationToken.class); + assertThat(trustResolver.getRememberMeClass()).isEqualTo(TestingAuthenticationToken.class); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisherTests.java b/core/src/test/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisherTests.java index a62ba0acef..6a7ac10f3e 100644 --- a/core/src/test/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisherTests.java +++ b/core/src/test/java/org/springframework/security/authentication/DefaultAuthenticationEventPublisherTests.java @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; -import static org.mockito.Mockito.*; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.junit.Test; -import org.junit.*; import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.authentication.event.AbstractAuthenticationFailureEvent; import org.springframework.security.authentication.event.AuthenticationFailureBadCredentialsEvent; @@ -28,194 +32,182 @@ import org.springframework.security.authentication.event.AuthenticationFailureLo import org.springframework.security.authentication.event.AuthenticationFailureProviderNotFoundEvent; import org.springframework.security.authentication.event.AuthenticationFailureServiceExceptionEvent; import org.springframework.security.authentication.event.AuthenticationSuccessEvent; -import org.springframework.security.authentication.event.AbstractAuthenticationFailureEvent; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; -import java.util.*; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; /** * @author Luke Taylor */ public class DefaultAuthenticationEventPublisherTests { + DefaultAuthenticationEventPublisher publisher; @Test public void expectedDefaultMappingsAreSatisfied() { - publisher = new DefaultAuthenticationEventPublisher(); + this.publisher = new DefaultAuthenticationEventPublisher(); ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - publisher.setApplicationEventPublisher(appPublisher); + this.publisher.setApplicationEventPublisher(appPublisher); Authentication a = mock(Authentication.class); - Exception cause = new Exception(); Object extraInfo = new Object(); - publisher.publishAuthenticationFailure(new BadCredentialsException(""), a); - publisher.publishAuthenticationFailure(new BadCredentialsException("", cause), a); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureBadCredentialsEvent.class)); + this.publisher.publishAuthenticationFailure(new BadCredentialsException(""), a); + this.publisher.publishAuthenticationFailure(new BadCredentialsException("", cause), a); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureBadCredentialsEvent.class)); reset(appPublisher); - publisher.publishAuthenticationFailure(new UsernameNotFoundException(""), a); - publisher.publishAuthenticationFailure(new UsernameNotFoundException("", cause), - a); - publisher.publishAuthenticationFailure(new AccountExpiredException(""), a); - publisher.publishAuthenticationFailure(new AccountExpiredException("", cause), a); - publisher.publishAuthenticationFailure(new ProviderNotFoundException(""), a); - publisher.publishAuthenticationFailure(new DisabledException(""), a); - publisher.publishAuthenticationFailure(new DisabledException("", cause), a); - publisher.publishAuthenticationFailure(new LockedException(""), a); - publisher.publishAuthenticationFailure(new LockedException("", cause), a); - publisher.publishAuthenticationFailure(new AuthenticationServiceException(""), a); - publisher.publishAuthenticationFailure(new AuthenticationServiceException("", - cause), a); - publisher.publishAuthenticationFailure(new CredentialsExpiredException(""), a); - publisher.publishAuthenticationFailure( - new CredentialsExpiredException("", cause), a); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureBadCredentialsEvent.class)); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureExpiredEvent.class)); - verify(appPublisher).publishEvent( - isA(AuthenticationFailureProviderNotFoundEvent.class)); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureDisabledEvent.class)); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureLockedEvent.class)); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureServiceExceptionEvent.class)); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureCredentialsExpiredEvent.class)); + this.publisher.publishAuthenticationFailure(new UsernameNotFoundException(""), a); + this.publisher.publishAuthenticationFailure(new UsernameNotFoundException("", cause), a); + this.publisher.publishAuthenticationFailure(new AccountExpiredException(""), a); + this.publisher.publishAuthenticationFailure(new AccountExpiredException("", cause), a); + this.publisher.publishAuthenticationFailure(new ProviderNotFoundException(""), a); + this.publisher.publishAuthenticationFailure(new DisabledException(""), a); + this.publisher.publishAuthenticationFailure(new DisabledException("", cause), a); + this.publisher.publishAuthenticationFailure(new LockedException(""), a); + this.publisher.publishAuthenticationFailure(new LockedException("", cause), a); + this.publisher.publishAuthenticationFailure(new AuthenticationServiceException(""), a); + this.publisher.publishAuthenticationFailure(new AuthenticationServiceException("", cause), a); + this.publisher.publishAuthenticationFailure(new CredentialsExpiredException(""), a); + this.publisher.publishAuthenticationFailure(new CredentialsExpiredException("", cause), a); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureBadCredentialsEvent.class)); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureExpiredEvent.class)); + verify(appPublisher).publishEvent(isA(AuthenticationFailureProviderNotFoundEvent.class)); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureDisabledEvent.class)); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureLockedEvent.class)); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureServiceExceptionEvent.class)); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureCredentialsExpiredEvent.class)); verifyNoMoreInteractions(appPublisher); } @Test public void authenticationSuccessIsPublished() { - publisher = new DefaultAuthenticationEventPublisher(); + this.publisher = new DefaultAuthenticationEventPublisher(); ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - publisher.setApplicationEventPublisher(appPublisher); - publisher.publishAuthenticationSuccess(mock(Authentication.class)); + this.publisher.setApplicationEventPublisher(appPublisher); + this.publisher.publishAuthenticationSuccess(mock(Authentication.class)); verify(appPublisher).publishEvent(isA(AuthenticationSuccessEvent.class)); - - publisher.setApplicationEventPublisher(null); + this.publisher.setApplicationEventPublisher(null); // Should be ignored with null app publisher - publisher.publishAuthenticationSuccess(mock(Authentication.class)); + this.publisher.publishAuthenticationSuccess(mock(Authentication.class)); } @Test public void additionalExceptionMappingsAreSupported() { - publisher = new DefaultAuthenticationEventPublisher(); + this.publisher = new DefaultAuthenticationEventPublisher(); Properties p = new Properties(); - p.put(MockAuthenticationException.class.getName(), - AuthenticationFailureDisabledEvent.class.getName()); - publisher.setAdditionalExceptionMappings(p); + p.put(MockAuthenticationException.class.getName(), AuthenticationFailureDisabledEvent.class.getName()); + this.publisher.setAdditionalExceptionMappings(p); ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - - publisher.setApplicationEventPublisher(appPublisher); - publisher.publishAuthenticationFailure(new MockAuthenticationException("test"), + this.publisher.setApplicationEventPublisher(appPublisher); + this.publisher.publishAuthenticationFailure(new MockAuthenticationException("test"), mock(Authentication.class)); verify(appPublisher).publishEvent(isA(AuthenticationFailureDisabledEvent.class)); } @Test(expected = RuntimeException.class) public void missingEventClassExceptionCausesException() { - publisher = new DefaultAuthenticationEventPublisher(); + this.publisher = new DefaultAuthenticationEventPublisher(); Properties p = new Properties(); p.put(MockAuthenticationException.class.getName(), "NoSuchClass"); - publisher.setAdditionalExceptionMappings(p); + this.publisher.setAdditionalExceptionMappings(p); } @Test public void unknownFailureExceptionIsIgnored() { - publisher = new DefaultAuthenticationEventPublisher(); + this.publisher = new DefaultAuthenticationEventPublisher(); Properties p = new Properties(); - p.put(MockAuthenticationException.class.getName(), - AuthenticationFailureDisabledEvent.class.getName()); - publisher.setAdditionalExceptionMappings(p); + p.put(MockAuthenticationException.class.getName(), AuthenticationFailureDisabledEvent.class.getName()); + this.publisher.setAdditionalExceptionMappings(p); ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - - publisher.setApplicationEventPublisher(appPublisher); - publisher.publishAuthenticationFailure(new AuthenticationException("") { + this.publisher.setApplicationEventPublisher(appPublisher); + this.publisher.publishAuthenticationFailure(new AuthenticationException("") { }, mock(Authentication.class)); verifyZeroInteractions(appPublisher); } @Test(expected = IllegalArgumentException.class) public void emptyMapCausesException() { - Map, - Class> mappings = new HashMap<>(); - publisher = new DefaultAuthenticationEventPublisher(); - publisher.setAdditionalExceptionMappings(mappings); + Map, Class> mappings = new HashMap<>(); + this.publisher = new DefaultAuthenticationEventPublisher(); + this.publisher.setAdditionalExceptionMappings(mappings); } @Test(expected = IllegalArgumentException.class) public void missingExceptionClassCausesException() { - Map, - Class> mappings = new HashMap<>(); + Map, Class> mappings = new HashMap<>(); mappings.put(null, AuthenticationFailureLockedEvent.class); - publisher = new DefaultAuthenticationEventPublisher(); - publisher.setAdditionalExceptionMappings(mappings); + this.publisher = new DefaultAuthenticationEventPublisher(); + this.publisher.setAdditionalExceptionMappings(mappings); } @Test(expected = IllegalArgumentException.class) public void missingEventClassAsMapValueCausesException() { - Map, - Class> mappings = new HashMap<>(); + Map, Class> mappings = new HashMap<>(); mappings.put(LockedException.class, null); - publisher = new DefaultAuthenticationEventPublisher(); - publisher.setAdditionalExceptionMappings(mappings); + this.publisher = new DefaultAuthenticationEventPublisher(); + this.publisher.setAdditionalExceptionMappings(mappings); } @Test public void additionalExceptionMappingsUsingMapAreSupported() { - publisher = new DefaultAuthenticationEventPublisher(); - Map, - Class> mappings = new HashMap<>(); + this.publisher = new DefaultAuthenticationEventPublisher(); + Map, Class> mappings = new HashMap<>(); mappings.put(MockAuthenticationException.class, AuthenticationFailureDisabledEvent.class); - publisher.setAdditionalExceptionMappings(mappings); + this.publisher.setAdditionalExceptionMappings(mappings); ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - - publisher.setApplicationEventPublisher(appPublisher); - publisher.publishAuthenticationFailure(new MockAuthenticationException("test"), + this.publisher.setApplicationEventPublisher(appPublisher); + this.publisher.publishAuthenticationFailure(new MockAuthenticationException("test"), mock(Authentication.class)); verify(appPublisher).publishEvent(isA(AuthenticationFailureDisabledEvent.class)); } @Test(expected = IllegalArgumentException.class) public void defaultAuthenticationFailureEventClassSetNullThen() { - publisher = new DefaultAuthenticationEventPublisher(); - publisher.setDefaultAuthenticationFailureEvent(null); + this.publisher = new DefaultAuthenticationEventPublisher(); + this.publisher.setDefaultAuthenticationFailureEvent(null); } @Test public void defaultAuthenticationFailureEventIsPublished() { - publisher = new DefaultAuthenticationEventPublisher(); - publisher.setDefaultAuthenticationFailureEvent(AuthenticationFailureBadCredentialsEvent.class); + this.publisher = new DefaultAuthenticationEventPublisher(); + this.publisher.setDefaultAuthenticationFailureEvent(AuthenticationFailureBadCredentialsEvent.class); ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - - publisher.setApplicationEventPublisher(appPublisher); - publisher.publishAuthenticationFailure(new AuthenticationException("") { + this.publisher.setApplicationEventPublisher(appPublisher); + this.publisher.publishAuthenticationFailure(new AuthenticationException("") { }, mock(Authentication.class)); verify(appPublisher).publishEvent(isA(AuthenticationFailureBadCredentialsEvent.class)); } @Test(expected = RuntimeException.class) public void defaultAuthenticationFailureEventMissingAppropriateConstructorThen() { - publisher = new DefaultAuthenticationEventPublisher(); - publisher.setDefaultAuthenticationFailureEvent(AuthenticationFailureEventWithoutAppropriateConstructor.class); + this.publisher = new DefaultAuthenticationEventPublisher(); + this.publisher + .setDefaultAuthenticationFailureEvent(AuthenticationFailureEventWithoutAppropriateConstructor.class); } - private static final class AuthenticationFailureEventWithoutAppropriateConstructor extends - AbstractAuthenticationFailureEvent { + private static final class AuthenticationFailureEventWithoutAppropriateConstructor + extends AbstractAuthenticationFailureEvent { + AuthenticationFailureEventWithoutAppropriateConstructor(Authentication auth) { - super(auth, new AuthenticationException("") {}); + super(auth, new AuthenticationException("") { + }); } + } - private static final class MockAuthenticationException extends - AuthenticationException { + private static final class MockAuthenticationException extends AuthenticationException { + MockAuthenticationException(String msg) { super(msg); } + } } diff --git a/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java b/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java index a9d89d7ac1..71e73e044c 100644 --- a/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/DelegatingReactiveAuthenticationManagerTests.java @@ -16,19 +16,20 @@ package org.springframework.security.authentication; +import java.time.Duration; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.security.core.Authentication; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import java.time.Duration; +import org.springframework.security.core.Authentication; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -36,6 +37,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class DelegatingReactiveAuthenticationManagerTests { + @Mock ReactiveAuthenticationManager delegate1; @@ -47,34 +49,31 @@ public class DelegatingReactiveAuthenticationManagerTests { @Test public void authenticateWhenEmptyAndNotThenReturnsNotEmpty() { - when(this.delegate1.authenticate(any())).thenReturn(Mono.empty()); - when(this.delegate2.authenticate(any())).thenReturn(Mono.just(this.authentication)); - - DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, this.delegate2); - + given(this.delegate1.authenticate(any())).willReturn(Mono.empty()); + given(this.delegate2.authenticate(any())).willReturn(Mono.just(this.authentication)); + DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, + this.delegate2); assertThat(manager.authenticate(this.authentication).block()).isEqualTo(this.authentication); } @Test public void authenticateWhenNotEmptyThenOtherDelegatesNotSubscribed() { - // delay to try and force delegate2 to finish (i.e. make sure we didn't use flatMap) - when(this.delegate1.authenticate(any())).thenReturn(Mono.just(this.authentication).delayElement(Duration.ofMillis(100))); - - DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, this.delegate2); - - StepVerifier.create(manager.authenticate(this.authentication)) - .expectNext(this.authentication) - .verifyComplete(); + // delay to try and force delegate2 to finish (i.e. make sure we didn't use + // flatMap) + given(this.delegate1.authenticate(any())) + .willReturn(Mono.just(this.authentication).delayElement(Duration.ofMillis(100))); + DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, + this.delegate2); + StepVerifier.create(manager.authenticate(this.authentication)).expectNext(this.authentication).verifyComplete(); } @Test public void authenticateWhenBadCredentialsThenDelegate2NotInvokedAndError() { - when(this.delegate1.authenticate(any())).thenReturn(Mono.error(new BadCredentialsException("Test"))); - - DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, this.delegate2); - - StepVerifier.create(manager.authenticate(this.authentication)) - .expectError(BadCredentialsException.class) - .verify(); + given(this.delegate1.authenticate(any())).willReturn(Mono.error(new BadCredentialsException("Test"))); + DelegatingReactiveAuthenticationManager manager = new DelegatingReactiveAuthenticationManager(this.delegate1, + this.delegate2); + StepVerifier.create(manager.authenticate(this.authentication)).expectError(BadCredentialsException.class) + .verify(); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java b/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java index a28c1c4e41..b75f9dcf8c 100644 --- a/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/ProviderManagerTests.java @@ -29,14 +29,14 @@ import org.springframework.security.core.AuthenticationException; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; /** * Tests {@link ProviderManager}. @@ -48,10 +48,12 @@ public class ProviderManagerTests { @Test(expected = ProviderNotFoundException.class) public void authenticationFailsWithUnsupportedToken() { Authentication token = new AbstractAuthenticationToken(null) { + @Override public Object getCredentials() { return ""; } + @Override public Object getPrincipal() { return ""; } @@ -63,12 +65,10 @@ public class ProviderManagerTests { @Test public void credentialsAreClearedByDefault() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password"); ProviderManager mgr = makeProviderManager(); Authentication result = mgr.authenticate(token); assertThat(result.getCredentials()).isNull(); - mgr.setEraseCredentialsAfterAuthentication(false); token = new UsernamePasswordAuthenticationToken("Test", "Password"); result = mgr.authenticate(token); @@ -81,7 +81,6 @@ public class ProviderManagerTests { ProviderManager mgr = new ProviderManager(createProviderWhichReturns(a)); AuthenticationEventPublisher publisher = mock(AuthenticationEventPublisher.class); mgr.setAuthenticationEventPublisher(publisher); - Authentication result = mgr.authenticate(a); assertThat(result).isEqualTo(a); verify(publisher).publishAuthenticationSuccess(result); @@ -90,11 +89,10 @@ public class ProviderManagerTests { @Test public void authenticationSucceedsWhenFirstProviderReturnsNullButSecondAuthenticates() { final Authentication a = mock(Authentication.class); - ProviderManager mgr = new ProviderManager(Arrays.asList( - createProviderWhichReturns(null), createProviderWhichReturns(a))); + ProviderManager mgr = new ProviderManager( + Arrays.asList(createProviderWhichReturns(null), createProviderWhichReturns(a))); AuthenticationEventPublisher publisher = mock(AuthenticationEventPublisher.class); mgr.setAuthenticationEventPublisher(publisher); - Authentication result = mgr.authenticate(a); assertThat(result).isSameAs(a); verify(publisher).publishAuthenticationSuccess(result); @@ -120,7 +118,7 @@ public class ProviderManagerTests { public void constructorWhenUsingListOfThenNoException() { List providers = spy(ArrayList.class); // List.of(null) in JDK 9 throws a NullPointerException - when(providers.contains(eq(null))).thenThrow(NullPointerException.class); + given(providers.contains(eq(null))).willThrow(NullPointerException.class); providers.add(mock(AuthenticationProvider.class)); new ProviderManager(providers); } @@ -129,25 +127,22 @@ public class ProviderManagerTests { public void detailsAreNotSetOnAuthenticationTokenIfAlreadySetByProvider() { Object requestDetails = "(Request Details)"; final Object resultDetails = "(Result Details)"; - // A provider which sets the details object AuthenticationProvider provider = new AuthenticationProvider() { - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { ((TestingAuthenticationToken) authentication).setDetails(resultDetails); return authentication; } + @Override public boolean supports(Class authentication) { return true; } }; - ProviderManager authMgr = new ProviderManager(provider); - TestingAuthenticationToken request = createAuthenticationToken(); request.setDetails(requestDetails); - Authentication result = authMgr.authenticate(request); assertThat(result.getDetails()).isEqualTo(resultDetails); } @@ -156,10 +151,8 @@ public class ProviderManagerTests { public void detailsAreSetOnAuthenticationTokenIfNotAlreadySetByProvider() { Object details = new Object(); ProviderManager authMgr = makeProviderManager(); - TestingAuthenticationToken request = createAuthenticationToken(); request.setDetails(details); - Authentication result = authMgr.authenticate(request); assertThat(result.getCredentials()).isNotNull(); assertThat(result.getDetails()).isSameAs(details); @@ -169,17 +162,15 @@ public class ProviderManagerTests { public void authenticationExceptionIsIgnoredIfLaterProviderAuthenticates() { final Authentication authReq = mock(Authentication.class); ProviderManager mgr = new ProviderManager( - createProviderWhichThrows(new BadCredentialsException("", - new Throwable())), createProviderWhichReturns(authReq)); + createProviderWhichThrows(new BadCredentialsException("", new Throwable())), + createProviderWhichReturns(authReq)); assertThat(mgr.authenticate(mock(Authentication.class))).isSameAs(authReq); } @Test public void authenticationExceptionIsRethrownIfNoLaterProviderAuthenticates() { - - ProviderManager mgr = new ProviderManager(Arrays.asList( - createProviderWhichThrows(new BadCredentialsException("")), - createProviderWhichReturns(null))); + ProviderManager mgr = new ProviderManager(Arrays + .asList(createProviderWhichThrows(new BadCredentialsException("")), createProviderWhichReturns(null))); try { mgr.authenticate(mock(Authentication.class)); fail("Expected BadCredentialsException"); @@ -191,14 +182,10 @@ public class ProviderManagerTests { // SEC-546 @Test public void accountStatusExceptionPreventsCallsToSubsequentProviders() { - AuthenticationProvider iThrowAccountStatusException = createProviderWhichThrows(new AccountStatusException( - "") { + AuthenticationProvider iThrowAccountStatusException = createProviderWhichThrows(new AccountStatusException("") { }); AuthenticationProvider otherProvider = mock(AuthenticationProvider.class); - - ProviderManager authMgr = new ProviderManager(Arrays.asList( - iThrowAccountStatusException, otherProvider)); - + ProviderManager authMgr = new ProviderManager(Arrays.asList(iThrowAccountStatusException, otherProvider)); try { authMgr.authenticate(mock(Authentication.class)); fail("Expected AccountStatusException"); @@ -212,20 +199,19 @@ public class ProviderManagerTests { public void parentAuthenticationIsUsedIfProvidersDontAuthenticate() { AuthenticationManager parent = mock(AuthenticationManager.class); Authentication authReq = mock(Authentication.class); - when(parent.authenticate(authReq)).thenReturn(authReq); - ProviderManager mgr = new ProviderManager( - Collections.singletonList(mock(AuthenticationProvider.class)), parent); + given(parent.authenticate(authReq)).willReturn(authReq); + ProviderManager mgr = new ProviderManager(Collections.singletonList(mock(AuthenticationProvider.class)), + parent); assertThat(mgr.authenticate(authReq)).isSameAs(authReq); } @Test public void parentIsNotCalledIfAccountStatusExceptionIsThrown() { - AuthenticationProvider iThrowAccountStatusException = createProviderWhichThrows(new AccountStatusException( - "", new Throwable()) { - }); + AuthenticationProvider iThrowAccountStatusException = createProviderWhichThrows( + new AccountStatusException("", new Throwable()) { + }); AuthenticationManager parent = mock(AuthenticationManager.class); - ProviderManager mgr = new ProviderManager( - Collections.singletonList(iThrowAccountStatusException), parent); + ProviderManager mgr = new ProviderManager(Collections.singletonList(iThrowAccountStatusException), parent); try { mgr.authenticate(mock(Authentication.class)); fail("Expected exception"); @@ -240,15 +226,12 @@ public class ProviderManagerTests { final Authentication authReq = mock(Authentication.class); AuthenticationEventPublisher publisher = mock(AuthenticationEventPublisher.class); AuthenticationManager parent = mock(AuthenticationManager.class); - when(parent.authenticate(authReq)).thenThrow(new ProviderNotFoundException("")); - + given(parent.authenticate(authReq)).willThrow(new ProviderNotFoundException("")); // Set a provider that throws an exception - this is the exception we expect to be // propagated ProviderManager mgr = new ProviderManager( - Collections.singletonList(createProviderWhichThrows(new BadCredentialsException(""))), - parent); + Collections.singletonList(createProviderWhichThrows(new BadCredentialsException(""))), parent); mgr.setAuthenticationEventPublisher(publisher); - try { mgr.authenticate(authReq); fail("Expected exception"); @@ -262,22 +245,20 @@ public class ProviderManagerTests { public void authenticationExceptionFromParentOverridesPreviousOnes() { AuthenticationManager parent = mock(AuthenticationManager.class); ProviderManager mgr = new ProviderManager( - Collections.singletonList(createProviderWhichThrows(new BadCredentialsException(""))), - parent); + Collections.singletonList(createProviderWhichThrows(new BadCredentialsException(""))), parent); final Authentication authReq = mock(Authentication.class); AuthenticationEventPublisher publisher = mock(AuthenticationEventPublisher.class); mgr.setAuthenticationEventPublisher(publisher); // Set a provider that throws an exception - this is the exception we expect to be // propagated - final BadCredentialsException expected = new BadCredentialsException( - "I'm the one from the parent"); - when(parent.authenticate(authReq)).thenThrow(expected); + final BadCredentialsException expected = new BadCredentialsException("I'm the one from the parent"); + given(parent.authenticate(authReq)).willThrow(expected); try { mgr.authenticate(authReq); fail("Expected exception"); } - catch (BadCredentialsException e) { - assertThat(e).isSameAs(expected); + catch (BadCredentialsException ex) { + assertThat(ex).isSameAs(expected); } } @@ -285,8 +266,8 @@ public class ProviderManagerTests { public void statusExceptionIsPublished() { AuthenticationManager parent = mock(AuthenticationManager.class); final LockedException expected = new LockedException(""); - ProviderManager mgr = new ProviderManager( - Collections.singletonList(createProviderWhichThrows(expected)), parent); + ProviderManager mgr = new ProviderManager(Collections.singletonList(createProviderWhichThrows(expected)), + parent); final Authentication authReq = mock(Authentication.class); AuthenticationEventPublisher publisher = mock(AuthenticationEventPublisher.class); mgr.setAuthenticationEventPublisher(publisher); @@ -294,8 +275,8 @@ public class ProviderManagerTests { mgr.authenticate(authReq); fail("Expected exception"); } - catch (LockedException e) { - assertThat(e).isSameAs(expected); + catch (LockedException ex) { + assertThat(ex).isSameAs(expected); } verify(publisher).publishAuthenticationFailure(expected, authReq); } @@ -303,13 +284,10 @@ public class ProviderManagerTests { // SEC-2367 @Test public void providerThrowsInternalAuthenticationServiceException() { - InternalAuthenticationServiceException expected = new InternalAuthenticationServiceException( - "Expected"); - ProviderManager mgr = new ProviderManager(Arrays.asList( - createProviderWhichThrows(expected), + InternalAuthenticationServiceException expected = new InternalAuthenticationServiceException("Expected"); + ProviderManager mgr = new ProviderManager(Arrays.asList(createProviderWhichThrows(expected), createProviderWhichThrows(new BadCredentialsException("Oops"))), null); final Authentication authReq = mock(Authentication.class); - try { mgr.authenticate(authReq); fail("Expected Exception"); @@ -323,46 +301,40 @@ public class ProviderManagerTests { public void authenticateWhenFailsInParentAndPublishesThenChildDoesNotPublish() { BadCredentialsException badCredentialsExParent = new BadCredentialsException("Bad Credentials in parent"); ProviderManager parentMgr = new ProviderManager(createProviderWhichThrows(badCredentialsExParent)); - ProviderManager childMgr = new ProviderManager(Collections.singletonList(createProviderWhichThrows( - new BadCredentialsException("Bad Credentials in child"))), parentMgr); - + ProviderManager childMgr = new ProviderManager(Collections.singletonList( + createProviderWhichThrows(new BadCredentialsException("Bad Credentials in child"))), parentMgr); AuthenticationEventPublisher publisher = mock(AuthenticationEventPublisher.class); parentMgr.setAuthenticationEventPublisher(publisher); childMgr.setAuthenticationEventPublisher(publisher); - final Authentication authReq = mock(Authentication.class); - try { childMgr.authenticate(authReq); fail("Expected exception"); } - catch (BadCredentialsException e) { - assertThat(e).isSameAs(badCredentialsExParent); + catch (BadCredentialsException ex) { + assertThat(ex).isSameAs(badCredentialsExParent); } - verify(publisher).publishAuthenticationFailure(badCredentialsExParent, authReq); // Parent publishes - verifyNoMoreInteractions(publisher); // Child should not publish (duplicate event) + verify(publisher).publishAuthenticationFailure(badCredentialsExParent, authReq); // Parent + // publishes + verifyNoMoreInteractions(publisher); // Child should not publish (duplicate event) } - private AuthenticationProvider createProviderWhichThrows( - final AuthenticationException e) { + private AuthenticationProvider createProviderWhichThrows(final AuthenticationException ex) { AuthenticationProvider provider = mock(AuthenticationProvider.class); - when(provider.supports(any(Class.class))).thenReturn(true); - when(provider.authenticate(any(Authentication.class))).thenThrow(e); - + given(provider.supports(any(Class.class))).willReturn(true); + given(provider.authenticate(any(Authentication.class))).willThrow(ex); return provider; } private AuthenticationProvider createProviderWhichReturns(final Authentication a) { AuthenticationProvider provider = mock(AuthenticationProvider.class); - when(provider.supports(any(Class.class))).thenReturn(true); - when(provider.authenticate(any(Authentication.class))).thenReturn(a); - + given(provider.supports(any(Class.class))).willReturn(true); + given(provider.authenticate(any(Authentication.class))).willReturn(a); return provider; } private TestingAuthenticationToken createAuthenticationToken() { - return new TestingAuthenticationToken("name", "password", - new ArrayList<>(0)); + return new TestingAuthenticationToken("name", "password", new ArrayList<>(0)); } private ProviderManager makeProviderManager() { @@ -370,12 +342,10 @@ public class ProviderManagerTests { return new ProviderManager(provider); } - // ~ Inner Classes - // ================================================================================================== - private static class MockProvider implements AuthenticationProvider { - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { if (supports(authentication.getClass())) { return authentication; } @@ -384,10 +354,12 @@ public class ProviderManagerTests { } } + @Override public boolean supports(Class authentication) { return TestingAuthenticationToken.class.isAssignableFrom(authentication) - || UsernamePasswordAuthenticationToken.class - .isAssignableFrom(authentication); + || UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication); } + } + } diff --git a/core/src/test/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapterTests.java b/core/src/test/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapterTests.java index 9f7390ab3d..7dc48107fc 100644 --- a/core/src/test/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapterTests.java +++ b/core/src/test/java/org/springframework/security/authentication/ReactiveAuthenticationManagerAdapterTests.java @@ -21,13 +21,14 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.security.core.Authentication; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.security.core.Authentication; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -35,8 +36,10 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class ReactiveAuthenticationManagerAdapterTests { + @Mock AuthenticationManager delegate; + @Mock Authentication authentication; @@ -44,7 +47,7 @@ public class ReactiveAuthenticationManagerAdapterTests { @Before public void setup() { - manager = new ReactiveAuthenticationManagerAdapter(delegate); + this.manager = new ReactiveAuthenticationManagerAdapter(this.delegate); } @Test(expected = IllegalArgumentException.class) @@ -59,31 +62,28 @@ public class ReactiveAuthenticationManagerAdapterTests { @Test public void authenticateWhenSuccessThenSuccess() { - when(delegate.authenticate(any())).thenReturn(authentication); - when(authentication.isAuthenticated()).thenReturn(true); - - Authentication result = manager.authenticate(authentication).block(); - - assertThat(result).isEqualTo(authentication); + given(this.delegate.authenticate(any())).willReturn(this.authentication); + given(this.authentication.isAuthenticated()).willReturn(true); + Authentication result = this.manager.authenticate(this.authentication).block(); + assertThat(result).isEqualTo(this.authentication); } @Test public void authenticateWhenReturnNotAuthenticatedThenError() { - when(delegate.authenticate(any())).thenReturn(authentication); - - Authentication result = manager.authenticate(authentication).block(); - + given(this.delegate.authenticate(any())).willReturn(this.authentication); + Authentication result = this.manager.authenticate(this.authentication).block(); assertThat(result).isNull(); } @Test public void authenticateWhenBadCredentialsThenError() { - when(delegate.authenticate(any())).thenThrow(new BadCredentialsException("Failed")); - - Mono result = manager.authenticate(authentication); - + given(this.delegate.authenticate(any())).willThrow(new BadCredentialsException("Failed")); + Mono result = this.manager.authenticate(this.authentication); + // @formatter:off StepVerifier.create(result) - .expectError(BadCredentialsException.class) - .verify(); + .expectError(BadCredentialsException.class) + .verify(); + // @formatter:on } + } diff --git a/core/src/test/java/org/springframework/security/authentication/ReactiveUserDetailsServiceAuthenticationManagerTests.java b/core/src/test/java/org/springframework/security/authentication/ReactiveUserDetailsServiceAuthenticationManagerTests.java index 73011471b0..7aaac05a44 100644 --- a/core/src/test/java/org/springframework/security/authentication/ReactiveUserDetailsServiceAuthenticationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/ReactiveUserDetailsServiceAuthenticationManagerTests.java @@ -13,27 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.authentication; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +package org.springframework.security.authentication; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.PasswordEncodedUser; -import org.springframework.security.core.userdetails.User; - import org.springframework.security.core.userdetails.ReactiveUserDetailsService; +import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.crypto.password.PasswordEncoder; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -41,18 +42,24 @@ import reactor.test.StepVerifier; */ @RunWith(MockitoJUnitRunner.class) public class ReactiveUserDetailsServiceAuthenticationManagerTests { - @Mock ReactiveUserDetailsService repository; + + @Mock + ReactiveUserDetailsService repository; + @Mock PasswordEncoder passwordEncoder; + UserDetailsRepositoryReactiveAuthenticationManager manager; + String username; + String password; @Before public void setup() { - manager = new UserDetailsRepositoryReactiveAuthenticationManager(repository); - username = "user"; - password = "pass"; + this.manager = new UserDetailsRepositoryReactiveAuthenticationManager(this.repository); + this.username = "user"; + this.password = "pass"; } @Test(expected = IllegalArgumentException.class) @@ -63,77 +70,77 @@ public class ReactiveUserDetailsServiceAuthenticationManagerTests { @Test public void authenticateWhenUserNotFoundThenBadCredentials() { - when(repository.findByUsername(username)).thenReturn(Mono.empty()); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(username, password); - Mono authentication = manager.authenticate(token); - - StepVerifier - .create(authentication) - .expectError(BadCredentialsException.class) - .verify(); + given(this.repository.findByUsername(this.username)).willReturn(Mono.empty()); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.username, + this.password); + Mono authentication = this.manager.authenticate(token); + // @formatter:off + StepVerifier.create(authentication) + .expectError(BadCredentialsException.class) + .verify(); + // @formatter:on } @Test public void authenticateWhenPasswordNotEqualThenBadCredentials() { + // @formatter:off UserDetails user = PasswordEncodedUser.withUsername(this.username) .password(this.password) .roles("USER") .build(); - when(repository.findByUsername(user.getUsername())).thenReturn(Mono.just(user)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(username, this.password + "INVALID"); - Mono authentication = manager.authenticate(token); - - StepVerifier - .create(authentication) - .expectError(BadCredentialsException.class) - .verify(); + // @formatter:on + given(this.repository.findByUsername(user.getUsername())).willReturn(Mono.just(user)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.username, + this.password + "INVALID"); + Mono authentication = this.manager.authenticate(token); + // @formatter:off + StepVerifier.create(authentication) + .expectError(BadCredentialsException.class) + .verify(); + // @formatter:on } @Test public void authenticateWhenSuccessThenSuccess() { + // @formatter:off UserDetails user = PasswordEncodedUser.withUsername(this.username) .password(this.password) .roles("USER") .build(); - when(repository.findByUsername(user.getUsername())).thenReturn(Mono.just(user)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(username, password); - Authentication authentication = manager.authenticate(token).block(); - + // @formatter:on + given(this.repository.findByUsername(user.getUsername())).willReturn(Mono.just(user)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.username, + this.password); + Authentication authentication = this.manager.authenticate(token).block(); assertThat(authentication).isEqualTo(authentication); } @Test public void authenticateWhenPasswordEncoderAndSuccessThenSuccess() { this.manager.setPasswordEncoder(this.passwordEncoder); - when(this.passwordEncoder.matches(any(), any())).thenReturn(true); + given(this.passwordEncoder.matches(any(), any())).willReturn(true); User user = new User(this.username, this.password, AuthorityUtils.createAuthorityList("ROLE_USER")); - when(this.repository.findByUsername(user.getUsername())).thenReturn(Mono.just(user)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.username, this.password); + given(this.repository.findByUsername(user.getUsername())).willReturn(Mono.just(user)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.username, + this.password); Authentication authentication = this.manager.authenticate(token).block(); - assertThat(authentication).isEqualTo(authentication); } @Test public void authenticateWhenPasswordEncoderAndFailThenFail() { this.manager.setPasswordEncoder(this.passwordEncoder); - when(this.passwordEncoder.matches(any(), any())).thenReturn(false); + given(this.passwordEncoder.matches(any(), any())).willReturn(false); User user = new User(this.username, this.password, AuthorityUtils.createAuthorityList("ROLE_USER")); - when(this.repository.findByUsername(user.getUsername())).thenReturn(Mono.just(user)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.username, this.password); - + given(this.repository.findByUsername(user.getUsername())).willReturn(Mono.just(user)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.username, + this.password); Mono authentication = this.manager.authenticate(token); - - StepVerifier - .create(authentication) - .expectError(BadCredentialsException.class) - .verify(); + // @formatter:off + StepVerifier.create(authentication) + .expectError(BadCredentialsException.class) + .verify(); + // @formatter:on } + } diff --git a/core/src/test/java/org/springframework/security/authentication/TestAuthentication.java b/core/src/test/java/org/springframework/security/authentication/TestAuthentication.java index ade69f9ca3..0583c42a48 100644 --- a/core/src/test/java/org/springframework/security/authentication/TestAuthentication.java +++ b/core/src/test/java/org/springframework/security/authentication/TestAuthentication.java @@ -37,4 +37,5 @@ public class TestAuthentication extends PasswordEncodedUser { public static Authentication autheticated(UserDetails user) { return new UsernamePasswordAuthenticationToken(user, null, user.getAuthorities()); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationProviderTests.java index 875a8bbf3b..1435300fc8 100644 --- a/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationProviderTests.java @@ -16,12 +16,13 @@ package org.springframework.security.authentication; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link TestingAuthenticationProvider}. * @@ -32,18 +33,13 @@ public class TestingAuthenticationProviderTests { @Test public void testAuthenticates() { TestingAuthenticationProvider provider = new TestingAuthenticationProvider(); - TestingAuthenticationToken token = new TestingAuthenticationToken("Test", - "Password", "ROLE_ONE", "ROLE_TWO"); + TestingAuthenticationToken token = new TestingAuthenticationToken("Test", "Password", "ROLE_ONE", "ROLE_TWO"); Authentication result = provider.authenticate(token); - assertThat(result instanceof TestingAuthenticationToken).isTrue(); - TestingAuthenticationToken castResult = (TestingAuthenticationToken) result; assertThat(castResult.getPrincipal()).isEqualTo("Test"); assertThat(castResult.getCredentials()).isEqualTo("Password"); - assertThat( - AuthorityUtils.authorityListToSet(castResult.getAuthorities())).contains( - "ROLE_ONE", "ROLE_TWO"); + assertThat(AuthorityUtils.authorityListToSet(castResult.getAuthorities())).contains("ROLE_ONE", "ROLE_TWO"); } @Test @@ -52,4 +48,5 @@ public class TestingAuthenticationProviderTests { assertThat(provider.supports(TestingAuthenticationToken.class)).isTrue(); assertThat(!provider.supports(String.class)).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationTokenTests.java index 473fa78945..8b7c6f4617 100644 --- a/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationTokenTests.java +++ b/core/src/test/java/org/springframework/security/authentication/TestingAuthenticationTokenTests.java @@ -13,41 +13,40 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication; -import org.junit.Test; -import org.springframework.security.core.authority.SimpleGrantedAuthority; - import java.util.Arrays; +import org.junit.Test; + +import org.springframework.security.core.authority.SimpleGrantedAuthority; + import static org.assertj.core.api.Assertions.assertThat; /** * @author Josh Cummings */ public class TestingAuthenticationTokenTests { + @Test public void constructorWhenNoAuthoritiesThenUnauthenticated() { - TestingAuthenticationToken unauthenticated = - new TestingAuthenticationToken("principal", "credentials"); - + TestingAuthenticationToken unauthenticated = new TestingAuthenticationToken("principal", "credentials"); assertThat(unauthenticated.isAuthenticated()).isFalse(); } @Test public void constructorWhenArityAuthoritiesThenAuthenticated() { - TestingAuthenticationToken authenticated = - new TestingAuthenticationToken("principal", "credentials", "authority"); - + TestingAuthenticationToken authenticated = new TestingAuthenticationToken("principal", "credentials", + "authority"); assertThat(authenticated.isAuthenticated()).isTrue(); } @Test public void constructorWhenCollectionAuthoritiesThenAuthenticated() { - TestingAuthenticationToken authenticated = - new TestingAuthenticationToken("principal", "credentials", + TestingAuthenticationToken authenticated = new TestingAuthenticationToken("principal", "credentials", Arrays.asList(new SimpleGrantedAuthority("authority"))); - assertThat(authenticated.isAuthenticated()).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java b/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java index 793edd5db3..7cd80768dc 100644 --- a/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/UserDetailsRepositoryReactiveAuthenticationManagerTests.java @@ -16,15 +16,11 @@ package org.springframework.security.authentication; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; - import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; @@ -37,6 +33,14 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsChecker; import org.springframework.security.crypto.password.PasswordEncoder; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; + /** * @author Rob Winch * @author Eddú Meléndez @@ -44,6 +48,7 @@ import org.springframework.security.crypto.password.PasswordEncoder; */ @RunWith(MockitoJUnitRunner.class) public class UserDetailsRepositoryReactiveAuthenticationManagerTests { + @Mock private ReactiveUserDetailsService userDetailsService; @@ -59,17 +64,18 @@ public class UserDetailsRepositoryReactiveAuthenticationManagerTests { @Mock private UserDetailsChecker postAuthenticationChecks; + // @formatter:off private UserDetails user = User.withUsername("user") .password("password") .roles("USER") .build(); - + // @formatter:on private UserDetailsRepositoryReactiveAuthenticationManager manager; @Before public void setup() { this.manager = new UserDetailsRepositoryReactiveAuthenticationManager(this.userDetailsService); - when(this.scheduler.schedule(any())).thenAnswer(a -> { + given(this.scheduler.schedule(any())).willAnswer((a) -> { Runnable r = a.getArgument(0); return Schedulers.immediate().schedule(r); }); @@ -77,150 +83,133 @@ public class UserDetailsRepositoryReactiveAuthenticationManagerTests { @Test public void setSchedulerWhenNullThenIllegalArgumentException() { - assertThatCode(() -> this.manager.setScheduler(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> this.manager.setScheduler(null)); } @Test public void authentiateWhenCustomSchedulerThenUsed() { - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); - when(this.encoder.matches(any(), any())).thenReturn(true); + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); + given(this.encoder.matches(any(), any())).willReturn(true); this.manager.setScheduler(this.scheduler); this.manager.setPasswordEncoder(this.encoder); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.user, this.user.getPassword()); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.user, + this.user.getPassword()); Authentication result = this.manager.authenticate(token).block(); - verify(this.scheduler).schedule(any()); } @Test public void authenticateWhenPasswordServiceThenUpdated() { String encodedPassword = "encoded"; - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); - when(this.encoder.matches(any(), any())).thenReturn(true); - when(this.encoder.upgradeEncoding(any())).thenReturn(true); - when(this.encoder.encode(any())).thenReturn(encodedPassword); - when(this.userDetailsPasswordService.updatePassword(any(), any())).thenReturn(Mono.just(this.user)); + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); + given(this.encoder.matches(any(), any())).willReturn(true); + given(this.encoder.upgradeEncoding(any())).willReturn(true); + given(this.encoder.encode(any())).willReturn(encodedPassword); + given(this.userDetailsPasswordService.updatePassword(any(), any())).willReturn(Mono.just(this.user)); this.manager.setPasswordEncoder(this.encoder); this.manager.setUserDetailsPasswordService(this.userDetailsPasswordService); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.user, this.user.getPassword()); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.user, + this.user.getPassword()); Authentication result = this.manager.authenticate(token).block(); - verify(this.encoder).encode(this.user.getPassword()); verify(this.userDetailsPasswordService).updatePassword(eq(this.user), eq(encodedPassword)); } @Test public void authenticateWhenPasswordServiceAndBadCredentialsThenNotUpdated() { - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); - when(this.encoder.matches(any(), any())).thenReturn(false); + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); + given(this.encoder.matches(any(), any())).willReturn(false); this.manager.setPasswordEncoder(this.encoder); this.manager.setUserDetailsPasswordService(this.userDetailsPasswordService); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.user, this.user.getPassword()); - - assertThatThrownBy(() -> this.manager.authenticate(token).block()) - .isInstanceOf(BadCredentialsException.class); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.user, + this.user.getPassword()); + assertThatExceptionOfType(BadCredentialsException.class) + .isThrownBy(() -> this.manager.authenticate(token).block()); verifyZeroInteractions(this.userDetailsPasswordService); } @Test public void authenticateWhenPasswordServiceAndUpgradeFalseThenNotUpdated() { - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); - when(this.encoder.matches(any(), any())).thenReturn(true); - when(this.encoder.upgradeEncoding(any())).thenReturn(false); + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); + given(this.encoder.matches(any(), any())).willReturn(true); + given(this.encoder.upgradeEncoding(any())).willReturn(false); this.manager.setPasswordEncoder(this.encoder); this.manager.setUserDetailsPasswordService(this.userDetailsPasswordService); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.user, this.user.getPassword()); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.user, + this.user.getPassword()); Authentication result = this.manager.authenticate(token).block(); - verifyZeroInteractions(this.userDetailsPasswordService); } @Test public void authenticateWhenPostAuthenticationChecksFail() { - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); - doThrow(new LockedException("account is locked")).when(this.postAuthenticationChecks).check(any()); - when(this.encoder.matches(any(), any())).thenReturn(true); + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); + willThrow(new LockedException("account is locked")).given(this.postAuthenticationChecks).check(any()); + given(this.encoder.matches(any(), any())).willReturn(true); this.manager.setPasswordEncoder(this.encoder); this.manager.setPostAuthenticationChecks(this.postAuthenticationChecks); - - assertThatExceptionOfType(LockedException.class) - .isThrownBy(() -> this.manager.authenticate(new UsernamePasswordAuthenticationToken(this.user, this.user.getPassword())).block()) + assertThatExceptionOfType(LockedException.class).isThrownBy(() -> this.manager + .authenticate(new UsernamePasswordAuthenticationToken(this.user, this.user.getPassword())).block()) .withMessage("account is locked"); - verify(this.postAuthenticationChecks).check(eq(this.user)); } @Test public void authenticateWhenPostAuthenticationChecksNotSet() { - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(this.user)); - when(this.encoder.matches(any(), any())).thenReturn(true); + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(this.user)); + given(this.encoder.matches(any(), any())).willReturn(true); this.manager.setPasswordEncoder(this.encoder); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - this.user, this.user.getPassword()); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(this.user, + this.user.getPassword()); this.manager.authenticate(token).block(); - verifyZeroInteractions(this.postAuthenticationChecks); } @Test(expected = AccountExpiredException.class) public void authenticateWhenAccountExpiredThenException() { this.manager.setPasswordEncoder(this.encoder); - + // @formatter:off UserDetails expiredUser = User.withUsername("user") .password("password") .roles("USER") .accountExpired(true) .build(); - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(expiredUser)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - expiredUser, expiredUser.getPassword()); - + // @formatter:on + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(expiredUser)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(expiredUser, + expiredUser.getPassword()); this.manager.authenticate(token).block(); } @Test(expected = LockedException.class) public void authenticateWhenAccountLockedThenException() { this.manager.setPasswordEncoder(this.encoder); - + // @formatter:off UserDetails lockedUser = User.withUsername("user") .password("password") .roles("USER") .accountLocked(true) .build(); - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(lockedUser)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - lockedUser, lockedUser.getPassword()); - + // @formatter:on + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(lockedUser)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(lockedUser, + lockedUser.getPassword()); this.manager.authenticate(token).block(); } @Test(expected = DisabledException.class) public void authenticateWhenAccountDisabledThenException() { this.manager.setPasswordEncoder(this.encoder); - + // @formatter:off UserDetails disabledUser = User.withUsername("user") .password("password") .roles("USER") .disabled(true) .build(); - when(this.userDetailsService.findByUsername(any())).thenReturn(Mono.just(disabledUser)); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - disabledUser, disabledUser.getPassword()); - + // @formatter:on + given(this.userDetailsService.findByUsername(any())).willReturn(Mono.just(disabledUser)); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(disabledUser, + disabledUser.getPassword()); this.manager.authenticate(token).block(); } diff --git a/core/src/test/java/org/springframework/security/authentication/UsernamePasswordAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/authentication/UsernamePasswordAuthenticationTokenTests.java index 71d599f8d7..61cd51ecef 100644 --- a/core/src/test/java/org/springframework/security/authentication/UsernamePasswordAuthenticationTokenTests.java +++ b/core/src/test/java/org/springframework/security/authentication/UsernamePasswordAuthenticationTokenTests.java @@ -16,12 +16,13 @@ package org.springframework.security.authentication; +import org.junit.Test; + +import org.springframework.security.core.authority.AuthorityUtils; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -import org.junit.Test; -import org.springframework.security.core.authority.AuthorityUtils; - /** * Tests {@link UsernamePasswordAuthenticationToken}. * @@ -29,33 +30,24 @@ import org.springframework.security.core.authority.AuthorityUtils; */ public class UsernamePasswordAuthenticationTokenTests { - // ~ Methods - // ======================================================================================================== - @Test public void authenticatedPropertyContractIsSatisfied() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", AuthorityUtils.NO_AUTHORITIES); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", + AuthorityUtils.NO_AUTHORITIES); // check default given we passed some GrantedAuthorty[]s (well, we passed empty // list) assertThat(token.isAuthenticated()).isTrue(); - // check explicit set to untrusted (we can safely go from trusted to untrusted, // but not the reverse) token.setAuthenticated(false); assertThat(!token.isAuthenticated()).isTrue(); - // Now let's create a UsernamePasswordAuthenticationToken without any // GrantedAuthorty[]s (different constructor) token = new UsernamePasswordAuthenticationToken("Test", "Password"); - assertThat(!token.isAuthenticated()).isTrue(); - // check we're allowed to still set it to untrusted token.setAuthenticated(false); assertThat(!token.isAuthenticated()).isTrue(); - // check denied changing it to trusted try { token.setAuthenticated(true); @@ -67,8 +59,7 @@ public class UsernamePasswordAuthenticationTokenTests { @Test public void gettersReturnCorrectData() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "Test", "Password", + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("Test", "Password", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); assertThat(token.getPrincipal()).isEqualTo("Test"); assertThat(token.getCredentials()).isEqualTo("Password"); @@ -81,4 +72,5 @@ public class UsernamePasswordAuthenticationTokenTests { Class clazz = UsernamePasswordAuthenticationToken.class; clazz.getDeclaredConstructor((Class[]) null); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationProviderTests.java index 11277b56bf..808cb36347 100644 --- a/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationProviderTests.java @@ -16,9 +16,8 @@ package org.springframework.security.authentication.anonymous; -import static org.assertj.core.api.Assertions.*; +import org.junit.Test; -import org.junit.*; import org.springframework.security.authentication.AnonymousAuthenticationProvider; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.BadCredentialsException; @@ -26,6 +25,9 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AnonymousAuthenticationProvider}. * @@ -33,18 +35,11 @@ import org.springframework.security.core.authority.AuthorityUtils; */ public class AnonymousAuthenticationProviderTests { - // ~ Methods - // ======================================================================================================== - @Test public void testDetectsAnInvalidKey() { - AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider( - "qwerty"); - - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken( - "WRONG_KEY", "Test", AuthorityUtils.createAuthorityList("ROLE_ONE", - "ROLE_TWO")); - + AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider("qwerty"); + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("WRONG_KEY", "Test", + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); try { aap.authenticate(token); fail("Should have thrown BadCredentialsException"); @@ -60,48 +55,38 @@ public class AnonymousAuthenticationProviderTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @Test public void testGettersSetters() { - AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider( - "qwerty"); + AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider("qwerty"); assertThat(aap.getKey()).isEqualTo("qwerty"); } @Test public void testIgnoresClassesItDoesNotSupport() { - AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider( - "qwerty"); - - TestingAuthenticationToken token = new TestingAuthenticationToken("user", - "password", "ROLE_A"); + AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider("qwerty"); + TestingAuthenticationToken token = new TestingAuthenticationToken("user", "password", "ROLE_A"); assertThat(aap.supports(TestingAuthenticationToken.class)).isFalse(); - // Try it anyway assertThat(aap.authenticate(token)).isNull(); } @Test public void testNormalOperation() { - AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider( - "qwerty"); - - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("qwerty", - "Test", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); - + AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider("qwerty"); + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("qwerty", "Test", + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); Authentication result = aap.authenticate(token); - assertThat(token).isEqualTo(result); } @Test public void testSupports() { - AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider( - "qwerty"); + AnonymousAuthenticationProvider aap = new AnonymousAuthenticationProvider("qwerty"); assertThat(aap.supports(AnonymousAuthenticationToken.class)).isTrue(); assertThat(aap.supports(TestingAuthenticationToken.class)).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationTokenTests.java index 1fec11d0c4..298a43e633 100644 --- a/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationTokenTests.java +++ b/core/src/test/java/org/springframework/security/authentication/anonymous/AnonymousAuthenticationTokenTests.java @@ -16,18 +16,19 @@ package org.springframework.security.authentication.anonymous; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.util.Collections; import java.util.List; import org.junit.Test; + import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AnonymousAuthenticationToken}. * @@ -35,11 +36,8 @@ import org.springframework.security.core.authority.AuthorityUtils; */ public class AnonymousAuthenticationTokenTests { - private final static List ROLES_12 = AuthorityUtils.createAuthorityList( - "ROLE_ONE", "ROLE_TWO"); + private static final List ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); - // ~ Methods - // ======================================================================================================== @Test public void testConstructorRejectsNulls() { try { @@ -48,25 +46,20 @@ public class AnonymousAuthenticationTokenTests { } catch (IllegalArgumentException expected) { } - try { new AnonymousAuthenticationToken("key", null, ROLES_12); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new AnonymousAuthenticationToken("key", "Test", - null); + new AnonymousAuthenticationToken("key", "Test", null); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { - new AnonymousAuthenticationToken("key", "Test", - AuthorityUtils.NO_AUTHORITIES); + new AnonymousAuthenticationToken("key", "Test", AuthorityUtils.NO_AUTHORITIES); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { @@ -75,31 +68,24 @@ public class AnonymousAuthenticationTokenTests { @Test public void testEqualsWhenEqual() { - AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); - AnonymousAuthenticationToken token2 = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); - + AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", "Test", ROLES_12); + AnonymousAuthenticationToken token2 = new AnonymousAuthenticationToken("key", "Test", ROLES_12); assertThat(token2).isEqualTo(token1); } @Test public void testGetters() { - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); - + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("key", "Test", ROLES_12); assertThat(token.getKeyHash()).isEqualTo("key".hashCode()); assertThat(token.getPrincipal()).isEqualTo("Test"); assertThat(token.getCredentials()).isEqualTo(""); - assertThat(AuthorityUtils.authorityListToSet(token.getAuthorities())).contains( - "ROLE_ONE", "ROLE_TWO"); + assertThat(AuthorityUtils.authorityListToSet(token.getAuthorities())).contains("ROLE_ONE", "ROLE_TWO"); assertThat(token.isAuthenticated()).isTrue(); } @Test public void testNoArgConstructorDoesntExist() { Class clazz = AnonymousAuthenticationToken.class; - try { clazz.getDeclaredConstructor((Class[]) null); fail("Should have thrown NoSuchMethodException"); @@ -110,39 +96,29 @@ public class AnonymousAuthenticationTokenTests { @Test public void testNotEqualsDueToAbstractParentEqualsCheck() { - AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); - AnonymousAuthenticationToken token2 = new AnonymousAuthenticationToken("key", - "DIFFERENT_PRINCIPAL", ROLES_12); - + AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", "Test", ROLES_12); + AnonymousAuthenticationToken token2 = new AnonymousAuthenticationToken("key", "DIFFERENT_PRINCIPAL", ROLES_12); assertThat(token1.equals(token2)).isFalse(); } @Test public void testNotEqualsDueToDifferentAuthenticationClass() { - AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); - UsernamePasswordAuthenticationToken token2 = new UsernamePasswordAuthenticationToken( - "Test", "Password", ROLES_12); - + AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", "Test", ROLES_12); + UsernamePasswordAuthenticationToken token2 = new UsernamePasswordAuthenticationToken("Test", "Password", + ROLES_12); assertThat(token1.equals(token2)).isFalse(); } @Test public void testNotEqualsDueToKey() { - AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); - - AnonymousAuthenticationToken token2 = new AnonymousAuthenticationToken( - "DIFFERENT_KEY", "Test", ROLES_12); - + AnonymousAuthenticationToken token1 = new AnonymousAuthenticationToken("key", "Test", ROLES_12); + AnonymousAuthenticationToken token2 = new AnonymousAuthenticationToken("DIFFERENT_KEY", "Test", ROLES_12); assertThat(token1.equals(token2)).isFalse(); } @Test public void testSetAuthenticatedIgnored() { - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("key", - "Test", ROLES_12); + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken("key", "Test", ROLES_12); assertThat(token.isAuthenticated()).isTrue(); token.setAuthenticated(false); assertThat(!token.isAuthenticated()).isTrue(); @@ -162,4 +138,5 @@ public class AnonymousAuthenticationTokenTests { public void constructorWhenPrincipalIsEmptyStringThenThrowIllegalArgumentException() { new AnonymousAuthenticationToken("key", "", ROLES_12); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java index 7d906af380..b045b798bb 100644 --- a/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java @@ -16,24 +16,12 @@ package org.springframework.security.authentication.dao; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; - import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import org.junit.Test; + import org.springframework.dao.DataRetrievalFailureException; import org.springframework.security.authentication.AccountExpiredException; import org.springframework.security.authentication.AuthenticationServiceException; @@ -50,6 +38,7 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.core.userdetails.UserDetailsPasswordService; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.core.userdetails.cache.EhCacheBasedUserCache; @@ -58,7 +47,19 @@ import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.factory.PasswordEncoderFactories; import org.springframework.security.crypto.password.NoOpPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; -import org.springframework.security.core.userdetails.UserDetailsPasswordService; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests {@link DaoAuthenticationProvider}. @@ -68,26 +69,19 @@ import org.springframework.security.core.userdetails.UserDetailsPasswordService; */ public class DaoAuthenticationProviderTests { - private static final List ROLES_12 = AuthorityUtils.createAuthorityList( - "ROLE_ONE", "ROLE_TWO"); + private static final List ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); - // ~ Methods - // ======================================================================================================== @Test public void testAuthenticateFailsForIncorrectPasswordCase() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "KOala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "KOala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @@ -97,113 +91,86 @@ public class DaoAuthenticationProviderTests { DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - - UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken( - "rod", null); + UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken("rod", null); try { provider.authenticate(authenticationToken); fail("Expected BadCredenialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticateFailsIfAccountExpired() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "peter", "opal"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("peter", "opal"); DaoAuthenticationProvider provider = createProvider(); - provider.setUserDetailsService( - new MockUserDetailsServiceUserPeterAccountExpired()); + provider.setUserDetailsService(new MockUserDetailsServiceUserPeterAccountExpired()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown AccountExpiredException"); } catch (AccountExpiredException expected) { - } } @Test public void testAuthenticateFailsIfAccountLocked() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "peter", "opal"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("peter", "opal"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserPeterAccountLocked()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown LockedException"); } catch (LockedException expected) { - } } @Test public void testAuthenticateFailsIfCredentialsExpired() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "peter", "opal"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("peter", "opal"); DaoAuthenticationProvider provider = createProvider(); - provider.setUserDetailsService( - new MockUserDetailsServiceUserPeterCredentialsExpired()); + provider.setUserDetailsService(new MockUserDetailsServiceUserPeterCredentialsExpired()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown CredentialsExpiredException"); } catch (CredentialsExpiredException expected) { - } - // Check that wrong password causes BadCredentialsException, rather than // CredentialsExpiredException token = new UsernamePasswordAuthenticationToken("peter", "wrong_password"); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticateFailsIfUserDisabled() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "peter", "opal"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("peter", "opal"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserPeter()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown DisabledException"); } catch (DisabledException expected) { - } } @Test public void testAuthenticateFailsWhenAuthenticationDaoHasBackendFailure() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceSimulateBackendError()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown InternalAuthenticationServiceException"); @@ -214,192 +181,146 @@ public class DaoAuthenticationProviderTests { @Test public void testAuthenticateFailsWithEmptyUsername() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - null, "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(null, "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticateFailsWithInvalidPassword() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "INVALID_PASSWORD"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "INVALID_PASSWORD"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticateFailsWithInvalidUsernameAndHideUserNotFoundExceptionFalse() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "INVALID_USER", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("INVALID_USER", "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setHideUserNotFoundExceptions(false); // we want // UsernameNotFoundExceptions provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown UsernameNotFoundException"); } catch (UsernameNotFoundException expected) { - } } @Test public void testAuthenticateFailsWithInvalidUsernameAndHideUserNotFoundExceptionsWithDefaultOfTrue() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "INVALID_USER", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("INVALID_USER", "koala"); DaoAuthenticationProvider provider = createProvider(); assertThat(provider.isHideUserNotFoundExceptions()).isTrue(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticateFailsWithInvalidUsernameAndChangePasswordEncoder() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "INVALID_USER", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("INVALID_USER", "koala"); DaoAuthenticationProvider provider = createProvider(); assertThat(provider.isHideUserNotFoundExceptions()).isTrue(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } - provider.setPasswordEncoder(PasswordEncoderFactories.createDelegatingPasswordEncoder()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticateFailsWithMixedCaseUsernameIfDefaultChanged() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "RoD", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("RoD", "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - try { provider.authenticate(token); fail("Should have thrown BadCredentialsException"); } catch (BadCredentialsException expected) { - } } @Test public void testAuthenticates() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "koala"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "koala"); token.setDetails("192.168.0.1"); - DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - Authentication result = provider.authenticate(token); - if (!(result instanceof UsernamePasswordAuthenticationToken)) { fail("Should have returned instance of UsernamePasswordAuthenticationToken"); } - UsernamePasswordAuthenticationToken castResult = (UsernamePasswordAuthenticationToken) result; assertThat(castResult.getPrincipal().getClass()).isEqualTo(User.class); assertThat(castResult.getCredentials()).isEqualTo("koala"); - assertThat( - AuthorityUtils.authorityListToSet(castResult.getAuthorities())).contains( - "ROLE_ONE", "ROLE_TWO"); + assertThat(AuthorityUtils.authorityListToSet(castResult.getAuthorities())).contains("ROLE_ONE", "ROLE_TWO"); assertThat(castResult.getDetails()).isEqualTo("192.168.0.1"); } @Test public void testAuthenticatesASecondTime() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); - Authentication result = provider.authenticate(token); - if (!(result instanceof UsernamePasswordAuthenticationToken)) { fail("Should have returned instance of UsernamePasswordAuthenticationToken"); } - // Now try to authenticate with the previous result (with its UserDetails) Authentication result2 = provider.authenticate(result); - if (!(result2 instanceof UsernamePasswordAuthenticationToken)) { fail("Should have returned instance of UsernamePasswordAuthenticationToken"); } - assertThat(result2.getCredentials()).isEqualTo(result.getCredentials()); } @Test public void testAuthenticatesWithForcePrincipalAsString() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); provider.setUserCache(new MockUserCache()); provider.setForcePrincipalAsString(true); - Authentication result = provider.authenticate(token); - if (!(result instanceof UsernamePasswordAuthenticationToken)) { fail("Should have returned instance of UsernamePasswordAuthenticationToken"); } - UsernamePasswordAuthenticationToken castResult = (UsernamePasswordAuthenticationToken) result; assertThat(castResult.getPrincipal().getClass()).isEqualTo(String.class); assertThat(castResult.getPrincipal()).isEqualTo("rod"); @@ -409,9 +330,7 @@ public class DaoAuthenticationProviderTests { public void authenticateWhenSuccessAndPasswordManagerThenUpdates() { String password = "password"; String encodedPassword = "encoded"; - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", password); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", password); PasswordEncoder encoder = mock(PasswordEncoder.class); UserDetailsService userDetailsService = mock(UserDetailsService.class); UserDetailsPasswordService passwordManager = mock(UserDetailsPasswordService.class); @@ -419,25 +338,20 @@ public class DaoAuthenticationProviderTests { provider.setPasswordEncoder(encoder); provider.setUserDetailsService(userDetailsService); provider.setUserDetailsPasswordService(passwordManager); - UserDetails user = PasswordEncodedUser.user(); - when(encoder.matches(any(), any())).thenReturn(true); - when(encoder.upgradeEncoding(any())).thenReturn(true); - when(encoder.encode(any())).thenReturn(encodedPassword); - when(userDetailsService.loadUserByUsername(any())).thenReturn(user); - when(passwordManager.updatePassword(any(), any())).thenReturn(user); - + given(encoder.matches(any(), any())).willReturn(true); + given(encoder.upgradeEncoding(any())).willReturn(true); + given(encoder.encode(any())).willReturn(encodedPassword); + given(userDetailsService.loadUserByUsername(any())).willReturn(user); + given(passwordManager.updatePassword(any(), any())).willReturn(user); Authentication result = provider.authenticate(token); - verify(encoder).encode(password); verify(passwordManager).updatePassword(eq(user), eq(encodedPassword)); } @Test public void authenticateWhenBadCredentialsAndPasswordManagerThenNoUpdate() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", "password"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password"); PasswordEncoder encoder = mock(PasswordEncoder.class); UserDetailsService userDetailsService = mock(UserDetailsService.class); UserDetailsPasswordService passwordManager = mock(UserDetailsPasswordService.class); @@ -445,22 +359,16 @@ public class DaoAuthenticationProviderTests { provider.setPasswordEncoder(encoder); provider.setUserDetailsService(userDetailsService); provider.setUserDetailsPasswordService(passwordManager); - UserDetails user = PasswordEncodedUser.user(); - when(encoder.matches(any(), any())).thenReturn(false); - when(userDetailsService.loadUserByUsername(any())).thenReturn(user); - - assertThatThrownBy(() -> provider.authenticate(token)) - .isInstanceOf(BadCredentialsException.class); - + given(encoder.matches(any(), any())).willReturn(false); + given(userDetailsService.loadUserByUsername(any())).willReturn(user); + assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> provider.authenticate(token)); verifyZeroInteractions(passwordManager); } @Test public void authenticateWhenNotUpgradeAndPasswordManagerThenNoUpdate() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", "password"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password"); PasswordEncoder encoder = mock(PasswordEncoder.class); UserDetailsService userDetailsService = mock(UserDetailsService.class); UserDetailsPasswordService passwordManager = mock(UserDetailsPasswordService.class); @@ -468,33 +376,26 @@ public class DaoAuthenticationProviderTests { provider.setPasswordEncoder(encoder); provider.setUserDetailsService(userDetailsService); provider.setUserDetailsPasswordService(passwordManager); - UserDetails user = PasswordEncodedUser.user(); - when(encoder.matches(any(), any())).thenReturn(true); - when(encoder.upgradeEncoding(any())).thenReturn(false); - when(userDetailsService.loadUserByUsername(any())).thenReturn(user); - + given(encoder.matches(any(), any())).willReturn(true); + given(encoder.upgradeEncoding(any())).willReturn(false); + given(userDetailsService.loadUserByUsername(any())).willReturn(user); Authentication result = provider.authenticate(token); - verifyZeroInteractions(passwordManager); } @Test public void testDetectsNullBeingReturnedFromAuthenticationDao() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "koala"); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(new MockUserDetailsServiceReturnsNull()); - try { provider.authenticate(token); fail("Should have thrown AuthenticationServiceException"); } catch (AuthenticationServiceException expected) { - assertThat( - "UserDetailsService returned null, which is an interface contract violation").isEqualTo( - expected.getMessage()); + assertThat("UserDetailsService returned null, which is an interface contract violation") + .isEqualTo(expected.getMessage()); } } @@ -502,13 +403,9 @@ public class DaoAuthenticationProviderTests { public void testGettersSetters() { DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setPasswordEncoder(new BCryptPasswordEncoder()); - assertThat(provider.getPasswordEncoder().getClass()).isEqualTo( - BCryptPasswordEncoder.class); - + assertThat(provider.getPasswordEncoder().getClass()).isEqualTo(BCryptPasswordEncoder.class); provider.setUserCache(new EhCacheBasedUserCache()); - assertThat(provider.getUserCache().getClass()).isEqualTo( - EhCacheBasedUserCache.class); - + assertThat(provider.getUserCache().getClass()).isEqualTo(EhCacheBasedUserCache.class); assertThat(provider.isForcePrincipalAsString()).isFalse(); provider.setForcePrincipalAsString(true); assertThat(provider.isForcePrincipalAsString()).isTrue(); @@ -516,44 +413,34 @@ public class DaoAuthenticationProviderTests { @Test public void testGoesBackToAuthenticationDaoToObtainLatestPasswordIfCachedPasswordSeemsIncorrect() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "rod", "koala"); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("rod", "koala"); MockUserDetailsServiceUserRod authenticationDao = new MockUserDetailsServiceUserRod(); MockUserCache cache = new MockUserCache(); DaoAuthenticationProvider provider = createProvider(); provider.setUserDetailsService(authenticationDao); provider.setUserCache(cache); - // This will work, as password still "koala" provider.authenticate(token); - // Check "rod = koala" ended up in the cache assertThat(cache.getUserFromCache("rod").getPassword()).isEqualTo("koala"); - // Now change the password the AuthenticationDao will return authenticationDao.setPassword("easternLongNeckTurtle"); - // Now try authentication again, with the new password token = new UsernamePasswordAuthenticationToken("rod", "easternLongNeckTurtle"); provider.authenticate(token); - // To get this far, the new password was accepted // Check the cache was updated - assertThat(cache.getUserFromCache("rod").getPassword()).isEqualTo( - "easternLongNeckTurtle"); + assertThat(cache.getUserFromCache("rod").getPassword()).isEqualTo("easternLongNeckTurtle"); } @Test public void testStartupFailsIfNoAuthenticationDao() throws Exception { DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); - try { provider.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @@ -563,13 +450,11 @@ public class DaoAuthenticationProviderTests { provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); assertThat(provider.getUserCache().getClass()).isEqualTo(NullUserCache.class); provider.setUserCache(null); - try { provider.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @@ -581,7 +466,6 @@ public class DaoAuthenticationProviderTests { provider.setUserCache(new MockUserCache()); assertThat(provider.getUserDetailsService()).isEqualTo(userDetailsService); provider.afterPropertiesSet(); - } @Test @@ -594,10 +478,9 @@ public class DaoAuthenticationProviderTests { // SEC-2056 @Test public void testUserNotFoundEncodesPassword() throws Exception { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "missing", "koala"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("missing", "koala"); PasswordEncoder encoder = mock(PasswordEncoder.class); - when(encoder.encode(anyString())).thenReturn("koala"); + given(encoder.encode(anyString())).willReturn("koala"); DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setHideUserNotFoundExceptions(false); provider.setPasswordEncoder(encoder); @@ -609,7 +492,6 @@ public class DaoAuthenticationProviderTests { } catch (UsernameNotFoundException success) { } - // ensure encoder invoked w/ non-null strings since PasswordEncoder impls may fail // if encoded password is null verify(encoder).matches(isA(String.class), isA(String.class)); @@ -617,15 +499,13 @@ public class DaoAuthenticationProviderTests { @Test public void testUserNotFoundBCryptPasswordEncoder() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "missing", "koala"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("missing", "koala"); PasswordEncoder encoder = new BCryptPasswordEncoder(); DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setHideUserNotFoundExceptions(false); provider.setPasswordEncoder(encoder); MockUserDetailsServiceUserRod userDetailsService = new MockUserDetailsServiceUserRod(); - userDetailsService.password = encoder.encode( - (CharSequence) token.getCredentials()); + userDetailsService.password = encoder.encode((CharSequence) token.getCredentials()); provider.setUserDetailsService(userDetailsService); try { provider.authenticate(token); @@ -637,8 +517,7 @@ public class DaoAuthenticationProviderTests { @Test public void testUserNotFoundDefaultEncoder() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "missing", null); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("missing", null); DaoAuthenticationProvider provider = createProvider(); provider.setHideUserNotFoundExceptions(false); provider.setUserDetailsService(new MockUserDetailsServiceUserRod()); @@ -656,28 +535,22 @@ public class DaoAuthenticationProviderTests { * SEC-2056 is fixed. */ public void IGNOREtestSec2056() { - UsernamePasswordAuthenticationToken foundUser = new UsernamePasswordAuthenticationToken( - "rod", "koala"); - UsernamePasswordAuthenticationToken notFoundUser = new UsernamePasswordAuthenticationToken( - "notFound", "koala"); + UsernamePasswordAuthenticationToken foundUser = new UsernamePasswordAuthenticationToken("rod", "koala"); + UsernamePasswordAuthenticationToken notFoundUser = new UsernamePasswordAuthenticationToken("notFound", "koala"); PasswordEncoder encoder = new BCryptPasswordEncoder(10, new SecureRandom()); DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setHideUserNotFoundExceptions(false); provider.setPasswordEncoder(encoder); MockUserDetailsServiceUserRod userDetailsService = new MockUserDetailsServiceUserRod(); - userDetailsService.password = encoder.encode( - (CharSequence) foundUser.getCredentials()); + userDetailsService.password = encoder.encode((CharSequence) foundUser.getCredentials()); provider.setUserDetailsService(userDetailsService); - int sampleSize = 100; - List userFoundTimes = new ArrayList<>(sampleSize); for (int i = 0; i < sampleSize; i++) { long start = System.currentTimeMillis(); provider.authenticate(foundUser); userFoundTimes.add(System.currentTimeMillis() - start); } - List userNotFoundTimes = new ArrayList<>(sampleSize); for (int i = 0; i < sampleSize; i++) { long start = System.currentTimeMillis(); @@ -689,13 +562,10 @@ public class DaoAuthenticationProviderTests { } userNotFoundTimes.add(System.currentTimeMillis() - start); } - double userFoundAvg = avg(userFoundTimes); double userNotFoundAvg = avg(userNotFoundTimes); - assertThat(Math.abs(userNotFoundAvg - userFoundAvg) <= 3).withFailMessage( - "User not found average " + userNotFoundAvg - + " should be within 3ms of user found average " - + userFoundAvg).isTrue(); + assertThat(Math.abs(userNotFoundAvg - userFoundAvg) <= 3).withFailMessage("User not found average " + + userNotFoundAvg + " should be within 3ms of user found average " + userFoundAvg).isTrue(); } private double avg(List counts) { @@ -708,8 +578,7 @@ public class DaoAuthenticationProviderTests { @Test public void testUserNotFoundNullCredentials() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "missing", null); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("missing", null); PasswordEncoder encoder = mock(PasswordEncoder.class); DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setHideUserNotFoundExceptions(false); @@ -721,91 +590,97 @@ public class DaoAuthenticationProviderTests { } catch (UsernameNotFoundException success) { } - verify(encoder, times(0)).matches(anyString(), anyString()); } - // ~ Inner Classes - // ================================================================================================== - - private class MockUserDetailsServiceReturnsNull implements UserDetailsService { - - public UserDetails loadUserByUsername(String username) { - return null; - } - } - - private class MockUserDetailsServiceSimulateBackendError - implements UserDetailsService { - - public UserDetails loadUserByUsername(String username) { - throw new DataRetrievalFailureException( - "This mock simulator is designed to fail"); - } - } - - private class MockUserDetailsServiceUserRod implements UserDetailsService { - - private String password = "koala"; - - public UserDetails loadUserByUsername(String username) { - if ("rod".equals(username)) { - return new User("rod", password, true, true, true, true, ROLES_12); - } - throw new UsernameNotFoundException("Could not find: " + username); - } - - public void setPassword(String password) { - this.password = password; - } - } - - private class MockUserDetailsServiceUserPeter implements UserDetailsService { - - public UserDetails loadUserByUsername(String username) { - if ("peter".equals(username)) { - return new User("peter", "opal", false, true, true, true, ROLES_12); - } - throw new UsernameNotFoundException("Could not find: " + username); - } - } - - private class MockUserDetailsServiceUserPeterAccountExpired - implements UserDetailsService { - - public UserDetails loadUserByUsername(String username) { - if ("peter".equals(username)) { - return new User("peter", "opal", true, false, true, true, ROLES_12); - } - throw new UsernameNotFoundException("Could not find: " + username); - } - } - - private class MockUserDetailsServiceUserPeterAccountLocked - implements UserDetailsService { - - public UserDetails loadUserByUsername(String username) { - if ("peter".equals(username)) { - return new User("peter", "opal", true, true, true, false, ROLES_12); - } - throw new UsernameNotFoundException("Could not find: " + username); - } - } - - private class MockUserDetailsServiceUserPeterCredentialsExpired - implements UserDetailsService { - - public UserDetails loadUserByUsername(String username) { - if ("peter".equals(username)) { - return new User("peter", "opal", true, true, false, true, ROLES_12); - } - throw new UsernameNotFoundException("Could not find: " + username); - } - } - private DaoAuthenticationProvider createProvider() { DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); provider.setPasswordEncoder(NoOpPasswordEncoder.getInstance()); return provider; } + + private class MockUserDetailsServiceReturnsNull implements UserDetailsService { + + @Override + public UserDetails loadUserByUsername(String username) { + return null; + } + + } + + private class MockUserDetailsServiceSimulateBackendError implements UserDetailsService { + + @Override + public UserDetails loadUserByUsername(String username) { + throw new DataRetrievalFailureException("This mock simulator is designed to fail"); + } + + } + + private class MockUserDetailsServiceUserRod implements UserDetailsService { + + private String password = "koala"; + + @Override + public UserDetails loadUserByUsername(String username) { + if ("rod".equals(username)) { + return new User("rod", this.password, true, true, true, true, ROLES_12); + } + throw new UsernameNotFoundException("Could not find: " + username); + } + + void setPassword(String password) { + this.password = password; + } + + } + + private class MockUserDetailsServiceUserPeter implements UserDetailsService { + + @Override + public UserDetails loadUserByUsername(String username) { + if ("peter".equals(username)) { + return new User("peter", "opal", false, true, true, true, ROLES_12); + } + throw new UsernameNotFoundException("Could not find: " + username); + } + + } + + private class MockUserDetailsServiceUserPeterAccountExpired implements UserDetailsService { + + @Override + public UserDetails loadUserByUsername(String username) { + if ("peter".equals(username)) { + return new User("peter", "opal", true, false, true, true, ROLES_12); + } + throw new UsernameNotFoundException("Could not find: " + username); + } + + } + + private class MockUserDetailsServiceUserPeterAccountLocked implements UserDetailsService { + + @Override + public UserDetails loadUserByUsername(String username) { + if ("peter".equals(username)) { + return new User("peter", "opal", true, true, true, false, ROLES_12); + } + throw new UsernameNotFoundException("Could not find: " + username); + } + + } + + private class MockUserDetailsServiceUserPeterCredentialsExpired implements UserDetailsService { + + @Override + public UserDetails loadUserByUsername(String username) { + if ("peter".equals(username)) { + return new User("peter", "opal", true, true, false, true, ROLES_12); + } + throw new UsernameNotFoundException("Could not find: " + username); + } + + } + } diff --git a/core/src/test/java/org/springframework/security/authentication/dao/MockUserCache.java b/core/src/test/java/org/springframework/security/authentication/dao/MockUserCache.java index ea7132f53b..5a2a0467ed 100644 --- a/core/src/test/java/org/springframework/security/authentication/dao/MockUserCache.java +++ b/core/src/test/java/org/springframework/security/authentication/dao/MockUserCache.java @@ -13,9 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/** - * - */ + package org.springframework.security.authentication.dao; import java.util.HashMap; @@ -25,17 +23,22 @@ import org.springframework.security.core.userdetails.UserCache; import org.springframework.security.core.userdetails.UserDetails; public class MockUserCache implements UserCache { + private Map cache = new HashMap<>(); + @Override public UserDetails getUserFromCache(String username) { - return cache.get(username); + return this.cache.get(username); } + @Override public void putUserInCache(UserDetails user) { - cache.put(user.getUsername(), user); + this.cache.put(user.getUsername(), user); } + @Override public void removeUserFromCache(String username) { - cache.remove(username); + this.cache.remove(username); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/event/AuthenticationEventTests.java b/core/src/test/java/org/springframework/security/authentication/event/AuthenticationEventTests.java index 1b08f45878..d843593f5c 100644 --- a/core/src/test/java/org/springframework/security/authentication/event/AuthenticationEventTests.java +++ b/core/src/test/java/org/springframework/security/authentication/event/AuthenticationEventTests.java @@ -16,29 +16,27 @@ package org.springframework.security.authentication.event; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import org.junit.Test; + import org.springframework.security.authentication.DisabledException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AbstractAuthenticationEvent} and its subclasses. * * @author Ben Alex */ public class AuthenticationEventTests { - // ~ Methods - // ======================================================================================================== private Authentication getAuthentication() { - UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken( - "Principal", "Credentials"); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("Principal", + "Credentials"); authentication.setDetails("127.0.0.1"); - return authentication; } @@ -53,8 +51,7 @@ public class AuthenticationEventTests { public void testAbstractAuthenticationFailureEvent() { Authentication auth = getAuthentication(); AuthenticationException exception = new DisabledException("TEST"); - AbstractAuthenticationFailureEvent event = new AuthenticationFailureDisabledEvent( - auth, exception); + AbstractAuthenticationFailureEvent event = new AuthenticationFailureDisabledEvent(auth, exception); assertThat(event.getAuthentication()).isEqualTo(auth); assertThat(event.getException()).isEqualTo(exception); } @@ -62,13 +59,11 @@ public class AuthenticationEventTests { @Test public void testRejectsNullAuthentication() { AuthenticationException exception = new DisabledException("TEST"); - try { new AuthenticationFailureDisabledEvent(null, exception); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @@ -79,7 +74,7 @@ public class AuthenticationEventTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } + } diff --git a/core/src/test/java/org/springframework/security/authentication/event/LoggerListenerTests.java b/core/src/test/java/org/springframework/security/authentication/event/LoggerListenerTests.java index fd8ac22222..4d788d4377 100644 --- a/core/src/test/java/org/springframework/security/authentication/event/LoggerListenerTests.java +++ b/core/src/test/java/org/springframework/security/authentication/event/LoggerListenerTests.java @@ -17,6 +17,7 @@ package org.springframework.security.authentication.event; import org.junit.Test; + import org.springframework.security.authentication.LockedException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; @@ -27,23 +28,20 @@ import org.springframework.security.core.Authentication; * @author Ben Alex */ public class LoggerListenerTests { - // ~ Methods - // ======================================================================================================== private Authentication getAuthentication() { - UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken( - "Principal", "Credentials"); + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("Principal", + "Credentials"); authentication.setDetails("127.0.0.1"); - return authentication; } @Test public void testLogsEvents() { - AuthenticationFailureDisabledEvent event = new AuthenticationFailureDisabledEvent( - getAuthentication(), new LockedException("TEST")); + AuthenticationFailureDisabledEvent event = new AuthenticationFailureDisabledEvent(getAuthentication(), + new LockedException("TEST")); LoggerListener listener = new LoggerListener(); listener.onApplicationEvent(event); - } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProviderTests.java index 66b117651a..e075c6184d 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/DefaultJaasAuthenticationProviderTests.java @@ -13,20 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication.jaas; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -import java.util.*; +import java.util.Arrays; +import java.util.Collections; import javax.security.auth.login.AppConfigurationEntry; import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag; @@ -38,12 +29,12 @@ import org.apache.commons.logging.Log; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.support.ClassPathXmlApplicationContext; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; -import org.springframework.security.authentication.jaas.DefaultJaasAuthenticationProvider; import org.springframework.security.authentication.jaas.event.JaasAuthenticationFailedEvent; import org.springframework.security.authentication.jaas.event.JaasAuthenticationSuccessEvent; import org.springframework.security.core.Authentication; @@ -52,82 +43,92 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.session.SessionDestroyedEvent; import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + public class DefaultJaasAuthenticationProviderTests { + private DefaultJaasAuthenticationProvider provider; + private UsernamePasswordAuthenticationToken token; + private ApplicationEventPublisher publisher; + private Log log; @Before public void setUp() throws Exception { Configuration configuration = mock(Configuration.class); - publisher = mock(ApplicationEventPublisher.class); - log = mock(Log.class); - provider = new DefaultJaasAuthenticationProvider(); - provider.setConfiguration(configuration); - provider.setApplicationEventPublisher(publisher); - provider.setAuthorityGranters(new AuthorityGranter[] { new TestAuthorityGranter() }); - provider.afterPropertiesSet(); - AppConfigurationEntry[] aces = new AppConfigurationEntry[] { new AppConfigurationEntry( - TestLoginModule.class.getName(), LoginModuleControlFlag.REQUIRED, - Collections. emptyMap()) }; - when(configuration.getAppConfigurationEntry(provider.getLoginContextName())) - .thenReturn(aces); - token = new UsernamePasswordAuthenticationToken("user", "password"); - ReflectionTestUtils.setField(provider, "log", log); - + this.publisher = mock(ApplicationEventPublisher.class); + this.log = mock(Log.class); + this.provider = new DefaultJaasAuthenticationProvider(); + this.provider.setConfiguration(configuration); + this.provider.setApplicationEventPublisher(this.publisher); + this.provider.setAuthorityGranters(new AuthorityGranter[] { new TestAuthorityGranter() }); + this.provider.afterPropertiesSet(); + AppConfigurationEntry[] aces = new AppConfigurationEntry[] { + new AppConfigurationEntry(TestLoginModule.class.getName(), LoginModuleControlFlag.REQUIRED, + Collections.emptyMap()) }; + given(configuration.getAppConfigurationEntry(this.provider.getLoginContextName())).willReturn(aces); + this.token = new UsernamePasswordAuthenticationToken("user", "password"); + ReflectionTestUtils.setField(this.provider, "log", this.log); } @Test(expected = IllegalArgumentException.class) public void afterPropertiesSetNullConfiguration() throws Exception { - provider.setConfiguration(null); - provider.afterPropertiesSet(); + this.provider.setConfiguration(null); + this.provider.afterPropertiesSet(); } @Test(expected = IllegalArgumentException.class) public void afterPropertiesSetNullAuthorityGranters() throws Exception { - provider.setAuthorityGranters(null); - provider.afterPropertiesSet(); + this.provider.setAuthorityGranters(null); + this.provider.afterPropertiesSet(); } @Test public void authenticateUnsupportedAuthentication() { - assertThat(provider.authenticate(new TestingAuthenticationToken("user", "password"))).isNull(); + assertThat(this.provider.authenticate(new TestingAuthenticationToken("user", "password"))).isNull(); } @Test public void authenticateSuccess() { - Authentication auth = provider.authenticate(token); - assertThat(auth.getPrincipal()).isEqualTo(token.getPrincipal()); - assertThat(auth.getCredentials()).isEqualTo(token.getCredentials()); + Authentication auth = this.provider.authenticate(this.token); + assertThat(auth.getPrincipal()).isEqualTo(this.token.getPrincipal()); + assertThat(auth.getCredentials()).isEqualTo(this.token.getCredentials()); assertThat(auth.isAuthenticated()).isEqualTo(true); assertThat(auth.getAuthorities().isEmpty()).isEqualTo(false); - verify(publisher).publishEvent(isA(JaasAuthenticationSuccessEvent.class)); - verifyNoMoreInteractions(publisher); + verify(this.publisher).publishEvent(isA(JaasAuthenticationSuccessEvent.class)); + verifyNoMoreInteractions(this.publisher); } @Test public void authenticateBadPassword() { try { - provider.authenticate(new UsernamePasswordAuthenticationToken("user", "asdf")); + this.provider.authenticate(new UsernamePasswordAuthenticationToken("user", "asdf")); fail("LoginException should have been thrown for the bad password"); } catch (AuthenticationException success) { } - verifyFailedLogin(); } @Test public void authenticateBadUser() { try { - provider.authenticate(new UsernamePasswordAuthenticationToken("asdf", - "password")); + this.provider.authenticate(new UsernamePasswordAuthenticationToken("asdf", "password")); fail("LoginException should have been thrown for the bad user"); } catch (AuthenticationException success) { } - verifyFailedLogin(); } @@ -137,13 +138,10 @@ public class DefaultJaasAuthenticationProviderTests { SecurityContext securityContext = mock(SecurityContext.class); JaasAuthenticationToken token = mock(JaasAuthenticationToken.class); LoginContext context = mock(LoginContext.class); - - when(event.getSecurityContexts()).thenReturn(Arrays.asList(securityContext)); - when(securityContext.getAuthentication()).thenReturn(token); - when(token.getLoginContext()).thenReturn(context); - - provider.onApplicationEvent(event); - + given(event.getSecurityContexts()).willReturn(Arrays.asList(securityContext)); + given(securityContext.getAuthentication()).willReturn(token); + given(token.getLoginContext()).willReturn(context); + this.provider.onApplicationEvent(event); verify(event).getSecurityContexts(); verify(securityContext).getAuthentication(); verify(token).getLoginContext(); @@ -154,11 +152,9 @@ public class DefaultJaasAuthenticationProviderTests { @Test public void logoutNullSession() { SessionDestroyedEvent event = mock(SessionDestroyedEvent.class); - - provider.handleLogout(event); - + this.provider.handleLogout(event); verify(event).getSecurityContexts(); - verify(log).debug(anyString()); + verify(this.log).debug(anyString()); verifyNoMoreInteractions(event); } @@ -166,11 +162,8 @@ public class DefaultJaasAuthenticationProviderTests { public void logoutNullAuthentication() { SessionDestroyedEvent event = mock(SessionDestroyedEvent.class); SecurityContext securityContext = mock(SecurityContext.class); - - when(event.getSecurityContexts()).thenReturn(Arrays.asList(securityContext)); - - provider.handleLogout(event); - + given(event.getSecurityContexts()).willReturn(Arrays.asList(securityContext)); + this.provider.handleLogout(event); verify(event).getSecurityContexts(); verify(event).getSecurityContexts(); verify(securityContext).getAuthentication(); @@ -181,12 +174,9 @@ public class DefaultJaasAuthenticationProviderTests { public void logoutNonJaasAuthentication() { SessionDestroyedEvent event = mock(SessionDestroyedEvent.class); SecurityContext securityContext = mock(SecurityContext.class); - - when(event.getSecurityContexts()).thenReturn(Arrays.asList(securityContext)); - when(securityContext.getAuthentication()).thenReturn(token); - - provider.handleLogout(event); - + given(event.getSecurityContexts()).willReturn(Arrays.asList(securityContext)); + given(securityContext.getAuthentication()).willReturn(this.token); + this.provider.handleLogout(event); verify(event).getSecurityContexts(); verify(event).getSecurityContexts(); verify(securityContext).getAuthentication(); @@ -198,15 +188,12 @@ public class DefaultJaasAuthenticationProviderTests { SessionDestroyedEvent event = mock(SessionDestroyedEvent.class); SecurityContext securityContext = mock(SecurityContext.class); JaasAuthenticationToken token = mock(JaasAuthenticationToken.class); - - when(event.getSecurityContexts()).thenReturn(Arrays.asList(securityContext)); - when(securityContext.getAuthentication()).thenReturn(token); - - provider.onApplicationEvent(event); + given(event.getSecurityContexts()).willReturn(Arrays.asList(securityContext)); + given(securityContext.getAuthentication()).willReturn(token); + this.provider.onApplicationEvent(event); verify(event).getSecurityContexts(); verify(securityContext).getAuthentication(); verify(token).getLoginContext(); - verifyNoMoreInteractions(event, securityContext, token); } @@ -217,42 +204,37 @@ public class DefaultJaasAuthenticationProviderTests { JaasAuthenticationToken token = mock(JaasAuthenticationToken.class); LoginContext context = mock(LoginContext.class); LoginException loginException = new LoginException("Failed Login"); - - when(event.getSecurityContexts()).thenReturn(Arrays.asList(securityContext)); - when(securityContext.getAuthentication()).thenReturn(token); - when(token.getLoginContext()).thenReturn(context); - doThrow(loginException).when(context).logout(); - - provider.onApplicationEvent(event); - + given(event.getSecurityContexts()).willReturn(Arrays.asList(securityContext)); + given(securityContext.getAuthentication()).willReturn(token); + given(token.getLoginContext()).willReturn(context); + willThrow(loginException).given(context).logout(); + this.provider.onApplicationEvent(event); verify(event).getSecurityContexts(); verify(securityContext).getAuthentication(); verify(token).getLoginContext(); verify(context).logout(); - verify(log).warn(anyString(), eq(loginException)); + verify(this.log).warn(anyString(), eq(loginException)); verifyNoMoreInteractions(event, securityContext, token, context); } @Test public void publishNullPublisher() { - provider.setApplicationEventPublisher(null); + this.provider.setApplicationEventPublisher(null); AuthenticationException ae = new BadCredentialsException("Failed to login"); - - provider.publishFailureEvent(token, ae); - provider.publishSuccessEvent(token); + this.provider.publishFailureEvent(this.token, ae); + this.provider.publishSuccessEvent(this.token); } @Test public void javadocExample() { String resName = "/" + getClass().getName().replace('.', '/') + ".xml"; - ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext( - resName); + ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext(resName); context.registerShutdownHook(); try { - provider = context.getBean(DefaultJaasAuthenticationProvider.class); - Authentication auth = provider.authenticate(token); + this.provider = context.getBean(DefaultJaasAuthenticationProvider.class); + Authentication auth = this.provider.authenticate(this.token); assertThat(auth.isAuthenticated()).isEqualTo(true); - assertThat(auth.getPrincipal()).isEqualTo(token.getPrincipal()); + assertThat(auth.getPrincipal()).isEqualTo(this.token.getPrincipal()); } finally { context.close(); @@ -260,10 +242,12 @@ public class DefaultJaasAuthenticationProviderTests { } private void verifyFailedLogin() { - ArgumentCaptor event = ArgumentCaptor.forClass(JaasAuthenticationFailedEvent.class); - verify(publisher).publishEvent(event.capture()); + ArgumentCaptor event = ArgumentCaptor + .forClass(JaasAuthenticationFailedEvent.class); + verify(this.publisher).publishEvent(event.capture()); assertThat(event.getValue()).isInstanceOf(JaasAuthenticationFailedEvent.class); assertThat(event.getValue().getException()).isNotNull(); - verifyNoMoreInteractions(publisher); + verifyNoMoreInteractions(this.publisher); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/JaasAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/jaas/JaasAuthenticationProviderTests.java index ea794e8126..6f59331bbf 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/JaasAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/JaasAuthenticationProviderTests.java @@ -16,21 +16,21 @@ package org.springframework.security.authentication.jaas; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import java.io.File; import java.io.FileOutputStream; import java.io.PrintWriter; import java.net.URL; import java.security.Security; -import java.util.*; +import java.util.Arrays; +import java.util.Collection; +import java.util.Set; import javax.security.auth.login.LoginContext; import javax.security.auth.login.LoginException; import org.junit.Before; import org.junit.Test; + import org.springframework.context.ApplicationContext; import org.springframework.context.support.ClassPathXmlApplicationContext; import org.springframework.core.io.FileSystemResource; @@ -45,68 +45,65 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.session.SessionDestroyedEvent; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests for the JaasAuthenticationProvider * * @author Ray Krueger */ public class JaasAuthenticationProviderTests { - // ~ Instance fields - // ================================================================================================ private ApplicationContext context; - private JaasAuthenticationProvider jaasProvider; - private JaasEventCheck eventCheck; - // ~ Methods - // ======================================================================================================== + private JaasAuthenticationProvider jaasProvider; + + private JaasEventCheck eventCheck; @Before public void setUp() { String resName = "/" + getClass().getName().replace('.', '/') + ".xml"; - context = new ClassPathXmlApplicationContext(resName); - eventCheck = (JaasEventCheck) context.getBean("eventCheck"); - jaasProvider = (JaasAuthenticationProvider) context - .getBean("jaasAuthenticationProvider"); + this.context = new ClassPathXmlApplicationContext(resName); + this.eventCheck = (JaasEventCheck) this.context.getBean("eventCheck"); + this.jaasProvider = (JaasAuthenticationProvider) this.context.getBean("jaasAuthenticationProvider"); } @Test public void testBadPassword() { try { - jaasProvider.authenticate(new UsernamePasswordAuthenticationToken("user", - "asdf")); + this.jaasProvider.authenticate(new UsernamePasswordAuthenticationToken("user", "asdf")); fail("LoginException should have been thrown for the bad password"); } - catch (AuthenticationException e) { + catch (AuthenticationException ex) { } - - assertThat(eventCheck.failedEvent).as("Failure event not fired").isNotNull(); - assertThat(eventCheck.failedEvent.getException()).withFailMessage("Failure event exception was null").isNotNull(); - assertThat(eventCheck.successEvent).as("Success event was fired").isNull(); + assertThat(this.eventCheck.failedEvent).as("Failure event not fired").isNotNull(); + assertThat(this.eventCheck.failedEvent.getException()).withFailMessage("Failure event exception was null") + .isNotNull(); + assertThat(this.eventCheck.successEvent).as("Success event was fired").isNull(); } @Test public void testBadUser() { try { - jaasProvider.authenticate(new UsernamePasswordAuthenticationToken("asdf", - "password")); + this.jaasProvider.authenticate(new UsernamePasswordAuthenticationToken("asdf", "password")); fail("LoginException should have been thrown for the bad user"); } - catch (AuthenticationException e) { + catch (AuthenticationException ex) { } - - assertThat(eventCheck.failedEvent).as("Failure event not fired").isNotNull(); - assertThat(eventCheck.failedEvent.getException()).withFailMessage("Failure event exception was null").isNotNull(); - assertThat(eventCheck.successEvent).as("Success event was fired").isNull(); + assertThat(this.eventCheck.failedEvent).as("Failure event not fired").isNotNull(); + assertThat(this.eventCheck.failedEvent.getException()).withFailMessage("Failure event exception was null") + .isNotNull(); + assertThat(this.eventCheck.successEvent).as("Success event was fired").isNull(); } @Test public void testConfigurationLoop() throws Exception { String resName = "/" + getClass().getName().replace('.', '/') + ".conf"; URL url = getClass().getResource(resName); - Security.setProperty("login.config.url.1", url.toString()); - setUp(); testFull(); } @@ -114,11 +111,10 @@ public class JaasAuthenticationProviderTests { @Test public void detectsMissingLoginConfig() throws Exception { JaasAuthenticationProvider myJaasProvider = new JaasAuthenticationProvider(); - myJaasProvider.setApplicationEventPublisher(context); - myJaasProvider.setAuthorityGranters(jaasProvider.getAuthorityGranters()); - myJaasProvider.setCallbackHandlers(jaasProvider.getCallbackHandlers()); - myJaasProvider.setLoginContextName(jaasProvider.getLoginContextName()); - + myJaasProvider.setApplicationEventPublisher(this.context); + myJaasProvider.setAuthorityGranters(this.jaasProvider.getAuthorityGranters()); + myJaasProvider.setCallbackHandlers(this.jaasProvider.getCallbackHandlers()); + myJaasProvider.setLoginContextName(this.jaasProvider.getLoginContextName()); try { myJaasProvider.afterPropertiesSet(); fail("Should have thrown ApplicationContextException"); @@ -133,10 +129,8 @@ public class JaasAuthenticationProviderTests { public void spacesInLoginConfigPathAreAccepted() throws Exception { File configFile; // Create temp directory with a space in the name - File configDir = new File(System.getProperty("java.io.tmpdir") + File.separator - + "jaas test"); + File configDir = new File(System.getProperty("java.io.tmpdir") + File.separator + "jaas test"); configDir.deleteOnExit(); - if (configDir.exists()) { configDir.delete(); } @@ -145,31 +139,27 @@ public class JaasAuthenticationProviderTests { configFile.deleteOnExit(); FileOutputStream fos = new FileOutputStream(configFile); PrintWriter pw = new PrintWriter(fos); - pw.append("JAASTestBlah {" - + "org.springframework.security.authentication.jaas.TestLoginModule required;" - + "};"); + pw.append( + "JAASTestBlah {" + "org.springframework.security.authentication.jaas.TestLoginModule required;" + "};"); pw.flush(); pw.close(); - JaasAuthenticationProvider myJaasProvider = new JaasAuthenticationProvider(); - myJaasProvider.setApplicationEventPublisher(context); + myJaasProvider.setApplicationEventPublisher(this.context); myJaasProvider.setLoginConfig(new FileSystemResource(configFile)); - myJaasProvider.setAuthorityGranters(jaasProvider.getAuthorityGranters()); - myJaasProvider.setCallbackHandlers(jaasProvider.getCallbackHandlers()); - myJaasProvider.setLoginContextName(jaasProvider.getLoginContextName()); - + myJaasProvider.setAuthorityGranters(this.jaasProvider.getAuthorityGranters()); + myJaasProvider.setCallbackHandlers(this.jaasProvider.getCallbackHandlers()); + myJaasProvider.setLoginContextName(this.jaasProvider.getLoginContextName()); myJaasProvider.afterPropertiesSet(); } @Test public void detectsMissingLoginContextName() throws Exception { JaasAuthenticationProvider myJaasProvider = new JaasAuthenticationProvider(); - myJaasProvider.setApplicationEventPublisher(context); - myJaasProvider.setAuthorityGranters(jaasProvider.getAuthorityGranters()); - myJaasProvider.setCallbackHandlers(jaasProvider.getCallbackHandlers()); - myJaasProvider.setLoginConfig(jaasProvider.getLoginConfig()); + myJaasProvider.setApplicationEventPublisher(this.context); + myJaasProvider.setAuthorityGranters(this.jaasProvider.getAuthorityGranters()); + myJaasProvider.setCallbackHandlers(this.jaasProvider.getCallbackHandlers()); + myJaasProvider.setLoginConfig(this.jaasProvider.getLoginConfig()); myJaasProvider.setLoginContextName(null); - try { myJaasProvider.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); @@ -177,9 +167,7 @@ public class JaasAuthenticationProviderTests { catch (IllegalArgumentException expected) { assertThat(expected.getMessage()).startsWith("loginContextName must be set on"); } - myJaasProvider.setLoginContextName(""); - try { myJaasProvider.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); @@ -191,111 +179,95 @@ public class JaasAuthenticationProviderTests { @Test public void testFull() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", "password", AuthorityUtils.createAuthorityList("ROLE_ONE")); - - assertThat(jaasProvider.supports(UsernamePasswordAuthenticationToken.class)).isTrue(); - - Authentication auth = jaasProvider.authenticate(token); - - assertThat(jaasProvider.getAuthorityGranters()).isNotNull(); - assertThat(jaasProvider.getCallbackHandlers()).isNotNull(); - assertThat(jaasProvider.getLoginConfig()).isNotNull(); - assertThat(jaasProvider.getLoginContextName()).isNotNull(); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password", + AuthorityUtils.createAuthorityList("ROLE_ONE")); + assertThat(this.jaasProvider.supports(UsernamePasswordAuthenticationToken.class)).isTrue(); + Authentication auth = this.jaasProvider.authenticate(token); + assertThat(this.jaasProvider.getAuthorityGranters()).isNotNull(); + assertThat(this.jaasProvider.getCallbackHandlers()).isNotNull(); + assertThat(this.jaasProvider.getLoginConfig()).isNotNull(); + assertThat(this.jaasProvider.getLoginContextName()).isNotNull(); Collection list = auth.getAuthorities(); Set set = AuthorityUtils.authorityListToSet(list); - - assertThat(set.contains("ROLE_ONE")).withFailMessage("GrantedAuthorities should not contain ROLE_ONE").isFalse(); + assertThat(set.contains("ROLE_ONE")).withFailMessage("GrantedAuthorities should not contain ROLE_ONE") + .isFalse(); assertThat(set.contains("ROLE_TEST1")).withFailMessage("GrantedAuthorities should contain ROLE_TEST1").isTrue(); assertThat(set.contains("ROLE_TEST2")).withFailMessage("GrantedAuthorities should contain ROLE_TEST2").isTrue(); boolean foundit = false; - for (GrantedAuthority a : list) { if (a instanceof JaasGrantedAuthority) { JaasGrantedAuthority grant = (JaasGrantedAuthority) a; - assertThat(grant.getPrincipal()).withFailMessage("Principal was null on JaasGrantedAuthority").isNotNull(); + assertThat(grant.getPrincipal()).withFailMessage("Principal was null on JaasGrantedAuthority") + .isNotNull(); foundit = true; } } - assertThat(foundit).as("Could not find a JaasGrantedAuthority").isTrue(); - - assertThat(eventCheck.successEvent).as("Success event should be fired").isNotNull(); - assertThat(eventCheck.successEvent.getAuthentication()).withFailMessage("Auth objects should be equal").isEqualTo(auth); - assertThat(eventCheck.failedEvent).as("Failure event should not be fired").isNull(); + assertThat(this.eventCheck.successEvent).as("Success event should be fired").isNotNull(); + assertThat(this.eventCheck.successEvent.getAuthentication()).withFailMessage("Auth objects should be equal") + .isEqualTo(auth); + assertThat(this.eventCheck.failedEvent).as("Failure event should not be fired").isNull(); } @Test public void testGetApplicationEventPublisher() { - assertThat(jaasProvider.getApplicationEventPublisher()).isNotNull(); + assertThat(this.jaasProvider.getApplicationEventPublisher()).isNotNull(); } @Test public void testLoginExceptionResolver() { - assertThat(jaasProvider.getLoginExceptionResolver()).isNotNull(); - jaasProvider.setLoginExceptionResolver(e -> new LockedException("This is just a test!")); - + assertThat(this.jaasProvider.getLoginExceptionResolver()).isNotNull(); + this.jaasProvider.setLoginExceptionResolver((e) -> new LockedException("This is just a test!")); try { - jaasProvider.authenticate(new UsernamePasswordAuthenticationToken("user", - "password")); + this.jaasProvider.authenticate(new UsernamePasswordAuthenticationToken("user", "password")); } - catch (LockedException e) { + catch (LockedException ex) { } - catch (Exception e) { + catch (Exception ex) { fail("LockedException should have been thrown and caught"); } } @Test public void testLogout() throws Exception { - MockLoginContext loginContext = new MockLoginContext( - jaasProvider.getLoginContextName()); - - JaasAuthenticationToken token = new JaasAuthenticationToken(null, null, - loginContext); - + MockLoginContext loginContext = new MockLoginContext(this.jaasProvider.getLoginContextName()); + JaasAuthenticationToken token = new JaasAuthenticationToken(null, null, loginContext); SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(token); - SessionDestroyedEvent event = mock(SessionDestroyedEvent.class); - when(event.getSecurityContexts()).thenReturn(Arrays.asList(context)); - - jaasProvider.handleLogout(event); - + given(event.getSecurityContexts()).willReturn(Arrays.asList(context)); + this.jaasProvider.handleLogout(event); assertThat(loginContext.loggedOut).isTrue(); } @Test public void testNullDefaultAuthorities() { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", "password"); - - assertThat(jaasProvider.supports(UsernamePasswordAuthenticationToken.class)).isTrue(); - - Authentication auth = jaasProvider.authenticate(token); - assertThat(auth - .getAuthorities()).withFailMessage("Only ROLE_TEST1 and ROLE_TEST2 should have been returned").hasSize(2); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password"); + assertThat(this.jaasProvider.supports(UsernamePasswordAuthenticationToken.class)).isTrue(); + Authentication auth = this.jaasProvider.authenticate(token); + assertThat(auth.getAuthorities()).withFailMessage("Only ROLE_TEST1 and ROLE_TEST2 should have been returned") + .hasSize(2); } @Test public void testUnsupportedAuthenticationObjectReturnsNull() { - assertThat(jaasProvider.authenticate(new TestingAuthenticationToken("foo", "bar", - AuthorityUtils.NO_AUTHORITIES))).isNull(); + assertThat(this.jaasProvider + .authenticate(new TestingAuthenticationToken("foo", "bar", AuthorityUtils.NO_AUTHORITIES))).isNull(); } - // ~ Inner Classes - // ================================================================================================== - private static class MockLoginContext extends LoginContext { + boolean loggedOut = false; MockLoginContext(String loginModule) throws LoginException { super(loginModule); } + @Override public void logout() { this.loggedOut = true; } + } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/JaasEventCheck.java b/core/src/test/java/org/springframework/security/authentication/jaas/JaasEventCheck.java index 9760f2863c..7efedeecb6 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/JaasEventCheck.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/JaasEventCheck.java @@ -25,22 +25,19 @@ import org.springframework.security.authentication.jaas.event.JaasAuthentication * @author Ray Krueger */ public class JaasEventCheck implements ApplicationListener { - // ~ Instance fields - // ================================================================================================ JaasAuthenticationFailedEvent failedEvent; + JaasAuthenticationSuccessEvent successEvent; - // ~ Methods - // ======================================================================================================== - + @Override public void onApplicationEvent(JaasAuthenticationEvent event) { if (event instanceof JaasAuthenticationFailedEvent) { - failedEvent = (JaasAuthenticationFailedEvent) event; + this.failedEvent = (JaasAuthenticationFailedEvent) event; } - if (event instanceof JaasAuthenticationSuccessEvent) { - successEvent = (JaasAuthenticationSuccessEvent) event; + this.successEvent = (JaasAuthenticationSuccessEvent) event; } } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/JaasGrantedAuthorityTests.java b/core/src/test/java/org/springframework/security/authentication/jaas/JaasGrantedAuthorityTests.java index 9b952fd0ec..8310693337 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/JaasGrantedAuthorityTests.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/JaasGrantedAuthorityTests.java @@ -18,31 +18,24 @@ package org.springframework.security.authentication.jaas; import org.junit.Test; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import org.springframework.security.authentication.jaas.JaasGrantedAuthority; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** - * * @author Clement Ng * */ public class JaasGrantedAuthorityTests { - /** - */ @Test public void authorityWithNullRoleFailsAssertion() { - assertThatThrownBy(() -> new JaasGrantedAuthority(null, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("role cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> new JaasGrantedAuthority(null, null)) + .withMessageContaining("role cannot be null"); } - /** - */ @Test public void authorityWithNullPrincipleFailsAssertion() { - assertThatThrownBy(() -> new JaasGrantedAuthority("role", null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("principal cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> new JaasGrantedAuthority("role", null)) + .withMessageContaining("principal cannot be null"); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/Sec760Tests.java b/core/src/test/java/org/springframework/security/authentication/jaas/Sec760Tests.java index e1d0010d4a..bec12f7be4 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/Sec760Tests.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/Sec760Tests.java @@ -13,21 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication.jaas; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; + import org.springframework.core.io.ClassPathResource; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; -import org.springframework.security.authentication.jaas.AuthorityGranter; -import org.springframework.security.authentication.jaas.JaasAuthenticationCallbackHandler; -import org.springframework.security.authentication.jaas.JaasAuthenticationProvider; -import org.springframework.security.authentication.jaas.JaasNameCallbackHandler; -import org.springframework.security.authentication.jaas.JaasPasswordCallbackHandler; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests bug reported in SEC-760. * @@ -37,45 +34,37 @@ import org.springframework.security.core.authority.AuthorityUtils; public class Sec760Tests { public String resolveConfigFile(String filename) { - String resName = "/" + getClass().getPackage().getName().replace('.', '/') - + filename; + String resName = "/" + getClass().getPackage().getName().replace('.', '/') + filename; return resName; } - private void testConfigureJaasCase(JaasAuthenticationProvider p1, - JaasAuthenticationProvider p2) throws Exception { + private void testConfigureJaasCase(JaasAuthenticationProvider p1, JaasAuthenticationProvider p2) throws Exception { p1.setLoginConfig(new ClassPathResource(resolveConfigFile("/test1.conf"))); p1.setLoginContextName("test1"); - p1.setCallbackHandlers(new JaasAuthenticationCallbackHandler[] { - new TestCallbackHandler(), new JaasNameCallbackHandler(), - new JaasPasswordCallbackHandler() }); + p1.setCallbackHandlers(new JaasAuthenticationCallbackHandler[] { new TestCallbackHandler(), + new JaasNameCallbackHandler(), new JaasPasswordCallbackHandler() }); p1.setAuthorityGranters(new AuthorityGranter[] { new TestAuthorityGranter() }); p1.afterPropertiesSet(); testAuthenticate(p1); - p2.setLoginConfig(new ClassPathResource(resolveConfigFile("/test2.conf"))); p2.setLoginContextName("test2"); - p2.setCallbackHandlers(new JaasAuthenticationCallbackHandler[] { - new TestCallbackHandler(), new JaasNameCallbackHandler(), - new JaasPasswordCallbackHandler() }); + p2.setCallbackHandlers(new JaasAuthenticationCallbackHandler[] { new TestCallbackHandler(), + new JaasNameCallbackHandler(), new JaasPasswordCallbackHandler() }); p2.setAuthorityGranters(new AuthorityGranter[] { new TestAuthorityGranter() }); p2.afterPropertiesSet(); testAuthenticate(p2); } private void testAuthenticate(JaasAuthenticationProvider p1) { - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - "user", "password", AuthorityUtils.createAuthorityList("ROLE_ONE", - "ROLE_TWO")); - + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("user", "password", + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); Authentication auth = p1.authenticate(token); assertThat(auth).isNotNull(); } @Test public void testConfigureJaas() throws Exception { - testConfigureJaasCase(new JaasAuthenticationProvider(), - new JaasAuthenticationProvider()); + testConfigureJaasCase(new JaasAuthenticationProvider(), new JaasAuthenticationProvider()); } } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/SecurityContextLoginModuleTests.java b/core/src/test/java/org/springframework/security/authentication/jaas/SecurityContextLoginModuleTests.java index f34222e28c..8090324520 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/SecurityContextLoginModuleTests.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/SecurityContextLoginModuleTests.java @@ -39,17 +39,13 @@ import static org.assertj.core.api.Assertions.fail; * @author Ray Krueger */ public class SecurityContextLoginModuleTests { - // ~ Instance fields - // ================================================================================================ private SecurityContextLoginModule module = null; - private Subject subject = new Subject(false, new HashSet<>(), - new HashSet<>(), new HashSet<>()); - private UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken( - "principal", "credentials"); - // ~ Methods - // ======================================================================================================== + private Subject subject = new Subject(false, new HashSet<>(), new HashSet<>(), new HashSet<>()); + + private UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("principal", + "credentials"); @Before public void setUp() { @@ -66,8 +62,7 @@ public class SecurityContextLoginModuleTests { @Test public void testAbort() throws Exception { - assertThat(this.module.abort()).as("Should return false, no auth is set") - .isFalse(); + assertThat(this.module.abort()).as("Should return false, no auth is set").isFalse(); SecurityContextHolder.getContext().setAuthentication(this.auth); this.module.login(); this.module.commit(); @@ -80,18 +75,15 @@ public class SecurityContextLoginModuleTests { this.module.login(); fail("LoginException expected, there is no Authentication in the SecurityContext"); } - catch (LoginException e) { + catch (LoginException ex) { } } @Test public void testLoginSuccess() throws Exception { SecurityContextHolder.getContext().setAuthentication(this.auth); - assertThat(this.module.login()) - .as("Login should succeed, there is an authentication set").isTrue(); - assertThat(this.module.commit()) - .withFailMessage( - "The authentication is not null, this should return true") + assertThat(this.module.login()).as("Login should succeed, there is an authentication set").isTrue(); + assertThat(this.module.commit()).withFailMessage("The authentication is not null, this should return true") .isTrue(); assertThat(this.subject.getPrincipals().contains(this.auth)) .withFailMessage("Principals should contain the authentication").isTrue(); @@ -102,13 +94,9 @@ public class SecurityContextLoginModuleTests { SecurityContextHolder.getContext().setAuthentication(this.auth); this.module.login(); assertThat(this.module.logout()).as("Should return true as it succeeds").isTrue(); - assertThat(this.module.getAuthentication()).as("Authentication should be null") - .isNull(); - + assertThat(this.module.getAuthentication()).as("Authentication should be null").isNull(); assertThat(this.subject.getPrincipals().contains(this.auth)) - .withFailMessage( - "Principals should not contain the authentication after logout") - .isFalse(); + .withFailMessage("Principals should not contain the authentication after logout").isFalse(); } @Test @@ -118,25 +106,23 @@ public class SecurityContextLoginModuleTests { this.module.login(); fail("LoginException expected, the authentication is null in the SecurityContext"); } - catch (Exception e) { + catch (Exception ex) { } } @Test public void testNullAuthenticationInSecurityContextIgnored() throws Exception { this.module = new SecurityContextLoginModule(); - Map options = new HashMap<>(); options.put("ignoreMissingAuthentication", "true"); - this.module.initialize(this.subject, null, null, options); SecurityContextHolder.getContext().setAuthentication(null); - assertThat(this.module.login()).as("Should return false and ask to be ignored") - .isFalse(); + assertThat(this.module.login()).as("Should return false and ask to be ignored").isFalse(); } @Test public void testNullLogout() throws Exception { assertThat(this.module.logout()).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/TestAuthorityGranter.java b/core/src/test/java/org/springframework/security/authentication/jaas/TestAuthorityGranter.java index 2f82d9ea53..2a557097a3 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/TestAuthorityGranter.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/TestAuthorityGranter.java @@ -17,28 +17,22 @@ package org.springframework.security.authentication.jaas; import java.security.Principal; - import java.util.HashSet; import java.util.Set; -import org.springframework.security.authentication.jaas.AuthorityGranter; - /** - * * @author Ray Krueger */ public class TestAuthorityGranter implements AuthorityGranter { - // ~ Methods - // ======================================================================================================== + @Override public Set grant(Principal principal) { Set rtnSet = new HashSet<>(); - if (principal.getName().equals("TEST_PRINCIPAL")) { rtnSet.add("ROLE_TEST1"); rtnSet.add("ROLE_TEST2"); } - return rtnSet; } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/TestCallbackHandler.java b/core/src/test/java/org/springframework/security/authentication/jaas/TestCallbackHandler.java index a32494f5fd..645b684709 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/TestCallbackHandler.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/TestCallbackHandler.java @@ -16,25 +16,24 @@ package org.springframework.security.authentication.jaas; -import org.springframework.security.authentication.jaas.JaasAuthenticationCallbackHandler; -import org.springframework.security.core.Authentication; - import javax.security.auth.callback.Callback; import javax.security.auth.callback.TextInputCallback; +import org.springframework.security.core.Authentication; + /** * TestCallbackHandler * * @author Ray Krueger */ public class TestCallbackHandler implements JaasAuthenticationCallbackHandler { - // ~ Methods - // ======================================================================================================== + @Override public void handle(Callback callback, Authentication auth) { if (callback instanceof TextInputCallback) { TextInputCallback tic = (TextInputCallback) callback; tic.setText(auth.getPrincipal().toString()); } } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/TestLoginModule.java b/core/src/test/java/org/springframework/security/authentication/jaas/TestLoginModule.java index 53ff92c19b..b00267effc 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/TestLoginModule.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/TestLoginModule.java @@ -19,7 +19,11 @@ package org.springframework.security.authentication.jaas; import java.util.Map; import javax.security.auth.Subject; -import javax.security.auth.callback.*; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.TextInputCallback; import javax.security.auth.login.LoginException; import javax.security.auth.spi.LoginModule; @@ -27,62 +31,56 @@ import javax.security.auth.spi.LoginModule; * @author Ray Krueger */ public class TestLoginModule implements LoginModule { - // ~ Instance fields - // ================================================================================================ private String password; + private String user; + private Subject subject; - // ~ Methods - // ======================================================================================================== - + @Override public boolean abort() { return true; } + @Override public boolean commit() { return true; } + @Override @SuppressWarnings("unchecked") - public void initialize(Subject subject, CallbackHandler callbackHandler, - Map sharedState, Map options) { + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { this.subject = subject; - try { TextInputCallback textCallback = new TextInputCallback("prompt"); NameCallback nameCallback = new NameCallback("prompt"); PasswordCallback passwordCallback = new PasswordCallback("prompt", false); - - callbackHandler.handle(new Callback[] { textCallback, nameCallback, - passwordCallback }); - - password = new String(passwordCallback.getPassword()); - user = nameCallback.getName(); + callbackHandler.handle(new Callback[] { textCallback, nameCallback, passwordCallback }); + this.password = new String(passwordCallback.getPassword()); + this.user = nameCallback.getName(); } - catch (Exception e) { - throw new RuntimeException(e); + catch (Exception ex) { + throw new RuntimeException(ex); } } + @Override public boolean login() throws LoginException { - if (!user.equals("user")) { + if (!this.user.equals("user")) { throw new LoginException("Bad User"); } - - if (!password.equals("password")) { + if (!this.password.equals("password")) { throw new LoginException("Bad Password"); } - - subject.getPrincipals().add(() -> "TEST_PRINCIPAL"); - - subject.getPrincipals().add(() -> "NULL_PRINCIPAL"); - + this.subject.getPrincipals().add(() -> "TEST_PRINCIPAL"); + this.subject.getPrincipals().add(() -> "NULL_PRINCIPAL"); return true; } + @Override public boolean logout() { return true; } + } diff --git a/core/src/test/java/org/springframework/security/authentication/jaas/memory/InMemoryConfigurationTests.java b/core/src/test/java/org/springframework/security/authentication/jaas/memory/InMemoryConfigurationTests.java index 6384c0a922..c1f4f8be7f 100644 --- a/core/src/test/java/org/springframework/security/authentication/jaas/memory/InMemoryConfigurationTests.java +++ b/core/src/test/java/org/springframework/security/authentication/jaas/memory/InMemoryConfigurationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.authentication.jaas.memory; import java.lang.reflect.Method; @@ -37,25 +38,21 @@ import static org.assertj.core.api.Assertions.assertThat; public class InMemoryConfigurationTests { private AppConfigurationEntry[] defaultEntries; + private Map mappedEntries; @Before public void setUp() { - this.defaultEntries = new AppConfigurationEntry[] { new AppConfigurationEntry( - TestLoginModule.class.getName(), LoginModuleControlFlag.REQUIRED, - Collections.emptyMap()) }; - - this.mappedEntries = Collections.singletonMap( - "name", - new AppConfigurationEntry[] { new AppConfigurationEntry( - TestLoginModule.class.getName(), LoginModuleControlFlag.OPTIONAL, - Collections.emptyMap()) }); + this.defaultEntries = new AppConfigurationEntry[] { new AppConfigurationEntry(TestLoginModule.class.getName(), + LoginModuleControlFlag.REQUIRED, Collections.emptyMap()) }; + this.mappedEntries = Collections.singletonMap("name", + new AppConfigurationEntry[] { new AppConfigurationEntry(TestLoginModule.class.getName(), + LoginModuleControlFlag.OPTIONAL, Collections.emptyMap()) }); } @Test public void constructorNullDefault() { - assertThat(new InMemoryConfiguration((AppConfigurationEntry[]) null) - .getAppConfigurationEntry("name")).isNull(); + assertThat(new InMemoryConfiguration((AppConfigurationEntry[]) null).getAppConfigurationEntry("name")).isNull(); } @Test(expected = IllegalArgumentException.class) @@ -65,16 +62,14 @@ public class InMemoryConfigurationTests { @Test public void constructorEmptyMap() { - assertThat(new InMemoryConfiguration( - Collections.emptyMap()) - .getAppConfigurationEntry("name")).isNull(); + assertThat(new InMemoryConfiguration(Collections.emptyMap()) + .getAppConfigurationEntry("name")).isNull(); } @Test public void constructorEmptyMapNullDefault() { - assertThat(new InMemoryConfiguration( - Collections.emptyMap(), null) - .getAppConfigurationEntry("name")).isNull(); + assertThat(new InMemoryConfiguration(Collections.emptyMap(), null) + .getAppConfigurationEntry("name")).isNull(); } @Test(expected = IllegalArgumentException.class) @@ -84,20 +79,15 @@ public class InMemoryConfigurationTests { @Test public void nonnullDefault() { - InMemoryConfiguration configuration = new InMemoryConfiguration( - this.defaultEntries); - assertThat(configuration.getAppConfigurationEntry("name")) - .isEqualTo(this.defaultEntries); + InMemoryConfiguration configuration = new InMemoryConfiguration(this.defaultEntries); + assertThat(configuration.getAppConfigurationEntry("name")).isEqualTo(this.defaultEntries); } @Test public void mappedNonnullDefault() { - InMemoryConfiguration configuration = new InMemoryConfiguration( - this.mappedEntries, this.defaultEntries); - assertThat(this.defaultEntries) - .isEqualTo(configuration.getAppConfigurationEntry("missing")); - assertThat(this.mappedEntries.get("name")) - .isEqualTo(configuration.getAppConfigurationEntry("name")); + InMemoryConfiguration configuration = new InMemoryConfiguration(this.mappedEntries, this.defaultEntries); + assertThat(this.defaultEntries).isEqualTo(configuration.getAppConfigurationEntry("missing")); + assertThat(this.mappedEntries.get("name")).isEqualTo(configuration.getAppConfigurationEntry("name")); } @Test @@ -105,4 +95,5 @@ public class InMemoryConfigurationTests { Method method = InMemoryConfiguration.class.getDeclaredMethod("refresh"); assertThat(method.getDeclaringClass()).isEqualTo(InMemoryConfiguration.class); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImplTests.java b/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImplTests.java index a7f5f25732..a52f73ace0 100644 --- a/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImplTests.java +++ b/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationManagerImplTests.java @@ -16,59 +16,54 @@ package org.springframework.security.authentication.rcp; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import org.junit.Test; + import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests {@link RemoteAuthenticationManagerImpl}. * * @author Ben Alex */ public class RemoteAuthenticationManagerImplTests { - // ~ Methods - // ======================================================================================================== @Test(expected = RemoteAuthenticationException.class) public void testFailedAuthenticationReturnsRemoteAuthenticationException() { RemoteAuthenticationManagerImpl manager = new RemoteAuthenticationManagerImpl(); AuthenticationManager am = mock(AuthenticationManager.class); - when(am.authenticate(any(Authentication.class))).thenThrow( - new BadCredentialsException("")); + given(am.authenticate(any(Authentication.class))).willThrow(new BadCredentialsException("")); manager.setAuthenticationManager(am); - manager.attemptAuthentication("rod", "password"); } @Test public void testStartupChecksAuthenticationManagerSet() throws Exception { RemoteAuthenticationManagerImpl manager = new RemoteAuthenticationManagerImpl(); - try { manager.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - manager.setAuthenticationManager(mock(AuthenticationManager.class)); manager.afterPropertiesSet(); - } @Test public void testSuccessfulAuthentication() { RemoteAuthenticationManagerImpl manager = new RemoteAuthenticationManagerImpl(); AuthenticationManager am = mock(AuthenticationManager.class); - when(am.authenticate(any(Authentication.class))).thenReturn( - new TestingAuthenticationToken("u", "p", "A")); + given(am.authenticate(any(Authentication.class))).willReturn(new TestingAuthenticationToken("u", "p", "A")); manager.setAuthenticationManager(am); - manager.attemptAuthentication("rod", "password"); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProviderTests.java index c66f3b503f..42b37a73c8 100644 --- a/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/rcp/RemoteAuthenticationProviderTests.java @@ -34,59 +34,44 @@ import static org.assertj.core.api.Assertions.fail; * @author Ben Alex */ public class RemoteAuthenticationProviderTests { - // ~ Methods - // ======================================================================================================== @Test public void testExceptionsGetPassedBackToCaller() { RemoteAuthenticationProvider provider = new RemoteAuthenticationProvider(); - provider.setRemoteAuthenticationManager( - new MockRemoteAuthenticationManager(false)); - + provider.setRemoteAuthenticationManager(new MockRemoteAuthenticationManager(false)); try { - provider.authenticate( - new UsernamePasswordAuthenticationToken("rod", "password")); + provider.authenticate(new UsernamePasswordAuthenticationToken("rod", "password")); fail("Should have thrown RemoteAuthenticationException"); } catch (RemoteAuthenticationException expected) { - } } @Test public void testGettersSetters() { RemoteAuthenticationProvider provider = new RemoteAuthenticationProvider(); - provider.setRemoteAuthenticationManager( - new MockRemoteAuthenticationManager(true)); + provider.setRemoteAuthenticationManager(new MockRemoteAuthenticationManager(true)); assertThat(provider.getRemoteAuthenticationManager()).isNotNull(); } @Test public void testStartupChecksAuthenticationManagerSet() throws Exception { RemoteAuthenticationProvider provider = new RemoteAuthenticationProvider(); - try { provider.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - - provider.setRemoteAuthenticationManager( - new MockRemoteAuthenticationManager(true)); + provider.setRemoteAuthenticationManager(new MockRemoteAuthenticationManager(true)); provider.afterPropertiesSet(); - } @Test public void testSuccessfulAuthenticationCreatesObject() { RemoteAuthenticationProvider provider = new RemoteAuthenticationProvider(); - provider.setRemoteAuthenticationManager( - new MockRemoteAuthenticationManager(true)); - - Authentication result = provider - .authenticate(new UsernamePasswordAuthenticationToken("rod", "password")); + provider.setRemoteAuthenticationManager(new MockRemoteAuthenticationManager(true)); + Authentication result = provider.authenticate(new UsernamePasswordAuthenticationToken("rod", "password")); assertThat(result.getPrincipal()).isEqualTo("rod"); assertThat(result.getCredentials()).isEqualTo("password"); assertThat(AuthorityUtils.authorityListToSet(result.getAuthorities())).contains("foo"); @@ -95,16 +80,13 @@ public class RemoteAuthenticationProviderTests { @Test public void testNullCredentialsDoesNotCauseNullPointerException() { RemoteAuthenticationProvider provider = new RemoteAuthenticationProvider(); - provider.setRemoteAuthenticationManager( - new MockRemoteAuthenticationManager(false)); - + provider.setRemoteAuthenticationManager(new MockRemoteAuthenticationManager(false)); try { provider.authenticate(new UsernamePasswordAuthenticationToken("rod", null)); fail("Expected Exception"); } catch (RemoteAuthenticationException success) { } - } @Test @@ -113,18 +95,17 @@ public class RemoteAuthenticationProviderTests { assertThat(provider.supports(UsernamePasswordAuthenticationToken.class)).isTrue(); } - // ~ Inner Classes - // ================================================================================================== - private class MockRemoteAuthenticationManager implements RemoteAuthenticationManager { + private boolean grantAccess; MockRemoteAuthenticationManager(boolean grantAccess) { this.grantAccess = grantAccess; } - public Collection attemptAuthentication( - String username, String password) throws RemoteAuthenticationException { + @Override + public Collection attemptAuthentication(String username, String password) + throws RemoteAuthenticationException { if (this.grantAccess) { return AuthorityUtils.createAuthorityList("foo"); } @@ -132,5 +113,7 @@ public class RemoteAuthenticationProviderTests { throw new RemoteAuthenticationException("as requested"); } } + } + } diff --git a/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationProviderTests.java b/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationProviderTests.java index c6a948e913..169e9802cd 100644 --- a/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationProviderTests.java +++ b/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationProviderTests.java @@ -34,17 +34,12 @@ import static org.assertj.core.api.Assertions.fail; * @author Ben Alex */ public class RememberMeAuthenticationProviderTests { - // ~ Methods - // ======================================================================================================== + @Test public void testDetectsAnInvalidKey() { - RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider( - "qwerty"); - - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken( - "WRONG_KEY", "Test", + RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider("qwerty"); + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("WRONG_KEY", "Test", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); - try { aap.authenticate(token); fail("Should have thrown BadCredentialsException"); @@ -60,49 +55,39 @@ public class RememberMeAuthenticationProviderTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @Test public void testGettersSetters() throws Exception { - RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider( - "qwerty"); + RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider("qwerty"); aap.afterPropertiesSet(); assertThat(aap.getKey()).isEqualTo("qwerty"); } @Test public void testIgnoresClassesItDoesNotSupport() { - RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider( - "qwerty"); - - TestingAuthenticationToken token = new TestingAuthenticationToken("user", - "password", "ROLE_A"); + RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider("qwerty"); + TestingAuthenticationToken token = new TestingAuthenticationToken("user", "password", "ROLE_A"); assertThat(aap.supports(TestingAuthenticationToken.class)).isFalse(); - // Try it anyway assertThat(aap.authenticate(token)).isNull(); } @Test public void testNormalOperation() { - RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider( - "qwerty"); - - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("qwerty", - "Test", AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); - + RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider("qwerty"); + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("qwerty", "Test", + AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); Authentication result = aap.authenticate(token); - assertThat(token).isEqualTo(result); } @Test public void testSupports() { - RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider( - "qwerty"); + RememberMeAuthenticationProvider aap = new RememberMeAuthenticationProvider("qwerty"); assertThat(aap.supports(RememberMeAuthenticationToken.class)).isTrue(); assertThat(aap.supports(TestingAuthenticationToken.class)).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationTokenTests.java b/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationTokenTests.java index dcd7ac2c8c..6bdf73bd5d 100644 --- a/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationTokenTests.java +++ b/core/src/test/java/org/springframework/security/authentication/rememberme/RememberMeAuthenticationTokenTests.java @@ -16,29 +16,28 @@ package org.springframework.security.authentication.rememberme; - -import static org.assertj.core.api.Assertions.*; - import java.util.ArrayList; import java.util.List; import org.junit.Test; + import org.springframework.security.authentication.RememberMeAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link RememberMeAuthenticationToken}. * * @author Ben Alex */ public class RememberMeAuthenticationTokenTests { - private static final List ROLES_12 = AuthorityUtils - .createAuthorityList("ROLE_ONE", "ROLE_TWO"); - // ~ Methods - // ======================================================================================================== + private static final List ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); + @Test public void testConstructorRejectsNulls() { try { @@ -46,17 +45,13 @@ public class RememberMeAuthenticationTokenTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - try { new RememberMeAuthenticationToken("key", null, ROLES_12); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } - try { List authsContainingNull = new ArrayList<>(); authsContainingNull.add(null); @@ -64,25 +59,19 @@ public class RememberMeAuthenticationTokenTests { fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @Test public void testEqualsWhenEqual() { - RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); - RememberMeAuthenticationToken token2 = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); - + RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", "Test", ROLES_12); + RememberMeAuthenticationToken token2 = new RememberMeAuthenticationToken("key", "Test", ROLES_12); assertThat(token2).isEqualTo(token1); } @Test public void testGetters() { - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); - + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("key", "Test", ROLES_12); assertThat(token.getKeyHash()).isEqualTo("key".hashCode()); assertThat(token.getPrincipal()).isEqualTo("Test"); assertThat(token.getCredentials()).isEqualTo(""); @@ -93,40 +82,33 @@ public class RememberMeAuthenticationTokenTests { @Test public void testNotEqualsDueToAbstractParentEqualsCheck() { - RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); - RememberMeAuthenticationToken token2 = new RememberMeAuthenticationToken("key", - "DIFFERENT_PRINCIPAL", ROLES_12); - + RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", "Test", ROLES_12); + RememberMeAuthenticationToken token2 = new RememberMeAuthenticationToken("key", "DIFFERENT_PRINCIPAL", + ROLES_12); assertThat(token1.equals(token2)).isFalse(); } @Test public void testNotEqualsDueToDifferentAuthenticationClass() { - RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); - UsernamePasswordAuthenticationToken token2 = new UsernamePasswordAuthenticationToken( - "Test", "Password", ROLES_12); - + RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", "Test", ROLES_12); + UsernamePasswordAuthenticationToken token2 = new UsernamePasswordAuthenticationToken("Test", "Password", + ROLES_12); assertThat(token1.equals(token2)).isFalse(); } @Test public void testNotEqualsDueToKey() { - RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); - RememberMeAuthenticationToken token2 = new RememberMeAuthenticationToken( - "DIFFERENT_KEY", "Test", ROLES_12); - + RememberMeAuthenticationToken token1 = new RememberMeAuthenticationToken("key", "Test", ROLES_12); + RememberMeAuthenticationToken token2 = new RememberMeAuthenticationToken("DIFFERENT_KEY", "Test", ROLES_12); assertThat(token1.equals(token2)).isFalse(); } @Test public void testSetAuthenticatedIgnored() { - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("key", - "Test", ROLES_12); + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken("key", "Test", ROLES_12); assertThat(token.isAuthenticated()).isTrue(); token.setAuthenticated(false); assertThat(!token.isAuthenticated()).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManagerTests.java b/core/src/test/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManagerTests.java index cb05abb276..6f56f58518 100644 --- a/core/src/test/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authorization/AuthenticatedReactiveAuthorizationManagerTests.java @@ -20,14 +20,15 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.security.authentication.AnonymousAuthenticationToken; -import org.springframework.security.core.Authentication; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; + import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -35,51 +36,47 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class AuthenticatedReactiveAuthorizationManagerTests { + @Mock Authentication authentication; AuthenticatedReactiveAuthorizationManager manager = AuthenticatedReactiveAuthorizationManager - .authenticated(); + .authenticated(); @Test public void checkWhenAuthenticatedThenReturnTrue() { - when(authentication.isAuthenticated()).thenReturn(true); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + given(this.authentication.isAuthenticated()).willReturn(true); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isTrue(); } @Test public void checkWhenNotAuthenticatedThenReturnFalse() { - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenEmptyThenReturnFalse() { - boolean granted = manager.check(Mono.empty(), null).block().isGranted(); - + boolean granted = this.manager.check(Mono.empty(), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenAnonymousAuthenticatedThenReturnFalse() { AnonymousAuthenticationToken anonymousAuthenticationToken = mock(AnonymousAuthenticationToken.class); - - boolean granted = manager.check(Mono.just(anonymousAuthenticationToken), null).block().isGranted(); - + boolean granted = this.manager.check(Mono.just(anonymousAuthenticationToken), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenErrorThenError() { - Mono result = manager.check(Mono.error(new RuntimeException("ooops")), null); - - StepVerifier - .create(result) - .expectError() - .verify(); + Mono result = this.manager.check(Mono.error(new RuntimeException("ooops")), null); + // @formatter:off + StepVerifier.create(result) + .expectError() + .verify(); + // @formatter:on } + } diff --git a/core/src/test/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManagerTests.java b/core/src/test/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManagerTests.java index 2fb371fa35..1ed793a989 100644 --- a/core/src/test/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authorization/AuthorityReactiveAuthorizationManagerTests.java @@ -16,19 +16,20 @@ package org.springframework.security.authorization; +import java.util.Collections; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.Authentication; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import java.util.Collections; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -36,101 +37,86 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class AuthorityReactiveAuthorizationManagerTests { + @Mock Authentication authentication; - AuthorityReactiveAuthorizationManager manager = AuthorityReactiveAuthorizationManager - .hasAuthority("ADMIN"); + AuthorityReactiveAuthorizationManager manager = AuthorityReactiveAuthorizationManager.hasAuthority("ADMIN"); @Test public void checkWhenHasAuthorityAndNotAuthenticatedThenReturnFalse() { - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenHasAuthorityAndEmptyThenReturnFalse() { - boolean granted = manager.check(Mono.empty(), null).block().isGranted(); - + boolean granted = this.manager.check(Mono.empty(), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenHasAuthorityAndErrorThenError() { - Mono result = manager.check(Mono.error(new RuntimeException("ooops")), null); - - StepVerifier - .create(result) - .expectError() - .verify(); + Mono result = this.manager.check(Mono.error(new RuntimeException("ooops")), null); + // @formatter:off + StepVerifier.create(result) + .expectError() + .verify(); + // @formatter:on } @Test public void checkWhenHasAuthorityAndAuthenticatedAndNoAuthoritiesThenReturnFalse() { - when(authentication.isAuthenticated()).thenReturn(true); - when(authentication.getAuthorities()).thenReturn(Collections.emptyList()); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + given(this.authentication.isAuthenticated()).willReturn(true); + given(this.authentication.getAuthorities()).willReturn(Collections.emptyList()); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenHasAuthorityAndAuthenticatedAndWrongAuthoritiesThenReturnFalse() { - authentication = new TestingAuthenticationToken("rob", "secret", "ROLE_ADMIN"); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + this.authentication = new TestingAuthenticationToken("rob", "secret", "ROLE_ADMIN"); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenHasAuthorityAndAuthorizedThenReturnTrue() { - authentication = new TestingAuthenticationToken("rob", "secret", "ADMIN"); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + this.authentication = new TestingAuthenticationToken("rob", "secret", "ADMIN"); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isTrue(); } @Test public void checkWhenHasRoleAndAuthorizedThenReturnTrue() { - manager = AuthorityReactiveAuthorizationManager.hasRole("ADMIN"); - authentication = new TestingAuthenticationToken("rob", "secret", "ROLE_ADMIN"); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + this.manager = AuthorityReactiveAuthorizationManager.hasRole("ADMIN"); + this.authentication = new TestingAuthenticationToken("rob", "secret", "ROLE_ADMIN"); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isTrue(); } @Test public void checkWhenHasRoleAndNotAuthorizedThenReturnFalse() { - manager = AuthorityReactiveAuthorizationManager.hasRole("ADMIN"); - authentication = new TestingAuthenticationToken("rob", "secret", "ADMIN"); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + this.manager = AuthorityReactiveAuthorizationManager.hasRole("ADMIN"); + this.authentication = new TestingAuthenticationToken("rob", "secret", "ADMIN"); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isFalse(); } @Test public void checkWhenHasAnyRoleAndAuthorizedThenReturnTrue() { - manager = AuthorityReactiveAuthorizationManager.hasAnyRole("GENERAL", "USER", "TEST"); - authentication = new TestingAuthenticationToken("rob", "secret", "ROLE_USER", "ROLE_AUDITING", "ROLE_ADMIN"); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + this.manager = AuthorityReactiveAuthorizationManager.hasAnyRole("GENERAL", "USER", "TEST"); + this.authentication = new TestingAuthenticationToken("rob", "secret", "ROLE_USER", "ROLE_AUDITING", + "ROLE_ADMIN"); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isTrue(); } @Test public void checkWhenHasAnyRoleAndNotAuthorizedThenReturnFalse() { - manager = AuthorityReactiveAuthorizationManager.hasAnyRole("GENERAL", "USER", "TEST"); - authentication = new TestingAuthenticationToken("rob", "secret", "USER", "AUDITING", "ADMIN"); - - boolean granted = manager.check(Mono.just(authentication), null).block().isGranted(); - + this.manager = AuthorityReactiveAuthorizationManager.hasAnyRole("GENERAL", "USER", "TEST"); + this.authentication = new TestingAuthenticationToken("rob", "secret", "USER", "AUDITING", "ADMIN"); + boolean granted = this.manager.check(Mono.just(this.authentication), null).block().isGranted(); assertThat(granted).isFalse(); } @@ -171,4 +157,5 @@ public class AuthorityReactiveAuthorizationManagerTests { String authority2 = null; AuthorityReactiveAuthorizationManager.hasAnyAuthority(authority1, authority2); } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorServiceTests.java b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorServiceTests.java index a88d1dadab..f8fcbc48cb 100644 --- a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorServiceTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorServiceTests.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.concurrent; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +package org.springframework.security.concurrent; import java.util.Arrays; import java.util.List; @@ -29,6 +26,10 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.verify; + /** * Abstract class for testing {@link DelegatingSecurityContextExecutorService} which * allows customization of how {@link DelegatingSecurityContextExecutorService} and its @@ -39,10 +40,12 @@ import org.mockito.Mock; * @see CurrentDelegatingSecurityContextExecutorServiceTests * @see ExplicitDelegatingSecurityContextExecutorServiceTests */ -public abstract class AbstractDelegatingSecurityContextExecutorServiceTests extends - AbstractDelegatingSecurityContextExecutorTests { +public abstract class AbstractDelegatingSecurityContextExecutorServiceTests + extends AbstractDelegatingSecurityContextExecutorTests { + @Mock private Future expectedFutureObject; + @Mock private Object resultArg; @@ -50,9 +53,10 @@ public abstract class AbstractDelegatingSecurityContextExecutorServiceTests exte @Before public final void setUpExecutorService() { - executor = create(); + this.executor = create(); } + @Override @Test(expected = IllegalArgumentException.class) public void constructorNullDelegate() { new DelegatingSecurityContextExecutorService(null); @@ -60,112 +64,108 @@ public abstract class AbstractDelegatingSecurityContextExecutorServiceTests exte @Test public void shutdown() { - executor.shutdown(); - verify(delegate).shutdown(); + this.executor.shutdown(); + verify(this.delegate).shutdown(); } @Test public void shutdownNow() { - List result = executor.shutdownNow(); - verify(delegate).shutdownNow(); - assertThat(result).isEqualTo(delegate.shutdownNow()).isNotNull(); + List result = this.executor.shutdownNow(); + verify(this.delegate).shutdownNow(); + assertThat(result).isEqualTo(this.delegate.shutdownNow()).isNotNull(); } @Test public void isShutdown() { - boolean result = executor.isShutdown(); - verify(delegate).isShutdown(); - assertThat(result).isEqualTo(delegate.isShutdown()).isNotNull(); + boolean result = this.executor.isShutdown(); + verify(this.delegate).isShutdown(); + assertThat(result).isEqualTo(this.delegate.isShutdown()).isNotNull(); } @Test public void isTerminated() { - boolean result = executor.isTerminated(); - verify(delegate).isTerminated(); - assertThat(result).isEqualTo(delegate.isTerminated()).isNotNull(); + boolean result = this.executor.isTerminated(); + verify(this.delegate).isTerminated(); + assertThat(result).isEqualTo(this.delegate.isTerminated()).isNotNull(); } @Test public void awaitTermination() throws InterruptedException { - boolean result = executor.awaitTermination(1, TimeUnit.SECONDS); - verify(delegate).awaitTermination(1, TimeUnit.SECONDS); - assertThat(result).isEqualTo(delegate.awaitTermination(1, TimeUnit.SECONDS)) - .isNotNull(); + boolean result = this.executor.awaitTermination(1, TimeUnit.SECONDS); + verify(this.delegate).awaitTermination(1, TimeUnit.SECONDS); + assertThat(result).isEqualTo(this.delegate.awaitTermination(1, TimeUnit.SECONDS)).isNotNull(); } @Test public void submitCallable() { - when(delegate.submit(wrappedCallable)).thenReturn(expectedFutureObject); - Future result = executor.submit(callable); - verify(delegate).submit(wrappedCallable); - assertThat(result).isEqualTo(expectedFutureObject); + given(this.delegate.submit(this.wrappedCallable)).willReturn(this.expectedFutureObject); + Future result = this.executor.submit(this.callable); + verify(this.delegate).submit(this.wrappedCallable); + assertThat(result).isEqualTo(this.expectedFutureObject); } @Test public void submitRunnableWithResult() { - when(delegate.submit(wrappedRunnable, resultArg)) - .thenReturn(expectedFutureObject); - Future result = executor.submit(runnable, resultArg); - verify(delegate).submit(wrappedRunnable, resultArg); - assertThat(result).isEqualTo(expectedFutureObject); + given(this.delegate.submit(this.wrappedRunnable, this.resultArg)).willReturn(this.expectedFutureObject); + Future result = this.executor.submit(this.runnable, this.resultArg); + verify(this.delegate).submit(this.wrappedRunnable, this.resultArg); + assertThat(result).isEqualTo(this.expectedFutureObject); } @Test @SuppressWarnings("unchecked") public void submitRunnable() { - when((Future) delegate.submit(wrappedRunnable)).thenReturn( - expectedFutureObject); - Future result = executor.submit(runnable); - verify(delegate).submit(wrappedRunnable); - assertThat(result).isEqualTo(expectedFutureObject); + given((Future) this.delegate.submit(this.wrappedRunnable)).willReturn(this.expectedFutureObject); + Future result = this.executor.submit(this.runnable); + verify(this.delegate).submit(this.wrappedRunnable); + assertThat(result).isEqualTo(this.expectedFutureObject); } @Test @SuppressWarnings("unchecked") public void invokeAll() throws Exception { - List> exectedResult = Arrays.asList(expectedFutureObject); - List> wrappedCallables = Arrays.asList(wrappedCallable); - when(delegate.invokeAll(wrappedCallables)).thenReturn(exectedResult); - List> result = executor.invokeAll(Arrays.asList(callable)); - verify(delegate).invokeAll(wrappedCallables); + List> exectedResult = Arrays.asList(this.expectedFutureObject); + List> wrappedCallables = Arrays.asList(this.wrappedCallable); + given(this.delegate.invokeAll(wrappedCallables)).willReturn(exectedResult); + List> result = this.executor.invokeAll(Arrays.asList(this.callable)); + verify(this.delegate).invokeAll(wrappedCallables); assertThat(result).isEqualTo(exectedResult); } @Test @SuppressWarnings("unchecked") public void invokeAllTimeout() throws Exception { - List> exectedResult = Arrays.asList(expectedFutureObject); - List> wrappedCallables = Arrays.asList(wrappedCallable); - when(delegate.invokeAll(wrappedCallables, 1, TimeUnit.SECONDS)).thenReturn( - exectedResult); - List> result = executor.invokeAll(Arrays.asList(callable), 1, - TimeUnit.SECONDS); - verify(delegate).invokeAll(wrappedCallables, 1, TimeUnit.SECONDS); + List> exectedResult = Arrays.asList(this.expectedFutureObject); + List> wrappedCallables = Arrays.asList(this.wrappedCallable); + given(this.delegate.invokeAll(wrappedCallables, 1, TimeUnit.SECONDS)).willReturn(exectedResult); + List> result = this.executor.invokeAll(Arrays.asList(this.callable), 1, TimeUnit.SECONDS); + verify(this.delegate).invokeAll(wrappedCallables, 1, TimeUnit.SECONDS); assertThat(result).isEqualTo(exectedResult); } @Test @SuppressWarnings("unchecked") public void invokeAny() throws Exception { - List> exectedResult = Arrays.asList(expectedFutureObject); - List> wrappedCallables = Arrays.asList(wrappedCallable); - when(delegate.invokeAny(wrappedCallables)).thenReturn(exectedResult); - Object result = executor.invokeAny(Arrays.asList(callable)); - verify(delegate).invokeAny(wrappedCallables); + List> exectedResult = Arrays.asList(this.expectedFutureObject); + List> wrappedCallables = Arrays.asList(this.wrappedCallable); + given(this.delegate.invokeAny(wrappedCallables)).willReturn(exectedResult); + Object result = this.executor.invokeAny(Arrays.asList(this.callable)); + verify(this.delegate).invokeAny(wrappedCallables); assertThat(result).isEqualTo(exectedResult); } @Test @SuppressWarnings("unchecked") public void invokeAnyTimeout() throws Exception { - List> exectedResult = Arrays.asList(expectedFutureObject); - List> wrappedCallables = Arrays.asList(wrappedCallable); - when(delegate.invokeAny(wrappedCallables, 1, TimeUnit.SECONDS)).thenReturn( - exectedResult); - Object result = executor.invokeAny(Arrays.asList(callable), 1, TimeUnit.SECONDS); - verify(delegate).invokeAny(wrappedCallables, 1, TimeUnit.SECONDS); + List> exectedResult = Arrays.asList(this.expectedFutureObject); + List> wrappedCallables = Arrays.asList(this.wrappedCallable); + given(this.delegate.invokeAny(wrappedCallables, 1, TimeUnit.SECONDS)).willReturn(exectedResult); + Object result = this.executor.invokeAny(Arrays.asList(this.callable), 1, TimeUnit.SECONDS); + verify(this.delegate).invokeAny(wrappedCallables, 1, TimeUnit.SECONDS); assertThat(result).isEqualTo(exectedResult); } + @Override protected abstract DelegatingSecurityContextExecutorService create(); + } diff --git a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorTests.java b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorTests.java index 3e3dfc5d4e..073670a93d 100644 --- a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextExecutorTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.concurrent; -import static org.mockito.Mockito.verify; +package org.springframework.security.concurrent; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -23,6 +22,8 @@ import java.util.concurrent.ScheduledExecutorService; import org.junit.Test; import org.mockito.Mock; +import static org.mockito.Mockito.verify; + /** * Abstract class for testing {@link DelegatingSecurityContextExecutor} which allows * customization of how {@link DelegatingSecurityContextExecutor} and its mocks are @@ -33,32 +34,30 @@ import org.mockito.Mock; * @see CurrentDelegatingSecurityContextExecutorTests * @see ExplicitDelegatingSecurityContextExecutorTests */ -public abstract class AbstractDelegatingSecurityContextExecutorTests extends - AbstractDelegatingSecurityContextTestSupport { +public abstract class AbstractDelegatingSecurityContextExecutorTests + extends AbstractDelegatingSecurityContextTestSupport { + @Mock protected ScheduledExecutorService delegate; private DelegatingSecurityContextExecutor executor; - // --- constructor --- - @Test(expected = IllegalArgumentException.class) public void constructorNullDelegate() { new DelegatingSecurityContextExecutor(null); } - // --- execute --- - @Test public void execute() { - executor = create(); - executor.execute(runnable); - verify(getExecutor()).execute(wrappedRunnable); + this.executor = create(); + this.executor.execute(this.runnable); + verify(getExecutor()).execute(this.wrappedRunnable); } protected Executor getExecutor() { - return delegate; + return this.delegate; } protected abstract DelegatingSecurityContextExecutor create(); + } diff --git a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextScheduledExecutorServiceTests.java b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextScheduledExecutorServiceTests.java index ec0d114bb7..25dc94c333 100644 --- a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextScheduledExecutorServiceTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextScheduledExecutorServiceTests.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.concurrent; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +package org.springframework.security.concurrent; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -26,6 +23,10 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import static org.assertj.core.api.Assertions.assertThatObject; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.verify; + /** * Abstract class for testing {@link DelegatingSecurityContextScheduledExecutorService} * which allows customization of how @@ -38,6 +39,7 @@ import org.mockito.Mock; */ public abstract class AbstractDelegatingSecurityContextScheduledExecutorServiceTests extends AbstractDelegatingSecurityContextExecutorServiceTests { + @Mock private ScheduledFuture expectedResult; @@ -45,55 +47,48 @@ public abstract class AbstractDelegatingSecurityContextScheduledExecutorServiceT @Before public final void setUpExecutor() { - executor = create(); + this.executor = create(); } @Test @SuppressWarnings("unchecked") public void scheduleRunnable() { - when( - (ScheduledFuture) delegate.schedule(wrappedRunnable, 1, - TimeUnit.SECONDS)).thenReturn(expectedResult); - ScheduledFuture result = executor.schedule(runnable, 1, TimeUnit.SECONDS); - assertThat(result).isEqualTo(expectedResult); - verify(delegate).schedule(wrappedRunnable, 1, TimeUnit.SECONDS); + given((ScheduledFuture) this.delegate.schedule(this.wrappedRunnable, 1, TimeUnit.SECONDS)) + .willReturn(this.expectedResult); + ScheduledFuture result = this.executor.schedule(this.runnable, 1, TimeUnit.SECONDS); + assertThatObject(result).isEqualTo(this.expectedResult); + verify(this.delegate).schedule(this.wrappedRunnable, 1, TimeUnit.SECONDS); } @Test public void scheduleCallable() { - when( - delegate.schedule(wrappedCallable, 1, - TimeUnit.SECONDS)).thenReturn(expectedResult); - ScheduledFuture result = executor.schedule(callable, 1, TimeUnit.SECONDS); - assertThat(result).isEqualTo(expectedResult); - verify(delegate).schedule(wrappedCallable, 1, TimeUnit.SECONDS); + given(this.delegate.schedule(this.wrappedCallable, 1, TimeUnit.SECONDS)).willReturn(this.expectedResult); + ScheduledFuture result = this.executor.schedule(this.callable, 1, TimeUnit.SECONDS); + assertThatObject(result).isEqualTo(this.expectedResult); + verify(this.delegate).schedule(this.wrappedCallable, 1, TimeUnit.SECONDS); } @Test @SuppressWarnings("unchecked") public void scheduleAtFixedRate() { - when( - (ScheduledFuture) delegate.scheduleAtFixedRate(wrappedRunnable, - 1, 2, TimeUnit.SECONDS)).thenReturn(expectedResult); - ScheduledFuture result = executor.scheduleAtFixedRate(runnable, 1, 2, - TimeUnit.SECONDS); - assertThat(result).isEqualTo(expectedResult); - verify(delegate).scheduleAtFixedRate(wrappedRunnable, 1, 2, TimeUnit.SECONDS); + given((ScheduledFuture) this.delegate.scheduleAtFixedRate(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS)) + .willReturn(this.expectedResult); + ScheduledFuture result = this.executor.scheduleAtFixedRate(this.runnable, 1, 2, TimeUnit.SECONDS); + assertThatObject(result).isEqualTo(this.expectedResult); + verify(this.delegate).scheduleAtFixedRate(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS); } @Test @SuppressWarnings("unchecked") public void scheduleWithFixedDelay() { - when( - (ScheduledFuture) delegate.scheduleWithFixedDelay( - wrappedRunnable, 1, 2, TimeUnit.SECONDS)).thenReturn( - expectedResult); - ScheduledFuture result = executor.scheduleWithFixedDelay(runnable, 1, 2, - TimeUnit.SECONDS); - assertThat(result).isEqualTo(expectedResult); - verify(delegate).scheduleWithFixedDelay(wrappedRunnable, 1, 2, TimeUnit.SECONDS); + given((ScheduledFuture) this.delegate.scheduleWithFixedDelay(this.wrappedRunnable, 1, 2, + TimeUnit.SECONDS)).willReturn(this.expectedResult); + ScheduledFuture result = this.executor.scheduleWithFixedDelay(this.runnable, 1, 2, TimeUnit.SECONDS); + assertThatObject(result).isEqualTo(this.expectedResult); + verify(this.delegate).scheduleWithFixedDelay(this.wrappedRunnable, 1, 2, TimeUnit.SECONDS); } @Override protected abstract DelegatingSecurityContextScheduledExecutorService create(); + } diff --git a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java index cead9643bb..8e1de868f0 100644 --- a/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java +++ b/core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.concurrent; -import static org.mockito.ArgumentMatchers.eq; -import static org.powermock.api.mockito.PowerMockito.doReturn; -import static org.powermock.api.mockito.PowerMockito.spy; +package org.springframework.security.concurrent; import java.util.concurrent.Callable; @@ -27,11 +24,15 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import static org.mockito.ArgumentMatchers.eq; + /** * Abstract base class for testing classes that extend * {@link AbstractDelegatingSecurityContextSupport} @@ -41,9 +42,9 @@ import org.springframework.security.core.context.SecurityContextHolder; * */ @RunWith(PowerMockRunner.class) -@PrepareForTest({ DelegatingSecurityContextRunnable.class, - DelegatingSecurityContextCallable.class }) +@PrepareForTest({ DelegatingSecurityContextRunnable.class, DelegatingSecurityContextCallable.class }) public abstract class AbstractDelegatingSecurityContextTestSupport { + @Mock protected SecurityContext securityContext; @@ -66,30 +67,31 @@ public abstract class AbstractDelegatingSecurityContextTestSupport { protected Runnable wrappedRunnable; public final void explicitSecurityContextPowermockSetup() throws Exception { - spy(DelegatingSecurityContextCallable.class); - doReturn(wrappedCallable).when(DelegatingSecurityContextCallable.class, "create", - eq(callable), securityContextCaptor.capture()); - spy(DelegatingSecurityContextRunnable.class); - doReturn(wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create", - eq(runnable), securityContextCaptor.capture()); + PowerMockito.spy(DelegatingSecurityContextCallable.class); + PowerMockito.doReturn(this.wrappedCallable).when(DelegatingSecurityContextCallable.class, "create", + eq(this.callable), this.securityContextCaptor.capture()); + PowerMockito.spy(DelegatingSecurityContextRunnable.class); + PowerMockito.doReturn(this.wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create", + eq(this.runnable), this.securityContextCaptor.capture()); } public final void currentSecurityContextPowermockSetup() throws Exception { - spy(DelegatingSecurityContextCallable.class); - doReturn(wrappedCallable).when(DelegatingSecurityContextCallable.class, "create", - callable, null); - spy(DelegatingSecurityContextRunnable.class); - doReturn(wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create", - runnable, null); + PowerMockito.spy(DelegatingSecurityContextCallable.class); + PowerMockito.doReturn(this.wrappedCallable).when(DelegatingSecurityContextCallable.class, "create", + this.callable, null); + PowerMockito.spy(DelegatingSecurityContextRunnable.class); + PowerMockito.doReturn(this.wrappedRunnable).when(DelegatingSecurityContextRunnable.class, "create", + this.runnable, null); } @Before public final void setContext() { - SecurityContextHolder.setContext(currentSecurityContext); + SecurityContextHolder.setContext(this.currentSecurityContext); } @After public final void clearContext() { SecurityContextHolder.clearContext(); } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorServiceTests.java b/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorServiceTests.java index 6e757524ff..7579e28448 100644 --- a/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorServiceTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorServiceTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.junit.Before; +import org.springframework.security.core.context.SecurityContext; + /** * Tests using the current {@link SecurityContext} on * {@link DelegatingSecurityContextExecutorService} @@ -25,8 +28,8 @@ import org.junit.Before; * @since 3.2 * */ -public class CurrentDelegatingSecurityContextExecutorServiceTests extends - AbstractDelegatingSecurityContextExecutorServiceTests { +public class CurrentDelegatingSecurityContextExecutorServiceTests + extends AbstractDelegatingSecurityContextExecutorServiceTests { @Before public void setUp() throws Exception { @@ -35,6 +38,7 @@ public class CurrentDelegatingSecurityContextExecutorServiceTests extends @Override protected DelegatingSecurityContextExecutorService create() { - return new DelegatingSecurityContextExecutorService(delegate); + return new DelegatingSecurityContextExecutorService(this.delegate); } -} \ No newline at end of file + +} diff --git a/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorTests.java b/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorTests.java index 6e25d87ef9..65a597aa9b 100644 --- a/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextExecutorTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.junit.Before; +import org.springframework.security.core.context.SecurityContext; + /** * Tests using the current {@link SecurityContext} on * {@link DelegatingSecurityContextExecutor} @@ -25,8 +28,7 @@ import org.junit.Before; * @since 3.2 * */ -public class CurrentDelegatingSecurityContextExecutorTests extends - AbstractDelegatingSecurityContextExecutorTests { +public class CurrentDelegatingSecurityContextExecutorTests extends AbstractDelegatingSecurityContextExecutorTests { @Before public void setUp() throws Exception { @@ -37,4 +39,5 @@ public class CurrentDelegatingSecurityContextExecutorTests extends protected DelegatingSecurityContextExecutor create() { return new DelegatingSecurityContextExecutor(getExecutor()); } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextScheduledExecutorServiceTests.java b/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextScheduledExecutorServiceTests.java index 29b19cc47b..b3eb423d54 100644 --- a/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextScheduledExecutorServiceTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/CurrentDelegatingSecurityContextScheduledExecutorServiceTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.junit.Before; +import org.springframework.security.core.context.SecurityContext; + /** * Tests using the current {@link SecurityContext} on * {@link DelegatingSecurityContextScheduledExecutorService} @@ -25,8 +28,8 @@ import org.junit.Before; * @since 3.2 * */ -public class CurrentDelegatingSecurityContextScheduledExecutorServiceTests extends - AbstractDelegatingSecurityContextScheduledExecutorServiceTests { +public class CurrentDelegatingSecurityContextScheduledExecutorServiceTests + extends AbstractDelegatingSecurityContextScheduledExecutorServiceTests { @Before public void setUp() throws Exception { @@ -35,7 +38,7 @@ public class CurrentDelegatingSecurityContextScheduledExecutorServiceTests exten @Override protected DelegatingSecurityContextScheduledExecutorService create() { - return new DelegatingSecurityContextScheduledExecutorService(delegate); + return new DelegatingSecurityContextScheduledExecutorService(this.delegate); } } diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java index f121c47e91..1783ee4806 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.concurrent; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +package org.springframework.security.concurrent; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; @@ -32,20 +29,27 @@ import org.mockito.Mock; import org.mockito.internal.stubbing.answers.Returns; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.verify; + /** - * * @author Rob Winch * @since 3.2 */ @RunWith(MockitoJUnitRunner.class) public class DelegatingSecurityContextCallableTests { + @Mock private Callable delegate; + @Mock private SecurityContext securityContext; + @Mock private Object callableResult; @@ -58,15 +62,16 @@ public class DelegatingSecurityContextCallableTests { @Before @SuppressWarnings("serial") public void setUp() throws Exception { - originalSecurityContext = SecurityContextHolder.createEmptyContext(); - when(delegate.call()).thenAnswer(new Returns(callableResult) { + this.originalSecurityContext = SecurityContextHolder.createEmptyContext(); + given(this.delegate.call()).willAnswer(new Returns(this.callableResult) { @Override public Object answer(InvocationOnMock invocation) throws Throwable { - assertThat(SecurityContextHolder.getContext()).isEqualTo(securityContext); + assertThat(SecurityContextHolder.getContext()) + .isEqualTo(DelegatingSecurityContextCallableTests.this.securityContext); return super.answer(invocation); } }); - executor = Executors.newFixedThreadPool(1); + this.executor = Executors.newFixedThreadPool(1); } @After @@ -74,8 +79,6 @@ public class DelegatingSecurityContextCallableTests { SecurityContextHolder.clearContext(); } - // --- constructor --- - @Test(expected = IllegalArgumentException.class) public void constructorNullDelegate() { new DelegatingSecurityContextCallable<>(null); @@ -83,7 +86,7 @@ public class DelegatingSecurityContextCallableTests { @Test(expected = IllegalArgumentException.class) public void constructorNullDelegateNonNullSecurityContext() { - new DelegatingSecurityContextCallable<>(null, securityContext); + new DelegatingSecurityContextCallable<>(null, this.securityContext); } @Test(expected = IllegalArgumentException.class) @@ -93,42 +96,36 @@ public class DelegatingSecurityContextCallableTests { @Test(expected = IllegalArgumentException.class) public void constructorNullSecurityContext() { - new DelegatingSecurityContextCallable<>(delegate, null); + new DelegatingSecurityContextCallable<>(this.delegate, null); } - // --- call --- - @Test public void call() throws Exception { - callable = new DelegatingSecurityContextCallable<>(delegate, - securityContext); - assertWrapped(callable); + this.callable = new DelegatingSecurityContextCallable<>(this.delegate, this.securityContext); + assertWrapped(this.callable); } @Test public void callDefaultSecurityContext() throws Exception { - SecurityContextHolder.setContext(securityContext); - callable = new DelegatingSecurityContextCallable<>(delegate); + SecurityContextHolder.setContext(this.securityContext); + this.callable = new DelegatingSecurityContextCallable<>(this.delegate); SecurityContextHolder.clearContext(); // ensure callable is what sets up the // SecurityContextHolder - assertWrapped(callable); + assertWrapped(this.callable); } // SEC-3031 @Test public void callOnSameThread() throws Exception { - originalSecurityContext = securityContext; - SecurityContextHolder.setContext(originalSecurityContext); - callable = new DelegatingSecurityContextCallable<>(delegate, - securityContext); - assertWrapped(callable.call()); + this.originalSecurityContext = this.securityContext; + SecurityContextHolder.setContext(this.originalSecurityContext); + this.callable = new DelegatingSecurityContextCallable<>(this.delegate, this.securityContext); + assertWrapped(this.callable.call()); } - // --- create --- - @Test(expected = IllegalArgumentException.class) public void createNullDelegate() { - DelegatingSecurityContextCallable.create(null, securityContext); + DelegatingSecurityContextCallable.create(null, this.securityContext); } @Test(expected = IllegalArgumentException.class) @@ -138,37 +135,34 @@ public class DelegatingSecurityContextCallableTests { @Test public void createNullSecurityContext() throws Exception { - SecurityContextHolder.setContext(securityContext); - callable = DelegatingSecurityContextCallable.create(delegate, null); + SecurityContextHolder.setContext(this.securityContext); + this.callable = DelegatingSecurityContextCallable.create(this.delegate, null); SecurityContextHolder.clearContext(); // ensure callable is what sets up the // SecurityContextHolder - assertWrapped(callable); + assertWrapped(this.callable); } @Test public void create() throws Exception { - callable = DelegatingSecurityContextCallable.create(delegate, securityContext); - assertWrapped(callable); + this.callable = DelegatingSecurityContextCallable.create(this.delegate, this.securityContext); + assertWrapped(this.callable); } - // --- toString - // SEC-2682 @Test public void toStringDelegates() { - callable = new DelegatingSecurityContextCallable<>(delegate, - securityContext); - assertThat(callable.toString()).isEqualTo(delegate.toString()); + this.callable = new DelegatingSecurityContextCallable<>(this.delegate, this.securityContext); + assertThat(this.callable.toString()).isEqualTo(this.delegate.toString()); } private void assertWrapped(Callable callable) throws Exception { - Future submit = executor.submit(callable); + Future submit = this.executor.submit(callable); assertWrapped(submit.get()); } private void assertWrapped(Object callableResult) throws Exception { - verify(delegate).call(); - assertThat(SecurityContextHolder.getContext()).isEqualTo( - originalSecurityContext); + verify(this.delegate).call(); + assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext); } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java index 0540b568a9..3c47bc6416 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.concurrent; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.verify; +package org.springframework.security.concurrent; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -30,22 +27,29 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer; + import org.springframework.core.task.SyncTaskExecutor; import org.springframework.core.task.support.ExecutorServiceAdapter; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.Mockito.verify; + /** - * * @author Rob Winch * @since 3.2 */ @RunWith(MockitoJUnitRunner.class) public class DelegatingSecurityContextRunnableTests { + @Mock private Runnable delegate; + @Mock private SecurityContext securityContext; + @Mock private Object callableResult; @@ -57,13 +61,12 @@ public class DelegatingSecurityContextRunnableTests { @Before public void setUp() { - originalSecurityContext = SecurityContextHolder.createEmptyContext(); - doAnswer((Answer) invocation -> { - assertThat(SecurityContextHolder.getContext()).isEqualTo(securityContext); + this.originalSecurityContext = SecurityContextHolder.createEmptyContext(); + willAnswer((Answer) (invocation) -> { + assertThat(SecurityContextHolder.getContext()).isEqualTo(this.securityContext); return null; - }).when(delegate).run(); - - executor = Executors.newFixedThreadPool(1); + }).given(this.delegate).run(); + this.executor = Executors.newFixedThreadPool(1); } @After @@ -71,8 +74,6 @@ public class DelegatingSecurityContextRunnableTests { SecurityContextHolder.clearContext(); } - // --- constructor --- - @Test(expected = IllegalArgumentException.class) public void constructorNullDelegate() { new DelegatingSecurityContextRunnable(null); @@ -80,7 +81,7 @@ public class DelegatingSecurityContextRunnableTests { @Test(expected = IllegalArgumentException.class) public void constructorNullDelegateNonNullSecurityContext() { - new DelegatingSecurityContextRunnable(null, securityContext); + new DelegatingSecurityContextRunnable(null, this.securityContext); } @Test(expected = IllegalArgumentException.class) @@ -90,42 +91,37 @@ public class DelegatingSecurityContextRunnableTests { @Test(expected = IllegalArgumentException.class) public void constructorNullSecurityContext() { - new DelegatingSecurityContextRunnable(delegate, null); + new DelegatingSecurityContextRunnable(this.delegate, null); } - // --- run --- - @Test public void call() throws Exception { - runnable = new DelegatingSecurityContextRunnable(delegate, securityContext); - assertWrapped(runnable); + this.runnable = new DelegatingSecurityContextRunnable(this.delegate, this.securityContext); + assertWrapped(this.runnable); } @Test public void callDefaultSecurityContext() throws Exception { - SecurityContextHolder.setContext(securityContext); - runnable = new DelegatingSecurityContextRunnable(delegate); + SecurityContextHolder.setContext(this.securityContext); + this.runnable = new DelegatingSecurityContextRunnable(this.delegate); SecurityContextHolder.clearContext(); // ensure runnable is what sets up the // SecurityContextHolder - assertWrapped(runnable); + assertWrapped(this.runnable); } // SEC-3031 @Test public void callOnSameThread() throws Exception { - originalSecurityContext = securityContext; - SecurityContextHolder.setContext(originalSecurityContext); - executor = synchronousExecutor(); - runnable = new DelegatingSecurityContextRunnable(delegate, - securityContext); - assertWrapped(runnable); + this.originalSecurityContext = this.securityContext; + SecurityContextHolder.setContext(this.originalSecurityContext); + this.executor = synchronousExecutor(); + this.runnable = new DelegatingSecurityContextRunnable(this.delegate, this.securityContext); + assertWrapped(this.runnable); } - // --- create --- - @Test(expected = IllegalArgumentException.class) public void createNullDelegate() { - DelegatingSecurityContextRunnable.create(null, securityContext); + DelegatingSecurityContextRunnable.create(null, this.securityContext); } @Test(expected = IllegalArgumentException.class) @@ -135,37 +131,35 @@ public class DelegatingSecurityContextRunnableTests { @Test public void createNullSecurityContext() throws Exception { - SecurityContextHolder.setContext(securityContext); - runnable = DelegatingSecurityContextRunnable.create(delegate, null); + SecurityContextHolder.setContext(this.securityContext); + this.runnable = DelegatingSecurityContextRunnable.create(this.delegate, null); SecurityContextHolder.clearContext(); // ensure runnable is what sets up the // SecurityContextHolder - assertWrapped(runnable); + assertWrapped(this.runnable); } @Test public void create() throws Exception { - runnable = DelegatingSecurityContextRunnable.create(delegate, securityContext); - assertWrapped(runnable); + this.runnable = DelegatingSecurityContextRunnable.create(this.delegate, this.securityContext); + assertWrapped(this.runnable); } - // --- toString - // SEC-2682 @Test public void toStringDelegates() { - runnable = new DelegatingSecurityContextRunnable(delegate, securityContext); - assertThat(runnable.toString()).isEqualTo(delegate.toString()); + this.runnable = new DelegatingSecurityContextRunnable(this.delegate, this.securityContext); + assertThat(this.runnable.toString()).isEqualTo(this.delegate.toString()); } private void assertWrapped(Runnable runnable) throws Exception { - Future submit = executor.submit(runnable); + Future submit = this.executor.submit(runnable); submit.get(); - verify(delegate).run(); - assertThat(SecurityContextHolder.getContext()).isEqualTo( - originalSecurityContext); + verify(this.delegate).run(); + assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext); } private static ExecutorService synchronousExecutor() { return new ExecutorServiceAdapter(new SyncTaskExecutor()); } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextSupportTests.java b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextSupportTests.java index f18ecb1c9c..f311782752 100644 --- a/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextSupportTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextSupportTests.java @@ -13,57 +13,60 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; +import org.junit.Test; + +import org.springframework.security.core.context.SecurityContext; + import static org.assertj.core.api.Assertions.assertThat; -import org.junit.Test; -import org.springframework.security.core.context.SecurityContext; - /** - * * @author Rob Winch * @since 3.2 * */ -public class DelegatingSecurityContextSupportTests extends - AbstractDelegatingSecurityContextTestSupport { +public class DelegatingSecurityContextSupportTests extends AbstractDelegatingSecurityContextTestSupport { + private AbstractDelegatingSecurityContextSupport support; @Test public void wrapCallable() throws Exception { explicitSecurityContextPowermockSetup(); - support = new ConcreteDelegatingSecurityContextSupport(securityContext); - assertThat(support.wrap(callable)).isSameAs(wrappedCallable); - assertThat(securityContextCaptor.getValue()).isSameAs(securityContext); + this.support = new ConcreteDelegatingSecurityContextSupport(this.securityContext); + assertThat(this.support.wrap(this.callable)).isSameAs(this.wrappedCallable); + assertThat(this.securityContextCaptor.getValue()).isSameAs(this.securityContext); } @Test public void wrapCallableNullSecurityContext() throws Exception { currentSecurityContextPowermockSetup(); - support = new ConcreteDelegatingSecurityContextSupport(null); - assertThat(support.wrap(callable)).isSameAs(wrappedCallable); + this.support = new ConcreteDelegatingSecurityContextSupport(null); + assertThat(this.support.wrap(this.callable)).isSameAs(this.wrappedCallable); } @Test public void wrapRunnable() throws Exception { explicitSecurityContextPowermockSetup(); - support = new ConcreteDelegatingSecurityContextSupport(securityContext); - assertThat(support.wrap(runnable)).isSameAs(wrappedRunnable); - assertThat(securityContextCaptor.getValue()).isSameAs(securityContext); + this.support = new ConcreteDelegatingSecurityContextSupport(this.securityContext); + assertThat(this.support.wrap(this.runnable)).isSameAs(this.wrappedRunnable); + assertThat(this.securityContextCaptor.getValue()).isSameAs(this.securityContext); } @Test public void wrapRunnableNullSecurityContext() throws Exception { currentSecurityContextPowermockSetup(); - support = new ConcreteDelegatingSecurityContextSupport(null); - assertThat(support.wrap(runnable)).isSameAs(wrappedRunnable); + this.support = new ConcreteDelegatingSecurityContextSupport(null); + assertThat(this.support.wrap(this.runnable)).isSameAs(this.wrappedRunnable); } - private static class ConcreteDelegatingSecurityContextSupport extends - AbstractDelegatingSecurityContextSupport { + private static class ConcreteDelegatingSecurityContextSupport extends AbstractDelegatingSecurityContextSupport { + ConcreteDelegatingSecurityContextSupport(SecurityContext securityContext) { super(securityContext); } + } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorServiceTests.java b/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorServiceTests.java index 46f4d2be28..e2b2681e5c 100644 --- a/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorServiceTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorServiceTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.junit.Before; +import org.springframework.security.core.context.SecurityContext; + /** * Tests Explicitly specifying the {@link SecurityContext} on * {@link DelegatingSecurityContextExecutorService} @@ -25,8 +28,8 @@ import org.junit.Before; * @since 3.2 * */ -public class ExplicitDelegatingSecurityContextExecutorServiceTests extends - AbstractDelegatingSecurityContextExecutorServiceTests { +public class ExplicitDelegatingSecurityContextExecutorServiceTests + extends AbstractDelegatingSecurityContextExecutorServiceTests { @Before public void setUp() throws Exception { @@ -35,6 +38,7 @@ public class ExplicitDelegatingSecurityContextExecutorServiceTests extends @Override protected DelegatingSecurityContextExecutorService create() { - return new DelegatingSecurityContextExecutorService(delegate, securityContext); + return new DelegatingSecurityContextExecutorService(this.delegate, this.securityContext); } -} \ No newline at end of file + +} diff --git a/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorTests.java b/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorTests.java index f4eb4fc0bd..232012d27f 100644 --- a/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextExecutorTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.junit.Before; + import org.springframework.security.core.context.SecurityContext; /** @@ -26,8 +28,7 @@ import org.springframework.security.core.context.SecurityContext; * @since 3.2 * */ -public class ExplicitDelegatingSecurityContextExecutorTests extends - AbstractDelegatingSecurityContextExecutorTests { +public class ExplicitDelegatingSecurityContextExecutorTests extends AbstractDelegatingSecurityContextExecutorTests { @Before public void setUp() throws Exception { @@ -36,6 +37,7 @@ public class ExplicitDelegatingSecurityContextExecutorTests extends @Override protected DelegatingSecurityContextExecutor create() { - return new DelegatingSecurityContextExecutor(getExecutor(), securityContext); + return new DelegatingSecurityContextExecutor(getExecutor(), this.securityContext); } + } diff --git a/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextScheduledExecutorServiceTests.java b/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextScheduledExecutorServiceTests.java index a8058a6e0a..a59bd8164f 100644 --- a/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextScheduledExecutorServiceTests.java +++ b/core/src/test/java/org/springframework/security/concurrent/ExplicitDelegatingSecurityContextScheduledExecutorServiceTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.concurrent; import org.junit.Before; + import org.springframework.security.core.context.SecurityContext; /** @@ -26,8 +28,8 @@ import org.springframework.security.core.context.SecurityContext; * @since 3.2 * */ -public class ExplicitDelegatingSecurityContextScheduledExecutorServiceTests extends - AbstractDelegatingSecurityContextScheduledExecutorServiceTests { +public class ExplicitDelegatingSecurityContextScheduledExecutorServiceTests + extends AbstractDelegatingSecurityContextScheduledExecutorServiceTests { @Before public void setUp() throws Exception { @@ -36,7 +38,7 @@ public class ExplicitDelegatingSecurityContextScheduledExecutorServiceTests exte @Override protected DelegatingSecurityContextScheduledExecutorService create() { - return new DelegatingSecurityContextScheduledExecutorService(delegate, - securityContext); + return new DelegatingSecurityContextScheduledExecutorService(this.delegate, this.securityContext); } -} \ No newline at end of file + +} diff --git a/core/src/test/java/org/springframework/security/context/DelegatingApplicationListenerTests.java b/core/src/test/java/org/springframework/security/context/DelegatingApplicationListenerTests.java index ed3cd8bbdc..86d242ad66 100644 --- a/core/src/test/java/org/springframework/security/context/DelegatingApplicationListenerTests.java +++ b/core/src/test/java/org/springframework/security/context/DelegatingApplicationListenerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.context; import org.junit.Before; @@ -20,16 +21,18 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.context.ApplicationEvent; import org.springframework.context.event.SmartApplicationListener; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class DelegatingApplicationListenerTests { + @Mock SmartApplicationListener delegate; @@ -39,46 +42,42 @@ public class DelegatingApplicationListenerTests { @Before public void setup() { - event = new ApplicationEvent(this) { + this.event = new ApplicationEvent(this) { }; - listener = new DelegatingApplicationListener(); - listener.addListener(delegate); + this.listener = new DelegatingApplicationListener(); + this.listener.addListener(this.delegate); } @Test public void processEventNull() { - listener.onApplicationEvent(null); - - verify(delegate, never()).onApplicationEvent(any(ApplicationEvent.class)); + this.listener.onApplicationEvent(null); + verify(this.delegate, never()).onApplicationEvent(any(ApplicationEvent.class)); } @Test public void processEventSuccess() { - when(delegate.supportsEventType(event.getClass())).thenReturn(true); - when(delegate.supportsSourceType(event.getSource().getClass())).thenReturn(true); - listener.onApplicationEvent(event); - - verify(delegate).onApplicationEvent(event); + given(this.delegate.supportsEventType(this.event.getClass())).willReturn(true); + given(this.delegate.supportsSourceType(this.event.getSource().getClass())).willReturn(true); + this.listener.onApplicationEvent(this.event); + verify(this.delegate).onApplicationEvent(this.event); } @Test public void processEventEventTypeNotSupported() { - listener.onApplicationEvent(event); - - verify(delegate, never()).onApplicationEvent(any(ApplicationEvent.class)); + this.listener.onApplicationEvent(this.event); + verify(this.delegate, never()).onApplicationEvent(any(ApplicationEvent.class)); } @Test public void processEventSourceTypeNotSupported() { - when(delegate.supportsEventType(event.getClass())).thenReturn(true); - listener.onApplicationEvent(event); - - verify(delegate, never()).onApplicationEvent(any(ApplicationEvent.class)); + given(this.delegate.supportsEventType(this.event.getClass())).willReturn(true); + this.listener.onApplicationEvent(this.event); + verify(this.delegate, never()).onApplicationEvent(any(ApplicationEvent.class)); } @Test(expected = IllegalArgumentException.class) public void addNull() { - listener.addListener(null); + this.listener.addListener(null); } } diff --git a/core/src/test/java/org/springframework/security/converter/RsaKeyConvertersTest.java b/core/src/test/java/org/springframework/security/converter/RsaKeyConvertersTest.java deleted file mode 100644 index f1ba277e12..0000000000 --- a/core/src/test/java/org/springframework/security/converter/RsaKeyConvertersTest.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.converter; - -import java.io.ByteArrayInputStream; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.security.interfaces.RSAPrivateCrtKey; -import java.security.interfaces.RSAPrivateKey; -import java.security.interfaces.RSAPublicKey; - -import org.assertj.core.api.Assertions; -import org.assertj.core.api.AssertionsForClassTypes; -import org.junit.Test; - -import org.springframework.core.convert.converter.Converter; - -/** - * Tests for {@link RsaKeyConverters} - */ -public class RsaKeyConvertersTest { - private static final String PKCS8_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n" + - "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCMk7CKSTfu3QoV\n" + - "HoPVXxwZO+qweztd36cVWYqGOZinrOR2crWFu50AgR2CsdIH0+cqo7F4Vx7/3O8i\n" + - "RpYYZPe2VoO5sumzJt8P6fS80/TAKjhJDAqgZKRJTgGN8KxCM6p/aJli1ZeDBqiV\n" + - "v7vJJe+ZgJuPGRS+HMNa/wPxEkqqXsglcJcQV1ZEtfKXSHB7jizKpRL38185SyAC\n" + - "pwyjvBu6Cmm1URfhQo88mf239ONh4dZ2HoDfzN1q6Ssu4F4hgutxr9B0DVLDP5u+\n" + - "WFrm3nsJ76zf99uJ+ntMUHJ+bY+gOjSlVWIVBIZeAaEGKCNWRk/knjvjbijpvm3U\n" + - "acGlgdL3AgMBAAECggEACxxxS7zVyu91qI2s5eSKmAQAXMqgup6+2hUluc47nqUv\n" + - "uZz/c/6MPkn2Ryo+65d4IgqmMFjSfm68B/2ER5FTcvoLl1Xo2twrrVpUmcg3BClS\n" + - "IZPuExdhVNnxjYKEWwcyZrehyAoR261fDdcFxLRW588efIUC+rPTTRHzAc7sT+Ln\n" + - "t/uFeYNWJm3LaegOLoOmlMAhJ5puAWSN1F0FxtRf/RVgzbLA9QC975SKHJsfWCSr\n" + - "IZyPsdeaqomKaF65l8nfqlE0Ua2L35gIOGKjUwb7uUE8nI362RWMtYdoi3zDDyoY\n" + - "hSFbgjylCHDM0u6iSh6KfqOHtkYyJ8tUYgVWl787wQKBgQDYO3wL7xuDdD101Lyl\n" + - "AnaDdFB9fxp83FG1cWr+t7LYm9YxGfEUsKHAJXN6TIayDkOOoVwIl+Gz0T3Z06Bm\n" + - "eBGLrB9mrVA7+C7NJwu5gTMlzP6HxUR9zKJIQ/VB1NUGM77LSmvOFbHc9Q0+z8EH\n" + - "X5WO516a3Z7lNtZJcCoPOtu2rwKBgQCmbj41Fh+SSEUApCEKms5ETRpe7LXQlJgx\n" + - "yW7zcJNNuIb1C3vBLPxjiOTMgYKOeMg5rtHTGLT43URHLh9ArjawasjSAr4AM3J4\n" + - "xpoi/sKGDdiKOsuDWIGfzdYL8qyTHSdpZLQsCTMRiRYgAHZFPgNa7SLZRfZicGlr\n" + - "GHN1rJW6OQKBgEjiM/upyrJSWeypUDSmUeAZMpA6aWkwsfHgmtnkfUn5rQa74cDB\n" + - "kKO9e+D7LmOR3z+SL/1NhGwh2SE07dncGr3jdGodfO/ZxZyszozmeaECKcEFwwJM\n" + - "GV8WWPKplGwUwPiwywmZ0mvRxXcoe73KgBS88+xrSwWjqDL0tZiQlEJNAoGATkei\n" + - "GMQMG3jEg9Wu+NbxV6zQT3+U0MNjhl9RQU1c63x0dcNt9OFc4NAdlZcAulRTENaK\n" + - "OHjxffBM0hH+fySx8m53gFfr2BpaqDX5f6ZGBlly1SlsWZ4CchCVsc71nshipi7I\n" + - "k8HL9F5/OpQdDNprJ5RMBNfkWE65Nrcsb1e6oPkCgYAxwgdiSOtNg8PjDVDmAhwT\n" + - "Mxj0Dtwi2fAqQ76RVrrXpNp3uCOIAu4CfruIb5llcJ3uak0ZbnWri32AxSgk80y3\n" + - "EWiRX/WEDu5znejF+5O3pI02atWWcnxifEKGGlxwkcMbQdA67MlrJLFaSnnGpNXo\n" + - "yPfcul058SOqhafIZQMEKQ==\n" + - "-----END PRIVATE KEY-----"; - - private static final String PKCS1_PRIVATE_KEY = - "-----BEGIN RSA PRIVATE KEY-----\n" + - "MIICWwIBAAKBgQDdlatRjRjogo3WojgGHFHYLugdUWAY9iR3fy4arWNA1KoS8kVw\n" + - "33cJibXr8bvwUAUparCwlvdbH6dvEOfou0/gCFQsHUfQrSDv+MuSUMAe8jzKE4qW\n" + - "+jK+xQU9a03GUnKHkkle+Q0pX/g6jXZ7r1/xAK5Do2kQ+X5xK9cipRgEKwIDAQAB\n" + - "AoGAD+onAtVye4ic7VR7V50DF9bOnwRwNXrARcDhq9LWNRrRGElESYYTQ6EbatXS\n" + - "3MCyjjX2eMhu/aF5YhXBwkppwxg+EOmXeh+MzL7Zh284OuPbkglAaGhV9bb6/5Cp\n" + - "uGb1esyPbYW+Ty2PC0GSZfIXkXs76jXAu9TOBvD0ybc2YlkCQQDywg2R/7t3Q2OE\n" + - "2+yo382CLJdrlSLVROWKwb4tb2PjhY4XAwV8d1vy0RenxTB+K5Mu57uVSTHtrMK0\n" + - "GAtFr833AkEA6avx20OHo61Yela/4k5kQDtjEf1N0LfI+BcWZtxsS3jDM3i1Hp0K\n" + - "Su5rsCPb8acJo5RO26gGVrfAsDcIXKC+bQJAZZ2XIpsitLyPpuiMOvBbzPavd4gY\n" + - "6Z8KWrfYzJoI/Q9FuBo6rKwl4BFoToD7WIUS+hpkagwWiz+6zLoX1dbOZwJACmH5\n" + - "fSSjAkLRi54PKJ8TFUeOP15h9sQzydI8zJU+upvDEKZsZc/UhT/SySDOxQ4G/523\n" + - "Y0sz/OZtSWcol/UMgQJALesy++GdvoIDLfJX5GBQpuFgFenRiRDabxrE9MNUZ2aP\n" + - "FaFp+DyAe+b4nDwuJaW2LURbr8AEZga7oQj0uYxcYw==\n" + - "-----END RSA PRIVATE KEY-----"; - - private static final String X509_PUBLIC_KEY = - "-----BEGIN PUBLIC KEY-----\n" + - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDdlatRjRjogo3WojgGHFHYLugd\n" + - "UWAY9iR3fy4arWNA1KoS8kVw33cJibXr8bvwUAUparCwlvdbH6dvEOfou0/gCFQs\n" + - "HUfQrSDv+MuSUMAe8jzKE4qW+jK+xQU9a03GUnKHkkle+Q0pX/g6jXZ7r1/xAK5D\n" + - "o2kQ+X5xK9cipRgEKwIDAQAB\n" + - "-----END PUBLIC KEY-----"; - - private static final String MALFORMED_X509_KEY = "malformed"; - - private final Converter x509 = RsaKeyConverters.x509(); - private final Converter pkcs8 = RsaKeyConverters.pkcs8(); - - @Test - public void pkcs8WhenConvertingPkcs8PrivateKeyThenOk() { - RSAPrivateKey key = this.pkcs8.convert(toInputStream(PKCS8_PRIVATE_KEY)); - Assertions.assertThat(key).isInstanceOf(RSAPrivateCrtKey.class); - Assertions.assertThat(key.getModulus().bitLength()).isEqualTo(2048); - } - - @Test - public void pkcs8WhenConvertingPkcs1PrivateKeyThenIllegalArgumentException() { - AssertionsForClassTypes.assertThatCode(() -> this.pkcs8.convert(toInputStream(PKCS1_PRIVATE_KEY))) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void x509WhenConverteringX509PublicKeyThenOk() { - RSAPublicKey key = this.x509.convert(toInputStream(X509_PUBLIC_KEY)); - Assertions.assertThat(key.getModulus().bitLength()).isEqualTo(1024); - } - - @Test - public void x509WhenConvertingDerEncodedX509PublicKeyThenIllegalArgumentException() { - AssertionsForClassTypes.assertThatCode(() -> this.x509.convert(toInputStream(MALFORMED_X509_KEY))) - .isInstanceOf(IllegalArgumentException.class); - } - - private static InputStream toInputStream(String string) { - return new ByteArrayInputStream(string.getBytes(StandardCharsets.UTF_8)); - } -} diff --git a/core/src/test/java/org/springframework/security/converter/RsaKeyConvertersTests.java b/core/src/test/java/org/springframework/security/converter/RsaKeyConvertersTests.java new file mode 100644 index 0000000000..3d70a84d93 --- /dev/null +++ b/core/src/test/java/org/springframework/security/converter/RsaKeyConvertersTests.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.converter; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.interfaces.RSAPrivateCrtKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; + +import org.assertj.core.api.Assertions; +import org.junit.Test; + +import org.springframework.core.convert.converter.Converter; + +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link RsaKeyConverters} + */ +public class RsaKeyConvertersTests { + + // @formatter:off + private static final String PKCS8_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n" + + "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCMk7CKSTfu3QoV\n" + + "HoPVXxwZO+qweztd36cVWYqGOZinrOR2crWFu50AgR2CsdIH0+cqo7F4Vx7/3O8i\n" + + "RpYYZPe2VoO5sumzJt8P6fS80/TAKjhJDAqgZKRJTgGN8KxCM6p/aJli1ZeDBqiV\n" + + "v7vJJe+ZgJuPGRS+HMNa/wPxEkqqXsglcJcQV1ZEtfKXSHB7jizKpRL38185SyAC\n" + + "pwyjvBu6Cmm1URfhQo88mf239ONh4dZ2HoDfzN1q6Ssu4F4hgutxr9B0DVLDP5u+\n" + + "WFrm3nsJ76zf99uJ+ntMUHJ+bY+gOjSlVWIVBIZeAaEGKCNWRk/knjvjbijpvm3U\n" + + "acGlgdL3AgMBAAECggEACxxxS7zVyu91qI2s5eSKmAQAXMqgup6+2hUluc47nqUv\n" + + "uZz/c/6MPkn2Ryo+65d4IgqmMFjSfm68B/2ER5FTcvoLl1Xo2twrrVpUmcg3BClS\n" + + "IZPuExdhVNnxjYKEWwcyZrehyAoR261fDdcFxLRW588efIUC+rPTTRHzAc7sT+Ln\n" + + "t/uFeYNWJm3LaegOLoOmlMAhJ5puAWSN1F0FxtRf/RVgzbLA9QC975SKHJsfWCSr\n" + + "IZyPsdeaqomKaF65l8nfqlE0Ua2L35gIOGKjUwb7uUE8nI362RWMtYdoi3zDDyoY\n" + + "hSFbgjylCHDM0u6iSh6KfqOHtkYyJ8tUYgVWl787wQKBgQDYO3wL7xuDdD101Lyl\n" + + "AnaDdFB9fxp83FG1cWr+t7LYm9YxGfEUsKHAJXN6TIayDkOOoVwIl+Gz0T3Z06Bm\n" + + "eBGLrB9mrVA7+C7NJwu5gTMlzP6HxUR9zKJIQ/VB1NUGM77LSmvOFbHc9Q0+z8EH\n" + + "X5WO516a3Z7lNtZJcCoPOtu2rwKBgQCmbj41Fh+SSEUApCEKms5ETRpe7LXQlJgx\n" + + "yW7zcJNNuIb1C3vBLPxjiOTMgYKOeMg5rtHTGLT43URHLh9ArjawasjSAr4AM3J4\n" + + "xpoi/sKGDdiKOsuDWIGfzdYL8qyTHSdpZLQsCTMRiRYgAHZFPgNa7SLZRfZicGlr\n" + + "GHN1rJW6OQKBgEjiM/upyrJSWeypUDSmUeAZMpA6aWkwsfHgmtnkfUn5rQa74cDB\n" + + "kKO9e+D7LmOR3z+SL/1NhGwh2SE07dncGr3jdGodfO/ZxZyszozmeaECKcEFwwJM\n" + + "GV8WWPKplGwUwPiwywmZ0mvRxXcoe73KgBS88+xrSwWjqDL0tZiQlEJNAoGATkei\n" + + "GMQMG3jEg9Wu+NbxV6zQT3+U0MNjhl9RQU1c63x0dcNt9OFc4NAdlZcAulRTENaK\n" + + "OHjxffBM0hH+fySx8m53gFfr2BpaqDX5f6ZGBlly1SlsWZ4CchCVsc71nshipi7I\n" + + "k8HL9F5/OpQdDNprJ5RMBNfkWE65Nrcsb1e6oPkCgYAxwgdiSOtNg8PjDVDmAhwT\n" + + "Mxj0Dtwi2fAqQ76RVrrXpNp3uCOIAu4CfruIb5llcJ3uak0ZbnWri32AxSgk80y3\n" + + "EWiRX/WEDu5znejF+5O3pI02atWWcnxifEKGGlxwkcMbQdA67MlrJLFaSnnGpNXo\n" + + "yPfcul058SOqhafIZQMEKQ==\n" + + "-----END PRIVATE KEY-----"; + // @formatter:on + + // @formatter:off + private static final String PKCS1_PRIVATE_KEY = "-----BEGIN RSA PRIVATE KEY-----\n" + + "MIICWwIBAAKBgQDdlatRjRjogo3WojgGHFHYLugdUWAY9iR3fy4arWNA1KoS8kVw\n" + + "33cJibXr8bvwUAUparCwlvdbH6dvEOfou0/gCFQsHUfQrSDv+MuSUMAe8jzKE4qW\n" + + "+jK+xQU9a03GUnKHkkle+Q0pX/g6jXZ7r1/xAK5Do2kQ+X5xK9cipRgEKwIDAQAB\n" + + "AoGAD+onAtVye4ic7VR7V50DF9bOnwRwNXrARcDhq9LWNRrRGElESYYTQ6EbatXS\n" + + "3MCyjjX2eMhu/aF5YhXBwkppwxg+EOmXeh+MzL7Zh284OuPbkglAaGhV9bb6/5Cp\n" + + "uGb1esyPbYW+Ty2PC0GSZfIXkXs76jXAu9TOBvD0ybc2YlkCQQDywg2R/7t3Q2OE\n" + + "2+yo382CLJdrlSLVROWKwb4tb2PjhY4XAwV8d1vy0RenxTB+K5Mu57uVSTHtrMK0\n" + + "GAtFr833AkEA6avx20OHo61Yela/4k5kQDtjEf1N0LfI+BcWZtxsS3jDM3i1Hp0K\n" + + "Su5rsCPb8acJo5RO26gGVrfAsDcIXKC+bQJAZZ2XIpsitLyPpuiMOvBbzPavd4gY\n" + + "6Z8KWrfYzJoI/Q9FuBo6rKwl4BFoToD7WIUS+hpkagwWiz+6zLoX1dbOZwJACmH5\n" + + "fSSjAkLRi54PKJ8TFUeOP15h9sQzydI8zJU+upvDEKZsZc/UhT/SySDOxQ4G/523\n" + + "Y0sz/OZtSWcol/UMgQJALesy++GdvoIDLfJX5GBQpuFgFenRiRDabxrE9MNUZ2aP\n" + + "FaFp+DyAe+b4nDwuJaW2LURbr8AEZga7oQj0uYxcYw==\n" + + "-----END RSA PRIVATE KEY-----"; + // @formatter:on + + // @formatter:off + private static final String X509_PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDdlatRjRjogo3WojgGHFHYLugd\n" + + "UWAY9iR3fy4arWNA1KoS8kVw33cJibXr8bvwUAUparCwlvdbH6dvEOfou0/gCFQs\n" + + "HUfQrSDv+MuSUMAe8jzKE4qW+jK+xQU9a03GUnKHkkle+Q0pX/g6jXZ7r1/xAK5D\n" + + "o2kQ+X5xK9cipRgEKwIDAQAB\n" + + "-----END PUBLIC KEY-----"; + // @formatter:on + + private static final String MALFORMED_X509_KEY = "malformed"; + + private final Converter x509 = RsaKeyConverters.x509(); + + private final Converter pkcs8 = RsaKeyConverters.pkcs8(); + + @Test + public void pkcs8WhenConvertingPkcs8PrivateKeyThenOk() { + RSAPrivateKey key = this.pkcs8.convert(toInputStream(PKCS8_PRIVATE_KEY)); + Assertions.assertThat(key).isInstanceOf(RSAPrivateCrtKey.class); + Assertions.assertThat(key.getModulus().bitLength()).isEqualTo(2048); + } + + @Test + public void pkcs8WhenConvertingPkcs1PrivateKeyThenIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.pkcs8.convert(toInputStream(PKCS1_PRIVATE_KEY))); + } + + @Test + public void x509WhenConverteringX509PublicKeyThenOk() { + RSAPublicKey key = this.x509.convert(toInputStream(X509_PUBLIC_KEY)); + Assertions.assertThat(key.getModulus().bitLength()).isEqualTo(1024); + } + + @Test + public void x509WhenConvertingDerEncodedX509PublicKeyThenIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.x509.convert(toInputStream(MALFORMED_X509_KEY))); + } + + private static InputStream toInputStream(String string) { + return new ByteArrayInputStream(string.getBytes(StandardCharsets.UTF_8)); + } + +} diff --git a/core/src/test/java/org/springframework/security/core/JavaVersionTests.java b/core/src/test/java/org/springframework/security/core/JavaVersionTests.java index f7a6ab7057..5f2602f7ab 100644 --- a/core/src/test/java/org/springframework/security/core/JavaVersionTests.java +++ b/core/src/test/java/org/springframework/security/core/JavaVersionTests.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core; -import org.junit.Test; +package org.springframework.security.core; import java.io.DataInputStream; import java.io.InputStream; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** - * * @author Rob Winch * */ @@ -47,4 +47,5 @@ public class JavaVersionTests { assertThat(major).isEqualTo(JDK8_CLASS_VERSION); } } + } diff --git a/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java b/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java index 27d137b690..33b297c385 100644 --- a/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java +++ b/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java @@ -13,16 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.powermock.api.mockito.PowerMockito.doReturn; -import static org.powermock.api.mockito.PowerMockito.spy; +package org.springframework.security.core; import org.apache.commons.logging.Log; import org.junit.After; @@ -30,11 +22,20 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.reflect.Whitebox; + import org.springframework.core.SpringVersion; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; + /** * Checks that the embedded version information is up to date. * @@ -50,7 +51,7 @@ public class SpringSecurityCoreVersionTests { @Before public void setup() { - Whitebox.setInternalState(SpringSecurityCoreVersion.class, logger); + Whitebox.setInternalState(SpringSecurityCoreVersion.class, this.logger); } @After @@ -62,98 +63,81 @@ public class SpringSecurityCoreVersionTests { public void springVersionIsUpToDate() { // Property is set by the build script String springVersion = System.getProperty("springVersion"); - assertThat(SpringSecurityCoreVersion.MIN_SPRING_VERSION).isEqualTo(springVersion); } @Test public void serialVersionMajorAndMinorVersionMatchBuildVersion() { String version = System.getProperty("springSecurityVersion"); - // Strip patch version - String serialVersion = String.valueOf( - SpringSecurityCoreVersion.SERIAL_VERSION_UID).substring(0, 2); - + String serialVersion = String.valueOf(SpringSecurityCoreVersion.SERIAL_VERSION_UID).substring(0, 2); assertThat(serialVersion.charAt(0)).isEqualTo(version.charAt(0)); assertThat(serialVersion.charAt(1)).isEqualTo(version.charAt(2)); - } // SEC-2295 @Test public void noLoggingIfVersionsAreEqual() throws Exception { String version = "1"; - spy(SpringSecurityCoreVersion.class); - spy(SpringVersion.class); - doReturn(version).when(SpringSecurityCoreVersion.class, "getVersion"); - doReturn(version).when(SpringVersion.class, "getVersion"); - + PowerMockito.spy(SpringSecurityCoreVersion.class); + PowerMockito.spy(SpringVersion.class); + PowerMockito.doReturn(version).when(SpringSecurityCoreVersion.class, "getVersion"); + PowerMockito.doReturn(version).when(SpringVersion.class, "getVersion"); performChecks(); - - verifyZeroInteractions(logger); + verifyZeroInteractions(this.logger); } @Test public void noLoggingIfSpringVersionNull() throws Exception { - spy(SpringSecurityCoreVersion.class); - spy(SpringVersion.class); - doReturn("1").when(SpringSecurityCoreVersion.class, "getVersion"); - doReturn(null).when(SpringVersion.class, "getVersion"); - + PowerMockito.spy(SpringSecurityCoreVersion.class); + PowerMockito.spy(SpringVersion.class); + PowerMockito.doReturn("1").when(SpringSecurityCoreVersion.class, "getVersion"); + PowerMockito.doReturn(null).when(SpringVersion.class, "getVersion"); performChecks(); - - verifyZeroInteractions(logger); + verifyZeroInteractions(this.logger); } @Test public void warnIfSpringVersionTooSmall() throws Exception { - spy(SpringSecurityCoreVersion.class); - spy(SpringVersion.class); - doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion"); - doReturn("2").when(SpringVersion.class, "getVersion"); - + PowerMockito.spy(SpringSecurityCoreVersion.class); + PowerMockito.spy(SpringVersion.class); + PowerMockito.doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion"); + PowerMockito.doReturn("2").when(SpringVersion.class, "getVersion"); performChecks(); - - verify(logger, times(1)).warn(any()); + verify(this.logger, times(1)).warn(any()); } @Test public void noWarnIfSpringVersionLarger() throws Exception { - spy(SpringSecurityCoreVersion.class); - spy(SpringVersion.class); - doReturn("4.0.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion"); - doReturn("4.0.0.RELEASE").when(SpringVersion.class, "getVersion"); - + PowerMockito.spy(SpringSecurityCoreVersion.class); + PowerMockito.spy(SpringVersion.class); + PowerMockito.doReturn("4.0.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion"); + PowerMockito.doReturn("4.0.0.RELEASE").when(SpringVersion.class, "getVersion"); performChecks(); - - verify(logger, never()).warn(any()); + verify(this.logger, never()).warn(any()); } // SEC-2697 @Test public void noWarnIfSpringPatchVersionDoubleDigits() throws Exception { String minSpringVersion = "3.2.8.RELEASE"; - spy(SpringSecurityCoreVersion.class); - spy(SpringVersion.class); - doReturn("3.2.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion"); - doReturn("3.2.10.RELEASE").when(SpringVersion.class, "getVersion"); - + PowerMockito.spy(SpringSecurityCoreVersion.class); + PowerMockito.spy(SpringVersion.class); + PowerMockito.doReturn("3.2.0.RELEASE").when(SpringSecurityCoreVersion.class, "getVersion"); + PowerMockito.doReturn("3.2.10.RELEASE").when(SpringVersion.class, "getVersion"); performChecks(minSpringVersion); - - verify(logger, never()).warn(any()); + verify(this.logger, never()).warn(any()); } @Test public void noLoggingIfPropertySet() throws Exception { - spy(SpringSecurityCoreVersion.class); - spy(SpringVersion.class); - doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion"); - doReturn("2").when(SpringVersion.class, "getVersion"); + PowerMockito.spy(SpringSecurityCoreVersion.class); + PowerMockito.spy(SpringVersion.class); + PowerMockito.doReturn("3").when(SpringSecurityCoreVersion.class, "getVersion"); + PowerMockito.doReturn("2").when(SpringVersion.class, "getVersion"); System.setProperty(getDisableChecksProperty(), Boolean.TRUE.toString()); - performChecks(); - - verifyZeroInteractions(logger); + verifyZeroInteractions(this.logger); } private String getDisableChecksProperty() { @@ -165,7 +149,7 @@ public class SpringSecurityCoreVersionTests { } private void performChecks(String minSpringVersion) throws Exception { - Whitebox.invokeMethod(SpringSecurityCoreVersion.class, "performVersionChecks", - minSpringVersion); + Whitebox.invokeMethod(SpringSecurityCoreVersion.class, "performVersionChecks", minSpringVersion); } + } diff --git a/core/src/test/java/org/springframework/security/core/SpringSecurityMessageSourceTests.java b/core/src/test/java/org/springframework/security/core/SpringSecurityMessageSourceTests.java index dca405de83..b795c3609b 100644 --- a/core/src/test/java/org/springframework/security/core/SpringSecurityMessageSourceTests.java +++ b/core/src/test/java/org/springframework/security/core/SpringSecurityMessageSourceTests.java @@ -16,27 +16,25 @@ package org.springframework.security.core; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Locale; import org.junit.Test; + import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.context.support.MessageSourceAccessor; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link org.springframework.security.core.SpringSecurityMessageSource}. */ public class SpringSecurityMessageSourceTests { - // ~ Methods - // ======================================================================================================== @Test public void testOperation() { SpringSecurityMessageSource msgs = new SpringSecurityMessageSource(); assertThat("\u4E0D\u5141\u8BB8\u8BBF\u95EE").isEqualTo( - msgs.getMessage("AbstractAccessDecisionManager.accessDenied", null, - Locale.SIMPLIFIED_CHINESE)); + msgs.getMessage("AbstractAccessDecisionManager.accessDenied", null, Locale.SIMPLIFIED_CHINESE)); } @Test @@ -44,13 +42,10 @@ public class SpringSecurityMessageSourceTests { // Change Locale to English Locale before = LocaleContextHolder.getLocale(); LocaleContextHolder.setLocale(Locale.FRENCH); - // Cause a message to be generated MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); - assertThat("Le jeton nonce est compromis FOOBAR").isEqualTo( - messages.getMessage("DigestAuthenticationFilter.nonceCompromised", - new Object[] { "FOOBAR" }, "ERROR - FAILED TO LOOKUP")); - + assertThat("Le jeton nonce est compromis FOOBAR").isEqualTo(messages.getMessage( + "DigestAuthenticationFilter.nonceCompromised", new Object[] { "FOOBAR" }, "ERROR - FAILED TO LOOKUP")); // Revert to original Locale LocaleContextHolder.setLocale(before); } @@ -60,16 +55,14 @@ public class SpringSecurityMessageSourceTests { public void germanSystemLocaleWithEnglishLocaleContextHolder() { Locale beforeSystem = Locale.getDefault(); Locale.setDefault(Locale.GERMAN); - Locale beforeHolder = LocaleContextHolder.getLocale(); LocaleContextHolder.setLocale(Locale.US); - MessageSourceAccessor msgs = SpringSecurityMessageSource.getAccessor(); - assertThat("Access is denied").isEqualTo( - msgs.getMessage("AbstractAccessDecisionManager.accessDenied", "Ooops")); - + assertThat("Access is denied") + .isEqualTo(msgs.getMessage("AbstractAccessDecisionManager.accessDenied", "Ooops")); // Revert to original Locale Locale.setDefault(beforeSystem); LocaleContextHolder.setLocale(beforeHolder); } + } diff --git a/core/src/test/java/org/springframework/security/core/authority/AuthorityUtilsTests.java b/core/src/test/java/org/springframework/security/core/authority/AuthorityUtilsTests.java index 34ad79bf24..34af08348f 100644 --- a/core/src/test/java/org/springframework/security/core/authority/AuthorityUtilsTests.java +++ b/core/src/test/java/org/springframework/security/core/authority/AuthorityUtilsTests.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core.authority; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.core.authority; import java.util.List; import java.util.Set; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.AuthorityUtils; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor @@ -33,13 +34,12 @@ public class AuthorityUtilsTests { public void commaSeparatedStringIsParsedCorrectly() { List authorityArray = AuthorityUtils .commaSeparatedStringToAuthorityList(" ROLE_A, B, C, ROLE_D\n,\n E "); - Set authorities = AuthorityUtils.authorityListToSet(authorityArray); - assertThat(authorities.contains("B")).isTrue(); assertThat(authorities.contains("C")).isTrue(); assertThat(authorities.contains("E")).isTrue(); assertThat(authorities.contains("ROLE_A")).isTrue(); assertThat(authorities.contains("ROLE_D")).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/core/authority/SimpleGrantedAuthorityTests.java b/core/src/test/java/org/springframework/security/core/authority/SimpleGrantedAuthorityTests.java index ab065d1a9d..09177a9fef 100644 --- a/core/src/test/java/org/springframework/security/core/authority/SimpleGrantedAuthorityTests.java +++ b/core/src/test/java/org/springframework/security/core/authority/SimpleGrantedAuthorityTests.java @@ -16,12 +16,13 @@ package org.springframework.security.core.authority; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import org.junit.Test; -import org.junit.*; import org.springframework.security.core.GrantedAuthority; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + /** * Tests {@link SimpleGrantedAuthority}. * @@ -34,14 +35,10 @@ public class SimpleGrantedAuthorityTests { SimpleGrantedAuthority auth1 = new SimpleGrantedAuthority("TEST"); assertThat(auth1).isEqualTo(auth1); assertThat(new SimpleGrantedAuthority("TEST")).isEqualTo(auth1); - assertThat(auth1.equals("TEST")).isFalse(); - SimpleGrantedAuthority auth3 = new SimpleGrantedAuthority("NOT_EQUAL"); assertThat(!auth1.equals(auth3)).isTrue(); - assertThat(auth1.equals(mock(GrantedAuthority.class))).isFalse(); - assertThat(auth1.equals(222)).isFalse(); } diff --git a/core/src/test/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapperTests.java b/core/src/test/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapperTests.java index dba49f94a2..b57bbc3801 100644 --- a/core/src/test/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapperTests.java +++ b/core/src/test/java/org/springframework/security/core/authority/mapping/MapBasedAttributes2GrantedAuthoritiesMapperTests.java @@ -13,18 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; -import static org.assertj.core.api.Assertions.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; /** - * * @author Ruud Senden */ @SuppressWarnings("unchecked") @@ -164,10 +169,10 @@ public class MapBasedAttributes2GrantedAuthoritiesMapperTests { @Test public void testMappingCombination() throws Exception { - String[] roles = { "role1", "role2", "role3", "role4", "role5", "role6", "role7", - "role8", "role9", "role10", "role11" }; - String[] expectedGas = { "ga1", "ga2", "ga3", "ga4", "ga5", "ga6", "ga7", "ga8", - "ga9", "ga10", "ga11", "ga12", "ga13", "ga14" }; + String[] roles = { "role1", "role2", "role3", "role4", "role5", "role6", "role7", "role8", "role9", "role10", + "role11" }; + String[] expectedGas = { "ga1", "ga2", "ga3", "ga4", "ga5", "ga6", "ga7", "ga8", "ga9", "ga10", "ga11", "ga12", + "ga13", "ga14" }; testGetGrantedAuthorities(getDefaultMapper(), roles, expectedGas); } @@ -177,10 +182,8 @@ public class MapBasedAttributes2GrantedAuthoritiesMapperTests { m.put("role2", new SimpleGrantedAuthority("ga2")); m.put("role3", Arrays.asList("ga3", new SimpleGrantedAuthority("ga4"))); m.put("role4", "ga5,ga6"); - m.put("role5", Arrays.asList("ga7", "ga8", - new Object[] { new SimpleGrantedAuthority("ga9") })); - m.put("role6", new Object[] { "ga10", "ga11", - new Object[] { new SimpleGrantedAuthority("ga12") } }); + m.put("role5", Arrays.asList("ga7", "ga8", new Object[] { new SimpleGrantedAuthority("ga9") })); + m.put("role6", new Object[] { "ga10", "ga11", new Object[] { new SimpleGrantedAuthority("ga12") } }); m.put("role7", new String[] { "ga13", "ga14" }); m.put("role8", new String[] { "ga13", "ga14", null }); m.put("role9", null); @@ -189,25 +192,24 @@ public class MapBasedAttributes2GrantedAuthoritiesMapperTests { return m; } - private MapBasedAttributes2GrantedAuthoritiesMapper getDefaultMapper() - throws Exception { + private MapBasedAttributes2GrantedAuthoritiesMapper getDefaultMapper() throws Exception { MapBasedAttributes2GrantedAuthoritiesMapper mapper = new MapBasedAttributes2GrantedAuthoritiesMapper(); mapper.setAttributes2grantedAuthoritiesMap(getValidAttributes2GrantedAuthoritiesMap()); mapper.afterPropertiesSet(); return mapper; } - private void testGetGrantedAuthorities( - MapBasedAttributes2GrantedAuthoritiesMapper mapper, String[] roles, + private void testGetGrantedAuthorities(MapBasedAttributes2GrantedAuthoritiesMapper mapper, String[] roles, String[] expectedGas) { - List result = mapper - .getGrantedAuthorities(Arrays.asList(roles)); + List result = mapper.getGrantedAuthorities(Arrays.asList(roles)); Collection resultColl = new ArrayList(result.size()); for (GrantedAuthority auth : result) { resultColl.add(auth.getAuthority()); } Collection expectedColl = Arrays.asList(expectedGas); - assertThat(resultColl.containsAll(expectedColl)).withFailMessage("Role collections should match; result: " + resultColl - + ", expected: " + expectedColl).isTrue(); + assertThat(resultColl.containsAll(expectedColl)) + .withFailMessage("Role collections should match; result: " + resultColl + ", expected: " + expectedColl) + .isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleAuthoritiesMapperTests.java b/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleAuthoritiesMapperTests.java index c70474b70e..35781b5f30 100644 --- a/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleAuthoritiesMapperTests.java +++ b/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleAuthoritiesMapperTests.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; -import static org.assertj.core.api.Assertions.*; +import java.util.List; +import java.util.Set; + +import org.junit.Test; -import org.junit.*; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor @@ -39,8 +42,8 @@ public class SimpleAuthoritiesMapperTests { @Test public void defaultPrefixIsCorrectlyApplied() { SimpleAuthorityMapper mapper = new SimpleAuthorityMapper(); - Set mapped = AuthorityUtils.authorityListToSet(mapper - .mapAuthorities(AuthorityUtils.createAuthorityList("AaA", "ROLE_bbb"))); + Set mapped = AuthorityUtils + .authorityListToSet(mapper.mapAuthorities(AuthorityUtils.createAuthorityList("AaA", "ROLE_bbb"))); assertThat(mapped.contains("ROLE_AaA")).isTrue(); assertThat(mapped.contains("ROLE_bbb")).isTrue(); } @@ -50,18 +53,15 @@ public class SimpleAuthoritiesMapperTests { SimpleAuthorityMapper mapper = new SimpleAuthorityMapper(); mapper.setPrefix(""); List toMap = AuthorityUtils.createAuthorityList("AaA", "Bbb"); - Set mapped = AuthorityUtils.authorityListToSet(mapper - .mapAuthorities(toMap)); + Set mapped = AuthorityUtils.authorityListToSet(mapper.mapAuthorities(toMap)); assertThat(mapped).hasSize(2); assertThat(mapped.contains("AaA")).isTrue(); assertThat(mapped.contains("Bbb")).isTrue(); - mapper.setConvertToLowerCase(true); mapped = AuthorityUtils.authorityListToSet(mapper.mapAuthorities(toMap)); assertThat(mapped).hasSize(2); assertThat(mapped.contains("aaa")).isTrue(); assertThat(mapped.contains("bbb")).isTrue(); - mapper.setConvertToLowerCase(false); mapper.setConvertToUpperCase(true); mapped = AuthorityUtils.authorityListToSet(mapper.mapAuthorities(toMap)); @@ -74,9 +74,8 @@ public class SimpleAuthoritiesMapperTests { public void duplicatesAreRemoved() { SimpleAuthorityMapper mapper = new SimpleAuthorityMapper(); mapper.setConvertToUpperCase(true); - - Set mapped = AuthorityUtils.authorityListToSet(mapper - .mapAuthorities(AuthorityUtils.createAuthorityList("AaA", "AAA"))); + Set mapped = AuthorityUtils + .authorityListToSet(mapper.mapAuthorities(AuthorityUtils.createAuthorityList("AaA", "AAA"))); assertThat(mapped).hasSize(1); } @@ -84,9 +83,9 @@ public class SimpleAuthoritiesMapperTests { public void defaultAuthorityIsAssignedIfSet() { SimpleAuthorityMapper mapper = new SimpleAuthorityMapper(); mapper.setDefaultAuthority("ROLE_USER"); - Set mapped = AuthorityUtils.authorityListToSet(mapper - .mapAuthorities(AuthorityUtils.NO_AUTHORITIES)); + Set mapped = AuthorityUtils.authorityListToSet(mapper.mapAuthorities(AuthorityUtils.NO_AUTHORITIES)); assertThat(mapped).hasSize(1); assertThat(mapped.contains("ROLE_USER")).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleMappableRolesRetrieverTests.java b/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleMappableRolesRetrieverTests.java index 3bee9e2f02..2c3cdab572 100644 --- a/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleMappableRolesRetrieverTests.java +++ b/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleMappableRolesRetrieverTests.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core.authority.mapping; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.core.authority.mapping; import java.util.Set; import org.junit.Test; + import org.springframework.util.StringUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author TSARDD * @since 18-okt-2007 */ @@ -35,10 +36,8 @@ public class SimpleMappableRolesRetrieverTests { SimpleMappableAttributesRetriever r = new SimpleMappableAttributesRetriever(); r.setMappableAttributes(roles); Set result = r.getMappableAttributes(); - assertThat( - roles.containsAll(result) && result.containsAll(roles)).withFailMessage( - "Role collections do not match; result: " + result - + ", expected: " + roles).isTrue(); + assertThat(roles.containsAll(result) && result.containsAll(roles)) + .withFailMessage("Role collections do not match; result: " + result + ", expected: " + roles).isTrue(); } } diff --git a/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleRoles2GrantedAuthoritiesMapperTests.java b/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleRoles2GrantedAuthoritiesMapperTests.java index 7388c594fb..da948dd54e 100644 --- a/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleRoles2GrantedAuthoritiesMapperTests.java +++ b/core/src/test/java/org/springframework/security/core/authority/mapping/SimpleRoles2GrantedAuthoritiesMapperTests.java @@ -13,17 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.authority.mapping; -import static org.assertj.core.api.Assertions.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** - * * @author TSARDD * @since 18-okt-2007 */ @@ -123,19 +128,17 @@ public class SimpleRoles2GrantedAuthoritiesMapperTests { testGetGrantedAuthorities(mapper, roles, expectedGas); } - private void testGetGrantedAuthorities( - SimpleAttributes2GrantedAuthoritiesMapper mapper, String[] roles, + private void testGetGrantedAuthorities(SimpleAttributes2GrantedAuthoritiesMapper mapper, String[] roles, String[] expectedGas) { - List result = mapper - .getGrantedAuthorities(Arrays.asList(roles)); + List result = mapper.getGrantedAuthorities(Arrays.asList(roles)); Collection resultColl = new ArrayList<>(result.size()); for (GrantedAuthority grantedAuthority : result) { resultColl.add(grantedAuthority.getAuthority()); } Collection expectedColl = Arrays.asList(expectedGas); - assertThat(expectedColl.containsAll(resultColl) - && resultColl.containsAll(expectedColl)).withFailMessage("Role collections do not match; result: " + resultColl - + ", expected: " + expectedColl).isTrue(); + assertThat(expectedColl.containsAll(resultColl) && resultColl.containsAll(expectedColl)) + .withFailMessage("Role collections do not match; result: " + resultColl + ", expected: " + expectedColl) + .isTrue(); } private SimpleAttributes2GrantedAuthoritiesMapper getDefaultMapper() { diff --git a/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java b/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java index 92414c6668..df166b481b 100644 --- a/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java +++ b/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java @@ -17,11 +17,12 @@ package org.springframework.security.core.context; import org.junit.Test; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.Authentication; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; + /** * @author Rob Winch * @since 5.0 @@ -31,38 +32,39 @@ public class ReactiveSecurityContextHolderTests { @Test public void getContextWhenEmpty() { Mono context = ReactiveSecurityContextHolder.getContext(); - + // @formatter:off StepVerifier.create(context) - .verifyComplete(); + .verifyComplete(); + // @formatter:on } @Test public void setContextAndGetContextThenEmitsContext() { SecurityContext expectedContext = new SecurityContextImpl( - new TestingAuthenticationToken("user", "password", "ROLE_USER")); - + new TestingAuthenticationToken("user", "password", "ROLE_USER")); Mono context = Mono.subscriberContext() - .flatMap( c -> ReactiveSecurityContextHolder.getContext()) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedContext))); - + .flatMap((c) -> ReactiveSecurityContextHolder.getContext()) + .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedContext))); + // @formatter:off StepVerifier.create(context) - .expectNext(expectedContext) - .verifyComplete(); + .expectNext(expectedContext) + .verifyComplete(); + // @formatter:on } @Test public void demo() { Authentication authentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); - + // @formatter:off Mono messageByUsername = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .map(Authentication::getName) - .flatMap(this::findMessageByUsername) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); - + .map(SecurityContext::getAuthentication) + .map(Authentication::getName) + .flatMap(this::findMessageByUsername) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); StepVerifier.create(messageByUsername) - .expectNext("Hi user") - .verifyComplete(); + .expectNext("Hi user") + .verifyComplete(); + // @formatter:on } private Mono findMessageByUsername(String username) { @@ -72,29 +74,29 @@ public class ReactiveSecurityContextHolderTests { @Test public void setContextAndClearAndGetContextThenEmitsEmpty() { SecurityContext expectedContext = new SecurityContextImpl( - new TestingAuthenticationToken("user", "password", "ROLE_USER")); - + new TestingAuthenticationToken("user", "password", "ROLE_USER")); + // @formatter:off Mono context = Mono.subscriberContext() - .flatMap( c -> ReactiveSecurityContextHolder.getContext()) - .subscriberContext(ReactiveSecurityContextHolder.clearContext()) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedContext))); - + .flatMap((c) -> ReactiveSecurityContextHolder.getContext()) + .subscriberContext(ReactiveSecurityContextHolder.clearContext()) + .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedContext))); StepVerifier.create(context) - .verifyComplete(); + .verifyComplete(); + // @formatter:on } @Test public void setAuthenticationAndGetContextThenEmitsContext() { - Authentication expectedAuthentication = new TestingAuthenticationToken("user", - "password", "ROLE_USER"); - + Authentication expectedAuthentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + // @formatter:off Mono authentication = Mono.subscriberContext() - .flatMap( c -> ReactiveSecurityContextHolder.getContext()) - .map(SecurityContext::getAuthentication) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expectedAuthentication)); - + .flatMap((c) -> ReactiveSecurityContextHolder.getContext()) + .map(SecurityContext::getAuthentication) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expectedAuthentication)); StepVerifier.create(authentication) - .expectNext(expectedAuthentication) - .verifyComplete(); + .expectNext(expectedAuthentication) + .verifyComplete(); + // @formatter:on } + } diff --git a/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java b/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java index 00a1b2cf6e..7ea8a2eca8 100644 --- a/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java +++ b/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java @@ -16,12 +16,13 @@ package org.springframework.security.core.context; -import static org.assertj.core.api.Assertions.*; - import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; -import org.springframework.security.core.context.SecurityContextImpl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** * Tests {@link SecurityContextHolder}. @@ -30,12 +31,9 @@ import org.springframework.security.core.context.SecurityContextImpl; */ public class SecurityContextHolderTests { - // ~ Methods - // ======================================================================================================== @Before public final void setUp() { - SecurityContextHolder - .setStrategyName(SecurityContextHolder.MODE_INHERITABLETHREADLOCAL); + SecurityContextHolder.setStrategyName(SecurityContextHolder.MODE_INHERITABLETHREADLOCAL); } @Test @@ -62,7 +60,7 @@ public class SecurityContextHolderTests { fail("Should have rejected null"); } catch (IllegalArgumentException expected) { - } } + } diff --git a/core/src/test/java/org/springframework/security/core/context/SecurityContextImplTests.java b/core/src/test/java/org/springframework/security/core/context/SecurityContextImplTests.java index 2cd01bf098..d6e2f5bbbf 100644 --- a/core/src/test/java/org/springframework/security/core/context/SecurityContextImplTests.java +++ b/core/src/test/java/org/springframework/security/core/context/SecurityContextImplTests.java @@ -16,12 +16,13 @@ package org.springframework.security.core.context; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link SecurityContextImpl}. * @@ -29,8 +30,6 @@ import org.springframework.security.core.Authentication; */ public class SecurityContextImplTests { - // ~ Methods - // ======================================================================================================== @Test public void testEmptyObjectsAreEquals() { SecurityContextImpl obj1 = new SecurityContextImpl(); @@ -46,4 +45,5 @@ public class SecurityContextImplTests { assertThat(context.getAuthentication()).isEqualTo(auth); assertThat(context.toString().lastIndexOf("rod") != -1).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscovererTests.java b/core/src/test/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscovererTests.java index 3627451887..2627f17d47 100644 --- a/core/src/test/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscovererTests.java +++ b/core/src/test/java/org/springframework/security/core/parameters/AnnotationParameterNameDiscovererTests.java @@ -13,141 +13,143 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core.parameters; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.core.parameters; import org.junit.Before; import org.junit.Test; + import org.springframework.security.access.method.P; import org.springframework.util.ReflectionUtils; +import static org.assertj.core.api.Assertions.assertThat; + public class AnnotationParameterNameDiscovererTests { + private AnnotationParameterNameDiscoverer discoverer; @Before public void setup() { - discoverer = new AnnotationParameterNameDiscoverer(P.class.getName()); + this.discoverer = new AnnotationParameterNameDiscoverer(P.class.getName()); } @Test public void getParameterNamesInterfaceSingleParam() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByTo", String.class))).isEqualTo( - new String[] { "to" }); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(Dao.class, "findMessageByTo", String.class))) + .isEqualTo(new String[] { "to" }); } @Test public void getParameterNamesInterfaceSingleParamAnnotatedWithMultiParams() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByToAndFrom", String.class, String.class))) - .isEqualTo(new String[] { "to", null }); + assertThat(this.discoverer.getParameterNames( + ReflectionUtils.findMethod(Dao.class, "findMessageByToAndFrom", String.class, String.class))) + .isEqualTo(new String[] { "to", null }); } @Test public void getParameterNamesInterfaceNoAnnotation() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByIdNoAnnotation", String.class))).isNull(); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(Dao.class, "findMessageByIdNoAnnotation", String.class))) + .isNull(); } @Test public void getParameterNamesClassSingleParam() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByTo", String.class))).isEqualTo( - new String[] { "to" }); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(Dao.class, "findMessageByTo", String.class))) + .isEqualTo(new String[] { "to" }); } @Test public void getParameterNamesClassSingleParamAnnotatedWithMultiParams() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByToAndFrom", String.class, String.class))) - .isEqualTo(new String[] { "to", null }); + assertThat(this.discoverer.getParameterNames( + ReflectionUtils.findMethod(Dao.class, "findMessageByToAndFrom", String.class, String.class))) + .isEqualTo(new String[] { "to", null }); } @Test public void getParameterNamesClassNoAnnotation() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByIdNoAnnotation", String.class))).isNull(); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(Dao.class, "findMessageByIdNoAnnotation", String.class))) + .isNull(); } @Test public void getParameterNamesConstructor() throws Exception { - assertThat(discoverer.getParameterNames(Impl.class.getDeclaredConstructor(String.class))) + assertThat(this.discoverer.getParameterNames(Impl.class.getDeclaredConstructor(String.class))) .isEqualTo(new String[] { "id" }); } @Test public void getParameterNamesConstructorNoAnnotation() throws Exception { - assertThat(discoverer.getParameterNames(Impl.class.getDeclaredConstructor(Long.class))) - .isNull(); + assertThat(this.discoverer.getParameterNames(Impl.class.getDeclaredConstructor(Long.class))).isNull(); } @Test public void getParameterNamesClassAnnotationOnInterface() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(DaoImpl.class, - "findMessageByTo", String.class))).isEqualTo( - new String[] { "to" }); - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByTo", String.class))).isEqualTo( - new String[] { "to" }); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(DaoImpl.class, "findMessageByTo", String.class))) + .isEqualTo(new String[] { "to" }); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(Dao.class, "findMessageByTo", String.class))) + .isEqualTo(new String[] { "to" }); } @Test public void getParameterNamesClassAnnotationOnImpl() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByToAndFrom", String.class, String.class))) - .isEqualTo(new String[] { "to", null }); - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(DaoImpl.class, - "findMessageByToAndFrom", String.class, String.class))) - .isEqualTo(new String[] { "to", "from" }); + assertThat(this.discoverer.getParameterNames( + ReflectionUtils.findMethod(Dao.class, "findMessageByToAndFrom", String.class, String.class))) + .isEqualTo(new String[] { "to", null }); + assertThat(this.discoverer.getParameterNames( + ReflectionUtils.findMethod(DaoImpl.class, "findMessageByToAndFrom", String.class, String.class))) + .isEqualTo(new String[] { "to", "from" }); } @Test public void getParameterNamesClassAnnotationOnBaseClass() { - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(Dao.class, - "findMessageByIdNoAnnotation", String.class))).isNull(); - assertThat( - discoverer.getParameterNames(ReflectionUtils.findMethod(DaoImpl.class, - "findMessageByIdNoAnnotation", String.class))).isEqualTo( - new String[] { "id" }); + assertThat(this.discoverer + .getParameterNames(ReflectionUtils.findMethod(Dao.class, "findMessageByIdNoAnnotation", String.class))) + .isNull(); + assertThat(this.discoverer.getParameterNames( + ReflectionUtils.findMethod(DaoImpl.class, "findMessageByIdNoAnnotation", String.class))) + .isEqualTo(new String[] { "id" }); } interface Dao { + String findMessageByTo(@P("to") String to); String findMessageByToAndFrom(@P("to") String to, String from); String findMessageByIdNoAnnotation(String id); + } static class BaseDaoImpl { + public String findMessageByIdNoAnnotation(@P("id") String id) { return null; } + } static class DaoImpl extends BaseDaoImpl implements Dao { + + @Override public String findMessageByTo(String to) { return null; } + @Override public String findMessageByToAndFrom(@P("to") String to, @P("from") String from) { return null; } + } static class Impl { + Impl(Long dataSourceId) { } @@ -165,5 +167,7 @@ public class AnnotationParameterNameDiscovererTests { String findMessageByIdNoAnnotation(String id) { return null; } + } + } diff --git a/core/src/test/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscovererTests.java b/core/src/test/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscovererTests.java index b929f457f4..bbaf08c7a1 100644 --- a/core/src/test/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscovererTests.java +++ b/core/src/test/java/org/springframework/security/core/parameters/DefaultSecurityParameterNameDiscovererTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core.parameters; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.core.parameters; import java.util.Arrays; import java.util.List; @@ -23,61 +22,55 @@ import java.util.Set; import org.junit.Before; import org.junit.Test; + import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.LocalVariableTableParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Rob Winch * @since 3.2 */ @SuppressWarnings("unchecked") public class DefaultSecurityParameterNameDiscovererTests { + private DefaultSecurityParameterNameDiscoverer discoverer; @Before public void setup() { - discoverer = new DefaultSecurityParameterNameDiscoverer(); + this.discoverer = new DefaultSecurityParameterNameDiscoverer(); } @Test public void constructorDefault() { List discoverers = (List) ReflectionTestUtils - .getField(discoverer, "parameterNameDiscoverers"); - + .getField(this.discoverer, "parameterNameDiscoverers"); assertThat(discoverers).hasSize(2); - ParameterNameDiscoverer annotationDisc = discoverers.get(0); assertThat(annotationDisc).isInstanceOf(AnnotationParameterNameDiscoverer.class); - Set annotationsToUse = (Set) ReflectionTestUtils.getField( - annotationDisc, "annotationClassesToUse"); + Set annotationsToUse = (Set) ReflectionTestUtils.getField(annotationDisc, + "annotationClassesToUse"); assertThat(annotationsToUse).containsOnly("org.springframework.security.access.method.P", P.class.getName()); - - assertThat(discoverers.get(1).getClass()).isEqualTo( - DefaultParameterNameDiscoverer.class); + assertThat(discoverers.get(1).getClass()).isEqualTo(DefaultParameterNameDiscoverer.class); } @Test public void constructorDiscoverers() { - discoverer = new DefaultSecurityParameterNameDiscoverer( + this.discoverer = new DefaultSecurityParameterNameDiscoverer( Arrays.asList(new LocalVariableTableParameterNameDiscoverer())); - List discoverers = (List) ReflectionTestUtils - .getField(discoverer, "parameterNameDiscoverers"); - + .getField(this.discoverer, "parameterNameDiscoverers"); assertThat(discoverers).hasSize(3); - assertThat(discoverers.get(0)).isInstanceOf( - LocalVariableTableParameterNameDiscoverer.class); - + assertThat(discoverers.get(0)).isInstanceOf(LocalVariableTableParameterNameDiscoverer.class); ParameterNameDiscoverer annotationDisc = discoverers.get(1); assertThat(annotationDisc).isInstanceOf(AnnotationParameterNameDiscoverer.class); - Set annotationsToUse = (Set) ReflectionTestUtils.getField( - annotationDisc, "annotationClassesToUse"); + Set annotationsToUse = (Set) ReflectionTestUtils.getField(annotationDisc, + "annotationClassesToUse"); assertThat(annotationsToUse).containsOnly("org.springframework.security.access.method.P", P.class.getName()); - - assertThat(discoverers.get(2)).isInstanceOf( - DefaultParameterNameDiscoverer.class); + assertThat(discoverers.get(2)).isInstanceOf(DefaultParameterNameDiscoverer.class); } + } diff --git a/core/src/test/java/org/springframework/security/core/session/SessionInformationTests.java b/core/src/test/java/org/springframework/security/core/session/SessionInformationTests.java index 3cfcf86c5b..626fcea0ed 100644 --- a/core/src/test/java/org/springframework/security/core/session/SessionInformationTests.java +++ b/core/src/test/java/org/springframework/security/core/session/SessionInformationTests.java @@ -16,12 +16,12 @@ package org.springframework.security.core.session; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Date; import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link SessionInformation}. * @@ -29,24 +29,18 @@ import org.junit.Test; */ public class SessionInformationTests { - // ~ Methods - // ======================================================================================================== @Test public void testObject() throws Exception { Object principal = "Some principal object"; String sessionId = "1234567890"; Date currentDate = new Date(); - - SessionInformation info = new SessionInformation(principal, sessionId, - currentDate); + SessionInformation info = new SessionInformation(principal, sessionId, currentDate); assertThat(info.getPrincipal()).isEqualTo(principal); assertThat(info.getSessionId()).isEqualTo(sessionId); assertThat(info.getLastRequest()).isEqualTo(currentDate); - Thread.sleep(10); - info.refreshLastRequest(); - assertThat(info.getLastRequest().after(currentDate)).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/core/session/SessionRegistryImplTests.java b/core/src/test/java/org/springframework/security/core/session/SessionRegistryImplTests.java index 7ed19b2bc6..df9ea8376d 100644 --- a/core/src/test/java/org/springframework/security/core/session/SessionRegistryImplTests.java +++ b/core/src/test/java/org/springframework/security/core/session/SessionRegistryImplTests.java @@ -16,17 +16,15 @@ package org.springframework.security.core.session; -import static org.assertj.core.api.Assertions.*; - import java.util.Date; import java.util.List; import org.junit.Before; import org.junit.Test; + import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.session.SessionDestroyedEvent; -import org.springframework.security.core.session.SessionInformation; -import org.springframework.security.core.session.SessionRegistryImpl; + +import static org.assertj.core.api.Assertions.assertThat; /** * Tests {@link SessionRegistryImpl}. @@ -34,26 +32,22 @@ import org.springframework.security.core.session.SessionRegistryImpl; * @author Ben Alex */ public class SessionRegistryImplTests { - private SessionRegistryImpl sessionRegistry; - // ~ Methods - // ======================================================================================================== + private SessionRegistryImpl sessionRegistry; @Before public void setUp() { - sessionRegistry = new SessionRegistryImpl(); + this.sessionRegistry = new SessionRegistryImpl(); } @Test public void sessionDestroyedEventRemovesSessionFromRegistry() { Object principal = "Some principal object"; final String sessionId = "zzzz"; - // Register new Session - sessionRegistry.registerNewSession(sessionId, principal); - + this.sessionRegistry.registerNewSession(sessionId, principal); // De-register session via an ApplicationEvent - sessionRegistry.onApplicationEvent(new SessionDestroyedEvent("") { + this.sessionRegistry.onApplicationEvent(new SessionDestroyedEvent("") { @Override public String getId() { return sessionId; @@ -64,9 +58,8 @@ public class SessionRegistryImplTests { return null; } }); - // Check attempts to retrieve cleared session return null - assertThat(sessionRegistry.getSessionInformation(sessionId)).isNull(); + assertThat(this.sessionRegistry.getSessionInformation(sessionId)).isNull(); } @Test @@ -74,12 +67,10 @@ public class SessionRegistryImplTests { Object principal = "Some principal object"; final String sessionId = "zzzz"; final String newSessionId = "123"; - // Register new Session - sessionRegistry.registerNewSession(sessionId, principal); - + this.sessionRegistry.registerNewSession(sessionId, principal); // De-register session via an ApplicationEvent - sessionRegistry.onApplicationEvent(new SessionIdChangedEvent("") { + this.sessionRegistry.onApplicationEvent(new SessionIdChangedEvent("") { @Override public String getOldSessionId() { return sessionId; @@ -90,10 +81,9 @@ public class SessionRegistryImplTests { return newSessionId; } }); - - assertThat(sessionRegistry.getSessionInformation(sessionId)).isNull(); - assertThat(sessionRegistry.getSessionInformation(newSessionId)).isNotNull(); - assertThat(sessionRegistry.getSessionInformation(newSessionId).getPrincipal()).isEqualTo(principal); + assertThat(this.sessionRegistry.getSessionInformation(sessionId)).isNull(); + assertThat(this.sessionRegistry.getSessionInformation(newSessionId)).isNotNull(); + assertThat(this.sessionRegistry.getSessionInformation(newSessionId).getPrincipal()).isEqualTo(principal); } @Test @@ -103,14 +93,12 @@ public class SessionRegistryImplTests { String sessionId1 = "1234567890"; String sessionId2 = "9876543210"; String sessionId3 = "5432109876"; - - sessionRegistry.registerNewSession(sessionId1, principal1); - sessionRegistry.registerNewSession(sessionId2, principal1); - sessionRegistry.registerNewSession(sessionId3, principal2); - - assertThat(sessionRegistry.getAllPrincipals()).hasSize(2); - assertThat(sessionRegistry.getAllPrincipals().contains(principal1)).isTrue(); - assertThat(sessionRegistry.getAllPrincipals().contains(principal2)).isTrue(); + this.sessionRegistry.registerNewSession(sessionId1, principal1); + this.sessionRegistry.registerNewSession(sessionId2, principal1); + this.sessionRegistry.registerNewSession(sessionId3, principal2); + assertThat(this.sessionRegistry.getAllPrincipals()).hasSize(2); + assertThat(this.sessionRegistry.getAllPrincipals().contains(principal1)).isTrue(); + assertThat(this.sessionRegistry.getAllPrincipals().contains(principal2)).isTrue(); } @Test @@ -118,37 +106,28 @@ public class SessionRegistryImplTests { Object principal = "Some principal object"; String sessionId = "1234567890"; // Register new Session - sessionRegistry.registerNewSession(sessionId, principal); - + this.sessionRegistry.registerNewSession(sessionId, principal); // Retrieve existing session by session ID - Date currentDateTime = sessionRegistry.getSessionInformation(sessionId) - .getLastRequest(); - assertThat(sessionRegistry.getSessionInformation(sessionId).getPrincipal()).isEqualTo(principal); - assertThat(sessionRegistry.getSessionInformation(sessionId).getSessionId()).isEqualTo(sessionId); - assertThat(sessionRegistry.getSessionInformation(sessionId).getLastRequest()).isNotNull(); - + Date currentDateTime = this.sessionRegistry.getSessionInformation(sessionId).getLastRequest(); + assertThat(this.sessionRegistry.getSessionInformation(sessionId).getPrincipal()).isEqualTo(principal); + assertThat(this.sessionRegistry.getSessionInformation(sessionId).getSessionId()).isEqualTo(sessionId); + assertThat(this.sessionRegistry.getSessionInformation(sessionId).getLastRequest()).isNotNull(); // Retrieve existing session by principal - assertThat(sessionRegistry.getAllSessions(principal, false)).hasSize(1); - + assertThat(this.sessionRegistry.getAllSessions(principal, false)).hasSize(1); // Sleep to ensure SessionRegistryImpl will update time Thread.sleep(1000); - // Update request date/time - sessionRegistry.refreshLastRequest(sessionId); - - Date retrieved = sessionRegistry.getSessionInformation(sessionId) - .getLastRequest(); + this.sessionRegistry.refreshLastRequest(sessionId); + Date retrieved = this.sessionRegistry.getSessionInformation(sessionId).getLastRequest(); assertThat(retrieved.after(currentDateTime)).isTrue(); - // Check it retrieves correctly when looked up via principal - assertThat(sessionRegistry.getAllSessions(principal, false).get(0).getLastRequest()).isCloseTo(retrieved, 2000L); - + assertThat(this.sessionRegistry.getAllSessions(principal, false).get(0).getLastRequest()).isCloseTo(retrieved, + 2000L); // Clear session information - sessionRegistry.removeSessionInformation(sessionId); - + this.sessionRegistry.removeSessionInformation(sessionId); // Check attempts to retrieve cleared session return null - assertThat(sessionRegistry.getSessionInformation(sessionId)).isNull(); - assertThat(sessionRegistry.getAllSessions(principal, false)).isEmpty(); + assertThat(this.sessionRegistry.getSessionInformation(sessionId)).isNull(); + assertThat(this.sessionRegistry.getAllSessions(principal, false)).isEmpty(); } @Test @@ -156,25 +135,20 @@ public class SessionRegistryImplTests { Object principal = "Some principal object"; String sessionId1 = "1234567890"; String sessionId2 = "9876543210"; - - sessionRegistry.registerNewSession(sessionId1, principal); - List sessions = sessionRegistry.getAllSessions(principal, - false); + this.sessionRegistry.registerNewSession(sessionId1, principal); + List sessions = this.sessionRegistry.getAllSessions(principal, false); assertThat(sessions).hasSize(1); assertThat(contains(sessionId1, principal)).isTrue(); - - sessionRegistry.registerNewSession(sessionId2, principal); - sessions = sessionRegistry.getAllSessions(principal, false); + this.sessionRegistry.registerNewSession(sessionId2, principal); + sessions = this.sessionRegistry.getAllSessions(principal, false); assertThat(sessions).hasSize(2); assertThat(contains(sessionId2, principal)).isTrue(); - // Expire one session - SessionInformation session = sessionRegistry.getSessionInformation(sessionId2); + SessionInformation session = this.sessionRegistry.getSessionInformation(sessionId2); session.expireNow(); - // Check retrieval still correct - assertThat(sessionRegistry.getSessionInformation(sessionId2).isExpired()).isTrue(); - assertThat(sessionRegistry.getSessionInformation(sessionId1).isExpired()).isFalse(); + assertThat(this.sessionRegistry.getSessionInformation(sessionId2).isExpired()).isTrue(); + assertThat(this.sessionRegistry.getSessionInformation(sessionId1).isExpired()).isFalse(); } @Test @@ -182,37 +156,31 @@ public class SessionRegistryImplTests { Object principal = "Some principal object"; String sessionId1 = "1234567890"; String sessionId2 = "9876543210"; - - sessionRegistry.registerNewSession(sessionId1, principal); - List sessions = sessionRegistry.getAllSessions(principal, - false); + this.sessionRegistry.registerNewSession(sessionId1, principal); + List sessions = this.sessionRegistry.getAllSessions(principal, false); assertThat(sessions).hasSize(1); assertThat(contains(sessionId1, principal)).isTrue(); - - sessionRegistry.registerNewSession(sessionId2, principal); - sessions = sessionRegistry.getAllSessions(principal, false); + this.sessionRegistry.registerNewSession(sessionId2, principal); + sessions = this.sessionRegistry.getAllSessions(principal, false); assertThat(sessions).hasSize(2); assertThat(contains(sessionId2, principal)).isTrue(); - - sessionRegistry.removeSessionInformation(sessionId1); - sessions = sessionRegistry.getAllSessions(principal, false); + this.sessionRegistry.removeSessionInformation(sessionId1); + sessions = this.sessionRegistry.getAllSessions(principal, false); assertThat(sessions).hasSize(1); assertThat(contains(sessionId2, principal)).isTrue(); - - sessionRegistry.removeSessionInformation(sessionId2); - assertThat(sessionRegistry.getSessionInformation(sessionId2)).isNull(); - assertThat(sessionRegistry.getAllSessions(principal, false)).isEmpty(); + this.sessionRegistry.removeSessionInformation(sessionId2); + assertThat(this.sessionRegistry.getSessionInformation(sessionId2)).isNull(); + assertThat(this.sessionRegistry.getAllSessions(principal, false)).isEmpty(); } private boolean contains(String sessionId, Object principal) { - List info = sessionRegistry.getAllSessions(principal, false); - + List info = this.sessionRegistry.getAllSessions(principal, false); for (SessionInformation sessionInformation : info) { if (sessionId.equals(sessionInformation.getSessionId())) { return true; } } - return false; } + } diff --git a/core/src/test/java/org/springframework/security/core/token/DefaultTokenTests.java b/core/src/test/java/org/springframework/security/core/token/DefaultTokenTests.java index 6d2adf21a4..1dc583b84b 100644 --- a/core/src/test/java/org/springframework/security/core/token/DefaultTokenTests.java +++ b/core/src/test/java/org/springframework/security/core/token/DefaultTokenTests.java @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; import java.util.Date; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; -import org.springframework.security.core.token.DefaultToken; + +import static org.assertj.core.api.Assertions.assertThat; /** * Tests {@link DefaultToken}. @@ -29,12 +29,12 @@ import org.springframework.security.core.token.DefaultToken; * */ public class DefaultTokenTests { + @Test public void testEquality() { String key = "key"; long created = new Date().getTime(); String extendedInformation = "extended"; - DefaultToken t1 = new DefaultToken(key, created, extendedInformation); DefaultToken t2 = new DefaultToken(key, created, extendedInformation); assertThat(t2).isEqualTo(t1); @@ -51,9 +51,9 @@ public class DefaultTokenTests { public void testEqualityWithDifferentExtendedInformation3() { String key = "key"; long created = new Date().getTime(); - DefaultToken t1 = new DefaultToken(key, created, "length1"); DefaultToken t2 = new DefaultToken(key, created, "longerLength2"); assertThat(t1).isNotEqualTo(t2); } + } diff --git a/core/src/test/java/org/springframework/security/core/token/KeyBasedPersistenceTokenServiceTests.java b/core/src/test/java/org/springframework/security/core/token/KeyBasedPersistenceTokenServiceTests.java index 4b9cfc0fb0..fd4efeebb9 100644 --- a/core/src/test/java/org/springframework/security/core/token/KeyBasedPersistenceTokenServiceTests.java +++ b/core/src/test/java/org/springframework/security/core/token/KeyBasedPersistenceTokenServiceTests.java @@ -13,18 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.core.token; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.core.token; import java.security.SecureRandom; import java.util.Date; import org.junit.Test; -import org.springframework.security.core.token.DefaultToken; -import org.springframework.security.core.token.KeyBasedPersistenceTokenService; -import org.springframework.security.core.token.SecureRandomFactoryBean; -import org.springframework.security.core.token.Token; + +import static org.assertj.core.api.Assertions.assertThat; /** * Tests {@link KeyBasedPersistenceTokenService}. @@ -44,8 +41,8 @@ public class KeyBasedPersistenceTokenServiceTests { service.setSecureRandom(rnd); service.afterPropertiesSet(); } - catch (Exception e) { - throw new RuntimeException(e); + catch (Exception ex) { + throw new RuntimeException(ex); } return service; } @@ -98,4 +95,5 @@ public class KeyBasedPersistenceTokenServiceTests { Token token = new DefaultToken(fake, new Date().getTime(), ""); service.verifyToken(token.getKey()); } + } diff --git a/core/src/test/java/org/springframework/security/core/token/SecureRandomFactoryBeanTests.java b/core/src/test/java/org/springframework/security/core/token/SecureRandomFactoryBeanTests.java index da7144dd11..b3ab4cc491 100644 --- a/core/src/test/java/org/springframework/security/core/token/SecureRandomFactoryBeanTests.java +++ b/core/src/test/java/org/springframework/security/core/token/SecureRandomFactoryBeanTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.token; import java.security.SecureRandom; @@ -31,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThat; * */ public class SecureRandomFactoryBeanTests { + @Test public void testObjectType() { SecureRandomFactoryBean factory = new SecureRandomFactoryBean(); @@ -61,9 +63,7 @@ public class SecureRandomFactoryBeanTests { factory.setSeed(resource); SecureRandom first = factory.getObject(); SecureRandom second = factory.getObject(); - assertThat(first.nextInt()) - .isNotEqualTo(0) - .isNotEqualTo(second.nextInt()); + assertThat(first.nextInt()).isNotEqualTo(0).isNotEqualTo(second.nextInt()); } } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsServiceTests.java b/core/src/test/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsServiceTests.java index 97154ac416..a6f8760871 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsServiceTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/MapReactiveUserDetailsServiceTests.java @@ -13,25 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails; - -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Arrays; import java.util.Collection; import java.util.Collections; import org.junit.Test; - import reactor.core.publisher.Mono; +import static org.assertj.core.api.Assertions.assertThat; + public class MapReactiveUserDetailsServiceTests { + + // @formatter:off private static final UserDetails USER_DETAILS = User.withUsername("user") .password("password") .roles("USER") .build(); - + // @formatter:on private MapReactiveUserDetailsService users = new MapReactiveUserDetailsService(Arrays.asList(USER_DETAILS)); @Test(expected = IllegalArgumentException.class) @@ -55,33 +56,33 @@ public class MapReactiveUserDetailsServiceTests { @Test public void findByUsernameWhenFoundThenReturns() { - assertThat((users.findByUsername(USER_DETAILS.getUsername()).block())).isEqualTo(USER_DETAILS); + assertThat((this.users.findByUsername(USER_DETAILS.getUsername()).block())).isEqualTo(USER_DETAILS); } @Test public void findByUsernameWhenDifferentCaseThenReturns() { - assertThat((users.findByUsername("uSeR").block())).isEqualTo(USER_DETAILS); + assertThat((this.users.findByUsername("uSeR").block())).isEqualTo(USER_DETAILS); } @Test public void findByUsernameWhenClearCredentialsThenFindByUsernameStillHasCredentials() { - User foundUser = users.findByUsername(USER_DETAILS.getUsername()).cast(User.class).block(); + User foundUser = this.users.findByUsername(USER_DETAILS.getUsername()).cast(User.class).block(); assertThat(foundUser.getPassword()).isNotEmpty(); foundUser.eraseCredentials(); assertThat(foundUser.getPassword()).isNull(); - - foundUser = users.findByUsername(USER_DETAILS.getUsername()).cast(User.class).block(); + foundUser = this.users.findByUsername(USER_DETAILS.getUsername()).cast(User.class).block(); assertThat(foundUser.getPassword()).isNotEmpty(); } @Test public void findByUsernameWhenNotFoundThenEmpty() { - assertThat((users.findByUsername("notfound"))).isEqualTo(Mono.empty()); + assertThat((this.users.findByUsername("notfound"))).isEqualTo(Mono.empty()); } @Test public void updatePassword() { - users.updatePassword(USER_DETAILS, "new").block(); - assertThat(users.findByUsername(USER_DETAILS.getUsername()).block().getPassword()).isEqualTo("new"); + this.users.updatePassword(USER_DETAILS, "new").block(); + assertThat(this.users.findByUsername(USER_DETAILS.getUsername()).block().getPassword()).isEqualTo("new"); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/MockUserDetailsService.java b/core/src/test/java/org/springframework/security/core/userdetails/MockUserDetailsService.java index 98e81972ec..5a771a6ed9 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/MockUserDetailsService.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/MockUserDetailsService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails; import java.util.HashMap; @@ -30,24 +31,25 @@ import org.springframework.security.core.authority.AuthorityUtils; * @author Luke Taylor */ public class MockUserDetailsService implements UserDetailsService { + private Map users = new HashMap<>(); - private List auths = AuthorityUtils - .createAuthorityList("ROLE_USER"); + + private List auths = AuthorityUtils.createAuthorityList("ROLE_USER"); public MockUserDetailsService() { - users.put("valid", new User("valid", "", true, true, true, true, auths)); - users.put("locked", new User("locked", "", true, true, true, false, auths)); - users.put("disabled", new User("disabled", "", false, true, true, true, auths)); - users.put("credentialsExpired", new User("credentialsExpired", "", true, true, - false, true, auths)); - users.put("expired", new User("expired", "", true, false, true, true, auths)); + this.users.put("valid", new User("valid", "", true, true, true, true, this.auths)); + this.users.put("locked", new User("locked", "", true, true, true, false, this.auths)); + this.users.put("disabled", new User("disabled", "", false, true, true, true, this.auths)); + this.users.put("credentialsExpired", new User("credentialsExpired", "", true, true, false, true, this.auths)); + this.users.put("expired", new User("expired", "", true, false, true, true, this.auths)); } + @Override public UserDetails loadUserByUsername(String username) { - if (users.get(username) == null) { + if (this.users.get(username) == null) { throw new UsernameNotFoundException("User not found: " + username); } - - return users.get(username); + return this.users.get(username); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/PasswordEncodedUser.java b/core/src/test/java/org/springframework/security/core/userdetails/PasswordEncodedUser.java index b5261ba561..d273cb9e6d 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/PasswordEncodedUser.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/PasswordEncodedUser.java @@ -23,6 +23,7 @@ import java.util.function.Function; * @since 5.0 */ public class PasswordEncodedUser { + private static final UserDetails USER = withUsername("user").password("password").roles("USER").build(); private static final UserDetails ADMIN = withUsername("admin").password("password").roles("USER", "ADMIN").build(); @@ -44,13 +45,18 @@ public class PasswordEncodedUser { } public static User.UserBuilder withUserDetails(UserDetails userDetails) { - return User.withUserDetails(userDetails) - .passwordEncoder(passwordEncoder()); + // @formatter:off + return User + .withUserDetails(userDetails) + .passwordEncoder(passwordEncoder()); + // @formatter:on } private static Function passwordEncoder() { - return rawPassword -> "{noop}" + rawPassword; + return (rawPassword) -> "{noop}" + rawPassword; + } + + protected PasswordEncodedUser() { } - protected PasswordEncodedUser() {} } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapperTests.java b/core/src/test/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapperTests.java index 6da8ad2af0..8364a1f360 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapperTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/UserDetailsByNameServiceWrapperTests.java @@ -13,21 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.core.userdetails; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** - * * @author TSARDD * @since 18-okt-2007 */ @SuppressWarnings("unchecked") -public class UserDetailsByNameServiceWrapperTests { +public class UserDetailsByNameServiceWrapperTests { @Test public final void testAfterPropertiesSet() { @@ -46,9 +48,8 @@ public class UserDetailsByNameServiceWrapperTests { @Test public final void testGetUserDetails() throws Exception { UserDetailsByNameServiceWrapper svc = new UserDetailsByNameServiceWrapper(); - final User user = new User("dummy", "dummy", true, true, true, true, - AuthorityUtils.NO_AUTHORITIES); - svc.setUserDetailsService(name -> { + final User user = new User("dummy", "dummy", true, true, true, true, AuthorityUtils.NO_AUTHORITIES); + svc.setUserDetailsService((name) -> { if (user != null && user.getUsername().equals(name)) { return user; } @@ -57,11 +58,9 @@ public class UserDetailsByNameServiceWrapperTests { } }); svc.afterPropertiesSet(); - UserDetails result1 = svc.loadUserDetails(new TestingAuthenticationToken("dummy", - "dummy")); + UserDetails result1 = svc.loadUserDetails(new TestingAuthenticationToken("dummy", "dummy")); assertThat(result1).as("Result doesn't match original user").isEqualTo(user); - UserDetails result2 = svc.loadUserDetails(new TestingAuthenticationToken( - "dummy2", "dummy")); + UserDetails result2 = svc.loadUserDetails(new TestingAuthenticationToken("dummy2", "dummy")); assertThat(result2).as("Result should have been null").isNull(); } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/UserTests.java b/core/src/test/java/org/springframework/security/core/userdetails/UserTests.java index 00478f6474..4ac32d65ba 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/UserTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/UserTests.java @@ -16,8 +16,6 @@ package org.springframework.security.core.userdetails; -import static org.assertj.core.api.Assertions.*; - import java.io.ByteArrayOutputStream; import java.io.ObjectOutputStream; import java.util.HashSet; @@ -26,31 +24,30 @@ import java.util.Set; import java.util.function.Function; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link User}. * * @author Ben Alex */ public class UserTests { - private static final List ROLE_12 = AuthorityUtils - .createAuthorityList("ROLE_ONE", "ROLE_TWO"); - // ~ Methods - // ======================================================================================================== + private static final List ROLE_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); @Test public void equalsReturnsTrueIfUsernamesAreTheSame() { User user1 = new User("rod", "koala", true, true, true, true, ROLE_12); - assertThat(user1).isNotNull(); assertThat(user1).isNotEqualTo("A STRING"); assertThat(user1).isEqualTo(user1); - assertThat(user1).isEqualTo((new User("rod", "notthesame", true, true, true, true, - ROLE_12))); + assertThat(user1).isEqualTo((new User("rod", "notthesame", true, true, true, true, ROLE_12))); } @Test @@ -58,19 +55,15 @@ public class UserTests { User user1 = new User("rod", "koala", true, true, true, true, ROLE_12); Set users = new HashSet<>(); users.add(user1); - - assertThat(users).contains(new User("rod", "koala", true, true, true, true, - ROLE_12)); - assertThat(users).contains(new User("rod", "anotherpass", false, false, false, - false, AuthorityUtils.createAuthorityList("ROLE_X"))); - assertThat(users).doesNotContain(new User("bod", "koala", true, true, true, true, - ROLE_12)); + assertThat(users).contains(new User("rod", "koala", true, true, true, true, ROLE_12)); + assertThat(users).contains(new User("rod", "anotherpass", false, false, false, false, + AuthorityUtils.createAuthorityList("ROLE_X"))); + assertThat(users).doesNotContain(new User("bod", "koala", true, true, true, true, ROLE_12)); } @Test public void testNoArgConstructorDoesntExist() { Class clazz = User.class; - try { clazz.getDeclaredConstructor((Class[]) null); fail("Should have thrown NoSuchMethodException"); @@ -87,14 +80,12 @@ public class UserTests { } catch (IllegalArgumentException expected) { } - try { new User("rod", null, true, true, true, true, ROLE_12); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { } - try { List auths = AuthorityUtils.createAuthorityList("ROLE_ONE"); auths.add(null); @@ -125,10 +116,8 @@ public class UserTests { assertThat(user.getUsername()).isEqualTo("rod"); assertThat(user.getPassword()).isEqualTo("koala"); assertThat(user.isEnabled()).isTrue(); - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains( - "ROLE_ONE"); - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains( - "ROLE_TWO"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_ONE"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_TWO"); assertThat(user.toString()).contains("rod"); } @@ -151,9 +140,7 @@ public class UserTests { @Test public void withUserDetailsWhenAllEnabled() { User expected = new User("rob", "pass", true, true, true, true, ROLE_12); - UserDetails actual = User.withUserDetails(expected).build(); - assertThat(actual.getUsername()).isEqualTo(expected.getUsername()); assertThat(actual.getPassword()).isEqualTo(expected.getPassword()); assertThat(actual.getAuthorities()).isEqualTo(expected.getAuthorities()); @@ -163,13 +150,10 @@ public class UserTests { assertThat(actual.isEnabled()).isEqualTo(expected.isEnabled()); } - @Test public void withUserDetailsWhenAllDisabled() { User expected = new User("rob", "pass", false, false, false, false, ROLE_12); - UserDetails actual = User.withUserDetails(expected).build(); - assertThat(actual.getUsername()).isEqualTo(expected.getUsername()); assertThat(actual.getPassword()).isEqualTo(expected.getPassword()); assertThat(actual.getAuthorities()).isEqualTo(expected.getAuthorities()); @@ -182,46 +166,42 @@ public class UserTests { @Test public void withUserWhenDetailsPasswordEncoderThenEncodes() { UserDetails userDetails = User.withUsername("user").password("password").roles("USER").build(); - - UserDetails withEncodedPassword = User.withUserDetails(userDetails) - .passwordEncoder(p -> p + "encoded") - .build(); - + UserDetails withEncodedPassword = User.withUserDetails(userDetails).passwordEncoder((p) -> p + "encoded") + .build(); assertThat(withEncodedPassword.getPassword()).isEqualTo("passwordencoded"); } @Test public void withUsernameWhenPasswordEncoderAndPasswordThenEncodes() { - UserDetails withEncodedPassword = User.withUsername("user") - .password("password") - .passwordEncoder(p -> p + "encoded") - .roles("USER") - .build(); - + UserDetails withEncodedPassword = User.withUsername("user").password("password") + .passwordEncoder((p) -> p + "encoded").roles("USER").build(); assertThat(withEncodedPassword.getPassword()).isEqualTo("passwordencoded"); } @Test public void withUsernameWhenPasswordAndPasswordEncoderThenEncodes() { + // @formatter:off UserDetails withEncodedPassword = User.withUsername("user") - .passwordEncoder(p -> p + "encoded") + .passwordEncoder((p) -> p + "encoded") .password("password") .roles("USER") .build(); - + // @formatter:on assertThat(withEncodedPassword.getPassword()).isEqualTo("passwordencoded"); } @Test public void withUsernameWhenPasswordAndPasswordEncoderTwiceThenEncodesOnce() { - Function encoder = p -> p + "encoded"; + Function encoder = (p) -> p + "encoded"; + // @formatter:off UserDetails withEncodedPassword = User.withUsername("user") .passwordEncoder(encoder) .password("password") .passwordEncoder(encoder) .roles("USER") .build(); - + // @formatter:on assertThat(withEncodedPassword.getPassword()).isEqualTo("passwordencoded"); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCacheTests.java b/core/src/test/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCacheTests.java index 2336eb7ed5..45f1887853 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCacheTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/cache/EhCacheBasedUserCacheTests.java @@ -16,17 +16,18 @@ package org.springframework.security.core.userdetails.cache; -import static org.assertj.core.api.Assertions.*; import net.sf.ehcache.Cache; import net.sf.ehcache.CacheManager; import net.sf.ehcache.Ehcache; - import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; + import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.User; -import org.springframework.security.core.userdetails.cache.EhCacheBasedUserCache; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** * Tests {@link EhCacheBasedUserCache}. @@ -34,15 +35,13 @@ import org.springframework.security.core.userdetails.cache.EhCacheBasedUserCache * @author Ben Alex */ public class EhCacheBasedUserCacheTests { + private static CacheManager cacheManager; - // ~ Methods - // ======================================================================================================== @BeforeClass public static void initCacheManaer() { cacheManager = CacheManager.create(); - cacheManager.addCache(new Cache("ehcacheusercachetests", 500, false, false, 30, - 30)); + cacheManager.addCache(new Cache("ehcacheusercachetests", 500, false, false, 30, 30)); } @AfterClass @@ -54,7 +53,6 @@ public class EhCacheBasedUserCacheTests { private Ehcache getCache() { Ehcache cache = cacheManager.getCache("ehcacheusercachetests"); cache.removeAll(); - return cache; } @@ -68,15 +66,12 @@ public class EhCacheBasedUserCacheTests { EhCacheBasedUserCache cache = new EhCacheBasedUserCache(); cache.setCache(getCache()); cache.afterPropertiesSet(); - // Check it gets stored in the cache cache.putUserInCache(getUser()); assertThat(getUser().getPassword()).isEqualTo(cache.getUserFromCache(getUser().getUsername()).getPassword()); - // Check it gets removed from the cache cache.removeUserFromCache(getUser()); assertThat(cache.getUserFromCache(getUser().getUsername())).isNull(); - // Check it doesn't return values for null or unknown users assertThat(cache.getUserFromCache(null)).isNull(); assertThat(cache.getUserFromCache("UNKNOWN_USER")).isNull(); @@ -85,12 +80,11 @@ public class EhCacheBasedUserCacheTests { @Test(expected = IllegalArgumentException.class) public void startupDetectsMissingCache() throws Exception { EhCacheBasedUserCache cache = new EhCacheBasedUserCache(); - cache.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); - Ehcache myCache = getCache(); cache.setCache(myCache); assertThat(cache.getCache()).isEqualTo(myCache); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/cache/NullUserCacheTests.java b/core/src/test/java/org/springframework/security/core/userdetails/cache/NullUserCacheTests.java index 809edd2bca..4b626f5bc7 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/cache/NullUserCacheTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/cache/NullUserCacheTests.java @@ -16,12 +16,13 @@ package org.springframework.security.core.userdetails.cache; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.User; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link NullUserCache}. * @@ -29,9 +30,6 @@ import org.springframework.security.core.userdetails.User; */ public class NullUserCacheTests { - // ~ Methods - // ======================================================================================================== - private User getUser() { return new User("john", "password", true, true, true, true, AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO")); @@ -44,4 +42,5 @@ public class NullUserCacheTests { assertThat(cache.getUserFromCache(null)).isNull(); cache.removeUserFromCache(null); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCacheTests.java b/core/src/test/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCacheTests.java index 1346bc879f..44bf44bade 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCacheTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/cache/SpringCacheBasedUserCacheTests.java @@ -19,13 +19,14 @@ package org.springframework.security.core.userdetails.cache; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; + import org.springframework.cache.Cache; import org.springframework.cache.CacheManager; import org.springframework.cache.concurrent.ConcurrentMapCacheManager; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.User; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * Tests @@ -36,10 +37,9 @@ import static org.assertj.core.api.Assertions.*; * */ public class SpringCacheBasedUserCacheTests { + private static CacheManager cacheManager; - // ~ Methods - // ======================================================================================================== @BeforeClass public static void initCacheManaer() { cacheManager = new ConcurrentMapCacheManager(); @@ -64,15 +64,12 @@ public class SpringCacheBasedUserCacheTests { @Test public void cacheOperationsAreSuccessful() throws Exception { SpringCacheBasedUserCache cache = new SpringCacheBasedUserCache(getCache()); - // Check it gets stored in the cache cache.putUserInCache(getUser()); assertThat(getUser().getPassword()).isEqualTo(cache.getUserFromCache(getUser().getUsername()).getPassword()); - // Check it gets removed from the cache cache.removeUserFromCache(getUser()); assertThat(cache.getUserFromCache(getUser().getUsername())).isNull(); - // Check it doesn't return values for null or unknown users assertThat(cache.getUserFromCache(null)).isNull(); assertThat(cache.getUserFromCache("UNKNOWN_USER")).isNull(); @@ -82,4 +79,5 @@ public class SpringCacheBasedUserCacheTests { public void startupDetectsMissingCache() throws Exception { new SpringCacheBasedUserCache(null); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImplTests.java b/core/src/test/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImplTests.java index e7a479e2b6..90f45656b4 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImplTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/jdbc/JdbcDaoImplTests.java @@ -39,14 +39,10 @@ import static org.mockito.Mockito.verify; */ public class JdbcDaoImplTests { - // ~ Methods - // ======================================================================================================== - private JdbcDaoImpl makePopulatedJdbcDao() { JdbcDaoImpl dao = new JdbcDaoImpl(); dao.setDataSource(PopulatedDatabase.getDataSource()); dao.afterPropertiesSet(); - return dao; } @@ -55,7 +51,6 @@ public class JdbcDaoImplTests { dao.setDataSource(PopulatedDatabase.getDataSource()); dao.setRolePrefix("ARBITRARY_PREFIX_"); dao.afterPropertiesSet(); - return dao; } @@ -66,21 +61,16 @@ public class JdbcDaoImplTests { assertThat(user.getUsername()).isEqualTo("rod"); assertThat(user.getPassword()).isEqualTo("koala"); assertThat(user.isEnabled()).isTrue(); - - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ROLE_TELLER"); - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ROLE_SUPERVISOR"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_TELLER"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_SUPERVISOR"); } @Test - public void testCheckDaoOnlyReturnsGrantedAuthoritiesGrantedToUser() - throws Exception { + public void testCheckDaoOnlyReturnsGrantedAuthoritiesGrantedToUser() throws Exception { JdbcDaoImpl dao = makePopulatedJdbcDao(); UserDetails user = dao.loadUserByUsername("scott"); assertThat(user.getAuthorities()).hasSize(1); - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ROLE_TELLER"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_TELLER"); } @Test @@ -95,7 +85,6 @@ public class JdbcDaoImplTests { JdbcDaoImpl dao = new JdbcDaoImpl(); dao.setAuthoritiesByUsernameQuery("SELECT * FROM FOO"); assertThat(dao.getAuthoritiesByUsernameQuery()).isEqualTo("SELECT * FROM FOO"); - dao.setUsersByUsernameQuery("SELECT USERS FROM FOO"); assertThat(dao.getUsersByUsernameQuery()).isEqualTo("SELECT USERS FROM FOO"); } @@ -103,7 +92,6 @@ public class JdbcDaoImplTests { @Test public void testLookupFailsIfUserHasNoGrantedAuthorities() throws Exception { JdbcDaoImpl dao = makePopulatedJdbcDao(); - try { dao.loadUserByUsername("cooper"); fail("Should have thrown UsernameNotFoundException"); @@ -115,13 +103,11 @@ public class JdbcDaoImplTests { @Test public void testLookupFailsWithWrongUsername() throws Exception { JdbcDaoImpl dao = makePopulatedJdbcDao(); - try { dao.loadUserByUsername("UNKNOWN_USER"); fail("Should have thrown UsernameNotFoundException"); } catch (UsernameNotFoundException expected) { - } } @@ -136,13 +122,10 @@ public class JdbcDaoImplTests { public void testRolePrefixWorks() throws Exception { JdbcDaoImpl dao = makePopulatedJdbcDaoWithRolePrefix(); assertThat(dao.getRolePrefix()).isEqualTo("ARBITRARY_PREFIX_"); - UserDetails user = dao.loadUserByUsername("rod"); assertThat(user.getUsername()).isEqualTo("rod"); assertThat(user.getAuthorities()).hasSize(2); - - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ARBITRARY_PREFIX_ROLE_TELLER"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ARBITRARY_PREFIX_ROLE_TELLER"); assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) .contains("ARBITRARY_PREFIX_ROLE_SUPERVISOR"); } @@ -152,7 +135,6 @@ public class JdbcDaoImplTests { JdbcDaoImpl dao = makePopulatedJdbcDao(); dao.setEnableAuthorities(false); dao.setEnableGroups(true); - UserDetails jerry = dao.loadUserByUsername("jerry"); assertThat(jerry.getAuthorities()).hasSize(3); } @@ -170,34 +152,29 @@ public class JdbcDaoImplTests { @Test public void testStartupFailsIfDataSourceNotSet() { JdbcDaoImpl dao = new JdbcDaoImpl(); - try { dao.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @Test public void testStartupFailsIfUserMapSetToNull() { JdbcDaoImpl dao = new JdbcDaoImpl(); - try { dao.setDataSource(null); dao.afterPropertiesSet(); fail("Should have thrown IllegalArgumentException"); } catch (IllegalArgumentException expected) { - } } @Test(expected = IllegalArgumentException.class) public void setMessageSourceWhenNullThenThrowsException() { JdbcDaoImpl dao = new JdbcDaoImpl(); - dao.setMessageSource(null); } @@ -207,9 +184,8 @@ public class JdbcDaoImplTests { JdbcDaoImpl dao = new JdbcDaoImpl(); dao.setMessageSource(source); String code = "code"; - dao.getMessages().getMessage(code); - verify(source).getMessage(eq(code), any(), any()); } + } diff --git a/core/src/test/java/org/springframework/security/core/userdetails/memory/UserAttributeEditorTests.java b/core/src/test/java/org/springframework/security/core/userdetails/memory/UserAttributeEditorTests.java index e92a382dd7..1d700855e8 100644 --- a/core/src/test/java/org/springframework/security/core/userdetails/memory/UserAttributeEditorTests.java +++ b/core/src/test/java/org/springframework/security/core/userdetails/memory/UserAttributeEditorTests.java @@ -16,10 +16,10 @@ package org.springframework.security.core.userdetails.memory; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link UserAttributeEditor} and associated {@link UserAttribute}. * @@ -31,7 +31,6 @@ public class UserAttributeEditorTests { public void testCorrectOperationWithTrailingSpaces() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("password ,ROLE_ONE,ROLE_TWO "); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user.getPassword()).isEqualTo("password"); assertThat(user.getAuthorities()).hasSize(2); @@ -43,7 +42,6 @@ public class UserAttributeEditorTests { public void testCorrectOperationWithoutEnabledDisabledKeyword() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("password,ROLE_ONE,ROLE_TWO"); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user.isValid()).isTrue(); assertThat(user.isEnabled()).isTrue(); // default @@ -57,7 +55,6 @@ public class UserAttributeEditorTests { public void testDisabledKeyword() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("password,disabled,ROLE_ONE,ROLE_TWO"); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user.isValid()).isTrue(); assertThat(!user.isEnabled()).isTrue(); @@ -71,7 +68,6 @@ public class UserAttributeEditorTests { public void testEmptyStringReturnsNull() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText(""); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user == null).isTrue(); } @@ -80,7 +76,6 @@ public class UserAttributeEditorTests { public void testEnabledKeyword() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("password,ROLE_ONE,enabled,ROLE_TWO"); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user.isValid()).isTrue(); assertThat(user.isEnabled()).isTrue(); @@ -94,7 +89,6 @@ public class UserAttributeEditorTests { public void testMalformedStringReturnsNull() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("MALFORMED_STRING"); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user == null).isTrue(); } @@ -103,7 +97,6 @@ public class UserAttributeEditorTests { public void testNoPasswordOrRolesReturnsNull() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("disabled"); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user == null).isTrue(); } @@ -112,7 +105,6 @@ public class UserAttributeEditorTests { public void testNoRolesReturnsNull() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText("password,enabled"); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user == null).isTrue(); } @@ -121,8 +113,8 @@ public class UserAttributeEditorTests { public void testNullReturnsNull() { UserAttributeEditor editor = new UserAttributeEditor(); editor.setAsText(null); - UserAttribute user = (UserAttribute) editor.getValue(); assertThat(user == null).isTrue(); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/AbstractMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/AbstractMixinTests.java index 89302f2687..a29801bd5a 100644 --- a/core/src/test/java/org/springframework/security/jackson2/AbstractMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/AbstractMixinTests.java @@ -27,13 +27,14 @@ import org.springframework.security.core.userdetails.User; * @since 4.2 */ public abstract class AbstractMixinTests { + protected ObjectMapper mapper; @Before public void setup() { - mapper = new ObjectMapper(); + this.mapper = new ObjectMapper(); ClassLoader loader = getClass().getClassLoader(); - mapper.registerModules(SecurityJackson2Modules.getModules(loader)); + this.mapper.registerModules(SecurityJackson2Modules.getModules(loader)); } User createDefaultUser() { @@ -43,4 +44,5 @@ public abstract class AbstractMixinTests { User createUser(String username, String password, String authority) { return new User(username, password, AuthorityUtils.createAuthorityList(authority)); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixinTests.java index 7225c03aa1..924364f352 100644 --- a/core/src/test/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/AnonymousAuthenticationTokenMixinTests.java @@ -48,21 +48,17 @@ public class AnonymousAuthenticationTokenMixinTests extends AbstractMixinTests { + "\"authorities\": " + SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON + "}"; // @formatter:on - @Test public void serializeAnonymousAuthenticationTokenTest() throws JsonProcessingException, JSONException { User user = createDefaultUser(); - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken( - HASH_KEY, user, user.getAuthorities() - ); - String actualJson = mapper.writeValueAsString(token); + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken(HASH_KEY, user, user.getAuthorities()); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(ANONYMOUS_JSON, actualJson, true); } @Test public void deserializeAnonymousAuthenticationTokenTest() throws IOException { - AnonymousAuthenticationToken token = mapper - .readValue(ANONYMOUS_JSON, AnonymousAuthenticationToken.class); + AnonymousAuthenticationToken token = this.mapper.readValue(ANONYMOUS_JSON, AnonymousAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getKeyHash()).isEqualTo(HASH_KEY.hashCode()); assertThat(token.getAuthorities()).isNotNull().hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); @@ -70,20 +66,20 @@ public class AnonymousAuthenticationTokenMixinTests extends AbstractMixinTests { @Test(expected = JsonMappingException.class) public void deserializeAnonymousAuthenticationTokenWithoutAuthoritiesTest() throws IOException { - String jsonString = "{\"@class\": \"org.springframework.security.authentication.AnonymousAuthenticationToken\", \"details\": null," + - "\"principal\": \"user\", \"authenticated\": true, \"keyHash\": " + HASH_KEY.hashCode() + "," + - "\"authorities\": [\"java.util.ArrayList\", []]}"; - mapper.readValue(jsonString, AnonymousAuthenticationToken.class); + String jsonString = "{\"@class\": \"org.springframework.security.authentication.AnonymousAuthenticationToken\", \"details\": null," + + "\"principal\": \"user\", \"authenticated\": true, \"keyHash\": " + HASH_KEY.hashCode() + "," + + "\"authorities\": [\"java.util.ArrayList\", []]}"; + this.mapper.readValue(jsonString, AnonymousAuthenticationToken.class); } @Test - public void serializeAnonymousAuthenticationTokenMixinAfterEraseCredentialTest() throws JsonProcessingException, JSONException { + public void serializeAnonymousAuthenticationTokenMixinAfterEraseCredentialTest() + throws JsonProcessingException, JSONException { User user = createDefaultUser(); - AnonymousAuthenticationToken token = new AnonymousAuthenticationToken( - HASH_KEY, user, user.getAuthorities() - ); + AnonymousAuthenticationToken token = new AnonymousAuthenticationToken(HASH_KEY, user, user.getAuthorities()); token.eraseCredentials(); - String actualJson = mapper.writeValueAsString(token); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(ANONYMOUS_JSON.replace(UserDeserializerTests.USER_PASSWORD, "null"), actualJson, true); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/BadCredentialsExceptionMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/BadCredentialsExceptionMixinTests.java index 44d0bb6784..91dbb6750e 100644 --- a/core/src/test/java/org/springframework/security/jackson2/BadCredentialsExceptionMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/BadCredentialsExceptionMixinTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.jackson2; +import java.io.IOException; + import com.fasterxml.jackson.core.JsonProcessingException; import org.json.JSONException; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; -import org.springframework.security.authentication.BadCredentialsException; -import java.io.IOException; +import org.springframework.security.authentication.BadCredentialsException; import static org.assertj.core.api.Assertions.assertThat; @@ -39,20 +41,20 @@ public class BadCredentialsExceptionMixinTests extends AbstractMixinTests { + "\"suppressed\": [\"[Ljava.lang.Throwable;\",[]]" + "}"; // @formatter:on - @Test public void serializeBadCredentialsExceptionMixinTest() throws JsonProcessingException, JSONException { BadCredentialsException exception = new BadCredentialsException("message"); - String serializedJson = mapper.writeValueAsString(exception); + String serializedJson = this.mapper.writeValueAsString(exception); JSONAssert.assertEquals(EXCEPTION_JSON, serializedJson, true); } @Test public void deserializeBadCredentialsExceptionMixinTest() throws IOException { - BadCredentialsException exception = mapper.readValue(EXCEPTION_JSON, BadCredentialsException.class); + BadCredentialsException exception = this.mapper.readValue(EXCEPTION_JSON, BadCredentialsException.class); assertThat(exception).isNotNull(); assertThat(exception.getCause()).isNull(); assertThat(exception.getMessage()).isEqualTo("message"); assertThat(exception.getLocalizedMessage()).isEqualTo("message"); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixinTests.java index 41a54b17c4..85b05860e8 100644 --- a/core/src/test/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/RememberMeAuthenticationTokenMixinTests.java @@ -16,18 +16,19 @@ package org.springframework.security.jackson2; +import java.io.IOException; +import java.util.Collections; + import com.fasterxml.jackson.core.JsonProcessingException; import org.json.JSONException; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; + import org.springframework.security.authentication.RememberMeAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.User; -import java.io.IOException; -import java.util.Collections; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -47,7 +48,6 @@ public class RememberMeAuthenticationTokenMixinTests extends AbstractMixinTests + "\"authorities\": " + SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON + "}"; // @formatter:on - // @formatter:off private static final String REMEMBERME_AUTH_STRINGPRINCIPAL_JSON = "{" + "\"@class\": \"org.springframework.security.authentication.RememberMeAuthenticationToken\"," @@ -58,7 +58,6 @@ public class RememberMeAuthenticationTokenMixinTests extends AbstractMixinTests + "\"authorities\": " + SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON + "}"; // @formatter:on - @Test(expected = IllegalArgumentException.class) public void testWithNullPrincipal() { new RememberMeAuthenticationToken("key", null, Collections.emptyList()); @@ -71,31 +70,37 @@ public class RememberMeAuthenticationTokenMixinTests extends AbstractMixinTests @Test public void serializeRememberMeAuthenticationToken() throws JsonProcessingException, JSONException { - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken(REMEMBERME_KEY, "admin", Collections.singleton(new SimpleGrantedAuthority("ROLE_USER"))); - String actualJson = mapper.writeValueAsString(token); + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken(REMEMBERME_KEY, "admin", + Collections.singleton(new SimpleGrantedAuthority("ROLE_USER"))); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(REMEMBERME_AUTH_STRINGPRINCIPAL_JSON, actualJson, true); } @Test public void serializeRememberMeAuthenticationWithUserToken() throws JsonProcessingException, JSONException { User user = createDefaultUser(); - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken(REMEMBERME_KEY, user, user.getAuthorities()); - String actualJson = mapper.writeValueAsString(token); + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken(REMEMBERME_KEY, user, + user.getAuthorities()); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(String.format(REMEMBERME_AUTH_JSON, "\"password\""), actualJson, true); } @Test - public void serializeRememberMeAuthenticationWithUserTokenAfterEraseCredential() throws JsonProcessingException, JSONException { + public void serializeRememberMeAuthenticationWithUserTokenAfterEraseCredential() + throws JsonProcessingException, JSONException { User user = createDefaultUser(); - RememberMeAuthenticationToken token = new RememberMeAuthenticationToken(REMEMBERME_KEY, user, user.getAuthorities()); + RememberMeAuthenticationToken token = new RememberMeAuthenticationToken(REMEMBERME_KEY, user, + user.getAuthorities()); token.eraseCredentials(); - String actualJson = mapper.writeValueAsString(token); - JSONAssert.assertEquals(REMEMBERME_AUTH_JSON.replace(UserDeserializerTests.USER_PASSWORD, "null"), actualJson, true); + String actualJson = this.mapper.writeValueAsString(token); + JSONAssert.assertEquals(REMEMBERME_AUTH_JSON.replace(UserDeserializerTests.USER_PASSWORD, "null"), actualJson, + true); } @Test public void deserializeRememberMeAuthenticationToken() throws IOException { - RememberMeAuthenticationToken token = mapper.readValue(REMEMBERME_AUTH_STRINGPRINCIPAL_JSON, RememberMeAuthenticationToken.class); + RememberMeAuthenticationToken token = this.mapper.readValue(REMEMBERME_AUTH_STRINGPRINCIPAL_JSON, + RememberMeAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getPrincipal()).isNotNull().isEqualTo("admin").isEqualTo(token.getName()); assertThat(token.getAuthorities()).hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); @@ -103,14 +108,16 @@ public class RememberMeAuthenticationTokenMixinTests extends AbstractMixinTests @Test public void deserializeRememberMeAuthenticationTokenWithUserTest() throws IOException { - RememberMeAuthenticationToken token = mapper - .readValue(String.format(REMEMBERME_AUTH_JSON, "\"password\""), RememberMeAuthenticationToken.class); + RememberMeAuthenticationToken token = this.mapper.readValue(String.format(REMEMBERME_AUTH_JSON, "\"password\""), + RememberMeAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getPrincipal()).isNotNull().isInstanceOf(User.class); assertThat(((User) token.getPrincipal()).getUsername()).isEqualTo("admin"); assertThat(((User) token.getPrincipal()).getPassword()).isEqualTo("1234"); - assertThat(((User) token.getPrincipal()).getAuthorities()).hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); + assertThat(((User) token.getPrincipal()).getAuthorities()).hasSize(1) + .contains(new SimpleGrantedAuthority("ROLE_USER")); assertThat(token.getAuthorities()).hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); assertThat(((User) token.getPrincipal()).isEnabled()).isEqualTo(true); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/SecurityContextMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/SecurityContextMixinTests.java index d0757828c6..8f2806079f 100644 --- a/core/src/test/java/org/springframework/security/jackson2/SecurityContextMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/SecurityContextMixinTests.java @@ -16,21 +16,21 @@ package org.springframework.security.jackson2; -import static org.assertj.core.api.Assertions.assertThat; - import java.io.IOException; import java.util.Collection; import java.util.Collections; +import com.fasterxml.jackson.core.JsonProcessingException; import org.json.JSONException; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; + import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; -import com.fasterxml.jackson.core.JsonProcessingException; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Jitendra Singh @@ -44,18 +44,18 @@ public class SecurityContextMixinTests extends AbstractMixinTests { + "\"authentication\": " + UsernamePasswordAuthenticationTokenMixinTests.AUTHENTICATED_STRINGPRINCIPAL_JSON + "}"; // @formatter:on - @Test public void securityContextSerializeTest() throws JsonProcessingException, JSONException { SecurityContext context = new SecurityContextImpl(); - context.setAuthentication(new UsernamePasswordAuthenticationToken("admin", "1234", Collections.singleton(new SimpleGrantedAuthority("ROLE_USER")))); - String actualJson = mapper.writeValueAsString(context); + context.setAuthentication(new UsernamePasswordAuthenticationToken("admin", "1234", + Collections.singleton(new SimpleGrantedAuthority("ROLE_USER")))); + String actualJson = this.mapper.writeValueAsString(context); JSONAssert.assertEquals(SECURITY_CONTEXT_JSON, actualJson, true); } @Test public void securityContextDeserializeTest() throws IOException { - SecurityContext context = mapper.readValue(SECURITY_CONTEXT_JSON, SecurityContextImpl.class); + SecurityContext context = this.mapper.readValue(SECURITY_CONTEXT_JSON, SecurityContextImpl.class); assertThat(context).isNotNull(); assertThat(context.getAuthentication()).isNotNull().isInstanceOf(UsernamePasswordAuthenticationToken.class); assertThat(context.getAuthentication().getPrincipal()).isEqualTo("admin"); @@ -65,4 +65,5 @@ public class SecurityContextMixinTests extends AbstractMixinTests { assertThat(authorities).hasSize(1); assertThat(authorities).contains(new SimpleGrantedAuthority("ROLE_USER")); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/SecurityJackson2ModulesTests.java b/core/src/test/java/org/springframework/security/jackson2/SecurityJackson2ModulesTests.java index e28aaf0ffd..cbbbddc130 100644 --- a/core/src/test/java/org/springframework/security/jackson2/SecurityJackson2ModulesTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/SecurityJackson2ModulesTests.java @@ -14,9 +14,15 @@ * limitations under the License. */ - package org.springframework.security.jackson2; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.HashMap; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreType; @@ -25,108 +31,110 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.Before; import org.junit.Test; -import java.lang.annotation.*; -import java.util.HashMap; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** -* @author Rob Winch -* @since 5.0 -*/ + * @author Rob Winch + * @since 5.0 + */ public class SecurityJackson2ModulesTests { + private ObjectMapper mapper; @Before public void setup() { - mapper = new ObjectMapper(); - SecurityJackson2Modules.enableDefaultTyping(mapper); + this.mapper = new ObjectMapper(); + SecurityJackson2Modules.enableDefaultTyping(this.mapper); } @Test public void readValueWhenNotAllowedOrMappedThenThrowsException() { String content = "{\"@class\":\"org.springframework.security.jackson2.SecurityJackson2ModulesTests$NotAllowlisted\",\"property\":\"bar\"}"; - assertThatThrownBy(() -> { - mapper.readValue(content, Object.class); - } - ).hasStackTraceContaining("allowlist"); + // @formatter:off + assertThatExceptionOfType(Exception.class) + .isThrownBy(() -> this.mapper.readValue(content, Object.class)) + .withStackTraceContaining("allowlist"); + // @formatter:on } @Test public void readValueWhenExplicitDefaultTypingAfterSecuritySetupThenReadsAsSpecificType() throws Exception { - mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY); + this.mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY); String content = "{\"@class\":\"org.springframework.security.jackson2.SecurityJackson2ModulesTests$NotAllowlisted\",\"property\":\"bar\"}"; - - assertThat(mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlisted.class); + assertThat(this.mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlisted.class); } @Test public void readValueWhenExplicitDefaultTypingBeforeSecuritySetupThenReadsAsSpecificType() throws Exception { - mapper = new ObjectMapper(); - mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY); - SecurityJackson2Modules.enableDefaultTyping(mapper); + this.mapper = new ObjectMapper(); + this.mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY); + SecurityJackson2Modules.enableDefaultTyping(this.mapper); String content = "{\"@class\":\"org.springframework.security.jackson2.SecurityJackson2ModulesTests$NotAllowlisted\",\"property\":\"bar\"}"; - - assertThat(mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlisted.class); + assertThat(this.mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlisted.class); } @Test public void readValueWhenAnnotatedThenReadsAsSpecificType() throws Exception { String content = "{\"@class\":\"org.springframework.security.jackson2.SecurityJackson2ModulesTests$NotAllowlistedButAnnotated\",\"property\":\"bar\"}"; - - assertThat(mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlistedButAnnotated.class); + assertThat(this.mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlistedButAnnotated.class); } @Test public void readValueWhenMixinProvidedThenReadsAsSpecificType() throws Exception { - mapper.addMixIn(NotAllowlisted.class, NotAllowlistedMixin.class); + this.mapper.addMixIn(NotAllowlisted.class, NotAllowlistedMixin.class); String content = "{\"@class\":\"org.springframework.security.jackson2.SecurityJackson2ModulesTests$NotAllowlisted\",\"property\":\"bar\"}"; - - assertThat(mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlisted.class); + assertThat(this.mapper.readValue(content, Object.class)).isInstanceOf(NotAllowlisted.class); } @Test public void readValueWhenHashMapThenReadsAsSpecificType() throws Exception { - mapper.addMixIn(NotAllowlisted.class, NotAllowlistedMixin.class); + this.mapper.addMixIn(NotAllowlisted.class, NotAllowlistedMixin.class); String content = "{\"@class\":\"java.util.HashMap\"}"; - - assertThat(mapper.readValue(content, Object.class)).isInstanceOf(HashMap.class); + assertThat(this.mapper.readValue(content, Object.class)).isInstanceOf(HashMap.class); } @Target({ ElementType.TYPE, ElementType.ANNOTATION_TYPE }) @Retention(RetentionPolicy.RUNTIME) @Documented - public @interface NotJacksonAnnotation {} + public @interface NotJacksonAnnotation { + + } @NotJacksonAnnotation static class NotAllowlisted { + private String property = "bar"; - public String getProperty() { - return property; + String getProperty() { + return this.property; } - public void setProperty(String property) { + void setProperty(String property) { } + } @JsonIgnoreType(false) static class NotAllowlistedButAnnotated { + private String property = "bar"; - public String getProperty() { - return property; + String getProperty() { + return this.property; } - public void setProperty(String property) { + void setProperty(String property) { } + } @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, - isGetterVisibility = JsonAutoDetect.Visibility.NONE) + isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonIgnoreProperties(ignoreUnknown = true) abstract class NotAllowlistedMixin { } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixinTests.java index a7699e2d20..05d67d7323 100644 --- a/core/src/test/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/SimpleGrantedAuthorityMixinTests.java @@ -16,16 +16,17 @@ package org.springframework.security.jackson2; +import java.io.IOException; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonMappingException; import org.json.JSONException; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; + import org.springframework.security.core.authority.SimpleGrantedAuthority; -import java.io.IOException; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Jitendra Singh @@ -35,28 +36,22 @@ public class SimpleGrantedAuthorityMixinTests extends AbstractMixinTests { // @formatter:off public static final String AUTHORITY_JSON = "{\"@class\": \"org.springframework.security.core.authority.SimpleGrantedAuthority\", \"authority\": \"ROLE_USER\"}"; - public static final String AUTHORITIES_ARRAYLIST_JSON = "[\"java.util.Collections$UnmodifiableRandomAccessList\", [" + AUTHORITY_JSON + "]]"; - public static final String AUTHORITIES_SET_JSON = "[\"java.util.Collections$UnmodifiableSet\", [" + AUTHORITY_JSON + "]]"; - public static final String NO_AUTHORITIES_ARRAYLIST_JSON = "[\"java.util.Collections$UnmodifiableRandomAccessList\", []]"; - public static final String EMPTY_AUTHORITIES_ARRAYLIST_JSON = "[\"java.util.Collections$EmptyList\", []]"; - public static final String NO_AUTHORITIES_SET_JSON = "[\"java.util.Collections$UnmodifiableSet\", []]"; // @formatter:on - @Test public void serializeSimpleGrantedAuthorityTest() throws JsonProcessingException, JSONException { SimpleGrantedAuthority authority = new SimpleGrantedAuthority("ROLE_USER"); - String serializeJson = mapper.writeValueAsString(authority); + String serializeJson = this.mapper.writeValueAsString(authority); JSONAssert.assertEquals(AUTHORITY_JSON, serializeJson, true); } @Test public void deserializeGrantedAuthorityTest() throws IOException { - SimpleGrantedAuthority authority = mapper.readValue(AUTHORITY_JSON, SimpleGrantedAuthority.class); + SimpleGrantedAuthority authority = this.mapper.readValue(AUTHORITY_JSON, SimpleGrantedAuthority.class); assertThat(authority).isNotNull(); assertThat(authority.getAuthority()).isNotNull().isEqualTo("ROLE_USER"); } @@ -64,6 +59,7 @@ public class SimpleGrantedAuthorityMixinTests extends AbstractMixinTests { @Test(expected = JsonMappingException.class) public void deserializeGrantedAuthorityWithoutRoleTest() throws IOException { String json = "{\"@class\": \"org.springframework.security.core.authority.SimpleGrantedAuthority\"}"; - mapper.readValue(json, SimpleGrantedAuthority.class); + this.mapper.readValue(json, SimpleGrantedAuthority.class); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/UserDeserializerTests.java b/core/src/test/java/org/springframework/security/jackson2/UserDeserializerTests.java index 397491ee0a..299b4f9026 100644 --- a/core/src/test/java/org/springframework/security/jackson2/UserDeserializerTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/UserDeserializerTests.java @@ -38,13 +38,14 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 4.2 */ public class UserDeserializerTests extends AbstractMixinTests { + public static final String USER_PASSWORD = "\"1234\""; // @formatter:off public static final String USER_JSON = "{" + "\"@class\": \"org.springframework.security.core.userdetails.User\", " + "\"username\": \"admin\"," - + " \"password\": "+ USER_PASSWORD +", " + + " \"password\": " + USER_PASSWORD + ", " + "\"accountNonExpired\": true, " + "\"accountNonLocked\": true, " + "\"credentialsNonExpired\": true, " @@ -52,33 +53,31 @@ public class UserDeserializerTests extends AbstractMixinTests { + "\"authorities\": " + SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON + "}"; // @formatter:on - @Test public void serializeUserTest() throws JsonProcessingException, JSONException { User user = createDefaultUser(); - String userJson = mapper.writeValueAsString(user); + String userJson = this.mapper.writeValueAsString(user); JSONAssert.assertEquals(userWithPasswordJson(user.getPassword()), userJson, true); } @Test public void serializeUserWithoutAuthority() throws JsonProcessingException, JSONException { User user = new User("admin", "1234", Collections.emptyList()); - String userJson = mapper.writeValueAsString(user); + String userJson = this.mapper.writeValueAsString(user); JSONAssert.assertEquals(userWithNoAuthoritiesJson(), userJson, true); } @Test(expected = IllegalArgumentException.class) public void deserializeUserWithNullPasswordEmptyAuthorityTest() throws IOException { - String userJsonWithoutPasswordString = USER_JSON.replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON, "[]"); - - mapper.readValue(userJsonWithoutPasswordString, User.class); + String userJsonWithoutPasswordString = USER_JSON.replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON, + "[]"); + this.mapper.readValue(userJsonWithoutPasswordString, User.class); } @Test public void deserializeUserWithNullPasswordNoAuthorityTest() throws Exception { - String userJsonWithoutPasswordString = removeNode(userWithNoAuthoritiesJson(), mapper, "password"); - - User user = mapper.readValue(userJsonWithoutPasswordString, User.class); + String userJsonWithoutPasswordString = removeNode(userWithNoAuthoritiesJson(), this.mapper, "password"); + User user = this.mapper.readValue(userJsonWithoutPasswordString, User.class); assertThat(user).isNotNull(); assertThat(user.getUsername()).isEqualTo("admin"); assertThat(user.getPassword()).isNull(); @@ -88,13 +87,14 @@ public class UserDeserializerTests extends AbstractMixinTests { @Test(expected = IllegalArgumentException.class) public void deserializeUserWithNoClassIdInAuthoritiesTest() throws Exception { - String userJson = USER_JSON.replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON, "[{\"authority\": \"ROLE_USER\"}]"); - mapper.readValue(userJson, User.class); + String userJson = USER_JSON.replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON, + "[{\"authority\": \"ROLE_USER\"}]"); + this.mapper.readValue(userJson, User.class); } @Test public void deserializeUserWithClassIdInAuthoritiesTest() throws IOException { - User user = mapper.readValue(userJson(), User.class); + User user = this.mapper.readValue(userJson(), User.class); assertThat(user).isNotNull(); assertThat(user.getUsername()).isEqualTo("admin"); assertThat(user.getPassword()).isEqualTo("1234"); @@ -104,7 +104,6 @@ public class UserDeserializerTests extends AbstractMixinTests { private String removeNode(String json, ObjectMapper mapper, String toRemove) throws Exception { ObjectNode node = mapper.getFactory().createParser(json).readValueAsTree(); node.remove(toRemove); - String result = mapper.writeValueAsString(node); JSONAssert.assertNotEquals(json, result, false); return result; @@ -115,10 +114,12 @@ public class UserDeserializerTests extends AbstractMixinTests { } public static String userWithPasswordJson(String password) { - return userJson().replaceAll(Pattern.quote(USER_PASSWORD), "\""+ password +"\""); + return userJson().replaceAll(Pattern.quote(USER_PASSWORD), "\"" + password + "\""); } public static String userWithNoAuthoritiesJson() { - return userJson().replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON, SimpleGrantedAuthorityMixinTests.NO_AUTHORITIES_SET_JSON); + return userJson().replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_SET_JSON, + SimpleGrantedAuthorityMixinTests.NO_AUTHORITIES_SET_JSON); } + } diff --git a/core/src/test/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixinTests.java b/core/src/test/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixinTests.java index 7475476fe6..130c9ea0ab 100644 --- a/core/src/test/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixinTests.java +++ b/core/src/test/java/org/springframework/security/jackson2/UsernamePasswordAuthenticationTokenMixinTests.java @@ -20,6 +20,8 @@ import java.io.IOException; import java.util.ArrayList; import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonInclude.Value; import com.fasterxml.jackson.core.JsonProcessingException; import org.json.JSONException; import org.junit.Test; @@ -29,10 +31,6 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.userdetails.User; -import static com.fasterxml.jackson.annotation.JsonInclude.Include.ALWAYS; -import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_ABSENT; -import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; -import static com.fasterxml.jackson.annotation.JsonInclude.Value.construct; import static org.assertj.core.api.Assertions.assertThat; /** @@ -42,64 +40,56 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 4.2 */ public class UsernamePasswordAuthenticationTokenMixinTests extends AbstractMixinTests { - // @formatter:off + private static final String AUTHENTICATED_JSON = "{" - + "\"@class\": \"org.springframework.security.authentication.UsernamePasswordAuthenticationToken\"," - + "\"principal\": "+ UserDeserializerTests.USER_JSON + ", " - + "\"credentials\": \"1234\", " - + "\"authenticated\": true, " - + "\"details\": null, " - + "\"authorities\": "+ SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON - + "}"; - // @formatter:on + + "\"@class\": \"org.springframework.security.authentication.UsernamePasswordAuthenticationToken\"," + + "\"principal\": " + UserDeserializerTests.USER_JSON + ", " + "\"credentials\": \"1234\", " + + "\"authenticated\": true, " + "\"details\": null, " + "\"authorities\": " + + SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON + "}"; - // @formatter:off - public static final String AUTHENTICATED_STRINGPRINCIPAL_JSON = AUTHENTICATED_JSON.replace( UserDeserializerTests.USER_JSON, "\"admin\""); - // @formatter:on + public static final String AUTHENTICATED_STRINGPRINCIPAL_JSON = AUTHENTICATED_JSON + .replace(UserDeserializerTests.USER_JSON, "\"admin\""); - // @formatter:off private static final String NON_USER_PRINCIPAL_JSON = "{" - + "\"@class\": \"org.springframework.security.jackson2.UsernamePasswordAuthenticationTokenMixinTests$NonUserPrincipal\", " - + "\"username\": \"admin\"" - + "}"; - // @formatter:on + + "\"@class\": \"org.springframework.security.jackson2.UsernamePasswordAuthenticationTokenMixinTests$NonUserPrincipal\", " + + "\"username\": \"admin\"" + "}"; - // @formatter:off - private static final String AUTHENTICATED_STRINGDETAILS_JSON = AUTHENTICATED_JSON.replace("\"details\": null, ", "\"details\": \"details\", "); - // @formatter:on + private static final String AUTHENTICATED_STRINGDETAILS_JSON = AUTHENTICATED_JSON.replace("\"details\": null, ", + "\"details\": \"details\", "); - // @formatter:off private static final String AUTHENTICATED_NON_USER_PRINCIPAL_JSON = AUTHENTICATED_JSON - .replace(UserDeserializerTests.USER_JSON, NON_USER_PRINCIPAL_JSON) - .replaceAll(UserDeserializerTests.USER_PASSWORD, "null") - .replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON, SimpleGrantedAuthorityMixinTests.NO_AUTHORITIES_ARRAYLIST_JSON); - // @formatter:on + .replace(UserDeserializerTests.USER_JSON, NON_USER_PRINCIPAL_JSON) + .replaceAll(UserDeserializerTests.USER_PASSWORD, "null") + .replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON, + SimpleGrantedAuthorityMixinTests.NO_AUTHORITIES_ARRAYLIST_JSON); - // @formatter:off private static final String UNAUTHENTICATED_STRINGPRINCIPAL_JSON = AUTHENTICATED_STRINGPRINCIPAL_JSON - .replace("\"authenticated\": true, ", "\"authenticated\": false, ") - .replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON, SimpleGrantedAuthorityMixinTests.EMPTY_AUTHORITIES_ARRAYLIST_JSON); - // @formatter:on + .replace("\"authenticated\": true, ", "\"authenticated\": false, ") + .replace(SimpleGrantedAuthorityMixinTests.AUTHORITIES_ARRAYLIST_JSON, + SimpleGrantedAuthorityMixinTests.EMPTY_AUTHORITIES_ARRAYLIST_JSON); @Test - public void serializeUnauthenticatedUsernamePasswordAuthenticationTokenMixinTest() throws JsonProcessingException, JSONException { + public void serializeUnauthenticatedUsernamePasswordAuthenticationTokenMixinTest() + throws JsonProcessingException, JSONException { UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken("admin", "1234"); - String serializedJson = mapper.writeValueAsString(token); + String serializedJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(UNAUTHENTICATED_STRINGPRINCIPAL_JSON, serializedJson, true); } @Test - public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinTest() throws JsonProcessingException, JSONException { + public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinTest() + throws JsonProcessingException, JSONException { User user = createDefaultUser(); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(user.getUsername(), user.getPassword(), user.getAuthorities()); - String serializedJson = mapper.writeValueAsString(token); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(user.getUsername(), + user.getPassword(), user.getAuthorities()); + String serializedJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(AUTHENTICATED_STRINGPRINCIPAL_JSON, serializedJson, true); } @Test public void deserializeUnauthenticatedUsernamePasswordAuthenticationTokenMixinTest() throws IOException { - UsernamePasswordAuthenticationToken token = mapper - .readValue(UNAUTHENTICATED_STRINGPRINCIPAL_JSON, UsernamePasswordAuthenticationToken.class); + UsernamePasswordAuthenticationToken token = this.mapper.readValue(UNAUTHENTICATED_STRINGPRINCIPAL_JSON, + UsernamePasswordAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.isAuthenticated()).isEqualTo(false); assertThat(token.getAuthorities()).isNotNull().hasSize(0); @@ -108,64 +98,71 @@ public class UsernamePasswordAuthenticationTokenMixinTests extends AbstractMixin @Test public void deserializeAuthenticatedUsernamePasswordAuthenticationTokenMixinTest() throws IOException { UsernamePasswordAuthenticationToken expectedToken = createToken(); - UsernamePasswordAuthenticationToken token = mapper - .readValue(AUTHENTICATED_STRINGPRINCIPAL_JSON, UsernamePasswordAuthenticationToken.class); + UsernamePasswordAuthenticationToken token = this.mapper.readValue(AUTHENTICATED_STRINGPRINCIPAL_JSON, + UsernamePasswordAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.isAuthenticated()).isTrue(); assertThat(token.getAuthorities()).isEqualTo(expectedToken.getAuthorities()); } @Test - public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinWithUserTest() throws JsonProcessingException, JSONException { + public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinWithUserTest() + throws JsonProcessingException, JSONException { UsernamePasswordAuthenticationToken token = createToken(); - String actualJson = mapper.writeValueAsString(token); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(AUTHENTICATED_JSON, actualJson, true); } @Test public void deserializeAuthenticatedUsernamePasswordAuthenticationTokenWithUserTest() throws IOException { - UsernamePasswordAuthenticationToken token = mapper - .readValue(AUTHENTICATED_JSON, UsernamePasswordAuthenticationToken.class); + UsernamePasswordAuthenticationToken token = this.mapper.readValue(AUTHENTICATED_JSON, + UsernamePasswordAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getPrincipal()).isNotNull().isInstanceOf(User.class); - assertThat(((User) token.getPrincipal()).getAuthorities()).isNotNull().hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); + assertThat(((User) token.getPrincipal()).getAuthorities()).isNotNull().hasSize(1) + .contains(new SimpleGrantedAuthority("ROLE_USER")); assertThat(token.isAuthenticated()).isEqualTo(true); assertThat(token.getAuthorities()).hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); } @Test - public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinAfterEraseCredentialInvoked() throws JsonProcessingException, JSONException { + public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinAfterEraseCredentialInvoked() + throws JsonProcessingException, JSONException { UsernamePasswordAuthenticationToken token = createToken(); token.eraseCredentials(); - String actualJson = mapper.writeValueAsString(token); - JSONAssert.assertEquals(AUTHENTICATED_JSON.replaceAll(UserDeserializerTests.USER_PASSWORD, "null"), actualJson, true); + String actualJson = this.mapper.writeValueAsString(token); + JSONAssert.assertEquals(AUTHENTICATED_JSON.replaceAll(UserDeserializerTests.USER_PASSWORD, "null"), actualJson, + true); } @Test - public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinWithNonUserPrincipalTest() throws JsonProcessingException, JSONException { + public void serializeAuthenticatedUsernamePasswordAuthenticationTokenMixinWithNonUserPrincipalTest() + throws JsonProcessingException, JSONException { NonUserPrincipal principal = new NonUserPrincipal(); principal.setUsername("admin"); - UsernamePasswordAuthenticationToken token = - new UsernamePasswordAuthenticationToken(principal, null, new ArrayList<>()); - String actualJson = mapper.writeValueAsString(token); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(principal, null, + new ArrayList<>()); + String actualJson = this.mapper.writeValueAsString(token); JSONAssert.assertEquals(AUTHENTICATED_NON_USER_PRINCIPAL_JSON, actualJson, true); } @Test - public void deserializeAuthenticatedUsernamePasswordAuthenticationTokenWithNonUserPrincipalTest() throws IOException { - UsernamePasswordAuthenticationToken token = mapper - .readValue(AUTHENTICATED_NON_USER_PRINCIPAL_JSON, UsernamePasswordAuthenticationToken.class); + public void deserializeAuthenticatedUsernamePasswordAuthenticationTokenWithNonUserPrincipalTest() + throws IOException { + UsernamePasswordAuthenticationToken token = this.mapper.readValue(AUTHENTICATED_NON_USER_PRINCIPAL_JSON, + UsernamePasswordAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getPrincipal()).isNotNull().isInstanceOf(NonUserPrincipal.class); } @Test public void deserializeAuthenticatedUsernamePasswordAuthenticationTokenWithDetailsTest() throws IOException { - UsernamePasswordAuthenticationToken token = mapper - .readValue(AUTHENTICATED_STRINGDETAILS_JSON, UsernamePasswordAuthenticationToken.class); + UsernamePasswordAuthenticationToken token = this.mapper.readValue(AUTHENTICATED_STRINGDETAILS_JSON, + UsernamePasswordAuthenticationToken.class); assertThat(token).isNotNull(); assertThat(token.getPrincipal()).isNotNull().isInstanceOf(User.class); - assertThat(((User) token.getPrincipal()).getAuthorities()).isNotNull().hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); + assertThat(((User) token.getPrincipal()).getAuthorities()).isNotNull().hasSize(1) + .contains(new SimpleGrantedAuthority("ROLE_USER")); assertThat(token.isAuthenticated()).isEqualTo(true); assertThat(token.getAuthorities()).hasSize(1).contains(new SimpleGrantedAuthority("ROLE_USER")); assertThat(token.getDetails()).isExactlyInstanceOf(String.class).isEqualTo("details"); @@ -173,48 +170,44 @@ public class UsernamePasswordAuthenticationTokenMixinTests extends AbstractMixin @Test public void serializingThenDeserializingWithNoCredentialsOrDetailsShouldWork() throws IOException { - // given UsernamePasswordAuthenticationToken original = new UsernamePasswordAuthenticationToken("Frodo", null); - - // when String serialized = this.mapper.writeValueAsString(original); - UsernamePasswordAuthenticationToken deserialized = this.mapper.readValue(serialized, UsernamePasswordAuthenticationToken.class); - - // then + UsernamePasswordAuthenticationToken deserialized = this.mapper.readValue(serialized, + UsernamePasswordAuthenticationToken.class); assertThat(deserialized).isEqualTo(original); } @Test public void serializingThenDeserializingWithConfiguredObjectMapperShouldWork() throws IOException { - // given - this.mapper.setDefaultPropertyInclusion(construct(ALWAYS, NON_NULL)).setSerializationInclusion(NON_ABSENT); + this.mapper.setDefaultPropertyInclusion(Value.construct(Include.ALWAYS, Include.NON_NULL)) + .setSerializationInclusion(Include.NON_ABSENT); UsernamePasswordAuthenticationToken original = new UsernamePasswordAuthenticationToken("Frodo", null); - - // when String serialized = this.mapper.writeValueAsString(original); - UsernamePasswordAuthenticationToken deserialized = - this.mapper.readValue(serialized, UsernamePasswordAuthenticationToken.class); - - // then + UsernamePasswordAuthenticationToken deserialized = this.mapper.readValue(serialized, + UsernamePasswordAuthenticationToken.class); assertThat(deserialized).isEqualTo(original); } private UsernamePasswordAuthenticationToken createToken() { User user = createDefaultUser(); - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(user, user.getPassword(), user.getAuthorities()); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(user, user.getPassword(), + user.getAuthorities()); return token; } @JsonClassDescription public static class NonUserPrincipal { + private String username; public String getUsername() { - return username; + return this.username; } public void setUsername(String username) { this.username = username; } + } + } diff --git a/core/src/test/java/org/springframework/security/provisioning/InMemoryUserDetailsManagerTests.java b/core/src/test/java/org/springframework/security/provisioning/InMemoryUserDetailsManagerTests.java index e86a2a55af..cca7f4d75b 100644 --- a/core/src/test/java/org/springframework/security/provisioning/InMemoryUserDetailsManagerTests.java +++ b/core/src/test/java/org/springframework/security/provisioning/InMemoryUserDetailsManagerTests.java @@ -17,17 +17,19 @@ package org.springframework.security.provisioning; import org.junit.Test; + import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch * @since 5.1 */ public class InMemoryUserDetailsManagerTests { + private final UserDetails user = PasswordEncodedUser.user(); private InMemoryUserDetailsManager manager = new InMemoryUserDetailsManager(this.user); @@ -41,12 +43,11 @@ public class InMemoryUserDetailsManagerTests { @Test public void changePasswordWhenUsernameIsNotInLowercase() { - UserDetails userNotLowerCase = User.withUserDetails(PasswordEncodedUser.user()) - .username("User") - .build(); - + UserDetails userNotLowerCase = User.withUserDetails(PasswordEncodedUser.user()).username("User").build(); String newPassword = "newPassword"; this.manager.updatePassword(userNotLowerCase, newPassword); - assertThat(this.manager.loadUserByUsername(userNotLowerCase.getUsername()).getPassword()).isEqualTo(newPassword); + assertThat(this.manager.loadUserByUsername(userNotLowerCase.getUsername()).getPassword()) + .isEqualTo(newPassword); } + } diff --git a/core/src/test/java/org/springframework/security/provisioning/JdbcUserDetailsManagerTests.java b/core/src/test/java/org/springframework/security/provisioning/JdbcUserDetailsManagerTests.java index ebced08d6e..ddb9a46d7e 100644 --- a/core/src/test/java/org/springframework/security/provisioning/JdbcUserDetailsManagerTests.java +++ b/core/src/test/java/org/springframework/security/provisioning/JdbcUserDetailsManagerTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.provisioning; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.provisioning; import java.util.Collections; import java.util.HashMap; @@ -28,6 +26,7 @@ import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; + import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.security.PopulatedDatabase; import org.springframework.security.TestDataSource; @@ -44,21 +43,32 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserCache; import org.springframework.security.core.userdetails.UserDetails; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests for {@link JdbcUserDetailsManager} * * @author Luke Taylor */ public class JdbcUserDetailsManagerTests { + private static final String SELECT_JOE_SQL = "select * from users where username = 'joe'"; + private static final String SELECT_JOE_AUTHORITIES_SQL = "select * from authorities where username = 'joe'"; - private static final UserDetails joe = new User("joe", "password", true, true, true, - true, AuthorityUtils.createAuthorityList("A", "C", "B")); + private static final UserDetails joe = new User("joe", "password", true, true, true, true, + AuthorityUtils.createAuthorityList("A", "C", "B")); private static TestDataSource dataSource; + private JdbcUserDetailsManager manager; + private MockUserCache cache; + private JdbcTemplate template; @BeforeClass @@ -74,81 +84,74 @@ public class JdbcUserDetailsManagerTests { @Before public void initializeManagerAndCreateTables() { - manager = new JdbcUserDetailsManager(); - cache = new MockUserCache(); - manager.setUserCache(cache); - manager.setDataSource(dataSource); - manager.setCreateUserSql(JdbcUserDetailsManager.DEF_CREATE_USER_SQL); - manager.setUpdateUserSql(JdbcUserDetailsManager.DEF_UPDATE_USER_SQL); - manager.setUserExistsSql(JdbcUserDetailsManager.DEF_USER_EXISTS_SQL); - manager.setCreateAuthoritySql(JdbcUserDetailsManager.DEF_INSERT_AUTHORITY_SQL); - manager.setDeleteUserAuthoritiesSql(JdbcUserDetailsManager.DEF_DELETE_USER_AUTHORITIES_SQL); - manager.setDeleteUserSql(JdbcUserDetailsManager.DEF_DELETE_USER_SQL); - manager.setChangePasswordSql(JdbcUserDetailsManager.DEF_CHANGE_PASSWORD_SQL); - manager.initDao(); - template = manager.getJdbcTemplate(); - - template.execute("create table users(username varchar(20) not null primary key," + this.manager = new JdbcUserDetailsManager(); + this.cache = new MockUserCache(); + this.manager.setUserCache(this.cache); + this.manager.setDataSource(dataSource); + this.manager.setCreateUserSql(JdbcUserDetailsManager.DEF_CREATE_USER_SQL); + this.manager.setUpdateUserSql(JdbcUserDetailsManager.DEF_UPDATE_USER_SQL); + this.manager.setUserExistsSql(JdbcUserDetailsManager.DEF_USER_EXISTS_SQL); + this.manager.setCreateAuthoritySql(JdbcUserDetailsManager.DEF_INSERT_AUTHORITY_SQL); + this.manager.setDeleteUserAuthoritiesSql(JdbcUserDetailsManager.DEF_DELETE_USER_AUTHORITIES_SQL); + this.manager.setDeleteUserSql(JdbcUserDetailsManager.DEF_DELETE_USER_SQL); + this.manager.setChangePasswordSql(JdbcUserDetailsManager.DEF_CHANGE_PASSWORD_SQL); + this.manager.initDao(); + this.template = this.manager.getJdbcTemplate(); + this.template.execute("create table users(username varchar(20) not null primary key," + "password varchar(20) not null, enabled boolean not null)"); - template.execute("create table authorities (username varchar(20) not null, authority varchar(20) not null, " - + "constraint fk_authorities_users foreign key(username) references users(username))"); - PopulatedDatabase.createGroupTables(template); - PopulatedDatabase.insertGroupData(template); + this.template + .execute("create table authorities (username varchar(20) not null, authority varchar(20) not null, " + + "constraint fk_authorities_users foreign key(username) references users(username))"); + PopulatedDatabase.createGroupTables(this.template); + PopulatedDatabase.insertGroupData(this.template); } @After public void dropTablesAndClearContext() { - template.execute("drop table authorities"); - template.execute("drop table users"); - template.execute("drop table group_authorities"); - template.execute("drop table group_members"); - template.execute("drop table groups"); + this.template.execute("drop table authorities"); + this.template.execute("drop table users"); + this.template.execute("drop table group_authorities"); + this.template.execute("drop table group_members"); + this.template.execute("drop table groups"); SecurityContextHolder.clearContext(); } private void setUpAccLockingColumns() { - template.execute("alter table users add column acc_locked boolean default false not null"); - template.execute("alter table users add column acc_expired boolean default false not null"); - template.execute("alter table users add column creds_expired boolean default false not null"); - - manager.setUsersByUsernameQuery( + this.template.execute("alter table users add column acc_locked boolean default false not null"); + this.template.execute("alter table users add column acc_expired boolean default false not null"); + this.template.execute("alter table users add column creds_expired boolean default false not null"); + this.manager.setUsersByUsernameQuery( "select username,password,enabled, acc_locked, acc_expired, creds_expired from users where username = ?"); - manager.setCreateUserSql( + this.manager.setCreateUserSql( "insert into users (username, password, enabled, acc_locked, acc_expired, creds_expired) values (?,?,?,?,?,?)"); - manager.setUpdateUserSql( + this.manager.setUpdateUserSql( "update users set password = ?, enabled = ?, acc_locked=?, acc_expired=?, creds_expired=? where username = ?"); } @Test public void createUserInsertsCorrectData() { - manager.createUser(joe); - - UserDetails joe2 = manager.loadUserByUsername("joe"); - + this.manager.createUser(joe); + UserDetails joe2 = this.manager.loadUserByUsername("joe"); assertThat(joe2).isEqualTo(joe); } @Test public void createUserInsertsCorrectDataWithLocking() { setUpAccLockingColumns(); - UserDetails user = new User("joe", "pass", true, false, true, false, AuthorityUtils.createAuthorityList("A", "B")); - manager.createUser(user); - - UserDetails user2 = manager.loadUserByUsername(user.getUsername()); - + this.manager.createUser(user); + UserDetails user2 = this.manager.loadUserByUsername(user.getUsername()); assertThat(user2).isEqualToComparingFieldByField(user); } @Test public void deleteUserRemovesUserDataAndAuthoritiesAndClearsCache() { insertJoe(); - manager.deleteUser("joe"); - - assertThat(template.queryForList(SELECT_JOE_SQL)).isEmpty(); - assertThat(template.queryForList(SELECT_JOE_AUTHORITIES_SQL)).isEmpty(); - assertThat(cache.getUserMap().containsKey("joe")).isFalse(); + this.manager.deleteUser("joe"); + assertThat(this.template.queryForList(SELECT_JOE_SQL)).isEmpty(); + assertThat(this.template.queryForList(SELECT_JOE_AUTHORITIES_SQL)).isEmpty(); + assertThat(this.cache.getUserMap().containsKey("joe")).isFalse(); } @Test @@ -156,59 +159,49 @@ public class JdbcUserDetailsManagerTests { insertJoe(); User newJoe = new User("joe", "newpassword", false, true, true, true, AuthorityUtils.createAuthorityList(new String[] { "D", "F", "E" })); - - manager.updateUser(newJoe); - - UserDetails joe = manager.loadUserByUsername("joe"); - + this.manager.updateUser(newJoe); + UserDetails joe = this.manager.loadUserByUsername("joe"); assertThat(joe).isEqualTo(newJoe); - assertThat(cache.getUserMap().containsKey("joe")).isFalse(); + assertThat(this.cache.getUserMap().containsKey("joe")).isFalse(); } @Test public void updateUserChangesDataCorrectlyAndClearsCacheWithLocking() { setUpAccLockingColumns(); - insertJoe(); - User newJoe = new User("joe", "newpassword", false, false, false, true, AuthorityUtils.createAuthorityList("D", "F", "E")); - - manager.updateUser(newJoe); - - UserDetails joe = manager.loadUserByUsername(newJoe.getUsername()); - + this.manager.updateUser(newJoe); + UserDetails joe = this.manager.loadUserByUsername(newJoe.getUsername()); assertThat(joe).isEqualToComparingFieldByField(newJoe); - assertThat(cache.getUserMap().containsKey(newJoe.getUsername())).isFalse(); + assertThat(this.cache.getUserMap().containsKey(newJoe.getUsername())).isFalse(); } - @Test public void userExistsReturnsFalseForNonExistentUsername() { - assertThat(manager.userExists("joe")).isFalse(); + assertThat(this.manager.userExists("joe")).isFalse(); } @Test public void userExistsReturnsTrueForExistingUsername() { insertJoe(); - assertThat(manager.userExists("joe")).isTrue(); - assertThat(cache.getUserMap().containsKey("joe")).isTrue(); + assertThat(this.manager.userExists("joe")).isTrue(); + assertThat(this.cache.getUserMap().containsKey("joe")).isTrue(); } @Test(expected = AccessDeniedException.class) public void changePasswordFailsForUnauthenticatedUser() { - manager.changePassword("password", "newPassword"); + this.manager.changePassword("password", "newPassword"); } @Test public void changePasswordSucceedsWithAuthenticatedUserAndNoAuthenticationManagerSet() { insertJoe(); authenticateJoe(); - manager.changePassword("wrongpassword", "newPassword"); - UserDetails newJoe = manager.loadUserByUsername("joe"); - + this.manager.changePassword("wrongpassword", "newPassword"); + UserDetails newJoe = this.manager.loadUserByUsername("joe"); assertThat(newJoe.getPassword()).isEqualTo("newPassword"); - assertThat(cache.getUserMap().containsKey("joe")).isFalse(); + assertThat(this.cache.getUserMap().containsKey("joe")).isFalse(); } @Test @@ -216,19 +209,17 @@ public class JdbcUserDetailsManagerTests { insertJoe(); Authentication currentAuth = authenticateJoe(); AuthenticationManager am = mock(AuthenticationManager.class); - when(am.authenticate(currentAuth)).thenReturn(currentAuth); - - manager.setAuthenticationManager(am); - manager.changePassword("password", "newPassword"); - UserDetails newJoe = manager.loadUserByUsername("joe"); - + given(am.authenticate(currentAuth)).willReturn(currentAuth); + this.manager.setAuthenticationManager(am); + this.manager.changePassword("password", "newPassword"); + UserDetails newJoe = this.manager.loadUserByUsername("joe"); assertThat(newJoe.getPassword()).isEqualTo("newPassword"); // The password in the context should also be altered Authentication newAuth = SecurityContextHolder.getContext().getAuthentication(); assertThat(newAuth.getName()).isEqualTo("joe"); assertThat(newAuth.getDetails()).isEqualTo(currentAuth.getDetails()); assertThat(newAuth.getCredentials()).isNull(); - assertThat(cache.getUserMap().containsKey("joe")).isFalse(); + assertThat(this.cache.getUserMap().containsKey("joe")).isFalse(); } @Test @@ -236,30 +227,25 @@ public class JdbcUserDetailsManagerTests { insertJoe(); authenticateJoe(); AuthenticationManager am = mock(AuthenticationManager.class); - when(am.authenticate(any(Authentication.class))).thenThrow( - new BadCredentialsException("")); - - manager.setAuthenticationManager(am); - + given(am.authenticate(any(Authentication.class))).willThrow(new BadCredentialsException("")); + this.manager.setAuthenticationManager(am); try { - manager.changePassword("password", "newPassword"); + this.manager.changePassword("password", "newPassword"); fail("Expected BadCredentialsException"); } catch (BadCredentialsException expected) { } - // Check password hasn't changed. - UserDetails newJoe = manager.loadUserByUsername("joe"); + UserDetails newJoe = this.manager.loadUserByUsername("joe"); assertThat(newJoe.getPassword()).isEqualTo("password"); assertThat(SecurityContextHolder.getContext().getAuthentication().getCredentials()).isEqualTo("password"); - assertThat(cache.getUserMap().containsKey("joe")).isTrue(); + assertThat(this.cache.getUserMap().containsKey("joe")).isTrue(); } @Test public void findAllGroupsReturnsExpectedGroupNames() { - List groups = manager.findAllGroups(); + List groups = this.manager.findAllGroups(); assertThat(groups).hasSize(4); - Collections.sort(groups); assertThat(groups.get(0)).isEqualTo("GROUP_0"); assertThat(groups.get(1)).isEqualTo("GROUP_1"); @@ -269,154 +255,142 @@ public class JdbcUserDetailsManagerTests { @Test public void findGroupMembersReturnsCorrectData() { - List groupMembers = manager.findUsersInGroup("GROUP_0"); + List groupMembers = this.manager.findUsersInGroup("GROUP_0"); assertThat(groupMembers).hasSize(1); assertThat(groupMembers.get(0)).isEqualTo("jerry"); - groupMembers = manager.findUsersInGroup("GROUP_1"); + groupMembers = this.manager.findUsersInGroup("GROUP_1"); assertThat(groupMembers).hasSize(2); } @Test @SuppressWarnings("unchecked") public void createGroupInsertsCorrectData() { - manager.createGroup("TEST_GROUP", - AuthorityUtils.createAuthorityList("ROLE_X", "ROLE_Y")); - - List roles = template - .queryForList("select ga.authority from groups g, group_authorities ga " - + "where ga.group_id = g.id " + "and g.group_name = 'TEST_GROUP'"); - + this.manager.createGroup("TEST_GROUP", AuthorityUtils.createAuthorityList("ROLE_X", "ROLE_Y")); + List roles = this.template.queryForList("select ga.authority from groups g, group_authorities ga " + + "where ga.group_id = g.id " + "and g.group_name = 'TEST_GROUP'"); assertThat(roles).hasSize(2); } @Test public void deleteGroupRemovesData() { - manager.deleteGroup("GROUP_0"); - manager.deleteGroup("GROUP_1"); - manager.deleteGroup("GROUP_2"); - manager.deleteGroup("GROUP_3"); - - assertThat(template.queryForList("select * from group_authorities")).isEmpty(); - assertThat(template.queryForList("select * from group_members")).isEmpty(); - assertThat(template.queryForList("select id from groups")).isEmpty(); + this.manager.deleteGroup("GROUP_0"); + this.manager.deleteGroup("GROUP_1"); + this.manager.deleteGroup("GROUP_2"); + this.manager.deleteGroup("GROUP_3"); + assertThat(this.template.queryForList("select * from group_authorities")).isEmpty(); + assertThat(this.template.queryForList("select * from group_members")).isEmpty(); + assertThat(this.template.queryForList("select id from groups")).isEmpty(); } @Test public void renameGroupIsSuccessful() { - manager.renameGroup("GROUP_0", "GROUP_X"); - - assertThat(template.queryForObject("select id from groups where group_name = 'GROUP_X'", - Integer.class)).isZero(); + this.manager.renameGroup("GROUP_0", "GROUP_X"); + assertThat(this.template.queryForObject("select id from groups where group_name = 'GROUP_X'", Integer.class)) + .isZero(); } @Test public void addingGroupUserSetsCorrectData() { - manager.addUserToGroup("tom", "GROUP_0"); - - assertThat( - template.queryForList( - "select username from group_members where group_id = 0")).hasSize(2); + this.manager.addUserToGroup("tom", "GROUP_0"); + assertThat(this.template.queryForList("select username from group_members where group_id = 0")).hasSize(2); } @Test public void removeUserFromGroupDeletesGroupMemberRow() { - manager.removeUserFromGroup("jerry", "GROUP_1"); - - assertThat( - template.queryForList( - "select group_id from group_members where username = 'jerry'")).hasSize(1); + this.manager.removeUserFromGroup("jerry", "GROUP_1"); + assertThat(this.template.queryForList("select group_id from group_members where username = 'jerry'")) + .hasSize(1); } @Test public void findGroupAuthoritiesReturnsCorrectAuthorities() { - assertThat(AuthorityUtils.createAuthorityList("ROLE_A")).isEqualTo(manager.findGroupAuthorities("GROUP_0")); + assertThat(AuthorityUtils.createAuthorityList("ROLE_A")) + .isEqualTo(this.manager.findGroupAuthorities("GROUP_0")); } @Test public void addGroupAuthorityInsertsCorrectGroupAuthorityRow() { GrantedAuthority auth = new SimpleGrantedAuthority("ROLE_X"); - manager.addGroupAuthority("GROUP_0", auth); - - template.queryForObject( - "select authority from group_authorities where authority = 'ROLE_X' and group_id = 0", - String.class); + this.manager.addGroupAuthority("GROUP_0", auth); + this.template.queryForObject( + "select authority from group_authorities where authority = 'ROLE_X' and group_id = 0", String.class); } @Test public void deleteGroupAuthorityRemovesCorrectRows() { GrantedAuthority auth = new SimpleGrantedAuthority("ROLE_A"); - manager.removeGroupAuthority("GROUP_0", auth); - assertThat( - template.queryForList( - "select authority from group_authorities where group_id = 0")).isEmpty(); - - manager.removeGroupAuthority("GROUP_2", auth); - assertThat( - template.queryForList( - "select authority from group_authorities where group_id = 2")).hasSize(2); + this.manager.removeGroupAuthority("GROUP_0", auth); + assertThat(this.template.queryForList("select authority from group_authorities where group_id = 0")).isEmpty(); + this.manager.removeGroupAuthority("GROUP_2", auth); + assertThat(this.template.queryForList("select authority from group_authorities where group_id = 2")).hasSize(2); } // SEC-1156 @Test public void createUserDoesNotSaveAuthoritiesIfEnableAuthoritiesIsFalse() { - manager.setEnableAuthorities(false); - manager.createUser(joe); - assertThat(template.queryForList(SELECT_JOE_AUTHORITIES_SQL)).isEmpty(); + this.manager.setEnableAuthorities(false); + this.manager.createUser(joe); + assertThat(this.template.queryForList(SELECT_JOE_AUTHORITIES_SQL)).isEmpty(); } // SEC-1156 @Test public void updateUserDoesNotSaveAuthoritiesIfEnableAuthoritiesIsFalse() { - manager.setEnableAuthorities(false); + this.manager.setEnableAuthorities(false); insertJoe(); - template.execute("delete from authorities where username='joe'"); - manager.updateUser(joe); - assertThat(template.queryForList(SELECT_JOE_AUTHORITIES_SQL)).isEmpty(); + this.template.execute("delete from authorities where username='joe'"); + this.manager.updateUser(joe); + assertThat(this.template.queryForList(SELECT_JOE_AUTHORITIES_SQL)).isEmpty(); } // SEC-2166 @Test public void createNewAuthenticationUsesNullPasswordToKeepPassordsSave() { insertJoe(); - UsernamePasswordAuthenticationToken currentAuth = new UsernamePasswordAuthenticationToken( - "joe", null, AuthorityUtils.createAuthorityList("ROLE_USER")); - Authentication updatedAuth = manager.createNewAuthentication(currentAuth, "new"); + UsernamePasswordAuthenticationToken currentAuth = new UsernamePasswordAuthenticationToken("joe", null, + AuthorityUtils.createAuthorityList("ROLE_USER")); + Authentication updatedAuth = this.manager.createNewAuthentication(currentAuth, "new"); assertThat(updatedAuth.getCredentials()).isNull(); } private Authentication authenticateJoe() { - UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken( - "joe", "password", joe.getAuthorities()); + UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("joe", "password", + joe.getAuthorities()); SecurityContextHolder.getContext().setAuthentication(auth); - return auth; } private void insertJoe() { - template.execute("insert into users (username, password, enabled) values ('joe','password','true')"); - template.execute("insert into authorities (username, authority) values ('joe','A')"); - template.execute("insert into authorities (username, authority) values ('joe','B')"); - template.execute("insert into authorities (username, authority) values ('joe','C')"); - cache.putUserInCache(joe); + this.template.execute("insert into users (username, password, enabled) values ('joe','password','true')"); + this.template.execute("insert into authorities (username, authority) values ('joe','A')"); + this.template.execute("insert into authorities (username, authority) values ('joe','B')"); + this.template.execute("insert into authorities (username, authority) values ('joe','C')"); + this.cache.putUserInCache(joe); } private class MockUserCache implements UserCache { + private Map cache = new HashMap<>(); + @Override public UserDetails getUserFromCache(String username) { - return cache.get(username); + return this.cache.get(username); } + @Override public void putUserInCache(UserDetails user) { - cache.put(user.getUsername(), user); + this.cache.put(user.getUsername(), user); } + @Override public void removeUserFromCache(String username) { - cache.remove(username); + this.cache.remove(username); } Map getUserMap() { - return cache; + return this.cache; } + } + } diff --git a/core/src/test/java/org/springframework/security/scheduling/AbstractSecurityContextSchedulingTaskExecutorTests.java b/core/src/test/java/org/springframework/security/scheduling/AbstractSecurityContextSchedulingTaskExecutorTests.java index 7ae6cfa990..635451cbef 100644 --- a/core/src/test/java/org/springframework/security/scheduling/AbstractSecurityContextSchedulingTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/scheduling/AbstractSecurityContextSchedulingTaskExecutorTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.scheduling; -import static org.mockito.Mockito.verify; +package org.springframework.security.scheduling; import org.junit.Test; import org.mockito.Mock; + import org.springframework.scheduling.SchedulingTaskExecutor; import org.springframework.security.task.AbstractDelegatingSecurityContextAsyncTaskExecutorTests; +import static org.mockito.Mockito.verify; + /** * Abstract class for testing {@link DelegatingSecurityContextSchedulingTaskExecutor} * which allows customization of how @@ -32,8 +34,8 @@ import org.springframework.security.task.AbstractDelegatingSecurityContextAsyncT * @see CurrentSecurityContextSchedulingTaskExecutorTests * @see ExplicitSecurityContextSchedulingTaskExecutorTests */ -public abstract class AbstractSecurityContextSchedulingTaskExecutorTests extends - AbstractDelegatingSecurityContextAsyncTaskExecutorTests { +public abstract class AbstractSecurityContextSchedulingTaskExecutorTests + extends AbstractDelegatingSecurityContextAsyncTaskExecutorTests { @Mock protected SchedulingTaskExecutor taskExecutorDelegate; @@ -42,14 +44,17 @@ public abstract class AbstractSecurityContextSchedulingTaskExecutorTests extends @Test public void prefersShortLivedTasks() { - executor = create(); - executor.prefersShortLivedTasks(); - verify(taskExecutorDelegate).prefersShortLivedTasks(); + this.executor = create(); + this.executor.prefersShortLivedTasks(); + verify(this.taskExecutorDelegate).prefersShortLivedTasks(); } + @Override protected SchedulingTaskExecutor getExecutor() { - return taskExecutorDelegate; + return this.taskExecutorDelegate; } + @Override protected abstract DelegatingSecurityContextSchedulingTaskExecutor create(); + } diff --git a/core/src/test/java/org/springframework/security/scheduling/CurrentSecurityContextSchedulingTaskExecutorTests.java b/core/src/test/java/org/springframework/security/scheduling/CurrentSecurityContextSchedulingTaskExecutorTests.java index 3ee5bff24a..9635b24c27 100644 --- a/core/src/test/java/org/springframework/security/scheduling/CurrentSecurityContextSchedulingTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/scheduling/CurrentSecurityContextSchedulingTaskExecutorTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.scheduling; import org.junit.Before; + import org.springframework.security.core.context.SecurityContext; /** @@ -26,15 +28,17 @@ import org.springframework.security.core.context.SecurityContext; * @since 3.2 * */ -public class CurrentSecurityContextSchedulingTaskExecutorTests extends - AbstractSecurityContextSchedulingTaskExecutorTests { +public class CurrentSecurityContextSchedulingTaskExecutorTests + extends AbstractSecurityContextSchedulingTaskExecutorTests { @Before public void setUp() throws Exception { currentSecurityContextPowermockSetup(); } + @Override protected DelegatingSecurityContextSchedulingTaskExecutor create() { - return new DelegatingSecurityContextSchedulingTaskExecutor(taskExecutorDelegate); + return new DelegatingSecurityContextSchedulingTaskExecutor(this.taskExecutorDelegate); } + } diff --git a/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java b/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java index 33f49dce4e..d52d530956 100644 --- a/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java +++ b/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java @@ -13,24 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.scheduling; +import java.time.Duration; +import java.time.Instant; +import java.util.Date; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; + import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; -import java.time.Duration; -import java.time.Instant; -import java.util.Date; - -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.verify; /** - * Test An implementation of {@link TaskScheduler} invoking it whenever the trigger + * Test An implementation of {@link TaskScheduler} invoking it whenever the trigger * indicates a next execution time. * * @author Richard Valdivieso @@ -40,8 +45,10 @@ public class DelegatingSecurityContextTaskSchedulerTests { @Mock private TaskScheduler scheduler; + @Mock private Runnable runnable; + @Mock private Trigger trigger; @@ -50,43 +57,44 @@ public class DelegatingSecurityContextTaskSchedulerTests { @Before public void setup() { MockitoAnnotations.initMocks(this); - delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(scheduler); + this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler); } @After public void cleanup() { - delegatingSecurityContextTaskScheduler = null; + this.delegatingSecurityContextTaskScheduler = null; } @Test(expected = IllegalArgumentException.class) public void testSchedulerIsNotNull() { - delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(null); + this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(null); } @Test public void testSchedulerWithRunnableAndTrigger() { - delegatingSecurityContextTaskScheduler.schedule(runnable, trigger); - verify(scheduler).schedule(any(Runnable.class), any(Trigger.class)); + this.delegatingSecurityContextTaskScheduler.schedule(this.runnable, this.trigger); + verify(this.scheduler).schedule(any(Runnable.class), any(Trigger.class)); } @Test public void testSchedulerWithRunnableAndInstant() { Instant date = Instant.now(); - delegatingSecurityContextTaskScheduler.schedule(runnable, date); - verify(scheduler).schedule(any(Runnable.class), any(Date.class)); + this.delegatingSecurityContextTaskScheduler.schedule(this.runnable, date); + verify(this.scheduler).schedule(any(Runnable.class), any(Date.class)); } @Test public void testScheduleAtFixedRateWithRunnableAndDate() { Date date = new Date(1544751374L); Duration duration = Duration.ofSeconds(4L); - delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(runnable, date, 1000L); - verify(scheduler).scheduleAtFixedRate(isA(Runnable.class), isA(Date.class), eq(1000L)); + this.delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(this.runnable, date, 1000L); + verify(this.scheduler).scheduleAtFixedRate(isA(Runnable.class), isA(Date.class), eq(1000L)); } @Test public void testScheduleAtFixedRateWithRunnableAndLong() { - delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(runnable, 1000L); - verify(scheduler).scheduleAtFixedRate(isA(Runnable.class), eq(1000L)); + this.delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(this.runnable, 1000L); + verify(this.scheduler).scheduleAtFixedRate(isA(Runnable.class), eq(1000L)); } + } diff --git a/core/src/test/java/org/springframework/security/scheduling/ExplicitSecurityContextSchedulingTaskExecutorTests.java b/core/src/test/java/org/springframework/security/scheduling/ExplicitSecurityContextSchedulingTaskExecutorTests.java index d69f2ba664..c6eb7e0d1a 100644 --- a/core/src/test/java/org/springframework/security/scheduling/ExplicitSecurityContextSchedulingTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/scheduling/ExplicitSecurityContextSchedulingTaskExecutorTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.scheduling; import org.junit.Before; + import org.springframework.security.core.context.SecurityContext; /** @@ -26,16 +28,17 @@ import org.springframework.security.core.context.SecurityContext; * @since 3.2 * */ -public class ExplicitSecurityContextSchedulingTaskExecutorTests extends - AbstractSecurityContextSchedulingTaskExecutorTests { +public class ExplicitSecurityContextSchedulingTaskExecutorTests + extends AbstractSecurityContextSchedulingTaskExecutorTests { @Before public void setUp() throws Exception { explicitSecurityContextPowermockSetup(); } + @Override protected DelegatingSecurityContextSchedulingTaskExecutor create() { - return new DelegatingSecurityContextSchedulingTaskExecutor(taskExecutorDelegate, - securityContext); + return new DelegatingSecurityContextSchedulingTaskExecutor(this.taskExecutorDelegate, this.securityContext); } + } diff --git a/core/src/test/java/org/springframework/security/task/AbstractDelegatingSecurityContextAsyncTaskExecutorTests.java b/core/src/test/java/org/springframework/security/task/AbstractDelegatingSecurityContextAsyncTaskExecutorTests.java index 29e4943a76..4558c6a9cb 100644 --- a/core/src/test/java/org/springframework/security/task/AbstractDelegatingSecurityContextAsyncTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/task/AbstractDelegatingSecurityContextAsyncTaskExecutorTests.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.task; -import static org.mockito.Mockito.verify; +package org.springframework.security.task; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; + import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.security.concurrent.AbstractDelegatingSecurityContextExecutorTests; +import static org.mockito.Mockito.verify; + /** * Abstract class for testing {@link DelegatingSecurityContextAsyncTaskExecutor} which * allows customization of how {@link DelegatingSecurityContextAsyncTaskExecutor} and its @@ -33,8 +35,9 @@ import org.springframework.security.concurrent.AbstractDelegatingSecurityContext * @see CurrentDelegatingSecurityContextAsyncTaskExecutorTests * @see ExplicitDelegatingSecurityContextAsyncTaskExecutorTests */ -public abstract class AbstractDelegatingSecurityContextAsyncTaskExecutorTests extends - AbstractDelegatingSecurityContextExecutorTests { +public abstract class AbstractDelegatingSecurityContextAsyncTaskExecutorTests + extends AbstractDelegatingSecurityContextExecutorTests { + @Mock protected AsyncTaskExecutor taskExecutorDelegate; @@ -42,30 +45,33 @@ public abstract class AbstractDelegatingSecurityContextAsyncTaskExecutorTests ex @Before public final void setUpExecutor() { - executor = create(); + this.executor = create(); } @Test public void executeStartTimeout() { - executor.execute(runnable, 1); - verify(getExecutor()).execute(wrappedRunnable, 1); + this.executor.execute(this.runnable, 1); + verify(getExecutor()).execute(this.wrappedRunnable, 1); } @Test public void submit() { - executor.submit(runnable); - verify(getExecutor()).submit(wrappedRunnable); + this.executor.submit(this.runnable); + verify(getExecutor()).submit(this.wrappedRunnable); } @Test public void submitCallable() { - executor.submit(callable); - verify(getExecutor()).submit(wrappedCallable); + this.executor.submit(this.callable); + verify(getExecutor()).submit(this.wrappedCallable); } + @Override protected AsyncTaskExecutor getExecutor() { - return taskExecutorDelegate; + return this.taskExecutorDelegate; } + @Override protected abstract DelegatingSecurityContextAsyncTaskExecutor create(); + } diff --git a/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextAsyncTaskExecutorTests.java b/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextAsyncTaskExecutorTests.java index 059f09b912..a3e33a742b 100644 --- a/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextAsyncTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextAsyncTaskExecutorTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.task; import org.junit.Before; +import org.springframework.security.core.context.SecurityContext; + /** * Tests using the current {@link SecurityContext} on * {@link DelegatingSecurityContextAsyncTaskExecutor} @@ -25,8 +28,8 @@ import org.junit.Before; * @since 3.2 * */ -public class CurrentDelegatingSecurityContextAsyncTaskExecutorTests extends - AbstractDelegatingSecurityContextAsyncTaskExecutorTests { +public class CurrentDelegatingSecurityContextAsyncTaskExecutorTests + extends AbstractDelegatingSecurityContextAsyncTaskExecutorTests { @Before public void setUp() throws Exception { @@ -35,7 +38,7 @@ public class CurrentDelegatingSecurityContextAsyncTaskExecutorTests extends @Override protected DelegatingSecurityContextAsyncTaskExecutor create() { - return new DelegatingSecurityContextAsyncTaskExecutor(taskExecutorDelegate); + return new DelegatingSecurityContextAsyncTaskExecutor(this.taskExecutorDelegate); } } diff --git a/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextTaskExecutorTests.java b/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextTaskExecutorTests.java index 402a098bfa..7ec62fbf73 100644 --- a/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/task/CurrentDelegatingSecurityContextTaskExecutorTests.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.task; import java.util.concurrent.Executor; import org.junit.Before; import org.mockito.Mock; + import org.springframework.core.task.TaskExecutor; -import org.springframework.security.concurrent.DelegatingSecurityContextExecutor; import org.springframework.security.concurrent.AbstractDelegatingSecurityContextExecutorTests; +import org.springframework.security.concurrent.DelegatingSecurityContextExecutor; +import org.springframework.security.core.context.SecurityContext; /** * Tests using the current {@link SecurityContext} on @@ -31,8 +34,8 @@ import org.springframework.security.concurrent.AbstractDelegatingSecurityContext * @since 3.2 * */ -public class CurrentDelegatingSecurityContextTaskExecutorTests extends - AbstractDelegatingSecurityContextExecutorTests { +public class CurrentDelegatingSecurityContextTaskExecutorTests extends AbstractDelegatingSecurityContextExecutorTests { + @Mock private TaskExecutor taskExecutorDelegate; @@ -41,11 +44,14 @@ public class CurrentDelegatingSecurityContextTaskExecutorTests extends currentSecurityContextPowermockSetup(); } + @Override protected Executor getExecutor() { - return taskExecutorDelegate; + return this.taskExecutorDelegate; } + @Override protected DelegatingSecurityContextExecutor create() { - return new DelegatingSecurityContextTaskExecutor(taskExecutorDelegate); + return new DelegatingSecurityContextTaskExecutor(this.taskExecutorDelegate); } + } diff --git a/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextAsyncTaskExecutorTests.java b/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextAsyncTaskExecutorTests.java index 05bee79214..d1464332a6 100644 --- a/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextAsyncTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextAsyncTaskExecutorTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.task; import org.junit.Before; +import org.springframework.security.core.context.SecurityContext; + /** * Tests using an explicit {@link SecurityContext} on * {@link DelegatingSecurityContextAsyncTaskExecutor} @@ -25,8 +28,8 @@ import org.junit.Before; * @since 3.2 * */ -public class ExplicitDelegatingSecurityContextAsyncTaskExecutorTests extends - AbstractDelegatingSecurityContextAsyncTaskExecutorTests { +public class ExplicitDelegatingSecurityContextAsyncTaskExecutorTests + extends AbstractDelegatingSecurityContextAsyncTaskExecutorTests { @Before public void setUp() throws Exception { @@ -35,8 +38,7 @@ public class ExplicitDelegatingSecurityContextAsyncTaskExecutorTests extends @Override protected DelegatingSecurityContextAsyncTaskExecutor create() { - return new DelegatingSecurityContextAsyncTaskExecutor(taskExecutorDelegate, - securityContext); + return new DelegatingSecurityContextAsyncTaskExecutor(this.taskExecutorDelegate, this.securityContext); } } diff --git a/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextTaskExecutorTests.java b/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextTaskExecutorTests.java index 02d93d9aee..aced1ad9d5 100644 --- a/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextTaskExecutorTests.java +++ b/core/src/test/java/org/springframework/security/task/ExplicitDelegatingSecurityContextTaskExecutorTests.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.task; import java.util.concurrent.Executor; import org.junit.Before; import org.mockito.Mock; + import org.springframework.core.task.TaskExecutor; -import org.springframework.security.concurrent.DelegatingSecurityContextExecutor; import org.springframework.security.concurrent.AbstractDelegatingSecurityContextExecutorTests; +import org.springframework.security.concurrent.DelegatingSecurityContextExecutor; +import org.springframework.security.core.context.SecurityContext; /** * Tests using an explicit {@link SecurityContext} on @@ -31,8 +34,8 @@ import org.springframework.security.concurrent.AbstractDelegatingSecurityContext * @since 3.2 * */ -public class ExplicitDelegatingSecurityContextTaskExecutorTests extends - AbstractDelegatingSecurityContextExecutorTests { +public class ExplicitDelegatingSecurityContextTaskExecutorTests extends AbstractDelegatingSecurityContextExecutorTests { + @Mock private TaskExecutor taskExecutorDelegate; @@ -41,12 +44,14 @@ public class ExplicitDelegatingSecurityContextTaskExecutorTests extends explicitSecurityContextPowermockSetup(); } + @Override protected Executor getExecutor() { - return taskExecutorDelegate; + return this.taskExecutorDelegate; } + @Override protected DelegatingSecurityContextExecutor create() { - return new DelegatingSecurityContextTaskExecutor(taskExecutorDelegate, - securityContext); + return new DelegatingSecurityContextTaskExecutor(this.taskExecutorDelegate, this.securityContext); } + } diff --git a/core/src/test/java/org/springframework/security/util/FieldUtilsTests.java b/core/src/test/java/org/springframework/security/util/FieldUtilsTests.java index 50b90c7886..2d7317b7a1 100644 --- a/core/src/test/java/org/springframework/security/util/FieldUtilsTests.java +++ b/core/src/test/java/org/springframework/security/util/FieldUtilsTests.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.util; +import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import org.junit.*; - /** * @author Luke Taylor */ @@ -27,30 +27,32 @@ public class FieldUtilsTests { @Test public void gettingAndSettingProtectedFieldIsSuccessful() throws Exception { - new FieldUtils(); - Object tc = new TestClass(); - assertThat(FieldUtils.getProtectedFieldValue("protectedField", tc)).isEqualTo("x"); assertThat(FieldUtils.getFieldValue(tc, "nested.protectedField")).isEqualTo("z"); FieldUtils.setProtectedFieldValue("protectedField", tc, "y"); assertThat(FieldUtils.getProtectedFieldValue("protectedField", tc)).isEqualTo("y"); - try { FieldUtils.getProtectedFieldValue("nonExistentField", tc); } catch (IllegalStateException expected) { } } -} -@SuppressWarnings("unused") -class TestClass { - private String protectedField = "x"; - private Nested nested = new Nested(); -} + @SuppressWarnings("unused") + static class TestClass { + + private String protectedField = "x"; + + private Nested nested = new Nested(); + + } + + @SuppressWarnings("unused") + static class Nested { + + private String protectedField = "z"; + + } -@SuppressWarnings("unused") -class Nested { - private String protectedField = "z"; } diff --git a/core/src/test/java/org/springframework/security/util/InMemoryResourceTests.java b/core/src/test/java/org/springframework/security/util/InMemoryResourceTests.java index b1b7fe8098..23edf7030a 100644 --- a/core/src/test/java/org/springframework/security/util/InMemoryResourceTests.java +++ b/core/src/test/java/org/springframework/security/util/InMemoryResourceTests.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.util; -import static org.assertj.core.api.Assertions.*; +import org.junit.Test; -import org.junit.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor @@ -38,4 +39,5 @@ public class InMemoryResourceTests { assertThat(new InMemoryResource("xxx").equals(new InMemoryResource("xxxx"))).isFalse(); assertThat(new InMemoryResource("xxx").equals(new Object())).isFalse(); } + } diff --git a/core/src/test/java/org/springframework/security/util/MethodInvocationUtilsTests.java b/core/src/test/java/org/springframework/security/util/MethodInvocationUtilsTests.java index 2c055d33a3..7bc823f2bb 100644 --- a/core/src/test/java/org/springframework/security/util/MethodInvocationUtilsTests.java +++ b/core/src/test/java/org/springframework/security/util/MethodInvocationUtilsTests.java @@ -13,49 +13,46 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.util; -import static org.assertj.core.api.Assertions.*; - -import org.aopalliance.intercept.MethodInvocation; -import org.junit.*; -import org.springframework.aop.framework.AdvisedSupport; -import org.springframework.security.access.annotation.BusinessServiceImpl; - import java.io.Serializable; +import org.aopalliance.intercept.MethodInvocation; +import org.junit.Test; + +import org.springframework.aop.framework.AdvisedSupport; +import org.springframework.security.access.annotation.BusinessServiceImpl; + +import static org.assertj.core.api.Assertions.assertThat; + /** - * * @author Luke Taylor */ public class MethodInvocationUtilsTests { @Test public void createFromClassReturnsMethodWithNoArgInfoForMethodWithNoArgs() { - new MethodInvocationUtils(); - - MethodInvocation mi = MethodInvocationUtils.createFromClass(String.class, - "length"); + MethodInvocation mi = MethodInvocationUtils.createFromClass(String.class, "length"); assertThat(mi).isNotNull(); } @Test public void createFromClassReturnsMethodIfArgInfoOmittedAndMethodNameIsUnique() { - MethodInvocation mi = MethodInvocationUtils.createFromClass( - BusinessServiceImpl.class, "methodReturningAnArray"); + MethodInvocation mi = MethodInvocationUtils.createFromClass(BusinessServiceImpl.class, + "methodReturningAnArray"); assertThat(mi).isNotNull(); } @Test(expected = IllegalArgumentException.class) public void exceptionIsRaisedIfArgInfoOmittedAndMethodNameIsNotUnique() { - MethodInvocationUtils.createFromClass(BusinessServiceImpl.class, - "methodReturningAList"); + MethodInvocationUtils.createFromClass(BusinessServiceImpl.class, "methodReturningAList"); } @Test public void createFromClassReturnsMethodIfGivenArgInfoForMethodWithArgs() { - MethodInvocation mi = MethodInvocationUtils.createFromClass(null, String.class, - "compareTo", new Class[] { String.class }, new Object[] { "" }); + MethodInvocation mi = MethodInvocationUtils.createFromClass(null, String.class, "compareTo", + new Class[] { String.class }, new Object[] { "" }); assertThat(mi).isNotNull(); } @@ -63,25 +60,27 @@ public class MethodInvocationUtilsTests { public void createFromObjectLocatesExistingMethods() { AdvisedTarget t = new AdvisedTarget(); // Just lie about interfaces - t.setInterfaces(new Class[] { Serializable.class, MethodInvocation.class, - Blah.class }); - + t.setInterfaces(new Class[] { Serializable.class, MethodInvocation.class, Blah.class }); MethodInvocation mi = MethodInvocationUtils.create(t, "blah"); assertThat(mi).isNotNull(); - t.setProxyTargetClass(true); mi = MethodInvocationUtils.create(t, "blah"); assertThat(mi).isNotNull(); - assertThat(MethodInvocationUtils.create(t, "blah", "non-existent arg")).isNull(); } interface Blah { + void blah(); + } class AdvisedTarget extends AdvisedSupport implements Blah { + + @Override public void blah() { } + } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2EncodingUtils.java b/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2EncodingUtils.java index 9c1920e2bf..4502d58ee0 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2EncodingUtils.java +++ b/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2EncodingUtils.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.argon2; import java.util.Base64; + import org.bouncycastle.crypto.params.Argon2Parameters; import org.bouncycastle.util.Arrays; @@ -27,127 +29,125 @@ import org.bouncycastle.util.Arrays; * @author Simeon Macke * @since 5.3 */ -class Argon2EncodingUtils { +final class Argon2EncodingUtils { + private static final Base64.Encoder b64encoder = Base64.getEncoder().withoutPadding(); + private static final Base64.Decoder b64decoder = Base64.getDecoder(); + private Argon2EncodingUtils() { + } + /** - * Encodes a raw Argon2-hash and its parameters into the standard Argon2-hash-string as specified in the reference - * implementation (https://github.com/P-H-C/phc-winner-argon2/blob/master/src/encoding.c#L244): + * Encodes a raw Argon2-hash and its parameters into the standard Argon2-hash-string + * as specified in the reference implementation + * (https://github.com/P-H-C/phc-winner-argon2/blob/master/src/encoding.c#L244): * * {@code $argon2[$v=]$m=,t=,p=$$} * - * where {@code } is either 'd', 'id', or 'i', {@code } is a decimal integer (positive, - * fits in an 'unsigned long'), and {@code } is Base64-encoded data (no '=' padding - * characters, no newline or whitespace). - * - * The last two binary chunks (encoded in Base64) are, in that order, - * the salt and the output. If no salt has been used, the salt will be omitted. + * where {@code } is either 'd', 'id', or 'i', {@code } is a decimal integer + * (positive, fits in an 'unsigned long'), and {@code } is Base64-encoded data + * (no '=' padding characters, no newline or whitespace). * + * The last two binary chunks (encoded in Base64) are, in that order, the salt and the + * output. If no salt has been used, the salt will be omitted. * @param hash the raw Argon2 hash in binary format * @param parameters the Argon2 parameters that were used to create the hash * @return the encoded Argon2-hash-string as described above * @throws IllegalArgumentException if the Argon2Parameters are invalid */ - public static String encode(byte[] hash, Argon2Parameters parameters) throws IllegalArgumentException { + static String encode(byte[] hash, Argon2Parameters parameters) throws IllegalArgumentException { StringBuilder stringBuilder = new StringBuilder(); - switch (parameters.getType()) { - case Argon2Parameters.ARGON2_d: stringBuilder.append("$argon2d"); break; - case Argon2Parameters.ARGON2_i: stringBuilder.append("$argon2i"); break; - case Argon2Parameters.ARGON2_id: stringBuilder.append("$argon2id"); break; - default: throw new IllegalArgumentException("Invalid algorithm type: "+parameters.getType()); + case Argon2Parameters.ARGON2_d: + stringBuilder.append("$argon2d"); + break; + case Argon2Parameters.ARGON2_i: + stringBuilder.append("$argon2i"); + break; + case Argon2Parameters.ARGON2_id: + stringBuilder.append("$argon2id"); + break; + default: + throw new IllegalArgumentException("Invalid algorithm type: " + parameters.getType()); } - stringBuilder.append("$v=").append(parameters.getVersion()) - .append("$m=").append(parameters.getMemory()) - .append(",t=").append(parameters.getIterations()) - .append(",p=").append(parameters.getLanes()); - + stringBuilder.append("$v=").append(parameters.getVersion()).append("$m=").append(parameters.getMemory()) + .append(",t=").append(parameters.getIterations()).append(",p=").append(parameters.getLanes()); if (parameters.getSalt() != null) { - stringBuilder.append("$") - .append(b64encoder.encodeToString(parameters.getSalt())); + stringBuilder.append("$").append(b64encoder.encodeToString(parameters.getSalt())); } - - stringBuilder.append("$") - .append(b64encoder.encodeToString(hash)); - + stringBuilder.append("$").append(b64encoder.encodeToString(hash)); return stringBuilder.toString(); } /** * Decodes an Argon2 hash string as specified in the reference implementation - * (https://github.com/P-H-C/phc-winner-argon2/blob/master/src/encoding.c#L244) into the raw hash and the used - * parameters. + * (https://github.com/P-H-C/phc-winner-argon2/blob/master/src/encoding.c#L244) into + * the raw hash and the used parameters. * * The hash has to be formatted as follows: * {@code $argon2[$v=]$m=,t=,p=$$} * - * where {@code } is either 'd', 'id', or 'i', {@code } is a decimal integer (positive, - * fits in an 'unsigned long'), and {@code } is Base64-encoded data (no '=' padding - * characters, no newline or whitespace). + * where {@code } is either 'd', 'id', or 'i', {@code } is a decimal integer + * (positive, fits in an 'unsigned long'), and {@code } is Base64-encoded data + * (no '=' padding characters, no newline or whitespace). * - * The last two binary chunks (encoded in Base64) are, in that order, - * the salt and the output. Both are required. The binary salt length and the - * output length must be in the allowed ranges defined in argon2.h. + * The last two binary chunks (encoded in Base64) are, in that order, the salt and the + * output. Both are required. The binary salt length and the output length must be in + * the allowed ranges defined in argon2.h. * @param encodedHash the Argon2 hash string as described above - * @return an {@link Argon2Hash} object containing the raw hash and the {@link Argon2Parameters}. + * @return an {@link Argon2Hash} object containing the raw hash and the + * {@link Argon2Parameters}. * @throws IllegalArgumentException if the encoded hash is malformed */ - public static Argon2Hash decode(String encodedHash) throws IllegalArgumentException { + static Argon2Hash decode(String encodedHash) throws IllegalArgumentException { Argon2Parameters.Builder paramsBuilder; - String[] parts = encodedHash.split("\\$"); - if (parts.length < 4) { throw new IllegalArgumentException("Invalid encoded Argon2-hash"); } - int currentPart = 1; - switch (parts[currentPart++]) { - case "argon2d": paramsBuilder = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_d); break; - case "argon2i": paramsBuilder = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_i); break; - case "argon2id": paramsBuilder = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_id); break; - default: throw new IllegalArgumentException("Invalid algorithm type: "+parts[0]); + case "argon2d": + paramsBuilder = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_d); + break; + case "argon2i": + paramsBuilder = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_i); + break; + case "argon2id": + paramsBuilder = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_id); + break; + default: + throw new IllegalArgumentException("Invalid algorithm type: " + parts[0]); } - if (parts[currentPart].startsWith("v=")) { paramsBuilder.withVersion(Integer.parseInt(parts[currentPart].substring(2))); currentPart++; } - String[] performanceParams = parts[currentPart++].split(","); - if (performanceParams.length != 3) { throw new IllegalArgumentException("Amount of performance parameters invalid"); } - - if (performanceParams[0].startsWith("m=")) { - paramsBuilder.withMemoryAsKB(Integer.parseInt(performanceParams[0].substring(2))); - } else { + if (!performanceParams[0].startsWith("m=")) { throw new IllegalArgumentException("Invalid memory parameter"); } - - if (performanceParams[1].startsWith("t=")) { - paramsBuilder.withIterations(Integer.parseInt(performanceParams[1].substring(2))); - } else { + paramsBuilder.withMemoryAsKB(Integer.parseInt(performanceParams[0].substring(2))); + if (!performanceParams[1].startsWith("t=")) { throw new IllegalArgumentException("Invalid iterations parameter"); } - - if (performanceParams[2].startsWith("p=")) { - paramsBuilder.withParallelism(Integer.parseInt(performanceParams[2].substring(2))); - } else { + paramsBuilder.withIterations(Integer.parseInt(performanceParams[1].substring(2))); + if (!performanceParams[2].startsWith("p=")) { throw new IllegalArgumentException("Invalid parallelity parameter"); } - + paramsBuilder.withParallelism(Integer.parseInt(performanceParams[2].substring(2))); paramsBuilder.withSalt(b64decoder.decode(parts[currentPart++])); - return new Argon2Hash(b64decoder.decode(parts[currentPart]), paramsBuilder.build()); } public static class Argon2Hash { private byte[] hash; + private Argon2Parameters parameters; Argon2Hash(byte[] hash, Argon2Parameters parameters) { @@ -156,7 +156,7 @@ class Argon2EncodingUtils { } public byte[] getHash() { - return Arrays.clone(hash); + return Arrays.clone(this.hash); } public void setHash(byte[] hash) { @@ -164,11 +164,13 @@ class Argon2EncodingUtils { } public Argon2Parameters getParameters() { - return parameters; + return this.parameters; } public void setParameters(Argon2Parameters parameters) { this.parameters = parameters; } + } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoder.java index 0e38c1072b..b3991cb8c2 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoder.java +++ b/crypto/src/main/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoder.java @@ -20,22 +20,26 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.bouncycastle.crypto.generators.Argon2BytesGenerator; import org.bouncycastle.crypto.params.Argon2Parameters; + import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.KeyGenerators; import org.springframework.security.crypto.password.PasswordEncoder; /** *

      - * Implementation of PasswordEncoder that uses the Argon2 hashing function. - * Clients can optionally supply the length of the salt to use, the length - * of the generated hash, a cpu cost parameter, a memory cost parameter - * and a parallelization parameter. + * Implementation of PasswordEncoder that uses the Argon2 hashing function. Clients can + * optionally supply the length of the salt to use, the length of the generated hash, a + * cpu cost parameter, a memory cost parameter and a parallelization parameter. *

      * - *

      Note:

      - *

      The currently implementation uses Bouncy castle which does not exploit - * parallelism/optimizations that password crackers will, so there is an - * unnecessary asymmetry between attacker and defender.

      + *

      + * Note: + *

      + *

      + * The currently implementation uses Bouncy castle which does not exploit + * parallelism/optimizations that password crackers will, so there is an unnecessary + * asymmetry between attacker and defender. + *

      * * @author Simeon Macke * @since 5.3 @@ -43,85 +47,86 @@ import org.springframework.security.crypto.password.PasswordEncoder; public class Argon2PasswordEncoder implements PasswordEncoder { private static final int DEFAULT_SALT_LENGTH = 16; + private static final int DEFAULT_HASH_LENGTH = 32; + private static final int DEFAULT_PARALLELISM = 1; + private static final int DEFAULT_MEMORY = 1 << 12; + private static final int DEFAULT_ITERATIONS = 3; private final Log logger = LogFactory.getLog(getClass()); private final int hashLength; + private final int parallelism; + private final int memory; + private final int iterations; private final BytesKeyGenerator saltGenerator; + public Argon2PasswordEncoder() { + this(DEFAULT_SALT_LENGTH, DEFAULT_HASH_LENGTH, DEFAULT_PARALLELISM, DEFAULT_MEMORY, DEFAULT_ITERATIONS); + } + public Argon2PasswordEncoder(int saltLength, int hashLength, int parallelism, int memory, int iterations) { this.hashLength = hashLength; this.parallelism = parallelism; this.memory = memory; this.iterations = iterations; - this.saltGenerator = KeyGenerators.secureRandom(saltLength); } - public Argon2PasswordEncoder() { - this(DEFAULT_SALT_LENGTH, DEFAULT_HASH_LENGTH, DEFAULT_PARALLELISM, DEFAULT_MEMORY, DEFAULT_ITERATIONS); - } - @Override public String encode(CharSequence rawPassword) { - byte[] salt = saltGenerator.generateKey(); - byte[] hash = new byte[hashLength]; - - Argon2Parameters params = new Argon2Parameters.Builder(Argon2Parameters.ARGON2_id). - withSalt(salt). - withParallelism(parallelism). - withMemoryAsKB(memory). - withIterations(iterations). - build(); + byte[] salt = this.saltGenerator.generateKey(); + byte[] hash = new byte[this.hashLength]; + // @formatter:off + Argon2Parameters params = new Argon2Parameters + .Builder(Argon2Parameters.ARGON2_id) + .withSalt(salt) + .withParallelism(this.parallelism) + .withMemoryAsKB(this.memory) + .withIterations(this.iterations) + .build(); + // @formatter:on Argon2BytesGenerator generator = new Argon2BytesGenerator(); generator.init(params); generator.generateBytes(rawPassword.toString().toCharArray(), hash); - return Argon2EncodingUtils.encode(hash, params); } @Override public boolean matches(CharSequence rawPassword, String encodedPassword) { if (encodedPassword == null) { - logger.warn("password hash is null"); + this.logger.warn("password hash is null"); return false; } - Argon2EncodingUtils.Argon2Hash decoded; - try { decoded = Argon2EncodingUtils.decode(encodedPassword); - } catch (IllegalArgumentException e) { - logger.warn("Malformed password hash", e); + } + catch (IllegalArgumentException ex) { + this.logger.warn("Malformed password hash", ex); return false; } - byte[] hashBytes = new byte[decoded.getHash().length]; - Argon2BytesGenerator generator = new Argon2BytesGenerator(); generator.init(decoded.getParameters()); generator.generateBytes(rawPassword.toString().toCharArray(), hashBytes); - return constantTimeArrayEquals(decoded.getHash(), hashBytes); } @Override public boolean upgradeEncoding(String encodedPassword) { if (encodedPassword == null || encodedPassword.length() == 0) { - logger.warn("password hash is null"); + this.logger.warn("password hash is null"); return false; } - Argon2Parameters parameters = Argon2EncodingUtils.decode(encodedPassword).getParameters(); - return parameters.getMemory() < this.memory || parameters.getIterations() < this.iterations; } @@ -129,7 +134,6 @@ public class Argon2PasswordEncoder implements PasswordEncoder { if (expected.length != actual.length) { return false; } - int result = 0; for (int i = 0; i < expected.length; i++) { result |= expected[i] ^ actual[i]; diff --git a/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCrypt.java b/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCrypt.java index ab9a9d7a8a..1b545e5977 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCrypt.java +++ b/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCrypt.java @@ -11,33 +11,32 @@ // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + package org.springframework.security.crypto.bcrypt; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; -import java.util.Arrays; import java.security.SecureRandom; +import java.util.Arrays; /** - * BCrypt implements OpenBSD-style Blowfish password hashing using - * the scheme described in "A Future-Adaptable Password Scheme" by - * Niels Provos and David Mazieres. + * BCrypt implements OpenBSD-style Blowfish password hashing using the scheme described in + * "A Future-Adaptable Password Scheme" by Niels Provos and David Mazieres. *

      - * This password hashing system tries to thwart off-line password - * cracking using a computationally-intensive hashing algorithm, - * based on Bruce Schneier's Blowfish cipher. The work factor of - * the algorithm is parameterised, so it can be increased as - * computers get faster. + * This password hashing system tries to thwart off-line password cracking using a + * computationally-intensive hashing algorithm, based on Bruce Schneier's Blowfish cipher. + * The work factor of the algorithm is parameterised, so it can be increased as computers + * get faster. *

      - * Usage is really simple. To hash a password for the first time, - * call the hashpw method with a random salt, like this: + * Usage is really simple. To hash a password for the first time, call the hashpw method + * with a random salt, like this: *

      * * String pw_hash = BCrypt.hashpw(plain_password, BCrypt.gensalt());
      *
      *

      - * To check whether a plaintext password matches one that has been - * hashed previously, use the checkpw method: + * To check whether a plaintext password matches one that has been hashed previously, use + * the checkpw method: *

      * * if (BCrypt.checkpw(candidate_password, stored_hash))
      @@ -46,347 +45,185 @@ import java.security.SecureRandom; *     System.out.println("It does not match");
      *
      *

      - * The gensalt() method takes an optional parameter (log_rounds) - * that determines the computational complexity of the hashing: + * The gensalt() method takes an optional parameter (log_rounds) that determines the + * computational complexity of the hashing: *

      * * String strong_salt = BCrypt.gensalt(10)
      * String stronger_salt = BCrypt.gensalt(12)
      *
      *

      - * The amount of work increases exponentially (2**log_rounds), so - * each increment is twice as much work. The default log_rounds is - * 10, and the valid range is 4 to 31. + * The amount of work increases exponentially (2**log_rounds), so each increment is twice + * as much work. The default log_rounds is 10, and the valid range is 4 to 31. * * @author Damien Miller * @version 0.3 */ public class BCrypt { + // BCrypt parameters private static final int GENSALT_DEFAULT_LOG2_ROUNDS = 10; + private static final int BCRYPT_SALT_LEN = 16; // Blowfish parameters private static final int BLOWFISH_NUM_ROUNDS = 16; // Initial contents of key schedule - private static final int P_orig[] = { - 0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, - 0xa4093822, 0x299f31d0, 0x082efa98, 0xec4e6c89, - 0x452821e6, 0x38d01377, 0xbe5466cf, 0x34e90c6c, - 0xc0ac29b7, 0xc97c50dd, 0x3f84d5b5, 0xb5470917, - 0x9216d5d9, 0x8979fb1b - }; - private static final int S_orig[] = { - 0xd1310ba6, 0x98dfb5ac, 0x2ffd72db, 0xd01adfb7, - 0xb8e1afed, 0x6a267e96, 0xba7c9045, 0xf12c7f99, - 0x24a19947, 0xb3916cf7, 0x0801f2e2, 0x858efc16, - 0x636920d8, 0x71574e69, 0xa458fea3, 0xf4933d7e, - 0x0d95748f, 0x728eb658, 0x718bcd58, 0x82154aee, - 0x7b54a41d, 0xc25a59b5, 0x9c30d539, 0x2af26013, - 0xc5d1b023, 0x286085f0, 0xca417918, 0xb8db38ef, - 0x8e79dcb0, 0x603a180e, 0x6c9e0e8b, 0xb01e8a3e, - 0xd71577c1, 0xbd314b27, 0x78af2fda, 0x55605c60, - 0xe65525f3, 0xaa55ab94, 0x57489862, 0x63e81440, - 0x55ca396a, 0x2aab10b6, 0xb4cc5c34, 0x1141e8ce, - 0xa15486af, 0x7c72e993, 0xb3ee1411, 0x636fbc2a, - 0x2ba9c55d, 0x741831f6, 0xce5c3e16, 0x9b87931e, - 0xafd6ba33, 0x6c24cf5c, 0x7a325381, 0x28958677, - 0x3b8f4898, 0x6b4bb9af, 0xc4bfe81b, 0x66282193, - 0x61d809cc, 0xfb21a991, 0x487cac60, 0x5dec8032, - 0xef845d5d, 0xe98575b1, 0xdc262302, 0xeb651b88, - 0x23893e81, 0xd396acc5, 0x0f6d6ff3, 0x83f44239, - 0x2e0b4482, 0xa4842004, 0x69c8f04a, 0x9e1f9b5e, - 0x21c66842, 0xf6e96c9a, 0x670c9c61, 0xabd388f0, - 0x6a51a0d2, 0xd8542f68, 0x960fa728, 0xab5133a3, - 0x6eef0b6c, 0x137a3be4, 0xba3bf050, 0x7efb2a98, - 0xa1f1651d, 0x39af0176, 0x66ca593e, 0x82430e88, - 0x8cee8619, 0x456f9fb4, 0x7d84a5c3, 0x3b8b5ebe, - 0xe06f75d8, 0x85c12073, 0x401a449f, 0x56c16aa6, - 0x4ed3aa62, 0x363f7706, 0x1bfedf72, 0x429b023d, - 0x37d0d724, 0xd00a1248, 0xdb0fead3, 0x49f1c09b, - 0x075372c9, 0x80991b7b, 0x25d479d8, 0xf6e8def7, - 0xe3fe501a, 0xb6794c3b, 0x976ce0bd, 0x04c006ba, - 0xc1a94fb6, 0x409f60c4, 0x5e5c9ec2, 0x196a2463, - 0x68fb6faf, 0x3e6c53b5, 0x1339b2eb, 0x3b52ec6f, - 0x6dfc511f, 0x9b30952c, 0xcc814544, 0xaf5ebd09, - 0xbee3d004, 0xde334afd, 0x660f2807, 0x192e4bb3, - 0xc0cba857, 0x45c8740f, 0xd20b5f39, 0xb9d3fbdb, - 0x5579c0bd, 0x1a60320a, 0xd6a100c6, 0x402c7279, - 0x679f25fe, 0xfb1fa3cc, 0x8ea5e9f8, 0xdb3222f8, - 0x3c7516df, 0xfd616b15, 0x2f501ec8, 0xad0552ab, - 0x323db5fa, 0xfd238760, 0x53317b48, 0x3e00df82, - 0x9e5c57bb, 0xca6f8ca0, 0x1a87562e, 0xdf1769db, - 0xd542a8f6, 0x287effc3, 0xac6732c6, 0x8c4f5573, - 0x695b27b0, 0xbbca58c8, 0xe1ffa35d, 0xb8f011a0, - 0x10fa3d98, 0xfd2183b8, 0x4afcb56c, 0x2dd1d35b, - 0x9a53e479, 0xb6f84565, 0xd28e49bc, 0x4bfb9790, - 0xe1ddf2da, 0xa4cb7e33, 0x62fb1341, 0xcee4c6e8, - 0xef20cada, 0x36774c01, 0xd07e9efe, 0x2bf11fb4, - 0x95dbda4d, 0xae909198, 0xeaad8e71, 0x6b93d5a0, - 0xd08ed1d0, 0xafc725e0, 0x8e3c5b2f, 0x8e7594b7, - 0x8ff6e2fb, 0xf2122b64, 0x8888b812, 0x900df01c, - 0x4fad5ea0, 0x688fc31c, 0xd1cff191, 0xb3a8c1ad, - 0x2f2f2218, 0xbe0e1777, 0xea752dfe, 0x8b021fa1, - 0xe5a0cc0f, 0xb56f74e8, 0x18acf3d6, 0xce89e299, - 0xb4a84fe0, 0xfd13e0b7, 0x7cc43b81, 0xd2ada8d9, - 0x165fa266, 0x80957705, 0x93cc7314, 0x211a1477, - 0xe6ad2065, 0x77b5fa86, 0xc75442f5, 0xfb9d35cf, - 0xebcdaf0c, 0x7b3e89a0, 0xd6411bd3, 0xae1e7e49, - 0x00250e2d, 0x2071b35e, 0x226800bb, 0x57b8e0af, - 0x2464369b, 0xf009b91e, 0x5563911d, 0x59dfa6aa, - 0x78c14389, 0xd95a537f, 0x207d5ba2, 0x02e5b9c5, - 0x83260376, 0x6295cfa9, 0x11c81968, 0x4e734a41, - 0xb3472dca, 0x7b14a94a, 0x1b510052, 0x9a532915, - 0xd60f573f, 0xbc9bc6e4, 0x2b60a476, 0x81e67400, - 0x08ba6fb5, 0x571be91f, 0xf296ec6b, 0x2a0dd915, - 0xb6636521, 0xe7b9f9b6, 0xff34052e, 0xc5855664, - 0x53b02d5d, 0xa99f8fa1, 0x08ba4799, 0x6e85076a, - 0x4b7a70e9, 0xb5b32944, 0xdb75092e, 0xc4192623, - 0xad6ea6b0, 0x49a7df7d, 0x9cee60b8, 0x8fedb266, - 0xecaa8c71, 0x699a17ff, 0x5664526c, 0xc2b19ee1, - 0x193602a5, 0x75094c29, 0xa0591340, 0xe4183a3e, - 0x3f54989a, 0x5b429d65, 0x6b8fe4d6, 0x99f73fd6, - 0xa1d29c07, 0xefe830f5, 0x4d2d38e6, 0xf0255dc1, - 0x4cdd2086, 0x8470eb26, 0x6382e9c6, 0x021ecc5e, - 0x09686b3f, 0x3ebaefc9, 0x3c971814, 0x6b6a70a1, - 0x687f3584, 0x52a0e286, 0xb79c5305, 0xaa500737, - 0x3e07841c, 0x7fdeae5c, 0x8e7d44ec, 0x5716f2b8, - 0xb03ada37, 0xf0500c0d, 0xf01c1f04, 0x0200b3ff, - 0xae0cf51a, 0x3cb574b2, 0x25837a58, 0xdc0921bd, - 0xd19113f9, 0x7ca92ff6, 0x94324773, 0x22f54701, - 0x3ae5e581, 0x37c2dadc, 0xc8b57634, 0x9af3dda7, - 0xa9446146, 0x0fd0030e, 0xecc8c73e, 0xa4751e41, - 0xe238cd99, 0x3bea0e2f, 0x3280bba1, 0x183eb331, - 0x4e548b38, 0x4f6db908, 0x6f420d03, 0xf60a04bf, - 0x2cb81290, 0x24977c79, 0x5679b072, 0xbcaf89af, - 0xde9a771f, 0xd9930810, 0xb38bae12, 0xdccf3f2e, - 0x5512721f, 0x2e6b7124, 0x501adde6, 0x9f84cd87, - 0x7a584718, 0x7408da17, 0xbc9f9abc, 0xe94b7d8c, - 0xec7aec3a, 0xdb851dfa, 0x63094366, 0xc464c3d2, - 0xef1c1847, 0x3215d908, 0xdd433b37, 0x24c2ba16, - 0x12a14d43, 0x2a65c451, 0x50940002, 0x133ae4dd, - 0x71dff89e, 0x10314e55, 0x81ac77d6, 0x5f11199b, - 0x043556f1, 0xd7a3c76b, 0x3c11183b, 0x5924a509, - 0xf28fe6ed, 0x97f1fbfa, 0x9ebabf2c, 0x1e153c6e, - 0x86e34570, 0xeae96fb1, 0x860e5e0a, 0x5a3e2ab3, - 0x771fe71c, 0x4e3d06fa, 0x2965dcb9, 0x99e71d0f, - 0x803e89d6, 0x5266c825, 0x2e4cc978, 0x9c10b36a, - 0xc6150eba, 0x94e2ea78, 0xa5fc3c53, 0x1e0a2df4, - 0xf2f74ea7, 0x361d2b3d, 0x1939260f, 0x19c27960, - 0x5223a708, 0xf71312b6, 0xebadfe6e, 0xeac31f66, - 0xe3bc4595, 0xa67bc883, 0xb17f37d1, 0x018cff28, - 0xc332ddef, 0xbe6c5aa5, 0x65582185, 0x68ab9802, - 0xeecea50f, 0xdb2f953b, 0x2aef7dad, 0x5b6e2f84, - 0x1521b628, 0x29076170, 0xecdd4775, 0x619f1510, - 0x13cca830, 0xeb61bd96, 0x0334fe1e, 0xaa0363cf, - 0xb5735c90, 0x4c70a239, 0xd59e9e0b, 0xcbaade14, - 0xeecc86bc, 0x60622ca7, 0x9cab5cab, 0xb2f3846e, - 0x648b1eaf, 0x19bdf0ca, 0xa02369b9, 0x655abb50, - 0x40685a32, 0x3c2ab4b3, 0x319ee9d5, 0xc021b8f7, - 0x9b540b19, 0x875fa099, 0x95f7997e, 0x623d7da8, - 0xf837889a, 0x97e32d77, 0x11ed935f, 0x16681281, - 0x0e358829, 0xc7e61fd6, 0x96dedfa1, 0x7858ba99, - 0x57f584a5, 0x1b227263, 0x9b83c3ff, 0x1ac24696, - 0xcdb30aeb, 0x532e3054, 0x8fd948e4, 0x6dbc3128, - 0x58ebf2ef, 0x34c6ffea, 0xfe28ed61, 0xee7c3c73, - 0x5d4a14d9, 0xe864b7e3, 0x42105d14, 0x203e13e0, - 0x45eee2b6, 0xa3aaabea, 0xdb6c4f15, 0xfacb4fd0, - 0xc742f442, 0xef6abbb5, 0x654f3b1d, 0x41cd2105, - 0xd81e799e, 0x86854dc7, 0xe44b476a, 0x3d816250, - 0xcf62a1f2, 0x5b8d2646, 0xfc8883a0, 0xc1c7b6a3, - 0x7f1524c3, 0x69cb7492, 0x47848a0b, 0x5692b285, - 0x095bbf00, 0xad19489d, 0x1462b174, 0x23820e00, - 0x58428d2a, 0x0c55f5ea, 0x1dadf43e, 0x233f7061, - 0x3372f092, 0x8d937e41, 0xd65fecf1, 0x6c223bdb, - 0x7cde3759, 0xcbee7460, 0x4085f2a7, 0xce77326e, - 0xa6078084, 0x19f8509e, 0xe8efd855, 0x61d99735, - 0xa969a7aa, 0xc50c06c2, 0x5a04abfc, 0x800bcadc, - 0x9e447a2e, 0xc3453484, 0xfdd56705, 0x0e1e9ec9, - 0xdb73dbd3, 0x105588cd, 0x675fda79, 0xe3674340, - 0xc5c43465, 0x713e38d8, 0x3d28f89e, 0xf16dff20, - 0x153e21e7, 0x8fb03d4a, 0xe6e39f2b, 0xdb83adf7, - 0xe93d5a68, 0x948140f7, 0xf64c261c, 0x94692934, - 0x411520f7, 0x7602d4f7, 0xbcf46b2e, 0xd4a20068, - 0xd4082471, 0x3320f46a, 0x43b7d4b7, 0x500061af, - 0x1e39f62e, 0x97244546, 0x14214f74, 0xbf8b8840, - 0x4d95fc1d, 0x96b591af, 0x70f4ddd3, 0x66a02f45, - 0xbfbc09ec, 0x03bd9785, 0x7fac6dd0, 0x31cb8504, - 0x96eb27b3, 0x55fd3941, 0xda2547e6, 0xabca0a9a, - 0x28507825, 0x530429f4, 0x0a2c86da, 0xe9b66dfb, - 0x68dc1462, 0xd7486900, 0x680ec0a4, 0x27a18dee, - 0x4f3ffea2, 0xe887ad8c, 0xb58ce006, 0x7af4d6b6, - 0xaace1e7c, 0xd3375fec, 0xce78a399, 0x406b2a42, - 0x20fe9e35, 0xd9f385b9, 0xee39d7ab, 0x3b124e8b, - 0x1dc9faf7, 0x4b6d1856, 0x26a36631, 0xeae397b2, - 0x3a6efa74, 0xdd5b4332, 0x6841e7f7, 0xca7820fb, - 0xfb0af54e, 0xd8feb397, 0x454056ac, 0xba489527, - 0x55533a3a, 0x20838d87, 0xfe6ba9b7, 0xd096954b, - 0x55a867bc, 0xa1159a58, 0xcca92963, 0x99e1db33, - 0xa62a4a56, 0x3f3125f9, 0x5ef47e1c, 0x9029317c, - 0xfdf8e802, 0x04272f70, 0x80bb155c, 0x05282ce3, - 0x95c11548, 0xe4c66d22, 0x48c1133f, 0xc70f86dc, - 0x07f9c9ee, 0x41041f0f, 0x404779a4, 0x5d886e17, - 0x325f51eb, 0xd59bc0d1, 0xf2bcc18f, 0x41113564, - 0x257b7834, 0x602a9c60, 0xdff8e8a3, 0x1f636c1b, - 0x0e12b4c2, 0x02e1329e, 0xaf664fd1, 0xcad18115, - 0x6b2395e0, 0x333e92e1, 0x3b240b62, 0xeebeb922, - 0x85b2a20e, 0xe6ba0d99, 0xde720c8c, 0x2da2f728, - 0xd0127845, 0x95b794fd, 0x647d0862, 0xe7ccf5f0, - 0x5449a36f, 0x877d48fa, 0xc39dfd27, 0xf33e8d1e, - 0x0a476341, 0x992eff74, 0x3a6f6eab, 0xf4f8fd37, - 0xa812dc60, 0xa1ebddf8, 0x991be14c, 0xdb6e6b0d, - 0xc67b5510, 0x6d672c37, 0x2765d43b, 0xdcd0e804, - 0xf1290dc7, 0xcc00ffa3, 0xb5390f92, 0x690fed0b, - 0x667b9ffb, 0xcedb7d9c, 0xa091cf0b, 0xd9155ea3, - 0xbb132f88, 0x515bad24, 0x7b9479bf, 0x763bd6eb, - 0x37392eb3, 0xcc115979, 0x8026e297, 0xf42e312d, - 0x6842ada7, 0xc66a2b3b, 0x12754ccc, 0x782ef11c, - 0x6a124237, 0xb79251e7, 0x06a1bbe6, 0x4bfb6350, - 0x1a6b1018, 0x11caedfa, 0x3d25bdd8, 0xe2e1c3c9, - 0x44421659, 0x0a121386, 0xd90cec6e, 0xd5abea2a, - 0x64af674e, 0xda86a85f, 0xbebfe988, 0x64e4c3fe, - 0x9dbc8057, 0xf0f7c086, 0x60787bf8, 0x6003604d, - 0xd1fd8346, 0xf6381fb0, 0x7745ae04, 0xd736fccc, - 0x83426b33, 0xf01eab71, 0xb0804187, 0x3c005e5f, - 0x77a057be, 0xbde8ae24, 0x55464299, 0xbf582e61, - 0x4e58f48f, 0xf2ddfda2, 0xf474ef38, 0x8789bdc2, - 0x5366f9c3, 0xc8b38e74, 0xb475f255, 0x46fcd9b9, - 0x7aeb2661, 0x8b1ddf84, 0x846a0e79, 0x915f95e2, - 0x466e598e, 0x20b45770, 0x8cd55591, 0xc902de4c, - 0xb90bace1, 0xbb8205d0, 0x11a86248, 0x7574a99e, - 0xb77f19b6, 0xe0a9dc09, 0x662d09a1, 0xc4324633, - 0xe85a1f02, 0x09f0be8c, 0x4a99a025, 0x1d6efe10, - 0x1ab93d1d, 0x0ba5a4df, 0xa186f20f, 0x2868f169, - 0xdcb7da83, 0x573906fe, 0xa1e2ce9b, 0x4fcd7f52, - 0x50115e01, 0xa70683fa, 0xa002b5c4, 0x0de6d027, - 0x9af88c27, 0x773f8641, 0xc3604c06, 0x61a806b5, - 0xf0177a28, 0xc0f586e0, 0x006058aa, 0x30dc7d62, - 0x11e69ed7, 0x2338ea63, 0x53c2dd94, 0xc2c21634, - 0xbbcbee56, 0x90bcb6de, 0xebfc7da1, 0xce591d76, - 0x6f05e409, 0x4b7c0188, 0x39720a3d, 0x7c927c24, - 0x86e3725f, 0x724d9db9, 0x1ac15bb4, 0xd39eb8fc, - 0xed545578, 0x08fca5b5, 0xd83d7cd3, 0x4dad0fc4, - 0x1e50ef5e, 0xb161e6f8, 0xa28514d9, 0x6c51133c, - 0x6fd5c7e7, 0x56e14ec4, 0x362abfce, 0xddc6c837, - 0xd79a3234, 0x92638212, 0x670efa8e, 0x406000e0, - 0x3a39ce37, 0xd3faf5cf, 0xabc27737, 0x5ac52d1b, - 0x5cb0679e, 0x4fa33742, 0xd3822740, 0x99bc9bbe, - 0xd5118e9d, 0xbf0f7315, 0xd62d1c7e, 0xc700c47b, - 0xb78c1b6b, 0x21a19045, 0xb26eb1be, 0x6a366eb4, - 0x5748ab2f, 0xbc946e79, 0xc6a376d2, 0x6549c2c8, - 0x530ff8ee, 0x468dde7d, 0xd5730a1d, 0x4cd04dc6, - 0x2939bbdb, 0xa9ba4650, 0xac9526e8, 0xbe5ee304, - 0xa1fad5f0, 0x6a2d519a, 0x63ef8ce2, 0x9a86ee22, - 0xc089c2b8, 0x43242ef6, 0xa51e03aa, 0x9cf2d0a4, - 0x83c061ba, 0x9be96a4d, 0x8fe51550, 0xba645bd6, - 0x2826a2f9, 0xa73a3ae1, 0x4ba99586, 0xef5562e9, - 0xc72fefd3, 0xf752f7da, 0x3f046f69, 0x77fa0a59, - 0x80e4a915, 0x87b08601, 0x9b09e6ad, 0x3b3ee593, - 0xe990fd5a, 0x9e34d797, 0x2cf0b7d9, 0x022b8b51, - 0x96d5ac3a, 0x017da67d, 0xd1cf3ed6, 0x7c7d2d28, - 0x1f9f25cf, 0xadf2b89b, 0x5ad6b472, 0x5a88f54c, - 0xe029ac71, 0xe019a5e6, 0x47b0acfd, 0xed93fa9b, - 0xe8d3c48d, 0x283b57cc, 0xf8d56629, 0x79132e28, - 0x785f0191, 0xed756055, 0xf7960e44, 0xe3d35e8c, - 0x15056dd4, 0x88f46dba, 0x03a16125, 0x0564f0bd, - 0xc3eb9e15, 0x3c9057a2, 0x97271aec, 0xa93a072a, - 0x1b3f6d9b, 0x1e6321f5, 0xf59c66fb, 0x26dcf319, - 0x7533d928, 0xb155fdf5, 0x03563482, 0x8aba3cbb, - 0x28517711, 0xc20ad9f8, 0xabcc5167, 0xccad925f, - 0x4de81751, 0x3830dc8e, 0x379d5862, 0x9320f991, - 0xea7a90c2, 0xfb3e7bce, 0x5121ce64, 0x774fbe32, - 0xa8b6e37e, 0xc3293d46, 0x48de5369, 0x6413e680, - 0xa2ae0810, 0xdd6db224, 0x69852dfd, 0x09072166, - 0xb39a460a, 0x6445c0dd, 0x586cdecf, 0x1c20c8ae, - 0x5bbef7dd, 0x1b588d40, 0xccd2017f, 0x6bb4e3bb, - 0xdda26a7e, 0x3a59ff45, 0x3e350a44, 0xbcb4cdd5, - 0x72eacea8, 0xfa6484bb, 0x8d6612ae, 0xbf3c6f47, - 0xd29be463, 0x542f5d9e, 0xaec2771b, 0xf64e6370, - 0x740e0d8d, 0xe75b1357, 0xf8721671, 0xaf537d5d, - 0x4040cb08, 0x4eb4e2cc, 0x34d2466a, 0x0115af84, - 0xe1b00428, 0x95983a1d, 0x06b89fb4, 0xce6ea048, - 0x6f3f3b82, 0x3520ab82, 0x011a1d4b, 0x277227f8, - 0x611560b1, 0xe7933fdc, 0xbb3a792b, 0x344525bd, - 0xa08839e1, 0x51ce794b, 0x2f32c9b7, 0xa01fbac9, - 0xe01cc87e, 0xbcc7d1f6, 0xcf0111c3, 0xa1e8aac7, - 0x1a908749, 0xd44fbd9a, 0xd0dadecb, 0xd50ada38, - 0x0339c32a, 0xc6913667, 0x8df9317c, 0xe0b12b4f, - 0xf79e59b7, 0x43f5bb3a, 0xf2d519ff, 0x27d9459c, - 0xbf97222c, 0x15e6fc2a, 0x0f91fc71, 0x9b941525, - 0xfae59361, 0xceb69ceb, 0xc2a86459, 0x12baa8d1, - 0xb6c1075e, 0xe3056a0c, 0x10d25065, 0xcb03a442, - 0xe0ec6e0e, 0x1698db3b, 0x4c98a0be, 0x3278e964, - 0x9f1f9532, 0xe0d392df, 0xd3a0342b, 0x8971f21e, - 0x1b0a7441, 0x4ba3348c, 0xc5be7120, 0xc37632d8, - 0xdf359f8d, 0x9b992f2e, 0xe60b6f47, 0x0fe3f11d, - 0xe54cda54, 0x1edad891, 0xce6279cf, 0xcd3e7e6f, - 0x1618b166, 0xfd2c1d05, 0x848fd2c5, 0xf6fb2299, - 0xf523f357, 0xa6327623, 0x93a83531, 0x56cccd02, - 0xacf08162, 0x5a75ebb5, 0x6e163697, 0x88d273cc, - 0xde966292, 0x81b949d0, 0x4c50901b, 0x71c65614, - 0xe6c6c7bd, 0x327a140a, 0x45e1d006, 0xc3f27b9a, - 0xc9aa53fd, 0x62a80f00, 0xbb25bfe2, 0x35bdd2f6, - 0x71126905, 0xb2040222, 0xb6cbcf7c, 0xcd769c2b, - 0x53113ec0, 0x1640e3d3, 0x38abbd60, 0x2547adf0, - 0xba38209c, 0xf746ce76, 0x77afa1c5, 0x20756060, - 0x85cbfe4e, 0x8ae88dd8, 0x7aaaf9b0, 0x4cf9aa7e, - 0x1948c25c, 0x02fb8a8c, 0x01c36ae4, 0xd6ebe1f9, - 0x90d4f869, 0xa65cdea0, 0x3f09252d, 0xc208e69f, - 0xb74e6132, 0xce77e25b, 0x578fdfe3, 0x3ac372e6 - }; + private static final int P_orig[] = { 0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, 0xa4093822, 0x299f31d0, + 0x082efa98, 0xec4e6c89, 0x452821e6, 0x38d01377, 0xbe5466cf, 0x34e90c6c, 0xc0ac29b7, 0xc97c50dd, 0x3f84d5b5, + 0xb5470917, 0x9216d5d9, 0x8979fb1b }; + + private static final int S_orig[] = { 0xd1310ba6, 0x98dfb5ac, 0x2ffd72db, 0xd01adfb7, 0xb8e1afed, 0x6a267e96, + 0xba7c9045, 0xf12c7f99, 0x24a19947, 0xb3916cf7, 0x0801f2e2, 0x858efc16, 0x636920d8, 0x71574e69, 0xa458fea3, + 0xf4933d7e, 0x0d95748f, 0x728eb658, 0x718bcd58, 0x82154aee, 0x7b54a41d, 0xc25a59b5, 0x9c30d539, 0x2af26013, + 0xc5d1b023, 0x286085f0, 0xca417918, 0xb8db38ef, 0x8e79dcb0, 0x603a180e, 0x6c9e0e8b, 0xb01e8a3e, 0xd71577c1, + 0xbd314b27, 0x78af2fda, 0x55605c60, 0xe65525f3, 0xaa55ab94, 0x57489862, 0x63e81440, 0x55ca396a, 0x2aab10b6, + 0xb4cc5c34, 0x1141e8ce, 0xa15486af, 0x7c72e993, 0xb3ee1411, 0x636fbc2a, 0x2ba9c55d, 0x741831f6, 0xce5c3e16, + 0x9b87931e, 0xafd6ba33, 0x6c24cf5c, 0x7a325381, 0x28958677, 0x3b8f4898, 0x6b4bb9af, 0xc4bfe81b, 0x66282193, + 0x61d809cc, 0xfb21a991, 0x487cac60, 0x5dec8032, 0xef845d5d, 0xe98575b1, 0xdc262302, 0xeb651b88, 0x23893e81, + 0xd396acc5, 0x0f6d6ff3, 0x83f44239, 0x2e0b4482, 0xa4842004, 0x69c8f04a, 0x9e1f9b5e, 0x21c66842, 0xf6e96c9a, + 0x670c9c61, 0xabd388f0, 0x6a51a0d2, 0xd8542f68, 0x960fa728, 0xab5133a3, 0x6eef0b6c, 0x137a3be4, 0xba3bf050, + 0x7efb2a98, 0xa1f1651d, 0x39af0176, 0x66ca593e, 0x82430e88, 0x8cee8619, 0x456f9fb4, 0x7d84a5c3, 0x3b8b5ebe, + 0xe06f75d8, 0x85c12073, 0x401a449f, 0x56c16aa6, 0x4ed3aa62, 0x363f7706, 0x1bfedf72, 0x429b023d, 0x37d0d724, + 0xd00a1248, 0xdb0fead3, 0x49f1c09b, 0x075372c9, 0x80991b7b, 0x25d479d8, 0xf6e8def7, 0xe3fe501a, 0xb6794c3b, + 0x976ce0bd, 0x04c006ba, 0xc1a94fb6, 0x409f60c4, 0x5e5c9ec2, 0x196a2463, 0x68fb6faf, 0x3e6c53b5, 0x1339b2eb, + 0x3b52ec6f, 0x6dfc511f, 0x9b30952c, 0xcc814544, 0xaf5ebd09, 0xbee3d004, 0xde334afd, 0x660f2807, 0x192e4bb3, + 0xc0cba857, 0x45c8740f, 0xd20b5f39, 0xb9d3fbdb, 0x5579c0bd, 0x1a60320a, 0xd6a100c6, 0x402c7279, 0x679f25fe, + 0xfb1fa3cc, 0x8ea5e9f8, 0xdb3222f8, 0x3c7516df, 0xfd616b15, 0x2f501ec8, 0xad0552ab, 0x323db5fa, 0xfd238760, + 0x53317b48, 0x3e00df82, 0x9e5c57bb, 0xca6f8ca0, 0x1a87562e, 0xdf1769db, 0xd542a8f6, 0x287effc3, 0xac6732c6, + 0x8c4f5573, 0x695b27b0, 0xbbca58c8, 0xe1ffa35d, 0xb8f011a0, 0x10fa3d98, 0xfd2183b8, 0x4afcb56c, 0x2dd1d35b, + 0x9a53e479, 0xb6f84565, 0xd28e49bc, 0x4bfb9790, 0xe1ddf2da, 0xa4cb7e33, 0x62fb1341, 0xcee4c6e8, 0xef20cada, + 0x36774c01, 0xd07e9efe, 0x2bf11fb4, 0x95dbda4d, 0xae909198, 0xeaad8e71, 0x6b93d5a0, 0xd08ed1d0, 0xafc725e0, + 0x8e3c5b2f, 0x8e7594b7, 0x8ff6e2fb, 0xf2122b64, 0x8888b812, 0x900df01c, 0x4fad5ea0, 0x688fc31c, 0xd1cff191, + 0xb3a8c1ad, 0x2f2f2218, 0xbe0e1777, 0xea752dfe, 0x8b021fa1, 0xe5a0cc0f, 0xb56f74e8, 0x18acf3d6, 0xce89e299, + 0xb4a84fe0, 0xfd13e0b7, 0x7cc43b81, 0xd2ada8d9, 0x165fa266, 0x80957705, 0x93cc7314, 0x211a1477, 0xe6ad2065, + 0x77b5fa86, 0xc75442f5, 0xfb9d35cf, 0xebcdaf0c, 0x7b3e89a0, 0xd6411bd3, 0xae1e7e49, 0x00250e2d, 0x2071b35e, + 0x226800bb, 0x57b8e0af, 0x2464369b, 0xf009b91e, 0x5563911d, 0x59dfa6aa, 0x78c14389, 0xd95a537f, 0x207d5ba2, + 0x02e5b9c5, 0x83260376, 0x6295cfa9, 0x11c81968, 0x4e734a41, 0xb3472dca, 0x7b14a94a, 0x1b510052, 0x9a532915, + 0xd60f573f, 0xbc9bc6e4, 0x2b60a476, 0x81e67400, 0x08ba6fb5, 0x571be91f, 0xf296ec6b, 0x2a0dd915, 0xb6636521, + 0xe7b9f9b6, 0xff34052e, 0xc5855664, 0x53b02d5d, 0xa99f8fa1, 0x08ba4799, 0x6e85076a, 0x4b7a70e9, 0xb5b32944, + 0xdb75092e, 0xc4192623, 0xad6ea6b0, 0x49a7df7d, 0x9cee60b8, 0x8fedb266, 0xecaa8c71, 0x699a17ff, 0x5664526c, + 0xc2b19ee1, 0x193602a5, 0x75094c29, 0xa0591340, 0xe4183a3e, 0x3f54989a, 0x5b429d65, 0x6b8fe4d6, 0x99f73fd6, + 0xa1d29c07, 0xefe830f5, 0x4d2d38e6, 0xf0255dc1, 0x4cdd2086, 0x8470eb26, 0x6382e9c6, 0x021ecc5e, 0x09686b3f, + 0x3ebaefc9, 0x3c971814, 0x6b6a70a1, 0x687f3584, 0x52a0e286, 0xb79c5305, 0xaa500737, 0x3e07841c, 0x7fdeae5c, + 0x8e7d44ec, 0x5716f2b8, 0xb03ada37, 0xf0500c0d, 0xf01c1f04, 0x0200b3ff, 0xae0cf51a, 0x3cb574b2, 0x25837a58, + 0xdc0921bd, 0xd19113f9, 0x7ca92ff6, 0x94324773, 0x22f54701, 0x3ae5e581, 0x37c2dadc, 0xc8b57634, 0x9af3dda7, + 0xa9446146, 0x0fd0030e, 0xecc8c73e, 0xa4751e41, 0xe238cd99, 0x3bea0e2f, 0x3280bba1, 0x183eb331, 0x4e548b38, + 0x4f6db908, 0x6f420d03, 0xf60a04bf, 0x2cb81290, 0x24977c79, 0x5679b072, 0xbcaf89af, 0xde9a771f, 0xd9930810, + 0xb38bae12, 0xdccf3f2e, 0x5512721f, 0x2e6b7124, 0x501adde6, 0x9f84cd87, 0x7a584718, 0x7408da17, 0xbc9f9abc, + 0xe94b7d8c, 0xec7aec3a, 0xdb851dfa, 0x63094366, 0xc464c3d2, 0xef1c1847, 0x3215d908, 0xdd433b37, 0x24c2ba16, + 0x12a14d43, 0x2a65c451, 0x50940002, 0x133ae4dd, 0x71dff89e, 0x10314e55, 0x81ac77d6, 0x5f11199b, 0x043556f1, + 0xd7a3c76b, 0x3c11183b, 0x5924a509, 0xf28fe6ed, 0x97f1fbfa, 0x9ebabf2c, 0x1e153c6e, 0x86e34570, 0xeae96fb1, + 0x860e5e0a, 0x5a3e2ab3, 0x771fe71c, 0x4e3d06fa, 0x2965dcb9, 0x99e71d0f, 0x803e89d6, 0x5266c825, 0x2e4cc978, + 0x9c10b36a, 0xc6150eba, 0x94e2ea78, 0xa5fc3c53, 0x1e0a2df4, 0xf2f74ea7, 0x361d2b3d, 0x1939260f, 0x19c27960, + 0x5223a708, 0xf71312b6, 0xebadfe6e, 0xeac31f66, 0xe3bc4595, 0xa67bc883, 0xb17f37d1, 0x018cff28, 0xc332ddef, + 0xbe6c5aa5, 0x65582185, 0x68ab9802, 0xeecea50f, 0xdb2f953b, 0x2aef7dad, 0x5b6e2f84, 0x1521b628, 0x29076170, + 0xecdd4775, 0x619f1510, 0x13cca830, 0xeb61bd96, 0x0334fe1e, 0xaa0363cf, 0xb5735c90, 0x4c70a239, 0xd59e9e0b, + 0xcbaade14, 0xeecc86bc, 0x60622ca7, 0x9cab5cab, 0xb2f3846e, 0x648b1eaf, 0x19bdf0ca, 0xa02369b9, 0x655abb50, + 0x40685a32, 0x3c2ab4b3, 0x319ee9d5, 0xc021b8f7, 0x9b540b19, 0x875fa099, 0x95f7997e, 0x623d7da8, 0xf837889a, + 0x97e32d77, 0x11ed935f, 0x16681281, 0x0e358829, 0xc7e61fd6, 0x96dedfa1, 0x7858ba99, 0x57f584a5, 0x1b227263, + 0x9b83c3ff, 0x1ac24696, 0xcdb30aeb, 0x532e3054, 0x8fd948e4, 0x6dbc3128, 0x58ebf2ef, 0x34c6ffea, 0xfe28ed61, + 0xee7c3c73, 0x5d4a14d9, 0xe864b7e3, 0x42105d14, 0x203e13e0, 0x45eee2b6, 0xa3aaabea, 0xdb6c4f15, 0xfacb4fd0, + 0xc742f442, 0xef6abbb5, 0x654f3b1d, 0x41cd2105, 0xd81e799e, 0x86854dc7, 0xe44b476a, 0x3d816250, 0xcf62a1f2, + 0x5b8d2646, 0xfc8883a0, 0xc1c7b6a3, 0x7f1524c3, 0x69cb7492, 0x47848a0b, 0x5692b285, 0x095bbf00, 0xad19489d, + 0x1462b174, 0x23820e00, 0x58428d2a, 0x0c55f5ea, 0x1dadf43e, 0x233f7061, 0x3372f092, 0x8d937e41, 0xd65fecf1, + 0x6c223bdb, 0x7cde3759, 0xcbee7460, 0x4085f2a7, 0xce77326e, 0xa6078084, 0x19f8509e, 0xe8efd855, 0x61d99735, + 0xa969a7aa, 0xc50c06c2, 0x5a04abfc, 0x800bcadc, 0x9e447a2e, 0xc3453484, 0xfdd56705, 0x0e1e9ec9, 0xdb73dbd3, + 0x105588cd, 0x675fda79, 0xe3674340, 0xc5c43465, 0x713e38d8, 0x3d28f89e, 0xf16dff20, 0x153e21e7, 0x8fb03d4a, + 0xe6e39f2b, 0xdb83adf7, 0xe93d5a68, 0x948140f7, 0xf64c261c, 0x94692934, 0x411520f7, 0x7602d4f7, 0xbcf46b2e, + 0xd4a20068, 0xd4082471, 0x3320f46a, 0x43b7d4b7, 0x500061af, 0x1e39f62e, 0x97244546, 0x14214f74, 0xbf8b8840, + 0x4d95fc1d, 0x96b591af, 0x70f4ddd3, 0x66a02f45, 0xbfbc09ec, 0x03bd9785, 0x7fac6dd0, 0x31cb8504, 0x96eb27b3, + 0x55fd3941, 0xda2547e6, 0xabca0a9a, 0x28507825, 0x530429f4, 0x0a2c86da, 0xe9b66dfb, 0x68dc1462, 0xd7486900, + 0x680ec0a4, 0x27a18dee, 0x4f3ffea2, 0xe887ad8c, 0xb58ce006, 0x7af4d6b6, 0xaace1e7c, 0xd3375fec, 0xce78a399, + 0x406b2a42, 0x20fe9e35, 0xd9f385b9, 0xee39d7ab, 0x3b124e8b, 0x1dc9faf7, 0x4b6d1856, 0x26a36631, 0xeae397b2, + 0x3a6efa74, 0xdd5b4332, 0x6841e7f7, 0xca7820fb, 0xfb0af54e, 0xd8feb397, 0x454056ac, 0xba489527, 0x55533a3a, + 0x20838d87, 0xfe6ba9b7, 0xd096954b, 0x55a867bc, 0xa1159a58, 0xcca92963, 0x99e1db33, 0xa62a4a56, 0x3f3125f9, + 0x5ef47e1c, 0x9029317c, 0xfdf8e802, 0x04272f70, 0x80bb155c, 0x05282ce3, 0x95c11548, 0xe4c66d22, 0x48c1133f, + 0xc70f86dc, 0x07f9c9ee, 0x41041f0f, 0x404779a4, 0x5d886e17, 0x325f51eb, 0xd59bc0d1, 0xf2bcc18f, 0x41113564, + 0x257b7834, 0x602a9c60, 0xdff8e8a3, 0x1f636c1b, 0x0e12b4c2, 0x02e1329e, 0xaf664fd1, 0xcad18115, 0x6b2395e0, + 0x333e92e1, 0x3b240b62, 0xeebeb922, 0x85b2a20e, 0xe6ba0d99, 0xde720c8c, 0x2da2f728, 0xd0127845, 0x95b794fd, + 0x647d0862, 0xe7ccf5f0, 0x5449a36f, 0x877d48fa, 0xc39dfd27, 0xf33e8d1e, 0x0a476341, 0x992eff74, 0x3a6f6eab, + 0xf4f8fd37, 0xa812dc60, 0xa1ebddf8, 0x991be14c, 0xdb6e6b0d, 0xc67b5510, 0x6d672c37, 0x2765d43b, 0xdcd0e804, + 0xf1290dc7, 0xcc00ffa3, 0xb5390f92, 0x690fed0b, 0x667b9ffb, 0xcedb7d9c, 0xa091cf0b, 0xd9155ea3, 0xbb132f88, + 0x515bad24, 0x7b9479bf, 0x763bd6eb, 0x37392eb3, 0xcc115979, 0x8026e297, 0xf42e312d, 0x6842ada7, 0xc66a2b3b, + 0x12754ccc, 0x782ef11c, 0x6a124237, 0xb79251e7, 0x06a1bbe6, 0x4bfb6350, 0x1a6b1018, 0x11caedfa, 0x3d25bdd8, + 0xe2e1c3c9, 0x44421659, 0x0a121386, 0xd90cec6e, 0xd5abea2a, 0x64af674e, 0xda86a85f, 0xbebfe988, 0x64e4c3fe, + 0x9dbc8057, 0xf0f7c086, 0x60787bf8, 0x6003604d, 0xd1fd8346, 0xf6381fb0, 0x7745ae04, 0xd736fccc, 0x83426b33, + 0xf01eab71, 0xb0804187, 0x3c005e5f, 0x77a057be, 0xbde8ae24, 0x55464299, 0xbf582e61, 0x4e58f48f, 0xf2ddfda2, + 0xf474ef38, 0x8789bdc2, 0x5366f9c3, 0xc8b38e74, 0xb475f255, 0x46fcd9b9, 0x7aeb2661, 0x8b1ddf84, 0x846a0e79, + 0x915f95e2, 0x466e598e, 0x20b45770, 0x8cd55591, 0xc902de4c, 0xb90bace1, 0xbb8205d0, 0x11a86248, 0x7574a99e, + 0xb77f19b6, 0xe0a9dc09, 0x662d09a1, 0xc4324633, 0xe85a1f02, 0x09f0be8c, 0x4a99a025, 0x1d6efe10, 0x1ab93d1d, + 0x0ba5a4df, 0xa186f20f, 0x2868f169, 0xdcb7da83, 0x573906fe, 0xa1e2ce9b, 0x4fcd7f52, 0x50115e01, 0xa70683fa, + 0xa002b5c4, 0x0de6d027, 0x9af88c27, 0x773f8641, 0xc3604c06, 0x61a806b5, 0xf0177a28, 0xc0f586e0, 0x006058aa, + 0x30dc7d62, 0x11e69ed7, 0x2338ea63, 0x53c2dd94, 0xc2c21634, 0xbbcbee56, 0x90bcb6de, 0xebfc7da1, 0xce591d76, + 0x6f05e409, 0x4b7c0188, 0x39720a3d, 0x7c927c24, 0x86e3725f, 0x724d9db9, 0x1ac15bb4, 0xd39eb8fc, 0xed545578, + 0x08fca5b5, 0xd83d7cd3, 0x4dad0fc4, 0x1e50ef5e, 0xb161e6f8, 0xa28514d9, 0x6c51133c, 0x6fd5c7e7, 0x56e14ec4, + 0x362abfce, 0xddc6c837, 0xd79a3234, 0x92638212, 0x670efa8e, 0x406000e0, 0x3a39ce37, 0xd3faf5cf, 0xabc27737, + 0x5ac52d1b, 0x5cb0679e, 0x4fa33742, 0xd3822740, 0x99bc9bbe, 0xd5118e9d, 0xbf0f7315, 0xd62d1c7e, 0xc700c47b, + 0xb78c1b6b, 0x21a19045, 0xb26eb1be, 0x6a366eb4, 0x5748ab2f, 0xbc946e79, 0xc6a376d2, 0x6549c2c8, 0x530ff8ee, + 0x468dde7d, 0xd5730a1d, 0x4cd04dc6, 0x2939bbdb, 0xa9ba4650, 0xac9526e8, 0xbe5ee304, 0xa1fad5f0, 0x6a2d519a, + 0x63ef8ce2, 0x9a86ee22, 0xc089c2b8, 0x43242ef6, 0xa51e03aa, 0x9cf2d0a4, 0x83c061ba, 0x9be96a4d, 0x8fe51550, + 0xba645bd6, 0x2826a2f9, 0xa73a3ae1, 0x4ba99586, 0xef5562e9, 0xc72fefd3, 0xf752f7da, 0x3f046f69, 0x77fa0a59, + 0x80e4a915, 0x87b08601, 0x9b09e6ad, 0x3b3ee593, 0xe990fd5a, 0x9e34d797, 0x2cf0b7d9, 0x022b8b51, 0x96d5ac3a, + 0x017da67d, 0xd1cf3ed6, 0x7c7d2d28, 0x1f9f25cf, 0xadf2b89b, 0x5ad6b472, 0x5a88f54c, 0xe029ac71, 0xe019a5e6, + 0x47b0acfd, 0xed93fa9b, 0xe8d3c48d, 0x283b57cc, 0xf8d56629, 0x79132e28, 0x785f0191, 0xed756055, 0xf7960e44, + 0xe3d35e8c, 0x15056dd4, 0x88f46dba, 0x03a16125, 0x0564f0bd, 0xc3eb9e15, 0x3c9057a2, 0x97271aec, 0xa93a072a, + 0x1b3f6d9b, 0x1e6321f5, 0xf59c66fb, 0x26dcf319, 0x7533d928, 0xb155fdf5, 0x03563482, 0x8aba3cbb, 0x28517711, + 0xc20ad9f8, 0xabcc5167, 0xccad925f, 0x4de81751, 0x3830dc8e, 0x379d5862, 0x9320f991, 0xea7a90c2, 0xfb3e7bce, + 0x5121ce64, 0x774fbe32, 0xa8b6e37e, 0xc3293d46, 0x48de5369, 0x6413e680, 0xa2ae0810, 0xdd6db224, 0x69852dfd, + 0x09072166, 0xb39a460a, 0x6445c0dd, 0x586cdecf, 0x1c20c8ae, 0x5bbef7dd, 0x1b588d40, 0xccd2017f, 0x6bb4e3bb, + 0xdda26a7e, 0x3a59ff45, 0x3e350a44, 0xbcb4cdd5, 0x72eacea8, 0xfa6484bb, 0x8d6612ae, 0xbf3c6f47, 0xd29be463, + 0x542f5d9e, 0xaec2771b, 0xf64e6370, 0x740e0d8d, 0xe75b1357, 0xf8721671, 0xaf537d5d, 0x4040cb08, 0x4eb4e2cc, + 0x34d2466a, 0x0115af84, 0xe1b00428, 0x95983a1d, 0x06b89fb4, 0xce6ea048, 0x6f3f3b82, 0x3520ab82, 0x011a1d4b, + 0x277227f8, 0x611560b1, 0xe7933fdc, 0xbb3a792b, 0x344525bd, 0xa08839e1, 0x51ce794b, 0x2f32c9b7, 0xa01fbac9, + 0xe01cc87e, 0xbcc7d1f6, 0xcf0111c3, 0xa1e8aac7, 0x1a908749, 0xd44fbd9a, 0xd0dadecb, 0xd50ada38, 0x0339c32a, + 0xc6913667, 0x8df9317c, 0xe0b12b4f, 0xf79e59b7, 0x43f5bb3a, 0xf2d519ff, 0x27d9459c, 0xbf97222c, 0x15e6fc2a, + 0x0f91fc71, 0x9b941525, 0xfae59361, 0xceb69ceb, 0xc2a86459, 0x12baa8d1, 0xb6c1075e, 0xe3056a0c, 0x10d25065, + 0xcb03a442, 0xe0ec6e0e, 0x1698db3b, 0x4c98a0be, 0x3278e964, 0x9f1f9532, 0xe0d392df, 0xd3a0342b, 0x8971f21e, + 0x1b0a7441, 0x4ba3348c, 0xc5be7120, 0xc37632d8, 0xdf359f8d, 0x9b992f2e, 0xe60b6f47, 0x0fe3f11d, 0xe54cda54, + 0x1edad891, 0xce6279cf, 0xcd3e7e6f, 0x1618b166, 0xfd2c1d05, 0x848fd2c5, 0xf6fb2299, 0xf523f357, 0xa6327623, + 0x93a83531, 0x56cccd02, 0xacf08162, 0x5a75ebb5, 0x6e163697, 0x88d273cc, 0xde966292, 0x81b949d0, 0x4c50901b, + 0x71c65614, 0xe6c6c7bd, 0x327a140a, 0x45e1d006, 0xc3f27b9a, 0xc9aa53fd, 0x62a80f00, 0xbb25bfe2, 0x35bdd2f6, + 0x71126905, 0xb2040222, 0xb6cbcf7c, 0xcd769c2b, 0x53113ec0, 0x1640e3d3, 0x38abbd60, 0x2547adf0, 0xba38209c, + 0xf746ce76, 0x77afa1c5, 0x20756060, 0x85cbfe4e, 0x8ae88dd8, 0x7aaaf9b0, 0x4cf9aa7e, 0x1948c25c, 0x02fb8a8c, + 0x01c36ae4, 0xd6ebe1f9, 0x90d4f869, 0xa65cdea0, 0x3f09252d, 0xc208e69f, 0xb74e6132, 0xce77e25b, 0x578fdfe3, + 0x3ac372e6 }; // bcrypt IV: "OrpheanBeholderScryDoubt" - static private final int bf_crypt_ciphertext[] = { - 0x4f727068, 0x65616e42, 0x65686f6c, - 0x64657253, 0x63727944, 0x6f756274 - }; + static private final int bf_crypt_ciphertext[] = { 0x4f727068, 0x65616e42, 0x65686f6c, 0x64657253, 0x63727944, + 0x6f756274 }; // Table for Base64 encoding - static private final char base64_code[] = { - '.', '/', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', - 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', - 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', - 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', - '6', '7', '8', '9' - }; + static private final char base64_code[] = { '.', '/', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', + 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', + 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', + '2', '3', '4', '5', '6', '7', '8', '9' }; // Table for Base64 decoding - static private final byte index_64[] = { - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, 0, 1, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, -1, -1, - -1, -1, -1, -1, -1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, - -1, -1, -1, -1, -1, -1, 28, 29, 30, - 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, - 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, - 51, 52, 53, -1, -1, -1, -1, -1 - }; + static private final byte index_64[] = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 1, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, -1, -1, -1, -1, -1, -1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, -1, -1, -1, -1, -1, -1, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, -1, -1, -1, -1, -1 }; static final int MIN_LOG_ROUNDS = 4; static final int MAX_LOG_ROUNDS = 31; // Expanded Blowfish key private int P[]; + private int S[]; /** * Encode a byte array using bcrypt's slightly-modified base64 encoding scheme. Note * that this is not compatible with the standard MIME-base64 * encoding. - * * @param d the byte array to encode * @param len the number of bytes to encode * @param rs the destination buffer for the base64-encoded string * @exception IllegalArgumentException if the length is invalid */ - static void encode_base64(byte d[], int len, StringBuilder rs) - throws IllegalArgumentException { + static void encode_base64(byte d[], int len, StringBuilder rs) throws IllegalArgumentException { int off = 0; int c1, c2; @@ -418,54 +255,58 @@ public class BCrypt { } /** - * Look up the 3 bits base64-encoded by the specified character, - * range-checking againt conversion table - * @param x the base64-encoded value - * @return the decoded value of x + * Look up the 3 bits base64-encoded by the specified character, range-checking againt + * conversion table + * @param x the base64-encoded value + * @return the decoded value of x */ private static byte char64(char x) { - if ((int) x < 0 || (int) x >= index_64.length) + if (x < 0 || x >= index_64.length) { return -1; - return index_64[(int) x]; + } + return index_64[x]; } /** - * Decode a string encoded using bcrypt's base64 scheme to a - * byte array. Note that this is *not* compatible with - * the standard MIME-base64 encoding. - * @param s the string to decode - * @param maxolen the maximum number of bytes to decode - * @return an array containing the decoded bytes + * Decode a string encoded using bcrypt's base64 scheme to a byte array. Note that + * this is *not* compatible with the standard MIME-base64 encoding. + * @param s the string to decode + * @param maxolen the maximum number of bytes to decode + * @return an array containing the decoded bytes * @throws IllegalArgumentException if maxolen is invalid */ - static byte[] decode_base64(String s, int maxolen) - throws IllegalArgumentException { + static byte[] decode_base64(String s, int maxolen) throws IllegalArgumentException { StringBuilder rs = new StringBuilder(); int off = 0, slen = s.length(), olen = 0; byte ret[]; byte c1, c2, c3, c4, o; - if (maxolen <= 0) - throw new IllegalArgumentException ("Invalid maxolen"); + if (maxolen <= 0) { + throw new IllegalArgumentException("Invalid maxolen"); + } while (off < slen - 1 && olen < maxolen) { c1 = char64(s.charAt(off++)); c2 = char64(s.charAt(off++)); - if (c1 == -1 || c2 == -1) + if (c1 == -1 || c2 == -1) { break; + } o = (byte) (c1 << 2); o |= (c2 & 0x30) >> 4; rs.append((char) o); - if (++olen >= maxolen || off >= slen) + if (++olen >= maxolen || off >= slen) { break; + } c3 = char64(s.charAt(off++)); - if (c3 == -1) + if (c3 == -1) { break; + } o = (byte) ((c2 & 0x0f) << 4); o |= (c3 & 0x3c) >> 2; rs.append((char) o); - if (++olen >= maxolen || off >= slen) + if (++olen >= maxolen || off >= slen) { break; + } c4 = char64(s.charAt(off++)); o = (byte) ((c3 & 0x03) << 6); o |= c4; @@ -474,48 +315,47 @@ public class BCrypt { } ret = new byte[olen]; - for (off = 0; off < olen; off++) + for (off = 0; off < olen; off++) { ret[off] = (byte) rs.charAt(off); + } return ret; } /** - * Blowfish encipher a single 64-bit block encoded as - * two 32-bit halves - * @param lr an array containing the two 32-bit half blocks - * @param off the position in the array of the blocks + * Blowfish encipher a single 64-bit block encoded as two 32-bit halves + * @param lr an array containing the two 32-bit half blocks + * @param off the position in the array of the blocks */ private void encipher(int lr[], int off) { int i, n, l = lr[off], r = lr[off + 1]; - l ^= P[0]; + l ^= this.P[0]; for (i = 0; i <= BLOWFISH_NUM_ROUNDS - 2;) { // Feistel substitution on left word - n = S[(l >> 24) & 0xff]; - n += S[0x100 | ((l >> 16) & 0xff)]; - n ^= S[0x200 | ((l >> 8) & 0xff)]; - n += S[0x300 | (l & 0xff)]; - r ^= n ^ P[++i]; + n = this.S[(l >> 24) & 0xff]; + n += this.S[0x100 | ((l >> 16) & 0xff)]; + n ^= this.S[0x200 | ((l >> 8) & 0xff)]; + n += this.S[0x300 | (l & 0xff)]; + r ^= n ^ this.P[++i]; // Feistel substitution on right word - n = S[(r >> 24) & 0xff]; - n += S[0x100 | ((r >> 16) & 0xff)]; - n ^= S[0x200 | ((r >> 8) & 0xff)]; - n += S[0x300 | (r & 0xff)]; - l ^= n ^ P[++i]; + n = this.S[(r >> 24) & 0xff]; + n += this.S[0x100 | ((r >> 16) & 0xff)]; + n ^= this.S[0x200 | ((r >> 8) & 0xff)]; + n += this.S[0x300 | (r & 0xff)]; + l ^= n ^ this.P[++i]; } - lr[off] = r ^ P[BLOWFISH_NUM_ROUNDS + 1]; + lr[off] = r ^ this.P[BLOWFISH_NUM_ROUNDS + 1]; lr[off + 1] = l; } /** * Cycically extract a word of key material - * @param data the string to extract the data from - * @param offp a "pointer" (as a one-entry array) to the - * current offset into data - * @param signp a "pointer" (as a one-entry array) to the - * cumulative flag for non-benign sign extension - * @return correct and buggy next word of material from data as int[2] + * @param data the string to extract the data from + * @param offp a "pointer" (as a one-entry array) to the current offset into data + * @param signp a "pointer" (as a one-entry array) to the cumulative flag for + * non-benign sign extension + * @return correct and buggy next word of material from data as int[2] */ private static int[] streamtowords(byte data[], int offp[], int signp[]) { int i; @@ -525,8 +365,10 @@ public class BCrypt { for (i = 0; i < 4; i++) { words[0] = (words[0] << 8) | (data[off] & 0xff); - words[1] = (words[1] << 8) | (int) data[off]; // sign extension bug - if (i > 0) sign |= words[1] & 0x80; + words[1] = (words[1] << 8) | data[off]; // sign extension bug + if (i > 0) { + sign |= words[1] & 0x80; + } off = (off + 1) % data.length; } @@ -537,10 +379,9 @@ public class BCrypt { /** * Cycically extract a word of key material - * @param data the string to extract the data from - * @param offp a "pointer" (as a one-entry array) to the - * current offset into data - * @return the next word of material from data + * @param data the string to extract the data from + * @param offp a "pointer" (as a one-entry array) to the current offset into data + * @return the next word of material from data */ private static int streamtoword(byte data[], int offp[]) { int signp[] = { 0 }; @@ -549,10 +390,9 @@ public class BCrypt { /** * Cycically extract a word of key material, with sign-extension bug - * @param data the string to extract the data from - * @param offp a "pointer" (as a one-entry array) to the - * current offset into data - * @return the next word of material from data + * @param data the string to extract the data from + * @param offp a "pointer" (as a one-entry array) to the current offset into data + * @return the next word of material from data */ private static int streamtoword_bug(byte data[], int offp[]) { int signp[] = { 0 }; @@ -563,75 +403,76 @@ public class BCrypt { * Initialise the Blowfish key schedule */ private void init_key() { - P = P_orig.clone(); - S = S_orig.clone(); + this.P = P_orig.clone(); + this.S = S_orig.clone(); } /** * Key the Blowfish cipher - * @param key an array containing the key - * @param sign_ext_bug true to implement the 2x bug - * @param safety bit 16 is set when the safety measure is requested + * @param key an array containing the key + * @param sign_ext_bug true to implement the 2x bug + * @param safety bit 16 is set when the safety measure is requested */ private void key(byte key[], boolean sign_ext_bug, int safety) { int i; int koffp[] = { 0 }; int lr[] = { 0, 0 }; - int plen = P.length, slen = S.length; + int plen = this.P.length, slen = this.S.length; - for (i = 0; i < plen; i++) - if (!sign_ext_bug) - P[i] = P[i] ^ streamtoword(key, koffp); - else - P[i] = P[i] ^ streamtoword_bug(key, koffp); + for (i = 0; i < plen; i++) { + if (!sign_ext_bug) { + this.P[i] = this.P[i] ^ streamtoword(key, koffp); + } + else { + this.P[i] = this.P[i] ^ streamtoword_bug(key, koffp); + } + } for (i = 0; i < plen; i += 2) { encipher(lr, 0); - P[i] = lr[0]; - P[i + 1] = lr[1]; + this.P[i] = lr[0]; + this.P[i + 1] = lr[1]; } for (i = 0; i < slen; i += 2) { encipher(lr, 0); - S[i] = lr[0]; - S[i + 1] = lr[1]; + this.S[i] = lr[0]; + this.S[i + 1] = lr[1]; } } /** - * Perform the "enhanced key schedule" step described by - * Provos and Mazieres in "A Future-Adaptable Password Scheme" - * https://www.openbsd.org/papers/bcrypt-paper.ps - * @param data salt information - * @param key password information - * @param sign_ext_bug true to implement the 2x bug - * @param safety bit 16 is set when the safety measure is requested + * Perform the "enhanced key schedule" step described by Provos and Mazieres in "A + * Future-Adaptable Password Scheme" https://www.openbsd.org/papers/bcrypt-paper.ps + * @param data salt information + * @param key password information + * @param sign_ext_bug true to implement the 2x bug + * @param safety bit 16 is set when the safety measure is requested */ - private void ekskey(byte data[], byte key[], - boolean sign_ext_bug, int safety) { + private void ekskey(byte data[], byte key[], boolean sign_ext_bug, int safety) { int i; int koffp[] = { 0 }, doffp[] = { 0 }; int lr[] = { 0, 0 }; - int plen = P.length, slen = S.length; + int plen = this.P.length, slen = this.S.length; int signp[] = { 0 }; // non-benign sign-extension flag - int diff = 0; // zero iff correct and buggy are same + int diff = 0; // zero iff correct and buggy are same for (i = 0; i < plen; i++) { int words[] = streamtowords(key, koffp, signp); diff |= words[0] ^ words[1]; - P[i] = P[i] ^ words[sign_ext_bug ? 1 : 0]; + this.P[i] = this.P[i] ^ words[sign_ext_bug ? 1 : 0]; } int sign = signp[0]; /* * At this point, "diff" is zero iff the correct and buggy algorithms produced - * exactly the same result. If so and if "sign" is non-zero, which indicates - * that there was a non-benign sign extension, this means that we have a - * collision between the correctly computed hash for this password and a set of - * passwords that could be supplied to the buggy algorithm. Our safety measure - * is meant to protect from such many-buggy to one-correct collisions, by - * deviating from the correct algorithm in such cases. Let's check for this. + * exactly the same result. If so and if "sign" is non-zero, which indicates that + * there was a non-benign sign extension, this means that we have a collision + * between the correctly computed hash for this password and a set of passwords + * that could be supplied to the buggy algorithm. Our safety measure is meant to + * protect from such many-buggy to one-correct collisions, by deviating from the + * correct algorithm in such cases. Let's check for this. */ diff |= diff >> 16; /* still zero iff exact match */ diff &= 0xffff; /* ditto */ @@ -640,32 +481,32 @@ public class BCrypt { sign &= ~diff & safety; /* action needed? */ /* - * If we have determined that we need to deviate from the correct algorithm, - * flip bit 16 in initial expanded key. (The choice of 16 is arbitrary, but - * let's stick to it now. It came out of the approach we used above, and it's - * not any worse than any other choice we could make.) + * If we have determined that we need to deviate from the correct algorithm, flip + * bit 16 in initial expanded key. (The choice of 16 is arbitrary, but let's stick + * to it now. It came out of the approach we used above, and it's not any worse + * than any other choice we could make.) * * It is crucial that we don't do the same to the expanded key used in the main - * Eksblowfish loop. By doing it to only one of these two, we deviate from a - * state that could be directly specified by a password to the buggy algorithm - * (and to the fully correct one as well, but that's a side-effect). + * Eksblowfish loop. By doing it to only one of these two, we deviate from a state + * that could be directly specified by a password to the buggy algorithm (and to + * the fully correct one as well, but that's a side-effect). */ - P[0] ^= sign; + this.P[0] ^= sign; for (i = 0; i < plen; i += 2) { lr[0] ^= streamtoword(data, doffp); lr[1] ^= streamtoword(data, doffp); encipher(lr, 0); - P[i] = lr[0]; - P[i + 1] = lr[1]; + this.P[i] = lr[0]; + this.P[i + 1] = lr[1]; } for (i = 0; i < slen; i += 2) { lr[0] ^= streamtoword(data, doffp); lr[1] ^= streamtoword(data, doffp); encipher(lr, 0); - S[i] = lr[0]; - S[i + 1] = lr[1]; + this.S[i] = lr[0]; + this.S[i + 1] = lr[1]; } } @@ -677,28 +518,27 @@ public class BCrypt { } /** - * Perform the central password hashing step in the - * bcrypt scheme - * @param password the password to hash - * @param salt the binary salt to hash with the password - * @param log_rounds the binary logarithm of the number - * of rounds of hashing to apply - * @param sign_ext_bug true to implement the 2x bug - * @param safety bit 16 is set when the safety measure is requested - * @return an array containing the binary hashed password + * Perform the central password hashing step in the bcrypt scheme + * @param password the password to hash + * @param salt the binary salt to hash with the password + * @param log_rounds the binary logarithm of the number of rounds of hashing to apply + * @param sign_ext_bug true to implement the 2x bug + * @param safety bit 16 is set when the safety measure is requested + * @return an array containing the binary hashed password */ - private byte[] crypt_raw(byte password[], byte salt[], int log_rounds, - boolean sign_ext_bug, int safety) { + private byte[] crypt_raw(byte password[], byte salt[], int log_rounds, boolean sign_ext_bug, int safety) { int rounds, i, j; - int cdata[] = bf_crypt_ciphertext.clone(); + int cdata[] = bf_crypt_ciphertext.clone(); int clen = cdata.length; byte ret[]; - if (log_rounds < 4 || log_rounds > 31) - throw new IllegalArgumentException ("Bad number of rounds"); + if (log_rounds < 4 || log_rounds > 31) { + throw new IllegalArgumentException("Bad number of rounds"); + } rounds = 1 << log_rounds; - if (salt.length != BCRYPT_SALT_LEN) - throw new IllegalArgumentException ("Bad salt length"); + if (salt.length != BCRYPT_SALT_LEN) { + throw new IllegalArgumentException("Bad salt length"); + } init_key(); ekskey(salt, password, sign_ext_bug, safety); @@ -708,8 +548,9 @@ public class BCrypt { } for (i = 0; i < 64; i++) { - for (j = 0; j < (clen >> 1); j++) + for (j = 0; j < (clen >> 1); j++) { encipher(cdata, j << 1); + } } ret = new byte[clen * 4]; @@ -724,10 +565,9 @@ public class BCrypt { /** * Hash a password using the OpenBSD bcrypt scheme - * @param password the password to hash - * @param salt the salt to hash with (perhaps generated - * using BCrypt.gensalt) - * @return the hashed password + * @param password the password to hash + * @param salt the salt to hash with (perhaps generated using BCrypt.gensalt) + * @return the hashed password */ public static String hashpw(String password, String salt) { byte passwordb[]; @@ -739,10 +579,9 @@ public class BCrypt { /** * Hash a password using the OpenBSD bcrypt scheme - * @param passwordb the password to hash, as a byte array - * @param salt the salt to hash with (perhaps generated - * using BCrypt.gensalt) - * @return the hashed password + * @param passwordb the password to hash, as a byte array + * @param salt the salt to hash with (perhaps generated using BCrypt.gensalt) + * @return the hashed password */ public static String hashpw(byte passwordb[], String salt) { BCrypt B; @@ -762,21 +601,24 @@ public class BCrypt { throw new IllegalArgumentException("Invalid salt"); } - if (salt.charAt(0) != '$' || salt.charAt(1) != '2') - throw new IllegalArgumentException ("Invalid salt version"); - if (salt.charAt(2) == '$') + if (salt.charAt(0) != '$' || salt.charAt(1) != '2') { + throw new IllegalArgumentException("Invalid salt version"); + } + if (salt.charAt(2) == '$') { off = 3; + } else { minor = salt.charAt(2); - if ((minor != 'a' && minor != 'x' && minor != 'y' && minor != 'b') - || salt.charAt(3) != '$') - throw new IllegalArgumentException ("Invalid salt revision"); + if ((minor != 'a' && minor != 'x' && minor != 'y' && minor != 'b') || salt.charAt(3) != '$') { + throw new IllegalArgumentException("Invalid salt revision"); + } off = 4; } // Extract number of rounds - if (salt.charAt(off + 2) > '$') - throw new IllegalArgumentException ("Missing salt rounds"); + if (salt.charAt(off + 2) > '$') { + throw new IllegalArgumentException("Missing salt rounds"); + } if (off == 4 && saltLength < 29) { throw new IllegalArgumentException("Invalid salt"); @@ -786,18 +628,21 @@ public class BCrypt { real_salt = salt.substring(off + 3, off + 25); saltb = decode_base64(real_salt, BCRYPT_SALT_LEN); - if (minor >= 'a') // add null terminator + if (minor >= 'a') { passwordb = Arrays.copyOf(passwordb, passwordb.length + 1); + } B = new BCrypt(); hashed = B.crypt_raw(passwordb, saltb, rounds, minor == 'x', minor == 'a' ? 0x10000 : 0); rs.append("$2"); - if (minor >= 'a') + if (minor >= 'a') { rs.append(minor); + } rs.append("$"); - if (rounds < 10) + if (rounds < 10) { rs.append("0"); + } rs.append(rounds); rs.append("$"); encode_base64(saltb, saltb.length, rs); @@ -807,26 +652,23 @@ public class BCrypt { /** * Generate a salt for use with the BCrypt.hashpw() method - * @param prefix the prefix value (default $2a) - * @param log_rounds the log2 of the number of rounds of - * hashing to apply - the work factor therefore increases as - * 2**log_rounds. - * @param random an instance of SecureRandom to use - * @return an encoded salt value + * @param prefix the prefix value (default $2a) + * @param log_rounds the log2 of the number of rounds of hashing to apply - the work + * factor therefore increases as 2**log_rounds. + * @param random an instance of SecureRandom to use + * @return an encoded salt value * @exception IllegalArgumentException if prefix or log_rounds is invalid */ - public static String gensalt(String prefix, int log_rounds, SecureRandom random) - throws IllegalArgumentException { + public static String gensalt(String prefix, int log_rounds, SecureRandom random) throws IllegalArgumentException { StringBuilder rs = new StringBuilder(); byte rnd[] = new byte[BCRYPT_SALT_LEN]; - if (!prefix.startsWith("$2") || - (prefix.charAt(2) != 'a' && prefix.charAt(2) != 'y' && - prefix.charAt(2) != 'b')) { - throw new IllegalArgumentException ("Invalid prefix"); + if (!prefix.startsWith("$2") + || (prefix.charAt(2) != 'a' && prefix.charAt(2) != 'y' && prefix.charAt(2) != 'b')) { + throw new IllegalArgumentException("Invalid prefix"); } if (log_rounds < 4 || log_rounds > 31) { - throw new IllegalArgumentException ("Invalid log_rounds"); + throw new IllegalArgumentException("Invalid log_rounds"); } random.nextBytes(rnd); @@ -834,8 +676,9 @@ public class BCrypt { rs.append("$2"); rs.append(prefix.charAt(2)); rs.append("$"); - if (log_rounds < 10) + if (log_rounds < 10) { rs.append("0"); + } rs.append(log_rounds); rs.append("$"); encode_base64(rnd, rnd.length, rs); @@ -844,42 +687,36 @@ public class BCrypt { /** * Generate a salt for use with the BCrypt.hashpw() method - * @param prefix the prefix value (default $2a) - * @param log_rounds the log2 of the number of rounds of - * hashing to apply - the work factor therefore increases as - * 2**log_rounds. - * @return an encoded salt value + * @param prefix the prefix value (default $2a) + * @param log_rounds the log2 of the number of rounds of hashing to apply - the work + * factor therefore increases as 2**log_rounds. + * @return an encoded salt value * @exception IllegalArgumentException if prefix or log_rounds is invalid */ - public static String gensalt(String prefix, int log_rounds) - throws IllegalArgumentException { + public static String gensalt(String prefix, int log_rounds) throws IllegalArgumentException { return gensalt(prefix, log_rounds, new SecureRandom()); } /** * Generate a salt for use with the BCrypt.hashpw() method - * @param log_rounds the log2 of the number of rounds of - * hashing to apply - the work factor therefore increases as - * 2**log_rounds. - * @param random an instance of SecureRandom to use - * @return an encoded salt value + * @param log_rounds the log2 of the number of rounds of hashing to apply - the work + * factor therefore increases as 2**log_rounds. + * @param random an instance of SecureRandom to use + * @return an encoded salt value * @exception IllegalArgumentException if log_rounds is invalid */ - public static String gensalt(int log_rounds, SecureRandom random) - throws IllegalArgumentException { + public static String gensalt(int log_rounds, SecureRandom random) throws IllegalArgumentException { return gensalt("$2a", log_rounds, random); } /** * Generate a salt for use with the BCrypt.hashpw() method - * @param log_rounds the log2 of the number of rounds of - * hashing to apply - the work factor therefore increases as - * 2**log_rounds. - * @return an encoded salt value + * @param log_rounds the log2 of the number of rounds of hashing to apply - the work + * factor therefore increases as 2**log_rounds. + * @return an encoded salt value * @exception IllegalArgumentException if log_rounds is invalid */ - public static String gensalt(int log_rounds) - throws IllegalArgumentException { + public static String gensalt(int log_rounds) throws IllegalArgumentException { return gensalt(log_rounds, new SecureRandom()); } @@ -888,32 +725,29 @@ public class BCrypt { } /** - * Generate a salt for use with the BCrypt.hashpw() method, - * selecting a reasonable default for the number of hashing - * rounds to apply - * @return an encoded salt value + * Generate a salt for use with the BCrypt.hashpw() method, selecting a reasonable + * default for the number of hashing rounds to apply + * @return an encoded salt value */ public static String gensalt() { return gensalt(GENSALT_DEFAULT_LOG2_ROUNDS); } /** - * Check that a plaintext password matches a previously hashed - * one - * @param plaintext the plaintext password to verify - * @param hashed the previously-hashed password - * @return true if the passwords match, false otherwise + * Check that a plaintext password matches a previously hashed one + * @param plaintext the plaintext password to verify + * @param hashed the previously-hashed password + * @return true if the passwords match, false otherwise */ public static boolean checkpw(String plaintext, String hashed) { return equalsNoEarlyReturn(hashed, hashpw(plaintext, hashed)); } /** - * Check that a password (as a byte array) matches a previously hashed - * one - * @param passwordb the password to verify, as a byte array - * @param hashed the previously-hashed password - * @return true if the passwords match, false otherwise + * Check that a password (as a byte array) matches a previously hashed one + * @param passwordb the password to verify, as a byte array + * @param hashed the previously-hashed password + * @return true if the passwords match, false otherwise * @since 5.3 */ public static boolean checkpw(byte[] passwordb, String hashed) { @@ -923,4 +757,5 @@ public class BCrypt { static boolean equalsNoEarlyReturn(String a, String b) { return MessageDigest.isEqual(a.getBytes(StandardCharsets.UTF_8), b.getBytes(StandardCharsets.UTF_8)); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoder.java index dd787a9ea4..d17511b033 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoder.java +++ b/crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoder.java @@ -13,35 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.bcrypt; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.security.crypto.password.PasswordEncoder; +package org.springframework.security.crypto.bcrypt; import java.security.SecureRandom; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.security.crypto.password.PasswordEncoder; + /** * Implementation of PasswordEncoder that uses the BCrypt strong hashing function. Clients - * can optionally supply a "version" ($2a, $2b, $2y) and a "strength" (a.k.a. log rounds in BCrypt) - * and a SecureRandom instance. The larger the strength parameter the more work will have to be done - * (exponentially) to hash the passwords. The default value is 10. + * can optionally supply a "version" ($2a, $2b, $2y) and a "strength" (a.k.a. log rounds + * in BCrypt) and a SecureRandom instance. The larger the strength parameter the more work + * will have to be done (exponentially) to hash the passwords. The default value is 10. * * @author Dave Syer */ public class BCryptPasswordEncoder implements PasswordEncoder { - private Pattern BCRYPT_PATTERN = Pattern - .compile("\\A\\$2(a|y|b)?\\$(\\d\\d)\\$[./0-9A-Za-z]{53}"); + + private Pattern BCRYPT_PATTERN = Pattern.compile("\\A\\$2(a|y|b)?\\$(\\d\\d)\\$[./0-9A-Za-z]{53}"); + private final Log logger = LogFactory.getLog(getClass()); private final int strength; + private final BCryptVersion version; private final SecureRandom random; - public BCryptPasswordEncoder() { this(-1); } @@ -62,7 +65,7 @@ public class BCryptPasswordEncoder implements PasswordEncoder { /** * @param version the version of bcrypt, can be 2a,2b,2y - * @param random the secure random instance to use + * @param random the secure random instance to use */ public BCryptPasswordEncoder(BCryptVersion version, SecureRandom random) { this(version, -1, random); @@ -70,14 +73,14 @@ public class BCryptPasswordEncoder implements PasswordEncoder { /** * @param strength the log rounds to use, between 4 and 31 - * @param random the secure random instance to use + * @param random the secure random instance to use */ public BCryptPasswordEncoder(int strength, SecureRandom random) { this(BCryptVersion.$2A, strength, random); } /** - * @param version the version of bcrypt, can be 2a,2b,2y + * @param version the version of bcrypt, can be 2a,2b,2y * @param strength the log rounds to use, between 4 and 31 */ public BCryptPasswordEncoder(BCryptVersion version, int strength) { @@ -85,66 +88,63 @@ public class BCryptPasswordEncoder implements PasswordEncoder { } /** - * @param version the version of bcrypt, can be 2a,2b,2y + * @param version the version of bcrypt, can be 2a,2b,2y * @param strength the log rounds to use, between 4 and 31 - * @param random the secure random instance to use + * @param random the secure random instance to use */ public BCryptPasswordEncoder(BCryptVersion version, int strength, SecureRandom random) { if (strength != -1 && (strength < BCrypt.MIN_LOG_ROUNDS || strength > BCrypt.MAX_LOG_ROUNDS)) { throw new IllegalArgumentException("Bad strength"); } this.version = version; - this.strength = strength == -1 ? 10 : strength; + this.strength = (strength == -1) ? 10 : strength; this.random = random; } + @Override public String encode(CharSequence rawPassword) { if (rawPassword == null) { throw new IllegalArgumentException("rawPassword cannot be null"); } - - String salt; - if (random != null) { - salt = BCrypt.gensalt(version.getVersion(), strength, random); - } else { - salt = BCrypt.gensalt(version.getVersion(), strength); - } + String salt = getSalt(); return BCrypt.hashpw(rawPassword.toString(), salt); } + private String getSalt() { + if (this.random != null) { + return BCrypt.gensalt(this.version.getVersion(), this.strength, this.random); + } + return BCrypt.gensalt(this.version.getVersion(), this.strength); + } + + @Override public boolean matches(CharSequence rawPassword, String encodedPassword) { if (rawPassword == null) { throw new IllegalArgumentException("rawPassword cannot be null"); } - if (encodedPassword == null || encodedPassword.length() == 0) { - logger.warn("Empty encoded password"); + this.logger.warn("Empty encoded password"); return false; } - - if (!BCRYPT_PATTERN.matcher(encodedPassword).matches()) { - logger.warn("Encoded password does not look like BCrypt"); + if (!this.BCRYPT_PATTERN.matcher(encodedPassword).matches()) { + this.logger.warn("Encoded password does not look like BCrypt"); return false; } - return BCrypt.checkpw(rawPassword.toString(), encodedPassword); } @Override public boolean upgradeEncoding(String encodedPassword) { if (encodedPassword == null || encodedPassword.length() == 0) { - logger.warn("Empty encoded password"); + this.logger.warn("Empty encoded password"); return false; } - - Matcher matcher = BCRYPT_PATTERN.matcher(encodedPassword); + Matcher matcher = this.BCRYPT_PATTERN.matcher(encodedPassword); if (!matcher.matches()) { throw new IllegalArgumentException("Encoded password does not look like BCrypt: " + encodedPassword); } - else { - int strength = Integer.parseInt(matcher.group(2)); - return strength < this.strength; - } + int strength = Integer.parseInt(matcher.group(2)); + return strength < this.strength; } /** @@ -153,8 +153,11 @@ public class BCryptPasswordEncoder implements PasswordEncoder { * @author Lin Feng */ public enum BCryptVersion { + $2A("$2a"), + $2Y("$2y"), + $2B("$2b"); private final String version; @@ -166,5 +169,7 @@ public class BCryptPasswordEncoder implements PasswordEncoder { public String getVersion() { return this.version; } + } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/codec/Base64.java b/crypto/src/main/java/org/springframework/security/crypto/codec/Base64.java index 8269e5f313..998b4df10e 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/codec/Base64.java +++ b/crypto/src/main/java/org/springframework/security/crypto/codec/Base64.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.codec; /** * Base64 encoder which is a reduced version of Robert Harder's public domain - * implementation (version 2.3.7). See http://iharder.sourceforge.net/current/java/base64/ for more information. + * implementation (version 2.3.7). See http://iharder.sourceforge.net/current/java/base64/ + * for more information. *

      * For internal use only. * @@ -30,68 +32,66 @@ package org.springframework.security.crypto.codec; public final class Base64 { /** No options specified. Value is zero. */ - public final static int NO_OPTIONS = 0; + public static final int NO_OPTIONS = 0; /** Specify encoding in first bit. Value is one. */ - public final static int ENCODE = 1; + public static final int ENCODE = 1; /** Specify decoding in first bit. Value is zero. */ - public final static int DECODE = 0; + public static final int DECODE = 0; /** Do break lines when encoding. Value is 8. */ - public final static int DO_BREAK_LINES = 8; + public static final int DO_BREAK_LINES = 8; /** * Encode using Base64-like encoding that is URL- and Filename-safe as described in - * Section 4 of RFC3548: https://tools.ietf.org/html/rfc3548. - * It is important to note that data encoded this way is - * not officially valid Base64, or at the very least should not be called - * Base64 without also specifying that is was encoded using the URL- and Filename-safe - * dialect. + * Section 4 of RFC3548: https://tools.ietf.org/html/rfc3548. It + * is important to note that data encoded this way is not officially valid + * Base64, or at the very least should not be called Base64 without also specifying + * that is was encoded using the URL- and Filename-safe dialect. */ - public final static int URL_SAFE = 16; + public static final int URL_SAFE = 16; /** * Encode using the special "ordered" dialect of Base64. */ - public final static int ORDERED = 32; + public static final int ORDERED = 32; /** Maximum line length (76) of Base64 output. */ - private final static int MAX_LINE_LENGTH = 76; + private static final int MAX_LINE_LENGTH = 76; /** The equals sign (=) as a byte. */ - private final static byte EQUALS_SIGN = (byte) '='; + private static final byte EQUALS_SIGN = (byte) '='; /** The new line character (\n) as a byte. */ - private final static byte NEW_LINE = (byte) '\n'; + private static final byte NEW_LINE = (byte) '\n'; - private final static byte WHITE_SPACE_ENC = -5; // Indicates white space in encoding - private final static byte EQUALS_SIGN_ENC = -1; // Indicates equals sign in encoding + private static final byte WHITE_SPACE_ENC = -5; // Indicates white space in encoding + + private static final byte EQUALS_SIGN_ENC = -1; // Indicates equals sign in encoding /* ******** S T A N D A R D B A S E 6 4 A L P H A B E T ******** */ /** The 64 valid Base64 values. */ /* Host platform me be something funny like EBCDIC, so we hardcode these values. */ - private final static byte[] _STANDARD_ALPHABET = { (byte) 'A', (byte) 'B', - (byte) 'C', (byte) 'D', (byte) 'E', (byte) 'F', (byte) 'G', (byte) 'H', - (byte) 'I', (byte) 'J', (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', - (byte) 'O', (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', - (byte) 'U', (byte) 'V', (byte) 'W', (byte) 'X', (byte) 'Y', (byte) 'Z', - (byte) 'a', (byte) 'b', (byte) 'c', (byte) 'd', (byte) 'e', (byte) 'f', - (byte) 'g', (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', - (byte) 'm', (byte) 'n', (byte) 'o', (byte) 'p', (byte) 'q', (byte) 'r', - (byte) 's', (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', - (byte) 'y', (byte) 'z', (byte) '0', (byte) '1', (byte) '2', (byte) '3', - (byte) '4', (byte) '5', (byte) '6', (byte) '7', (byte) '8', (byte) '9', - (byte) '+', (byte) '/' }; + private static final byte[] _STANDARD_ALPHABET = { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D', (byte) 'E', + (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', (byte) 'J', (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', + (byte) 'O', (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', (byte) 'U', (byte) 'V', (byte) 'W', + (byte) 'X', (byte) 'Y', (byte) 'Z', (byte) 'a', (byte) 'b', (byte) 'c', (byte) 'd', (byte) 'e', (byte) 'f', + (byte) 'g', (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', (byte) 'm', (byte) 'n', (byte) 'o', + (byte) 'p', (byte) 'q', (byte) 'r', (byte) 's', (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', + (byte) 'y', (byte) 'z', (byte) '0', (byte) '1', (byte) '2', (byte) '3', (byte) '4', (byte) '5', (byte) '6', + (byte) '7', (byte) '8', (byte) '9', (byte) '+', (byte) '/' }; /** * Translates a Base64 value to either its 6-bit reconstruction value or a negative * number indicating some other meaning. **/ - private final static byte[] _STANDARD_DECODABET = { -9, -9, -9, -9, -9, -9, -9, -9, - -9, // Decimal 0 - 8 + private static final byte[] _STANDARD_DECODABET = { -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal + // 0 + // - + // 8 -5, -5, // Whitespace: Tab and Linefeed -9, -9, // Decimal 11 - 12 -5, // Whitespace: Carriage Return @@ -111,8 +111,8 @@ public final class Base64 { -9, -9, -9, -9, -9, -9, // Decimal 91 - 96 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, // Letters 'a' through 'm' 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // Letters 'n' through 'z' - -9, -9, -9, -9, -9 // Decimal 123 - 127 - , -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 128 - 139 + -9, -9, -9, -9, -9, // Decimal 123 - 127 + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 128 - 139 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 140 - 152 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 153 - 165 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 166 - 178 @@ -127,30 +127,28 @@ public final class Base64 { /* ******** U R L S A F E B A S E 6 4 A L P H A B E T ******** */ /** - * Used in the URL- and Filename-safe dialect described in Section 4 of RFC3548: https://tools.ietf.org/html/rfc3548. * Notice that the last two bytes become "hyphen" and "underscore" instead of "plus" * and "slash." */ - private final static byte[] _URL_SAFE_ALPHABET = { (byte) 'A', (byte) 'B', - (byte) 'C', (byte) 'D', (byte) 'E', (byte) 'F', (byte) 'G', (byte) 'H', - (byte) 'I', (byte) 'J', (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', - (byte) 'O', (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', - (byte) 'U', (byte) 'V', (byte) 'W', (byte) 'X', (byte) 'Y', (byte) 'Z', - (byte) 'a', (byte) 'b', (byte) 'c', (byte) 'd', (byte) 'e', (byte) 'f', - (byte) 'g', (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', - (byte) 'm', (byte) 'n', (byte) 'o', (byte) 'p', (byte) 'q', (byte) 'r', - (byte) 's', (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', - (byte) 'y', (byte) 'z', (byte) '0', (byte) '1', (byte) '2', (byte) '3', - (byte) '4', (byte) '5', (byte) '6', (byte) '7', (byte) '8', (byte) '9', - (byte) '-', (byte) '_' }; + private static final byte[] _URL_SAFE_ALPHABET = { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D', (byte) 'E', + (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', (byte) 'J', (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', + (byte) 'O', (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', (byte) 'U', (byte) 'V', (byte) 'W', + (byte) 'X', (byte) 'Y', (byte) 'Z', (byte) 'a', (byte) 'b', (byte) 'c', (byte) 'd', (byte) 'e', (byte) 'f', + (byte) 'g', (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', (byte) 'm', (byte) 'n', (byte) 'o', + (byte) 'p', (byte) 'q', (byte) 'r', (byte) 's', (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', + (byte) 'y', (byte) 'z', (byte) '0', (byte) '1', (byte) '2', (byte) '3', (byte) '4', (byte) '5', (byte) '6', + (byte) '7', (byte) '8', (byte) '9', (byte) '-', (byte) '_' }; /** * Used in decoding URL- and Filename-safe dialects of Base64. */ - private final static byte[] _URL_SAFE_DECODABET = { -9, -9, -9, -9, -9, -9, -9, -9, - -9, // Decimal 0 - 8 + private static final byte[] _URL_SAFE_DECODABET = { -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal + // 0 + // - + // 8 -5, -5, // Whitespace: Tab and Linefeed -9, -9, // Decimal 11 - 12 -5, // Whitespace: Carriage Return @@ -174,8 +172,8 @@ public final class Base64 { -9, // Decimal 96 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, // Letters 'a' through 'm' 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // Letters 'n' through 'z' - -9, -9, -9, -9, -9 // Decimal 123 - 127 - , -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 128 - 139 + -9, -9, -9, -9, -9, // Decimal 123 - 127 + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 128 - 139 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 140 - 152 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 153 - 165 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 166 - 178 @@ -189,24 +187,20 @@ public final class Base64 { /* ******** O R D E R E D B A S E 6 4 A L P H A B E T ******** */ - private final static byte[] _ORDERED_ALPHABET = { (byte) '-', (byte) '0', (byte) '1', - (byte) '2', (byte) '3', (byte) '4', (byte) '5', (byte) '6', (byte) '7', - (byte) '8', (byte) '9', (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D', - (byte) 'E', (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', (byte) 'J', - (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', (byte) 'O', (byte) 'P', - (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', (byte) 'U', (byte) 'V', - (byte) 'W', (byte) 'X', (byte) 'Y', (byte) 'Z', (byte) '_', (byte) 'a', - (byte) 'b', (byte) 'c', (byte) 'd', (byte) 'e', (byte) 'f', (byte) 'g', - (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', (byte) 'm', - (byte) 'n', (byte) 'o', (byte) 'p', (byte) 'q', (byte) 'r', (byte) 's', - (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', (byte) 'y', - (byte) 'z' }; + private static final byte[] _ORDERED_ALPHABET = { (byte) '-', (byte) '0', (byte) '1', (byte) '2', (byte) '3', + (byte) '4', (byte) '5', (byte) '6', (byte) '7', (byte) '8', (byte) '9', (byte) 'A', (byte) 'B', (byte) 'C', + (byte) 'D', (byte) 'E', (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', (byte) 'J', (byte) 'K', (byte) 'L', + (byte) 'M', (byte) 'N', (byte) 'O', (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', (byte) 'U', + (byte) 'V', (byte) 'W', (byte) 'X', (byte) 'Y', (byte) 'Z', (byte) '_', (byte) 'a', (byte) 'b', (byte) 'c', + (byte) 'd', (byte) 'e', (byte) 'f', (byte) 'g', (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', + (byte) 'm', (byte) 'n', (byte) 'o', (byte) 'p', (byte) 'q', (byte) 'r', (byte) 's', (byte) 't', (byte) 'u', + (byte) 'v', (byte) 'w', (byte) 'x', (byte) 'y', (byte) 'z' }; /** * Used in decoding the "ordered" dialect of Base64. */ - private final static byte[] _ORDERED_DECODABET = { -9, -9, -9, -9, -9, -9, -9, -9, - -9, // Decimal 0 - 8 + private static final byte[] _ORDERED_DECODABET = { -9, -9, -9, -9, -9, -9, -9, -9, -9, + // Decimal 0 - 8 -5, -5, // Whitespace: Tab and Linefeed -9, -9, // Decimal 11 - 12 -5, // Whitespace: Carriage Return @@ -230,8 +224,8 @@ public final class Base64 { -9, // Decimal 96 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, // Letters 'a' through 'm' 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // Letters 'n' through 'z' - -9, -9, -9, -9, -9 // Decimal 123 - 127 - , -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 128 - 139 + -9, -9, -9, -9, -9, // Decimal 123 - 127 + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 128 - 139 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 140 - 152 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 153 - 165 -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 166 - 178 @@ -243,6 +237,9 @@ public final class Base64 { -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9 // Decimal 244 - 255 }; + private Base64() { + } + public static byte[] decode(byte[] bytes) { return decode(bytes, 0, bytes.length, NO_OPTIONS); } @@ -255,7 +252,7 @@ public final class Base64 { try { decode(bytes); } - catch (InvalidBase64CharacterException e) { + catch (InvalidBase64CharacterException ex) { return false; } return true; @@ -312,7 +309,6 @@ public final class Base64 { *

      * This is the lowest level of the encoding methods with all possible parameters. *

      - * * @param source the array to convert * @param srcOffset the index where conversion begins * @param numSigBytes the number of significant bytes in your array @@ -321,8 +317,8 @@ public final class Base64 { * @return the destination array * @since 1.3 */ - private static byte[] encode3to4(byte[] source, int srcOffset, int numSigBytes, - byte[] destination, int destOffset, int options) { + private static byte[] encode3to4(byte[] source, int srcOffset, int numSigBytes, byte[] destination, int destOffset, + int options) { byte[] ALPHABET = getAlphabet(options); @@ -337,9 +333,9 @@ public final class Base64 { // significant bytes passed in the array. // We have to shift left 24 in order to flush out the 1's that appear // when Java treats a value as negative that is cast from a byte to an int. - int inBuff = (numSigBytes > 0 ? ((source[srcOffset] << 24) >>> 8) : 0) - | (numSigBytes > 1 ? ((source[srcOffset + 1] << 24) >>> 16) : 0) - | (numSigBytes > 2 ? ((source[srcOffset + 2] << 24) >>> 24) : 0); + int inBuff = ((numSigBytes > 0) ? ((source[srcOffset] << 24) >>> 8) : 0) + | ((numSigBytes > 1) ? ((source[srcOffset + 1] << 24) >>> 16) : 0) + | ((numSigBytes > 2) ? ((source[srcOffset + 2] << 24) >>> 24) : 0); switch (numSigBytes) { case 3: @@ -369,17 +365,16 @@ public final class Base64 { } /** - * * @param source The data to convert * @param off Offset in array where conversion should begin * @param len Length of data to convert * @param options Specified options * @return The Base64-encoded data as a String - * @see Base64#DO_BREAK_LINES * @throws java.io.IOException if there is an error * @throws NullPointerException if source array is null * @throws IllegalArgumentException if source array, offset, or length are invalid * @since 2.3.1 + * @see Base64#DO_BREAK_LINES */ private static byte[] encodeBytesToBytes(byte[] source, int off, int len, int options) { @@ -397,8 +392,7 @@ public final class Base64 { if (off + len > source.length) { throw new IllegalArgumentException(String.format( - "Cannot have offset of %d and length of %d with array of length %d", - off, len, source.length)); + "Cannot have offset of %d and length of %d with array of length %d", off, len, source.length)); } // end if: off < 0 boolean breakLines = (options & DO_BREAK_LINES) > 0; @@ -410,8 +404,10 @@ public final class Base64 { // Try to determine more precisely how big the array needs to be. // If we get it right, we don't have to do an array copy, and // we save a bunch of memory. - int encLen = (len / 3) * 4 + (len % 3 > 0 ? 4 : 0); // Bytes needed for actual - // encoding + + // Bytes needed for actual encoding + int encLen = (len / 3) * 4 + ((len % 3 > 0) ? 4 : 0); + if (breakLines) { encLen += encLen / MAX_LINE_LENGTH; // Plus extra newline characters } @@ -464,8 +460,6 @@ public final class Base64 { *

      * This is the lowest level of the decoding methods with all possible parameters. *

      - * - * * @param source the array to convert * @param srcOffset the index where conversion begins * @param destination the array to hold the conversion @@ -477,8 +471,8 @@ public final class Base64 { * not enough room in the array. * @since 1.3 */ - private static int decode4to3(final byte[] source, final int srcOffset, - final byte[] destination, final int destOffset, final int options) { + private static int decode4to3(final byte[] source, final int srcOffset, final byte[] destination, + final int destOffset, final int options) { // Lots of error checking and exception throwing if (source == null) { @@ -489,15 +483,13 @@ public final class Base64 { } // end if if (srcOffset < 0 || srcOffset + 3 >= source.length) { throw new IllegalArgumentException( - String.format( - "Source array with length %d cannot have offset of %d and still process four bytes.", + String.format("Source array with length %d cannot have offset of %d and still process four bytes.", source.length, srcOffset)); } // end if if (destOffset < 0 || destOffset + 2 >= destination.length) { - throw new IllegalArgumentException( - String.format( - "Destination array with length %d cannot have offset of %d and still store three bytes.", - destination.length, destOffset)); + throw new IllegalArgumentException(String.format( + "Destination array with length %d cannot have offset of %d and still store three bytes.", + destination.length, destOffset)); } // end if byte[] DECODABET = getDecodabet(options); @@ -538,8 +530,7 @@ public final class Base64 { // | ( ( DECODABET[ source[ srcOffset + 3 ] ] << 24 ) >>> 24 ); int outBuff = ((DECODABET[source[srcOffset]] & 0xFF) << 18) | ((DECODABET[source[srcOffset + 1]] & 0xFF) << 12) - | ((DECODABET[source[srcOffset + 2]] & 0xFF) << 6) - | ((DECODABET[source[srcOffset + 3]] & 0xFF)); + | ((DECODABET[source[srcOffset + 2]] & 0xFF) << 6) | ((DECODABET[source[srcOffset + 3]] & 0xFF)); destination[destOffset] = (byte) (outBuff >> 16); destination[destOffset + 1] = (byte) (outBuff >> 8); @@ -555,7 +546,6 @@ public final class Base64 { * recommended method, although it is used internally as part of the decoding process. * Special case: if len = 0, an empty array is returned. Still, if you need more speed * and reduced memory footprint (and aren't gzipping), consider this method. - * * @param source The Base64 encoded data * @param off The offset of where to begin decoding * @param len The length of characters to decode @@ -563,8 +553,7 @@ public final class Base64 { * @return decoded data * @throws IllegalArgumentException If bogus characters exist in source data */ - private static byte[] decode(final byte[] source, final int off, final int len, - final int options) { + private static byte[] decode(final byte[] source, final int off, final int len, final int options) { // Lots of error checking and exception throwing if (source == null) { @@ -572,8 +561,7 @@ public final class Base64 { } // end if if (off < 0 || off + len > source.length) { throw new IllegalArgumentException( - String.format( - "Source array with length %d cannot have offset of %d and process %d bytes.", + String.format("Source array with length %d cannot have offset of %d and process %d bytes.", source.length, off, len)); } // end if @@ -582,8 +570,7 @@ public final class Base64 { } else if (len < 4) { throw new IllegalArgumentException( - "Base64-encoded string must have at least four characters, but length specified was " - + len); + "Base64-encoded string must have at least four characters, but length specified was " + len); } // end if byte[] DECODABET = getDecodabet(options); @@ -620,9 +607,8 @@ public final class Base64 { } else { // There's a bad input character in the Base64 stream. - throw new InvalidBase64CharacterException(String.format( - "Bad Base64 input character decimal %d in array position %d", - ((int) source[i]) & 0xFF, i)); + throw new InvalidBase64CharacterException(String + .format("Bad Base64 input character decimal %d in array position %d", (source[i]) & 0xFF, i)); } } @@ -630,11 +616,13 @@ public final class Base64 { System.arraycopy(outBuff, 0, out, 0, outBuffPosn); return out; } -} -class InvalidBase64CharacterException extends IllegalArgumentException { + static class InvalidBase64CharacterException extends IllegalArgumentException { + + InvalidBase64CharacterException(String message) { + super(message); + } - InvalidBase64CharacterException(String message) { - super(message); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/codec/Hex.java b/crypto/src/main/java/org/springframework/security/crypto/codec/Hex.java index 7555abd5b7..2691e75550 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/codec/Hex.java +++ b/crypto/src/main/java/org/springframework/security/crypto/codec/Hex.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.codec; /** @@ -26,13 +27,14 @@ package org.springframework.security.crypto.codec; */ public final class Hex { - private static final char[] HEX = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', - 'a', 'b', 'c', 'd', 'e', 'f' }; + private static final char[] HEX = "0123456789abcdef".toCharArray(); + + private Hex() { + } public static char[] encode(byte[] bytes) { final int nBytes = bytes.length; char[] result = new char[2 * nBytes]; - int j = 0; for (byte aByte : bytes) { // Char for top 4 bits @@ -40,27 +42,21 @@ public final class Hex { // Bottom 4 result[j++] = HEX[(0x0F & aByte)]; } - return result; } public static byte[] decode(CharSequence s) { int nChars = s.length(); - if (nChars % 2 != 0) { - throw new IllegalArgumentException( - "Hex-encoded string must have an even number of characters"); + throw new IllegalArgumentException("Hex-encoded string must have an even number of characters"); } - byte[] result = new byte[nChars / 2]; - for (int i = 0; i < nChars; i += 2) { int msb = Character.digit(s.charAt(i), 16); int lsb = Character.digit(s.charAt(i + 1), 16); - if (msb < 0 || lsb < 0) { throw new IllegalArgumentException( - "Detected a Non-hex character at " + (i + 1) + " or " + (i + 2) + " position"); + "Detected a Non-hex character at " + (i + 1) + " or " + (i + 2) + " position"); } result[i / 2] = (byte) ((msb << 4) | lsb); } diff --git a/crypto/src/main/java/org/springframework/security/crypto/codec/Utf8.java b/crypto/src/main/java/org/springframework/security/crypto/codec/Utf8.java index a2274fa75d..093de89fdd 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/codec/Utf8.java +++ b/crypto/src/main/java/org/springframework/security/crypto/codec/Utf8.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.codec; import java.nio.ByteBuffer; @@ -29,8 +30,12 @@ import java.nio.charset.StandardCharsets; * @author Luke Taylor */ public final class Utf8 { + private static final Charset CHARSET = StandardCharsets.UTF_8; + private Utf8() { + } + /** * Get the bytes of the String in UTF-8 encoded form. */ @@ -39,11 +44,10 @@ public final class Utf8 { ByteBuffer bytes = CHARSET.newEncoder().encode(CharBuffer.wrap(string)); byte[] bytesCopy = new byte[bytes.limit()]; System.arraycopy(bytes.array(), 0, bytesCopy, 0, bytes.limit()); - return bytesCopy; } - catch (CharacterCodingException e) { - throw new IllegalArgumentException("Encoding failed", e); + catch (CharacterCodingException ex) { + throw new IllegalArgumentException("Encoding failed", ex); } } @@ -54,8 +58,9 @@ public final class Utf8 { try { return CHARSET.newDecoder().decode(ByteBuffer.wrap(bytes)).toString(); } - catch (CharacterCodingException e) { - throw new IllegalArgumentException("Decoding failed", e); + catch (CharacterCodingException ex) { + throw new IllegalArgumentException("Decoding failed", ex); } } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/codec/package-info.java b/crypto/src/main/java/org/springframework/security/crypto/codec/package-info.java index 0875553e36..e560395216 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/codec/package-info.java +++ b/crypto/src/main/java/org/springframework/security/crypto/codec/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Internal codec classes. Only intended for use within the framework. */ package org.springframework.security.crypto.codec; - diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/AesBytesEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/AesBytesEncryptor.java index 95c7b4067f..d8f2fc12b6 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/AesBytesEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/AesBytesEncryptor.java @@ -13,14 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.encrypt; -import static org.springframework.security.crypto.encrypt.CipherUtils.doFinal; -import static org.springframework.security.crypto.encrypt.CipherUtils.initCipher; -import static org.springframework.security.crypto.encrypt.CipherUtils.newCipher; -import static org.springframework.security.crypto.encrypt.CipherUtils.newSecretKey; -import static org.springframework.security.crypto.util.EncodingUtils.concatenate; -import static org.springframework.security.crypto.util.EncodingUtils.subArray; +package org.springframework.security.crypto.encrypt; import java.security.spec.AlgorithmParameterSpec; @@ -34,6 +28,7 @@ import javax.crypto.spec.SecretKeySpec; import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.KeyGenerators; +import org.springframework.security.crypto.util.EncodingUtils; /** * Encryptor that uses AES encryption. @@ -57,12 +52,89 @@ public final class AesBytesEncryptor implements BytesEncryptor { private static final String AES_GCM_ALGORITHM = "AES/GCM/NoPadding"; + public AesBytesEncryptor(String password, CharSequence salt) { + this(password, salt, null); + } + + public AesBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator) { + this(password, salt, ivGenerator, CipherAlgorithm.CBC); + } + + public AesBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator, CipherAlgorithm alg) { + this(CipherUtils.newSecretKey("PBKDF2WithHmacSHA1", + new PBEKeySpec(password.toCharArray(), Hex.decode(salt), 1024, 256)), ivGenerator, alg); + } + + /** + * Constructs an encryptor that uses AES encryption. + * @param secretKey the secret (symmetric) key + * @param ivGenerator the generator used to generate the initialization vector. If + * null, then a default algorithm will be used based on the provided + * {@link CipherAlgorithm} + * @param alg the {@link CipherAlgorithm} to be used + */ + public AesBytesEncryptor(SecretKey secretKey, BytesKeyGenerator ivGenerator, CipherAlgorithm alg) { + this.secretKey = new SecretKeySpec(secretKey.getEncoded(), "AES"); + this.alg = alg; + this.encryptor = alg.createCipher(); + this.decryptor = alg.createCipher(); + this.ivGenerator = (ivGenerator != null) ? ivGenerator : alg.defaultIvGenerator(); + } + + @Override + public byte[] encrypt(byte[] bytes) { + synchronized (this.encryptor) { + byte[] iv = this.ivGenerator.generateKey(); + CipherUtils.initCipher(this.encryptor, Cipher.ENCRYPT_MODE, this.secretKey, this.alg.getParameterSpec(iv)); + byte[] encrypted = CipherUtils.doFinal(this.encryptor, bytes); + return (this.ivGenerator != NULL_IV_GENERATOR) ? EncodingUtils.concatenate(iv, encrypted) : encrypted; + } + } + + @Override + public byte[] decrypt(byte[] encryptedBytes) { + synchronized (this.decryptor) { + byte[] iv = iv(encryptedBytes); + CipherUtils.initCipher(this.decryptor, Cipher.DECRYPT_MODE, this.secretKey, this.alg.getParameterSpec(iv)); + return CipherUtils.doFinal(this.decryptor, + (this.ivGenerator != NULL_IV_GENERATOR) ? encrypted(encryptedBytes, iv.length) : encryptedBytes); + } + } + + private byte[] iv(byte[] encrypted) { + return (this.ivGenerator != NULL_IV_GENERATOR) + ? EncodingUtils.subArray(encrypted, 0, this.ivGenerator.getKeyLength()) + : NULL_IV_GENERATOR.generateKey(); + } + + private byte[] encrypted(byte[] encryptedBytes, int ivLength) { + return EncodingUtils.subArray(encryptedBytes, ivLength, encryptedBytes.length); + } + + private static final BytesKeyGenerator NULL_IV_GENERATOR = new BytesKeyGenerator() { + + private final byte[] VALUE = new byte[16]; + + @Override + public int getKeyLength() { + return this.VALUE.length; + } + + @Override + public byte[] generateKey() { + return this.VALUE; + } + + }; + public enum CipherAlgorithm { - CBC(AES_CBC_ALGORITHM, NULL_IV_GENERATOR), GCM(AES_GCM_ALGORITHM, KeyGenerators - .secureRandom(16)); + CBC(AES_CBC_ALGORITHM, NULL_IV_GENERATOR), + + GCM(AES_GCM_ALGORITHM, KeyGenerators.secureRandom(16)); private BytesKeyGenerator ivGenerator; + private String name; CipherAlgorithm(String name, BytesKeyGenerator ivGenerator) { @@ -76,94 +148,17 @@ public final class AesBytesEncryptor implements BytesEncryptor { } public AlgorithmParameterSpec getParameterSpec(byte[] iv) { - return this == CBC ? new IvParameterSpec(iv) : new GCMParameterSpec(128, iv); + return (this != CBC) ? new GCMParameterSpec(128, iv) : new IvParameterSpec(iv); } public Cipher createCipher() { - return newCipher(this.toString()); + return CipherUtils.newCipher(this.toString()); } public BytesKeyGenerator defaultIvGenerator() { return this.ivGenerator; } + } - public AesBytesEncryptor(String password, CharSequence salt) { - this(password, salt, null); - } - - public AesBytesEncryptor(String password, CharSequence salt, - BytesKeyGenerator ivGenerator) { - this(password, salt, ivGenerator, CipherAlgorithm.CBC); - } - - public AesBytesEncryptor(String password, CharSequence salt, - BytesKeyGenerator ivGenerator, CipherAlgorithm alg) { - this(newSecretKey("PBKDF2WithHmacSHA1", new PBEKeySpec(password.toCharArray(), Hex.decode(salt), - 1024, 256)), ivGenerator, alg); - } - - /** - * Constructs an encryptor that uses AES encryption. - * - * @param secretKey the secret (symmetric) key - * @param ivGenerator the generator used to generate the initialization vector. If null, - * then a default algorithm will be used based on the provided {@link CipherAlgorithm} - * @param alg the {@link CipherAlgorithm} to be used - */ - public AesBytesEncryptor(SecretKey secretKey, BytesKeyGenerator ivGenerator, CipherAlgorithm alg) { - this.secretKey = new SecretKeySpec(secretKey.getEncoded(), "AES"); - this.alg = alg; - this.encryptor = alg.createCipher(); - this.decryptor = alg.createCipher(); - this.ivGenerator = ivGenerator != null ? ivGenerator : alg.defaultIvGenerator(); - } - - public byte[] encrypt(byte[] bytes) { - synchronized (this.encryptor) { - byte[] iv = this.ivGenerator.generateKey(); - initCipher(this.encryptor, Cipher.ENCRYPT_MODE, this.secretKey, - this.alg.getParameterSpec(iv)); - byte[] encrypted = doFinal(this.encryptor, bytes); - return this.ivGenerator != NULL_IV_GENERATOR ? concatenate(iv, encrypted) - : encrypted; - } - } - - public byte[] decrypt(byte[] encryptedBytes) { - synchronized (this.decryptor) { - byte[] iv = iv(encryptedBytes); - initCipher(this.decryptor, Cipher.DECRYPT_MODE, this.secretKey, - this.alg.getParameterSpec(iv)); - return doFinal( - this.decryptor, - this.ivGenerator != NULL_IV_GENERATOR ? encrypted(encryptedBytes, - iv.length) : encryptedBytes); - } - } - - // internal helpers - - private byte[] iv(byte[] encrypted) { - return this.ivGenerator != NULL_IV_GENERATOR ? subArray(encrypted, 0, - this.ivGenerator.getKeyLength()) : NULL_IV_GENERATOR.generateKey(); - } - - private byte[] encrypted(byte[] encryptedBytes, int ivLength) { - return subArray(encryptedBytes, ivLength, encryptedBytes.length); - } - - private static final BytesKeyGenerator NULL_IV_GENERATOR = new BytesKeyGenerator() { - - private final byte[] VALUE = new byte[16]; - - public int getKeyLength() { - return this.VALUE.length; - } - - public byte[] generateKey() { - return this.VALUE; - } - - }; } diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptor.java index 6797854479..7322ae5ff5 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptor.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import org.bouncycastle.crypto.PBEParametersGenerator; import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator; import org.bouncycastle.crypto.params.KeyParameter; + import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.KeyGenerators; @@ -31,22 +33,22 @@ import org.springframework.security.crypto.keygen.KeyGenerators; abstract class BouncyCastleAesBytesEncryptor implements BytesEncryptor { final KeyParameter secretKey; + final BytesKeyGenerator ivGenerator; BouncyCastleAesBytesEncryptor(String password, CharSequence salt) { this(password, salt, KeyGenerators.secureRandom(16)); } - BouncyCastleAesBytesEncryptor(String password, CharSequence salt, - BytesKeyGenerator ivGenerator) { + BouncyCastleAesBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator) { if (ivGenerator.getKeyLength() != 16) { throw new IllegalArgumentException("ivGenerator key length != block size 16"); } this.ivGenerator = ivGenerator; PBEParametersGenerator keyGenerator = new PKCS5S2ParametersGenerator(); - byte[] pkcs12PasswordBytes = PBEParametersGenerator - .PKCS5PasswordToUTF8Bytes(password.toCharArray()); + byte[] pkcs12PasswordBytes = PBEParametersGenerator.PKCS5PasswordToUTF8Bytes(password.toCharArray()); keyGenerator.init(pkcs12PasswordBytes, Hex.decode(salt), 1024); this.secretKey = (KeyParameter) keyGenerator.generateDerivedParameters(256); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesCbcBytesEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesCbcBytesEncryptor.java index d0e4c874c9..3aa6e6eab3 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesCbcBytesEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesCbcBytesEncryptor.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.encrypt; -import static org.springframework.security.crypto.util.EncodingUtils.concatenate; -import static org.springframework.security.crypto.util.EncodingUtils.subArray; +package org.springframework.security.crypto.encrypt; import org.bouncycastle.crypto.BufferedBlockCipher; import org.bouncycastle.crypto.InvalidCipherTextException; @@ -24,16 +22,17 @@ import org.bouncycastle.crypto.modes.CBCBlockCipher; import org.bouncycastle.crypto.paddings.PKCS7Padding; import org.bouncycastle.crypto.paddings.PaddedBufferedBlockCipher; import org.bouncycastle.crypto.params.ParametersWithIV; + import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; import org.springframework.security.crypto.keygen.BytesKeyGenerator; +import org.springframework.security.crypto.util.EncodingUtils; /** - * An Encryptor equivalent to {@link AesBytesEncryptor} using - * {@link CipherAlgorithm#CBC} that uses Bouncy Castle instead of JCE. The - * algorithm is equivalent to "AES/CBC/PKCS5Padding". + * An Encryptor equivalent to {@link AesBytesEncryptor} using {@link CipherAlgorithm#CBC} + * that uses Bouncy Castle instead of JCE. The algorithm is equivalent to + * "AES/CBC/PKCS5Padding". * * @author William Tran - * */ public class BouncyCastleAesCbcBytesEncryptor extends BouncyCastleAesBytesEncryptor { @@ -41,33 +40,29 @@ public class BouncyCastleAesCbcBytesEncryptor extends BouncyCastleAesBytesEncryp super(password, salt); } - public BouncyCastleAesCbcBytesEncryptor(String password, CharSequence salt, - BytesKeyGenerator ivGenerator) { + public BouncyCastleAesCbcBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator) { super(password, salt, ivGenerator); } @Override + @SuppressWarnings("deprecation") public byte[] encrypt(byte[] bytes) { byte[] iv = this.ivGenerator.generateKey(); - - @SuppressWarnings("deprecation") PaddedBufferedBlockCipher blockCipher = new PaddedBufferedBlockCipher( new CBCBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()), new PKCS7Padding()); - blockCipher.init(true, new ParametersWithIV(secretKey, iv)); + blockCipher.init(true, new ParametersWithIV(this.secretKey, iv)); byte[] encrypted = process(blockCipher, bytes); - return iv != null ? concatenate(iv, encrypted) : encrypted; + return (iv != null) ? EncodingUtils.concatenate(iv, encrypted) : encrypted; } @Override + @SuppressWarnings("deprecation") public byte[] decrypt(byte[] encryptedBytes) { - byte[] iv = subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength()); - encryptedBytes = subArray(encryptedBytes, this.ivGenerator.getKeyLength(), - encryptedBytes.length); - - @SuppressWarnings("deprecation") + byte[] iv = EncodingUtils.subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength()); + encryptedBytes = EncodingUtils.subArray(encryptedBytes, this.ivGenerator.getKeyLength(), encryptedBytes.length); PaddedBufferedBlockCipher blockCipher = new PaddedBufferedBlockCipher( new CBCBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()), new PKCS7Padding()); - blockCipher.init(false, new ParametersWithIV(secretKey, iv)); + blockCipher.init(false, new ParametersWithIV(this.secretKey, iv)); return process(blockCipher, encryptedBytes); } @@ -77,8 +72,8 @@ public class BouncyCastleAesCbcBytesEncryptor extends BouncyCastleAesBytesEncryp try { bytesWritten += blockCipher.doFinal(buf, bytesWritten); } - catch (InvalidCipherTextException e) { - throw new IllegalStateException("unable to encrypt/decrypt", e); + catch (InvalidCipherTextException ex) { + throw new IllegalStateException("unable to encrypt/decrypt", ex); } if (bytesWritten == buf.length) { return buf; @@ -87,4 +82,5 @@ public class BouncyCastleAesCbcBytesEncryptor extends BouncyCastleAesBytesEncryp System.arraycopy(buf, 0, out, 0, bytesWritten); return out; } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesGcmBytesEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesGcmBytesEncryptor.java index 5c81414117..cce6dd6d99 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesGcmBytesEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BouncyCastleAesGcmBytesEncryptor.java @@ -13,22 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.encrypt; -import static org.springframework.security.crypto.util.EncodingUtils.concatenate; -import static org.springframework.security.crypto.util.EncodingUtils.subArray; +package org.springframework.security.crypto.encrypt; import org.bouncycastle.crypto.InvalidCipherTextException; import org.bouncycastle.crypto.modes.AEADBlockCipher; import org.bouncycastle.crypto.modes.GCMBlockCipher; import org.bouncycastle.crypto.params.AEADParameters; + import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; import org.springframework.security.crypto.keygen.BytesKeyGenerator; +import org.springframework.security.crypto.util.EncodingUtils; /** - * An Encryptor equivalent to {@link AesBytesEncryptor} using - * {@link CipherAlgorithm#GCM} that uses Bouncy Castle instead of JCE. The - * algorithm is equivalent to "AES/GCM/NoPadding". + * An Encryptor equivalent to {@link AesBytesEncryptor} using {@link CipherAlgorithm#GCM} + * that uses Bouncy Castle instead of JCE. The algorithm is equivalent to + * "AES/GCM/NoPadding". * * @author William Tran * @@ -39,32 +39,27 @@ public class BouncyCastleAesGcmBytesEncryptor extends BouncyCastleAesBytesEncryp super(password, salt); } - public BouncyCastleAesGcmBytesEncryptor(String password, CharSequence salt, - BytesKeyGenerator ivGenerator) { + public BouncyCastleAesGcmBytesEncryptor(String password, CharSequence salt, BytesKeyGenerator ivGenerator) { super(password, salt, ivGenerator); } @Override + @SuppressWarnings("deprecation") public byte[] encrypt(byte[] bytes) { byte[] iv = this.ivGenerator.generateKey(); - - @SuppressWarnings("deprecation") GCMBlockCipher blockCipher = new GCMBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()); - blockCipher.init(true, new AEADParameters(secretKey, 128, iv, null)); - + blockCipher.init(true, new AEADParameters(this.secretKey, 128, iv, null)); byte[] encrypted = process(blockCipher, bytes); - return iv != null ? concatenate(iv, encrypted) : encrypted; + return (iv != null) ? EncodingUtils.concatenate(iv, encrypted) : encrypted; } @Override + @SuppressWarnings("deprecation") public byte[] decrypt(byte[] encryptedBytes) { - byte[] iv = subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength()); - encryptedBytes = subArray(encryptedBytes, this.ivGenerator.getKeyLength(), - encryptedBytes.length); - - @SuppressWarnings("deprecation") + byte[] iv = EncodingUtils.subArray(encryptedBytes, 0, this.ivGenerator.getKeyLength()); + encryptedBytes = EncodingUtils.subArray(encryptedBytes, this.ivGenerator.getKeyLength(), encryptedBytes.length); GCMBlockCipher blockCipher = new GCMBlockCipher(new org.bouncycastle.crypto.engines.AESFastEngine()); - blockCipher.init(false, new AEADParameters(secretKey, 128, iv, null)); + blockCipher.init(false, new AEADParameters(this.secretKey, 128, iv, null)); return process(blockCipher, encryptedBytes); } @@ -74,8 +69,8 @@ public class BouncyCastleAesGcmBytesEncryptor extends BouncyCastleAesBytesEncryp try { bytesWritten += blockCipher.doFinal(buf, bytesWritten); } - catch (InvalidCipherTextException e) { - throw new IllegalStateException("unable to encrypt/decrypt", e); + catch (InvalidCipherTextException ex) { + throw new IllegalStateException("unable to encrypt/decrypt", ex); } if (bytesWritten == buf.length) { return buf; diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BytesEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BytesEncryptor.java index 4285224627..37b4be273b 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/BytesEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/BytesEncryptor.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; /** * Service interface for symmetric data encryption. + * * @author Keith Donald */ public interface BytesEncryptor { diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/CipherUtils.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/CipherUtils.java index d77fd9f4fb..723208caff 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/CipherUtils.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/CipherUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import java.security.InvalidAlgorithmParameterException; @@ -33,81 +34,82 @@ import javax.crypto.spec.PBEParameterSpec; /** * Static helper for working with the Cipher API. + * * @author Keith Donald */ -class CipherUtils { +final class CipherUtils { + + private CipherUtils() { + } /** * Generates a SecretKey. */ - public static SecretKey newSecretKey(String algorithm, String password) { + static SecretKey newSecretKey(String algorithm, String password) { return newSecretKey(algorithm, new PBEKeySpec(password.toCharArray())); } /** * Generates a SecretKey. */ - public static SecretKey newSecretKey(String algorithm, PBEKeySpec keySpec) { + static SecretKey newSecretKey(String algorithm, PBEKeySpec keySpec) { try { SecretKeyFactory factory = SecretKeyFactory.getInstance(algorithm); return factory.generateSecret(keySpec); } - catch (NoSuchAlgorithmException e) { - throw new IllegalArgumentException("Not a valid encryption algorithm", e); + catch (NoSuchAlgorithmException ex) { + throw new IllegalArgumentException("Not a valid encryption algorithm", ex); } - catch (InvalidKeySpecException e) { - throw new IllegalArgumentException("Not a valid secret key", e); + catch (InvalidKeySpecException ex) { + throw new IllegalArgumentException("Not a valid secret key", ex); } } /** * Constructs a new Cipher. */ - public static Cipher newCipher(String algorithm) { + static Cipher newCipher(String algorithm) { try { return Cipher.getInstance(algorithm); } - catch (NoSuchAlgorithmException e) { - throw new IllegalArgumentException("Not a valid encryption algorithm", e); + catch (NoSuchAlgorithmException ex) { + throw new IllegalArgumentException("Not a valid encryption algorithm", ex); } - catch (NoSuchPaddingException e) { - throw new IllegalStateException("Should not happen", e); + catch (NoSuchPaddingException ex) { + throw new IllegalStateException("Should not happen", ex); } } /** * Initializes the Cipher for use. */ - public static T getParameterSpec(Cipher cipher, - Class parameterSpecClass) { + static T getParameterSpec(Cipher cipher, Class parameterSpecClass) { try { return cipher.getParameters().getParameterSpec(parameterSpecClass); } - catch (InvalidParameterSpecException e) { - throw new IllegalArgumentException("Unable to access parameter", e); + catch (InvalidParameterSpecException ex) { + throw new IllegalArgumentException("Unable to access parameter", ex); } } /** * Initializes the Cipher for use. */ - public static void initCipher(Cipher cipher, int mode, SecretKey secretKey) { + static void initCipher(Cipher cipher, int mode, SecretKey secretKey) { initCipher(cipher, mode, secretKey, null); } /** * Initializes the Cipher for use. */ - public static void initCipher(Cipher cipher, int mode, SecretKey secretKey, - byte[] salt, int iterationCount) { + static void initCipher(Cipher cipher, int mode, SecretKey secretKey, byte[] salt, int iterationCount) { initCipher(cipher, mode, secretKey, new PBEParameterSpec(salt, iterationCount)); } /** * Initializes the Cipher for use. */ - public static void initCipher(Cipher cipher, int mode, SecretKey secretKey, - AlgorithmParameterSpec parameterSpec) { + static void initCipher(Cipher cipher, int mode, SecretKey secretKey, AlgorithmParameterSpec parameterSpec) { try { if (parameterSpec != null) { cipher.init(mode, secretKey, parameterSpec); @@ -116,13 +118,11 @@ class CipherUtils { cipher.init(mode, secretKey); } } - catch (InvalidKeyException e) { - throw new IllegalArgumentException( - "Unable to initialize due to invalid secret key", e); + catch (InvalidKeyException ex) { + throw new IllegalArgumentException("Unable to initialize due to invalid secret key", ex); } - catch (InvalidAlgorithmParameterException e) { - throw new IllegalStateException( - "Unable to initialize due to invalid decryption parameter spec", e); + catch (InvalidAlgorithmParameterException ex) { + throw new IllegalStateException("Unable to initialize due to invalid decryption parameter spec", ex); } } @@ -130,21 +130,16 @@ class CipherUtils { * Invokes the Cipher to perform encryption or decryption (depending on the * initialized mode). */ - public static byte[] doFinal(Cipher cipher, byte[] input) { + static byte[] doFinal(Cipher cipher, byte[] input) { try { return cipher.doFinal(input); } - catch (IllegalBlockSizeException e) { - throw new IllegalStateException( - "Unable to invoke Cipher due to illegal block size", e); + catch (IllegalBlockSizeException ex) { + throw new IllegalStateException("Unable to invoke Cipher due to illegal block size", ex); } - catch (BadPaddingException e) { - throw new IllegalStateException("Unable to invoke Cipher due to bad padding", - e); + catch (BadPaddingException ex) { + throw new IllegalStateException("Unable to invoke Cipher due to bad padding", ex); } } - private CipherUtils() { - } - } diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/Encryptors.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/Encryptors.java index 7ebfb5a356..ff7505b289 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/Encryptors.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/Encryptors.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; @@ -24,7 +25,10 @@ import org.springframework.security.crypto.keygen.KeyGenerators; * * @author Keith Donald */ -public class Encryptors { +public final class Encryptors { + + private Encryptors() { + } /** * Creates a standard password-based bytes encryptor using 256 bit AES encryption with @@ -34,15 +38,13 @@ public class Encryptors { * hex-encoded; it should be random and at least 8 bytes in length. Also applies a * random 16-byte initialization vector to ensure each encrypted message will be * unique. Requires Java 6. - * * @param password the password used to generate the encryptor's secret key; should * not be shared * @param salt a hex-encoded, random, site-global salt value to use to generate the * key */ public static BytesEncryptor stronger(CharSequence password, CharSequence salt) { - return new AesBytesEncryptor(password.toString(), salt, - KeyGenerators.secureRandom(16), CipherAlgorithm.GCM); + return new AesBytesEncryptor(password.toString(), salt, KeyGenerators.secureRandom(16), CipherAlgorithm.GCM); } /** @@ -51,30 +53,25 @@ public class Encryptors { * Function #2). Salts the password to prevent dictionary attacks against the key. The * provided salt is expected to be hex-encoded; it should be random and at least 8 * bytes in length. Also applies a random 16-byte initialization vector to ensure each - * encrypted message will be unique. Requires Java 6. - * NOTE: This mode is not + * encrypted message will be unique. Requires Java 6. NOTE: This mode is not * authenticated - * and does not provide any guarantees about the authenticity of the data. - * For a more secure alternative, users should prefer + * and does not provide any guarantees about the authenticity of the data. For a more + * secure alternative, users should prefer * {@link #stronger(CharSequence, CharSequence)}. - * * @param password the password used to generate the encryptor's secret key; should * not be shared * @param salt a hex-encoded, random, site-global salt value to use to generate the * key * - * @see #stronger(CharSequence, CharSequence), which uses the significatly more secure - * GCM (instead of CBC) + * @see #stronger(CharSequence, CharSequence) */ public static BytesEncryptor standard(CharSequence password, CharSequence salt) { - return new AesBytesEncryptor(password.toString(), salt, - KeyGenerators.secureRandom(16)); + return new AesBytesEncryptor(password.toString(), salt, KeyGenerators.secureRandom(16)); } /** * Creates a text encryptor that uses "stronger" password-based encryption. Encrypted * text is hex-encoded. - * * @param password the password used to generate the encryptor's secret key; should * not be shared * @see Encryptors#stronger(CharSequence, CharSequence) @@ -86,7 +83,6 @@ public class Encryptors { /** * Creates a text encryptor that uses "standard" password-based encryption. Encrypted * text is hex-encoded. - * * @param password the password used to generate the encryptor's secret key; should * not be shared * @see Encryptors#standard(CharSequence, CharSequence) @@ -100,7 +96,6 @@ public class Encryptors { * encryption. Uses a 16-byte all-zero initialization vector so encrypting the same * data results in the same encryption result. This is done to allow encrypted data to * be queried against. Encrypted text is hex-encoded. - * * @param password the password used to generate the encryptor's secret key; should * not be shared * @param salt a hex-encoded, random, site-global salt value to use to generate the @@ -110,8 +105,7 @@ public class Encryptors { */ @Deprecated public static TextEncryptor queryableText(CharSequence password, CharSequence salt) { - return new HexEncodingTextEncryptor(new AesBytesEncryptor(password.toString(), - salt)); + return new HexEncodingTextEncryptor(new AesBytesEncryptor(password.toString(), salt)); } /** @@ -119,20 +113,19 @@ public class Encryptors { * environments where working with plain text strings is desired for simplicity. */ public static TextEncryptor noOpText() { - return NO_OP_TEXT_INSTANCE; + return NoOpTextEncryptor.INSTANCE; } - private Encryptors() { - } - - private static final TextEncryptor NO_OP_TEXT_INSTANCE = new NoOpTextEncryptor(); - private static final class NoOpTextEncryptor implements TextEncryptor { + static final TextEncryptor INSTANCE = new NoOpTextEncryptor(); + + @Override public String encrypt(String text) { return text; } + @Override public String decrypt(String encryptedText) { return encryptedText; } diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/HexEncodingTextEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/HexEncodingTextEncryptor.java index a4ed4c8699..80026b63b5 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/HexEncodingTextEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/HexEncodingTextEncryptor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import org.springframework.security.crypto.codec.Hex; @@ -22,6 +23,7 @@ import org.springframework.security.crypto.codec.Utf8; * Delegates to an {@link BytesEncryptor} to encrypt text strings. Raw text strings are * UTF-8 encoded before being passed to the encryptor. Encrypted strings are returned * hex-encoded. + * * @author Keith Donald */ final class HexEncodingTextEncryptor implements TextEncryptor { @@ -32,12 +34,14 @@ final class HexEncodingTextEncryptor implements TextEncryptor { this.encryptor = encryptor; } + @Override public String encrypt(String text) { - return new String(Hex.encode(encryptor.encrypt(Utf8.encode(text)))); + return new String(Hex.encode(this.encryptor.encrypt(Utf8.encode(text)))); } + @Override public String decrypt(String encryptedText) { - return Utf8.decode(encryptor.decrypt(Hex.decode(encryptedText))); + return Utf8.decode(this.encryptor.decrypt(Hex.decode(encryptedText))); } } diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/TextEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/TextEncryptor.java index 03ff65a4d7..34eb1e5828 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/encrypt/TextEncryptor.java +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/TextEncryptor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; /** diff --git a/crypto/src/main/java/org/springframework/security/crypto/factory/PasswordEncoderFactories.java b/crypto/src/main/java/org/springframework/security/crypto/factory/PasswordEncoderFactories.java index c0a54ba2cf..eb27d4a058 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/factory/PasswordEncoderFactories.java +++ b/crypto/src/main/java/org/springframework/security/crypto/factory/PasswordEncoderFactories.java @@ -16,6 +16,9 @@ package org.springframework.security.crypto.factory; +import java.util.HashMap; +import java.util.Map; + import org.springframework.security.crypto.argon2.Argon2PasswordEncoder; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.DelegatingPasswordEncoder; @@ -23,15 +26,16 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.crypto.password.Pbkdf2PasswordEncoder; import org.springframework.security.crypto.scrypt.SCryptPasswordEncoder; -import java.util.HashMap; -import java.util.Map; - /** * Used for creating {@link PasswordEncoder} instances + * * @author Rob Winch * @since 5.0 */ -public class PasswordEncoderFactories { +public final class PasswordEncoderFactories { + + private PasswordEncoderFactories() { + } /** * Creates a {@link DelegatingPasswordEncoder} with default mappings. Additional @@ -41,18 +45,21 @@ public class PasswordEncoderFactories { * *
        *
      • bcrypt - {@link BCryptPasswordEncoder} (Also used for encoding)
      • - *
      • ldap - {@link org.springframework.security.crypto.password.LdapShaPasswordEncoder}
      • - *
      • MD4 - {@link org.springframework.security.crypto.password.Md4PasswordEncoder}
      • + *
      • ldap - + * {@link org.springframework.security.crypto.password.LdapShaPasswordEncoder}
      • + *
      • MD4 - + * {@link org.springframework.security.crypto.password.Md4PasswordEncoder}
      • *
      • MD5 - {@code new MessageDigestPasswordEncoder("MD5")}
      • - *
      • noop - {@link org.springframework.security.crypto.password.NoOpPasswordEncoder}
      • + *
      • noop - + * {@link org.springframework.security.crypto.password.NoOpPasswordEncoder}
      • *
      • pbkdf2 - {@link Pbkdf2PasswordEncoder}
      • *
      • scrypt - {@link SCryptPasswordEncoder}
      • *
      • SHA-1 - {@code new MessageDigestPasswordEncoder("SHA-1")}
      • *
      • SHA-256 - {@code new MessageDigestPasswordEncoder("SHA-256")}
      • - *
      • sha256 - {@link org.springframework.security.crypto.password.StandardPasswordEncoder}
      • + *
      • sha256 - + * {@link org.springframework.security.crypto.password.StandardPasswordEncoder}
      • *
      • argon2 - {@link Argon2PasswordEncoder}
      • *
      - * * @return the {@link PasswordEncoder} to use */ @SuppressWarnings("deprecation") @@ -67,12 +74,11 @@ public class PasswordEncoderFactories { encoders.put("pbkdf2", new Pbkdf2PasswordEncoder()); encoders.put("scrypt", new SCryptPasswordEncoder()); encoders.put("SHA-1", new org.springframework.security.crypto.password.MessageDigestPasswordEncoder("SHA-1")); - encoders.put("SHA-256", new org.springframework.security.crypto.password.MessageDigestPasswordEncoder("SHA-256")); + encoders.put("SHA-256", + new org.springframework.security.crypto.password.MessageDigestPasswordEncoder("SHA-256")); encoders.put("sha256", new org.springframework.security.crypto.password.StandardPasswordEncoder()); encoders.put("argon2", new Argon2PasswordEncoder()); - return new DelegatingPasswordEncoder(encodingId, encoders); } - private PasswordEncoderFactories() {} } diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/Base64StringKeyGenerator.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/Base64StringKeyGenerator.java index bba2113ed9..347328bba0 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/Base64StringKeyGenerator.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/Base64StringKeyGenerator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; import java.util.Base64; @@ -26,8 +27,11 @@ import java.util.Base64; * @since 5.0 */ public class Base64StringKeyGenerator implements StringKeyGenerator { + private static final int DEFAULT_KEY_LENGTH = 32; + private final BytesKeyGenerator keyGenerator; + private final Base64.Encoder encoder; /** @@ -76,4 +80,5 @@ public class Base64StringKeyGenerator implements StringKeyGenerator { byte[] base64EncodedKey = this.encoder.encode(key); return new String(base64EncodedKey); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/BytesKeyGenerator.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/BytesKeyGenerator.java index 6dbff29f58..51c18df36c 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/BytesKeyGenerator.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/BytesKeyGenerator.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; /** * A generator for unique byte array-based keys. + * * @author Keith Donald */ public interface BytesKeyGenerator { diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/HexEncodingStringKeyGenerator.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/HexEncodingStringKeyGenerator.java index 005a8e6e63..9451c9dd4f 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/HexEncodingStringKeyGenerator.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/HexEncodingStringKeyGenerator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; import org.springframework.security.crypto.codec.Hex; @@ -20,6 +21,7 @@ import org.springframework.security.crypto.codec.Hex; /** * A StringKeyGenerator that generates hex-encoded String keys. Delegates to a * {@link BytesKeyGenerator} for the actual key generation. + * * @author Keith Donald */ final class HexEncodingStringKeyGenerator implements StringKeyGenerator { @@ -30,8 +32,9 @@ final class HexEncodingStringKeyGenerator implements StringKeyGenerator { this.keyGenerator = keyGenerator; } + @Override public String generateKey() { - return new String(Hex.encode(keyGenerator.generateKey())); + return new String(Hex.encode(this.keyGenerator.generateKey())); } } diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/KeyGenerators.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/KeyGenerators.java index b5c5989704..4b73049678 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/KeyGenerators.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/KeyGenerators.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; import java.security.SecureRandom; @@ -20,9 +21,13 @@ import java.security.SecureRandom; /** * Factory for commonly used key generators. Public API for constructing a * {@link BytesKeyGenerator} or {@link StringKeyGenerator}. + * * @author Keith Donald */ -public class KeyGenerators { +public final class KeyGenerators { + + private KeyGenerators() { + } /** * Create a {@link BytesKeyGenerator} that uses a {@link SecureRandom} to generate @@ -58,8 +63,4 @@ public class KeyGenerators { return new HexEncodingStringKeyGenerator(secureRandom()); } - // internal helpers - - private KeyGenerators() { - } } diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/SecureRandomBytesKeyGenerator.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/SecureRandomBytesKeyGenerator.java index f04577802d..e3cb10adfe 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/SecureRandomBytesKeyGenerator.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/SecureRandomBytesKeyGenerator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; import java.security.SecureRandom; @@ -27,6 +28,8 @@ import java.security.SecureRandom; */ final class SecureRandomBytesKeyGenerator implements BytesKeyGenerator { + private static final int DEFAULT_KEY_LENGTH = 8; + private final SecureRandom random; private final int keyLength; @@ -46,16 +49,16 @@ final class SecureRandomBytesKeyGenerator implements BytesKeyGenerator { this.keyLength = keyLength; } + @Override public int getKeyLength() { - return keyLength; + return this.keyLength; } + @Override public byte[] generateKey() { - byte[] bytes = new byte[keyLength]; - random.nextBytes(bytes); + byte[] bytes = new byte[this.keyLength]; + this.random.nextBytes(bytes); return bytes; } - private static final int DEFAULT_KEY_LENGTH = 8; - } diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/SharedKeyGenerator.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/SharedKeyGenerator.java index 5b9e7f5638..1e39d6323a 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/SharedKeyGenerator.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/SharedKeyGenerator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; /** @@ -30,12 +31,14 @@ final class SharedKeyGenerator implements BytesKeyGenerator { this.sharedKey = sharedKey; } + @Override public int getKeyLength() { - return sharedKey.length; + return this.sharedKey.length; } + @Override public byte[] generateKey() { - return sharedKey; + return this.sharedKey; } } diff --git a/crypto/src/main/java/org/springframework/security/crypto/keygen/StringKeyGenerator.java b/crypto/src/main/java/org/springframework/security/crypto/keygen/StringKeyGenerator.java index 53e3422f79..1f5ed5a169 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/keygen/StringKeyGenerator.java +++ b/crypto/src/main/java/org/springframework/security/crypto/keygen/StringKeyGenerator.java @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.keygen; /** * A generator for unique string keys. + * * @author Keith Donald */ public interface StringKeyGenerator { String generateKey(); -} \ No newline at end of file +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/AbstractPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/AbstractPasswordEncoder.java index f70e289adc..180fb1863b 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/password/AbstractPasswordEncoder.java +++ b/crypto/src/main/java/org/springframework/security/crypto/password/AbstractPasswordEncoder.java @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; +import java.security.MessageDigest; + import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.KeyGenerators; - -import java.security.MessageDigest; - -import static org.springframework.security.crypto.util.EncodingUtils.concatenate; -import static org.springframework.security.crypto.util.EncodingUtils.subArray; +import org.springframework.security.crypto.util.EncodingUtils; /** * Abstract base class for password encoders @@ -47,14 +46,14 @@ public abstract class AbstractPasswordEncoder implements PasswordEncoder { @Override public boolean matches(CharSequence rawPassword, String encodedPassword) { byte[] digested = Hex.decode(encodedPassword); - byte[] salt = subArray(digested, 0, this.saltGenerator.getKeyLength()); + byte[] salt = EncodingUtils.subArray(digested, 0, this.saltGenerator.getKeyLength()); return matches(digested, encodeAndConcatenate(rawPassword, salt)); } protected abstract byte[] encode(CharSequence rawPassword, byte[] salt); protected byte[] encodeAndConcatenate(CharSequence rawPassword, byte[] salt) { - return concatenate(salt, encode(rawPassword, salt)); + return EncodingUtils.concatenate(salt, encode(rawPassword, salt)); } /** @@ -63,4 +62,5 @@ public abstract class AbstractPasswordEncoder implements PasswordEncoder { protected static boolean matches(byte[] expected, byte[] actual) { return MessageDigest.isEqual(expected, actual); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/DelegatingPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/DelegatingPasswordEncoder.java index 4bc4e68083..e14b5111af 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/password/DelegatingPasswordEncoder.java +++ b/crypto/src/main/java/org/springframework/security/crypto/password/DelegatingPasswordEncoder.java @@ -31,7 +31,7 @@ import java.util.Map; * *
        * String idForEncode = "bcrypt";
      - * Map encoders = new HashMap<>();
      + * Map<String,PasswordEncoder> encoders = new HashMap<>();
        * encoders.put(idForEncode, new BCryptPasswordEncoder());
        * encoders.put("noop", NoOpPasswordEncoder.getInstance());
        * encoders.put("pbkdf2", new Pbkdf2PasswordEncoder());
      @@ -50,8 +50,8 @@ import java.util.Map;
        * {id}encodedPassword
        * 
      * - * Such that "id" is an identifier used to look up which {@link PasswordEncoder} should - * be used and "encodedPassword" is the original encoded password for the selected + * Such that "id" is an identifier used to look up which {@link PasswordEncoder} should be + * used and "encodedPassword" is the original encoded password for the selected * {@link PasswordEncoder}. The "id" must be at the beginning of the password, start with * "{" and end with "}". If the "id" cannot be found, the "id" will be null. * @@ -70,8 +70,8 @@ import java.util.Map; * *
        *
      1. The first password would have a {@code PasswordEncoder} id of "bcrypt" and - * encodedPassword of "$2a$10$dXJ3SW6G7P50lGmMkkmwe.20cQQubK3.HZWzG3YB1tlRy.fqvM/BG". - * When matching it would delegate to + * encodedPassword of "$2a$10$dXJ3SW6G7P50lGmMkkmwe.20cQQubK3.HZWzG3YB1tlRy.fqvM/BG". When + * matching it would delegate to * {@link org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder}
      2. *
      3. The second password would have a {@code PasswordEncoder} id of "noop" and * encodedPassword of "password". When matching it would delegate to @@ -114,18 +114,23 @@ import java.util.Map; * {@link IllegalArgumentException}. This behavior can be customized using * {@link #setDefaultPasswordEncoderForMatches(PasswordEncoder)}. * - * @see org.springframework.security.crypto.factory.PasswordEncoderFactories - * * @author Rob Winch * @author Michael Simons * @since 5.0 + * @see org.springframework.security.crypto.factory.PasswordEncoderFactories */ public class DelegatingPasswordEncoder implements PasswordEncoder { + private static final String PREFIX = "{"; + private static final String SUFFIX = "}"; + private final String idForEncode; + private final PasswordEncoder passwordEncoderForEncode; + private final Map idToPasswordEncoder; + private PasswordEncoder defaultPasswordEncoderForMatches = new UnmappedIdPasswordEncoder(); /** @@ -133,15 +138,16 @@ public class DelegatingPasswordEncoder implements PasswordEncoder { * @param idForEncode the id used to lookup which {@link PasswordEncoder} should be * used for {@link #encode(CharSequence)} * @param idToPasswordEncoder a Map of id to {@link PasswordEncoder} used to determine - * which {@link PasswordEncoder} should be used for {@link #matches(CharSequence, String)} + * which {@link PasswordEncoder} should be used for + * {@link #matches(CharSequence, String)} */ - public DelegatingPasswordEncoder(String idForEncode, - Map idToPasswordEncoder) { + public DelegatingPasswordEncoder(String idForEncode, Map idToPasswordEncoder) { if (idForEncode == null) { throw new IllegalArgumentException("idForEncode cannot be null"); } if (!idToPasswordEncoder.containsKey(idForEncode)) { - throw new IllegalArgumentException("idForEncode " + idForEncode + "is not found in idToPasswordEncoder " + idToPasswordEncoder); + throw new IllegalArgumentException( + "idForEncode " + idForEncode + "is not found in idToPasswordEncoder " + idToPasswordEncoder); } for (String id : idToPasswordEncoder.keySet()) { if (id == null) { @@ -165,16 +171,15 @@ public class DelegatingPasswordEncoder implements PasswordEncoder { * {@link PasswordEncoder}. * *

        - The encodedPassword provided will be the full password - * passed in including the {"id"} portion.* For example, if the password of - * "{notmapped}foobar" was used, the "id" would be "notmapped" and the encodedPassword - * passed into the {@link PasswordEncoder} would be "{notmapped}foobar". + * The encodedPassword provided will be the full password passed in including the + * {"id"} portion.* For example, if the password of "{notmapped}foobar" was used, the + * "id" would be "notmapped" and the encodedPassword passed into the + * {@link PasswordEncoder} would be "{notmapped}foobar". *

        - * @param defaultPasswordEncoderForMatches the encoder to use. The default is to - * throw an {@link IllegalArgumentException} + * @param defaultPasswordEncoderForMatches the encoder to use. The default is to throw + * an {@link IllegalArgumentException} */ - public void setDefaultPasswordEncoderForMatches( - PasswordEncoder defaultPasswordEncoderForMatches) { + public void setDefaultPasswordEncoderForMatches(PasswordEncoder defaultPasswordEncoderForMatches) { if (defaultPasswordEncoderForMatches == null) { throw new IllegalArgumentException("defaultPasswordEncoderForMatches cannot be null"); } @@ -194,8 +199,7 @@ public class DelegatingPasswordEncoder implements PasswordEncoder { String id = extractId(prefixEncodedPassword); PasswordEncoder delegate = this.idToPasswordEncoder.get(id); if (delegate == null) { - return this.defaultPasswordEncoderForMatches - .matches(rawPassword, prefixEncodedPassword); + return this.defaultPasswordEncoderForMatches.matches(rawPassword, prefixEncodedPassword); } String encodedPassword = extractEncodedPassword(prefixEncodedPassword); return delegate.matches(rawPassword, encodedPassword); @@ -244,10 +248,11 @@ public class DelegatingPasswordEncoder implements PasswordEncoder { } @Override - public boolean matches(CharSequence rawPassword, - String prefixEncodedPassword) { + public boolean matches(CharSequence rawPassword, String prefixEncodedPassword) { String id = extractId(prefixEncodedPassword); throw new IllegalArgumentException("There is no PasswordEncoder mapped for the id \"" + id + "\""); } + } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/Digester.java b/crypto/src/main/java/org/springframework/security/crypto/password/Digester.java index b0b57bc4b4..e7c8dffd3c 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/password/Digester.java +++ b/crypto/src/main/java/org/springframework/security/crypto/password/Digester.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; import java.security.MessageDigest; @@ -45,9 +46,9 @@ final class Digester { setIterations(iterations); } - public byte[] digest(byte[] value) { - MessageDigest messageDigest = createDigest(algorithm); - for (int i = 0; i < iterations; i++) { + byte[] digest(byte[] value) { + MessageDigest messageDigest = createDigest(this.algorithm); + for (int i = 0; i < this.iterations; i++) { value = messageDigest.digest(value); } return value; @@ -64,8 +65,9 @@ final class Digester { try { return MessageDigest.getInstance(algorithm); } - catch (NoSuchAlgorithmException e) { - throw new IllegalStateException("No such hashing algorithm", e); + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException("No such hashing algorithm", ex); } } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/LdapShaPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/LdapShaPasswordEncoder.java index 90252f3c14..eb9687c0fc 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/password/LdapShaPasswordEncoder.java +++ b/crypto/src/main/java/org/springframework/security/crypto/password/LdapShaPasswordEncoder.java @@ -16,13 +16,13 @@ package org.springframework.security.crypto.password; +import java.security.MessageDigest; +import java.util.Base64; + import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.KeyGenerators; -import java.security.MessageDigest; -import java.util.Base64; - /** * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered * secure. @@ -39,30 +39,27 @@ import java.util.Base64; * @deprecated Digest based password encoding is not considered secure. Instead use an * adaptive one way function like BCryptPasswordEncoder, Pbkdf2PasswordEncoder, or * SCryptPasswordEncoder. Even better use {@link DelegatingPasswordEncoder} which supports - * password upgrades. There are no plans to remove this support. It is deprecated to indicate - * that this is a legacy implementation and using it is considered insecure. + * password upgrades. There are no plans to remove this support. It is deprecated to + * indicate that this is a legacy implementation and using it is considered insecure. */ @Deprecated public class LdapShaPasswordEncoder implements PasswordEncoder { - // ~ Static fields/initializers - // ===================================================================================== /** The number of bytes in a SHA hash */ private static final int SHA_LENGTH = 20; + private static final String SSHA_PREFIX = "{SSHA}"; + private static final String SSHA_PREFIX_LC = SSHA_PREFIX.toLowerCase(); + private static final String SHA_PREFIX = "{SHA}"; + private static final String SHA_PREFIX_LC = SHA_PREFIX.toLowerCase(); - // ~ Instance fields - // ================================================================================================ private BytesKeyGenerator saltGenerator; private boolean forceLowerCasePrefix; - // ~ Constructors - // =================================================================================================== - public LdapShaPasswordEncoder() { this(KeyGenerators.secureRandom()); } @@ -74,18 +71,13 @@ public class LdapShaPasswordEncoder implements PasswordEncoder { this.saltGenerator = saltGenerator; } - // ~ Methods - // ======================================================================================================== - private byte[] combineHashAndSalt(byte[] hash, byte[] salt) { if (salt == null) { return hash; } - byte[] hashAndSalt = new byte[hash.length + salt.length]; System.arraycopy(hash, 0, hashAndSalt, 0, hash.length); System.arraycopy(salt, 0, hashAndSalt, hash.length, salt.length); - return hashAndSalt; } @@ -93,97 +85,85 @@ public class LdapShaPasswordEncoder implements PasswordEncoder { * Calculates the hash of password (and salt bytes, if supplied) and returns a base64 * encoded concatenation of the hash and salt, prefixed with {SHA} (or {SSHA} if salt * was used). - * * @param rawPass the password to be encoded. - * * @return the encoded password in the specified format * */ + @Override public String encode(CharSequence rawPass) { byte[] salt = this.saltGenerator.generateKey(); return encode(rawPass, salt); } - private String encode(CharSequence rawPassword, byte[] salt) { - MessageDigest sha; - - try { - sha = MessageDigest.getInstance("SHA"); - sha.update(Utf8.encode(rawPassword)); - } - catch (java.security.NoSuchAlgorithmException e) { - throw new IllegalStateException("No SHA implementation available!"); - } - + MessageDigest sha = getSha(rawPassword); if (salt != null) { sha.update(salt); } - byte[] hash = combineHashAndSalt(sha.digest(), salt); - - String prefix; - - if (salt == null || salt.length == 0) { - prefix = forceLowerCasePrefix ? SHA_PREFIX_LC : SHA_PREFIX; - } - else { - prefix = forceLowerCasePrefix ? SSHA_PREFIX_LC : SSHA_PREFIX; - } - + String prefix = getPrefix(salt); return prefix + Utf8.decode(Base64.getEncoder().encode(hash)); } + private MessageDigest getSha(CharSequence rawPassword) { + try { + MessageDigest sha = MessageDigest.getInstance("SHA"); + sha.update(Utf8.encode(rawPassword)); + return sha; + } + catch (java.security.NoSuchAlgorithmException ex) { + throw new IllegalStateException("No SHA implementation available!"); + } + } + + private String getPrefix(byte[] salt) { + if (salt == null || salt.length == 0) { + return this.forceLowerCasePrefix ? SHA_PREFIX_LC : SHA_PREFIX; + } + return this.forceLowerCasePrefix ? SSHA_PREFIX_LC : SSHA_PREFIX; + } + private byte[] extractSalt(String encPass) { String encPassNoLabel = encPass.substring(6); - byte[] hashAndSalt = Base64.getDecoder().decode(encPassNoLabel.getBytes()); int saltLength = hashAndSalt.length - SHA_LENGTH; byte[] salt = new byte[saltLength]; System.arraycopy(hashAndSalt, SHA_LENGTH, salt, 0, saltLength); - return salt; } /** * Checks the validity of an unencoded password against an encoded one in the form * "{SSHA}sQuQF8vj8Eg2Y1hPdh3bkQhCKQBgjhQI". - * * @param rawPassword unencoded password to be verified. * @param encodedPassword the actual SSHA or SHA encoded password - * * @return true if they match (independent of the case of the prefix). */ + @Override public boolean matches(CharSequence rawPassword, String encodedPassword) { - return matches(rawPassword == null ? null : rawPassword.toString(), encodedPassword); + return matches((rawPassword != null) ? rawPassword.toString() : null, encodedPassword); } private boolean matches(String rawPassword, String encodedPassword) { String prefix = extractPrefix(encodedPassword); - if (prefix == null) { return PasswordEncoderUtils.equals(encodedPassword, rawPassword); } - - byte[] salt; - if (prefix.equals(SSHA_PREFIX) || prefix.equals(SSHA_PREFIX_LC)) { - salt = extractSalt(encodedPassword); - } - else if (!prefix.equals(SHA_PREFIX) && !prefix.equals(SHA_PREFIX_LC)) { - throw new IllegalArgumentException("Unsupported password prefix '" + prefix - + "'"); - } - else { - // Standard SHA - salt = null; - } - + byte[] salt = getSalt(encodedPassword, prefix); int startOfHash = prefix.length(); - String encodedRawPass = encode(rawPassword, salt).substring(startOfHash); + return PasswordEncoderUtils.equals(encodedRawPass, encodedPassword.substring(startOfHash)); + } - return PasswordEncoderUtils - .equals(encodedRawPass, encodedPassword.substring(startOfHash)); + private byte[] getSalt(String encodedPassword, String prefix) { + if (prefix.equals(SSHA_PREFIX) || prefix.equals(SSHA_PREFIX_LC)) { + return extractSalt(encodedPassword); + } + if (!prefix.equals(SHA_PREFIX) && !prefix.equals(SHA_PREFIX_LC)) { + throw new IllegalArgumentException("Unsupported password prefix '" + prefix + "'"); + } + // Standard SHA + return null; } /** @@ -193,18 +173,15 @@ public class LdapShaPasswordEncoder implements PasswordEncoder { if (!encPass.startsWith("{")) { return null; } - int secondBrace = encPass.lastIndexOf('}'); - if (secondBrace < 0) { - throw new IllegalArgumentException( - "Couldn't find closing brace for SHA prefix"); + throw new IllegalArgumentException("Couldn't find closing brace for SHA prefix"); } - return encPass.substring(0, secondBrace + 1); } public void setForceLowerCasePrefix(boolean forceLowerCasePrefix) { this.forceLowerCasePrefix = forceLowerCasePrefix; } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/Md4.java b/crypto/src/main/java/org/springframework/security/crypto/password/Md4.java index 89c4db51b0..d201531ed0 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/password/Md4.java +++ b/crypto/src/main/java/org/springframework/security/crypto/password/Md4.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; /** @@ -22,28 +23,35 @@ package org.springframework.security.crypto.password; * @author Alan Stewart */ class Md4 { + private static final int BLOCK_SIZE = 64; + private static final int HASH_SIZE = 16; + private final byte[] buffer = new byte[BLOCK_SIZE]; + private int bufferOffset; + private long byteCount; + private final int[] state = new int[4]; + private final int[] tmp = new int[16]; Md4() { reset(); } - public void reset() { - bufferOffset = 0; - byteCount = 0; - state[0] = 0x67452301; - state[1] = 0xEFCDAB89; - state[2] = 0x98BADCFE; - state[3] = 0x10325476; + void reset() { + this.bufferOffset = 0; + this.byteCount = 0; + this.state[0] = 0x67452301; + this.state[1] = 0xEFCDAB89; + this.state[2] = 0x98BADCFE; + this.state[3] = 0x10325476; } - public byte[] digest() { + byte[] digest() { byte[] resBuf = new byte[HASH_SIZE]; digest(resBuf, 0, HASH_SIZE); return resBuf; @@ -52,7 +60,7 @@ class Md4 { private void digest(byte[] buffer, int off) { for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { - buffer[off + (i * 4 + j)] = (byte) (state[i] >>> (8 * j)); + buffer[off + (i * 4 + j)] = (byte) (this.state[i] >>> (8 * j)); } } } @@ -68,22 +76,19 @@ class Md4 { update(this.buffer, 0); this.bufferOffset = 0; } - while (this.bufferOffset < C) { this.buffer[this.bufferOffset++] = (byte) 0x00; } - - long bitCount = byteCount * 8; + long bitCount = this.byteCount * 8; for (int i = 0; i < 64; i += 8) { this.buffer[this.bufferOffset++] = (byte) (bitCount >>> (i)); } - update(this.buffer, 0); digest(buffer, offset); } - public void update(byte[] input, int offset, int length) { - byteCount += length; + void update(byte[] input, int offset, int length) { + this.byteCount += length; int todo; while (length >= (todo = BLOCK_SIZE - this.bufferOffset)) { System.arraycopy(input, offset, this.buffer, this.bufferOffset, todo); @@ -94,75 +99,75 @@ class Md4 { } System.arraycopy(input, offset, this.buffer, this.bufferOffset, length); - bufferOffset += length; + this.bufferOffset += length; } private void update(byte[] block, int offset) { for (int i = 0; i < 16; i++) { - tmp[i] = (block[offset++] & 0xFF) | (block[offset++] & 0xFF) << 8 - | (block[offset++] & 0xFF) << 16 | (block[offset++] & 0xFF) << 24; + this.tmp[i] = (block[offset++] & 0xFF) | (block[offset++] & 0xFF) << 8 | (block[offset++] & 0xFF) << 16 + | (block[offset++] & 0xFF) << 24; } - int A = state[0]; - int B = state[1]; - int C = state[2]; - int D = state[3]; + int A = this.state[0]; + int B = this.state[1]; + int C = this.state[2]; + int D = this.state[3]; - A = FF(A, B, C, D, tmp[0], 3); - D = FF(D, A, B, C, tmp[1], 7); - C = FF(C, D, A, B, tmp[2], 11); - B = FF(B, C, D, A, tmp[3], 19); - A = FF(A, B, C, D, tmp[4], 3); - D = FF(D, A, B, C, tmp[5], 7); - C = FF(C, D, A, B, tmp[6], 11); - B = FF(B, C, D, A, tmp[7], 19); - A = FF(A, B, C, D, tmp[8], 3); - D = FF(D, A, B, C, tmp[9], 7); - C = FF(C, D, A, B, tmp[10], 11); - B = FF(B, C, D, A, tmp[11], 19); - A = FF(A, B, C, D, tmp[12], 3); - D = FF(D, A, B, C, tmp[13], 7); - C = FF(C, D, A, B, tmp[14], 11); - B = FF(B, C, D, A, tmp[15], 19); + A = FF(A, B, C, D, this.tmp[0], 3); + D = FF(D, A, B, C, this.tmp[1], 7); + C = FF(C, D, A, B, this.tmp[2], 11); + B = FF(B, C, D, A, this.tmp[3], 19); + A = FF(A, B, C, D, this.tmp[4], 3); + D = FF(D, A, B, C, this.tmp[5], 7); + C = FF(C, D, A, B, this.tmp[6], 11); + B = FF(B, C, D, A, this.tmp[7], 19); + A = FF(A, B, C, D, this.tmp[8], 3); + D = FF(D, A, B, C, this.tmp[9], 7); + C = FF(C, D, A, B, this.tmp[10], 11); + B = FF(B, C, D, A, this.tmp[11], 19); + A = FF(A, B, C, D, this.tmp[12], 3); + D = FF(D, A, B, C, this.tmp[13], 7); + C = FF(C, D, A, B, this.tmp[14], 11); + B = FF(B, C, D, A, this.tmp[15], 19); - A = GG(A, B, C, D, tmp[0], 3); - D = GG(D, A, B, C, tmp[4], 5); - C = GG(C, D, A, B, tmp[8], 9); - B = GG(B, C, D, A, tmp[12], 13); - A = GG(A, B, C, D, tmp[1], 3); - D = GG(D, A, B, C, tmp[5], 5); - C = GG(C, D, A, B, tmp[9], 9); - B = GG(B, C, D, A, tmp[13], 13); - A = GG(A, B, C, D, tmp[2], 3); - D = GG(D, A, B, C, tmp[6], 5); - C = GG(C, D, A, B, tmp[10], 9); - B = GG(B, C, D, A, tmp[14], 13); - A = GG(A, B, C, D, tmp[3], 3); - D = GG(D, A, B, C, tmp[7], 5); - C = GG(C, D, A, B, tmp[11], 9); - B = GG(B, C, D, A, tmp[15], 13); + A = GG(A, B, C, D, this.tmp[0], 3); + D = GG(D, A, B, C, this.tmp[4], 5); + C = GG(C, D, A, B, this.tmp[8], 9); + B = GG(B, C, D, A, this.tmp[12], 13); + A = GG(A, B, C, D, this.tmp[1], 3); + D = GG(D, A, B, C, this.tmp[5], 5); + C = GG(C, D, A, B, this.tmp[9], 9); + B = GG(B, C, D, A, this.tmp[13], 13); + A = GG(A, B, C, D, this.tmp[2], 3); + D = GG(D, A, B, C, this.tmp[6], 5); + C = GG(C, D, A, B, this.tmp[10], 9); + B = GG(B, C, D, A, this.tmp[14], 13); + A = GG(A, B, C, D, this.tmp[3], 3); + D = GG(D, A, B, C, this.tmp[7], 5); + C = GG(C, D, A, B, this.tmp[11], 9); + B = GG(B, C, D, A, this.tmp[15], 13); - A = HH(A, B, C, D, tmp[0], 3); - D = HH(D, A, B, C, tmp[8], 9); - C = HH(C, D, A, B, tmp[4], 11); - B = HH(B, C, D, A, tmp[12], 15); - A = HH(A, B, C, D, tmp[2], 3); - D = HH(D, A, B, C, tmp[10], 9); - C = HH(C, D, A, B, tmp[6], 11); - B = HH(B, C, D, A, tmp[14], 15); - A = HH(A, B, C, D, tmp[1], 3); - D = HH(D, A, B, C, tmp[9], 9); - C = HH(C, D, A, B, tmp[5], 11); - B = HH(B, C, D, A, tmp[13], 15); - A = HH(A, B, C, D, tmp[3], 3); - D = HH(D, A, B, C, tmp[11], 9); - C = HH(C, D, A, B, tmp[7], 11); - B = HH(B, C, D, A, tmp[15], 15); + A = HH(A, B, C, D, this.tmp[0], 3); + D = HH(D, A, B, C, this.tmp[8], 9); + C = HH(C, D, A, B, this.tmp[4], 11); + B = HH(B, C, D, A, this.tmp[12], 15); + A = HH(A, B, C, D, this.tmp[2], 3); + D = HH(D, A, B, C, this.tmp[10], 9); + C = HH(C, D, A, B, this.tmp[6], 11); + B = HH(B, C, D, A, this.tmp[14], 15); + A = HH(A, B, C, D, this.tmp[1], 3); + D = HH(D, A, B, C, this.tmp[9], 9); + C = HH(C, D, A, B, this.tmp[5], 11); + B = HH(B, C, D, A, this.tmp[13], 15); + A = HH(A, B, C, D, this.tmp[3], 3); + D = HH(D, A, B, C, this.tmp[11], 9); + C = HH(C, D, A, B, this.tmp[7], 11); + B = HH(B, C, D, A, this.tmp[15], 15); - state[0] += A; - state[1] += B; - state[2] += C; - state[3] += D; + this.state[0] += A; + this.state[1] += B; + this.state[2] += C; + this.state[3] += D; } private int FF(int a, int b, int c, int d, int x, int s) { @@ -179,4 +184,5 @@ class Md4 { int t = a + (b ^ c ^ d) + x + 0x6ED9EBA1; return t << s | t >>> (32 - s); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/Md4PasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/Md4PasswordEncoder.java index 7fe06a979d..f02583ae7f 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/password/Md4PasswordEncoder.java +++ b/crypto/src/main/java/org/springframework/security/crypto/password/Md4PasswordEncoder.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; +import java.util.Base64; + import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.codec.Utf8; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; -import java.util.Base64; - /** - * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered secure. + * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered + * secure. * * Encodes passwords using MD4. The general format of the password is: * @@ -33,8 +35,7 @@ import java.util.Base64; * * * Such that "salt" is the salt, md4 is the digest method, and password is the actual - * password. For example with a password of "password", and a salt of - * "thisissalt": + * password. For example with a password of "password", and a salt of "thisissalt": * *
          * String s = salt == null ? "" : "{" + salt + "}";
        @@ -55,9 +56,9 @@ import java.util.Base64;
          * "{}" + md4(password + "{}")
          * 
        * - * The format is intended to work with the Md4PasswordEncoder that was found in the - * Spring Security core module. However, the passwords will need to be migrated to include - * any salt with the password since this API provides Salt internally vs making it the + * The format is intended to work with the Md4PasswordEncoder that was found in the Spring + * Security core module. However, the passwords will need to be migrated to include any + * salt with the password since this API provides Salt internally vs making it the * responsibility of the user. To migrate passwords from the SaltSource use the following: * *
        @@ -73,16 +74,19 @@ import java.util.Base64;
          * @deprecated Digest based password encoding is not considered secure. Instead use an
          * adaptive one way function like BCryptPasswordEncoder, Pbkdf2PasswordEncoder, or
          * SCryptPasswordEncoder. Even better use {@link DelegatingPasswordEncoder} which supports
        - * password upgrades. There are no plans to remove this support. It is deprecated to indicate
        - * that this is a legacy implementation and using it is considered insecure.
        + * password upgrades. There are no plans to remove this support. It is deprecated to
        + * indicate that this is a legacy implementation and using it is considered insecure.
          */
         @Deprecated
         public class Md4PasswordEncoder implements PasswordEncoder {
        -	private static final String PREFIX = "{";
        -	private static final String SUFFIX = "}";
        -	private StringKeyGenerator saltGenerator = new Base64StringKeyGenerator();
        -	private boolean encodeHashAsBase64;
         
        +	private static final String PREFIX = "{";
        +
        +	private static final String SUFFIX = "}";
        +
        +	private StringKeyGenerator saltGenerator = new Base64StringKeyGenerator();
        +
        +	private boolean encodeHashAsBase64;
         
         	public void setEncodeHashAsBase64(boolean encodeHashAsBase64) {
         		this.encodeHashAsBase64 = encodeHashAsBase64;
        @@ -91,11 +95,11 @@ public class Md4PasswordEncoder implements PasswordEncoder {
         	/**
         	 * Encodes the rawPass using a MessageDigest. If a salt is specified it will be merged
         	 * with the password before encoding.
        -	 *
         	 * @param rawPassword The plain text password
         	 * @return Hex string of password digest (or base64 encoded string if
         	 * encodeHashAsBase64 is enabled.
         	 */
        +	@Override
         	public String encode(CharSequence rawPassword) {
         		String salt = PREFIX + this.saltGenerator.generateKey() + SUFFIX;
         		return digest(salt, rawPassword);
        @@ -107,10 +111,8 @@ public class Md4PasswordEncoder implements PasswordEncoder {
         		}
         		String saltedPassword = rawPassword + salt;
         		byte[] saltedPasswordBytes = Utf8.encode(saltedPassword);
        -
         		Md4 md4 = new Md4();
         		md4.update(saltedPasswordBytes, 0, saltedPasswordBytes.length);
        -
         		byte[] digest = md4.digest();
         		String encoded = encode(digest);
         		return salt + encoded;
        @@ -120,19 +122,17 @@ public class Md4PasswordEncoder implements PasswordEncoder {
         		if (this.encodeHashAsBase64) {
         			return Utf8.decode(Base64.getEncoder().encode(digest));
         		}
        -		else {
        -			return new String(Hex.encode(digest));
        -		}
        +		return new String(Hex.encode(digest));
         	}
         
         	/**
         	 * Takes a previously encoded password and compares it with a rawpassword after mixing
         	 * in the salt and encoding that value
        -	 *
         	 * @param rawPassword plain text password
         	 * @param encodedPassword previously encoded password
         	 * @return true or false
         	 */
        +	@Override
         	public boolean matches(CharSequence rawPassword, String encodedPassword) {
         		String salt = extractSalt(encodedPassword);
         		String rawPasswordEncoded = digest(salt, rawPassword);
        @@ -150,4 +150,5 @@ public class Md4PasswordEncoder implements PasswordEncoder {
         		}
         		return prefixEncodedPassword.substring(start, end + 1);
         	}
        +
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoder.java
        index 620bd6bf95..cdc6032677 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoder.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoder.java
        @@ -13,18 +13,20 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.password;
         
        +import java.security.MessageDigest;
        +import java.util.Base64;
        +
         import org.springframework.security.crypto.codec.Hex;
         import org.springframework.security.crypto.codec.Utf8;
         import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
         import org.springframework.security.crypto.keygen.StringKeyGenerator;
         
        -import java.security.MessageDigest;
        -import java.util.Base64;
        -
         /**
        - * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered secure.
        + * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered
        + * secure.
          *
          * Encodes passwords using the passed in {@link MessageDigest}.
          *
        @@ -76,14 +78,18 @@ import java.util.Base64;
          * @deprecated Digest based password encoding is not considered secure. Instead use an
          * adaptive one way function like BCryptPasswordEncoder, Pbkdf2PasswordEncoder, or
          * SCryptPasswordEncoder. Even better use {@link DelegatingPasswordEncoder} which supports
        - * password upgrades. There are no plans to remove this support. It is deprecated to indicate
        - * that this is a legacy implementation and using it is considered insecure.
        + * password upgrades. There are no plans to remove this support. It is deprecated to
        + * indicate that this is a legacy implementation and using it is considered insecure.
          */
         @Deprecated
         public class MessageDigestPasswordEncoder implements PasswordEncoder {
        +
         	private static final String PREFIX = "{";
        +
         	private static final String SUFFIX = "}";
        +
         	private StringKeyGenerator saltGenerator = new Base64StringKeyGenerator();
        +
         	private boolean encodeHashAsBase64;
         
         	private Digester digester;
        @@ -92,7 +98,6 @@ public class MessageDigestPasswordEncoder implements PasswordEncoder {
         	 * The digest algorithm to use Supports the named
         	 * 
         	 * Message Digest Algorithms in the Java environment.
        -	 *
         	 * @param algorithm
         	 */
         	public MessageDigestPasswordEncoder(String algorithm) {
        @@ -106,11 +111,11 @@ public class MessageDigestPasswordEncoder implements PasswordEncoder {
         	/**
         	 * Encodes the rawPass using a MessageDigest. If a salt is specified it will be merged
         	 * with the password before encoding.
        -	 *
         	 * @param rawPassword The plain text password
         	 * @return Hex string of password digest (or base64 encoded string if
         	 * encodeHashAsBase64 is enabled.
         	 */
        +	@Override
         	public String encode(CharSequence rawPassword) {
         		String salt = PREFIX + this.saltGenerator.generateKey() + SUFFIX;
         		return digest(salt, rawPassword);
        @@ -118,7 +123,6 @@ public class MessageDigestPasswordEncoder implements PasswordEncoder {
         
         	private String digest(String salt, CharSequence rawPassword) {
         		String saltedPassword = rawPassword + salt;
        -
         		byte[] digest = this.digester.digest(Utf8.encode(saltedPassword));
         		String encoded = encode(digest);
         		return salt + encoded;
        @@ -128,19 +132,17 @@ public class MessageDigestPasswordEncoder implements PasswordEncoder {
         		if (this.encodeHashAsBase64) {
         			return Utf8.decode(Base64.getEncoder().encode(digest));
         		}
        -		else {
        -			return new String(Hex.encode(digest));
        -		}
        +		return new String(Hex.encode(digest));
         	}
         
         	/**
         	 * Takes a previously encoded password and compares it with a rawpassword after mixing
         	 * in the salt and encoding that value
        -	 *
         	 * @param rawPassword plain text password
         	 * @param encodedPassword previously encoded password
         	 * @return true or false
         	 */
        +	@Override
         	public boolean matches(CharSequence rawPassword, String encodedPassword) {
         		String salt = extractSalt(encodedPassword);
         		String rawPasswordEncoded = digest(salt, rawPassword);
        @@ -152,7 +154,6 @@ public class MessageDigestPasswordEncoder implements PasswordEncoder {
         	 * "stretched". If this is greater than one, the initial digest is calculated, the
         	 * digest function will be called repeatedly on the result for the additional number
         	 * of iterations.
        -	 *
         	 * @param iterations the number of iterations which will be executed on the hashed
         	 * password/salt value. Defaults to 1.
         	 */
        @@ -171,4 +172,5 @@ public class MessageDigestPasswordEncoder implements PasswordEncoder {
         		}
         		return prefixEncodedPassword.substring(start, end + 1);
         	}
        +
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/NoOpPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/NoOpPasswordEncoder.java
        index 3c28e16741..d092acf3ee 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/password/NoOpPasswordEncoder.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/password/NoOpPasswordEncoder.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.password;
         
         /**
        @@ -23,19 +24,26 @@ package org.springframework.security.crypto.password;
          * passwords may be preferred.
          *
          * @author Keith Donald
        - * @deprecated This PasswordEncoder is not secure. Instead use an
        - * adaptive one way function like BCryptPasswordEncoder, Pbkdf2PasswordEncoder, or
        - * SCryptPasswordEncoder. Even better use {@link DelegatingPasswordEncoder} which supports
        - * password upgrades. There are no plans to remove this support. It is deprecated to indicate that
        - * this is a legacy implementation and using it is considered insecure.
        + * @deprecated This PasswordEncoder is not secure. Instead use an adaptive one way
        + * function like BCryptPasswordEncoder, Pbkdf2PasswordEncoder, or SCryptPasswordEncoder.
        + * Even better use {@link DelegatingPasswordEncoder} which supports password upgrades.
        + * There are no plans to remove this support. It is deprecated to indicate that this is a
        + * legacy implementation and using it is considered insecure.
          */
         @Deprecated
         public final class NoOpPasswordEncoder implements PasswordEncoder {
         
        +	private static final PasswordEncoder INSTANCE = new NoOpPasswordEncoder();
        +
        +	private NoOpPasswordEncoder() {
        +	}
        +
        +	@Override
         	public String encode(CharSequence rawPassword) {
         		return rawPassword.toString();
         	}
         
        +	@Override
         	public boolean matches(CharSequence rawPassword, String encodedPassword) {
         		return rawPassword.toString().equals(encodedPassword);
         	}
        @@ -47,9 +55,4 @@ public final class NoOpPasswordEncoder implements PasswordEncoder {
         		return INSTANCE;
         	}
         
        -	private static final PasswordEncoder INSTANCE = new NoOpPasswordEncoder();
        -
        -	private NoOpPasswordEncoder() {
        -	}
        -
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoder.java
        index 251acd0251..3d3f8ee425 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoder.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoder.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.password;
         
         /**
        @@ -34,7 +35,6 @@ public interface PasswordEncoder {
         	 * Verify the encoded password obtained from storage matches the submitted raw
         	 * password after it too is encoded. Returns true if the passwords match, false if
         	 * they do not. The stored password itself is never decoded.
        -	 *
         	 * @param rawPassword the raw password to encode and match
         	 * @param encodedPassword the encoded password from storage to compare with
         	 * @return true if the raw password, after encoding, matches the encoded password from
        @@ -52,4 +52,5 @@ public interface PasswordEncoder {
         	default boolean upgradeEncoding(String encodedPassword) {
         		return false;
         	}
        +
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoderUtils.java b/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoderUtils.java
        index 6bf5b937ae..4e1a69a949 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoderUtils.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/password/PasswordEncoderUtils.java
        @@ -13,18 +13,22 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.password;
         
        -import org.springframework.security.crypto.codec.Utf8;
        -
         import java.security.MessageDigest;
         
        +import org.springframework.security.crypto.codec.Utf8;
        +
         /**
          * Utility for constant time comparison to prevent against timing attacks.
          *
          * @author Rob Winch
          */
        -class PasswordEncoderUtils {
        +final class PasswordEncoderUtils {
        +
        +	private PasswordEncoderUtils() {
        +	}
         
         	/**
         	 * Constant time comparison to prevent against timing attacks.
        @@ -35,18 +39,13 @@ class PasswordEncoderUtils {
         	static boolean equals(String expected, String actual) {
         		byte[] expectedBytes = bytesUtf8(expected);
         		byte[] actualBytes = bytesUtf8(actual);
        -
         		return MessageDigest.isEqual(expectedBytes, actualBytes);
         	}
         
         	private static byte[] bytesUtf8(String s) {
        -		if (s == null) {
        -			return null;
        -		}
        -
        -		return Utf8.encode(s); // need to check if Utf8.encode() runs in constant time (probably not). This may leak length of string.
        +		// need to check if Utf8.encode() runs in constant time (probably not).
        +		// This may leak length of string.
        +		return (s != null) ? Utf8.encode(s) : null;
         	}
         
        -	private PasswordEncoderUtils() {
        -	}
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoder.java
        index 3fa38e6774..73e660417c 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoder.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoder.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.password;
         
         import java.security.GeneralSecurityException;
        @@ -27,9 +28,7 @@ import org.springframework.security.crypto.codec.Hex;
         import org.springframework.security.crypto.codec.Utf8;
         import org.springframework.security.crypto.keygen.BytesKeyGenerator;
         import org.springframework.security.crypto.keygen.KeyGenerators;
        -
        -import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
        -import static org.springframework.security.crypto.util.EncodingUtils.subArray;
        +import org.springframework.security.crypto.util.EncodingUtils;
         
         /**
          * A {@code PasswordEncoder} implementation that uses PBKDF2 with a configurable number of
        @@ -46,21 +45,27 @@ import static org.springframework.security.crypto.util.EncodingUtils.subArray;
         public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         
         	private static final int DEFAULT_HASH_WIDTH = 256;
        +
         	private static final int DEFAULT_ITERATIONS = 185000;
         
         	private final BytesKeyGenerator saltGenerator = KeyGenerators.secureRandom();
         
         	private final byte[] secret;
        +
         	private final int hashWidth;
        +
         	private final int iterations;
        +
         	private String algorithm = SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA1.name();
        +
         	private boolean encodeHashAsBase64;
         
         	/**
         	 * Constructs a PBKDF2 password encoder with no additional secret value. There will be
        -	 * {@value DEFAULT_ITERATIONS} iterations and a hash width of {@value DEFAULT_HASH_WIDTH}. The default is based upon aiming for .5
        -	 * seconds to validate the password when this class was added.. Users should tune
        -	 * password verification to their own systems.
        +	 * {@value DEFAULT_ITERATIONS} iterations and a hash width of
        +	 * {@value DEFAULT_HASH_WIDTH}. The default is based upon aiming for .5 seconds to
        +	 * validate the password when this class was added.. Users should tune password
        +	 * verification to their own systems.
         	 */
         	public Pbkdf2PasswordEncoder() {
         		this("");
        @@ -68,8 +73,8 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         
         	/**
         	 * Constructs a standard password encoder with a secret value which is also included
        -	 * in the password hash. There will be {@value DEFAULT_ITERATIONS} iterations and a hash width of {@value DEFAULT_HASH_WIDTH}.
        -	 *
        +	 * in the password hash. There will be {@value DEFAULT_ITERATIONS} iterations and a
        +	 * hash width of {@value DEFAULT_HASH_WIDTH}.
         	 * @param secret the secret key used in the encoding process (should not be shared)
         	 */
         	public Pbkdf2PasswordEncoder(CharSequence secret) {
        @@ -79,7 +84,6 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         	/**
         	 * Constructs a standard password encoder with a secret value as well as iterations
         	 * and hash.
        -	 *
         	 * @param secret the secret
         	 * @param iterations the number of iterations. Users should aim for taking about .5
         	 * seconds on their own system.
        @@ -92,8 +96,9 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         	}
         
         	/**
        -	 * Sets the algorithm to use. See
        -	 * SecretKeyFactory Algorithms
        +	 * Sets the algorithm to use. See SecretKeyFactory
        +	 * Algorithms
         	 * @param secretKeyFactoryAlgorithm the algorithm to use (i.e.
         	 * {@code SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA1},
         	 * {@code SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA256},
        @@ -107,11 +112,11 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         		String algorithmName = secretKeyFactoryAlgorithm.name();
         		try {
         			SecretKeyFactory.getInstance(algorithmName);
        +			this.algorithm = algorithmName;
         		}
        -		catch (NoSuchAlgorithmException e) {
        -			throw new IllegalArgumentException("Invalid algorithm '" + algorithmName + "'.", e);
        +		catch (NoSuchAlgorithmException ex) {
        +			throw new IllegalArgumentException("Invalid algorithm '" + algorithmName + "'.", ex);
         		}
        -		this.algorithm = algorithmName;
         	}
         
         	/**
        @@ -141,7 +146,7 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         	@Override
         	public boolean matches(CharSequence rawPassword, String encodedPassword) {
         		byte[] digested = decode(encodedPassword);
        -		byte[] salt = subArray(digested, 0, this.saltGenerator.getKeyLength());
        +		byte[] salt = EncodingUtils.subArray(digested, 0, this.saltGenerator.getKeyLength());
         		return MessageDigest.isEqual(digested, encode(rawPassword, salt));
         	}
         
        @@ -155,12 +160,12 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         	private byte[] encode(CharSequence rawPassword, byte[] salt) {
         		try {
         			PBEKeySpec spec = new PBEKeySpec(rawPassword.toString().toCharArray(),
        -					concatenate(salt, this.secret), this.iterations, this.hashWidth);
        +					EncodingUtils.concatenate(salt, this.secret), this.iterations, this.hashWidth);
         			SecretKeyFactory skf = SecretKeyFactory.getInstance(this.algorithm);
        -			return concatenate(salt, skf.generateSecret(spec).getEncoded());
        +			return EncodingUtils.concatenate(salt, skf.generateSecret(spec).getEncoded());
         		}
        -		catch (GeneralSecurityException e) {
        -			throw new IllegalStateException("Could not create hash", e);
        +		catch (GeneralSecurityException ex) {
        +			throw new IllegalStateException("Could not create hash", ex);
         		}
         	}
         
        @@ -170,8 +175,9 @@ public class Pbkdf2PasswordEncoder implements PasswordEncoder {
         	 * @since 5.0
         	 */
         	public enum SecretKeyFactoryAlgorithm {
        -		PBKDF2WithHmacSHA1,
        -		PBKDF2WithHmacSHA256,
        -		PBKDF2WithHmacSHA512
        +
        +		PBKDF2WithHmacSHA1, PBKDF2WithHmacSHA256, PBKDF2WithHmacSHA512
        +
         	}
        +
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/password/StandardPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/password/StandardPasswordEncoder.java
        index c07fb5bf9f..e19e445ce2 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/password/StandardPasswordEncoder.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/password/StandardPasswordEncoder.java
        @@ -13,17 +13,16 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.password;
         
        -import static org.springframework.security.crypto.util.EncodingUtils.concatenate;
        -import static org.springframework.security.crypto.util.EncodingUtils.subArray;
        +import java.security.MessageDigest;
         
         import org.springframework.security.crypto.codec.Hex;
         import org.springframework.security.crypto.codec.Utf8;
         import org.springframework.security.crypto.keygen.BytesKeyGenerator;
         import org.springframework.security.crypto.keygen.KeyGenerators;
        -
        -import java.security.MessageDigest;
        +import org.springframework.security.crypto.util.EncodingUtils;
         
         /**
          * This {@link PasswordEncoder} is provided for legacy purposes only and is not considered
        @@ -45,12 +44,14 @@ import java.security.MessageDigest;
          * @deprecated Digest based password encoding is not considered secure. Instead use an
          * adaptive one way function like BCryptPasswordEncoder, Pbkdf2PasswordEncoder, or
          * SCryptPasswordEncoder. Even better use {@link DelegatingPasswordEncoder} which supports
        - * password upgrades. There are no plans to remove this support. It is deprecated to indicate
        - * that this is a legacy implementation and using it is considered insecure.
        + * password upgrades. There are no plans to remove this support. It is deprecated to
        + * indicate that this is a legacy implementation and using it is considered insecure.
          */
         @Deprecated
         public final class StandardPasswordEncoder implements PasswordEncoder {
         
        +	private static final int DEFAULT_ITERATIONS = 1024;
        +
         	private final Digester digester;
         
         	private final byte[] secret;
        @@ -67,25 +68,24 @@ public final class StandardPasswordEncoder implements PasswordEncoder {
         	/**
         	 * Constructs a standard password encoder with a secret value which is also included
         	 * in the password hash.
        -	 *
         	 * @param secret the secret key used in the encoding process (should not be shared)
         	 */
         	public StandardPasswordEncoder(CharSequence secret) {
         		this("SHA-256", secret);
         	}
         
        +	@Override
         	public String encode(CharSequence rawPassword) {
        -		return encode(rawPassword, saltGenerator.generateKey());
        +		return encode(rawPassword, this.saltGenerator.generateKey());
         	}
         
        +	@Override
         	public boolean matches(CharSequence rawPassword, String encodedPassword) {
         		byte[] digested = decode(encodedPassword);
        -		byte[] salt = subArray(digested, 0, saltGenerator.getKeyLength());
        +		byte[] salt = EncodingUtils.subArray(digested, 0, this.saltGenerator.getKeyLength());
         		return MessageDigest.isEqual(digested, digest(rawPassword, salt));
         	}
         
        -	// internal helpers
        -
         	private StandardPasswordEncoder(String algorithm, CharSequence secret) {
         		this.digester = new Digester(algorithm, DEFAULT_ITERATIONS);
         		this.secret = Utf8.encode(secret);
        @@ -98,15 +98,12 @@ public final class StandardPasswordEncoder implements PasswordEncoder {
         	}
         
         	private byte[] digest(CharSequence rawPassword, byte[] salt) {
        -		byte[] digest = digester.digest(concatenate(salt, secret,
        -				Utf8.encode(rawPassword)));
        -		return concatenate(salt, digest);
        +		byte[] digest = this.digester.digest(EncodingUtils.concatenate(salt, this.secret, Utf8.encode(rawPassword)));
        +		return EncodingUtils.concatenate(salt, digest);
         	}
         
         	private byte[] decode(CharSequence encodedPassword) {
         		return Hex.decode(encodedPassword);
         	}
         
        -	private static final int DEFAULT_ITERATIONS = 1024;
        -
         }
        diff --git a/crypto/src/main/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoder.java b/crypto/src/main/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoder.java
        index 189c0ab266..98bafd4be2 100644
        --- a/crypto/src/main/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoder.java
        +++ b/crypto/src/main/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoder.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.crypto.scrypt;
         
         import java.security.MessageDigest;
        @@ -21,6 +22,7 @@ import java.util.Base64;
         import org.apache.commons.logging.Log;
         import org.apache.commons.logging.LogFactory;
         import org.bouncycastle.crypto.generators.SCrypt;
        +
         import org.springframework.security.crypto.codec.Utf8;
         import org.springframework.security.crypto.keygen.BytesKeyGenerator;
         import org.springframework.security.crypto.keygen.KeyGenerators;
        @@ -28,9 +30,9 @@ import org.springframework.security.crypto.password.PasswordEncoder;
         
         /**
          * 

        - * Implementation of PasswordEncoder that uses the SCrypt hashing function. - * Clients can optionally supply a cpu cost parameter, a memory cost parameter - * and a parallelization parameter. + * Implementation of PasswordEncoder that uses the SCrypt hashing function. Clients can + * optionally supply a cpu cost parameter, a memory cost parameter and a parallelization + * parameter. *

        * *

        @@ -41,13 +43,13 @@ import org.springframework.security.crypto.password.PasswordEncoder; * *

          *
        • The currently implementation uses Bouncy castle which does not exploit - * parallelism/optimizations that password crackers will, so there is an - * unnecessary asymmetry between attacker and defender.
        • - *
        • Scrypt is based on Salsa20 which performs poorly in Java (on par with - * AES) but performs awesome (~4-5x faster) on SIMD capable platforms
        • + * parallelism/optimizations that password crackers will, so there is an unnecessary + * asymmetry between attacker and defender. + *
        • Scrypt is based on Salsa20 which performs poorly in Java (on par with AES) but + * performs awesome (~4-5x faster) on SIMD capable platforms
        • *
        • While there are some that would disagree, consider reading - - * - * Why I Don't Recommend Scrypt (for password storage)
        • + * Why I + * Don't Recommend Scrypt (for password storage) *
        * * @author Shazin Sadakath @@ -74,25 +76,17 @@ public class SCryptPasswordEncoder implements PasswordEncoder { /** * Creates a new instance - * - * @param cpuCost - * cpu cost of the algorithm (as defined in scrypt this is N). - * must be power of 2 greater than 1. Default is currently 16,384 - * or 2^14) - * @param memoryCost - * memory cost of the algorithm (as defined in scrypt this is r) - * Default is currently 8. - * @param parallelization - * the parallelization of the algorithm (as defined in scrypt - * this is p) Default is currently 1. Note that the - * implementation does not currently take advantage of - * parallelization. - * @param keyLength - * key length for the algorithm (as defined in scrypt this is - * dkLen). The default is currently 32. - * @param saltLength - * salt length (as defined in scrypt this is the length of S). - * The default is currently 64. + * @param cpuCost cpu cost of the algorithm (as defined in scrypt this is N). must be + * power of 2 greater than 1. Default is currently 16,384 or 2^14) + * @param memoryCost memory cost of the algorithm (as defined in scrypt this is r) + * Default is currently 8. + * @param parallelization the parallelization of the algorithm (as defined in scrypt + * this is p) Default is currently 1. Note that the implementation does not currently + * take advantage of parallelization. + * @param keyLength key length for the algorithm (as defined in scrypt this is dkLen). + * The default is currently 32. + * @param saltLength salt length (as defined in scrypt this is the length of S). The + * default is currently 64. */ public SCryptPasswordEncoder(int cpuCost, int memoryCost, int parallelization, int keyLength, int saltLength) { if (cpuCost <= 1) { @@ -115,7 +109,6 @@ public class SCryptPasswordEncoder implements PasswordEncoder { if (saltLength < 1 || saltLength > Integer.MAX_VALUE) { throw new IllegalArgumentException("Salt length must be >= 1 and <= " + Integer.MAX_VALUE); } - this.cpuCost = cpuCost; this.memoryCost = memoryCost; this.parallelization = parallelization; @@ -123,13 +116,15 @@ public class SCryptPasswordEncoder implements PasswordEncoder { this.saltGenerator = KeyGenerators.secureRandom(saltLength); } + @Override public String encode(CharSequence rawPassword) { - return digest(rawPassword, saltGenerator.generateKey()); + return digest(rawPassword, this.saltGenerator.generateKey()); } + @Override public boolean matches(CharSequence rawPassword, String encodedPassword) { - if (encodedPassword == null || encodedPassword.length() < keyLength) { - logger.warn("Empty encoded password"); + if (encodedPassword == null || encodedPassword.length() < this.keyLength) { + this.logger.warn("Empty encoded password"); return false; } return decodeAndCheckMatches(rawPassword, encodedPassword); @@ -140,56 +135,43 @@ public class SCryptPasswordEncoder implements PasswordEncoder { if (encodedPassword == null || encodedPassword.isEmpty()) { return false; } - String[] parts = encodedPassword.split("\\$"); - if (parts.length != 4) { throw new IllegalArgumentException("Encoded password does not look like SCrypt: " + encodedPassword); } - long params = Long.parseLong(parts[1], 16); - int cpuCost = (int) Math.pow(2, params >> 16 & 0xffff); int memoryCost = (int) params >> 8 & 0xff; int parallelization = (int) params & 0xff; - - return cpuCost < this.cpuCost - || memoryCost < this.memoryCost - || parallelization < this.parallelization; + return cpuCost < this.cpuCost || memoryCost < this.memoryCost || parallelization < this.parallelization; } private boolean decodeAndCheckMatches(CharSequence rawPassword, String encodedPassword) { String[] parts = encodedPassword.split("\\$"); - if (parts.length != 4) { return false; } - long params = Long.parseLong(parts[1], 16); byte[] salt = decodePart(parts[2]); byte[] derived = decodePart(parts[3]); - int cpuCost = (int) Math.pow(2, params >> 16 & 0xffff); int memoryCost = (int) params >> 8 & 0xff; int parallelization = (int) params & 0xff; - byte[] generated = SCrypt.generate(Utf8.encode(rawPassword), salt, cpuCost, memoryCost, parallelization, - keyLength); - + this.keyLength); return MessageDigest.isEqual(derived, generated); } private String digest(CharSequence rawPassword, byte[] salt) { - byte[] derived = SCrypt.generate(Utf8.encode(rawPassword), salt, cpuCost, memoryCost, parallelization, keyLength); - - String params = Long - .toString(((int) (Math.log(cpuCost) / Math.log(2)) << 16L) | memoryCost << 8 | parallelization, 16); - + byte[] derived = SCrypt.generate(Utf8.encode(rawPassword), salt, this.cpuCost, this.memoryCost, + this.parallelization, this.keyLength); + String params = Long.toString( + ((int) (Math.log(this.cpuCost) / Math.log(2)) << 16L) | this.memoryCost << 8 | this.parallelization, + 16); StringBuilder sb = new StringBuilder((salt.length + derived.length) * 2); sb.append("$").append(params).append('$'); sb.append(encodePart(salt)).append('$'); sb.append(encodePart(derived)); - return sb.toString(); } @@ -200,4 +182,5 @@ public class SCryptPasswordEncoder implements PasswordEncoder { private String encodePart(byte[] part) { return Utf8.decode(Base64.getEncoder().encode(part)); } + } diff --git a/crypto/src/main/java/org/springframework/security/crypto/util/EncodingUtils.java b/crypto/src/main/java/org/springframework/security/crypto/util/EncodingUtils.java index 4a49aaedfc..f95bd7bd0b 100644 --- a/crypto/src/main/java/org/springframework/security/crypto/util/EncodingUtils.java +++ b/crypto/src/main/java/org/springframework/security/crypto/util/EncodingUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.util; /** @@ -22,7 +23,10 @@ package org.springframework.security.crypto.util; * * @author Keith Donald */ -public class EncodingUtils { +public final class EncodingUtils { + + private EncodingUtils() { + } /** * Combine the individual byte arrays into one array. @@ -54,7 +58,4 @@ public class EncodingUtils { return subarray; } - private EncodingUtils() { - } - } diff --git a/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2EncodingUtilsTests.java b/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2EncodingUtilsTests.java index 660bd5347f..5bab1e821f 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2EncodingUtilsTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2EncodingUtilsTests.java @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.argon2; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Base64; + import org.bouncycastle.crypto.params.Argon2Parameters; import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Simeon Macke */ @@ -30,48 +32,44 @@ public class Argon2EncodingUtilsTests { private TestDataEntry testDataEntry1 = new TestDataEntry( "$argon2i$v=19$m=1024,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs", - new Argon2EncodingUtils.Argon2Hash(decoder.decode("cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"), - (new Argon2Parameters.Builder(Argon2Parameters.ARGON2_i)). - withVersion(19).withMemoryAsKB(1024).withIterations(3).withParallelism(2). - withSalt("cRdFbCw23gz2Mlxk".getBytes()).build() - )); + new Argon2EncodingUtils.Argon2Hash(this.decoder.decode("cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"), + (new Argon2Parameters.Builder(Argon2Parameters.ARGON2_i)).withVersion(19).withMemoryAsKB(1024) + .withIterations(3).withParallelism(2).withSalt("cRdFbCw23gz2Mlxk".getBytes()).build())); private TestDataEntry testDataEntry2 = new TestDataEntry( "$argon2id$v=19$m=333,t=5,p=2$JDR8N3k1QWx0$+PrEoHOHsWkU9lnsxqnOFrWTVEuOh7ZRIUIbe2yUG8FgTYNCWJfHQI09JAAFKzr2JAvoejEpTMghUt0WsntQYA", - new Argon2EncodingUtils.Argon2Hash(decoder.decode("+PrEoHOHsWkU9lnsxqnOFrWTVEuOh7ZRIUIbe2yUG8FgTYNCWJfHQI09JAAFKzr2JAvoejEpTMghUt0WsntQYA"), - (new Argon2Parameters.Builder(Argon2Parameters.ARGON2_id)). - withVersion(19).withMemoryAsKB(333).withIterations(5).withParallelism(2). - withSalt("$4|7y5Alt".getBytes()).build() - )); + new Argon2EncodingUtils.Argon2Hash( + this.decoder.decode( + "+PrEoHOHsWkU9lnsxqnOFrWTVEuOh7ZRIUIbe2yUG8FgTYNCWJfHQI09JAAFKzr2JAvoejEpTMghUt0WsntQYA"), + (new Argon2Parameters.Builder(Argon2Parameters.ARGON2_id)).withVersion(19).withMemoryAsKB(333) + .withIterations(5).withParallelism(2).withSalt("$4|7y5Alt".getBytes()).build())); @Test public void decodeWhenValidEncodedHashWithIThenDecodeCorrectly() { - assertArgon2HashEquals(testDataEntry1.decoded, Argon2EncodingUtils.decode(testDataEntry1.encoded)); + assertArgon2HashEquals(this.testDataEntry1.decoded, Argon2EncodingUtils.decode(this.testDataEntry1.encoded)); } @Test public void decodeWhenValidEncodedHashWithIDThenDecodeCorrectly() { - assertArgon2HashEquals(testDataEntry2.decoded, Argon2EncodingUtils.decode(testDataEntry2.encoded)); + assertArgon2HashEquals(this.testDataEntry2.decoded, Argon2EncodingUtils.decode(this.testDataEntry2.encoded)); } @Test public void encodeWhenValidArgumentsWithIThenEncodeToCorrectHash() { - assertThat(Argon2EncodingUtils - .encode(testDataEntry1.decoded.getHash(), testDataEntry1.decoded.getParameters())) - .isEqualTo(testDataEntry1.encoded); + assertThat(Argon2EncodingUtils.encode(this.testDataEntry1.decoded.getHash(), + this.testDataEntry1.decoded.getParameters())).isEqualTo(this.testDataEntry1.encoded); } @Test public void encodeWhenValidArgumentsWithID2ThenEncodeToCorrectHash() { - assertThat(Argon2EncodingUtils - .encode(testDataEntry2.decoded.getHash(), testDataEntry2.decoded.getParameters())) - .isEqualTo(testDataEntry2.encoded); + assertThat(Argon2EncodingUtils.encode(this.testDataEntry2.decoded.getHash(), + this.testDataEntry2.decoded.getParameters())).isEqualTo(this.testDataEntry2.encoded); } @Test(expected = IllegalArgumentException.class) public void encodeWhenNonexistingAlgorithmThenThrowException() { - Argon2EncodingUtils.encode(new byte[]{0, 1, 2, 3}, (new Argon2Parameters.Builder(3)). - withVersion(19).withMemoryAsKB(333).withIterations(5).withParallelism(2).build()); + Argon2EncodingUtils.encode(new byte[] { 0, 1, 2, 3 }, (new Argon2Parameters.Builder(3)).withVersion(19) + .withMemoryAsKB(333).withIterations(5).withParallelism(2).build()); } @Test(expected = IllegalArgumentException.class) @@ -81,70 +79,80 @@ public class Argon2EncodingUtilsTests { @Test(expected = IllegalArgumentException.class) public void decodeWhenNonexistingAlgorithmThenThrowException() { - Argon2EncodingUtils.decode("$argon2x$v=19$m=1024,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils.decode( + "$argon2x$v=19$m=1024,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenIllegalVersionParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=x$m=1024,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils.decode( + "$argon2i$v=x$m=1024,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenIllegalMemoryParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=19$m=x,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils + .decode("$argon2i$v=19$m=x,t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenIllegalIterationsParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=19$m=1024,t=x,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils.decode( + "$argon2i$v=19$m=1024,t=x,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenIllegalParallelityParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=19$m=1024,t=3,p=x$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils.decode( + "$argon2i$v=19$m=1024,t=3,p=x$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenMissingVersionParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$m=1024,t=3,p=x$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils + .decode("$argon2i$m=1024,t=3,p=x$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenMissingMemoryParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=19$t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils + .decode("$argon2i$v=19$t=3,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenMissingIterationsParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=19$m=1024,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils + .decode("$argon2i$v=19$m=1024,p=2$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } @Test(expected = IllegalArgumentException.class) public void decodeWhenMissingParallelityParameterThenThrowException() { - Argon2EncodingUtils.decode("$argon2i$v=19$m=1024,t=3$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); + Argon2EncodingUtils + .decode("$argon2i$v=19$m=1024,t=3$Y1JkRmJDdzIzZ3oyTWx4aw$cGE5Cbd/cx7micVhXVBdH5qTr66JI1iUyuNNVAnErXs"); } - private void assertArgon2HashEquals(Argon2EncodingUtils.Argon2Hash expected, Argon2EncodingUtils.Argon2Hash actual) { + private void assertArgon2HashEquals(Argon2EncodingUtils.Argon2Hash expected, + Argon2EncodingUtils.Argon2Hash actual) { assertThat(actual.getHash()).isEqualTo(expected.getHash()); assertThat(actual.getParameters().getSalt()).isEqualTo(expected.getParameters().getSalt()); assertThat(actual.getParameters().getType()).isEqualTo(expected.getParameters().getType()); - assertThat(actual.getParameters().getVersion()) - .isEqualTo(expected.getParameters().getVersion()); - assertThat(actual.getParameters().getMemory()) - .isEqualTo(expected.getParameters().getMemory()); - assertThat(actual.getParameters().getIterations()) - .isEqualTo(expected.getParameters().getIterations()); - assertThat(actual.getParameters().getLanes()) - .isEqualTo(expected.getParameters().getLanes()); + assertThat(actual.getParameters().getVersion()).isEqualTo(expected.getParameters().getVersion()); + assertThat(actual.getParameters().getMemory()).isEqualTo(expected.getParameters().getMemory()); + assertThat(actual.getParameters().getIterations()).isEqualTo(expected.getParameters().getIterations()); + assertThat(actual.getParameters().getLanes()).isEqualTo(expected.getParameters().getLanes()); } private static class TestDataEntry { + String encoded; + Argon2EncodingUtils.Argon2Hash decoded; TestDataEntry(String encoded, Argon2EncodingUtils.Argon2Hash decoded) { this.encoded = encoded; this.decoded = decoded; } + } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoderTests.java index 229b43227f..23fde39954 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/argon2/Argon2PasswordEncoderTests.java @@ -13,19 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.argon2; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.crypto.argon2; import java.lang.reflect.Field; import java.util.Arrays; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.security.crypto.keygen.BytesKeyGenerator; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Simeon Macke */ @@ -39,54 +42,53 @@ public class Argon2PasswordEncoderTests { @Test public void encodeDoesNotEqualPassword() { - String result = encoder.encode("password"); + String result = this.encoder.encode("password"); assertThat(result).isNotEqualTo("password"); } @Test public void encodeWhenEqualPasswordThenMatches() { - String result = encoder.encode("password"); - assertThat(encoder.matches("password", result)).isTrue(); + String result = this.encoder.encode("password"); + assertThat(this.encoder.matches("password", result)).isTrue(); } @Test public void encodeWhenEqualWithUnicodeThenMatches() { - String result = encoder.encode("passw\u9292rd"); - assertThat(encoder.matches("pass\u9292\u9292rd", result)).isFalse(); - assertThat(encoder.matches("passw\u9292rd", result)).isTrue(); + String result = this.encoder.encode("passw\u9292rd"); + assertThat(this.encoder.matches("pass\u9292\u9292rd", result)).isFalse(); + assertThat(this.encoder.matches("passw\u9292rd", result)).isTrue(); } @Test public void encodeWhenNotEqualThenNotMatches() { - String result = encoder.encode("password"); - assertThat(encoder.matches("bogus", result)).isFalse(); + String result = this.encoder.encode("password"); + assertThat(this.encoder.matches("bogus", result)).isFalse(); } @Test public void encodeWhenEqualPasswordWithCustomParamsThenMatches() { - encoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); - String result = encoder.encode("password"); - assertThat(encoder.matches("password", result)).isTrue(); + this.encoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); + String result = this.encoder.encode("password"); + assertThat(this.encoder.matches("password", result)).isTrue(); } @Test public void encodeWhenRanTwiceThenResultsNotEqual() { String password = "secret"; - assertThat(encoder.encode(password)).isNotEqualTo(encoder.encode(password)); + assertThat(this.encoder.encode(password)).isNotEqualTo(this.encoder.encode(password)); } @Test public void encodeWhenRanTwiceWithCustomParamsThenNotEquals() { - encoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); + this.encoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); String password = "secret"; - assertThat(encoder.encode(password)).isNotEqualTo(encoder.encode(password)); + assertThat(this.encoder.encode(password)).isNotEqualTo(this.encoder.encode(password)); } @Test public void matchesWhenGeneratedWithDifferentEncoderThenTrue() { Argon2PasswordEncoder oldEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); Argon2PasswordEncoder newEncoder = new Argon2PasswordEncoder(); - String password = "secret"; String oldEncodedPassword = oldEncoder.encode(password); assertThat(newEncoder.matches(password, oldEncodedPassword)).isTrue(); @@ -94,52 +96,46 @@ public class Argon2PasswordEncoderTests { @Test public void matchesWhenEncodedPassIsNullThenFalse() { - assertThat(encoder.matches("password", null)).isFalse(); + assertThat(this.encoder.matches("password", null)).isFalse(); } @Test public void matchesWhenEncodedPassIsEmptyThenFalse() { - assertThat(encoder.matches("password", "")).isFalse(); + assertThat(this.encoder.matches("password", "")).isFalse(); } @Test public void matchesWhenEncodedPassIsBogusThenFalse() { - assertThat(encoder.matches("password", "012345678901234567890123456789")).isFalse(); + assertThat(this.encoder.matches("password", "012345678901234567890123456789")).isFalse(); } @Test public void encodeWhenUsingPredictableSaltThenEqualTestHash() throws Exception { injectPredictableSaltGen(); - - String hash = encoder.encode("sometestpassword"); - + String hash = this.encoder.encode("sometestpassword"); assertThat(hash).isEqualTo( "$argon2id$v=19$m=4096,t=3,p=1$QUFBQUFBQUFBQUFBQUFBQQ$hmmTNyJlwbb6HAvFoHFWF+u03fdb0F2qA+39oPlcAqo"); } @Test public void encodeWhenUsingPredictableSaltWithCustomParamsThenEqualTestHash() throws Exception { - encoder = new Argon2PasswordEncoder(16, 32, 4, 512, 5); + this.encoder = new Argon2PasswordEncoder(16, 32, 4, 512, 5); injectPredictableSaltGen(); - String hash = encoder.encode("sometestpassword"); - + String hash = this.encoder.encode("sometestpassword"); assertThat(hash).isEqualTo( "$argon2id$v=19$m=512,t=5,p=4$QUFBQUFBQUFBQUFBQUFBQQ$PNv4C3K50bz3rmON+LtFpdisD7ePieLNq+l5iUHgc1k"); } @Test public void upgradeEncodingWhenSameEncodingThenFalse() { - String hash = encoder.encode("password"); - - assertThat(encoder.upgradeEncoding(hash)).isFalse(); + String hash = this.encoder.encode("password"); + assertThat(this.encoder.upgradeEncoding(hash)).isFalse(); } @Test public void upgradeEncodingWhenSameStandardParamsThenFalse() { Argon2PasswordEncoder newEncoder = new Argon2PasswordEncoder(); - - String hash = encoder.encode("password"); - + String hash = this.encoder.encode("password"); assertThat(newEncoder.upgradeEncoding(hash)).isFalse(); } @@ -147,9 +143,7 @@ public class Argon2PasswordEncoderTests { public void upgradeEncodingWhenSameCustomParamsThenFalse() { Argon2PasswordEncoder oldEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); Argon2PasswordEncoder newEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); - String hash = oldEncoder.encode("password"); - assertThat(newEncoder.upgradeEncoding(hash)).isFalse(); } @@ -157,9 +151,7 @@ public class Argon2PasswordEncoderTests { public void upgradeEncodingWhenHashHasLowerMemoryThenTrue() { Argon2PasswordEncoder oldEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); Argon2PasswordEncoder newEncoder = new Argon2PasswordEncoder(20, 64, 4, 512, 4); - String hash = oldEncoder.encode("password"); - assertThat(newEncoder.upgradeEncoding(hash)).isTrue(); } @@ -167,9 +159,7 @@ public class Argon2PasswordEncoderTests { public void upgradeEncodingWhenHashHasLowerIterationsThenTrue() { Argon2PasswordEncoder oldEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); Argon2PasswordEncoder newEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 5); - String hash = oldEncoder.encode("password"); - assertThat(newEncoder.upgradeEncoding(hash)).isTrue(); } @@ -177,38 +167,36 @@ public class Argon2PasswordEncoderTests { public void upgradeEncodingWhenHashHasHigherParamsThenFalse() { Argon2PasswordEncoder oldEncoder = new Argon2PasswordEncoder(20, 64, 4, 256, 4); Argon2PasswordEncoder newEncoder = new Argon2PasswordEncoder(20, 64, 4, 128, 3); - String hash = oldEncoder.encode("password"); - assertThat(newEncoder.upgradeEncoding(hash)).isFalse(); } @Test public void upgradeEncodingWhenEncodedPassIsNullThenFalse() { - assertThat(encoder.upgradeEncoding(null)).isFalse(); + assertThat(this.encoder.upgradeEncoding(null)).isFalse(); } @Test public void upgradeEncodingWhenEncodedPassIsEmptyThenFalse() { - assertThat(encoder.upgradeEncoding("")).isFalse(); + assertThat(this.encoder.upgradeEncoding("")).isFalse(); } @Test(expected = IllegalArgumentException.class) public void upgradeEncodingWhenEncodedPassIsBogusThenThrowException() { - encoder.upgradeEncoding("thisIsNoValidHash"); + this.encoder.upgradeEncoding("thisIsNoValidHash"); } - private void injectPredictableSaltGen() throws Exception { byte[] bytes = new byte[16]; Arrays.fill(bytes, (byte) 0x41); - Mockito.when(keyGeneratorMock.generateKey()).thenReturn(bytes); - - //we can't use the @InjectMock-annotation because the salt-generator is set in the constructor - //and Mockito will only inject mocks if they are null - Field saltGen = encoder.getClass().getDeclaredField("saltGenerator"); + Mockito.when(this.keyGeneratorMock.generateKey()).thenReturn(bytes); + // we can't use the @InjectMock-annotation because the salt-generator is set in + // the constructor + // and Mockito will only inject mocks if they are null + Field saltGen = this.encoder.getClass().getDeclaredField("saltGenerator"); saltGen.setAccessible(true); - saltGen.set(encoder, keyGeneratorMock); + saltGen.set(this.encoder, this.keyGeneratorMock); saltGen.setAccessible(false); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoderTests.java index 1ae357f019..b9c9c1072f 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoderTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.bcrypt; -import org.junit.Test; - import java.security.SecureRandom; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -79,8 +80,7 @@ public class BCryptPasswordEncoderTests { @Test public void $2bUnicode() { - BCryptPasswordEncoder encoder = - new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2B); + BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2B); String result = encoder.encode("passw\u9292rd"); assertThat(encoder.matches("pass\u9292\u9292rd", result)).isFalse(); assertThat(encoder.matches("passw\u9292rd", result)).isTrue(); @@ -96,16 +96,14 @@ public class BCryptPasswordEncoderTests { @Test public void $2aNotMatches() { - BCryptPasswordEncoder encoder = - new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2A); + BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2A); String result = encoder.encode("password"); assertThat(encoder.matches("bogus", result)).isFalse(); } @Test public void $2bNotMatches() { - BCryptPasswordEncoder encoder = - new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2B); + BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2B); String result = encoder.encode("password"); assertThat(encoder.matches("bogus", result)).isFalse(); } @@ -115,21 +113,18 @@ public class BCryptPasswordEncoderTests { BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(8); String result = encoder.encode("password"); assertThat(encoder.matches("password", result)).isTrue(); - } @Test public void $2aCustomStrength() { - BCryptPasswordEncoder encoder = - new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2A, 8); + BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2A, 8); String result = encoder.encode("password"); assertThat(encoder.matches("password", result)).isTrue(); } @Test public void $2bCustomStrength() { - BCryptPasswordEncoder encoder = - new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2B, 8); + BCryptPasswordEncoder encoder = new BCryptPasswordEncoder(BCryptPasswordEncoder.BCryptVersion.$2B, 8); String result = encoder.encode("password"); assertThat(encoder.matches("password", result)).isTrue(); } @@ -173,16 +168,15 @@ public class BCryptPasswordEncoderTests { public void upgradeFromLowerStrength() { BCryptPasswordEncoder weakEncoder = new BCryptPasswordEncoder(5); BCryptPasswordEncoder strongEncoder = new BCryptPasswordEncoder(15); - String weakPassword = weakEncoder.encode("password"); String strongPassword = strongEncoder.encode("password"); - assertThat(weakEncoder.upgradeEncoding(strongPassword)).isFalse(); assertThat(strongEncoder.upgradeEncoding(weakPassword)).isTrue(); } /** - * @see https://github.com/spring-projects/spring-security/pull/7042#issuecomment-506755496 + * @see https://github.com/spring-projects/spring-security/pull/7042#issuecomment-506755496 */ @Test public void upgradeFromNullOrEmpty() { @@ -192,7 +186,8 @@ public class BCryptPasswordEncoderTests { } /** - * @see https://github.com/spring-projects/spring-security/pull/7042#issuecomment-506755496 + * @see https://github.com/spring-projects/spring-security/pull/7042#issuecomment-506755496 */ @Test(expected = IllegalArgumentException.class) public void upgradeFromNonBCrypt() { diff --git a/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptTests.java b/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptTests.java index 743a85fa8b..010c1e9c8e 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptTests.java @@ -11,27 +11,30 @@ // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - package org.springframework.security.crypto.bcrypt; -import org.junit.BeforeClass; -import org.junit.Test; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.junit.BeforeClass; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** * JUnit unit tests for BCrypt routines + * * @author Damien Miller */ public class BCryptTests { private static class TestObject { + private final T password; + private final String salt; + private final String expected; private TestObject(T password, String salt, String expected) { @@ -39,6 +42,7 @@ public class BCryptTests { this.salt = salt; this.expected = expected; } + } private static void print(String s) { @@ -136,23 +140,22 @@ public class BCryptTests { "$2y$06$sYDFHqOcXTjBgOsqC0WCKeMd3T1UhHuWQSxncLGtXDLMrcE6vFDti")); testObjectsString.add(new TestObject<>("~!@#$%^&*() ~!@#$%^&*()PNBFRD", "$2y$06$6Xm0gCw4g7ZNDCEp4yTise", "$2y$06$6Xm0gCw4g7ZNDCEp4yTisez0kSdpXEl66MvdxGidnmChIe8dFmMnq")); - testObjectsByteArray = new ArrayList<>(); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2a$06$fPIsBO8qRqkjj273rfaOI.", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2a$06$fPIsBO8qRqkjj273rfaOI.", "$2a$06$fPIsBO8qRqkjj273rfaOI.uiVGfgi6Z1Iz.vZr11mi/38o09TUVCy")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2a$08$Eq2r4G/76Wv39MzSX262hu", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2a$08$Eq2r4G/76Wv39MzSX262hu", "$2a$08$Eq2r4G/76Wv39MzSX262hu2lrqIItOWKIkPsMMvm5LAFD.iVB7Nmm")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2a$10$LgfYWkbzEvQ4JakH7rOvHe", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2a$10$LgfYWkbzEvQ4JakH7rOvHe", "$2a$10$LgfYWkbzEvQ4JakH7rOvHeU6pINYiHnazYxe4GikGWx9MaUr27Vpa")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2a$12$WApznUOJfkEGSmYRfnkrPO", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2a$12$WApznUOJfkEGSmYRfnkrPO", "$2a$12$WApznUOJfkEGSmYRfnkrPONS3wcUvmKuh3LpjxSs6g78T77gZta3W")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2b$06$FGWA8OlY6RtQhXBXuCJ8Wu", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2b$06$FGWA8OlY6RtQhXBXuCJ8Wu", "$2b$06$FGWA8OlY6RtQhXBXuCJ8Wu5oPJaT8BeCRmS273I6cpp5RwwjAWn7S")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2b$06$G6aYU7UhUEUDJBdTgq3CRe", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2b$06$G6aYU7UhUEUDJBdTgq3CRe", "$2b$06$G6aYU7UhUEUDJBdTgq3CRebzUYAyG8MCS3WdBk0CcPb9bfj1.3cSG")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2y$06$sYDFHqOcXTjBgOsqC0WCKe", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2y$06$sYDFHqOcXTjBgOsqC0WCKe", "$2y$06$sYDFHqOcXTjBgOsqC0WCKeOv88fqPKkuV1yGVh./TROmn1mL8gYh2")); - testObjectsByteArray.add(new TestObject<>(new byte[] { }, "$2y$06$6Xm0gCw4g7ZNDCEp4yTise", + testObjectsByteArray.add(new TestObject<>(new byte[] {}, "$2y$06$6Xm0gCw4g7ZNDCEp4yTise", "$2y$06$6Xm0gCw4g7ZNDCEp4yTisecBqTHmLJBHxTNZa8w2hupJKsIhPWOgG")); testObjectsByteArray.add(new TestObject<>(new byte[] { -11 }, "$2a$06$fPIsBO8qRqkjj273rfaOI.", "$2a$06$fPIsBO8qRqkjj273rfaOI.AyMTPwvUEmZ2EdJM/p0S0eP3UQpBas.")); @@ -310,11 +313,9 @@ public class BCryptTests { print("BCrypt.hashpw w/ international chars: "); String pw1 = "ππππππππ"; String pw2 = "????????"; - String h1 = BCrypt.hashpw(pw1, BCrypt.gensalt()); assertThat(BCrypt.checkpw(pw2, h1)).isFalse(); print("."); - String h2 = BCrypt.hashpw(pw2, BCrypt.gensalt()); assertThat(BCrypt.checkpw(pw1, h2)).isFalse(); print("."); @@ -342,8 +343,7 @@ public class BCryptTests { BCrypt.decode_base64("", 0); } - private static String encode_base64(byte d[], int len) - throws IllegalArgumentException { + private static String encode_base64(byte d[], int len) throws IllegalArgumentException { StringBuilder rs = new StringBuilder(); BCrypt.encode_base64(d, len, rs); return rs.toString(); @@ -353,7 +353,7 @@ public class BCryptTests { public void testBase64EncodeSimpleByteArrays() { assertThat(encode_base64(new byte[] { 0 }, 1)).isEqualTo(".."); assertThat(encode_base64(new byte[] { 0, 0 }, 2)).isEqualTo("..."); - assertThat(encode_base64(new byte[] { 0, 0 , 0 }, 3)).isEqualTo("...."); + assertThat(encode_base64(new byte[] { 0, 0, 0 }, 3)).isEqualTo("...."); } @Test @@ -382,15 +382,12 @@ public class BCryptTests { @Test public void testBase64EncodeDecode() { byte[] ba = new byte[3]; - for (int b = 0; b <= 0xFF; b++) { for (int i = 0; i < ba.length; i++) { Arrays.fill(ba, (byte) 0); ba[i] = (byte) b; - String s = encode_base64(ba, 3); assertThat(s.length()).isEqualTo(4); - byte[] decoded = BCrypt.decode_base64(s, 3); assertThat(decoded).isEqualTo(ba); } @@ -435,8 +432,8 @@ public class BCryptTests { @Test public void hashpwWorksWithOldRevision() { - assertThat(BCrypt.hashpw("password", "$2$05$......................")).isEqualTo( - "$2$05$......................bvpG2UfzdyW/S0ny/4YyEZrmczoJfVm"); + assertThat(BCrypt.hashpw("password", "$2$05$......................")) + .isEqualTo("$2$05$......................bvpG2UfzdyW/S0ny/4YyEZrmczoJfVm"); } @Test(expected = IllegalArgumentException.class) @@ -448,10 +445,9 @@ public class BCryptTests { public void equalsOnStringsIsCorrect() { assertThat(BCrypt.equalsNoEarlyReturn("", "")).isTrue(); assertThat(BCrypt.equalsNoEarlyReturn("test", "test")).isTrue(); - assertThat(BCrypt.equalsNoEarlyReturn("test", "")).isFalse(); assertThat(BCrypt.equalsNoEarlyReturn("", "test")).isFalse(); - assertThat(BCrypt.equalsNoEarlyReturn("test", "pass")).isFalse(); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/codec/Base64Tests.java b/crypto/src/test/java/org/springframework/security/crypto/codec/Base64Tests.java index 8da8ffe9b8..6518c4b007 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/codec/Base64Tests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/codec/Base64Tests.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.codec; -import static org.assertj.core.api.Assertions.*; +import org.junit.Test; -import org.junit.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor @@ -27,17 +28,13 @@ public class Base64Tests { @Test public void isBase64ReturnsTrueForValidBase64() { - new Base64(); // unused - - assertThat(Base64.isBase64(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', - (byte) 'D' })).isTrue(); + assertThat(Base64.isBase64(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' })).isTrue(); } @Test public void isBase64ReturnsFalseForInvalidBase64() { // Include invalid '`' character - assertThat(Base64.isBase64(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', - (byte) '`' })).isFalse(); + assertThat(Base64.isBase64(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) '`' })).isFalse(); } @Test(expected = NullPointerException.class) @@ -49,4 +46,5 @@ public class Base64Tests { public void isBase64RejectsInvalidLength() { Base64.isBase64(new byte[] { (byte) 'A' }); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java b/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java index a85b69264a..ce02345c92 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.codec; import org.junit.Rule; @@ -33,8 +34,8 @@ public class HexTests { @Test public void encode() { - assertThat(Hex.encode(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', - (byte) 'D' })).isEqualTo(new char[] {'4', '1', '4', '2', '4', '3', '4', '4'}); + assertThat(Hex.encode(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' })) + .isEqualTo(new char[] { '4', '1', '4', '2', '4', '3', '4', '4' }); } @Test @@ -44,8 +45,7 @@ public class HexTests { @Test public void decode() { - assertThat(Hex.decode("41424344")).isEqualTo(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', - (byte) 'D' }); + assertThat(Hex.decode("41424344")).isEqualTo(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' }); } @Test @@ -55,29 +55,29 @@ public class HexTests { @Test public void decodeNotEven() { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Hex-encoded string must have an even number of characters"); + this.expectedException.expect(IllegalArgumentException.class); + this.expectedException.expectMessage("Hex-encoded string must have an even number of characters"); Hex.decode("414243444"); } @Test public void decodeExistNonHexCharAtFirst() { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Detected a Non-hex character at 1 or 2 position"); + this.expectedException.expect(IllegalArgumentException.class); + this.expectedException.expectMessage("Detected a Non-hex character at 1 or 2 position"); Hex.decode("G0"); } @Test public void decodeExistNonHexCharAtSecond() { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Detected a Non-hex character at 3 or 4 position"); + this.expectedException.expect(IllegalArgumentException.class); + this.expectedException.expectMessage("Detected a Non-hex character at 3 or 4 position"); Hex.decode("410G"); } @Test public void decodeExistNonHexCharAtBoth() { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Detected a Non-hex character at 5 or 6 position"); + this.expectedException.expect(IllegalArgumentException.class); + this.expectedException.expectMessage("Detected a Non-hex character at 5 or 6 position"); Hex.decode("4142GG"); } diff --git a/crypto/src/test/java/org/springframework/security/crypto/codec/Utf8Tests.java b/crypto/src/test/java/org/springframework/security/crypto/codec/Utf8Tests.java index a78107ad87..4c1b202df8 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/codec/Utf8Tests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/codec/Utf8Tests.java @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.codec; -import static org.assertj.core.api.Assertions.*; +import java.util.Arrays; -import org.junit.*; +import org.junit.Test; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Luke Taylor @@ -32,9 +33,8 @@ public class Utf8Tests { byte[] bytes = Utf8.encode("6048b75ed560785c"); assertThat(bytes).hasSize(16); assertThat(Arrays.equals("6048b75ed560785c".getBytes("UTF-8"), bytes)).isTrue(); - String decoded = Utf8.decode(bytes); - assertThat(decoded).isEqualTo("6048b75ed560785c"); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/AesBytesEncryptorTests.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/AesBytesEncryptorTests.java index ce95884e9d..d806b028cb 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/encrypt/AesBytesEncryptorTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/AesBytesEncryptorTests.java @@ -16,28 +16,30 @@ package org.springframework.security.crypto.encrypt; +import javax.crypto.SecretKey; +import javax.crypto.spec.PBEKeySpec; + import org.junit.Before; import org.junit.Test; import org.springframework.security.crypto.codec.Hex; +import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; import org.springframework.security.crypto.keygen.BytesKeyGenerator; - -import javax.crypto.SecretKey; -import javax.crypto.spec.PBEKeySpec; +import org.springframework.security.crypto.password.Pbkdf2PasswordEncoder.SecretKeyFactoryAlgorithm; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm.GCM; -import static org.springframework.security.crypto.encrypt.CipherUtils.newSecretKey; -import static org.springframework.security.crypto.password.Pbkdf2PasswordEncoder.SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA1; /** * Tests for {@link AesBytesEncryptor} */ public class AesBytesEncryptorTests { + private String secret = "value"; + private String password = "password"; + private String hexSalt = "deadbeef"; private BytesKeyGenerator generator; @@ -45,8 +47,8 @@ public class AesBytesEncryptorTests { @Before public void setUp() { this.generator = mock(BytesKeyGenerator.class); - when(this.generator.generateKey()).thenReturn(Hex.decode("4b0febebd439db7ca77153cb254520c3")); - when(this.generator.getKeyLength()).thenReturn(16); + given(this.generator.generateKey()).willReturn(Hex.decode("4b0febebd439db7ca77153cb254520c3")); + given(this.generator.getKeyLength()).willReturn(16); } @Test @@ -65,7 +67,6 @@ public class AesBytesEncryptorTests { byte[] encryption = encryptor.encrypt(this.secret.getBytes()); assertThat(new String(Hex.encode(encryption))) .isEqualTo("4b0febebd439db7ca77153cb254520c3b7232ac29355d07869433f1ecf55fe94"); - byte[] decryption = encryptor.decrypt(encryption); assertThat(new String(decryption)).isEqualTo(this.secret); } @@ -73,12 +74,11 @@ public class AesBytesEncryptorTests { @Test public void roundtripWhenUsingGcmThenEncryptsAndDecrypts() { CryptoAssumptions.assumeGCMJCE(); - AesBytesEncryptor encryptor = new AesBytesEncryptor(this.password, this.hexSalt, this.generator, GCM); - + AesBytesEncryptor encryptor = new AesBytesEncryptor(this.password, this.hexSalt, this.generator, + CipherAlgorithm.GCM); byte[] encryption = encryptor.encrypt(this.secret.getBytes()); assertThat(new String(Hex.encode(encryption))) .isEqualTo("4b0febebd439db7ca77153cb254520c3e4d61ae38207b4e42b820d311dc3d4e0e2f37ed5ee"); - byte[] decryption = encryptor.decrypt(encryption); assertThat(new String(decryption)).isEqualTo(this.secret); } @@ -86,15 +86,12 @@ public class AesBytesEncryptorTests { @Test public void roundtripWhenUsingSecretKeyThenEncryptsAndDecrypts() { CryptoAssumptions.assumeGCMJCE(); - PBEKeySpec keySpec = new PBEKeySpec(this.password.toCharArray(), Hex.decode(this.hexSalt), - 1024, 256); - SecretKey secretKey = newSecretKey(PBKDF2WithHmacSHA1.name(), keySpec); - AesBytesEncryptor encryptor = new AesBytesEncryptor(secretKey, this.generator, GCM); - + PBEKeySpec keySpec = new PBEKeySpec(this.password.toCharArray(), Hex.decode(this.hexSalt), 1024, 256); + SecretKey secretKey = CipherUtils.newSecretKey(SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA1.name(), keySpec); + AesBytesEncryptor encryptor = new AesBytesEncryptor(secretKey, this.generator, CipherAlgorithm.GCM); byte[] encryption = encryptor.encrypt(this.secret.getBytes()); assertThat(new String(Hex.encode(encryption))) .isEqualTo("4b0febebd439db7ca77153cb254520c3e4d61ae38207b4e42b820d311dc3d4e0e2f37ed5ee"); - byte[] decryption = encryptor.decrypt(encryption); assertThat(new String(decryption)).isEqualTo(this.secret); } diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorEquivalencyTest.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorEquivalencyTests.java similarity index 71% rename from crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorEquivalencyTest.java rename to crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorEquivalencyTests.java index 548b3791af..44506004d1 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorEquivalencyTest.java +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorEquivalencyTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import java.security.SecureRandom; @@ -22,34 +23,38 @@ import java.util.UUID; import org.junit.Assert; import org.junit.Before; import org.junit.Test; + import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; import org.springframework.security.crypto.keygen.BytesKeyGenerator; import org.springframework.security.crypto.keygen.KeyGenerators; -public class BouncyCastleAesBytesEncryptorEquivalencyTest { +public class BouncyCastleAesBytesEncryptorEquivalencyTests { private byte[] testData; + private String password; + private String salt; + private SecureRandom secureRandom = new SecureRandom(); @Before public void setup() { // generate random password, salt, and test data - password = UUID.randomUUID().toString(); - /** insecure salt byte, recommend 64 or larger than 64*/ + this.password = UUID.randomUUID().toString(); + /** insecure salt byte, recommend 64 or larger than 64 */ byte[] saltBytes = new byte[16]; - secureRandom.nextBytes(saltBytes); - salt = new String(Hex.encode(saltBytes)); + this.secureRandom.nextBytes(saltBytes); + this.salt = new String(Hex.encode(saltBytes)); } @Test public void bouncyCastleAesCbcWithPredictableIvEquvalent() throws Exception { CryptoAssumptions.assumeCBCJCE(); - BytesEncryptor bcEncryptor = new BouncyCastleAesCbcBytesEncryptor(password, salt, + BytesEncryptor bcEncryptor = new BouncyCastleAesCbcBytesEncryptor(this.password, this.salt, new PredictableRandomBytesKeyGenerator(16)); - BytesEncryptor jceEncryptor = new AesBytesEncryptor(password, salt, + BytesEncryptor jceEncryptor = new AesBytesEncryptor(this.password, this.salt, new PredictableRandomBytesKeyGenerator(16)); testEquivalence(bcEncryptor, jceEncryptor); } @@ -57,19 +62,18 @@ public class BouncyCastleAesBytesEncryptorEquivalencyTest { @Test public void bouncyCastleAesCbcWithSecureIvCompatible() throws Exception { CryptoAssumptions.assumeCBCJCE(); - BytesEncryptor bcEncryptor = new BouncyCastleAesCbcBytesEncryptor(password, salt, - KeyGenerators.secureRandom(16)); - BytesEncryptor jceEncryptor = new AesBytesEncryptor(password, salt, + BytesEncryptor bcEncryptor = new BouncyCastleAesCbcBytesEncryptor(this.password, this.salt, KeyGenerators.secureRandom(16)); + BytesEncryptor jceEncryptor = new AesBytesEncryptor(this.password, this.salt, KeyGenerators.secureRandom(16)); testCompatibility(bcEncryptor, jceEncryptor); } @Test public void bouncyCastleAesGcmWithPredictableIvEquvalent() throws Exception { CryptoAssumptions.assumeGCMJCE(); - BytesEncryptor bcEncryptor = new BouncyCastleAesGcmBytesEncryptor(password, salt, + BytesEncryptor bcEncryptor = new BouncyCastleAesGcmBytesEncryptor(this.password, this.salt, new PredictableRandomBytesKeyGenerator(16)); - BytesEncryptor jceEncryptor = new AesBytesEncryptor(password, salt, + BytesEncryptor jceEncryptor = new AesBytesEncryptor(this.password, this.salt, new PredictableRandomBytesKeyGenerator(16), CipherAlgorithm.GCM); testEquivalence(bcEncryptor, jceEncryptor); } @@ -77,42 +81,41 @@ public class BouncyCastleAesBytesEncryptorEquivalencyTest { @Test public void bouncyCastleAesGcmWithSecureIvCompatible() throws Exception { CryptoAssumptions.assumeGCMJCE(); - BytesEncryptor bcEncryptor = new BouncyCastleAesGcmBytesEncryptor(password, salt, + BytesEncryptor bcEncryptor = new BouncyCastleAesGcmBytesEncryptor(this.password, this.salt, KeyGenerators.secureRandom(16)); - BytesEncryptor jceEncryptor = new AesBytesEncryptor(password, salt, - KeyGenerators.secureRandom(16), CipherAlgorithm.GCM); + BytesEncryptor jceEncryptor = new AesBytesEncryptor(this.password, this.salt, KeyGenerators.secureRandom(16), + CipherAlgorithm.GCM); testCompatibility(bcEncryptor, jceEncryptor); } private void testEquivalence(BytesEncryptor left, BytesEncryptor right) { for (int size = 1; size < 2048; size++) { - testData = new byte[size]; - secureRandom.nextBytes(testData); + this.testData = new byte[size]; + this.secureRandom.nextBytes(this.testData); // tests that right and left generate the same encrypted bytes // and can decrypt back to the original input - byte[] leftEncrypted = left.encrypt(testData); - byte[] rightEncrypted = right.encrypt(testData); + byte[] leftEncrypted = left.encrypt(this.testData); + byte[] rightEncrypted = right.encrypt(this.testData); Assert.assertArrayEquals(leftEncrypted, rightEncrypted); byte[] leftDecrypted = left.decrypt(leftEncrypted); byte[] rightDecrypted = right.decrypt(rightEncrypted); - Assert.assertArrayEquals(testData, leftDecrypted); - Assert.assertArrayEquals(testData, rightDecrypted); + Assert.assertArrayEquals(this.testData, leftDecrypted); + Assert.assertArrayEquals(this.testData, rightDecrypted); } - } private void testCompatibility(BytesEncryptor left, BytesEncryptor right) { // tests that right can decrypt what left encrypted and vice versa // and that the decypted data is the same as the original for (int size = 1; size < 2048; size++) { - testData = new byte[size]; - secureRandom.nextBytes(testData); - byte[] leftEncrypted = left.encrypt(testData); - byte[] rightEncrypted = right.encrypt(testData); + this.testData = new byte[size]; + this.secureRandom.nextBytes(this.testData); + byte[] leftEncrypted = left.encrypt(this.testData); + byte[] rightEncrypted = right.encrypt(this.testData); byte[] leftDecrypted = left.decrypt(rightEncrypted); byte[] rightDecrypted = right.decrypt(leftEncrypted); - Assert.assertArrayEquals(testData, leftDecrypted); - Assert.assertArrayEquals(testData, rightDecrypted); + Assert.assertArrayEquals(this.testData, leftDecrypted); + Assert.assertArrayEquals(this.testData, rightDecrypted); } } @@ -130,13 +133,15 @@ public class BouncyCastleAesBytesEncryptorEquivalencyTest { this.keyLength = keyLength; } + @Override public int getKeyLength() { - return keyLength; + return this.keyLength; } + @Override public byte[] generateKey() { - byte[] bytes = new byte[keyLength]; - random.nextBytes(bytes); + byte[] bytes = new byte[this.keyLength]; + this.random.nextBytes(bytes); return bytes; } diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorTest.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorTests.java similarity index 74% rename from crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorTest.java rename to crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorTests.java index 3efe507387..88d0712c5c 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorTest.java +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/BouncyCastleAesBytesEncryptorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import java.security.SecureRandom; @@ -22,58 +23,60 @@ import org.bouncycastle.util.Arrays; import org.junit.Assert; import org.junit.Before; import org.junit.Test; + import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.keygen.KeyGenerators; -public class BouncyCastleAesBytesEncryptorTest { +public class BouncyCastleAesBytesEncryptorTests { private byte[] testData; + private String password; + private String salt; @Before public void setup() { // generate random password, salt, and test data SecureRandom secureRandom = new SecureRandom(); - password = UUID.randomUUID().toString(); + this.password = UUID.randomUUID().toString(); byte[] saltBytes = new byte[16]; secureRandom.nextBytes(saltBytes); - salt = new String(Hex.encode(saltBytes)); - testData = new byte[1024 * 1024]; - secureRandom.nextBytes(testData); + this.salt = new String(Hex.encode(saltBytes)); + this.testData = new byte[1024 * 1024]; + secureRandom.nextBytes(this.testData); } @Test public void bcCbcWithSecureIvGeneratesDifferentMessages() { - BytesEncryptor bcEncryptor = new BouncyCastleAesCbcBytesEncryptor(password, salt); + BytesEncryptor bcEncryptor = new BouncyCastleAesCbcBytesEncryptor(this.password, this.salt); generatesDifferentCipherTexts(bcEncryptor); } @Test public void bcGcmWithSecureIvGeneratesDifferentMessages() { - BytesEncryptor bcEncryptor = new BouncyCastleAesGcmBytesEncryptor(password, salt); + BytesEncryptor bcEncryptor = new BouncyCastleAesGcmBytesEncryptor(this.password, this.salt); generatesDifferentCipherTexts(bcEncryptor); } private void generatesDifferentCipherTexts(BytesEncryptor bcEncryptor) { - byte[] encrypted1 = bcEncryptor.encrypt(testData); - byte[] encrypted2 = bcEncryptor.encrypt(testData); + byte[] encrypted1 = bcEncryptor.encrypt(this.testData); + byte[] encrypted2 = bcEncryptor.encrypt(this.testData); Assert.assertFalse(Arrays.areEqual(encrypted1, encrypted2)); byte[] decrypted1 = bcEncryptor.decrypt(encrypted1); byte[] decrypted2 = bcEncryptor.decrypt(encrypted2); - Assert.assertArrayEquals(testData, decrypted1); - Assert.assertArrayEquals(testData, decrypted2); + Assert.assertArrayEquals(this.testData, decrypted1); + Assert.assertArrayEquals(this.testData, decrypted2); } @Test(expected = IllegalArgumentException.class) public void bcCbcWithWrongLengthIv() { - new BouncyCastleAesCbcBytesEncryptor(password, salt, - KeyGenerators.secureRandom(8)); + new BouncyCastleAesCbcBytesEncryptor(this.password, this.salt, KeyGenerators.secureRandom(8)); } @Test(expected = IllegalArgumentException.class) public void bcGcmWithWrongLengthIv() { - new BouncyCastleAesGcmBytesEncryptor(password, salt, - KeyGenerators.secureRandom(8)); + new BouncyCastleAesGcmBytesEncryptor(this.password, this.salt, KeyGenerators.secureRandom(8)); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/CryptoAssumptions.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/CryptoAssumptions.java index 188e61385e..3fca2601c8 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/encrypt/CryptoAssumptions.java +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/CryptoAssumptions.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; import java.security.NoSuchAlgorithmException; @@ -22,9 +23,13 @@ import javax.crypto.NoSuchPaddingException; import org.junit.Assume; import org.junit.AssumptionViolatedException; + import org.springframework.security.crypto.encrypt.AesBytesEncryptor.CipherAlgorithm; -public class CryptoAssumptions { +public final class CryptoAssumptions { + + private CryptoAssumptions() { + } public static void assumeGCMJCE() { assumeAes256(CipherAlgorithm.GCM); @@ -40,18 +45,13 @@ public class CryptoAssumptions { Cipher.getInstance(cipherAlgorithm.toString()); aes256Available = Cipher.getMaxAllowedKeyLength("AES") >= 256; } - catch (NoSuchAlgorithmException e) { - throw new AssumptionViolatedException( - cipherAlgorithm + " not available, skipping test", e); + catch (NoSuchAlgorithmException ex) { + throw new AssumptionViolatedException(cipherAlgorithm + " not available, skipping test", ex); } - catch (NoSuchPaddingException e) { - throw new AssumptionViolatedException( - cipherAlgorithm + " padding not available, skipping test", e); + catch (NoSuchPaddingException ex) { + throw new AssumptionViolatedException(cipherAlgorithm + " padding not available, skipping test", ex); } - Assume.assumeTrue( - "AES key length of 256 not allowed, skipping test", - aes256Available); - + Assume.assumeTrue("AES key length of 256 not allowed, skipping test", aes256Available); } } diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/EncryptorsTests.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/EncryptorsTests.java index d0c69da5ae..470cc796f7 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/encrypt/EncryptorsTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/EncryptorsTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.encrypt; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + public class EncryptorsTests { @Test @@ -29,8 +30,7 @@ public class EncryptorsTests { assertThat(result).isNotNull(); assertThat(new String(result).equals("text")).isFalse(); assertThat(new String(encryptor.decrypt(result))).isEqualTo("text"); - assertThat(new String(result)).isNotEqualTo( - new String(encryptor.encrypt("text".getBytes()))); + assertThat(new String(result)).isNotEqualTo(new String(encryptor.encrypt("text".getBytes()))); } @Test @@ -41,8 +41,7 @@ public class EncryptorsTests { assertThat(result).isNotNull(); assertThat(new String(result).equals("text")).isFalse(); assertThat(new String(encryptor.decrypt(result))).isEqualTo("text"); - assertThat(new String(result)).isNotEqualTo( - new String(encryptor.encrypt("text".getBytes()))); + assertThat(new String(result)).isNotEqualTo(new String(encryptor.encrypt("text".getBytes()))); } @Test @@ -70,8 +69,7 @@ public class EncryptorsTests { @Test public void queryableText() { CryptoAssumptions.assumeCBCJCE(); - TextEncryptor encryptor = Encryptors.queryableText("password", - "5c0744940b5c369b"); + TextEncryptor encryptor = Encryptors.queryableText("password", "5c0744940b5c369b"); String result = encryptor.encrypt("text"); assertThat(result).isNotNull(); assertThat(result.equals("text")).isFalse(); diff --git a/crypto/src/test/java/org/springframework/security/crypto/factory/PasswordEncoderFactoriesTests.java b/crypto/src/test/java/org/springframework/security/crypto/factory/PasswordEncoderFactoriesTests.java index f11207784a..89143fae4e 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/factory/PasswordEncoderFactoriesTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/factory/PasswordEncoderFactoriesTests.java @@ -17,6 +17,7 @@ package org.springframework.security.crypto.factory; import org.junit.Test; + import org.springframework.security.crypto.password.PasswordEncoder; import static org.assertj.core.api.Assertions.assertThat; @@ -26,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 5.0 */ public class PasswordEncoderFactoriesTests { + private PasswordEncoder encoder = PasswordEncoderFactories.createDelegatingPasswordEncoder(); private String rawPassword = "password"; @@ -33,7 +35,6 @@ public class PasswordEncoderFactoriesTests { @Test public void encodeWhenDefaultThenBCryptUsed() { String encodedPassword = this.encoder.encode(this.rawPassword); - assertThat(encodedPassword).startsWith("{bcrypt}"); assertThat(this.encoder.matches(this.rawPassword, encodedPassword)).isTrue(); } diff --git a/crypto/src/test/java/org/springframework/security/crypto/keygen/Base64StringKeyGeneratorTests.java b/crypto/src/test/java/org/springframework/security/crypto/keygen/Base64StringKeyGeneratorTests.java index 951622f2da..ca108238b1 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/keygen/Base64StringKeyGeneratorTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/keygen/Base64StringKeyGeneratorTests.java @@ -16,17 +16,18 @@ package org.springframework.security.crypto.keygen; -import org.junit.Test; - import java.util.Base64; -import static org.assertj.core.api.Assertions.*; +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch * @since 5.0 */ public class Base64StringKeyGeneratorTests { + @Test(expected = IllegalArgumentException.class) public void constructorIntWhenLessThan32ThenIllegalArgumentException() { new Base64StringKeyGenerator(31); @@ -63,4 +64,5 @@ public class Base64StringKeyGeneratorTests { String result = new Base64StringKeyGenerator(Base64.getUrlEncoder(), size).generateKey(); assertThat(Base64.getUrlDecoder().decode(result.getBytes())).hasSize(size); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/keygen/KeyGeneratorsTests.java b/crypto/src/test/java/org/springframework/security/crypto/keygen/KeyGeneratorsTests.java index 665323fe2e..f0ec884464 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/keygen/KeyGeneratorsTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/keygen/KeyGeneratorsTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.keygen; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.crypto.keygen; import java.util.Arrays; import org.junit.Test; + import org.springframework.security.crypto.codec.Hex; +import static org.assertj.core.api.Assertions.assertThat; + public class KeyGeneratorsTests { @Test diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/DelegatingPasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/DelegatingPasswordEncoderTests.java index 8bca684658..e123f27c5f 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/DelegatingPasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/DelegatingPasswordEncoderTests.java @@ -16,22 +16,22 @@ package org.springframework.security.crypto.password; +import java.util.HashMap; +import java.util.Hashtable; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import java.util.HashMap; -import java.util.Hashtable; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -40,6 +40,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class DelegatingPasswordEncoderTests { + @Mock private PasswordEncoder bcrypt; @@ -68,7 +69,6 @@ public class DelegatingPasswordEncoderTests { this.delegates = new HashMap<>(); this.delegates.put(this.bcryptId, this.bcrypt); this.delegates.put("noop", this.noop); - this.passwordEncoder = new DelegatingPasswordEncoder(this.bcryptId, this.delegates); } @@ -91,95 +91,82 @@ public class DelegatingPasswordEncoderTests { public void matchesWhenCustomDefaultPasswordEncoderForMatchesThenDelegates() { String encodedPassword = "{unmapped}" + this.rawPassword; this.passwordEncoder.setDefaultPasswordEncoderForMatches(this.invalidId); - assertThat(this.passwordEncoder.matches(this.rawPassword, encodedPassword)).isFalse(); - verify(this.invalidId).matches(this.rawPassword, encodedPassword); verifyZeroInteractions(this.bcrypt, this.noop); } @Test public void encodeWhenValidThenUsesIdForEncode() { - when(this.bcrypt.encode(this.rawPassword)).thenReturn(this.encodedPassword); - + given(this.bcrypt.encode(this.rawPassword)).willReturn(this.encodedPassword); assertThat(this.passwordEncoder.encode(this.rawPassword)).isEqualTo(this.bcryptEncodedPassword); } @Test public void matchesWhenBCryptThenDelegatesToBCrypt() { - when(this.bcrypt.matches(this.rawPassword, this.encodedPassword)).thenReturn(true); - + given(this.bcrypt.matches(this.rawPassword, this.encodedPassword)).willReturn(true); assertThat(this.passwordEncoder.matches(this.rawPassword, this.bcryptEncodedPassword)).isTrue(); - verify(this.bcrypt).matches(this.rawPassword, this.encodedPassword); verifyZeroInteractions(this.noop); } @Test public void matchesWhenNoopThenDelegatesToNoop() { - when(this.noop.matches(this.rawPassword, this.encodedPassword)).thenReturn(true); - + given(this.noop.matches(this.rawPassword, this.encodedPassword)).willReturn(true); assertThat(this.passwordEncoder.matches(this.rawPassword, this.noopEncodedPassword)).isTrue(); - verify(this.noop).matches(this.rawPassword, this.encodedPassword); verifyZeroInteractions(this.bcrypt); } @Test public void matchesWhenUnMappedThenIllegalArgumentException() { - assertThatThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "{unmapped}" + this.rawPassword)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("There is no PasswordEncoder mapped for the id \"unmapped\""); - + assertThatIllegalArgumentException() + .isThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "{unmapped}" + this.rawPassword)) + .withMessage("There is no PasswordEncoder mapped for the id \"unmapped\""); verifyZeroInteractions(this.bcrypt, this.noop); } @Test public void matchesWhenNoClosingPrefixStringThenIllegalArgumentExcetion() { - assertThatThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "{bcrypt" + this.rawPassword)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("There is no PasswordEncoder mapped for the id \"null\""); - + assertThatIllegalArgumentException() + .isThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "{bcrypt" + this.rawPassword)) + .withMessage("There is no PasswordEncoder mapped for the id \"null\""); verifyZeroInteractions(this.bcrypt, this.noop); } @Test public void matchesWhenNoStartingPrefixStringThenFalse() { - assertThatThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "bcrypt}" + this.rawPassword)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("There is no PasswordEncoder mapped for the id \"null\""); - + assertThatIllegalArgumentException() + .isThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "bcrypt}" + this.rawPassword)) + .withMessage("There is no PasswordEncoder mapped for the id \"null\""); verifyZeroInteractions(this.bcrypt, this.noop); } @Test public void matchesWhenNoIdStringThenFalse() { - assertThatThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "{}" + this.rawPassword)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("There is no PasswordEncoder mapped for the id \"\""); - + assertThatIllegalArgumentException() + .isThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "{}" + this.rawPassword)) + .withMessage("There is no PasswordEncoder mapped for the id \"\""); verifyZeroInteractions(this.bcrypt, this.noop); } @Test public void matchesWhenPrefixInMiddleThenFalse() { - assertThatThrownBy(() -> this.passwordEncoder.matches(this.rawPassword, "invalid" + this.bcryptEncodedPassword)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("There is no PasswordEncoder mapped for the id \"null\""); - + assertThatIllegalArgumentException() + .isThrownBy( + () -> this.passwordEncoder.matches(this.rawPassword, "invalid" + this.bcryptEncodedPassword)) + .isInstanceOf(IllegalArgumentException.class) + .withMessage("There is no PasswordEncoder mapped for the id \"null\""); verifyZeroInteractions(this.bcrypt, this.noop); } @Test public void matchesWhenIdIsNullThenFalse() { this.delegates = new Hashtable<>(this.delegates); - DelegatingPasswordEncoder passwordEncoder = new DelegatingPasswordEncoder(this.bcryptId, this.delegates); - - assertThatThrownBy(() -> passwordEncoder.matches(this.rawPassword, this.rawPassword)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("There is no PasswordEncoder mapped for the id \"null\""); - + assertThatIllegalArgumentException() + .isThrownBy(() -> passwordEncoder.matches(this.rawPassword, this.rawPassword)) + .withMessage("There is no PasswordEncoder mapped for the id \"null\""); verifyZeroInteractions(this.bcrypt, this.noop); } @@ -187,10 +174,8 @@ public class DelegatingPasswordEncoderTests { public void matchesWhenNullIdThenDelegatesToInvalidId() { this.delegates.put(null, this.invalidId); this.passwordEncoder = new DelegatingPasswordEncoder(this.bcryptId, this.delegates); - when(this.invalidId.matches(this.rawPassword, this.encodedPassword)).thenReturn(true); - + given(this.invalidId.matches(this.rawPassword, this.encodedPassword)).willReturn(true); assertThat(this.passwordEncoder.matches(this.rawPassword, this.encodedPassword)).isTrue(); - verify(this.invalidId).matches(this.rawPassword, this.encodedPassword); verifyZeroInteractions(this.bcrypt, this.noop); } @@ -212,29 +197,26 @@ public class DelegatingPasswordEncoderTests { @Test public void upgradeEncodingWhenIdInvalidFormatThenTrue() { - assertThat(this.passwordEncoder.upgradeEncoding("{bcrypt"+ this.encodedPassword)).isTrue(); + assertThat(this.passwordEncoder.upgradeEncoding("{bcrypt" + this.encodedPassword)).isTrue(); } @Test public void upgradeEncodingWhenSameIdAndEncoderFalseThenEncoderDecidesFalse() { assertThat(this.passwordEncoder.upgradeEncoding(this.bcryptEncodedPassword)).isFalse(); - - verify(bcrypt).upgradeEncoding(this.encodedPassword); + verify(this.bcrypt).upgradeEncoding(this.encodedPassword); } @Test public void upgradeEncodingWhenSameIdAndEncoderTrueThenEncoderDecidesTrue() { - when(this.bcrypt.upgradeEncoding(any())).thenReturn(true); - + given(this.bcrypt.upgradeEncoding(any())).willReturn(true); assertThat(this.passwordEncoder.upgradeEncoding(this.bcryptEncodedPassword)).isTrue(); - - verify(bcrypt).upgradeEncoding(this.encodedPassword); + verify(this.bcrypt).upgradeEncoding(this.encodedPassword); } @Test public void upgradeEncodingWhenDifferentIdThenTrue() { assertThat(this.passwordEncoder.upgradeEncoding(this.noopEncodedPassword)).isTrue(); - - verifyZeroInteractions(bcrypt); + verifyZeroInteractions(this.bcrypt); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/DigesterTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/DigesterTests.java index 1f270dc2d8..b5a2c24351 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/DigesterTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/DigesterTests.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.security.crypto.codec.Hex; import org.springframework.security.crypto.codec.Utf8; -import org.springframework.security.crypto.password.Digester; + +import static org.assertj.core.api.Assertions.assertThat; public class DigesterTests { diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/LdapShaPasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/LdapShaPasswordEncoderTests.java index 36e900aabf..c2ba10087f 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/LdapShaPasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/LdapShaPasswordEncoderTests.java @@ -17,6 +17,7 @@ package org.springframework.security.crypto.password; import org.junit.Test; + import org.springframework.security.crypto.keygen.KeyGenerators; import static org.assertj.core.api.Assertions.assertThat; @@ -28,14 +29,9 @@ import static org.assertj.core.api.Assertions.assertThat; */ @SuppressWarnings("deprecation") public class LdapShaPasswordEncoderTests { - // ~ Instance fields - // ================================================================================================ LdapShaPasswordEncoder sha = new LdapShaPasswordEncoder(); - // ~ Methods - // ======================================================================================================== - @Test public void invalidPasswordFails() { assertThat(this.sha.matches("wrongpassword", "{SHA}ddSFGmjXYPbZC+NXR2kCzBRjqiE=")).isFalse(); @@ -87,14 +83,11 @@ public class LdapShaPasswordEncoderTests { public void correctPrefixCaseIsUsed() { this.sha.setForceLowerCasePrefix(false); assertThat(this.sha.encode("somepassword").startsWith("{SSHA}")); - this.sha.setForceLowerCasePrefix(true); assertThat(this.sha.encode("somepassword").startsWith("{ssha}")); - this.sha = new LdapShaPasswordEncoder(KeyGenerators.shared(0)); this.sha.setForceLowerCasePrefix(false); assertThat(this.sha.encode("somepassword").startsWith("{SHA}")); - this.sha.setForceLowerCasePrefix(true); assertThat(this.sha.encode("somepassword").startsWith("{SSHA}")); } @@ -109,4 +102,5 @@ public class LdapShaPasswordEncoderTests { // No right brace this.sha.matches("somepassword", "{SSHA25ro4PKC8jhQZ26jVsozhX/xaP0suHgX"); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/Md4PasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/Md4PasswordEncoderTests.java index 4a4d9fe3f8..a1de26c6c3 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/Md4PasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/Md4PasswordEncoderTests.java @@ -16,11 +16,10 @@ package org.springframework.security.crypto.password; +import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import org.junit.Test; - @SuppressWarnings("deprecation") public class Md4PasswordEncoderTests { @@ -63,7 +62,6 @@ public class Md4PasswordEncoderTests { String rawPassword = "password"; Md4PasswordEncoder md4 = new Md4PasswordEncoder(); String encodedPassword = md4.encode(rawPassword); - assertThat(md4.matches(rawPassword, encodedPassword)).isTrue(); } @@ -72,5 +70,5 @@ public class Md4PasswordEncoderTests { Md4PasswordEncoder encoder = new Md4PasswordEncoder(); assertThat(encoder.matches("password", "{thisissalt}6cc7924dad12ade79dfb99e424f25260")); } -} +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoderTests.java index 60d3001959..057545ca41 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/MessageDigestPasswordEncoderTests.java @@ -32,14 +32,12 @@ import static org.assertj.core.api.Assertions.assertThat; */ @SuppressWarnings("deprecation") public class MessageDigestPasswordEncoderTests { - // ~ Methods - // ======================================================================================================== @Test public void md5BasicFunctionality() { MessageDigestPasswordEncoder pe = new MessageDigestPasswordEncoder("MD5"); String raw = "abc123"; - assertThat(pe.matches( raw, "{THIS_IS_A_SALT}a68aafd90299d0b137de28fb4bb68573")).isTrue(); + assertThat(pe.matches(raw, "{THIS_IS_A_SALT}a68aafd90299d0b137de28fb4bb68573")).isTrue(); } @Test @@ -97,7 +95,6 @@ public class MessageDigestPasswordEncoderTests { MessageDigestPasswordEncoder pe = new MessageDigestPasswordEncoder("SHA-1"); String raw = "abc123"; assertThat(pe.matches(raw, "{THIS_IS_A_SALT}b2f50ffcbd3407fe9415c062d55f54731f340d32")); - } @Test @@ -119,4 +116,5 @@ public class MessageDigestPasswordEncoderTests { public void testInvalidStrength() { new MessageDigestPasswordEncoder("SHA-666"); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/PasswordEncoderUtilsTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/PasswordEncoderUtilsTests.java index 183bb1f088..beeadc065a 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/PasswordEncoderUtilsTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/PasswordEncoderUtilsTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; import org.junit.Test; @@ -67,4 +68,5 @@ public class PasswordEncoderUtilsTests { public void equalsWhenSameThenTrue() { assertThat(PasswordEncoderUtils.equals("abcdef", "abcdef")).isTrue(); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoderTests.java index 51a499ebaf..bd54171718 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/Pbkdf2PasswordEncoderTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; import java.util.Arrays; @@ -25,6 +26,7 @@ import org.springframework.security.crypto.keygen.KeyGenerators; import static org.assertj.core.api.Assertions.assertThat; public class Pbkdf2PasswordEncoderTests { + private Pbkdf2PasswordEncoder encoder = new Pbkdf2PasswordEncoder("secret"); @Test @@ -37,8 +39,7 @@ public class Pbkdf2PasswordEncoderTests { @Test public void matchesLengthChecked() { String result = this.encoder.encode("password"); - assertThat(this.encoder.matches("password", - result.substring(0, result.length() - 2))).isFalse(); + assertThat(this.encoder.matches("password", result.substring(0, result.length() - 2))).isFalse(); } @Test @@ -68,17 +69,14 @@ public class Pbkdf2PasswordEncoderTests { String encodedPassword = "ab1146a8458d4ce4e65789e5a3f60e423373cfa10b01abd23739e5ae2fdc37f8e9ede4ae6da65264"; String originalEncodedPassword = "ab1146a8458d4ce4ab1146a8458d4ce4e65789e5a3f60e423373cfa10b01abd23739e5ae2fdc37f8e9ede4ae6da65264"; byte[] originalBytes = Hex.decode(originalEncodedPassword); - byte[] fixedBytes = Arrays.copyOfRange(originalBytes, saltLength, - originalBytes.length); + byte[] fixedBytes = Arrays.copyOfRange(originalBytes, saltLength, originalBytes.length); String fixedHex = String.valueOf(Hex.encode(fixedBytes)); - assertThat(fixedHex).isEqualTo(encodedPassword); } @Test public void encodeAndMatchWhenBase64ThenSuccess() { this.encoder.setEncodeHashAsBase64(true); - String rawPassword = "password"; String encodedPassword = this.encoder.encode(rawPassword); assertThat(this.encoder.matches(rawPassword, encodedPassword)).isTrue(); @@ -89,15 +87,14 @@ public class Pbkdf2PasswordEncoderTests { this.encoder.setEncodeHashAsBase64(true); String rawPassword = "password"; String encodedPassword = "3FOwOMcDgxP+z1x/sv184LFY2WVD+ZGMgYP3LPOSmCcDmk1XPYvcCQ=="; - assertThat(this.encoder.matches(rawPassword, encodedPassword)).isTrue(); - java.util.Base64.getDecoder().decode(encodedPassword); // validate can decode as Base64 + java.util.Base64.getDecoder().decode(encodedPassword); // validate can decode as + // Base64 } @Test public void encodeAndMatchWhenSha256ThenSuccess() { this.encoder.setAlgorithm(Pbkdf2PasswordEncoder.SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA256); - String rawPassword = "password"; String encodedPassword = this.encoder.encode(rawPassword); assertThat(this.encoder.matches(rawPassword, encodedPassword)).isTrue(); @@ -106,11 +103,11 @@ public class Pbkdf2PasswordEncoderTests { @Test public void matchWhenSha256ThenSuccess() { this.encoder.setAlgorithm(Pbkdf2PasswordEncoder.SecretKeyFactoryAlgorithm.PBKDF2WithHmacSHA256); - String rawPassword = "password"; String encodedPassword = "821447f994e2b04c5014e31fa9fca4ae1cc9f2188c4ed53d3ddb5ba7980982b51a0ecebfc0b81a79"; assertThat(this.encoder.matches(rawPassword, encodedPassword)).isTrue(); } + /** * Used to find the iteration count that takes .5 seconds. */ @@ -126,8 +123,7 @@ public class Pbkdf2PasswordEncoderTests { long avg = 0; while (avg < HALF_SECOND) { iterations += 10000; - Pbkdf2PasswordEncoder encoder = new Pbkdf2PasswordEncoder("", iterations, - 256); + Pbkdf2PasswordEncoder encoder = new Pbkdf2PasswordEncoder("", iterations, 256); String encoded = encoder.encode("password"); System.out.println("Trying " + iterations); long start = System.currentTimeMillis(); @@ -141,4 +137,5 @@ public class Pbkdf2PasswordEncoderTests { } System.out.println("Iterations " + iterations); } + } diff --git a/crypto/src/test/java/org/springframework/security/crypto/password/StandardPasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/password/StandardPasswordEncoderTests.java index ed9bc39a3d..0c89dbcfd5 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/password/StandardPasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/password/StandardPasswordEncoderTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.password; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + @SuppressWarnings("deprecation") public class StandardPasswordEncoderTests { @@ -26,21 +27,21 @@ public class StandardPasswordEncoderTests { @Test public void matches() { - String result = encoder.encode("password"); + String result = this.encoder.encode("password"); assertThat(result).isNotEqualTo("password"); - assertThat(encoder.matches("password", result)).isTrue(); + assertThat(this.encoder.matches("password", result)).isTrue(); } @Test public void matchesLengthChecked() { - String result = encoder.encode("password"); - assertThat(encoder.matches("password", result.substring(0, result.length() - 2))).isFalse(); + String result = this.encoder.encode("password"); + assertThat(this.encoder.matches("password", result.substring(0, result.length() - 2))).isFalse(); } @Test public void notMatches() { - String result = encoder.encode("password"); - assertThat(encoder.matches("bogus", result)).isFalse(); + String result = this.encoder.encode("password"); + assertThat(this.encoder.matches("bogus", result)).isFalse(); } } diff --git a/crypto/src/test/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoderTests.java b/crypto/src/test/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoderTests.java index 5ac30814bc..6dcd99865a 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoderTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/scrypt/SCryptPasswordEncoderTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.crypto.scrypt; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Shazin Sadakath * @@ -67,7 +68,6 @@ public class SCryptPasswordEncoderTests { public void samePasswordWithDifferentParams() { SCryptPasswordEncoder oldEncoder = new SCryptPasswordEncoder(16384, 8, 1, 32, 64); SCryptPasswordEncoder newEncoder = new SCryptPasswordEncoder(); - String password = "secret"; String oldEncodedPassword = oldEncoder.encode(password); assertThat(newEncoder.matches(password, oldEncodedPassword)).isTrue(); @@ -139,10 +139,8 @@ public class SCryptPasswordEncoderTests { public void upgradeEncodingWhenWeakerToStrongerThenFalse() { SCryptPasswordEncoder weakEncoder = new SCryptPasswordEncoder((int) Math.pow(2, 10), 4, 1, 32, 64); SCryptPasswordEncoder strongEncoder = new SCryptPasswordEncoder((int) Math.pow(2, 16), 8, 1, 32, 64); - String weakPassword = weakEncoder.encode("password"); String strongPassword = strongEncoder.encode("password"); - assertThat(weakEncoder.upgradeEncoding(strongPassword)).isFalse(); } @@ -150,10 +148,8 @@ public class SCryptPasswordEncoderTests { public void upgradeEncodingWhenStrongerToWeakerThenTrue() { SCryptPasswordEncoder weakEncoder = new SCryptPasswordEncoder((int) Math.pow(2, 10), 4, 1, 32, 64); SCryptPasswordEncoder strongEncoder = new SCryptPasswordEncoder((int) Math.pow(2, 16), 8, 1, 32, 64); - String weakPassword = weakEncoder.encode("password"); String strongPassword = strongEncoder.encode("password"); - assertThat(strongEncoder.upgradeEncoding(weakPassword)).isTrue(); } @@ -161,5 +157,5 @@ public class SCryptPasswordEncoderTests { public void upgradeEncodingWhenInvalidInputThenException() { new SCryptPasswordEncoder().upgradeEncoding("not-a-scrypt-password"); } -} +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/util/EncodingUtilsTests.java b/crypto/src/test/java/org/springframework/security/crypto/util/EncodingUtilsTests.java index 3c6fe20f87..cc233a211e 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/util/EncodingUtilsTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/util/EncodingUtilsTests.java @@ -13,48 +13,49 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.crypto.util; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.crypto.util; import java.util.Arrays; import org.junit.Test; + import org.springframework.security.crypto.codec.Hex; +import static org.assertj.core.api.Assertions.assertThat; + public class EncodingUtilsTests { @Test public void hexEncode() { - byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, - (byte) 67, (byte) 0xC0, (byte) 0xC1, (byte) 0xC2 }; + byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, (byte) 67, (byte) 0xC0, (byte) 0xC1, + (byte) 0xC2 }; String result = new String(Hex.encode(bytes)); assertThat(result).isEqualTo("01ff414243c0c1c2"); } @Test public void hexDecode() { - byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, - (byte) 67, (byte) 0xC0, (byte) 0xC1, (byte) 0xC2 }; + byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, (byte) 67, (byte) 0xC0, (byte) 0xC1, + (byte) 0xC2 }; byte[] result = Hex.decode("01ff414243c0c1c2"); assertThat(Arrays.equals(bytes, result)).isTrue(); } @Test public void concatenate() { - byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, - (byte) 67, (byte) 0xC0, (byte) 0xC1, (byte) 0xC2 }; + byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, (byte) 67, (byte) 0xC0, (byte) 0xC1, + (byte) 0xC2 }; byte[] one = new byte[] { (byte) 0x01 }; byte[] two = new byte[] { (byte) 0xFF, (byte) 65, (byte) 66 }; byte[] three = new byte[] { (byte) 67, (byte) 0xC0, (byte) 0xC1, (byte) 0xC2 }; - assertThat(Arrays.equals(bytes, - EncodingUtils.concatenate(one, two, three))).isTrue(); + assertThat(Arrays.equals(bytes, EncodingUtils.concatenate(one, two, three))).isTrue(); } @Test public void subArray() { - byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, - (byte) 67, (byte) 0xC0, (byte) 0xC1, (byte) 0xC2 }; + byte[] bytes = new byte[] { (byte) 0x01, (byte) 0xFF, (byte) 65, (byte) 66, (byte) 67, (byte) 0xC0, (byte) 0xC1, + (byte) 0xC2 }; byte[] two = new byte[] { (byte) 0xFF, (byte) 65, (byte) 66 }; byte[] subArray = EncodingUtils.subArray(bytes, 1, 4); assertThat(subArray).hasSize(3); diff --git a/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java b/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java index 936ccf6f25..3696904a9a 100644 --- a/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java +++ b/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.data.repository.query; import org.springframework.data.spel.spi.EvaluationContextExtension; @@ -75,11 +76,11 @@ import org.springframework.security.core.context.SecurityContextHolder; * This works because the principal in this instance is a User which has an id field on * it. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ -public class SecurityEvaluationContextExtension - implements EvaluationContextExtension { +public class SecurityEvaluationContextExtension implements EvaluationContextExtension { + private Authentication authentication; /** @@ -91,13 +92,13 @@ public class SecurityEvaluationContextExtension /** * Creates a new instance that always uses the same {@link Authentication} object. - * * @param authentication the {@link Authentication} to use */ public SecurityEvaluationContextExtension(Authentication authentication) { this.authentication = authentication; } + @Override public String getExtensionId() { return "security"; } @@ -113,8 +114,8 @@ public class SecurityEvaluationContextExtension if (this.authentication != null) { return this.authentication; } - SecurityContext context = SecurityContextHolder.getContext(); return context.getAuthentication(); } + } diff --git a/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java b/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java index 2efcb97c7f..b4937afebb 100644 --- a/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java +++ b/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.data.repository.query; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.security.access.expression.SecurityExpressionRoot; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; @@ -25,11 +27,12 @@ import org.springframework.security.core.context.SecurityContextHolder; import static org.assertj.core.api.Assertions.assertThat; public class SecurityEvaluationContextExtensionTests { + SecurityEvaluationContextExtension securityExtension; @Before public void setup() { - securityExtension = new SecurityEvaluationContextExtension(); + this.securityExtension = new SecurityEvaluationContextExtension(); } @After @@ -44,36 +47,29 @@ public class SecurityEvaluationContextExtensionTests { @Test public void getRootObjectSecurityContextHolderAuthentication() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken( - "user", "password", "ROLE_USER"); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); SecurityContextHolder.getContext().setAuthentication(authentication); - assertThat(getRoot().getAuthentication()).isSameAs(authentication); } @Test public void getRootObjectExplicitAuthenticationOverridesSecurityContextHolder() { - TestingAuthenticationToken explicit = new TestingAuthenticationToken("explicit", - "password", "ROLE_EXPLICIT"); - securityExtension = new SecurityEvaluationContextExtension(explicit); - - TestingAuthenticationToken authentication = new TestingAuthenticationToken( - "user", "password", "ROLE_USER"); + TestingAuthenticationToken explicit = new TestingAuthenticationToken("explicit", "password", "ROLE_EXPLICIT"); + this.securityExtension = new SecurityEvaluationContextExtension(explicit); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); SecurityContextHolder.getContext().setAuthentication(authentication); - assertThat(getRoot().getAuthentication()).isSameAs(explicit); } @Test public void getRootObjectExplicitAuthentication() { - TestingAuthenticationToken explicit = new TestingAuthenticationToken("explicit", - "password", "ROLE_EXPLICIT"); - securityExtension = new SecurityEvaluationContextExtension(explicit); - + TestingAuthenticationToken explicit = new TestingAuthenticationToken("explicit", "password", "ROLE_EXPLICIT"); + this.securityExtension = new SecurityEvaluationContextExtension(explicit); assertThat(getRoot().getAuthentication()).isSameAs(explicit); } private SecurityExpressionRoot getRoot() { - return (SecurityExpressionRoot) securityExtension.getRootObject(); + return this.securityExtension.getRootObject(); } -} \ No newline at end of file + +} diff --git a/docs/manual/src/docs/asciidoc/_includes/reactive/oauth2/resource-server.adoc b/docs/manual/src/docs/asciidoc/_includes/reactive/oauth2/resource-server.adoc index 59eca2b899..1c5ff1821b 100644 --- a/docs/manual/src/docs/asciidoc/_includes/reactive/oauth2/resource-server.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/reactive/oauth2/resource-server.adoc @@ -896,9 +896,9 @@ fun jwtDecoder(): ReactiveJwtDecoder { [[webflux-oauth2resourceserver-opaque-minimaldependencies]] === Minimal Dependencies for Introspection -As described in <> most of Resource Server support is collected in `spring-security-oauth2-resource-server`. -However unless a custom <> is provided, the Resource Server will fallback to ReactiveOpaqueTokenIntrospector. -Meaning that both `spring-security-oauth2-resource-server` and `oauth2-oidc-sdk` are necessary in order to have a working minimal Resource Server that supports opaque Bearer Tokens. +As described in <> most of Resource Server support is collected in `spring-security-oauth2-resource-server`. +However unless a custom <> is provided, the Resource Server will fallback to ReactiveOpaqueTokenIntrospector. +Meaning that both `spring-security-oauth2-resource-server` and `oauth2-oidc-sdk` are necessary in order to have a working minimal Resource Server that supports opaque Bearer Tokens. Please refer to `spring-security-oauth2-resource-server` in order to determin the correct version for `oauth2-oidc-sdk`. [[webflux-oauth2resourceserver-opaque-minimalconfiguration]] @@ -1486,8 +1486,8 @@ public class JwtOpaqueTokenIntrospector implements ReactiveOpaqueTokenIntrospect public Mono convert(JWT jwt) { try { return Mono.just(jwt.getJWTClaimsSet()); - } catch (Exception e) { - return Mono.error(e); + } catch (Exception ex) { + return Mono.error(ex); } } } diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/architecture/exception-translation-filter.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/architecture/exception-translation-filter.adoc index 5fa7ecc452..67cad58ddf 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/architecture/exception-translation-filter.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/architecture/exception-translation-filter.adoc @@ -36,8 +36,8 @@ The pseudocode for `ExceptionTranslationFilter` looks something like this: ---- try { filterChain.doFilter(request, response); // <1> -} catch (AccessDeniedException | AuthenticationException e) { - if (!authenticated || e instanceof AuthenticationException) { +} catch (AccessDeniedException | AuthenticationException ex) { + if (!authenticated || ex instanceof AuthenticationException) { startAuthentication(); // <2> } else { accessDenied(); // <3> diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/integrations/servlet-api.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/integrations/servlet-api.adoc index 9e957f39a0..e57d545c15 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/integrations/servlet-api.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/integrations/servlet-api.adoc @@ -75,7 +75,7 @@ For example, the following would attempt to authenticate with the username "user ---- try { httpServletRequest.login("user","password"); -} catch(ServletException e) { +} catch(ServletException ex) { // fail to authenticate } ---- @@ -111,8 +111,8 @@ async.start(new Runnable() { asyncResponse.setStatus(HttpServletResponse.SC_OK); asyncResponse.getWriter().write(String.valueOf(authentication)); async.complete(); - } catch(Exception e) { - throw new RuntimeException(e); + } catch(Exception ex) { + throw new RuntimeException(ex); } } }); @@ -174,8 +174,8 @@ new Thread("AsyncThread") { // Write to and commit the httpServletResponse httpServletResponse.getOutputStream().flush(); - } catch (Exception e) { - e.printStackTrace(); + } catch (Exception ex) { + ex.printStackTrace(); } } }.start(); diff --git a/docs/manual/src/docs/asciidoc/_includes/servlet/oauth2/oauth2-resourceserver.adoc b/docs/manual/src/docs/asciidoc/_includes/servlet/oauth2/oauth2-resourceserver.adoc index e7aa3e2624..0aae178bcd 100644 --- a/docs/manual/src/docs/asciidoc/_includes/servlet/oauth2/oauth2-resourceserver.adoc +++ b/docs/manual/src/docs/asciidoc/_includes/servlet/oauth2/oauth2-resourceserver.adoc @@ -1055,9 +1055,9 @@ To do so, remember that `NimbusJwtDecoder` ships with a constructor that takes N [[oauth2resourceserver-opaque-minimaldependencies]] === Minimal Dependencies for Introspection -As described in <> most of Resource Server support is collected in `spring-security-oauth2-resource-server`. -However unless a custom <> is provided, the Resource Server will fallback to NimbusOpaqueTokenIntrospector. -Meaning that both `spring-security-oauth2-resource-server` and `oauth2-oidc-sdk` are necessary in order to have a working minimal Resource Server that supports opaque Bearer Tokens. +As described in <> most of Resource Server support is collected in `spring-security-oauth2-resource-server`. +However unless a custom <> is provided, the Resource Server will fallback to NimbusOpaqueTokenIntrospector. +Meaning that both `spring-security-oauth2-resource-server` and `oauth2-oidc-sdk` are necessary in order to have a working minimal Resource Server that supports opaque Bearer Tokens. Please refer to `spring-security-oauth2-resource-server` in order to determin the correct version for `oauth2-oidc-sdk`. [[oauth2resourceserver-opaque-minimalconfiguration]] @@ -1626,8 +1626,8 @@ public class JwtOpaqueTokenIntrospector implements OpaqueTokenIntrospector { try { Jwt jwt = this.jwtDecoder.decode(token); return new DefaultOAuth2AuthenticatedPrincipal(jwt.getClaims(), NO_AUTHORITIES); - } catch (JwtException e) { - throw new OAuth2IntrospectionException(e); + } catch (JwtException ex) { + throw new OAuth2IntrospectionException(ex); } } @@ -1899,8 +1899,8 @@ public class TenantJWSKeySelector private JWSKeySelector fromUri(String uri) { try { return JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL(new URL(uri)); <4> - } catch (Exception e) { - throw new IllegalArgumentException(e); + } catch (Exception ex) { + throw new IllegalArgumentException(ex); } } } diff --git a/etc/checkstyle/checkstyle-suppressions.xml b/etc/checkstyle/checkstyle-suppressions.xml new file mode 100644 index 0000000000..8ad15ce4af --- /dev/null +++ b/etc/checkstyle/checkstyle-suppressions.xml @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/etc/checkstyle/checkstyle.xml b/etc/checkstyle/checkstyle.xml index d7b392014d..48399d6ea3 100644 --- a/etc/checkstyle/checkstyle.xml +++ b/etc/checkstyle/checkstyle.xml @@ -1,51 +1,26 @@ - - - + + - + - - - - - + + + - - - - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - diff --git a/etc/checkstyle/header.txt b/etc/checkstyle/header.txt index 5e5d28b99f..a0eb45dbdf 100644 --- a/etc/checkstyle/header.txt +++ b/etc/checkstyle/header.txt @@ -1,5 +1,5 @@ ^\Q/*\E$ -^\Q * Copyright\E (\d{4}(\-\d{4})? the original author or authors\.|(\d{4}, )*(\d{4}) Acegi Technology Pty Limited)$ +^\Q * Copyright \E20\d\d(-20\d\d)?(, \d\d\d\d)*\ (Acegi Technology Pty Limited|the original author or authors.)$ ^\Q *\E$ ^\Q * Licensed under the Apache License, Version 2.0 (the "License");\E$ ^\Q * you may not use this file except in compliance with the License.\E$ @@ -13,4 +13,4 @@ ^\Q * See the License for the specific language governing permissions and\E$ ^\Q * limitations under the License.\E$ ^\Q */\E$ -^.*$ +^$ diff --git a/etc/checkstyle/suppressions.xml b/etc/checkstyle/suppressions.xml deleted file mode 100644 index 297f0624c9..0000000000 --- a/etc/checkstyle/suppressions.xml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - - - diff --git a/gradle.properties b/gradle.properties index 8030f9eddc..aca683ccc0 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,6 @@ aspectjVersion=1.9.3 gaeVersion=1.9.80 +springJavaformatVersion=0.0.24 springBootVersion=2.4.0-M1 version=5.4.0-SNAPSHOT kotlinVersion=1.3.72 diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java index ecf8f79605..f4f42b4c4a 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpNamespaceWithMultipleInterceptorsTests.java @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.integration; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.integration; import javax.servlet.http.HttpSession; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; @@ -33,6 +33,8 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import static org.assertj.core.api.Assertions.assertThat; + @ContextConfiguration(locations = { "/http-extra-fsi-app-context.xml" }) @RunWith(SpringJUnit4ClassRunner.class) public class HttpNamespaceWithMultipleInterceptorsTests { @@ -47,7 +49,7 @@ public class HttpNamespaceWithMultipleInterceptorsTests { request.setServletPath("/somefile.html"); request.setSession(createAuthenticatedSession("ROLE_0", "ROLE_1", "ROLE_2")); MockHttpServletResponse response = new MockHttpServletResponse(); - fcp.doFilter(request, response, new MockFilterChain()); + this.fcp.doFilter(request, response, new MockFilterChain()); assertThat(response.getStatus()).isEqualTo(200); } @@ -59,16 +61,15 @@ public class HttpNamespaceWithMultipleInterceptorsTests { request.setServletPath("/secure/somefile.html"); request.setSession(createAuthenticatedSession("ROLE_0")); MockHttpServletResponse response = new MockHttpServletResponse(); - fcp.doFilter(request, response, new MockFilterChain()); + this.fcp.doFilter(request, response, new MockFilterChain()); assertThat(response.getStatus()).isEqualTo(403); } public HttpSession createAuthenticatedSession(String... roles) { MockHttpSession session = new MockHttpSession(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("bob", "bobspassword", roles)); - session.setAttribute( - HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("bob", "bobspassword", roles)); + session.setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, SecurityContextHolder.getContext()); SecurityContextHolder.clearContext(); return session; diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpPathParameterStrippingTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpPathParameterStrippingTests.java index 79d1f2c4e7..eebe083e15 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/HttpPathParameterStrippingTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/HttpPathParameterStrippingTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; -import static org.assertj.core.api.Assertions.assertThat; +import javax.servlet.http.HttpSession; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; @@ -32,7 +34,7 @@ import org.springframework.security.web.firewall.RequestRejectedException; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import javax.servlet.http.HttpSession; +import static org.assertj.core.api.Assertions.assertThat; @ContextConfiguration(locations = { "/http-path-param-stripping-app-context.xml" }) @RunWith(SpringJUnit4ClassRunner.class) @@ -42,13 +44,12 @@ public class HttpPathParameterStrippingTests { private FilterChainProxy fcp; @Test(expected = RequestRejectedException.class) - public void securedFilterChainCannotBeBypassedByAddingPathParameters() - throws Exception { + public void securedFilterChainCannotBeBypassedByAddingPathParameters() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.setPathInfo("/secured;x=y/admin.html"); request.setSession(createAuthenticatedSession("ROLE_USER")); MockHttpServletResponse response = new MockHttpServletResponse(); - fcp.doFilter(request, response, new MockFilterChain()); + this.fcp.doFilter(request, response, new MockFilterChain()); } @Test(expected = RequestRejectedException.class) @@ -57,10 +58,9 @@ public class HttpPathParameterStrippingTests { request.setServletPath("/secured/admin.html;x=user.html"); request.setSession(createAuthenticatedSession("ROLE_USER")); MockHttpServletResponse response = new MockHttpServletResponse(); - fcp.doFilter(request, response, new MockFilterChain()); + this.fcp.doFilter(request, response, new MockFilterChain()); } - @Test(expected = RequestRejectedException.class) public void adminFilePatternCannotBeBypassedByAddingPathParametersWithPathInfo() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); @@ -68,16 +68,15 @@ public class HttpPathParameterStrippingTests { request.setPathInfo("/admin.html;x=user.html"); request.setSession(createAuthenticatedSession("ROLE_USER")); MockHttpServletResponse response = new MockHttpServletResponse(); - fcp.doFilter(request, response, new MockFilterChain()); + this.fcp.doFilter(request, response, new MockFilterChain()); assertThat(response.getStatus()).isEqualTo(403); } public HttpSession createAuthenticatedSession(String... roles) { MockHttpSession session = new MockHttpSession(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("bob", "bobspassword", roles)); - session.setAttribute( - HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken("bob", "bobspassword", roles)); + session.setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, SecurityContextHolder.getContext()); SecurityContextHolder.clearContext(); return session; diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/MultiAnnotationTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/MultiAnnotationTests.java index 724790c65a..90fe40b464 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/MultiAnnotationTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/MultiAnnotationTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -35,15 +37,17 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @ContextConfiguration(locations = { "/multi-sec-annotation-app-context.xml" }) @RunWith(SpringJUnit4ClassRunner.class) public class MultiAnnotationTests { - private final TestingAuthenticationToken joe_a = new TestingAuthenticationToken( - "joe", "pass", "ROLE_A"); - private final TestingAuthenticationToken joe_b = new TestingAuthenticationToken( - "joe", "pass", "ROLE_B"); + + private final TestingAuthenticationToken joe_a = new TestingAuthenticationToken("joe", "pass", "ROLE_A"); + + private final TestingAuthenticationToken joe_b = new TestingAuthenticationToken("joe", "pass", "ROLE_B"); @Autowired MultiAnnotationService service; + @Autowired PreAuthorizeService preService; + @Autowired SecuredService secService; @@ -55,49 +59,50 @@ public class MultiAnnotationTests { @Test(expected = AccessDeniedException.class) public void preAuthorizeDeniedIsDenied() { - SecurityContextHolder.getContext().setAuthentication(joe_a); - service.preAuthorizeDenyAllMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_a); + this.service.preAuthorizeDenyAllMethod(); } @Test(expected = AccessDeniedException.class) public void preAuthorizeRoleAIsDeniedIfRoleMissing() { - SecurityContextHolder.getContext().setAuthentication(joe_b); - service.preAuthorizeHasRoleAMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_b); + this.service.preAuthorizeHasRoleAMethod(); } @Test public void preAuthorizeRoleAIsAllowedIfRolePresent() { - SecurityContextHolder.getContext().setAuthentication(joe_a); - service.preAuthorizeHasRoleAMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_a); + this.service.preAuthorizeHasRoleAMethod(); } @Test public void securedAnonymousIsAllowed() { - SecurityContextHolder.getContext().setAuthentication(joe_a); - service.securedAnonymousMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_a); + this.service.securedAnonymousMethod(); } @Test(expected = AccessDeniedException.class) public void securedRoleAIsDeniedIfRoleMissing() { - SecurityContextHolder.getContext().setAuthentication(joe_b); - service.securedRoleAMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_b); + this.service.securedRoleAMethod(); } @Test public void securedRoleAIsAllowedIfRolePresent() { - SecurityContextHolder.getContext().setAuthentication(joe_a); - service.securedRoleAMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_a); + this.service.securedRoleAMethod(); } @Test(expected = AccessDeniedException.class) public void preAuthorizedOnlyServiceDeniesIfRoleMissing() { - SecurityContextHolder.getContext().setAuthentication(joe_b); - preService.preAuthorizedMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_b); + this.preService.preAuthorizedMethod(); } @Test(expected = AccessDeniedException.class) public void securedOnlyRoleAServiceDeniesIfRoleMissing() { - SecurityContextHolder.getContext().setAuthentication(joe_b); - secService.securedMethod(); + SecurityContextHolder.getContext().setAuthentication(this.joe_b); + this.secService.securedMethod(); } + } diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/SEC933ApplicationContextTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/SEC933ApplicationContextTests.java index 00c4fc62e2..dd9fc54bd4 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/SEC933ApplicationContextTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/SEC933ApplicationContextTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.integration; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.integration; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import static org.assertj.core.api.Assertions.assertThat; + @ContextConfiguration(locations = { "/sec-933-app-context.xml" }) @RunWith(SpringJUnit4ClassRunner.class) public class SEC933ApplicationContextTests { @@ -33,6 +35,7 @@ public class SEC933ApplicationContextTests { @Test public void testSimpleApplicationContextBootstrap() { - assertThat(userDetailsService).isNotNull(); + assertThat(this.userDetailsService).isNotNull(); } + } diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/SEC936ApplicationContextTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/SEC936ApplicationContextTests.java index 861a98139d..03414f32ea 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/SEC936ApplicationContextTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/SEC936ApplicationContextTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -32,15 +34,18 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @ContextConfiguration(locations = { "/sec-936-app-context.xml" }) @RunWith(SpringJUnit4ClassRunner.class) public class SEC936ApplicationContextTests { + @Autowired - /** SessionRegistry is used as the test service interface (nothing to do with the test) */ + /** + * SessionRegistry is used as the test service interface (nothing to do with the test) + */ private SessionRegistry sessionRegistry; @Test(expected = AccessDeniedException.class) public void securityInterceptorHandlesCallWithNoTargetObject() { - SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("bob", "bobspassword")); - sessionRegistry.getAllPrincipals(); + SecurityContextHolder.getContext() + .setAuthentication(new UsernamePasswordAuthenticationToken("bob", "bobspassword")); + this.sessionRegistry.getAllPrincipals(); } } diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/StubUserRepository.java b/itest/context/src/integration-test/java/org/springframework/security/integration/StubUserRepository.java index 099fe44f09..a61081d7e4 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/StubUserRepository.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/StubUserRepository.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; public class StubUserRepository implements UserRepository { + @Override public void doSomething() { } -} \ No newline at end of file + +} diff --git a/itest/context/src/integration-test/java/org/springframework/security/integration/python/PythonInterpreterBasedSecurityTests.java b/itest/context/src/integration-test/java/org/springframework/security/integration/python/PythonInterpreterBasedSecurityTests.java index df288f1f4e..b01eeb24c2 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/integration/python/PythonInterpreterBasedSecurityTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/integration/python/PythonInterpreterBasedSecurityTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; @@ -32,11 +34,12 @@ public class PythonInterpreterBasedSecurityTests { @Test public void serviceMethod() { - SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("bob", "bobspassword")); + SecurityContextHolder.getContext() + .setAuthentication(new UsernamePasswordAuthenticationToken("bob", "bobspassword")); // for (int i=0; i < 1000; i++) { - service.someMethod(); + this.service.someMethod(); // } } + } diff --git a/itest/context/src/integration-test/java/org/springframework/security/performance/FilterChainPerformanceTests.java b/itest/context/src/integration-test/java/org/springframework/security/performance/FilterChainPerformanceTests.java index 9155ef1794..eb0f760253 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/performance/FilterChainPerformanceTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/performance/FilterChainPerformanceTests.java @@ -13,10 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.performance; -import org.junit.*; +import java.util.Arrays; +import java.util.List; + +import javax.servlet.http.HttpSession; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.mock.web.MockFilterChain; @@ -33,24 +43,24 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.util.StopWatch; -import javax.servlet.http.HttpSession; -import java.util.*; - /** - * * @author Luke Taylor * @since 2.0 */ @ContextConfiguration(locations = { "/filter-chain-performance-app-context.xml" }) @RunWith(SpringJUnit4ClassRunner.class) public class FilterChainPerformanceTests { + // Adjust as required private static final int N_INVOCATIONS = 1; // 1000 + private static final int N_AUTHORITIES = 2; // 200 + private static StopWatch sw = new StopWatch("Filter Chain Performance Tests"); - private final UsernamePasswordAuthenticationToken user = new UsernamePasswordAuthenticationToken( - "bob", "bobspassword", createRoles(N_AUTHORITIES)); + private final UsernamePasswordAuthenticationToken user = new UsernamePasswordAuthenticationToken("bob", + "bobspassword", createRoles(N_AUTHORITIES)); + private HttpSession session; @Autowired @@ -63,10 +73,9 @@ public class FilterChainPerformanceTests { @Before public void createAuthenticatedSession() { - session = new MockHttpSession(); - SecurityContextHolder.getContext().setAuthentication(user); - session.setAttribute( - HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + this.session = new MockHttpSession(); + SecurityContextHolder.getContext().setAuthentication(this.user); + this.session.setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, SecurityContextHolder.getContext()); SecurityContextHolder.clearContext(); } @@ -83,7 +92,7 @@ public class FilterChainPerformanceTests { private MockHttpServletRequest createRequest(String url) { MockHttpServletRequest request = new MockHttpServletRequest(); - request.setSession(session); + request.setSession(this.session); request.setServletPath(url); request.setMethod("GET"); return request; @@ -93,21 +102,21 @@ public class FilterChainPerformanceTests { for (int i = 0; i < N_INVOCATIONS; i++) { MockHttpServletRequest request = createRequest("/somefile.html"); stack.doFilter(request, new MockHttpServletResponse(), new MockFilterChain()); - session = request.getSession(); + this.session = request.getSession(); } } @Test public void minimalStackInvocation() throws Exception { sw.start("Run with Minimal Filter Stack"); - runWithStack(minimalStack); + runWithStack(this.minimalStack); sw.stop(); } @Test public void fullStackInvocation() throws Exception { sw.start("Run with Full Filter Stack"); - runWithStack(fullStack); + runWithStack(this.fullStack); sw.stop(); } @@ -119,16 +128,14 @@ public class FilterChainPerformanceTests { public void provideDataOnScalingWithNumberOfAuthoritiesUserHas() throws Exception { StopWatch sw = new StopWatch("Scaling with nAuthorities"); for (int user = 0; user < N_AUTHORITIES / 10; user++) { - int nAuthorities = user == 0 ? 1 : user * 10; + int nAuthorities = (user != 0) ? user * 10 : 1; SecurityContextHolder.getContext().setAuthentication( - new UsernamePasswordAuthenticationToken("bob", "bobspassword", - createRoles(nAuthorities))); - session.setAttribute( - HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, + new UsernamePasswordAuthenticationToken("bob", "bobspassword", createRoles(nAuthorities))); + this.session.setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, SecurityContextHolder.getContext()); SecurityContextHolder.clearContext(); sw.start(nAuthorities + " authorities"); - runWithStack(minimalStack); + runWithStack(this.minimalStack); System.out.println(sw.shortSummary()); sw.stop(); } @@ -146,4 +153,5 @@ public class FilterChainPerformanceTests { return Arrays.asList(roles); } + } diff --git a/itest/context/src/integration-test/java/org/springframework/security/performance/ProtectPointcutPerformanceTests.java b/itest/context/src/integration-test/java/org/springframework/security/performance/ProtectPointcutPerformanceTests.java index 10a6d87e2c..bd3aa82544 100644 --- a/itest/context/src/integration-test/java/org/springframework/security/performance/ProtectPointcutPerformanceTests.java +++ b/itest/context/src/integration-test/java/org/springframework/security/performance/ProtectPointcutPerformanceTests.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.performance; -import static org.assertj.core.api.Assertions.fail; +package org.springframework.security.performance; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; @@ -30,6 +30,8 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.util.StopWatch; +import static org.assertj.core.api.Assertions.fail; + /** * @author Luke Taylor */ @@ -51,8 +53,7 @@ public class ProtectPointcutPerformanceTests implements ApplicationContextAware sw.start(); for (int i = 0; i < 1000; i++) { try { - SessionRegistry reg = (SessionRegistry) ctx.getBean( - "sessionRegistryPrototype"); + SessionRegistry reg = (SessionRegistry) this.ctx.getBean("sessionRegistryPrototype"); reg.getAllPrincipals(); fail("Expected AuthenticationCredentialsNotFoundException"); } @@ -64,8 +65,9 @@ public class ProtectPointcutPerformanceTests implements ApplicationContextAware } - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - ctx = applicationContext; + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.ctx = applicationContext; } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/UserDetailsServiceImpl.java b/itest/context/src/main/java/org/springframework/security/integration/UserDetailsServiceImpl.java index 8da64c0336..ad732690fb 100755 --- a/itest/context/src/main/java/org/springframework/security/integration/UserDetailsServiceImpl.java +++ b/itest/context/src/main/java/org/springframework/security/integration/UserDetailsServiceImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; import org.springframework.beans.factory.annotation.Required; @@ -25,6 +26,7 @@ public class UserDetailsServiceImpl implements UserDetailsService { @SuppressWarnings({ "unused", "FieldCanBeLocal" }) private UserRepository userRepository; + @Override @Transactional(readOnly = true) public UserDetails loadUserByUsername(String username) { return null; @@ -34,4 +36,5 @@ public class UserDetailsServiceImpl implements UserDetailsService { public void setUserRepository(UserRepository userRepository) { this.userRepository = userRepository; } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/UserRepository.java b/itest/context/src/main/java/org/springframework/security/integration/UserRepository.java index 4d76711c8e..0c74be0044 100755 --- a/itest/context/src/main/java/org/springframework/security/integration/UserRepository.java +++ b/itest/context/src/main/java/org/springframework/security/integration/UserRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; public interface UserRepository { diff --git a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationService.java b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationService.java index 8e2bb8874d..a40559ec4f 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationService.java +++ b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.multiannotation; import org.springframework.security.access.annotation.Secured; @@ -36,4 +37,5 @@ public interface MultiAnnotationService { @Secured("ROLE_A") void securedRoleAMethod(); + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationServiceImpl.java b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationServiceImpl.java index 50e8a1c282..1105f09340 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationServiceImpl.java +++ b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/MultiAnnotationServiceImpl.java @@ -13,19 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.multiannotation; public class MultiAnnotationServiceImpl implements MultiAnnotationService { + @Override public void preAuthorizeDenyAllMethod() { } + @Override public void preAuthorizeHasRoleAMethod() { } + @Override public void securedAnonymousMethod() { } + @Override public void securedRoleAMethod() { } diff --git a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeService.java b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeService.java index a12a0b4172..8573a7b764 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeService.java +++ b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeService.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.multiannotation; import org.springframework.security.access.prepost.PreAuthorize; /** - * * @author Luke Taylor */ public interface PreAuthorizeService { @PreAuthorize("hasRole('ROLE_A')") void preAuthorizedMethod(); + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeServiceImpl.java b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeServiceImpl.java index f8ff122e21..f7be682050 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeServiceImpl.java +++ b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/PreAuthorizeServiceImpl.java @@ -13,12 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.multiannotation; /** * @author Luke Taylor */ public class PreAuthorizeServiceImpl implements PreAuthorizeService { + + @Override public void preAuthorizedMethod() { } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredService.java b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredService.java index ee5f36b6f3..74bab792b2 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredService.java +++ b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredService.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.multiannotation; import org.springframework.security.access.annotation.Secured; /** - * * @author Luke Taylor */ public interface SecuredService { + @Secured("ROLE_A") void securedMethod(); + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredServiceImpl.java b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredServiceImpl.java index d580b5867c..a43a23c0e8 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredServiceImpl.java +++ b/itest/context/src/main/java/org/springframework/security/integration/multiannotation/SecuredServiceImpl.java @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.multiannotation; /** - * * @author Luke Taylor */ public class SecuredServiceImpl implements SecuredService { + + @Override public void securedMethod() { } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPostInvocationAdvice.java b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPostInvocationAdvice.java index 7c180e5850..8b5c700292 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPostInvocationAdvice.java +++ b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPostInvocationAdvice.java @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; import org.aopalliance.intercept.MethodInvocation; + import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.prepost.PostInvocationAttribute; import org.springframework.security.access.prepost.PostInvocationAuthorizationAdvice; import org.springframework.security.core.Authentication; -public class PythonInterpreterPostInvocationAdvice implements - PostInvocationAuthorizationAdvice { +public class PythonInterpreterPostInvocationAdvice implements PostInvocationAuthorizationAdvice { - public Object after(Authentication authentication, MethodInvocation mi, - PostInvocationAttribute pia, Object returnedObject) - throws AccessDeniedException { + @Override + public Object after(Authentication authentication, MethodInvocation mi, PostInvocationAttribute pia, + Object returnedObject) throws AccessDeniedException { return returnedObject; } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAdvice.java b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAdvice.java index d72861f797..475e778237 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAdvice.java +++ b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAdvice.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; import java.io.IOException; @@ -24,6 +25,7 @@ import org.aopalliance.intercept.MethodInvocation; import org.python.core.Py; import org.python.core.PyObject; import org.python.util.PythonInterpreter; + import org.springframework.core.LocalVariableTableParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.io.Resource; @@ -33,12 +35,12 @@ import org.springframework.security.access.prepost.PreInvocationAuthorizationAdv import org.springframework.security.core.Authentication; import org.springframework.util.ClassUtils; -public class PythonInterpreterPreInvocationAdvice implements - PreInvocationAuthorizationAdvice { +public class PythonInterpreterPreInvocationAdvice implements PreInvocationAuthorizationAdvice { + private final ParameterNameDiscoverer parameterNameDiscoverer = new LocalVariableTableParameterNameDiscoverer(); - public boolean before(Authentication authentication, MethodInvocation mi, - PreInvocationAttribute preAttr) { + @Override + public boolean before(Authentication authentication, MethodInvocation mi, PreInvocationAttribute preAttr) { PythonInterpreterPreInvocationAttribute pythonAttr = (PythonInterpreterPreInvocationAttribute) preAttr; String script = pythonAttr.getScript(); @@ -46,14 +48,13 @@ public class PythonInterpreterPreInvocationAdvice implements python.set("authentication", authentication); python.set("args", createArgumentMap(mi)); python.set("method", mi.getMethod().getName()); - Resource scriptResource = new PathMatchingResourcePatternResolver() - .getResource(script); + Resource scriptResource = new PathMatchingResourcePatternResolver().getResource(script); try { python.execfile(scriptResource.getInputStream()); } - catch (IOException e) { - throw new IllegalArgumentException("Couldn't run python script, " + script, e); + catch (IOException ex) { + throw new IllegalArgumentException("Couldn't run python script, " + script, ex); } PyObject allowed = python.get("allow"); @@ -68,9 +69,8 @@ public class PythonInterpreterPreInvocationAdvice implements private Map createArgumentMap(MethodInvocation mi) { Object[] args = mi.getArguments(); Object targetObject = mi.getThis(); - Method method = ClassUtils.getMostSpecificMethod(mi.getMethod(), - targetObject.getClass()); - String[] paramNames = parameterNameDiscoverer.getParameterNames(method); + Method method = ClassUtils.getMostSpecificMethod(mi.getMethod(), targetObject.getClass()); + String[] paramNames = this.parameterNameDiscoverer.getParameterNames(method); Map argMap = new HashMap<>(); for (int i = 0; i < args.length; i++) { @@ -79,4 +79,5 @@ public class PythonInterpreterPreInvocationAdvice implements return argMap; } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAttribute.java b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAttribute.java index d033af1d4f..b50e2f95f5 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAttribute.java +++ b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPreInvocationAttribute.java @@ -13,22 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; import org.springframework.security.access.prepost.PreInvocationAttribute; public class PythonInterpreterPreInvocationAttribute implements PreInvocationAttribute { + private final String script; PythonInterpreterPreInvocationAttribute(String script) { this.script = script; } + @Override public String getAttribute() { return null; } public String getScript() { - return script; + return this.script; } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPrePostInvocationAttributeFactory.java b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPrePostInvocationAttributeFactory.java index 84c11e94aa..c14516b323 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPrePostInvocationAttributeFactory.java +++ b/itest/context/src/main/java/org/springframework/security/integration/python/PythonInterpreterPrePostInvocationAttributeFactory.java @@ -13,27 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; import org.python.util.PythonInterpreter; + import org.springframework.security.access.prepost.PostInvocationAttribute; import org.springframework.security.access.prepost.PreInvocationAttribute; import org.springframework.security.access.prepost.PrePostInvocationAttributeFactory; -public class PythonInterpreterPrePostInvocationAttributeFactory implements - PrePostInvocationAttributeFactory { +public class PythonInterpreterPrePostInvocationAttributeFactory implements PrePostInvocationAttributeFactory { public PythonInterpreterPrePostInvocationAttributeFactory() { PythonInterpreter.initialize(System.getProperties(), null, new String[] {}); } - public PreInvocationAttribute createPreInvocationAttribute(String preFilterAttribute, - String filterObject, String preAuthorizeAttribute) { + @Override + public PreInvocationAttribute createPreInvocationAttribute(String preFilterAttribute, String filterObject, + String preAuthorizeAttribute) { return new PythonInterpreterPreInvocationAttribute(preAuthorizeAttribute); } - public PostInvocationAttribute createPostInvocationAttribute( - String postFilterAttribute, String postAuthorizeAttribute) { + @Override + public PostInvocationAttribute createPostInvocationAttribute(String postFilterAttribute, + String postAuthorizeAttribute) { return null; } + } diff --git a/itest/context/src/main/java/org/springframework/security/integration/python/TestService.java b/itest/context/src/main/java/org/springframework/security/integration/python/TestService.java index 7bc9b50628..6a97c9028b 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/python/TestService.java +++ b/itest/context/src/main/java/org/springframework/security/integration/python/TestService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; import org.springframework.security.access.prepost.PreAuthorize; diff --git a/itest/context/src/main/java/org/springframework/security/integration/python/TestServiceImpl.java b/itest/context/src/main/java/org/springframework/security/integration/python/TestServiceImpl.java index 98081e347e..dda588cb7e 100644 --- a/itest/context/src/main/java/org/springframework/security/integration/python/TestServiceImpl.java +++ b/itest/context/src/main/java/org/springframework/security/integration/python/TestServiceImpl.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration.python; public class TestServiceImpl implements TestService { + @Override public void someMethod() { System.out.print("Invoked someMethod()"); } diff --git a/itest/misc/src/integration-test/java/org/springframework/security/context/SecurityContextHolderMTTests.java b/itest/misc/src/integration-test/java/org/springframework/security/context/SecurityContextHolderMTTests.java index bd79033cb3..4a09b0afc3 100644 --- a/itest/misc/src/integration-test/java/org/springframework/security/context/SecurityContextHolderMTTests.java +++ b/itest/misc/src/integration-test/java/org/springframework/security/context/SecurityContextHolderMTTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.context; import java.util.Random; diff --git a/itest/web/src/integration-test/java/org/springframework/security/integration/AbstractWebServerIntegrationTests.java b/itest/web/src/integration-test/java/org/springframework/security/integration/AbstractWebServerIntegrationTests.java index 2d61b7ee89..8f1ef5b8e1 100644 --- a/itest/web/src/integration-test/java/org/springframework/security/integration/AbstractWebServerIntegrationTests.java +++ b/itest/web/src/integration-test/java/org/springframework/security/integration/AbstractWebServerIntegrationTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; - import org.junit.After; + import org.springframework.context.ConfigurableApplicationContext; import org.springframework.mock.web.MockServletContext; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.support.XmlWebApplicationContext; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; + /** * Base class which allows the application to be started with a particular Spring * application context. Subclasses override the getContextConfigLocations method @@ -33,12 +35,13 @@ import org.springframework.web.context.support.XmlWebApplicationContext; * @author Luke Taylor */ public abstract class AbstractWebServerIntegrationTests { + protected ConfigurableApplicationContext context; @After public void close() { - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @@ -53,9 +56,11 @@ public abstract class AbstractWebServerIntegrationTests { context.refresh(); this.context = context; - return MockMvcBuilders - .webAppContextSetup(context) - .apply(springSecurity()) - .build(); + // @formatter:off + return MockMvcBuilders.webAppContextSetup(context) + .apply(springSecurity()) + .build(); + // @formatter:on } + } diff --git a/itest/web/src/integration-test/java/org/springframework/security/integration/BasicAuthenticationTests.java b/itest/web/src/integration-test/java/org/springframework/security/integration/BasicAuthenticationTests.java index ddec15c81c..3b69dd4899 100644 --- a/itest/web/src/integration-test/java/org/springframework/security/integration/BasicAuthenticationTests.java +++ b/itest/web/src/integration-test/java/org/springframework/security/integration/BasicAuthenticationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; import org.junit.Test; @@ -28,17 +29,23 @@ public class BasicAuthenticationTests extends AbstractWebServerIntegrationTests @Test public void httpBasicWhenAuthenticationRequiredAndNotAuthenticatedThen401() throws Exception { - MockMvc mockMvc = createMockMvc("classpath:/spring/http-security-basic.xml", "classpath:/spring/in-memory-provider.xml", "classpath:/spring/testapp-servlet.xml"); + MockMvc mockMvc = createMockMvc("classpath:/spring/http-security-basic.xml", + "classpath:/spring/in-memory-provider.xml", "classpath:/spring/testapp-servlet.xml"); + // @formatter:off mockMvc.perform(get("/secure/index")) - .andExpect(status().isUnauthorized()); + .andExpect(status().isUnauthorized()); + // @formatter:on } @Test public void httpBasicWhenProvidedThen200() throws Exception { - MockMvc mockMvc = createMockMvc("classpath:/spring/http-security-basic.xml", "classpath:/spring/in-memory-provider.xml", "classpath:/spring/testapp-servlet.xml"); + MockMvc mockMvc = createMockMvc("classpath:/spring/http-security-basic.xml", + "classpath:/spring/in-memory-provider.xml", "classpath:/spring/testapp-servlet.xml"); + // @formatter:off MockHttpServletRequestBuilder request = get("/secure/index") .with(httpBasic("johnc", "johncspassword")); - mockMvc.perform(request) - .andExpect(status().isOk()); + // @formatter:on + mockMvc.perform(request).andExpect(status().isOk()); } + } diff --git a/itest/web/src/integration-test/java/org/springframework/security/integration/ConcurrentSessionManagementTests.java b/itest/web/src/integration-test/java/org/springframework/security/integration/ConcurrentSessionManagementTests.java index bf6bb9153c..6d93da3973 100644 --- a/itest/web/src/integration-test/java/org/springframework/security/integration/ConcurrentSessionManagementTests.java +++ b/itest/web/src/integration-test/java/org/springframework/security/integration/ConcurrentSessionManagementTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.integration; import java.util.Collections; @@ -46,55 +47,63 @@ public class ConcurrentSessionManagementTests extends AbstractWebServerIntegrati final MockHttpSession session1 = new MockHttpSession(); final MockHttpSession session2 = new MockHttpSession(); - MockMvc mockMvc = createMockMvc("classpath:/spring/http-security-concurrency.xml", "classpath:/spring/in-memory-provider.xml", "classpath:/spring/testapp-servlet.xml"); + MockMvc mockMvc = createMockMvc("classpath:/spring/http-security-concurrency.xml", + "classpath:/spring/in-memory-provider.xml", "classpath:/spring/testapp-servlet.xml"); + // @formatter:off mockMvc.perform(get("/secure/index").session(session1)) - .andExpect(status().is3xxRedirection()); + .andExpect(status().is3xxRedirection()); + // @formatter:on - MockHttpServletRequestBuilder login1 = login() - .session(session1); - mockMvc. - perform(login1) - .andExpect(authenticated().withUsername("jimi")); + MockHttpServletRequestBuilder login1 = login().session(session1); + mockMvc.perform(login1).andExpect(authenticated().withUsername("jimi")); - - MockHttpServletRequestBuilder login2 = login() - .session(session2); + MockHttpServletRequestBuilder login2 = login().session(session2); + // @formatter:off mockMvc.perform(login2) - .andExpect(redirectedUrl("/login.jsp?login_error=true")); + .andExpect(redirectedUrl("/login.jsp?login_error=true")); + // @formatter:on Exception exception = (Exception) session2.getAttribute("SPRING_SECURITY_LAST_EXCEPTION"); assertThat(exception).isNotNull(); assertThat(exception.getMessage()).contains("Maximum sessions of 1 for this principal exceeded"); // Now logout to kill first session + // @formatter:off mockMvc.perform(post("/logout").with(csrf())) - .andExpect(status().is3xxRedirection()) - .andDo(result -> context.publishEvent(new SessionDestroyedEvent(session1) { - @Override - public List getSecurityContexts() { - return Collections.emptyList(); - } + .andExpect(status().is3xxRedirection()) + .andDo((result) -> this.context.publishEvent(new SessionDestroyedEvent(session1) { + @Override + public List getSecurityContexts() { + return Collections.emptyList(); + } - @Override - public String getId() { - return session1.getId(); - } - })); + @Override + public String getId() { + return session1.getId(); + } + })); + // @formatter:on // Try second session again - login2 = login() - .session(session2); + login2 = login().session(session2); + // @formatter:off mockMvc.perform(login2) - .andExpect(authenticated().withUsername("jimi")); + .andExpect(authenticated().withUsername("jimi")); + // @formatter:on + // @formatter:off mockMvc.perform(get("/secure/index").session(session2)) - .andExpect(content().string(containsString("A Secure Page"))); + .andExpect(content().string(containsString("A Secure Page"))); + // @formatter:on } private MockHttpServletRequestBuilder login() { + // @formatter:off return post("/login") - .param("username", "jimi") - .param("password", "jimispassword") - .with(csrf()); + .param("username", "jimi") + .param("password", "jimispassword") + .with(csrf()); + // @formatter:on } + } diff --git a/itest/web/src/main/java/org/springframework/security/itest/web/TestController.java b/itest/web/src/main/java/org/springframework/security/itest/web/TestController.java index 307ce68f6c..e66df34c57 100644 --- a/itest/web/src/main/java/org/springframework/security/itest/web/TestController.java +++ b/itest/web/src/main/java/org/springframework/security/itest/web/TestController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.itest.web; import org.springframework.web.bind.annotation.RequestMapping; @@ -38,4 +39,5 @@ public class TestController { public String secure() { return "A Secure Page"; } + } diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/ApacheDsContainerConfig.java b/ldap/src/integration-test/java/org/springframework/security/ldap/ApacheDsContainerConfig.java index ec6cd0fc7a..b1a2aca60d 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/ApacheDsContainerConfig.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/ApacheDsContainerConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap; import javax.annotation.PreDestroy; @@ -32,16 +33,15 @@ public class ApacheDsContainerConfig { @Bean ApacheDSContainer ldapContainer() throws Exception { - this.container = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:test-server.ldif"); + this.container = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif"); this.container.setPort(0); return this.container; } @Bean ContextSource contextSource(ApacheDSContainer ldapContainer) throws Exception { - return new DefaultSpringSecurityContextSource("ldap://127.0.0.1:" - + ldapContainer.getLocalPort() + "/dc=springframework,dc=org"); + return new DefaultSpringSecurityContextSource( + "ldap://127.0.0.1:" + ldapContainer.getLocalPort() + "/dc=springframework,dc=org"); } @PreDestroy diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/DefaultSpringSecurityContextSourceTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/DefaultSpringSecurityContextSourceTests.java index e2da339fcf..85d9c99e9f 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/DefaultSpringSecurityContextSourceTests.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/DefaultSpringSecurityContextSourceTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.ldap; import java.util.ArrayList; import java.util.Hashtable; @@ -32,6 +31,8 @@ import org.springframework.ldap.core.support.AbstractContextSource; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Luke Taylor * @author Eddú Meléndez @@ -53,8 +54,7 @@ public class DefaultSpringSecurityContextSourceTests { @Test public void supportsSpacesInUrl() { - new DefaultSpringSecurityContextSource( - "ldap://myhost:10389/dc=spring%20framework,dc=org"); + new DefaultSpringSecurityContextSource("ldap://myhost:10389/dc=spring%20framework,dc=org"); } @Test @@ -64,8 +64,8 @@ public class DefaultSpringSecurityContextSourceTests { ctxSrc.setUserDn("manager"); ctxSrc.setPassword("password"); ctxSrc.afterPropertiesSet(); - assertThat(ctxSrc.getAuthenticatedEnvForTest("manager", "password")).containsKey( - AbstractContextSource.SUN_LDAP_POOLING_FLAG); + assertThat(ctxSrc.getAuthenticatedEnvForTest("manager", "password")) + .containsKey(AbstractContextSource.SUN_LDAP_POOLING_FLAG); } @Test @@ -75,47 +75,41 @@ public class DefaultSpringSecurityContextSourceTests { ctxSrc.setUserDn("manager"); ctxSrc.setPassword("password"); ctxSrc.afterPropertiesSet(); - assertThat(ctxSrc.getAuthenticatedEnvForTest("user", "password")).doesNotContainKey( - AbstractContextSource.SUN_LDAP_POOLING_FLAG); + assertThat(ctxSrc.getAuthenticatedEnvForTest("user", "password")) + .doesNotContainKey(AbstractContextSource.SUN_LDAP_POOLING_FLAG); } // SEC-1145. Confirms that there is no issue here with pooling. @Test(expected = AuthenticationException.class) - public void cantBindWithWrongPasswordImmediatelyAfterSuccessfulBind() - throws Exception { + public void cantBindWithWrongPasswordImmediatelyAfterSuccessfulBind() throws Exception { DirContext ctx = null; try { - ctx = this.contextSource.getContext( - "uid=Bob,ou=people,dc=springframework,dc=org", "bobspassword"); + ctx = this.contextSource.getContext("uid=Bob,ou=people,dc=springframework,dc=org", "bobspassword"); } - catch (Exception e) { + catch (Exception ex) { } assertThat(ctx).isNotNull(); // com.sun.jndi.ldap.LdapPoolManager.showStats(System.out); ctx.close(); // com.sun.jndi.ldap.LdapPoolManager.showStats(System.out); // Now get it gain, with wrong password. Should fail. - ctx = this.contextSource.getContext( - "uid=Bob,ou=people,dc=springframework,dc=org", "wrongpassword"); + ctx = this.contextSource.getContext("uid=Bob,ou=people,dc=springframework,dc=org", "wrongpassword"); ctx.close(); } @Test public void serverUrlWithSpacesIsSupported() { DefaultSpringSecurityContextSource contextSource = new DefaultSpringSecurityContextSource( - this.contextSource.getUrls()[0] - + "ou=space%20cadets,dc=springframework,dc=org"); + this.contextSource.getUrls()[0] + "ou=space%20cadets,dc=springframework,dc=org"); contextSource.afterPropertiesSet(); - contextSource.getContext( - "uid=space cadet,ou=space cadets,dc=springframework,dc=org", - "spacecadetspassword"); + contextSource.getContext("uid=space cadet,ou=space cadets,dc=springframework,dc=org", "spacecadetspassword"); } @Test(expected = IllegalArgumentException.class) public void instantiationFailsWithEmptyServerList() { List serverUrls = new ArrayList<>(); - DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource( - serverUrls, "dc=springframework,dc=org"); + DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource(serverUrls, + "dc=springframework,dc=org"); ctxSrc.afterPropertiesSet(); } @@ -125,8 +119,8 @@ public class DefaultSpringSecurityContextSourceTests { serverUrls.add("ldap://foo:789"); serverUrls.add("ldap://bar:389"); serverUrls.add("ldaps://blah:636"); - DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource( - serverUrls, "dc=springframework,dc=org"); + DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource(serverUrls, + "dc=springframework,dc=org"); assertThat(ctxSrc.isAnonymousReadOnly()).isFalse(); assertThat(ctxSrc.isPooled()).isTrue(); @@ -140,8 +134,7 @@ public class DefaultSpringSecurityContextSourceTests { serverUrls.add("ldap://foo:789"); serverUrls.add("ldap://bar:389"); serverUrls.add("ldaps://blah:636"); - DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource( - serverUrls, baseDn); + DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource(serverUrls, baseDn); assertThat(ctxSrc.isAnonymousReadOnly()).isFalse(); assertThat(ctxSrc.isPooled()).isTrue(); @@ -154,12 +147,12 @@ public class DefaultSpringSecurityContextSourceTests { serverUrls.add("ldaps://blah:636/"); // this url should be rejected because the root DN goes into a separate parameter serverUrls.add("ldap://bar:389/dc=foobar,dc=org"); - DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource( - serverUrls, "dc=springframework,dc=org"); + DefaultSpringSecurityContextSource ctxSrc = new DefaultSpringSecurityContextSource(serverUrls, + "dc=springframework,dc=org"); } - static class EnvExposingDefaultSpringSecurityContextSource extends - DefaultSpringSecurityContextSource { + static class EnvExposingDefaultSpringSecurityContextSource extends DefaultSpringSecurityContextSource { + EnvExposingDefaultSpringSecurityContextSource(String providerUrl) { super(providerUrl); } @@ -168,5 +161,7 @@ public class DefaultSpringSecurityContextSourceTests { Hashtable getAuthenticatedEnvForTest(String userDn, String password) { return getAuthenticatedEnv(userDn, password); } + } + } diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateITests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateITests.java index 697a9f9bfb..aadb2faaf9 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateITests.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateITests.java @@ -16,8 +16,6 @@ package org.springframework.security.ldap; -import static org.assertj.core.api.Assertions.*; - import java.util.List; import java.util.Map; import java.util.Set; @@ -28,7 +26,8 @@ import javax.naming.directory.DirContext; import javax.naming.directory.SearchControls; import javax.naming.directory.SearchResult; -import org.junit.*; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; @@ -38,6 +37,9 @@ import org.springframework.security.crypto.codec.Utf8; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * @author Luke Taylor * @author Eddú Meléndez @@ -45,41 +47,35 @@ import org.springframework.test.context.junit4.SpringRunner; @RunWith(SpringRunner.class) @ContextConfiguration(classes = ApacheDsContainerConfig.class) public class SpringSecurityLdapTemplateITests { - // ~ Instance fields - // ================================================================================================ @Autowired private DefaultSpringSecurityContextSource contextSource; - private SpringSecurityLdapTemplate template; - // ~ Methods - // ======================================================================================================== + private SpringSecurityLdapTemplate template; @Before public void setUp() { - template = new SpringSecurityLdapTemplate(this.contextSource); + this.template = new SpringSecurityLdapTemplate(this.contextSource); } @Test public void compareOfCorrectValueSucceeds() { - assertThat(template.compare("uid=bob,ou=people", "uid", "bob")).isTrue(); + assertThat(this.template.compare("uid=bob,ou=people", "uid", "bob")).isTrue(); } @Test public void compareOfCorrectByteValueSucceeds() { - assertThat(template.compare("uid=bob,ou=people", "userPassword", - Utf8.encode("bobspassword"))).isTrue(); + assertThat(this.template.compare("uid=bob,ou=people", "userPassword", Utf8.encode("bobspassword"))).isTrue(); } @Test public void compareOfWrongByteValueFails() { - assertThat(template.compare("uid=bob,ou=people", "userPassword", - Utf8.encode("wrongvalue"))).isFalse(); + assertThat(this.template.compare("uid=bob,ou=people", "userPassword", Utf8.encode("wrongvalue"))).isFalse(); } @Test public void compareOfWrongValueFails() { - assertThat(template.compare("uid=bob,ou=people", "uid", "wrongvalue")).isFalse(); + assertThat(this.template.compare("uid=bob,ou=people", "uid", "wrongvalue")).isFalse(); } // @Test @@ -95,7 +91,7 @@ public class SpringSecurityLdapTemplateITests { @Test public void namingExceptionIsTranslatedCorrectly() { try { - template.executeReadOnly((ContextExecutor) dirContext -> { + this.template.executeReadOnly((ContextExecutor) (dirContext) -> { throw new NamingException(); }); fail("Expected UncategorizedLdapException on NamingException"); @@ -108,8 +104,8 @@ public class SpringSecurityLdapTemplateITests { public void roleSearchReturnsCorrectNumberOfRoles() { String param = "uid=ben,ou=people,dc=springframework,dc=org"; - Set values = template.searchForSingleAttributeValues("ou=groups", - "(member={0})", new String[] { param }, "ou"); + Set values = this.template.searchForSingleAttributeValues("ou=groups", "(member={0})", + new String[] { param }, "ou"); assertThat(values).as("Expected 3 results from search").hasSize(3); assertThat(values.contains("developer")).isTrue(); @@ -119,14 +115,12 @@ public class SpringSecurityLdapTemplateITests { @Test public void testMultiAttributeRetrievalWithNullAttributeNames() { - Set>> values = template - .searchForMultipleAttributeValues("ou=people", "(uid={0})", - new String[] { "bob" }, null); + Set>> values = this.template.searchForMultipleAttributeValues("ou=people", "(uid={0})", + new String[] { "bob" }, null); assertThat(values).hasSize(1); Map> record = values.iterator().next(); assertAttributeValue(record, "uid", "bob"); - assertAttributeValue(record, "objectclass", "top", "person", - "organizationalPerson", "inetOrgPerson"); + assertAttributeValue(record, "objectclass", "top", "person", "organizationalPerson", "inetOrgPerson"); assertAttributeValue(record, "cn", "Bob Hamilton"); assertAttributeValue(record, "sn", "Hamilton"); assertThat(record.containsKey("userPassword")).isFalse(); @@ -134,14 +128,12 @@ public class SpringSecurityLdapTemplateITests { @Test public void testMultiAttributeRetrievalWithZeroLengthAttributeNames() { - Set>> values = template - .searchForMultipleAttributeValues("ou=people", "(uid={0})", - new String[] { "bob" }, new String[0]); + Set>> values = this.template.searchForMultipleAttributeValues("ou=people", "(uid={0})", + new String[] { "bob" }, new String[0]); assertThat(values).hasSize(1); Map> record = values.iterator().next(); assertAttributeValue(record, "uid", "bob"); - assertAttributeValue(record, "objectclass", "top", "person", - "organizationalPerson", "inetOrgPerson"); + assertAttributeValue(record, "objectclass", "top", "person", "organizationalPerson", "inetOrgPerson"); assertAttributeValue(record, "cn", "Bob Hamilton"); assertAttributeValue(record, "sn", "Hamilton"); assertThat(record.containsKey("userPassword")).isFalse(); @@ -149,9 +141,8 @@ public class SpringSecurityLdapTemplateITests { @Test public void testMultiAttributeRetrievalWithSpecifiedAttributeNames() { - Set>> values = template - .searchForMultipleAttributeValues("ou=people", "(uid={0})", - new String[] { "bob" }, new String[] { "uid", "cn", "sn" }); + Set>> values = this.template.searchForMultipleAttributeValues("ou=people", "(uid={0})", + new String[] { "bob" }, new String[] { "uid", "cn", "sn" }); assertThat(values).hasSize(1); Map> record = values.iterator().next(); assertAttributeValue(record, "uid", "bob"); @@ -161,8 +152,7 @@ public class SpringSecurityLdapTemplateITests { assertThat(record.containsKey("objectclass")).isFalse(); } - protected void assertAttributeValue(Map> record, - String attributeName, String... values) { + protected void assertAttributeValue(Map> record, String attributeName, String... values) { assertThat(record.containsKey(attributeName)).isTrue(); assertThat(record.get(attributeName)).hasSize(values.length); for (int i = 0; i < values.length; i++) { @@ -174,8 +164,8 @@ public class SpringSecurityLdapTemplateITests { public void testRoleSearchForMissingAttributeFailsGracefully() { String param = "uid=ben,ou=people,dc=springframework,dc=org"; - Set values = template.searchForSingleAttributeValues("ou=groups", - "(member={0})", new String[] { param }, "mail"); + Set values = this.template.searchForSingleAttributeValues("ou=groups", "(member={0})", + new String[] { param }, "mail"); assertThat(values).isEmpty(); } @@ -184,8 +174,8 @@ public class SpringSecurityLdapTemplateITests { public void roleSearchWithEscapedCharacterSucceeds() { String param = "cn=mouse\\, jerry,ou=people,dc=springframework,dc=org"; - Set values = template.searchForSingleAttributeValues("ou=groups", - "(member={0})", new String[] { param }, "cn"); + Set values = this.template.searchForSingleAttributeValues("ou=groups", "(member={0})", + new String[] { param }, "cn"); assertThat(values).hasSize(1); } @@ -205,9 +195,8 @@ public class SpringSecurityLdapTemplateITests { controls.setReturningAttributes(null); String param = "cn=mouse\\, jerry,ou=people,dc=springframework,dc=org"; - javax.naming.NamingEnumeration results = ctx.search( - "ou=groups,dc=springframework,dc=org", "(member={0})", - new String[] { param }, controls); + javax.naming.NamingEnumeration results = ctx.search("ou=groups,dc=springframework,dc=org", + "(member={0})", new String[] { param }, controls); assertThat(results.hasMore()).as("Expected a result").isTrue(); } @@ -216,7 +205,7 @@ public class SpringSecurityLdapTemplateITests { public void searchForSingleEntryWithEscapedCharsInDnSucceeds() { String param = "mouse, jerry"; - template.searchForSingleEntry("ou=people", "(cn={0})", new String[] { param }); + this.template.searchForSingleEntry("ou=people", "(cn={0})", new String[] { param }); } } diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/BindAuthenticatorTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/BindAuthenticatorTests.java index 072cfebb96..e574a396f5 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/BindAuthenticatorTests.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/BindAuthenticatorTests.java @@ -35,7 +35,6 @@ import org.springframework.test.context.junit4.SpringRunner; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; - /** * Tests for {@link BindAuthenticator}. * @@ -45,16 +44,13 @@ import static org.assertj.core.api.Assertions.fail; @RunWith(SpringRunner.class) @ContextConfiguration(classes = ApacheDsContainerConfig.class) public class BindAuthenticatorTests { - // ~ Instance fields - // ================================================================================================ @Autowired private DefaultSpringSecurityContextSource contextSource; - private BindAuthenticator authenticator; - private Authentication bob; - // ~ Methods - // ======================================================================================================== + private BindAuthenticator authenticator; + + private Authentication bob; @Before public void setUp() { @@ -66,19 +62,16 @@ public class BindAuthenticatorTests { @Test(expected = BadCredentialsException.class) public void emptyPasswordIsRejected() { - this.authenticator - .authenticate(new UsernamePasswordAuthenticationToken("jen", "")); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("jen", "")); } @Test public void testAuthenticationWithCorrectPasswordSucceeds() { - this.authenticator.setUserDnPatterns( - new String[] { "uid={0},ou=people", "cn={0},ou=people" }); + this.authenticator.setUserDnPatterns(new String[] { "uid={0},ou=people", "cn={0},ou=people" }); DirContextOperations user = this.authenticator.authenticate(this.bob); assertThat(user.getStringAttribute("uid")).isEqualTo("bob"); - this.authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "mouse, jerry", "jerryspassword")); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("mouse, jerry", "jerryspassword")); } @Test @@ -86,8 +79,7 @@ public class BindAuthenticatorTests { this.authenticator.setUserDnPatterns(new String[] { "uid={0},ou=people" }); try { - this.authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "nonexistentsuser", "password")); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("nonexistentsuser", "password")); fail("Shouldn't be able to bind with invalid username"); } catch (BadCredentialsException expected) { @@ -98,28 +90,21 @@ public class BindAuthenticatorTests { public void testAuthenticationWithUserSearch() throws Exception { // DirContextAdapter ctx = new DirContextAdapter(new // DistinguishedName("uid=bob,ou=people")); - this.authenticator.setUserSearch(new FilterBasedLdapUserSearch("ou=people", - "(uid={0})", this.contextSource)); + this.authenticator.setUserSearch(new FilterBasedLdapUserSearch("ou=people", "(uid={0})", this.contextSource)); this.authenticator.afterPropertiesSet(); DirContextOperations result = this.authenticator.authenticate(this.bob); - //ensure we are getting the same attributes back + // ensure we are getting the same attributes back assertThat(result.getStringAttribute("cn")).isEqualTo("Bob Hamilton"); // SEC-1444 - this.authenticator.setUserSearch(new FilterBasedLdapUserSearch("ou=people", - "(cn={0})", this.contextSource)); - this.authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "mouse, jerry", "jerryspassword")); - this.authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "slash/guy", "slashguyspassword")); + this.authenticator.setUserSearch(new FilterBasedLdapUserSearch("ou=people", "(cn={0})", this.contextSource)); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("mouse, jerry", "jerryspassword")); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("slash/guy", "slashguyspassword")); // SEC-1661 - this.authenticator.setUserSearch(new FilterBasedLdapUserSearch( - "ou=\\\"quoted people\\\"", "(cn={0})", this.contextSource)); - this.authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "quote\"guy", "quoteguyspassword")); this.authenticator.setUserSearch( - new FilterBasedLdapUserSearch("", "(cn={0})", this.contextSource)); - this.authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "quote\"guy", "quoteguyspassword")); + new FilterBasedLdapUserSearch("ou=\\\"quoted people\\\"", "(cn={0})", this.contextSource)); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("quote\"guy", "quoteguyspassword")); + this.authenticator.setUserSearch(new FilterBasedLdapUserSearch("", "(cn={0})", this.contextSource)); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("quote\"guy", "quoteguyspassword")); } /* @@ -148,8 +133,7 @@ public class BindAuthenticatorTests { this.authenticator.setUserDnPatterns(new String[] { "uid={0},ou=people" }); try { - this.authenticator.authenticate( - new UsernamePasswordAuthenticationToken("bob", "wrongpassword")); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("bob", "wrongpassword")); fail("Shouldn't be able to bind with wrong password"); } catch (BadCredentialsException expected) { @@ -159,7 +143,7 @@ public class BindAuthenticatorTests { @Test public void testUserDnPatternReturnsCorrectDn() { this.authenticator.setUserDnPatterns(new String[] { "cn={0},ou=people" }); - assertThat(this.authenticator.getUserDns("Joe").get(0)) - .isEqualTo("cn=Joe,ou=people"); + assertThat(this.authenticator.getUserDns("Joe").get(0)).isEqualTo("cn=Joe,ou=people"); } + } diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorTests.java index 1fe8f97f8f..54ddd40d8c 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorTests.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorTests.java @@ -16,10 +16,13 @@ package org.springframework.security.ldap.authentication; -import org.junit.*; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.ldap.core.DirContextAdapter; +import org.springframework.ldap.core.DistinguishedName; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; @@ -27,15 +30,13 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.crypto.keygen.KeyGenerators; import org.springframework.security.crypto.password.LdapShaPasswordEncoder; import org.springframework.security.crypto.password.NoOpPasswordEncoder; - -import org.springframework.ldap.core.DirContextAdapter; -import org.springframework.ldap.core.DistinguishedName; import org.springframework.security.ldap.ApacheDsContainerConfig; import org.springframework.security.ldap.DefaultSpringSecurityContextSource; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** * Tests for {@link PasswordComparisonAuthenticator}. @@ -46,44 +47,42 @@ import static org.assertj.core.api.Assertions.*; @RunWith(SpringRunner.class) @ContextConfiguration(classes = ApacheDsContainerConfig.class) public class PasswordComparisonAuthenticatorTests { - // ~ Instance fields - // ================================================================================================ @Autowired private DefaultSpringSecurityContextSource contextSource; - private PasswordComparisonAuthenticator authenticator; - private Authentication bob; - private Authentication ben; - // ~ Methods - // ======================================================================================================== + private PasswordComparisonAuthenticator authenticator; + + private Authentication bob; + + private Authentication ben; @Before public void setUp() { - authenticator = new PasswordComparisonAuthenticator(this.contextSource); - authenticator.setPasswordEncoder(NoOpPasswordEncoder.getInstance()); - authenticator.setUserDnPatterns(new String[] { "uid={0},ou=people" }); - bob = new UsernamePasswordAuthenticationToken("bob", "bobspassword"); - ben = new UsernamePasswordAuthenticationToken("ben", "benspassword"); + this.authenticator = new PasswordComparisonAuthenticator(this.contextSource); + this.authenticator.setPasswordEncoder(NoOpPasswordEncoder.getInstance()); + this.authenticator.setUserDnPatterns(new String[] { "uid={0},ou=people" }); + this.bob = new UsernamePasswordAuthenticationToken("bob", "bobspassword"); + this.ben = new UsernamePasswordAuthenticationToken("ben", "benspassword"); } @Test public void testAllAttributesAreRetrievedByDefault() { - DirContextAdapter user = (DirContextAdapter) authenticator.authenticate(bob); + DirContextAdapter user = (DirContextAdapter) this.authenticator.authenticate(this.bob); // System.out.println(user.getAttributes().toString()); assertThat(user.getAttributes().size()).withFailMessage("User should have 5 attributes").isEqualTo(5); } @Test public void testFailedSearchGivesUserNotFoundException() throws Exception { - authenticator = new PasswordComparisonAuthenticator(this.contextSource); - assertThat(authenticator.getUserDns("Bob")).withFailMessage("User DN matches shouldn't be available").isEmpty(); - authenticator.setUserSearch(new MockUserSearch(null)); - authenticator.afterPropertiesSet(); + this.authenticator = new PasswordComparisonAuthenticator(this.contextSource); + assertThat(this.authenticator.getUserDns("Bob")).withFailMessage("User DN matches shouldn't be available") + .isEmpty(); + this.authenticator.setUserSearch(new MockUserSearch(null)); + this.authenticator.afterPropertiesSet(); try { - authenticator.authenticate(new UsernamePasswordAuthenticationToken("Joe", - "pass")); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("Joe", "pass")); fail("Expected exception on failed user search"); } catch (UsernameNotFoundException expected) { @@ -93,74 +92,70 @@ public class PasswordComparisonAuthenticatorTests { @Test(expected = BadCredentialsException.class) public void testLdapPasswordCompareFailsWithWrongPassword() { // Don't retrieve the password - authenticator.setUserAttributes(new String[] { "uid", "cn", "sn" }); - authenticator.authenticate(new UsernamePasswordAuthenticationToken("bob", - "wrongpass")); + this.authenticator.setUserAttributes(new String[] { "uid", "cn", "sn" }); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("bob", "wrongpass")); } @Test public void testMultipleDnPatternsWorkOk() { - authenticator.setUserDnPatterns(new String[] { "uid={0},ou=nonexistent", - "uid={0},ou=people" }); - authenticator.authenticate(bob); + this.authenticator.setUserDnPatterns(new String[] { "uid={0},ou=nonexistent", "uid={0},ou=people" }); + this.authenticator.authenticate(this.bob); } @Test public void testOnlySpecifiedAttributesAreRetrieved() { - authenticator.setUserAttributes(new String[] { "uid", "userPassword" }); + this.authenticator.setUserAttributes(new String[] { "uid", "userPassword" }); - DirContextAdapter user = (DirContextAdapter) authenticator.authenticate(bob); - assertThat(user - .getAttributes().size()).withFailMessage("Should have retrieved 2 attribute (uid)").isEqualTo(2); + DirContextAdapter user = (DirContextAdapter) this.authenticator.authenticate(this.bob); + assertThat(user.getAttributes().size()).withFailMessage("Should have retrieved 2 attribute (uid)").isEqualTo(2); } @Test public void testLdapCompareSucceedsWithCorrectPassword() { // Don't retrieve the password - authenticator.setUserAttributes(new String[] { "uid" }); - authenticator.authenticate(bob); + this.authenticator.setUserAttributes(new String[] { "uid" }); + this.authenticator.authenticate(this.bob); } @Test public void testLdapCompareSucceedsWithShaEncodedPassword() { // Don't retrieve the password - authenticator.setUserAttributes(new String[] { "uid" }); - authenticator.setPasswordEncoder(new LdapShaPasswordEncoder(KeyGenerators.shared(0))); - authenticator.setUsePasswordAttrCompare(false); - authenticator.authenticate(ben); + this.authenticator.setUserAttributes(new String[] { "uid" }); + this.authenticator.setPasswordEncoder(new LdapShaPasswordEncoder(KeyGenerators.shared(0))); + this.authenticator.setUsePasswordAttrCompare(false); + this.authenticator.authenticate(this.ben); } @Test(expected = IllegalArgumentException.class) public void testPasswordEncoderCantBeNull() { - authenticator.setPasswordEncoder(null); + this.authenticator.setPasswordEncoder(null); } @Test public void testUseOfDifferentPasswordAttributeSucceeds() { - authenticator.setPasswordAttributeName("uid"); - authenticator.authenticate(new UsernamePasswordAuthenticationToken("bob", "bob")); + this.authenticator.setPasswordAttributeName("uid"); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("bob", "bob")); } @Test public void testLdapCompareWithDifferentPasswordAttributeSucceeds() { - authenticator.setUserAttributes(new String[] { "uid" }); - authenticator.setPasswordAttributeName("cn"); - authenticator.authenticate(new UsernamePasswordAuthenticationToken("ben", - "Ben Alex")); + this.authenticator.setUserAttributes(new String[] { "uid" }); + this.authenticator.setPasswordAttributeName("cn"); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("ben", "Ben Alex")); } @Test public void testWithUserSearch() { - authenticator = new PasswordComparisonAuthenticator(this.contextSource); - authenticator.setPasswordEncoder(NoOpPasswordEncoder.getInstance()); - assertThat(authenticator.getUserDns("Bob")).withFailMessage("User DN matches shouldn't be available").isEmpty(); + this.authenticator = new PasswordComparisonAuthenticator(this.contextSource); + this.authenticator.setPasswordEncoder(NoOpPasswordEncoder.getInstance()); + assertThat(this.authenticator.getUserDns("Bob")).withFailMessage("User DN matches shouldn't be available") + .isEmpty(); - DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName( - "uid=Bob,ou=people")); + DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName("uid=Bob,ou=people")); ctx.setAttributeValue("userPassword", "bobspassword"); - authenticator.setUserSearch(new MockUserSearch(ctx)); - authenticator.authenticate(new UsernamePasswordAuthenticationToken( - "shouldntbeused", "bobspassword")); + this.authenticator.setUserSearch(new MockUserSearch(ctx)); + this.authenticator.authenticate(new UsernamePasswordAuthenticationToken("shouldntbeused", "bobspassword")); } + } diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearchTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearchTests.java index 789e89f4ee..517a4f67cf 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearchTests.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearchTests.java @@ -16,8 +16,6 @@ package org.springframework.security.ldap.search; -import static org.assertj.core.api.Assertions.assertThat; - import javax.naming.ldap.LdapName; import org.junit.Test; @@ -32,6 +30,8 @@ import org.springframework.security.ldap.DefaultSpringSecurityContextSource; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests for FilterBasedLdapUserSearch. * @@ -47,8 +47,7 @@ public class FilterBasedLdapUserSearchTests { @Test public void basicSearchSucceeds() throws Exception { - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", - "(uid={0})", this.contextSource); + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", "(uid={0})", this.contextSource); locator.setSearchSubtree(false); locator.setSearchTimeLimit(0); locator.setDerefLinkFlag(false); @@ -61,8 +60,7 @@ public class FilterBasedLdapUserSearchTests { @Test public void searchForNameWithCommaSucceeds() throws Exception { - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", - "(uid={0})", this.contextSource); + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", "(uid={0})", this.contextSource); locator.setSearchSubtree(false); DirContextOperations jerry = locator.searchForUser("jerry"); @@ -74,8 +72,7 @@ public class FilterBasedLdapUserSearchTests { // Try some funny business with filters. @Test public void extraFilterPartToExcludeBob() { - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch( - "ou=people", + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", "(&(cn=*)(!(|(uid={0})(uid=rod)(uid=jerry)(uid=slashguy)(uid=javadude)(uid=groovydude)(uid=closuredude)(uid=scaladude))))", this.contextSource); @@ -86,23 +83,20 @@ public class FilterBasedLdapUserSearchTests { @Test(expected = IncorrectResultSizeDataAccessException.class) public void searchFailsOnMultipleMatches() { - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", - "(cn=*)", this.contextSource); + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", "(cn=*)", this.contextSource); locator.searchForUser("Ignored"); } @Test(expected = UsernameNotFoundException.class) public void searchForInvalidUserFails() { - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", - "(uid={0})", this.contextSource); + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=people", "(uid={0})", this.contextSource); locator.searchForUser("Joe"); } @Test public void subTreeSearchSucceeds() throws Exception { // Don't set the searchBase, so search from the root. - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("", "(cn={0})", - this.contextSource); + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("", "(cn={0})", this.contextSource); locator.setSearchSubtree(true); DirContextOperations ben = locator.searchForUser("Ben Alex"); @@ -113,8 +107,8 @@ public class FilterBasedLdapUserSearchTests { @Test public void searchWithDifferentSearchBaseIsSuccessful() { - FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch( - "ou=otherpeople", "(cn={0})", this.contextSource); + FilterBasedLdapUserSearch locator = new FilterBasedLdapUserSearch("ou=otherpeople", "(cn={0})", + this.contextSource); DirContextOperations joe = locator.searchForUser("Joe Smeth"); assertThat(joe.getStringAttribute("cn")).isEqualTo("Joe Smeth"); } diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSContainerTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSContainerTests.java index add74ddf45..8a14656709 100644 --- a/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSContainerTests.java +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSContainerTests.java @@ -16,9 +16,6 @@ package org.springframework.security.ldap.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.io.File; import java.io.FileOutputStream; import java.io.IOException; @@ -30,9 +27,14 @@ import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; + import org.springframework.core.io.ClassPathResource; import org.springframework.util.FileCopyUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.fail; + /** * Useful for debugging the container by itself. * @@ -50,10 +52,8 @@ public class ApacheDSContainerTests { // SEC-2162 @Test public void failsToStartThrowsException() throws Exception { - ApacheDSContainer server1 = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:test-server.ldif"); - ApacheDSContainer server2 = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:missing.ldif"); + ApacheDSContainer server1 = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif"); + ApacheDSContainer server2 = new ApacheDSContainer("dc=springframework,dc=org", "classpath:missing.ldif"); List ports = getDefaultPorts(1); server1.setPort(ports.get(0)); server2.setPort(ports.get(0)); @@ -70,12 +70,12 @@ public class ApacheDSContainerTests { try { server1.destroy(); } - catch (Throwable t) { + catch (Throwable ex) { } try { server2.destroy(); } - catch (Throwable t) { + catch (Throwable ex) { } } } @@ -83,10 +83,8 @@ public class ApacheDSContainerTests { // SEC-2161 @Test public void multipleInstancesSimultanciously() throws Exception { - ApacheDSContainer server1 = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:test-server.ldif"); - ApacheDSContainer server2 = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:test-server.ldif"); + ApacheDSContainer server1 = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif"); + ApacheDSContainer server2 = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif"); List ports = getDefaultPorts(2); server1.setPort(ports.get(0)); server2.setPort(ports.get(1)); @@ -98,20 +96,19 @@ public class ApacheDSContainerTests { try { server1.destroy(); } - catch (Throwable t) { + catch (Throwable ex) { } try { server2.destroy(); } - catch (Throwable t) { + catch (Throwable ex) { } } } @Test public void startWithLdapOverSslWithoutCertificate() throws Exception { - ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:test-server.ldif"); + ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif"); List ports = getDefaultPorts(1); server.setPort(ports.get(0)); server.setLdapOverSslEnabled(true); @@ -120,21 +117,21 @@ public class ApacheDSContainerTests { server.afterPropertiesSet(); fail("Expected an IllegalArgumentException to be thrown."); } - catch (IllegalArgumentException e){ - assertThat(e).hasMessage("When LdapOverSsl is enabled, the keyStoreFile property must be set."); + catch (IllegalArgumentException ex) { + assertThat(ex).hasMessage("When LdapOverSsl is enabled, the keyStoreFile property must be set."); } } @Test public void startWithLdapOverSslWithWrongPassword() throws Exception { - final ClassPathResource keyStoreResource = new ClassPathResource("/org/springframework/security/ldap/server/spring.keystore"); - final File temporaryKeyStoreFile = new File(temporaryFolder.getRoot(), "spring.keystore"); + final ClassPathResource keyStoreResource = new ClassPathResource( + "/org/springframework/security/ldap/server/spring.keystore"); + final File temporaryKeyStoreFile = new File(this.temporaryFolder.getRoot(), "spring.keystore"); FileCopyUtils.copy(keyStoreResource.getInputStream(), new FileOutputStream(temporaryKeyStoreFile)); assertThat(temporaryKeyStoreFile).isFile(); - ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org", - "classpath:test-server.ldif"); + ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif"); List ports = getDefaultPorts(1); server.setPort(ports.get(0)); @@ -142,20 +139,13 @@ public class ApacheDSContainerTests { server.setLdapOverSslEnabled(true); server.setKeyStoreFile(temporaryKeyStoreFile); server.setCertificatePassord("incorrect-password"); - - try { - server.afterPropertiesSet(); - fail("Expected a RuntimeException to be thrown."); - } - catch (RuntimeException e){ - assertThat(e).hasMessage("Server startup failed"); - assertThat(e).hasRootCauseInstanceOf(UnrecoverableKeyException.class); - } + assertThatExceptionOfType(RuntimeException.class).isThrownBy(server::afterPropertiesSet) + .withMessage("Server startup failed").withRootCauseInstanceOf(UnrecoverableKeyException.class); } /** - * This test starts an LDAP server using LDAPs (LDAP over SSL). A self-signed certificate is being used, which was - * previously generated with: + * This test starts an LDAP server using LDAPs (LDAP over SSL). A self-signed + * certificate is being used, which was previously generated with: * *
         	 * {@code
        @@ -168,14 +158,14 @@ public class ApacheDSContainerTests {
         	@Test
         	public void startWithLdapOverSsl() throws Exception {
         
        -		final ClassPathResource keyStoreResource = new ClassPathResource("/org/springframework/security/ldap/server/spring.keystore");
        -		final File temporaryKeyStoreFile = new File(temporaryFolder.getRoot(), "spring.keystore");
        +		final ClassPathResource keyStoreResource = new ClassPathResource(
        +				"/org/springframework/security/ldap/server/spring.keystore");
        +		final File temporaryKeyStoreFile = new File(this.temporaryFolder.getRoot(), "spring.keystore");
         		FileCopyUtils.copy(keyStoreResource.getInputStream(), new FileOutputStream(temporaryKeyStoreFile));
         
         		assertThat(temporaryKeyStoreFile).isFile();
         
        -		ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org",
        -				"classpath:test-server.ldif");
        +		ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif");
         
         		List ports = getDefaultPorts(1);
         		server.setPort(ports.get(0));
        @@ -191,7 +181,7 @@ public class ApacheDSContainerTests {
         			try {
         				server.destroy();
         			}
        -			catch (Throwable t) {
        +			catch (Throwable ex) {
         			}
         		}
         	}
        @@ -216,8 +206,7 @@ public class ApacheDSContainerTests {
         
         	@Test
         	public void afterPropertiesSetWhenPortIsZeroThenRandomPortIsSelected() throws Exception {
        -		ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org",
        -				"classpath:test-server.ldif");
        +		ApacheDSContainer server = new ApacheDSContainer("dc=springframework,dc=org", "classpath:test-server.ldif");
         		server.setPort(0);
         		try {
         			server.afterPropertiesSet();
        @@ -229,4 +218,5 @@ public class ApacheDSContainerTests {
         			server.destroy();
         		}
         	}
        +
         }
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSEmbeddedLdifTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSEmbeddedLdifTests.java
        index 09556a01bb..d166c00933 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSEmbeddedLdifTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/server/ApacheDSEmbeddedLdifTests.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap.server;
         
         import org.junit.After;
        @@ -34,16 +35,17 @@ import static org.assertj.core.api.Assertions.assertThat;
         public class ApacheDSEmbeddedLdifTests {
         
         	private static final String LDAP_ROOT = "ou=ssattributes,dc=springframework,dc=org";
        +
         	private static final int LDAP_PORT = 52389;
         
         	private ApacheDSContainer server;
        +
         	private SpringSecurityLdapTemplate ldapTemplate;
         
         	@Before
         	public void setUp() throws Exception {
         		// TODO: InMemoryXmlApplicationContext would be useful here, but it is not visible
        -		this.server = new ApacheDSContainer(LDAP_ROOT,
        -				"classpath:test-server-custom-attribute-types.ldif");
        +		this.server = new ApacheDSContainer(LDAP_ROOT, "classpath:test-server-custom-attribute-types.ldif");
         		this.server.setPort(LDAP_PORT);
         		this.server.afterPropertiesSet();
         
        @@ -68,10 +70,10 @@ public class ApacheDSEmbeddedLdifTests {
         	@Ignore // Not fixed yet
         	@Test // SEC-2387
         	public void customAttributeTypesShouldBeProperlyCreatedWhenLoadedFromLdif() {
        -		assertThat(this.ldapTemplate.compare("uid=objectWithCustomAttribute1", "uid",
        -				"objectWithCustomAttribute1")).isTrue();
        -		assertThat(this.ldapTemplate.compare("uid=objectWithCustomAttribute1",
        -				"customAttribute", "I am custom")).isTrue();
        +		assertThat(this.ldapTemplate.compare("uid=objectWithCustomAttribute1", "uid", "objectWithCustomAttribute1"))
        +				.isTrue();
        +		assertThat(this.ldapTemplate.compare("uid=objectWithCustomAttribute1", "customAttribute", "I am custom"))
        +				.isTrue();
         	}
         
         }
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerLdifTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerLdifTests.java
        index e5346d56f5..645da24654 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerLdifTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerLdifTests.java
        @@ -13,10 +13,14 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap.server;
         
        +import javax.annotation.PreDestroy;
        +
         import org.junit.After;
         import org.junit.Test;
        +
         import org.springframework.context.annotation.AnnotationConfigApplicationContext;
         import org.springframework.context.annotation.Bean;
         import org.springframework.context.annotation.Configuration;
        @@ -24,8 +28,6 @@ import org.springframework.ldap.core.ContextSource;
         import org.springframework.security.ldap.DefaultSpringSecurityContextSource;
         import org.springframework.security.ldap.SpringSecurityLdapTemplate;
         
        -import javax.annotation.PreDestroy;
        -
         import static org.assertj.core.api.Assertions.assertThat;
         import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown;
         
        @@ -40,25 +42,66 @@ public class UnboundIdContainerLdifTests {
         
         	@After
         	public void closeAppContext() {
        -		if (appCtx != null) {
        -			appCtx.close();
        -			appCtx = null;
        +		if (this.appCtx != null) {
        +			this.appCtx.close();
        +			this.appCtx = null;
         		}
         	}
         
         	@Test
         	public void unboundIdContainerWhenCustomLdifNameThenLdifLoaded() {
        -		appCtx = new AnnotationConfigApplicationContext(CustomLdifConfig.class);
        +		this.appCtx = new AnnotationConfigApplicationContext(CustomLdifConfig.class);
         
        -		DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) appCtx
        +		DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) this.appCtx
         				.getBean(ContextSource.class);
         
         		SpringSecurityLdapTemplate template = new SpringSecurityLdapTemplate(contextSource);
         		assertThat(template.compare("uid=bob,ou=people", "uid", "bob")).isTrue();
         	}
         
        +	@Test
        +	public void unboundIdContainerWhenWildcardLdifNameThenLdifLoaded() {
        +		this.appCtx = new AnnotationConfigApplicationContext(WildcardLdifConfig.class);
        +
        +		DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) this.appCtx
        +				.getBean(ContextSource.class);
        +
        +		SpringSecurityLdapTemplate template = new SpringSecurityLdapTemplate(contextSource);
        +		assertThat(template.compare("uid=bob,ou=people", "uid", "bob")).isTrue();
        +	}
        +
        +	@Test
        +	public void unboundIdContainerWhenMalformedLdifThenException() {
        +		try {
        +			this.appCtx = new AnnotationConfigApplicationContext(MalformedLdifConfig.class);
        +			failBecauseExceptionWasNotThrown(IllegalStateException.class);
        +		}
        +		catch (Exception ex) {
        +			assertThat(ex.getCause()).isInstanceOf(IllegalStateException.class);
        +			assertThat(ex.getMessage()).contains("Unable to load LDIF classpath:test-server-malformed.txt");
        +		}
        +	}
        +
        +	@Test
        +	public void unboundIdContainerWhenMissingLdifThenException() {
        +		try {
        +			this.appCtx = new AnnotationConfigApplicationContext(MissingLdifConfig.class);
        +			failBecauseExceptionWasNotThrown(IllegalStateException.class);
        +		}
        +		catch (Exception ex) {
        +			assertThat(ex.getCause()).isInstanceOf(IllegalStateException.class);
        +			assertThat(ex.getMessage()).contains("Unable to load LDIF classpath:does-not-exist.ldif");
        +		}
        +	}
        +
        +	@Test
        +	public void unboundIdContainerWhenWildcardLdifNotFoundThenProceeds() {
        +		new AnnotationConfigApplicationContext(WildcardNoLdifConfig.class);
        +	}
        +
         	@Configuration
         	static class CustomLdifConfig {
        +
         		private UnboundIdContainer container = new UnboundIdContainer("dc=springframework,dc=org",
         				"classpath:test-server.ldif");
         
        @@ -70,29 +113,20 @@ public class UnboundIdContainerLdifTests {
         
         		@Bean
         		ContextSource contextSource(UnboundIdContainer container) {
        -			return new DefaultSpringSecurityContextSource("ldap://127.0.0.1:"
        -					+ container.getPort() + "/dc=springframework,dc=org");
        +			return new DefaultSpringSecurityContextSource(
        +					"ldap://127.0.0.1:" + container.getPort() + "/dc=springframework,dc=org");
         		}
         
         		@PreDestroy
         		void shutdown() {
         			this.container.stop();
         		}
        -	}
         
        -	@Test
        -	public void unboundIdContainerWhenWildcardLdifNameThenLdifLoaded() {
        -		appCtx = new AnnotationConfigApplicationContext(WildcardLdifConfig.class);
        -
        -		DefaultSpringSecurityContextSource contextSource = (DefaultSpringSecurityContextSource) appCtx
        -				.getBean(ContextSource.class);
        -
        -		SpringSecurityLdapTemplate template = new SpringSecurityLdapTemplate(contextSource);
        -		assertThat(template.compare("uid=bob,ou=people", "uid", "bob")).isTrue();
         	}
         
         	@Configuration
         	static class WildcardLdifConfig {
        +
         		private UnboundIdContainer container = new UnboundIdContainer("dc=springframework,dc=org",
         				"classpath*:test-server.ldif");
         
        @@ -104,29 +138,20 @@ public class UnboundIdContainerLdifTests {
         
         		@Bean
         		ContextSource contextSource(UnboundIdContainer container) {
        -			return new DefaultSpringSecurityContextSource("ldap://127.0.0.1:"
        -					+ container.getPort() + "/dc=springframework,dc=org");
        +			return new DefaultSpringSecurityContextSource(
        +					"ldap://127.0.0.1:" + container.getPort() + "/dc=springframework,dc=org");
         		}
         
         		@PreDestroy
         		void shutdown() {
         			this.container.stop();
         		}
        -	}
         
        -	@Test
        -	public void unboundIdContainerWhenMalformedLdifThenException() {
        -		try {
        -			appCtx = new AnnotationConfigApplicationContext(MalformedLdifConfig.class);
        -			failBecauseExceptionWasNotThrown(IllegalStateException.class);
        -		} catch (Exception e) {
        -			assertThat(e.getCause()).isInstanceOf(IllegalStateException.class);
        -			assertThat(e.getMessage()).contains("Unable to load LDIF classpath:test-server-malformed.txt");
        -		}
         	}
         
         	@Configuration
         	static class MalformedLdifConfig {
        +
         		private UnboundIdContainer container = new UnboundIdContainer("dc=springframework,dc=org",
         				"classpath:test-server-malformed.txt");
         
        @@ -140,21 +165,12 @@ public class UnboundIdContainerLdifTests {
         		void shutdown() {
         			this.container.stop();
         		}
        -	}
         
        -	@Test
        -	public void unboundIdContainerWhenMissingLdifThenException() {
        -		try {
        -			appCtx = new AnnotationConfigApplicationContext(MissingLdifConfig.class);
        -			failBecauseExceptionWasNotThrown(IllegalStateException.class);
        -		} catch (Exception e) {
        -			assertThat(e.getCause()).isInstanceOf(IllegalStateException.class);
        -			assertThat(e.getMessage()).contains("Unable to load LDIF classpath:does-not-exist.ldif");
        -		}
         	}
         
         	@Configuration
         	static class MissingLdifConfig {
        +
         		private UnboundIdContainer container = new UnboundIdContainer("dc=springframework,dc=org",
         				"classpath:does-not-exist.ldif");
         
        @@ -168,15 +184,12 @@ public class UnboundIdContainerLdifTests {
         		void shutdown() {
         			this.container.stop();
         		}
        -	}
         
        -	@Test
        -	public void unboundIdContainerWhenWildcardLdifNotFoundThenProceeds() {
        -		new AnnotationConfigApplicationContext(WildcardNoLdifConfig.class);
         	}
         
         	@Configuration
         	static class WildcardNoLdifConfig {
        +
         		private UnboundIdContainer container = new UnboundIdContainer("dc=springframework,dc=org",
         				"classpath*:*.test.ldif");
         
        @@ -190,5 +203,7 @@ public class UnboundIdContainerLdifTests {
         		void shutdown() {
         			this.container.stop();
         		}
        +
         	}
        +
         }
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerTests.java
        index 5b1a7bc59d..a86de8c005 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/server/UnboundIdContainerTests.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap.server;
         
         import java.io.IOException;
        @@ -33,8 +34,7 @@ public class UnboundIdContainerTests {
         
         	@Test
         	public void startLdapServer() throws Exception {
        -		UnboundIdContainer server = new UnboundIdContainer("dc=springframework,dc=org",
        -				"classpath:test-server.ldif");
        +		UnboundIdContainer server = new UnboundIdContainer("dc=springframework,dc=org", "classpath:test-server.ldif");
         		server.setApplicationContext(new GenericApplicationContext());
         		List ports = getDefaultPorts(1);
         		server.setPort(ports.get(0));
        @@ -42,7 +42,8 @@ public class UnboundIdContainerTests {
         		try {
         			server.afterPropertiesSet();
         			assertThat(server.getPort()).isEqualTo(ports.get(0));
        -		} finally {
        +		}
        +		finally {
         			server.destroy();
         		}
         	}
        @@ -55,7 +56,8 @@ public class UnboundIdContainerTests {
         		try {
         			server.afterPropertiesSet();
         			assertThat(server.getPort()).isNotEqualTo(0);
        -		} finally {
        +		}
        +		finally {
         			server.destroy();
         		}
         	}
        @@ -70,7 +72,8 @@ public class UnboundIdContainerTests {
         				availablePorts.add(socket.getLocalPort());
         			}
         			return availablePorts;
        -		} finally {
        +		}
        +		finally {
         			for (ServerSocket conn : connections) {
         				conn.close();
         			}
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulatorTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulatorTests.java
        index f1789129ae..37cb58cc02 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulatorTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulatorTests.java
        @@ -16,14 +16,14 @@
         
         package org.springframework.security.ldap.userdetails;
         
        -import static org.assertj.core.api.Assertions.*;
        -
         import java.util.Collection;
         import java.util.HashSet;
         import java.util.Set;
         
        -import org.junit.*;
        +import org.junit.Before;
        +import org.junit.Test;
         import org.junit.runner.RunWith;
        +
         import org.springframework.beans.factory.annotation.Autowired;
         import org.springframework.ldap.core.ContextSource;
         import org.springframework.ldap.core.DirContextAdapter;
        @@ -36,8 +36,9 @@ import org.springframework.security.ldap.SpringSecurityLdapTemplate;
         import org.springframework.test.context.ContextConfiguration;
         import org.springframework.test.context.junit4.SpringRunner;
         
        +import static org.assertj.core.api.Assertions.assertThat;
        +
         /**
        - *
          * @author Luke Taylor
          * @author Eddú Meléndez
          */
        @@ -48,56 +49,51 @@ public class DefaultLdapAuthoritiesPopulatorTests {
         
         	@Autowired
         	private ContextSource contextSource;
        -	private DefaultLdapAuthoritiesPopulator populator;
         
        -	// ~ Methods
        -	// ========================================================================================================
        +	private DefaultLdapAuthoritiesPopulator populator;
         
         	@Before
         	public void setUp() {
        -		populator = new DefaultLdapAuthoritiesPopulator(this.contextSource, "ou=groups");
        -		populator.setIgnorePartialResultException(false);
        +		this.populator = new DefaultLdapAuthoritiesPopulator(this.contextSource, "ou=groups");
        +		this.populator.setIgnorePartialResultException(false);
         	}
         
         	@Test
         	public void defaultRoleIsAssignedWhenSet() {
        -		populator.setDefaultRole("ROLE_USER");
        -		assertThat(populator.getContextSource()).isSameAs(this.contextSource);
        +		this.populator.setDefaultRole("ROLE_USER");
        +		assertThat(this.populator.getContextSource()).isSameAs(this.contextSource);
         
        -		DirContextAdapter ctx = new DirContextAdapter(
        -				new DistinguishedName("cn=notfound"));
        +		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName("cn=notfound"));
         
        -		Collection authorities = populator.getGrantedAuthorities(ctx,
        -				"notfound");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "notfound");
         		assertThat(authorities).hasSize(1);
         		assertThat(AuthorityUtils.authorityListToSet(authorities).contains("ROLE_USER")).isTrue();
         	}
         
         	@Test
         	public void nullSearchBaseIsAccepted() {
        -		populator = new DefaultLdapAuthoritiesPopulator(this.contextSource, null);
        -		populator.setDefaultRole("ROLE_USER");
        +		this.populator = new DefaultLdapAuthoritiesPopulator(this.contextSource, null);
        +		this.populator.setDefaultRole("ROLE_USER");
         
        -		Collection authorities = populator.getGrantedAuthorities(
        -				new DirContextAdapter(new DistinguishedName("cn=notused")), "notused");
        +		Collection authorities = this.populator
        +				.getGrantedAuthorities(new DirContextAdapter(new DistinguishedName("cn=notused")), "notused");
         		assertThat(authorities).hasSize(1);
         		assertThat(AuthorityUtils.authorityListToSet(authorities).contains("ROLE_USER")).isTrue();
         	}
         
         	@Test
         	public void groupSearchReturnsExpectedRoles() {
        -		populator.setRolePrefix("ROLE_");
        -		populator.setGroupRoleAttribute("ou");
        -		populator.setSearchSubtree(true);
        -		populator.setSearchSubtree(false);
        -		populator.setConvertToUpperCase(true);
        -		populator.setGroupSearchFilter("(member={0})");
        +		this.populator.setRolePrefix("ROLE_");
        +		this.populator.setGroupRoleAttribute("ou");
        +		this.populator.setSearchSubtree(true);
        +		this.populator.setSearchSubtree(false);
        +		this.populator.setConvertToUpperCase(true);
        +		this.populator.setGroupSearchFilter("(member={0})");
         
        -		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName(
        -				"uid=ben,ou=people,dc=springframework,dc=org"));
        +		DirContextAdapter ctx = new DirContextAdapter(
        +				new DistinguishedName("uid=ben,ou=people,dc=springframework,dc=org"));
         
        -		Set authorities = AuthorityUtils.authorityListToSet(populator
        -				.getGrantedAuthorities(ctx, "ben"));
        +		Set authorities = AuthorityUtils.authorityListToSet(this.populator.getGrantedAuthorities(ctx, "ben"));
         
         		assertThat(authorities).as("Should have 2 roles").hasSize(2);
         
        @@ -107,15 +103,15 @@ public class DefaultLdapAuthoritiesPopulatorTests {
         
         	@Test
         	public void useOfUsernameParameterReturnsExpectedRoles() {
        -		populator.setGroupRoleAttribute("ou");
        -		populator.setConvertToUpperCase(true);
        -		populator.setGroupSearchFilter("(ou={1})");
        +		this.populator.setGroupRoleAttribute("ou");
        +		this.populator.setConvertToUpperCase(true);
        +		this.populator.setGroupSearchFilter("(ou={1})");
         
        -		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName(
        -				"uid=ben,ou=people,dc=springframework,dc=org"));
        +		DirContextAdapter ctx = new DirContextAdapter(
        +				new DistinguishedName("uid=ben,ou=people,dc=springframework,dc=org"));
         
        -		Set authorities = AuthorityUtils.authorityListToSet(populator
        -				.getGrantedAuthorities(ctx, "manager"));
        +		Set authorities = AuthorityUtils
        +				.authorityListToSet(this.populator.getGrantedAuthorities(ctx, "manager"));
         
         		assertThat(authorities).as("Should have 1 role").hasSize(1);
         		assertThat(authorities.contains("ROLE_MANAGER")).isTrue();
        @@ -123,14 +119,14 @@ public class DefaultLdapAuthoritiesPopulatorTests {
         
         	@Test
         	public void subGroupRolesAreNotFoundByDefault() {
        -		populator.setGroupRoleAttribute("ou");
        -		populator.setConvertToUpperCase(true);
        +		this.populator.setGroupRoleAttribute("ou");
        +		this.populator.setConvertToUpperCase(true);
         
        -		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName(
        -				"uid=ben,ou=people,dc=springframework,dc=org"));
        +		DirContextAdapter ctx = new DirContextAdapter(
        +				new DistinguishedName("uid=ben,ou=people,dc=springframework,dc=org"));
         
        -		Set authorities = AuthorityUtils.authorityListToSet(populator
        -				.getGrantedAuthorities(ctx, "manager"));
        +		Set authorities = AuthorityUtils
        +				.authorityListToSet(this.populator.getGrantedAuthorities(ctx, "manager"));
         
         		assertThat(authorities).as("Should have 2 roles").hasSize(2);
         		assertThat(authorities.contains("ROLE_MANAGER")).isTrue();
        @@ -139,15 +135,15 @@ public class DefaultLdapAuthoritiesPopulatorTests {
         
         	@Test
         	public void subGroupRolesAreFoundWhenSubtreeSearchIsEnabled() {
        -		populator.setGroupRoleAttribute("ou");
        -		populator.setConvertToUpperCase(true);
        -		populator.setSearchSubtree(true);
        +		this.populator.setGroupRoleAttribute("ou");
        +		this.populator.setConvertToUpperCase(true);
        +		this.populator.setSearchSubtree(true);
         
        -		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName(
        -				"uid=ben,ou=people,dc=springframework,dc=org"));
        +		DirContextAdapter ctx = new DirContextAdapter(
        +				new DistinguishedName("uid=ben,ou=people,dc=springframework,dc=org"));
         
        -		Set authorities = AuthorityUtils.authorityListToSet(populator
        -				.getGrantedAuthorities(ctx, "manager"));
        +		Set authorities = AuthorityUtils
        +				.authorityListToSet(this.populator.getGrantedAuthorities(ctx, "manager"));
         
         		assertThat(authorities).as("Should have 3 roles").hasSize(3);
         		assertThat(authorities.contains("ROLE_MANAGER")).isTrue();
        @@ -157,32 +153,30 @@ public class DefaultLdapAuthoritiesPopulatorTests {
         
         	@Test
         	public void extraRolesAreAdded() {
        -		populator = new DefaultLdapAuthoritiesPopulator(this.contextSource, null) {
        +		this.populator = new DefaultLdapAuthoritiesPopulator(this.contextSource, null) {
         			@Override
        -			protected Set getAdditionalRoles(DirContextOperations user,
        -					String username) {
        -				return new HashSet<>(
        -						AuthorityUtils.createAuthorityList("ROLE_EXTRA"));
        +			protected Set getAdditionalRoles(DirContextOperations user, String username) {
        +				return new HashSet<>(AuthorityUtils.createAuthorityList("ROLE_EXTRA"));
         			}
         		};
         
        -		Collection authorities = populator.getGrantedAuthorities(
        -				new DirContextAdapter(new DistinguishedName("cn=notused")), "notused");
        +		Collection authorities = this.populator
        +				.getGrantedAuthorities(new DirContextAdapter(new DistinguishedName("cn=notused")), "notused");
         		assertThat(authorities).hasSize(1);
         		assertThat(AuthorityUtils.authorityListToSet(authorities).contains("ROLE_EXTRA")).isTrue();
         	}
         
         	@Test
         	public void userDnWithEscapedCharacterParameterReturnsExpectedRoles() {
        -		populator.setGroupRoleAttribute("ou");
        -		populator.setConvertToUpperCase(true);
        -		populator.setGroupSearchFilter("(member={0})");
        +		this.populator.setGroupRoleAttribute("ou");
        +		this.populator.setConvertToUpperCase(true);
        +		this.populator.setGroupSearchFilter("(member={0})");
         
        -		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName(
        -				"cn=mouse\\, jerry,ou=people,dc=springframework,dc=org"));
        +		DirContextAdapter ctx = new DirContextAdapter(
        +				new DistinguishedName("cn=mouse\\, jerry,ou=people,dc=springframework,dc=org"));
         
        -		Set authorities = AuthorityUtils.authorityListToSet(populator
        -				.getGrantedAuthorities(ctx, "notused"));
        +		Set authorities = AuthorityUtils
        +				.authorityListToSet(this.populator.getGrantedAuthorities(ctx, "notused"));
         
         		assertThat(authorities).as("Should have 1 role").hasSize(1);
         		assertThat(authorities.contains("ROLE_MANAGER")).isTrue();
        @@ -190,22 +184,23 @@ public class DefaultLdapAuthoritiesPopulatorTests {
         
         	@Test
         	public void customAuthoritiesMappingFunction() {
        -		populator.setAuthorityMapper(record -> {
        +		this.populator.setAuthorityMapper((record) -> {
         			String dn = record.get(SpringSecurityLdapTemplate.DN_KEY).get(0);
        -			String role = record.get(populator.getGroupRoleAttribute()).get(0);
        +			String role = record.get(this.populator.getGroupRoleAttribute()).get(0);
         			return new LdapAuthority(role, dn);
         		});
         
        -		DirContextAdapter ctx = new DirContextAdapter(new DistinguishedName(
        -				"cn=mouse\\, jerry,ou=people,dc=springframework,dc=org"));
        +		DirContextAdapter ctx = new DirContextAdapter(
        +				new DistinguishedName("cn=mouse\\, jerry,ou=people,dc=springframework,dc=org"));
         
        -		Collection authorities = populator.getGrantedAuthorities(ctx, "notused");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "notused");
         
         		assertThat(authorities).allMatch(LdapAuthority.class::isInstance);
         	}
         
         	@Test(expected = IllegalArgumentException.class)
         	public void customAuthoritiesMappingFunctionThrowsIfNull() {
        -		populator.setAuthorityMapper(null);
        +		this.populator.setAuthorityMapper(null);
         	}
        +
         }
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerModifyPasswordTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerModifyPasswordTests.java
        index 72f2b7a109..ab2e6e3992 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerModifyPasswordTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerModifyPasswordTests.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap.userdetails;
         
         import javax.annotation.PreDestroy;
        @@ -35,7 +36,7 @@ import org.springframework.test.context.ContextConfiguration;
         import org.springframework.test.context.junit4.SpringRunner;
         
         import static org.assertj.core.api.Assertions.assertThat;
        -import static org.assertj.core.api.Assertions.assertThatCode;
        +import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
         
         /**
          * Tests for {@link LdapUserDetailsManager#changePassword}, specifically relating to the
        @@ -44,7 +45,7 @@ import static org.assertj.core.api.Assertions.assertThatCode;
          * @author Josh Cummings
          */
         @RunWith(SpringRunner.class)
        -@ContextConfiguration(classes=LdapUserDetailsManagerModifyPasswordTests.UnboundIdContainerConfiguration.class)
        +@ContextConfiguration(classes = LdapUserDetailsManagerModifyPasswordTests.UnboundIdContainerConfiguration.class)
         public class LdapUserDetailsManagerModifyPasswordTests {
         
         	LdapUserDetailsManager userDetailsManager;
        @@ -60,15 +61,14 @@ public class LdapUserDetailsManagerModifyPasswordTests {
         	}
         
         	@Test
        -	@WithMockUser(username="bob", password="bobspassword", authorities="ROLE_USER")
        +	@WithMockUser(username = "bob", password = "bobspassword", authorities = "ROLE_USER")
         	public void changePasswordWhenOldPasswordIsIncorrectThenThrowsException() {
        -		assertThatCode(() ->
        -				this.userDetailsManager.changePassword("wrongoldpassword", "bobsnewpassword"))
        -				.isInstanceOf(BadCredentialsException.class);
        +		assertThatExceptionOfType(BadCredentialsException.class)
        +				.isThrownBy(() -> this.userDetailsManager.changePassword("wrongoldpassword", "bobsnewpassword"));
         	}
         
         	@Test
        -	@WithMockUser(username="bob", password="bobspassword", authorities="ROLE_USER")
        +	@WithMockUser(username = "bob", password = "bobspassword", authorities = "ROLE_USER")
         	public void changePasswordWhenOldPasswordIsCorrectThenPasses() {
         		SpringSecurityLdapTemplate template = new SpringSecurityLdapTemplate(this.contextSource);
         
        @@ -76,11 +76,13 @@ public class LdapUserDetailsManagerModifyPasswordTests {
         				"bobsshinynewandformidablylongandnearlyimpossibletorememberthoughdemonstrablyhardtocrackduetoitshighlevelofentropypasswordofjustice");
         
         		assertThat(template.compare("uid=bob,ou=people", "userPassword",
        -				"bobsshinynewandformidablylongandnearlyimpossibletorememberthoughdemonstrablyhardtocrackduetoitshighlevelofentropypasswordofjustice")).isTrue();
        +				"bobsshinynewandformidablylongandnearlyimpossibletorememberthoughdemonstrablyhardtocrackduetoitshighlevelofentropypasswordofjustice"))
        +						.isTrue();
         	}
         
         	@Configuration
         	static class UnboundIdContainerConfiguration {
        +
         		private UnboundIdContainer container = new UnboundIdContainer("dc=springframework,dc=org",
         				"classpath:test-server.ldif");
         
        @@ -92,13 +94,15 @@ public class LdapUserDetailsManagerModifyPasswordTests {
         
         		@Bean
         		ContextSource contextSource(UnboundIdContainer container) {
        -			return new DefaultSpringSecurityContextSource("ldap://127.0.0.1:"
        -					+ container.getPort() + "/dc=springframework,dc=org");
        +			return new DefaultSpringSecurityContextSource(
        +					"ldap://127.0.0.1:" + container.getPort() + "/dc=springframework,dc=org");
         		}
         
         		@PreDestroy
         		void shutdown() {
         			this.container.stop();
         		}
        +
         	}
        +
         }
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerTests.java
        index eaf24d8b25..c02d3c3564 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManagerTests.java
        @@ -16,9 +16,6 @@
         
         package org.springframework.security.ldap.userdetails;
         
        -import static org.assertj.core.api.Assertions.assertThat;
        -import static org.assertj.core.api.Assertions.fail;
        -
         import java.util.List;
         
         import org.junit.After;
        @@ -41,6 +38,9 @@ import org.springframework.security.ldap.SpringSecurityLdapTemplate;
         import org.springframework.test.context.ContextConfiguration;
         import org.springframework.test.context.junit4.SpringRunner;
         
        +import static org.assertj.core.api.Assertions.assertThat;
        +import static org.assertj.core.api.Assertions.fail;
        +
         /**
          * @author Luke Taylor
          * @author Eddú Meléndez
        @@ -52,8 +52,8 @@ public class LdapUserDetailsManagerTests {
         	@Autowired
         	private ContextSource contextSource;
         
        -	private static final List TEST_AUTHORITIES = AuthorityUtils.createAuthorityList(
        -			"ROLE_CLOWNS", "ROLE_ACROBATS");
        +	private static final List TEST_AUTHORITIES = AuthorityUtils.createAuthorityList("ROLE_CLOWNS",
        +			"ROLE_ACROBATS");
         
         	private LdapUserDetailsManager mgr;
         
        @@ -61,33 +61,32 @@ public class LdapUserDetailsManagerTests {
         
         	@Before
         	public void setUp() {
        -		mgr = new LdapUserDetailsManager(this.contextSource);
        -		template = new SpringSecurityLdapTemplate(this.contextSource);
        +		this.mgr = new LdapUserDetailsManager(this.contextSource);
        +		this.template = new SpringSecurityLdapTemplate(this.contextSource);
         		DirContextAdapter ctx = new DirContextAdapter();
         
         		ctx.setAttributeValue("objectclass", "organizationalUnit");
         		ctx.setAttributeValue("ou", "test people");
        -		template.bind("ou=test people", ctx, null);
        +		this.template.bind("ou=test people", ctx, null);
         
         		ctx.setAttributeValue("ou", "testgroups");
        -		template.bind("ou=testgroups", ctx, null);
        +		this.template.bind("ou=testgroups", ctx, null);
         
         		DirContextAdapter group = new DirContextAdapter();
         
         		group.setAttributeValue("objectclass", "groupOfNames");
         		group.setAttributeValue("cn", "clowns");
        -		group.setAttributeValue("member",
        -				"cn=nobody,ou=test people,dc=springframework,dc=org");
        -		template.bind("cn=clowns,ou=testgroups", group, null);
        +		group.setAttributeValue("member", "cn=nobody,ou=test people,dc=springframework,dc=org");
        +		this.template.bind("cn=clowns,ou=testgroups", group, null);
         
         		group.setAttributeValue("cn", "acrobats");
        -		template.bind("cn=acrobats,ou=testgroups", group, null);
        +		this.template.bind("cn=acrobats,ou=testgroups", group, null);
         
        -		mgr.setUsernameMapper(new DefaultLdapUsernameToDnMapper("ou=test people", "uid"));
        -		mgr.setGroupSearchBase("ou=testgroups");
        -		mgr.setGroupRoleAttributeName("cn");
        -		mgr.setGroupMemberAttributeName("member");
        -		mgr.setUserDetailsMapper(new PersonContextMapper());
        +		this.mgr.setUsernameMapper(new DefaultLdapUsernameToDnMapper("ou=test people", "uid"));
        +		this.mgr.setGroupSearchBase("ou=testgroups");
        +		this.mgr.setGroupRoleAttributeName("cn");
        +		this.mgr.setGroupMemberAttributeName("member");
        +		this.mgr.setUserDetailsMapper(new PersonContextMapper());
         	}
         
         	@After
        @@ -101,17 +100,17 @@ public class LdapUserDetailsManagerTests {
         		// template.unbind((String) people.next() + ",ou=testpeople");
         		// }
         
        -		template.unbind("ou=test people", true);
        -		template.unbind("ou=testgroups", true);
        +		this.template.unbind("ou=test people", true);
        +		this.template.unbind("ou=testgroups", true);
         
         		SecurityContextHolder.clearContext();
         	}
         
         	@Test
         	public void testLoadUserByUsernameReturnsCorrectData() {
        -		mgr.setUsernameMapper(new DefaultLdapUsernameToDnMapper("ou=people", "uid"));
        -		mgr.setGroupSearchBase("ou=groups");
        -		LdapUserDetails bob = (LdapUserDetails) mgr.loadUserByUsername("bob");
        +		this.mgr.setUsernameMapper(new DefaultLdapUsernameToDnMapper("ou=people", "uid"));
        +		this.mgr.setGroupSearchBase("ou=groups");
        +		LdapUserDetails bob = (LdapUserDetails) this.mgr.loadUserByUsername("bob");
         		assertThat(bob.getUsername()).isEqualTo("bob");
         		assertThat(bob.getDn()).isEqualTo("uid=bob,ou=people,dc=springframework,dc=org");
         		assertThat(bob.getPassword()).isEqualTo("bobspassword");
        @@ -121,18 +120,18 @@ public class LdapUserDetailsManagerTests {
         
         	@Test(expected = UsernameNotFoundException.class)
         	public void testLoadingInvalidUsernameThrowsUsernameNotFoundException() {
        -		mgr.loadUserByUsername("jim");
        +		this.mgr.loadUserByUsername("jim");
         	}
         
         	@Test
         	public void testUserExistsReturnsTrueForValidUser() {
        -		mgr.setUsernameMapper(new DefaultLdapUsernameToDnMapper("ou=people", "uid"));
        -		assertThat(mgr.userExists("bob")).isTrue();
        +		this.mgr.setUsernameMapper(new DefaultLdapUsernameToDnMapper("ou=people", "uid"));
        +		assertThat(this.mgr.userExists("bob")).isTrue();
         	}
         
         	@Test
         	public void testUserExistsReturnsFalseForInValidUser() {
        -		assertThat(mgr.userExists("jim")).isFalse();
        +		assertThat(this.mgr.userExists("jim")).isFalse();
         	}
         
         	@Test
        @@ -155,7 +154,7 @@ public class LdapUserDetailsManagerTests {
         
         		p.setAuthorities(TEST_AUTHORITIES);
         
        -		mgr.createUser(p.createUserDetails());
        +		this.mgr.createUser(p.createUserDetails());
         	}
         
         	@Test
        @@ -167,17 +166,17 @@ public class LdapUserDetailsManagerTests {
         		p.setUid("don");
         		p.setAuthorities(TEST_AUTHORITIES);
         
        -		mgr.createUser(p.createUserDetails());
        -		mgr.setUserDetailsMapper(new InetOrgPersonContextMapper());
        +		this.mgr.createUser(p.createUserDetails());
        +		this.mgr.setUserDetailsMapper(new InetOrgPersonContextMapper());
         
        -		InetOrgPerson don = (InetOrgPerson) mgr.loadUserByUsername("don");
        +		InetOrgPerson don = (InetOrgPerson) this.mgr.loadUserByUsername("don");
         
         		assertThat(don.getAuthorities()).hasSize(2);
         
        -		mgr.deleteUser("don");
        +		this.mgr.deleteUser("don");
         
         		try {
        -			mgr.loadUserByUsername("don");
        +			this.mgr.loadUserByUsername("don");
         			fail("Expected UsernameNotFoundException after deleting user");
         		}
         		catch (UsernameNotFoundException expected) {
        @@ -185,9 +184,7 @@ public class LdapUserDetailsManagerTests {
         		}
         
         		// Check that no authorities are left
        -		assertThat(
        -				mgr.getUserAuthorities(mgr.usernameMapper.buildDn("don"), "don")).hasSize(
        -						0);
        +		assertThat(this.mgr.getUserAuthorities(this.mgr.usernameMapper.buildDn("don"), "don")).hasSize(0);
         	}
         
         	@Test
        @@ -200,16 +197,15 @@ public class LdapUserDetailsManagerTests {
         		p.setPassword("yossarianspassword");
         		p.setAuthorities(TEST_AUTHORITIES);
         
        -		mgr.createUser(p.createUserDetails());
        +		this.mgr.createUser(p.createUserDetails());
         
         		SecurityContextHolder.getContext().setAuthentication(
        -				new UsernamePasswordAuthenticationToken("johnyossarian",
        -						"yossarianspassword", TEST_AUTHORITIES));
        +				new UsernamePasswordAuthenticationToken("johnyossarian", "yossarianspassword", TEST_AUTHORITIES));
         
        -		mgr.changePassword("yossarianspassword", "yossariansnewpassword");
        +		this.mgr.changePassword("yossarianspassword", "yossariansnewpassword");
         
        -		assertThat(template.compare("uid=johnyossarian,ou=test people", "userPassword",
        -				"yossariansnewpassword")).isTrue();
        +		assertThat(this.template.compare("uid=johnyossarian,ou=test people", "userPassword", "yossariansnewpassword"))
        +				.isTrue();
         	}
         
         	@Test(expected = BadCredentialsException.class)
        @@ -222,12 +218,12 @@ public class LdapUserDetailsManagerTests {
         		p.setPassword("yossarianspassword");
         		p.setAuthorities(TEST_AUTHORITIES);
         
        -		mgr.createUser(p.createUserDetails());
        +		this.mgr.createUser(p.createUserDetails());
         
         		SecurityContextHolder.getContext().setAuthentication(
        -				new UsernamePasswordAuthenticationToken("johnyossarian",
        -						"yossarianspassword", TEST_AUTHORITIES));
        +				new UsernamePasswordAuthenticationToken("johnyossarian", "yossarianspassword", TEST_AUTHORITIES));
         
        -		mgr.changePassword("wrongpassword", "yossariansnewpassword");
        +		this.mgr.changePassword("wrongpassword", "yossariansnewpassword");
         	}
        +
         }
        diff --git a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulatorTests.java b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulatorTests.java
        index 70ff12f642..e2c917507b 100644
        --- a/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulatorTests.java
        +++ b/ldap/src/integration-test/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulatorTests.java
        @@ -13,8 +13,13 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap.userdetails;
         
        +import java.util.Arrays;
        +import java.util.Collection;
        +import java.util.HashSet;
        +
         import org.junit.Before;
         import org.junit.Test;
         import org.junit.runner.RunWith;
        @@ -27,11 +32,7 @@ import org.springframework.security.ldap.ApacheDsContainerConfig;
         import org.springframework.test.context.ContextConfiguration;
         import org.springframework.test.context.junit4.SpringRunner;
         
        -import java.util.Arrays;
        -import java.util.Collection;
        -import java.util.HashSet;
        -
        -import static org.assertj.core.api.Assertions.*;
        +import static org.assertj.core.api.Assertions.assertThat;
         
         /**
          * @author Filip Hanik
        @@ -43,94 +44,86 @@ public class NestedLdapAuthoritiesPopulatorTests {
         
         	@Autowired
         	private ContextSource contextSource;
        -	private NestedLdapAuthoritiesPopulator populator;
        -	private LdapAuthority javaDevelopers;
        -	private LdapAuthority groovyDevelopers;
        -	private LdapAuthority scalaDevelopers;
        -	private LdapAuthority closureDevelopers;
        -	private LdapAuthority jDevelopers;
        -	private LdapAuthority circularJavaDevelopers;
         
        -	// ~ Methods
        -	// ========================================================================================================
        +	private NestedLdapAuthoritiesPopulator populator;
        +
        +	private LdapAuthority javaDevelopers;
        +
        +	private LdapAuthority groovyDevelopers;
        +
        +	private LdapAuthority scalaDevelopers;
        +
        +	private LdapAuthority closureDevelopers;
        +
        +	private LdapAuthority jDevelopers;
        +
        +	private LdapAuthority circularJavaDevelopers;
         
         	@Before
         	public void setUp() {
        -		populator = new NestedLdapAuthoritiesPopulator(this.contextSource,
        -				"ou=jdeveloper");
        -		populator.setGroupSearchFilter("(member={0})");
        -		populator.setIgnorePartialResultException(false);
        -		populator.setRolePrefix("");
        -		populator.setSearchSubtree(true);
        -		populator.setConvertToUpperCase(false);
        -		jDevelopers = new LdapAuthority("j-developers",
        -				"cn=j-developers,ou=jdeveloper,dc=springframework,dc=org");
        -		javaDevelopers = new LdapAuthority("java-developers",
        +		this.populator = new NestedLdapAuthoritiesPopulator(this.contextSource, "ou=jdeveloper");
        +		this.populator.setGroupSearchFilter("(member={0})");
        +		this.populator.setIgnorePartialResultException(false);
        +		this.populator.setRolePrefix("");
        +		this.populator.setSearchSubtree(true);
        +		this.populator.setConvertToUpperCase(false);
        +		this.jDevelopers = new LdapAuthority("j-developers", "cn=j-developers,ou=jdeveloper,dc=springframework,dc=org");
        +		this.javaDevelopers = new LdapAuthority("java-developers",
         				"cn=java-developers,ou=jdeveloper,dc=springframework,dc=org");
        -		groovyDevelopers = new LdapAuthority("groovy-developers",
        +		this.groovyDevelopers = new LdapAuthority("groovy-developers",
         				"cn=groovy-developers,ou=jdeveloper,dc=springframework,dc=org");
        -		scalaDevelopers = new LdapAuthority("scala-developers",
        +		this.scalaDevelopers = new LdapAuthority("scala-developers",
         				"cn=scala-developers,ou=jdeveloper,dc=springframework,dc=org");
        -		closureDevelopers = new LdapAuthority("closure-developers",
        +		this.closureDevelopers = new LdapAuthority("closure-developers",
         				"cn=closure-developers,ou=jdeveloper,dc=springframework,dc=org");
        -		circularJavaDevelopers = new LdapAuthority("circular-java-developers",
        +		this.circularJavaDevelopers = new LdapAuthority("circular-java-developers",
         				"cn=circular-java-developers,ou=jdeveloper,dc=springframework,dc=org");
         	}
         
         	@Test
         	public void testScalaDudeJDevelopersAuthorities() {
        -		DirContextAdapter ctx = new DirContextAdapter(
        -				"uid=scaladude,ou=people,dc=springframework,dc=org");
        -		Collection authorities = populator.getGrantedAuthorities(ctx,
        -				"scaladude");
        +		DirContextAdapter ctx = new DirContextAdapter("uid=scaladude,ou=people,dc=springframework,dc=org");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "scaladude");
         		assertThat(authorities).hasSize(5);
        -		assertThat(authorities).isEqualTo(Arrays.asList(javaDevelopers, circularJavaDevelopers,
        -				scalaDevelopers, groovyDevelopers, jDevelopers));
        +		assertThat(authorities).isEqualTo(Arrays.asList(this.javaDevelopers, this.circularJavaDevelopers,
        +				this.scalaDevelopers, this.groovyDevelopers, this.jDevelopers));
         	}
         
         	@Test
         	public void testJavaDudeJDevelopersAuthorities() {
        -		DirContextAdapter ctx = new DirContextAdapter(
        -				"uid=javadude,ou=people,dc=springframework,dc=org");
        -		Collection authorities = populator.getGrantedAuthorities(ctx,
        -				"javadude");
        +		DirContextAdapter ctx = new DirContextAdapter("uid=javadude,ou=people,dc=springframework,dc=org");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "javadude");
         		assertThat(authorities).hasSize(4);
        -		assertThat(authorities).contains(javaDevelopers);
        +		assertThat(authorities).contains(this.javaDevelopers);
         	}
         
         	@Test
         	public void testScalaDudeJDevelopersAuthoritiesWithSearchLimit() {
        -		populator.setMaxSearchDepth(1);
        -		DirContextAdapter ctx = new DirContextAdapter(
        -				"uid=scaladude,ou=people,dc=springframework,dc=org");
        -		Collection authorities = populator.getGrantedAuthorities(ctx,
        -				"scaladude");
        +		this.populator.setMaxSearchDepth(1);
        +		DirContextAdapter ctx = new DirContextAdapter("uid=scaladude,ou=people,dc=springframework,dc=org");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "scaladude");
         		assertThat(authorities).hasSize(1);
        -		assertThat(authorities).isEqualTo(Arrays.asList(scalaDevelopers));
        +		assertThat(authorities).isEqualTo(Arrays.asList(this.scalaDevelopers));
         	}
         
         	@Test
         	public void testGroovyDudeJDevelopersAuthorities() {
        -		DirContextAdapter ctx = new DirContextAdapter(
        -				"uid=groovydude,ou=people,dc=springframework,dc=org");
        -		Collection authorities = populator.getGrantedAuthorities(ctx,
        -				"groovydude");
        +		DirContextAdapter ctx = new DirContextAdapter("uid=groovydude,ou=people,dc=springframework,dc=org");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "groovydude");
         		assertThat(authorities).hasSize(4);
        -		assertThat(authorities).isEqualTo(Arrays.asList(javaDevelopers, circularJavaDevelopers, groovyDevelopers,
        -				jDevelopers));
        +		assertThat(authorities).isEqualTo(Arrays.asList(this.javaDevelopers, this.circularJavaDevelopers,
        +				this.groovyDevelopers, this.jDevelopers));
         	}
         
         	@Test
         	public void testClosureDudeJDevelopersWithMembershipAsAttributeValues() {
        -		populator.setAttributeNames(new HashSet(Arrays.asList("member")));
        +		this.populator.setAttributeNames(new HashSet(Arrays.asList("member")));
         
        -		DirContextAdapter ctx = new DirContextAdapter(
        -				"uid=closuredude,ou=people,dc=springframework,dc=org");
        -		Collection authorities = populator.getGrantedAuthorities(ctx,
        -				"closuredude");
        +		DirContextAdapter ctx = new DirContextAdapter("uid=closuredude,ou=people,dc=springframework,dc=org");
        +		Collection authorities = this.populator.getGrantedAuthorities(ctx, "closuredude");
         		assertThat(authorities).hasSize(5);
        -		assertThat(authorities).isEqualTo(Arrays.asList(javaDevelopers, circularJavaDevelopers,
        -				closureDevelopers, groovyDevelopers, jDevelopers));
        +		assertThat(authorities).isEqualTo(Arrays.asList(this.javaDevelopers, this.circularJavaDevelopers,
        +				this.closureDevelopers, this.groovyDevelopers, this.jDevelopers));
         
         		LdapAuthority[] ldapAuthorities = authorities.toArray(new LdapAuthority[0]);
         		assertThat(ldapAuthorities).hasSize(5);
        @@ -138,21 +131,23 @@ public class NestedLdapAuthoritiesPopulatorTests {
         		assertThat(ldapAuthorities[0].getAttributes().containsKey("member")).isTrue();
         		assertThat(ldapAuthorities[0].getAttributes().get("member")).isNotNull();
         		assertThat(ldapAuthorities[0].getAttributes().get("member")).hasSize(3);
        -		assertThat(ldapAuthorities[0].getFirstAttributeValue("member")).isEqualTo("cn=groovy-developers,ou=jdeveloper,dc=springframework,dc=org");
        +		assertThat(ldapAuthorities[0].getFirstAttributeValue("member"))
        +				.isEqualTo("cn=groovy-developers,ou=jdeveloper,dc=springframework,dc=org");
         
         		// java group
         		assertThat(ldapAuthorities[1].getAttributes().containsKey("member")).isTrue();
         		assertThat(ldapAuthorities[1].getAttributes().get("member")).isNotNull();
         		assertThat(ldapAuthorities[1].getAttributes().get("member")).hasSize(3);
        -		assertThat(groovyDevelopers.getDn()).isEqualTo(ldapAuthorities[1].getFirstAttributeValue("member"));
        -		assertThat(ldapAuthorities[2]
        -				.getAttributes().get("member")).contains("uid=closuredude,ou=people,dc=springframework,dc=org");
        +		assertThat(this.groovyDevelopers.getDn()).isEqualTo(ldapAuthorities[1].getFirstAttributeValue("member"));
        +		assertThat(ldapAuthorities[2].getAttributes().get("member"))
        +				.contains("uid=closuredude,ou=people,dc=springframework,dc=org");
         
         		// test non existent attribute
         		assertThat(ldapAuthorities[2].getFirstAttributeValue("test")).isNull();
         		assertThat(ldapAuthorities[2].getAttributeValues("test")).isNotNull();
         		assertThat(ldapAuthorities[2].getAttributeValues("test")).isEmpty();
         		// test role name
        -		assertThat(ldapAuthorities[3].getAuthority()).isEqualTo(groovyDevelopers.getAuthority());
        +		assertThat(ldapAuthorities[3].getAuthority()).isEqualTo(this.groovyDevelopers.getAuthority());
         	}
        +
         }
        diff --git a/ldap/src/main/java/org/springframework/security/ldap/DefaultLdapUsernameToDnMapper.java b/ldap/src/main/java/org/springframework/security/ldap/DefaultLdapUsernameToDnMapper.java
        index bbb80ccf3d..e4c0ac0e08 100644
        --- a/ldap/src/main/java/org/springframework/security/ldap/DefaultLdapUsernameToDnMapper.java
        +++ b/ldap/src/main/java/org/springframework/security/ldap/DefaultLdapUsernameToDnMapper.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap;
         
         import org.springframework.ldap.core.DistinguishedName;
        @@ -26,7 +27,9 @@ import org.springframework.ldap.core.DistinguishedName;
          * @author Luke Taylor
          */
         public class DefaultLdapUsernameToDnMapper implements LdapUsernameToDnMapper {
        +
         	private final String userDnBase;
        +
         	private final String usernameAttribute;
         
         	/**
        @@ -41,11 +44,11 @@ public class DefaultLdapUsernameToDnMapper implements LdapUsernameToDnMapper {
         	/**
         	 * Assembles the Distinguished Name that should be used the given username.
         	 */
        +	@Override
         	public DistinguishedName buildDn(String username) {
        -		DistinguishedName dn = new DistinguishedName(userDnBase);
        -
        -		dn.add(usernameAttribute, username);
        -
        +		DistinguishedName dn = new DistinguishedName(this.userDnBase);
        +		dn.add(this.usernameAttribute, username);
         		return dn;
         	}
        +
         }
        diff --git a/ldap/src/main/java/org/springframework/security/ldap/DefaultSpringSecurityContextSource.java b/ldap/src/main/java/org/springframework/security/ldap/DefaultSpringSecurityContextSource.java
        index f9790d606d..f9c6293712 100644
        --- a/ldap/src/main/java/org/springframework/security/ldap/DefaultSpringSecurityContextSource.java
        +++ b/ldap/src/main/java/org/springframework/security/ldap/DefaultSpringSecurityContextSource.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.ldap;
         
         import java.util.ArrayList;
        @@ -47,67 +48,54 @@ import org.springframework.util.Assert;
          * @since 2.0
          */
         public class DefaultSpringSecurityContextSource extends LdapContextSource {
        -	protected final Log logger = LogFactory.getLog(getClass());
         
        -	private String rootDn;
        +	protected final Log logger = LogFactory.getLog(getClass());
         
         	/**
         	 * Create and initialize an instance which will connect to the supplied LDAP URL. If
         	 * you want to use more than one server for fail-over, rather use the
         	 * {@link #DefaultSpringSecurityContextSource(List, String)} constructor.
        -	 *
         	 * @param providerUrl an LDAP URL of the form
         	 * ldap://localhost:389/base_dn
         	 */
         	public DefaultSpringSecurityContextSource(String providerUrl) {
         		Assert.hasLength(providerUrl, "An LDAP connection URL must be supplied.");
        -
        -		StringTokenizer st = new StringTokenizer(providerUrl);
        -
        +		StringTokenizer tokenizer = new StringTokenizer(providerUrl);
         		ArrayList urls = new ArrayList<>();
        -
         		// Work out rootDn from the first URL and check that the other URLs (if any) match
        -		while (st.hasMoreTokens()) {
        -			String url = st.nextToken();
        +		String rootDn = null;
        +		while (tokenizer.hasMoreTokens()) {
        +			String url = tokenizer.nextToken();
         			String urlRootDn = LdapUtils.parseRootDnFromUrl(url);
        -
         			urls.add(url.substring(0, url.lastIndexOf(urlRootDn)));
        -
         			this.logger.info(" URL '" + url + "', root DN is '" + urlRootDn + "'");
        -
        -			if (this.rootDn == null) {
        -				this.rootDn = urlRootDn;
        -			}
        -			else if (!this.rootDn.equals(urlRootDn)) {
        -				throw new IllegalArgumentException(
        -						"Root DNs must be the same when using multiple URLs");
        -			}
        +			Assert.isTrue(rootDn == null || rootDn.equals(urlRootDn),
        +					"Root DNs must be the same when using multiple URLs");
        +			rootDn = (rootDn != null) ? rootDn : urlRootDn;
         		}
        -
         		setUrls(urls.toArray(new String[0]));
        -		setBase(this.rootDn);
        +		setBase(rootDn);
         		setPooled(true);
         		setAuthenticationStrategy(new SimpleDirContextAuthenticationStrategy() {
        +
         			@Override
         			@SuppressWarnings("rawtypes")
         			public void setupEnvironment(Hashtable env, String dn, String password) {
         				super.setupEnvironment(env, dn, password);
        -				// Remove the pooling flag unless we are authenticating as the 'manager'
        -				// user.
        +				// Remove the pooling flag unless authenticating as the 'manager' user.
         				if (!DefaultSpringSecurityContextSource.this.userDn.equals(dn)
         						&& env.containsKey(SUN_LDAP_POOLING_FLAG)) {
        -					DefaultSpringSecurityContextSource.this.logger
        -							.debug("Removing pooling flag for user " + dn);
        +					DefaultSpringSecurityContextSource.this.logger.debug("Removing pooling flag for user " + dn);
         					env.remove(SUN_LDAP_POOLING_FLAG);
         				}
         			}
        +
         		});
         	}
         
         	/**
         	 * Create and initialize an instance which will connect of the LDAP Spring Security
         	 * Context Source. It will connect to any of the provided LDAP server URLs.
        -	 *
         	 * @param urls A list of string values which are LDAP server URLs. An example would be
         	 * ldap://ldap.company.com:389. LDAPS URLs (SSL-secured) may be used as
         	 * well, given that Spring Security is able to connect to the server. Note that these
        @@ -128,7 +116,6 @@ public class DefaultSpringSecurityContextSource extends LdapContextSource {
         	 * Builds a Spring LDAP-compliant Provider URL string, i.e. a space-separated list of
         	 * LDAP servers with their base DNs. As the base DN must be identical for all servers,
         	 * it needs to be supplied only once.
        -	 *
         	 * @param urls A list of string values which are LDAP server URLs. An example would be
         	 *
         	 * 
        @@ -149,16 +136,13 @@ public class DefaultSpringSecurityContextSource extends LdapContextSource {
         	private static String buildProviderUrl(List urls, String baseDn) {
         		Assert.notNull(baseDn, "The Base DN for the LDAP server must not be null.");
         		Assert.notEmpty(urls, "At least one LDAP server URL must be provided.");
        -
         		String trimmedBaseDn = baseDn.trim();
         		StringBuilder providerUrl = new StringBuilder();
        -
         		for (String serverUrl : urls) {
         			String trimmedUrl = serverUrl.trim();
         			if ("".equals(trimmedUrl)) {
         				continue;
         			}
        -
         			providerUrl.append(trimmedUrl);
         			if (!trimmedUrl.endsWith("/")) {
         				providerUrl.append("/");
        @@ -166,7 +150,6 @@ public class DefaultSpringSecurityContextSource extends LdapContextSource {
         			providerUrl.append(trimmedBaseDn);
         			providerUrl.append(" ");
         		}
        -
         		return providerUrl.toString();
         
         	}
        diff --git a/ldap/src/main/java/org/springframework/security/ldap/LdapEncoder.java b/ldap/src/main/java/org/springframework/security/ldap/LdapEncoder.java
        index 43316dab64..a3911aa180 100644
        --- a/ldap/src/main/java/org/springframework/security/ldap/LdapEncoder.java
        +++ b/ldap/src/main/java/org/springframework/security/ldap/LdapEncoder.java
        @@ -1,245 +1,195 @@
        -/*
        - * Copyright 2005-2010 the original author or authors.
        - *
        - * Licensed under the Apache License, Version 2.0 (the "License");
        - * you may not use this file except in compliance with the License.
        - * You may obtain a copy of the License at
        - *
        - *      https://www.apache.org/licenses/LICENSE-2.0
        - *
        - * Unless required by applicable law or agreed to in writing, software
        - * distributed under the License is distributed on an "AS IS" BASIS,
        - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        - * See the License for the specific language governing permissions and
        - * limitations under the License.
        - */
        -
        -package org.springframework.security.ldap;
        -
        -import org.springframework.ldap.BadLdapGrammarException;
        -
        -/**
        - * Helper class to encode and decode ldap names and values.
        - *
        - * 

        - * NOTE: This is a copy from Spring LDAP so that both Spring LDAP 1.x and 2.x can be - * supported without reflection. - *

        - * - * @author Adam Skogman - * @author Mattias Hellborg Arthursson - */ -final class LdapEncoder { - - private static final int HEX = 16; - private static String[] NAME_ESCAPE_TABLE = new String[96]; - - private static String[] FILTER_ESCAPE_TABLE = new String['\\' + 1]; - - static { - - // Name encoding table ------------------------------------- - - // all below 0x20 (control chars) - for (char c = 0; c < ' '; c++) { - NAME_ESCAPE_TABLE[c] = "\\" + toTwoCharHex(c); - } - - NAME_ESCAPE_TABLE['#'] = "\\#"; - NAME_ESCAPE_TABLE[','] = "\\,"; - NAME_ESCAPE_TABLE[';'] = "\\;"; - NAME_ESCAPE_TABLE['='] = "\\="; - NAME_ESCAPE_TABLE['+'] = "\\+"; - NAME_ESCAPE_TABLE['<'] = "\\<"; - NAME_ESCAPE_TABLE['>'] = "\\>"; - NAME_ESCAPE_TABLE['\"'] = "\\\""; - NAME_ESCAPE_TABLE['\\'] = "\\\\"; - - // Filter encoding table ------------------------------------- - - // fill with char itself - for (char c = 0; c < FILTER_ESCAPE_TABLE.length; c++) { - FILTER_ESCAPE_TABLE[c] = String.valueOf(c); - } - - // escapes (RFC2254) - FILTER_ESCAPE_TABLE['*'] = "\\2a"; - FILTER_ESCAPE_TABLE['('] = "\\28"; - FILTER_ESCAPE_TABLE[')'] = "\\29"; - FILTER_ESCAPE_TABLE['\\'] = "\\5c"; - FILTER_ESCAPE_TABLE[0] = "\\00"; - - } - - /** - * All static methods - not to be instantiated. - */ - private LdapEncoder() { - } - - protected static String toTwoCharHex(char c) { - - String raw = Integer.toHexString(c).toUpperCase(); - - if (raw.length() > 1) { - return raw; - } - else { - return "0" + raw; - } - } - - /** - * Escape a value for use in a filter. - * - * @param value the value to escape. - * @return a properly escaped representation of the supplied value. - */ - public static String filterEncode(String value) { - - if (value == null) { - return null; - } - - // make buffer roomy - StringBuilder encodedValue = new StringBuilder(value.length() * 2); - - int length = value.length(); - - for (int i = 0; i < length; i++) { - - char c = value.charAt(i); - - if (c < FILTER_ESCAPE_TABLE.length) { - encodedValue.append(FILTER_ESCAPE_TABLE[c]); - } - else { - // default: add the char - encodedValue.append(c); - } - } - - return encodedValue.toString(); - } - - /** - * LDAP Encodes a value for use with a DN. Escapes for LDAP, not JNDI! - * - *
        - * Escapes:
        - * ' ' [space] - "\ " [if first or last]
        - * '#' [hash] - "\#"
        - * ',' [comma] - "\,"
        - * ';' [semicolon] - "\;"
        - * '= [equals] - "\="
        - * '+' [plus] - "\+"
        - * '<' [less than] - "\<"
        - * '>' [greater than] - "\>"
        - * '"' [double quote] - "\""
        - * '\' [backslash] - "\\"
        - * - * @param value the value to escape. - * @return The escaped value. - */ - public static String nameEncode(String value) { - - if (value == null) { - return null; - } - - // make buffer roomy - StringBuilder encodedValue = new StringBuilder(value.length() * 2); - - int length = value.length(); - int last = length - 1; - - for (int i = 0; i < length; i++) { - - char c = value.charAt(i); - - // space first or last - if (c == ' ' && (i == 0 || i == last)) { - encodedValue.append("\\ "); - continue; - } - - if (c < NAME_ESCAPE_TABLE.length) { - // check in table for escapes - String esc = NAME_ESCAPE_TABLE[c]; - - if (esc != null) { - encodedValue.append(esc); - continue; - } - } - - // default: add the char - encodedValue.append(c); - } - - return encodedValue.toString(); - - } - - /** - * Decodes a value. Converts escaped chars to ordinary chars. - * - * @param value Trimmed value, so no leading an trailing blanks, except an escaped - * space last. - * @return The decoded value as a string. - * @throws BadLdapGrammarException - */ - static public String nameDecode(String value) throws BadLdapGrammarException { - - if (value == null) { - return null; - } - - // make buffer same size - StringBuilder decoded = new StringBuilder(value.length()); - - int i = 0; - while (i < value.length()) { - char currentChar = value.charAt(i); - if (currentChar == '\\') { - if (value.length() <= i + 1) { - // Ending with a single backslash is not allowed - throw new BadLdapGrammarException( - "Unexpected end of value " + "unterminated '\\'"); - } - else { - char nextChar = value.charAt(i + 1); - if (nextChar == ',' || nextChar == '=' || nextChar == '+' - || nextChar == '<' || nextChar == '>' || nextChar == '#' - || nextChar == ';' || nextChar == '\\' || nextChar == '\"' - || nextChar == ' ') { - // Normal backslash escape - decoded.append(nextChar); - i += 2; - } - else { - if (value.length() <= i + 2) { - throw new BadLdapGrammarException("Unexpected end of value " - + "expected special or hex, found '" + nextChar - + "'"); - } - else { - // This should be a hex value - String hexString = "" + nextChar + value.charAt(i + 2); - decoded.append((char) Integer.parseInt(hexString, HEX)); - i += 3; - } - } - } - } - else { - // This character wasn't escaped - just append it - decoded.append(currentChar); - i++; - } - } - - return decoded.toString(); - - } -} +/* + * Copyright 2005-2010 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.ldap; + +import org.springframework.ldap.BadLdapGrammarException; + +/** + * Helper class to encode and decode ldap names and values. + * + *

        + * NOTE: This is a copy from Spring LDAP so that both Spring LDAP 1.x and 2.x can be + * supported without reflection. + *

        + * + * @author Adam Skogman + * @author Mattias Hellborg Arthursson + */ +final class LdapEncoder { + + private static final int HEX = 16; + + private static String[] NAME_ESCAPE_TABLE = new String[96]; + static { + // all below 0x20 (control chars) + for (char c = 0; c < ' '; c++) { + NAME_ESCAPE_TABLE[c] = "\\" + toTwoCharHex(c); + } + NAME_ESCAPE_TABLE['#'] = "\\#"; + NAME_ESCAPE_TABLE[','] = "\\,"; + NAME_ESCAPE_TABLE[';'] = "\\;"; + NAME_ESCAPE_TABLE['='] = "\\="; + NAME_ESCAPE_TABLE['+'] = "\\+"; + NAME_ESCAPE_TABLE['<'] = "\\<"; + NAME_ESCAPE_TABLE['>'] = "\\>"; + NAME_ESCAPE_TABLE['\"'] = "\\\""; + NAME_ESCAPE_TABLE['\\'] = "\\\\"; + } + + private static String[] FILTER_ESCAPE_TABLE = new String['\\' + 1]; + + static { + // fill with char itself + for (char c = 0; c < FILTER_ESCAPE_TABLE.length; c++) { + FILTER_ESCAPE_TABLE[c] = String.valueOf(c); + } + // escapes (RFC2254) + FILTER_ESCAPE_TABLE['*'] = "\\2a"; + FILTER_ESCAPE_TABLE['('] = "\\28"; + FILTER_ESCAPE_TABLE[')'] = "\\29"; + FILTER_ESCAPE_TABLE['\\'] = "\\5c"; + FILTER_ESCAPE_TABLE[0] = "\\00"; + } + + /** + * All static methods - not to be instantiated. + */ + private LdapEncoder() { + } + + protected static String toTwoCharHex(char c) { + String raw = Integer.toHexString(c).toUpperCase(); + return (raw.length() > 1) ? raw : "0" + raw; + } + + /** + * Escape a value for use in a filter. + * @param value the value to escape. + * @return a properly escaped representation of the supplied value. + */ + static String filterEncode(String value) { + if (value == null) { + return null; + } + StringBuilder encodedValue = new StringBuilder(value.length() * 2); + int length = value.length(); + for (int i = 0; i < length; i++) { + char ch = value.charAt(i); + encodedValue.append((ch < FILTER_ESCAPE_TABLE.length) ? FILTER_ESCAPE_TABLE[ch] : ch); + } + return encodedValue.toString(); + } + + /** + * LDAP Encodes a value for use with a DN. Escapes for LDAP, not JNDI! + * + *
        + * Escapes:
        + * ' ' [space] - "\ " [if first or last]
        + * '#' [hash] - "\#"
        + * ',' [comma] - "\,"
        + * ';' [semicolon] - "\;"
        + * '= [equals] - "\="
        + * '+' [plus] - "\+"
        + * '<' [less than] - "\<"
        + * '>' [greater than] - "\>"
        + * '"' [double quote] - "\""
        + * '\' [backslash] - "\\"
        + * @param value the value to escape. + * @return The escaped value. + */ + static String nameEncode(String value) { + if (value == null) { + return null; + } + StringBuilder encodedValue = new StringBuilder(value.length() * 2); + int length = value.length(); + int last = length - 1; + for (int i = 0; i < length; i++) { + char c = value.charAt(i); + // space first or last + if (c == ' ' && (i == 0 || i == last)) { + encodedValue.append("\\ "); + continue; + } + // check in table for escapes + if (c < NAME_ESCAPE_TABLE.length) { + String esc = NAME_ESCAPE_TABLE[c]; + if (esc != null) { + encodedValue.append(esc); + continue; + } + } + // default: add the char + encodedValue.append(c); + } + return encodedValue.toString(); + } + + /** + * Decodes a value. Converts escaped chars to ordinary chars. + * @param value Trimmed value, so no leading an trailing blanks, except an escaped + * space last. + * @return The decoded value as a string. + * @throws BadLdapGrammarException + */ + static String nameDecode(String value) throws BadLdapGrammarException { + if (value == null) { + return null; + } + StringBuilder decoded = new StringBuilder(value.length()); + int i = 0; + while (i < value.length()) { + char currentChar = value.charAt(i); + if (currentChar == '\\') { + // Ending with a single backslash is not allowed + if (value.length() <= i + 1) { + throw new BadLdapGrammarException("Unexpected end of value " + "unterminated '\\'"); + } + char nextChar = value.charAt(i + 1); + if (isNormalBackslashEscape(nextChar)) { + decoded.append(nextChar); + i += 2; + } + else { + if (value.length() <= i + 2) { + throw new BadLdapGrammarException( + "Unexpected end of value " + "expected special or hex, found '" + nextChar + "'"); + } + // This should be a hex value + String hexString = "" + nextChar + value.charAt(i + 2); + decoded.append((char) Integer.parseInt(hexString, HEX)); + i += 3; + } + } + else { + // This character wasn't escaped - just append it + decoded.append(currentChar); + i++; + } + } + + return decoded.toString(); + + } + + private static boolean isNormalBackslashEscape(char nextChar) { + return nextChar == ',' || nextChar == '=' || nextChar == '+' || nextChar == '<' || nextChar == '>' + || nextChar == '#' || nextChar == ';' || nextChar == '\\' || nextChar == '\"' || nextChar == ' '; + } + +} diff --git a/ldap/src/main/java/org/springframework/security/ldap/LdapUsernameToDnMapper.java b/ldap/src/main/java/org/springframework/security/ldap/LdapUsernameToDnMapper.java index c8521e359a..43d38b8177 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/LdapUsernameToDnMapper.java +++ b/ldap/src/main/java/org/springframework/security/ldap/LdapUsernameToDnMapper.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap; import org.springframework.ldap.core.DistinguishedName; @@ -23,5 +24,7 @@ import org.springframework.ldap.core.DistinguishedName; * @author Luke Taylor */ public interface LdapUsernameToDnMapper { + DistinguishedName buildDn(String username); + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/LdapUtils.java b/ldap/src/main/java/org/springframework/security/ldap/LdapUtils.java index b4a7a7ae82..a222bf48a8 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/LdapUtils.java +++ b/ldap/src/main/java/org/springframework/security/ldap/LdapUtils.java @@ -16,19 +16,20 @@ package org.springframework.security.ldap; -import org.springframework.ldap.core.DirContextAdapter; -import org.springframework.ldap.core.DistinguishedName; -import org.springframework.security.crypto.codec.Utf8; -import org.springframework.util.Assert; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import java.net.URI; +import java.net.URISyntaxException; import javax.naming.Context; import javax.naming.NamingEnumeration; import javax.naming.NamingException; -import java.net.URI; -import java.net.URISyntaxException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.ldap.core.DirContextAdapter; +import org.springframework.ldap.core.DistinguishedName; +import org.springframework.security.crypto.codec.Utf8; +import org.springframework.util.Assert; /** * LDAP Utility methods. @@ -36,32 +37,23 @@ import java.net.URISyntaxException; * @author Luke Taylor */ public final class LdapUtils { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(LdapUtils.class); - // ~ Constructors - // =================================================================================================== - private LdapUtils() { } - // ~ Methods - // ======================================================================================================== - public static void closeContext(Context ctx) { if (ctx instanceof DirContextAdapter) { return; } - try { if (ctx != null) { ctx.close(); } } - catch (NamingException e) { - logger.error("Failed to close context.", e); + catch (NamingException ex) { + logger.error("Failed to close context.", ex); } } @@ -71,8 +63,8 @@ public final class LdapUtils { ne.close(); } } - catch (NamingException e) { - logger.error("Failed to close enumeration.", e); + catch (NamingException ex) { + logger.error("Failed to close enumeration.", ex); } } @@ -82,34 +74,23 @@ public final class LdapUtils { * If the DN is "cn=bob,ou=people,dc=springframework,dc=org" and the base context name * is "ou=people,dc=springframework,dc=org" it would return "cn=bob". *

        - * * @param fullDn the DN * @param baseCtx the context to work out the name relative to. - * * @return the - * * @throws NamingException any exceptions thrown by the context are propagated. */ - public static String getRelativeName(String fullDn, Context baseCtx) - throws NamingException { - + public static String getRelativeName(String fullDn, Context baseCtx) throws NamingException { String baseDn = baseCtx.getNameInNamespace(); - if (baseDn.length() == 0) { return fullDn; } - DistinguishedName base = new DistinguishedName(baseDn); DistinguishedName full = new DistinguishedName(fullDn); - if (base.equals(full)) { return ""; } - Assert.isTrue(full.startsWith(base), "Full DN does not start with base DN"); - full.removeFirst(base); - return full.toString(); } @@ -117,32 +98,24 @@ public final class LdapUtils { * Gets the full dn of a name by prepending the name of the context it is relative to. * If the name already contains the base name, it is returned unaltered. */ - public static DistinguishedName getFullDn(DistinguishedName dn, Context baseCtx) - throws NamingException { + public static DistinguishedName getFullDn(DistinguishedName dn, Context baseCtx) throws NamingException { DistinguishedName baseDn = new DistinguishedName(baseCtx.getNameInNamespace()); - if (dn.contains(baseDn)) { return dn; } - baseDn.append(dn); - return baseDn; } public static String convertPasswordToString(Object passObj) { Assert.notNull(passObj, "Password object to convert must not be null"); - if (passObj instanceof byte[]) { return Utf8.decode((byte[]) passObj); } - else if (passObj instanceof String) { + if (passObj instanceof String) { return (String) passObj; } - else { - throw new IllegalArgumentException( - "Password object was not a String or byte array."); - } + throw new IllegalArgumentException("Password object was not a String or byte array."); } /** @@ -151,16 +124,12 @@ public final class LdapUtils { * For example, the URL ldap://monkeymachine:11389/dc=springframework,dc=org * has the root DN "dc=springframework,dc=org". *

        - * * @param url the LDAP URL - * * @return the root DN */ public static String parseRootDnFromUrl(String url) { Assert.hasLength(url, "url must have length"); - String urlRootDn; - if (url.startsWith("ldap:") || url.startsWith("ldaps:")) { URI uri = parseLdapUrl(url); urlRootDn = uri.getRawPath(); @@ -169,11 +138,9 @@ public final class LdapUtils { // Assume it's an embedded server urlRootDn = url; } - if (urlRootDn.startsWith("/")) { urlRootDn = urlRootDn.substring(1); } - return urlRootDn; } @@ -188,15 +155,12 @@ public final class LdapUtils { private static URI parseLdapUrl(String url) { Assert.hasLength(url, "url must have length"); - try { return new URI(url); } - catch (URISyntaxException e) { - IllegalArgumentException iae = new IllegalArgumentException( - "Unable to parse url: " + url); - iae.initCause(e); - throw iae; + catch (URISyntaxException ex) { + throw new IllegalArgumentException("Unable to parse url: " + url, ex); } } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/SpringSecurityLdapTemplate.java b/ldap/src/main/java/org/springframework/security/ldap/SpringSecurityLdapTemplate.java index fe0d9c0002..08c281b1e9 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/SpringSecurityLdapTemplate.java +++ b/ldap/src/main/java/org/springframework/security/ldap/SpringSecurityLdapTemplate.java @@ -13,19 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.dao.IncorrectResultSizeDataAccessException; -import org.springframework.ldap.core.ContextExecutor; -import org.springframework.ldap.core.ContextMapper; -import org.springframework.ldap.core.ContextSource; -import org.springframework.ldap.core.DirContextAdapter; -import org.springframework.ldap.core.DirContextOperations; -import org.springframework.ldap.core.DistinguishedName; -import org.springframework.ldap.core.LdapTemplate; -import org.springframework.util.Assert; +import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import javax.naming.NamingEnumeration; import javax.naming.NamingException; @@ -35,14 +33,21 @@ import javax.naming.directory.Attributes; import javax.naming.directory.DirContext; import javax.naming.directory.SearchControls; import javax.naming.directory.SearchResult; -import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.ldap.core.ContextExecutor; +import org.springframework.ldap.core.ContextMapper; +import org.springframework.ldap.core.ContextSource; +import org.springframework.ldap.core.DirContextAdapter; +import org.springframework.ldap.core.DirContextOperations; +import org.springframework.ldap.core.DistinguishedName; +import org.springframework.ldap.core.LdapTemplate; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * Extension of Spring LDAP's LdapTemplate class which adds extra functionality required @@ -54,8 +59,7 @@ import java.util.Set; * @since 2.0 */ public class SpringSecurityLdapTemplate extends LdapTemplate { - // ~ Static fields/initializers - // ===================================================================================== + private static final Log logger = LogFactory.getLog(SpringSecurityLdapTemplate.class); public static final String[] NO_ATTRS = new String[0]; @@ -68,77 +72,47 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { private static final boolean RETURN_OBJECT = true; - // ~ Instance fields - // ================================================================================================ - /** Default search controls */ private SearchControls searchControls = new SearchControls(); - // ~ Constructors - // =================================================================================================== - public SpringSecurityLdapTemplate(ContextSource contextSource) { Assert.notNull(contextSource, "ContextSource cannot be null"); setContextSource(contextSource); - - searchControls.setSearchScope(SearchControls.SUBTREE_SCOPE); + this.searchControls.setSearchScope(SearchControls.SUBTREE_SCOPE); } - // ~ Methods - // ======================================================================================================== - /** * Performs an LDAP compare operation of the value of an attribute for a particular * directory entry. - * * @param dn the entry who's attribute is to be used * @param attributeName the attribute who's value we want to compare * @param value the value to be checked against the directory value - * * @return true if the supplied value matches that in the directory */ - public boolean compare(final String dn, final String attributeName, final Object value) { - final String comparisonFilter = "(" + attributeName + "={0})"; - - class LdapCompareCallback implements ContextExecutor { - - public Object executeWithContext(DirContext ctx) throws NamingException { - SearchControls ctls = new SearchControls(); - ctls.setReturningAttributes(NO_ATTRS); - ctls.setSearchScope(SearchControls.OBJECT_SCOPE); - - NamingEnumeration results = ctx.search(dn, - comparisonFilter, new Object[] { value }, ctls); - - Boolean match = results.hasMore(); - LdapUtils.closeEnumeration(results); - - return match; - } - } - - Boolean matches = (Boolean) executeReadOnly(new LdapCompareCallback()); - - return matches; + public boolean compare(String dn, String attributeName, Object value) { + String comparisonFilter = "(" + attributeName + "={0})"; + return executeReadOnly((ctx) -> { + SearchControls searchControls = new SearchControls(); + searchControls.setReturningAttributes(NO_ATTRS); + searchControls.setSearchScope(SearchControls.OBJECT_SCOPE); + Object[] params = new Object[] { value }; + NamingEnumeration results = ctx.search(dn, comparisonFilter, params, searchControls); + Boolean match = results.hasMore(); + LdapUtils.closeEnumeration(results); + return match; + }); } /** * Composes an object from the attributes of the given DN. - * * @param dn the directory entry which will be read * @param attributesToRetrieve the named attributes which will be retrieved from the * directory entry. - * * @return the object created by the mapper */ - public DirContextOperations retrieveEntry(final String dn, - final String[] attributesToRetrieve) { - - return (DirContextOperations) executeReadOnly((ContextExecutor) ctx -> { + public DirContextOperations retrieveEntry(final String dn, final String[] attributesToRetrieve) { + return (DirContextOperations) executeReadOnly((ContextExecutor) (ctx) -> { Attributes attrs = ctx.getAttributes(dn, attributesToRetrieve); - - // Object object = ctx.lookup(LdapUtils.getRelativeName(dn, ctx)); - return new DirContextAdapter(attrs, new DistinguishedName(dn), new DistinguishedName(ctx.getNameInNamespace())); }); @@ -149,20 +123,18 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { * the named attribute found in all entries matched by the search. Note that one * directory entry may have several values for the attribute. Intended for role * searches and similar scenarios. - * * @param base the DN to search in * @param filter search filter to use * @param params the parameters to substitute in the search filter * @param attributeName the attribute who's values are to be retrieved. - * * @return the set of String values for the attribute as a union of the values found * in all the matching entries. */ - public Set searchForSingleAttributeValues(final String base, - final String filter, final Object[] params, final String attributeName) { + public Set searchForSingleAttributeValues(final String base, final String filter, final Object[] params, + final String attributeName) { String[] attributeNames = new String[] { attributeName }; - Set>> multipleAttributeValues = searchForMultipleAttributeValues( - base, filter, params, attributeNames); + Set>> multipleAttributeValues = searchForMultipleAttributeValues(base, filter, params, + attributeNames); Set result = new HashSet<>(); for (Map> map : multipleAttributeValues) { List values = map.get(attributeName); @@ -178,45 +150,36 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { * attribute found in all entries matched by the search. Note that one directory entry * may have several values for the attribute. Intended for role searches and similar * scenarios. - * * @param base the DN to search in * @param filter search filter to use * @param params the parameters to substitute in the search filter * @param attributeNames the attributes' values that are to be retrieved. - * * @return the set of String values for each attribute found in all the matching * entries. The attribute name is the key for each set of values. In addition each map * contains the DN as a String with the key predefined key {@link #DN_KEY}. */ - public Set>> searchForMultipleAttributeValues( - final String base, final String filter, final Object[] params, - final String[] attributeNames) { + public Set>> searchForMultipleAttributeValues(String base, String filter, Object[] params, + String[] attributeNames) { // Escape the params acording to RFC2254 Object[] encodedParams = new String[params.length]; - for (int i = 0; i < params.length; i++) { encodedParams[i] = LdapEncoder.filterEncode(params[i].toString()); } - String formattedFilter = MessageFormat.format(filter, encodedParams); - logger.debug("Using filter: " + formattedFilter); - - final HashSet>> set = new HashSet<>(); - - ContextMapper roleMapper = ctx -> { + logger.debug(LogMessage.format("Using filter: %s", formattedFilter)); + HashSet>> result = new HashSet<>(); + ContextMapper roleMapper = (ctx) -> { DirContextAdapter adapter = (DirContextAdapter) ctx; Map> record = new HashMap<>(); - if (attributeNames == null || attributeNames.length == 0) { + if (ObjectUtils.isEmpty(attributeNames)) { try { - for (NamingEnumeration ae = adapter.getAttributes().getAll(); ae - .hasMore();) { - Attribute attr = (Attribute) ae.next(); + for (NamingEnumeration enumeration = adapter.getAttributes().getAll(); enumeration.hasMore();) { + Attribute attr = (Attribute) enumeration.next(); extractStringAttributeValues(adapter, record, attr.getID()); } } - catch (NamingException x) { - org.springframework.ldap.support.LdapUtils - .convertLdapException(x); + catch (NamingException ex) { + org.springframework.ldap.support.LdapUtils.convertLdapException(ex); } } else { @@ -225,18 +188,14 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { } } record.put(DN_KEY, Arrays.asList(getAdapterDN(adapter))); - set.add(record); + result.add(record); return null; }; - SearchControls ctls = new SearchControls(); - ctls.setSearchScope(searchControls.getSearchScope()); - ctls.setReturningAttributes(attributeNames != null && attributeNames.length > 0 ? attributeNames - : null); - + ctls.setSearchScope(this.searchControls.getSearchScope()); + ctls.setReturningAttributes((attributeNames != null && attributeNames.length > 0) ? attributeNames : null); search(base, formattedFilter, ctls, roleMapper); - - return set; + return result; } /** @@ -256,37 +215,31 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { * Extracts String values for a specified attribute name and places them in the map * representing the ldap record If a value is not of type String, it will derive it's * value from the {@link Object#toString()} - * * @param adapter - the adapter that contains the values * @param record - the map holding the attribute names and values * @param attributeName - the name for which to fetch the values from */ - private void extractStringAttributeValues(DirContextAdapter adapter, - Map> record, String attributeName) { + private void extractStringAttributeValues(DirContextAdapter adapter, Map> record, + String attributeName) { Object[] values = adapter.getObjectAttributes(attributeName); if (values == null || values.length == 0) { - if (logger.isDebugEnabled()) { - logger.debug("No attribute value found for '" + attributeName + "'"); - } + logger.debug(LogMessage.format("No attribute value found for '%s'", attributeName)); return; } - List svalues = new ArrayList<>(); - for (Object o : values) { - if (o != null) { - if (String.class.isAssignableFrom(o.getClass())) { - svalues.add((String) o); + List stringValues = new ArrayList<>(); + for (Object value : values) { + if (value != null) { + if (String.class.isAssignableFrom(value.getClass())) { + stringValues.add((String) value); } else { - if (logger.isDebugEnabled()) { - logger.debug("Attribute:" + attributeName - + " contains a non string value of type[" + o.getClass() - + "]"); - } - svalues.add(o.toString()); + logger.debug(LogMessage.format("Attribute:%s contains a non string value of type[%s]", + attributeName, value.getClass())); + stringValues.add(value.toString()); } } } - record.put(attributeName, svalues); + record.put(attributeName, stringValues); } /** @@ -295,68 +248,47 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { *

        * Ignores PartialResultException if thrown, for compatibility with Active * Directory (see {@link LdapTemplate#setIgnorePartialResultException(boolean)}). - * * @param base the search base, relative to the base context supplied by the context * source. * @param filter the LDAP search filter * @param params parameters to be substituted in the search. - * * @return a DirContextOperations instance created from the matching entry. - * * @throws IncorrectResultSizeDataAccessException if no results are found or the * search returns more than one result. */ - public DirContextOperations searchForSingleEntry(final String base, - final String filter, final Object[] params) { - - return (DirContextOperations) executeReadOnly((ContextExecutor) ctx -> searchForSingleEntryInternal(ctx, searchControls, base, filter, - params)); + public DirContextOperations searchForSingleEntry(String base, String filter, Object[] params) { + return (DirContextOperations) executeReadOnly((ContextExecutor) (ctx) -> searchForSingleEntryInternal(ctx, + this.searchControls, base, filter, params)); } /** * Internal method extracted to avoid code duplication in AD search. */ - public static DirContextOperations searchForSingleEntryInternal(DirContext ctx, - SearchControls searchControls, String base, String filter, Object[] params) - throws NamingException { - final DistinguishedName ctxBaseDn = new DistinguishedName( - ctx.getNameInNamespace()); + public static DirContextOperations searchForSingleEntryInternal(DirContext ctx, SearchControls searchControls, + String base, String filter, Object[] params) throws NamingException { + final DistinguishedName ctxBaseDn = new DistinguishedName(ctx.getNameInNamespace()); final DistinguishedName searchBaseDn = new DistinguishedName(base); - final NamingEnumeration resultsEnum = ctx.search(searchBaseDn, - filter, params, buildControls(searchControls)); - - if (logger.isDebugEnabled()) { - logger.debug("Searching for entry under DN '" + ctxBaseDn + "', base = '" - + searchBaseDn + "', filter = '" + filter + "'"); - } - + final NamingEnumeration resultsEnum = ctx.search(searchBaseDn, filter, params, + buildControls(searchControls)); + logger.debug(LogMessage.format("Searching for entry under DN '%s', base = '%s', filter = '%s'", ctxBaseDn, + searchBaseDn, filter)); Set results = new HashSet<>(); try { while (resultsEnum.hasMore()) { SearchResult searchResult = resultsEnum.next(); DirContextAdapter dca = (DirContextAdapter) searchResult.getObject(); - Assert.notNull(dca, - "No object returned by search, DirContext is not correctly configured"); - - if (logger.isDebugEnabled()) { - logger.debug("Found DN: " + dca.getDn()); - } + Assert.notNull(dca, "No object returned by search, DirContext is not correctly configured"); + logger.debug(LogMessage.format("Found DN: %s", dca.getDn())); results.add(dca); } } - catch (PartialResultException e) { + catch (PartialResultException ex) { LdapUtils.closeEnumeration(resultsEnum); logger.info("Ignoring PartialResultException"); } - - if (results.size() == 0) { - throw new IncorrectResultSizeDataAccessException(1, 0); - } - - if (results.size() > 1) { + if (results.size() != 1) { throw new IncorrectResultSizeDataAccessException(1, results.size()); } - return results.iterator().next(); } @@ -367,19 +299,18 @@ public class SpringSecurityLdapTemplate extends LdapTemplate { * @return */ private static SearchControls buildControls(SearchControls originalControls) { - return new SearchControls(originalControls.getSearchScope(), - originalControls.getCountLimit(), originalControls.getTimeLimit(), - originalControls.getReturningAttributes(), RETURN_OBJECT, + return new SearchControls(originalControls.getSearchScope(), originalControls.getCountLimit(), + originalControls.getTimeLimit(), originalControls.getReturningAttributes(), RETURN_OBJECT, originalControls.getDerefLinkFlag()); } /** * Sets the search controls which will be used for search operations by the template. - * * @param searchControls the SearchControls instance which will be cached in the * template. */ public void setSearchControls(SearchControls searchControls) { this.searchControls = searchControls; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticationProvider.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticationProvider.java index 7fc3544929..69dd517de6 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticationProvider.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication; import java.util.Collection; @@ -23,6 +24,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.BadCredentialsException; @@ -46,80 +48,66 @@ import org.springframework.util.StringUtils; * @author Luke Taylor * @since 3.1 */ -public abstract class AbstractLdapAuthenticationProvider - implements AuthenticationProvider, MessageSourceAware { +public abstract class AbstractLdapAuthenticationProvider implements AuthenticationProvider, MessageSourceAware { + protected final Log logger = LogFactory.getLog(getClass()); + protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); + private boolean useAuthenticationRequestCredentials = true; + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); + protected UserDetailsContextMapper userDetailsContextMapper = new LdapUserDetailsMapper(); - public Authentication authenticate(Authentication authentication) - throws AuthenticationException { + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { Assert.isInstanceOf(UsernamePasswordAuthenticationToken.class, authentication, () -> this.messages.getMessage("LdapAuthenticationProvider.onlySupports", "Only UsernamePasswordAuthenticationToken is supported")); - - final UsernamePasswordAuthenticationToken userToken = (UsernamePasswordAuthenticationToken) authentication; - + UsernamePasswordAuthenticationToken userToken = (UsernamePasswordAuthenticationToken) authentication; String username = userToken.getName(); String password = (String) authentication.getCredentials(); - - if (this.logger.isDebugEnabled()) { - this.logger.debug("Processing authentication request for user: " + username); - } - + this.logger.debug(LogMessage.format("Processing authentication request for user: %s", username)); if (!StringUtils.hasLength(username)) { - throw new BadCredentialsException(this.messages.getMessage( - "LdapAuthenticationProvider.emptyUsername", "Empty Username")); + throw new BadCredentialsException( + this.messages.getMessage("LdapAuthenticationProvider.emptyUsername", "Empty Username")); } - if (!StringUtils.hasLength(password)) { - throw new BadCredentialsException(this.messages.getMessage( - "AbstractLdapAuthenticationProvider.emptyPassword", - "Empty Password")); + throw new BadCredentialsException( + this.messages.getMessage("AbstractLdapAuthenticationProvider.emptyPassword", "Empty Password")); } - Assert.notNull(password, "Null password was supplied in authentication token"); - DirContextOperations userData = doAuthentication(userToken); - - UserDetails user = this.userDetailsContextMapper.mapUserFromContext(userData, - authentication.getName(), - loadUserAuthorities(userData, authentication.getName(), - (String) authentication.getCredentials())); - + UserDetails user = this.userDetailsContextMapper.mapUserFromContext(userData, authentication.getName(), + loadUserAuthorities(userData, authentication.getName(), (String) authentication.getCredentials())); return createSuccessfulAuthentication(userToken, user); } - protected abstract DirContextOperations doAuthentication( - UsernamePasswordAuthenticationToken auth); + protected abstract DirContextOperations doAuthentication(UsernamePasswordAuthenticationToken auth); - protected abstract Collection loadUserAuthorities( - DirContextOperations userData, String username, String password); + protected abstract Collection loadUserAuthorities(DirContextOperations userData, + String username, String password); /** * Creates the final {@code Authentication} object which will be returned from the * {@code authenticate} method. - * * @param authentication the original authentication request token * @param user the UserDetails instance returned by the configured * UserDetailsContextMapper. * @return the Authentication object for the fully authenticated user. */ - protected Authentication createSuccessfulAuthentication( - UsernamePasswordAuthenticationToken authentication, UserDetails user) { - Object password = this.useAuthenticationRequestCredentials - ? authentication.getCredentials() : user.getPassword(); - - UsernamePasswordAuthenticationToken result = new UsernamePasswordAuthenticationToken( - user, password, + protected Authentication createSuccessfulAuthentication(UsernamePasswordAuthenticationToken authentication, + UserDetails user) { + Object password = this.useAuthenticationRequestCredentials ? authentication.getCredentials() + : user.getPassword(); + UsernamePasswordAuthenticationToken result = new UsernamePasswordAuthenticationToken(user, password, this.authoritiesMapper.mapAuthorities(user.getAuthorities())); result.setDetails(authentication.getDetails()); - return result; } + @Override public boolean supports(Class authentication) { return UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication); } @@ -130,14 +118,13 @@ public abstract class AbstractLdapAuthenticationProvider * obtained from the UserDetails object created by the configured * {@code UserDetailsContextMapper}. Often it will not be possible to read the * password from the directory, so defaults to true. - * * @param useAuthenticationRequestCredentials */ - public void setUseAuthenticationRequestCredentials( - boolean useAuthenticationRequestCredentials) { + public void setUseAuthenticationRequestCredentials(boolean useAuthenticationRequestCredentials) { this.useAuthenticationRequestCredentials = useAuthenticationRequestCredentials; } + @Override public void setMessageSource(MessageSource messageSource) { this.messages = new MessageSourceAccessor(messageSource); } @@ -151,14 +138,11 @@ public abstract class AbstractLdapAuthenticationProvider * will be stored as the principal in the Authentication returned by the * {@link #createSuccessfulAuthentication(org.springframework.security.authentication.UsernamePasswordAuthenticationToken, org.springframework.security.core.userdetails.UserDetails)} * method. - * * @param userDetailsContextMapper the strategy instance. If not set, defaults to a * simple LdapUserDetailsMapper. */ - public void setUserDetailsContextMapper( - UserDetailsContextMapper userDetailsContextMapper) { - Assert.notNull(userDetailsContextMapper, - "UserDetailsContextMapper must not be null"); + public void setUserDetailsContextMapper(UserDetailsContextMapper userDetailsContextMapper) { + Assert.notNull(userDetailsContextMapper, "UserDetailsContextMapper must not be null"); this.userDetailsContextMapper = userDetailsContextMapper; } @@ -169,4 +153,5 @@ public abstract class AbstractLdapAuthenticationProvider protected UserDetailsContextMapper getUserDetailsContextMapper() { return this.userDetailsContextMapper; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticator.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticator.java index 0fb836a6ae..346aa1ca4d 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/AbstractLdapAuthenticator.java @@ -16,29 +16,26 @@ package org.springframework.security.ldap.authentication; -import org.springframework.security.core.SpringSecurityMessageSource; -import org.springframework.security.ldap.search.LdapUserSearch; +import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + import org.springframework.beans.factory.InitializingBean; import org.springframework.context.MessageSource; import org.springframework.context.MessageSourceAware; import org.springframework.context.support.MessageSourceAccessor; import org.springframework.ldap.core.ContextSource; +import org.springframework.security.core.SpringSecurityMessageSource; +import org.springframework.security.ldap.search.LdapUserSearch; import org.springframework.util.Assert; -import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - /** * Base class for the authenticator implementations. * * @author Luke Taylor */ -public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, - InitializingBean, MessageSourceAware { - // ~ Instance fields - // ================================================================================================ +public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, InitializingBean, MessageSourceAware { private final ContextSource contextSource; @@ -47,6 +44,7 @@ public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, * isn't sufficient */ private LdapUserSearch userSearch; + protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor(); /** @@ -59,12 +57,8 @@ public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, /** Stores the patterns which are used as potential DN matches */ private MessageFormat[] userDnFormat = null; - // ~ Constructors - // =================================================================================================== - /** * Create an initialized instance with the {@link ContextSource} provided. - * * @param contextSource */ public AbstractLdapAuthenticator(ContextSource contextSource) { @@ -72,52 +66,46 @@ public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, this.contextSource = contextSource; } - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { - Assert.isTrue((userDnFormat != null) || (userSearch != null), + Assert.isTrue((this.userDnFormat != null) || (this.userSearch != null), "Either an LdapUserSearch or DN pattern (or both) must be supplied."); } protected ContextSource getContextSource() { - return contextSource; + return this.contextSource; } public String[] getUserAttributes() { - return userAttributes; + return this.userAttributes; } /** * Builds list of possible DNs for the user, worked out from the * userDnPatterns property. - * * @param username the user's login name - * * @return the list of possible DN matches, empty if userDnPatterns wasn't * set. */ protected List getUserDns(String username) { - if (userDnFormat == null) { + if (this.userDnFormat == null) { return Collections.emptyList(); } - - List userDns = new ArrayList<>(userDnFormat.length); + List userDns = new ArrayList<>(this.userDnFormat.length); String[] args = new String[] { LdapEncoder.nameEncode(username) }; - - synchronized (userDnFormat) { - for (MessageFormat formatter : userDnFormat) { + synchronized (this.userDnFormat) { + for (MessageFormat formatter : this.userDnFormat) { userDns.add(formatter.format(args)); } } - return userDns; } protected LdapUserSearch getUserSearch() { - return userSearch; + return this.userSearch; } + @Override public void setMessageSource(MessageSource messageSource) { Assert.notNull(messageSource, "Message source must not be null"); this.messages = new MessageSourceAccessor(messageSource); @@ -125,12 +113,10 @@ public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, /** * Sets the user attributes which will be retrieved from the directory. - * * @param userAttributes */ public void setUserAttributes(String[] userAttributes) { - Assert.notNull(userAttributes, - "The userAttributes property cannot be set to null"); + Assert.notNull(userAttributes, "The userAttributes property cannot be set to null"); this.userAttributes = userAttributes; } @@ -138,17 +124,15 @@ public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, * Sets the pattern which will be used to supply a DN for the user. The pattern should * be the name relative to the root DN. The pattern argument {0} will contain the * username. An example would be "cn={0},ou=people". - * * @param dnPattern the array of patterns which will be tried when converting a * username to a DN. */ public void setUserDnPatterns(String[] dnPattern) { Assert.notNull(dnPattern, "The array of DN patterns cannot be set to null"); // this.userDnPattern = dnPattern; - userDnFormat = new MessageFormat[dnPattern.length]; - + this.userDnFormat = new MessageFormat[dnPattern.length]; for (int i = 0; i < dnPattern.length; i++) { - userDnFormat[i] = new MessageFormat(dnPattern[i]); + this.userDnFormat[i] = new MessageFormat(dnPattern[i]); } } @@ -156,4 +140,5 @@ public abstract class AbstractLdapAuthenticator implements LdapAuthenticator, Assert.notNull(userSearch, "The userSearch cannot be set to null"); this.userSearch = userSearch; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/BindAuthenticator.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/BindAuthenticator.java index 26dea8c4a7..1c4fa66eff 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/BindAuthenticator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/BindAuthenticator.java @@ -16,8 +16,13 @@ package org.springframework.security.ldap.authentication; +import javax.naming.directory.Attributes; +import javax.naming.directory.DirContext; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.ldap.NamingException; import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DirContextOperations; @@ -32,133 +37,101 @@ import org.springframework.security.ldap.ppolicy.PasswordPolicyControlExtractor; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import javax.naming.directory.Attributes; -import javax.naming.directory.DirContext; - /** * An authenticator which binds as a user. * * @author Luke Taylor - * * @see AbstractLdapAuthenticator */ public class BindAuthenticator extends AbstractLdapAuthenticator { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(BindAuthenticator.class); - // ~ Constructors - // =================================================================================================== - /** * Create an initialized instance using the {@link BaseLdapPathContextSource} * provided. - * * @param contextSource the BaseLdapPathContextSource instance against which bind * operations will be performed. - * */ public BindAuthenticator(BaseLdapPathContextSource contextSource) { super(contextSource); } - // ~ Methods - // ======================================================================================================== - + @Override public DirContextOperations authenticate(Authentication authentication) { DirContextOperations user = null; Assert.isInstanceOf(UsernamePasswordAuthenticationToken.class, authentication, "Can only process UsernamePasswordAuthenticationToken objects"); - String username = authentication.getName(); String password = (String) authentication.getCredentials(); - if (!StringUtils.hasLength(password)) { - logger.debug("Rejecting empty password for user " + username); - throw new BadCredentialsException(messages.getMessage( - "BindAuthenticator.emptyPassword", "Empty Password")); + logger.debug(LogMessage.format("Rejecting empty password for user %s", username)); + throw new BadCredentialsException( + this.messages.getMessage("BindAuthenticator.emptyPassword", "Empty Password")); } - // If DN patterns are configured, try authenticating with them directly for (String dn : getUserDns(username)) { user = bindWithDn(dn, username, password); - if (user != null) { break; } } - // Otherwise use the configured search object to find the user and authenticate // with the returned DN. if (user == null && getUserSearch() != null) { DirContextOperations userFromSearch = getUserSearch().searchForUser(username); - user = bindWithDn(userFromSearch.getDn().toString(), username, password, - userFromSearch.getAttributes()); + user = bindWithDn(userFromSearch.getDn().toString(), username, password, userFromSearch.getAttributes()); } - if (user == null) { - throw new BadCredentialsException(messages.getMessage( - "BindAuthenticator.badCredentials", "Bad credentials")); + throw new BadCredentialsException( + this.messages.getMessage("BindAuthenticator.badCredentials", "Bad credentials")); } - return user; } - private DirContextOperations bindWithDn(String userDnStr, String username, - String password) { + private DirContextOperations bindWithDn(String userDnStr, String username, String password) { return bindWithDn(userDnStr, username, password, null); } - private DirContextOperations bindWithDn(String userDnStr, String username, - String password, Attributes attrs) { + private DirContextOperations bindWithDn(String userDnStr, String username, String password, Attributes attrs) { BaseLdapPathContextSource ctxSource = (BaseLdapPathContextSource) getContextSource(); DistinguishedName userDn = new DistinguishedName(userDnStr); DistinguishedName fullDn = new DistinguishedName(userDn); fullDn.prepend(ctxSource.getBaseLdapPath()); - - logger.debug("Attempting to bind as " + fullDn); - + logger.debug(LogMessage.format("Attempting to bind as %s", fullDn)); DirContext ctx = null; try { ctx = getContextSource().getContext(fullDn.toString(), password); // Check for password policy control - PasswordPolicyControl ppolicy = PasswordPolicyControlExtractor - .extractControl(ctx); - + PasswordPolicyControl ppolicy = PasswordPolicyControlExtractor.extractControl(ctx); logger.debug("Retrieving attributes..."); - if (attrs == null || attrs.size()==0) { + if (attrs == null || attrs.size() == 0) { attrs = ctx.getAttributes(userDn, getUserAttributes()); } - - DirContextAdapter result = new DirContextAdapter(attrs, userDn, - ctxSource.getBaseLdapPath()); - + DirContextAdapter result = new DirContextAdapter(attrs, userDn, ctxSource.getBaseLdapPath()); if (ppolicy != null) { result.setAttributeValue(ppolicy.getID(), ppolicy); } - return result; } - catch (NamingException e) { + catch (NamingException ex) { // This will be thrown if an invalid user name is used and the method may // be called multiple times to try different names, so we trap the exception // unless a subclass wishes to implement more specialized behaviour. - if ((e instanceof org.springframework.ldap.AuthenticationException) - || (e instanceof org.springframework.ldap.OperationNotSupportedException)) { - handleBindException(userDnStr, username, e); + if ((ex instanceof org.springframework.ldap.AuthenticationException) + || (ex instanceof org.springframework.ldap.OperationNotSupportedException)) { + handleBindException(userDnStr, username, ex); } else { - throw e; + throw ex; } } - catch (javax.naming.NamingException e) { - throw LdapUtils.convertLdapException(e); + catch (javax.naming.NamingException ex) { + throw LdapUtils.convertLdapException(ex); } finally { LdapUtils.closeContext(ctx); } - return null; } @@ -168,8 +141,7 @@ public class BindAuthenticator extends AbstractLdapAuthenticator { * logger. */ protected void handleBindException(String userDn, String username, Throwable cause) { - if (logger.isDebugEnabled()) { - logger.debug("Failed to bind as " + userDn + ": " + cause); - } + logger.debug(LogMessage.format("Failed to bind as %s: %s", userDn, cause)); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticationProvider.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticationProvider.java index 8381f2efcb..9b4573a0db 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticationProvider.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticationProvider.java @@ -108,37 +108,30 @@ import org.springframework.util.Assert; * this means that if the LDAP directory is configured to allow unauthenticated access, it * might be possible to authenticate as any user just by supplying an empty * password. More information on the misuse of unauthenticated access can be found in - * draft - * -ietf-ldapbis-authmeth-19.txt. - * + * + * draft -ietf-ldapbis-authmeth-19.txt. * * @author Luke Taylor - * * @see BindAuthenticator * @see DefaultLdapAuthoritiesPopulator */ public class LdapAuthenticationProvider extends AbstractLdapAuthenticationProvider { - // ~ Instance fields - // ================================================================================================ private LdapAuthenticator authenticator; - private LdapAuthoritiesPopulator authoritiesPopulator; - private boolean hideUserNotFoundExceptions = true; - // ~ Constructors - // =================================================================================================== + private LdapAuthoritiesPopulator authoritiesPopulator; + + private boolean hideUserNotFoundExceptions = true; /** * Create an instance with the supplied authenticator and authorities populator * implementations. - * * @param authenticator the authentication strategy (bind, password comparison, etc) * to be used by this provider for authenticating users. * @param authoritiesPopulator the strategy for obtaining the authorities for a given * user after they've been authenticated. */ - public LdapAuthenticationProvider(LdapAuthenticator authenticator, - LdapAuthoritiesPopulator authoritiesPopulator) { + public LdapAuthenticationProvider(LdapAuthenticator authenticator, LdapAuthoritiesPopulator authoritiesPopulator) { this.setAuthenticator(authenticator); this.setAuthoritiesPopulator(authoritiesPopulator); } @@ -146,7 +139,6 @@ public class LdapAuthenticationProvider extends AbstractLdapAuthenticationProvid /** * Creates an instance with the supplied authenticator and a null authorities * populator. In this case, the authorities must be mapped from the user context. - * * @param authenticator the authenticator strategy. */ public LdapAuthenticationProvider(LdapAuthenticator authenticator) { @@ -154,9 +146,6 @@ public class LdapAuthenticationProvider extends AbstractLdapAuthenticationProvid this.setAuthoritiesPopulator(new NullLdapAuthoritiesPopulator()); } - // ~ Methods - // ======================================================================================================== - private void setAuthenticator(LdapAuthenticator authenticator) { Assert.notNull(authenticator, "An LdapAuthenticator must be supplied"); this.authenticator = authenticator; @@ -167,8 +156,7 @@ public class LdapAuthenticationProvider extends AbstractLdapAuthenticationProvid } private void setAuthoritiesPopulator(LdapAuthoritiesPopulator authoritiesPopulator) { - Assert.notNull(authoritiesPopulator, - "An LdapAuthoritiesPopulator must be supplied"); + Assert.notNull(authoritiesPopulator, "An LdapAuthoritiesPopulator must be supplied"); this.authoritiesPopulator = authoritiesPopulator; } @@ -181,35 +169,31 @@ public class LdapAuthenticationProvider extends AbstractLdapAuthenticationProvid } @Override - protected DirContextOperations doAuthentication( - UsernamePasswordAuthenticationToken authentication) { + protected DirContextOperations doAuthentication(UsernamePasswordAuthenticationToken authentication) { try { return getAuthenticator().authenticate(authentication); } - catch (PasswordPolicyException ppe) { + catch (PasswordPolicyException ex) { // The only reason a ppolicy exception can occur during a bind is that the // account is locked. - throw new LockedException(this.messages.getMessage( - ppe.getStatus().getErrorCode(), ppe.getStatus().getDefaultMessage())); + throw new LockedException( + this.messages.getMessage(ex.getStatus().getErrorCode(), ex.getStatus().getDefaultMessage())); } - catch (UsernameNotFoundException notFound) { + catch (UsernameNotFoundException ex) { if (this.hideUserNotFoundExceptions) { - throw new BadCredentialsException(this.messages.getMessage( - "LdapAuthenticationProvider.badCredentials", "Bad credentials")); - } - else { - throw notFound; + throw new BadCredentialsException( + this.messages.getMessage("LdapAuthenticationProvider.badCredentials", "Bad credentials")); } + throw ex; } - catch (NamingException ldapAccessFailure) { - throw new InternalAuthenticationServiceException( - ldapAccessFailure.getMessage(), ldapAccessFailure); + catch (NamingException ex) { + throw new InternalAuthenticationServiceException(ex.getMessage(), ex); } } @Override - protected Collection loadUserAuthorities( - DirContextOperations userData, String username, String password) { + protected Collection loadUserAuthorities(DirContextOperations userData, String username, + String password) { return getAuthoritiesPopulator().getGrantedAuthorities(userData, username); } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticator.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticator.java index 93cd57730c..f1492aa34b 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapAuthenticator.java @@ -16,8 +16,8 @@ package org.springframework.security.ldap.authentication; -import org.springframework.security.core.Authentication; import org.springframework.ldap.core.DirContextOperations; +import org.springframework.security.core.Authentication; /** * The strategy interface for locating and authenticating an Ldap user. @@ -26,19 +26,16 @@ import org.springframework.ldap.core.DirContextOperations; * the information for that user from the directory. * * @author Luke Taylor - * * @see org.springframework.security.ldap.userdetails.DefaultLdapAuthoritiesPopulator * @see org.springframework.security.ldap.authentication.UserDetailsServiceLdapAuthoritiesPopulator */ public interface LdapAuthenticator { - // ~ Methods - // ======================================================================================================== /** * Authenticates as a user and obtains additional user information from the directory. - * * @param authentication * @return the details of the successfully authenticated user. */ DirContextOperations authenticate(Authentication authentication); + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapEncoder.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapEncoder.java index 30ed9e4493..f79f4843ae 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapEncoder.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/LdapEncoder.java @@ -1,245 +1,195 @@ -/* - * Copyright 2005-2010 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.ldap.authentication; - -import org.springframework.ldap.BadLdapGrammarException; - -/** - * Helper class to encode and decode ldap names and values. - * - *

        - * NOTE: This is a copy from Spring LDAP so that both Spring LDAP 1.x and 2.x can be - * supported without reflection. - *

        - * - * @author Adam Skogman - * @author Mattias Hellborg Arthursson - */ -final class LdapEncoder { - - private static final int HEX = 16; - private static String[] NAME_ESCAPE_TABLE = new String[96]; - - private static String[] FILTER_ESCAPE_TABLE = new String['\\' + 1]; - - static { - - // Name encoding table ------------------------------------- - - // all below 0x20 (control chars) - for (char c = 0; c < ' '; c++) { - NAME_ESCAPE_TABLE[c] = "\\" + toTwoCharHex(c); - } - - NAME_ESCAPE_TABLE['#'] = "\\#"; - NAME_ESCAPE_TABLE[','] = "\\,"; - NAME_ESCAPE_TABLE[';'] = "\\;"; - NAME_ESCAPE_TABLE['='] = "\\="; - NAME_ESCAPE_TABLE['+'] = "\\+"; - NAME_ESCAPE_TABLE['<'] = "\\<"; - NAME_ESCAPE_TABLE['>'] = "\\>"; - NAME_ESCAPE_TABLE['\"'] = "\\\""; - NAME_ESCAPE_TABLE['\\'] = "\\\\"; - - // Filter encoding table ------------------------------------- - - // fill with char itself - for (char c = 0; c < FILTER_ESCAPE_TABLE.length; c++) { - FILTER_ESCAPE_TABLE[c] = String.valueOf(c); - } - - // escapes (RFC2254) - FILTER_ESCAPE_TABLE['*'] = "\\2a"; - FILTER_ESCAPE_TABLE['('] = "\\28"; - FILTER_ESCAPE_TABLE[')'] = "\\29"; - FILTER_ESCAPE_TABLE['\\'] = "\\5c"; - FILTER_ESCAPE_TABLE[0] = "\\00"; - - } - - /** - * All static methods - not to be instantiated. - */ - private LdapEncoder() { - } - - protected static String toTwoCharHex(char c) { - - String raw = Integer.toHexString(c).toUpperCase(); - - if (raw.length() > 1) { - return raw; - } - else { - return "0" + raw; - } - } - - /** - * Escape a value for use in a filter. - * - * @param value the value to escape. - * @return a properly escaped representation of the supplied value. - */ - public static String filterEncode(String value) { - - if (value == null) { - return null; - } - - // make buffer roomy - StringBuilder encodedValue = new StringBuilder(value.length() * 2); - - int length = value.length(); - - for (int i = 0; i < length; i++) { - - char c = value.charAt(i); - - if (c < FILTER_ESCAPE_TABLE.length) { - encodedValue.append(FILTER_ESCAPE_TABLE[c]); - } - else { - // default: add the char - encodedValue.append(c); - } - } - - return encodedValue.toString(); - } - - /** - * LDAP Encodes a value for use with a DN. Escapes for LDAP, not JNDI! - * - *
        - * Escapes:
        - * ' ' [space] - "\ " [if first or last]
        - * '#' [hash] - "\#"
        - * ',' [comma] - "\,"
        - * ';' [semicolon] - "\;"
        - * '= [equals] - "\="
        - * '+' [plus] - "\+"
        - * '<' [less than] - "\<"
        - * '>' [greater than] - "\>"
        - * '"' [double quote] - "\""
        - * '\' [backslash] - "\\"
        - * - * @param value the value to escape. - * @return The escaped value. - */ - public static String nameEncode(String value) { - - if (value == null) { - return null; - } - - // make buffer roomy - StringBuilder encodedValue = new StringBuilder(value.length() * 2); - - int length = value.length(); - int last = length - 1; - - for (int i = 0; i < length; i++) { - - char c = value.charAt(i); - - // space first or last - if (c == ' ' && (i == 0 || i == last)) { - encodedValue.append("\\ "); - continue; - } - - if (c < NAME_ESCAPE_TABLE.length) { - // check in table for escapes - String esc = NAME_ESCAPE_TABLE[c]; - - if (esc != null) { - encodedValue.append(esc); - continue; - } - } - - // default: add the char - encodedValue.append(c); - } - - return encodedValue.toString(); - - } - - /** - * Decodes a value. Converts escaped chars to ordinary chars. - * - * @param value Trimmed value, so no leading an trailing blanks, except an escaped - * space last. - * @return The decoded value as a string. - * @throws BadLdapGrammarException - */ - static public String nameDecode(String value) throws BadLdapGrammarException { - - if (value == null) { - return null; - } - - // make buffer same size - StringBuilder decoded = new StringBuilder(value.length()); - - int i = 0; - while (i < value.length()) { - char currentChar = value.charAt(i); - if (currentChar == '\\') { - if (value.length() <= i + 1) { - // Ending with a single backslash is not allowed - throw new BadLdapGrammarException( - "Unexpected end of value " + "unterminated '\\'"); - } - else { - char nextChar = value.charAt(i + 1); - if (nextChar == ',' || nextChar == '=' || nextChar == '+' - || nextChar == '<' || nextChar == '>' || nextChar == '#' - || nextChar == ';' || nextChar == '\\' || nextChar == '\"' - || nextChar == ' ') { - // Normal backslash escape - decoded.append(nextChar); - i += 2; - } - else { - if (value.length() <= i + 2) { - throw new BadLdapGrammarException("Unexpected end of value " - + "expected special or hex, found '" + nextChar - + "'"); - } - else { - // This should be a hex value - String hexString = "" + nextChar + value.charAt(i + 2); - decoded.append((char) Integer.parseInt(hexString, HEX)); - i += 3; - } - } - } - } - else { - // This character wasn't escaped - just append it - decoded.append(currentChar); - i++; - } - } - - return decoded.toString(); - - } -} +/* + * Copyright 2005-2010 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.ldap.authentication; + +import org.springframework.ldap.BadLdapGrammarException; + +/** + * Helper class to encode and decode ldap names and values. + * + *

        + * NOTE: This is a copy from Spring LDAP so that both Spring LDAP 1.x and 2.x can be + * supported without reflection. + *

        + * + * @author Adam Skogman + * @author Mattias Hellborg Arthursson + */ +final class LdapEncoder { + + private static final int HEX = 16; + + private static String[] NAME_ESCAPE_TABLE = new String[96]; + static { + // all below 0x20 (control chars) + for (char c = 0; c < ' '; c++) { + NAME_ESCAPE_TABLE[c] = "\\" + toTwoCharHex(c); + } + NAME_ESCAPE_TABLE['#'] = "\\#"; + NAME_ESCAPE_TABLE[','] = "\\,"; + NAME_ESCAPE_TABLE[';'] = "\\;"; + NAME_ESCAPE_TABLE['='] = "\\="; + NAME_ESCAPE_TABLE['+'] = "\\+"; + NAME_ESCAPE_TABLE['<'] = "\\<"; + NAME_ESCAPE_TABLE['>'] = "\\>"; + NAME_ESCAPE_TABLE['\"'] = "\\\""; + NAME_ESCAPE_TABLE['\\'] = "\\\\"; + } + + private static String[] FILTER_ESCAPE_TABLE = new String['\\' + 1]; + + static { + // fill with char itself + for (char c = 0; c < FILTER_ESCAPE_TABLE.length; c++) { + FILTER_ESCAPE_TABLE[c] = String.valueOf(c); + } + // escapes (RFC2254) + FILTER_ESCAPE_TABLE['*'] = "\\2a"; + FILTER_ESCAPE_TABLE['('] = "\\28"; + FILTER_ESCAPE_TABLE[')'] = "\\29"; + FILTER_ESCAPE_TABLE['\\'] = "\\5c"; + FILTER_ESCAPE_TABLE[0] = "\\00"; + } + + /** + * All static methods - not to be instantiated. + */ + private LdapEncoder() { + } + + protected static String toTwoCharHex(char c) { + String raw = Integer.toHexString(c).toUpperCase(); + return (raw.length() > 1) ? raw : "0" + raw; + } + + /** + * Escape a value for use in a filter. + * @param value the value to escape. + * @return a properly escaped representation of the supplied value. + */ + static String filterEncode(String value) { + if (value == null) { + return null; + } + StringBuilder encodedValue = new StringBuilder(value.length() * 2); + int length = value.length(); + for (int i = 0; i < length; i++) { + char ch = value.charAt(i); + encodedValue.append((ch < FILTER_ESCAPE_TABLE.length) ? FILTER_ESCAPE_TABLE[ch] : ch); + } + return encodedValue.toString(); + } + + /** + * LDAP Encodes a value for use with a DN. Escapes for LDAP, not JNDI! + * + *
        + * Escapes:
        + * ' ' [space] - "\ " [if first or last]
        + * '#' [hash] - "\#"
        + * ',' [comma] - "\,"
        + * ';' [semicolon] - "\;"
        + * '= [equals] - "\="
        + * '+' [plus] - "\+"
        + * '<' [less than] - "\<"
        + * '>' [greater than] - "\>"
        + * '"' [double quote] - "\""
        + * '\' [backslash] - "\\"
        + * @param value the value to escape. + * @return The escaped value. + */ + static String nameEncode(String value) { + if (value == null) { + return null; + } + StringBuilder encodedValue = new StringBuilder(value.length() * 2); + int length = value.length(); + int last = length - 1; + for (int i = 0; i < length; i++) { + char c = value.charAt(i); + // space first or last + if (c == ' ' && (i == 0 || i == last)) { + encodedValue.append("\\ "); + continue; + } + // check in table for escapes + if (c < NAME_ESCAPE_TABLE.length) { + String esc = NAME_ESCAPE_TABLE[c]; + if (esc != null) { + encodedValue.append(esc); + continue; + } + } + // default: add the char + encodedValue.append(c); + } + return encodedValue.toString(); + } + + /** + * Decodes a value. Converts escaped chars to ordinary chars. + * @param value Trimmed value, so no leading an trailing blanks, except an escaped + * space last. + * @return The decoded value as a string. + * @throws BadLdapGrammarException + */ + static String nameDecode(String value) throws BadLdapGrammarException { + if (value == null) { + return null; + } + StringBuilder decoded = new StringBuilder(value.length()); + int i = 0; + while (i < value.length()) { + char currentChar = value.charAt(i); + if (currentChar == '\\') { + // Ending with a single backslash is not allowed + if (value.length() <= i + 1) { + throw new BadLdapGrammarException("Unexpected end of value " + "unterminated '\\'"); + } + char nextChar = value.charAt(i + 1); + if (isNormalBackslashEscape(nextChar)) { + decoded.append(nextChar); + i += 2; + } + else { + if (value.length() <= i + 2) { + throw new BadLdapGrammarException( + "Unexpected end of value " + "expected special or hex, found '" + nextChar + "'"); + } + // This should be a hex value + String hexString = "" + nextChar + value.charAt(i + 2); + decoded.append((char) Integer.parseInt(hexString, HEX)); + i += 3; + } + } + else { + // This character wasn't escaped - just append it + decoded.append(currentChar); + i++; + } + } + + return decoded.toString(); + + } + + private static boolean isNormalBackslashEscape(char nextChar) { + return nextChar == ',' || nextChar == '=' || nextChar == '+' || nextChar == '<' || nextChar == '>' + || nextChar == '#' || nextChar == ';' || nextChar == '\\' || nextChar == '\"' || nextChar == ' '; + } + +} diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/NullLdapAuthoritiesPopulator.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/NullLdapAuthoritiesPopulator.java index a14f1ffb3c..d0d80e5f08 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/NullLdapAuthoritiesPopulator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/NullLdapAuthoritiesPopulator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication; import java.util.Collection; @@ -23,13 +24,14 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.ldap.userdetails.LdapAuthoritiesPopulator; /** - * * @author Luke Taylor * @since 3.0 */ public final class NullLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator { - public Collection getGrantedAuthorities( - DirContextOperations userDetails, String username) { + + @Override + public Collection getGrantedAuthorities(DirContextOperations userDetails, String username) { return AuthorityUtils.NO_AUTHORITIES; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticator.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticator.java index 348c16dcb4..a64b8b0368 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticator.java @@ -18,6 +18,8 @@ package org.springframework.security.ldap.authentication; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.ldap.NameNotFoundException; import org.springframework.ldap.core.DirContextOperations; import org.springframework.ldap.core.support.BaseLdapPathContextSource; @@ -47,41 +49,28 @@ import org.springframework.util.Assert; * @author Luke Taylor */ public final class PasswordComparisonAuthenticator extends AbstractLdapAuthenticator { - // ~ Static fields/initializers - // ===================================================================================== - private static final Log logger = LogFactory - .getLog(PasswordComparisonAuthenticator.class); - - // ~ Instance fields - // ================================================================================================ + private static final Log logger = LogFactory.getLog(PasswordComparisonAuthenticator.class); private PasswordEncoder passwordEncoder = new LdapShaPasswordEncoder(KeyGenerators.shared(0)); - private String passwordAttributeName = "userPassword"; - private boolean usePasswordAttrCompare = false; - // ~ Constructors - // =================================================================================================== + private String passwordAttributeName = "userPassword"; + + private boolean usePasswordAttrCompare = false; public PasswordComparisonAuthenticator(BaseLdapPathContextSource contextSource) { super(contextSource); } - // ~ Methods - // ======================================================================================================== - + @Override public DirContextOperations authenticate(final Authentication authentication) { Assert.isInstanceOf(UsernamePasswordAuthenticationToken.class, authentication, "Can only process UsernamePasswordAuthenticationToken objects"); // locate the user and check the password - DirContextOperations user = null; String username = authentication.getName(); String password = (String) authentication.getCredentials(); - - SpringSecurityLdapTemplate ldapTemplate = new SpringSecurityLdapTemplate( - getContextSource()); - + SpringSecurityLdapTemplate ldapTemplate = new SpringSecurityLdapTemplate(getContextSource()); for (String userDn : getUserDns(username)) { try { user = ldapTemplate.retrieveEntry(userDn, getUserAttributes()); @@ -92,33 +81,29 @@ public final class PasswordComparisonAuthenticator extends AbstractLdapAuthentic break; } } - if (user == null && getUserSearch() != null) { user = getUserSearch().searchForUser(username); } - if (user == null) { throw new UsernameNotFoundException("User not found: " + username); } - if (logger.isDebugEnabled()) { - logger.debug("Performing LDAP compare of password attribute '" - + passwordAttributeName + "' for user '" + user.getDn() + "'"); + logger.debug(LogMessage.format("Performing LDAP compare of password attribute '%s' for user '%s'", + this.passwordAttributeName, user.getDn())); } - - if (usePasswordAttrCompare && isPasswordAttrCompare(user, password)) { + if (this.usePasswordAttrCompare && isPasswordAttrCompare(user, password)) { return user; } - else if (isLdapPasswordCompare(user, ldapTemplate, password)) { + if (isLdapPasswordCompare(user, ldapTemplate, password)) { return user; } - throw new BadCredentialsException(messages.getMessage( - "PasswordComparisonAuthenticator.badCredentials", "Bad credentials")); + throw new BadCredentialsException( + this.messages.getMessage("PasswordComparisonAuthenticator.badCredentials", "Bad credentials")); } private boolean isPasswordAttrCompare(DirContextOperations user, String password) { String passwordAttrValue = getPassword(user); - return passwordEncoder.matches(password, passwordAttrValue); + return this.passwordEncoder.matches(password, passwordAttrValue); } private String getPassword(DirContextOperations user) { @@ -132,17 +117,15 @@ public final class PasswordComparisonAuthenticator extends AbstractLdapAuthentic return String.valueOf(passwordAttrValue); } - private boolean isLdapPasswordCompare(DirContextOperations user, - SpringSecurityLdapTemplate ldapTemplate, String password) { - String encodedPassword = passwordEncoder.encode(password); + private boolean isLdapPasswordCompare(DirContextOperations user, SpringSecurityLdapTemplate ldapTemplate, + String password) { + String encodedPassword = this.passwordEncoder.encode(password); byte[] passwordBytes = Utf8.encode(encodedPassword); - return ldapTemplate.compare(user.getDn().toString(), passwordAttributeName, - passwordBytes); + return ldapTemplate.compare(user.getDn().toString(), this.passwordAttributeName, passwordBytes); } public void setPasswordAttributeName(String passwordAttribute) { - Assert.hasLength(passwordAttribute, - "passwordAttributeName must not be empty or null"); + Assert.hasLength(passwordAttribute, "passwordAttributeName must not be empty or null"); this.passwordAttributeName = passwordAttribute; } @@ -155,4 +138,5 @@ public final class PasswordComparisonAuthenticator extends AbstractLdapAuthentic this.passwordEncoder = passwordEncoder; setUsePasswordAttrCompare(true); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/SpringSecurityAuthenticationSource.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/SpringSecurityAuthenticationSource.java index 0ee20693e0..14584623fc 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/SpringSecurityAuthenticationSource.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/SpringSecurityAuthenticationSource.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.ldap.core.AuthenticationSource; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.ldap.userdetails.LdapUserDetails; -import org.springframework.ldap.core.AuthenticationSource; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; /** * An AuthenticationSource to retrieve authentication information stored in Spring @@ -36,54 +37,44 @@ import org.apache.commons.logging.LogFactory; * @since 2.0 */ public class SpringSecurityAuthenticationSource implements AuthenticationSource { - private static final Log log = LogFactory - .getLog(SpringSecurityAuthenticationSource.class); + + private static final Log log = LogFactory.getLog(SpringSecurityAuthenticationSource.class); /** * Get the principals of the logged in user, in this case the distinguished name. - * * @return the distinguished name of the logged in user. */ + @Override public String getPrincipal() { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); - + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (authentication == null) { log.warn("No Authentication object set in SecurityContext - returning empty String as Principal"); return ""; } - Object principal = authentication.getPrincipal(); - if (principal instanceof LdapUserDetails) { LdapUserDetails details = (LdapUserDetails) principal; return details.getDn(); } - else if (authentication instanceof AnonymousAuthenticationToken) { - if (log.isDebugEnabled()) { - log.debug("Anonymous Authentication, returning empty String as Principal"); - } + if (authentication instanceof AnonymousAuthenticationToken) { + log.debug("Anonymous Authentication, returning empty String as Principal"); return ""; } - else { - throw new IllegalArgumentException( - "The principal property of the authentication object" - + "needs to be an LdapUserDetails."); - } + throw new IllegalArgumentException( + "The principal property of the authentication object" + "needs to be an LdapUserDetails."); } /** * @see org.springframework.ldap.core.AuthenticationSource#getCredentials() */ + @Override public String getCredentials() { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); - + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (authentication == null) { log.warn("No Authentication object set in SecurityContext - returning empty String as Credentials"); return ""; } - return (String) authentication.getCredentials(); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/UserDetailsServiceLdapAuthoritiesPopulator.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/UserDetailsServiceLdapAuthoritiesPopulator.java index 3f91dedfb4..64ea83fb08 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/UserDetailsServiceLdapAuthoritiesPopulator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/UserDetailsServiceLdapAuthoritiesPopulator.java @@ -13,26 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication; import java.util.Collection; +import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.ldap.userdetails.LdapAuthoritiesPopulator; -import org.springframework.ldap.core.DirContextOperations; import org.springframework.util.Assert; /** * Simple LdapAuthoritiesPopulator which delegates to a UserDetailsService, using the name * which was supplied at login as the username. * - * * @author Luke Taylor * @since 2.0 */ -public class UserDetailsServiceLdapAuthoritiesPopulator implements - LdapAuthoritiesPopulator { +public class UserDetailsServiceLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator { + private final UserDetailsService userDetailsService; public UserDetailsServiceLdapAuthoritiesPopulator(UserDetailsService userService) { @@ -40,8 +40,10 @@ public class UserDetailsServiceLdapAuthoritiesPopulator implements this.userDetailsService = userService; } - public Collection getGrantedAuthorities( - DirContextOperations userData, String username) { - return userDetailsService.loadUserByUsername(username).getAuthorities(); + @Override + public Collection getGrantedAuthorities(DirContextOperations userData, + String username) { + return this.userDetailsService.loadUserByUsername(username).getAuthorities(); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryAuthenticationException.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryAuthenticationException.java index 9f5c287e71..f4d5199b23 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryAuthenticationException.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryAuthenticationException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication.ad; import org.springframework.security.core.AuthenticationException; @@ -41,15 +42,16 @@ import org.springframework.security.core.AuthenticationException; */ @SuppressWarnings("serial") public final class ActiveDirectoryAuthenticationException extends AuthenticationException { + private final String dataCode; - ActiveDirectoryAuthenticationException(String dataCode, String message, - Throwable cause) { + ActiveDirectoryAuthenticationException(String dataCode, String message, Throwable cause) { super(message, cause); this.dataCode = dataCode; } public String getDataCode() { - return dataCode; + return this.dataCode; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProvider.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProvider.java index afcb2fa87e..381b3c3179 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProvider.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProvider.java @@ -13,8 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication.ad; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Hashtable; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.naming.AuthenticationException; +import javax.naming.Context; +import javax.naming.NamingException; +import javax.naming.OperationNotSupportedException; +import javax.naming.directory.DirContext; +import javax.naming.directory.SearchControls; +import javax.naming.ldap.InitialLdapContext; + +import org.springframework.core.log.LogMessage; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.ldap.CommunicationException; import org.springframework.ldap.core.DirContextOperations; @@ -37,24 +58,12 @@ import org.springframework.security.ldap.authentication.AbstractLdapAuthenticati import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import javax.naming.AuthenticationException; -import javax.naming.Context; -import javax.naming.NamingException; -import javax.naming.OperationNotSupportedException; -import javax.naming.directory.DirContext; -import javax.naming.directory.SearchControls; -import javax.naming.ldap.InitialLdapContext; -import java.io.Serializable; -import java.util.*; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - /** * Specialized LDAP authentication provider which uses Active Directory configuration * conventions. *

        - * It will authenticate using the Active Directory + * It will authenticate using the Active Directory + * * {@code userPrincipalName} or a custom {@link #setSearchFilter(String) searchFilter} * in the form {@code username@domain}. If the username does not already end with the * domain name, the {@code userPrincipalName} will be built by appending the configured @@ -90,26 +99,37 @@ import java.util.regex.Pattern; * @author Rob Winch * @since 3.1 */ -public final class ActiveDirectoryLdapAuthenticationProvider extends - AbstractLdapAuthenticationProvider { - private static final Pattern SUB_ERROR_CODE = Pattern - .compile(".*data\\s([0-9a-f]{3,4}).*"); +public final class ActiveDirectoryLdapAuthenticationProvider extends AbstractLdapAuthenticationProvider { + + private static final Pattern SUB_ERROR_CODE = Pattern.compile(".*data\\s([0-9a-f]{3,4}).*"); // Error codes private static final int USERNAME_NOT_FOUND = 0x525; + private static final int INVALID_PASSWORD = 0x52e; + private static final int NOT_PERMITTED = 0x530; + private static final int PASSWORD_EXPIRED = 0x532; + private static final int ACCOUNT_DISABLED = 0x533; + private static final int ACCOUNT_EXPIRED = 0x701; + private static final int PASSWORD_NEEDS_RESET = 0x773; + private static final int ACCOUNT_LOCKED = 0x775; private final String domain; + private final String rootDn; + private final String url; + private boolean convertSubErrorCodesToExceptions; + private String searchFilter = "(&(objectClass=user)(userPrincipalName={0}))"; + private Map contextEnvironmentProperties = new HashMap<>(); // Only used to allow tests to substitute a mock LdapContext @@ -120,8 +140,7 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends * @param url an LDAP url (or multiple URLs) * @param rootDn the root DN (may be null or empty) */ - public ActiveDirectoryLdapAuthenticationProvider(String domain, String url, - String rootDn) { + public ActiveDirectoryLdapAuthenticationProvider(String domain, String url, String rootDn) { Assert.isTrue(StringUtils.hasText(url), "Url cannot be empty"); this.domain = StringUtils.hasText(domain) ? domain.toLowerCase() : null; this.url = url; @@ -136,27 +155,24 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends Assert.isTrue(StringUtils.hasText(url), "Url cannot be empty"); this.domain = StringUtils.hasText(domain) ? domain.toLowerCase() : null; this.url = url; - rootDn = this.domain == null ? null : rootDnFromDomain(this.domain); + this.rootDn = (this.domain != null) ? rootDnFromDomain(this.domain) : null; } @Override - protected DirContextOperations doAuthentication( - UsernamePasswordAuthenticationToken auth) { + protected DirContextOperations doAuthentication(UsernamePasswordAuthenticationToken auth) { String username = auth.getName(); String password = (String) auth.getCredentials(); DirContext ctx = null; - try { ctx = bindAsUser(username, password); return searchForUser(ctx, username); } - catch (CommunicationException e) { - throw badLdapConnection(e); + catch (CommunicationException ex) { + throw badLdapConnection(ex); } - catch (NamingException e) { - logger.error("Failed to locate directory entry for authenticated user: " - + username, e); - throw badCredentials(e); + catch (NamingException ex) { + this.logger.error("Failed to locate directory entry for authenticated user: " + username, ex); + throw badCredentials(ex); } finally { LdapUtils.closeContext(ctx); @@ -168,35 +184,26 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends * obtained from the user's Active Directory entry. */ @Override - protected Collection loadUserAuthorities( - DirContextOperations userData, String username, String password) { + protected Collection loadUserAuthorities(DirContextOperations userData, String username, + String password) { String[] groups = userData.getStringAttributes("memberOf"); - if (groups == null) { - logger.debug("No values for 'memberOf' attribute."); - + this.logger.debug("No values for 'memberOf' attribute."); return AuthorityUtils.NO_AUTHORITIES; } - - if (logger.isDebugEnabled()) { - logger.debug("'memberOf' attribute values: " + Arrays.asList(groups)); + if (this.logger.isDebugEnabled()) { + this.logger.debug("'memberOf' attribute values: " + Arrays.asList(groups)); } - - ArrayList authorities = new ArrayList<>( - groups.length); - + List authorities = new ArrayList<>(groups.length); for (String group : groups) { - authorities.add(new SimpleGrantedAuthority(new DistinguishedName(group) - .removeLast().getValue())); + authorities.add(new SimpleGrantedAuthority(new DistinguishedName(group).removeLast().getValue())); } - return authorities; } private DirContext bindAsUser(String username, String password) { // TODO. add DNS lookup based on domain - final String bindUrl = url; - + final String bindUrl = this.url; Hashtable env = new Hashtable<>(); env.put(Context.SECURITY_AUTHENTICATION, "simple"); String bindPrincipal = createBindPrincipal(username); @@ -206,39 +213,29 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory"); env.put(Context.OBJECT_FACTORIES, DefaultDirObjectFactory.class.getName()); env.putAll(this.contextEnvironmentProperties); - try { - return contextFactory.createContext(env); + return this.contextFactory.createContext(env); } - catch (NamingException e) { - if ((e instanceof AuthenticationException) - || (e instanceof OperationNotSupportedException)) { - handleBindException(bindPrincipal, e); - throw badCredentials(e); - } else { - throw LdapUtils.convertLdapException(e); + catch (NamingException ex) { + if ((ex instanceof AuthenticationException) || (ex instanceof OperationNotSupportedException)) { + handleBindException(bindPrincipal, ex); + throw badCredentials(ex); } + throw LdapUtils.convertLdapException(ex); } } private void handleBindException(String bindPrincipal, NamingException exception) { - if (logger.isDebugEnabled()) { - logger.debug("Authentication for " + bindPrincipal + " failed:" + exception); - } - + this.logger.debug(LogMessage.format("Authentication for %s failed:%s", bindPrincipal, exception)); handleResolveObj(exception); - int subErrorCode = parseSubErrorCode(exception.getMessage()); - if (subErrorCode <= 0) { - logger.debug("Failed to locate AD-specific sub-error code in message"); + this.logger.debug("Failed to locate AD-specific sub-error code in message"); return; } - - logger.info("Active Directory authentication failed: " - + subCodeToLogMessage(subErrorCode)); - - if (convertSubErrorCodesToExceptions) { + this.logger.info( + LogMessage.of(() -> "Active Directory authentication failed: " + subCodeToLogMessage(subErrorCode))); + if (this.convertSubErrorCodesToExceptions) { raiseExceptionForErrorCode(subErrorCode, exception); } } @@ -252,34 +249,29 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends } private int parseSubErrorCode(String message) { - Matcher m = SUB_ERROR_CODE.matcher(message); - - if (m.matches()) { - return Integer.parseInt(m.group(1), 16); + Matcher matcher = SUB_ERROR_CODE.matcher(message); + if (matcher.matches()) { + return Integer.parseInt(matcher.group(1), 16); } - return -1; } private void raiseExceptionForErrorCode(int code, NamingException exception) { String hexString = Integer.toHexString(code); - Throwable cause = new ActiveDirectoryAuthenticationException(hexString, - exception.getMessage(), exception); + Throwable cause = new ActiveDirectoryAuthenticationException(hexString, exception.getMessage(), exception); switch (code) { case PASSWORD_EXPIRED: - throw new CredentialsExpiredException(messages.getMessage( - "LdapAuthenticationProvider.credentialsExpired", - "User credentials have expired"), cause); + throw new CredentialsExpiredException(this.messages.getMessage( + "LdapAuthenticationProvider.credentialsExpired", "User credentials have expired"), cause); case ACCOUNT_DISABLED: - throw new DisabledException(messages.getMessage( - "LdapAuthenticationProvider.disabled", "User is disabled"), cause); + throw new DisabledException( + this.messages.getMessage("LdapAuthenticationProvider.disabled", "User is disabled"), cause); case ACCOUNT_EXPIRED: - throw new AccountExpiredException(messages.getMessage( - "LdapAuthenticationProvider.expired", "User account has expired"), - cause); + throw new AccountExpiredException( + this.messages.getMessage("LdapAuthenticationProvider.expired", "User account has expired"), cause); case ACCOUNT_LOCKED: - throw new LockedException(messages.getMessage( - "LdapAuthenticationProvider.locked", "User account is locked"), cause); + throw new LockedException( + this.messages.getMessage("LdapAuthenticationProvider.locked", "User account is locked"), cause); default: throw badCredentials(cause); } @@ -304,13 +296,12 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends case ACCOUNT_LOCKED: return "Account locked"; } - return "Unknown (error code " + Integer.toHexString(code) + ")"; } private BadCredentialsException badCredentials() { - return new BadCredentialsException(messages.getMessage( - "LdapAuthenticationProvider.badCredentials", "Bad credentials")); + return new BadCredentialsException( + this.messages.getMessage("LdapAuthenticationProvider.badCredentials", "Bad credentials")); } private BadCredentialsException badCredentials(Throwable cause) { @@ -318,74 +309,62 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends } private InternalAuthenticationServiceException badLdapConnection(Throwable cause) { - return new InternalAuthenticationServiceException(messages.getMessage( - "LdapAuthenticationProvider.badLdapConnection", - "Connection to LDAP server failed."), cause); + return new InternalAuthenticationServiceException(this.messages.getMessage( + "LdapAuthenticationProvider.badLdapConnection", "Connection to LDAP server failed."), cause); } - private DirContextOperations searchForUser(DirContext context, String username) - throws NamingException { + private DirContextOperations searchForUser(DirContext context, String username) throws NamingException { SearchControls searchControls = new SearchControls(); searchControls.setSearchScope(SearchControls.SUBTREE_SCOPE); - String bindPrincipal = createBindPrincipal(username); - String searchRoot = rootDn != null ? rootDn - : searchRootFromPrincipal(bindPrincipal); + String searchRoot = (this.rootDn != null) ? this.rootDn : searchRootFromPrincipal(bindPrincipal); try { - return SpringSecurityLdapTemplate.searchForSingleEntryInternal(context, - searchControls, searchRoot, searchFilter, - new Object[] { bindPrincipal, username }); + return SpringSecurityLdapTemplate.searchForSingleEntryInternal(context, searchControls, searchRoot, + this.searchFilter, new Object[] { bindPrincipal, username }); } - catch (CommunicationException ldapCommunicationException) { - throw badLdapConnection(ldapCommunicationException); + catch (CommunicationException ex) { + throw badLdapConnection(ex); } - catch (IncorrectResultSizeDataAccessException incorrectResults) { - // Search should never return multiple results if properly configured - just - // rethrow - if (incorrectResults.getActualSize() != 0) { - throw incorrectResults; + catch (IncorrectResultSizeDataAccessException ex) { + // Search should never return multiple results if properly configured - + if (ex.getActualSize() != 0) { + throw ex; } // If we found no results, then the username/password did not match UsernameNotFoundException userNameNotFoundException = new UsernameNotFoundException( - "User " + username + " not found in directory.", incorrectResults); + "User " + username + " not found in directory.", ex); throw badCredentials(userNameNotFoundException); } } private String searchRootFromPrincipal(String bindPrincipal) { int atChar = bindPrincipal.lastIndexOf('@'); - if (atChar < 0) { - logger.debug("User principal '" + bindPrincipal + this.logger.debug("User principal '" + bindPrincipal + "' does not contain the domain, and no domain has been configured"); throw badCredentials(); } - - return rootDnFromDomain(bindPrincipal.substring(atChar + 1, - bindPrincipal.length())); + return rootDnFromDomain(bindPrincipal.substring(atChar + 1, bindPrincipal.length())); } private String rootDnFromDomain(String domain) { String[] tokens = StringUtils.tokenizeToStringArray(domain, "."); StringBuilder root = new StringBuilder(); - for (String token : tokens) { if (root.length() > 0) { root.append(','); } root.append("dc=").append(token); } - return root.toString(); } String createBindPrincipal(String username) { - if (domain == null || username.toLowerCase().endsWith(domain)) { + if (this.domain == null || username.toLowerCase().endsWith(this.domain)) { return username; } - - return username + "@" + domain; + return username + "@" + this.domain; } /** @@ -398,12 +377,10 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends * {@link AccountExpiredException} or {@link LockedException} will be thrown for the * corresponding codes. All other codes will result in the default * {@code BadCredentialsException}. - * * @param convertSubErrorCodesToExceptions {@code true} to raise an exception based on * the AD error code. */ - public void setConvertSubErrorCodesToExceptions( - boolean convertSubErrorCodesToExceptions) { + public void setConvertSubErrorCodesToExceptions(boolean convertSubErrorCodesToExceptions) { this.convertSubErrorCodesToExceptions = convertSubErrorCodesToExceptions; } @@ -414,7 +391,6 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends *

        * Defaults to: {@code (&(objectClass=user)(userPrincipalName={0}))} *

        - * * @param searchFilter the filter string * * @since 3.2.6 @@ -426,8 +402,8 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends /** * Allows a custom environment properties to be used to create initial LDAP context. - * - * @param environment the additional environment parameters to use when creating the LDAP Context + * @param environment the additional environment parameters to use when creating the + * LDAP Context */ public void setContextEnvironmentProperties(Map environment) { Assert.notEmpty(environment, "environment must not be empty"); @@ -435,8 +411,11 @@ public final class ActiveDirectoryLdapAuthenticationProvider extends } static class ContextFactory { + DirContext createContext(Hashtable env) throws NamingException { return new InitialLdapContext(env, null); } + } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/authentication/package-info.java b/ldap/src/main/java/org/springframework/security/ldap/authentication/package-info.java index fee0732440..eed3e65804 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/authentication/package-info.java +++ b/ldap/src/main/java/org/springframework/security/ldap/authentication/package-info.java @@ -13,14 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * The LDAP authentication provider package. Interfaces are provided for - * both authentication and retrieval of user roles from an LDAP server. + * The LDAP authentication provider package. Interfaces are provided for both + * authentication and retrieval of user roles from an LDAP server. *

        - * The main provider class is LdapAuthenticationProvider. - * This is configured with an LdapAuthenticator instance and - * an LdapAuthoritiesPopulator. The latter is used to obtain the - * list of roles for the user. + * The main provider class is LdapAuthenticationProvider. This is configured with + * an LdapAuthenticator instance and an LdapAuthoritiesPopulator. The + * latter is used to obtain the list of roles for the user. */ package org.springframework.security.ldap.authentication; - diff --git a/ldap/src/main/java/org/springframework/security/ldap/package-info.java b/ldap/src/main/java/org/springframework/security/ldap/package-info.java index 70dc7c7526..891f2c4874 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/package-info.java +++ b/ldap/src/main/java/org/springframework/security/ldap/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Spring Security's LDAP module. */ package org.springframework.security.ldap; - diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSource.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSource.java index 249f3736a0..6fb79ffd4f 100755 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSource.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSource.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; import java.util.Hashtable; @@ -22,6 +23,7 @@ import javax.naming.directory.DirContext; import javax.naming.ldap.Control; import javax.naming.ldap.LdapContext; +import org.springframework.core.log.LogMessage; import org.springframework.ldap.support.LdapUtils; import org.springframework.security.ldap.DefaultSpringSecurityContextSource; @@ -44,53 +46,34 @@ public class PasswordPolicyAwareContextSource extends DefaultSpringSecurityConte } @Override - public DirContext getContext(String principal, String credentials) - throws PasswordPolicyException { - if (principal.equals(userDn)) { + public DirContext getContext(String principal, String credentials) throws PasswordPolicyException { + if (principal.equals(this.userDn)) { return super.getContext(principal, credentials); } - - final boolean debug = logger.isDebugEnabled(); - - if (debug) { - logger.debug("Binding as '" + userDn + "', prior to reconnect as user '" - + principal + "'"); - } - + this.logger + .debug(LogMessage.format("Binding as '%s', prior to reconnect as user '%s'", this.userDn, principal)); // First bind as manager user before rebinding as the specific principal. - LdapContext ctx = (LdapContext) super.getContext(userDn, password); - + LdapContext ctx = (LdapContext) super.getContext(this.userDn, this.password); Control[] rctls = { new PasswordPolicyControl(false) }; - try { ctx.addToEnvironment(Context.SECURITY_PRINCIPAL, principal); ctx.addToEnvironment(Context.SECURITY_CREDENTIALS, credentials); ctx.reconnect(rctls); } - catch (javax.naming.NamingException ne) { - PasswordPolicyResponseControl ctrl = PasswordPolicyControlExtractor - .extractControl(ctx); - if (debug) { - logger.debug("Failed to obtain context", ne); - logger.debug("Password policy response: " + ctrl); + catch (javax.naming.NamingException ex) { + PasswordPolicyResponseControl ctrl = PasswordPolicyControlExtractor.extractControl(ctx); + if (this.logger.isDebugEnabled()) { + this.logger.debug("Failed to obtain context", ex); + this.logger.debug("Password policy response: " + ctrl); } - LdapUtils.closeContext(ctx); - - if (ctrl != null) { - if (ctrl.isLocked()) { - throw new PasswordPolicyException(ctrl.getErrorStatus()); - } + if (ctrl != null && ctrl.isLocked()) { + throw new PasswordPolicyException(ctrl.getErrorStatus()); } - - throw LdapUtils.convertLdapException(ne); + throw LdapUtils.convertLdapException(ex); } - - if (debug) { - logger.debug("PPolicy control returned: " - + PasswordPolicyControlExtractor.extractControl(ctx)); - } - + this.logger.debug( + LogMessage.of(() -> "PPolicy control returned: " + PasswordPolicyControlExtractor.extractControl(ctx))); return ctx; } @@ -98,10 +81,8 @@ public class PasswordPolicyAwareContextSource extends DefaultSpringSecurityConte @SuppressWarnings("unchecked") protected Hashtable getAuthenticatedEnv(String principal, String credentials) { Hashtable env = super.getAuthenticatedEnv(principal, credentials); - - env.put(LdapContext.CONTROL_FACTORIES, - PasswordPolicyControlFactory.class.getName()); - + env.put(LdapContext.CONTROL_FACTORIES, PasswordPolicyControlFactory.class.getName()); return env; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControl.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControl.java index 9d45957fa3..84eb48cdf9 100755 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControl.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControl.java @@ -28,24 +28,17 @@ import javax.naming.ldap.Control; * * @author Stefan Zoerner * @author Luke Taylor - * * @see PasswordPolicyResponseControl */ public class PasswordPolicyControl implements Control { - // ~ Static fields/initializers - // ===================================================================================== - /** OID of the Password Policy Control */ + /** + * OID of the Password Policy Control + */ public static final String OID = "1.3.6.1.4.1.42.2.27.8.5.1"; - // ~ Instance fields - // ================================================================================================ - private final boolean critical; - // ~ Constructors - // =================================================================================================== - /** * Creates a non-critical (request) control. */ @@ -55,22 +48,18 @@ public class PasswordPolicyControl implements Control { /** * Creates a (request) control. - * * @param critical indicates whether the control is critical for the client */ public PasswordPolicyControl(boolean critical) { this.critical = critical; } - // ~ Methods - // ======================================================================================================== - /** * Retrieves the ASN.1 BER encoded value of the LDAP control. The request value for * this control is always empty. - * * @return always null */ + @Override public byte[] getEncodedValue() { return null; } @@ -78,6 +67,7 @@ public class PasswordPolicyControl implements Control { /** * Returns the OID of the Password Policy Control ("1.3.6.1.4.1.42.2.27.8.5.1"). */ + @Override public String getID() { return OID; } @@ -85,7 +75,9 @@ public class PasswordPolicyControl implements Control { /** * Returns whether the control is critical for the client. */ + @Override public boolean isCritical() { - return critical; + return this.critical; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlExtractor.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlExtractor.java index 715972ecf9..79f007e408 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlExtractor.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlExtractor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; import javax.naming.directory.DirContext; @@ -28,9 +29,12 @@ import org.apache.commons.logging.LogFactory; * @author Luke Taylor * @since 3.0 */ -public class PasswordPolicyControlExtractor { - private static final Log logger = LogFactory - .getLog(PasswordPolicyControlExtractor.class); +public final class PasswordPolicyControlExtractor { + + private static final Log logger = LogFactory.getLog(PasswordPolicyControlExtractor.class); + + private PasswordPolicyControlExtractor() { + } public static PasswordPolicyResponseControl extractControl(DirContext dirCtx) { LdapContext ctx = (LdapContext) dirCtx; @@ -38,16 +42,14 @@ public class PasswordPolicyControlExtractor { try { ctrls = ctx.getResponseControls(); } - catch (javax.naming.NamingException e) { - logger.error("Failed to obtain response controls", e); + catch (javax.naming.NamingException ex) { + logger.error("Failed to obtain response controls", ex); } - for (int i = 0; ctrls != null && i < ctrls.length; i++) { if (ctrls[i] instanceof PasswordPolicyResponseControl) { return (PasswordPolicyResponseControl) ctrls[i]; } } - return null; } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactory.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactory.java index 8584e1e42e..0bb3e274a2 100755 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactory.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactory.java @@ -26,23 +26,20 @@ import javax.naming.ldap.ControlFactory; * @author Luke Taylor */ public class PasswordPolicyControlFactory extends ControlFactory { - // ~ Methods - // ======================================================================================================== /** * Creates an instance of PasswordPolicyResponseControl if the passed control is a * response control of this type. Attributes of the result are filled with the correct * values (e.g. error code). - * * @param ctl the control the check - * * @return a response control of type PasswordPolicyResponseControl, or null */ + @Override public Control getControlInstance(Control ctl) { if (ctl.getID().equals(PasswordPolicyControl.OID)) { return new PasswordPolicyResponseControl(ctl.getEncodedValue()); } - return null; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyData.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyData.java index 0a69477b9f..2098ebeb53 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyData.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyData.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; /** @@ -20,7 +21,9 @@ package org.springframework.security.ldap.ppolicy; * @since 3.0 */ public interface PasswordPolicyData { + int getTimeBeforeExpiration(); int getGraceLoginsRemaining(); + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyErrorStatus.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyErrorStatus.java index 02a289e29a..b81bc34eab 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyErrorStatus.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyErrorStatus.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; /** @@ -39,20 +40,28 @@ package org.springframework.security.ldap.ppolicy; * @since 3.0 */ public enum PasswordPolicyErrorStatus { - PASSWORD_EXPIRED("ppolicy.expired", "Your password has expired"), ACCOUNT_LOCKED( - "ppolicy.locked", "Account is locked"), CHANGE_AFTER_RESET( - "ppolicy.change.after.reset", - "Your password must be changed after being reset"), PASSWORD_MOD_NOT_ALLOWED( - "ppolicy.mod.not.allowed", "Password cannot be changed"), MUST_SUPPLY_OLD_PASSWORD( - "ppolicy.must.supply.old.password", "The old password must be supplied"), INSUFFICIENT_PASSWORD_QUALITY( - "ppolicy.insufficient.password.quality", - "The supplied password is of insufficient quality"), PASSWORD_TOO_SHORT( - "ppolicy.password.too.short", "The supplied password is too short"), PASSWORD_TOO_YOUNG( - "ppolicy.password.too.young", - "Your password was changed too recently to be changed again"), PASSWORD_IN_HISTORY( - "ppolicy.password.in.history", "The supplied password has already been used"); + + PASSWORD_EXPIRED("ppolicy.expired", "Your password has expired"), + + ACCOUNT_LOCKED("ppolicy.locked", "Account is locked"), + + CHANGE_AFTER_RESET("ppolicy.change.after.reset", "Your password must be changed after being reset"), + + PASSWORD_MOD_NOT_ALLOWED("ppolicy.mod.not.allowed", "Password cannot be changed"), + + MUST_SUPPLY_OLD_PASSWORD("ppolicy.must.supply.old.password", "The old password must be supplied"), + + INSUFFICIENT_PASSWORD_QUALITY("ppolicy.insufficient.password.quality", + "The supplied password is of insufficient quality"), + + PASSWORD_TOO_SHORT("ppolicy.password.too.short", "The supplied password is too short"), + + PASSWORD_TOO_YOUNG("ppolicy.password.too.young", "Your password was changed too recently to be changed again"), + + PASSWORD_IN_HISTORY("ppolicy.password.in.history", "The supplied password has already been used"); private final String errorCode; + private final String defaultMessage; PasswordPolicyErrorStatus(String errorCode, String defaultMessage) { @@ -61,10 +70,11 @@ public enum PasswordPolicyErrorStatus { } public String getErrorCode() { - return errorCode; + return this.errorCode; } public String getDefaultMessage() { - return defaultMessage; + return this.defaultMessage; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyException.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyException.java index c276a71b02..73ab142052 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyException.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; /** @@ -25,6 +26,7 @@ package org.springframework.security.ldap.ppolicy; * @since 3.0 */ public class PasswordPolicyException extends RuntimeException { + private final PasswordPolicyErrorStatus status; public PasswordPolicyException(PasswordPolicyErrorStatus status) { @@ -33,6 +35,7 @@ public class PasswordPolicyException extends RuntimeException { } public PasswordPolicyErrorStatus getStatus() { - return status; + return this.status; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControl.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControl.java index f2c92d7b64..bb1c8b0898 100755 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControl.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControl.java @@ -41,33 +41,24 @@ import org.springframework.dao.DataRetrievalFailureException; * graceLoginsRemaining. *

        * - * * @author Stefan Zoerner * @author Luke Taylor - * * @see org.springframework.security.ldap.ppolicy.PasswordPolicyControl - * @see Stefan - * Zoerner's IBM developerworks article on LDAP controls. + * @see Stefan Zoerner's + * IBM developerworks article on LDAP controls. */ public class PasswordPolicyResponseControl extends PasswordPolicyControl { - // ~ Static fields/initializers - // ===================================================================================== - private static final Log logger = LogFactory - .getLog(PasswordPolicyResponseControl.class); - - // ~ Instance fields - // ================================================================================================ + private static final Log logger = LogFactory.getLog(PasswordPolicyResponseControl.class); private final byte[] encodedValue; private PasswordPolicyErrorStatus errorStatus; private int graceLoginsRemaining = Integer.MAX_VALUE; - private int timeBeforeExpiration = Integer.MAX_VALUE; - // ~ Constructors - // =================================================================================================== + private int timeBeforeExpiration = Integer.MAX_VALUE; /** * Decodes the Ber encoded control data. The ASN.1 value of the control data is: @@ -86,21 +77,15 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { */ public PasswordPolicyResponseControl(byte[] encodedValue) { this.encodedValue = encodedValue; - - // PPolicyDecoder decoder = new JLdapDecoder(); PPolicyDecoder decoder = new NetscapeDecoder(); - try { decoder.decode(); } - catch (IOException e) { - throw new DataRetrievalFailureException("Failed to parse control value", e); + catch (IOException ex) { + throw new DataRetrievalFailureException("Failed to parse control value", ex); } } - // ~ Methods - // ======================================================================================================== - /** * Returns the unchanged value of the response control. Returns the unchanged value of * the response control as byte array. @@ -116,7 +101,6 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { /** * Returns the graceLoginsRemaining. - * * @return Returns the graceLoginsRemaining. */ public int getGraceLoginsRemaining() { @@ -125,7 +109,6 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { /** * Returns the timeBeforeExpiration. - * * @return Returns the time before expiration in seconds */ public int getTimeBeforeExpiration() { @@ -134,7 +117,6 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { /** * Checks whether an error is present. - * * @return true, if an error is present */ public boolean hasError() { @@ -143,12 +125,10 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { /** * Checks whether a warning is present. - * * @return true, if a warning is present */ public boolean hasWarning() { - return (this.graceLoginsRemaining != Integer.MAX_VALUE) - || (this.timeBeforeExpiration != Integer.MAX_VALUE); + return (this.graceLoginsRemaining != Integer.MAX_VALUE) || (this.timeBeforeExpiration != Integer.MAX_VALUE); } public boolean isExpired() { @@ -165,7 +145,6 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { /** * Determines whether an account locked error has been returned. - * * @return true if the account is locked. */ public boolean isLocked() { @@ -175,74 +154,53 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { /** * Create a textual representation containing error and warning messages, if any are * present. - * * @return error and warning messages */ @Override public String toString() { StringBuilder sb = new StringBuilder("PasswordPolicyResponseControl"); - if (hasError()) { sb.append(", error: ").append(this.errorStatus.getDefaultMessage()); } - if (this.graceLoginsRemaining != Integer.MAX_VALUE) { - sb.append(", warning: ").append(this.graceLoginsRemaining) - .append(" grace logins remain"); + sb.append(", warning: ").append(this.graceLoginsRemaining).append(" grace logins remain"); } - if (this.timeBeforeExpiration != Integer.MAX_VALUE) { - sb.append(", warning: time before expiration is ") - .append(this.timeBeforeExpiration); + sb.append(", warning: time before expiration is ").append(this.timeBeforeExpiration); } - if (!hasError() && !hasWarning()) { sb.append(" (no error, no warning)"); } - return sb.toString(); } - // ~ Inner Interfaces - // =============================================================================================== - private interface PPolicyDecoder { - void decode() throws IOException; - } - // ~ Inner Classes - // ================================================================================================== + void decode() throws IOException; + + } /** * Decoder based on Netscape ldapsdk library */ private class NetscapeDecoder implements PPolicyDecoder { + + @Override public void decode() throws IOException { int[] bread = { 0 }; - BERSequence seq = (BERSequence) BERElement - .getElement(new SpecificTagDecoder(), - new ByteArrayInputStream( - PasswordPolicyResponseControl.this.encodedValue), - bread); - + BERSequence seq = (BERSequence) BERElement.getElement(new SpecificTagDecoder(), + new ByteArrayInputStream(PasswordPolicyResponseControl.this.encodedValue), bread); int size = seq.size(); - if (logger.isDebugEnabled()) { - logger.debug("PasswordPolicyResponse, ASN.1 sequence has " + size - + " elements"); + logger.debug("PasswordPolicyResponse, ASN.1 sequence has " + size + " elements"); } - for (int i = 0; i < seq.size(); i++) { BERTag elt = (BERTag) seq.elementAt(i); - int tag = elt.getTag() & 0x1F; - if (tag == 0) { BERChoice warning = (BERChoice) elt.getValue(); - BERTag content = (BERTag) warning.getValue(); int value = ((BERInteger) content.getValue()).getValue(); - if ((content.getTag() & 0x1F) == 0) { PasswordPolicyResponseControl.this.timeBeforeExpiration = value; } @@ -252,36 +210,31 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { } else if (tag == 1) { BERIntegral error = (BERIntegral) elt.getValue(); - PasswordPolicyResponseControl.this.errorStatus = PasswordPolicyErrorStatus - .values()[error.getValue()]; + PasswordPolicyResponseControl.this.errorStatus = PasswordPolicyErrorStatus.values()[error + .getValue()]; } } } class SpecificTagDecoder extends BERTagDecoder { + /** Allows us to remember which of the two options we're decoding */ private Boolean inChoice = null; @Override - public BERElement getElement(BERTagDecoder decoder, int tag, - InputStream stream, int[] bytesRead, boolean[] implicit) - throws IOException { + public BERElement getElement(BERTagDecoder decoder, int tag, InputStream stream, int[] bytesRead, + boolean[] implicit) throws IOException { tag &= 0x1F; implicit[0] = false; - if (tag == 0) { // Either the choice or the time before expiry within it if (this.inChoice == null) { setInChoice(true); - // Read the choice length from the stream (ignored) BERElement.readLengthOctets(stream, bytesRead); - int[] componentLength = new int[1]; - BERElement choice = new BERChoice(decoder, stream, - componentLength); + BERElement choice = new BERChoice(decoder, stream, componentLength); bytesRead[0] += componentLength[0]; - // inChoice = null; return choice; } @@ -295,7 +248,6 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { if (this.inChoice == null) { // The enumeration setInChoice(false); - return new BEREnumerated(stream, bytesRead); } else { @@ -305,76 +257,15 @@ public class PasswordPolicyResponseControl extends PasswordPolicyControl { } } } - throw new DataRetrievalFailureException("Unexpected tag " + tag); } private void setInChoice(boolean inChoice) { this.inChoice = inChoice; } + } + } - /** Decoder based on the OpenLDAP/Novell JLDAP library */ - - // private class JLdapDecoder implements PPolicyDecoder { - // - // public void decode() throws IOException { - // - // LBERDecoder decoder = new LBERDecoder(); - // - // ASN1Sequence seq = (ASN1Sequence)decoder.decode(encodedValue); - // - // if(seq == null) { - // - // } - // - // int size = seq.size(); - // - // if(logger.isDebugEnabled()) { - // logger.debug("PasswordPolicyResponse, ASN.1 sequence has " + - // size + " elements"); - // } - // - // for(int i=0; i < size; i++) { - // - // ASN1Tagged taggedObject = (ASN1Tagged)seq.get(i); - // - // int tag = taggedObject.getIdentifier().getTag(); - // - // ASN1OctetString value = (ASN1OctetString)taggedObject.taggedValue(); - // byte[] content = value.byteValue(); - // - // if(tag == 0) { - // parseWarning(content, decoder); - // - // } else if(tag == 1) { - // // Error: set the code to the value - // errorCode = content[0]; - // } - // } - // } - // - // private void parseWarning(byte[] content, LBERDecoder decoder) { - // // It's the warning (choice). Parse the number and set either the - // // expiry time or number of logins remaining. - // ASN1Tagged taggedObject = (ASN1Tagged)decoder.decode(content); - // int contentTag = taggedObject.getIdentifier().getTag(); - // content = ((ASN1OctetString)taggedObject.taggedValue()).byteValue(); - // int number; - // - // try { - // number = ((Long)decoder.decodeNumeric(new ByteArrayInputStream(content), - // content.length)).intValue(); - // } catch(IOException e) { - // throw new LdapDataAccessException("Failed to parse number ", e); - // } - // - // if(contentTag == 0) { - // timeBeforeExpiration = number; - // } else if (contentTag == 1) { - // graceLoginsRemaining = number; - // } - // } - // } } diff --git a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/package-info.java b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/package-info.java index 8bbbc99b7a..06956868f2 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/ppolicy/package-info.java +++ b/ldap/src/main/java/org/springframework/security/ldap/ppolicy/package-info.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Implementation of password policy functionality based on the - * + * Implementation of password policy functionality based on the * Password Policy for LDAP Directories. *

        - * This code will not work with servers such as Active Directory, which do not implement this standard. + * This code will not work with servers such as Active Directory, which do not implement + * this standard. */ package org.springframework.security.ldap.ppolicy; - diff --git a/ldap/src/main/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearch.java b/ldap/src/main/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearch.java index b58b884f00..86c6b4e8e7 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearch.java +++ b/ldap/src/main/java/org/springframework/security/ldap/search/FilterBasedLdapUserSearch.java @@ -16,39 +16,31 @@ package org.springframework.security.ldap.search; -import org.springframework.security.core.userdetails.UsernameNotFoundException; -import org.springframework.security.ldap.SpringSecurityLdapTemplate; +import javax.naming.directory.SearchControls; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.dao.IncorrectResultSizeDataAccessException; - -import org.springframework.util.Assert; - import org.springframework.ldap.core.ContextSource; import org.springframework.ldap.core.DirContextOperations; import org.springframework.ldap.core.support.BaseLdapPathContextSource; - -import javax.naming.directory.SearchControls; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.ldap.SpringSecurityLdapTemplate; +import org.springframework.util.Assert; /** * LdapUserSearch implementation which uses an Ldap filter to locate the user. * * @author Robert Sanders * @author Luke Taylor - * * @see SearchControls */ public class FilterBasedLdapUserSearch implements LdapUserSearch { - // ~ Static fields/initializers - // ===================================================================================== private static final Log logger = LogFactory.getLog(FilterBasedLdapUserSearch.class); - // ~ Instance fields - // ================================================================================================ - private final ContextSource contextSource; /** @@ -57,7 +49,9 @@ public class FilterBasedLdapUserSearch implements LdapUserSearch { */ private final SearchControls searchControls = new SearchControls(); - /** Context name to search in, relative to the base of the configured ContextSource. */ + /** + * Context name to search in, relative to the base of the configured ContextSource. + */ private String searchBase = ""; /** @@ -76,99 +70,70 @@ public class FilterBasedLdapUserSearch implements LdapUserSearch { */ private final String searchFilter; - // ~ Constructors - // =================================================================================================== - - public FilterBasedLdapUserSearch(String searchBase, String searchFilter, - BaseLdapPathContextSource contextSource) { + public FilterBasedLdapUserSearch(String searchBase, String searchFilter, BaseLdapPathContextSource contextSource) { Assert.notNull(contextSource, "contextSource must not be null"); Assert.notNull(searchFilter, "searchFilter must not be null."); - Assert.notNull(searchBase, - "searchBase must not be null (an empty string is acceptable)."); - + Assert.notNull(searchBase, "searchBase must not be null (an empty string is acceptable)."); this.searchFilter = searchFilter; this.contextSource = contextSource; this.searchBase = searchBase; - setSearchSubtree(true); - if (searchBase.length() == 0) { - logger.info("SearchBase not set. Searches will be performed from the root: " - + contextSource.getBaseLdapPath()); + logger.info( + "SearchBase not set. Searches will be performed from the root: " + contextSource.getBaseLdapPath()); } } - // ~ Methods - // ======================================================================================================== - /** * Return the LdapUserDetails containing the user's information - * * @param username the username to search for. - * * @return An LdapUserDetails object containing the details of the located user's * directory entry - * * @throws UsernameNotFoundException if no matching entry is found. */ @Override public DirContextOperations searchForUser(String username) { - if (logger.isDebugEnabled()) { - logger.debug("Searching for user '" + username + "', with user search " - + this); - } - - SpringSecurityLdapTemplate template = new SpringSecurityLdapTemplate( - contextSource); - - template.setSearchControls(searchControls); - + logger.debug(LogMessage.of(() -> "Searching for user '" + username + "', with user search " + this)); + SpringSecurityLdapTemplate template = new SpringSecurityLdapTemplate(this.contextSource); + template.setSearchControls(this.searchControls); try { - - return template.searchForSingleEntry(searchBase, searchFilter, - new String[] { username }); - + return template.searchForSingleEntry(this.searchBase, this.searchFilter, new String[] { username }); } - catch (IncorrectResultSizeDataAccessException notFound) { - if (notFound.getActualSize() == 0) { - throw new UsernameNotFoundException("User " + username - + " not found in directory."); + catch (IncorrectResultSizeDataAccessException ex) { + if (ex.getActualSize() == 0) { + throw new UsernameNotFoundException("User " + username + " not found in directory."); } - // Search should never return multiple results if properly configured, so just - // rethrow - throw notFound; + // Search should never return multiple results if properly configured + throw ex; } } /** * Sets the corresponding property on the {@link SearchControls} instance used in the * search. - * * @param deref the derefLinkFlag value as defined in SearchControls.. */ public void setDerefLinkFlag(boolean deref) { - searchControls.setDerefLinkFlag(deref); + this.searchControls.setDerefLinkFlag(deref); } /** * If true then searches the entire subtree as identified by context, if false (the * default) then only searches the level identified by the context. - * * @param searchSubtree true the underlying search controls should be set to * SearchControls.SUBTREE_SCOPE rather than SearchControls.ONELEVEL_SCOPE. */ public void setSearchSubtree(boolean searchSubtree) { - searchControls.setSearchScope(searchSubtree ? SearchControls.SUBTREE_SCOPE - : SearchControls.ONELEVEL_SCOPE); + this.searchControls + .setSearchScope(searchSubtree ? SearchControls.SUBTREE_SCOPE : SearchControls.ONELEVEL_SCOPE); } /** * The time to wait before the search fails; the default is zero, meaning forever. - * * @param searchTimeLimit the time limit for the search (in milliseconds). */ public void setSearchTimeLimit(int searchTimeLimit) { - searchControls.setTimeLimit(searchTimeLimit); + this.searchControls.setTimeLimit(searchTimeLimit); } /** @@ -176,26 +141,23 @@ public class FilterBasedLdapUserSearch implements LdapUserSearch { *

        * null indicates that all attributes will be returned. An empty array indicates no * attributes are returned. - * * @param attrs An array of attribute names identifying the attributes that will be * returned. Can be null. */ public void setReturningAttributes(String[] attrs) { - searchControls.setReturningAttributes(attrs); + this.searchControls.setReturningAttributes(attrs); } @Override public String toString() { StringBuilder sb = new StringBuilder(); - - sb.append("[ searchFilter: '").append(searchFilter).append("', "); - sb.append("searchBase: '").append(searchBase).append("'"); - sb.append(", scope: ") - .append(searchControls.getSearchScope() == SearchControls.SUBTREE_SCOPE ? "subtree" - : "single-level, "); - sb.append(", searchTimeLimit: ").append(searchControls.getTimeLimit()); - sb.append(", derefLinkFlag: ").append(searchControls.getDerefLinkFlag()) - .append(" ]"); + sb.append("[ searchFilter: '").append(this.searchFilter).append("', "); + sb.append("searchBase: '").append(this.searchBase).append("'"); + sb.append(", scope: ").append( + (this.searchControls.getSearchScope() != SearchControls.SUBTREE_SCOPE) ? "single-level, " : "subtree"); + sb.append(", searchTimeLimit: ").append(this.searchControls.getTimeLimit()); + sb.append(", derefLinkFlag: ").append(this.searchControls.getDerefLinkFlag()).append(" ]"); return sb.toString(); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/search/LdapUserSearch.java b/ldap/src/main/java/org/springframework/security/ldap/search/LdapUserSearch.java index e0d86b8eaf..3852cc62b2 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/search/LdapUserSearch.java +++ b/ldap/src/main/java/org/springframework/security/ldap/search/LdapUserSearch.java @@ -28,19 +28,16 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException; * @author Luke Taylor */ public interface LdapUserSearch { - // ~ Methods - // ======================================================================================================== /** * Locates a single user in the directory and returns the LDAP information for that * user. - * * @param username the login name supplied to the authentication service. - * * @return a DirContextOperations object containing the user's full DN and requested * attributes. * @throws UsernameNotFoundException if no user with the supplied name could be * located by the search. */ DirContextOperations searchForUser(String username) throws UsernameNotFoundException; + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/search/package-info.java b/ldap/src/main/java/org/springframework/security/ldap/search/package-info.java index 0aa23b3e4e..7a8f00287b 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/search/package-info.java +++ b/ldap/src/main/java/org/springframework/security/ldap/search/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * {@code LdapUserSearch} implementations. These may be used to locate the user in the directory. + * {@code LdapUserSearch} implementations. These may be used to locate the user in the + * directory. */ package org.springframework.security.ldap.search; - diff --git a/ldap/src/main/java/org/springframework/security/ldap/server/ApacheDSContainer.java b/ldap/src/main/java/org/springframework/security/ldap/server/ApacheDSContainer.java index f02dd51edc..eb1eb79093 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/server/ApacheDSContainer.java +++ b/ldap/src/main/java/org/springframework/security/ldap/server/ApacheDSContainer.java @@ -13,16 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.server; +import java.io.File; +import java.io.IOException; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.io.File; -import java.io.IOException; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.directory.server.core.DefaultDirectoryService; @@ -41,6 +41,7 @@ import org.apache.directory.server.protocol.shared.transport.TcpTransport; import org.apache.directory.shared.ldap.exception.LdapNameNotFoundException; import org.apache.directory.shared.ldap.name.LdapDN; import org.apache.mina.transport.socket.SocketAcceptor; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; @@ -65,8 +66,8 @@ import org.springframework.util.Assert; * application context is closed to allow the bean to be disposed of and the server * shutdown prior to attempting to start it again. *

        - * This class is intended for testing and internal security namespace use, only, and is not - * considered part of the framework's public API. + * This class is intended for testing and internal security namespace use, only, and is + * not considered part of the framework's public API. * * @author Luke Taylor * @author Rob Winch @@ -76,115 +77,104 @@ import org.springframework.util.Assert; * supported with no GA version to replace it. */ @Deprecated -public class ApacheDSContainer implements InitializingBean, DisposableBean, Lifecycle, - ApplicationContextAware { +public class ApacheDSContainer implements InitializingBean, DisposableBean, Lifecycle, ApplicationContextAware { + private final Log logger = LogFactory.getLog(getClass()); final DefaultDirectoryService service; + LdapServer server; private TcpTransport transport; + private ApplicationContext ctxt; + private File workingDir; private boolean running; + private final String ldifResources; + private final JdbmPartition partition; + private final String root; + private int port = 53389; + private int localPort; private boolean ldapOverSslEnabled; + private File keyStoreFile; + private String certificatePassord; public ApacheDSContainer(String root, String ldifs) throws Exception { this.ldifResources = ldifs; - service = new DefaultDirectoryService(); + this.service = new DefaultDirectoryService(); List list = new ArrayList<>(); - list.add(new NormalizationInterceptor()); list.add(new AuthenticationInterceptor()); list.add(new ReferralInterceptor()); - // list.add( new AciAuthorizationInterceptor() ); - // list.add( new DefaultAuthorizationInterceptor() ); list.add(new ExceptionInterceptor()); - // list.add( new ChangeLogInterceptor() ); list.add(new OperationalAttributeInterceptor()); - // list.add( new SchemaInterceptor() ); list.add(new SubentryInterceptor()); - // list.add( new CollectiveAttributeInterceptor() ); - // list.add( new EventInterceptor() ); - // list.add( new TriggerInterceptor() ); - // list.add( new JournalInterceptor() ); - - service.setInterceptors(list); - partition = new JdbmPartition(); - partition.setId("rootPartition"); - partition.setSuffix(root); + this.service.setInterceptors(list); + this.partition = new JdbmPartition(); + this.partition.setId("rootPartition"); + this.partition.setSuffix(root); this.root = root; - service.addPartition(partition); - service.setExitVmOnShutdown(false); - service.setShutdownHookEnabled(false); - service.getChangeLog().setEnabled(false); - service.setDenormalizeOpAttrsEnabled(true); + this.service.addPartition(this.partition); + this.service.setExitVmOnShutdown(false); + this.service.setShutdownHookEnabled(false); + this.service.getChangeLog().setEnabled(false); + this.service.setDenormalizeOpAttrsEnabled(true); } + @Override public void afterPropertiesSet() throws Exception { - if (workingDir == null) { + if (this.workingDir == null) { String apacheWorkDir = System.getProperty("apacheDSWorkDir"); - if (apacheWorkDir == null) { apacheWorkDir = createTempDirectory("apacheds-spring-security-"); } - setWorkingDirectory(new File(apacheWorkDir)); } - if (this.ldapOverSslEnabled && this.keyStoreFile == null) { - throw new IllegalArgumentException("When LdapOverSsl is enabled, the keyStoreFile property must be set."); - } - - server = new LdapServer(); - server.setDirectoryService(service); + Assert.isTrue(!this.ldapOverSslEnabled || this.keyStoreFile != null, + "When LdapOverSsl is enabled, the keyStoreFile property must be set."); + this.server = new LdapServer(); + this.server.setDirectoryService(this.service); // AbstractLdapIntegrationTests assume IPv4, so we specify the same here - - this.transport = new TcpTransport(port); - if (ldapOverSslEnabled) { - transport.setEnableSSL(true); - server.setKeystoreFile(this.keyStoreFile.getAbsolutePath()); - server.setCertificatePassword(this.certificatePassord); + this.transport = new TcpTransport(this.port); + if (this.ldapOverSslEnabled) { + this.transport.setEnableSSL(true); + this.server.setKeystoreFile(this.keyStoreFile.getAbsolutePath()); + this.server.setCertificatePassword(this.certificatePassord); } - server.setTransports(transport); + this.server.setTransports(this.transport); start(); } + @Override public void destroy() { stop(); } - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - ctxt = applicationContext; + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.ctxt = applicationContext; } public void setWorkingDirectory(File workingDir) { Assert.notNull(workingDir, "workingDir cannot be null"); - - logger.info("Setting working directory for LDAP_PROVIDER: " - + workingDir.getAbsolutePath()); - - if (workingDir.exists()) { - throw new IllegalArgumentException( - "The specified working directory '" - + workingDir.getAbsolutePath() - + "' already exists. Another directory service instance may be using it or it may be from a " - + " previous unclean shutdown. Please confirm and delete it or configure a different " - + "working directory"); - } - + this.logger.info("Setting working directory for LDAP_PROVIDER: " + workingDir.getAbsolutePath()); + Assert.isTrue(!workingDir.exists(), + "The specified working directory '" + workingDir.getAbsolutePath() + + "' already exists. Another directory service instance may be using it or it may be from a " + + " previous unclean shutdown. Please confirm and delete it or configure a different " + + "working directory"); this.workingDir = workingDir; - - service.setWorkingDirectory(workingDir); + this.service.setWorkingDirectory(workingDir); } public void setPort(int port) { @@ -197,7 +187,6 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life /** * Returns the port that is resolved by {@link TcpTransport}. - * * @return the port that is resolved by {@link TcpTransport} */ public int getLocalPort() { @@ -207,7 +196,6 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life /** * If set to {@code true} will enable LDAP over SSL (LDAPs). If set to {@code true} * {@link ApacheDSContainer#setCertificatePassord(String)} must be set as well. - * * @param ldapOverSslEnabled If not set, will default to false */ public void setLdapOverSslEnabled(boolean ldapOverSslEnabled) { @@ -215,7 +203,8 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life } /** - * The keyStore must not be null and must be a valid file. Will set the keyStore file on the underlying {@link LdapServer}. + * The keyStore must not be null and must be a valid file. Will set the keyStore file + * on the underlying {@link LdapServer}. * @param keyStoreFile Mandatory if LDAPs is enabled */ public void setKeyStoreFile(File keyStoreFile) { @@ -226,7 +215,6 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life /** * Will set the certificate password on the underlying {@link LdapServer}. - * * @param certificatePassord May be null */ public void setCertificatePassord(String certificatePassord) { @@ -234,125 +222,107 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life } public DefaultDirectoryService getService() { - return service; + return this.service; } + @Override public void start() { if (isRunning()) { return; } - - if (service.isStarted()) { - throw new IllegalStateException("DirectoryService is already running."); - } - - logger.info("Starting directory server..."); + Assert.state(!this.service.isStarted(), "DirectoryService is already running."); + this.logger.info("Starting directory server..."); try { - service.startup(); - server.start(); + this.service.startup(); + this.server.start(); } - catch (Exception e) { - throw new RuntimeException("Server startup failed", e); + catch (Exception ex) { + throw new RuntimeException("Server startup failed", ex); } - try { - service.getAdminSession().lookup(partition.getSuffixDn()); + this.service.getAdminSession().lookup(this.partition.getSuffixDn()); } - catch (LdapNameNotFoundException e) { - try { - LdapDN dn = new LdapDN(root); - Assert.isTrue(root.startsWith("dc="), "root must start with dc="); - String dc = root.substring(3, root.indexOf(',')); - ServerEntry entry = service.newEntry(dn); - entry.add("objectClass", "top", "domain", "extensibleObject"); - entry.add("dc", dc); - service.getAdminSession().add(entry); - } - catch (Exception e1) { - logger.error("Failed to create dc entry", e1); - } + catch (LdapNameNotFoundException ex) { + handleLdapNameNotFoundException(); } - catch (Exception e) { - logger.error("Lookup failed", e); + catch (Exception ex) { + this.logger.error("Lookup failed", ex); } - SocketAcceptor socketAcceptor = this.server.getSocketAcceptor(this.transport); InetSocketAddress localAddress = socketAcceptor.getLocalAddress(); this.localPort = localAddress.getPort(); - - running = true; - + this.running = true; try { importLdifs(); } - catch (Exception e) { - throw new RuntimeException("Failed to import LDIF file(s)", e); + catch (Exception ex) { + throw new RuntimeException("Failed to import LDIF file(s)", ex); } } + private void handleLdapNameNotFoundException() { + try { + LdapDN dn = new LdapDN(this.root); + Assert.isTrue(this.root.startsWith("dc="), "root must start with dc="); + String dc = this.root.substring(3, this.root.indexOf(',')); + ServerEntry entry = this.service.newEntry(dn); + entry.add("objectClass", "top", "domain", "extensibleObject"); + entry.add("dc", dc); + this.service.getAdminSession().add(entry); + } + catch (Exception ex) { + this.logger.error("Failed to create dc entry", ex); + } + } + + @Override public void stop() { if (!isRunning()) { return; } - - logger.info("Shutting down directory server ..."); + this.logger.info("Shutting down directory server ..."); try { - server.stop(); - service.shutdown(); + this.server.stop(); + this.service.shutdown(); } - catch (Exception e) { - logger.error("Shutdown failed", e); + catch (Exception ex) { + this.logger.error("Shutdown failed", ex); return; } - - running = false; - - if (workingDir.exists()) { - logger.info("Deleting working directory " + workingDir.getAbsolutePath()); - deleteDir(workingDir); + this.running = false; + if (this.workingDir.exists()) { + this.logger.info("Deleting working directory " + this.workingDir.getAbsolutePath()); + deleteDir(this.workingDir); } } private void importLdifs() throws Exception { // Import any ldif files - Resource[] ldifs; - - if (ctxt == null) { - // Not running within an app context - ldifs = new PathMatchingResourcePatternResolver().getResources(ldifResources); - } - else { - ldifs = ctxt.getResources(ldifResources); - } - + Resource[] ldifs = (this.ctxt != null) ? this.ctxt.getResources(this.ldifResources) + : new PathMatchingResourcePatternResolver().getResources(this.ldifResources); // Note that we can't just import using the ServerContext returned // from starting Apache DS, apparently because of the long-running issue // DIRSERVER-169. // We need a standard context. // DirContext dirContext = contextSource.getReadWriteContext(); - if (ldifs == null || ldifs.length == 0) { return; } + Assert.isTrue(ldifs.length == 1, () -> "More than one LDIF resource found with the supplied pattern:" + + this.ldifResources + " Got " + Arrays.toString(ldifs)); + String ldifFile = getLdifFile(ldifs); + this.logger.info("Loading LDIF file: " + ldifFile); + LdifFileLoader loader = new LdifFileLoader(this.service.getAdminSession(), new File(ldifFile), null, + getClass().getClassLoader()); + loader.execute(); + } - if (ldifs.length == 1) { - String ldifFile; - - try { - ldifFile = ldifs[0].getFile().getAbsolutePath(); - } - catch (IOException e) { - ldifFile = ldifs[0].getURI().toString(); - } - logger.info("Loading LDIF file: " + ldifFile); - LdifFileLoader loader = new LdifFileLoader(service.getAdminSession(), - new File(ldifFile), null, getClass().getClassLoader()); - loader.execute(); + private String getLdifFile(Resource[] ldifs) throws IOException { + try { + return ldifs[0].getFile().getAbsolutePath(); } - else { - throw new IllegalArgumentException( - "More than one LDIF resource found with the supplied pattern:" - + ldifResources + " Got " + Arrays.toString(ldifs)); + catch (IOException ex) { + return ldifs[0].getURI().toString(); } } @@ -360,7 +330,6 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life String parentTempDir = System.getProperty("java.io.tmpdir"); String fileNamePrefix = prefix + System.nanoTime(); String fileName = fileNamePrefix; - for (int i = 0; i < 1000; i++) { File tempDir = new File(parentTempDir, fileName); if (!tempDir.exists()) { @@ -368,9 +337,8 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life } fileName = fileNamePrefix + "~" + i; } - - throw new IOException("Failed to create a temporary directory for file at " - + new File(parentTempDir, fileNamePrefix)); + throw new IOException( + "Failed to create a temporary directory for file at " + new File(parentTempDir, fileNamePrefix)); } private boolean deleteDir(File dir) { @@ -383,11 +351,12 @@ public class ApacheDSContainer implements InitializingBean, DisposableBean, Life } } } - return dir.delete(); } + @Override public boolean isRunning() { - return running; + return this.running; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/server/UnboundIdContainer.java b/ldap/src/main/java/org/springframework/security/ldap/server/UnboundIdContainer.java index 9d5794be27..269b8adae1 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/server/UnboundIdContainer.java +++ b/ldap/src/main/java/org/springframework/security/ldap/server/UnboundIdContainer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.server; import java.io.InputStream; @@ -37,8 +38,7 @@ import org.springframework.util.StringUtils; /** * @author Eddú Meléndez */ -public class UnboundIdContainer implements InitializingBean, DisposableBean, Lifecycle, - ApplicationContextAware { +public class UnboundIdContainer implements InitializingBean, DisposableBean, Lifecycle, ApplicationContextAware { private InMemoryDirectoryServer directoryServer; @@ -85,20 +85,16 @@ public class UnboundIdContainer implements InitializingBean, DisposableBean, Lif if (isRunning()) { return; } - try { InMemoryDirectoryServerConfig config = new InMemoryDirectoryServerConfig(this.defaultPartitionSuffix); config.addAdditionalBindCredentials("uid=admin,ou=system", "secret"); - config.setListenerConfigs(InMemoryListenerConfig.createLDAPConfig("LDAP", this.port)); config.setEnforceSingleStructuralObjectClass(false); config.setEnforceAttributeSyntaxCompliance(true); - DN dn = new DN(this.defaultPartitionSuffix); Entry entry = new Entry(dn); entry.addAttribute("objectClass", "top", "domain", "extensibleObject"); entry.addAttribute("dc", dn.getRDN().getAttributeValues()[0]); - InMemoryDirectoryServer directoryServer = new InMemoryDirectoryServer(config); directoryServer.add(entry); importLdif(directoryServer); @@ -106,10 +102,10 @@ public class UnboundIdContainer implements InitializingBean, DisposableBean, Lif this.port = directoryServer.getListenPort(); this.directoryServer = directoryServer; this.running = true; - } catch (LDAPException ex) { + } + catch (LDAPException ex) { throw new RuntimeException("Server startup failed", ex); } - } private void importLdif(InMemoryDirectoryServer directoryServer) { @@ -124,7 +120,8 @@ public class UnboundIdContainer implements InitializingBean, DisposableBean, Lif directoryServer.importFromLDIF(false, new LDIFReader(inputStream)); } } - } catch (Exception ex) { + } + catch (Exception ex) { throw new IllegalStateException("Unable to load LDIF " + this.ldif, ex); } } @@ -139,4 +136,5 @@ public class UnboundIdContainer implements InitializingBean, DisposableBean, Lif public boolean isRunning() { return this.running; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/server/package-info.java b/ldap/src/main/java/org/springframework/security/ldap/server/package-info.java index 8b1c1f1484..eb33468157 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/server/package-info.java +++ b/ldap/src/main/java/org/springframework/security/ldap/server/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Embedded Apache Directory Server implementation, as used by the configuration namespace. + * Embedded Apache Directory Server implementation, as used by the configuration + * namespace. */ package org.springframework.security.ldap.server; - diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulator.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulator.java index 12f311f1c8..8ef21554c4 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/DefaultLdapAuthoritiesPopulator.java @@ -28,6 +28,8 @@ import javax.naming.directory.SearchControls; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.ldap.core.ContextSource; import org.springframework.ldap.core.DirContextOperations; import org.springframework.ldap.core.LdapTemplate; @@ -99,14 +101,8 @@ import org.springframework.util.Assert; * @author Filip Hanik */ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator { - // ~ Static fields/initializers - // ===================================================================================== - private static final Log logger = LogFactory - .getLog(DefaultLdapAuthoritiesPopulator.class); - - // ~ Instance fields - // ================================================================================================ + private static final Log logger = LogFactory.getLog(DefaultLdapAuthoritiesPopulator.class); /** * A default role which will be assigned to all authenticated users if set @@ -154,92 +150,66 @@ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator */ private Function>, GrantedAuthority> authorityMapper; - // ~ Constructors - // =================================================================================================== - /** * Constructor for group search scenarios. userRoleAttributes may still be * set as a property. - * * @param contextSource supplies the contexts used to search for user roles. * @param groupSearchBase if this is an empty string the search will be performed from * the root DN of the context factory. If null, no search will be performed. */ - public DefaultLdapAuthoritiesPopulator(ContextSource contextSource, - String groupSearchBase) { + public DefaultLdapAuthoritiesPopulator(ContextSource contextSource, String groupSearchBase) { Assert.notNull(contextSource, "contextSource must not be null"); this.ldapTemplate = new SpringSecurityLdapTemplate(contextSource); getLdapTemplate().setSearchControls(getSearchControls()); this.groupSearchBase = groupSearchBase; - if (groupSearchBase == null) { logger.info("groupSearchBase is null. No group search will be performed."); } else if (groupSearchBase.length() == 0) { - logger.info( - "groupSearchBase is empty. Searches will be performed from the context source base"); + logger.info("groupSearchBase is empty. Searches will be performed from the context source base"); } - - this.authorityMapper = record -> { + this.authorityMapper = (record) -> { String role = record.get(this.groupRoleAttribute).get(0); - if (this.convertToUpperCase) { role = role.toUpperCase(); } - return new SimpleGrantedAuthority(this.rolePrefix + role); }; } - // ~ Methods - // ======================================================================================================== - /** * This method should be overridden if required to obtain any additional roles for the * given user (on top of those obtained from the standard search implemented by this * class). - * * @param user the context representing the user who's roles are required * @return the extra roles which will be merged with those returned by the group * search */ - protected Set getAdditionalRoles(DirContextOperations user, - String username) { + protected Set getAdditionalRoles(DirContextOperations user, String username) { return null; } /** * Obtains the authorities for the user who's directory entry is represented by the * supplied LdapUserDetails object. - * * @param user the user who's authorities are required * @return the set of roles granted to the user. */ @Override - public final Collection getGrantedAuthorities( - DirContextOperations user, String username) { + public final Collection getGrantedAuthorities(DirContextOperations user, String username) { String userDn = user.getNameInNamespace(); - - if (logger.isDebugEnabled()) { - logger.debug("Getting authorities for user " + userDn); - } - + logger.debug(LogMessage.format("Getting authorities for user %s", userDn)); Set roles = getGroupMembershipRoles(userDn, username); - Set extraRoles = getAdditionalRoles(user, username); - if (extraRoles != null) { roles.addAll(extraRoles); } - if (this.defaultRole != null) { roles.add(this.defaultRole); } - List result = new ArrayList<>(roles.size()); result.addAll(roles); - return result; } @@ -247,29 +217,16 @@ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator if (getGroupSearchBase() == null) { return new HashSet<>(); } - Set authorities = new HashSet<>(); - - if (logger.isDebugEnabled()) { - logger.debug("Searching for roles for user '" + username + "', DN = " + "'" - + userDn + "', with filter " + this.groupSearchFilter - + " in search base '" + getGroupSearchBase() + "'"); - } - - Set>> userRoles = getLdapTemplate() - .searchForMultipleAttributeValues(getGroupSearchBase(), - this.groupSearchFilter, - new String[] { userDn, username }, - new String[] { this.groupRoleAttribute }); - - if (logger.isDebugEnabled()) { - logger.debug("Roles from search: " + userRoles); - } - + logger.debug(LogMessage.of(() -> "Searching for roles for user '" + username + "', DN = " + "'" + userDn + + "', with filter " + this.groupSearchFilter + " in search base '" + getGroupSearchBase() + "'")); + Set>> userRoles = getLdapTemplate().searchForMultipleAttributeValues( + getGroupSearchBase(), this.groupSearchFilter, new String[] { userDn, username }, + new String[] { this.groupRoleAttribute }); + logger.debug(LogMessage.of(() -> "Roles from search: " + userRoles)); for (Map> role : userRoles) { - authorities.add(authorityMapper.apply(role)); + authorities.add(this.authorityMapper.apply(role)); } - return authorities; } @@ -290,7 +247,6 @@ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator /** * The default role which will be assigned to all users. - * * @param defaultRole the role name, including any desired prefix. */ public void setDefaultRole(String defaultRole) { @@ -320,13 +276,11 @@ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator /** * If set to true, a subtree scope search will be performed. If false a single-level * search is used. - * * @param searchSubtree set to true to enable searching of the entire tree below the * groupSearchBase. */ public void setSearchSubtree(boolean searchSubtree) { - int searchScope = searchSubtree ? SearchControls.SUBTREE_SCOPE - : SearchControls.ONELEVEL_SCOPE; + int searchScope = searchSubtree ? SearchControls.SUBTREE_SCOPE : SearchControls.ONELEVEL_SCOPE; this.searchControls.setSearchScope(searchScope); } @@ -341,9 +295,8 @@ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator } /** - * Sets the mapping function which will be used to create instances of {@link GrantedAuthority} - * given the context record. - * + * Sets the mapping function which will be used to create instances of + * {@link GrantedAuthority} given the context record. * @param authorityMapper the mapping function */ public void setAuthorityMapper(Function>, GrantedAuthority> authorityMapper) { @@ -419,4 +372,5 @@ public class DefaultLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator private SearchControls getSearchControls() { return this.searchControls; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPerson.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPerson.java index dc0969115b..2c70b4f425 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPerson.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPerson.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import org.springframework.ldap.core.DirContextAdapter; @@ -26,131 +27,150 @@ import org.springframework.security.core.SpringSecurityCoreVersion; *

        * The username will be mapped from the uid attribute by default. * - * @author Luke + * @author Luke Taylor */ public class InetOrgPerson extends Person { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private String carLicense; + // Person.cn private String destinationIndicator; + private String departmentNumber; + // Person.description private String displayName; + private String employeeNumber; + private String homePhone; + private String homePostalAddress; + private String initials; + private String mail; + private String mobile; + private String o; + private String ou; + private String postalAddress; + private String postalCode; + private String roomNumber; + private String street; + // Person.sn // Person.telephoneNumber private String title; + private String uid; public String getUid() { - return uid; + return this.uid; } public String getMail() { - return mail; + return this.mail; } public String getEmployeeNumber() { - return employeeNumber; + return this.employeeNumber; } public String getInitials() { - return initials; + return this.initials; } public String getDestinationIndicator() { - return destinationIndicator; + return this.destinationIndicator; } public String getO() { - return o; + return this.o; } public String getOu() { - return ou; + return this.ou; } public String getTitle() { - return title; + return this.title; } public String getCarLicense() { - return carLicense; + return this.carLicense; } public String getDepartmentNumber() { - return departmentNumber; + return this.departmentNumber; } public String getDisplayName() { - return displayName; + return this.displayName; } public String getHomePhone() { - return homePhone; + return this.homePhone; } public String getRoomNumber() { - return roomNumber; + return this.roomNumber; } public String getHomePostalAddress() { - return homePostalAddress; + return this.homePostalAddress; } public String getMobile() { - return mobile; + return this.mobile; } public String getPostalAddress() { - return postalAddress; + return this.postalAddress; } public String getPostalCode() { - return postalCode; + return this.postalCode; } public String getStreet() { - return street; + return this.street; } + @Override protected void populateContext(DirContextAdapter adapter) { super.populateContext(adapter); - adapter.setAttributeValue("carLicense", carLicense); - adapter.setAttributeValue("departmentNumber", departmentNumber); - adapter.setAttributeValue("destinationIndicator", destinationIndicator); - adapter.setAttributeValue("displayName", displayName); - adapter.setAttributeValue("employeeNumber", employeeNumber); - adapter.setAttributeValue("homePhone", homePhone); - adapter.setAttributeValue("homePostalAddress", homePostalAddress); - adapter.setAttributeValue("initials", initials); - adapter.setAttributeValue("mail", mail); - adapter.setAttributeValue("mobile", mobile); - adapter.setAttributeValue("postalAddress", postalAddress); - adapter.setAttributeValue("postalCode", postalCode); - adapter.setAttributeValue("ou", ou); - adapter.setAttributeValue("o", o); - adapter.setAttributeValue("roomNumber", roomNumber); - adapter.setAttributeValue("street", street); - adapter.setAttributeValue("uid", uid); - adapter.setAttributeValues("objectclass", new String[] { "top", "person", - "organizationalPerson", "inetOrgPerson" }); + adapter.setAttributeValue("carLicense", this.carLicense); + adapter.setAttributeValue("departmentNumber", this.departmentNumber); + adapter.setAttributeValue("destinationIndicator", this.destinationIndicator); + adapter.setAttributeValue("displayName", this.displayName); + adapter.setAttributeValue("employeeNumber", this.employeeNumber); + adapter.setAttributeValue("homePhone", this.homePhone); + adapter.setAttributeValue("homePostalAddress", this.homePostalAddress); + adapter.setAttributeValue("initials", this.initials); + adapter.setAttributeValue("mail", this.mail); + adapter.setAttributeValue("mobile", this.mobile); + adapter.setAttributeValue("postalAddress", this.postalAddress); + adapter.setAttributeValue("postalCode", this.postalCode); + adapter.setAttributeValue("ou", this.ou); + adapter.setAttributeValue("o", this.o); + adapter.setAttributeValue("roomNumber", this.roomNumber); + adapter.setAttributeValue("street", this.street); + adapter.setAttributeValue("uid", this.uid); + adapter.setAttributeValues("objectclass", + new String[] { "top", "person", "organizationalPerson", "inetOrgPerson" }); } public static class Essence extends Person.Essence { + public Essence() { } @@ -198,84 +218,87 @@ public class InetOrgPerson extends Person { setUid(ctx.getStringAttribute("uid")); } + @Override protected LdapUserDetailsImpl createTarget() { return new InetOrgPerson(); } public void setMail(String email) { - ((InetOrgPerson) instance).mail = email; + ((InetOrgPerson) this.instance).mail = email; } public void setUid(String uid) { - ((InetOrgPerson) instance).uid = uid; + ((InetOrgPerson) this.instance).uid = uid; - if (instance.getUsername() == null) { + if (this.instance.getUsername() == null) { setUsername(uid); } } public void setInitials(String initials) { - ((InetOrgPerson) instance).initials = initials; + ((InetOrgPerson) this.instance).initials = initials; } public void setO(String organization) { - ((InetOrgPerson) instance).o = organization; + ((InetOrgPerson) this.instance).o = organization; } public void setOu(String ou) { - ((InetOrgPerson) instance).ou = ou; + ((InetOrgPerson) this.instance).ou = ou; } public void setRoomNumber(String no) { - ((InetOrgPerson) instance).roomNumber = no; + ((InetOrgPerson) this.instance).roomNumber = no; } public void setTitle(String title) { - ((InetOrgPerson) instance).title = title; + ((InetOrgPerson) this.instance).title = title; } public void setCarLicense(String carLicense) { - ((InetOrgPerson) instance).carLicense = carLicense; + ((InetOrgPerson) this.instance).carLicense = carLicense; } public void setDepartmentNumber(String departmentNumber) { - ((InetOrgPerson) instance).departmentNumber = departmentNumber; + ((InetOrgPerson) this.instance).departmentNumber = departmentNumber; } public void setDisplayName(String displayName) { - ((InetOrgPerson) instance).displayName = displayName; + ((InetOrgPerson) this.instance).displayName = displayName; } public void setEmployeeNumber(String no) { - ((InetOrgPerson) instance).employeeNumber = no; + ((InetOrgPerson) this.instance).employeeNumber = no; } public void setDestinationIndicator(String destination) { - ((InetOrgPerson) instance).destinationIndicator = destination; + ((InetOrgPerson) this.instance).destinationIndicator = destination; } public void setHomePhone(String homePhone) { - ((InetOrgPerson) instance).homePhone = homePhone; + ((InetOrgPerson) this.instance).homePhone = homePhone; } public void setStreet(String street) { - ((InetOrgPerson) instance).street = street; + ((InetOrgPerson) this.instance).street = street; } public void setPostalCode(String postalCode) { - ((InetOrgPerson) instance).postalCode = postalCode; + ((InetOrgPerson) this.instance).postalCode = postalCode; } public void setPostalAddress(String postalAddress) { - ((InetOrgPerson) instance).postalAddress = postalAddress; + ((InetOrgPerson) this.instance).postalAddress = postalAddress; } public void setMobile(String mobile) { - ((InetOrgPerson) instance).mobile = mobile; + ((InetOrgPerson) this.instance).mobile = mobile; } public void setHomePostalAddress(String homePostalAddress) { - ((InetOrgPerson) instance).homePostalAddress = homePostalAddress; + ((InetOrgPerson) this.instance).homePostalAddress = homePostalAddress; } + } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPersonContextMapper.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPersonContextMapper.java index f75cbfca58..a5770becd8 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPersonContextMapper.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/InetOrgPersonContextMapper.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import java.util.Collection; +import org.springframework.ldap.core.DirContextAdapter; +import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.ldap.core.DirContextOperations; -import org.springframework.ldap.core.DirContextAdapter; import org.springframework.util.Assert; /** @@ -28,22 +29,21 @@ import org.springframework.util.Assert; */ public class InetOrgPersonContextMapper implements UserDetailsContextMapper { + @Override public UserDetails mapUserFromContext(DirContextOperations ctx, String username, Collection authorities) { InetOrgPerson.Essence p = new InetOrgPerson.Essence(ctx); - p.setUsername(username); p.setAuthorities(authorities); - return p.createUserDetails(); } + @Override public void mapUserToContext(UserDetails user, DirContextAdapter ctx) { - Assert.isInstanceOf(InetOrgPerson.class, user, - "UserDetails must be an InetOrgPerson instance"); - + Assert.isInstanceOf(InetOrgPerson.class, user, "UserDetails must be an InetOrgPerson instance"); InetOrgPerson p = (InetOrgPerson) user; p.populateContext(ctx); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthoritiesPopulator.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthoritiesPopulator.java index 0a18f2c36a..629c177ba3 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthoritiesPopulator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthoritiesPopulator.java @@ -18,9 +18,8 @@ package org.springframework.security.ldap.userdetails; import java.util.Collection; -import org.springframework.security.core.GrantedAuthority; - import org.springframework.ldap.core.DirContextOperations; +import org.springframework.security.core.GrantedAuthority; /** * Obtains a list of granted authorities for an Ldap user. @@ -32,17 +31,13 @@ import org.springframework.ldap.core.DirContextOperations; * @author Luke Taylor */ public interface LdapAuthoritiesPopulator { - // ~ Methods - // ======================================================================================================== /** * Get the list of authorities for the user. - * * @param userData the context object which was returned by the LDAP authenticator. - * * @return the granted authorities for the given user. * */ - Collection getGrantedAuthorities( - DirContextOperations userData, String username); + Collection getGrantedAuthorities(DirContextOperations userData, String username); + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthority.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthority.java index 362bf0e874..9e89a5d631 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthority.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapAuthority.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap.userdetails; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.util.Assert; +package org.springframework.security.ldap.userdetails; import java.util.Collections; import java.util.List; import java.util.Map; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.util.Assert; + /** * An authority that contains at least a DN and a role name for an LDAP entry but can also * contain other desired attributes to be fetched during an LDAP authority search. @@ -31,12 +32,13 @@ import java.util.Map; public class LdapAuthority implements GrantedAuthority { private String dn; + private String role; + private Map> attributes; /** * Constructs an LdapAuthority that has a role and a DN but no other attributes - * * @param role * @param dn */ @@ -46,7 +48,6 @@ public class LdapAuthority implements GrantedAuthority { /** * Constructs an LdapAuthority with the given role, DN and other LDAP attributes - * * @param role * @param dn * @param attributes @@ -54,7 +55,6 @@ public class LdapAuthority implements GrantedAuthority { public LdapAuthority(String role, String dn, Map> attributes) { Assert.notNull(role, "role can not be null"); Assert.notNull(dn, "dn can not be null"); - this.role = role; this.dn = dn; this.attributes = attributes; @@ -62,93 +62,77 @@ public class LdapAuthority implements GrantedAuthority { /** * Returns the LDAP attributes - * * @return the LDAP attributes, map can be null */ public Map> getAttributes() { - return attributes; + return this.attributes; } /** * Returns the DN for this LDAP authority - * * @return */ public String getDn() { - return dn; + return this.dn; } /** * Returns the values for a specific attribute - * * @param name the attribute name * @return a String array, never null but may be zero length */ public List getAttributeValues(String name) { List result = null; - if (attributes != null) { - result = attributes.get(name); + if (this.attributes != null) { + result = this.attributes.get(name); } - if (result == null) { - result = Collections.emptyList(); - } - return result; + return (result != null) ? result : Collections.emptyList(); } /** * Returns the first attribute value for a specified attribute - * * @param name * @return the first attribute value for a specified attribute, may be null */ public String getFirstAttributeValue(String name) { List result = getAttributeValues(name); - if (result.isEmpty()) { - return null; - } - else { - return result.get(0); - } + return (!result.isEmpty()) ? result.get(0) : null; } - /** - * {@inheritDoc} - */ @Override public String getAuthority() { - return role; + return this.role; } /** * Compares the LdapAuthority based on {@link #getAuthority()} and {@link #getDn()} - * values {@inheritDoc} + * values. */ @Override - public boolean equals(Object o) { - if (this == o) { + public boolean equals(Object obj) { + if (this == obj) { return true; } - if (!(o instanceof LdapAuthority)) { + if (!(obj instanceof LdapAuthority)) { return false; } - - LdapAuthority that = (LdapAuthority) o; - - if (!dn.equals(that.dn)) { + LdapAuthority other = (LdapAuthority) obj; + if (!this.dn.equals(other.dn)) { return false; } - return role.equals(that.role); + return this.role.equals(other.role); } @Override public int hashCode() { - int result = dn.hashCode(); - result = 31 * result + (role != null ? role.hashCode() : 0); + int result = this.dn.hashCode(); + result = 31 * result + ((this.role != null) ? this.role.hashCode() : 0); return result; } @Override public String toString() { - return "LdapAuthority{" + "dn='" + dn + '\'' + ", role='" + role + '\'' + '}'; + return "LdapAuthority{" + "dn='" + this.dn + '\'' + ", role='" + this.role + '\'' + '}'; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetails.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetails.java index 4729184fca..71c7954baa 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetails.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetails.java @@ -25,13 +25,11 @@ import org.springframework.security.core.userdetails.UserDetails; * @author Luke Taylor */ public interface LdapUserDetails extends UserDetails, CredentialsContainer { - // ~ Methods - // ======================================================================================================== /** * The DN of the entry for this user's account. - * * @return the user's DN */ String getDn(); + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImpl.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImpl.java index 831aa3c1ac..29a7323e3d 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImpl.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImpl.java @@ -51,115 +51,112 @@ public class LdapUserDetailsImpl implements LdapUserDetails, PasswordPolicyData private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private String dn; + private String password; + private String username; + private Collection authorities = AuthorityUtils.NO_AUTHORITIES; + private boolean accountNonExpired = true; + private boolean accountNonLocked = true; + private boolean credentialsNonExpired = true; + private boolean enabled = true; + // PPolicy data private int timeBeforeExpiration = Integer.MAX_VALUE; - private int graceLoginsRemaining = Integer.MAX_VALUE; - // ~ Constructors - // =================================================================================================== + private int graceLoginsRemaining = Integer.MAX_VALUE; protected LdapUserDetailsImpl() { } - // ~ Methods - // ======================================================================================================== - @Override public Collection getAuthorities() { - return authorities; + return this.authorities; } @Override public String getDn() { - return dn; + return this.dn; } @Override public String getPassword() { - return password; + return this.password; } @Override public String getUsername() { - return username; + return this.username; } @Override public boolean isAccountNonExpired() { - return accountNonExpired; + return this.accountNonExpired; } @Override public boolean isAccountNonLocked() { - return accountNonLocked; + return this.accountNonLocked; } @Override public boolean isCredentialsNonExpired() { - return credentialsNonExpired; + return this.credentialsNonExpired; } @Override public boolean isEnabled() { - return enabled; + return this.enabled; } @Override public void eraseCredentials() { - password = null; + this.password = null; } @Override public int getTimeBeforeExpiration() { - return timeBeforeExpiration; + return this.timeBeforeExpiration; } @Override public int getGraceLoginsRemaining() { - return graceLoginsRemaining; + return this.graceLoginsRemaining; } @Override public boolean equals(Object obj) { if (obj instanceof LdapUserDetailsImpl) { - return dn.equals(((LdapUserDetailsImpl) obj).dn); + return this.dn.equals(((LdapUserDetailsImpl) obj).dn); } return false; } @Override public int hashCode() { - return dn.hashCode(); + return this.dn.hashCode(); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()).append(": "); - sb.append("Dn: ").append(dn).append("; "); + sb.append("Dn: ").append(this.dn).append("; "); sb.append("Username: ").append(this.username).append("; "); sb.append("Password: [PROTECTED]; "); sb.append("Enabled: ").append(this.enabled).append("; "); sb.append("AccountNonExpired: ").append(this.accountNonExpired).append("; "); - sb.append("CredentialsNonExpired: ").append(this.credentialsNonExpired) - .append("; "); + sb.append("CredentialsNonExpired: ").append(this.credentialsNonExpired).append("; "); sb.append("AccountNonLocked: ").append(this.accountNonLocked).append("; "); - if (this.getAuthorities() != null && !this.getAuthorities().isEmpty()) { sb.append("Granted Authorities: "); boolean first = true; - for (Object authority : this.getAuthorities()) { if (first) { first = false; @@ -167,25 +164,22 @@ public class LdapUserDetailsImpl implements LdapUserDetails, PasswordPolicyData else { sb.append(", "); } - sb.append(authority.toString()); } } else { sb.append("Not granted any authorities"); } - return sb.toString(); } - // ~ Inner Classes - // ================================================================================================== - /** * Variation of essence pattern. Used to create mutable intermediate object */ public static class Essence { + protected LdapUserDetailsImpl instance = createTarget(); + private List mutableAuthorities = new ArrayList<>(); public Essence() { @@ -216,12 +210,12 @@ public class LdapUserDetailsImpl implements LdapUserDetails, PasswordPolicyData */ public void addAuthority(GrantedAuthority a) { if (!hasAuthority(a)) { - mutableAuthorities.add(a); + this.mutableAuthorities.add(a); } } private boolean hasAuthority(GrantedAuthority a) { - for (GrantedAuthority authority : mutableAuthorities) { + for (GrantedAuthority authority : this.mutableAuthorities) { if (authority.equals(a)) { return true; } @@ -230,67 +224,64 @@ public class LdapUserDetailsImpl implements LdapUserDetails, PasswordPolicyData } public LdapUserDetails createUserDetails() { - Assert.notNull(instance, - "Essence can only be used to create a single instance"); - Assert.notNull(instance.username, "username must not be null"); - Assert.notNull(instance.getDn(), "Distinguished name must not be null"); - - instance.authorities = Collections.unmodifiableList(mutableAuthorities); - - LdapUserDetails newInstance = instance; - - instance = null; - + Assert.notNull(this.instance, "Essence can only be used to create a single instance"); + Assert.notNull(this.instance.username, "username must not be null"); + Assert.notNull(this.instance.getDn(), "Distinguished name must not be null"); + this.instance.authorities = Collections.unmodifiableList(this.mutableAuthorities); + LdapUserDetails newInstance = this.instance; + this.instance = null; return newInstance; } public Collection getGrantedAuthorities() { - return mutableAuthorities; + return this.mutableAuthorities; } public void setAccountNonExpired(boolean accountNonExpired) { - instance.accountNonExpired = accountNonExpired; + this.instance.accountNonExpired = accountNonExpired; } public void setAccountNonLocked(boolean accountNonLocked) { - instance.accountNonLocked = accountNonLocked; + this.instance.accountNonLocked = accountNonLocked; } public void setAuthorities(Collection authorities) { - mutableAuthorities = new ArrayList<>(); - mutableAuthorities.addAll(authorities); + this.mutableAuthorities = new ArrayList<>(); + this.mutableAuthorities.addAll(authorities); } public void setCredentialsNonExpired(boolean credentialsNonExpired) { - instance.credentialsNonExpired = credentialsNonExpired; + this.instance.credentialsNonExpired = credentialsNonExpired; } public void setDn(String dn) { - instance.dn = dn; + this.instance.dn = dn; } public void setDn(Name dn) { - instance.dn = dn.toString(); + this.instance.dn = dn.toString(); } public void setEnabled(boolean enabled) { - instance.enabled = enabled; + this.instance.enabled = enabled; } public void setPassword(String password) { - instance.password = password; + this.instance.password = password; } public void setUsername(String username) { - instance.username = username; + this.instance.username = username; } public void setTimeBeforeExpiration(int timeBeforeExpiration) { - instance.timeBeforeExpiration = timeBeforeExpiration; + this.instance.timeBeforeExpiration = timeBeforeExpiration; } public void setGraceLoginsRemaining(int graceLoginsRemaining) { - instance.graceLoginsRemaining = graceLoginsRemaining; + this.instance.graceLoginsRemaining = graceLoginsRemaining; } + } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManager.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManager.java index af8f214591..b2baff26a3 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManager.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import java.io.ByteArrayOutputStream; @@ -22,6 +23,7 @@ import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.ListIterator; + import javax.naming.Context; import javax.naming.NameNotFoundException; import javax.naming.NamingEnumeration; @@ -38,6 +40,7 @@ import javax.naming.ldap.LdapContext; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.ldap.core.AttributesMapper; import org.springframework.ldap.core.AttributesMapperCallbackHandler; import org.springframework.ldap.core.ContextExecutor; @@ -76,14 +79,14 @@ import org.springframework.util.Assert; * @since 2.0 */ public class LdapUserDetailsManager implements UserDetailsManager { + private final Log logger = LogFactory.getLog(LdapUserDetailsManager.class); /** * The strategy for mapping usernames to LDAP distinguished names. This will be used * when building DNs for creating new users etc. */ - LdapUsernameToDnMapper usernameMapper = new DefaultLdapUsernameToDnMapper("cn=users", - "uid"); + LdapUsernameToDnMapper usernameMapper = new DefaultLdapUsernameToDnMapper("cn=users", "uid"); /** The DN under which groups are stored */ private DistinguishedName groupSearchBase = new DistinguishedName("cn=groups"); @@ -93,6 +96,7 @@ public class LdapUserDetailsManager implements UserDetailsManager { /** The attribute which corresponds to the role name of a group. */ private String groupRoleAttributeName = "cn"; + /** The attribute which contains members of a group */ private String groupMemberAttributeName = "uniquemember"; @@ -100,6 +104,7 @@ public class LdapUserDetailsManager implements UserDetailsManager { /** The pattern to be used for the user search. {0} is the user's DN */ private String groupSearchFilter = "(uniquemember={0})"; + /** * The strategy used to create a UserDetails object from the LDAP context, username * and list of authorities. This should be set to match the required UserDetails @@ -110,15 +115,12 @@ public class LdapUserDetailsManager implements UserDetailsManager { private final LdapTemplate template; /** Default context mapper used to create a set of roles from a list of attributes */ - private AttributesMapper roleMapper = attributes -> { - Attribute roleAttr = attributes.get(groupRoleAttributeName); - + private AttributesMapper roleMapper = (attributes) -> { + Attribute roleAttr = attributes.get(this.groupRoleAttributeName); NamingEnumeration ne = roleAttr.getAll(); - // assert ne.hasMore(); Object group = ne.next(); String role = group.toString(); - - return new SimpleGrantedAuthority(rolePrefix + role.toUpperCase()); + return new SimpleGrantedAuthority(this.rolePrefix + role.toUpperCase()); }; private String[] attributesToRetrieve; @@ -126,30 +128,26 @@ public class LdapUserDetailsManager implements UserDetailsManager { private boolean usePasswordModifyExtensionOperation = false; public LdapUserDetailsManager(ContextSource contextSource) { - template = new LdapTemplate(contextSource); + this.template = new LdapTemplate(contextSource); } + @Override public UserDetails loadUserByUsername(String username) { - DistinguishedName dn = usernameMapper.buildDn(username); + DistinguishedName dn = this.usernameMapper.buildDn(username); List authorities = getUserAuthorities(dn, username); - - logger.debug("Loading user '" + username + "' with DN '" + dn + "'"); - + this.logger.debug(LogMessage.format("Loading user '%s' with DN '%s'", username, dn)); DirContextAdapter userCtx = loadUserAsContext(dn, username); - - return userDetailsMapper.mapUserFromContext(userCtx, username, authorities); + return this.userDetailsMapper.mapUserFromContext(userCtx, username, authorities); } - private DirContextAdapter loadUserAsContext(final DistinguishedName dn, - final String username) { - return (DirContextAdapter) template.executeReadOnly((ContextExecutor) ctx -> { + private DirContextAdapter loadUserAsContext(final DistinguishedName dn, final String username) { + return (DirContextAdapter) this.template.executeReadOnly((ContextExecutor) (ctx) -> { try { - Attributes attrs = ctx.getAttributes(dn, attributesToRetrieve); + Attributes attrs = ctx.getAttributes(dn, this.attributesToRetrieve); return new DirContextAdapter(attrs, LdapUtils.getFullDn(dn, ctx)); } - catch (NameNotFoundException notFound) { - throw new UsernameNotFoundException( - "User " + username + " not found", notFound); + catch (NameNotFoundException ex) { + throw new UsernameNotFoundException("User " + username + " not found", ex); } }); } @@ -163,106 +161,85 @@ public class LdapUserDetailsManager implements UserDetailsManager { * *

        * Configured one way, this method will modify the user's password via the - * - * LDAP Password Modify Extended Operation - * . + * LDAP Password Modify + * Extended Operation . * - * See {@link LdapUserDetailsManager#setUsePasswordModifyExtensionOperation(boolean)} for details. + * See {@link LdapUserDetailsManager#setUsePasswordModifyExtensionOperation(boolean)} + * for details. *

        * *

        - * By default, though, if the old password is supplied, the update will be made by rebinding as the user, - * thus modifying the password using the user's permissions. If + * By default, though, if the old password is supplied, the update will be made by + * rebinding as the user, thus modifying the password using the user's permissions. If * oldPassword is null, the update will be attempted using a standard * read/write context supplied by the context source. *

        - * * @param oldPassword the old password * @param newPassword the new value of the password. */ + @Override public void changePassword(final String oldPassword, final String newPassword) { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); - Assert.notNull( - authentication, + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + Assert.notNull(authentication, "No authentication object found in security context. Can't change current user's password!"); - String username = authentication.getName(); - - logger.debug("Changing password for user '" + username); - - DistinguishedName userDn = usernameMapper.buildDn(username); - - if (usePasswordModifyExtensionOperation) { + this.logger.debug(LogMessage.format("Changing password for user '%s'", username)); + DistinguishedName userDn = this.usernameMapper.buildDn(username); + if (this.usePasswordModifyExtensionOperation) { changePasswordUsingExtensionOperation(userDn, oldPassword, newPassword); - } else { + } + else { changePasswordUsingAttributeModification(userDn, oldPassword, newPassword); } } /** - * * @param dn the distinguished name of the entry - may be either relative to the base * context or a complete DN including the name of the context (either is supported). * @param username the user whose roles are required. * @return the granted authorities returned by the group search */ @SuppressWarnings("unchecked") - List getUserAuthorities(final DistinguishedName dn, - final String username) { - SearchExecutor se = ctx -> { + List getUserAuthorities(final DistinguishedName dn, final String username) { + SearchExecutor se = (ctx) -> { DistinguishedName fullDn = LdapUtils.getFullDn(dn, ctx); SearchControls ctrls = new SearchControls(); - ctrls.setReturningAttributes(new String[] { groupRoleAttributeName }); - - return ctx.search(groupSearchBase, groupSearchFilter, new String[] { - fullDn.toUrl(), username }, ctrls); + ctrls.setReturningAttributes(new String[] { this.groupRoleAttributeName }); + return ctx.search(this.groupSearchBase, this.groupSearchFilter, new String[] { fullDn.toUrl(), username }, + ctrls); }; - - AttributesMapperCallbackHandler roleCollector = new AttributesMapperCallbackHandler( - roleMapper); - - template.search(se, roleCollector); + AttributesMapperCallbackHandler roleCollector = new AttributesMapperCallbackHandler(this.roleMapper); + this.template.search(se, roleCollector); return roleCollector.getList(); } + @Override public void createUser(UserDetails user) { DirContextAdapter ctx = new DirContextAdapter(); copyToContext(user, ctx); - DistinguishedName dn = usernameMapper.buildDn(user.getUsername()); - - logger.debug("Creating new user '" + user.getUsername() + "' with DN '" + dn - + "'"); - - template.bind(dn, ctx, null); - - // Check for any existing authorities which might be set for this DN and remove - // them + DistinguishedName dn = this.usernameMapper.buildDn(user.getUsername()); + this.logger.debug(LogMessage.format("Creating new user '%s' with DN '%s'", user.getUsername(), dn)); + this.template.bind(dn, ctx, null); + // Check for any existing authorities which might be set for this + // DN and remove them List authorities = getUserAuthorities(dn, user.getUsername()); - if (authorities.size() > 0) { removeAuthorities(dn, authorities); } - addAuthorities(dn, user.getAuthorities()); } + @Override public void updateUser(UserDetails user) { - DistinguishedName dn = usernameMapper.buildDn(user.getUsername()); - - logger.debug("Updating user '" + user.getUsername() + "' with DN '" + dn + "'"); - + DistinguishedName dn = this.usernameMapper.buildDn(user.getUsername()); + this.logger.debug(LogMessage.format("Updating new user '%s' with DN '%s'", user.getUsername(), dn)); List authorities = getUserAuthorities(dn, user.getUsername()); - DirContextAdapter ctx = loadUserAsContext(dn, user.getUsername()); ctx.setUpdateMode(true); copyToContext(user, ctx); - // Remove the objectclass attribute from the list of mods (if present). - List mods = new LinkedList<>(Arrays.asList(ctx - .getModificationItems())); + List mods = new LinkedList<>(Arrays.asList(ctx.getModificationItems())); ListIterator modIt = mods.listIterator(); - while (modIt.hasNext()) { ModificationItem mod = modIt.next(); Attribute a = mod.getAttribute(); @@ -270,74 +247,67 @@ public class LdapUserDetailsManager implements UserDetailsManager { modIt.remove(); } } - - template.modifyAttributes(dn, mods.toArray(new ModificationItem[0])); - + this.template.modifyAttributes(dn, mods.toArray(new ModificationItem[0])); // template.rebind(dn, ctx, null); // Remove the old authorities and replace them with the new one removeAuthorities(dn, authorities); addAuthorities(dn, user.getAuthorities()); } + @Override public void deleteUser(String username) { - DistinguishedName dn = usernameMapper.buildDn(username); + DistinguishedName dn = this.usernameMapper.buildDn(username); removeAuthorities(dn, getUserAuthorities(dn, username)); - template.unbind(dn); + this.template.unbind(dn); } + @Override public boolean userExists(String username) { - DistinguishedName dn = usernameMapper.buildDn(username); - + DistinguishedName dn = this.usernameMapper.buildDn(username); try { - Object obj = template.lookup(dn); + Object obj = this.template.lookup(dn); if (obj instanceof Context) { LdapUtils.closeContext((Context) obj); } return true; } - catch (org.springframework.ldap.NameNotFoundException e) { + catch (org.springframework.ldap.NameNotFoundException ex) { return false; } } /** * Creates a DN from a group name. - * * @param group the name of the group * @return the DN of the corresponding group, including the groupSearchBase */ protected DistinguishedName buildGroupDn(String group) { - DistinguishedName dn = new DistinguishedName(groupSearchBase); - dn.add(groupRoleAttributeName, group.toLowerCase()); - + DistinguishedName dn = new DistinguishedName(this.groupSearchBase); + dn.add(this.groupRoleAttributeName, group.toLowerCase()); return dn; } protected void copyToContext(UserDetails user, DirContextAdapter ctx) { - userDetailsMapper.mapUserToContext(user, ctx); + this.userDetailsMapper.mapUserToContext(user, ctx); } - protected void addAuthorities(DistinguishedName userDn, - Collection authorities) { + protected void addAuthorities(DistinguishedName userDn, Collection authorities) { modifyAuthorities(userDn, authorities, DirContext.ADD_ATTRIBUTE); } - protected void removeAuthorities(DistinguishedName userDn, - Collection authorities) { + protected void removeAuthorities(DistinguishedName userDn, Collection authorities) { modifyAuthorities(userDn, authorities, DirContext.REMOVE_ATTRIBUTE); } private void modifyAuthorities(final DistinguishedName userDn, final Collection authorities, final int modType) { - template.executeReadWrite((ContextExecutor) ctx -> { + this.template.executeReadWrite((ContextExecutor) (ctx) -> { for (GrantedAuthority authority : authorities) { String group = convertAuthorityToGroup(authority); DistinguishedName fullDn = LdapUtils.getFullDn(userDn, ctx); ModificationItem addGroup = new ModificationItem(modType, - new BasicAttribute(groupMemberAttributeName, fullDn.toUrl())); - - ctx.modifyAttributes(buildGroupDn(group), - new ModificationItem[] { addGroup }); + new BasicAttribute(this.groupMemberAttributeName, fullDn.toUrl())); + ctx.modifyAttributes(buildGroupDn(group), new ModificationItem[] { addGroup }); } return null; }); @@ -345,11 +315,9 @@ public class LdapUserDetailsManager implements UserDetailsManager { private String convertAuthorityToGroup(GrantedAuthority authority) { String group = authority.getAuthority(); - - if (group.startsWith(rolePrefix)) { - group = group.substring(rolePrefix.length()); + if (group.startsWith(this.rolePrefix)) { + group = group.substring(this.rolePrefix.length()); } - return group; } @@ -384,7 +352,6 @@ public class LdapUserDetailsManager implements UserDetailsManager { *

        * Usually this will be uniquemember (the default value) or member. *

        - * * @param groupMemberAttributeName the name of the attribute used to store group * members. */ @@ -401,17 +368,19 @@ public class LdapUserDetailsManager implements UserDetailsManager { /** * Sets the method by which a user's password gets modified. * - * If set to {@code true}, then {@link LdapUserDetailsManager#changePassword} will modify - * the user's password by way of the - * Password Modify Extension Operation. + * If set to {@code true}, then {@link LdapUserDetailsManager#changePassword} will + * modify the user's password by way of the + * Password Modify + * Extension Operation. * - * If set to {@code false}, then {@link LdapUserDetailsManager#changePassword} will modify - * the user's password by directly modifying attributes on the corresponding entry. + * If set to {@code false}, then {@link LdapUserDetailsManager#changePassword} will + * modify the user's password by directly modifying attributes on the corresponding + * entry. * - * Before using this setting, ensure that the corresponding LDAP server supports this extended operation. + * Before using this setting, ensure that the corresponding LDAP server supports this + * extended operation. * * By default, {@code usePasswordModifyExtensionOperation} is false. - * * @param usePasswordModifyExtensionOperation * @since 4.2.9 */ @@ -419,95 +388,82 @@ public class LdapUserDetailsManager implements UserDetailsManager { this.usePasswordModifyExtensionOperation = usePasswordModifyExtensionOperation; } - private void changePasswordUsingAttributeModification - (DistinguishedName userDn, String oldPassword, String newPassword) { - - final ModificationItem[] passwordChange = new ModificationItem[] { new ModificationItem( - DirContext.REPLACE_ATTRIBUTE, new BasicAttribute(passwordAttributeName, - newPassword)) }; - + private void changePasswordUsingAttributeModification(DistinguishedName userDn, String oldPassword, + String newPassword) { + ModificationItem[] passwordChange = new ModificationItem[] { new ModificationItem(DirContext.REPLACE_ATTRIBUTE, + new BasicAttribute(this.passwordAttributeName, newPassword)) }; if (oldPassword == null) { - template.modifyAttributes(userDn, passwordChange); + this.template.modifyAttributes(userDn, passwordChange); return; } - - template.executeReadWrite(dirCtx -> { + this.template.executeReadWrite((dirCtx) -> { LdapContext ctx = (LdapContext) dirCtx; ctx.removeFromEnvironment("com.sun.jndi.ldap.connect.pool"); - ctx.addToEnvironment(Context.SECURITY_PRINCIPAL, - LdapUtils.getFullDn(userDn, ctx).toString()); + ctx.addToEnvironment(Context.SECURITY_PRINCIPAL, LdapUtils.getFullDn(userDn, ctx).toString()); ctx.addToEnvironment(Context.SECURITY_CREDENTIALS, oldPassword); // TODO: reconnect doesn't appear to actually change the credentials try { ctx.reconnect(null); - } catch (javax.naming.AuthenticationException e) { - throw new BadCredentialsException( - "Authentication for password change failed."); } - + catch (javax.naming.AuthenticationException ex) { + throw new BadCredentialsException("Authentication for password change failed."); + } ctx.modifyAttributes(userDn, passwordChange); - return null; }); - } - private void changePasswordUsingExtensionOperation - (DistinguishedName userDn, String oldPassword, String newPassword) { - - template.executeReadWrite(dirCtx -> { + private void changePasswordUsingExtensionOperation(DistinguishedName userDn, String oldPassword, + String newPassword) { + this.template.executeReadWrite((dirCtx) -> { LdapContext ctx = (LdapContext) dirCtx; - String userIdentity = LdapUtils.getFullDn(userDn, ctx).encode(); - PasswordModifyRequest request = - new PasswordModifyRequest(userIdentity, oldPassword, newPassword); - + PasswordModifyRequest request = new PasswordModifyRequest(userIdentity, oldPassword, newPassword); try { return ctx.extendedOperation(request); - } catch (javax.naming.AuthenticationException e) { - throw new BadCredentialsException( - "Authentication for password change failed."); + } + catch (javax.naming.AuthenticationException ex) { + throw new BadCredentialsException("Authentication for password change failed."); } }); } /** * An implementation of the - * - * LDAP Password Modify Extended Operation - * - * client request. + * LDAP Password Modify + * Extended Operation client request. * - * Can be directed at any LDAP server that supports the Password Modify Extended Operation. + * Can be directed at any LDAP server that supports the Password Modify Extended + * Operation. * * @author Josh Cummings * @since 4.2.9 */ private static class PasswordModifyRequest implements ExtendedRequest { + private static final byte SEQUENCE_TYPE = 48; private static final String PASSWORD_MODIFY_OID = "1.3.6.1.4.1.4203.1.11.1"; + private static final byte USER_IDENTITY_OCTET_TYPE = -128; + private static final byte OLD_PASSWORD_OCTET_TYPE = -127; + private static final byte NEW_PASSWORD_OCTET_TYPE = -126; private final ByteArrayOutputStream value = new ByteArrayOutputStream(); PasswordModifyRequest(String userIdentity, String oldPassword, String newPassword) { ByteArrayOutputStream elements = new ByteArrayOutputStream(); - if (userIdentity != null) { berEncode(USER_IDENTITY_OCTET_TYPE, userIdentity.getBytes(), elements); } - if (oldPassword != null) { berEncode(OLD_PASSWORD_OCTET_TYPE, oldPassword.getBytes(), elements); } - if (newPassword != null) { berEncode(NEW_PASSWORD_OCTET_TYPE, newPassword.getBytes(), elements); } - berEncode(SEQUENCE_TYPE, elements.toByteArray(), this.value); } @@ -527,44 +483,47 @@ public class LdapUserDetailsManager implements UserDetailsManager { } /** - * Only minimal support for - * - * BER encoding - * ; just what is necessary for the Password Modify request. + * Only minimal support for BER + * encoding ; just what is necessary for the Password Modify request. * */ private void berEncode(byte type, byte[] src, ByteArrayOutputStream dest) { int length = src.length; - dest.write(type); - if (length < 128) { dest.write(length); - } else if ((length & 0x0000_00FF) == length) { + } + else if ((length & 0x0000_00FF) == length) { dest.write((byte) 0x81); dest.write((byte) (length & 0xFF)); - } else if ((length & 0x0000_FFFF) == length) { + } + else if ((length & 0x0000_FFFF) == length) { dest.write((byte) 0x82); dest.write((byte) ((length >> 8) & 0xFF)); dest.write((byte) (length & 0xFF)); - } else if ((length & 0x00FF_FFFF) == length) { + } + else if ((length & 0x00FF_FFFF) == length) { dest.write((byte) 0x83); dest.write((byte) ((length >> 16) & 0xFF)); dest.write((byte) ((length >> 8) & 0xFF)); dest.write((byte) (length & 0xFF)); - } else { + } + else { dest.write((byte) 0x84); dest.write((byte) ((length >> 24) & 0xFF)); dest.write((byte) ((length >> 16) & 0xFF)); dest.write((byte) ((length >> 8) & 0xFF)); dest.write((byte) (length & 0xFF)); } - try { dest.write(src); - } catch (IOException e) { + } + catch (IOException ex) { throw new IllegalArgumentException("Failed to BER encode provided value of type: " + type); } } + } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapper.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapper.java index 89df68fca1..56e724a075 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapper.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapper.java @@ -21,6 +21,7 @@ import java.util.Collection; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.core.GrantedAuthority; @@ -38,97 +39,75 @@ import org.springframework.util.Assert; * @author Eddú Meléndez */ public class LdapUserDetailsMapper implements UserDetailsContextMapper { - // ~ Instance fields - // ================================================================================================ private final Log logger = LogFactory.getLog(LdapUserDetailsMapper.class); - private String passwordAttributeName = "userPassword"; - private String rolePrefix = "ROLE_"; - private String[] roleAttributes = null; - private boolean convertToUpperCase = true; - // ~ Methods - // ======================================================================================================== + private String passwordAttributeName = "userPassword"; + + private String rolePrefix = "ROLE_"; + + private String[] roleAttributes = null; + + private boolean convertToUpperCase = true; @Override public UserDetails mapUserFromContext(DirContextOperations ctx, String username, Collection authorities) { String dn = ctx.getNameInNamespace(); - - this.logger.debug("Mapping user details from context with DN: " + dn); - + this.logger.debug(LogMessage.format("Mapping user details from context with DN: %s", dn)); LdapUserDetailsImpl.Essence essence = new LdapUserDetailsImpl.Essence(); essence.setDn(dn); - Object passwordValue = ctx.getObjectAttribute(this.passwordAttributeName); - if (passwordValue != null) { essence.setPassword(mapPassword(passwordValue)); } - essence.setUsername(username); - // Map the roles - for (int i = 0; (this.roleAttributes != null) - && (i < this.roleAttributes.length); i++) { + for (int i = 0; (this.roleAttributes != null) && (i < this.roleAttributes.length); i++) { String[] rolesForAttribute = ctx.getStringAttributes(this.roleAttributes[i]); - if (rolesForAttribute == null) { - this.logger.debug("Couldn't read role attribute '" - + this.roleAttributes[i] + "' for user " + dn); + this.logger.debug( + LogMessage.format("Couldn't read role attribute '%s' for user $s", this.roleAttributes[i], dn)); continue; } - for (String role : rolesForAttribute) { GrantedAuthority authority = createAuthority(role); - if (authority != null) { essence.addAuthority(authority); } } } - // Add the supplied authorities - for (GrantedAuthority authority : authorities) { essence.addAuthority(authority); } - // Check for PPolicy data - PasswordPolicyResponseControl ppolicy = (PasswordPolicyResponseControl) ctx .getObjectAttribute(PasswordPolicyControl.OID); - if (ppolicy != null) { essence.setTimeBeforeExpiration(ppolicy.getTimeBeforeExpiration()); essence.setGraceLoginsRemaining(ppolicy.getGraceLoginsRemaining()); } - return essence.createUserDetails(); - } @Override public void mapUserToContext(UserDetails user, DirContextAdapter ctx) { - throw new UnsupportedOperationException( - "LdapUserDetailsMapper only supports reading from a context. Please" - + " use a subclass if mapUserToContext() is required."); + throw new UnsupportedOperationException("LdapUserDetailsMapper only supports reading from a context. Please" + + " use a subclass if mapUserToContext() is required."); } /** * Extension point to allow customized creation of the user's password from the * attribute stored in the directory. - * * @param passwordValue the value of the password attribute * @return a String representation of the password. */ protected String mapPassword(Object passwordValue) { - if (!(passwordValue instanceof String)) { // Assume it's binary passwordValue = new String((byte[]) passwordValue); } - return (String) passwordValue; } @@ -141,7 +120,6 @@ public class LdapUserDetailsMapper implements UserDetailsContextMapper { * rolePrefix and convertToUpperCase properties. Non-String * attributes are ignored. *

        - * * @param role the attribute returned from * @return the authority to be added to the list of authorities for the user, or null * if this attribute should be ignored. @@ -159,7 +137,6 @@ public class LdapUserDetailsMapper implements UserDetailsContextMapper { /** * Determines whether role field values will be converted to upper case when loaded. * The default is true. - * * @param convertToUpperCase true if the roles should be converted to upper case. */ public void setConvertToUpperCase(boolean convertToUpperCase) { @@ -169,7 +146,6 @@ public class LdapUserDetailsMapper implements UserDetailsContextMapper { /** * The name of the attribute which contains the user's password. Defaults to * "userPassword". - * * @param passwordAttributeName the name of the attribute */ public void setPasswordAttributeName(String passwordAttributeName) { @@ -180,7 +156,6 @@ public class LdapUserDetailsMapper implements UserDetailsContextMapper { * The names of any attributes in the user's entry which represent application roles. * These will be converted to GrantedAuthoritys and added to the list in the * returned LdapUserDetails object. The attribute values must be Strings by default. - * * @param roleAttributes the names of the role attributes. */ public void setRoleAttributes(String[] roleAttributes) { @@ -195,4 +170,5 @@ public class LdapUserDetailsMapper implements UserDetailsContextMapper { public void setRolePrefix(String rolePrefix) { this.rolePrefix = rolePrefix; } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsService.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsService.java index 8a91097353..94a592c68e 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsService.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/LdapUserDetailsService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import java.util.Collection; @@ -35,28 +36,29 @@ import org.springframework.util.Assert; * @author Luke Taylor */ public class LdapUserDetailsService implements UserDetailsService { + private final LdapUserSearch userSearch; + private final LdapAuthoritiesPopulator authoritiesPopulator; + private UserDetailsContextMapper userDetailsMapper = new LdapUserDetailsMapper(); public LdapUserDetailsService(LdapUserSearch userSearch) { this(userSearch, new NullLdapAuthoritiesPopulator()); } - public LdapUserDetailsService(LdapUserSearch userSearch, - LdapAuthoritiesPopulator authoritiesPopulator) { + public LdapUserDetailsService(LdapUserSearch userSearch, LdapAuthoritiesPopulator authoritiesPopulator) { Assert.notNull(userSearch, "userSearch must not be null"); Assert.notNull(authoritiesPopulator, "authoritiesPopulator must not be null"); this.userSearch = userSearch; this.authoritiesPopulator = authoritiesPopulator; } - public UserDetails loadUserByUsername(String username) - throws UsernameNotFoundException { - DirContextOperations userData = userSearch.searchForUser(username); - - return userDetailsMapper.mapUserFromContext(userData, username, - authoritiesPopulator.getGrantedAuthorities(userData, username)); + @Override + public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException { + DirContextOperations userData = this.userSearch.searchForUser(username); + return this.userDetailsMapper.mapUserFromContext(userData, username, + this.authoritiesPopulator.getGrantedAuthorities(userData, username)); } public void setUserDetailsMapper(UserDetailsContextMapper userDetailsMapper) { @@ -64,11 +66,13 @@ public class LdapUserDetailsService implements UserDetailsService { this.userDetailsMapper = userDetailsMapper; } - private static final class NullLdapAuthoritiesPopulator implements - LdapAuthoritiesPopulator { - public Collection getGrantedAuthorities( - DirContextOperations userDetails, String username) { + private static final class NullLdapAuthoritiesPopulator implements LdapAuthoritiesPopulator { + + @Override + public Collection getGrantedAuthorities(DirContextOperations userDetails, String username) { return AuthorityUtils.NO_AUTHORITIES; } + } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulator.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulator.java index 07e37b4cef..33d55d7c5b 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulator.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/NestedLdapAuthoritiesPopulator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import java.util.HashSet; @@ -23,6 +24,7 @@ import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.ldap.core.ContextSource; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.ldap.SpringSecurityLdapTemplate; @@ -119,8 +121,8 @@ import org.springframework.util.StringUtils; */ public class NestedLdapAuthoritiesPopulator extends DefaultLdapAuthoritiesPopulator { - private static final Log logger = LogFactory - .getLog(NestedLdapAuthoritiesPopulator.class); + + private static final Log logger = LogFactory.getLog(NestedLdapAuthoritiesPopulator.class); /** * The attribute names to retrieve for each LDAP group @@ -135,78 +137,52 @@ public class NestedLdapAuthoritiesPopulator extends DefaultLdapAuthoritiesPopula /** * Constructor for group search scenarios. userRoleAttributes may still be * set as a property. - * * @param contextSource supplies the contexts used to search for user roles. * @param groupSearchBase if this is an empty string the search will be performed from * the root DN of the */ - public NestedLdapAuthoritiesPopulator(ContextSource contextSource, - String groupSearchBase) { + public NestedLdapAuthoritiesPopulator(ContextSource contextSource, String groupSearchBase) { super(contextSource, groupSearchBase); } - /** - * {@inheritDoc} - */ @Override public Set getGroupMembershipRoles(String userDn, String username) { if (getGroupSearchBase() == null) { return new HashSet<>(); } - Set authorities = new HashSet<>(); - performNestedSearch(userDn, username, authorities, getMaxSearchDepth()); - return authorities; } /** * Performs the nested group search - * * @param userDn - the userDN to search for, will become the group DN for subsequent * searches * @param username - the username of the user * @param authorities - the authorities set that will be populated, must not be null * @param depth - the depth remaining, when 0 recursion will end */ - private void performNestedSearch(String userDn, String username, - Set authorities, int depth) { + private void performNestedSearch(String userDn, String username, Set authorities, int depth) { if (depth == 0) { // back out of recursion - if (logger.isDebugEnabled()) { - logger.debug("Search aborted, max depth reached," - + " for roles for user '" + username + "', DN = " + "'" + userDn - + "', with filter " + getGroupSearchFilter() + " in search base '" - + getGroupSearchBase() + "'"); - } + logger.debug(LogMessage.of(() -> "Search aborted, max depth reached," + " for roles for user '" + username + + "', DN = " + "'" + userDn + "', with filter " + getGroupSearchFilter() + " in search base '" + + getGroupSearchBase() + "'")); return; } - - if (logger.isDebugEnabled()) { - logger.debug("Searching for roles for user '" + username + "', DN = " + "'" - + userDn + "', with filter " + getGroupSearchFilter() - + " in search base '" + getGroupSearchBase() + "'"); - } - + logger.debug(LogMessage.of(() -> "Searching for roles for user '" + username + "', DN = " + "'" + userDn + + "', with filter " + getGroupSearchFilter() + " in search base '" + getGroupSearchBase() + "'")); if (getAttributeNames() == null) { setAttributeNames(new HashSet<>()); } - if (StringUtils.hasText(getGroupRoleAttribute()) - && !getAttributeNames().contains(getGroupRoleAttribute())) { + if (StringUtils.hasText(getGroupRoleAttribute()) && !getAttributeNames().contains(getGroupRoleAttribute())) { getAttributeNames().add(getGroupRoleAttribute()); } - - Set>> userRoles = getLdapTemplate() - .searchForMultipleAttributeValues(getGroupSearchBase(), - getGroupSearchFilter(), new String[] { userDn, username }, - getAttributeNames() - .toArray(new String[0])); - - if (logger.isDebugEnabled()) { - logger.debug("Roles from search: " + userRoles); - } - + Set>> userRoles = getLdapTemplate().searchForMultipleAttributeValues( + getGroupSearchBase(), getGroupSearchFilter(), new String[] { userDn, username }, + getAttributeNames().toArray(new String[0])); + logger.debug(LogMessage.format("Roles from search: %s", userRoles)); for (Map> record : userRoles) { boolean circular = false; String dn = record.get(SpringSecurityLdapTemplate.DN_KEY).get(0); @@ -222,21 +198,18 @@ public class NestedLdapAuthoritiesPopulator extends DefaultLdapAuthoritiesPopula role = getRolePrefix() + role; // if the group already exist, we will not search for it's parents again. // this prevents a forever loop for a misconfigured ldap directory - circular = circular - | (!authorities.add(new LdapAuthority(role, dn, record))); + circular = circular | (!authorities.add(new LdapAuthority(role, dn, record))); } - String roleName = roles.size() > 0 ? roles.iterator().next() : dn; + String roleName = (roles.size() > 0) ? roles.iterator().next() : dn; if (!circular) { performNestedSearch(dn, roleName, authorities, (depth - 1)); } - } } /** * Returns the attribute names that this populator has been configured to retrieve * Value can be null, represents fetch all attributes - * * @return the attribute names or null for all */ private Set getAttributeNames() { @@ -245,7 +218,6 @@ public class NestedLdapAuthoritiesPopulator extends DefaultLdapAuthoritiesPopula /** * Sets the attribute names to retrieve for each ldap groups. Null means retrieve all - * * @param attributeNames - the names of the LDAP attributes to retrieve */ public void setAttributeNames(Set attributeNames) { @@ -255,7 +227,6 @@ public class NestedLdapAuthoritiesPopulator extends DefaultLdapAuthoritiesPopula /** * How far should a nested search go. Depth is calculated in the number of levels we * search up for parent groups. - * * @return the max search depth, default is 10 */ private int getMaxSearchDepth() { @@ -265,7 +236,6 @@ public class NestedLdapAuthoritiesPopulator extends DefaultLdapAuthoritiesPopula /** * How far should a nested search go. Depth is calculated in the number of levels we * search up for parent groups. - * * @param maxSearchDepth the max search depth */ public void setMaxSearchDepth(int maxSearchDepth) { diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/Person.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/Person.java index f68f5a6736..5dbc580cf8 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/Person.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/Person.java @@ -13,19 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; -import org.springframework.security.core.SpringSecurityCoreVersion; -import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DirContextOperations; - +import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.ldap.LdapUtils; - -import java.util.List; -import java.util.ArrayList; -import java.util.Arrays; +import org.springframework.util.Assert; /** * UserDetails implementation whose properties are based on the LDAP schema for @@ -39,41 +38,44 @@ public class Person extends LdapUserDetailsImpl { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private String givenName; + private String sn; + private String description; + private String telephoneNumber; + private List cn = new ArrayList<>(); protected Person() { } public String getGivenName() { - return givenName; + return this.givenName; } public String getSn() { - return sn; + return this.sn; } public String[] getCn() { - return cn.toArray(new String[0]); + return this.cn.toArray(new String[0]); } public String getDescription() { - return description; + return this.description; } public String getTelephoneNumber() { - return telephoneNumber; + return this.telephoneNumber; } protected void populateContext(DirContextAdapter adapter) { - adapter.setAttributeValue("givenName", givenName); - adapter.setAttributeValue("sn", sn); + adapter.setAttributeValue("givenName", this.givenName); + adapter.setAttributeValue("sn", this.sn); adapter.setAttributeValues("cn", getCn()); adapter.setAttributeValue("description", getDescription()); adapter.setAttributeValue("telephoneNumber", getTelephoneNumber()); - if (getPassword() != null) { adapter.setAttributeValue("userPassword", getPassword()); } @@ -92,11 +94,9 @@ public class Person extends LdapUserDetailsImpl { setSn(ctx.getStringAttribute("sn")); setDescription(ctx.getStringAttribute("description")); setTelephoneNumber(ctx.getStringAttribute("telephoneNumber")); - Object passo = ctx.getObjectAttribute("userPassword"); - - if (passo != null) { - String password = LdapUtils.convertPasswordToString(passo); - setPassword(password); + Object password = ctx.getObjectAttribute("userPassword"); + if (password != null) { + setPassword(LdapUtils.convertPasswordToString(password)); } } @@ -106,37 +106,39 @@ public class Person extends LdapUserDetailsImpl { setSn(copyMe.sn); setDescription(copyMe.getDescription()); setTelephoneNumber(copyMe.getTelephoneNumber()); - ((Person) instance).cn = new ArrayList<>(copyMe.cn); + ((Person) this.instance).cn = new ArrayList<>(copyMe.cn); } + @Override protected LdapUserDetailsImpl createTarget() { return new Person(); } public void setGivenName(String givenName) { - ((Person) instance).givenName = givenName; + ((Person) this.instance).givenName = givenName; } public void setSn(String sn) { - ((Person) instance).sn = sn; + ((Person) this.instance).sn = sn; } public void setCn(String[] cn) { - ((Person) instance).cn = Arrays.asList(cn); + ((Person) this.instance).cn = Arrays.asList(cn); } public void addCn(String value) { - ((Person) instance).cn.add(value); + ((Person) this.instance).cn.add(value); } public void setTelephoneNumber(String tel) { - ((Person) instance).telephoneNumber = tel; + ((Person) this.instance).telephoneNumber = tel; } public void setDescription(String desc) { - ((Person) instance).description = desc; + ((Person) this.instance).description = desc; } + @Override public LdapUserDetails createUserDetails() { Person p = (Person) super.createUserDetails(); Assert.notNull(p.cn, "person.sn cannot be null"); @@ -144,5 +146,7 @@ public class Person extends LdapUserDetailsImpl { // TODO: Check contents for null entries return p; } + } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/PersonContextMapper.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/PersonContextMapper.java index ba09ac79a0..3662112afa 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/PersonContextMapper.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/PersonContextMapper.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import java.util.Collection; +import org.springframework.ldap.core.DirContextAdapter; +import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.ldap.core.DirContextOperations; -import org.springframework.ldap.core.DirContextAdapter; import org.springframework.util.Assert; /** @@ -28,21 +29,20 @@ import org.springframework.util.Assert; */ public class PersonContextMapper implements UserDetailsContextMapper { + @Override public UserDetails mapUserFromContext(DirContextOperations ctx, String username, Collection authorities) { Person.Essence p = new Person.Essence(ctx); - p.setUsername(username); p.setAuthorities(authorities); - return p.createUserDetails(); - } + @Override public void mapUserToContext(UserDetails user, DirContextAdapter ctx) { Assert.isInstanceOf(Person.class, user, "UserDetails must be a Person instance"); - Person p = (Person) user; p.populateContext(ctx); } + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/UserDetailsContextMapper.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/UserDetailsContextMapper.java index 6c59f10ac2..9c4a9e7294 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/UserDetailsContextMapper.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/UserDetailsContextMapper.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import java.util.Collection; +import org.springframework.ldap.core.DirContextAdapter; +import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.ldap.core.DirContextOperations; -import org.springframework.ldap.core.DirContextAdapter; /** * Operations to map a UserDetails object to and from a Spring LDAP @@ -36,7 +37,6 @@ public interface UserDetailsContextMapper { /** * Creates a fully populated UserDetails object for use by the security framework. - * * @param ctx the context object which contains the user information. * @param username the user's supplied login name. * @param authorities @@ -50,4 +50,5 @@ public interface UserDetailsContextMapper { * object. Called when saving a user, for example. */ void mapUserToContext(UserDetails user, DirContextAdapter ctx); + } diff --git a/ldap/src/main/java/org/springframework/security/ldap/userdetails/package-info.java b/ldap/src/main/java/org/springframework/security/ldap/userdetails/package-info.java index e53eeade05..d4beff61c2 100644 --- a/ldap/src/main/java/org/springframework/security/ldap/userdetails/package-info.java +++ b/ldap/src/main/java/org/springframework/security/ldap/userdetails/package-info.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * LDAP-focused {@code UserDetails} implementations which map from a ubset of the data * contained in some of the standard LDAP types (such as {@code InetOrgPerson}). */ package org.springframework.security.ldap.userdetails; - diff --git a/ldap/src/test/java/org/springframework/security/ldap/LdapUtilsTests.java b/ldap/src/test/java/org/springframework/security/ldap/LdapUtilsTests.java index 8c9175a04a..ae84d19654 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/LdapUtilsTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/LdapUtilsTests.java @@ -16,14 +16,16 @@ package org.springframework.security.ldap; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import javax.naming.NamingException; import javax.naming.directory.DirContext; import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; + /** * Tests {@link LdapUtils} * @@ -31,42 +33,34 @@ import org.junit.Test; */ public class LdapUtilsTests { - // ~ Methods - // ======================================================================================================== - @Test public void testCloseContextSwallowsNamingException() throws Exception { final DirContext dirCtx = mock(DirContext.class); - doThrow(new NamingException()).when(dirCtx).close(); - + willThrow(new NamingException()).given(dirCtx).close(); LdapUtils.closeContext(dirCtx); } @Test - public void testGetRelativeNameReturnsEmptyStringForDnEqualToBaseName() - throws Exception { + public void testGetRelativeNameReturnsEmptyStringForDnEqualToBaseName() throws Exception { final DirContext mockCtx = mock(DirContext.class); - - when(mockCtx.getNameInNamespace()).thenReturn("dc=springframework,dc=org"); - + given(mockCtx.getNameInNamespace()).willReturn("dc=springframework,dc=org"); assertThat(LdapUtils.getRelativeName("dc=springframework,dc=org", mockCtx)).isEqualTo(""); } @Test public void testGetRelativeNameReturnsFullDnWithEmptyBaseName() throws Exception { final DirContext mockCtx = mock(DirContext.class); - when(mockCtx.getNameInNamespace()).thenReturn(""); - - assertThat(LdapUtils.getRelativeName("cn=jane,dc=springframework,dc=org", mockCtx)).isEqualTo("cn=jane,dc=springframework,dc=org"); + given(mockCtx.getNameInNamespace()).willReturn(""); + assertThat(LdapUtils.getRelativeName("cn=jane,dc=springframework,dc=org", mockCtx)) + .isEqualTo("cn=jane,dc=springframework,dc=org"); } @Test public void testGetRelativeNameWorksWithArbitrarySpaces() throws Exception { final DirContext mockCtx = mock(DirContext.class); - when(mockCtx.getNameInNamespace()).thenReturn("dc=springsecurity,dc = org"); - - assertThat(LdapUtils.getRelativeName( - "cn=jane smith, dc = springsecurity , dc=org", mockCtx)).isEqualTo("cn=jane smith"); + given(mockCtx.getNameInNamespace()).willReturn("dc=springsecurity,dc = org"); + assertThat(LdapUtils.getRelativeName("cn=jane smith, dc = springsecurity , dc=org", mockCtx)) + .isEqualTo("cn=jane smith"); } @Test @@ -75,19 +69,16 @@ public class LdapUtilsTests { assertThat(LdapUtils.parseRootDnFromUrl("ldap://monkeymachine:11389")).isEqualTo(""); assertThat(LdapUtils.parseRootDnFromUrl("ldap://monkeymachine/")).isEqualTo(""); assertThat(LdapUtils.parseRootDnFromUrl("ldap://monkeymachine.co.uk/")).isEqualTo(""); - assertThat( - LdapUtils - .parseRootDnFromUrl("ldaps://monkeymachine.co.uk/dc=springframework,dc=org")).isEqualTo("dc=springframework,dc=org"); - assertThat( - LdapUtils.parseRootDnFromUrl("ldap:///dc=springframework,dc=org")).isEqualTo("dc=springframework,dc=org"); - assertThat( - LdapUtils - .parseRootDnFromUrl("ldap://monkeymachine/dc=springframework,dc=org")).isEqualTo("dc=springframework,dc=org"); - assertThat( - LdapUtils - .parseRootDnFromUrl("ldap://monkeymachine.co.uk/dc=springframework,dc=org/ou=blah")).isEqualTo("dc=springframework,dc=org/ou=blah"); - assertThat( - LdapUtils - .parseRootDnFromUrl("ldap://monkeymachine.co.uk:389/dc=springframework,dc=org/ou=blah")).isEqualTo("dc=springframework,dc=org/ou=blah"); + assertThat(LdapUtils.parseRootDnFromUrl("ldaps://monkeymachine.co.uk/dc=springframework,dc=org")) + .isEqualTo("dc=springframework,dc=org"); + assertThat(LdapUtils.parseRootDnFromUrl("ldap:///dc=springframework,dc=org")) + .isEqualTo("dc=springframework,dc=org"); + assertThat(LdapUtils.parseRootDnFromUrl("ldap://monkeymachine/dc=springframework,dc=org")) + .isEqualTo("dc=springframework,dc=org"); + assertThat(LdapUtils.parseRootDnFromUrl("ldap://monkeymachine.co.uk/dc=springframework,dc=org/ou=blah")) + .isEqualTo("dc=springframework,dc=org/ou=blah"); + assertThat(LdapUtils.parseRootDnFromUrl("ldap://monkeymachine.co.uk:389/dc=springframework,dc=org/ou=blah")) + .isEqualTo("dc=springframework,dc=org/ou=blah"); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityAuthenticationSourceTests.java b/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityAuthenticationSourceTests.java index 664a05dde7..1b8f25afad 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityAuthenticationSourceTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityAuthenticationSourceTests.java @@ -13,26 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.ldap.core.AuthenticationSource; +import org.springframework.ldap.core.DistinguishedName; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.ldap.authentication.SpringSecurityAuthenticationSource; import org.springframework.security.ldap.userdetails.LdapUserDetailsImpl; -import org.springframework.ldap.core.AuthenticationSource; -import org.springframework.ldap.core.DistinguishedName; -import org.junit.After; import static org.assertj.core.api.Assertions.assertThat; -import org.junit.Before; -import org.junit.Test; /** * @author Luke Taylor */ public class SpringSecurityAuthenticationSourceTests { + @Before @After public void clearContext() { @@ -49,28 +52,22 @@ public class SpringSecurityAuthenticationSourceTests { @Test public void principalIsEmptyForAnonymousUser() { AuthenticationSource source = new SpringSecurityAuthenticationSource(); - SecurityContextHolder.getContext().setAuthentication( - new AnonymousAuthenticationToken("key", "anonUser", AuthorityUtils - .createAuthorityList("ignored"))); + new AnonymousAuthenticationToken("key", "anonUser", AuthorityUtils.createAuthorityList("ignored"))); assertThat(source.getPrincipal()).isEqualTo(""); } @Test(expected = IllegalArgumentException.class) public void getPrincipalRejectsNonLdapUserDetailsObject() { AuthenticationSource source = new SpringSecurityAuthenticationSource(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken(new Object(), "password")); - + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken(new Object(), "password")); source.getPrincipal(); } @Test public void expectedCredentialsAreReturned() { AuthenticationSource source = new SpringSecurityAuthenticationSource(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken(new Object(), "password")); - + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken(new Object(), "password")); assertThat(source.getCredentials()).isEqualTo("password"); } @@ -80,9 +77,9 @@ public class SpringSecurityAuthenticationSourceTests { user.setUsername("joe"); user.setDn(new DistinguishedName("uid=joe,ou=users")); AuthenticationSource source = new SpringSecurityAuthenticationSource(); - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken(user.createUserDetails(), null)); - + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken(user.createUserDetails(), null)); assertThat(source.getPrincipal()).isEqualTo("uid=joe,ou=users"); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateTests.java b/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateTests.java index 5b7e362392..5494ae1490 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/SpringSecurityLdapTemplateTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +package org.springframework.security.ldap; import javax.naming.NamingEnumeration; import javax.naming.directory.DirContext; @@ -29,18 +27,28 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DistinguishedName; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + @RunWith(MockitoJUnitRunner.class) public class SpringSecurityLdapTemplateTests { @Mock private DirContext ctx; + @Captor private ArgumentCaptor searchControls; + @Mock private NamingEnumeration resultsEnum; + @Mock private SearchResult searchResult; @@ -52,18 +60,14 @@ public class SpringSecurityLdapTemplateTests { String searchResultName = "ldap://example.com/dc=springframework,dc=org"; Object[] params = new Object[] {}; DirContextAdapter searchResultObject = mock(DirContextAdapter.class); - - when( - ctx.search(any(DistinguishedName.class), eq(filter), eq(params), - searchControls.capture())).thenReturn(resultsEnum); - when(resultsEnum.hasMore()).thenReturn(true, false); - when(resultsEnum.next()).thenReturn(searchResult); - when(searchResult.getObject()).thenReturn(searchResultObject); - - SpringSecurityLdapTemplate.searchForSingleEntryInternal(ctx, - mock(SearchControls.class), base, filter, params); - - assertThat(searchControls.getValue().getReturningObjFlag()).isTrue(); + given(this.ctx.search(any(DistinguishedName.class), eq(filter), eq(params), this.searchControls.capture())) + .willReturn(this.resultsEnum); + given(this.resultsEnum.hasMore()).willReturn(true, false); + given(this.resultsEnum.next()).willReturn(this.searchResult); + given(this.searchResult.getObject()).willReturn(searchResultObject); + SpringSecurityLdapTemplate.searchForSingleEntryInternal(this.ctx, mock(SearchControls.class), base, filter, + params); + assertThat(this.searchControls.getValue().getReturningObjFlag()).isTrue(); } } diff --git a/ldap/src/test/java/org/springframework/security/ldap/authentication/LdapAuthenticationProviderTests.java b/ldap/src/test/java/org/springframework/security/ldap/authentication/LdapAuthenticationProviderTests.java index 87fd372b8b..cf6b6eefcc 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/authentication/LdapAuthenticationProviderTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/authentication/LdapAuthenticationProviderTests.java @@ -16,12 +16,10 @@ package org.springframework.security.ldap.authentication; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import java.util.*; +import java.util.Collection; import org.junit.Test; + import org.springframework.ldap.CommunicationException; import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DirContextOperations; @@ -37,6 +35,11 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.ldap.userdetails.LdapAuthoritiesPopulator; import org.springframework.security.ldap.userdetails.LdapUserDetailsMapper; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * Tests {@link LdapAuthenticationProvider}. * @@ -46,41 +49,32 @@ import org.springframework.security.ldap.userdetails.LdapUserDetailsMapper; */ public class LdapAuthenticationProviderTests { - // ~ Methods - // ======================================================================================================== - @Test public void testSupportsUsernamePasswordAuthenticationToken() { - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - new MockAuthenticator(), new MockAuthoritiesPopulator()); - + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(new MockAuthenticator(), + new MockAuthoritiesPopulator()); assertThat(ldapProvider.supports(UsernamePasswordAuthenticationToken.class)).isTrue(); } @Test public void testDefaultMapperIsSet() { - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - new MockAuthenticator(), new MockAuthoritiesPopulator()); - + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(new MockAuthenticator(), + new MockAuthoritiesPopulator()); assertThat(ldapProvider.getUserDetailsContextMapper() instanceof LdapUserDetailsMapper).isTrue(); } @Test public void testEmptyOrNullUserNameThrowsException() { - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - new MockAuthenticator(), new MockAuthoritiesPopulator()); - + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(new MockAuthenticator(), + new MockAuthoritiesPopulator()); try { - ldapProvider.authenticate(new UsernamePasswordAuthenticationToken(null, - "password")); + ldapProvider.authenticate(new UsernamePasswordAuthenticationToken(null, "password")); fail("Expected BadCredentialsException for empty username"); } catch (BadCredentialsException expected) { } - try { - ldapProvider.authenticate(new UsernamePasswordAuthenticationToken("", - "bobspassword")); + ldapProvider.authenticate(new UsernamePasswordAuthenticationToken("", "bobspassword")); fail("Expected BadCredentialsException for null username"); } catch (BadCredentialsException expected) { @@ -90,26 +84,18 @@ public class LdapAuthenticationProviderTests { @Test(expected = BadCredentialsException.class) public void usernameNotFoundExceptionIsHiddenByDefault() { final LdapAuthenticator authenticator = mock(LdapAuthenticator.class); - final UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken( - "joe", "password"); - when(authenticator.authenticate(joe)).thenThrow( - new UsernameNotFoundException("nobody")); - - LdapAuthenticationProvider provider = new LdapAuthenticationProvider( - authenticator); + final UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password"); + given(authenticator.authenticate(joe)).willThrow(new UsernameNotFoundException("nobody")); + LdapAuthenticationProvider provider = new LdapAuthenticationProvider(authenticator); provider.authenticate(joe); } @Test(expected = UsernameNotFoundException.class) public void usernameNotFoundExceptionIsNotHiddenIfConfigured() { final LdapAuthenticator authenticator = mock(LdapAuthenticator.class); - final UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken( - "joe", "password"); - when(authenticator.authenticate(joe)).thenThrow( - new UsernameNotFoundException("nobody")); - - LdapAuthenticationProvider provider = new LdapAuthenticationProvider( - authenticator); + final UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password"); + given(authenticator.authenticate(joe)).willThrow(new UsernameNotFoundException("nobody")); + LdapAuthenticationProvider provider = new LdapAuthenticationProvider(authenticator); provider.setHideUserNotFoundExceptions(false); provider.authenticate(joe); } @@ -117,16 +103,13 @@ public class LdapAuthenticationProviderTests { @Test public void normalUsage() { MockAuthoritiesPopulator populator = new MockAuthoritiesPopulator(); - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - new MockAuthenticator(), populator); + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(new MockAuthenticator(), populator); LdapUserDetailsMapper userMapper = new LdapUserDetailsMapper(); userMapper.setRoleAttributes(new String[] { "ou" }); ldapProvider.setUserDetailsContextMapper(userMapper); - assertThat(ldapProvider.getAuthoritiesPopulator()).isNotNull(); - - UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken( - "ben", "benspassword"); + UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken("ben", + "benspassword"); Object authDetails = new Object(); authRequest.setDetails(authDetails); Authentication authResult = ldapProvider.authenticate(authRequest); @@ -137,53 +120,42 @@ public class LdapAuthenticationProviderTests { assertThat(user.getPassword()).isEqualTo("{SHA}nFCebWjxfaLbHHG1Qk5UU4trbvQ="); assertThat(user.getUsername()).isEqualTo("ben"); assertThat(populator.getRequestedUsername()).isEqualTo("ben"); - - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ROLE_FROM_ENTRY"); - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ROLE_FROM_POPULATOR"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_FROM_ENTRY"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_FROM_POPULATOR"); } @Test public void passwordIsSetFromUserDataIfUseAuthenticationRequestCredentialsIsFalse() { - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - new MockAuthenticator(), new MockAuthoritiesPopulator()); + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(new MockAuthenticator(), + new MockAuthoritiesPopulator()); ldapProvider.setUseAuthenticationRequestCredentials(false); - - UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken( - "ben", "benspassword"); + UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken("ben", + "benspassword"); Authentication authResult = ldapProvider.authenticate(authRequest); assertThat(authResult.getCredentials()).isEqualTo("{SHA}nFCebWjxfaLbHHG1Qk5UU4trbvQ="); - } @Test public void useWithNullAuthoritiesPopulatorReturnsCorrectRole() { - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - new MockAuthenticator()); + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(new MockAuthenticator()); LdapUserDetailsMapper userMapper = new LdapUserDetailsMapper(); userMapper.setRoleAttributes(new String[] { "ou" }); ldapProvider.setUserDetailsContextMapper(userMapper); - UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken( - "ben", "benspassword"); - UserDetails user = (UserDetails) ldapProvider.authenticate(authRequest) - .getPrincipal(); + UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken("ben", + "benspassword"); + UserDetails user = (UserDetails) ldapProvider.authenticate(authRequest).getPrincipal(); assertThat(user.getAuthorities()).hasSize(1); - assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())) - .contains("ROLE_FROM_ENTRY"); + assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_FROM_ENTRY"); } @Test public void authenticateWithNamingException() { - UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken( - "ben", "benspassword"); + UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken("ben", + "benspassword"); LdapAuthenticator mockAuthenticator = mock(LdapAuthenticator.class); - CommunicationException expectedCause = new CommunicationException( - new javax.naming.CommunicationException()); - when(mockAuthenticator.authenticate(authRequest)).thenThrow(expectedCause); - - LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider( - mockAuthenticator); + CommunicationException expectedCause = new CommunicationException(new javax.naming.CommunicationException()); + given(mockAuthenticator.authenticate(authRequest)).willThrow(expectedCause); + LdapAuthenticationProvider ldapProvider = new LdapAuthenticationProvider(mockAuthenticator); try { ldapProvider.authenticate(authRequest); fail("Expected Exception"); @@ -193,46 +165,42 @@ public class LdapAuthenticationProviderTests { } } - // ~ Inner Classes - // ================================================================================================== - class MockAuthenticator implements LdapAuthenticator { + @Override public DirContextOperations authenticate(Authentication authentication) { DirContextAdapter ctx = new DirContextAdapter(); ctx.setAttributeValue("ou", "FROM_ENTRY"); String username = authentication.getName(); String password = (String) authentication.getCredentials(); - if (username.equals("ben") && password.equals("benspassword")) { - ctx.setDn(new DistinguishedName( - "cn=ben,ou=people,dc=springframework,dc=org")); + ctx.setDn(new DistinguishedName("cn=ben,ou=people,dc=springframework,dc=org")); ctx.setAttributeValue("userPassword", "{SHA}nFCebWjxfaLbHHG1Qk5UU4trbvQ="); - return ctx; } else if (username.equals("jen") && password.equals("")) { - ctx.setDn(new DistinguishedName( - "cn=jen,ou=people,dc=springframework,dc=org")); - + ctx.setDn(new DistinguishedName("cn=jen,ou=people,dc=springframework,dc=org")); return ctx; } - throw new BadCredentialsException("Authentication failed."); } + } class MockAuthoritiesPopulator implements LdapAuthoritiesPopulator { + String username; - public Collection getGrantedAuthorities( - DirContextOperations userCtx, String username) { + @Override + public Collection getGrantedAuthorities(DirContextOperations userCtx, String username) { this.username = username; return AuthorityUtils.createAuthorityList("ROLE_FROM_POPULATOR"); } String getRequestedUsername() { - return username; + return this.username; } + } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/authentication/MockUserSearch.java b/ldap/src/test/java/org/springframework/security/ldap/authentication/MockUserSearch.java index cc846d9889..7d581a1f72 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/authentication/MockUserSearch.java +++ b/ldap/src/test/java/org/springframework/security/ldap/authentication/MockUserSearch.java @@ -16,24 +16,16 @@ package org.springframework.security.ldap.authentication; +import org.springframework.ldap.core.DirContextOperations; import org.springframework.security.ldap.search.LdapUserSearch; -import org.springframework.ldap.core.DirContextOperations; - /** - * - * * @author Luke Taylor */ public class MockUserSearch implements LdapUserSearch { - // ~ Instance fields - // ================================================================================================ DirContextOperations user; - // ~ Constructors - // =================================================================================================== - public MockUserSearch() { } @@ -41,10 +33,9 @@ public class MockUserSearch implements LdapUserSearch { this.user = user; } - // ~ Methods - // ======================================================================================================== - + @Override public DirContextOperations searchForUser(String username) { - return user; + return this.user; } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorMockTests.java b/ldap/src/test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorMockTests.java index e2aff2a5f3..7a1a8e35bd 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorMockTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/authentication/PasswordComparisonAuthenticatorMockTests.java @@ -16,8 +16,6 @@ package org.springframework.security.ldap.authentication; -import static org.mockito.Mockito.*; - import javax.naming.NamingEnumeration; import javax.naming.directory.BasicAttribute; import javax.naming.directory.BasicAttributes; @@ -25,45 +23,37 @@ import javax.naming.directory.DirContext; import javax.naming.directory.SearchControls; import org.junit.Test; + import org.springframework.ldap.core.support.BaseLdapPathContextSource; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** - * * @author Luke Taylor */ public class PasswordComparisonAuthenticatorMockTests { - // ~ Methods - // ======================================================================================================== - @Test public void ldapCompareOperationIsUsedWhenPasswordIsNotRetrieved() throws Exception { final DirContext dirCtx = mock(DirContext.class); final BaseLdapPathContextSource source = mock(BaseLdapPathContextSource.class); final BasicAttributes attrs = new BasicAttributes(); attrs.put(new BasicAttribute("uid", "bob")); - - PasswordComparisonAuthenticator authenticator = new PasswordComparisonAuthenticator( - source); - + PasswordComparisonAuthenticator authenticator = new PasswordComparisonAuthenticator(source); authenticator.setUserDnPatterns(new String[] { "cn={0},ou=people" }); - // Get the mock to return an empty attribute set - when(source.getReadOnlyContext()).thenReturn(dirCtx); - when(dirCtx.getAttributes(eq("cn=Bob,ou=people"), any(String[].class))) - .thenReturn(attrs); - when(dirCtx.getNameInNamespace()).thenReturn("dc=springframework,dc=org"); - + given(source.getReadOnlyContext()).willReturn(dirCtx); + given(dirCtx.getAttributes(eq("cn=Bob,ou=people"), any(String[].class))).willReturn(attrs); + given(dirCtx.getNameInNamespace()).willReturn("dc=springframework,dc=org"); // Setup a single return value (i.e. success) final NamingEnumeration searchResults = new BasicAttributes("", null).getAll(); - - when( - dirCtx.search(eq("cn=Bob,ou=people"), eq("(userPassword={0})"), - any(Object[].class), any(SearchControls.class))).thenReturn( - searchResults); - - authenticator.authenticate(new UsernamePasswordAuthenticationToken("Bob", - "bobspassword")); + given(dirCtx.search(eq("cn=Bob,ou=people"), eq("(userPassword={0})"), any(Object[].class), + any(SearchControls.class))).willReturn(searchResults); + authenticator.authenticate(new UsernamePasswordAuthenticationToken("Bob", "bobspassword")); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java b/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java index 27226bd631..8272ec247f 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.authentication.ad; import java.util.Collections; import java.util.Hashtable; + import javax.naming.AuthenticationException; import javax.naming.CommunicationException; import javax.naming.Name; @@ -49,363 +51,299 @@ import org.springframework.security.authentication.InternalAuthenticationService import org.springframework.security.authentication.LockedException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.ldap.authentication.ad.ActiveDirectoryLdapAuthenticationProvider.ContextFactory; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.ldap.authentication.ad.ActiveDirectoryLdapAuthenticationProvider.ContextFactory; /** * @author Luke Taylor * @author Rob Winch */ public class ActiveDirectoryLdapAuthenticationProviderTests { + public static final String EXISTING_LDAP_PROVIDER = "ldap://192.168.1.200/"; + public static final String NON_EXISTING_LDAP_PROVIDER = "ldap://192.168.1.201/"; @Rule public ExpectedException thrown = ExpectedException.none(); ActiveDirectoryLdapAuthenticationProvider provider; - UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken( - "joe", "password"); + + UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password"); @Before public void setUp() { - provider = new ActiveDirectoryLdapAuthenticationProvider("mydomain.eu", - "ldap://192.168.1.200/"); + this.provider = new ActiveDirectoryLdapAuthenticationProvider("mydomain.eu", "ldap://192.168.1.200/"); } @Test public void bindPrincipalIsCreatedCorrectly() { - assertThat(provider.createBindPrincipal("joe")).isEqualTo("joe@mydomain.eu"); - assertThat(provider.createBindPrincipal("joe@mydomain.eu")).isEqualTo("joe@mydomain.eu"); + assertThat(this.provider.createBindPrincipal("joe")).isEqualTo("joe@mydomain.eu"); + assertThat(this.provider.createBindPrincipal("joe@mydomain.eu")).isEqualTo("joe@mydomain.eu"); } @Test public void successfulAuthenticationProducesExpectedAuthorities() throws Exception { - checkAuthentication("dc=mydomain,dc=eu", provider); + checkAuthentication("dc=mydomain,dc=eu", this.provider); } // SEC-1915 @Test public void customSearchFilterIsUsedForSuccessfulAuthentication() throws Exception { - // given String customSearchFilter = "(&(objectClass=user)(sAMAccountName={0}))"; - DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - + given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); - SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, - dca.getAttributes()); - when( - ctx.search(any(Name.class), eq(customSearchFilter), any(Object[].class), - any(SearchControls.class))).thenReturn( - new MockNamingEnumeration(sr)); - + SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); + given(ctx.search(any(Name.class), eq(customSearchFilter), any(Object[].class), any(SearchControls.class))) + .willReturn(new MockNamingEnumeration(sr)); ActiveDirectoryLdapAuthenticationProvider customProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", "ldap://192.168.1.200/"); customProvider.contextFactory = createContextFactoryReturning(ctx); - - // when customProvider.setSearchFilter(customSearchFilter); - Authentication result = customProvider.authenticate(joe); - - // then + Authentication result = customProvider.authenticate(this.joe); assertThat(result.isAuthenticated()).isTrue(); } @Test public void defaultSearchFilter() throws Exception { - // given final String defaultSearchFilter = "(&(objectClass=user)(userPrincipalName={0}))"; - DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - + given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); - SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, - dca.getAttributes()); - when( - ctx.search(any(Name.class), eq(defaultSearchFilter), any(Object[].class), - any(SearchControls.class))).thenReturn( - new MockNamingEnumeration(sr)); - + SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); + given(ctx.search(any(Name.class), eq(defaultSearchFilter), any(Object[].class), any(SearchControls.class))) + .willReturn(new MockNamingEnumeration(sr)); ActiveDirectoryLdapAuthenticationProvider customProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", "ldap://192.168.1.200/"); customProvider.contextFactory = createContextFactoryReturning(ctx); - - // when - Authentication result = customProvider.authenticate(joe); - - // then + Authentication result = customProvider.authenticate(this.joe); assertThat(result.isAuthenticated()).isTrue(); - verify(ctx).search(any(DistinguishedName.class), eq(defaultSearchFilter), - any(Object[].class), any(SearchControls.class)); + verify(ctx).search(any(DistinguishedName.class), eq(defaultSearchFilter), any(Object[].class), + any(SearchControls.class)); } // SEC-2897,SEC-2224 @Test public void bindPrincipalAndUsernameUsed() throws Exception { - // given final String defaultSearchFilter = "(&(objectClass=user)(userPrincipalName={0}))"; ArgumentCaptor captor = ArgumentCaptor.forClass(Object[].class); - DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - + given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); - SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, - dca.getAttributes()); - when( - ctx.search(any(Name.class), eq(defaultSearchFilter), captor.capture(), - any(SearchControls.class))).thenReturn( - new MockNamingEnumeration(sr)); - + SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); + given(ctx.search(any(Name.class), eq(defaultSearchFilter), captor.capture(), any(SearchControls.class))) + .willReturn(new MockNamingEnumeration(sr)); ActiveDirectoryLdapAuthenticationProvider customProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", "ldap://192.168.1.200/"); customProvider.contextFactory = createContextFactoryReturning(ctx); - - // when - Authentication result = customProvider.authenticate(joe); - - // then + Authentication result = customProvider.authenticate(this.joe); assertThat(captor.getValue()).containsExactly("joe@mydomain.eu", "joe"); assertThat(result.isAuthenticated()).isTrue(); } @Test(expected = IllegalArgumentException.class) public void setSearchFilterNull() { - provider.setSearchFilter(null); + this.provider.setSearchFilter(null); } @Test(expected = IllegalArgumentException.class) public void setSearchFilterEmpty() { - provider.setSearchFilter(" "); + this.provider.setSearchFilter(" "); } @Test - public void nullDomainIsSupportedIfAuthenticatingWithFullUserPrincipal() - throws Exception { - provider = new ActiveDirectoryLdapAuthenticationProvider(null, - "ldap://192.168.1.200/"); + public void nullDomainIsSupportedIfAuthenticatingWithFullUserPrincipal() throws Exception { + this.provider = new ActiveDirectoryLdapAuthenticationProvider(null, "ldap://192.168.1.200/"); DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - + given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); - SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, - dca.getAttributes()); - when( - ctx.search(eq(new DistinguishedName("DC=mydomain,DC=eu")), - any(String.class), any(Object[].class), any(SearchControls.class))) - .thenReturn(new MockNamingEnumeration(sr)); - provider.contextFactory = createContextFactoryReturning(ctx); - + SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); + given(ctx.search(eq(new DistinguishedName("DC=mydomain,DC=eu")), any(String.class), any(Object[].class), + any(SearchControls.class))).willReturn(new MockNamingEnumeration(sr)); + this.provider.contextFactory = createContextFactoryReturning(ctx); try { - provider.authenticate(joe); + this.provider.authenticate(this.joe); fail("Expected BadCredentialsException for user with no domain information"); } catch (BadCredentialsException expected) { } - - provider.authenticate(new UsernamePasswordAuthenticationToken("joe@mydomain.eu", - "password")); + this.provider.authenticate(new UsernamePasswordAuthenticationToken("joe@mydomain.eu", "password")); } @Test(expected = BadCredentialsException.class) public void failedUserSearchCausesBadCredentials() throws Exception { DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - when( - ctx.search(any(Name.class), any(String.class), any(Object[].class), - any(SearchControls.class))) - .thenThrow(new NameNotFoundException()); - - provider.contextFactory = createContextFactoryReturning(ctx); - - provider.authenticate(joe); + given(ctx.getNameInNamespace()).willReturn(""); + given(ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) + .willThrow(new NameNotFoundException()); + this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.authenticate(this.joe); } // SEC-2017 @Test(expected = BadCredentialsException.class) public void noUserSearchCausesUsernameNotFound() throws Exception { DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - when( - ctx.search(any(Name.class), any(String.class), any(Object[].class), - any(SearchControls.class))).thenReturn( - new EmptyEnumeration<>()); - - provider.contextFactory = createContextFactoryReturning(ctx); - - provider.authenticate(joe); + given(ctx.getNameInNamespace()).willReturn(""); + given(ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) + .willReturn(new EmptyEnumeration<>()); + this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.authenticate(this.joe); } // SEC-2500 @Test(expected = BadCredentialsException.class) public void sec2500PreventAnonymousBind() { - provider.authenticate(new UsernamePasswordAuthenticationToken("rwinch", "")); + this.provider.authenticate(new UsernamePasswordAuthenticationToken("rwinch", "")); } @SuppressWarnings("unchecked") @Test(expected = IncorrectResultSizeDataAccessException.class) public void duplicateUserSearchCausesError() throws Exception { DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); + given(ctx.getNameInNamespace()).willReturn(""); NamingEnumeration searchResults = mock(NamingEnumeration.class); - when(searchResults.hasMore()).thenReturn(true, true, false); + given(searchResults.hasMore()).willReturn(true, true, false); SearchResult searchResult = mock(SearchResult.class); - when(searchResult.getObject()).thenReturn(new DirContextAdapter("ou=1"), - new DirContextAdapter("ou=2")); - when(searchResults.next()).thenReturn(searchResult); - when( - ctx.search(any(Name.class), any(String.class), any(Object[].class), - any(SearchControls.class))).thenReturn(searchResults); - - provider.contextFactory = createContextFactoryReturning(ctx); - - provider.authenticate(joe); + given(searchResult.getObject()).willReturn(new DirContextAdapter("ou=1"), new DirContextAdapter("ou=2")); + given(searchResults.next()).willReturn(searchResult); + given(ctx.search(any(Name.class), any(String.class), any(Object[].class), any(SearchControls.class))) + .willReturn(searchResults); + this.provider.contextFactory = createContextFactoryReturning(ctx); + this.provider.authenticate(this.joe); } static final String msg = "[LDAP: error code 49 - 80858585: LdapErr: DSID-DECAFF0, comment: AcceptSecurityContext error, data "; @Test(expected = BadCredentialsException.class) public void userNotFoundIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "525, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "525, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = BadCredentialsException.class) public void incorrectPasswordIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "52e, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "52e, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = BadCredentialsException.class) public void notPermittedIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "530, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "530, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test public void passwordNeedsResetIsCorrectlyMapped() { final String dataCode = "773"; - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + dataCode + ", xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - - thrown.expect(BadCredentialsException.class); - thrown.expect(new BaseMatcher() { + this.provider.contextFactory = createContextFactoryThrowing( + new AuthenticationException(msg + dataCode + ", xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.thrown.expect(BadCredentialsException.class); + this.thrown.expect(new BaseMatcher() { private Matcher causeInstance = CoreMatchers .instanceOf(ActiveDirectoryAuthenticationException.class); + private Matcher causeDataCode = CoreMatchers.equalTo(dataCode); + @Override public boolean matches(Object that) { Throwable t = (Throwable) that; - ActiveDirectoryAuthenticationException cause = (ActiveDirectoryAuthenticationException) t - .getCause(); - return causeInstance.matches(cause) - && causeDataCode.matches(cause.getDataCode()); + ActiveDirectoryAuthenticationException cause = (ActiveDirectoryAuthenticationException) t.getCause(); + return this.causeInstance.matches(cause) && this.causeDataCode.matches(cause.getDataCode()); } + @Override public void describeTo(Description desc) { desc.appendText("getCause() "); - causeInstance.describeTo(desc); + this.causeInstance.describeTo(desc); desc.appendText("getCause().getDataCode() "); - causeDataCode.describeTo(desc); + this.causeDataCode.describeTo(desc); } }); - - provider.authenticate(joe); + this.provider.authenticate(this.joe); } @Test(expected = CredentialsExpiredException.class) public void expiredPasswordIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "532, xxxx]")); - + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "532, xxxx]")); try { - provider.authenticate(joe); + this.provider.authenticate(this.joe); fail("BadCredentialsException should had been thrown"); } catch (BadCredentialsException expected) { } - - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = DisabledException.class) public void accountDisabledIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "533, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "533, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = AccountExpiredException.class) public void accountExpiredIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "701, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "701, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = LockedException.class) public void accountLockedIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "775, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "775, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = BadCredentialsException.class) public void unknownErrorCodeIsCorrectlyMapped() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg + "999, xxxx]")); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg + "999, xxxx]")); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = BadCredentialsException.class) public void errorWithNoSubcodeIsHandledCleanly() { - provider.contextFactory = createContextFactoryThrowing(new AuthenticationException( - msg)); - provider.setConvertSubErrorCodesToExceptions(true); - provider.authenticate(joe); + this.provider.contextFactory = createContextFactoryThrowing(new AuthenticationException(msg)); + this.provider.setConvertSubErrorCodesToExceptions(true); + this.provider.authenticate(this.joe); } @Test(expected = org.springframework.ldap.CommunicationException.class) public void nonAuthenticationExceptionIsConvertedToSpringLdapException() throws Throwable { try { - provider.contextFactory = createContextFactoryThrowing(new CommunicationException( - msg)); - provider.authenticate(joe); - } catch (InternalAuthenticationServiceException e) { - // Since GH-8418 ldap communication exception is wrapped into InternalAuthenticationServiceException. + this.provider.contextFactory = createContextFactoryThrowing(new CommunicationException(msg)); + this.provider.authenticate(this.joe); + } + catch (InternalAuthenticationServiceException ex) { + // Since GH-8418 ldap communication exception is wrapped into + // InternalAuthenticationServiceException. // This test is about the wrapped exception, so we throw it. - throw e.getCause(); + throw ex.getCause(); } } - @Test(expected = org.springframework.security.authentication.InternalAuthenticationServiceException.class ) + @Test(expected = org.springframework.security.authentication.InternalAuthenticationServiceException.class) public void connectionExceptionIsWrappedInInternalException() throws Exception { ActiveDirectoryLdapAuthenticationProvider noneReachableProvider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", NON_EXISTING_LDAP_PROVIDER, "dc=ad,dc=eu,dc=mydomain"); - noneReachableProvider.setContextEnvironmentProperties( - Collections.singletonMap("com.sun.jndi.ldap.connect.timeout", "5")); - noneReachableProvider.doAuthentication(joe); + noneReachableProvider + .setContextEnvironmentProperties(Collections.singletonMap("com.sun.jndi.ldap.connect.timeout", "5")); + noneReachableProvider.doAuthentication(this.joe); } @Test @@ -413,43 +351,40 @@ public class ActiveDirectoryLdapAuthenticationProviderTests { ActiveDirectoryLdapAuthenticationProvider provider = new ActiveDirectoryLdapAuthenticationProvider( "mydomain.eu", EXISTING_LDAP_PROVIDER, "dc=ad,dc=eu,dc=mydomain"); checkAuthentication("dc=ad,dc=eu,dc=mydomain", provider); - } @Test(expected = IllegalArgumentException.class) public void setContextEnvironmentPropertiesNull() { - provider.setContextEnvironmentProperties(null); + this.provider.setContextEnvironmentProperties(null); } @Test(expected = IllegalArgumentException.class) public void setContextEnvironmentPropertiesEmpty() { - provider.setContextEnvironmentProperties(new Hashtable<>()); + this.provider.setContextEnvironmentProperties(new Hashtable<>()); } @Test public void contextEnvironmentPropertiesUsed() { Hashtable env = new Hashtable<>(); - env.put("java.naming.ldap.factory.socket", "unknown.package.NonExistingSocketFactory"); - provider.setContextEnvironmentProperties(env); - + this.provider.setContextEnvironmentProperties(env); try { - provider.authenticate(joe); + this.provider.authenticate(this.joe); fail("CommunicationException was expected with a root cause of ClassNotFoundException"); } catch (InternalAuthenticationServiceException expected) { assertThat(expected.getCause()).isInstanceOf(org.springframework.ldap.CommunicationException.class); - org.springframework.ldap.CommunicationException cause = - (org.springframework.ldap.CommunicationException) expected.getCause(); + org.springframework.ldap.CommunicationException cause = (org.springframework.ldap.CommunicationException) expected + .getCause(); assertThat(cause.getRootCause()).isInstanceOf(ClassNotFoundException.class); } } - ContextFactory createContextFactoryThrowing(final NamingException e) { + ContextFactory createContextFactoryThrowing(final NamingException ex) { return new ContextFactory() { @Override DirContext createContext(Hashtable env) throws NamingException { - throw e; + throw ex; } }; } @@ -463,60 +398,58 @@ public class ActiveDirectoryLdapAuthenticationProviderTests { }; } - private void checkAuthentication(String rootDn, - ActiveDirectoryLdapAuthenticationProvider provider) throws NamingException { + private void checkAuthentication(String rootDn, ActiveDirectoryLdapAuthenticationProvider provider) + throws NamingException { DirContext ctx = mock(DirContext.class); - when(ctx.getNameInNamespace()).thenReturn(""); - + given(ctx.getNameInNamespace()).willReturn(""); DirContextAdapter dca = new DirContextAdapter(); - SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, - dca.getAttributes()); + SearchResult sr = new SearchResult("CN=Joe Jannsen,CN=Users", dca, dca.getAttributes()); @SuppressWarnings("deprecation") DistinguishedName searchBaseDn = new DistinguishedName(rootDn); - when( - ctx.search(eq(searchBaseDn), any(String.class), any(Object[].class), - any(SearchControls.class))).thenReturn( - new MockNamingEnumeration(sr)).thenReturn(new MockNamingEnumeration(sr)); - + given(ctx.search(eq(searchBaseDn), any(String.class), any(Object[].class), any(SearchControls.class))) + .willReturn(new MockNamingEnumeration(sr)).willReturn(new MockNamingEnumeration(sr)); provider.contextFactory = createContextFactoryReturning(ctx); - - Authentication result = provider.authenticate(joe); - + Authentication result = provider.authenticate(this.joe); assertThat(result.getAuthorities()).isEmpty(); - dca.addAttributeValue("memberOf", "CN=Admin,CN=Users,DC=mydomain,DC=eu"); - - result = provider.authenticate(joe); - + result = provider.authenticate(this.joe); assertThat(result.getAuthorities()).hasSize(1); } static class MockNamingEnumeration implements NamingEnumeration { + private SearchResult sr; MockNamingEnumeration(SearchResult sr) { this.sr = sr; } + @Override public SearchResult next() { - SearchResult result = sr; - sr = null; + SearchResult result = this.sr; + this.sr = null; return result; } + @Override public boolean hasMore() { - return sr != null; + return this.sr != null; } + @Override public void close() { } + @Override public boolean hasMoreElements() { return hasMore(); } + @Override public SearchResult nextElement() { return next(); } + } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/OpenLDAPIntegrationTestSuite.java b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/OpenLDAPIntegrationTestSuite.java index 02ed0e1c8f..4636baa78f 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/OpenLDAPIntegrationTestSuite.java +++ b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/OpenLDAPIntegrationTestSuite.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; /** @@ -27,6 +28,7 @@ package org.springframework.security.ldap.ppolicy; public class OpenLDAPIntegrationTestSuite { PasswordPolicyAwareContextSource cs; + /* * @Before public void createContextSource() throws Exception { cs = new * PasswordPolicyAwareContextSource("ldap://localhost:22389/dc=springsource,dc=com"); @@ -60,4 +62,5 @@ public class OpenLDAPIntegrationTestSuite { * = (LdapUserDetailsImpl) a.getPrincipal(); assertTrue(ud.getTimeBeforeExpiration() < * Integer.MAX_VALUE && ud.getTimeBeforeExpiration() > 0); } */ + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSourceTests.java b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSourceTests.java index fbe36fc031..84cc77f851 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSourceTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyAwareContextSourceTests.java @@ -13,71 +13,72 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import org.junit.*; -import org.springframework.ldap.UncategorizedLdapException; +import java.util.Hashtable; import javax.naming.Context; import javax.naming.NamingException; import javax.naming.directory.DirContext; import javax.naming.ldap.Control; import javax.naming.ldap.LdapContext; -import java.util.*; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.ldap.UncategorizedLdapException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; /** * @author Luke Taylor */ public class PasswordPolicyAwareContextSourceTests { + private PasswordPolicyAwareContextSource ctxSource; + private final LdapContext ctx = mock(LdapContext.class); @Before public void setUp() { - reset(ctx); - ctxSource = new PasswordPolicyAwareContextSource( - "ldap://blah:789/dc=springframework,dc=org") { + reset(this.ctx); + this.ctxSource = new PasswordPolicyAwareContextSource("ldap://blah:789/dc=springframework,dc=org") { @Override protected DirContext createContext(Hashtable env) { if ("manager".equals(env.get(Context.SECURITY_PRINCIPAL))) { - return ctx; + return PasswordPolicyAwareContextSourceTests.this.ctx; } - return null; } }; - ctxSource.setUserDn("manager"); - ctxSource.setPassword("password"); - ctxSource.afterPropertiesSet(); + this.ctxSource.setUserDn("manager"); + this.ctxSource.setPassword("password"); + this.ctxSource.afterPropertiesSet(); } @Test public void contextIsReturnedWhenNoControlsAreSetAndReconnectIsSuccessful() { - assertThat(ctxSource.getContext("user", "ignored")).isNotNull(); + assertThat(this.ctxSource.getContext("user", "ignored")).isNotNull(); } @Test(expected = UncategorizedLdapException.class) - public void standardExceptionIsPropagatedWhenExceptionRaisedAndNoControlsAreSet() - throws Exception { - doThrow(new NamingException("some LDAP exception")).when(ctx).reconnect( - any(Control[].class)); - - ctxSource.getContext("user", "ignored"); + public void standardExceptionIsPropagatedWhenExceptionRaisedAndNoControlsAreSet() throws Exception { + willThrow(new NamingException("some LDAP exception")).given(this.ctx).reconnect(any(Control[].class)); + this.ctxSource.getContext("user", "ignored"); } @Test(expected = PasswordPolicyException.class) - public void lockedPasswordPolicyControlRaisesPasswordPolicyException() - throws Exception { - when(ctx.getResponseControls()).thenReturn( - new Control[] { new PasswordPolicyResponseControl( - PasswordPolicyResponseControlTests.OPENLDAP_LOCKED_CTRL) }); - - doThrow(new NamingException("locked message")).when(ctx).reconnect( - any(Control[].class)); - - ctxSource.getContext("user", "ignored"); + public void lockedPasswordPolicyControlRaisesPasswordPolicyException() throws Exception { + given(this.ctx.getResponseControls()).willReturn(new Control[] { + new PasswordPolicyResponseControl(PasswordPolicyResponseControlTests.OPENLDAP_LOCKED_CTRL) }); + willThrow(new NamingException("locked message")).given(this.ctx).reconnect(any(Control[].class)); + this.ctxSource.getContext("user", "ignored"); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactoryTests.java b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactoryTests.java index c29f13b34f..50babf5437 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactoryTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyControlFactoryTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.ppolicy; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import org.junit.*; - import javax.naming.ldap.Control; +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * @author Luke Taylor */ @@ -31,8 +33,7 @@ public class PasswordPolicyControlFactoryTests { public void returnsNullForUnrecognisedOID() { PasswordPolicyControlFactory ctrlFactory = new PasswordPolicyControlFactory(); Control wrongCtrl = mock(Control.class); - - when(wrongCtrl.getID()).thenReturn("wrongId"); + given(wrongCtrl.getID()).willReturn("wrongId"); assertThat(ctrlFactory.getControlInstance(wrongCtrl)).isNull(); } @@ -40,12 +41,11 @@ public class PasswordPolicyControlFactoryTests { public void returnsControlForCorrectOID() { PasswordPolicyControlFactory ctrlFactory = new PasswordPolicyControlFactory(); Control control = mock(Control.class); - - when(control.getID()).thenReturn(PasswordPolicyControl.OID); - when(control.getEncodedValue()).thenReturn( - PasswordPolicyResponseControlTests.OPENLDAP_LOCKED_CTRL); + given(control.getID()).willReturn(PasswordPolicyControl.OID); + given(control.getEncodedValue()).willReturn(PasswordPolicyResponseControlTests.OPENLDAP_LOCKED_CTRL); Control result = ctrlFactory.getControlInstance(control); assertThat(result).isNotNull(); assertThat(PasswordPolicyResponseControlTests.OPENLDAP_LOCKED_CTRL).isEqualTo(result.getEncodedValue()); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControlTests.java b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControlTests.java index 9ed6c7cd31..0422f10ef7 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControlTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/ppolicy/PasswordPolicyResponseControlTests.java @@ -16,18 +16,16 @@ package org.springframework.security.ldap.ppolicy; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests for PasswordPolicyResponse. * * @author Luke Taylor */ public class PasswordPolicyResponseControlTests { - // ~ Methods - // ======================================================================================================== /** * Useful method for obtaining data from a server for use in tests @@ -68,7 +66,6 @@ public class PasswordPolicyResponseControlTests { // // //com.sun.jndi.ldap.LdapPoolManager.showStats(System.out); // } - // private PasswordPolicyResponseControl getPPolicyResponseCtl(InitialLdapContext ctx) // throws NamingException { // Control[] ctrls = ctx.getResponseControls(); @@ -81,36 +78,27 @@ public class PasswordPolicyResponseControlTests { // // return null; // } - @Test public void openLDAP33SecondsTillPasswordExpiryCtrlIsParsedCorrectly() { byte[] ctrlBytes = { 0x30, 0x05, (byte) 0xA0, 0x03, (byte) 0xA0, 0x1, 0x21 }; - PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl(ctrlBytes); - assertThat(ctrl.hasWarning()).isTrue(); assertThat(ctrl.getTimeBeforeExpiration()).isEqualTo(33); } @Test public void openLDAP496GraceLoginsRemainingCtrlIsParsedCorrectly() { - byte[] ctrlBytes = { 0x30, 0x06, (byte) 0xA0, 0x04, (byte) 0xA1, 0x02, 0x01, - (byte) 0xF0 }; - + byte[] ctrlBytes = { 0x30, 0x06, (byte) 0xA0, 0x04, (byte) 0xA1, 0x02, 0x01, (byte) 0xF0 }; PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl(ctrlBytes); - assertThat(ctrl.hasWarning()).isTrue(); assertThat(ctrl.getGraceLoginsRemaining()).isEqualTo(496); } - static final byte[] OPENLDAP_5_LOGINS_REMAINING_CTRL = { 0x30, 0x05, (byte) 0xA0, - 0x03, (byte) 0xA1, 0x01, 0x05 }; + static final byte[] OPENLDAP_5_LOGINS_REMAINING_CTRL = { 0x30, 0x05, (byte) 0xA0, 0x03, (byte) 0xA1, 0x01, 0x05 }; @Test public void openLDAP5GraceLoginsRemainingCtrlIsParsedCorrectly() { - PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl( - OPENLDAP_5_LOGINS_REMAINING_CTRL); - + PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl(OPENLDAP_5_LOGINS_REMAINING_CTRL); assertThat(ctrl.hasWarning()).isTrue(); assertThat(ctrl.getGraceLoginsRemaining()).isEqualTo(5); } @@ -119,9 +107,7 @@ public class PasswordPolicyResponseControlTests { @Test public void openLDAPAccountLockedCtrlIsParsedCorrectly() { - PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl( - OPENLDAP_LOCKED_CTRL); - + PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl(OPENLDAP_LOCKED_CTRL); assertThat(ctrl.hasError() && ctrl.isLocked()).isTrue(); assertThat(ctrl.hasWarning()).isFalse(); } @@ -129,10 +115,9 @@ public class PasswordPolicyResponseControlTests { @Test public void openLDAPPasswordExpiredCtrlIsParsedCorrectly() { byte[] ctrlBytes = { 0x30, 0x03, (byte) 0xA1, 0x01, 0x00 }; - PasswordPolicyResponseControl ctrl = new PasswordPolicyResponseControl(ctrlBytes); - assertThat(ctrl.hasError() && ctrl.isExpired()).isTrue(); assertThat(ctrl.hasWarning()).isFalse(); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/userdetails/InetOrgPersonTests.java b/ldap/src/test/java/org/springframework/security/ldap/userdetails/InetOrgPersonTests.java index 168f725460..5e007f84ed 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/userdetails/InetOrgPersonTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/userdetails/InetOrgPersonTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap.userdetails; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.ldap.userdetails; import java.util.HashSet; import java.util.Set; import org.junit.Test; + import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DistinguishedName; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Luke Taylor */ @@ -33,7 +35,6 @@ public class InetOrgPersonTests { public void testUsernameIsMappedFromContextUidIfNotSet() { InetOrgPerson.Essence essence = new InetOrgPerson.Essence(createUserContext()); InetOrgPerson p = (InetOrgPerson) essence.createUserDetails(); - assertThat(p.getUsername()).isEqualTo("ghengis"); } @@ -53,7 +54,6 @@ public class InetOrgPersonTests { InetOrgPerson.Essence essence = new InetOrgPerson.Essence(createUserContext()); essence.setUsername("joe"); InetOrgPerson p = (InetOrgPerson) essence.createUserDetails(); - assertThat(p.getUsername()).isEqualTo("joe"); assertThat(p.getUid()).isEqualTo("ghengis"); } @@ -62,7 +62,6 @@ public class InetOrgPersonTests { public void attributesMapCorrectlyFromContext() { InetOrgPerson.Essence essence = new InetOrgPerson.Essence(createUserContext()); InetOrgPerson p = (InetOrgPerson) essence.createUserDetails(); - assertThat(p.getCarLicense()).isEqualTo("HORS1"); assertThat(p.getMail()).isEqualTo("ghengis@mongolia"); assertThat(p.getGivenName()).isEqualTo("Ghengis"); @@ -87,7 +86,6 @@ public class InetOrgPersonTests { public void testPasswordIsSetFromContextUserPassword() { InetOrgPerson.Essence essence = new InetOrgPerson.Essence(createUserContext()); InetOrgPerson p = (InetOrgPerson) essence.createUserDetails(); - assertThat(p.getPassword()).isEqualTo("pillage"); } @@ -95,13 +93,11 @@ public class InetOrgPersonTests { public void mappingBackToContextMatchesOriginalData() { DirContextAdapter ctx1 = createUserContext(); DirContextAdapter ctx2 = new DirContextAdapter(); - ctx1.setAttributeValues("objectclass", new String[] { "top", "person", - "organizationalPerson", "inetOrgPerson" }); + ctx1.setAttributeValues("objectclass", + new String[] { "top", "person", "organizationalPerson", "inetOrgPerson" }); ctx2.setDn(new DistinguishedName("ignored=ignored")); - InetOrgPerson p = (InetOrgPerson) (new InetOrgPerson.Essence(ctx1)) - .createUserDetails(); + InetOrgPerson p = (InetOrgPerson) (new InetOrgPerson.Essence(ctx1)).createUserDetails(); p.populateContext(ctx2); - assertThat(ctx2).isEqualTo(ctx1); } @@ -110,20 +106,16 @@ public class InetOrgPersonTests { DirContextAdapter ctx1 = createUserContext(); DirContextAdapter ctx2 = new DirContextAdapter(); ctx2.setDn(new DistinguishedName("ignored=ignored")); - ctx1.setAttributeValues("objectclass", new String[] { "top", "person", - "organizationalPerson", "inetOrgPerson" }); - InetOrgPerson p = (InetOrgPerson) (new InetOrgPerson.Essence(ctx1)) - .createUserDetails(); - InetOrgPerson p2 = (InetOrgPerson) new InetOrgPerson.Essence(p) - .createUserDetails(); + ctx1.setAttributeValues("objectclass", + new String[] { "top", "person", "organizationalPerson", "inetOrgPerson" }); + InetOrgPerson p = (InetOrgPerson) (new InetOrgPerson.Essence(ctx1)).createUserDetails(); + InetOrgPerson p2 = (InetOrgPerson) new InetOrgPerson.Essence(p).createUserDetails(); p2.populateContext(ctx2); - assertThat(ctx2).isEqualTo(ctx1); } private DirContextAdapter createUserContext() { DirContextAdapter ctx = new DirContextAdapter(); - ctx.setDn(new DistinguishedName("ignored=ignored")); ctx.setAttributeValue("uid", "ghengis"); ctx.setAttributeValue("userPassword", "pillage"); @@ -148,7 +140,6 @@ public class InetOrgPersonTests { ctx.setAttributeValue("sn", "Khan"); ctx.setAttributeValue("street", "Westward Avenue"); ctx.setAttributeValue("telephoneNumber", "+442075436521"); - return ctx; } diff --git a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapAuthorityTests.java b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapAuthorityTests.java index c6727986c7..27b9c8421e 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapAuthorityTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapAuthorityTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap.userdetails; -import org.junit.Before; -import org.junit.Test; -import org.springframework.security.ldap.SpringSecurityLdapTemplate; +package org.springframework.security.ldap.userdetails; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.ldap.SpringSecurityLdapTemplate; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -32,38 +34,39 @@ import static org.assertj.core.api.Assertions.assertThat; public class LdapAuthorityTests { public static final String DN = "cn=filip,ou=Users,dc=test,dc=com"; + LdapAuthority authority; @Before public void setUp() { Map> attributes = new HashMap<>(); attributes.put(SpringSecurityLdapTemplate.DN_KEY, Arrays.asList(DN)); - attributes.put("mail", - Arrays.asList("filip@ldap.test.org", "filip@ldap.test2.org")); - authority = new LdapAuthority("testRole", DN, attributes); + attributes.put("mail", Arrays.asList("filip@ldap.test.org", "filip@ldap.test2.org")); + this.authority = new LdapAuthority("testRole", DN, attributes); } @Test public void testGetDn() { - assertThat(authority.getDn()).isEqualTo(DN); - assertThat(authority.getAttributeValues(SpringSecurityLdapTemplate.DN_KEY)).isNotNull(); - assertThat(authority.getAttributeValues(SpringSecurityLdapTemplate.DN_KEY)).hasSize(1); - assertThat(authority.getFirstAttributeValue(SpringSecurityLdapTemplate.DN_KEY)).isEqualTo(DN); + assertThat(this.authority.getDn()).isEqualTo(DN); + assertThat(this.authority.getAttributeValues(SpringSecurityLdapTemplate.DN_KEY)).isNotNull(); + assertThat(this.authority.getAttributeValues(SpringSecurityLdapTemplate.DN_KEY)).hasSize(1); + assertThat(this.authority.getFirstAttributeValue(SpringSecurityLdapTemplate.DN_KEY)).isEqualTo(DN); } @Test public void testGetAttributes() { - assertThat(authority.getAttributes()).isNotNull(); - assertThat(authority.getAttributeValues("mail")).isNotNull(); - assertThat(authority.getAttributeValues("mail")).hasSize(2); - assertThat(authority.getFirstAttributeValue("mail")).isEqualTo("filip@ldap.test.org"); - assertThat(authority.getAttributeValues("mail").get(0)).isEqualTo("filip@ldap.test.org"); - assertThat(authority.getAttributeValues("mail").get(1)).isEqualTo("filip@ldap.test2.org"); + assertThat(this.authority.getAttributes()).isNotNull(); + assertThat(this.authority.getAttributeValues("mail")).isNotNull(); + assertThat(this.authority.getAttributeValues("mail")).hasSize(2); + assertThat(this.authority.getFirstAttributeValue("mail")).isEqualTo("filip@ldap.test.org"); + assertThat(this.authority.getAttributeValues("mail").get(0)).isEqualTo("filip@ldap.test.org"); + assertThat(this.authority.getAttributeValues("mail").get(1)).isEqualTo("filip@ldap.test2.org"); } @Test public void testGetAuthority() { - assertThat(authority.getAuthority()).isNotNull(); - assertThat(authority.getAuthority()).isEqualTo("testRole"); + assertThat(this.authority.getAuthority()).isNotNull(); + assertThat(this.authority.getAuthority()).isEqualTo("testRole"); } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImplTests.java b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImplTests.java index 30556e3002..e803d89288 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImplTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsImplTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.ldap.userdetails; import org.junit.Test; + import org.springframework.security.core.CredentialsContainer; import static org.assertj.core.api.Assertions.assertThat; @@ -33,7 +35,6 @@ public class LdapUserDetailsImplTests { mutableLdapUserDetails.setDn("uid=username1,ou=people,dc=example,dc=com"); mutableLdapUserDetails.setUsername("username1"); mutableLdapUserDetails.setPassword("password"); - LdapUserDetails ldapUserDetails = mutableLdapUserDetails.createUserDetails(); assertThat(ldapUserDetails).isInstanceOf(CredentialsContainer.class); ldapUserDetails.eraseCredentials(); diff --git a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapperTests.java b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapperTests.java index 1c8ca9e384..e0205051e3 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapperTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsMapperTests.java @@ -40,17 +40,12 @@ public class LdapUserDetailsMapperTests { LdapUserDetailsMapper mapper = new LdapUserDetailsMapper(); mapper.setConvertToUpperCase(false); mapper.setRolePrefix(""); - mapper.setRoleAttributes(new String[] { "userRole" }); - DirContextAdapter ctx = new DirContextAdapter(); - ctx.setAttributeValues("userRole", new String[] { "X", "Y", "Z" }); ctx.setAttributeValue("uid", "ani"); - - LdapUserDetailsImpl user = (LdapUserDetailsImpl) mapper.mapUserFromContext(ctx, - "ani", AuthorityUtils.NO_AUTHORITIES); - + LdapUserDetailsImpl user = (LdapUserDetailsImpl) mapper.mapUserFromContext(ctx, "ani", + AuthorityUtils.NO_AUTHORITIES); assertThat(user.getAuthorities()).hasSize(3); } @@ -60,19 +55,13 @@ public class LdapUserDetailsMapperTests { @Test public void testNonRetrievedRoleAttributeIsIgnored() { LdapUserDetailsMapper mapper = new LdapUserDetailsMapper(); - mapper.setRoleAttributes(new String[] { "userRole", "nonRetrievedAttribute" }); - BasicAttributes attrs = new BasicAttributes(); attrs.put(new BasicAttribute("userRole", "x")); - - DirContextAdapter ctx = new DirContextAdapter(attrs, - new DistinguishedName("cn=someName")); + DirContextAdapter ctx = new DirContextAdapter(attrs, new DistinguishedName("cn=someName")); ctx.setAttributeValue("uid", "ani"); - - LdapUserDetailsImpl user = (LdapUserDetailsImpl) mapper.mapUserFromContext(ctx, - "ani", AuthorityUtils.NO_AUTHORITIES); - + LdapUserDetailsImpl user = (LdapUserDetailsImpl) mapper.mapUserFromContext(ctx, "ani", + AuthorityUtils.NO_AUTHORITIES); assertThat(user.getAuthorities()).hasSize(1); assertThat(AuthorityUtils.authorityListToSet(user.getAuthorities())).contains("ROLE_X"); } @@ -80,18 +69,13 @@ public class LdapUserDetailsMapperTests { @Test public void testPasswordAttributeIsMappedCorrectly() { LdapUserDetailsMapper mapper = new LdapUserDetailsMapper(); - mapper.setPasswordAttributeName("myappsPassword"); BasicAttributes attrs = new BasicAttributes(); attrs.put(new BasicAttribute("myappsPassword", "mypassword".getBytes())); - - DirContextAdapter ctx = new DirContextAdapter(attrs, - new DistinguishedName("cn=someName")); + DirContextAdapter ctx = new DirContextAdapter(attrs, new DistinguishedName("cn=someName")); ctx.setAttributeValue("uid", "ani"); - LdapUserDetails user = (LdapUserDetailsImpl) mapper.mapUserFromContext(ctx, "ani", AuthorityUtils.NO_AUTHORITIES); - assertThat(user.getPassword()).isEqualTo("mypassword"); } diff --git a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsServiceTests.java b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsServiceTests.java index d8ff18881a..3dfd7c53dd 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsServiceTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/userdetails/LdapUserDetailsServiceTests.java @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap.userdetails; -import static org.assertj.core.api.Assertions.*; +package org.springframework.security.ldap.userdetails; import java.util.Collection; import java.util.Set; import org.junit.Test; + import org.springframework.ldap.core.DirContextAdapter; import org.springframework.ldap.core.DirContextOperations; import org.springframework.ldap.core.DistinguishedName; @@ -30,6 +30,8 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.ldap.authentication.MockUserSearch; import org.springframework.security.ldap.authentication.NullLdapAuthoritiesPopulator; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests for {@link LdapUserDetailsService} * @@ -49,36 +51,31 @@ public class LdapUserDetailsServiceTests { @Test public void correctAuthoritiesAreReturned() { - DirContextAdapter userData = new DirContextAdapter(new DistinguishedName( - "uid=joe")); - - LdapUserDetailsService service = new LdapUserDetailsService(new MockUserSearch( - userData), new MockAuthoritiesPopulator()); + DirContextAdapter userData = new DirContextAdapter(new DistinguishedName("uid=joe")); + LdapUserDetailsService service = new LdapUserDetailsService(new MockUserSearch(userData), + new MockAuthoritiesPopulator()); service.setUserDetailsMapper(new LdapUserDetailsMapper()); - UserDetails user = service.loadUserByUsername("doesntmatterwegetjoeanyway"); - - Set authorities = AuthorityUtils - .authorityListToSet(user.getAuthorities()); + Set authorities = AuthorityUtils.authorityListToSet(user.getAuthorities()); assertThat(authorities).hasSize(1); assertThat(authorities.contains("ROLE_FROM_POPULATOR")).isTrue(); } @Test public void nullPopulatorConstructorReturnsEmptyAuthoritiesList() { - DirContextAdapter userData = new DirContextAdapter(new DistinguishedName( - "uid=joe")); - - LdapUserDetailsService service = new LdapUserDetailsService(new MockUserSearch( - userData)); + DirContextAdapter userData = new DirContextAdapter(new DistinguishedName("uid=joe")); + LdapUserDetailsService service = new LdapUserDetailsService(new MockUserSearch(userData)); UserDetails user = service.loadUserByUsername("doesntmatterwegetjoeanyway"); assertThat(user.getAuthorities()).isEmpty(); } class MockAuthoritiesPopulator implements LdapAuthoritiesPopulator { - public Collection getGrantedAuthorities( - DirContextOperations userCtx, String username) { + + @Override + public Collection getGrantedAuthorities(DirContextOperations userCtx, String username) { return AuthorityUtils.createAuthorityList("ROLE_FROM_POPULATOR"); } + } + } diff --git a/ldap/src/test/java/org/springframework/security/ldap/userdetails/UserDetailsServiceLdapAuthoritiesPopulatorTests.java b/ldap/src/test/java/org/springframework/security/ldap/userdetails/UserDetailsServiceLdapAuthoritiesPopulatorTests.java index 083f91c32d..09ea5382e6 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/userdetails/UserDetailsServiceLdapAuthoritiesPopulatorTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/userdetails/UserDetailsServiceLdapAuthoritiesPopulatorTests.java @@ -13,15 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.ldap.userdetails; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.ldap.userdetails; import java.util.Collection; import java.util.List; import org.junit.Test; + import org.springframework.ldap.core.DirContextAdapter; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; @@ -29,6 +28,10 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.ldap.authentication.UserDetailsServiceLdapAuthoritiesPopulator; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** * @author Luke Taylor */ @@ -38,16 +41,13 @@ public class UserDetailsServiceLdapAuthoritiesPopulatorTests { public void delegationToUserDetailsServiceReturnsCorrectRoles() { UserDetailsService uds = mock(UserDetailsService.class); UserDetails user = mock(UserDetails.class); - when(uds.loadUserByUsername("joe")).thenReturn(user); + given(uds.loadUserByUsername("joe")).willReturn(user); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(user.getAuthorities()).thenReturn(authorities); - - UserDetailsServiceLdapAuthoritiesPopulator populator = new UserDetailsServiceLdapAuthoritiesPopulator( - uds); - Collection auths = populator.getGrantedAuthorities( - new DirContextAdapter(), "joe"); - + given(user.getAuthorities()).willReturn(authorities); + UserDetailsServiceLdapAuthoritiesPopulator populator = new UserDetailsServiceLdapAuthoritiesPopulator(uds); + Collection auths = populator.getGrantedAuthorities(new DirContextAdapter(), "joe"); assertThat(auths).hasSize(1); assertThat(AuthorityUtils.authorityListToSet(auths).contains("ROLE_USER")).isTrue(); } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandler.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandler.java index dee415443a..4a0e6e8ffa 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandler.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; import org.springframework.messaging.Message; @@ -29,22 +30,19 @@ import org.springframework.util.Assert; * {@link MessageSecurityExpressionRoot}. * * @param the type for the body of the Message - * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ -public class DefaultMessageSecurityExpressionHandler extends - AbstractSecurityExpressionHandler> { +public class DefaultMessageSecurityExpressionHandler extends AbstractSecurityExpressionHandler> { private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); @Override - protected SecurityExpressionOperations createSecurityExpressionRoot( - Authentication authentication, Message invocation) { - MessageSecurityExpressionRoot root = new MessageSecurityExpressionRoot( - authentication, invocation); + protected SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, + Message invocation) { + MessageSecurityExpressionRoot root = new MessageSecurityExpressionRoot(authentication, invocation); root.setPermissionEvaluator(getPermissionEvaluator()); - root.setTrustResolver(trustResolver); + root.setTrustResolver(this.trustResolver); root.setRoleHierarchy(getRoleHierarchy()); return root; } @@ -53,4 +51,5 @@ public class DefaultMessageSecurityExpressionHandler extends Assert.notNull(trustResolver, "trustResolver cannot be null"); this.trustResolver = trustResolver; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java index 847969bca2..ef7a10f438 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java @@ -13,12 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; import org.springframework.expression.EvaluationContext; -/** - * /** * Allows post processing the {@link EvaluationContext} * @@ -32,15 +31,13 @@ import org.springframework.expression.EvaluationContext; interface EvaluationContextPostProcessor { /** - * Allows post processing of the {@link EvaluationContext}. Implementations - * may return a new instance of {@link EvaluationContext} or modify the - * {@link EvaluationContext} that was passed in. - * - * @param context - * the original {@link EvaluationContext} - * @param invocation - * the security invocation object (i.e. Message) + * Allows post processing of the {@link EvaluationContext}. Implementations may return + * a new instance of {@link EvaluationContext} or modify the {@link EvaluationContext} + * that was passed in. + * @param context the original {@link EvaluationContext} + * @param invocation the security invocation object (i.e. Message) * @return the upated context. */ EvaluationContext postProcess(EvaluationContext context, I invocation); + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java index 69934b64b4..a819ce4cd3 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; import java.util.Arrays; @@ -32,11 +33,14 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher; * A class used to create a {@link MessageSecurityMetadataSource} that uses * {@link MessageMatcher} mapped to Spring Expressions. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { + private ExpressionBasedMessageSecurityMetadataSourceFactory() { + } + /** * Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher} * mapped to Spring Expressions. Each entry is considered in order and only the first @@ -61,7 +65,6 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { * *

        * For a complete listing of expressions see {@link MessageSecurityExpressionRoot} - * * @param matcherToExpression an ordered mapping of {@link MessageMatcher} to Strings * that are turned into an Expression using * {@link DefaultMessageSecurityExpressionHandler#getExpressionParser()} @@ -69,7 +72,8 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { */ public static MessageSecurityMetadataSource createExpressionMessageMetadataSource( LinkedHashMap, String> matcherToExpression) { - return createExpressionMessageMetadataSource(matcherToExpression, new DefaultMessageSecurityExpressionHandler<>()); + return createExpressionMessageMetadataSource(matcherToExpression, + new DefaultMessageSecurityExpressionHandler<>()); } /** @@ -98,7 +102,6 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { *

        * For a complete listing of expressions see {@link MessageSecurityExpressionRoot} *

        - * * @param matcherToExpression an ordered mapping of {@link MessageMatcher} to Strings * that are turned into an Expression using * {@link DefaultMessageSecurityExpressionHandler#getExpressionParser()} @@ -106,21 +109,17 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { * @return the {@link MessageSecurityMetadataSource} to use. Cannot be null. */ public static MessageSecurityMetadataSource createExpressionMessageMetadataSource( - LinkedHashMap, String> matcherToExpression, SecurityExpressionHandler> handler) { - + LinkedHashMap, String> matcherToExpression, + SecurityExpressionHandler> handler) { LinkedHashMap, Collection> matcherToAttrs = new LinkedHashMap<>(); - for (Map.Entry, String> entry : matcherToExpression.entrySet()) { MessageMatcher matcher = entry.getKey(); String rawExpression = entry.getValue(); - Expression expression = handler.getExpressionParser().parseExpression( - rawExpression); + Expression expression = handler.getExpressionParser().parseExpression(rawExpression); ConfigAttribute attribute = new MessageExpressionConfigAttribute(expression, matcher); matcherToAttrs.put(matcher, Arrays.asList(attribute)); } return new DefaultMessageSecurityMetadataSource(matcherToAttrs); } - private ExpressionBasedMessageSecurityMetadataSourceFactory() { - } } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java index 51a599c7f0..e663c4f06a 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; +import java.util.Map; + import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.messaging.Message; @@ -23,8 +26,6 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher; import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; import org.springframework.util.Assert; -import java.util.Map; - /** * Simple expression configuration attribute for use in {@link Message} authorizations. * @@ -34,13 +35,13 @@ import java.util.Map; */ @SuppressWarnings("serial") class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationContextPostProcessor> { - private final Expression authorizeExpression; - private final MessageMatcher matcher; + private final Expression authorizeExpression; + + private final MessageMatcher matcher; /** * Creates a new instance - * * @param authorizeExpression the {@link Expression} to use. Cannot be null * @param matcher the {@link MessageMatcher} used to match the messages. */ @@ -51,28 +52,30 @@ class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationCon this.matcher = matcher; } - Expression getAuthorizeExpression() { - return authorizeExpression; + return this.authorizeExpression; } + @Override public String getAttribute() { return null; } @Override public String toString() { - return authorizeExpression.getExpressionString(); + return this.authorizeExpression.getExpressionString(); } @Override public EvaluationContext postProcess(EvaluationContext ctx, Message message) { - if (matcher instanceof SimpDestinationMessageMatcher) { - final Map variables = ((SimpDestinationMessageMatcher) matcher).extractPathVariables(message); - for (Map.Entry entry : variables.entrySet()){ + if (this.matcher instanceof SimpDestinationMessageMatcher) { + Map variables = ((SimpDestinationMessageMatcher) this.matcher) + .extractPathVariables(message); + for (Map.Entry entry : variables.entrySet()) { ctx.setVariable(entry.getKey(), entry.getValue()); } } return ctx; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java index 0d06873e59..b097df8c1e 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; +import java.util.Collection; + import org.springframework.expression.EvaluationContext; import org.springframework.messaging.Message; import org.springframework.security.access.AccessDecisionVoter; @@ -24,8 +27,6 @@ import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.core.Authentication; import org.springframework.util.Assert; -import java.util.Collection; - /** * Voter which handles {@link Message} authorisation decisions. If a * {@link MessageExpressionConfigAttribute} is found, then its expression is evaluated. If @@ -33,35 +34,29 @@ import java.util.Collection; * If no {@code MessageExpressionConfigAttribute} is found, then {@code ACCESS_ABSTAIN} is * returned. * - * @since 4.0 * @author Rob Winch * @author Daniel Bustamante Ospina + * @since 4.0 */ public class MessageExpressionVoter implements AccessDecisionVoter> { + private SecurityExpressionHandler> expressionHandler = new DefaultMessageSecurityExpressionHandler<>(); - public int vote(Authentication authentication, Message message, - Collection attributes) { - assert authentication != null; - assert message != null; - assert attributes != null; - + @Override + public int vote(Authentication authentication, Message message, Collection attributes) { + Assert.notNull(authentication, "authentication must not be null"); + Assert.notNull(message, "message must not be null"); + Assert.notNull(attributes, "attributes must not be null"); MessageExpressionConfigAttribute attr = findConfigAttribute(attributes); - if (attr == null) { return ACCESS_ABSTAIN; } - - EvaluationContext ctx = expressionHandler.createEvaluationContext(authentication, - message); + EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message); ctx = attr.postProcess(ctx, message); - - return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED - : ACCESS_DENIED; + return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED; } - private MessageExpressionConfigAttribute findConfigAttribute( - Collection attributes) { + private MessageExpressionConfigAttribute findConfigAttribute(Collection attributes) { for (ConfigAttribute attribute : attributes) { if (attribute instanceof MessageExpressionConfigAttribute) { return (MessageExpressionConfigAttribute) attribute; @@ -70,17 +65,19 @@ public class MessageExpressionVoter implements AccessDecisionVoter return null; } + @Override public boolean supports(ConfigAttribute attribute) { return attribute instanceof MessageExpressionConfigAttribute; } + @Override public boolean supports(Class clazz) { return Message.class.isAssignableFrom(clazz); } - public void setExpressionHandler( - SecurityExpressionHandler> expressionHandler) { + public void setExpressionHandler(SecurityExpressionHandler> expressionHandler) { Assert.notNull(expressionHandler, "expressionHandler cannot be null"); this.expressionHandler = expressionHandler; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageSecurityExpressionRoot.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageSecurityExpressionRoot.java index 2a7580f778..710fbeb154 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageSecurityExpressionRoot.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageSecurityExpressionRoot.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; import org.springframework.messaging.Message; @@ -22,8 +23,8 @@ import org.springframework.security.core.Authentication; /** * The {@link SecurityExpressionRoot} used for {@link Message} expressions. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public class MessageSecurityExpressionRoot extends SecurityExpressionRoot { @@ -33,4 +34,5 @@ public class MessageSecurityExpressionRoot extends SecurityExpressionRoot { super(authentication); this.message = message; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptor.java index 8efd76b741..9fe9f7117f 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.intercept; import org.springframework.messaging.Message; @@ -33,18 +34,17 @@ import org.springframework.util.Assert; *

        * Refer to {@link AbstractSecurityInterceptor} for details on the workflow. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ -public final class ChannelSecurityInterceptor extends AbstractSecurityInterceptor - implements ChannelInterceptor { +public final class ChannelSecurityInterceptor extends AbstractSecurityInterceptor implements ChannelInterceptor { + private static final ThreadLocal tokenHolder = new ThreadLocal<>(); private final MessageSecurityMetadataSource metadataSource; /** * Creates a new instance - * * @param metadataSource the MessageSecurityMetadataSource to use. Cannot be null. * * @see DefaultMessageSecurityMetadataSource @@ -62,9 +62,10 @@ public final class ChannelSecurityInterceptor extends AbstractSecurityIntercepto @Override public SecurityMetadataSource obtainSecurityMetadataSource() { - return metadataSource; + return this.metadataSource; } + @Override public Message preSend(Message message, MessageChannel channel) { InterceptorStatusToken token = beforeInvocation(message); if (token != null) { @@ -73,27 +74,30 @@ public final class ChannelSecurityInterceptor extends AbstractSecurityIntercepto return message; } + @Override public void postSend(Message message, MessageChannel channel, boolean sent) { InterceptorStatusToken token = clearToken(); afterInvocation(token, null); } - public void afterSendCompletion(Message message, MessageChannel channel, - boolean sent, Exception ex) { + @Override + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { InterceptorStatusToken token = clearToken(); finallyInvocation(token); } + @Override public boolean preReceive(MessageChannel channel) { return true; } + @Override public Message postReceive(Message message, MessageChannel channel) { return message; } - public void afterReceiveCompletion(Message message, MessageChannel channel, - Exception ex) { + @Override + public void afterReceiveCompletion(Message message, MessageChannel channel, Exception ex) { } private InterceptorStatusToken clearToken() { @@ -101,4 +105,5 @@ public final class ChannelSecurityInterceptor extends AbstractSecurityIntercepto tokenHolder.remove(); return token; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java index 13c661a0c1..6e3eb8ba41 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java @@ -13,15 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.intercept; +import java.util.Collection; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; + import org.springframework.messaging.Message; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory; import org.springframework.security.messaging.util.matcher.MessageMatcher; -import java.util.*; - /** * A default implementation of {@link MessageSecurityMetadataSource} that looks up the * {@link ConfigAttribute} instances using a {@link MessageMatcher}. @@ -31,14 +36,13 @@ import java.util.*; * {@code Collection} is returned. *

        * + * @author Rob Winch + * @since 4.0 * @see ChannelSecurityInterceptor * @see ExpressionBasedMessageSecurityMetadataSourceFactory - * - * @since 4.0 - * @author Rob Winch */ -public final class DefaultMessageSecurityMetadataSource implements - MessageSecurityMetadataSource { +public final class DefaultMessageSecurityMetadataSource implements MessageSecurityMetadataSource { + private final Map, Collection> messageMap; public DefaultMessageSecurityMetadataSource( @@ -46,12 +50,11 @@ public final class DefaultMessageSecurityMetadataSource implements this.messageMap = messageMap; } + @Override @SuppressWarnings({ "rawtypes", "unchecked" }) - public Collection getAttributes(Object object) - throws IllegalArgumentException { + public Collection getAttributes(Object object) throws IllegalArgumentException { final Message message = (Message) object; - for (Map.Entry, Collection> entry : messageMap - .entrySet()) { + for (Map.Entry, Collection> entry : this.messageMap.entrySet()) { if (entry.getKey().matches(message)) { return entry.getValue(); } @@ -59,17 +62,18 @@ public final class DefaultMessageSecurityMetadataSource implements return null; } + @Override public Collection getAllConfigAttributes() { Set allAttributes = new HashSet<>(); - - for (Collection entry : messageMap.values()) { + for (Collection entry : this.messageMap.values()) { allAttributes.addAll(entry); } - return allAttributes; } + @Override public boolean supports(Class clazz) { return Message.class.isAssignableFrom(clazz); } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageSecurityMetadataSource.java b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageSecurityMetadataSource.java index fefca0620d..acf6565c45 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageSecurityMetadataSource.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageSecurityMetadataSource.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.intercept; import org.springframework.messaging.Message; @@ -21,11 +22,11 @@ import org.springframework.security.access.SecurityMetadataSource; /** * A {@link SecurityMetadataSource} that is used for securing {@link Message} * + * @author Rob Winch + * @since 4.0 * @see ChannelSecurityInterceptor * @see DefaultMessageSecurityMetadataSource - * - * @since 4.0 - * @author Rob Winch */ public interface MessageSecurityMetadataSource extends SecurityMetadataSource { + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java index faebd3d2d9..58cd4f720e 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.context; import java.lang.annotation.Annotation; @@ -81,86 +82,60 @@ import org.springframework.util.StringUtils; * @author Rob Winch * @since 4.0 */ -public final class AuthenticationPrincipalArgumentResolver - implements HandlerMethodArgumentResolver { +public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { private ExpressionParser parser = new SpelExpressionParser(); - /* - * (non-Javadoc) - * - * @see - * org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver# - * supportsParameter(org.springframework.core.MethodParameter) - */ + @Override public boolean supportsParameter(MethodParameter parameter) { return findMethodAnnotation(AuthenticationPrincipal.class, parameter) != null; } - /* - * (non-Javadoc) - * - * @see - * org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver# - * resolveArgument(org.springframework.core.MethodParameter, - * org.springframework.messaging.Message) - */ + @Override public Object resolveArgument(MethodParameter parameter, Message message) { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (authentication == null) { return null; } Object principal = authentication.getPrincipal(); - - AuthenticationPrincipal authPrincipal = findMethodAnnotation( - AuthenticationPrincipal.class, parameter); - + AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); String expressionToParse = authPrincipal.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(principal); context.setVariable("this", principal); - Expression expression = this.parser.parseExpression(expressionToParse); principal = expression.getValue(context); } - - if (principal != null - && !parameter.getParameterType().isAssignableFrom(principal.getClass())) { + if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) { if (authPrincipal.errorOnInvalidType()) { - throw new ClassCastException(principal + " is not assignable to " - + parameter.getParameterType()); - } - else { - return null; + throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } + return null; } return principal; } /** * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. - * * @param annotationClass the class of the {@link Annotation} to find on the * {@link MethodParameter} * @param parameter the {@link MethodParameter} to search for an {@link Annotation} * @return the {@link Annotation} that was found or null. */ - private T findMethodAnnotation(Class annotationClass, - MethodParameter parameter) { + private T findMethodAnnotation(Class annotationClass, MethodParameter parameter) { T annotation = parameter.getParameterAnnotation(annotationClass); if (annotation != null) { return annotation; } Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); for (Annotation toSearch : annotationsToSearch) { - annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), - annotationClass); + annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), annotationClass); if (annotation != null) { return annotation; } } return null; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java index 50f5113dd9..594cfcacba 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.context; import java.util.Stack; @@ -36,19 +37,20 @@ import org.springframework.util.Assert; * {@link Authentication} from the specified {@link Message#getHeaders()}. *

        * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter implements ExecutorChannelInterceptor { - private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder - .createEmptyContext(); - private static final ThreadLocal> ORIGINAL_CONTEXT = new ThreadLocal<>(); + + private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); + + private static final ThreadLocal> originalContext = new ThreadLocal<>(); private final String authenticationHeaderName; - private Authentication anonymous = new AnonymousAuthenticationToken("key", - "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); /** * Creates a new instance using the header of the name @@ -61,13 +63,11 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA /** * Creates a new instance that uses the specified header to obtain the * {@link Authentication}. - * * @param authenticationHeaderName the header name to obtain the * {@link Authentication}. Cannot be null. */ public SecurityContextChannelInterceptor(String authenticationHeaderName) { - Assert.notNull(authenticationHeaderName, - "authenticationHeaderName cannot be null"); + Assert.notNull(authenticationHeaderName, "authenticationHeaderName cannot be null"); this.authenticationHeaderName = authenticationHeaderName; } @@ -78,7 +78,6 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA * new AnonymousAuthenticationToken("key", "anonymous", * AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); * - * * @param authentication the Authentication used for anonymous authentication. Cannot * be null. */ @@ -94,68 +93,63 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA } @Override - public void afterSendCompletion(Message message, MessageChannel channel, - boolean sent, Exception ex) { + public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) { cleanup(); } - public Message beforeHandle(Message message, MessageChannel channel, - MessageHandler handler) { + @Override + public Message beforeHandle(Message message, MessageChannel channel, MessageHandler handler) { setup(message); return message; } - public void afterMessageHandled(Message message, MessageChannel channel, - MessageHandler handler, Exception ex) { + @Override + public void afterMessageHandled(Message message, MessageChannel channel, MessageHandler handler, Exception ex) { cleanup(); } private void setup(Message message) { SecurityContext currentContext = SecurityContextHolder.getContext(); - - Stack contextStack = ORIGINAL_CONTEXT.get(); + Stack contextStack = originalContext.get(); if (contextStack == null) { contextStack = new Stack<>(); - ORIGINAL_CONTEXT.set(contextStack); + originalContext.set(contextStack); } contextStack.push(currentContext); - - Object user = message.getHeaders().get(authenticationHeaderName); - - Authentication authentication; - if ((user instanceof Authentication)) { - authentication = (Authentication) user; - } - else { - authentication = this.anonymous; - } + Object user = message.getHeaders().get(this.authenticationHeaderName); + Authentication authentication = getAuthentication(user); SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); } - private void cleanup() { - Stack contextStack = ORIGINAL_CONTEXT.get(); + private Authentication getAuthentication(Object user) { + if ((user instanceof Authentication)) { + return (Authentication) user; + } + return this.anonymous; + } + private void cleanup() { + Stack contextStack = originalContext.get(); if (contextStack == null || contextStack.isEmpty()) { SecurityContextHolder.clearContext(); - ORIGINAL_CONTEXT.remove(); + originalContext.remove(); return; } - - SecurityContext originalContext = contextStack.pop(); - + SecurityContext context = contextStack.pop(); try { - if (EMPTY_CONTEXT.equals(originalContext)) { + if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) { SecurityContextHolder.clearContext(); - ORIGINAL_CONTEXT.remove(); + originalContext.remove(); } else { - SecurityContextHolder.setContext(originalContext); + SecurityContextHolder.setContext(context); } } - catch (Throwable t) { + catch (Throwable ex) { SecurityContextHolder.clearContext(); } } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java index c34fb576ac..2c83e97902 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java @@ -16,7 +16,11 @@ package org.springframework.security.messaging.handler.invocation.reactive; +import java.lang.annotation.Annotation; + import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; @@ -36,9 +40,6 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; - -import java.lang.annotation.Annotation; /** * Allows resolving the {@link Authentication#getPrincipal()} using the @@ -86,18 +87,17 @@ import java.lang.annotation.Annotation; * } * } * + * * @author Rob Winch * @since 5.2 */ -public class AuthenticationPrincipalArgumentResolver - implements HandlerMethodArgumentResolver { +public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { private ExpressionParser parser = new SpelExpressionParser(); private BeanResolver beanResolver; - private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry - .getSharedInstance(); + private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance(); /** * Sets the {@link BeanResolver} to be used on the expressions @@ -109,8 +109,8 @@ public class AuthenticationPrincipalArgumentResolver /** * Sets the {@link ReactiveAdapterRegistry} to be used. - * @param adapterRegistry the {@link ReactiveAdapterRegistry} to use. Cannot be null. Default is - * {@link ReactiveAdapterRegistry#getSharedInstance()} + * @param adapterRegistry the {@link ReactiveAdapterRegistry} to use. Cannot be null. + * Default is {@link ReactiveAdapterRegistry#getSharedInstance()} */ public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) { Assert.notNull(adapterRegistry, "adapterRegistry cannot be null"); @@ -122,46 +122,37 @@ public class AuthenticationPrincipalArgumentResolver return findMethodAnnotation(AuthenticationPrincipal.class, parameter) != null; } + @Override public Mono resolveArgument(MethodParameter parameter, Message message) { - ReactiveAdapter adapter = this.adapterRegistry - .getAdapter(parameter.getParameterType()); + ReactiveAdapter adapter = this.adapterRegistry.getAdapter(parameter.getParameterType()); + // @formatter:off return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication).flatMap(a -> { + .map(SecurityContext::getAuthentication) + .flatMap((a) -> { Object p = resolvePrincipal(parameter, a.getPrincipal()); Mono principal = Mono.justOrEmpty(p); - return adapter == null ? - principal : - Mono.just(adapter.fromPublisher(principal)); + return (adapter != null) ? Mono.just(adapter.fromPublisher(principal)) : principal; }); + // @formatter:on } private Object resolvePrincipal(MethodParameter parameter, Object principal) { - AuthenticationPrincipal authPrincipal = findMethodAnnotation( - AuthenticationPrincipal.class, parameter); - + AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); String expressionToParse = authPrincipal.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(principal); context.setVariable("this", principal); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); principal = expression.getValue(context); } - if (isInvalidType(parameter, principal)) { - if (authPrincipal.errorOnInvalidType()) { - throw new ClassCastException( - principal + " is not assignable to " + parameter - .getParameterType()); - } - else { - return null; + throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } + return null; } - return principal; } @@ -170,8 +161,7 @@ public class AuthenticationPrincipalArgumentResolver return false; } Class typeToCheck = parameter.getParameterType(); - boolean isParameterPublisher = Publisher.class - .isAssignableFrom(parameter.getParameterType()); + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); if (isParameterPublisher) { ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); Class genericType = resolvableType.resolveGeneric(0); @@ -185,26 +175,24 @@ public class AuthenticationPrincipalArgumentResolver /** * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. - * * @param annotationClass the class of the {@link Annotation} to find on the - * {@link MethodParameter} - * @param parameter the {@link MethodParameter} to search for an {@link Annotation} + * {@link MethodParameter} + * @param parameter the {@link MethodParameter} to search for an {@link Annotation} * @return the {@link Annotation} that was found or null. */ - private T findMethodAnnotation(Class annotationClass, - MethodParameter parameter) { + private T findMethodAnnotation(Class annotationClass, MethodParameter parameter) { T annotation = parameter.getParameterAnnotation(annotationClass); if (annotation != null) { return annotation; } Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); for (Annotation toSearch : annotationsToSearch) { - annotation = AnnotationUtils - .findAnnotation(toSearch.annotationType(), annotationClass); + annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), annotationClass); if (annotation != null) { return annotation; } } return null; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java index f771a7d03a..3b67e85a43 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java @@ -16,7 +16,11 @@ package org.springframework.security.messaging.handler.invocation.reactive; +import java.lang.annotation.Annotation; + import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; @@ -36,9 +40,6 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; - -import java.lang.annotation.Annotation; /** * Allows resolving the {@link Authentication#getPrincipal()} using the @@ -56,9 +57,9 @@ import java.lang.annotation.Annotation; * * *

        - * Will resolve the SecurityContext argument using the {@link ReactiveSecurityContextHolder}. - * If the {@link SecurityContext} is empty, it will return null. If the types do not - * match, null will be returned unless + * Will resolve the SecurityContext argument using the + * {@link ReactiveSecurityContextHolder}. If the {@link SecurityContext} is empty, it will + * return null. If the types do not match, null will be returned unless * {@link CurrentSecurityContext#errorOnInvalidType()} is true in which case a * {@link ClassCastException} will be thrown. * @@ -85,18 +86,17 @@ import java.lang.annotation.Annotation; * } * } * + * * @author Rob Winch * @since 5.2 */ -public class CurrentSecurityContextArgumentResolver - implements HandlerMethodArgumentResolver { +public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgumentResolver { private ExpressionParser parser = new SpelExpressionParser(); private BeanResolver beanResolver; - private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry - .getSharedInstance(); + private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance(); /** * Sets the {@link BeanResolver} to be used on the expressions @@ -108,8 +108,8 @@ public class CurrentSecurityContextArgumentResolver /** * Sets the {@link ReactiveAdapterRegistry} to be used. - * @param adapterRegistry the {@link ReactiveAdapterRegistry} to use. Cannot be null. Default is - * {@link ReactiveAdapterRegistry#getSharedInstance()} + * @param adapterRegistry the {@link ReactiveAdapterRegistry} to use. Cannot be null. + * Default is {@link ReactiveAdapterRegistry#getSharedInstance()} */ public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) { Assert.notNull(adapterRegistry, "adapterRegistry cannot be null"); @@ -121,46 +121,36 @@ public class CurrentSecurityContextArgumentResolver return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; } + @Override public Mono resolveArgument(MethodParameter parameter, Message message) { - ReactiveAdapter adapter = this.adapterRegistry - .getAdapter(parameter.getParameterType()); + ReactiveAdapter adapter = this.adapterRegistry.getAdapter(parameter.getParameterType()); + // @formatter:off return ReactiveSecurityContextHolder.getContext() - .flatMap(securityContext -> { + .flatMap((securityContext) -> { Object sc = resolveSecurityContext(parameter, securityContext); Mono result = Mono.justOrEmpty(sc); - return adapter == null ? - result : - Mono.just(adapter.fromPublisher(result)); + return (adapter != null) ? Mono.just(adapter.fromPublisher(result)) : result; }); + // @formatter:on } private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) { - CurrentSecurityContext contextAnno = findMethodAnnotation( - CurrentSecurityContext.class, parameter); - + CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter); String expressionToParse = contextAnno.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(securityContext); context.setVariable("this", securityContext); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); securityContext = expression.getValue(context); } - if (isInvalidType(parameter, securityContext)) { - if (contextAnno.errorOnInvalidType()) { - throw new ClassCastException( - securityContext + " is not assignable to " + parameter - .getParameterType()); - } - else { - return null; + throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType()); } + return null; } - return securityContext; } @@ -169,8 +159,7 @@ public class CurrentSecurityContextArgumentResolver return false; } Class typeToCheck = parameter.getParameterType(); - boolean isParameterPublisher = Publisher.class - .isAssignableFrom(parameter.getParameterType()); + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); if (isParameterPublisher) { ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); Class genericType = resolvableType.resolveGeneric(0); @@ -184,26 +173,24 @@ public class CurrentSecurityContextArgumentResolver /** * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. - * * @param annotationClass the class of the {@link Annotation} to find on the - * {@link MethodParameter} - * @param parameter the {@link MethodParameter} to search for an {@link Annotation} + * {@link MethodParameter} + * @param parameter the {@link MethodParameter} to search for an {@link Annotation} * @return the {@link Annotation} that was found or null. */ - private T findMethodAnnotation(Class annotationClass, - MethodParameter parameter) { + private T findMethodAnnotation(Class annotationClass, MethodParameter parameter) { T annotation = parameter.getParameterAnnotation(annotationClass); if (annotation != null) { return annotation; } Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); for (Annotation toSearch : annotationsToSearch) { - annotation = AnnotationUtils - .findAnnotation(toSearch.annotationType(), annotationClass); + annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), annotationClass); if (annotation != null) { return annotation; } } return null; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java index 7f13a639ba..bf899ba1df 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java @@ -13,57 +13,61 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.util.matcher; -import static java.util.Arrays.asList; -import static org.apache.commons.logging.LogFactory.getLog; -import static org.springframework.util.Assert.notEmpty; - +import java.util.Arrays; import java.util.List; import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.util.Assert; /** * Abstract {@link MessageMatcher} containing multiple {@link MessageMatcher} * * @since 4.0 */ -abstract class AbstractMessageMatcherComposite implements MessageMatcher { - protected final Log LOGGER = getLog(getClass()); +public abstract class AbstractMessageMatcherComposite implements MessageMatcher { + + protected final Log logger = LogFactory.getLog(getClass()); + + /** + * @deprecated since 5.4 in favor of {@link #logger} + */ + @Deprecated + protected final Log LOGGER = this.logger; private final List> messageMatchers; /** * Creates a new instance - * * @param messageMatchers the {@link MessageMatcher} instances to try */ AbstractMessageMatcherComposite(List> messageMatchers) { - notEmpty(messageMatchers, "messageMatchers must contain a value"); - if (messageMatchers.contains(null)) { - throw new IllegalArgumentException( - "messageMatchers cannot contain null values"); - } + Assert.notEmpty(messageMatchers, "messageMatchers must contain a value"); + Assert.isTrue(!messageMatchers.contains(null), "messageMatchers cannot contain null values"); this.messageMatchers = messageMatchers; } /** * Creates a new instance - * * @param messageMatchers the {@link MessageMatcher} instances to try */ @SafeVarargs AbstractMessageMatcherComposite(MessageMatcher... messageMatchers) { - this(asList(messageMatchers)); + this(Arrays.asList(messageMatchers)); } public List> getMessageMatchers() { - return messageMatchers; + return this.messageMatchers; } @Override public String toString() { - return getClass().getSimpleName() + "[messageMatchers=" + messageMatchers + "]"; + return getClass().getSimpleName() + "[messageMatchers=" + this.messageMatchers + "]"; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java index 03cbc3c98c..6edc0c4ef4 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.util.matcher; import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.messaging.Message; /** @@ -26,9 +28,9 @@ import org.springframework.messaging.Message; * @since 4.0 */ public final class AndMessageMatcher extends AbstractMessageMatcherComposite { + /** * Creates a new instance - * * @param messageMatchers the {@link MessageMatcher} instances to try */ public AndMessageMatcher(List> messageMatchers) { @@ -37,7 +39,6 @@ public final class AndMessageMatcher extends AbstractMessageMatcherComposite< /** * Creates a new instance - * * @param messageMatchers the {@link MessageMatcher} instances to try */ @SafeVarargs @@ -46,17 +47,17 @@ public final class AndMessageMatcher extends AbstractMessageMatcherComposite< } + @Override public boolean matches(Message message) { for (MessageMatcher matcher : getMessageMatchers()) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Trying to match using " + matcher); - } + this.logger.debug(LogMessage.format("Trying to match using %s", matcher)); if (!matcher.matches(message)) { - LOGGER.debug("Did not match"); + this.logger.debug("Did not match"); return false; } } - LOGGER.debug("All messageMatchers returned true"); + this.logger.debug("All messageMatchers returned true"); return true; } -} \ No newline at end of file + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java index 41e2a36229..ffafb72a6a 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.util.matcher; import org.springframework.messaging.Message; @@ -20,22 +21,16 @@ import org.springframework.messaging.Message; /** * API for determining if a {@link Message} should be matched on. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public interface MessageMatcher { - /** - * Returns true if the {@link Message} matches, else false - * @param message the {@link Message} to match on - * @return true if the {@link Message} matches, else false - */ - boolean matches(Message message); - /** * Matches every {@link Message} */ MessageMatcher ANY_MESSAGE = new MessageMatcher() { + @Override public boolean matches(Message message) { return true; @@ -45,5 +40,14 @@ public interface MessageMatcher { public String toString() { return "ANY_MESSAGE"; } + }; + + /** + * Returns true if the {@link Message} matches, else false + * @param message the {@link Message} to match on + * @return true if the {@link Message} matches, else false + */ + boolean matches(Message message); + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java index b60a8d3d4d..010fe7aecf 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.util.matcher; import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.messaging.Message; /** @@ -26,9 +28,9 @@ import org.springframework.messaging.Message; * @since 4.0 */ public final class OrMessageMatcher extends AbstractMessageMatcherComposite { + /** * Creates a new instance - * * @param messageMatchers the {@link MessageMatcher} instances to try */ public OrMessageMatcher(List> messageMatchers) { @@ -37,7 +39,6 @@ public final class OrMessageMatcher extends AbstractMessageMatcherComposite extends AbstractMessageMatcherComposite message) { for (MessageMatcher matcher : getMessageMatchers()) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Trying to match using " + matcher); - } + this.logger.debug(LogMessage.format("Trying to match using %s", matcher)); if (matcher.matches(message)) { - LOGGER.debug("matched"); + this.logger.debug("matched"); return true; } } - LOGGER.debug("No matches found"); + this.logger.debug("No matches found"); return false; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java index 0db50c1982..d4ae0e15d6 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.util.matcher; +import java.util.Collections; +import java.util.Map; + import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; @@ -22,9 +26,6 @@ import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; import org.springframework.util.PathMatcher; -import java.util.Collections; -import java.util.Map; - /** *

        * MessageMatcher which compares a pre-defined pattern against the destination of a @@ -32,13 +33,13 @@ import java.util.Map; * {@link SimpMessageType}. *

        * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public final class SimpDestinationMessageMatcher implements MessageMatcher { - public static final MessageMatcher NULL_DESTINATION_MATCHER = message -> { - String destination = SimpMessageHeaderAccessor.getDestination(message - .getHeaders()); + + public static final MessageMatcher NULL_DESTINATION_MATCHER = (message) -> { + String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); return destination == null; }; @@ -49,6 +50,7 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher messageTypeMatcher; + private final String pattern; /** @@ -77,7 +79,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher{@code com/**/test} - matches all destinations ending with {@code test} * underneath the {@code com} path * - * * @param pattern the pattern to use */ public SimpDestinationMessageMatcher(String pattern) { @@ -87,7 +88,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher * Creates a new instance with the specified pattern and {@link PathMatcher}. - * * @param pattern the pattern to use * @param pathMatcher the {@link PathMatcher} to use. */ @@ -99,89 +99,70 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher * Creates a new instance with the specified pattern, {@link SimpMessageType}, and * {@link PathMatcher}. - * * @param pattern the pattern to use * @param type the {@link SimpMessageType} to match on or null if any * {@link SimpMessageType} should be matched. * @param pathMatcher the {@link PathMatcher} to use. */ - private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, - PathMatcher pathMatcher) { + private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, PathMatcher pathMatcher) { Assert.notNull(pattern, "pattern cannot be null"); Assert.notNull(pathMatcher, "pathMatcher cannot be null"); - if (!isTypeWithDestination(type)) { - throw new IllegalArgumentException("SimpMessageType " + type - + " does not contain a destination and so cannot be matched on."); - } - + Assert.isTrue(isTypeWithDestination(type), + () -> "SimpMessageType " + type + " does not contain a destination and so cannot be matched on."); this.matcher = pathMatcher; - this.messageTypeMatcher = type == null ? ANY_MESSAGE - : new SimpMessageTypeMatcher(type); + this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE; this.pattern = pattern; } + @Override public boolean matches(Message message) { - if (!messageTypeMatcher.matches(message)) { + if (!this.messageTypeMatcher.matches(message)) { return false; } - - String destination = SimpMessageHeaderAccessor.getDestination(message - .getHeaders()); - return destination != null && matcher.match(pattern, destination); + String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); + return destination != null && this.matcher.match(this.pattern, destination); } - - public Map extractPathVariables(Message message){ - final String destination = SimpMessageHeaderAccessor.getDestination(message - .getHeaders()); - return destination != null ? matcher.extractUriTemplateVariables(pattern, destination) + public Map extractPathVariables(Message message) { + final String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); + return (destination != null) ? this.matcher.extractUriTemplateVariables(this.pattern, destination) : Collections.emptyMap(); } public MessageMatcher getMessageTypeMatcher() { - return messageTypeMatcher; + return this.messageTypeMatcher; } @Override public String toString() { - return "SimpDestinationMessageMatcher [matcher=" + matcher - + ", messageTypeMatcher=" + messageTypeMatcher + ", pattern=" + pattern - + "]"; + return "SimpDestinationMessageMatcher [matcher=" + this.matcher + ", messageTypeMatcher=" + + this.messageTypeMatcher + ", pattern=" + this.pattern + "]"; } private boolean isTypeWithDestination(SimpMessageType type) { - if (type == null) { - return true; - } - return SimpMessageType.MESSAGE.equals(type) - || SimpMessageType.SUBSCRIBE.equals(type); + return type == null || SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type); } /** *

        * Creates a new instance with the specified pattern, * {@code SimpMessageType.SUBSCRIBE}, and {@link PathMatcher}. - * * @param pattern the pattern to use * @param matcher the {@link PathMatcher} to use. */ - public static SimpDestinationMessageMatcher createSubscribeMatcher(String pattern, - PathMatcher matcher) { - return new SimpDestinationMessageMatcher(pattern, SimpMessageType.SUBSCRIBE, - matcher); + public static SimpDestinationMessageMatcher createSubscribeMatcher(String pattern, PathMatcher matcher) { + return new SimpDestinationMessageMatcher(pattern, SimpMessageType.SUBSCRIBE, matcher); } /** *

        * Creates a new instance with the specified pattern, {@code SimpMessageType.MESSAGE}, * and {@link PathMatcher}. - * * @param pattern the pattern to use * @param matcher the {@link PathMatcher} to use. */ - public static SimpDestinationMessageMatcher createMessageMatcher(String pattern, - PathMatcher matcher) { - return new SimpDestinationMessageMatcher(pattern, SimpMessageType.MESSAGE, - matcher); + public static SimpDestinationMessageMatcher createMessageMatcher(String pattern, PathMatcher matcher) { + return new SimpDestinationMessageMatcher(pattern, SimpMessageType.MESSAGE, matcher); } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java index 24a66b8f3e..00aa6e8a64 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.util.matcher; import org.springframework.messaging.Message; @@ -26,16 +27,16 @@ import org.springframework.util.ObjectUtils; * A {@link MessageMatcher} that matches if the provided {@link Message} has a type that * is the same as the {@link SimpMessageType} that was specified in the constructor. * - * @since 4.0 * @author Rob Winch + * @since 4.0 * */ public class SimpMessageTypeMatcher implements MessageMatcher { + private final SimpMessageType typeToMatch; /** * Creates a new instance - * * @param typeToMatch the {@link SimpMessageType} that will result in a match. Cannot * be null. */ @@ -48,8 +49,7 @@ public class SimpMessageTypeMatcher implements MessageMatcher { public boolean matches(Message message) { MessageHeaders headers = message.getHeaders(); SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); - - return typeToMatch == messageType; + return this.typeToMatch == messageType; } @Override @@ -62,7 +62,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher { } SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other; return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch); - } @Override @@ -73,6 +72,7 @@ public class SimpMessageTypeMatcher implements MessageMatcher { @Override public String toString() { - return "SimpMessageTypeMatcher [typeToMatch=" + typeToMatch + "]"; + return "SimpMessageTypeMatcher [typeToMatch=" + this.typeToMatch + "]"; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java index c59858e1a3..059b34bddb 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.web.csrf; import java.util.Map; @@ -37,31 +38,27 @@ import org.springframework.security.web.csrf.MissingCsrfTokenException; * @since 4.0 */ public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter { - private final MessageMatcher matcher = new SimpMessageTypeMatcher( - SimpMessageType.CONNECT); + + private final MessageMatcher matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT); @Override public Message preSend(Message message, MessageChannel channel) { - if (!matcher.matches(message)) { + if (!this.matcher.matches(message)) { return message; } - - Map sessionAttributes = SimpMessageHeaderAccessor - .getSessionAttributes(message.getHeaders()); - CsrfToken expectedToken = sessionAttributes == null ? null - : (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()); - + Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); + CsrfToken expectedToken = (sessionAttributes != null) + ? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null; if (expectedToken == null) { throw new MissingCsrfTokenException(null); } - String actualTokenValue = SimpMessageHeaderAccessor.wrap(message) .getFirstNativeHeader(expectedToken.getHeaderName()); - boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue); - if (csrfCheckPassed) { - return message; + if (!csrfCheckPassed) { + throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); } - throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); + return message; } + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java index 591b87dbb6..aa40975f2f 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.web.socket.server; import java.util.Map; @@ -36,11 +37,10 @@ import org.springframework.web.socket.server.HandshakeInterceptor; */ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor { - public boolean beforeHandshake(ServerHttpRequest request, - ServerHttpResponse response, WebSocketHandler wsHandler, + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) { - HttpServletRequest httpRequest = ((ServletServerHttpRequest) request) - .getServletRequest(); + HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest(); CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName()); if (token == null) { return true; @@ -49,7 +49,9 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor return true; } - public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler wsHandler, Exception exception) { + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + Exception exception) { } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandlerTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandlerTests.java index d5dcdfc2c7..13277bc737 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandlerTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/expression/DefaultMessageSecurityExpressionHandlerTests.java @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.messaging.access.expression; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +package org.springframework.security.messaging.access.expression; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.messaging.Message; @@ -36,10 +35,15 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + @RunWith(MockitoJUnitRunner.class) public class DefaultMessageSecurityExpressionHandlerTests { + @Mock AuthenticationTrustResolver trustResolver; + @Mock PermissionEvaluator permissionEvaluator; @@ -51,65 +55,52 @@ public class DefaultMessageSecurityExpressionHandlerTests { @Before public void setup() { - handler = new DefaultMessageSecurityExpressionHandler<>(); - - message = new GenericMessage<>(""); - authentication = new AnonymousAuthenticationToken("key", "anonymous", + this.handler = new DefaultMessageSecurityExpressionHandler<>(); + this.message = new GenericMessage<>(""); + this.authentication = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); } // SEC-2705 @Test public void trustResolverPopulated() { - EvaluationContext context = handler.createEvaluationContext(authentication, - message); - Expression expression = handler.getExpressionParser().parseExpression( - "authenticated"); - + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.message); + Expression expression = this.handler.getExpressionParser().parseExpression("authenticated"); assertThat(ExpressionUtils.evaluateAsBoolean(expression, context)).isFalse(); } @Test(expected = IllegalArgumentException.class) public void trustResolverNull() { - handler.setTrustResolver(null); + this.handler.setTrustResolver(null); } @Test public void trustResolverCustom() { - handler.setTrustResolver(trustResolver); - EvaluationContext context = handler.createEvaluationContext(authentication, - message); - Expression expression = handler.getExpressionParser().parseExpression( - "authenticated"); - when(trustResolver.isAnonymous(authentication)).thenReturn(false); - + this.handler.setTrustResolver(this.trustResolver); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.message); + Expression expression = this.handler.getExpressionParser().parseExpression("authenticated"); + given(this.trustResolver.isAnonymous(this.authentication)).willReturn(false); assertThat(ExpressionUtils.evaluateAsBoolean(expression, context)).isTrue(); } @Test public void roleHierarchy() { - authentication = new TestingAuthenticationToken("admin", "pass", "ROLE_ADMIN"); + this.authentication = new TestingAuthenticationToken("admin", "pass", "ROLE_ADMIN"); RoleHierarchyImpl roleHierarchy = new RoleHierarchyImpl(); roleHierarchy.setHierarchy("ROLE_ADMIN > ROLE_USER"); - handler.setRoleHierarchy(roleHierarchy); - EvaluationContext context = handler.createEvaluationContext(authentication, - message); - Expression expression = handler.getExpressionParser().parseExpression( - "hasRole('ROLE_USER')"); - + this.handler.setRoleHierarchy(roleHierarchy); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.message); + Expression expression = this.handler.getExpressionParser().parseExpression("hasRole('ROLE_USER')"); assertThat(ExpressionUtils.evaluateAsBoolean(expression, context)).isTrue(); } @Test public void permissionEvaluator() { - handler.setPermissionEvaluator(permissionEvaluator); - EvaluationContext context = handler.createEvaluationContext(authentication, - message); - Expression expression = handler.getExpressionParser().parseExpression( - "hasPermission(message, 'read')"); - when(permissionEvaluator.hasPermission(authentication, message, "read")) - .thenReturn(true); - + this.handler.setPermissionEvaluator(this.permissionEvaluator); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.message); + Expression expression = this.handler.getExpressionParser().parseExpression("hasPermission(message, 'read')"); + given(this.permissionEvaluator.hasPermission(this.authentication, this.message, "read")).willReturn(true); assertThat(ExpressionUtils.evaluateAsBoolean(expression, context)).isTrue(); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java index 3d5f813866..f4a66f8761 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java @@ -13,34 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; -import static org.assertj.core.api.Assertions.assertThat; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory.*; +import java.util.Collection; +import java.util.LinkedHashMap; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.core.Authentication; import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource; import org.springframework.security.messaging.util.matcher.MessageMatcher; -import java.util.Collection; -import java.util.LinkedHashMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; @RunWith(MockitoJUnitRunner.class) public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests { + @Mock MessageMatcher matcher1; + @Mock MessageMatcher matcher2; + @Mock Message message; + @Mock Authentication authentication; @@ -56,49 +61,42 @@ public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests { @Before public void setup() { - expression1 = "permitAll"; - expression2 = "denyAll"; - matcherToExpression = new LinkedHashMap<>(); - matcherToExpression.put(matcher1, expression1); - matcherToExpression.put(matcher2, expression2); - - source = createExpressionMessageMetadataSource(matcherToExpression); - rootObject = new MessageSecurityExpressionRoot(authentication, message); + this.expression1 = "permitAll"; + this.expression2 = "denyAll"; + this.matcherToExpression = new LinkedHashMap<>(); + this.matcherToExpression.put(this.matcher1, this.expression1); + this.matcherToExpression.put(this.matcher2, this.expression2); + this.source = ExpressionBasedMessageSecurityMetadataSourceFactory + .createExpressionMessageMetadataSource(this.matcherToExpression); + this.rootObject = new MessageSecurityExpressionRoot(this.authentication, this.message); } @Test public void createExpressionMessageMetadataSourceNoMatch() { - - Collection attrs = source.getAttributes(message); - + Collection attrs = this.source.getAttributes(this.message); assertThat(attrs).isNull(); } @Test public void createExpressionMessageMetadataSourceMatchFirst() { - when(matcher1.matches(message)).thenReturn(true); - - Collection attrs = source.getAttributes(message); - + given(this.matcher1.matches(this.message)).willReturn(true); + Collection attrs = this.source.getAttributes(this.message); assertThat(attrs).hasSize(1); ConfigAttribute attr = attrs.iterator().next(); assertThat(attr).isInstanceOf(MessageExpressionConfigAttribute.class); - assertThat( - ((MessageExpressionConfigAttribute) attr).getAuthorizeExpression() - .getValue(rootObject)).isEqualTo(true); + assertThat(((MessageExpressionConfigAttribute) attr).getAuthorizeExpression().getValue(this.rootObject)) + .isEqualTo(true); } @Test public void createExpressionMessageMetadataSourceMatchSecond() { - when(matcher2.matches(message)).thenReturn(true); - - Collection attrs = source.getAttributes(message); - + given(this.matcher2.matches(this.message)).willReturn(true); + Collection attrs = this.source.getAttributes(this.message); assertThat(attrs).hasSize(1); ConfigAttribute attr = attrs.iterator().next(); assertThat(attr).isInstanceOf(MessageExpressionConfigAttribute.class); - assertThat( - ((MessageExpressionConfigAttribute) attr).getAuthorizeExpression() - .getValue(rootObject)).isEqualTo(false); + assertThat(((MessageExpressionConfigAttribute) attr).getAuthorizeExpression().getValue(this.rootObject)) + .isEqualTo(false); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttributeTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttributeTests.java index 95d9a9f901..5a12f89ea9 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttributeTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttributeTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; import org.junit.Before; @@ -20,6 +21,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.messaging.Message; @@ -29,10 +31,13 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher; import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @RunWith(MockitoJUnitRunner.class) public class MessageExpressionConfigAttributeTests { + @Mock Expression expression; @@ -43,45 +48,47 @@ public class MessageExpressionConfigAttributeTests { @Before public void setup() { - attribute = new MessageExpressionConfigAttribute(expression, matcher); + this.attribute = new MessageExpressionConfigAttribute(this.expression, this.matcher); } @Test(expected = IllegalArgumentException.class) public void constructorNullExpression() { - new MessageExpressionConfigAttribute(null, matcher); + new MessageExpressionConfigAttribute(null, this.matcher); } @Test(expected = IllegalArgumentException.class) public void constructorNullMatcher() { - new MessageExpressionConfigAttribute(expression, null); + new MessageExpressionConfigAttribute(this.expression, null); } @Test public void getAuthorizeExpression() { - assertThat(attribute.getAuthorizeExpression()).isSameAs(expression); + assertThat(this.attribute.getAuthorizeExpression()).isSameAs(this.expression); } @Test public void getAttribute() { - assertThat(attribute.getAttribute()).isNull(); + assertThat(this.attribute.getAttribute()).isNull(); } @Test public void toStringUsesExpressionString() { - when(expression.getExpressionString()).thenReturn("toString"); - - assertThat(attribute.toString()).isEqualTo(expression.getExpressionString()); + given(this.expression.getExpressionString()).willReturn("toString"); + assertThat(this.attribute.toString()).isEqualTo(this.expression.getExpressionString()); } @Test public void postProcessContext() { SimpDestinationMessageMatcher matcher = new SimpDestinationMessageMatcher("/topics/{topic}/**"); - Message message = MessageBuilder.withPayload("M").setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/topics/someTopic/sub1").build(); + // @formatter:off + Message message = MessageBuilder.withPayload("M") + .setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/topics/someTopic/sub1") + .build(); + // @formatter:on EvaluationContext context = mock(EvaluationContext.class); - - attribute = new MessageExpressionConfigAttribute(expression, matcher); - attribute.postProcess(context, message); - + this.attribute = new MessageExpressionConfigAttribute(this.expression, matcher); + this.attribute.postProcess(context, message); verify(context).setVariable("topic", "someTopic"); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionVoterTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionVoterTests.java index 93104c608a..700e2714cb 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionVoterTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/expression/MessageExpressionVoterTests.java @@ -13,42 +13,55 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.expression; +import java.util.Arrays; +import java.util.Collection; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.messaging.Message; +import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.core.Authentication; import org.springframework.security.messaging.util.matcher.MessageMatcher; -import java.util.Arrays; -import java.util.Collection; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; -import static org.springframework.security.access.AccessDecisionVoter.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @RunWith(MockitoJUnitRunner.class) public class MessageExpressionVoterTests { + @Mock Authentication authentication; + @Mock Message message; + Collection attributes; + @Mock Expression expression; + @Mock MessageMatcher matcher; + @Mock SecurityExpressionHandler expressionHandler; + @Mock EvaluationContext evaluationContext; @@ -56,87 +69,81 @@ public class MessageExpressionVoterTests { @Before public void setup() { - attributes = Arrays - . asList(new MessageExpressionConfigAttribute(expression, matcher)); - - voter = new MessageExpressionVoter(); + this.attributes = Arrays + .asList(new MessageExpressionConfigAttribute(this.expression, this.matcher)); + this.voter = new MessageExpressionVoter(); } @Test public void voteGranted() { - when(expression.getValue(any(EvaluationContext.class), eq(Boolean.class))) - .thenReturn(true); - assertThat(voter.vote(authentication, message, attributes)).isEqualTo( - ACCESS_GRANTED); + given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(true); + assertThat(this.voter.vote(this.authentication, this.message, this.attributes)) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); } @Test public void voteDenied() { - when(expression.getValue(any(EvaluationContext.class), eq(Boolean.class))) - .thenReturn(false); - assertThat(voter.vote(authentication, message, attributes)).isEqualTo( - ACCESS_DENIED); + given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(false); + assertThat(this.voter.vote(this.authentication, this.message, this.attributes)) + .isEqualTo(AccessDecisionVoter.ACCESS_DENIED); } @Test public void voteAbstain() { - attributes = Arrays. asList(new SecurityConfig("ROLE_USER")); - assertThat(voter.vote(authentication, message, attributes)).isEqualTo( - ACCESS_ABSTAIN); + this.attributes = Arrays.asList(new SecurityConfig("ROLE_USER")); + assertThat(this.voter.vote(this.authentication, this.message, this.attributes)) + .isEqualTo(AccessDecisionVoter.ACCESS_ABSTAIN); } @Test public void supportsObjectClassFalse() { - assertThat(voter.supports(Object.class)).isFalse(); + assertThat(this.voter.supports(Object.class)).isFalse(); } @Test public void supportsMessageClassTrue() { - assertThat(voter.supports(Message.class)).isTrue(); + assertThat(this.voter.supports(Message.class)).isTrue(); } @Test public void supportsSecurityConfigFalse() { - assertThat(voter.supports(new SecurityConfig("ROLE_USER"))).isFalse(); + assertThat(this.voter.supports(new SecurityConfig("ROLE_USER"))).isFalse(); } @Test public void supportsMessageExpressionConfigAttributeTrue() { - assertThat(voter.supports(new MessageExpressionConfigAttribute(expression, matcher))) - .isTrue(); + assertThat(this.voter.supports(new MessageExpressionConfigAttribute(this.expression, this.matcher))).isTrue(); } @Test(expected = IllegalArgumentException.class) public void setExpressionHandlerNull() { - voter.setExpressionHandler(null); + this.voter.setExpressionHandler(null); } @Test public void customExpressionHandler() { - voter.setExpressionHandler(expressionHandler); - when(expressionHandler.createEvaluationContext(authentication, message)) - .thenReturn(evaluationContext); - when(expression.getValue(evaluationContext, Boolean.class)).thenReturn(true); - - assertThat(voter.vote(authentication, message, attributes)).isEqualTo( - ACCESS_GRANTED); - - verify(expressionHandler).createEvaluationContext(authentication, message); + this.voter.setExpressionHandler(this.expressionHandler); + given(this.expressionHandler.createEvaluationContext(this.authentication, this.message)) + .willReturn(this.evaluationContext); + given(this.expression.getValue(this.evaluationContext, Boolean.class)).willReturn(true); + assertThat(this.voter.vote(this.authentication, this.message, this.attributes)) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + verify(this.expressionHandler).createEvaluationContext(this.authentication, this.message); } @Test - public void postProcessEvaluationContext(){ + public void postProcessEvaluationContext() { final MessageExpressionConfigAttribute configAttribute = mock(MessageExpressionConfigAttribute.class); - voter.setExpressionHandler(expressionHandler); - when(expressionHandler.createEvaluationContext(authentication, message)).thenReturn(evaluationContext); - when(configAttribute.getAuthorizeExpression()).thenReturn(expression); - attributes = Arrays. asList(configAttribute); - when(configAttribute.postProcess(evaluationContext, message)).thenReturn(evaluationContext); - when(expression.getValue(any(EvaluationContext.class), eq(Boolean.class))) - .thenReturn(true); - - assertThat(voter.vote(authentication, message, attributes)).isEqualTo( - ACCESS_GRANTED); - verify(configAttribute).postProcess(evaluationContext, message); + this.voter.setExpressionHandler(this.expressionHandler); + given(this.expressionHandler.createEvaluationContext(this.authentication, this.message)) + .willReturn(this.evaluationContext); + given(configAttribute.getAuthorizeExpression()).willReturn(this.expression); + this.attributes = Arrays.asList(configAttribute); + given(configAttribute.postProcess(this.evaluationContext, this.message)).willReturn(this.evaluationContext); + given(this.expression.getValue(any(EvaluationContext.class), eq(Boolean.class))).willReturn(true); + assertThat(this.voter.vote(this.authentication, this.message, this.attributes)) + .isEqualTo(AccessDecisionVoter.ACCESS_GRANTED); + verify(configAttribute).postProcess(this.evaluationContext, this.message); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptorTests.java index 991186ea48..94d546c32e 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptorTests.java @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.intercept; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.security.access.AccessDecisionManager; @@ -32,28 +38,30 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; @RunWith(MockitoJUnitRunner.class) public class ChannelSecurityInterceptorTests { + @Mock Message message; + @Mock MessageChannel channel; + @Mock MessageSecurityMetadataSource source; + @Mock AccessDecisionManager accessDecisionManager; + @Mock RunAsManager runAsManager; + @Mock Authentication runAs; @@ -65,13 +73,12 @@ public class ChannelSecurityInterceptorTests { @Before public void setup() { - attrs = Arrays. asList(new SecurityConfig("ROLE_USER")); - interceptor = new ChannelSecurityInterceptor(source); - interceptor.setAccessDecisionManager(accessDecisionManager); - interceptor.setRunAsManager(runAsManager); - - originalAuth = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); - SecurityContextHolder.getContext().setAuthentication(originalAuth); + this.attrs = Arrays.asList(new SecurityConfig("ROLE_USER")); + this.interceptor = new ChannelSecurityInterceptor(this.source); + this.interceptor.setAccessDecisionManager(this.accessDecisionManager); + this.interceptor.setRunAsManager(this.runAsManager); + this.originalAuth = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(this.originalAuth); } @After @@ -86,92 +93,76 @@ public class ChannelSecurityInterceptorTests { @Test public void getSecureObjectClass() { - assertThat(interceptor.getSecureObjectClass()).isEqualTo(Message.class); + assertThat(this.interceptor.getSecureObjectClass()).isEqualTo(Message.class); } @Test public void obtainSecurityMetadataSource() { - assertThat(interceptor.obtainSecurityMetadataSource()).isEqualTo(source); + assertThat(this.interceptor.obtainSecurityMetadataSource()).isEqualTo(this.source); } @Test public void preSendNullAttributes() { - assertThat(interceptor.preSend(message, channel)).isSameAs(message); + assertThat(this.interceptor.preSend(this.message, this.channel)).isSameAs(this.message); } @Test public void preSendGrant() { - when(source.getAttributes(message)).thenReturn(attrs); - - Message result = interceptor.preSend(message, channel); - - assertThat(result).isSameAs(message); + given(this.source.getAttributes(this.message)).willReturn(this.attrs); + Message result = this.interceptor.preSend(this.message, this.channel); + assertThat(result).isSameAs(this.message); } @Test(expected = AccessDeniedException.class) public void preSendDeny() { - when(source.getAttributes(message)).thenReturn(attrs); - doThrow(new AccessDeniedException("")).when(accessDecisionManager).decide( - any(Authentication.class), eq(message), eq(attrs)); - - interceptor.preSend(message, channel); + given(this.source.getAttributes(this.message)).willReturn(this.attrs); + willThrow(new AccessDeniedException("")).given(this.accessDecisionManager).decide(any(Authentication.class), + eq(this.message), eq(this.attrs)); + this.interceptor.preSend(this.message, this.channel); } @SuppressWarnings("unchecked") @Test public void preSendPostSendRunAs() { - when(source.getAttributes(message)).thenReturn(attrs); - when( - runAsManager.buildRunAs(any(Authentication.class), any(), - any(Collection.class))).thenReturn(runAs); - - Message preSend = interceptor.preSend(message, channel); - - assertThat(SecurityContextHolder.getContext().getAuthentication()) - .isSameAs(runAs); - - interceptor.postSend(preSend, channel, true); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - originalAuth); + given(this.source.getAttributes(this.message)).willReturn(this.attrs); + given(this.runAsManager.buildRunAs(any(Authentication.class), any(), any(Collection.class))) + .willReturn(this.runAs); + Message preSend = this.interceptor.preSend(this.message, this.channel); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.runAs); + this.interceptor.postSend(preSend, this.channel, true); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.originalAuth); } @Test public void afterSendCompletionNotTokenMessageNoExceptionThrown() { - interceptor.afterSendCompletion(message, channel, true, null); + this.interceptor.afterSendCompletion(this.message, this.channel, true, null); } @SuppressWarnings("unchecked") @Test public void preSendFinallySendRunAs() { - when(source.getAttributes(message)).thenReturn(attrs); - when( - runAsManager.buildRunAs(any(Authentication.class), any(), - any(Collection.class))).thenReturn(runAs); - - Message preSend = interceptor.preSend(message, channel); - - assertThat(SecurityContextHolder.getContext().getAuthentication()) - .isSameAs(runAs); - - interceptor.afterSendCompletion(preSend, channel, true, new RuntimeException()); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - originalAuth); + given(this.source.getAttributes(this.message)).willReturn(this.attrs); + given(this.runAsManager.buildRunAs(any(Authentication.class), any(), any(Collection.class))) + .willReturn(this.runAs); + Message preSend = this.interceptor.preSend(this.message, this.channel); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.runAs); + this.interceptor.afterSendCompletion(preSend, this.channel, true, new RuntimeException()); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.originalAuth); } @Test public void preReceive() { - assertThat(interceptor.preReceive(channel)).isTrue(); + assertThat(this.interceptor.preReceive(this.channel)).isTrue(); } @Test public void postReceive() { - assertThat(interceptor.postReceive(message, channel)).isSameAs(message); + assertThat(this.interceptor.postReceive(this.message, this.channel)).isSameAs(this.message); } @Test public void afterReceiveCompletionNullExceptionNoExceptionThrown() { - interceptor.afterReceiveCompletion(message, channel, null); + this.interceptor.afterReceiveCompletion(this.message, this.channel, null); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSourceTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSourceTests.java index 6702abe267..4bf00db940 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSourceTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSourceTests.java @@ -13,34 +13,40 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.access.intercept; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedHashMap; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.SecurityConfig; import org.springframework.security.core.Authentication; import org.springframework.security.messaging.util.matcher.MessageMatcher; -import java.util.Arrays; -import java.util.Collection; -import java.util.LinkedHashMap; - import static org.assertj.core.api.Assertions.assertThat; -import static org.powermock.api.mockito.PowerMockito.when; +import static org.mockito.BDDMockito.given; @RunWith(MockitoJUnitRunner.class) public class DefaultMessageSecurityMetadataSourceTests { + @Mock MessageMatcher matcher1; + @Mock MessageMatcher matcher2; + @Mock Message message; + @Mock Authentication authentication; @@ -54,44 +60,42 @@ public class DefaultMessageSecurityMetadataSourceTests { @Before public void setup() { - messageMap = new LinkedHashMap<>(); - messageMap.put(matcher1, Arrays. asList(config1)); - messageMap.put(matcher2, Arrays. asList(config2)); - - source = new DefaultMessageSecurityMetadataSource(messageMap); + this.messageMap = new LinkedHashMap<>(); + this.messageMap.put(this.matcher1, Arrays.asList(this.config1)); + this.messageMap.put(this.matcher2, Arrays.asList(this.config2)); + this.source = new DefaultMessageSecurityMetadataSource(this.messageMap); } @Test public void getAttributesNull() { - assertThat(source.getAttributes(message)).isNull(); + assertThat(this.source.getAttributes(this.message)).isNull(); } @Test public void getAttributesFirst() { - when(matcher1.matches(message)).thenReturn(true); - - assertThat(source.getAttributes(message)).containsOnly(config1); + given(this.matcher1.matches(this.message)).willReturn(true); + assertThat(this.source.getAttributes(this.message)).containsOnly(this.config1); } @Test public void getAttributesSecond() { - when(matcher1.matches(message)).thenReturn(true); - - assertThat(source.getAttributes(message)).containsOnly(config2); + given(this.matcher1.matches(this.message)).willReturn(true); + assertThat(this.source.getAttributes(this.message)).containsOnly(this.config2); } @Test public void getAllConfigAttributes() { - assertThat(source.getAllConfigAttributes()).containsOnly(config1, config2); + assertThat(this.source.getAllConfigAttributes()).containsOnly(this.config1, this.config2); } @Test public void supportsFalse() { - assertThat(source.supports(Object.class)).isFalse(); + assertThat(this.source.supports(Object.class)).isFalse(); } @Test public void supportsTrue() { - assertThat(source.supports(Message.class)).isTrue(); + assertThat(this.source.supports(Message.class)).isTrue(); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolverTests.java index 55e31111b1..fe3e6c2323 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolverTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolverTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.messaging.context; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.messaging.context; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -26,6 +25,7 @@ import java.lang.reflect.Method; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.core.MethodParameter; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.annotation.AuthenticationPrincipal; @@ -35,17 +35,21 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.util.ReflectionUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Rob Winch * */ public class AuthenticationPrincipalArgumentResolverTests { + private Object expectedPrincipal; + private AuthenticationPrincipalArgumentResolver resolver; @Before public void setup() { - resolver = new AuthenticationPrincipalArgumentResolver(); + this.resolver = new AuthenticationPrincipalArgumentResolver(); } @After @@ -55,65 +59,60 @@ public class AuthenticationPrincipalArgumentResolverTests { @Test public void supportsParameterNoAnnotation() { - assertThat(resolver.supportsParameter(showUserNoAnnotation())).isFalse(); + assertThat(this.resolver.supportsParameter(showUserNoAnnotation())).isFalse(); } @Test public void supportsParameterAnnotation() { - assertThat(resolver.supportsParameter(showUserAnnotationObject())).isTrue(); + assertThat(this.resolver.supportsParameter(showUserAnnotationObject())).isTrue(); } @Test public void supportsParameterCustomAnnotation() { - assertThat(resolver.supportsParameter(showUserCustomAnnotation())).isTrue(); + assertThat(this.resolver.supportsParameter(showUserCustomAnnotation())).isTrue(); } @Test public void resolveArgumentNullAuthentication() throws Exception { - assertThat(resolver.resolveArgument(showUserAnnotationString(), null)).isNull(); + assertThat(this.resolver.resolveArgument(showUserAnnotationString(), null)).isNull(); } @Test public void resolveArgumentNullPrincipal() throws Exception { setAuthenticationPrincipal(null); - assertThat(resolver.resolveArgument(showUserAnnotationString(), null)).isNull(); + assertThat(this.resolver.resolveArgument(showUserAnnotationString(), null)).isNull(); } @Test public void resolveArgumentString() throws Exception { setAuthenticationPrincipal("john"); - assertThat(resolver.resolveArgument(showUserAnnotationString(), null)).isEqualTo( - expectedPrincipal); + assertThat(this.resolver.resolveArgument(showUserAnnotationString(), null)).isEqualTo(this.expectedPrincipal); } @Test public void resolveArgumentPrincipalStringOnObject() throws Exception { setAuthenticationPrincipal("john"); - assertThat(resolver.resolveArgument(showUserAnnotationObject(), null)).isEqualTo( - expectedPrincipal); + assertThat(this.resolver.resolveArgument(showUserAnnotationObject(), null)).isEqualTo(this.expectedPrincipal); } @Test public void resolveArgumentUserDetails() throws Exception { - setAuthenticationPrincipal(new User("user", "password", - AuthorityUtils.createAuthorityList("ROLE_USER"))); - assertThat(resolver.resolveArgument(showUserAnnotationUserDetails(), null)) - .isEqualTo(expectedPrincipal); + setAuthenticationPrincipal(new User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))); + assertThat(this.resolver.resolveArgument(showUserAnnotationUserDetails(), null)) + .isEqualTo(this.expectedPrincipal); } @Test public void resolveArgumentCustomUserPrincipal() throws Exception { setAuthenticationPrincipal(new CustomUserPrincipal()); - assertThat( - resolver.resolveArgument(showUserAnnotationCustomUserPrincipal(), null)) - .isEqualTo(expectedPrincipal); + assertThat(this.resolver.resolveArgument(showUserAnnotationCustomUserPrincipal(), null)) + .isEqualTo(this.expectedPrincipal); } @Test public void resolveArgumentCustomAnnotation() throws Exception { setAuthenticationPrincipal(new CustomUserPrincipal()); - assertThat(resolver.resolveArgument(showUserCustomAnnotation(), null)).isEqualTo( - expectedPrincipal); + assertThat(this.resolver.resolveArgument(showUserCustomAnnotation(), null)).isEqualTo(this.expectedPrincipal); } @Test @@ -121,8 +120,7 @@ public class AuthenticationPrincipalArgumentResolverTests { CustomUserPrincipal principal = new CustomUserPrincipal(); setAuthenticationPrincipal(principal); this.expectedPrincipal = principal.property; - assertThat(this.resolver.resolveArgument(showUserSpel(), null)) - .isEqualTo(this.expectedPrincipal); + assertThat(this.resolver.resolveArgument(showUserSpel(), null)).isEqualTo(this.expectedPrincipal); } @Test @@ -137,26 +135,25 @@ public class AuthenticationPrincipalArgumentResolverTests { @Test public void resolveArgumentNullOnInvalidType() throws Exception { setAuthenticationPrincipal(new CustomUserPrincipal()); - assertThat(resolver.resolveArgument(showUserAnnotationString(), null)).isNull(); + assertThat(this.resolver.resolveArgument(showUserAnnotationString(), null)).isNull(); } @Test(expected = ClassCastException.class) public void resolveArgumentErrorOnInvalidType() throws Exception { setAuthenticationPrincipal(new CustomUserPrincipal()); - resolver.resolveArgument(showUserAnnotationErrorOnInvalidType(), null); + this.resolver.resolveArgument(showUserAnnotationErrorOnInvalidType(), null); } @Test(expected = ClassCastException.class) public void resolveArgumentCustomserErrorOnInvalidType() throws Exception { setAuthenticationPrincipal(new CustomUserPrincipal()); - resolver.resolveArgument(showUserAnnotationCurrentUserErrorOnInvalidType(), null); + this.resolver.resolveArgument(showUserAnnotationCurrentUserErrorOnInvalidType(), null); } @Test public void resolveArgumentObject() throws Exception { setAuthenticationPrincipal(new Object()); - assertThat(resolver.resolveArgument(showUserAnnotationObject(), null)).isEqualTo( - expectedPrincipal); + assertThat(this.resolver.resolveArgument(showUserAnnotationObject(), null)).isEqualTo(this.expectedPrincipal); } private MethodParameter showUserNoAnnotation() { @@ -172,8 +169,7 @@ public class AuthenticationPrincipalArgumentResolverTests { } private MethodParameter showUserAnnotationCurrentUserErrorOnInvalidType() { - return getMethodParameter("showUserAnnotationCurrentUserErrorOnInvalidType", - String.class); + return getMethodParameter("showUserAnnotationCurrentUserErrorOnInvalidType", String.class); } private MethodParameter showUserAnnotationUserDetails() { @@ -201,24 +197,32 @@ public class AuthenticationPrincipalArgumentResolverTests { } private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { - Method method = ReflectionUtils.findMethod(TestController.class, methodName, - paramTypes); + Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes); return new MethodParameter(method, 0); } + private void setAuthenticationPrincipal(Object principal) { + this.expectedPrincipal = principal; + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken(this.expectedPrincipal, "password", "ROLE_USER")); + } + @Target({ ElementType.PARAMETER }) @Retention(RetentionPolicy.RUNTIME) @AuthenticationPrincipal static @interface CurrentUser { + } @Target({ ElementType.PARAMETER }) @Retention(RetentionPolicy.RUNTIME) @AuthenticationPrincipal(errorOnInvalidType = true) static @interface CurrentUserErrorOnInvalidType { + } public static class TestController { + public void showUserNoAnnotation(String user) { } @@ -229,8 +233,7 @@ public class AuthenticationPrincipalArgumentResolverTests { @AuthenticationPrincipal(errorOnInvalidType = true) String user) { } - public void showUserAnnotationCurrentUserErrorOnInvalidType( - @CurrentUserErrorOnInvalidType String user) { + public void showUserAnnotationCurrentUserErrorOnInvalidType(@CurrentUserErrorOnInvalidType String user) { } public void showUserAnnotation(@AuthenticationPrincipal UserDetails user) { @@ -245,20 +248,23 @@ public class AuthenticationPrincipalArgumentResolverTests { public void showUserAnnotation(@AuthenticationPrincipal Object user) { } - public void showUserSpel( - @AuthenticationPrincipal(expression = "property") String user) { + public void showUserSpel(@AuthenticationPrincipal(expression = "property") String user) { } - public void showUserSpelCopy( - @AuthenticationPrincipal(expression = "new org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolverTests$CopyUserPrincipal(#this)") CopyUserPrincipal user) { + public void showUserSpelCopy(@AuthenticationPrincipal( + expression = "new org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolverTests$CopyUserPrincipal(#this)") CopyUserPrincipal user) { } + } static class CustomUserPrincipal { + public final String property = "property"; + } public static class CopyUserPrincipal { + public final String property; public CopyUserPrincipal(String property) { @@ -269,15 +275,6 @@ public class AuthenticationPrincipalArgumentResolverTests { this.property = toCopy.property; } - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result - + ((this.property == null) ? 0 : this.property.hashCode()); - return result; - } - @Override public boolean equals(Object obj) { if (this == obj) { @@ -300,13 +297,15 @@ public class AuthenticationPrincipalArgumentResolverTests { } return true; } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.property == null) ? 0 : this.property.hashCode()); + return result; + } + } - private void setAuthenticationPrincipal(Object principal) { - this.expectedPrincipal = principal; - SecurityContextHolder.getContext() - .setAuthentication( - new TestingAuthenticationToken(expectedPrincipal, "password", - "ROLE_USER")); - } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java index 56e784bc6a..c11683c321 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.context; +import java.security.Principal; + import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -30,19 +34,18 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; - -import java.security.Principal; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.core.context.SecurityContextHolder.*; @RunWith(MockitoJUnitRunner.class) public class SecurityContextChannelInterceptorTests { + @Mock MessageChannel channel; + @Mock MessageHandler handler; + @Mock Principal principal; @@ -56,17 +59,16 @@ public class SecurityContextChannelInterceptorTests { @Before public void setup() { - authentication = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); - messageBuilder = MessageBuilder.withPayload("payload"); - expectedAnonymous = new AnonymousAuthenticationToken("key", "anonymous", + this.authentication = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); + this.messageBuilder = MessageBuilder.withPayload("payload"); + this.expectedAnonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); - - interceptor = new SecurityContextChannelInterceptor(); + this.interceptor = new SecurityContextChannelInterceptor(); } @After public void cleanup() { - clearContext(); + SecurityContextHolder.clearContext(); } @Test(expected = IllegalArgumentException.class) @@ -77,144 +79,113 @@ public class SecurityContextChannelInterceptorTests { @Test public void preSendCustomHeader() { String headerName = "header"; - interceptor = new SecurityContextChannelInterceptor(headerName); - messageBuilder.setHeader(headerName, authentication); - - interceptor.preSend(messageBuilder.build(), channel); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - authentication); + this.interceptor = new SecurityContextChannelInterceptor(headerName); + this.messageBuilder.setHeader(headerName, this.authentication); + this.interceptor.preSend(this.messageBuilder.build(), this.channel); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); } @Test public void preSendUserSet() { - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication); - - interceptor.preSend(messageBuilder.build(), channel); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - authentication); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.preSend(this.messageBuilder.build(), this.channel); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); } @Test(expected = IllegalArgumentException.class) public void setAnonymousAuthenticationNull() { - interceptor.setAnonymousAuthentication(null); + this.interceptor.setAnonymousAuthentication(null); } @Test public void preSendUsesCustomAnonymous() { - expectedAnonymous = new AnonymousAuthenticationToken("customKey", - "customAnonymous", AuthorityUtils.createAuthorityList("ROLE_CUSTOM")); - interceptor.setAnonymousAuthentication(expectedAnonymous); - - interceptor.preSend(messageBuilder.build(), channel); - + this.expectedAnonymous = new AnonymousAuthenticationToken("customKey", "customAnonymous", + AuthorityUtils.createAuthorityList("ROLE_CUSTOM")); + this.interceptor.setAnonymousAuthentication(this.expectedAnonymous); + this.interceptor.preSend(this.messageBuilder.build(), this.channel); assertAnonymous(); } // SEC-2845 @Test public void preSendUserNotAuthentication() { - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, principal); - - interceptor.preSend(messageBuilder.build(), channel); - + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.principal); + this.interceptor.preSend(this.messageBuilder.build(), this.channel); assertAnonymous(); } // SEC-2845 @Test public void preSendUserNotSet() { - interceptor.preSend(messageBuilder.build(), channel); - + this.interceptor.preSend(this.messageBuilder.build(), this.channel); assertAnonymous(); } // SEC-2845 @Test public void preSendUserNotSetCustomAnonymous() { - interceptor.preSend(messageBuilder.build(), channel); - + this.interceptor.preSend(this.messageBuilder.build(), this.channel); assertAnonymous(); } @Test public void afterSendCompletion() { - SecurityContextHolder.getContext().setAuthentication(authentication); - - interceptor.afterSendCompletion(messageBuilder.build(), channel, true, null); - + SecurityContextHolder.getContext().setAuthentication(this.authentication); + this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } @Test public void afterSendCompletionNullAuthentication() { - interceptor.afterSendCompletion(messageBuilder.build(), channel, true, null); - + this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } @Test public void beforeHandleUserSet() { - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication); - - interceptor.beforeHandle(messageBuilder.build(), channel, handler); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - authentication); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); } // SEC-2845 @Test public void beforeHandleUserNotAuthentication() { - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, principal); - - interceptor.beforeHandle(messageBuilder.build(), channel, handler); - + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.principal); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); assertAnonymous(); } // SEC-2845 @Test public void beforeHandleUserNotSet() { - interceptor.beforeHandle(messageBuilder.build(), channel, handler); - + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); assertAnonymous(); } @Test public void afterMessageHandledUserNotSet() { - interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); - + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } @Test public void afterMessageHandled() { - SecurityContextHolder.getContext().setAuthentication(authentication); - - interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); - + SecurityContextHolder.getContext().setAuthentication(this.authentication); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } // SEC-2829 @Test public void restoresOriginalContext() { - TestingAuthenticationToken original = new TestingAuthenticationToken("original", - "original", "ROLE_USER"); + TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER"); SecurityContextHolder.getContext().setAuthentication(original); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication); - interceptor.beforeHandle(messageBuilder.build(), channel, handler); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - authentication); - - interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - original); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original); } /** @@ -223,48 +194,31 @@ public class SecurityContextChannelInterceptorTests { */ @Test public void restoresOriginalContextNestedThreeDeep() { - AnonymousAuthenticationToken anonymous = new AnonymousAuthenticationToken("key", - "anonymous", AuthorityUtils.createAuthorityList("ROLE_USER")); - - TestingAuthenticationToken origional = new TestingAuthenticationToken("original", - "origional", "ROLE_USER"); + AnonymousAuthenticationToken anonymous = new AnonymousAuthenticationToken("key", "anonymous", + AuthorityUtils.createAuthorityList("ROLE_USER")); + TestingAuthenticationToken origional = new TestingAuthenticationToken("original", "origional", "ROLE_USER"); SecurityContextHolder.getContext().setAuthentication(origional); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication); - interceptor.beforeHandle(messageBuilder.build(), channel, handler); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - authentication); - + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); // start send websocket - messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, null); - interceptor.beforeHandle(messageBuilder.build(), channel, handler); - - assertThat(SecurityContextHolder.getContext().getAuthentication().getName()) - .isEqualTo(anonymous.getName()); - - interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - authentication); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, null); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo(anonymous.getName()); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); // end send websocket - - interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); - - assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs( - origional); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(origional); } private void assertAnonymous() { - Authentication currentAuthentication = SecurityContextHolder.getContext() - .getAuthentication(); - assertThat(currentAuthentication) - .isInstanceOf(AnonymousAuthenticationToken.class); - + Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); + assertThat(currentAuthentication).isInstanceOf(AnonymousAuthenticationToken.class); AnonymousAuthenticationToken anonymous = (AnonymousAuthenticationToken) currentAuthentication; - assertThat(anonymous.getName()).isEqualTo(expectedAnonymous.getName()); - assertThat(anonymous.getAuthorities()).containsOnlyElementsOf( - expectedAnonymous.getAuthorities()); - assertThat(anonymous.getKeyHash()).isEqualTo(expectedAnonymous.getKeyHash()); + assertThat(anonymous.getName()).isEqualTo(this.expectedAnonymous.getName()); + assertThat(anonymous.getAuthorities()).containsOnlyElementsOf(this.expectedAnonymous.getAuthorities()); + assertThat(anonymous.getKeyHash()).isEqualTo(this.expectedAnonymous.getKeyHash()); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java index 7651e2eb2b..f7e259aeed 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java @@ -27,6 +27,7 @@ import java.util.Set; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.aopalliance.intercept.MethodInterceptor; import org.apache.commons.logging.Log; @@ -54,30 +55,30 @@ import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; -import static java.util.stream.Collectors.joining; - /** - * NOTE: This class is a replica of the same class in spring-web so it can - * be used for tests in spring-messaging. + * NOTE: This class is a replica of the same class in spring-web so it can be used for + * tests in spring-messaging. * - *

        Convenience class to resolve method parameters from hints. + *

        + * Convenience class to resolve method parameters from hints. * *

        Background

        * - *

        When testing annotated methods we create test classes such as - * "TestController" with a diverse range of method signatures representing - * supported annotations and argument types. It becomes challenging to use - * naming strategies to keep track of methods and arguments especially in - * combination with variables for reflection metadata. + *

        + * When testing annotated methods we create test classes such as "TestController" with a + * diverse range of method signatures representing supported annotations and argument + * types. It becomes challenging to use naming strategies to keep track of methods and + * arguments especially in combination with variables for reflection metadata. * - *

        The idea with {@link ResolvableMethod} is NOT to rely on naming techniques - * but to use hints to zero in on method parameters. Such hints can be strongly - * typed and explicit about what is being tested. + *

        + * The idea with {@link ResolvableMethod} is NOT to rely on naming techniques but to use + * hints to zero in on method parameters. Such hints can be strongly typed and explicit + * about what is being tested. * *

        1. Declared Return Type

        * - * When testing return types it's likely to have many methods with a unique - * return type, possibly with or without an annotation. + * When testing return types it's likely to have many methods with a unique return type, + * possibly with or without an annotation. * *
          * import static org.springframework.web.method.ResolvableMethod.on;
        @@ -100,8 +101,8 @@ import static java.util.stream.Collectors.joining;
          *
          * 

        2. Method Arguments

        * - * When testing method arguments it's more likely to have one or a small number - * of methods with a wide array of argument types and parameter annotations. + * When testing method arguments it's more likely to have one or a small number of methods + * with a wide array of argument types and parameter annotations. * *
          * import static org.springframework.web.method.MvcAnnotationPredicates.requestParam;
        @@ -119,13 +120,13 @@ import static java.util.stream.Collectors.joining;
          * Locate a method by invoking it through a proxy of the target handler:
          *
          * 
        - * ResolvableMethod.on(TestController.class).mockCall(o -> o.handle(null)).method();
        + * ResolvableMethod.on(TestController.class).mockCall((o) -> o.handle(null)).method();
          * 
        * * @author Rossen Stoyanchev * @since 5.2 */ -public class ResolvableMethod { +public final class ResolvableMethod { private static final Log logger = LogFactory.getLog(ResolvableMethod.class); @@ -136,16 +137,13 @@ public class ResolvableMethod { // Matches ValueConstants.DEFAULT_NONE (spring-web and spring-messaging) private static final String DEFAULT_VALUE_NONE = "\n\t\t\n\t\t\n\uE000\uE001\uE002\n\t\t\t\t\n"; - private final Method method; - private ResolvableMethod(Method method) { Assert.notNull(method, "'method' is required"); this.method = method; } - /** * Return the resolved method. */ @@ -188,8 +186,8 @@ public class ResolvableMethod { } /** - * Filter on method arguments with annotation. - * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + * Filter on method arguments with annotation. See + * {@link org.springframework.web.method.MvcAnnotationPredicates}. */ @SafeVarargs public final ArgResolver annot(Predicate... filter) { @@ -210,25 +208,22 @@ public class ResolvableMethod { return new ArgResolver().annotNotPresent(annotationTypes); } - @Override public String toString() { return "ResolvableMethod=" + formatMethod(); } - private String formatMethod() { - return (method().getName() + - Arrays.stream(this.method.getParameters()) - .map(this::formatParameter) - .collect(joining(",\n\t", "(\n\t", "\n)"))); + return (method().getName() + Arrays.stream(this.method.getParameters()).map(this::formatParameter) + .collect(Collectors.joining(",\n\t", "(\n\t", "\n)"))); } private String formatParameter(Parameter param) { Annotation[] anns = param.getAnnotations(); - return (anns.length > 0 ? - Arrays.stream(anns).map(this::formatAnnotation).collect(joining(",", "[", "]")) + " " + param : - param.toString()); + return (anns.length > 0) + ? Arrays.stream(anns).map(this::formatAnnotation).collect(Collectors.joining(",", "[", "]")) + " " + + param + : param.toString(); } private String formatAnnotation(Annotation annotation) { @@ -242,8 +237,8 @@ public class ResolvableMethod { } private static ResolvableType toResolvableType(Class type, Class... generics) { - return (ObjectUtils.isEmpty(generics) ? ResolvableType.forClass(type) : - ResolvableType.forClassWithGenerics(type, generics)); + return (ObjectUtils.isEmpty(generics) ? ResolvableType.forClass(type) + : ResolvableType.forClassWithGenerics(type, generics)); } private static ResolvableType toResolvableType(Class type, ResolvableType generic, ResolvableType... generics) { @@ -253,7 +248,6 @@ public class ResolvableMethod { return ResolvableType.forClassWithGenerics(type, genericTypes); } - /** * Create a {@code ResolvableMethod} builder for the given handler class. */ @@ -261,23 +255,61 @@ public class ResolvableMethod { return new Builder<>(objectClass); } + @SuppressWarnings("unchecked") + private static T initProxy(Class type, MethodInvocationInterceptor interceptor) { + Assert.notNull(type, "'type' must not be null"); + if (type.isInterface()) { + ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE); + factory.addInterface(type); + factory.addInterface(Supplier.class); + factory.addAdvice(interceptor); + return (T) factory.getProxy(); + } + else { + Enhancer enhancer = new Enhancer(); + enhancer.setSuperclass(type); + enhancer.setInterfaces(new Class[] { Supplier.class }); + enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE); + enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class); + Class proxyClass = enhancer.createClass(); + Object proxy = null; + if (objenesis.isWorthTrying()) { + try { + proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache()); + } + catch (ObjenesisException ex) { + logger.debug("Objenesis failed, falling back to default constructor", ex); + } + } + if (proxy == null) { + try { + proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance(); + } + catch (Throwable ex) { + throw new IllegalStateException( + "Unable to instantiate proxy " + "via both Objenesis and default constructor fails as well", + ex); + } + } + ((Factory) proxy).setCallbacks(new Callback[] { interceptor }); + return (T) proxy; + } + } /** * Builder for {@code ResolvableMethod}. */ - public static class Builder { + public static final class Builder { private final Class objectClass; private final List> filters = new ArrayList<>(4); - private Builder(Class objectClass) { Assert.notNull(objectClass, "Class must not be null"); this.objectClass = objectClass; } - private void addFilter(String message, Predicate filter) { this.filters.add(new LabeledPredicate<>(message, filter)); } @@ -286,7 +318,7 @@ public class ResolvableMethod { * Filter on methods with the given name. */ public Builder named(String methodName) { - addFilter("methodName=" + methodName, method -> method.getName().equals(methodName)); + addFilter("methodName=" + methodName, (method) -> method.getName().equals(methodName)); return this; } @@ -294,15 +326,14 @@ public class ResolvableMethod { * Filter on methods with the given parameter types. */ public Builder argTypes(Class... argTypes) { - addFilter("argTypes=" + Arrays.toString(argTypes), method -> - ObjectUtils.isEmpty(argTypes) ? method.getParameterCount() == 0 : - Arrays.equals(method.getParameterTypes(), argTypes)); + addFilter("argTypes=" + Arrays.toString(argTypes), (method) -> ObjectUtils.isEmpty(argTypes) + ? method.getParameterCount() == 0 : Arrays.equals(method.getParameterTypes(), argTypes)); return this; } /** - * Filter on annotated methods. - * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + * Filter on annotated methods. See + * {@link org.springframework.web.method.MvcAnnotationPredicates}. */ @SafeVarargs public final Builder annot(Predicate... filters) { @@ -312,15 +343,14 @@ public class ResolvableMethod { /** * Filter on methods annotated with the given annotation type. - * @see #annot(Predicate[]) - * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + * @see #annot(Predicate[]) See + * {@link org.springframework.web.method.MvcAnnotationPredicates}. */ @SafeVarargs public final Builder annotPresent(Class... annotationTypes) { String message = "annotationPresent=" + Arrays.toString(annotationTypes); - addFilter(message, method -> - Arrays.stream(annotationTypes).allMatch(annotType -> - AnnotatedElementUtils.findMergedAnnotation(method, annotType) != null)); + addFilter(message, (candidate) -> Arrays.stream(annotationTypes) + .allMatch((annotType) -> AnnotatedElementUtils.findMergedAnnotation(candidate, annotType) != null)); return this; } @@ -330,13 +360,13 @@ public class ResolvableMethod { @SafeVarargs public final Builder annotNotPresent(Class... annotationTypes) { String message = "annotationNotPresent=" + Arrays.toString(annotationTypes); - addFilter(message, method -> { + addFilter(message, (candidate) -> { if (annotationTypes.length != 0) { - return Arrays.stream(annotationTypes).noneMatch(annotType -> - AnnotatedElementUtils.findMergedAnnotation(method, annotType) != null); + return Arrays.stream(annotationTypes).noneMatch( + (annotType) -> AnnotatedElementUtils.findMergedAnnotation(candidate, annotType) != null); } else { - return method.getAnnotations().length == 0; + return candidate.getAnnotations().length == 0; } }); return this; @@ -368,15 +398,16 @@ public class ResolvableMethod { public Builder returning(ResolvableType returnType) { String expected = returnType.toString(); String message = "returnType=" + expected; - addFilter(message, m -> expected.equals(ResolvableType.forMethodReturnType(m).toString())); + addFilter(message, (m) -> expected.equals(ResolvableType.forMethodReturnType(m).toString())); return this; } /** - * Build a {@code ResolvableMethod} from the provided filters which must - * resolve to a unique, single method. - *

        See additional resolveXxx shortcut methods going directly to - * {@link Method} or return type parameter. + * Build a {@code ResolvableMethod} from the provided filters which must resolve + * to a unique, single method. + *

        + * See additional resolveXxx shortcut methods going directly to {@link Method} or + * return type parameter. * @throws IllegalStateException for no match or multiple matches */ public ResolvableMethod method() { @@ -387,12 +418,12 @@ public class ResolvableMethod { } private boolean isMatch(Method method) { - return this.filters.stream().allMatch(p -> p.test(method)); + return this.filters.stream().allMatch((p) -> p.test(method)); } private String formatMethods(Set methods) { - return "\nMatched:\n" + methods.stream() - .map(Method::toGenericString).collect(joining(",\n\t", "[\n\t", "\n]")); + return "\nMatched:\n" + methods.stream().map(Method::toGenericString) + .collect(Collectors.joining(",\n\t", "[\n\t", "\n]")); } public ResolvableMethod mockCall(Consumer invoker) { @@ -403,20 +434,20 @@ public class ResolvableMethod { return new ResolvableMethod(method); } - // Build & resolve shortcuts... - /** * Resolve and return the {@code Method} equivalent to: - *

        {@code build().method()} + *

        + * {@code build().method()} */ - public final Method resolveMethod() { + public Method resolveMethod() { return method().method(); } /** * Resolve and return the {@code Method} equivalent to: - *

        {@code named(methodName).build().method()} + *

        + * {@code named(methodName).build().method()} */ public Method resolveMethod(String methodName) { return named(methodName).method().method(); @@ -424,15 +455,17 @@ public class ResolvableMethod { /** * Resolve and return the declared return type equivalent to: - *

        {@code build().returnType()} + *

        + * {@code build().returnType()} */ - public final MethodParameter resolveReturnType() { + public MethodParameter resolveReturnType() { return method().returnType(); } /** * Shortcut to the unique return type equivalent to: - *

        {@code returning(returnType).build().returnType()} + *

        + * {@code returning(returnType).build().returnType()} * @param returnType the return type * @param generics optional array of generic types */ @@ -442,14 +475,14 @@ public class ResolvableMethod { /** * Shortcut to the unique return type equivalent to: - *

        {@code returning(returnType).build().returnType()} + *

        + * {@code returning(returnType).build().returnType()} * @param returnType the return type * @param generic at least one generic type * @param generics optional extra generic types */ public MethodParameter resolveReturnType(Class returnType, ResolvableType generic, ResolvableType... generics) { - return returning(returnType, generic, generics).method().returnType(); } @@ -457,37 +490,33 @@ public class ResolvableMethod { return returning(returnType).method().returnType(); } - @Override public String toString() { - return "ResolvableMethod.Builder[\n" + - "\tobjectClass = " + this.objectClass.getName() + ",\n" + - "\tfilters = " + formatFilters() + "\n]"; + return "ResolvableMethod.Builder[\n" + "\tobjectClass = " + this.objectClass.getName() + ",\n" + + "\tfilters = " + formatFilters() + "\n]"; } private String formatFilters() { return this.filters.stream().map(Object::toString) - .collect(joining(",\n\t\t", "[\n\t\t", "\n\t]")); + .collect(Collectors.joining(",\n\t\t", "[\n\t\t", "\n\t]")); } - } + } /** * Predicate with a descriptive label. */ - private static class LabeledPredicate implements Predicate { + private static final class LabeledPredicate implements Predicate { private final String label; private final Predicate delegate; - private LabeledPredicate(String label, Predicate delegate) { this.label = label; this.delegate = delegate; } - @Override public boolean test(T method) { return this.delegate.test(method); @@ -512,25 +541,24 @@ public class ResolvableMethod { public String toString() { return this.label; } - } + } /** * Resolver for method arguments. */ - public class ArgResolver { + public final class ArgResolver { private final List> filters = new ArrayList<>(4); - @SafeVarargs private ArgResolver(Predicate... filter) { this.filters.addAll(Arrays.asList(filter)); } /** - * Filter on method arguments with annotations. - * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + * Filter on method arguments with annotations. See + * {@link org.springframework.web.method.MvcAnnotationPredicates}. */ @SafeVarargs public final ArgResolver annot(Predicate... filters) { @@ -541,12 +569,12 @@ public class ResolvableMethod { /** * Filter on method arguments that have the given annotations. * @param annotationTypes the annotation types - * @see #annot(Predicate[]) - * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + * @see #annot(Predicate[]) See + * {@link org.springframework.web.method.MvcAnnotationPredicates}. */ @SafeVarargs public final ArgResolver annotPresent(Class... annotationTypes) { - this.filters.add(param -> Arrays.stream(annotationTypes).allMatch(param::hasParameterAnnotation)); + this.filters.add((param) -> Arrays.stream(annotationTypes).allMatch(param::hasParameterAnnotation)); return this; } @@ -556,10 +584,9 @@ public class ResolvableMethod { */ @SafeVarargs public final ArgResolver annotNotPresent(Class... annotationTypes) { - this.filters.add(param -> - (annotationTypes.length > 0 ? - Arrays.stream(annotationTypes).noneMatch(param::hasParameterAnnotation) : - param.getParameterAnnotations().length == 0)); + this.filters.add((param) -> (annotationTypes.length > 0) + ? Arrays.stream(annotationTypes).noneMatch(param::hasParameterAnnotation) + : param.getParameterAnnotations().length == 0); return this; } @@ -584,43 +611,40 @@ public class ResolvableMethod { * @param type the expected type */ public MethodParameter arg(ResolvableType type) { - this.filters.add(p -> type.toString().equals(ResolvableType.forMethodParameter(p).toString())); + this.filters.add((p) -> type.toString().equals(ResolvableType.forMethodParameter(p).toString())); return arg(); } /** * Resolve the argument. */ - public final MethodParameter arg() { + public MethodParameter arg() { List matches = applyFilters(); - Assert.state(!matches.isEmpty(), () -> - "No matching arg in method\n" + formatMethod()); - Assert.state(matches.size() == 1, () -> - "Multiple matching args in method\n" + formatMethod() + "\nMatches:\n\t" + matches); + Assert.state(!matches.isEmpty(), () -> "No matching arg in method\n" + formatMethod()); + Assert.state(matches.size() == 1, + () -> "Multiple matching args in method\n" + formatMethod() + "\nMatches:\n\t" + matches); return matches.get(0); } - private List applyFilters() { List matches = new ArrayList<>(); - for (int i = 0; i < method.getParameterCount(); i++) { - MethodParameter param = new SynthesizingMethodParameter(method, i); + for (int i = 0; i < ResolvableMethod.this.method.getParameterCount(); i++) { + MethodParameter param = new SynthesizingMethodParameter(ResolvableMethod.this.method, i); param.initParameterNameDiscovery(nameDiscoverer); - if (this.filters.stream().allMatch(p -> p.test(param))) { + if (this.filters.stream().allMatch((p) -> p.test(param))) { matches.add(param); } } return matches; } - } + } private static class MethodInvocationInterceptor implements org.springframework.cglib.proxy.MethodInterceptor, MethodInterceptor { private Method invokedMethod; - Method getInvokedMethod() { return this.invokedMethod; } @@ -642,51 +666,7 @@ public class ResolvableMethod { public Object invoke(org.aopalliance.intercept.MethodInvocation inv) throws Throwable { return intercept(inv.getThis(), inv.getMethod(), inv.getArguments(), null); } - } - @SuppressWarnings("unchecked") - private static T initProxy(Class type, MethodInvocationInterceptor interceptor) { - Assert.notNull(type, "'type' must not be null"); - if (type.isInterface()) { - ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE); - factory.addInterface(type); - factory.addInterface(Supplier.class); - factory.addAdvice(interceptor); - return (T) factory.getProxy(); - } - - else { - Enhancer enhancer = new Enhancer(); - enhancer.setSuperclass(type); - enhancer.setInterfaces(new Class[] {Supplier.class}); - enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE); - enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class); - - Class proxyClass = enhancer.createClass(); - Object proxy = null; - - if (objenesis.isWorthTrying()) { - try { - proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache()); - } - catch (ObjenesisException ex) { - logger.debug("Objenesis failed, falling back to default constructor", ex); - } - } - - if (proxy == null) { - try { - proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance(); - } - catch (Throwable ex) { - throw new IllegalStateException("Unable to instantiate proxy " + - "via both Objenesis and default constructor fails as well", ex); - } - } - - ((Factory) proxy).setCallbacks(new Callback[] {interceptor}); - return (T) proxy; - } } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java index c2748493cf..ebc7fe9514 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java @@ -16,7 +16,12 @@ package org.springframework.security.messaging.handler.invocation.reactive; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.core.MethodParameter; import org.springframework.core.annotation.SynthesizingMethodParameter; import org.springframework.security.authentication.TestAuthentication; @@ -25,17 +30,14 @@ import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.messaging.handler.invocation.ResolvableMethod; -import reactor.core.publisher.Mono; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch */ public class AuthenticationPrincipalArgumentResolverTests { + private AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver(); @Test @@ -52,9 +54,12 @@ public class AuthenticationPrincipalArgumentResolverTests { @Test public void resolveArgumentWhenAuthenticationPrincipalThenFound() { Authentication authentication = TestAuthentication.authenticatedUser(); - Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalOnMonoUserDetails"), null) + // @formatter:off + Mono result = (Mono) this.resolver + .resolveArgument(arg0("authenticationPrincipalOnMonoUserDetails"), null) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .block(); + // @formatter:on assertThat(result.block()).isEqualTo(authentication.getPrincipal()); } @@ -70,9 +75,12 @@ public class AuthenticationPrincipalArgumentResolverTests { @Test public void resolveArgumentWhenMonoAndAuthenticationPrincipalThenFound() { Authentication authentication = TestAuthentication.authenticatedUser(); - Mono result = (Mono) this.resolver.resolveArgument(arg0("currentUserOnMonoUserDetails"), null) + // @formatter:off + Mono result = (Mono) this.resolver + .resolveArgument(arg0("currentUserOnMonoUserDetails"), null) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .block(); + // @formatter:on assertThat(result.block()).isEqualTo(authentication.getPrincipal()); } @@ -83,14 +91,18 @@ public class AuthenticationPrincipalArgumentResolverTests { @Test public void resolveArgumentWhenExpressionThenFound() { Authentication authentication = TestAuthentication.authenticatedUser(); - Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalExpression"), null) + // @formatter:off + Mono result = (Mono) this.resolver + .resolveArgument(arg0("authenticationPrincipalExpression"), null) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .block(); + // @formatter:on assertThat(result.block()).isEqualTo(authentication.getName()); } @SuppressWarnings("unused") - private void authenticationPrincipalExpression(@AuthenticationPrincipal(expression = "username") Mono username) { + private void authenticationPrincipalExpression( + @AuthenticationPrincipal(expression = "username") Mono username) { } @Test @@ -109,5 +121,8 @@ public class AuthenticationPrincipalArgumentResolverTests { @AuthenticationPrincipal @Retention(RetentionPolicy.RUNTIME) - @interface CurrentUser {} + @interface CurrentUser { + + } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java index dba93aa552..79f86ce422 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java @@ -16,7 +16,12 @@ package org.springframework.security.messaging.handler.invocation.reactive; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.core.MethodParameter; import org.springframework.core.annotation.SynthesizingMethodParameter; import org.springframework.security.authentication.TestAuthentication; @@ -26,10 +31,6 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.messaging.handler.invocation.ResolvableMethod; -import reactor.core.publisher.Mono; - -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; import static org.assertj.core.api.Assertions.assertThat; @@ -37,6 +38,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Rob Winch */ public class CurrentSecurityContextArgumentResolverTests { + private CurrentSecurityContextArgumentResolver resolver = new CurrentSecurityContextArgumentResolver(); @Test @@ -46,16 +48,17 @@ public class CurrentSecurityContextArgumentResolverTests { @Test public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() { - Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null).block(); + Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null) + .block(); assertThat(result).isNull(); } @Test public void resolveArgumentWhenAuthenticationPrincipalThenFound() { Authentication authentication = TestAuthentication.authenticatedUser(); - Mono result = (Mono) this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .block(); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)).block(); assertThat(result.block().getAuthentication()).isEqualTo(authentication); } @@ -71,9 +74,9 @@ public class CurrentSecurityContextArgumentResolverTests { @Test public void resolveArgumentWhenMonoAndAuthenticationPrincipalThenFound() { Authentication authentication = TestAuthentication.authenticatedUser(); - Mono result = (Mono) this.resolver.resolveArgument(arg0("currentUserOnMonoUserDetails"), null) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .block(); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("currentUserOnMonoUserDetails"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)).block(); assertThat(result.block()).isEqualTo(authentication.getPrincipal()); } @@ -84,14 +87,15 @@ public class CurrentSecurityContextArgumentResolverTests { @Test public void resolveArgumentWhenExpressionThenFound() { Authentication authentication = TestAuthentication.authenticatedUser(); - Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalExpression"), null) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .block(); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("authenticationPrincipalExpression"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)).block(); assertThat(result.block()).isEqualTo(authentication.getName()); } @SuppressWarnings("unused") - private void authenticationPrincipalExpression(@CurrentSecurityContext(expression = "authentication?.principal?.username") Mono username) { + private void authenticationPrincipalExpression( + @CurrentSecurityContext(expression = "authentication?.principal?.username") Mono username) { } @Test @@ -110,5 +114,8 @@ public class CurrentSecurityContextArgumentResolverTests { @CurrentSecurityContext(expression = "authentication?.principal") @Retention(RetentionPolicy.RUNTIME) - @interface CurrentUser {} + @interface CurrentUser { + + } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/AndMessageMatcherTest.java b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/AndMessageMatcherTests.java similarity index 66% rename from messaging/src/test/java/org/springframework/security/messaging/util/matcher/AndMessageMatcherTest.java rename to messaging/src/test/java/org/springframework/security/messaging/util/matcher/AndMessageMatcherTests.java index a765f18e77..368b50d839 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/AndMessageMatcherTest.java +++ b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/AndMessageMatcherTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.messaging.util.matcher; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +package org.springframework.security.messaging.util.matcher; import java.util.Arrays; import java.util.Collections; @@ -26,10 +24,15 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + @RunWith(MockitoJUnitRunner.class) -public class AndMessageMatcherTest { +public class AndMessageMatcherTests { + @Mock private MessageMatcher delegate; @@ -54,7 +57,7 @@ public class AndMessageMatcherTest { @SuppressWarnings("unchecked") @Test(expected = IllegalArgumentException.class) public void constructorEmptyArray() { - new AndMessageMatcher<>((MessageMatcher[]) new MessageMatcher[0]); + new AndMessageMatcher<>(new MessageMatcher[0]); } @Test(expected = IllegalArgumentException.class) @@ -74,43 +77,39 @@ public class AndMessageMatcherTest { @Test public void matchesSingleTrue() { - when(delegate.matches(message)).thenReturn(true); - matcher = new AndMessageMatcher<>(delegate); - - assertThat(matcher.matches(message)).isTrue(); + given(this.delegate.matches(this.message)).willReturn(true); + this.matcher = new AndMessageMatcher<>(this.delegate); + assertThat(this.matcher.matches(this.message)).isTrue(); } @Test public void matchesMultiTrue() { - when(delegate.matches(message)).thenReturn(true); - when(delegate2.matches(message)).thenReturn(true); - matcher = new AndMessageMatcher<>(delegate, delegate2); - - assertThat(matcher.matches(message)).isTrue(); + given(this.delegate.matches(this.message)).willReturn(true); + given(this.delegate2.matches(this.message)).willReturn(true); + this.matcher = new AndMessageMatcher<>(this.delegate, this.delegate2); + assertThat(this.matcher.matches(this.message)).isTrue(); } @Test public void matchesSingleFalse() { - when(delegate.matches(message)).thenReturn(false); - matcher = new AndMessageMatcher<>(delegate); - - assertThat(matcher.matches(message)).isFalse(); + given(this.delegate.matches(this.message)).willReturn(false); + this.matcher = new AndMessageMatcher<>(this.delegate); + assertThat(this.matcher.matches(this.message)).isFalse(); } @Test public void matchesMultiBothFalse() { - when(delegate.matches(message)).thenReturn(false); - matcher = new AndMessageMatcher<>(delegate, delegate2); - - assertThat(matcher.matches(message)).isFalse(); + given(this.delegate.matches(this.message)).willReturn(false); + this.matcher = new AndMessageMatcher<>(this.delegate, this.delegate2); + assertThat(this.matcher.matches(this.message)).isFalse(); } @Test public void matchesMultiSingleFalse() { - when(delegate.matches(message)).thenReturn(true); - when(delegate2.matches(message)).thenReturn(false); - matcher = new AndMessageMatcher<>(delegate, delegate2); - - assertThat(matcher.matches(message)).isFalse(); + given(this.delegate.matches(this.message)).willReturn(true); + given(this.delegate2.matches(this.message)).willReturn(false); + this.matcher = new AndMessageMatcher<>(this.delegate, this.delegate2); + assertThat(this.matcher.matches(this.message)).isFalse(); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/OrMessageMatcherTest.java b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/OrMessageMatcherTests.java similarity index 67% rename from messaging/src/test/java/org/springframework/security/messaging/util/matcher/OrMessageMatcherTest.java rename to messaging/src/test/java/org/springframework/security/messaging/util/matcher/OrMessageMatcherTests.java index 0e132cffe5..51aa9f3040 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/OrMessageMatcherTest.java +++ b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/OrMessageMatcherTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.messaging.util.matcher; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +package org.springframework.security.messaging.util.matcher; import java.util.Arrays; import java.util.Collections; @@ -26,10 +24,15 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + @RunWith(MockitoJUnitRunner.class) -public class OrMessageMatcherTest { +public class OrMessageMatcherTests { + @Mock private MessageMatcher delegate; @@ -54,7 +57,7 @@ public class OrMessageMatcherTest { @SuppressWarnings("unchecked") @Test(expected = IllegalArgumentException.class) public void constructorEmptyArray() { - new OrMessageMatcher<>((MessageMatcher[]) new MessageMatcher[0]); + new OrMessageMatcher<>(new MessageMatcher[0]); } @Test(expected = IllegalArgumentException.class) @@ -74,42 +77,38 @@ public class OrMessageMatcherTest { @Test public void matchesSingleTrue() { - when(delegate.matches(message)).thenReturn(true); - matcher = new OrMessageMatcher<>(delegate); - - assertThat(matcher.matches(message)).isTrue(); + given(this.delegate.matches(this.message)).willReturn(true); + this.matcher = new OrMessageMatcher<>(this.delegate); + assertThat(this.matcher.matches(this.message)).isTrue(); } @Test public void matchesMultiTrue() { - when(delegate.matches(message)).thenReturn(true); - matcher = new OrMessageMatcher<>(delegate, delegate2); - - assertThat(matcher.matches(message)).isTrue(); + given(this.delegate.matches(this.message)).willReturn(true); + this.matcher = new OrMessageMatcher<>(this.delegate, this.delegate2); + assertThat(this.matcher.matches(this.message)).isTrue(); } @Test public void matchesSingleFalse() { - when(delegate.matches(message)).thenReturn(false); - matcher = new OrMessageMatcher<>(delegate); - - assertThat(matcher.matches(message)).isFalse(); + given(this.delegate.matches(this.message)).willReturn(false); + this.matcher = new OrMessageMatcher<>(this.delegate); + assertThat(this.matcher.matches(this.message)).isFalse(); } @Test public void matchesMultiBothFalse() { - when(delegate.matches(message)).thenReturn(false); - when(delegate2.matches(message)).thenReturn(false); - matcher = new OrMessageMatcher<>(delegate, delegate2); - - assertThat(matcher.matches(message)).isFalse(); + given(this.delegate.matches(this.message)).willReturn(false); + given(this.delegate2.matches(this.message)).willReturn(false); + this.matcher = new OrMessageMatcher<>(this.delegate, this.delegate2); + assertThat(this.matcher.matches(this.message)).isFalse(); } @Test public void matchesMultiSingleFalse() { - when(delegate.matches(message)).thenReturn(true); - matcher = new OrMessageMatcher<>(delegate, delegate2); - - assertThat(matcher.matches(message)).isTrue(); + given(this.delegate.matches(this.message)).willReturn(true); + this.matcher = new OrMessageMatcher<>(this.delegate, this.delegate2); + assertThat(this.matcher.matches(this.message)).isTrue(); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcherTests.java b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcherTests.java index 4ebe6cf797..9161a95ff8 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcherTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcherTests.java @@ -13,19 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.messaging.util.matcher; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.messaging.util.matcher; import org.junit.Before; import org.junit.Test; + import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.AntPathMatcher; import org.springframework.util.PathMatcher; +import static org.assertj.core.api.Assertions.assertThat; + public class SimpDestinationMessageMatcherTests { + MessageBuilder messageBuilder; SimpDestinationMessageMatcher matcher; @@ -34,9 +37,9 @@ public class SimpDestinationMessageMatcherTests { @Before public void setup() { - messageBuilder = MessageBuilder.withPayload("M"); - matcher = new SimpDestinationMessageMatcher("/**"); - pathMatcher = new AntPathMatcher(); + this.messageBuilder = MessageBuilder.withPayload("M"); + this.matcher = new SimpDestinationMessageMatcher("/**"); + this.pathMatcher = new AntPathMatcher(); } @Test(expected = IllegalArgumentException.class) @@ -50,110 +53,79 @@ public class SimpDestinationMessageMatcherTests { @Test public void matchesDoesNotMatchNullDestination() { - assertThat(matcher.matches(messageBuilder.build())).isFalse(); + assertThat(this.matcher.matches(this.messageBuilder.build())).isFalse(); } @Test public void matchesAllWithDestination() { - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "/destination/1"); - - assertThat(matcher.matches(messageBuilder.build())).isTrue(); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/destination/1"); + assertThat(this.matcher.matches(this.messageBuilder.build())).isTrue(); } @Test public void matchesSpecificWithDestination() { - matcher = new SimpDestinationMessageMatcher("/destination/1"); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "/destination/1"); - - assertThat(matcher.matches(messageBuilder.build())).isTrue(); + this.matcher = new SimpDestinationMessageMatcher("/destination/1"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/destination/1"); + assertThat(this.matcher.matches(this.messageBuilder.build())).isTrue(); } @Test public void matchesFalseWithDestination() { - matcher = new SimpDestinationMessageMatcher("/nomatch"); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, - "/destination/1"); - - assertThat(matcher.matches(messageBuilder.build())).isFalse(); + this.matcher = new SimpDestinationMessageMatcher("/nomatch"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/destination/1"); + assertThat(this.matcher.matches(this.messageBuilder.build())).isFalse(); } @Test public void matchesFalseMessageTypeNotDisconnectType() { - matcher = SimpDestinationMessageMatcher.createMessageMatcher("/match", - pathMatcher); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.DISCONNECT); - - assertThat(matcher.matches(messageBuilder.build())).isFalse(); + this.matcher = SimpDestinationMessageMatcher.createMessageMatcher("/match", this.pathMatcher); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.DISCONNECT); + assertThat(this.matcher.matches(this.messageBuilder.build())).isFalse(); } @Test public void matchesTrueMessageType() { - matcher = SimpDestinationMessageMatcher.createMessageMatcher("/match", - pathMatcher); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/match"); - messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.MESSAGE); - - assertThat(matcher.matches(messageBuilder.build())).isTrue(); + this.matcher = SimpDestinationMessageMatcher.createMessageMatcher("/match", this.pathMatcher); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/match"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.MESSAGE); + assertThat(this.matcher.matches(this.messageBuilder.build())).isTrue(); } @Test public void matchesTrueSubscribeType() { - matcher = SimpDestinationMessageMatcher.createSubscribeMatcher("/match", - pathMatcher); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/match"); - messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.SUBSCRIBE); - - assertThat(matcher.matches(messageBuilder.build())).isTrue(); + this.matcher = SimpDestinationMessageMatcher.createSubscribeMatcher("/match", this.pathMatcher); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/match"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.SUBSCRIBE); + assertThat(this.matcher.matches(this.messageBuilder.build())).isTrue(); } @Test public void matchesNullMessageType() { - matcher = new SimpDestinationMessageMatcher("/match"); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/match"); - messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.MESSAGE); - - assertThat(matcher.matches(messageBuilder.build())).isTrue(); + this.matcher = new SimpDestinationMessageMatcher("/match"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/match"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.MESSAGE); + assertThat(this.matcher.matches(this.messageBuilder.build())).isTrue(); } @Test public void extractPathVariablesFromDestination() { - matcher = new SimpDestinationMessageMatcher("/topics/{topic}/**"); - - messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/topics/someTopic/sub1"); - messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.MESSAGE); - - assertThat(matcher.extractPathVariables(messageBuilder.build()).get("topic")).isEqualTo("someTopic"); + this.matcher = new SimpDestinationMessageMatcher("/topics/{topic}/**"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, "/topics/someTopic/sub1"); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.MESSAGE); + assertThat(this.matcher.extractPathVariables(this.messageBuilder.build()).get("topic")).isEqualTo("someTopic"); } @Test public void extractedVariablesAreEmptyInNullDestination() { - matcher = new SimpDestinationMessageMatcher("/topics/{topic}/**"); - assertThat(matcher.extractPathVariables(messageBuilder.build())).isEmpty(); + this.matcher = new SimpDestinationMessageMatcher("/topics/{topic}/**"); + assertThat(this.matcher.extractPathVariables(this.messageBuilder.build())).isEmpty(); } @Test public void typeConstructorParameterIsTransmitted() { - matcher = SimpDestinationMessageMatcher.createMessageMatcher("/match", - pathMatcher); - - MessageMatcher expectedTypeMatcher = new SimpMessageTypeMatcher( - SimpMessageType.MESSAGE); - - assertThat(matcher.getMessageTypeMatcher()).isEqualTo(expectedTypeMatcher); - + this.matcher = SimpDestinationMessageMatcher.createMessageMatcher("/match", this.pathMatcher); + MessageMatcher expectedTypeMatcher = new SimpMessageTypeMatcher(SimpMessageType.MESSAGE); + assertThat(this.matcher.getMessageTypeMatcher()).isEqualTo(expectedTypeMatcher); } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcherTests.java b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcherTests.java index 0ef2cf05a4..1f7f9c562c 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcherTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcherTests.java @@ -13,23 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.messaging.util.matcher; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.messaging.util.matcher; import org.junit.Before; import org.junit.Test; + import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; +import static org.assertj.core.api.Assertions.assertThat; + public class SimpMessageTypeMatcherTests { + private SimpMessageTypeMatcher matcher; @Before public void setup() { - matcher = new SimpMessageTypeMatcher(SimpMessageType.MESSAGE); + this.matcher = new SimpMessageTypeMatcher(SimpMessageType.MESSAGE); } @Test(expected = IllegalArgumentException.class) @@ -39,28 +42,28 @@ public class SimpMessageTypeMatcherTests { @Test public void matchesMessageMessageTrue() { - Message message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.MESSAGE).build(); - - assertThat(matcher.matches(message)).isTrue(); + // @formatter:off + Message message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.MESSAGE) + .build(); + // @formatter:on + assertThat(this.matcher.matches(message)).isTrue(); } @Test public void matchesMessageConnectFalse() { - Message message = MessageBuilder - .withPayload("Hi") - .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, - SimpMessageType.CONNECT).build(); - - assertThat(matcher.matches(message)).isFalse(); + // @formatter:off + Message message = MessageBuilder.withPayload("Hi") + .setHeader(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.CONNECT) + .build(); + // @formatter:on + assertThat(this.matcher.matches(message)).isFalse(); } @Test public void matchesMessageNullFalse() { Message message = MessageBuilder.withPayload("Hi").build(); - - assertThat(matcher.matches(message)).isFalse(); + assertThat(this.matcher.matches(message)).isFalse(); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java index 7c37cd4c60..f7d1e6d76c 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.web.csrf; import java.util.HashMap; @@ -23,6 +24,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -35,6 +37,7 @@ import org.springframework.security.web.csrf.MissingCsrfTokenException; @RunWith(MockitoJUnitRunner.class) public class CsrfChannelInterceptorTests { + @Mock MessageChannel channel; @@ -46,107 +49,94 @@ public class CsrfChannelInterceptorTests { @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); - interceptor = new CsrfChannelInterceptor(); - - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); - messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken()); - messageHeaders.setSessionAttributes(new HashMap<>()); - messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), token); + this.token = new DefaultCsrfToken("header", "param", "token"); + this.interceptor = new CsrfChannelInterceptor(); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), this.token.getToken()); + this.messageHeaders.setSessionAttributes(new HashMap<>()); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); } @Test public void preSendValidToken() { - interceptor.preSend(message(), channel); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresConnectAck() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresDisconnect() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresDisconnectAck() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresHeartbeat() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresMessage() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresOther() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresSubscribe() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); + this.interceptor.preSend(message(), this.channel); } @Test public void preSendIgnoresUnsubscribe() { - messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); - - interceptor.preSend(message(), channel); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + this.interceptor.preSend(message(), this.channel); } @Test(expected = InvalidCsrfTokenException.class) public void preSendNoToken() { - messageHeaders.removeNativeHeader(token.getHeaderName()); - - interceptor.preSend(message(), channel); + this.messageHeaders.removeNativeHeader(this.token.getHeaderName()); + this.interceptor.preSend(message(), this.channel); } @Test(expected = InvalidCsrfTokenException.class) public void preSendInvalidToken() { - messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken() - + "invalid"); - - interceptor.preSend(message(), channel); + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), this.token.getToken() + "invalid"); + this.interceptor.preSend(message(), this.channel); } @Test(expected = MissingCsrfTokenException.class) public void preSendMissingToken() { - messageHeaders.getSessionAttributes().clear(); - - interceptor.preSend(message(), channel); + this.messageHeaders.getSessionAttributes().clear(); + this.interceptor.preSend(message(), this.channel); } @Test(expected = MissingCsrfTokenException.class) public void preSendMissingTokenNullSessionAttributes() { - messageHeaders.setSessionAttributes(null); - - interceptor.preSend(message(), channel); + this.messageHeaders.setSessionAttributes(null); + this.interceptor.preSend(message(), this.channel); } private Message message() { - Map headersToCopy = messageHeaders.toMap(); + Map headersToCopy = this.messageHeaders.toMap(); return MessageBuilder.withPayload("hi").copyHeaders(headersToCopy).build(); } + } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java index 6a36e5de8c..b92390fecd 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java @@ -13,13 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.messaging.web.socket.server; -import org.junit.Test; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; @@ -28,19 +33,17 @@ import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.web.socket.WebSocketHandler; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; /** - * * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class CsrfTokenHandshakeInterceptorTests { + @Mock WebSocketHandler wsHandler; + @Mock ServerHttpResponse response; @@ -54,29 +57,25 @@ public class CsrfTokenHandshakeInterceptorTests { @Before public void setup() { - httpRequest = new MockHttpServletRequest(); - attributes = new HashMap<>(); - request = new ServletServerHttpRequest(httpRequest); - - interceptor = new CsrfTokenHandshakeInterceptor(); + this.httpRequest = new MockHttpServletRequest(); + this.attributes = new HashMap<>(); + this.request = new ServletServerHttpRequest(this.httpRequest); + this.interceptor = new CsrfTokenHandshakeInterceptor(); } @Test public void beforeHandshakeNoAttribute() throws Exception { - interceptor.beforeHandshake(request, response, wsHandler, attributes); - - assertThat(attributes).isEmpty(); + this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes); + assertThat(this.attributes).isEmpty(); } @Test public void beforeHandshake() throws Exception { CsrfToken token = new DefaultCsrfToken("header", "param", "token"); - httpRequest.setAttribute(CsrfToken.class.getName(), token); - - interceptor.beforeHandshake(request, response, wsHandler, attributes); - - assertThat(attributes.keySet()).containsOnly(CsrfToken.class.getName()); - assertThat(attributes.values()).containsOnly(token); + this.httpRequest.setAttribute(CsrfToken.class.getName(), token); + this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes); + assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName()); + assertThat(this.attributes.values()).containsOnly(token); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java index 7ff23c3ceb..2911fd90de 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.springframework.lang.Nullable; @@ -21,8 +22,8 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; /** - * An implementation of an {@link OAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. + * An implementation of an {@link OAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. * * @author Joe Grandja * @since 5.2 @@ -31,24 +32,27 @@ import org.springframework.util.Assert; public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { /** - * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns {@code null} if authorization is not supported, - * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized. - * + * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() + * client} in the provided {@code context}. Returns {@code null} if authorization is + * not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the + * client is already authorized. * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not + * supported */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && - context.getAuthorizedClient() == null) { - // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectFilter which initiates authorization + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals( + context.getClientRegistration().getAuthorizationGrantType()) && context.getAuthorizedClient() == null) { + // ClientAuthorizationRequiredException is caught by + // OAuth2AuthorizationRequestRedirectFilter which initiates authorization throw new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()); } return null; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java index 002432bd37..ab15fe304a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java @@ -13,41 +13,48 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; /** - * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. * * @author Joe Grandja * @since 5.2 * @see ReactiveOAuth2AuthorizedClientProvider */ -public final class AuthorizationCodeReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { +public final class AuthorizationCodeReactiveOAuth2AuthorizedClientProvider + implements ReactiveOAuth2AuthorizedClientProvider { /** - * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns an empty {@code Mono} if authorization is not supported, - * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized. - * + * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() + * client} in the provided {@code context}. Returns an empty {@code Mono} if + * authorization is not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the + * client is already authorized. * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization is not supported */ @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && - context.getAuthorizedClient() == null) { - // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectWebFilter which initiates authorization - return Mono.error(() -> new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId())); + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals( + context.getClientRegistration().getAuthorizationGrantType()) && context.getAuthorizedClient() == null) { + // ClientAuthorizationRequiredException is caught by + // OAuth2AuthorizationRequestRedirectWebFilter which initiates authorization + return Mono.error(() -> new ClientAuthorizationRequiredException( + context.getClientRegistration().getRegistrationId())); } return Mono.empty(); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java index 179e4f40a8..4c55ac444d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -27,41 +33,36 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - /** - * An implementation of an {@link OAuth2AuthorizedClientManager} - * that is capable of operating outside of the context of a {@code HttpServletRequest}, - * e.g. in a scheduled/background thread and/or in the service-tier. + * An implementation of an {@link OAuth2AuthorizedClientManager} that is capable of + * operating outside of the context of a {@code HttpServletRequest}, e.g. in a + * scheduled/background thread and/or in the service-tier. * *

        - * (When operating within the context of a {@code HttpServletRequest}, - * use {@link DefaultOAuth2AuthorizedClientManager} instead.) + * (When operating within the context of a {@code HttpServletRequest}, use + * {@link DefaultOAuth2AuthorizedClientManager} instead.) * *

        Authorized Client Persistence

        * *

        - * This manager utilizes an {@link OAuth2AuthorizedClientService} - * to persist {@link OAuth2AuthorizedClient}s. + * This manager utilizes an {@link OAuth2AuthorizedClientService} to persist + * {@link OAuth2AuthorizedClient}s. * *

        * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} - * will be saved in the {@link OAuth2AuthorizedClientService}. - * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler} - * via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}. + * will be saved in the {@link OAuth2AuthorizedClientService}. This functionality can be + * changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler} via + * {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}. * *

        * By default, when an authorization attempt fails due to an - * {@value OAuth2ErrorCodes#INVALID_GRANT} error, - * the previously saved {@link OAuth2AuthorizedClient} - * will be removed from the {@link OAuth2AuthorizedClientService}. - * (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur - * when a refresh token that is no longer valid is used to retrieve a new access token.) - * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler} - * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. + * {@value OAuth2ErrorCodes#INVALID_GRANT} error, the previously saved + * {@link OAuth2AuthorizedClient} will be removed from the + * {@link OAuth2AuthorizedClientService}. (The {@value OAuth2ErrorCodes#INVALID_GRANT} + * error can occur when a refresh token that is no longer valid is used to retrieve a new + * access token.) This functionality can be changed by configuring a custom + * {@link OAuth2AuthorizationFailureHandler} via + * {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. * * @author Joe Grandja * @since 5.2 @@ -72,95 +73,113 @@ import java.util.function.Function; * @see OAuth2AuthorizationFailureHandler */ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { - private static final OAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = - OAuth2AuthorizedClientProviderBuilder.builder() - .clientCredentials() - .build(); + + private static final OAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = OAuth2AuthorizedClientProviderBuilder + .builder().clientCredentials().build(); + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private Function> contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; /** - * Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using the provided parameters. - * + * Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using + * the provided parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientService the authorized client service */ - public AuthorizedClientServiceOAuth2AuthorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientService authorizedClientService) { + public AuthorizedClientServiceOAuth2AuthorizedClientManager( + ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientService authorizedClientService) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientService = authorizedClientService; this.authorizedClientProvider = DEFAULT_AUTHORIZED_CLIENT_PROVIDER; this.contextAttributesMapper = new DefaultContextAttributesMapper(); - this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> - authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> authorizedClientService + .saveAuthorizedClient(authorizedClient, principal); this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName())); + (clientRegistrationId, principal, attributes) -> authorizedClientService + .removeAuthorizedClient(clientRegistrationId, principal.getName())); } @Nullable @Override public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); - OAuth2AuthorizationContext.Builder contextBuilder; if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); - } else { - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - authorizedClient = this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()); + } + else { + ClientRegistration clientRegistration = this.clientRegistrationRepository + .findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + authorizedClient = this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, + principal.getName()); if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); - } else { + } + else { contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); } } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .principal(principal) - .attributes(attributes -> { + OAuth2AuthorizationContext authorizationContext = buildAuthorizationContext(authorizeRequest, principal, + contextBuilder); + try { + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + } + catch (OAuth2AuthorizationException ex) { + this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, Collections.emptyMap()); + throw ex; + } + if (authorizedClient != null) { + this.authorizationSuccessHandler.onAuthorizationSuccess(authorizedClient, principal, + Collections.emptyMap()); + } + else { + // In the case of re-authorization, the returned `authorizedClient` may be + // null if re-authorization is not supported. + // For these cases, return the provided + // `authorizationContext.authorizedClient`. + if (authorizationContext.getAuthorizedClient() != null) { + return authorizationContext.getAuthorizedClient(); + } + } + return authorizedClient; + } + + private OAuth2AuthorizationContext buildAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest, + Authentication principal, OAuth2AuthorizationContext.Builder contextBuilder) { + // @formatter:off + return contextBuilder.principal(principal) + .attributes((attributes) -> { Map contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); if (!CollectionUtils.isEmpty(contextAttributes)) { attributes.putAll(contextAttributes); } }) .build(); - - try { - authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - } catch (OAuth2AuthorizationException ex) { - this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, Collections.emptyMap()); - throw ex; - } - - if (authorizedClient != null) { - this.authorizationSuccessHandler.onAuthorizationSuccess( - authorizedClient, principal, Collections.emptyMap()); - } else { - // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. - // For these cases, return the provided `authorizationContext.authorizedClient`. - if (authorizationContext.getAuthorizedClient() != null) { - return authorizationContext.getAuthorizedClient(); - } - } - - return authorizedClient; + // @formatter:on } /** - * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. - * - * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or + * re-authorizing) an OAuth 2.0 Client. + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for + * authorizing (or re-authorizing) an OAuth 2.0 Client */ public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); @@ -168,24 +187,28 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen } /** - * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes - * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. - * - * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes - * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + * Sets the {@code Function} used for mapping attribute(s) from the + * {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes to be associated to + * the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * @param contextAttributesMapper the {@code Function} used for supplying the + * {@code Map} of attributes to the {@link OAuth2AuthorizationContext#getAttributes() + * authorization context} */ - public void setContextAttributesMapper(Function> contextAttributesMapper) { + public void setContextAttributesMapper( + Function> contextAttributesMapper) { Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); this.contextAttributesMapper = contextAttributesMapper; } /** - * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations. + * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful + * authorizations. * *

        - * The default saves {@link OAuth2AuthorizedClient}s in the {@link OAuth2AuthorizedClientService}. - * - * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations + * The default saves {@link OAuth2AuthorizedClient}s in the + * {@link OAuth2AuthorizedClientService}. + * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} + * that handles successful authorizations * @since 5.3 */ public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { @@ -194,14 +217,16 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen } /** - * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures. + * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization + * failures. * *

        - * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default. - * - * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures - * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler + * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by + * default. + * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} + * that handles authorization failures * @since 5.3 + * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler */ public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); @@ -209,9 +234,11 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen } /** - * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + * The default implementation of the {@link #setContextAttributesMapper(Function) + * contextAttributesMapper}. */ - public static class DefaultContextAttributesMapper implements Function> { + public static class DefaultContextAttributesMapper + implements Function> { @Override public Map apply(OAuth2AuthorizeRequest authorizeRequest) { @@ -224,5 +251,7 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen } return contextAttributes; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java index 3a446ec2da..2724d5b5e7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java @@ -13,78 +13,95 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.Map; -import java.util.function.Function; /** - * An implementation of a {@link ReactiveOAuth2AuthorizedClientManager} - * that is capable of operating outside of the context of a {@link ServerWebExchange}, - * e.g. in a scheduled/background thread and/or in the service-tier. + * An implementation of a {@link ReactiveOAuth2AuthorizedClientManager} that is capable of + * operating outside of the context of a {@link ServerWebExchange}, e.g. in a + * scheduled/background thread and/or in the service-tier. * - *

        (When operating within the context of a {@link ServerWebExchange}, - * use {@link DefaultReactiveOAuth2AuthorizedClientManager} instead.)

        + *

        + * (When operating within the context of a {@link ServerWebExchange}, use + * {@link DefaultReactiveOAuth2AuthorizedClientManager} instead.) + *

        * - *

        This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}.

        + *

        + * This is a reactive equivalent of + * {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}. + *

        * *

        Authorized Client Persistence

        * - *

        This client manager utilizes a {@link ReactiveOAuth2AuthorizedClientService} - * to persist {@link OAuth2AuthorizedClient}s.

        + *

        + * This client manager utilizes a {@link ReactiveOAuth2AuthorizedClientService} to persist + * {@link OAuth2AuthorizedClient}s. + *

        * - *

        By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} - * will be saved in the authorized client service. - * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler} - * via {@link #setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler)}.

        + *

        + * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} + * will be saved in the authorized client service. This functionality can be changed by + * configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler} via + * {@link #setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler)}. + *

        * - *

        By default, when an authorization attempt fails due to an + *

        + * By default, when an authorization attempt fails due to an * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error, - * the previously saved {@link OAuth2AuthorizedClient} - * will be removed from the authorized client service. - * (The {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} - * error generally occurs when a refresh token that is no longer valid - * is used to retrieve a new access token.) - * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationFailureHandler} - * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

        + * the previously saved {@link OAuth2AuthorizedClient} will be removed from the authorized + * client service. (The + * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error + * generally occurs when a refresh token that is no longer valid is used to retrieve a new + * access token.) This functionality can be changed by configuring a custom + * {@link ReactiveOAuth2AuthorizationFailureHandler} via + * {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}. + *

        * * @author Ankur Pathak * @author Phil Clay + * @since 5.2.2 * @see ReactiveOAuth2AuthorizedClientManager * @see ReactiveOAuth2AuthorizedClientProvider * @see ReactiveOAuth2AuthorizedClientService * @see ReactiveOAuth2AuthorizationSuccessHandler * @see ReactiveOAuth2AuthorizationFailureHandler - * @since 5.2.2 */ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { - private static final ReactiveOAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .clientCredentials() - .build(); + private static final ReactiveOAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder().clientCredentials().build(); + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = DEFAULT_AUTHORIZED_CLIENT_PROVIDER; + private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; /** - * Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters. - * + * Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} + * using the provided parameters. * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientService the authorized client service + * @param authorizedClientService the authorized client service */ public AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( ReactiveClientRegistrationRepository clientRegistrationRepository, @@ -93,19 +110,18 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientService = authorizedClientService; - this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> - authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> authorizedClientService + .saveAuthorizedClient(authorizedClient, principal); this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName())); + (clientRegistrationId, principal, attributes) -> this.authorizedClientService + .removeAuthorizedClient(clientRegistrationId, principal.getName())); } @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - return createAuthorizationContext(authorizeRequest) - .flatMap(authorizationContext -> authorize(authorizationContext, authorizeRequest.getPrincipal())); + .flatMap((authorizationContext) -> authorize(authorizationContext, authorizeRequest.getPrincipal())); } private Mono createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) { @@ -113,56 +129,57 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager Authentication principal = authorizeRequest.getPrincipal(); return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) .map(OAuth2AuthorizationContext::withAuthorizedClient) - .switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()) + .switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository + .findByRegistrationId(clientRegistrationId) + .flatMap((clientRegistration) -> this.authorizedClientService + .loadAuthorizedClient(clientRegistrationId, principal.getName()) .map(OAuth2AuthorizationContext::withAuthorizedClient) - .switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'"))))) - .flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest) - .defaultIfEmpty(Collections.emptyMap()) - .map(contextAttributes -> { + .switchIfEmpty(Mono.fromSupplier( + () -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))))) + .flatMap((contextBuilder) -> this.contextAttributesMapper.apply(authorizeRequest) + .defaultIfEmpty(Collections.emptyMap()).map((contextAttributes) -> { OAuth2AuthorizationContext.Builder builder = contextBuilder.principal(principal); if (!contextAttributes.isEmpty()) { - builder = builder.attributes(attributes -> attributes.putAll(contextAttributes)); + builder = builder.attributes((attributes) -> attributes.putAll(contextAttributes)); } return builder.build(); })); } /** - * Performs authorization and then delegates to either the {@link #authorizationSuccessHandler} - * or {@link #authorizationFailureHandler}, depending on the authorization result. - * + * Performs authorization and then delegates to either the + * {@link #authorizationSuccessHandler} or {@link #authorizationFailureHandler}, + * depending on the authorization result. * @param authorizationContext the context to authorize * @param principal the principle to authorize - * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds - * and the {@link #authorizationSuccessHandler} has completed, - * or completes with an exception after the authorization attempt fails - * and the {@link #authorizationFailureHandler} has completed + * @return a {@link Mono} that emits the authorized client after the authorization + * attempt succeeds and the {@link #authorizationSuccessHandler} has completed, or + * completes with an exception after the authorization attempt fails and the + * {@link #authorizationFailureHandler} has completed */ - private Mono authorize( - OAuth2AuthorizationContext authorizationContext, + private Mono authorize(OAuth2AuthorizationContext authorizationContext, Authentication principal) { return this.authorizedClientProvider.authorize(authorizationContext) - // Delegate to the authorizationSuccessHandler of the successful authorization - .flatMap(authorizedClient -> this.authorizationSuccessHandler.onAuthorizationSuccess( - authorizedClient, - principal, - Collections.emptyMap()) + // Delegate to the authorizationSuccessHandler of the successful + // authorization + .flatMap((authorizedClient) -> this.authorizationSuccessHandler + .onAuthorizationSuccess(authorizedClient, principal, Collections.emptyMap()) .thenReturn(authorizedClient)) // Delegate to the authorizationFailureHandler of the failed authorization - .onErrorResume(OAuth2AuthorizationException.class, authorizationException -> this.authorizationFailureHandler.onAuthorizationFailure( - authorizationException, - principal, - Collections.emptyMap()) - .then(Mono.error(authorizationException))) + .onErrorResume(OAuth2AuthorizationException.class, + (authorizationException) -> this.authorizationFailureHandler + .onAuthorizationFailure(authorizationException, principal, Collections.emptyMap()) + .then(Mono.error(authorizationException))) .switchIfEmpty(Mono.defer(() -> Mono.justOrEmpty(authorizationContext.getAuthorizedClient()))); } /** - * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. - * - * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or + * re-authorizing) an OAuth 2.0 Client. + * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} + * used for authorizing (or re-authorizing) an OAuth 2.0 Client */ public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); @@ -170,13 +187,15 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager } /** - * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes - * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. - * - * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes - * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + * Sets the {@code Function} used for mapping attribute(s) from the + * {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes to be associated to + * the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * @param contextAttributesMapper the {@code Function} used for supplying the + * {@code Map} of attributes to the {@link OAuth2AuthorizationContext#getAttributes() + * authorization context} */ - public void setContextAttributesMapper(Function>> contextAttributesMapper) { + public void setContextAttributesMapper( + Function>> contextAttributesMapper) { Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); this.contextAttributesMapper = contextAttributesMapper; } @@ -184,9 +203,10 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager /** * Sets the handler that handles successful authorizations. * - * The default saves {@link OAuth2AuthorizedClient}s in the {@link ReactiveOAuth2AuthorizedClientService}. - * - * @param authorizationSuccessHandler the handler that handles successful authorizations. + * The default saves {@link OAuth2AuthorizedClient}s in the + * {@link ReactiveOAuth2AuthorizedClientService}. + * @param authorizationSuccessHandler the handler that handles successful + * authorizations. * @since 5.3 */ public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { @@ -197,12 +217,13 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager /** * Sets the handler that handles authorization failures. * - *

        A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} - * is used by default.

        - * + *

        + * A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} is used + * by default. + *

        * @param authorizationFailureHandler the handler that handles authorization failures. - * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler * @since 5.3 + * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler */ public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); @@ -210,16 +231,19 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager } /** - * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + * The default implementation of the {@link #setContextAttributesMapper(Function) + * contextAttributesMapper}. */ - public static class DefaultContextAttributesMapper implements Function>> { + public static class DefaultContextAttributesMapper + implements Function>> { - private final AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper mapper = - new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper(); + private final AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper mapper = new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper(); @Override public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) { - return Mono.fromCallable(() -> mapper.apply(authorizeRequest)); + return Mono.fromCallable(() -> this.mapper.apply(authorizeRequest)); } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java index 0cbd6ee2c8..8050b74a03 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -20,8 +21,8 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.util.Assert; /** - * This exception is thrown on the client side when an attempt to authenticate - * or authorize an OAuth 2.0 client fails. + * This exception is thrown on the client side when an attempt to authenticate or + * authorize an OAuth 2.0 client fails. * * @author Phil Clay * @since 5.3 @@ -33,16 +34,15 @@ public class ClientAuthorizationException extends OAuth2AuthorizationException { /** * Constructs a {@code ClientAuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param clientRegistrationId the identifier for the client's registration */ public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId) { this(error, clientRegistrationId, error.toString()); } + /** * Constructs a {@code ClientAuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param clientRegistrationId the identifier for the client's registration * @param message the exception message @@ -55,7 +55,6 @@ public class ClientAuthorizationException extends OAuth2AuthorizationException { /** * Constructs a {@code ClientAuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param clientRegistrationId the identifier for the client's registration * @param cause the root cause @@ -66,13 +65,13 @@ public class ClientAuthorizationException extends OAuth2AuthorizationException { /** * Constructs a {@code ClientAuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param clientRegistrationId the identifier for the client's registration * @param message the exception message * @param cause the root cause */ - public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, String message, Throwable cause) { + public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, String message, + Throwable cause) { super(error, message, cause); Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); this.clientRegistrationId = clientRegistrationId; @@ -80,10 +79,10 @@ public class ClientAuthorizationException extends OAuth2AuthorizationException { /** * Returns the identifier for the client's registration. - * * @return the identifier for the client's registration */ public String getClientRegistrationId() { return this.clientRegistrationId; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java index d9b9e7a6a7..ee4c0e4784 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java @@ -13,24 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.springframework.security.oauth2.core.OAuth2Error; /** - * This exception is thrown when an OAuth 2.0 Client is required - * to obtain authorization from the Resource Owner. + * This exception is thrown when an OAuth 2.0 Client is required to obtain authorization + * from the Resource Owner. * * @author Joe Grandja * @since 5.1 * @see OAuth2AuthorizedClient */ public class ClientAuthorizationRequiredException extends ClientAuthorizationException { + private static final String CLIENT_AUTHORIZATION_REQUIRED_ERROR_CODE = "client_authorization_required"; /** - * Constructs a {@code ClientAuthorizationRequiredException} using the provided parameters. - * + * Constructs a {@code ClientAuthorizationRequiredException} using the provided + * parameters. * @param clientRegistrationId the identifier for the client's registration */ public ClientAuthorizationRequiredException(String clientRegistrationId) { @@ -38,4 +40,5 @@ public class ClientAuthorizationRequiredException extends ClientAuthorizationExc "Authorization required for Client Registration Id: " + clientRegistrationId, null), clientRegistrationId); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index ae5975d27f..527b7bfd9f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -26,13 +31,9 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; - /** - * An implementation of an {@link OAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. + * An implementation of an {@link OAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. * * @author Joe Grandja * @since 5.2 @@ -40,55 +41,60 @@ import java.time.Instant; * @see DefaultClientCredentialsTokenResponseClient */ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private OAuth2AccessTokenResponseClient accessTokenResponseClient = - new DefaultClientCredentialsTokenResponseClient(); + + private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + private Clock clock = Clock.systemUTC(); /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns {@code null} if authorization (or re-authorization) is not supported, - * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR - * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. - * + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns {@code null} if authorization (or re-authorization) is not + * supported, e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() + * authorization grant type} is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS + * client_credentials} OR the {@link OAuth2AuthorizedClient#getAccessToken() access + * token} is not expired. * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or + * re-authorization) is not supported */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { return null; } - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { - // If client is already authorized but access token is NOT expired than no need for re-authorization + // If client is already authorized but access token is NOT expired than no + // need for re-authorization return null; } - // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. // // Therefore, renewing an expired access token (re-authorization) // is the same as acquiring a new access token (authorization). + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, clientCredentialsGrantRequest); + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken()); + } - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - - OAuth2AccessTokenResponse tokenResponse; + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { try { - tokenResponse = this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - } catch (OAuth2AuthorizationException ex) { + return this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + } + catch (OAuth2AuthorizationException ex) { throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); } - - return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken()); } private boolean hasTokenExpired(AbstractOAuth2Token token) { @@ -96,23 +102,26 @@ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OA } /** - * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code client_credentials} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code client_credentials} grant */ - public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + public void setAccessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } /** * Sets the maximum acceptable clock skew, which is used when checking the - * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. * *

        - * An access token is considered expired if {@code OAuth2AccessToken#getExpiresAt() - clockSkew} - * is before the current time {@code clock#instant()}. - * + * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. * @param clockSkew the maximum acceptable clock skew */ public void setClockSkew(Duration clockSkew) { @@ -122,12 +131,13 @@ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OA } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. * @param clock the clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java index 71835832cb..e8ec38f7cd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; @@ -23,65 +30,63 @@ import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; /** - * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. * * @author Joe Grandja * @since 5.2 * @see ReactiveOAuth2AuthorizedClientProvider * @see WebClientReactiveClientCredentialsTokenResponseClient */ -public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { - private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = - new WebClientReactiveClientCredentialsTokenResponseClient(); +public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider + implements ReactiveOAuth2AuthorizedClientProvider { + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + private Clock clock = Clock.systemUTC(); /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns an empty {@code Mono} if authorization (or re-authorization) is not supported, - * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR - * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. - * + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns an empty {@code Mono} if authorization (or + * re-authorization) is not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization (or re-authorization) is not supported + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization (or re-authorization) is not supported */ @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { return Mono.empty(); } - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { - // If client is already authorized but access token is NOT expired than no need for re-authorization + // If client is already authorized but access token is NOT expired than no + // need for re-authorization return Mono.empty(); } - // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. // // Therefore, renewing an expired access token (re-authorization) // is the same as acquiring a new access token (authorization). - return Mono.just(new OAuth2ClientCredentialsGrantRequest(clientRegistration)) .flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, - e -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e)) - .map(tokenResponse -> new OAuth2AuthorizedClient( - clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken())); + (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), + ex)) + .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken())); } private boolean hasTokenExpired(AbstractOAuth2Token token) { @@ -89,23 +94,26 @@ public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider imple } /** - * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code client_credentials} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code client_credentials} grant */ - public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + public void setAccessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } /** * Sets the maximum acceptable clock skew, which is used when checking the - * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. * *

        - * An access token is considered expired if {@code OAuth2AccessToken#getExpiresAt() - clockSkew} - * is before the current time {@code clock#instant()}. - * + * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. * @param clockSkew the maximum acceptable clock skew */ public void setClockSkew(Duration clockSkew) { @@ -115,12 +123,13 @@ public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider imple } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. * @param clock the clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java index d57bcf8153..43cddfbb3e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java @@ -13,36 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.client; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + /** - * An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates - * to it's internal {@code List} of {@link OAuth2AuthorizedClientProvider}(s). + * An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates to + * it's internal {@code List} of {@link OAuth2AuthorizedClientProvider}(s). *

        * Each provider is given a chance to * {@link OAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} - * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context - * with the first {@code non-null} {@link OAuth2AuthorizedClient} being returned. + * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * context with the first {@code non-null} {@link OAuth2AuthorizedClient} being returned. * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClientProvider */ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private final List authorizedClientProviders; /** - * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param authorizedClientProviders a list of {@link OAuth2AuthorizedClientProvider}(s) + * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided + * parameters. + * @param authorizedClientProviders a list of + * {@link OAuth2AuthorizedClientProvider}(s) */ public DelegatingOAuth2AuthorizedClientProvider(OAuth2AuthorizedClientProvider... authorizedClientProviders) { Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); @@ -50,9 +53,10 @@ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2Aut } /** - * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param authorizedClientProviders a {@code List} of {@link OAuth2AuthorizedClientProvider}(s) + * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided + * parameters. + * @param authorizedClientProviders a {@code List} of + * {@link OAuth2AuthorizedClientProvider}(s) */ public DelegatingOAuth2AuthorizedClientProvider(List authorizedClientProviders) { Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); @@ -63,7 +67,7 @@ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2Aut @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - for (OAuth2AuthorizedClientProvider authorizedClientProvider : authorizedClientProviders) { + for (OAuth2AuthorizedClientProvider authorizedClientProvider : this.authorizedClientProviders) { OAuth2AuthorizedClient oauth2AuthorizedClient = authorizedClientProvider.authorize(context); if (oauth2AuthorizedClient != null) { return oauth2AuthorizedClient; @@ -71,4 +75,5 @@ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2Aut } return null; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java index 1264d792c5..af24c53289 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProvider.java @@ -13,49 +13,58 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; +package org.springframework.security.oauth2.client; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; + /** - * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} that simply delegates - * to it's internal {@code List} of {@link ReactiveOAuth2AuthorizedClientProvider}(s). + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} that simply + * delegates to it's internal {@code List} of + * {@link ReactiveOAuth2AuthorizedClientProvider}(s). *

        * Each provider is given a chance to - * {@link ReactiveOAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} - * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context - * with the first available {@link OAuth2AuthorizedClient} being returned. + * {@link ReactiveOAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) + * authorize} the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the + * provided context with the first available {@link OAuth2AuthorizedClient} being + * returned. * * @author Joe Grandja * @since 5.2 * @see ReactiveOAuth2AuthorizedClientProvider */ public final class DelegatingReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { + private final List authorizedClientProviders; /** - * Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param authorizedClientProviders a list of {@link ReactiveOAuth2AuthorizedClientProvider}(s) + * Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the + * provided parameters. + * @param authorizedClientProviders a list of + * {@link ReactiveOAuth2AuthorizedClientProvider}(s) */ - public DelegatingReactiveOAuth2AuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider... authorizedClientProviders) { + public DelegatingReactiveOAuth2AuthorizedClientProvider( + ReactiveOAuth2AuthorizedClientProvider... authorizedClientProviders) { Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); this.authorizedClientProviders = Collections.unmodifiableList(Arrays.asList(authorizedClientProviders)); } /** - * Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param authorizedClientProviders a {@code List} of {@link OAuth2AuthorizedClientProvider}(s) + * Constructs a {@code DelegatingReactiveOAuth2AuthorizedClientProvider} using the + * provided parameters. + * @param authorizedClientProviders a {@code List} of + * {@link OAuth2AuthorizedClientProvider}(s) */ - public DelegatingReactiveOAuth2AuthorizedClientProvider(List authorizedClientProviders) { + public DelegatingReactiveOAuth2AuthorizedClientProvider( + List authorizedClientProviders) { Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); this.authorizedClientProviders = Collections.unmodifiableList(new ArrayList<>(authorizedClientProviders)); } @@ -64,7 +73,7 @@ public final class DelegatingReactiveOAuth2AuthorizedClientProvider implements R public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); return Flux.fromIterable(this.authorizedClientProviders) - .concatMap(authorizedClientProvider -> authorizedClientProvider.authorize(context)) - .next(); + .concatMap((authorizedClientProvider) -> authorizedClientProvider.authorize(context)).next(); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java index 2164b4ba1d..3041ce764f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.util.Assert; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - /** - * An {@link OAuth2AuthorizedClientService} that stores - * {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory. + * An {@link OAuth2AuthorizedClientService} that stores {@link OAuth2AuthorizedClient + * Authorized Client(s)} in-memory. * * @author Joe Grandja * @author Vedran Pavic @@ -37,12 +38,14 @@ import java.util.concurrent.ConcurrentHashMap; * @see Authentication */ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService { + private final Map authorizedClients; + private final ClientRegistrationRepository clientRegistrationRepository; /** - * Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters. - * + * Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations */ public InMemoryOAuth2AuthorizedClientService(ClientRegistrationRepository clientRegistrationRepository) { @@ -52,14 +55,15 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author } /** - * Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters. - * - * @since 5.2 + * Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClients the initial {@code Map} of authorized client(s) keyed by {@link OAuth2AuthorizedClientId} + * @param authorizedClients the initial {@code Map} of authorized client(s) keyed by + * {@link OAuth2AuthorizedClientId} + * @since 5.2 */ public InMemoryOAuth2AuthorizedClientService(ClientRegistrationRepository clientRegistrationRepository, - Map authorizedClients) { + Map authorizedClients) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notEmpty(authorizedClients, "authorizedClients cannot be empty"); this.clientRegistrationRepository = clientRegistrationRepository; @@ -68,7 +72,8 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author @Override @SuppressWarnings("unchecked") - public T loadAuthorizedClient(String clientRegistrationId, String principalName) { + public T loadAuthorizedClient(String clientRegistrationId, + String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); @@ -82,8 +87,8 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(principal, "principal cannot be null"); - this.authorizedClients.put(new OAuth2AuthorizedClientId(authorizedClient.getClientRegistration().getRegistrationId(), - principal.getName()), authorizedClient); + this.authorizedClients.put(new OAuth2AuthorizedClientId( + authorizedClient.getClientRegistration().getRegistrationId(), principal.getName()), authorizedClient); } @Override @@ -95,4 +100,5 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author this.authorizedClients.remove(new OAuth2AuthorizedClientId(clientRegistrationId, principalName)); } } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java index 66091f8c90..c4058489b9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; /** - * An {@link OAuth2AuthorizedClientService} that stores - * {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory. + * An {@link OAuth2AuthorizedClientService} that stores {@link OAuth2AuthorizedClient + * Authorized Client(s)} in-memory. * * @author Rob Winch * @author Vedran Pavic @@ -37,27 +39,31 @@ import java.util.concurrent.ConcurrentHashMap; * @see Authentication */ public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService { + private final Map authorizedClients = new ConcurrentHashMap<>(); + private final ReactiveClientRegistrationRepository clientRegistrationRepository; /** - * Constructs an {@code InMemoryReactiveOAuth2AuthorizedClientService} using the provided parameters. - * + * Constructs an {@code InMemoryReactiveOAuth2AuthorizedClientService} using the + * provided parameters. * @param clientRegistrationRepository the repository of client registrations */ - public InMemoryReactiveOAuth2AuthorizedClientService(ReactiveClientRegistrationRepository clientRegistrationRepository) { + public InMemoryReactiveOAuth2AuthorizedClientService( + ReactiveClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; } @Override @SuppressWarnings("unchecked") - public Mono loadAuthorizedClient(String clientRegistrationId, String principalName) { + public Mono loadAuthorizedClient(String clientRegistrationId, + String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); return (Mono) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .map(clientRegistration -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) - .flatMap(identifier -> Mono.justOrEmpty(this.authorizedClients.get(identifier))); + .map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) + .flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier))); } @Override @@ -75,9 +81,12 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac public Mono removeAuthorizedClient(String clientRegistrationId, String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); + // @formatter:off return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .map(clientRegistration -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) + .map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) .doOnNext(this.authorizedClients::remove) .then(Mono.empty()); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java index cd3da393ac..7d88ba7236 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java @@ -13,8 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.nio.charset.StandardCharsets; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.function.Function; + import org.springframework.dao.DataRetrievalFailureException; import org.springframework.dao.DuplicateKeyException; import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; @@ -31,26 +44,15 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.nio.charset.StandardCharsets; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; -import java.sql.Types; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Set; -import java.util.function.Function; - /** - * A JDBC implementation of an {@link OAuth2AuthorizedClientService} - * that uses a {@link JdbcOperations} for {@link OAuth2AuthorizedClient} persistence. + * A JDBC implementation of an {@link OAuth2AuthorizedClientService} that uses a + * {@link JdbcOperations} for {@link OAuth2AuthorizedClient} persistence. * *

        * NOTE: This {@code OAuth2AuthorizedClientService} depends on the table definition - * described in "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" - * and therefore MUST be defined in the database schema. + * described in + * "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" and + * therefore MUST be defined in the database schema. * * @author Joe Grandja * @author Stav Shamir @@ -61,42 +63,58 @@ import java.util.function.Function; * @see RowMapper */ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService { - private static final String COLUMN_NAMES = - "client_registration_id, " + - "principal_name, " + - "access_token_type, " + - "access_token_value, " + - "access_token_issued_at, " + - "access_token_expires_at, " + - "access_token_scopes, " + - "refresh_token_value, " + - "refresh_token_issued_at"; + + // @formatter:off + private static final String COLUMN_NAMES = "client_registration_id, " + + "principal_name, " + + "access_token_type, " + + "access_token_value, " + + "access_token_issued_at, " + + "access_token_expires_at, " + + "access_token_scopes, " + + "refresh_token_value, " + + "refresh_token_issued_at"; + // @formatter:on + private static final String TABLE_NAME = "oauth2_authorized_client"; + private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?"; - private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + - " FROM " + TABLE_NAME + " WHERE " + PK_FILTER; - private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + - " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; - private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + - " WHERE " + PK_FILTER; - private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME + - " SET access_token_type = ?, access_token_value = ?, access_token_issued_at = ?," + - " access_token_expires_at = ?, access_token_scopes = ?," + - " refresh_token_value = ?, refresh_token_issued_at = ?" + - " WHERE " + PK_FILTER; + + // @formatter:off + private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + // @formatter:on + + // @formatter:off + private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + // @formatter:on + + private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + + // @formatter:off + private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME + + " SET access_token_type = ?, access_token_value = ?, access_token_issued_at = ?," + + " access_token_expires_at = ?, access_token_scopes = ?," + + " refresh_token_value = ?, refresh_token_issued_at = ?" + + " WHERE " + PK_FILTER; + // @formatter:on + protected final JdbcOperations jdbcOperations; + protected RowMapper authorizedClientRowMapper; + protected Function> authorizedClientParametersMapper; /** - * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided parameters. - * + * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided + * parameters. * @param jdbcOperations the JDBC operations * @param clientRegistrationRepository the repository of client registrations */ - public JdbcOAuth2AuthorizedClientService( - JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { - + public JdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations, + ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.jdbcOperations = jdbcOperations; @@ -106,19 +124,16 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient @Override @SuppressWarnings("unchecked") - public T loadAuthorizedClient(String clientRegistrationId, String principalName) { + public T loadAuthorizedClient(String clientRegistrationId, + String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); - SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), - new SqlParameterValue(Types.VARCHAR, principalName) - }; + new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); - - List result = this.jdbcOperations.query( - LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); - + List result = this.jdbcOperations.query(LOAD_AUTHORIZED_CLIENT_SQL, pss, + this.authorizedClientRowMapper); return !result.isEmpty() ? (T) result.get(0) : null; } @@ -126,40 +141,36 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(principal, "principal cannot be null"); - boolean existsAuthorizedClient = null != this.loadAuthorizedClient( authorizedClient.getClientRegistration().getRegistrationId(), principal.getName()); - if (existsAuthorizedClient) { updateAuthorizedClient(authorizedClient, principal); - } else { + } + else { try { insertAuthorizedClient(authorizedClient, principal); - } catch (DuplicateKeyException e) { + } + catch (DuplicateKeyException ex) { updateAuthorizedClient(authorizedClient, principal); } } } private void updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { - List parameters = this.authorizedClientParametersMapper.apply( - new OAuth2AuthorizedClientHolder(authorizedClient, principal)); - + List parameters = this.authorizedClientParametersMapper + .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)); SqlParameterValue clientRegistrationIdParameter = parameters.remove(0); SqlParameterValue principalNameParameter = parameters.remove(0); parameters.add(clientRegistrationIdParameter); parameters.add(principalNameParameter); - PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss); } private void insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { - List parameters = this.authorizedClientParametersMapper.apply( - new OAuth2AuthorizedClientHolder(authorizedClient, principal)); + List parameters = this.authorizedClientParametersMapper + .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)); PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); } @@ -167,21 +178,19 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient public void removeAuthorizedClient(String clientRegistrationId, String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); - SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), - new SqlParameterValue(Types.VARCHAR, principalName) - }; + new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); - this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); } /** - * Sets the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. - * The default is {@link OAuth2AuthorizedClientRowMapper}. - * - * @param authorizedClientRowMapper the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient} + * Sets the {@link RowMapper} used for mapping the current row in + * {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. The default is + * {@link OAuth2AuthorizedClientRowMapper}. + * @param authorizedClientRowMapper the {@link RowMapper} used for mapping the current + * row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient} */ public final void setAuthorizedClientRowMapper(RowMapper authorizedClientRowMapper) { Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null"); @@ -189,21 +198,24 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient } /** - * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}. - * The default is {@link OAuth2AuthorizedClientParametersMapper}. - * - * @param authorizedClientParametersMapper the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue} + * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to + * a {@code List} of {@link SqlParameterValue}. The default is + * {@link OAuth2AuthorizedClientParametersMapper}. + * @param authorizedClientParametersMapper the {@code Function} used for mapping + * {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue} */ - public final void setAuthorizedClientParametersMapper(Function> authorizedClientParametersMapper) { + public final void setAuthorizedClientParametersMapper( + Function> authorizedClientParametersMapper) { Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null"); this.authorizedClientParametersMapper = authorizedClientParametersMapper; } /** - * The default {@link RowMapper} that maps the current row - * in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. + * The default {@link RowMapper} that maps the current row in + * {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}. */ public static class OAuth2AuthorizedClientRowMapper implements RowMapper { + protected final ClientRegistrationRepository clientRegistrationRepository; public OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) { @@ -214,17 +226,15 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient @Override public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException { String clientRegistrationId = rs.getString("client_registration_id"); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId( - clientRegistrationId); + ClientRegistration clientRegistration = this.clientRegistrationRepository + .findByRegistrationId(clientRegistrationId); if (clientRegistration == null) { - throw new DataRetrievalFailureException("The ClientRegistration with id '" + - clientRegistrationId + "' exists in the data source, " + - "however, it was not found in the ClientRegistrationRepository."); + throw new DataRetrievalFailureException( + "The ClientRegistration with id '" + clientRegistrationId + "' exists in the data source, " + + "however, it was not found in the ClientRegistrationRepository."); } - OAuth2AccessToken.TokenType tokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( - rs.getString("access_token_type"))) { + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) { tokenType = OAuth2AccessToken.TokenType.BEARER; } String tokenValue = new String(rs.getBytes("access_token_value"), StandardCharsets.UTF_8); @@ -235,9 +245,7 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient if (accessTokenScopes != null) { scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); } - OAuth2AccessToken accessToken = new OAuth2AccessToken( - tokenType, tokenValue, issuedAt, expiresAt, scopes); - + OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, scopes); OAuth2RefreshToken refreshToken = null; byte[] refreshTokenValue = rs.getBytes("refresh_token_value"); if (refreshTokenValue != null) { @@ -249,19 +257,18 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient } refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); } - String principalName = rs.getString("principal_name"); - - return new OAuth2AuthorizedClient( - clientRegistration, principalName, accessToken, refreshToken); + return new OAuth2AuthorizedClient(clientRegistration, principalName, accessToken, refreshToken); } + } /** - * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} - * to a {@code List} of {@link SqlParameterValue}. + * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} to a + * {@code List} of {@link SqlParameterValue}. */ - public static class OAuth2AuthorizedClientParametersMapper implements Function> { + public static class OAuth2AuthorizedClientParametersMapper + implements Function> { @Override public List apply(OAuth2AuthorizedClientHolder authorizedClientHolder) { @@ -270,26 +277,19 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); OAuth2AccessToken accessToken = authorizedClient.getAccessToken(); OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); - List parameters = new ArrayList<>(); - parameters.add(new SqlParameterValue( - Types.VARCHAR, clientRegistration.getRegistrationId())); - parameters.add(new SqlParameterValue( - Types.VARCHAR, principal.getName())); - parameters.add(new SqlParameterValue( - Types.VARCHAR, accessToken.getTokenType().getValue())); - parameters.add(new SqlParameterValue( - Types.BLOB, accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8))); - parameters.add(new SqlParameterValue( - Types.TIMESTAMP, Timestamp.from(accessToken.getIssuedAt()))); - parameters.add(new SqlParameterValue( - Types.TIMESTAMP, Timestamp.from(accessToken.getExpiresAt()))); + parameters.add(new SqlParameterValue(Types.VARCHAR, clientRegistration.getRegistrationId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, principal.getName())); + parameters.add(new SqlParameterValue(Types.VARCHAR, accessToken.getTokenType().getValue())); + parameters.add( + new SqlParameterValue(Types.BLOB, accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8))); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(accessToken.getIssuedAt()))); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(accessToken.getExpiresAt()))); String accessTokenScopes = null; if (!CollectionUtils.isEmpty(accessToken.getScopes())) { accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ","); } - parameters.add(new SqlParameterValue( - Types.VARCHAR, accessTokenScopes)); + parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes)); byte[] refreshTokenValue = null; Timestamp refreshTokenIssuedAt = null; if (refreshToken != null) { @@ -298,25 +298,26 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient refreshTokenIssuedAt = Timestamp.from(refreshToken.getIssuedAt()); } } - parameters.add(new SqlParameterValue( - Types.BLOB, refreshTokenValue)); - parameters.add(new SqlParameterValue( - Types.TIMESTAMP, refreshTokenIssuedAt)); - + parameters.add(new SqlParameterValue(Types.BLOB, refreshTokenValue)); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, refreshTokenIssuedAt)); return parameters; } + } /** - * A holder for an {@link OAuth2AuthorizedClient} and End-User {@link Authentication} (Resource Owner). + * A holder for an {@link OAuth2AuthorizedClient} and End-User {@link Authentication} + * (Resource Owner). */ public static final class OAuth2AuthorizedClientHolder { + private final OAuth2AuthorizedClient authorizedClient; + private final Authentication principal; /** - * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided + * parameters. * @param authorizedClient the authorized client * @param principal the End-User {@link Authentication} (Resource Owner) */ @@ -329,7 +330,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient /** * Returns the {@link OAuth2AuthorizedClient}. - * * @return the {@link OAuth2AuthorizedClient} */ public OAuth2AuthorizedClient getAuthorizedClient() { @@ -338,11 +338,12 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient /** * Returns the End-User {@link Authentication} (Resource Owner). - * * @return the End-User {@link Authentication} (Resource Owner) */ public Authentication getPrincipal() { return this.principal; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java index 8bac099ae7..a74a319d69 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.lang.Nullable; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; +package org.springframework.security.oauth2.client; import java.util.Collections; import java.util.HashMap; @@ -27,34 +22,50 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.function.Consumer; +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + /** - * A context that holds authorization-specific state and is used by an {@link OAuth2AuthorizedClientProvider} - * when attempting to authorize (or re-authorize) an OAuth 2.0 Client. + * A context that holds authorization-specific state and is used by an + * {@link OAuth2AuthorizedClientProvider} when attempting to authorize (or re-authorize) + * an OAuth 2.0 Client. * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClientProvider */ public final class OAuth2AuthorizationContext { - /** - * The name of the {@link #getAttribute(String) attribute} in the context associated to the value for the "request scope(s)". - * The value of the attribute is a {@code String[]} of scope(s) to be requested by the {@link #getClientRegistration() client}. - */ - public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".REQUEST_SCOPE"); /** - * The name of the {@link #getAttribute(String) attribute} in the context associated to the value for the resource owner's username. + * The name of the {@link #getAttribute(String) attribute} in the context associated + * to the value for the "request scope(s)". The value of the attribute is a + * {@code String[]} of scope(s) to be requested by the {@link #getClientRegistration() + * client}. + */ + public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName() + .concat(".REQUEST_SCOPE"); + + /** + * The name of the {@link #getAttribute(String) attribute} in the context associated + * to the value for the resource owner's username. */ public static final String USERNAME_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".USERNAME"); /** - * The name of the {@link #getAttribute(String) attribute} in the context associated to the value for the resource owner's password. + * The name of the {@link #getAttribute(String) attribute} in the context associated + * to the value for the resource owner's password. */ public static final String PASSWORD_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".PASSWORD"); private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; private OAuth2AuthorizationContext() { @@ -62,7 +73,6 @@ public final class OAuth2AuthorizationContext { /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -70,10 +80,11 @@ public final class OAuth2AuthorizationContext { } /** - * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} - * if the {@link #withClientRegistration(ClientRegistration) client registration} was supplied. - * - * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client registration was supplied + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if the + * {@link #withClientRegistration(ClientRegistration) client registration} was + * supplied. + * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client + * registration was supplied */ @Nullable public OAuth2AuthorizedClient getAuthorizedClient() { @@ -82,7 +93,6 @@ public final class OAuth2AuthorizationContext { /** * Returns the {@code Principal} (to be) associated to the authorized client. - * * @return the {@code Principal} (to be) associated to the authorized client */ public Authentication getPrincipal() { @@ -91,7 +101,6 @@ public final class OAuth2AuthorizationContext { /** * Returns the attributes associated to the context. - * * @return a {@code Map} of the attributes associated to the context */ public Map getAttributes() { @@ -99,8 +108,8 @@ public final class OAuth2AuthorizationContext { } /** - * Returns the value of an attribute associated to the context or {@code null} if not available. - * + * Returns the value of an attribute associated to the context or {@code null} if not + * available. * @param name the name of the attribute * @param the type of the attribute * @return the value of the attribute associated to the context @@ -113,7 +122,6 @@ public final class OAuth2AuthorizationContext { /** * Returns a new {@link Builder} initialized with the {@link ClientRegistration}. - * * @param clientRegistration the {@link ClientRegistration client registration} * @return the {@link Builder} */ @@ -123,7 +131,6 @@ public final class OAuth2AuthorizationContext { /** * Returns a new {@link Builder} initialized with the {@link OAuth2AuthorizedClient}. - * * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} * @return the {@link Builder} */ @@ -134,10 +141,14 @@ public final class OAuth2AuthorizationContext { /** * A builder for {@link OAuth2AuthorizationContext}. */ - public static class Builder { + public static final class Builder { + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; private Builder(ClientRegistration clientRegistration) { @@ -152,8 +163,8 @@ public final class OAuth2AuthorizationContext { /** * Sets the {@code Principal} (to be) associated to the authorized client. - * - * @param principal the {@code Principal} (to be) associated to the authorized client + * @param principal the {@code Principal} (to be) associated to the authorized + * client * @return the {@link Builder} */ public Builder principal(Authentication principal) { @@ -163,8 +174,8 @@ public final class OAuth2AuthorizationContext { /** * Provides a {@link Consumer} access to the attributes associated to the context. - * - * @param attributesConsumer a {@link Consumer} of the attributes associated to the context + * @param attributesConsumer a {@link Consumer} of the attributes associated to + * the context * @return the {@link OAuth2AuthorizeRequest.Builder} */ public Builder attributes(Consumer> attributesConsumer) { @@ -177,7 +188,6 @@ public final class OAuth2AuthorizationContext { /** * Sets an attribute associated to the context. - * * @param name the name of the attribute * @param value the value of the attribute * @return the {@link Builder} @@ -192,7 +202,6 @@ public final class OAuth2AuthorizationContext { /** * Builds a new {@link OAuth2AuthorizationContext}. - * * @return a {@link OAuth2AuthorizationContext} */ public OAuth2AuthorizationContext build() { @@ -201,14 +210,16 @@ public final class OAuth2AuthorizationContext { if (this.authorizedClient != null) { context.clientRegistration = this.authorizedClient.getClientRegistration(); context.authorizedClient = this.authorizedClient; - } else { + } + else { context.clientRegistration = this.clientRegistration; } context.principal = this.principal; - context.attributes = Collections.unmodifiableMap( - CollectionUtils.isEmpty(this.attributes) ? - Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); + context.attributes = Collections.unmodifiableMap(CollectionUtils.isEmpty(this.attributes) + ? Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); return context; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java index c24141d8b4..4e6089eab1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Map; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import java.util.Map; - /** - * Handles when an OAuth 2.0 Client fails to authorize (or re-authorize) - * via the Authorization Server or Resource Server. + * Handles when an OAuth 2.0 Client fails to authorize (or re-authorize) via the + * Authorization Server or Resource Server. * * @author Joe Grandja * @since 5.3 @@ -33,16 +34,17 @@ import java.util.Map; public interface OAuth2AuthorizationFailureHandler { /** - * Called when an OAuth 2.0 Client fails to authorize (or re-authorize) - * via the Authorization Server or Resource Server. - * + * Called when an OAuth 2.0 Client fails to authorize (or re-authorize) via the + * Authorization Server or Resource Server. * @param authorizationException the exception that contains details about what failed * @param principal the {@code Principal} associated with the attempted authorization - * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions. - * For example, this might contain a {@code javax.servlet.http.HttpServletRequest} - * and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed - * within the context of a {@code javax.servlet.ServletContext}. + * @param attributes an immutable {@code Map} of (optional) attributes present under + * certain conditions. For example, this might contain a + * {@code javax.servlet.http.HttpServletRequest} and + * {@code javax.servlet.http.HttpServletResponse} if the authorization was performed + * within the context of a {@code javax.servlet.ServletContext}. */ - void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, - Authentication principal, Map attributes); + void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, Authentication principal, + Map attributes); + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java index b350924ab5..7e5e81a385 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.core.Authentication; +package org.springframework.security.oauth2.client; import java.util.Map; +import org.springframework.security.core.Authentication; + /** - * Handles when an OAuth 2.0 Client has been successfully - * authorized (or re-authorized) via the Authorization Server. + * Handles when an OAuth 2.0 Client has been successfully authorized (or re-authorized) + * via the Authorization Server. * * @author Joe Grandja * @since 5.3 @@ -32,16 +33,18 @@ import java.util.Map; public interface OAuth2AuthorizationSuccessHandler { /** - * Called when an OAuth 2.0 Client has been successfully - * authorized (or re-authorized) via the Authorization Server. - * - * @param authorizedClient the client that was successfully authorized (or re-authorized) + * Called when an OAuth 2.0 Client has been successfully authorized (or re-authorized) + * via the Authorization Server. + * @param authorizedClient the client that was successfully authorized (or + * re-authorized) * @param principal the {@code Principal} associated with the authorized client - * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions. - * For example, this might contain a {@code javax.servlet.http.HttpServletRequest} - * and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed - * within the context of a {@code javax.servlet.ServletContext}. + * @param attributes an immutable {@code Map} of (optional) attributes present under + * certain conditions. For example, this might contain a + * {@code javax.servlet.http.HttpServletRequest} and + * {@code javax.servlet.http.HttpServletResponse} if the authorization was performed + * within the context of a {@code javax.servlet.ServletContext}. */ - void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, - Authentication principal, Map attributes); + void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, + Map attributes); + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java index a8aa973e2a..58a4ad3531 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Consumer; + import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.Authentication; @@ -22,25 +29,24 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.function.Consumer; - /** * Represents a request the {@link OAuth2AuthorizedClientManager} uses to - * {@link OAuth2AuthorizedClientManager#authorize(OAuth2AuthorizeRequest) authorize} (or re-authorize) - * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. + * {@link OAuth2AuthorizedClientManager#authorize(OAuth2AuthorizeRequest) authorize} (or + * re-authorize) the {@link ClientRegistration client} identified by the provided + * {@link #getClientRegistrationId() clientRegistrationId}. * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClientManager */ public final class OAuth2AuthorizeRequest { + private String clientRegistrationId; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; private OAuth2AuthorizeRequest() { @@ -48,7 +54,6 @@ public final class OAuth2AuthorizeRequest { /** * Returns the identifier for the {@link ClientRegistration client registration}. - * * @return the identifier for the client registration */ public String getClientRegistrationId() { @@ -56,8 +61,8 @@ public final class OAuth2AuthorizeRequest { } /** - * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. - * + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it + * was not provided. * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided */ @Nullable @@ -67,7 +72,6 @@ public final class OAuth2AuthorizeRequest { /** * Returns the {@code Principal} (to be) associated to the authorized client. - * * @return the {@code Principal} (to be) associated to the authorized client */ public Authentication getPrincipal() { @@ -76,7 +80,6 @@ public final class OAuth2AuthorizeRequest { /** * Returns the attributes associated to the request. - * * @return a {@code Map} of the attributes associated to the request */ public Map getAttributes() { @@ -84,8 +87,8 @@ public final class OAuth2AuthorizeRequest { } /** - * Returns the value of an attribute associated to the request or {@code null} if not available. - * + * Returns the value of an attribute associated to the request or {@code null} if not + * available. * @param name the name of the attribute * @param the type of the attribute * @return the value of the attribute associated to the request @@ -97,9 +100,10 @@ public final class OAuth2AuthorizeRequest { } /** - * Returns a new {@link Builder} initialized with the identifier for the {@link ClientRegistration client registration}. - * - * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} + * Returns a new {@link Builder} initialized with the identifier for the + * {@link ClientRegistration client registration}. + * @param clientRegistrationId the identifier for the {@link ClientRegistration client + * registration} * @return the {@link Builder} */ public static Builder withClientRegistrationId(String clientRegistrationId) { @@ -107,8 +111,8 @@ public final class OAuth2AuthorizeRequest { } /** - * Returns a new {@link Builder} initialized with the {@link OAuth2AuthorizedClient authorized client}. - * + * Returns a new {@link Builder} initialized with the {@link OAuth2AuthorizedClient + * authorized client}. * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} * @return the {@link Builder} */ @@ -119,10 +123,14 @@ public final class OAuth2AuthorizeRequest { /** * A builder for {@link OAuth2AuthorizeRequest}. */ - public static class Builder { + public static final class Builder { + private String clientRegistrationId; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; private Builder(String clientRegistrationId) { @@ -136,11 +144,12 @@ public final class OAuth2AuthorizeRequest { } /** - * Sets the name of the {@code Principal} (to be) associated to the authorized client. - * - * @since 5.3 - * @param principalName the name of the {@code Principal} (to be) associated to the authorized client + * Sets the name of the {@code Principal} (to be) associated to the authorized + * client. + * @param principalName the name of the {@code Principal} (to be) associated to + * the authorized client * @return the {@link Builder} + * @since 5.3 */ public Builder principal(String principalName) { return principal(createAuthentication(principalName)); @@ -148,8 +157,8 @@ public final class OAuth2AuthorizeRequest { private static Authentication createAuthentication(final String principalName) { Assert.hasText(principalName, "principalName cannot be empty"); - return new AbstractAuthenticationToken(null) { + @Override public Object getCredentials() { return ""; @@ -159,13 +168,14 @@ public final class OAuth2AuthorizeRequest { public Object getPrincipal() { return principalName; } + }; } /** * Sets the {@code Principal} (to be) associated to the authorized client. - * - * @param principal the {@code Principal} (to be) associated to the authorized client + * @param principal the {@code Principal} (to be) associated to the authorized + * client * @return the {@link Builder} */ public Builder principal(Authentication principal) { @@ -175,8 +185,8 @@ public final class OAuth2AuthorizeRequest { /** * Provides a {@link Consumer} access to the attributes associated to the request. - * - * @param attributesConsumer a {@link Consumer} of the attributes associated to the request + * @param attributesConsumer a {@link Consumer} of the attributes associated to + * the request * @return the {@link Builder} */ public Builder attributes(Consumer> attributesConsumer) { @@ -189,7 +199,6 @@ public final class OAuth2AuthorizeRequest { /** * Sets an attribute associated to the request. - * * @param name the name of the attribute * @param value the value of the attribute * @return the {@link Builder} @@ -204,23 +213,25 @@ public final class OAuth2AuthorizeRequest { /** * Builds a new {@link OAuth2AuthorizeRequest}. - * * @return a {@link OAuth2AuthorizeRequest} */ public OAuth2AuthorizeRequest build() { Assert.notNull(this.principal, "principal cannot be null"); OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest(); if (this.authorizedClient != null) { - authorizeRequest.clientRegistrationId = this.authorizedClient.getClientRegistration().getRegistrationId(); + authorizeRequest.clientRegistrationId = this.authorizedClient.getClientRegistration() + .getRegistrationId(); authorizeRequest.authorizedClient = this.authorizedClient; - } else { + } + else { authorizeRequest.clientRegistrationId = this.clientRegistrationId; } authorizeRequest.principal = this.principal; - authorizeRequest.attributes = Collections.unmodifiableMap( - CollectionUtils.isEmpty(this.attributes) ? - Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); + authorizeRequest.attributes = Collections.unmodifiableMap(CollectionUtils.isEmpty(this.attributes) + ? Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); return authorizeRequest; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClient.java index e45fd060e6..9d9efe1343 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.io.Serializable; + import org.springframework.lang.Nullable; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -22,17 +25,15 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.util.Assert; -import java.io.Serializable; - /** * A representation of an OAuth 2.0 "Authorized Client". *

        - * A client is considered "authorized" when the End-User (Resource Owner) - * has granted authorization to the client to access it's protected resources. + * A client is considered "authorized" when the End-User (Resource Owner) has + * granted authorization to the client to access it's protected resources. *

        - * This class associates the {@link #getClientRegistration() Client} - * to the {@link #getAccessToken() Access Token} - * granted/authorized by the {@link #getPrincipalName() Resource Owner}. + * This class associates the {@link #getClientRegistration() Client} to the + * {@link #getAccessToken() Access Token} granted/authorized by the + * {@link #getPrincipalName() Resource Owner}. * * @author Joe Grandja * @since 5.0 @@ -41,33 +42,37 @@ import java.io.Serializable; * @see OAuth2RefreshToken */ public class OAuth2AuthorizedClient implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final ClientRegistration clientRegistration; + private final String principalName; + private final OAuth2AccessToken accessToken; + private final OAuth2RefreshToken refreshToken; /** * Constructs an {@code OAuth2AuthorizedClient} using the provided parameters. - * * @param clientRegistration the authorized client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) * @param accessToken the access token credential granted */ - public OAuth2AuthorizedClient(ClientRegistration clientRegistration, String principalName, OAuth2AccessToken accessToken) { + public OAuth2AuthorizedClient(ClientRegistration clientRegistration, String principalName, + OAuth2AccessToken accessToken) { this(clientRegistration, principalName, accessToken, null); } /** * Constructs an {@code OAuth2AuthorizedClient} using the provided parameters. - * * @param clientRegistration the authorized client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) * @param accessToken the access token credential granted * @param refreshToken the refresh token credential granted */ public OAuth2AuthorizedClient(ClientRegistration clientRegistration, String principalName, - OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken) { + OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.hasText(principalName, "principalName cannot be empty"); Assert.notNull(accessToken, "accessToken cannot be null"); @@ -79,7 +84,6 @@ public class OAuth2AuthorizedClient implements Serializable { /** * Returns the authorized client's {@link ClientRegistration registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -88,7 +92,6 @@ public class OAuth2AuthorizedClient implements Serializable { /** * Returns the End-User's {@code Principal} name. - * * @return the End-User's {@code Principal} name */ public String getPrincipalName() { @@ -97,7 +100,6 @@ public class OAuth2AuthorizedClient implements Serializable { /** * Returns the {@link OAuth2AccessToken access token} credential granted. - * * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { @@ -106,11 +108,11 @@ public class OAuth2AuthorizedClient implements Serializable { /** * Returns the {@link OAuth2RefreshToken refresh token} credential granted. - * - * @since 5.1 * @return the {@link OAuth2RefreshToken} + * @since 5.1 */ public @Nullable OAuth2RefreshToken getRefreshToken() { return this.refreshToken; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java index 3c0cb7d40d..4662679fd0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.core.SpringSecurityCoreVersion; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.client; import java.io.Serializable; import java.util.Objects; +import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.util.Assert; + /** * The identifier for {@link OAuth2AuthorizedClient}. * @@ -30,13 +31,15 @@ import java.util.Objects; * @see OAuth2AuthorizedClientService */ public final class OAuth2AuthorizedClientId implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final String clientRegistrationId; + private final String principalName; /** * Constructs an {@code OAuth2AuthorizedClientId} using the provided parameters. - * * @param clientRegistrationId the identifier for the client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) */ @@ -64,4 +67,5 @@ public final class OAuth2AuthorizedClientId implements Serializable { public int hashCode() { return Objects.hash(this.clientRegistrationId, this.principalName); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java index d8c0f1d3a1..0f62e2ccbe 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.springframework.lang.Nullable; @@ -20,16 +21,16 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; /** - * Implementations of this interface are responsible for the overall management - * of {@link OAuth2AuthorizedClient Authorized Client(s)}. + * Implementations of this interface are responsible for the overall management of + * {@link OAuth2AuthorizedClient Authorized Client(s)}. * *

        * The primary responsibilities include: *

          - *
        1. Authorizing (or re-authorizing) an OAuth 2.0 Client - * by leveraging an {@link OAuth2AuthorizedClientProvider}(s).
        2. - *
        3. Delegating the persistence of an {@link OAuth2AuthorizedClient}, - * typically using an {@link OAuth2AuthorizedClientService} OR {@link OAuth2AuthorizedClientRepository}.
        4. + *
        5. Authorizing (or re-authorizing) an OAuth 2.0 Client by leveraging an + * {@link OAuth2AuthorizedClientProvider}(s).
        6. + *
        7. Delegating the persistence of an {@link OAuth2AuthorizedClient}, typically using an + * {@link OAuth2AuthorizedClientService} OR {@link OAuth2AuthorizedClientRepository}.
        8. *
        * * @author Joe Grandja @@ -43,20 +44,23 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepo public interface OAuth2AuthorizedClientManager { /** - * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} - * identified by the provided {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. - * Implementations must return {@code null} if authorization is not supported for the specified client, - * e.g. the associated {@link OAuth2AuthorizedClientProvider}(s) does not support - * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration + * client} identified by the provided + * {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. + * Implementations must return {@code null} if authorization is not supported for the + * specified client, e.g. the associated {@link OAuth2AuthorizedClientProvider}(s) + * does not support the {@link ClientRegistration#getAuthorizationGrantType() + * authorization grant} type configured for the client. * *

        - * In the case of re-authorization, implementations must return the provided {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} - * if re-authorization is not supported for the client OR is not required, - * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR + * In the case of re-authorization, implementations must return the provided + * {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} if + * re-authorization is not supported for the client OR is not required, e.g. a + * {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. - * * @param authorizeRequest the authorize request - * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not + * supported for the specified client */ @Nullable OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java index ad96f4a985..92c61d5b29 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.springframework.lang.Nullable; @@ -20,25 +21,30 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.core.AuthorizationGrantType; /** - * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. - * Implementations will typically implement a specific {@link AuthorizationGrantType authorization grant} type. + * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. Implementations + * will typically implement a specific {@link AuthorizationGrantType authorization grant} + * type. * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClient * @see OAuth2AuthorizationContext - * @see Section 1.3 Authorization Grant + * @see Section + * 1.3 Authorization Grant */ @FunctionalInterface public interface OAuth2AuthorizedClientProvider { /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. - * Implementations must return {@code null} if authorization is not supported for the specified client, - * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. - * + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * context. Implementations must return {@code null} if authorization is not supported + * for the specified client, e.g. the provider doesn't support the + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type + * configured for the client. * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not + * supported for the specified client */ @Nullable OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java index 7d02a4efcc..fa109dd2aa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; -import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; -import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.client; import java.time.Clock; import java.time.Duration; @@ -30,13 +25,19 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.util.Assert; + /** * A builder that builds a {@link DelegatingOAuth2AuthorizedClientProvider} composed of - * one or more {@link OAuth2AuthorizedClientProvider}(s) that implement specific authorization grants. - * The supported authorization grants are {@link #authorizationCode() authorization_code}, - * {@link #refreshToken() refresh_token}, {@link #clientCredentials() client_credentials} - * and {@link #password() password}. - * In addition to the standard authorization grants, an implementation of an extension grant + * one or more {@link OAuth2AuthorizedClientProvider}(s) that implement specific + * authorization grants. The supported authorization grants are + * {@link #authorizationCode() authorization_code}, {@link #refreshToken() refresh_token}, + * {@link #clientCredentials() client_credentials} and {@link #password() password}. In + * addition to the standard authorization grants, an implementation of an extension grant * may be supplied via {@link #provider(OAuth2AuthorizedClientProvider)}. * * @author Joe Grandja @@ -49,14 +50,15 @@ import java.util.function.Consumer; * @see DelegatingOAuth2AuthorizedClientProvider */ public final class OAuth2AuthorizedClientProviderBuilder { + private final Map, Builder> builders = new LinkedHashMap<>(); private OAuth2AuthorizedClientProviderBuilder() { } /** - * Returns a new {@link OAuth2AuthorizedClientProviderBuilder} for configuring the supported authorization grant(s). - * + * Returns a new {@link OAuth2AuthorizedClientProviderBuilder} for configuring the + * supported authorization grant(s). * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public static OAuth2AuthorizedClientProviderBuilder builder() { @@ -64,273 +66,146 @@ public final class OAuth2AuthorizedClientProviderBuilder { } /** - * Configures an {@link OAuth2AuthorizedClientProvider} to be composed with the {@link DelegatingOAuth2AuthorizedClientProvider}. - * This may be used for implementations of extension authorization grants. - * + * Configures an {@link OAuth2AuthorizedClientProvider} to be composed with the + * {@link DelegatingOAuth2AuthorizedClientProvider}. This may be used for + * implementations of extension authorization grants. * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder provider(OAuth2AuthorizedClientProvider provider) { Assert.notNull(provider, "provider cannot be null"); - this.builders.computeIfAbsent(provider.getClass(), k -> () -> provider); + this.builders.computeIfAbsent(provider.getClass(), (k) -> () -> provider); return OAuth2AuthorizedClientProviderBuilder.this; } /** * Configures support for the {@code authorization_code} grant. - * * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder authorizationCode() { - this.builders.computeIfAbsent(AuthorizationCodeOAuth2AuthorizedClientProvider.class, k -> new AuthorizationCodeGrantBuilder()); + this.builders.computeIfAbsent(AuthorizationCodeOAuth2AuthorizedClientProvider.class, + (k) -> new AuthorizationCodeGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } - /** - * A builder for the {@code authorization_code} grant. - */ - public class AuthorizationCodeGrantBuilder implements Builder { - - private AuthorizationCodeGrantBuilder() { - } - - /** - * Builds an instance of {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. - * - * @return the {@link AuthorizationCodeOAuth2AuthorizedClientProvider} - */ - @Override - public OAuth2AuthorizedClientProvider build() { - return new AuthorizationCodeOAuth2AuthorizedClientProvider(); - } - } - /** * Configures support for the {@code refresh_token} grant. - * * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder refreshToken() { - this.builders.computeIfAbsent(RefreshTokenOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); + this.builders.computeIfAbsent(RefreshTokenOAuth2AuthorizedClientProvider.class, + (k) -> new RefreshTokenGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } /** * Configures support for the {@code refresh_token} grant. - * - * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used for further configuration + * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used + * for further configuration * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder refreshToken(Consumer builderConsumer) { RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( - RefreshTokenOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); + RefreshTokenOAuth2AuthorizedClientProvider.class, (k) -> new RefreshTokenGrantBuilder()); builderConsumer.accept(builder); return OAuth2AuthorizedClientProviderBuilder.this; } - /** - * A builder for the {@code refresh_token} grant. - */ - public class RefreshTokenGrantBuilder implements Builder { - private OAuth2AccessTokenResponseClient accessTokenResponseClient; - private Duration clockSkew; - private Clock clock; - - private RefreshTokenGrantBuilder() { - } - - /** - * Sets the client used when requesting an access token credential at the Token Endpoint. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint - * @return the {@link RefreshTokenGrantBuilder} - */ - public RefreshTokenGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { - this.accessTokenResponseClient = accessTokenResponseClient; - return this; - } - - /** - * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. - * An access token is considered expired if it's before {@code Instant.now(this.clock) - clockSkew}. - * - * @param clockSkew the maximum acceptable clock skew - * @return the {@link RefreshTokenGrantBuilder} - */ - public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { - this.clockSkew = clockSkew; - return this; - } - - /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * - * @param clock the clock - * @return the {@link RefreshTokenGrantBuilder} - */ - public RefreshTokenGrantBuilder clock(Clock clock) { - this.clock = clock; - return this; - } - - /** - * Builds an instance of {@link RefreshTokenOAuth2AuthorizedClientProvider}. - * - * @return the {@link RefreshTokenOAuth2AuthorizedClientProvider} - */ - @Override - public OAuth2AuthorizedClientProvider build() { - RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); - if (this.accessTokenResponseClient != null) { - authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); - } - if (this.clockSkew != null) { - authorizedClientProvider.setClockSkew(this.clockSkew); - } - if (this.clock != null) { - authorizedClientProvider.setClock(this.clock); - } - return authorizedClientProvider; - } - } - /** * Configures support for the {@code client_credentials} grant. - * * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder clientCredentials() { - this.builders.computeIfAbsent(ClientCredentialsOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); + this.builders.computeIfAbsent(ClientCredentialsOAuth2AuthorizedClientProvider.class, + (k) -> new ClientCredentialsGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } /** * Configures support for the {@code client_credentials} grant. - * - * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} used for further configuration + * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} + * used for further configuration * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ - public OAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer builderConsumer) { + public OAuth2AuthorizedClientProviderBuilder clientCredentials( + Consumer builderConsumer) { ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( - ClientCredentialsOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); + ClientCredentialsOAuth2AuthorizedClientProvider.class, (k) -> new ClientCredentialsGrantBuilder()); builderConsumer.accept(builder); return OAuth2AuthorizedClientProviderBuilder.this; } - /** - * A builder for the {@code client_credentials} grant. - */ - public class ClientCredentialsGrantBuilder implements Builder { - private OAuth2AccessTokenResponseClient accessTokenResponseClient; - private Duration clockSkew; - private Clock clock; - - private ClientCredentialsGrantBuilder() { - } - - /** - * Sets the client used when requesting an access token credential at the Token Endpoint. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint - * @return the {@link ClientCredentialsGrantBuilder} - */ - public ClientCredentialsGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { - this.accessTokenResponseClient = accessTokenResponseClient; - return this; - } - - /** - * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. - * An access token is considered expired if it's before {@code Instant.now(this.clock) - clockSkew}. - * - * @param clockSkew the maximum acceptable clock skew - * @return the {@link ClientCredentialsGrantBuilder} - */ - public ClientCredentialsGrantBuilder clockSkew(Duration clockSkew) { - this.clockSkew = clockSkew; - return this; - } - - /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * - * @param clock the clock - * @return the {@link ClientCredentialsGrantBuilder} - */ - public ClientCredentialsGrantBuilder clock(Clock clock) { - this.clock = clock; - return this; - } - - /** - * Builds an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider}. - * - * @return the {@link ClientCredentialsOAuth2AuthorizedClientProvider} - */ - @Override - public OAuth2AuthorizedClientProvider build() { - ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); - if (this.accessTokenResponseClient != null) { - authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); - } - if (this.clockSkew != null) { - authorizedClientProvider.setClockSkew(this.clockSkew); - } - if (this.clock != null) { - authorizedClientProvider.setClock(this.clock); - } - return authorizedClientProvider; - } - } - /** * Configures support for the {@code password} grant. - * * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder password() { - this.builders.computeIfAbsent(PasswordOAuth2AuthorizedClientProvider.class, k -> new PasswordGrantBuilder()); + this.builders.computeIfAbsent(PasswordOAuth2AuthorizedClientProvider.class, (k) -> new PasswordGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } /** * Configures support for the {@code password} grant. - * - * @param builderConsumer a {@code Consumer} of {@link PasswordGrantBuilder} used for further configuration + * @param builderConsumer a {@code Consumer} of {@link PasswordGrantBuilder} used for + * further configuration * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder password(Consumer builderConsumer) { - PasswordGrantBuilder builder = (PasswordGrantBuilder) this.builders.computeIfAbsent( - PasswordOAuth2AuthorizedClientProvider.class, k -> new PasswordGrantBuilder()); + PasswordGrantBuilder builder = (PasswordGrantBuilder) this.builders + .computeIfAbsent(PasswordOAuth2AuthorizedClientProvider.class, (k) -> new PasswordGrantBuilder()); builderConsumer.accept(builder); return OAuth2AuthorizedClientProviderBuilder.this; } + /** + * Builds an instance of {@link DelegatingOAuth2AuthorizedClientProvider} composed of + * one or more {@link OAuth2AuthorizedClientProvider}(s). + * @return the {@link DelegatingOAuth2AuthorizedClientProvider} + */ + public OAuth2AuthorizedClientProvider build() { + List authorizedClientProviders = new ArrayList<>(); + for (Builder builder : this.builders.values()) { + authorizedClientProviders.add(builder.build()); + } + return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + interface Builder { + + OAuth2AuthorizedClientProvider build(); + + } + /** * A builder for the {@code password} grant. */ - public class PasswordGrantBuilder implements Builder { + public final class PasswordGrantBuilder implements Builder { + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + private Clock clock; private PasswordGrantBuilder() { } /** - * Sets the client used when requesting an access token credential at the Token Endpoint. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * Sets the client used when requesting an access token credential at the Token + * Endpoint. + * @param accessTokenResponseClient the client used when requesting an access + * token credential at the Token Endpoint * @return the {@link PasswordGrantBuilder} */ - public PasswordGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + public PasswordGrantBuilder accessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { this.accessTokenResponseClient = accessTokenResponseClient; return this; } /** - * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. - * An access token is considered expired if it's before {@code Instant.now(this.clock) - clockSkew}. - * + * Sets the maximum acceptable clock skew, which is used when checking the access + * token expiry. An access token is considered expired if it's before + * {@code Instant.now(this.clock) - clockSkew}. * @param clockSkew the maximum acceptable clock skew * @return the {@link PasswordGrantBuilder} */ @@ -340,8 +215,8 @@ public final class OAuth2AuthorizedClientProviderBuilder { } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the + * access token expiry. * @param clock the clock * @return the {@link PasswordGrantBuilder} */ @@ -352,7 +227,6 @@ public final class OAuth2AuthorizedClientProviderBuilder { /** * Builds an instance of {@link PasswordOAuth2AuthorizedClientProvider}. - * * @return the {@link PasswordOAuth2AuthorizedClientProvider} */ @Override @@ -369,23 +243,168 @@ public final class OAuth2AuthorizedClientProviderBuilder { } return authorizedClientProvider; } + } /** - * Builds an instance of {@link DelegatingOAuth2AuthorizedClientProvider} - * composed of one or more {@link OAuth2AuthorizedClientProvider}(s). - * - * @return the {@link DelegatingOAuth2AuthorizedClientProvider} + * A builder for the {@code client_credentials} grant. */ - public OAuth2AuthorizedClientProvider build() { - List authorizedClientProviders = new ArrayList<>(); - for (Builder builder : this.builders.values()) { - authorizedClientProviders.add(builder.build()); + public final class ClientCredentialsGrantBuilder implements Builder { + + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + + private Duration clockSkew; + + private Clock clock; + + private ClientCredentialsGrantBuilder() { } - return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint. + * @param accessTokenResponseClient the client used when requesting an access + * token credential at the Token Endpoint + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder accessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access + * token expiry. An access token is considered expired if it's before + * {@code Instant.now(this.clock) - clockSkew}. + * @param clockSkew the maximum acceptable clock skew + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the + * access token expiry. + * @param clock the clock + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder clock(Clock clock) { + this.clock = clock; + return this; + } + + /** + * Builds an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider}. + * @return the {@link ClientCredentialsOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + if (this.clock != null) { + authorizedClientProvider.setClock(this.clock); + } + return authorizedClientProvider; + } + } - interface Builder { - OAuth2AuthorizedClientProvider build(); + /** + * A builder for the {@code authorization_code} grant. + */ + public final class AuthorizationCodeGrantBuilder implements Builder { + + private AuthorizationCodeGrantBuilder() { + } + + /** + * Builds an instance of {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. + * @return the {@link AuthorizationCodeOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + return new AuthorizationCodeOAuth2AuthorizedClientProvider(); + } + } + + /** + * A builder for the {@code refresh_token} grant. + */ + public final class RefreshTokenGrantBuilder implements Builder { + + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + + private Duration clockSkew; + + private Clock clock; + + private RefreshTokenGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint. + * @param accessTokenResponseClient the client used when requesting an access + * token credential at the Token Endpoint + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder accessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access + * token expiry. An access token is considered expired if it's before + * {@code Instant.now(this.clock) - clockSkew}. + * @param clockSkew the maximum acceptable clock skew + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the + * access token expiry. + * @param clock the clock + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clock(Clock clock) { + this.clock = clock; + return this; + } + + /** + * Builds an instance of {@link RefreshTokenOAuth2AuthorizedClientProvider}. + * @return the {@link RefreshTokenOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + if (this.clock != null) { + authorizedClientProvider.setClock(this.clock); + } + return authorizedClientProvider; + } + + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientService.java index 4a275287dd..aa07cd3abd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.springframework.security.core.Authentication; @@ -20,12 +21,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.core.OAuth2AccessToken; /** - * Implementations of this interface are responsible for the management - * of {@link OAuth2AuthorizedClient Authorized Client(s)}, which provide the purpose - * of associating an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential + * Implementations of this interface are responsible for the management of + * {@link OAuth2AuthorizedClient Authorized Client(s)}, which provide the purpose of + * associating an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, - * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} - * that originally granted the authorization. + * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} that originally + * granted the authorization. * * @author Joe Grandja * @since 5.0 @@ -37,10 +38,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; public interface OAuth2AuthorizedClientService { /** - * Returns the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User's {@code Principal} name - * or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User's {@code Principal} name or {@code null} if + * not available. * @param clientRegistrationId the identifier for the client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) * @param a type of OAuth2AuthorizedClient @@ -49,18 +49,16 @@ public interface OAuth2AuthorizedClientService { T loadAuthorizedClient(String clientRegistrationId, String principalName); /** - * Saves the {@link OAuth2AuthorizedClient} associating it to - * the provided End-User {@link Authentication} (Resource Owner). - * + * Saves the {@link OAuth2AuthorizedClient} associating it to the provided End-User + * {@link Authentication} (Resource Owner). * @param authorizedClient the authorized client * @param principal the End-User {@link Authentication} (Resource Owner) */ void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal); /** - * Removes the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User's {@code Principal} name. - * + * Removes the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User's {@code Principal} name. * @param clientRegistrationId the identifier for the client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java index e40522a7e5..931e862b79 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -27,13 +32,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; - /** - * An implementation of an {@link OAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#PASSWORD password} grant. + * An implementation of an {@link OAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#PASSWORD password} grant. * * @author Joe Grandja * @since 5.2 @@ -41,96 +42,105 @@ import java.time.Instant; * @see DefaultPasswordTokenResponseClient */ public final class PasswordOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private OAuth2AccessTokenResponseClient accessTokenResponseClient = - new DefaultPasswordTokenResponseClient(); + + private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultPasswordTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + private Clock clock = Clock.systemUTC(); /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns {@code null} if authorization (or re-authorization) is not supported, - * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#PASSWORD password} OR - * the {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME username} and/or - * {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME password} attributes - * are not available in the provided {@code context} OR - * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns {@code null} if authorization (or re-authorization) is not + * supported, e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() + * authorization grant type} is not {@link AuthorizationGrantType#PASSWORD password} + * OR the {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME username} and/or + * {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME password} attributes are + * not available in the provided {@code context} OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * *

        - * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} + * are supported: *

          - *
        1. {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME} (required) - a {@code String} value for the resource owner's username
        2. - *
        3. {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME} (required) - a {@code String} value for the resource owner's password
        4. + *
        5. {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME} (required) - a + * {@code String} value for the resource owner's username
        6. + *
        7. {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME} (required) - a + * {@code String} value for the resource owner's password
        8. *
        - * * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or + * re-authorization) is not supported */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (!AuthorizationGrantType.PASSWORD.equals(clientRegistration.getAuthorizationGrantType())) { return null; } - String username = context.getAttribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME); String password = context.getAttribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME); if (!StringUtils.hasText(username) || !StringUtils.hasText(password)) { return null; } - if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { - // If client is already authorized and access token is NOT expired than no need for re-authorization + // If client is already authorized and access token is NOT expired than no + // need for re-authorization return null; } - - if (authorizedClient != null && hasTokenExpired(authorizedClient.getAccessToken()) && authorizedClient.getRefreshToken() != null) { - // If client is already authorized and access token is expired and a refresh token is available, - // than return and allow RefreshTokenOAuth2AuthorizedClientProvider to handle the refresh + if (authorizedClient != null && hasTokenExpired(authorizedClient.getAccessToken()) + && authorizedClient.getRefreshToken() != null) { + // If client is already authorized and access token is expired and a refresh + // token is available, than return and allow + // RefreshTokenOAuth2AuthorizedClientProvider to handle the refresh return null; } - - OAuth2PasswordGrantRequest passwordGrantRequest = - new OAuth2PasswordGrantRequest(clientRegistration, username, password); - - OAuth2AccessTokenResponse tokenResponse; - try { - tokenResponse = this.accessTokenResponseClient.getTokenResponse(passwordGrantRequest); - } catch (OAuth2AuthorizationException ex) { - throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); - } - + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, username, + password); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, passwordGrantRequest); return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); } + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, + OAuth2PasswordGrantRequest passwordGrantRequest) { + try { + return this.accessTokenResponseClient.getTokenResponse(passwordGrantRequest); + } + catch (OAuth2AuthorizationException ex) { + throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); + } + } + private boolean hasTokenExpired(AbstractOAuth2Token token) { return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); } /** - * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code password} grant. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code password} grant + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code password} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code password} grant */ - public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + public void setAccessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } /** * Sets the maximum acceptable clock skew, which is used when checking the - * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. * *

        - * An access token is considered expired if {@code OAuth2AccessToken#getExpiresAt() - clockSkew} - * is before the current time {@code clock#instant()}. - * + * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. * @param clockSkew the maximum acceptable clock skew */ public void setClockSkew(Duration clockSkew) { @@ -140,12 +150,13 @@ public final class PasswordOAuth2AuthorizedClientProvider implements OAuth2Autho } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. * @param clock the clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java index 971170aef6..7240fef0cc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.WebClientReactivePasswordTokenResponseClient; @@ -24,15 +31,10 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; /** - * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#PASSWORD password} grant. + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#PASSWORD password} grant. * * @author Joe Grandja * @since 5.2 @@ -40,67 +42,71 @@ import java.time.Instant; * @see WebClientReactivePasswordTokenResponseClient */ public final class PasswordReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { - private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = - new WebClientReactivePasswordTokenResponseClient(); + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactivePasswordTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + private Clock clock = Clock.systemUTC(); /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns an empty {@code Mono} if authorization (or re-authorization) is not supported, - * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#PASSWORD password} OR - * the {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME username} and/or - * {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME password} attributes - * are not available in the provided {@code context} OR - * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns an empty {@code Mono} if authorization (or + * re-authorization) is not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#PASSWORD password} OR the + * {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME username} and/or + * {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME password} attributes are + * not available in the provided {@code context} OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * *

        - * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} + * are supported: *

          - *
        1. {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME} (required) - a {@code String} value for the resource owner's username
        2. - *
        3. {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME} (required) - a {@code String} value for the resource owner's password
        4. + *
        5. {@link OAuth2AuthorizationContext#USERNAME_ATTRIBUTE_NAME} (required) - a + * {@code String} value for the resource owner's username
        6. + *
        7. {@link OAuth2AuthorizationContext#PASSWORD_ATTRIBUTE_NAME} (required) - a + * {@code String} value for the resource owner's password
        8. *
        - * * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization (or re-authorization) is not supported + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization (or re-authorization) is not supported */ @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (!AuthorizationGrantType.PASSWORD.equals(clientRegistration.getAuthorizationGrantType())) { return Mono.empty(); } - String username = context.getAttribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME); String password = context.getAttribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME); if (!StringUtils.hasText(username) || !StringUtils.hasText(password)) { return Mono.empty(); } - if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { - // If client is already authorized and access token is NOT expired than no need for re-authorization + // If client is already authorized and access token is NOT expired than no + // need for re-authorization return Mono.empty(); } - - if (authorizedClient != null && hasTokenExpired(authorizedClient.getAccessToken()) && authorizedClient.getRefreshToken() != null) { - // If client is already authorized and access token is expired and a refresh token is available, - // than return and allow RefreshTokenReactiveOAuth2AuthorizedClientProvider to handle the refresh + if (authorizedClient != null && hasTokenExpired(authorizedClient.getAccessToken()) + && authorizedClient.getRefreshToken() != null) { + // If client is already authorized and access token is expired and a refresh + // token is available, + // than return and allow RefreshTokenReactiveOAuth2AuthorizedClientProvider to + // handle the refresh return Mono.empty(); } - - OAuth2PasswordGrantRequest passwordGrantRequest = - new OAuth2PasswordGrantRequest(clientRegistration, username, password); - - return Mono.just(passwordGrantRequest) - .flatMap(this.accessTokenResponseClient::getTokenResponse) + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, username, + password); + return Mono.just(passwordGrantRequest).flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, - e -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e)) - .map(tokenResponse -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), + e)) + .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken())); } @@ -109,23 +115,26 @@ public final class PasswordReactiveOAuth2AuthorizedClientProvider implements Rea } /** - * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code password} grant. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code password} grant + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code password} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code password} grant */ - public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + public void setAccessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } /** * Sets the maximum acceptable clock skew, which is used when checking the - * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. * *

        - * An access token is considered expired if {@code OAuth2AccessToken#getExpiresAt() - clockSkew} - * is before the current time {@code clock#instant()}. - * + * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. * @param clockSkew the maximum acceptable clock skew */ public void setClockSkew(Duration clockSkew) { @@ -135,12 +144,13 @@ public final class PasswordReactiveOAuth2AuthorizedClientProvider implements Rea } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. * @param clock the clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java index ed93e6cbf4..70c23ed0d0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import reactor.core.publisher.Mono; +package org.springframework.security.oauth2.client; import java.util.Map; +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; + /** - * Handles when an OAuth 2.0 Client - * fails to authorize (or re-authorize) - * via the authorization server or resource server. + * Handles when an OAuth 2.0 Client fails to authorize (or re-authorize) via the + * authorization server or resource server. * * @author Phil Clay * @since 5.3 @@ -33,19 +34,18 @@ import java.util.Map; public interface ReactiveOAuth2AuthorizationFailureHandler { /** - * Called when an OAuth 2.0 Client - * fails to authorize (or re-authorize) - * via the authorization server or resource server. - * + * Called when an OAuth 2.0 Client fails to authorize (or re-authorize) via the + * authorization server or resource server. * @param authorizationException the exception that contains details about what failed * @param principal the {@code Principal} that was attempted to be authorized - * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions. - * For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} - * if the authorization was performed within the context of a {@code ServerWebExchange}. - * @return an empty {@link Mono} that completes after this handler has finished handling the event. + * @param attributes an immutable {@code Map} of extra optional attributes present + * under certain conditions. For example, this might contain a + * {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} if the + * authorization was performed within the context of a {@code ServerWebExchange}. + * @return an empty {@link Mono} that completes after this handler has finished + * handling the event. */ - Mono onAuthorizationFailure( - OAuth2AuthorizationException authorizationException, - Authentication principal, + Mono onAuthorizationFailure(OAuth2AuthorizationException authorizationException, Authentication principal, Map attributes); + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java index 1df3258f63..5b43d8d279 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.core.Authentication; -import reactor.core.publisher.Mono; +package org.springframework.security.oauth2.client; import java.util.Map; +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; + /** - * Handles when an OAuth 2.0 Client - * has been successfully authorized (or re-authorized) + * Handles when an OAuth 2.0 Client has been successfully authorized (or re-authorized) * via the authorization server. * * @author Phil Clay @@ -32,20 +33,18 @@ import java.util.Map; public interface ReactiveOAuth2AuthorizationSuccessHandler { /** - * Called when an OAuth 2.0 Client - * has been successfully authorized (or re-authorized) + * Called when an OAuth 2.0 Client has been successfully authorized (or re-authorized) * via the authorization server. - * * @param authorizedClient the client that was successfully authorized * @param principal the {@code Principal} associated with the authorized client - * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions. - * For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} - * if the authorization was performed within the context of a {@code ServerWebExchange}. - * @return an empty {@link Mono} that completes after this handler has finished handling the event. + * @param attributes an immutable {@code Map} of extra optional attributes present + * under certain conditions. For example, this might contain a + * {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} if the + * authorization was performed within the context of a {@code ServerWebExchange}. + * @return an empty {@link Mono} that completes after this handler has finished + * handling the event. */ - Mono onAuthorizationSuccess( - OAuth2AuthorizedClient authorizedClient, - Authentication principal, + Mono onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, Map attributes); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientManager.java index 8730aaefb9..1ee84cbc0f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientManager.java @@ -13,23 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; -import reactor.core.publisher.Mono; /** - * Implementations of this interface are responsible for the overall management - * of {@link OAuth2AuthorizedClient Authorized Client(s)}. + * Implementations of this interface are responsible for the overall management of + * {@link OAuth2AuthorizedClient Authorized Client(s)}. * *

        * The primary responsibilities include: *

          - *
        1. Authorizing (or re-authorizing) an OAuth 2.0 Client - * by leveraging a {@link ReactiveOAuth2AuthorizedClientProvider}(s).
        2. - *
        3. Delegating the persistence of an {@link OAuth2AuthorizedClient}, - * typically using a {@link ReactiveOAuth2AuthorizedClientService} OR {@link ServerOAuth2AuthorizedClientRepository}.
        4. + *
        5. Authorizing (or re-authorizing) an OAuth 2.0 Client by leveraging a + * {@link ReactiveOAuth2AuthorizedClientProvider}(s).
        6. + *
        7. Delegating the persistence of an {@link OAuth2AuthorizedClient}, typically using a + * {@link ReactiveOAuth2AuthorizedClientService} OR + * {@link ServerOAuth2AuthorizedClientRepository}.
        8. *
        * * @author Joe Grandja @@ -43,20 +46,24 @@ import reactor.core.publisher.Mono; public interface ReactiveOAuth2AuthorizedClientManager { /** - * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} - * identified by the provided {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. - * Implementations must return an empty {@code Mono} if authorization is not supported for the specified client, - * e.g. the associated {@link ReactiveOAuth2AuthorizedClientProvider}(s) does not support - * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration + * client} identified by the provided + * {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. + * Implementations must return an empty {@code Mono} if authorization is not supported + * for the specified client, e.g. the associated + * {@link ReactiveOAuth2AuthorizedClientProvider}(s) does not support the + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type + * configured for the client. * *

        - * In the case of re-authorization, implementations must return the provided {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} - * if re-authorization is not supported for the client OR is not required, - * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR + * In the case of re-authorization, implementations must return the provided + * {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} if + * re-authorization is not supported for the client OR is not required, e.g. a + * {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. - * * @param authorizeRequest the authorize request - * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization is not supported for the specified client */ Mono authorize(OAuth2AuthorizeRequest authorizeRequest); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java index f5b775410c..f78b4f4b08 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProvider.java @@ -13,32 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import reactor.core.publisher.Mono; /** - * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. - * Implementations will typically implement a specific {@link AuthorizationGrantType authorization grant} type. + * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. Implementations + * will typically implement a specific {@link AuthorizationGrantType authorization grant} + * type. * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClient * @see OAuth2AuthorizationContext - * @see Section 1.3 Authorization Grant + * @see Section + * 1.3 Authorization Grant */ @FunctionalInterface public interface ReactiveOAuth2AuthorizedClientProvider { /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. - * Implementations must return an empty {@code Mono} if authorization is not supported for the specified client, - * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. - * + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * context. Implementations must return an empty {@code Mono} if authorization is not + * supported for the specified client, e.g. the provider doesn't support the + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type + * configured for the client. * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization is not supported for the specified client */ Mono authorize(OAuth2AuthorizationContext context); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java index 482b5962ec..7b0580571d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; -import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; -import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.client; import java.time.Clock; import java.time.Duration; @@ -30,13 +25,19 @@ import java.util.Map; import java.util.function.Consumer; import java.util.stream.Collectors; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.util.Assert; + /** - * A builder that builds a {@link DelegatingReactiveOAuth2AuthorizedClientProvider} composed of - * one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s) that implement specific authorization grants. - * The supported authorization grants are {@link #authorizationCode() authorization_code}, - * {@link #refreshToken() refresh_token}, {@link #clientCredentials() client_credentials} - * and {@link #password() password}. - * In addition to the standard authorization grants, an implementation of an extension grant + * A builder that builds a {@link DelegatingReactiveOAuth2AuthorizedClientProvider} + * composed of one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s) that + * implement specific authorization grants. The supported authorization grants are + * {@link #authorizationCode() authorization_code}, {@link #refreshToken() refresh_token}, + * {@link #clientCredentials() client_credentials} and {@link #password() password}. In + * addition to the standard authorization grants, an implementation of an extension grant * may be supplied via {@link #provider(ReactiveOAuth2AuthorizedClientProvider)}. * * @author Joe Grandja @@ -49,14 +50,15 @@ import java.util.stream.Collectors; * @see DelegatingReactiveOAuth2AuthorizedClientProvider */ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { + private final Map, Builder> builders = new LinkedHashMap<>(); private ReactiveOAuth2AuthorizedClientProviderBuilder() { } /** - * Returns a new {@link ReactiveOAuth2AuthorizedClientProviderBuilder} for configuring the supported authorization grant(s). - * + * Returns a new {@link ReactiveOAuth2AuthorizedClientProviderBuilder} for configuring + * the supported authorization grant(s). * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} */ public static ReactiveOAuth2AuthorizedClientProviderBuilder builder() { @@ -64,184 +66,167 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { } /** - * Configures a {@link ReactiveOAuth2AuthorizedClientProvider} to be composed with the {@link DelegatingReactiveOAuth2AuthorizedClientProvider}. - * This may be used for implementations of extension authorization grants. - * + * Configures a {@link ReactiveOAuth2AuthorizedClientProvider} to be composed with the + * {@link DelegatingReactiveOAuth2AuthorizedClientProvider}. This may be used for + * implementations of extension authorization grants. * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} */ public ReactiveOAuth2AuthorizedClientProviderBuilder provider(ReactiveOAuth2AuthorizedClientProvider provider) { Assert.notNull(provider, "provider cannot be null"); - this.builders.computeIfAbsent(provider.getClass(), k -> () -> provider); + this.builders.computeIfAbsent(provider.getClass(), (k) -> () -> provider); return ReactiveOAuth2AuthorizedClientProviderBuilder.this; } /** * Configures support for the {@code authorization_code} grant. - * * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} */ public ReactiveOAuth2AuthorizedClientProviderBuilder authorizationCode() { - this.builders.computeIfAbsent(AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class, k -> new AuthorizationCodeGrantBuilder()); + this.builders.computeIfAbsent(AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class, + (k) -> new AuthorizationCodeGrantBuilder()); return ReactiveOAuth2AuthorizedClientProviderBuilder.this; } + /** + * Configures support for the {@code refresh_token} grant. + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken() { + this.builders.computeIfAbsent(RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, + (k) -> new RefreshTokenGrantBuilder()); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code refresh_token} grant. + * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used + * for further configuration + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken( + Consumer builderConsumer) { + RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( + RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, (k) -> new RefreshTokenGrantBuilder()); + builderConsumer.accept(builder); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code client_credentials} grant. + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials() { + this.builders.computeIfAbsent(ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, + (k) -> new ClientCredentialsGrantBuilder()); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code client_credentials} grant. + * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} + * used for further configuration + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials( + Consumer builderConsumer) { + ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( + ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, + (k) -> new ClientCredentialsGrantBuilder()); + builderConsumer.accept(builder); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code password} grant. + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder password() { + this.builders.computeIfAbsent(PasswordReactiveOAuth2AuthorizedClientProvider.class, + (k) -> new PasswordGrantBuilder()); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code password} grant. + * @param builderConsumer a {@code Consumer} of {@link PasswordGrantBuilder} used for + * further configuration + * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} + */ + public ReactiveOAuth2AuthorizedClientProviderBuilder password(Consumer builderConsumer) { + PasswordGrantBuilder builder = (PasswordGrantBuilder) this.builders.computeIfAbsent( + PasswordReactiveOAuth2AuthorizedClientProvider.class, (k) -> new PasswordGrantBuilder()); + builderConsumer.accept(builder); + return ReactiveOAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Builds an instance of {@link DelegatingReactiveOAuth2AuthorizedClientProvider} + * composed of one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s). + * @return the {@link DelegatingReactiveOAuth2AuthorizedClientProvider} + */ + public ReactiveOAuth2AuthorizedClientProvider build() { + List authorizedClientProviders = this.builders.values().stream() + .map(Builder::build).collect(Collectors.toList()); + return new DelegatingReactiveOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + interface Builder { + + ReactiveOAuth2AuthorizedClientProvider build(); + + } + /** * A builder for the {@code authorization_code} grant. */ - public class AuthorizationCodeGrantBuilder implements Builder { + public final class AuthorizationCodeGrantBuilder implements Builder { private AuthorizationCodeGrantBuilder() { } /** - * Builds an instance of {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}. - * + * Builds an instance of + * {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}. * @return the {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider} */ @Override public ReactiveOAuth2AuthorizedClientProvider build() { return new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider(); } - } - /** - * Configures support for the {@code refresh_token} grant. - * - * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} - */ - public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken() { - this.builders.computeIfAbsent(RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); - return ReactiveOAuth2AuthorizedClientProviderBuilder.this; - } - - /** - * Configures support for the {@code refresh_token} grant. - * - * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used for further configuration - * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} - */ - public ReactiveOAuth2AuthorizedClientProviderBuilder refreshToken(Consumer builderConsumer) { - RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( - RefreshTokenReactiveOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); - builderConsumer.accept(builder); - return ReactiveOAuth2AuthorizedClientProviderBuilder.this; - } - - /** - * A builder for the {@code refresh_token} grant. - */ - public class RefreshTokenGrantBuilder implements Builder { - private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; - private Duration clockSkew; - private Clock clock; - - private RefreshTokenGrantBuilder() { - } - - /** - * Sets the client used when requesting an access token credential at the Token Endpoint. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint - * @return the {@link RefreshTokenGrantBuilder} - */ - public RefreshTokenGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { - this.accessTokenResponseClient = accessTokenResponseClient; - return this; - } - - /** - * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. - * An access token is considered expired if it's before {@code Instant.now(this.clock) - clockSkew}. - * - * @param clockSkew the maximum acceptable clock skew - * @return the {@link RefreshTokenGrantBuilder} - */ - public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { - this.clockSkew = clockSkew; - return this; - } - - /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * - * @param clock the clock - * @return the {@link RefreshTokenGrantBuilder} - */ - public RefreshTokenGrantBuilder clock(Clock clock) { - this.clock = clock; - return this; - } - - /** - * Builds an instance of {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. - * - * @return the {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider} - */ - @Override - public ReactiveOAuth2AuthorizedClientProvider build() { - RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider(); - if (this.accessTokenResponseClient != null) { - authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); - } - if (this.clockSkew != null) { - authorizedClientProvider.setClockSkew(this.clockSkew); - } - if (this.clock != null) { - authorizedClientProvider.setClock(this.clock); - } - return authorizedClientProvider; - } - } - - /** - * Configures support for the {@code client_credentials} grant. - * - * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} - */ - public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials() { - this.builders.computeIfAbsent(ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); - return ReactiveOAuth2AuthorizedClientProviderBuilder.this; - } - - /** - * Configures support for the {@code client_credentials} grant. - * - * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} used for further configuration - * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} - */ - public ReactiveOAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer builderConsumer) { - ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( - ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); - builderConsumer.accept(builder); - return ReactiveOAuth2AuthorizedClientProviderBuilder.this; } /** * A builder for the {@code client_credentials} grant. */ - public class ClientCredentialsGrantBuilder implements Builder { + public final class ClientCredentialsGrantBuilder implements Builder { + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + private Clock clock; private ClientCredentialsGrantBuilder() { } /** - * Sets the client used when requesting an access token credential at the Token Endpoint. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * Sets the client used when requesting an access token credential at the Token + * Endpoint. + * @param accessTokenResponseClient the client used when requesting an access + * token credential at the Token Endpoint * @return the {@link ClientCredentialsGrantBuilder} */ - public ClientCredentialsGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + public ClientCredentialsGrantBuilder accessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { this.accessTokenResponseClient = accessTokenResponseClient; return this; } /** - * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. - * An access token is considered expired if it's before {@code Instant.now(this.clock) - clockSkew}. - * + * Sets the maximum acceptable clock skew, which is used when checking the access + * token expiry. An access token is considered expired if it's before + * {@code Instant.now(this.clock) - clockSkew}. * @param clockSkew the maximum acceptable clock skew * @return the {@link ClientCredentialsGrantBuilder} */ @@ -251,8 +236,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the + * access token expiry. * @param clock the clock * @return the {@link ClientCredentialsGrantBuilder} */ @@ -262,8 +247,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { } /** - * Builds an instance of {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}. - * + * Builds an instance of + * {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}. * @return the {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider} */ @Override @@ -280,57 +265,40 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { } return authorizedClientProvider; } - } - /** - * Configures support for the {@code password} grant. - * - * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} - */ - public ReactiveOAuth2AuthorizedClientProviderBuilder password() { - this.builders.computeIfAbsent(PasswordReactiveOAuth2AuthorizedClientProvider.class, k -> new PasswordGrantBuilder()); - return ReactiveOAuth2AuthorizedClientProviderBuilder.this; - } - - /** - * Configures support for the {@code password} grant. - * - * @param builderConsumer a {@code Consumer} of {@link PasswordGrantBuilder} used for further configuration - * @return the {@link ReactiveOAuth2AuthorizedClientProviderBuilder} - */ - public ReactiveOAuth2AuthorizedClientProviderBuilder password(Consumer builderConsumer) { - PasswordGrantBuilder builder = (PasswordGrantBuilder) this.builders.computeIfAbsent( - PasswordReactiveOAuth2AuthorizedClientProvider.class, k -> new PasswordGrantBuilder()); - builderConsumer.accept(builder); - return ReactiveOAuth2AuthorizedClientProviderBuilder.this; } /** * A builder for the {@code password} grant. */ - public class PasswordGrantBuilder implements Builder { + public final class PasswordGrantBuilder implements Builder { + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + private Clock clock; private PasswordGrantBuilder() { } /** - * Sets the client used when requesting an access token credential at the Token Endpoint. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * Sets the client used when requesting an access token credential at the Token + * Endpoint. + * @param accessTokenResponseClient the client used when requesting an access + * token credential at the Token Endpoint * @return the {@link PasswordGrantBuilder} */ - public PasswordGrantBuilder accessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + public PasswordGrantBuilder accessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { this.accessTokenResponseClient = accessTokenResponseClient; return this; } /** - * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. - * An access token is considered expired if it's before {@code Instant.now(this.clock) - clockSkew}. - * + * Sets the maximum acceptable clock skew, which is used when checking the access + * token expiry. An access token is considered expired if it's before + * {@code Instant.now(this.clock) - clockSkew}. * @param clockSkew the maximum acceptable clock skew * @return the {@link PasswordGrantBuilder} */ @@ -340,8 +308,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the + * access token expiry. * @param clock the clock * @return the {@link PasswordGrantBuilder} */ @@ -352,7 +320,6 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { /** * Builds an instance of {@link PasswordReactiveOAuth2AuthorizedClientProvider}. - * * @return the {@link PasswordReactiveOAuth2AuthorizedClientProvider} */ @Override @@ -369,23 +336,79 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { } return authorizedClientProvider; } + } /** - * Builds an instance of {@link DelegatingReactiveOAuth2AuthorizedClientProvider} - * composed of one or more {@link ReactiveOAuth2AuthorizedClientProvider}(s). - * - * @return the {@link DelegatingReactiveOAuth2AuthorizedClientProvider} + * A builder for the {@code refresh_token} grant. */ - public ReactiveOAuth2AuthorizedClientProvider build() { - List authorizedClientProviders = - this.builders.values().stream() - .map(Builder::build) - .collect(Collectors.toList()); - return new DelegatingReactiveOAuth2AuthorizedClientProvider(authorizedClientProviders); + public final class RefreshTokenGrantBuilder implements Builder { + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + + private Duration clockSkew; + + private Clock clock; + + private RefreshTokenGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint. + * @param accessTokenResponseClient the client used when requesting an access + * token credential at the Token Endpoint + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder accessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access + * token expiry. An access token is considered expired if it's before + * {@code Instant.now(this.clock) - clockSkew}. + * @param clockSkew the maximum acceptable clock skew + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the + * access token expiry. + * @param clock the clock + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clock(Clock clock) { + this.clock = clock; + return this; + } + + /** + * Builds an instance of + * {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. + * @return the {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider} + */ + @Override + public ReactiveOAuth2AuthorizedClientProvider build() { + RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + if (this.clock != null) { + authorizedClientProvider.setClock(this.clock); + } + return authorizedClientProvider; + } + } - interface Builder { - ReactiveOAuth2AuthorizedClientProvider build(); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientService.java index 3e189f8b11..49c53f1d22 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientService.java @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import reactor.core.publisher.Mono; - /** - * Implementations of this interface are responsible for the management - * of {@link OAuth2AuthorizedClient Authorized Client(s)}, which provide the purpose - * of associating an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential + * Implementations of this interface are responsible for the management of + * {@link OAuth2AuthorizedClient Authorized Client(s)}, which provide the purpose of + * associating an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, - * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} - * that originally granted the authorization. + * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} that originally + * granted the authorization. * * @author Rob Winch * @since 5.1 @@ -39,32 +40,27 @@ import reactor.core.publisher.Mono; public interface ReactiveOAuth2AuthorizedClientService { /** - * Returns the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User's {@code Principal} name - * or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User's {@code Principal} name or {@code null} if + * not available. * @param clientRegistrationId the identifier for the client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) * @param a type of OAuth2AuthorizedClient * @return the {@link OAuth2AuthorizedClient} or {@code null} if not available */ - Mono loadAuthorizedClient(String clientRegistrationId, - String principalName); + Mono loadAuthorizedClient(String clientRegistrationId, String principalName); /** - * Saves the {@link OAuth2AuthorizedClient} associating it to - * the provided End-User {@link Authentication} (Resource Owner). - * + * Saves the {@link OAuth2AuthorizedClient} associating it to the provided End-User + * {@link Authentication} (Resource Owner). * @param authorizedClient the authorized client * @param principal the End-User {@link Authentication} (Resource Owner) */ - Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, - Authentication principal); + Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal); /** - * Removes the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User's {@code Principal} name. - * + * Removes the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User's {@code Principal} name. * @param clientRegistrationId the identifier for the client's registration * @param principalName the name of the End-User {@code Principal} (Resource Owner) */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index f77a44809b..04962922d9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -25,17 +34,9 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - /** - * An implementation of an {@link OAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * An implementation of an {@link OAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. * * @author Joe Grandja * @since 5.2 @@ -43,84 +44,93 @@ import java.util.Set; * @see DefaultRefreshTokenTokenResponseClient */ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private OAuth2AccessTokenResponseClient accessTokenResponseClient = - new DefaultRefreshTokenTokenResponseClient(); + + private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + private Clock clock = Clock.systemUTC(); /** - * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns {@code null} if re-authorization is not supported, - * e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} - * is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * Attempt to re-authorize the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns {@code null} if re-authorization is not supported, e.g. + * the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() + * refresh token} is not available for the authorized client OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * *

        - * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} + * are supported: *

          - *
        1. {@link OAuth2AuthorizationContext#REQUEST_SCOPE_ATTRIBUTE_NAME} (optional) - a {@code String[]} of scope(s) - * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
        2. + *
        3. {@link OAuth2AuthorizationContext#REQUEST_SCOPE_ATTRIBUTE_NAME} (optional) - a + * {@code String[]} of scope(s) to be requested by the + * {@link OAuth2AuthorizationContext#getClientRegistration() client}
        4. *
        - * * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is not supported + * @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is + * not supported */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (authorizedClient == null || - authorizedClient.getRefreshToken() == null || - !hasTokenExpired(authorizedClient.getAccessToken())) { + if (authorizedClient == null || authorizedClient.getRefreshToken() == null + || !hasTokenExpired(authorizedClient.getAccessToken())) { return null; } - Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = Collections.emptySet(); if (requestScope != null) { - Assert.isInstanceOf(String[].class, requestScope, - "The context attribute must be of type String[] '" + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + Assert.isInstanceOf(String[].class, requestScope, "The context attribute must be of type String[] '" + + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); - - OAuth2AccessTokenResponse tokenResponse; - try { - tokenResponse = this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); - } catch (OAuth2AuthorizationException ex) { - throw new ClientAuthorizationException(ex.getError(), authorizedClient.getClientRegistration().getRegistrationId(), ex); - } - + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest); return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); } + private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient, + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + try { + return this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + } + catch (OAuth2AuthorizationException ex) { + throw new ClientAuthorizationException(ex.getError(), + authorizedClient.getClientRegistration().getRegistrationId(), ex); + } + } + private boolean hasTokenExpired(AbstractOAuth2Token token) { return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); } /** - * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code refresh_token} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code refresh_token} grant */ - public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + public void setAccessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } /** * Sets the maximum acceptable clock skew, which is used when checking the - * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. * *

        - * An access token is considered expired if {@code OAuth2AccessToken#getExpiresAt() - clockSkew} - * is before the current time {@code clock#instant()}. - * + * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. * @param clockSkew the maximum acceptable clock skew */ public void setClockSkew(Duration clockSkew) { @@ -130,12 +140,13 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. * @param clock the clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java index f2526ed225..5f6e16369d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java @@ -13,17 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client; -import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.WebClientReactiveRefreshTokenTokenResponseClient; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AbstractOAuth2Token; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import org.springframework.util.Assert; -import reactor.core.publisher.Mono; +package org.springframework.security.oauth2.client; import java.time.Clock; import java.time.Duration; @@ -33,65 +24,79 @@ import java.util.Collections; import java.util.HashSet; import java.util.Set; +import reactor.core.publisher.Mono; + +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.util.Assert; + /** - * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} - * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. * * @author Joe Grandja * @since 5.2 * @see ReactiveOAuth2AuthorizedClientProvider * @see WebClientReactiveRefreshTokenTokenResponseClient */ -public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { - private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = - new WebClientReactiveRefreshTokenTokenResponseClient(); +public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider + implements ReactiveOAuth2AuthorizedClientProvider { + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); + private Clock clock = Clock.systemUTC(); /** - * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. - * Returns an empty {@code Mono} if re-authorization is not supported, - * e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} - * is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * Attempt to re-authorize the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns an empty {@code Mono} if re-authorization is not + * supported, e.g. the client is not authorized OR the + * {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available for + * the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access + * token} is not expired. * *

        - * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} + * are supported: *

          - *
        1. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s) - * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
        2. + *
        3. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - + * a {@code String[]} of scope(s) to be requested by the + * {@link OAuth2AuthorizationContext#getClientRegistration() client}
        4. *
        - * * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if re-authorization is not supported + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * re-authorization is not supported */ @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (authorizedClient == null || - authorizedClient.getRefreshToken() == null || - !hasTokenExpired(authorizedClient.getAccessToken())) { + if (authorizedClient == null || authorizedClient.getRefreshToken() == null + || !hasTokenExpired(authorizedClient.getAccessToken())) { return Mono.empty(); } - Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = Collections.emptySet(); if (requestScope != null) { - Assert.isInstanceOf(String[].class, requestScope, - "The context attribute must be of type String[] '" + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + Assert.isInstanceOf(String[].class, requestScope, "The context attribute must be of type String[] '" + + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } ClientRegistration clientRegistration = context.getClientRegistration(); - - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - clientRegistration, authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); - - return Mono.just(refreshTokenGrantRequest) - .flatMap(this.accessTokenResponseClient::getTokenResponse) + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); + return Mono.just(refreshTokenGrantRequest).flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, - e -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e)) - .map(tokenResponse -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), + e)) + .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken())); } @@ -100,23 +105,26 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider implements } /** - * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant. - * - * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code refresh_token} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code refresh_token} grant */ - public void setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + public void setAccessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } /** * Sets the maximum acceptable clock skew, which is used when checking the - * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is 60 seconds. + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. * *

        - * An access token is considered expired if {@code OAuth2AccessToken#getExpiresAt() - clockSkew} - * is before the current time {@code clock#instant()}. - * + * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. * @param clockSkew the maximum acceptable clock skew */ public void setClockSkew(Duration clockSkew) { @@ -126,12 +134,13 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider implements } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access token expiry. - * + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. * @param clock the clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java index 1afe7f43d1..3701e8457c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -22,16 +29,10 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.util.Assert; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - /** - * An {@link OAuth2AuthorizationFailureHandler} that removes an {@link OAuth2AuthorizedClient} - * when the {@link OAuth2Error#getErrorCode()} matches - * one of the configured {@link OAuth2ErrorCodes OAuth 2.0 error codes}. + * An {@link OAuth2AuthorizationFailureHandler} that removes an + * {@link OAuth2AuthorizedClient} when the {@link OAuth2Error#getErrorCode()} matches one + * of the configured {@link OAuth2ErrorCodes OAuth 2.0 error codes}. * * @author Joe Grandja * @since 5.3 @@ -42,29 +43,28 @@ import java.util.Set; public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements OAuth2AuthorizationFailureHandler { /** - * The default OAuth 2.0 error codes that will trigger removal of an {@link OAuth2AuthorizedClient}. + * The default OAuth 2.0 error codes that will trigger removal of an + * {@link OAuth2AuthorizedClient}. * @see OAuth2ErrorCodes */ - public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( - /* - * Returned from Resource Servers when an access token provided is expired, revoked, - * malformed, or invalid for other reasons. - * - * Note that this is needed because ServletOAuth2AuthorizedClientExchangeFilterFunction - * delegates this type of failure received from a Resource Server - * to this failure handler. - */ - OAuth2ErrorCodes.INVALID_TOKEN, - - /* - * Returned from Authorization Servers when the authorization grant or refresh token is invalid, expired, revoked, - * does not match the redirection URI used in the authorization request, or was issued to another client. - */ - OAuth2ErrorCodes.INVALID_GRANT - ))); + public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES; + static { + Set codes = new LinkedHashSet<>(); + // Returned from Resource Servers when an access token provided is expired, + // revoked, malformed, or invalid for other reasons. Note that this is needed + // because ServletOAuth2AuthorizedClientExchangeFilterFunction delegates this type + // of failure received from a Resource Server to this failure handler. + codes.add(OAuth2ErrorCodes.INVALID_TOKEN); + // Returned from Authorization Servers when the authorization grant or refresh + // token is invalid, expired, revoked, does not match the redirection URI used in + // the authorization request, or was issued to another client. + codes.add(OAuth2ErrorCodes.INVALID_GRANT); + DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(codes); + } /** - * The OAuth 2.0 error codes which will trigger removal of an {@link OAuth2AuthorizedClient}. + * The OAuth 2.0 error codes which will trigger removal of an + * {@link OAuth2AuthorizedClient}. * @see OAuth2ErrorCodes */ private final Set removeAuthorizedClientErrorCodes; @@ -76,6 +76,59 @@ public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements */ private final OAuth2AuthorizedClientRemover delegate; + /** + * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using + * the provided parameters. + * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for + * removing an {@link OAuth2AuthorizedClient} if the error code is one of the + * {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. + */ + public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( + OAuth2AuthorizedClientRemover authorizedClientRemover) { + this(authorizedClientRemover, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); + } + + /** + * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using + * the provided parameters. + * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for + * removing an {@link OAuth2AuthorizedClient} if the error code is one of the + * {@link #removeAuthorizedClientErrorCodes}. + * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will + * trigger removal of an authorized client. + * @see OAuth2ErrorCodes + */ + public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( + OAuth2AuthorizedClientRemover authorizedClientRemover, Set removeAuthorizedClientErrorCodes) { + Assert.notNull(authorizedClientRemover, "authorizedClientRemover cannot be null"); + Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); + this.removeAuthorizedClientErrorCodes = Collections + .unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); + this.delegate = authorizedClientRemover; + } + + @Override + public void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, Authentication principal, + Map attributes) { + if (authorizationException instanceof ClientAuthorizationException + && hasRemovalErrorCode(authorizationException)) { + ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; + this.delegate.removeAuthorizedClient(clientAuthorizationException.getClientRegistrationId(), principal, + attributes); + } + } + + /** + * Returns true if the given exception has an error code that indicates that the + * authorized client should be removed. + * @param authorizationException the exception that caused the authorization failure + * @return true if the given exception has an error code that indicates that the + * authorized client should be removed. + */ + private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) { + return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode()); + } + /** * Removes an {@link OAuth2AuthorizedClient} from an * {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}. @@ -84,68 +137,19 @@ public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements public interface OAuth2AuthorizedClientRemover { /** - * Removes the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User {@link Authentication} (Resource Owner). - * + * Removes the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User {@link Authentication} (Resource Owner). * @param clientRegistrationId the identifier for the client's registration * @param principal the End-User {@link Authentication} (Resource Owner) - * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions. - * For example, this might contain a {@code javax.servlet.http.HttpServletRequest} - * and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed - * within the context of a {@code javax.servlet.ServletContext}. + * @param attributes an immutable {@code Map} of (optional) attributes present + * under certain conditions. For example, this might contain a + * {@code javax.servlet.http.HttpServletRequest} and + * {@code javax.servlet.http.HttpServletResponse} if the authorization was + * performed within the context of a {@code javax.servlet.ServletContext}. */ - void removeAuthorizedClient(String clientRegistrationId, Authentication principal, Map attributes); + void removeAuthorizedClient(String clientRegistrationId, Authentication principal, + Map attributes); + } - /** - * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters. - * - * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for removing an {@link OAuth2AuthorizedClient} - * if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. - */ - public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientRemover authorizedClientRemover) { - this(authorizedClientRemover, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); - } - - /** - * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters. - * - * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for removing an {@link OAuth2AuthorizedClient} - * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. - * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client. - * @see OAuth2ErrorCodes - */ - public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - OAuth2AuthorizedClientRemover authorizedClientRemover, - Set removeAuthorizedClientErrorCodes) { - Assert.notNull(authorizedClientRemover, "authorizedClientRemover cannot be null"); - Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); - this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); - this.delegate = authorizedClientRemover; - } - - @Override - public void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, - Authentication principal, Map attributes) { - - if (authorizationException instanceof ClientAuthorizationException && - hasRemovalErrorCode(authorizationException)) { - - ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; - this.delegate.removeAuthorizedClient( - clientAuthorizationException.getClientRegistrationId(), principal, attributes); - } - } - - /** - * Returns true if the given exception has an error code that - * indicates that the authorized client should be removed. - * - * @param authorizationException the exception that caused the authorization failure - * @return true if the given exception has an error code that - * indicates that the authorized client should be removed. - */ - private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) { - return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode()); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java index a0106ab40f..0e7edd67d7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java @@ -13,56 +13,60 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; /** - * A {@link ReactiveOAuth2AuthorizationFailureHandler} that removes an {@link OAuth2AuthorizedClient} - * when the {@link OAuth2Error#getErrorCode()} matches - * one of the configured {@link OAuth2ErrorCodes OAuth 2.0 error codes}. + * A {@link ReactiveOAuth2AuthorizationFailureHandler} that removes an + * {@link OAuth2AuthorizedClient} when the {@link OAuth2Error#getErrorCode()} matches one + * of the configured {@link OAuth2ErrorCodes OAuth 2.0 error codes}. * * @author Phil Clay * @since 5.3 */ -public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler implements ReactiveOAuth2AuthorizationFailureHandler { +public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler + implements ReactiveOAuth2AuthorizationFailureHandler { /** - * The default OAuth 2.0 error codes that will trigger removal of the authorized client. + * The default OAuth 2.0 error codes that will trigger removal of the authorized + * client. * @see OAuth2ErrorCodes */ - public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( - /* - * Returned from resource servers when an access token provided is expired, revoked, - * malformed, or invalid for other reasons. - * - * Note that this is needed because the ServerOAuth2AuthorizedClientExchangeFilterFunction - * delegates this type of failure received from a resource server - * to this failure handler. - */ - OAuth2ErrorCodes.INVALID_TOKEN, - /* - * Returned from authorization servers when a refresh token is invalid, expired, revoked, - * does not match the redirection URI used in the authorization request, or was issued to another client. - */ - OAuth2ErrorCodes.INVALID_GRANT))); + public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES; + static { + Set codes = new LinkedHashSet<>(); + // Returned from resource servers when an access token provided is expired, + // revoked, malformed, or invalid for other reasons. Note that this is needed + // because the ServerOAuth2AuthorizedClientExchangeFilterFunction delegates this + // type of failure received from a resource server to this failure handler. + codes.add(OAuth2ErrorCodes.INVALID_TOKEN); + // Returned from authorization servers when a refresh token is invalid, expired, + // revoked, does not match the redirection URI used in the authorization request, + // or was issued to another client. + codes.add(OAuth2ErrorCodes.INVALID_GRANT); + DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(codes); + } /** * A delegate that removes an {@link OAuth2AuthorizedClient} from a - * {@link ServerOAuth2AuthorizedClientRepository} or {@link ReactiveOAuth2AuthorizedClientService} - * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. + * {@link ServerOAuth2AuthorizedClientRepository} or + * {@link ReactiveOAuth2AuthorizedClientService} if the error code is one of the + * {@link #removeAuthorizedClientErrorCodes}. */ private final OAuth2AuthorizedClientRemover delegate; @@ -73,77 +77,85 @@ public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler imp private final Set removeAuthorizedClientErrorCodes; /** - * Removes an {@link OAuth2AuthorizedClient} from a - * {@link ServerOAuth2AuthorizedClientRepository} or {@link ReactiveOAuth2AuthorizedClientService}. + * Constructs a + * {@code RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} using the + * provided parameters. + * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for + * removing an {@link OAuth2AuthorizedClient} if the error code is one of the + * {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. */ - @FunctionalInterface - public interface OAuth2AuthorizedClientRemover { - - /** - * Removes the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User {@link Authentication} (Resource Owner). - * - * @param clientRegistrationId the identifier for the client's registration - * @param principal the End-User {@link Authentication} (Resource Owner) - * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions. - * For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} - * if the authorization was performed within the context of a {@code ServerWebExchange}. - * @return an empty {@link Mono} that completes after this handler has finished handling the event. - */ - Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, Map attributes); - } - - /** - * Constructs a {@code RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} using the provided parameters. - * - * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for removing an {@link OAuth2AuthorizedClient} - * if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. - */ - public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientRemover authorizedClientRemover) { + public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( + OAuth2AuthorizedClientRemover authorizedClientRemover) { this(authorizedClientRemover, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); } /** - * Constructs a {@code RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} using the provided parameters. - * - * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for removing an {@link OAuth2AuthorizedClient} - * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. - * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client. + * Constructs a + * {@code RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} using the + * provided parameters. + * @param authorizedClientRemover the {@link OAuth2AuthorizedClientRemover} used for + * removing an {@link OAuth2AuthorizedClient} if the error code is one of the + * {@link #removeAuthorizedClientErrorCodes}. + * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will + * trigger removal of an authorized client. * @see OAuth2ErrorCodes */ public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( - OAuth2AuthorizedClientRemover authorizedClientRemover, - Set removeAuthorizedClientErrorCodes) { + OAuth2AuthorizedClientRemover authorizedClientRemover, Set removeAuthorizedClientErrorCodes) { Assert.notNull(authorizedClientRemover, "authorizedClientRemover cannot be null"); Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); - this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); + this.removeAuthorizedClientErrorCodes = Collections + .unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); this.delegate = authorizedClientRemover; } @Override public Mono onAuthorizationFailure(OAuth2AuthorizationException authorizationException, Authentication principal, Map attributes) { - if (authorizationException instanceof ClientAuthorizationException && hasRemovalErrorCode(authorizationException)) { - ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; - return this.delegate.removeAuthorizedClient( - clientAuthorizationException.getClientRegistrationId(), principal, attributes); - } else { - return Mono.empty(); + return this.delegate.removeAuthorizedClient(clientAuthorizationException.getClientRegistrationId(), + principal, attributes); } + return Mono.empty(); } /** - * Returns true if the given exception has an error code that - * indicates that the authorized client should be removed. - * + * Returns true if the given exception has an error code that indicates that the + * authorized client should be removed. * @param authorizationException the exception that caused the authorization failure - * @return true if the given exception has an error code that - * indicates that the authorized client should be removed. + * @return true if the given exception has an error code that indicates that the + * authorized client should be removed. */ private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) { return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode()); } + + /** + * Removes an {@link OAuth2AuthorizedClient} from a + * {@link ServerOAuth2AuthorizedClientRepository} or + * {@link ReactiveOAuth2AuthorizedClientService}. + */ + @FunctionalInterface + public interface OAuth2AuthorizedClientRemover { + + /** + * Removes the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User {@link Authentication} (Resource Owner). + * @param clientRegistrationId the identifier for the client's registration + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param attributes an immutable {@code Map} of extra optional attributes present + * under certain conditions. For example, this might contain a + * {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} if + * the authorization was performed within the context of a + * {@code ServerWebExchange}. + * @return an empty {@link Mono} that completes after this handler has finished + * handling the event. + */ + Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, + Map attributes); + + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/annotation/RegisteredOAuth2AuthorizedClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/annotation/RegisteredOAuth2AuthorizedClient.java index 4bad4aff8e..d553b58fb6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/annotation/RegisteredOAuth2AuthorizedClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/annotation/RegisteredOAuth2AuthorizedClient.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.annotation; -import org.springframework.core.annotation.AliasFor; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; +package org.springframework.security.oauth2.client.annotation; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; @@ -25,13 +22,16 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.springframework.core.annotation.AliasFor; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; + /** - * This annotation may be used to resolve a method parameter - * to an argument value of type {@link OAuth2AuthorizedClient}. + * This annotation may be used to resolve a method parameter to an argument value of type + * {@link OAuth2AuthorizedClient}. * *

        - * For example: - *

        + * For example: 
          * @Controller
          * public class MyController {
          *     @GetMapping("/authorized-client")
        @@ -52,18 +52,16 @@ public @interface RegisteredOAuth2AuthorizedClient {
         
         	/**
         	 * Sets the client registration identifier.
        -	 *
         	 * @return the client registration identifier
         	 */
         	@AliasFor("value")
         	String registrationId() default "";
         
         	/**
        -	 * The default attribute for this annotation.
        -	 * This is an alias for {@link #registrationId()}.
        -	 * For example, {@code @RegisteredOAuth2AuthorizedClient("login-client")} is equivalent to
        +	 * The default attribute for this annotation. This is an alias for
        +	 * {@link #registrationId()}. For example,
        +	 * {@code @RegisteredOAuth2AuthorizedClient("login-client")} is equivalent to
         	 * {@code @RegisteredOAuth2AuthorizedClient(registrationId="login-client")}.
        -	 *
         	 * @return the client registration identifier
         	 */
         	@AliasFor("registrationId")
        diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationToken.java
        index 73fd7d5b19..7d5fdba358 100644
        --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationToken.java
        +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationToken.java
        @@ -13,8 +13,11 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.oauth2.client.authentication;
         
        +import java.util.Collection;
        +
         import org.springframework.security.authentication.AbstractAuthenticationToken;
         import org.springframework.security.core.Authentication;
         import org.springframework.security.core.GrantedAuthority;
        @@ -23,16 +26,14 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
         import org.springframework.security.oauth2.core.user.OAuth2User;
         import org.springframework.util.Assert;
         
        -import java.util.Collection;
        -
         /**
        - * An implementation of an {@link AbstractAuthenticationToken}
        - * that represents an OAuth 2.0 {@link Authentication}.
        + * An implementation of an {@link AbstractAuthenticationToken} that represents an OAuth
        + * 2.0 {@link Authentication}.
          * 

        - * The {@link Authentication} associates an {@link OAuth2User} {@code Principal} - * to the identifier of the {@link #getAuthorizedClientRegistrationId() Authorized Client}, - * which the End-User ({@code Principal}) granted authorization to - * so that it can access it's protected resources at the UserInfo Endpoint. + * The {@link Authentication} associates an {@link OAuth2User} {@code Principal} to the + * identifier of the {@link #getAuthorizedClientRegistrationId() Authorized Client}, which + * the End-User ({@code Principal}) granted authorization to so that it can access it's + * protected resources at the UserInfo Endpoint. * * @author Joe Grandja * @since 5.0 @@ -41,20 +42,22 @@ import java.util.Collection; * @see OAuth2AuthorizedClient */ public class OAuth2AuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final OAuth2User principal; + private final String authorizedClientRegistrationId; /** * Constructs an {@code OAuth2AuthenticationToken} using the provided parameters. - * * @param principal the user {@code Principal} registered with the OAuth 2.0 Provider * @param authorities the authorities granted to the user - * @param authorizedClientRegistrationId the registration identifier of the {@link OAuth2AuthorizedClient Authorized Client} + * @param authorizedClientRegistrationId the registration identifier of the + * {@link OAuth2AuthorizedClient Authorized Client} */ - public OAuth2AuthenticationToken(OAuth2User principal, - Collection authorities, - String authorizedClientRegistrationId) { + public OAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { super(authorities); Assert.notNull(principal, "principal cannot be null"); Assert.hasText(authorizedClientRegistrationId, "authorizedClientRegistrationId cannot be empty"); @@ -75,11 +78,12 @@ public class OAuth2AuthenticationToken extends AbstractAuthenticationToken { } /** - * Returns the registration identifier of the {@link OAuth2AuthorizedClient Authorized Client}. - * + * Returns the registration identifier of the {@link OAuth2AuthorizedClient Authorized + * Client}. * @return the registration identifier of the Authorized Client. */ public String getAuthorizedClientRegistrationId() { return this.authorizedClientRegistrationId; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 4ad62b09ce..efcf42a19a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; import org.springframework.security.authentication.AuthenticationProvider; @@ -28,70 +29,67 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.util.Assert; /** - * An implementation of an {@link AuthenticationProvider} for the OAuth 2.0 Authorization Code Grant. + * An implementation of an {@link AuthenticationProvider} for the OAuth 2.0 Authorization + * Code Grant. * *

        - * This {@link AuthenticationProvider} is responsible for authenticating - * an Authorization Code credential with the Authorization Server's Token Endpoint - * and if valid, exchanging it for an Access Token credential. + * This {@link AuthenticationProvider} is responsible for authenticating an Authorization + * Code credential with the Authorization Server's Token Endpoint and if valid, exchanging + * it for an Access Token credential. * * @author Joe Grandja * @since 5.1 * @see OAuth2AuthorizationCodeAuthenticationToken * @see OAuth2AccessTokenResponseClient - * @see Section 4.1 Authorization Code Grant Flow - * @see Section 4.1.3 Access Token Request - * @see Section 4.1.4 Access Token Response + * @see Section + * 4.1 Authorization Code Grant Flow + * @see Section 4.1.3 Access Token + * Request + * @see Section 4.1.4 Access Token + * Response */ public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { + private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; + private final OAuth2AccessTokenResponseClient accessTokenResponseClient; /** - * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. - * - * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the + * provided parameters. + * @param accessTokenResponseClient the client used for requesting the access token + * credential from the Token Endpoint */ public OAuth2AuthorizationCodeAuthenticationProvider( - OAuth2AccessTokenResponseClient accessTokenResponseClient) { - + OAuth2AccessTokenResponseClient accessTokenResponseClient) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = - (OAuth2AuthorizationCodeAuthenticationToken) authentication; - - OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationResponse(); + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = (OAuth2AuthorizationCodeAuthenticationToken) authentication; + OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication.getAuthorizationExchange() + .getAuthorizationResponse(); if (authorizationResponse.statusError()) { throw new OAuth2AuthorizationException(authorizationResponse.getError()); } - - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationRequest(); + OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() + .getAuthorizationRequest(); if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); throw new OAuth2AuthorizationException(oauth2Error); } - - OAuth2AccessTokenResponse accessTokenResponse = - this.accessTokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange())); - - OAuth2AuthorizationCodeAuthenticationToken authenticationResult = - new OAuth2AuthorizationCodeAuthenticationToken( + OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( + new OAuth2AuthorizationCodeGrantRequest(authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange())); + OAuth2AuthorizationCodeAuthenticationToken authenticationResult = new OAuth2AuthorizationCodeAuthenticationToken( authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange(), - accessTokenResponse.getAccessToken(), - accessTokenResponse.getRefreshToken(), - accessTokenResponse.getAdditionalParameters()); + authorizationCodeAuthentication.getAuthorizationExchange(), accessTokenResponse.getAccessToken(), + accessTokenResponse.getRefreshToken(), accessTokenResponse.getAdditionalParameters()); authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); - return authenticationResult; } @@ -99,4 +97,5 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica public boolean supports(Class authentication) { return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java index 3c45b3e1d7..9d5fe681b6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationToken.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.SpringSecurityCoreVersion; @@ -24,10 +29,6 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.util.Assert; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - /** * An {@link AbstractAuthenticationToken} for the OAuth 2.0 Authorization Code Grant. * @@ -37,24 +38,31 @@ import java.util.Map; * @see ClientRegistration * @see OAuth2AuthorizationExchange * @see OAuth2AccessToken - * @see Section 4.1 Authorization Code Grant Flow + * @see Section + * 4.1 Authorization Code Grant Flow */ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private Map additionalParameters = new HashMap<>(); + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; /** - * This constructor should be used when the Authorization Request/Response is complete. - * + * This constructor should be used when the Authorization Request/Response is + * complete. * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange */ public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange) { + OAuth2AuthorizationExchange authorizationExchange) { super(Collections.emptyList()); Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); @@ -65,35 +73,32 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti /** * This constructor should be used when the Access Token Request/Response is complete, * which indicates that the Authorization Code Grant flow has fully completed. - * * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange * @param accessToken the access token credential */ public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange, - OAuth2AccessToken accessToken) { + OAuth2AuthorizationExchange authorizationExchange, OAuth2AccessToken accessToken) { this(clientRegistration, authorizationExchange, accessToken, null); } /** * This constructor should be used when the Access Token Request/Response is complete, * which indicates that the Authorization Code Grant flow has fully completed. - * * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange * @param accessToken the access token credential * @param refreshToken the refresh token credential */ public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange, - OAuth2AccessToken accessToken, - @Nullable OAuth2RefreshToken refreshToken) { + OAuth2AuthorizationExchange authorizationExchange, OAuth2AccessToken accessToken, + @Nullable OAuth2RefreshToken refreshToken) { this(clientRegistration, authorizationExchange, accessToken, refreshToken, Collections.emptyMap()); } - public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, OAuth2AuthorizationExchange authorizationExchange, OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken, - Map additionalParameters) { + public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, + OAuth2AuthorizationExchange authorizationExchange, OAuth2AccessToken accessToken, + OAuth2RefreshToken refreshToken, Map additionalParameters) { this(clientRegistration, authorizationExchange); Assert.notNull(accessToken, "accessToken cannot be null"); this.accessToken = accessToken; @@ -109,14 +114,12 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti @Override public Object getCredentials() { - return this.accessToken != null ? - this.accessToken.getTokenValue() : - this.authorizationExchange.getAuthorizationResponse().getCode(); + return (this.accessToken != null) ? this.accessToken.getTokenValue() + : this.authorizationExchange.getAuthorizationResponse().getCode(); } /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -125,7 +128,6 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti /** * Returns the {@link OAuth2AuthorizationExchange authorization exchange}. - * * @return the {@link OAuth2AuthorizationExchange} */ public OAuth2AuthorizationExchange getAuthorizationExchange() { @@ -134,7 +136,6 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti /** * Returns the {@link OAuth2AccessToken access token}. - * * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { @@ -143,7 +144,6 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti /** * Returns the {@link OAuth2RefreshToken refresh token}. - * * @return the {@link OAuth2RefreshToken} */ public @Nullable OAuth2RefreshToken getRefreshToken() { @@ -152,10 +152,10 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti /** * Returns the additional parameters - * * @return the additional parameters */ public Map getAdditionalParameters() { return this.additionalParameters; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java index 28e4d7bdfe..a2497f5978 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; @@ -31,23 +36,22 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - -import java.util.function.Function; /** - * An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login, - * which leverages the OAuth 2.0 Authorization Code Grant Flow. + * An implementation of an + * {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth + * 2.0 Login, which leverages the OAuth 2.0 Authorization Code Grant Flow. * - * This {@link org.springframework.security.authentication.AuthenticationProvider} is responsible for authenticating - * an Authorization Code credential with the Authorization Server's Token Endpoint - * and if valid, exchanging it for an Access Token credential. + * This {@link org.springframework.security.authentication.AuthenticationProvider} is + * responsible for authenticating an Authorization Code credential with the Authorization + * Server's Token Endpoint and if valid, exchanging it for an Access Token credential. *

        - * It will also obtain the user attributes of the End-User (Resource Owner) - * from the UserInfo Endpoint using an {@link org.springframework.security.oauth2.client.userinfo.OAuth2UserService}, - * which will create a {@code Principal} in the form of an {@link OAuth2User}. - * The {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} - * to complete the authentication. + * It will also obtain the user attributes of the End-User (Resource Owner) from the + * UserInfo Endpoint using an + * {@link org.springframework.security.oauth2.client.userinfo.OAuth2UserService}, which + * will create a {@code Principal} in the form of an {@link OAuth2User}. The + * {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} to + * complete the authentication. * * @author Rob Winch * @since 5.1 @@ -55,12 +59,19 @@ import java.util.function.Function; * @see ReactiveOAuth2AccessTokenResponseClient * @see ReactiveOAuth2UserService * @see OAuth2User - * @see Section 4.1 Authorization Code Grant Flow - * @see Section 4.1.3 Access Token Request - * @see Section 4.1.4 Access Token Response + * @see Section + * 4.1 Authorization Code Grant Flow + * @see Section 4.1.3 Access Token + * Request + * @see Section 4.1.4 Access Token + * Response */ public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements ReactiveAuthenticationManager { + private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; + private final ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; public OAuth2AuthorizationCodeReactiveAuthenticationManager( @@ -73,34 +84,33 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements Rea public Mono authenticate(Authentication authentication) { return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - - OAuth2AuthorizationResponse authorizationResponse = token.getAuthorizationExchange().getAuthorizationResponse(); + OAuth2AuthorizationResponse authorizationResponse = token.getAuthorizationExchange() + .getAuthorizationResponse(); if (authorizationResponse.statusError()) { return Mono.error(new OAuth2AuthorizationException(authorizationResponse.getError())); } - - OAuth2AuthorizationRequest authorizationRequest = token.getAuthorizationExchange().getAuthorizationRequest(); + OAuth2AuthorizationRequest authorizationRequest = token.getAuthorizationExchange() + .getAuthorizationRequest(); if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); return Mono.error(new OAuth2AuthorizationException(oauth2Error)); } - OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( - token.getClientRegistration(), - token.getAuthorizationExchange()); - - return this.accessTokenResponseClient.getTokenResponse(authzRequest) - .map(onSuccess(token)); + token.getClientRegistration(), token.getAuthorizationExchange()); + return this.accessTokenResponseClient.getTokenResponse(authzRequest).map(onSuccess(token)); }); } - private Function onSuccess(OAuth2AuthorizationCodeAuthenticationToken token) { - return accessTokenResponse -> { + private Function onSuccess( + OAuth2AuthorizationCodeAuthenticationToken token) { + return (accessTokenResponse) -> { ClientRegistration registration = token.getClientRegistration(); OAuth2AuthorizationExchange exchange = token.getAuthorizationExchange(); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2RefreshToken refreshToken = accessTokenResponse.getRefreshToken(); - return new OAuth2AuthorizationCodeAuthenticationToken(registration, exchange, accessToken, refreshToken, accessTokenResponse.getAdditionalParameters()); + return new OAuth2AuthorizationCodeAuthenticationToken(registration, exchange, accessToken, refreshToken, + accessTokenResponse.getAdditionalParameters()); }; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java index 8d3b70234d..a82d320e53 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; +import java.util.Collection; +import java.util.Map; + import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -31,22 +35,19 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; -import java.util.Collection; -import java.util.Map; - /** - * An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login, - * which leverages the OAuth 2.0 Authorization Code Grant Flow. + * An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login, which + * leverages the OAuth 2.0 Authorization Code Grant Flow. * - * This {@link AuthenticationProvider} is responsible for authenticating - * an Authorization Code credential with the Authorization Server's Token Endpoint - * and if valid, exchanging it for an Access Token credential. + * This {@link AuthenticationProvider} is responsible for authenticating an Authorization + * Code credential with the Authorization Server's Token Endpoint and if valid, exchanging + * it for an Access Token credential. *

        - * It will also obtain the user attributes of the End-User (Resource Owner) - * from the UserInfo Endpoint using an {@link OAuth2UserService}, - * which will create a {@code Principal} in the form of an {@link OAuth2User}. - * The {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} - * to complete the authentication. + * It will also obtain the user attributes of the End-User (Resource Owner) from the + * UserInfo Endpoint using an {@link OAuth2UserService}, which will create a + * {@code Principal} in the form of an {@link OAuth2User}. The {@code OAuth2User} is then + * associated to the {@link OAuth2LoginAuthenticationToken} to complete the + * authentication. * * @author Joe Grandja * @since 5.0 @@ -54,82 +55,82 @@ import java.util.Map; * @see OAuth2AccessTokenResponseClient * @see OAuth2UserService * @see OAuth2User - * @see Section 4.1 Authorization Code Grant Flow - * @see Section 4.1.3 Access Token Request - * @see Section 4.1.4 Access Token Response + * @see Section + * 4.1 Authorization Code Grant Flow + * @see Section 4.1.3 Access Token + * Request + * @see Section 4.1.4 Access Token + * Response */ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider { + private final OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider; + private final OAuth2UserService userService; - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); + + private GrantedAuthoritiesMapper authoritiesMapper = ((authorities) -> authorities); /** - * Constructs an {@code OAuth2LoginAuthenticationProvider} using the provided parameters. - * - * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint - * @param userService the service used for obtaining the user attributes of the End-User from the UserInfo Endpoint + * Constructs an {@code OAuth2LoginAuthenticationProvider} using the provided + * parameters. + * @param accessTokenResponseClient the client used for requesting the access token + * credential from the Token Endpoint + * @param userService the service used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint */ public OAuth2LoginAuthenticationProvider( - OAuth2AccessTokenResponseClient accessTokenResponseClient, - OAuth2UserService userService) { - + OAuth2AccessTokenResponseClient accessTokenResponseClient, + OAuth2UserService userService) { Assert.notNull(userService, "userService cannot be null"); - this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(accessTokenResponseClient); + this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( + accessTokenResponseClient); this.userService = userService; } @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2LoginAuthenticationToken loginAuthenticationToken = - (OAuth2LoginAuthenticationToken) authentication; - - // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope - // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (loginAuthenticationToken.getAuthorizationExchange() - .getAuthorizationRequest().getScopes().contains("openid")) { + OAuth2LoginAuthenticationToken loginAuthenticationToken = (OAuth2LoginAuthenticationToken) authentication; + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. + if (loginAuthenticationToken.getAuthorizationExchange().getAuthorizationRequest().getScopes() + .contains("openid")) { // This is an OpenID Connect Authentication Request so return null // and let OidcAuthorizationCodeAuthenticationProvider handle it instead return null; } - OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken; try { authorizationCodeAuthenticationToken = (OAuth2AuthorizationCodeAuthenticationToken) this.authorizationCodeAuthenticationProvider .authenticate(new OAuth2AuthorizationCodeAuthenticationToken( loginAuthenticationToken.getClientRegistration(), loginAuthenticationToken.getAuthorizationExchange())); - } catch (OAuth2AuthorizationException ex) { + } + catch (OAuth2AuthorizationException ex) { OAuth2Error oauth2Error = ex.getError(); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - OAuth2AccessToken accessToken = authorizationCodeAuthenticationToken.getAccessToken(); Map additionalParameters = authorizationCodeAuthenticationToken.getAdditionalParameters(); - OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest( loginAuthenticationToken.getClientRegistration(), accessToken, additionalParameters)); - - Collection mappedAuthorities = - this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); - + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oauth2User.getAuthorities()); OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( - loginAuthenticationToken.getClientRegistration(), - loginAuthenticationToken.getAuthorizationExchange(), - oauth2User, - mappedAuthorities, - accessToken, - authorizationCodeAuthenticationToken.getRefreshToken()); + loginAuthenticationToken.getClientRegistration(), loginAuthenticationToken.getAuthorizationExchange(), + oauth2User, mappedAuthorities, accessToken, authorizationCodeAuthenticationToken.getRefreshToken()); authenticationResult.setDetails(loginAuthenticationToken.getDetails()); - return authenticationResult; } /** - * Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OAuth2User#getAuthorities()} - * to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}. - * - * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OAuth2User#getAuthorities()} to a new set of authorities which will be + * associated to the {@link OAuth2LoginAuthenticationToken}. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the + * user's authorities */ public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); @@ -140,4 +141,5 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider public boolean supports(Class authentication) { return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java index a8fc5fd128..afbe15784f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; +import java.util.Collection; +import java.util.Collections; + import org.springframework.lang.Nullable; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.GrantedAuthority; @@ -26,12 +30,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; -import java.util.Collection; -import java.util.Collections; - /** - * An {@link AbstractAuthenticationToken} for OAuth 2.0 Login, - * which leverages the OAuth 2.0 Authorization Code Grant Flow. + * An {@link AbstractAuthenticationToken} for OAuth 2.0 Login, which leverages the OAuth + * 2.0 Authorization Code Grant Flow. * * @author Joe Grandja * @since 5.0 @@ -40,25 +41,31 @@ import java.util.Collections; * @see ClientRegistration * @see OAuth2AuthorizationExchange * @see OAuth2AccessToken - * @see Section 4.1 Authorization Code Grant Flow + * @see Section + * 4.1 Authorization Code Grant Flow */ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private OAuth2User principal; + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; /** - * This constructor should be used when the Authorization Request/Response is complete. - * + * This constructor should be used when the Authorization Request/Response is + * complete. * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange */ public OAuth2LoginAuthenticationToken(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange) { - + OAuth2AuthorizationExchange authorizationExchange) { super(Collections.emptyList()); Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); @@ -69,9 +76,8 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken /** * This constructor should be used when the Access Token Request/Response is complete, - * which indicates that the Authorization Code Grant flow has fully completed - * and OAuth 2.0 Login has been achieved. - * + * which indicates that the Authorization Code Grant flow has fully completed and + * OAuth 2.0 Login has been achieved. * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange * @param principal the user {@code Principal} registered with the OAuth 2.0 Provider @@ -79,18 +85,15 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken * @param accessToken the access token credential */ public OAuth2LoginAuthenticationToken(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange, - OAuth2User principal, - Collection authorities, - OAuth2AccessToken accessToken) { + OAuth2AuthorizationExchange authorizationExchange, OAuth2User principal, + Collection authorities, OAuth2AccessToken accessToken) { this(clientRegistration, authorizationExchange, principal, authorities, accessToken, null); } /** * This constructor should be used when the Access Token Request/Response is complete, - * which indicates that the Authorization Code Grant flow has fully completed - * and OAuth 2.0 Login has been achieved. - * + * which indicates that the Authorization Code Grant flow has fully completed and + * OAuth 2.0 Login has been achieved. * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange * @param principal the user {@code Principal} registered with the OAuth 2.0 Provider @@ -99,11 +102,9 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken * @param refreshToken the refresh token credential */ public OAuth2LoginAuthenticationToken(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange, - OAuth2User principal, - Collection authorities, - OAuth2AccessToken accessToken, - @Nullable OAuth2RefreshToken refreshToken) { + OAuth2AuthorizationExchange authorizationExchange, OAuth2User principal, + Collection authorities, OAuth2AccessToken accessToken, + @Nullable OAuth2RefreshToken refreshToken) { super(authorities); Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); @@ -129,7 +130,6 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -138,7 +138,6 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken /** * Returns the {@link OAuth2AuthorizationExchange authorization exchange}. - * * @return the {@link OAuth2AuthorizationExchange} */ public OAuth2AuthorizationExchange getAuthorizationExchange() { @@ -147,7 +146,6 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken /** * Returns the {@link OAuth2AccessToken access token}. - * * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { @@ -156,11 +154,11 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken /** * Returns the {@link OAuth2RefreshToken refresh token}. - * - * @since 5.1 * @return the {@link OAuth2RefreshToken} + * @since 5.1 */ public @Nullable OAuth2RefreshToken getRefreshToken() { return this.refreshToken; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java index 9fb1820ff5..e4b72951b5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java @@ -13,8 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; +import java.util.Collection; +import java.util.Map; + +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; @@ -28,24 +34,22 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - -import java.util.Collection; -import java.util.Map; /** - * An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login, - * which leverages the OAuth 2.0 Authorization Code Grant Flow. + * An implementation of an + * {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth + * 2.0 Login, which leverages the OAuth 2.0 Authorization Code Grant Flow. * - * This {@link org.springframework.security.authentication.AuthenticationProvider} is responsible for authenticating - * an Authorization Code credential with the Authorization Server's Token Endpoint - * and if valid, exchanging it for an Access Token credential. + * This {@link org.springframework.security.authentication.AuthenticationProvider} is + * responsible for authenticating an Authorization Code credential with the Authorization + * Server's Token Endpoint and if valid, exchanging it for an Access Token credential. *

        - * It will also obtain the user attributes of the End-User (Resource Owner) - * from the UserInfo Endpoint using an {@link org.springframework.security.oauth2.client.userinfo.OAuth2UserService}, - * which will create a {@code Principal} in the form of an {@link OAuth2User}. - * The {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} - * to complete the authentication. + * It will also obtain the user attributes of the End-User (Resource Owner) from the + * UserInfo Endpoint using an + * {@link org.springframework.security.oauth2.client.userinfo.OAuth2UserService}, which + * will create a {@code Principal} in the form of an {@link OAuth2User}. The + * {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} to + * complete the authentication. * * @author Rob Winch * @since 5.1 @@ -53,24 +57,30 @@ import java.util.Map; * @see org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient * @see org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService * @see OAuth2User - * @see Section 4.1 Authorization Code Grant Flow - * @see Section 4.1.3 Access Token Request - * @see Section 4.1.4 Access Token Response + * @see Section + * 4.1 Authorization Code Grant Flow + * @see Section 4.1.3 Access Token + * Request + * @see Section 4.1.4 Access Token + * Response */ -public class OAuth2LoginReactiveAuthenticationManager implements - ReactiveAuthenticationManager { +public class OAuth2LoginReactiveAuthenticationManager implements ReactiveAuthenticationManager { + private final ReactiveAuthenticationManager authorizationCodeManager; private final ReactiveOAuth2UserService userService; - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); + private GrantedAuthoritiesMapper authoritiesMapper = ((authorities) -> authorities); public OAuth2LoginReactiveAuthenticationManager( ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient, ReactiveOAuth2UserService userService) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(userService, "userService cannot be null"); - this.authorizationCodeManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(accessTokenResponseClient); + this.authorizationCodeManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager( + accessTokenResponseClient); this.userService = userService; } @@ -78,29 +88,29 @@ public class OAuth2LoginReactiveAuthenticationManager implements public Mono authenticate(Authentication authentication) { return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - - // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (token.getAuthorizationExchange() - .getAuthorizationRequest().getScopes().contains("openid")) { + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. + if (token.getAuthorizationExchange().getAuthorizationRequest().getScopes().contains("openid")) { // This is an OpenID Connect Authentication Request so return null - // and let OidcAuthorizationCodeReactiveAuthenticationManager handle it instead once one is created + // and let OidcAuthorizationCodeReactiveAuthenticationManager handle it + // instead once one is created return Mono.empty(); } - return this.authorizationCodeManager.authenticate(token) - .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) - .cast(OAuth2AuthorizationCodeAuthenticationToken.class) - .flatMap(this::onSuccess); + .onErrorMap(OAuth2AuthorizationException.class, + (e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) + .cast(OAuth2AuthorizationCodeAuthenticationToken.class).flatMap(this::onSuccess); }); } /** - * Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OAuth2User#getAuthorities()} - * to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}. - * + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OAuth2User#getAuthorities()} to a new set of authorities which will be + * associated to the {@link OAuth2LoginAuthenticationToken}. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the + * user's authorities * @since 5.4 - * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities */ public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); @@ -110,20 +120,16 @@ public class OAuth2LoginReactiveAuthenticationManager implements private Mono onSuccess(OAuth2AuthorizationCodeAuthenticationToken authentication) { OAuth2AccessToken accessToken = authentication.getAccessToken(); Map additionalParameters = authentication.getAdditionalParameters(); - OAuth2UserRequest userRequest = new OAuth2UserRequest(authentication.getClientRegistration(), accessToken, additionalParameters); - return this.userService.loadUser(userRequest) - .map(oauth2User -> { - Collection mappedAuthorities = - this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); - - OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( - authentication.getClientRegistration(), - authentication.getAuthorizationExchange(), - oauth2User, - mappedAuthorities, - accessToken, - authentication.getRefreshToken()); - return authenticationResult; - }); + OAuth2UserRequest userRequest = new OAuth2UserRequest(authentication.getClientRegistration(), accessToken, + additionalParameters); + return this.userService.loadUser(userRequest).map((oauth2User) -> { + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oauth2User.getAuthorities()); + OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( + authentication.getClientRegistration(), authentication.getAuthorizationExchange(), oauth2User, + mappedAuthorities, accessToken, authentication.getRefreshToken()); + return authenticationResult; + }); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/package-info.java index 1fe301ac98..76109f9a3c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Support classes and interfaces for authenticating and authorizing a client - * with an OAuth 2.0 Authorization Server using a specific authorization grant flow. + * Support classes and interfaces for authenticating and authorizing a client with an + * OAuth 2.0 Authorization Server using a specific authorization grant flow. */ package org.springframework.security.oauth2.client.authentication; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java index d1c7f26aa4..a0d5a698d4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractOAuth2AuthorizationGrantRequest.java @@ -13,27 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; /** - * Base implementation of an OAuth 2.0 Authorization Grant request - * that holds an authorization grant credential and is used when - * initiating a request to the Authorization Server's Token Endpoint. + * Base implementation of an OAuth 2.0 Authorization Grant request that holds an + * authorization grant credential and is used when initiating a request to the + * Authorization Server's Token Endpoint. * * @author Joe Grandja * @since 5.0 * @see AuthorizationGrantType - * @see Section 1.3 Authorization Grant + * @see Section + * 1.3 Authorization Grant */ public abstract class AbstractOAuth2AuthorizationGrantRequest { + private final AuthorizationGrantType authorizationGrantType; /** * Sub-class constructor. - * * @param authorizationGrantType the authorization grant type */ protected AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType authorizationGrantType) { @@ -43,10 +45,10 @@ public abstract class AbstractOAuth2AuthorizationGrantRequest { /** * Returns the authorization grant type. - * * @return the authorization grant type */ public AuthorizationGrantType getGrantType() { return this.authorizationGrantType; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java index 2991b855a8..a97bd09c9a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java @@ -13,63 +13,74 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Collections; +import java.util.Set; + +import reactor.core.publisher.Mono; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.Set; - -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; /** - * Abstract base class for all of the {@code WebClientReactive*TokenResponseClient}s - * that communicate to the Authorization Server's Token Endpoint. + * Abstract base class for all of the {@code WebClientReactive*TokenResponseClient}s that + * communicate to the Authorization Server's Token Endpoint. * - *

        Submits a form request body specific to the type of grant request.

        + *

        + * Submits a form request body specific to the type of grant request. + *

        * - *

        Accepts a JSON response body containing an OAuth 2.0 Access token or error.

        + *

        + * Accepts a JSON response body containing an OAuth 2.0 Access token or error. + *

        * + * @param type of grant request * @author Phil Clay * @since 5.3 - * @param type of grant request - * @see RFC-6749 Token Endpoint + * @see RFC-6749 Token + * Endpoint * @see WebClientReactiveAuthorizationCodeTokenResponseClient * @see WebClientReactiveClientCredentialsTokenResponseClient * @see WebClientReactivePasswordTokenResponseClient * @see WebClientReactiveRefreshTokenTokenResponseClient */ -abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient +public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { private WebClient webClient = WebClient.builder().build(); + AbstractWebClientReactiveOAuth2AccessTokenResponseClient() { + } + @Override public Mono getTokenResponse(T grantRequest) { Assert.notNull(grantRequest, "grantRequest cannot be null"); + // @formatter:off return Mono.defer(() -> this.webClient.post() .uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri()) - .headers(headers -> populateTokenRequestHeaders(grantRequest, headers)) + .headers((headers) -> populateTokenRequestHeaders(grantRequest, headers)) .body(createTokenRequestBody(grantRequest)) .exchange() - .flatMap(response -> readTokenResponse(grantRequest, response))); + .flatMap((response) -> readTokenResponse(grantRequest, response)) + ); + // @formatter:on } /** * Returns the {@link ClientRegistration} for the given {@code grantRequest}. - * * @param grantRequest the grant request * @return the {@link ClientRegistration} for the given {@code grantRequest}. */ @@ -77,7 +88,6 @@ abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClientThis method pre-populates the body with some standard properties, - * and then delegates to {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)} - * for subclasses to further populate the body before returning.

        - * + *

        + * This method pre-populates the body with some standard properties, and then + * delegates to + * {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)} + * for subclasses to further populate the body before returning. + *

        * @param grantRequest the grant request * @return the body for the token request. */ private BodyInserters.FormInserter createTokenRequestBody(T grantRequest) { - BodyInserters.FormInserter body = BodyInserters - .fromFormData(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + BodyInserters.FormInserter body = BodyInserters.fromFormData(OAuth2ParameterNames.GRANT_TYPE, + grantRequest.getGrantType().getValue()); return populateTokenRequestBody(grantRequest, body); } /** * Populates the body of the token request. * - *

        By default, populates properties that are common to all grant types. - * Subclasses can extend this method to populate grant type specific properties.

        - * + *

        + * By default, populates properties that are common to all grant types. Subclasses can + * extend this method to populate grant type specific properties. + *

        * @param grantRequest the grant request * @param body the body to populate * @return the populated body */ - BodyInserters.FormInserter populateTokenRequestBody(T grantRequest, BodyInserters.FormInserter body) { + BodyInserters.FormInserter populateTokenRequestBody(T grantRequest, + BodyInserters.FormInserter body) { ClientRegistration clientRegistration = clientRegistration(grantRequest); if (!ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); @@ -126,31 +140,30 @@ abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient scopes = scopes(grantRequest); if (!CollectionUtils.isEmpty(scopes)) { - body.with(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(scopes, " ")); + body.with(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")); } return body; } /** * Returns the scopes to include as a property in the token request. - * * @param grantRequest the grant request * @return the scopes to include as a property in the token request. */ abstract Set scopes(T grantRequest); /** - * Returns the scopes to include in the response if the authorization - * server returned no scopes in the response. - * - *

        As per RFC-6749 Section 5.1 Successful Access Token Response, - * if AccessTokenResponse.scope is empty, then default to the scope - * originally requested by the client in the Token Request.

        + * Returns the scopes to include in the response if the authorization server returned + * no scopes in the response. * + *

        + * As per RFC-6749 Section + * 5.1 Successful Access Token Response, if AccessTokenResponse.scope is empty, + * then default to the scope originally requested by the client in the Token Request. + *

        * @param grantRequest the grant request - * @return the scopes to include in the response if the authorization - * server returned no scopes. + * @return the scopes to include in the response if the authorization server returned + * no scopes. */ Set defaultScopes(T grantRequest) { return scopes(grantRequest); @@ -158,41 +171,45 @@ abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient readTokenResponse(T grantRequest, ClientResponse response) { - return response.body(oauth2AccessTokenResponse()) - .map(tokenResponse -> populateTokenResponse(grantRequest, tokenResponse)); + return response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse()) + .map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse)); } /** - * Populates the given {@link OAuth2AccessTokenResponse} with additional details - * from the grant request. - * + * Populates the given {@link OAuth2AccessTokenResponse} with additional details from + * the grant request. * @param grantRequest the request for which the response was received. * @param tokenResponse the original token response - * @return a token response optionally populated with additional details from the request. + * @return a token response optionally populated with additional details from the + * request. */ OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) { if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { Set defaultScopes = defaultScopes(grantRequest); - tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) + // @formatter:off + tokenResponse = OAuth2AccessTokenResponse + .withResponse(tokenResponse) .scopes(defaultScopes) .build(); + // @formatter:on } return tokenResponse; } /** - * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response. - * - * @param webClient the {@link WebClient} used when requesting the Access Token Response + * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token + * Response. + * @param webClient the {@link WebClient} used when requesting the Access Token + * Response */ public void setWebClient(WebClient webClient) { Assert.notNull(webClient, "webClient cannot be null"); this.webClient = webClient; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java index 174dc75e43..db43cb47bf 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Arrays; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; @@ -33,92 +36,105 @@ import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Arrays; - /** - * The default implementation of an {@link OAuth2AccessTokenResponseClient} - * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. - * This implementation uses a {@link RestOperations} when requesting - * an access token credential at the Authorization Server's Token Endpoint. + * The default implementation of an {@link OAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. This + * implementation uses a {@link RestOperations} when requesting an access token credential + * at the Authorization Server's Token Endpoint. * * @author Joe Grandja * @since 5.1 * @see OAuth2AccessTokenResponseClient * @see OAuth2AuthorizationCodeGrantRequest * @see OAuth2AccessTokenResponse - * @see Section 4.1.3 Access Token Request (Authorization Code Grant) - * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) */ -public final class DefaultAuthorizationCodeTokenResponseClient implements OAuth2AccessTokenResponseClient { +public final class DefaultAuthorizationCodeTokenResponseClient + implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private Converter> requestEntityConverter = - new OAuth2AuthorizationCodeGrantRequestEntityConverter(); + private Converter> requestEntityConverter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); private RestOperations restOperations; public DefaultAuthorizationCodeTokenResponseClient() { - RestTemplate restTemplate = new RestTemplate(Arrays.asList( - new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + RestTemplate restTemplate = new RestTemplate( + Arrays.asList(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); this.restOperations = restTemplate; } @Override - public OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { + public OAuth2AccessTokenResponse getTokenResponse( + OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { Assert.notNull(authorizationCodeGrantRequest, "authorizationCodeGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(authorizationCodeGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 // If AccessTokenResponse.scope is empty, then default to the scope // originally requested by the client in the Token Request + // @formatter:off tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) .scopes(authorizationCodeGrantRequest.getClientRegistration().getScopes()) .build(); + // @formatter:on } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** - * Sets the {@link Converter} used for converting the {@link OAuth2AuthorizationCodeGrantRequest} - * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. - * - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + * Sets the {@link Converter} used for converting the + * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link RequestEntity} + * representation of the OAuth 2.0 Access Token Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the Access Token Request */ - public void setRequestEntityConverter(Converter> requestEntityConverter) { + public void setRequestEntityConverter( + Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); this.requestEntityConverter = requestEntityConverter; } /** - * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token + * Response. * *

        - * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: *

          - *
        1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
        2. - *
        3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        4. + *
        5. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and + * {@link OAuth2AccessTokenResponseHttpMessageConverter}
        6. + *
        7. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        8. *
        - * - * @param restOperations the {@link RestOperations} used when requesting the Access Token Response + * @param restOperations the {@link RestOperations} used when requesting the Access + * Token Response */ public void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java index fdd5eb1e75..168a85a9da 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Arrays; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; @@ -33,92 +36,105 @@ import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Arrays; - /** - * The default implementation of an {@link OAuth2AccessTokenResponseClient} - * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. - * This implementation uses a {@link RestOperations} when requesting - * an access token credential at the Authorization Server's Token Endpoint. + * The default implementation of an {@link OAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. This + * implementation uses a {@link RestOperations} when requesting an access token credential + * at the Authorization Server's Token Endpoint. * * @author Joe Grandja * @since 5.1 * @see OAuth2AccessTokenResponseClient * @see OAuth2ClientCredentialsGrantRequest * @see OAuth2AccessTokenResponse - * @see Section 4.4.2 Access Token Request (Client Credentials Grant) - * @see Section 4.4.3 Access Token Response (Client Credentials Grant) + * @see Section 4.4.2 Access Token Request + * (Client Credentials Grant) + * @see Section 4.4.3 Access Token Response + * (Client Credentials Grant) */ -public final class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient { +public final class DefaultClientCredentialsTokenResponseClient + implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private Converter> requestEntityConverter = - new OAuth2ClientCredentialsGrantRequestEntityConverter(); + private Converter> requestEntityConverter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); private RestOperations restOperations; public DefaultClientCredentialsTokenResponseClient() { - RestTemplate restTemplate = new RestTemplate(Arrays.asList( - new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + RestTemplate restTemplate = new RestTemplate( + Arrays.asList(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); this.restOperations = restTemplate; } @Override - public OAuth2AccessTokenResponse getTokenResponse(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + public OAuth2AccessTokenResponse getTokenResponse( + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(clientCredentialsGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 // If AccessTokenResponse.scope is empty, then default to the scope // originally requested by the client in the Token Request + // @formatter:off tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) .scopes(clientCredentialsGrantRequest.getClientRegistration().getScopes()) .build(); + // @formatter:on } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** - * Sets the {@link Converter} used for converting the {@link OAuth2ClientCredentialsGrantRequest} - * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. - * - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + * Sets the {@link Converter} used for converting the + * {@link OAuth2ClientCredentialsGrantRequest} to a {@link RequestEntity} + * representation of the OAuth 2.0 Access Token Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the Access Token Request */ - public void setRequestEntityConverter(Converter> requestEntityConverter) { + public void setRequestEntityConverter( + Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); this.requestEntityConverter = requestEntityConverter; } /** - * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token + * Response. * *

        - * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: *

          - *
        1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
        2. - *
        3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        4. + *
        5. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and + * {@link OAuth2AccessTokenResponseHttpMessageConverter}
        6. + *
        7. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        8. *
        - * - * @param restOperations the {@link RestOperations} used when requesting the Access Token Response + * @param restOperations the {@link RestOperations} used when requesting the Access + * Token Response */ public void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java index e2f7180d2e..047e787885 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Arrays; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; @@ -33,33 +36,36 @@ import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Arrays; - /** - * The default implementation of an {@link OAuth2AccessTokenResponseClient} - * for the {@link AuthorizationGrantType#PASSWORD password} grant. - * This implementation uses a {@link RestOperations} when requesting - * an access token credential at the Authorization Server's Token Endpoint. + * The default implementation of an {@link OAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#PASSWORD password} grant. This implementation uses a + * {@link RestOperations} when requesting an access token credential at the Authorization + * Server's Token Endpoint. * * @author Joe Grandja * @since 5.2 * @see OAuth2AccessTokenResponseClient * @see OAuth2PasswordGrantRequest * @see OAuth2AccessTokenResponse - * @see Section 4.3.2 Access Token Request (Resource Owner Password Credentials Grant) - * @see Section 4.3.3 Access Token Response (Resource Owner Password Credentials Grant) + * @see Section 4.3.2 Access Token Request + * (Resource Owner Password Credentials Grant) + * @see Section 4.3.3 Access Token Response + * (Resource Owner Password Credentials Grant) */ -public final class DefaultPasswordTokenResponseClient implements OAuth2AccessTokenResponseClient { +public final class DefaultPasswordTokenResponseClient + implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private Converter> requestEntityConverter = - new OAuth2PasswordGrantRequestEntityConverter(); + private Converter> requestEntityConverter = new OAuth2PasswordGrantRequestEntityConverter(); private RestOperations restOperations; public DefaultPasswordTokenResponseClient() { - RestTemplate restTemplate = new RestTemplate(Arrays.asList( - new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + RestTemplate restTemplate = new RestTemplate( + Arrays.asList(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); this.restOperations = restTemplate; } @@ -67,58 +73,64 @@ public final class DefaultPasswordTokenResponseClient implements OAuth2AccessTok @Override public OAuth2AccessTokenResponse getTokenResponse(OAuth2PasswordGrantRequest passwordGrantRequest) { Assert.notNull(passwordGrantRequest, "passwordGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(passwordGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 // If AccessTokenResponse.scope is empty, then default to the scope // originally requested by the client in the Token Request tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) - .scopes(passwordGrantRequest.getClientRegistration().getScopes()) - .build(); + .scopes(passwordGrantRequest.getClientRegistration().getScopes()).build(); } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** - * Sets the {@link Converter} used for converting the {@link OAuth2PasswordGrantRequest} - * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. - * - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + * Sets the {@link Converter} used for converting the + * {@link OAuth2PasswordGrantRequest} to a {@link RequestEntity} representation of the + * OAuth 2.0 Access Token Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the Access Token Request */ - public void setRequestEntityConverter(Converter> requestEntityConverter) { + public void setRequestEntityConverter( + Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); this.requestEntityConverter = requestEntityConverter; } /** - * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token + * Response. * *

        - * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: *

          - *
        1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
        2. - *
        3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        4. + *
        5. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and + * {@link OAuth2AccessTokenResponseHttpMessageConverter}
        6. + *
        7. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        8. *
        - * - * @param restOperations the {@link RestOperations} used when requesting the Access Token Response + * @param restOperations the {@link RestOperations} used when requesting the Access + * Token Response */ public void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java index 0efd37d8eb..8550d077c0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Arrays; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; @@ -33,32 +36,32 @@ import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Arrays; - /** - * The default implementation of an {@link OAuth2AccessTokenResponseClient} - * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. - * This implementation uses a {@link RestOperations} when requesting - * an access token credential at the Authorization Server's Token Endpoint. + * The default implementation of an {@link OAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. This implementation + * uses a {@link RestOperations} when requesting an access token credential at the + * Authorization Server's Token Endpoint. * * @author Joe Grandja * @since 5.2 * @see OAuth2AccessTokenResponseClient * @see OAuth2RefreshTokenGrantRequest * @see OAuth2AccessTokenResponse - * @see Section 6 Refreshing an Access Token + * @see Section 6 + * Refreshing an Access Token */ -public final class DefaultRefreshTokenTokenResponseClient implements OAuth2AccessTokenResponseClient { +public final class DefaultRefreshTokenTokenResponseClient + implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private Converter> requestEntityConverter = - new OAuth2RefreshTokenGrantRequestEntityConverter(); + private Converter> requestEntityConverter = new OAuth2RefreshTokenGrantRequestEntityConverter(); private RestOperations restOperations; public DefaultRefreshTokenTokenResponseClient() { - RestTemplate restTemplate = new RestTemplate(Arrays.asList( - new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + RestTemplate restTemplate = new RestTemplate( + Arrays.asList(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); this.restOperations = restTemplate; } @@ -66,24 +69,13 @@ public final class DefaultRefreshTokenTokenResponseClient implements OAuth2Acces @Override public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(refreshTokenGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes()) || - tokenResponse.getRefreshToken() == null) { - OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(tokenResponse); - + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes()) + || tokenResponse.getRefreshToken() == null) { + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse + .withResponse(tokenResponse); if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 @@ -91,43 +83,59 @@ public final class DefaultRefreshTokenTokenResponseClient implements OAuth2Acces // originally requested by the client in the Token Request tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); } - if (tokenResponse.getRefreshToken() == null) { // Reuse existing refresh token tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); } - tokenResponse = tokenResponseBuilder.build(); } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** - * Sets the {@link Converter} used for converting the {@link OAuth2RefreshTokenGrantRequest} - * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. - * - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + * Sets the {@link Converter} used for converting the + * {@link OAuth2RefreshTokenGrantRequest} to a {@link RequestEntity} representation of + * the OAuth 2.0 Access Token Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the Access Token Request */ - public void setRequestEntityConverter(Converter> requestEntityConverter) { + public void setRequestEntityConverter( + Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); this.requestEntityConverter = requestEntityConverter; } /** - * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token + * Response. * *

        - * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: *

          - *
        1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
        2. - *
        3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        4. + *
        5. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and + * {@link OAuth2AccessTokenResponseHttpMessageConverter}
        6. + *
        7. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        8. *
        - * - * @param restOperations the {@link RestOperations} used when requesting the Access Token Response + * @param restOperations the {@link RestOperations} used when requesting the Access + * Token Response */ public void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java index cac276aa67..3165867cd2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.io.IOException; +import java.net.URI; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import com.nimbusds.oauth2.sdk.AccessTokenResponse; import com.nimbusds.oauth2.sdk.AuthorizationCode; @@ -30,6 +37,7 @@ import com.nimbusds.oauth2.sdk.auth.ClientSecretPost; import com.nimbusds.oauth2.sdk.auth.Secret; import com.nimbusds.oauth2.sdk.http.HTTPRequest; import com.nimbusds.oauth2.sdk.id.ClientID; + import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -40,16 +48,9 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.CollectionUtils; -import java.io.IOException; -import java.net.URI; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; - /** - * An implementation of an {@link OAuth2AccessTokenResponseClient} that "exchanges" - * an authorization code credential for an access token credential + * An implementation of an {@link OAuth2AccessTokenResponseClient} that + * "exchanges" an authorization code credential for an access token credential * at the Authorization Server's Token Endpoint. * *

        @@ -61,36 +62,76 @@ import java.util.Set; * @see OAuth2AccessTokenResponseClient * @see OAuth2AuthorizationCodeGrantRequest * @see OAuth2AccessTokenResponse - * @see Nimbus OAuth 2.0 SDK - * @see Section 4.1.3 Access Token Request (Authorization Code Grant) - * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + * @see Nimbus OAuth 2.0 + * SDK + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) */ @Deprecated -public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessTokenResponseClient { +public class NimbusAuthorizationCodeTokenResponseClient + implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; @Override public OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) { ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); - // Build the authorization code grant request for the token endpoint AuthorizationCode authorizationCode = new AuthorizationCode( - authorizationGrantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()); - URI redirectUri = toURI(authorizationGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getRedirectUri()); + authorizationGrantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()); + URI redirectUri = toURI( + authorizationGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getRedirectUri()); AuthorizationGrant authorizationCodeGrant = new AuthorizationCodeGrant(authorizationCode, redirectUri); URI tokenUri = toURI(clientRegistration.getProviderDetails().getTokenUri()); - // Set the credentials to authenticate the client at the token endpoint ClientID clientId = new ClientID(clientRegistration.getClientId()); Secret clientSecret = new Secret(clientRegistration.getClientSecret()); - ClientAuthentication clientAuthentication; - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - clientAuthentication = new ClientSecretPost(clientId, clientSecret); - } else { - clientAuthentication = new ClientSecretBasic(clientId, clientSecret); + boolean isPost = ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod()); + ClientAuthentication clientAuthentication = isPost ? new ClientSecretPost(clientId, clientSecret) + : new ClientSecretBasic(clientId, clientSecret); + com.nimbusds.oauth2.sdk.TokenResponse tokenResponse = getTokenResponse(authorizationCodeGrant, tokenUri, + clientAuthentication); + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; + ErrorObject errorObject = tokenErrorResponse.getErrorObject(); + throw new OAuth2AuthorizationException(getOAuthError(errorObject)); } + AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse; + String accessToken = accessTokenResponse.getTokens().getAccessToken().getValue(); + OAuth2AccessToken.TokenType accessTokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue() + .equalsIgnoreCase(accessTokenResponse.getTokens().getAccessToken().getType().getValue())) { + accessTokenType = OAuth2AccessToken.TokenType.BEARER; + } + long expiresIn = accessTokenResponse.getTokens().getAccessToken().getLifetime(); + // As per spec, in section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Authorization Request + Set scopes = getScopes(authorizationGrantRequest, accessTokenResponse); + String refreshToken = null; + if (accessTokenResponse.getTokens().getRefreshToken() != null) { + refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); + } + Map additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); + // @formatter:off + return OAuth2AccessTokenResponse.withToken(accessToken) + .tokenType(accessTokenType) + .expiresIn(expiresIn) + .scopes(scopes) + .refreshToken(refreshToken) + .additionalParameters(additionalParameters) + .build(); + // @formatter:on + } - com.nimbusds.oauth2.sdk.TokenResponse tokenResponse; + private com.nimbusds.oauth2.sdk.TokenResponse getTokenResponse(AuthorizationGrant authorizationCodeGrant, + URI tokenUri, ClientAuthentication clientAuthentication) { try { // Send the Access Token request TokenRequest tokenRequest = new TokenRequest(tokenUri, clientAuthentication, authorizationCodeGrant); @@ -98,71 +139,43 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT httpRequest.setAccept(MediaType.APPLICATION_JSON_VALUE); httpRequest.setConnectTimeout(30000); httpRequest.setReadTimeout(30000); - tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send()); - } catch (ParseException | IOException ex) { + return com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send()); + } + catch (ParseException | IOException ex) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); throw new OAuth2AuthorizationException(oauth2Error, ex); } + } - if (!tokenResponse.indicatesSuccess()) { - TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; - ErrorObject errorObject = tokenErrorResponse.getErrorObject(); - OAuth2Error oauth2Error; - if (errorObject == null) { - oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); - } else { - oauth2Error = new OAuth2Error( - errorObject.getCode() != null ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR, - errorObject.getDescription(), - errorObject.getURI() != null ? errorObject.getURI().toString() : null); - } - throw new OAuth2AuthorizationException(oauth2Error); - } - - AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse; - - String accessToken = accessTokenResponse.getTokens().getAccessToken().getValue(); - OAuth2AccessToken.TokenType accessTokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(accessTokenResponse.getTokens().getAccessToken().getType().getValue())) { - accessTokenType = OAuth2AccessToken.TokenType.BEARER; - } - long expiresIn = accessTokenResponse.getTokens().getAccessToken().getLifetime(); - - // As per spec, in section 5.1 Successful Access Token Response - // https://tools.ietf.org/html/rfc6749#section-5.1 - // If AccessTokenResponse.scope is empty, then default to the scope - // originally requested by the client in the Authorization Request - Set scopes; + private Set getScopes(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest, + AccessTokenResponse accessTokenResponse) { if (CollectionUtils.isEmpty(accessTokenResponse.getTokens().getAccessToken().getScope())) { - scopes = new LinkedHashSet<>( - authorizationGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getScopes()); - } else { - scopes = new LinkedHashSet<>( - accessTokenResponse.getTokens().getAccessToken().getScope().toStringList()); + return new LinkedHashSet<>( + authorizationGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getScopes()); } + return new LinkedHashSet<>(accessTokenResponse.getTokens().getAccessToken().getScope().toStringList()); + } - String refreshToken = null; - if (accessTokenResponse.getTokens().getRefreshToken() != null) { - refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); + private OAuth2Error getOAuthError(ErrorObject errorObject) { + if (errorObject == null) { + return new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); } - - Map additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); - - return OAuth2AccessTokenResponse.withToken(accessToken) - .tokenType(accessTokenType) - .expiresIn(expiresIn) - .scopes(scopes) - .refreshToken(refreshToken) - .additionalParameters(additionalParameters) - .build(); + String errorCode = (errorObject.getCode() != null) ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR; + String description = errorObject.getDescription(); + String uri = (errorObject.getURI() != null) ? errorObject.getURI().toString() : null; + return new OAuth2Error(errorCode, description, uri); } private static URI toURI(String uriStr) { try { return new URI(uriStr); - } catch (Exception ex) { + } + catch (Exception ex) { throw new IllegalArgumentException("An error occurred parsing URI: " + uriStr, ex); } } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AccessTokenResponseClient.java index cf86174324..508365898b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AccessTokenResponseClient.java @@ -13,37 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; /** - * A strategy for "exchanging" an authorization grant credential - * (e.g. an Authorization Code) for an access token credential - * at the Authorization Server's Token Endpoint. + * A strategy for "exchanging" an authorization grant credential (e.g. an + * Authorization Code) for an access token credential at the Authorization Server's Token + * Endpoint. * * @author Joe Grandja * @since 5.0 * @see AbstractOAuth2AuthorizationGrantRequest * @see OAuth2AccessTokenResponse * @see AuthorizationGrantType - * @see Section 1.3 Authorization Grant - * @see Section 4.1.3 Access Token Request (Authorization Code Grant) - * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + * @see Section + * 1.3 Authorization Grant + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) */ @FunctionalInterface -public interface OAuth2AccessTokenResponseClient { +public interface OAuth2AccessTokenResponseClient { /** - * Exchanges the authorization grant credential, provided in the authorization grant request, - * for an access token credential at the Authorization Server's Token Endpoint. - * - * @param authorizationGrantRequest the authorization grant request that contains the authorization grant credential - * @return an {@link OAuth2AccessTokenResponse} that contains the {@link OAuth2AccessTokenResponse#getAccessToken() access token} credential - * @throws OAuth2AuthorizationException if an error occurs while attempting to exchange for the access token credential + * Exchanges the authorization grant credential, provided in the authorization grant + * request, for an access token credential at the Authorization Server's Token + * Endpoint. + * @param authorizationGrantRequest the authorization grant request that contains the + * authorization grant credential + * @return an {@link OAuth2AccessTokenResponse} that contains the + * {@link OAuth2AccessTokenResponse#getAccessToken() access token} credential + * @throws OAuth2AuthorizationException if an error occurs while attempting to + * exchange for the access token credential */ OAuth2AccessTokenResponse getTokenResponse(T authorizationGrantRequest); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java index 3b4c36c670..feae3d1f37 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -21,28 +22,33 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch import org.springframework.util.Assert; /** - * An OAuth 2.0 Authorization Code Grant request that holds an Authorization Code credential, - * which was granted by the Resource Owner to the {@link #getClientRegistration() Client}. + * An OAuth 2.0 Authorization Code Grant request that holds an Authorization Code + * credential, which was granted by the Resource Owner to the + * {@link #getClientRegistration() Client}. * * @author Joe Grandja * @since 5.0 * @see AbstractOAuth2AuthorizationGrantRequest * @see ClientRegistration * @see OAuth2AuthorizationExchange - * @see Section 1.3.1 Authorization Code Grant + * @see Section 1.3.1 Authorization Code + * Grant */ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final ClientRegistration clientRegistration; + private final OAuth2AuthorizationExchange authorizationExchange; /** - * Constructs an {@code OAuth2AuthorizationCodeGrantRequest} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationCodeGrantRequest} using the provided + * parameters. * @param clientRegistration the client registration * @param authorizationExchange the authorization exchange */ public OAuth2AuthorizationCodeGrantRequest(ClientRegistration clientRegistration, - OAuth2AuthorizationExchange authorizationExchange) { + OAuth2AuthorizationExchange authorizationExchange) { super(AuthorizationGrantType.AUTHORIZATION_CODE); Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); @@ -52,7 +58,6 @@ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2Authoriza /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -61,10 +66,10 @@ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2Authoriza /** * Returns the {@link OAuth2AuthorizationExchange authorization exchange}. - * * @return the {@link OAuth2AuthorizationExchange} */ public OAuth2AuthorizationExchange getAuthorizationExchange() { return this.authorizationExchange; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java index a8a088a77a..a77470c0de 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.net.URI; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -28,12 +31,10 @@ import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; - /** - * A {@link Converter} that converts the provided {@link OAuth2AuthorizationCodeGrantRequest} - * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request - * for the Authorization Code Grant. + * A {@link Converter} that converts the provided + * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link RequestEntity} representation + * of an OAuth 2.0 Access Token Request for the Authorization Code Grant. * * @author Joe Grandja * @since 5.1 @@ -41,42 +42,41 @@ import java.net.URI; * @see OAuth2AuthorizationCodeGrantRequest * @see RequestEntity */ -public class OAuth2AuthorizationCodeGrantRequestEntityConverter implements Converter> { +public class OAuth2AuthorizationCodeGrantRequestEntityConverter + implements Converter> { /** * Returns the {@link RequestEntity} used for the Access Token Request. - * * @param authorizationCodeGrantRequest the authorization code grant request * @return the {@link RequestEntity} used for the Access Token Request */ @Override public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = this.buildFormParameters(authorizationCodeGrantRequest); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) - .build() + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. - * + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body. * @param authorizationCodeGrantRequest the authorization code grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + * @return a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body */ - private MultiValueMap buildFormParameters(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { + private MultiValueMap buildFormParameters( + OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); OAuth2AuthorizationExchange authorizationExchange = authorizationCodeGrantRequest.getAuthorizationExchange(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, authorizationCodeGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); - String codeVerifier = authorizationExchange.getAuthorizationRequest().getAttribute(PkceParameterNames.CODE_VERIFIER); + String codeVerifier = authorizationExchange.getAuthorizationRequest() + .getAttribute(PkceParameterNames.CODE_VERIFIER); if (redirectUri != null) { formParameters.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri); } @@ -89,7 +89,7 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter implements Conve if (codeVerifier != null) { formParameters.add(PkceParameterNames.CODE_VERIFIER, codeVerifier); } - return formParameters; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java index a1ed924307..1ca61bf69e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationGrantRequestEntityUtils.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Collections; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -22,15 +25,11 @@ import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import java.util.Collections; - -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; - /** - * Utility methods used by the {@link Converter}'s that convert - * from an implementation of an {@link AbstractOAuth2AuthorizationGrantRequest} - * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request - * for the specific Authorization Grant. + * Utility methods used by the {@link Converter}'s that convert from an implementation of + * an {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link RequestEntity} + * representation of an OAuth 2.0 Access Token Request for the specific Authorization + * Grant. * * @author Joe Grandja * @since 5.1 @@ -38,8 +37,12 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @see OAuth2ClientCredentialsGrantRequestEntityConverter */ final class OAuth2AuthorizationGrantRequestEntityUtils { + private static HttpHeaders DEFAULT_TOKEN_REQUEST_HEADERS = getDefaultTokenRequestHeaders(); + private OAuth2AuthorizationGrantRequestEntityUtils() { + } + static HttpHeaders getTokenRequestHeaders(ClientRegistration clientRegistration) { HttpHeaders headers = new HttpHeaders(); headers.addAll(DEFAULT_TOKEN_REQUEST_HEADERS); @@ -52,8 +55,9 @@ final class OAuth2AuthorizationGrantRequestEntityUtils { private static HttpHeaders getDefaultTokenRequestHeaders() { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); - final MediaType contentType = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + final MediaType contentType = MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); headers.setContentType(contentType); return headers; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java index 8764d2068b..f91868a0fe 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -20,21 +21,24 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; /** - * An OAuth 2.0 Client Credentials Grant request that holds - * the client's credentials in {@link #getClientRegistration()}. + * An OAuth 2.0 Client Credentials Grant request that holds the client's credentials in + * {@link #getClientRegistration()}. * * @author Joe Grandja * @since 5.1 * @see AbstractOAuth2AuthorizationGrantRequest * @see ClientRegistration - * @see Section 1.3.4 Client Credentials Grant + * @see Section 1.3.4 Client Credentials + * Grant */ public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final ClientRegistration clientRegistration; /** - * Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided parameters. - * + * Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided + * parameters. * @param clientRegistration the client registration */ public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration) { @@ -47,10 +51,10 @@ public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2Authoriza /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { return this.clientRegistration; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java index 75c0398cc7..b555aabd1b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.net.URI; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -28,12 +31,10 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; - /** - * A {@link Converter} that converts the provided {@link OAuth2ClientCredentialsGrantRequest} - * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request - * for the Client Credentials Grant. + * A {@link Converter} that converts the provided + * {@link OAuth2ClientCredentialsGrantRequest} to a {@link RequestEntity} representation + * of an OAuth 2.0 Access Token Request for the Client Credentials Grant. * * @author Joe Grandja * @since 5.1 @@ -41,36 +42,34 @@ import java.net.URI; * @see OAuth2ClientCredentialsGrantRequest * @see RequestEntity */ -public class OAuth2ClientCredentialsGrantRequestEntityConverter implements Converter> { +public class OAuth2ClientCredentialsGrantRequestEntityConverter + implements Converter> { /** * Returns the {@link RequestEntity} used for the Access Token Request. - * * @param clientCredentialsGrantRequest the client credentials grant request * @return the {@link RequestEntity} used for the Access Token Request */ @Override public RequestEntity convert(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = this.buildFormParameters(clientCredentialsGrantRequest); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) - .build() + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. - * + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body. * @param clientCredentialsGrantRequest the client credentials grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + * @return a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body */ - private MultiValueMap buildFormParameters(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + private MultiValueMap buildFormParameters( + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { @@ -81,7 +80,7 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverter implements Conve formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java index 0898fc32a0..cc82b3f47f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -20,22 +21,26 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; /** - * An OAuth 2.0 Resource Owner Password Credentials Grant request - * that holds the resource owner's credentials. + * An OAuth 2.0 Resource Owner Password Credentials Grant request that holds the resource + * owner's credentials. * * @author Joe Grandja * @since 5.2 * @see AbstractOAuth2AuthorizationGrantRequest - * @see Section 1.3.3 Resource Owner Password Credentials + * @see Section 1.3.3 Resource Owner + * Password Credentials */ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final ClientRegistration clientRegistration; + private final String username; + private final String password; /** * Constructs an {@code OAuth2PasswordGrantRequest} using the provided parameters. - * * @param clientRegistration the client registration * @param username the resource owner's username * @param password the resource owner's password @@ -54,7 +59,6 @@ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrant /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -63,7 +67,6 @@ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrant /** * Returns the resource owner's username. - * * @return the resource owner's username */ public String getUsername() { @@ -72,10 +75,10 @@ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrant /** * Returns the resource owner's password. - * * @return the resource owner's password */ public String getPassword() { return this.password; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java index f2ba2b40a0..1bef0f8404 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.net.URI; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -28,12 +31,10 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; - /** - * A {@link Converter} that converts the provided {@link OAuth2PasswordGrantRequest} - * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request - * for the Resource Owner Password Credentials Grant. + * A {@link Converter} that converts the provided {@link OAuth2PasswordGrantRequest} to a + * {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the + * Resource Owner Password Credentials Grant. * * @author Joe Grandja * @since 5.2 @@ -41,36 +42,33 @@ import java.net.URI; * @see OAuth2PasswordGrantRequest * @see RequestEntity */ -public class OAuth2PasswordGrantRequestEntityConverter implements Converter> { +public class OAuth2PasswordGrantRequestEntityConverter + implements Converter> { /** * Returns the {@link RequestEntity} used for the Access Token Request. - * * @param passwordGrantRequest the password grant request * @return the {@link RequestEntity} used for the Access Token Request */ @Override public RequestEntity convert(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = buildFormParameters(passwordGrantRequest); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) - .build() + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. - * + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body. * @param passwordGrantRequest the password grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + * @return a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body */ private MultiValueMap buildFormParameters(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername()); @@ -83,7 +81,7 @@ public class OAuth2PasswordGrantRequestEntityConverter implements ConverterSection 6 Refreshing an Access Token + * @see Section 6 + * Refreshing an Access Token */ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final ClientRegistration clientRegistration; + private final OAuth2AccessToken accessToken; + private final OAuth2RefreshToken refreshToken; + private final Set scopes; /** * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. - * * @param clientRegistration the authorized client's registration * @param accessToken the access token credential granted * @param refreshToken the refresh token credential granted */ public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, - OAuth2RefreshToken refreshToken) { + OAuth2RefreshToken refreshToken) { this(clientRegistration, accessToken, refreshToken, Collections.emptySet()); } /** * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. - * * @param clientRegistration the authorized client's registration * @param accessToken the access token credential granted * @param refreshToken the refresh token credential granted * @param scopes the scopes to request */ public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, - OAuth2RefreshToken refreshToken, Set scopes) { + OAuth2RefreshToken refreshToken, Set scopes) { super(AuthorizationGrantType.REFRESH_TOKEN); Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(accessToken, "accessToken cannot be null"); @@ -70,13 +74,12 @@ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationG this.clientRegistration = clientRegistration; this.accessToken = accessToken; this.refreshToken = refreshToken; - this.scopes = Collections.unmodifiableSet(scopes != null ? - new LinkedHashSet<>(scopes) : Collections.emptySet()); + this.scopes = Collections + .unmodifiableSet((scopes != null) ? new LinkedHashSet<>(scopes) : Collections.emptySet()); } /** * Returns the authorized client's {@link ClientRegistration registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -85,7 +88,6 @@ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationG /** * Returns the {@link OAuth2AccessToken access token} credential granted. - * * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { @@ -94,7 +96,6 @@ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationG /** * Returns the {@link OAuth2RefreshToken refresh token} credential granted. - * * @return the {@link OAuth2RefreshToken} */ public OAuth2RefreshToken getRefreshToken() { @@ -103,10 +104,10 @@ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationG /** * Returns the scope(s) to request. - * * @return the scope(s) to request */ public Set getScopes() { return this.scopes; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java index 00cac8beed..bd22022bf8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.net.URI; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -28,12 +31,10 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; - /** * A {@link Converter} that converts the provided {@link OAuth2RefreshTokenGrantRequest} - * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request - * for the Refresh Token Grant. + * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the + * Refresh Token Grant. * * @author Joe Grandja * @since 5.2 @@ -41,36 +42,33 @@ import java.net.URI; * @see OAuth2RefreshTokenGrantRequest * @see RequestEntity */ -public class OAuth2RefreshTokenGrantRequestEntityConverter implements Converter> { +public class OAuth2RefreshTokenGrantRequestEntityConverter + implements Converter> { /** * Returns the {@link RequestEntity} used for the Access Token Request. - * * @param refreshTokenGrantRequest the refresh token grant request * @return the {@link RequestEntity} used for the Access Token Request */ @Override public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = buildFormParameters(refreshTokenGrantRequest); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) - .build() + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } /** - * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. - * + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body. * @param refreshTokenGrantRequest the refresh token grant request - * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + * @return a {@link MultiValueMap} of the form parameters used for the Access Token + * Request body */ private MultiValueMap buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, @@ -83,7 +81,7 @@ public class OAuth2RefreshTokenGrantRequestEntityConverter implements Converter< formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java index f0196bdcc3..0246a91cd1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java @@ -13,37 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import reactor.core.publisher.Mono; /** - * A reactive strategy for "exchanging" an authorization grant credential - * (e.g. an Authorization Code) for an access token credential - * at the Authorization Server's Token Endpoint. + * A reactive strategy for "exchanging" an authorization grant credential (e.g. + * an Authorization Code) for an access token credential at the Authorization Server's + * Token Endpoint. * * @author Rob Winch * @since 5.1 * @see AbstractOAuth2AuthorizationGrantRequest * @see OAuth2AccessTokenResponse * @see AuthorizationGrantType - * @see Section 1.3 Authorization Grant - * @see Section 4.1.3 Access Token Request (Authorization Code Grant) - * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + * @see Section + * 1.3 Authorization Grant + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) */ @FunctionalInterface -public interface ReactiveOAuth2AccessTokenResponseClient { +public interface ReactiveOAuth2AccessTokenResponseClient { /** - * Exchanges the authorization grant credential, provided in the authorization grant request, - * for an access token credential at the Authorization Server's Token Endpoint. - * - * @param authorizationGrantRequest the authorization grant request that contains the authorization grant credential - * @return an {@link OAuth2AccessTokenResponse} that contains the {@link OAuth2AccessTokenResponse#getAccessToken() access token} credential - * @throws OAuth2AuthorizationException if an error occurs while attempting to exchange for the access token credential + * Exchanges the authorization grant credential, provided in the authorization grant + * request, for an access token credential at the Authorization Server's Token + * Endpoint. + * @param authorizationGrantRequest the authorization grant request that contains the + * authorization grant credential + * @return an {@link OAuth2AccessTokenResponse} that contains the + * {@link OAuth2AccessTokenResponse#getAccessToken() access token} credential + * @throws OAuth2AuthorizationException if an error occurs while attempting to + * exchange for the access token credential */ Mono getTokenResponse(T authorizationGrantRequest); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java index 80cbe4cafd..77926000d4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Collections; +import java.util.Set; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; @@ -23,12 +27,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.web.reactive.function.BodyInserters; -import java.util.Collections; -import java.util.Set; - /** - * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" - * an authorization code credential for an access token credential + * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that + * "exchanges" an authorization code credential for an access token credential * at the Authorization Server's Token Endpoint. * *

        @@ -39,13 +40,20 @@ import java.util.Set; * @see ReactiveOAuth2AccessTokenResponseClient * @see OAuth2AuthorizationCodeGrantRequest * @see OAuth2AccessTokenResponse - * @see Nimbus OAuth 2.0 SDK - * @see Section 4.1.3 Access Token Request (Authorization Code Grant) - * @see Section 4.1.4 Access Token Response (Authorization Code Grant) - * @see Section 4.2 Client Creates the Code Challenge + * @see Nimbus OAuth 2.0 + * SDK + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) + * @see Section + * 4.2 Client Creates the Code Challenge */ -public class WebClientReactiveAuthorizationCodeTokenResponseClient extends - AbstractWebClientReactiveOAuth2AccessTokenResponseClient { +public class WebClientReactiveAuthorizationCodeTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override ClientRegistration clientRegistration(OAuth2AuthorizationCodeGrantRequest grantRequest) { @@ -63,8 +71,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClient extends } @Override - BodyInserters.FormInserter populateTokenRequestBody( - OAuth2AuthorizationCodeGrantRequest grantRequest, + BodyInserters.FormInserter populateTokenRequestBody(OAuth2AuthorizationCodeGrantRequest grantRequest, BodyInserters.FormInserter body) { super.populateTokenRequestBody(grantRequest, body); OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange(); @@ -74,10 +81,12 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClient extends if (redirectUri != null) { body.with(OAuth2ParameterNames.REDIRECT_URI, redirectUri); } - String codeVerifier = authorizationExchange.getAuthorizationRequest().getAttribute(PkceParameterNames.CODE_VERIFIER); + String codeVerifier = authorizationExchange.getAuthorizationRequest() + .getAttribute(PkceParameterNames.CODE_VERIFIER); if (codeVerifier != null) { body.with(PkceParameterNames.CODE_VERIFIER, codeVerifier); } return body; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java index b5f4142ef6..7ad39faf0c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java @@ -13,29 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Set; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import java.util.Set; - /** - * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" - * a client credential for an access token credential - * at the Authorization Server's Token Endpoint. + * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that + * "exchanges" a client credential for an access token credential at the + * Authorization Server's Token Endpoint. * * @author Rob Winch * @since 5.1 * @see ReactiveOAuth2AccessTokenResponseClient * @see OAuth2AuthorizationCodeGrantRequest * @see OAuth2AccessTokenResponse - * @see Nimbus OAuth 2.0 SDK - * @see Section 4.1.3 Access Token Request (Authorization Code Grant) - * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + * @see Nimbus OAuth 2.0 + * SDK + * @see Section 4.1.3 Access Token Request + * (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response + * (Authorization Code Grant) */ -public class WebClientReactiveClientCredentialsTokenResponseClient extends - AbstractWebClientReactiveOAuth2AccessTokenResponseClient { +public class WebClientReactiveClientCredentialsTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override ClientRegistration clientRegistration(OAuth2ClientCredentialsGrantRequest grantRequest) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java index 442e2543fc..bac8801ea6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Set; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -22,24 +25,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; -import java.util.Set; - /** - * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} - * for the {@link AuthorizationGrantType#PASSWORD password} grant. - * This implementation uses {@link WebClient} when requesting - * an access token credential at the Authorization Server's Token Endpoint. + * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#PASSWORD password} grant. This implementation uses + * {@link WebClient} when requesting an access token credential at the Authorization + * Server's Token Endpoint. * * @author Joe Grandja * @since 5.2 * @see ReactiveOAuth2AccessTokenResponseClient * @see OAuth2PasswordGrantRequest * @see OAuth2AccessTokenResponse - * @see Section 4.3.2 Access Token Request (Resource Owner Password Credentials Grant) - * @see Section 4.3.3 Access Token Response (Resource Owner Password Credentials Grant) + * @see Section 4.3.2 Access Token Request + * (Resource Owner Password Credentials Grant) + * @see Section 4.3.3 Access Token Response + * (Resource Owner Password Credentials Grant) */ -public final class WebClientReactivePasswordTokenResponseClient extends - AbstractWebClientReactiveOAuth2AccessTokenResponseClient { +public final class WebClientReactivePasswordTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override ClientRegistration clientRegistration(OAuth2PasswordGrantRequest grantRequest) { @@ -52,12 +57,11 @@ public final class WebClientReactivePasswordTokenResponseClient extends } @Override - BodyInserters.FormInserter populateTokenRequestBody( - OAuth2PasswordGrantRequest grantRequest, + BodyInserters.FormInserter populateTokenRequestBody(OAuth2PasswordGrantRequest grantRequest, BodyInserters.FormInserter body) { return super.populateTokenRequestBody(grantRequest, body) - .with(OAuth2ParameterNames.USERNAME, grantRequest.getUsername()) - .with(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword()); + .with(OAuth2ParameterNames.USERNAME, grantRequest.getUsername()) + .with(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword()); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java index 9ad787af7f..ee09608f20 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Set; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -23,23 +26,22 @@ import org.springframework.util.CollectionUtils; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; -import java.util.Set; - /** - * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} - * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. - * This implementation uses {@link WebClient} when requesting - * an access token credential at the Authorization Server's Token Endpoint. + * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. This implementation + * uses {@link WebClient} when requesting an access token credential at the Authorization + * Server's Token Endpoint. * * @author Joe Grandja * @since 5.2 * @see ReactiveOAuth2AccessTokenResponseClient * @see OAuth2RefreshTokenGrantRequest * @see OAuth2AccessTokenResponse - * @see Section 6 Refreshing an Access Token + * @see Section 6 + * Refreshing an Access Token */ -public final class WebClientReactiveRefreshTokenTokenResponseClient extends - AbstractWebClientReactiveOAuth2AccessTokenResponseClient { +public final class WebClientReactiveRefreshTokenTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override ClientRegistration clientRegistration(OAuth2RefreshTokenGrantRequest grantRequest) { @@ -57,24 +59,21 @@ public final class WebClientReactiveRefreshTokenTokenResponseClient extends } @Override - BodyInserters.FormInserter populateTokenRequestBody( - OAuth2RefreshTokenGrantRequest grantRequest, + BodyInserters.FormInserter populateTokenRequestBody(OAuth2RefreshTokenGrantRequest grantRequest, BodyInserters.FormInserter body) { - return super.populateTokenRequestBody(grantRequest, body) - .with(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue()); + return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.REFRESH_TOKEN, + grantRequest.getRefreshToken().getTokenValue()); } @Override - OAuth2AccessTokenResponse populateTokenResponse( - OAuth2RefreshTokenGrantRequest grantRequest, + OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest, OAuth2AccessTokenResponse accessTokenResponse) { - - if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) && - accessTokenResponse.getRefreshToken() != null) { + if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) + && accessTokenResponse.getRefreshToken() != null) { return accessTokenResponse; } - - OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(accessTokenResponse); + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse + .withResponse(accessTokenResponse); if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) { tokenResponseBuilder.scopes(defaultScopes(grantRequest)); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/package-info.java index a443ff76b9..bd3ed6e516 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Classes and interfaces providing support to the client - * for initiating requests to the Authorization Server's Protocol Endpoints. + * Classes and interfaces providing support to the client for initiating requests to the + * Authorization Server's Protocol Endpoints. */ package org.springframework.security.oauth2.client.endpoint; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java index 3f865a5897..2b50b967ac 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.http; +import java.io.IOException; + import com.nimbusds.oauth2.sdk.token.BearerTokenError; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.client.ClientHttpResponse; @@ -27,18 +31,18 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.ResponseErrorHandler; -import java.io.IOException; - /** * A {@link ResponseErrorHandler} that handles an {@link OAuth2Error OAuth 2.0 Error}. * - * @see ResponseErrorHandler - * @see OAuth2Error * @author Joe Grandja * @since 5.1 + * @see ResponseErrorHandler + * @see OAuth2Error */ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler { + private final OAuth2ErrorHttpMessageConverter oauth2ErrorConverter = new OAuth2ErrorHttpMessageConverter(); + private final ResponseErrorHandler defaultErrorHandler = new DefaultResponseErrorHandler(); @Override @@ -51,14 +55,12 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler { if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) { this.defaultErrorHandler.handleError(response); } - // A Bearer Token Error may be in the WWW-Authenticate response header // See https://tools.ietf.org/html/rfc6750#section-3 OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders()); if (oauth2Error == null) { oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response); } - throw new OAuth2AuthorizationException(oauth2Error); } @@ -67,20 +69,21 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler { if (!StringUtils.hasText(wwwAuthenticateHeader)) { return null; } - - BearerTokenError bearerTokenError; - try { - bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader); - } catch (Exception ex) { - return null; - } - - String errorCode = bearerTokenError.getCode() != null ? - bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR; + BearerTokenError bearerTokenError = getBearerToken(wwwAuthenticateHeader); + String errorCode = (bearerTokenError.getCode() != null) ? bearerTokenError.getCode() + : OAuth2ErrorCodes.SERVER_ERROR; String errorDescription = bearerTokenError.getDescription(); - String errorUri = bearerTokenError.getURI() != null ? - bearerTokenError.getURI().toString() : null; - + String errorUri = (bearerTokenError.getURI() != null) ? bearerTokenError.getURI().toString() : null; return new OAuth2Error(errorCode, errorDescription, errorUri); } + + private BearerTokenError getBearerToken(String wwwAuthenticateHeader) { + try { + return BearerTokenError.parse(wwwAuthenticateHeader); + } + catch (Exception ex) { + return null; + } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java index 9e12424f74..d8cfde2efc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java @@ -13,27 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.io.IOException; + import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.util.StdConverter; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import java.io.IOException; - -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.MAP_TYPE_REFERENCE; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.SET_TYPE_REFERENCE; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findObjectNode; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findStringValue; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findValue; - /** * A {@code JsonDeserializer} for {@link ClientRegistration}. * @@ -43,43 +39,41 @@ import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils. * @see ClientRegistrationMixin */ final class ClientRegistrationDeserializer extends JsonDeserializer { - private static final StdConverter CLIENT_AUTHENTICATION_METHOD_CONVERTER = - new StdConverters.ClientAuthenticationMethodConverter(); - private static final StdConverter AUTHORIZATION_GRANT_TYPE_CONVERTER = - new StdConverters.AuthorizationGrantTypeConverter(); - private static final StdConverter AUTHENTICATION_METHOD_CONVERTER = - new StdConverters.AuthenticationMethodConverter(); + + private static final StdConverter CLIENT_AUTHENTICATION_METHOD_CONVERTER = new StdConverters.ClientAuthenticationMethodConverter(); + + private static final StdConverter AUTHORIZATION_GRANT_TYPE_CONVERTER = new StdConverters.AuthorizationGrantTypeConverter(); + + private static final StdConverter AUTHENTICATION_METHOD_CONVERTER = new StdConverters.AuthenticationMethodConverter(); @Override public ClientRegistration deserialize(JsonParser parser, DeserializationContext context) throws IOException { ObjectMapper mapper = (ObjectMapper) parser.getCodec(); JsonNode clientRegistrationNode = mapper.readTree(parser); - JsonNode providerDetailsNode = findObjectNode(clientRegistrationNode, "providerDetails"); - JsonNode userInfoEndpointNode = findObjectNode(providerDetailsNode, "userInfoEndpoint"); - + JsonNode providerDetailsNode = JsonNodeUtils.findObjectNode(clientRegistrationNode, "providerDetails"); + JsonNode userInfoEndpointNode = JsonNodeUtils.findObjectNode(providerDetailsNode, "userInfoEndpoint"); return ClientRegistration - .withRegistrationId(findStringValue(clientRegistrationNode, "registrationId")) - .clientId(findStringValue(clientRegistrationNode, "clientId")) - .clientSecret(findStringValue(clientRegistrationNode, "clientSecret")) - .clientAuthenticationMethod( - CLIENT_AUTHENTICATION_METHOD_CONVERTER.convert( - findObjectNode(clientRegistrationNode, "clientAuthenticationMethod"))) - .authorizationGrantType( - AUTHORIZATION_GRANT_TYPE_CONVERTER.convert( - findObjectNode(clientRegistrationNode, "authorizationGrantType"))) - .redirectUri(findStringValue(clientRegistrationNode, "redirectUri")) - .scope(findValue(clientRegistrationNode, "scopes", SET_TYPE_REFERENCE, mapper)) - .clientName(findStringValue(clientRegistrationNode, "clientName")) - .authorizationUri(findStringValue(providerDetailsNode, "authorizationUri")) - .tokenUri(findStringValue(providerDetailsNode, "tokenUri")) - .userInfoUri(findStringValue(userInfoEndpointNode, "uri")) - .userInfoAuthenticationMethod( - AUTHENTICATION_METHOD_CONVERTER.convert( - findObjectNode(userInfoEndpointNode, "authenticationMethod"))) - .userNameAttributeName(findStringValue(userInfoEndpointNode, "userNameAttributeName")) - .jwkSetUri(findStringValue(providerDetailsNode, "jwkSetUri")) - .issuerUri(findStringValue(providerDetailsNode, "issuerUri")) - .providerConfigurationMetadata(findValue(providerDetailsNode, "configurationMetadata", MAP_TYPE_REFERENCE, mapper)) + .withRegistrationId(JsonNodeUtils.findStringValue(clientRegistrationNode, "registrationId")) + .clientId(JsonNodeUtils.findStringValue(clientRegistrationNode, "clientId")) + .clientSecret(JsonNodeUtils.findStringValue(clientRegistrationNode, "clientSecret")) + .clientAuthenticationMethod(CLIENT_AUTHENTICATION_METHOD_CONVERTER + .convert(JsonNodeUtils.findObjectNode(clientRegistrationNode, "clientAuthenticationMethod"))) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE_CONVERTER + .convert(JsonNodeUtils.findObjectNode(clientRegistrationNode, "authorizationGrantType"))) + .redirectUri(JsonNodeUtils.findStringValue(clientRegistrationNode, "redirectUri")) + .scope(JsonNodeUtils.findValue(clientRegistrationNode, "scopes", JsonNodeUtils.STRING_SET, mapper)) + .clientName(JsonNodeUtils.findStringValue(clientRegistrationNode, "clientName")) + .authorizationUri(JsonNodeUtils.findStringValue(providerDetailsNode, "authorizationUri")) + .tokenUri(JsonNodeUtils.findStringValue(providerDetailsNode, "tokenUri")) + .userInfoUri(JsonNodeUtils.findStringValue(userInfoEndpointNode, "uri")) + .userInfoAuthenticationMethod(AUTHENTICATION_METHOD_CONVERTER + .convert(JsonNodeUtils.findObjectNode(userInfoEndpointNode, "authenticationMethod"))) + .userNameAttributeName(JsonNodeUtils.findStringValue(userInfoEndpointNode, "userNameAttributeName")) + .jwkSetUri(JsonNodeUtils.findStringValue(providerDetailsNode, "jwkSetUri")) + .issuerUri(JsonNodeUtils.findStringValue(providerDetailsNode, "issuerUri")) + .providerConfigurationMetadata(JsonNodeUtils.findValue(providerDetailsNode, "configurationMetadata", + JsonNodeUtils.STRING_OBJECT_MAP, mapper)) .build(); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationMixin.java index a60708495c..e56d1c039c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationMixin.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + import org.springframework.security.oauth2.client.registration.ClientRegistration; /** - * This mixin class is used to serialize/deserialize {@link ClientRegistration}. - * It also registers a custom deserializer {@link ClientRegistrationDeserializer}. + * This mixin class is used to serialize/deserialize {@link ClientRegistration}. It also + * registers a custom deserializer {@link ClientRegistrationDeserializer}. * * @author Joe Grandja * @since 5.3 @@ -37,4 +39,5 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonIgnoreProperties(ignoreUnknown = true) abstract class ClientRegistrationMixin { + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOAuth2UserMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOAuth2UserMixin.java index 056b27d59e..917062c905 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOAuth2UserMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOAuth2UserMixin.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Collection; +import java.util.Map; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; -import java.util.Collection; -import java.util.Map; - /** * This mixin class is used to serialize/deserialize {@link DefaultOAuth2User}. * @@ -41,9 +43,9 @@ import java.util.Map; abstract class DefaultOAuth2UserMixin { @JsonCreator - DefaultOAuth2UserMixin( - @JsonProperty("authorities") Collection authorities, + DefaultOAuth2UserMixin(@JsonProperty("authorities") Collection authorities, @JsonProperty("attributes") Map attributes, @JsonProperty("nameAttributeKey") String nameAttributeKey) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOidcUserMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOidcUserMixin.java index 003b6e707e..5b46dc9396 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOidcUserMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/DefaultOidcUserMixin.java @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Collection; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; -import java.util.Collection; - /** * This mixin class is used to serialize/deserialize {@link DefaultOidcUser}. * @@ -38,14 +40,13 @@ import java.util.Collection; @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, isGetterVisibility = JsonAutoDetect.Visibility.NONE) -@JsonIgnoreProperties(value = {"attributes"}, ignoreUnknown = true) +@JsonIgnoreProperties(value = { "attributes" }, ignoreUnknown = true) abstract class DefaultOidcUserMixin { @JsonCreator - DefaultOidcUserMixin( - @JsonProperty("authorities") Collection authorities, - @JsonProperty("idToken") OidcIdToken idToken, - @JsonProperty("userInfo") OidcUserInfo userInfo, + DefaultOidcUserMixin(@JsonProperty("authorities") Collection authorities, + @JsonProperty("idToken") OidcIdToken idToken, @JsonProperty("userInfo") OidcUserInfo userInfo, @JsonProperty("nameAttributeKey") String nameAttributeKey) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/JsonNodeUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/JsonNodeUtils.java index d9320227b7..351295efe6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/JsonNodeUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/JsonNodeUtils.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Map; +import java.util.Set; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import java.util.Map; -import java.util.Set; - /** * Utility class for {@code JsonNode}. * @@ -29,40 +30,36 @@ import java.util.Set; * @since 5.3 */ abstract class JsonNodeUtils { - static final TypeReference> SET_TYPE_REFERENCE = new TypeReference>() {}; - static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() {}; + static final TypeReference> STRING_SET = new TypeReference>() { + }; + + static final TypeReference> STRING_OBJECT_MAP = new TypeReference>() { + }; static String findStringValue(JsonNode jsonNode, String fieldName) { if (jsonNode == null) { return null; } - JsonNode nodeValue = jsonNode.findValue(fieldName); - if (nodeValue != null && nodeValue.isTextual()) { - return nodeValue.asText(); - } - return null; + JsonNode value = jsonNode.findValue(fieldName); + return (value != null && value.isTextual()) ? value.asText() : null; } - static T findValue(JsonNode jsonNode, String fieldName, TypeReference valueTypeReference, ObjectMapper mapper) { + static T findValue(JsonNode jsonNode, String fieldName, TypeReference valueTypeReference, + ObjectMapper mapper) { if (jsonNode == null) { return null; } - JsonNode nodeValue = jsonNode.findValue(fieldName); - if (nodeValue != null && nodeValue.isContainerNode()) { - return (T) mapper.convertValue(nodeValue, valueTypeReference); - } - return null; + JsonNode value = jsonNode.findValue(fieldName); + return (value != null && value.isContainerNode()) ? mapper.convertValue(value, valueTypeReference) : null; } static JsonNode findObjectNode(JsonNode jsonNode, String fieldName) { if (jsonNode == null) { return null; } - JsonNode nodeValue = jsonNode.findValue(fieldName); - if (nodeValue != null && nodeValue.isObject()) { - return nodeValue; - } - return null; + JsonNode value = jsonNode.findValue(fieldName); + return (value != null && value.isObject()) ? value : null; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AccessTokenMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AccessTokenMixin.java index f94eb87178..1d9e682829 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AccessTokenMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AccessTokenMixin.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.time.Instant; +import java.util.Set; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import java.time.Instant; -import java.util.Set; +import org.springframework.security.oauth2.core.OAuth2AccessToken; /** * This mixin class is used to serialize/deserialize {@link OAuth2AccessToken}. @@ -42,10 +44,10 @@ abstract class OAuth2AccessTokenMixin { @JsonCreator OAuth2AccessTokenMixin( - @JsonProperty("tokenType") @JsonDeserialize(converter = StdConverters.AccessTokenTypeConverter.class) OAuth2AccessToken.TokenType tokenType, - @JsonProperty("tokenValue") String tokenValue, - @JsonProperty("issuedAt") Instant issuedAt, - @JsonProperty("expiresAt") Instant expiresAt, - @JsonProperty("scopes") Set scopes) { + @JsonProperty("tokenType") @JsonDeserialize( + converter = StdConverters.AccessTokenTypeConverter.class) OAuth2AccessToken.TokenType tokenType, + @JsonProperty("tokenValue") String tokenValue, @JsonProperty("issuedAt") Instant issuedAt, + @JsonProperty("expiresAt") Instant expiresAt, @JsonProperty("scopes") Set scopes) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixin.java index 187eb16d8e..9fa810c1ac 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixin.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; @@ -20,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -34,13 +36,13 @@ import org.springframework.security.oauth2.core.OAuth2Error; */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, - isGetterVisibility = JsonAutoDetect.Visibility.NONE) -@JsonIgnoreProperties(ignoreUnknown = true, value = {"cause", "stackTrace", "suppressedExceptions"}) + isGetterVisibility = JsonAutoDetect.Visibility.NONE) +@JsonIgnoreProperties(ignoreUnknown = true, value = { "cause", "stackTrace", "suppressedExceptions" }) abstract class OAuth2AuthenticationExceptionMixin { @JsonCreator - OAuth2AuthenticationExceptionMixin( - @JsonProperty("error") OAuth2Error error, + OAuth2AuthenticationExceptionMixin(@JsonProperty("error") OAuth2Error error, @JsonProperty("detailMessage") String message) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixin.java index ebd3c1b77c..30ac991032 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixin.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Collection; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.core.user.OAuth2User; -import java.util.Collection; - /** * This mixin class is used to serialize/deserialize {@link OAuth2AuthenticationToken}. * @@ -37,13 +39,13 @@ import java.util.Collection; @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, isGetterVisibility = JsonAutoDetect.Visibility.NONE) -@JsonIgnoreProperties(value = {"authenticated"}, ignoreUnknown = true) +@JsonIgnoreProperties(value = { "authenticated" }, ignoreUnknown = true) abstract class OAuth2AuthenticationTokenMixin { @JsonCreator - OAuth2AuthenticationTokenMixin( - @JsonProperty("principal") OAuth2User principal, + OAuth2AuthenticationTokenMixin(@JsonProperty("principal") OAuth2User principal, @JsonProperty("authorities") Collection authorities, @JsonProperty("authorizedClientRegistrationId") String authorizedClientRegistrationId) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java index ae3d64ae38..00e717bb8f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.io.IOException; + import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; @@ -22,16 +25,10 @@ import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.util.StdConverter; + import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; - -import java.io.IOException; - -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.MAP_TYPE_REFERENCE; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.SET_TYPE_REFERENCE; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findObjectNode; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findStringValue; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findValue; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder; /** * A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}. @@ -42,35 +39,43 @@ import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils. * @see OAuth2AuthorizationRequestMixin */ final class OAuth2AuthorizationRequestDeserializer extends JsonDeserializer { - private static final StdConverter AUTHORIZATION_GRANT_TYPE_CONVERTER = - new StdConverters.AuthorizationGrantTypeConverter(); + + private static final StdConverter AUTHORIZATION_GRANT_TYPE_CONVERTER = new StdConverters.AuthorizationGrantTypeConverter(); @Override - public OAuth2AuthorizationRequest deserialize(JsonParser parser, DeserializationContext context) throws IOException { + public OAuth2AuthorizationRequest deserialize(JsonParser parser, DeserializationContext context) + throws IOException { ObjectMapper mapper = (ObjectMapper) parser.getCodec(); - JsonNode authorizationRequestNode = mapper.readTree(parser); - - AuthorizationGrantType authorizationGrantType = AUTHORIZATION_GRANT_TYPE_CONVERTER.convert( - findObjectNode(authorizationRequestNode, "authorizationGrantType")); - - OAuth2AuthorizationRequest.Builder builder; - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) { - builder = OAuth2AuthorizationRequest.authorizationCode(); - } else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) { - builder = OAuth2AuthorizationRequest.implicit(); - } else { - throw new JsonParseException(parser, "Invalid authorizationGrantType"); - } - - return builder - .authorizationUri(findStringValue(authorizationRequestNode, "authorizationUri")) - .clientId(findStringValue(authorizationRequestNode, "clientId")) - .redirectUri(findStringValue(authorizationRequestNode, "redirectUri")) - .scopes(findValue(authorizationRequestNode, "scopes", SET_TYPE_REFERENCE, mapper)) - .state(findStringValue(authorizationRequestNode, "state")) - .additionalParameters(findValue(authorizationRequestNode, "additionalParameters", MAP_TYPE_REFERENCE, mapper)) - .authorizationRequestUri(findStringValue(authorizationRequestNode, "authorizationRequestUri")) - .attributes(findValue(authorizationRequestNode, "attributes", MAP_TYPE_REFERENCE, mapper)) - .build(); + JsonNode root = mapper.readTree(parser); + return deserialize(parser, mapper, root); } + + private OAuth2AuthorizationRequest deserialize(JsonParser parser, ObjectMapper mapper, JsonNode root) + throws JsonParseException { + AuthorizationGrantType authorizationGrantType = AUTHORIZATION_GRANT_TYPE_CONVERTER + .convert(JsonNodeUtils.findObjectNode(root, "authorizationGrantType")); + Builder builder = getBuilder(parser, authorizationGrantType); + builder.authorizationUri(JsonNodeUtils.findStringValue(root, "authorizationUri")); + builder.clientId(JsonNodeUtils.findStringValue(root, "clientId")); + builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri")); + builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, mapper)); + builder.state(JsonNodeUtils.findStringValue(root, "state")); + builder.additionalParameters( + JsonNodeUtils.findValue(root, "additionalParameters", JsonNodeUtils.STRING_OBJECT_MAP, mapper)); + builder.authorizationRequestUri(JsonNodeUtils.findStringValue(root, "authorizationRequestUri")); + builder.attributes(JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP, mapper)); + return builder.build(); + } + + private OAuth2AuthorizationRequest.Builder getBuilder(JsonParser parser, + AuthorizationGrantType authorizationGrantType) throws JsonParseException { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) { + return OAuth2AuthorizationRequest.authorizationCode(); + } + if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) { + return OAuth2AuthorizationRequest.implicit(); + } + throw new JsonParseException(parser, "Invalid authorizationGrantType"); + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixin.java index 4e85728a1f..49f7b40903 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixin.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; /** @@ -37,4 +39,5 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonIgnoreProperties(ignoreUnknown = true) abstract class OAuth2AuthorizationRequestMixin { + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixin.java index 3ca81104eb..03cd18f2e0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixin.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; @@ -20,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -40,10 +42,10 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; abstract class OAuth2AuthorizedClientMixin { @JsonCreator - OAuth2AuthorizedClientMixin( - @JsonProperty("clientRegistration") ClientRegistration clientRegistration, + OAuth2AuthorizedClientMixin(@JsonProperty("clientRegistration") ClientRegistration clientRegistration, @JsonProperty("principalName") String principalName, @JsonProperty("accessToken") OAuth2AccessToken accessToken, @JsonProperty("refreshToken") OAuth2RefreshToken refreshToken) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ClientJackson2Module.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ClientJackson2Module.java index 4aa70fc847..4c8158247a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ClientJackson2Module.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ClientJackson2Module.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Collections; + import com.fasterxml.jackson.core.Version; import com.fasterxml.jackson.databind.module.SimpleModule; + import org.springframework.security.jackson2.SecurityJackson2Modules; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; @@ -33,39 +37,38 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; -import java.util.Collections; - /** - * Jackson {@code Module} for {@code spring-security-oauth2-client}, - * that registers the following mix-in annotations: + * Jackson {@code Module} for {@code spring-security-oauth2-client}, that registers the + * following mix-in annotations: * *

          - *
        • {@link OAuth2AuthorizationRequestMixin}
        • - *
        • {@link ClientRegistrationMixin}
        • - *
        • {@link OAuth2AccessTokenMixin}
        • - *
        • {@link OAuth2RefreshTokenMixin}
        • - *
        • {@link OAuth2AuthorizedClientMixin}
        • - *
        • {@link OAuth2UserAuthorityMixin}
        • - *
        • {@link DefaultOAuth2UserMixin}
        • - *
        • {@link OidcIdTokenMixin}
        • - *
        • {@link OidcUserInfoMixin}
        • - *
        • {@link OidcUserAuthorityMixin}
        • - *
        • {@link DefaultOidcUserMixin}
        • - *
        • {@link OAuth2AuthenticationTokenMixin}
        • - *
        • {@link OAuth2AuthenticationExceptionMixin}
        • - *
        • {@link OAuth2ErrorMixin}
        • + *
        • {@link OAuth2AuthorizationRequestMixin}
        • + *
        • {@link ClientRegistrationMixin}
        • + *
        • {@link OAuth2AccessTokenMixin}
        • + *
        • {@link OAuth2RefreshTokenMixin}
        • + *
        • {@link OAuth2AuthorizedClientMixin}
        • + *
        • {@link OAuth2UserAuthorityMixin}
        • + *
        • {@link DefaultOAuth2UserMixin}
        • + *
        • {@link OidcIdTokenMixin}
        • + *
        • {@link OidcUserInfoMixin}
        • + *
        • {@link OidcUserAuthorityMixin}
        • + *
        • {@link DefaultOidcUserMixin}
        • + *
        • {@link OAuth2AuthenticationTokenMixin}
        • + *
        • {@link OAuth2AuthenticationExceptionMixin}
        • + *
        • {@link OAuth2ErrorMixin}
        • *
        * - * If not already enabled, default typing will be automatically enabled - * as type info is required to properly serialize/deserialize objects. - * In order to use this module just add it to your {@code ObjectMapper} configuration. + * If not already enabled, default typing will be automatically enabled as type info is + * required to properly serialize/deserialize objects. In order to use this module just + * add it to your {@code ObjectMapper} configuration. * *
          *     ObjectMapper mapper = new ObjectMapper();
          *     mapper.registerModule(new OAuth2ClientJackson2Module());
          * 
        * - * NOTE: Use {@link SecurityJackson2Modules#getModules(ClassLoader)} to get a list of all security modules. + * NOTE: Use {@link SecurityJackson2Modules#getModules(ClassLoader)} to get a list + * of all security modules. * * @author Joe Grandja * @since 5.3 @@ -94,7 +97,8 @@ public class OAuth2ClientJackson2Module extends SimpleModule { @Override public void setupModule(SetupContext context) { SecurityJackson2Modules.enableDefaultTyping(context.getOwner()); - context.setMixInAnnotations(Collections.unmodifiableMap(Collections.emptyMap()).getClass(), UnmodifiableMapMixin.class); + context.setMixInAnnotations(Collections.unmodifiableMap(Collections.emptyMap()).getClass(), + UnmodifiableMapMixin.class); context.setMixInAnnotations(OAuth2AuthorizationRequest.class, OAuth2AuthorizationRequestMixin.class); context.setMixInAnnotations(ClientRegistration.class, ClientRegistrationMixin.class); context.setMixInAnnotations(OAuth2AccessToken.class, OAuth2AccessTokenMixin.class); @@ -110,4 +114,5 @@ public class OAuth2ClientJackson2Module extends SimpleModule { context.setMixInAnnotations(OAuth2AuthenticationException.class, OAuth2AuthenticationExceptionMixin.class); context.setMixInAnnotations(OAuth2Error.class, OAuth2ErrorMixin.class); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ErrorMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ErrorMixin.java index aba0c5b61f..cc8cf872cd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ErrorMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2ErrorMixin.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; @@ -20,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.oauth2.core.OAuth2Error; /** @@ -34,14 +36,13 @@ import org.springframework.security.oauth2.core.OAuth2Error; */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, - isGetterVisibility = JsonAutoDetect.Visibility.NONE) + isGetterVisibility = JsonAutoDetect.Visibility.NONE) @JsonIgnoreProperties(ignoreUnknown = true) abstract class OAuth2ErrorMixin { @JsonCreator - OAuth2ErrorMixin( - @JsonProperty("errorCode") String errorCode, - @JsonProperty("description") String description, + OAuth2ErrorMixin(@JsonProperty("errorCode") String errorCode, @JsonProperty("description") String description, @JsonProperty("uri") String uri) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2RefreshTokenMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2RefreshTokenMixin.java index 191fb3e2d1..f017ca6feb 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2RefreshTokenMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2RefreshTokenMixin.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.time.Instant; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import java.time.Instant; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; /** * This mixin class is used to serialize/deserialize {@link OAuth2RefreshToken}. @@ -39,8 +41,7 @@ import java.time.Instant; abstract class OAuth2RefreshTokenMixin { @JsonCreator - OAuth2RefreshTokenMixin( - @JsonProperty("tokenValue") String tokenValue, - @JsonProperty("issuedAt") Instant issuedAt) { + OAuth2RefreshTokenMixin(@JsonProperty("tokenValue") String tokenValue, @JsonProperty("issuedAt") Instant issuedAt) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2UserAuthorityMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2UserAuthorityMixin.java index 02509a3575..e870fb087f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2UserAuthorityMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2UserAuthorityMixin.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Map; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; -import java.util.Map; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; /** * This mixin class is used to serialize/deserialize {@link OAuth2UserAuthority}. @@ -39,8 +41,8 @@ import java.util.Map; abstract class OAuth2UserAuthorityMixin { @JsonCreator - OAuth2UserAuthorityMixin( - @JsonProperty("authority") String authority, + OAuth2UserAuthorityMixin(@JsonProperty("authority") String authority, @JsonProperty("attributes") Map attributes) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcIdTokenMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcIdTokenMixin.java index d795e157b7..cdd0d0f592 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcIdTokenMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcIdTokenMixin.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.time.Instant; +import java.util.Map; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import java.time.Instant; -import java.util.Map; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; /** * This mixin class is used to serialize/deserialize {@link OidcIdToken}. @@ -40,10 +42,8 @@ import java.util.Map; abstract class OidcIdTokenMixin { @JsonCreator - OidcIdTokenMixin( - @JsonProperty("tokenValue") String tokenValue, - @JsonProperty("issuedAt") Instant issuedAt, - @JsonProperty("expiresAt") Instant expiresAt, - @JsonProperty("claims") Map claims) { + OidcIdTokenMixin(@JsonProperty("tokenValue") String tokenValue, @JsonProperty("issuedAt") Instant issuedAt, + @JsonProperty("expiresAt") Instant expiresAt, @JsonProperty("claims") Map claims) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserAuthorityMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserAuthorityMixin.java index 1ab4de1ac2..5e5e0eda0a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserAuthorityMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserAuthorityMixin.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; @@ -20,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; @@ -35,13 +37,12 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, isGetterVisibility = JsonAutoDetect.Visibility.NONE) -@JsonIgnoreProperties(value = {"attributes"}, ignoreUnknown = true) +@JsonIgnoreProperties(value = { "attributes" }, ignoreUnknown = true) abstract class OidcUserAuthorityMixin { @JsonCreator - OidcUserAuthorityMixin( - @JsonProperty("authority") String authority, - @JsonProperty("idToken") OidcIdToken idToken, + OidcUserAuthorityMixin(@JsonProperty("authority") String authority, @JsonProperty("idToken") OidcIdToken idToken, @JsonProperty("userInfo") OidcUserInfo userInfo) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserInfoMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserInfoMixin.java index 89b131fbdb..8dbdd4e4dc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserInfoMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OidcUserInfoMixin.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Map; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import org.springframework.security.oauth2.core.oidc.OidcUserInfo; -import java.util.Map; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; /** * This mixin class is used to serialize/deserialize {@link OidcUserInfo}. @@ -41,4 +43,5 @@ abstract class OidcUserInfoMixin { @JsonCreator OidcUserInfoMixin(@JsonProperty("claims") Map claims) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/StdConverters.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/StdConverters.java index 10510e5baf..9d5afd26c1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/StdConverters.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/StdConverters.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.util.StdConverter; + import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils.findStringValue; - /** * {@code StdConverter} implementations. * @@ -33,60 +33,76 @@ import static org.springframework.security.oauth2.client.jackson2.JsonNodeUtils. abstract class StdConverters { static final class AccessTokenTypeConverter extends StdConverter { + @Override public OAuth2AccessToken.TokenType convert(JsonNode jsonNode) { - String value = findStringValue(jsonNode, "value"); + String value = JsonNodeUtils.findStringValue(jsonNode, "value"); if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(value)) { return OAuth2AccessToken.TokenType.BEARER; } return null; } + } static final class ClientAuthenticationMethodConverter extends StdConverter { + @Override public ClientAuthenticationMethod convert(JsonNode jsonNode) { - String value = findStringValue(jsonNode, "value"); + String value = JsonNodeUtils.findStringValue(jsonNode, "value"); if (ClientAuthenticationMethod.BASIC.getValue().equalsIgnoreCase(value)) { return ClientAuthenticationMethod.BASIC; - } else if (ClientAuthenticationMethod.POST.getValue().equalsIgnoreCase(value)) { + } + if (ClientAuthenticationMethod.POST.getValue().equalsIgnoreCase(value)) { return ClientAuthenticationMethod.POST; - } else if (ClientAuthenticationMethod.NONE.getValue().equalsIgnoreCase(value)) { + } + if (ClientAuthenticationMethod.NONE.getValue().equalsIgnoreCase(value)) { return ClientAuthenticationMethod.NONE; } return null; } + } static final class AuthorizationGrantTypeConverter extends StdConverter { + @Override public AuthorizationGrantType convert(JsonNode jsonNode) { - String value = findStringValue(jsonNode, "value"); + String value = JsonNodeUtils.findStringValue(jsonNode, "value"); if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) { return AuthorizationGrantType.AUTHORIZATION_CODE; - } else if (AuthorizationGrantType.IMPLICIT.getValue().equalsIgnoreCase(value)) { + } + if (AuthorizationGrantType.IMPLICIT.getValue().equalsIgnoreCase(value)) { return AuthorizationGrantType.IMPLICIT; - } else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equalsIgnoreCase(value)) { + } + if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equalsIgnoreCase(value)) { return AuthorizationGrantType.CLIENT_CREDENTIALS; - } else if (AuthorizationGrantType.PASSWORD.getValue().equalsIgnoreCase(value)) { + } + if (AuthorizationGrantType.PASSWORD.getValue().equalsIgnoreCase(value)) { return AuthorizationGrantType.PASSWORD; } return null; } + } static final class AuthenticationMethodConverter extends StdConverter { + @Override public AuthenticationMethod convert(JsonNode jsonNode) { - String value = findStringValue(jsonNode, "value"); + String value = JsonNodeUtils.findStringValue(jsonNode, "value"); if (AuthenticationMethod.HEADER.getValue().equalsIgnoreCase(value)) { return AuthenticationMethod.HEADER; - } else if (AuthenticationMethod.FORM.getValue().equalsIgnoreCase(value)) { + } + if (AuthenticationMethod.FORM.getValue().equalsIgnoreCase(value)) { return AuthenticationMethod.FORM; - } else if (AuthenticationMethod.QUERY.getValue().equalsIgnoreCase(value)) { + } + if (AuthenticationMethod.QUERY.getValue().equalsIgnoreCase(value)) { return AuthenticationMethod.QUERY; } return null; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapDeserializer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapDeserializer.java index e7f97a1b87..d44b9c7278 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapDeserializer.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapDeserializer.java @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.IOException; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; - /** * A {@code JsonDeserializer} for {@link Collections#unmodifiableMap(Map)}. * @@ -49,4 +50,5 @@ final class UnmodifiableMapDeserializer extends JsonDeserializer> { } return Collections.unmodifiableMap(result); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapMixin.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapMixin.java index 18753d4154..95941cc7e0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapMixin.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/UnmodifiableMapMixin.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.Collections; +import java.util.Map; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import java.util.Collections; -import java.util.Map; - /** - * This mixin class is used to serialize/deserialize {@link Collections#unmodifiableMap(Map)}. - * It also registers a custom deserializer {@link UnmodifiableMapDeserializer}. + * This mixin class is used to serialize/deserialize + * {@link Collections#unmodifiableMap(Map)}. It also registers a custom deserializer + * {@link UnmodifiableMapDeserializer}. * * @author Joe Grandja * @since 5.3 @@ -39,4 +41,5 @@ abstract class UnmodifiableMapMixin { @JsonCreator UnmodifiableMapMixin(Map map) { } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/DefaultOidcIdTokenValidatorFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/DefaultOidcIdTokenValidatorFactory.java index 70c5a6749c..c1eb285814 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/DefaultOidcIdTokenValidatorFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/DefaultOidcIdTokenValidatorFactory.java @@ -13,18 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; +import java.util.function.Function; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtTimestampValidator; -import java.util.function.Function; - /** - * * @author Joe Grandja * @since 5.2 */ @@ -32,7 +32,8 @@ class DefaultOidcIdTokenValidatorFactory implements Function apply(ClientRegistration clientRegistration) { - return new DelegatingOAuth2TokenValidator<>( - new JwtTimestampValidator(), new OidcIdTokenValidator(clientRegistration)); + return new DelegatingOAuth2TokenValidator<>(new JwtTimestampValidator(), + new OidcIdTokenValidator(clientRegistration)); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java index 9246502d74..9c956fcfcc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Collection; +import java.util.Map; + import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -43,26 +51,19 @@ import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.util.Assert; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Base64; -import java.util.Collection; -import java.util.Map; - /** - * An implementation of an {@link AuthenticationProvider} - * for the OpenID Connect Core 1.0 Authorization Code Grant Flow. + * An implementation of an {@link AuthenticationProvider} for the OpenID Connect Core 1.0 + * Authorization Code Grant Flow. *

        - * This {@link AuthenticationProvider} is responsible for authenticating - * an Authorization Code credential with the Authorization Server's Token Endpoint - * and if valid, exchanging it for an Access Token credential. + * This {@link AuthenticationProvider} is responsible for authenticating an Authorization + * Code credential with the Authorization Server's Token Endpoint and if valid, exchanging + * it for an Access Token credential. *

        - * It will also obtain the user attributes of the End-User (Resource Owner) - * from the UserInfo Endpoint using an {@link OAuth2UserService}, - * which will create a {@code Principal} in the form of an {@link OidcUser}. - * The {@code OidcUser} is then associated to the {@link OAuth2LoginAuthenticationToken} - * to complete the authentication. + * It will also obtain the user attributes of the End-User (Resource Owner) from the + * UserInfo Endpoint using an {@link OAuth2UserService}, which will create a + * {@code Principal} in the form of an {@link OidcUser}. The {@code OidcUser} is then + * associated to the {@link OAuth2LoginAuthenticationToken} to complete the + * authentication. * * @author Joe Grandja * @author Mark Heckler @@ -72,29 +73,43 @@ import java.util.Map; * @see OidcUserService * @see OidcUser * @see OidcIdTokenDecoderFactory - * @see Section 3.1 Authorization Code Grant Flow - * @see Section 3.1.3.1 Token Request - * @see Section 3.1.3.3 Token Response + * @see Section 3.1 + * Authorization Code Grant Flow + * @see Section 3.1.3.1 + * Token Request + * @see Section 3.1.3.3 + * Token Response */ public class OidcAuthorizationCodeAuthenticationProvider implements AuthenticationProvider { + private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce"; + private final OAuth2AccessTokenResponseClient accessTokenResponseClient; + private final OAuth2UserService userService; + private JwtDecoderFactory jwtDecoderFactory = new OidcIdTokenDecoderFactory(); - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); + + private GrantedAuthoritiesMapper authoritiesMapper = ((authorities) -> authorities); /** - * Constructs an {@code OidcAuthorizationCodeAuthenticationProvider} using the provided parameters. - * - * @param accessTokenResponseClient the client used for requesting the access token credential from the Token Endpoint - * @param userService the service used for obtaining the user attributes of the End-User from the UserInfo Endpoint + * Constructs an {@code OidcAuthorizationCodeAuthenticationProvider} using the + * provided parameters. + * @param accessTokenResponseClient the client used for requesting the access token + * credential from the Token Endpoint + * @param userService the service used for obtaining the user attributes of the + * End-User from the UserInfo Endpoint */ public OidcAuthorizationCodeAuthenticationProvider( - OAuth2AccessTokenResponseClient accessTokenResponseClient, - OAuth2UserService userService) { - + OAuth2AccessTokenResponseClient accessTokenResponseClient, + OAuth2UserService userService) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(userService, "userService cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; @@ -103,97 +118,95 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2LoginAuthenticationToken authorizationCodeAuthentication = - (OAuth2LoginAuthenticationToken) authentication; - - // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + OAuth2LoginAuthenticationToken authorizationCodeAuthentication = (OAuth2LoginAuthenticationToken) authentication; + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // scope - // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (!authorizationCodeAuthentication.getAuthorizationExchange() - .getAuthorizationRequest().getScopes().contains(OidcScopes.OPENID)) { + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. + if (!authorizationCodeAuthentication.getAuthorizationExchange().getAuthorizationRequest().getScopes() + .contains(OidcScopes.OPENID)) { // This is NOT an OpenID Connect Authentication Request so return null // and let OAuth2LoginAuthenticationProvider handle it instead return null; } - - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationRequest(); - OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationResponse(); - + OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() + .getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication.getAuthorizationExchange() + .getAuthorizationResponse(); if (authorizationResponse.statusError()) { - throw new OAuth2AuthenticationException( - authorizationResponse.getError(), authorizationResponse.getError().toString()); + throw new OAuth2AuthenticationException(authorizationResponse.getError(), + authorizationResponse.getError().toString()); } - if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - - OAuth2AccessTokenResponse accessTokenResponse; - try { - accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange())); - } catch (OAuth2AuthorizationException ex) { - OAuth2Error oauth2Error = ex.getError(); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - + OAuth2AccessTokenResponse accessTokenResponse = getResponse(authorizationCodeAuthentication); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); - Map additionalParameters = accessTokenResponse.getAdditionalParameters(); if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) { - OAuth2Error invalidIdTokenError = new OAuth2Error( - INVALID_ID_TOKEN_ERROR_CODE, - "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), - null); + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, + "Missing (required) ID Token in Token Response for Client Registration: " + + clientRegistration.getRegistrationId(), + null); throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); } OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse); - - // Validate nonce - String requestNonce = authorizationRequest.getAttribute(OidcParameterNames.NONCE); - if (requestNonce != null) { - String nonceHash; - try { - nonceHash = createHash(requestNonce); - } catch (NoSuchAlgorithmException e) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - String nonceHashClaim = idToken.getNonce(); - if (nonceHashClaim == null || !nonceHashClaim.equals(nonceHash)) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - } - - OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest( - clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters)); - Collection mappedAuthorities = - this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities()); - + validateNonce(authorizationRequest, idToken); + OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(clientRegistration, + accessTokenResponse.getAccessToken(), idToken, additionalParameters)); + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oidcUser.getAuthorities()); OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( - authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange(), - oidcUser, - mappedAuthorities, - accessTokenResponse.getAccessToken(), - accessTokenResponse.getRefreshToken()); + authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange(), oidcUser, mappedAuthorities, + accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()); authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); - return authenticationResult; } + private OAuth2AccessTokenResponse getResponse(OAuth2LoginAuthenticationToken authorizationCodeAuthentication) { + try { + return this.accessTokenResponseClient.getTokenResponse( + new OAuth2AuthorizationCodeGrantRequest(authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange())); + } + catch (OAuth2AuthorizationException ex) { + OAuth2Error oauth2Error = ex.getError(); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateNonce(OAuth2AuthorizationRequest authorizationRequest, OidcIdToken idToken) { + String requestNonce = authorizationRequest.getAttribute(OidcParameterNames.NONCE); + if (requestNonce == null) { + return; + } + String nonceHash = getNonceHash(requestNonce); + String nonceHashClaim = idToken.getNonce(); + if (nonceHashClaim == null || !nonceHashClaim.equals(nonceHash)) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private String getNonceHash(String requestNonce) { + try { + return createHash(requestNonce); + } + catch (NoSuchAlgorithmException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + /** - * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature verification. - * The factory returns a {@link JwtDecoder} associated to the provided {@link ClientRegistration}. - * + * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature + * verification. The factory returns a {@link JwtDecoder} associated to the provided + * {@link ClientRegistration}. + * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} + * signature verification * @since 5.2 - * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature verification */ public final void setJwtDecoderFactory(JwtDecoderFactory jwtDecoderFactory) { Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); @@ -201,10 +214,11 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati } /** - * Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OidcUser#getAuthorities()}} - * to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}. - * - * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OidcUser#getAuthorities()}} to a new set of authorities which will be + * associated to the {@link OAuth2LoginAuthenticationToken}. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the + * user's authorities */ public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); @@ -216,17 +230,24 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati return OAuth2LoginAuthenticationToken.class.isAssignableFrom(authentication); } - private OidcIdToken createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) { + private OidcIdToken createOidcToken(ClientRegistration clientRegistration, + OAuth2AccessTokenResponse accessTokenResponse) { JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); - Jwt jwt; + Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); + OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), + jwt.getClaims()); + return idToken; + } + + private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) { try { - jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); - } catch (JwtException ex) { + Map parameters = accessTokenResponse.getAdditionalParameters(); + return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN)); + } + catch (JwtException ex) { OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); } - OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()); - return idToken; } static String createHash(String nonce) throws NoSuchAlgorithmException { @@ -234,4 +255,5 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati byte[] digest = md.digest(nonce.getBytes(StandardCharsets.US_ASCII)); return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index a413dca145..4d536afb5a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Collection; +import java.util.Map; + +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; @@ -41,28 +51,22 @@ import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; - -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Base64; -import java.util.Collection; -import java.util.Map; /** - * An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login, - * which leverages the OAuth 2.0 Authorization Code Grant Flow. + * An implementation of an + * {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth + * 2.0 Login, which leverages the OAuth 2.0 Authorization Code Grant Flow. *

        - * This {@link org.springframework.security.authentication.AuthenticationProvider} is responsible for authenticating - * an Authorization Code credential with the Authorization Server's Token Endpoint - * and if valid, exchanging it for an Access Token credential. + * This {@link org.springframework.security.authentication.AuthenticationProvider} is + * responsible for authenticating an Authorization Code credential with the Authorization + * Server's Token Endpoint and if valid, exchanging it for an Access Token credential. *

        - * It will also obtain the user attributes of the End-User (Resource Owner) - * from the UserInfo Endpoint using an {@link org.springframework.security.oauth2.client.userinfo.OAuth2UserService}, - * which will create a {@code Principal} in the form of an {@link OAuth2User}. - * The {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} - * to complete the authentication. + * It will also obtain the user attributes of the End-User (Resource Owner) from the + * UserInfo Endpoint using an + * {@link org.springframework.security.oauth2.client.userinfo.OAuth2UserService}, which + * will create a {@code Principal} in the form of an {@link OAuth2User}. The + * {@code OAuth2User} is then associated to the {@link OAuth2LoginAuthenticationToken} to + * complete the authentication. * * @author Rob Winch * @author Mark Heckler @@ -72,22 +76,28 @@ import java.util.Map; * @see ReactiveOAuth2UserService * @see OAuth2User * @see ReactiveOidcIdTokenDecoderFactory - * @see Section 4.1 Authorization Code Grant Flow - * @see Section 4.1.3 Access Token Request - * @see Section 4.1.4 Access Token Response + * @see Section + * 4.1 Authorization Code Grant Flow + * @see Section 4.1.3 Access Token + * Request + * @see Section 4.1.4 Access Token + * Response */ -public class OidcAuthorizationCodeReactiveAuthenticationManager implements - ReactiveAuthenticationManager { +public class OidcAuthorizationCodeReactiveAuthenticationManager implements ReactiveAuthenticationManager { private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter"; + private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token"; + private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce"; private final ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; private final ReactiveOAuth2UserService userService; - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); + private GrantedAuthoritiesMapper authoritiesMapper = ((authorities) -> authorities); private ReactiveJwtDecoderFactory jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(); @@ -104,53 +114,51 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements public Mono authenticate(Authentication authentication) { return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - - // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - if (!authorizationCodeAuthentication.getAuthorizationExchange() - .getAuthorizationRequest().getScopes().contains("openid")) { + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope + // value. + if (!authorizationCodeAuthentication.getAuthorizationExchange().getAuthorizationRequest().getScopes() + .contains("openid")) { // This is an OpenID Connect Authentication Request so return empty // and let OAuth2LoginReactiveAuthenticationManager handle it instead return Mono.empty(); } - - - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication - .getAuthorizationExchange().getAuthorizationRequest(); + OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() + .getAuthorizationRequest(); OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication .getAuthorizationExchange().getAuthorizationResponse(); - if (authorizationResponse.statusError()) { - return Mono.error(new OAuth2AuthenticationException( - authorizationResponse.getError(), authorizationResponse.getError().toString())); + return Mono.error(new OAuth2AuthenticationException(authorizationResponse.getError(), + authorizationResponse.getError().toString())); } - if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); - return Mono.error(new OAuth2AuthenticationException( - oauth2Error, oauth2Error.toString())); + return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString())); } - OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange()); - - return this.accessTokenResponseClient.getTokenResponse(authzRequest) - .flatMap(accessTokenResponse -> authenticationResult(authorizationCodeAuthentication, accessTokenResponse)) - .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) - .onErrorMap(JwtException.class, e -> { - OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null); - return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e); + return this.accessTokenResponseClient.getTokenResponse(authzRequest).flatMap( + (accessTokenResponse) -> authenticationResult(authorizationCodeAuthentication, accessTokenResponse)) + .onErrorMap(OAuth2AuthorizationException.class, + (e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) + .onErrorMap(JwtException.class, (e) -> { + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), + null); + return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), + e); }); }); } /** - * Sets the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature verification. - * The factory returns a {@link ReactiveJwtDecoder} associated to the provided {@link ClientRegistration}. - * + * Sets the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature + * verification. The factory returns a {@link ReactiveJwtDecoder} associated to the + * provided {@link ClientRegistration}. + * @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} used for + * {@link OidcIdToken} signature verification * @since 5.2 - * @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature verification */ public final void setJwtDecoderFactory(ReactiveJwtDecoderFactory jwtDecoderFactory) { Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null"); @@ -158,66 +166,64 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements } /** - * Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OidcUser#getAuthorities()} - * to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}. - * + * Sets the {@link GrantedAuthoritiesMapper} used for mapping + * {@link OidcUser#getAuthorities()} to a new set of authorities which will be + * associated to the {@link OAuth2LoginAuthenticationToken}. + * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the + * user's authorities * @since 5.4 - * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities */ public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null"); this.authoritiesMapper = authoritiesMapper; } - private Mono authenticationResult(OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { + private Mono authenticationResult( + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, + OAuth2AccessTokenResponse accessTokenResponse) { OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); Map additionalParameters = accessTokenResponse.getAdditionalParameters(); - if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) { - OAuth2Error invalidIdTokenError = new OAuth2Error( - INVALID_ID_TOKEN_ERROR_CODE, - "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), + OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, + "Missing (required) ID Token in Token Response for Client Registration: " + + clientRegistration.getRegistrationId(), null); return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString())); } - + // @formatter:off return createOidcToken(clientRegistration, accessTokenResponse) - .doOnNext(idToken -> validateNonce(authorizationCodeAuthentication, idToken)) - .map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters)) + .doOnNext((idToken) -> validateNonce(authorizationCodeAuthentication, idToken)) + .map((idToken) -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters)) .flatMap(this.userService::loadUser) - .map(oauth2User -> { - Collection mappedAuthorities = - this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); - - return new OAuth2LoginAuthenticationToken( - authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange(), - oauth2User, - mappedAuthorities, - accessToken, - accessTokenResponse.getRefreshToken()); + .map((oauth2User) -> { + Collection mappedAuthorities = this.authoritiesMapper + .mapAuthorities(oauth2User.getAuthorities()); + return new OAuth2LoginAuthenticationToken(authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange(), oauth2User, mappedAuthorities, + accessToken, accessTokenResponse.getRefreshToken()); }); + // @formatter:on } - private Mono createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) { + private Mono createOidcToken(ClientRegistration clientRegistration, + OAuth2AccessTokenResponse accessTokenResponse) { ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); String rawIdToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN); + // @formatter:off return jwtDecoder.decode(rawIdToken) - .map(jwt -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims())); + .map((jwt) -> + new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()) + ); + // @formatter:on } - private static Mono validateNonce(OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OidcIdToken idToken) { - String requestNonce = authorizationCodeAuthentication.getAuthorizationExchange() - .getAuthorizationRequest().getAttribute(OidcParameterNames.NONCE); + private static Mono validateNonce( + OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OidcIdToken idToken) { + String requestNonce = authorizationCodeAuthentication.getAuthorizationExchange().getAuthorizationRequest() + .getAttribute(OidcParameterNames.NONCE); if (requestNonce != null) { - String nonceHash; - try { - nonceHash = createHash(requestNonce); - } catch (NoSuchAlgorithmException e) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } + String nonceHash = getNonceHash(requestNonce); String nonceHashClaim = idToken.getNonce(); if (nonceHashClaim == null || !nonceHashClaim.equals(nonceHash)) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); @@ -228,9 +234,20 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements return Mono.just(idToken); } + private static String getNonceHash(String requestNonce) { + try { + return createHash(requestNonce); + } + catch (NoSuchAlgorithmException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + static String createHash(String nonce) throws NoSuchAlgorithmException { MessageDigest md = MessageDigest.getInstance("SHA-256"); byte[] digest = md.digest(nonce.getBytes(StandardCharsets.US_ASCII)); return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index da51f71d62..a349ddbe0e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; import java.net.URL; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; + import javax.crypto.spec.SecretKeySpec; import org.springframework.core.convert.TypeDescriptor; @@ -47,13 +50,10 @@ import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey; - /** - * A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} - * used for {@link OidcIdToken} signature verification. - * The provided {@link JwtDecoder} is associated to a specific {@link ClientRegistration}. + * A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for + * {@link OidcIdToken} signature verification. The provided {@link JwtDecoder} is + * associated to a specific {@link ClientRegistration}. * * @author Joe Grandja * @author Rafael Dominguez @@ -64,26 +64,36 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecre * @see OidcIdToken */ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory { + private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier"; - private static Map jcaAlgorithmMappings = new HashMap() { - { - put(MacAlgorithm.HS256, "HmacSHA256"); - put(MacAlgorithm.HS384, "HmacSHA384"); - put(MacAlgorithm.HS512, "HmacSHA512"); - } + + private static final Map JCA_ALGORITHM_MAPPINGS; + static { + Map mappings = new HashMap<>(); + mappings.put(MacAlgorithm.HS256, "HmacSHA256"); + mappings.put(MacAlgorithm.HS384, "HmacSHA384"); + mappings.put(MacAlgorithm.HS512, "HmacSHA512"); + JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); }; - private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = - new ClaimTypeConverter(createDefaultClaimTypeConverters()); + + private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( + createDefaultClaimTypeConverters()); + private final Map jwtDecoders = new ConcurrentHashMap<>(); + private Function> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); - private Function jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256; - private Function, Map>> claimTypeConverterFactory = - clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; + + private Function jwsAlgorithmResolver = ( + clientRegistration) -> SignatureAlgorithm.RS256; + + private Function, Map>> claimTypeConverterFactory = ( + clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; /** - * Returns the default {@link Converter}'s used for type conversion of claim values for an {@link OidcIdToken}. - * - * @return a {@link Map} of {@link Converter}'s keyed by {@link IdTokenClaimNames claim name} + * Returns the default {@link Converter}'s used for type conversion of claim values + * for an {@link OidcIdToken}. + * @return a {@link Map} of {@link Converter}'s keyed by {@link IdTokenClaimNames + * claim name} */ public static Map> createDefaultClaimTypeConverters() { Converter booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); @@ -92,34 +102,34 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory stringConverter = getConverter(TypeDescriptor.valueOf(String.class)); Converter collectionStringConverter = getConverter( TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class))); - - Map> claimTypeConverters = new HashMap<>(); - claimTypeConverters.put(IdTokenClaimNames.ISS, urlConverter); - claimTypeConverters.put(IdTokenClaimNames.AUD, collectionStringConverter); - claimTypeConverters.put(IdTokenClaimNames.NONCE, stringConverter); - claimTypeConverters.put(IdTokenClaimNames.EXP, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.IAT, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AMR, collectionStringConverter); - claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.UPDATED_AT, instantConverter); - return claimTypeConverters; + Map> converters = new HashMap<>(); + converters.put(IdTokenClaimNames.ISS, urlConverter); + converters.put(IdTokenClaimNames.AUD, collectionStringConverter); + converters.put(IdTokenClaimNames.NONCE, stringConverter); + converters.put(IdTokenClaimNames.EXP, instantConverter); + converters.put(IdTokenClaimNames.IAT, instantConverter); + converters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); + converters.put(IdTokenClaimNames.AMR, collectionStringConverter); + converters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.UPDATED_AT, instantConverter); + return converters; } private static Converter getConverter(TypeDescriptor targetDescriptor) { - final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - return source -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); + TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); + return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, + targetDescriptor); } @Override public JwtDecoder createDecoder(ClientRegistration clientRegistration) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> { + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration); jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); - Converter, Map> claimTypeConverter = - this.claimTypeConverterFactory.apply(clientRegistration); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); if (claimTypeConverter != null) { jwtDecoder.setClaimSetConverter(claimTypeConverter); } @@ -134,68 +144,66 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory> jwtValidatorFactory) { Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null"); @@ -204,11 +212,11 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory jwsAlgorithmResolver) { Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null"); @@ -216,14 +224,17 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory, Map>> claimTypeConverterFactory) { + public void setClaimTypeConverterFactory( + Function, Map>> claimTypeConverterFactory) { Assert.notNull(claimTypeConverterFactory, "claimTypeConverterFactory cannot be null"); this.claimTypeConverterFactory = claimTypeConverterFactory; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java index 03ae739a7b..2d53babd6d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; +import java.net.URL; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; @@ -26,30 +36,27 @@ import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.net.URL; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; - /** - * An {@link OAuth2TokenValidator} responsible for - * validating the claims in an {@link OidcIdToken ID Token}. + * An {@link OAuth2TokenValidator} responsible for validating the claims in an + * {@link OidcIdToken ID Token}. * * @author Rob Winch * @author Joe Grandja * @since 5.1 * @see OAuth2TokenValidator * @see Jwt - * @see ID Token Validation + * @see ID Token + * Validation */ public final class OidcIdTokenValidator implements OAuth2TokenValidator { + private static final Duration DEFAULT_CLOCK_SKEW = Duration.ofSeconds(60); + private final ClientRegistration clientRegistration; + private Duration clockSkew = DEFAULT_CLOCK_SKEW; + private Clock clock = Clock.systemUTC(); public OidcIdTokenValidator(ClientRegistration clientRegistration) { @@ -59,75 +66,68 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { @Override public OAuth2TokenValidatorResult validate(Jwt idToken) { - // 3.1.3.7 ID Token Validation + // 3.1.3.7 ID Token Validation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation - Map invalidClaims = validateRequiredClaims(idToken); if (!invalidClaims.isEmpty()) { return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims)); } - - // 2. The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery) + // 2. The Issuer Identifier for the OpenID Provider (which is typically obtained + // during Discovery) // MUST exactly match the value of the iss (issuer) Claim. String metadataIssuer = this.clientRegistration.getProviderDetails().getIssuerUri(); - if (metadataIssuer != null && !Objects.equals(metadataIssuer, idToken.getIssuer().toExternalForm())) { invalidClaims.put(IdTokenClaimNames.ISS, idToken.getIssuer()); } - - // 3. The Client MUST validate that the aud (audience) Claim contains its client_id value + // 3. The Client MUST validate that the aud (audience) Claim contains its + // client_id value // registered at the Issuer identified by the iss (issuer) Claim as an audience. // The aud (audience) Claim MAY contain an array with more than one element. - // The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience, + // The ID Token MUST be rejected if the ID Token does not list the Client as a + // valid audience, // or if it contains additional audiences not trusted by the Client. if (!idToken.getAudience().contains(this.clientRegistration.getClientId())) { invalidClaims.put(IdTokenClaimNames.AUD, idToken.getAudience()); } - // 4. If the ID Token contains multiple audiences, // the Client SHOULD verify that an azp Claim is present. String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP); if (idToken.getAudience().size() > 1 && authorizedParty == null) { invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty); } - // 5. If an azp (authorized party) Claim is present, // the Client SHOULD verify that its client_id is the Claim Value. if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) { invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty); } - - // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client + // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the + // Client // in the id_token_signed_response_alg parameter during Registration. // TODO Depends on gh-4413 - // 9. The current time MUST be before the time represented by the exp Claim. Instant now = Instant.now(this.clock); if (now.minus(this.clockSkew).isAfter(idToken.getExpiresAt())) { invalidClaims.put(IdTokenClaimNames.EXP, idToken.getExpiresAt()); } - - // 10. The iat Claim can be used to reject tokens that were issued too far away from the current time, + // 10. The iat Claim can be used to reject tokens that were issued too far away + // from the current time, // limiting the amount of time that nonces need to be stored to prevent attacks. // The acceptable range is Client specific. if (now.plus(this.clockSkew).isBefore(idToken.getIssuedAt())) { invalidClaims.put(IdTokenClaimNames.IAT, idToken.getIssuedAt()); } - if (!invalidClaims.isEmpty()) { return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims)); } - return OAuth2TokenValidatorResult.success(); } /** - * Sets the maximum acceptable clock skew. The default is 60 seconds. - * The clock skew is used when validating the {@link JwtClaimNames#EXP exp} - * and {@link JwtClaimNames#IAT iat} claims. - * - * @since 5.2 + * Sets the maximum acceptable clock skew. The default is 60 seconds. The clock skew + * is used when validating the {@link JwtClaimNames#EXP exp} and + * {@link JwtClaimNames#IAT iat} claims. * @param clockSkew the maximum acceptable clock skew + * @since 5.2 */ public void setClockSkew(Duration clockSkew) { Assert.notNull(clockSkew, "clockSkew cannot be null"); @@ -136,12 +136,10 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { } /** - * Sets the {@link Clock} used in {@link Instant#now(Clock)} - * when validating the {@link JwtClaimNames#EXP exp} - * and {@link JwtClaimNames#IAT iat} claims. - * - * @since 5.3 + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when validating the + * {@link JwtClaimNames#EXP exp} and {@link JwtClaimNames#IAT iat} claims. * @param clock the clock + * @since 5.3 */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); @@ -149,14 +147,12 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { } private static OAuth2Error invalidIdToken(Map invalidClaims) { - return new OAuth2Error("invalid_id_token", - "The ID Token contains invalid claims: " + invalidClaims, + return new OAuth2Error("invalid_id_token", "The ID Token contains invalid claims: " + invalidClaims, "https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation"); } private static Map validateRequiredClaims(Jwt idToken) { Map requiredClaims = new HashMap<>(); - URL issuer = idToken.getIssuer(); if (issuer == null) { requiredClaims.put(IdTokenClaimNames.ISS, issuer); @@ -177,7 +173,7 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { if (issuedAt == null) { requiredClaims.put(IdTokenClaimNames.IAT, issuedAt); } - return requiredClaims; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java index 6661efb7ed..ec4b0bcbfa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; import java.net.URL; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; + import javax.crypto.spec.SecretKeySpec; import org.springframework.core.convert.TypeDescriptor; @@ -47,13 +50,10 @@ import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withSecretKey; - /** * A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder} - * used for {@link OidcIdToken} signature verification. - * The provided {@link ReactiveJwtDecoder} is associated to a specific {@link ClientRegistration}. + * used for {@link OidcIdToken} signature verification. The provided + * {@link ReactiveJwtDecoder} is associated to a specific {@link ClientRegistration}. * * @author Joe Grandja * @author Rafael Dominguez @@ -64,26 +64,36 @@ import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.w * @see OidcIdToken */ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecoderFactory { + private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier"; - private static Map jcaAlgorithmMappings = new HashMap() { - { - put(MacAlgorithm.HS256, "HmacSHA256"); - put(MacAlgorithm.HS384, "HmacSHA384"); - put(MacAlgorithm.HS512, "HmacSHA512"); - } - }; - private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = - new ClaimTypeConverter(createDefaultClaimTypeConverters()); + + private static final Map JCA_ALGORITHM_MAPPINGS; + static { + Map mappings = new HashMap(); + mappings.put(MacAlgorithm.HS256, "HmacSHA256"); + mappings.put(MacAlgorithm.HS384, "HmacSHA384"); + mappings.put(MacAlgorithm.HS512, "HmacSHA512"); + JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); + } + + private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( + createDefaultClaimTypeConverters()); + private final Map jwtDecoders = new ConcurrentHashMap<>(); + private Function> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); - private Function jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256; - private Function, Map>> claimTypeConverterFactory = - clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; + + private Function jwsAlgorithmResolver = ( + clientRegistration) -> SignatureAlgorithm.RS256; + + private Function, Map>> claimTypeConverterFactory = ( + clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; /** - * Returns the default {@link Converter}'s used for type conversion of claim values for an {@link OidcIdToken}. - * - * @return a {@link Map} of {@link Converter}'s keyed by {@link IdTokenClaimNames claim name} + * Returns the default {@link Converter}'s used for type conversion of claim values + * for an {@link OidcIdToken}. + * @return a {@link Map} of {@link Converter}'s keyed by {@link IdTokenClaimNames + * claim name} */ public static Map> createDefaultClaimTypeConverters() { Converter booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); @@ -92,34 +102,34 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod Converter stringConverter = getConverter(TypeDescriptor.valueOf(String.class)); Converter collectionStringConverter = getConverter( TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class))); - - Map> claimTypeConverters = new HashMap<>(); - claimTypeConverters.put(IdTokenClaimNames.ISS, urlConverter); - claimTypeConverters.put(IdTokenClaimNames.AUD, collectionStringConverter); - claimTypeConverters.put(IdTokenClaimNames.NONCE, stringConverter); - claimTypeConverters.put(IdTokenClaimNames.EXP, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.IAT, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AMR, collectionStringConverter); - claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.UPDATED_AT, instantConverter); - return claimTypeConverters; + Map> converters = new HashMap<>(); + converters.put(IdTokenClaimNames.ISS, urlConverter); + converters.put(IdTokenClaimNames.AUD, collectionStringConverter); + converters.put(IdTokenClaimNames.NONCE, stringConverter); + converters.put(IdTokenClaimNames.EXP, instantConverter); + converters.put(IdTokenClaimNames.IAT, instantConverter); + converters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); + converters.put(IdTokenClaimNames.AMR, collectionStringConverter); + converters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.UPDATED_AT, instantConverter); + return converters; } private static Converter getConverter(TypeDescriptor targetDescriptor) { final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - return source -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); + return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, + targetDescriptor); } @Override public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> { + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration); jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); - Converter, Map> claimTypeConverter = - this.claimTypeConverterFactory.apply(clientRegistration); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); if (claimTypeConverter != null) { jwtDecoder.setClaimSetConverter(claimTypeConverter); } @@ -134,68 +144,68 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod // // 6. If the ID Token is received via direct communication between the Client // and the Token Endpoint (which it is in this flow), - // the TLS server validation MAY be used to validate the issuer in place of checking the token signature. - // The Client MUST validate the signature of all other ID Tokens according to JWS [JWS] + // the TLS server validation MAY be used to validate the issuer in place of + // checking the token signature. + // The Client MUST validate the signature of all other ID Tokens according to + // JWS [JWS] // using the algorithm specified in the JWT alg Header Parameter. // The Client MUST use the keys provided by the Issuer. // - // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client + // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by + // the Client // in the id_token_signed_response_alg parameter during Registration. - String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); if (!StringUtils.hasText(jwkSetUri)) { - OAuth2Error oauth2Error = new OAuth2Error( - MISSING_SIGNATURE_VERIFIER_ERROR_CODE, - "Failed to find a Signature Verifier for Client Registration: '" + - clientRegistration.getRegistrationId() + - "'. Check to ensure you have configured the JwkSet URI.", - null - ); + OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, + "Failed to find a Signature Verifier for Client Registration: '" + + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured the JwkSet URI.", + null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - return withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); - } else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { + return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) + .build(); + } + if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // - // 8. If the JWT alg Header Parameter uses a MAC based algorithm such as HS256, HS384, or HS512, + // 8. If the JWT alg Header Parameter uses a MAC based algorithm such as + // HS256, HS384, or HS512, // the octets of the UTF-8 representation of the client_secret // corresponding to the client_id contained in the aud (audience) Claim // are used as the key to validate the signature. - // For MAC based algorithms, the behavior is unspecified if the aud is multi-valued or + // For MAC based algorithms, the behavior is unspecified if the aud is + // multi-valued or // if an azp value is present that is different than the aud value. String clientSecret = clientRegistration.getClientSecret(); if (!StringUtils.hasText(clientSecret)) { - OAuth2Error oauth2Error = new OAuth2Error( - MISSING_SIGNATURE_VERIFIER_ERROR_CODE, - "Failed to find a Signature Verifier for Client Registration: '" + - clientRegistration.getRegistrationId() + - "'. Check to ensure you have configured the client secret.", - null - ); + OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, + "Failed to find a Signature Verifier for Client Registration: '" + + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured the client secret.", + null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - SecretKeySpec secretKeySpec = new SecretKeySpec( - clientSecret.getBytes(StandardCharsets.UTF_8), jcaAlgorithmMappings.get(jwsAlgorithm)); - return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build(); + SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), + JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm)); + return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm) + .build(); } - - OAuth2Error oauth2Error = new OAuth2Error( - MISSING_SIGNATURE_VERIFIER_ERROR_CODE, - "Failed to find a Signature Verifier for Client Registration: '" + - clientRegistration.getRegistrationId() + - "'. Check to ensure you have configured a valid JWS Algorithm: '" + - jwsAlgorithm + "'", - null - ); + OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, + "Failed to find a Signature Verifier for Client Registration: '" + + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'", + null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } /** - * Sets the factory that provides an {@link OAuth2TokenValidator}, which is used by the {@link ReactiveJwtDecoder}. - * The default composes {@link JwtTimestampValidator} and {@link OidcIdTokenValidator}. - * - * @param jwtValidatorFactory the factory that provides an {@link OAuth2TokenValidator} + * Sets the factory that provides an {@link OAuth2TokenValidator}, which is used by + * the {@link ReactiveJwtDecoder}. The default composes {@link JwtTimestampValidator} + * and {@link OidcIdTokenValidator}. + * @param jwtValidatorFactory the factory that provides an + * {@link OAuth2TokenValidator} */ public void setJwtValidatorFactory(Function> jwtValidatorFactory) { Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null"); @@ -204,11 +214,11 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod /** * Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm} - * used for the signature or MAC on the {@link OidcIdToken ID Token}. - * The default resolves to {@link SignatureAlgorithm#RS256 RS256} for all {@link ClientRegistration clients}. - * - * @param jwsAlgorithmResolver the resolver that provides the expected {@link JwsAlgorithm JWS algorithm} - * for a specific {@link ClientRegistration client} + * used for the signature or MAC on the {@link OidcIdToken ID Token}. The default + * resolves to {@link SignatureAlgorithm#RS256 RS256} for all + * {@link ClientRegistration clients}. + * @param jwsAlgorithmResolver the resolver that provides the expected + * {@link JwsAlgorithm JWS algorithm} for a specific {@link ClientRegistration client} */ public void setJwsAlgorithmResolver(Function jwsAlgorithmResolver) { Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null"); @@ -216,14 +226,17 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod } /** - * Sets the factory that provides a {@link Converter} used for type conversion of claim values for an {@link OidcIdToken}. - * The default is {@link ClaimTypeConverter} for all {@link ClientRegistration clients}. - * - * @param claimTypeConverterFactory the factory that provides a {@link Converter} used for type conversion - * of claim values for a specific {@link ClientRegistration client} + * Sets the factory that provides a {@link Converter} used for type conversion of + * claim values for an {@link OidcIdToken}. The default is {@link ClaimTypeConverter} + * for all {@link ClientRegistration clients}. + * @param claimTypeConverterFactory the factory that provides a {@link Converter} used + * for type conversion of claim values for a specific {@link ClientRegistration + * client} */ - public void setClaimTypeConverterFactory(Function, Map>> claimTypeConverterFactory) { + public void setClaimTypeConverterFactory( + Function, Map>> claimTypeConverterFactory) { Assert.notNull(claimTypeConverterFactory, "claimTypeConverterFactory cannot be null"); this.claimTypeConverterFactory = claimTypeConverterFactory; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/package-info.java index 6e60be59e7..e487281a4a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Support classes and interfaces for authenticating and authorizing a client - * with an OpenID Connect 1.0 Provider using a specific authorization grant flow. + * Support classes and interfaces for authenticating and authorizing a client with an + * OpenID Connect 1.0 Provider using a specific authorization grant flow. */ package org.springframework.security.oauth2.client.oidc.authentication; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index 5893d799d8..1845ff0d11 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Instant; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; @@ -36,17 +46,10 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; - -import java.time.Instant; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.function.Function; /** - * An implementation of an {@link ReactiveOAuth2UserService} that supports OpenID Connect 1.0 Provider's. + * An implementation of an {@link ReactiveOAuth2UserService} that supports OpenID Connect + * 1.0 Provider's. * * @author Rob Winch * @since 5.1 @@ -56,29 +59,28 @@ import java.util.function.Function; * @see DefaultOidcUser * @see OidcUserInfo */ -public class OidcReactiveOAuth2UserService implements - ReactiveOAuth2UserService { +public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService { private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; - private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = - new ClaimTypeConverter(createDefaultClaimTypeConverters()); + private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( + createDefaultClaimTypeConverters()); private ReactiveOAuth2UserService oauth2UserService = new DefaultReactiveOAuth2UserService(); - private Function, Map>> claimTypeConverterFactory = - clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; + private Function, Map>> claimTypeConverterFactory = ( + clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; /** - * Returns the default {@link Converter}'s used for type conversion of claim values for an {@link OidcUserInfo}. - + * Returns the default {@link Converter}'s used for type conversion of claim values + * for an {@link OidcUserInfo}. + * @return a {@link Map} of {@link Converter}'s keyed by {@link StandardClaimNames + * claim name} * @since 5.2 - * @return a {@link Map} of {@link Converter}'s keyed by {@link StandardClaimNames claim name} */ public static Map> createDefaultClaimTypeConverters() { Converter booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); Converter instantConverter = getConverter(TypeDescriptor.valueOf(Instant.class)); - Map> claimTypeConverters = new HashMap<>(); claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); @@ -88,57 +90,63 @@ public class OidcReactiveOAuth2UserService implements private static Converter getConverter(TypeDescriptor targetDescriptor) { final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - return source -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); + return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, + targetDescriptor); } @Override public Mono loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { Assert.notNull(userRequest, "userRequest cannot be null"); + // @formatter:off return getUserInfo(userRequest) - .map(userInfo -> new OidcUserAuthority(userRequest.getIdToken(), userInfo)) - .defaultIfEmpty(new OidcUserAuthority(userRequest.getIdToken(), null)) - .map(authority -> { - OidcUserInfo userInfo = authority.getUserInfo(); - Set authorities = new HashSet<>(); - authorities.add(authority); - OAuth2AccessToken token = userRequest.getAccessToken(); - for (String scope : token.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); - } - String userNameAttributeName = userRequest.getClientRegistration() - .getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName(); - if (StringUtils.hasText(userNameAttributeName)) { - return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); - } else { + .map((userInfo) -> + new OidcUserAuthority(userRequest.getIdToken(), userInfo) + ) + .defaultIfEmpty(new OidcUserAuthority(userRequest.getIdToken(), null)) + .map((authority) -> { + OidcUserInfo userInfo = authority.getUserInfo(); + Set authorities = new HashSet<>(); + authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } + String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails() + .getUserInfoEndpoint().getUserNameAttributeName(); + if (StringUtils.hasText(userNameAttributeName)) { + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, + userNameAttributeName); + } return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); - } - }); + }); + // @formatter:on } private Mono getUserInfo(OidcUserRequest userRequest) { if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) { return Mono.empty(); } - - return this.oauth2UserService.loadUser(userRequest) - .map(OAuth2User::getAttributes) - .map(claims -> convertClaims(claims, userRequest.getClientRegistration())) - .map(OidcUserInfo::new) - .doOnNext(userInfo -> { - String subject = userInfo.getSubject(); - if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - }); + // @formatter:off + return this.oauth2UserService + .loadUser(userRequest) + .map(OAuth2User::getAttributes) + .map((claims) -> convertClaims(claims, userRequest.getClientRegistration())) + .map(OidcUserInfo::new) + .doOnNext((userInfo) -> { + String subject = userInfo.getSubject(); + if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + }); + // @formatter:on } private Map convertClaims(Map claims, ClientRegistration clientRegistration) { - Converter, Map> claimTypeConverter = - this.claimTypeConverterFactory.apply(clientRegistration); - return claimTypeConverter != null ? - claimTypeConverter.convert(claims) : - DEFAULT_CLAIM_TYPE_CONVERTER.convert(claims); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); + return (claimTypeConverter != null) ? claimTypeConverter.convert(claims) + : DEFAULT_CLAIM_TYPE_CONVERTER.convert(claims); } public void setOauth2UserService(ReactiveOAuth2UserService oauth2UserService) { @@ -147,15 +155,18 @@ public class OidcReactiveOAuth2UserService implements } /** - * Sets the factory that provides a {@link Converter} used for type conversion of claim values for an {@link OidcUserInfo}. - * The default is {@link ClaimTypeConverter} for all {@link ClientRegistration clients}. - * + * Sets the factory that provides a {@link Converter} used for type conversion of + * claim values for an {@link OidcUserInfo}. The default is {@link ClaimTypeConverter} + * for all {@link ClientRegistration clients}. + * @param claimTypeConverterFactory the factory that provides a {@link Converter} used + * for type conversion of claim values for a specific {@link ClientRegistration + * client} * @since 5.2 - * @param claimTypeConverterFactory the factory that provides a {@link Converter} used for type conversion - * of claim values for a specific {@link ClientRegistration client} */ - public final void setClaimTypeConverterFactory(Function, Map>> claimTypeConverterFactory) { + public final void setClaimTypeConverterFactory( + Function, Map>> claimTypeConverterFactory) { Assert.notNull(claimTypeConverterFactory, "claimTypeConverterFactory cannot be null"); this.claimTypeConverterFactory = claimTypeConverterFactory; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java index ce8c52d8c4..19f78658d6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java @@ -13,20 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.userinfo; +import java.util.Collections; +import java.util.Map; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.util.Assert; -import java.util.Collections; -import java.util.Map; - /** - * Represents a request the {@link OidcUserService} uses - * when initiating a request to the UserInfo Endpoint. + * Represents a request the {@link OidcUserService} uses when initiating a request to the + * UserInfo Endpoint. * * @author Joe Grandja * @since 5.0 @@ -36,33 +37,29 @@ import java.util.Map; * @see OidcUserService */ public class OidcUserRequest extends OAuth2UserRequest { + private final OidcIdToken idToken; /** * Constructs an {@code OidcUserRequest} using the provided parameters. - * * @param clientRegistration the client registration * @param accessToken the access token credential * @param idToken the ID Token */ - public OidcUserRequest(ClientRegistration clientRegistration, - OAuth2AccessToken accessToken, OidcIdToken idToken) { - + public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, OidcIdToken idToken) { this(clientRegistration, accessToken, idToken, Collections.emptyMap()); } /** * Constructs an {@code OidcUserRequest} using the provided parameters. - * - * @since 5.1 * @param clientRegistration the client registration * @param accessToken the access token credential * @param idToken the ID Token * @param additionalParameters the additional parameters, may be empty + * @since 5.1 */ - public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, - OidcIdToken idToken, Map additionalParameters) { - + public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, OidcIdToken idToken, + Map additionalParameters) { super(clientRegistration, accessToken, additionalParameters); Assert.notNull(idToken, "idToken cannot be null"); this.idToken = idToken; @@ -70,10 +67,10 @@ public class OidcUserRequest extends OAuth2UserRequest { /** * Returns the {@link OidcIdToken ID Token} containing claims about the user. - * * @return the {@link OidcIdToken} containing claims about the user. */ public OidcIdToken getIdToken() { return this.idToken; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java index 5a7787d351..e8e6362479 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java @@ -30,13 +30,14 @@ import org.springframework.util.StringUtils; final class OidcUserRequestUtils { /** - * Determines if an {@link OidcUserRequest} should attempt to retrieve the user info endpoint. Will return true if - * all of the following are true: + * Determines if an {@link OidcUserRequest} should attempt to retrieve the user info + * endpoint. Will return true if all of the following are true: * *

          - *
        • The user info endpoint is defined on the ClientRegistration
        • - *
        • The Client Registration uses the {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the - * access token are defined in the {@link ClientRegistration}
        • + *
        • The user info endpoint is defined on the ClientRegistration
        • + *
        • The Client Registration uses the + * {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the access token + * are defined in the {@link ClientRegistration}
        • *
        * @param userRequest * @return @@ -44,27 +45,28 @@ final class OidcUserRequestUtils { static boolean shouldRetrieveUserInfo(OidcUserRequest userRequest) { // Auto-disabled if UserInfo Endpoint URI is not provided ClientRegistration clientRegistration = userRequest.getClientRegistration(); - if (StringUtils.isEmpty(clientRegistration.getProviderDetails() - .getUserInfoEndpoint().getUri())) { - + if (StringUtils.isEmpty(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri())) { return false; } - // The Claims requested by the profile, email, address, and phone scope values // are returned from the UserInfo Endpoint (as described in Section 5.3.2), - // when a response_type value is used that results in an Access Token being issued. - // However, when no Access Token is issued, which is the case for the response_type=id_token, + // when a response_type value is used that results in an Access Token being + // issued. + // However, when no Access Token is issued, which is the case for the + // response_type=id_token, // the resulting Claims are returned in the ID Token. - // The Authorization Code Grant Flow, which is response_type=code, results in an Access Token being issued. + // The Authorization Code Grant Flow, which is response_type=code, results in an + // Access Token being issued. if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - - // Return true if there is at least one match between the authorized scope(s) and UserInfo scope(s) - return CollectionUtils - .containsAny(userRequest.getAccessToken().getScopes(), userRequest.getClientRegistration().getScopes()); + // Return true if there is at least one match between the authorized scope(s) + // and UserInfo scope(s) + return CollectionUtils.containsAny(userRequest.getAccessToken().getScopes(), + userRequest.getClientRegistration().getScopes()); } - return false; } - private OidcUserRequestUtils() {} + private OidcUserRequestUtils() { + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index 3fb35f5ada..31f181213a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; @@ -29,6 +30,7 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; @@ -50,7 +52,8 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; /** - * An implementation of an {@link OAuth2UserService} that supports OpenID Connect 1.0 Provider's. + * An implementation of an {@link OAuth2UserService} that supports OpenID Connect 1.0 + * Provider's. * * @author Joe Grandja * @since 5.0 @@ -61,25 +64,30 @@ import org.springframework.util.StringUtils; * @see OidcUserInfo */ public class OidcUserService implements OAuth2UserService { + private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; - private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = - new ClaimTypeConverter(createDefaultClaimTypeConverters()); - private Set accessibleScopes = new HashSet<>(Arrays.asList( - OidcScopes.PROFILE, OidcScopes.EMAIL, OidcScopes.ADDRESS, OidcScopes.PHONE)); + + private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( + createDefaultClaimTypeConverters()); + + private Set accessibleScopes = new HashSet<>( + Arrays.asList(OidcScopes.PROFILE, OidcScopes.EMAIL, OidcScopes.ADDRESS, OidcScopes.PHONE)); + private OAuth2UserService oauth2UserService = new DefaultOAuth2UserService(); - private Function, Map>> claimTypeConverterFactory = - clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; + + private Function, Map>> claimTypeConverterFactory = ( + clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; /** - * Returns the default {@link Converter}'s used for type conversion of claim values for an {@link OidcUserInfo}. - + * Returns the default {@link Converter}'s used for type conversion of claim values + * for an {@link OidcUserInfo}. + * @return a {@link Map} of {@link Converter}'s keyed by {@link StandardClaimNames + * claim name} * @since 5.2 - * @return a {@link Map} of {@link Converter}'s keyed by {@link StandardClaimNames claim name} */ public static Map> createDefaultClaimTypeConverters() { Converter booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); Converter instantConverter = getConverter(TypeDescriptor.valueOf(Instant.class)); - Map> claimTypeConverters = new HashMap<>(); claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); @@ -88,8 +96,9 @@ public class OidcUserService implements OAuth2UserService getConverter(TypeDescriptor targetDescriptor) { - final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - return source -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); + TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); + return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, + targetDescriptor); } @Override @@ -98,26 +107,16 @@ public class OidcUserService implements OAuth2UserService claims; - Converter, Map> claimTypeConverter = - this.claimTypeConverterFactory.apply(userRequest.getClientRegistration()); - if (claimTypeConverter != null) { - claims = claimTypeConverter.convert(oauth2User.getAttributes()); - } else { - claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes()); - } + Map claims = getClaims(userRequest, oauth2User); userInfo = new OidcUserInfo(claims); - // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse - // 1) The sub (subject) Claim MUST always be returned in the UserInfo Response if (userInfo.getSubject() == null) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - - // 2) Due to the possibility of token substitution attacks (see Section 16.11), + // 2) Due to the possibility of token substitution attacks (see Section + // 16.11), // the UserInfo Response is not guaranteed to be about the End-User // identified by the sub (subject) element of the ID Token. // The sub Claim in the UserInfo Response MUST be verified to exactly match @@ -128,57 +127,63 @@ public class OidcUserService implements OAuth2UserService authorities = new LinkedHashSet<>(); authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); OAuth2AccessToken token = userRequest.getAccessToken(); for (String authority : token.getScopes()) { authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); } + return getUser(userRequest, userInfo, authorities); + } - OidcUser user; - - String userNameAttributeName = userRequest.getClientRegistration() - .getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName(); - if (StringUtils.hasText(userNameAttributeName)) { - user = new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); - } else { - user = new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); + private Map getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) { + Converter, Map> converter = this.claimTypeConverterFactory + .apply(userRequest.getClientRegistration()); + if (converter != null) { + return converter.convert(oauth2User.getAttributes()); } + return DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes()); + } - return user; + private OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo, Set authorities) { + ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); + String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName(); + if (StringUtils.hasText(userNameAttributeName)) { + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); + } + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); } private boolean shouldRetrieveUserInfo(OidcUserRequest userRequest) { // Auto-disabled if UserInfo Endpoint URI is not provided - if (StringUtils.isEmpty(userRequest.getClientRegistration().getProviderDetails() - .getUserInfoEndpoint().getUri())) { - + ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); + if (StringUtils.isEmpty(providerDetails.getUserInfoEndpoint().getUri())) { return false; } - // The Claims requested by the profile, email, address, and phone scope values // are returned from the UserInfo Endpoint (as described in Section 5.3.2), - // when a response_type value is used that results in an Access Token being issued. - // However, when no Access Token is issued, which is the case for the response_type=id_token, + // when a response_type value is used that results in an Access Token being + // issued. + // However, when no Access Token is issued, which is the case for the + // response_type=id_token, // the resulting Claims are returned in the ID Token. - // The Authorization Code Grant Flow, which is response_type=code, results in an Access Token being issued. - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals( - userRequest.getClientRegistration().getAuthorizationGrantType())) { - - // Return true if there is at least one match between the authorized scope(s) and accessible scope(s) - return this.accessibleScopes.isEmpty() || - CollectionUtils.containsAny(userRequest.getAccessToken().getScopes(), this.accessibleScopes); + // The Authorization Code Grant Flow, which is response_type=code, results in an + // Access Token being issued. + if (AuthorizationGrantType.AUTHORIZATION_CODE + .equals(userRequest.getClientRegistration().getAuthorizationGrantType())) { + // Return true if there is at least one match between the authorized scope(s) + // and accessible scope(s) + return this.accessibleScopes.isEmpty() + || CollectionUtils.containsAny(userRequest.getAccessToken().getScopes(), this.accessibleScopes); } - return false; } /** * Sets the {@link OAuth2UserService} used when requesting the user info resource. - * + * @param oauth2UserService the {@link OAuth2UserService} used when requesting the + * user info resource. * @since 5.1 - * @param oauth2UserService the {@link OAuth2UserService} used when requesting the user info resource. */ public final void setOauth2UserService(OAuth2UserService oauth2UserService) { Assert.notNull(oauth2UserService, "oauth2UserService cannot be null"); @@ -186,30 +191,34 @@ public class OidcUserService implements OAuth2UserService, Map>> claimTypeConverterFactory) { + public final void setClaimTypeConverterFactory( + Function, Map>> claimTypeConverterFactory) { Assert.notNull(claimTypeConverterFactory, "claimTypeConverterFactory cannot be null"); this.claimTypeConverterFactory = claimTypeConverterFactory; } /** - * Sets the scope(s) that allow access to the user info resource. - * The default is {@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email}, {@link OidcScopes#ADDRESS address} and {@link OidcScopes#PHONE phone}. - * The scope(s) are checked against the "granted" scope(s) associated to the {@link OidcUserRequest#getAccessToken() access token} - * to determine if the user info resource is accessible or not. - * If there is at least one match, the user info resource will be requested, otherwise it will not. - * - * @since 5.2 + * Sets the scope(s) that allow access to the user info resource. The default is + * {@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email}, + * {@link OidcScopes#ADDRESS address} and {@link OidcScopes#PHONE phone}. The scope(s) + * are checked against the "granted" scope(s) associated to the + * {@link OidcUserRequest#getAccessToken() access token} to determine if the user info + * resource is accessible or not. If there is at least one match, the user info + * resource will be requested, otherwise it will not. * @param accessibleScopes the scope(s) that allow access to the user info resource + * @since 5.2 */ public final void setAccessibleScopes(Set accessibleScopes) { Assert.notNull(accessibleScopes, "accessibleScopes cannot be null"); this.accessibleScopes = accessibleScopes; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/package-info.java index 22e306921d..9db24f17d5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Classes and interfaces providing support to the client for initiating requests - * to the OpenID Connect 1.0 Provider's UserInfo Endpoint. + * Classes and interfaces providing support to the client for initiating requests to the + * OpenID Connect 1.0 Provider's UserInfo Endpoint. */ package org.springframework.security.oauth2.client.oidc.userinfo; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandler.java index 2e329cfa8b..66abfe60d4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandler.java @@ -19,12 +19,14 @@ package org.springframework.security.oauth2.client.oidc.web.logout; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.Collections; + import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler; @@ -38,10 +40,13 @@ import org.springframework.web.util.UriComponentsBuilder; * * @author Josh Cummings * @since 5.2 - * @see RP-Initiated Logout + * @see RP-Initiated + * Logout * @see org.springframework.security.web.authentication.logout.LogoutSuccessHandler */ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler { + private final ClientRegistrationRepository clientRegistrationRepository; private String postLogoutRedirectUri; @@ -52,39 +57,32 @@ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogo } @Override - protected String determineTargetUrl(HttpServletRequest request, - HttpServletResponse response, Authentication authentication) { + protected String determineTargetUrl(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) { String targetUrl = null; - URI endSessionEndpoint; if (authentication instanceof OAuth2AuthenticationToken && authentication.getPrincipal() instanceof OidcUser) { String registrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); ClientRegistration clientRegistration = this.clientRegistrationRepository .findByRegistrationId(registrationId); - endSessionEndpoint = this.endSessionEndpoint(clientRegistration); + URI endSessionEndpoint = this.endSessionEndpoint(clientRegistration); if (endSessionEndpoint != null) { String idToken = idToken(authentication); URI postLogoutRedirectUri = postLogoutRedirectUri(request); targetUrl = endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri); } } - if (targetUrl == null) { - targetUrl = super.determineTargetUrl(request, response); - } - - return targetUrl; + return (targetUrl != null) ? targetUrl : super.determineTargetUrl(request, response); } private URI endSessionEndpoint(ClientRegistration clientRegistration) { - URI result = null; if (clientRegistration != null) { - Object endSessionEndpoint = clientRegistration.getProviderDetails().getConfigurationMetadata() - .get("end_session_endpoint"); + ProviderDetails providerDetails = clientRegistration.getProviderDetails(); + Object endSessionEndpoint = providerDetails.getConfigurationMetadata().get("end_session_endpoint"); if (endSessionEndpoint != null) { - result = URI.create(endSessionEndpoint.toString()); + return URI.create(endSessionEndpoint.toString()); } } - - return result; + return null; } private String idToken(Authentication authentication) { @@ -95,7 +93,9 @@ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogo if (this.postLogoutRedirectUri == null) { return null; } - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + // @formatter:off + UriComponents uriComponents = UriComponentsBuilder + .fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) .replacePath(request.getContextPath()) .replaceQuery(null) .fragment(null) @@ -103,22 +103,26 @@ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogo return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri) .buildAndExpand(Collections.singletonMap("baseUrl", uriComponents.toUriString())) .toUri(); + // @formatter:on } - private String endpointUri(URI endSessionEndpoint, String idToken, URI postLogoutRedirectUri) { UriComponentsBuilder builder = UriComponentsBuilder.fromUri(endSessionEndpoint); builder.queryParam("id_token_hint", idToken); if (postLogoutRedirectUri != null) { builder.queryParam("post_logout_redirect_uri", postLogoutRedirectUri); } - return builder.encode(StandardCharsets.UTF_8).build().toUriString(); + // @formatter:off + return builder.encode(StandardCharsets.UTF_8) + .build() + .toUriString(); + // @formatter:on } /** * Set the post logout redirect uri to use - * - * @param postLogoutRedirectUri - A valid URL to which the OP should redirect after logging out the user + * @param postLogoutRedirectUri - A valid URL to which the OP should redirect after + * logging out the user * @deprecated {@link #setPostLogoutRedirectUri(String)} */ @Deprecated @@ -135,15 +139,15 @@ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogo * handler.setPostLogoutRedirectUri("{baseUrl}"); *
        * - * will make so that {@code post_logout_redirect_uri} will be set to the base url for the client - * application. - * - * @param postLogoutRedirectUri - A template for creating the {@code post_logout_redirect_uri} - * query parameter + * will make so that {@code post_logout_redirect_uri} will be set to the base url for + * the client application. + * @param postLogoutRedirectUri - A template for creating the + * {@code post_logout_redirect_uri} query parameter * @since 5.3 */ public void setPostLogoutRedirectUri(String postLogoutRedirectUri) { Assert.notNull(postLogoutRedirectUri, "postLogoutRedirectUri cannot be null"); this.postLogoutRedirectUri = postLogoutRedirectUri; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java index ce4a5dac77..903bf7ea88 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java @@ -42,44 +42,44 @@ import org.springframework.web.util.UriComponentsBuilder; * * @author Josh Cummings * @since 5.2 - * @see RP-Initiated Logout + * @see RP-Initiated + * Logout * @see org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler */ -public class OidcClientInitiatedServerLogoutSuccessHandler - implements ServerLogoutSuccessHandler { +public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogoutSuccessHandler { private final ServerRedirectStrategy redirectStrategy = new DefaultServerRedirectStrategy(); - private final RedirectServerLogoutSuccessHandler serverLogoutSuccessHandler - = new RedirectServerLogoutSuccessHandler(); + + private final RedirectServerLogoutSuccessHandler serverLogoutSuccessHandler = new RedirectServerLogoutSuccessHandler(); + private final ReactiveClientRegistrationRepository clientRegistrationRepository; private String postLogoutRedirectUri; /** - * Constructs an {@link OidcClientInitiatedServerLogoutSuccessHandler} with the provided parameters - * - * @param clientRegistrationRepository The {@link ReactiveClientRegistrationRepository} to use to derive - * the end_session_endpoint value + * Constructs an {@link OidcClientInitiatedServerLogoutSuccessHandler} with the + * provided parameters + * @param clientRegistrationRepository The + * {@link ReactiveClientRegistrationRepository} to use to derive the + * end_session_endpoint value */ - public OidcClientInitiatedServerLogoutSuccessHandler - (ReactiveClientRegistrationRepository clientRegistrationRepository) { - + public OidcClientInitiatedServerLogoutSuccessHandler( + ReactiveClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; } - /** - * {@inheritDoc} - */ @Override public Mono onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) { + // @formatter:off return Mono.just(authentication) .filter(OAuth2AuthenticationToken.class::isInstance) - .filter(token -> authentication.getPrincipal() instanceof OidcUser) + .filter((token) -> authentication.getPrincipal() instanceof OidcUser) .map(OAuth2AuthenticationToken.class::cast) .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId) .flatMap(this.clientRegistrationRepository::findByRegistrationId) - .flatMap(clientRegistration -> { + .flatMap((clientRegistration) -> { URI endSessionEndpoint = endSessionEndpoint(clientRegistration); if (endSessionEndpoint == null) { return Mono.empty(); @@ -88,22 +88,22 @@ public class OidcClientInitiatedServerLogoutSuccessHandler URI postLogoutRedirectUri = postLogoutRedirectUri(exchange.getExchange().getRequest()); return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri)); }) - .switchIfEmpty(this.serverLogoutSuccessHandler - .onLogoutSuccess(exchange, authentication).then(Mono.empty())) - .flatMap(endpointUri -> this.redirectStrategy.sendRedirect(exchange.getExchange(), endpointUri)); + .switchIfEmpty( + this.serverLogoutSuccessHandler.onLogoutSuccess(exchange, authentication).then(Mono.empty()) + ) + .flatMap((endpointUri) -> this.redirectStrategy.sendRedirect(exchange.getExchange(), endpointUri)); + // @formatter:on } private URI endSessionEndpoint(ClientRegistration clientRegistration) { - URI result = null; if (clientRegistration != null) { Object endSessionEndpoint = clientRegistration.getProviderDetails().getConfigurationMetadata() .get("end_session_endpoint"); if (endSessionEndpoint != null) { - result = URI.create(endSessionEndpoint.toString()); + return URI.create(endSessionEndpoint.toString()); } } - - return result; + return null; } private URI endpointUri(URI endSessionEndpoint, String idToken, URI postLogoutRedirectUri) { @@ -123,6 +123,7 @@ public class OidcClientInitiatedServerLogoutSuccessHandler if (this.postLogoutRedirectUri == null) { return null; } + // @formatter:off UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) .replacePath(request.getPath().contextPath().value()) .replaceQuery(null) @@ -131,12 +132,13 @@ public class OidcClientInitiatedServerLogoutSuccessHandler return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri) .buildAndExpand(Collections.singletonMap("baseUrl", uriComponents.toUriString())) .toUri(); + // @formatter:on } /** * Set the post logout redirect uri to use - * - * @param postLogoutRedirectUri - A valid URL to which the OP should redirect after logging out the user + * @param postLogoutRedirectUri - A valid URL to which the OP should redirect after + * logging out the user * @deprecated {@link #setPostLogoutRedirectUri(String)} */ @Deprecated @@ -153,11 +155,10 @@ public class OidcClientInitiatedServerLogoutSuccessHandler * handler.setPostLogoutRedirectUri("{baseUrl}"); *
        * - * will make so that {@code post_logout_redirect_uri} will be set to the base url for the client - * application. - * - * @param postLogoutRedirectUri - A template for creating the {@code post_logout_redirect_uri} - * query parameter + * will make so that {@code post_logout_redirect_uri} will be set to the base url for + * the client application. + * @param postLogoutRedirectUri - A template for creating the + * {@code post_logout_redirect_uri} query parameter * @since 5.3 */ public void setPostLogoutRedirectUri(String postLogoutRedirectUri) { @@ -166,12 +167,13 @@ public class OidcClientInitiatedServerLogoutSuccessHandler } /** - * The URL to redirect to after successfully logging out when not originally an OIDC login - * + * The URL to redirect to after successfully logging out when not originally an OIDC + * login * @param logoutSuccessUrl the url to redirect to. Default is "/login?logout". */ public void setLogoutSuccessUrl(URI logoutSuccessUrl) { Assert.notNull(logoutSuccessUrl, "logoutSuccessUrl cannot be null"); this.serverLogoutSuccessHandler.setLogoutSuccessUrl(logoutSuccessUrl); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/package-info.java index 0403e6c217..f8ab1aea4e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Core classes and interfaces providing support for OAuth 2.0 Client. */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index 939778b2ee..97292842f9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.registration; import java.io.Serializable; @@ -33,25 +34,35 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import static java.util.Collections.EMPTY_MAP; - /** - * A representation of a client registration with an OAuth 2.0 or OpenID Connect 1.0 Provider. + * A representation of a client registration with an OAuth 2.0 or OpenID Connect 1.0 + * Provider. * * @author Joe Grandja * @since 5.0 - * @see Section 2 Client Registration + * @see Section 2 + * Client Registration */ public final class ClientRegistration implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private String registrationId; + private String clientId; + private String clientSecret; + private ClientAuthenticationMethod clientAuthenticationMethod; + private AuthorizationGrantType authorizationGrantType; + private String redirectUri; + private Set scopes = Collections.emptySet(); + private ProviderDetails providerDetails = new ProviderDetails(); + private String clientName; private ClientRegistration() { @@ -59,7 +70,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the identifier for the registration. - * * @return the identifier for the registration */ public String getRegistrationId() { @@ -68,7 +78,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the client identifier. - * * @return the client identifier */ public String getClientId() { @@ -77,7 +86,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the client secret. - * * @return the client secret */ public String getClientSecret() { @@ -85,9 +93,8 @@ public final class ClientRegistration implements Serializable { } /** - * Returns the {@link ClientAuthenticationMethod authentication method} used - * when authenticating the client with the authorization server. - * + * Returns the {@link ClientAuthenticationMethod authentication method} used when + * authenticating the client with the authorization server. * @return the {@link ClientAuthenticationMethod} */ public ClientAuthenticationMethod getClientAuthenticationMethod() { @@ -95,8 +102,8 @@ public final class ClientRegistration implements Serializable { } /** - * Returns the {@link AuthorizationGrantType authorization grant type} used for the client. - * + * Returns the {@link AuthorizationGrantType authorization grant type} used for the + * client. * @return the {@link AuthorizationGrantType} */ public AuthorizationGrantType getAuthorizationGrantType() { @@ -105,7 +112,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the uri (or uri template) for the redirection endpoint. - * * @deprecated Use {@link #getRedirectUri()} instead * @return the uri (or uri template) for the redirection endpoint */ @@ -118,17 +124,19 @@ public final class ClientRegistration implements Serializable { * Returns the uri (or uri template) for the redirection endpoint. * *
        - * The supported uri template variables are: {baseScheme}, {baseHost}, {basePort}, {basePath} and {registrationId}. + * The supported uri template variables are: {baseScheme}, {baseHost}, {basePort}, + * {basePath} and {registrationId}. * *
        - * NOTE: {baseUrl} is also supported, which is the same as {baseScheme}://{baseHost}{basePort}{basePath}. + * NOTE: {baseUrl} is also supported, which is the same as + * {baseScheme}://{baseHost}{basePort}{basePath}. * *
        - * Configuring uri template variables is especially useful when the client is running behind a Proxy Server. - * This ensures that the X-Forwarded-* headers are used when expanding the redirect-uri. - * - * @since 5.4 + * Configuring uri template variables is especially useful when the client is running + * behind a Proxy Server. This ensures that the X-Forwarded-* headers are used when + * expanding the redirect-uri. * @return the uri (or uri template) for the redirection endpoint + * @since 5.4 */ public String getRedirectUri() { return this.redirectUri; @@ -136,7 +144,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the scope(s) used for the client. - * * @return the {@code Set} of scope(s) */ public Set getScopes() { @@ -145,7 +152,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the details of the provider. - * * @return the {@link ProviderDetails} */ public ProviderDetails getProviderDetails() { @@ -154,7 +160,6 @@ public final class ClientRegistration implements Serializable { /** * Returns the logical name of the client or registration. - * * @return the client or registration name */ public String getClientName() { @@ -163,136 +168,24 @@ public final class ClientRegistration implements Serializable { @Override public String toString() { + // @formatter:off return "ClientRegistration{" - + "registrationId='" + this.registrationId + '\'' - + ", clientId='" + this.clientId + '\'' - + ", clientSecret='" + this.clientSecret + '\'' - + ", clientAuthenticationMethod=" + this.clientAuthenticationMethod - + ", authorizationGrantType=" + this.authorizationGrantType - + ", redirectUri='" + this.redirectUri + '\'' - + ", scopes=" + this.scopes - + ", providerDetails=" + this.providerDetails - + ", clientName='" + this.clientName - + '\'' + '}'; + + "registrationId='" + this.registrationId + '\'' + + ", clientId='" + this.clientId + '\'' + + ", clientSecret='" + this.clientSecret + '\'' + + ", clientAuthenticationMethod=" + this.clientAuthenticationMethod + + ", authorizationGrantType=" + this.authorizationGrantType + + ", redirectUri='" + this.redirectUri + + '\'' + ", scopes=" + this.scopes + + ", providerDetails=" + this.providerDetails + + ", clientName='" + this.clientName + '\'' + + '}'; + // @formatter:on } /** - * Details of the Provider. - */ - public class ProviderDetails implements Serializable { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - private String authorizationUri; - private String tokenUri; - private UserInfoEndpoint userInfoEndpoint = new UserInfoEndpoint(); - private String jwkSetUri; - private String issuerUri; - private Map configurationMetadata = Collections.emptyMap(); - - private ProviderDetails() { - } - - /** - * Returns the uri for the authorization endpoint. - * - * @return the uri for the authorization endpoint - */ - public String getAuthorizationUri() { - return this.authorizationUri; - } - - /** - * Returns the uri for the token endpoint. - * - * @return the uri for the token endpoint - */ - public String getTokenUri() { - return this.tokenUri; - } - - /** - * Returns the details of the {@link UserInfoEndpoint UserInfo Endpoint}. - * - * @return the {@link UserInfoEndpoint} - */ - public UserInfoEndpoint getUserInfoEndpoint() { - return this.userInfoEndpoint; - } - - /** - * Returns the uri for the JSON Web Key (JWK) Set endpoint. - * - * @return the uri for the JSON Web Key (JWK) Set endpoint - */ - public String getJwkSetUri() { - return this.jwkSetUri; - } - - /** - * Returns the issuer identifier uri for the OpenID Connect 1.0 provider - * or the OAuth 2.0 Authorization Server. - * - * @since 5.4 - * @return the issuer identifier uri for the OpenID Connect 1.0 provider or the OAuth 2.0 Authorization Server - */ - public String getIssuerUri() { - return this.issuerUri; - } - - /** - * Returns a {@code Map} of the metadata describing the provider's configuration. - * - * @since 5.1 - * @return a {@code Map} of the metadata describing the provider's configuration - */ - public Map getConfigurationMetadata() { - return this.configurationMetadata; - } - - /** - * Details of the UserInfo Endpoint. - */ - public class UserInfoEndpoint implements Serializable { - private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - private String uri; - private AuthenticationMethod authenticationMethod = AuthenticationMethod.HEADER; - private String userNameAttributeName; - - private UserInfoEndpoint() { - } - - /** - * Returns the uri for the user info endpoint. - * - * @return the uri for the user info endpoint - */ - public String getUri() { - return this.uri; - } - - /** - * Returns the authentication method for the user info endpoint. - * - * @since 5.1 - * @return the {@link AuthenticationMethod} for the user info endpoint. - */ - public AuthenticationMethod getAuthenticationMethod() { - return this.authenticationMethod; - } - - /** - * Returns the attribute name used to access the user's name from the user info response. - * - * @return the attribute name used to access the user's name from the user info response - */ - public String getUserNameAttributeName() { - return this.userNameAttributeName; - } - } - } - - /** - * Returns a new {@link Builder}, initialized with the provided registration identifier. - * + * Returns a new {@link Builder}, initialized with the provided registration + * identifier. * @param registrationId the identifier for the registration * @return the {@link Builder} */ @@ -302,8 +195,8 @@ public final class ClientRegistration implements Serializable { } /** - * Returns a new {@link Builder}, initialized with the provided {@link ClientRegistration}. - * + * Returns a new {@link Builder}, initialized with the provided + * {@link ClientRegistration}. * @param clientRegistration the {@link ClientRegistration} to copy from * @return the {@link Builder} */ @@ -312,26 +205,164 @@ public final class ClientRegistration implements Serializable { return new Builder(clientRegistration); } + /** + * Details of the Provider. + */ + public class ProviderDetails implements Serializable { + + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + + private String authorizationUri; + + private String tokenUri; + + private UserInfoEndpoint userInfoEndpoint = new UserInfoEndpoint(); + + private String jwkSetUri; + + private String issuerUri; + + private Map configurationMetadata = Collections.emptyMap(); + + ProviderDetails() { + } + + /** + * Returns the uri for the authorization endpoint. + * @return the uri for the authorization endpoint + */ + public String getAuthorizationUri() { + return this.authorizationUri; + } + + /** + * Returns the uri for the token endpoint. + * @return the uri for the token endpoint + */ + public String getTokenUri() { + return this.tokenUri; + } + + /** + * Returns the details of the {@link UserInfoEndpoint UserInfo Endpoint}. + * @return the {@link UserInfoEndpoint} + */ + public UserInfoEndpoint getUserInfoEndpoint() { + return this.userInfoEndpoint; + } + + /** + * Returns the uri for the JSON Web Key (JWK) Set endpoint. + * @return the uri for the JSON Web Key (JWK) Set endpoint + */ + public String getJwkSetUri() { + return this.jwkSetUri; + } + + /** + * Returns the issuer identifier uri for the OpenID Connect 1.0 provider or the + * OAuth 2.0 Authorization Server. + * @return the issuer identifier uri for the OpenID Connect 1.0 provider or the + * OAuth 2.0 Authorization Server + * @since 5.4 + */ + public String getIssuerUri() { + return this.issuerUri; + } + + /** + * Returns a {@code Map} of the metadata describing the provider's configuration. + * @return a {@code Map} of the metadata describing the provider's configuration + * @since 5.1 + */ + public Map getConfigurationMetadata() { + return this.configurationMetadata; + } + + /** + * Details of the UserInfo Endpoint. + */ + public class UserInfoEndpoint implements Serializable { + + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + + private String uri; + + private AuthenticationMethod authenticationMethod = AuthenticationMethod.HEADER; + + private String userNameAttributeName; + + UserInfoEndpoint() { + } + + /** + * Returns the uri for the user info endpoint. + * @return the uri for the user info endpoint + */ + public String getUri() { + return this.uri; + } + + /** + * Returns the authentication method for the user info endpoint. + * @return the {@link AuthenticationMethod} for the user info endpoint. + * @since 5.1 + */ + public AuthenticationMethod getAuthenticationMethod() { + return this.authenticationMethod; + } + + /** + * Returns the attribute name used to access the user's name from the user + * info response. + * @return the attribute name used to access the user's name from the user + * info response + */ + public String getUserNameAttributeName() { + return this.userNameAttributeName; + } + + } + + } + /** * A builder for {@link ClientRegistration}. */ - public static class Builder implements Serializable { + public static final class Builder implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private String registrationId; + private String clientId; + private String clientSecret; + private ClientAuthenticationMethod clientAuthenticationMethod; + private AuthorizationGrantType authorizationGrantType; + private String redirectUri; + private Set scopes; + private String authorizationUri; + private String tokenUri; + private String userInfoUri; + private AuthenticationMethod userInfoAuthenticationMethod = AuthenticationMethod.HEADER; + private String userNameAttributeName; + private String jwkSetUri; + private String issuerUri; + private Map configurationMetadata = Collections.emptyMap(); + private String clientName; private Builder(String registrationId) { @@ -345,7 +376,7 @@ public final class ClientRegistration implements Serializable { this.clientAuthenticationMethod = clientRegistration.clientAuthenticationMethod; this.authorizationGrantType = clientRegistration.authorizationGrantType; this.redirectUri = clientRegistration.redirectUri; - this.scopes = clientRegistration.scopes == null ? null : new HashSet<>(clientRegistration.scopes); + this.scopes = (clientRegistration.scopes != null) ? new HashSet<>(clientRegistration.scopes) : null; this.authorizationUri = clientRegistration.providerDetails.authorizationUri; this.tokenUri = clientRegistration.providerDetails.tokenUri; this.userInfoUri = clientRegistration.providerDetails.userInfoEndpoint.uri; @@ -354,7 +385,7 @@ public final class ClientRegistration implements Serializable { this.jwkSetUri = clientRegistration.providerDetails.jwkSetUri; this.issuerUri = clientRegistration.providerDetails.issuerUri; Map configurationMetadata = clientRegistration.providerDetails.configurationMetadata; - if (configurationMetadata != EMPTY_MAP) { + if (configurationMetadata != Collections.EMPTY_MAP) { this.configurationMetadata = new HashMap<>(configurationMetadata); } this.clientName = clientRegistration.clientName; @@ -362,7 +393,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the registration id. - * * @param registrationId the registration id * @return the {@link Builder} */ @@ -373,7 +403,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the client identifier. - * * @param clientId the client identifier * @return the {@link Builder} */ @@ -384,7 +413,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the client secret. - * * @param clientSecret the client secret * @return the {@link Builder} */ @@ -394,9 +422,8 @@ public final class ClientRegistration implements Serializable { } /** - * Sets the {@link ClientAuthenticationMethod authentication method} used - * when authenticating the client with the authorization server. - * + * Sets the {@link ClientAuthenticationMethod authentication method} used when + * authenticating the client with the authorization server. * @param clientAuthenticationMethod the authentication method used for the client * @return the {@link Builder} */ @@ -406,8 +433,8 @@ public final class ClientRegistration implements Serializable { } /** - * Sets the {@link AuthorizationGrantType authorization grant type} used for the client. - * + * Sets the {@link AuthorizationGrantType authorization grant type} used for the + * client. * @param authorizationGrantType the authorization grant type used for the client * @return the {@link Builder} */ @@ -418,9 +445,9 @@ public final class ClientRegistration implements Serializable { /** * Sets the uri (or uri template) for the redirection endpoint. - * * @deprecated Use {@link #redirectUri(String)} instead - * @param redirectUriTemplate the uri (or uri template) for the redirection endpoint + * @param redirectUriTemplate the uri (or uri template) for the redirection + * endpoint * @return the {@link Builder} */ @Deprecated @@ -432,18 +459,20 @@ public final class ClientRegistration implements Serializable { * Sets the uri (or uri template) for the redirection endpoint. * *
        - * The supported uri template variables are: {baseScheme}, {baseHost}, {basePort}, {basePath} and {registrationId}. + * The supported uri template variables are: {baseScheme}, {baseHost}, {basePort}, + * {basePath} and {registrationId}. * *
        - * NOTE: {baseUrl} is also supported, which is the same as {baseScheme}://{baseHost}{basePort}{basePath}. + * NOTE: {baseUrl} is also supported, which is the same as + * {baseScheme}://{baseHost}{basePort}{basePath}. * *
        - * Configuring uri template variables is especially useful when the client is running behind a Proxy Server. - * This ensures that the X-Forwarded-* headers are used when expanding the redirect-uri. - * - * @since 5.4 + * Configuring uri template variables is especially useful when the client is + * running behind a Proxy Server. This ensures that the X-Forwarded-* headers are + * used when expanding the redirect-uri. * @param redirectUri the uri (or uri template) for the redirection endpoint * @return the {@link Builder} + * @since 5.4 */ public Builder redirectUri(String redirectUri) { this.redirectUri = redirectUri; @@ -452,35 +481,30 @@ public final class ClientRegistration implements Serializable { /** * Sets the scope(s) used for the client. - * * @param scope the scope(s) used for the client * @return the {@link Builder} */ public Builder scope(String... scope) { if (scope != null && scope.length > 0) { - this.scopes = Collections.unmodifiableSet( - new LinkedHashSet<>(Arrays.asList(scope))); + this.scopes = Collections.unmodifiableSet(new LinkedHashSet<>(Arrays.asList(scope))); } return this; } /** * Sets the scope(s) used for the client. - * * @param scope the scope(s) used for the client * @return the {@link Builder} */ public Builder scope(Collection scope) { if (scope != null && !scope.isEmpty()) { - this.scopes = Collections.unmodifiableSet( - new LinkedHashSet<>(scope)); + this.scopes = Collections.unmodifiableSet(new LinkedHashSet<>(scope)); } return this; } /** * Sets the uri for the authorization endpoint. - * * @param authorizationUri the uri for the authorization endpoint * @return the {@link Builder} */ @@ -491,7 +515,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the uri for the token endpoint. - * * @param tokenUri the uri for the token endpoint * @return the {@link Builder} */ @@ -502,7 +525,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the uri for the user info endpoint. - * * @param userInfoUri the uri for the user info endpoint * @return the {@link Builder} */ @@ -513,10 +535,10 @@ public final class ClientRegistration implements Serializable { /** * Sets the authentication method for the user info endpoint. - * - * @since 5.1 - * @param userInfoAuthenticationMethod the authentication method for the user info endpoint + * @param userInfoAuthenticationMethod the authentication method for the user info + * endpoint * @return the {@link Builder} + * @since 5.1 */ public Builder userInfoAuthenticationMethod(AuthenticationMethod userInfoAuthenticationMethod) { this.userInfoAuthenticationMethod = userInfoAuthenticationMethod; @@ -524,9 +546,10 @@ public final class ClientRegistration implements Serializable { } /** - * Sets the attribute name used to access the user's name from the user info response. - * - * @param userNameAttributeName the attribute name used to access the user's name from the user info response + * Sets the attribute name used to access the user's name from the user info + * response. + * @param userNameAttributeName the attribute name used to access the user's name + * from the user info response * @return the {@link Builder} */ public Builder userNameAttributeName(String userNameAttributeName) { @@ -536,7 +559,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the uri for the JSON Web Key (JWK) Set endpoint. - * * @param jwkSetUri the uri for the JSON Web Key (JWK) Set endpoint * @return the {@link Builder} */ @@ -546,12 +568,12 @@ public final class ClientRegistration implements Serializable { } /** - * Sets the issuer identifier uri for the OpenID Connect 1.0 provider - * or the OAuth 2.0 Authorization Server. - * - * @since 5.4 - * @param issuerUri the issuer identifier uri for the OpenID Connect 1.0 provider or the OAuth 2.0 Authorization Server + * Sets the issuer identifier uri for the OpenID Connect 1.0 provider or the OAuth + * 2.0 Authorization Server. + * @param issuerUri the issuer identifier uri for the OpenID Connect 1.0 provider + * or the OAuth 2.0 Authorization Server * @return the {@link Builder} + * @since 5.4 */ public Builder issuerUri(String issuerUri) { this.issuerUri = issuerUri; @@ -560,10 +582,10 @@ public final class ClientRegistration implements Serializable { /** * Sets the metadata describing the provider's configuration. - * - * @since 5.1 - * @param configurationMetadata the metadata describing the provider's configuration + * @param configurationMetadata the metadata describing the provider's + * configuration * @return the {@link Builder} + * @since 5.1 */ public Builder providerConfigurationMetadata(Map configurationMetadata) { if (configurationMetadata != null) { @@ -574,7 +596,6 @@ public final class ClientRegistration implements Serializable { /** * Sets the logical name of the client or registration. - * * @param clientName the client or registration name * @return the {@link Builder} */ @@ -585,18 +606,20 @@ public final class ClientRegistration implements Serializable { /** * Builds a new {@link ClientRegistration}. - * * @return a {@link ClientRegistration} */ public ClientRegistration build() { Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null"); if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType)) { this.validateClientCredentialsGrantType(); - } else if (AuthorizationGrantType.PASSWORD.equals(this.authorizationGrantType)) { + } + else if (AuthorizationGrantType.PASSWORD.equals(this.authorizationGrantType)) { this.validatePasswordGrantType(); - } else if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) { + } + else if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) { this.validateImplicitGrantType(); - } else if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType)) { + } + else if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType)) { this.validateAuthorizationCodeGrantType(); } this.validateScopes(); @@ -605,24 +628,29 @@ public final class ClientRegistration implements Serializable { private ClientRegistration create() { ClientRegistration clientRegistration = new ClientRegistration(); - clientRegistration.registrationId = this.registrationId; clientRegistration.clientId = this.clientId; clientRegistration.clientSecret = StringUtils.hasText(this.clientSecret) ? this.clientSecret : ""; - if (this.clientAuthenticationMethod != null) { - clientRegistration.clientAuthenticationMethod = this.clientAuthenticationMethod; - } else { - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType) && - !StringUtils.hasText(this.clientSecret)) { - clientRegistration.clientAuthenticationMethod = ClientAuthenticationMethod.NONE; - } else { - clientRegistration.clientAuthenticationMethod = ClientAuthenticationMethod.BASIC; - } - } + clientRegistration.clientAuthenticationMethod = (this.clientAuthenticationMethod != null) + ? this.clientAuthenticationMethod : deduceClientAuthenticationMethod(clientRegistration); clientRegistration.authorizationGrantType = this.authorizationGrantType; clientRegistration.redirectUri = this.redirectUri; clientRegistration.scopes = this.scopes; + clientRegistration.providerDetails = createProviderDetails(clientRegistration); + clientRegistration.clientName = StringUtils.hasText(this.clientName) ? this.clientName + : this.registrationId; + return clientRegistration; + } + private ClientAuthenticationMethod deduceClientAuthenticationMethod(ClientRegistration clientRegistration) { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType) + && !StringUtils.hasText(this.clientSecret)) { + return ClientAuthenticationMethod.NONE; + } + return ClientAuthenticationMethod.BASIC; + } + + private ProviderDetails createProviderDetails(ClientRegistration clientRegistration) { ProviderDetails providerDetails = clientRegistration.new ProviderDetails(); providerDetails.authorizationUri = this.authorizationUri; providerDetails.tokenUri = this.tokenUri; @@ -632,12 +660,7 @@ public final class ClientRegistration implements Serializable { providerDetails.jwkSetUri = this.jwkSetUri; providerDetails.issuerUri = this.issuerUri; providerDetails.configurationMetadata = Collections.unmodifiableMap(this.configurationMetadata); - clientRegistration.providerDetails = providerDetails; - - clientRegistration.clientName = StringUtils.hasText(this.clientName) ? - this.clientName : this.registrationId; - - return clientRegistration; + return providerDetails; } private void validateAuthorizationCodeGrantType() { @@ -679,22 +702,20 @@ public final class ClientRegistration implements Serializable { if (this.scopes == null) { return; } - for (String scope : this.scopes) { Assert.isTrue(validateScope(scope), "scope \"" + scope + "\" contains invalid characters"); } } private static boolean validateScope(String scope) { - return scope == null || - scope.chars().allMatch(c -> - withinTheRangeOf(c, 0x21, 0x21) || - withinTheRangeOf(c, 0x23, 0x5B) || - withinTheRangeOf(c, 0x5D, 0x7E)); + return scope == null || scope.chars().allMatch((c) -> withinTheRangeOf(c, 0x21, 0x21) + || withinTheRangeOf(c, 0x23, 0x5B) || withinTheRangeOf(c, 0x5D, 0x7E)); } private static boolean withinTheRangeOf(int c, int min, int max) { return c >= min && c <= max; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrationRepository.java index 10d27ec1ef..9292e5f490 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrationRepository.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.registration; /** * A repository for OAuth 2.0 / OpenID Connect 1.0 {@link ClientRegistration}(s). * *

        - * NOTE: Client registration information is ultimately stored and owned - * by the associated Authorization Server. - * Therefore, this repository provides the capability to store a sub-set copy - * of the primary client registration information - * externally from the Authorization Server. + * NOTE: Client registration information is ultimately stored and owned by the + * associated Authorization Server. Therefore, this repository provides the capability to + * store a sub-set copy of the primary client registration information externally + * from the Authorization Server. * * @author Joe Grandja * @since 5.0 @@ -32,8 +32,8 @@ package org.springframework.security.oauth2.client.registration; public interface ClientRegistrationRepository { /** - * Returns the client registration identified by the provided {@code registrationId}, or {@code null} if not found. - * + * Returns the client registration identified by the provided {@code registrationId}, + * or {@code null} if not found. * @param registrationId the registration identifier * @return the {@link ClientRegistration} if found, otherwise {@code null} */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java index 997b5ab763..762676b630 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java @@ -40,10 +40,11 @@ import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; /** - * Allows creating a {@link ClientRegistration.Builder} from an - * OpenID Provider Configuration - * or Authorization Server Metadata based on - * provided issuer. + * Allows creating a {@link ClientRegistration.Builder} from an OpenID + * Provider Configuration or + * Authorization Server + * Metadata based on provided issuer. * * @author Rob Winch * @author Josh Cummings @@ -51,24 +52,34 @@ import org.springframework.web.util.UriComponentsBuilder; * @since 5.1 */ public final class ClientRegistrations { + private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration"; + private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server"; + private static final RestTemplate rest = new RestTemplate(); - private static final ParameterizedTypeReference> typeReference = - new ParameterizedTypeReference>() {}; + + private static final ParameterizedTypeReference> typeReference = new ParameterizedTypeReference>() { + }; + + private ClientRegistrations() { + } /** - * Creates a {@link ClientRegistration.Builder} using the provided - * Issuer by making an - * OpenID Provider - * Configuration Request and using the values in the - * OpenID - * Provider Configuration Response to initialize the {@link ClientRegistration.Builder}. + * Creates a {@link ClientRegistration.Builder} using the provided Issuer + * by making an OpenID + * Provider Configuration Request and using the values in the OpenID + * Provider Configuration Response to initialize the + * {@link ClientRegistration.Builder}. * *

        - * For example, if the issuer provided is "https://example.com", then an "OpenID Provider Configuration Request" will - * be made to "https://example.com/.well-known/openid-configuration". The result is expected to be an "OpenID - * Provider Configuration Response". + * For example, if the issuer provided is "https://example.com", then an "OpenID + * Provider Configuration Request" will be made to + * "https://example.com/.well-known/openid-configuration". The result is expected to + * be an "OpenID Provider Configuration Response". *

        * *

        @@ -80,8 +91,10 @@ public final class ClientRegistrations { * .clientSecret("client-secret") * .build(); * - * @param issuer the Issuer - * @return a {@link ClientRegistration.Builder} that was initialized by the OpenID Provider Configuration. + * @param issuer the Issuer + * @return a {@link ClientRegistration.Builder} that was initialized by the OpenID + * Provider Configuration. */ public static ClientRegistration.Builder fromOidcIssuerLocation(String issuer) { Assert.hasText(issuer, "issuer cannot be empty"); @@ -89,29 +102,25 @@ public final class ClientRegistrations { } /** - * Creates a {@link ClientRegistration.Builder} using the provided - * Issuer by querying - * three different discovery endpoints serially, using the values in the first successful response to - * initialize. If an endpoint returns anything other than a 200 or a 4xx, the method will exit without - * attempting subsequent endpoints. + * Creates a {@link ClientRegistration.Builder} using the provided Issuer + * by querying three different discovery endpoints serially, using the values in the + * first successful response to initialize. If an endpoint returns anything other than + * a 200 or a 4xx, the method will exit without attempting subsequent endpoints. * - * The three endpoints are computed as follows, given that the {@code issuer} is composed of a {@code host} - * and a {@code path}: + * The three endpoints are computed as follows, given that the {@code issuer} is + * composed of a {@code host} and a {@code path}: * *

          - *
        1. - * {@code host/.well-known/openid-configuration/path}, as defined in - * RFC 8414's Compatibility Notes. - *
        2. - *
        3. - * {@code issuer/.well-known/openid-configuration}, as defined in - * - * OpenID Provider Configuration. - *
        4. - *
        5. - * {@code host/.well-known/oauth-authorization-server/path}, as defined in - * Authorization Server Metadata Request. - *
        6. + *
        7. {@code host/.well-known/openid-configuration/path}, as defined in + * RFC 8414's Compatibility + * Notes.
        8. + *
        9. {@code issuer/.well-known/openid-configuration}, as defined in + * OpenID Provider Configuration.
        10. + *
        11. {@code host/.well-known/oauth-authorization-server/path}, as defined in + * Authorization Server + * Metadata Request.
        12. *
        * * Note that the second endpoint is the equivalent of calling @@ -126,9 +135,9 @@ public final class ClientRegistrations { * .clientSecret("client-secret") * .build(); * - * * @param issuer - * @return a {@link ClientRegistration.Builder} that was initialized by one of the described endpoints + * @return a {@link ClientRegistration.Builder} that was initialized by one of the + * described endpoints */ public static ClientRegistration.Builder fromIssuerLocation(String issuer) { Assert.hasText(issuer, "issuer cannot be empty"); @@ -137,9 +146,11 @@ public final class ClientRegistrations { } private static Supplier oidc(URI issuer) { + // @formatter:off URI uri = UriComponentsBuilder.fromUri(issuer) - .replacePath(issuer.getPath() + OIDC_METADATA_PATH).build(Collections.emptyMap()); - + .replacePath(issuer.getPath() + OIDC_METADATA_PATH) + .build(Collections.emptyMap()); + // @formatter:on return () -> { RequestEntity request = RequestEntity.get(uri).build(); Map configuration = rest.exchange(request, typeReference).getBody(); @@ -154,14 +165,20 @@ public final class ClientRegistrations { } private static Supplier oidcRfc8414(URI issuer) { + // @formatter:off URI uri = UriComponentsBuilder.fromUri(issuer) - .replacePath(OIDC_METADATA_PATH + issuer.getPath()).build(Collections.emptyMap()); + .replacePath(OIDC_METADATA_PATH + issuer.getPath()) + .build(Collections.emptyMap()); + // @formatter:on return getRfc8414Builder(issuer, uri); } private static Supplier oauth(URI issuer) { + // @formatter:off URI uri = UriComponentsBuilder.fromUri(issuer) - .replacePath(OAUTH_METADATA_PATH + issuer.getPath()).build(Collections.emptyMap()); + .replacePath(OAUTH_METADATA_PATH + issuer.getPath()) + .build(Collections.emptyMap()); + // @formatter:on return getRfc8414Builder(issuer, uri); } @@ -171,12 +188,10 @@ public final class ClientRegistrations { Map configuration = rest.exchange(request, typeReference).getBody(); AuthorizationServerMetadata metadata = parse(configuration, AuthorizationServerMetadata::parse); ClientRegistration.Builder builder = withProviderConfiguration(metadata, issuer.toASCIIString()); - URI jwkSetUri = metadata.getJWKSetURI(); if (jwkSetUri != null) { builder.jwkSetUri(jwkSetUri.toASCIIString()); } - String userinfoEndpoint = (String) configuration.get("userinfo_endpoint"); if (userinfoEndpoint != null) { builder.userInfoUri(userinfoEndpoint); @@ -186,56 +201,56 @@ public final class ClientRegistrations { } @SafeVarargs - private static ClientRegistration.Builder getBuilder(String issuer, Supplier... suppliers) { + private static ClientRegistration.Builder getBuilder(String issuer, + Supplier... suppliers) { String errorMessage = "Unable to resolve Configuration with the provided Issuer of \"" + issuer + "\""; for (Supplier supplier : suppliers) { try { return supplier.get(); - } catch (HttpClientErrorException e) { - if (!e.getStatusCode().is4xxClientError()) { - throw e; + } + catch (HttpClientErrorException ex) { + if (!ex.getStatusCode().is4xxClientError()) { + throw ex; } // else try another endpoint - } catch (IllegalArgumentException | IllegalStateException e) { - throw e; - } catch (RuntimeException e) { - throw new IllegalArgumentException(errorMessage, e); + } + catch (IllegalArgumentException | IllegalStateException ex) { + throw ex; + } + catch (RuntimeException ex) { + throw new IllegalArgumentException(errorMessage, ex); } } throw new IllegalArgumentException(errorMessage); } - private static T parse(Map body, - ThrowingFunction parser) { - + private static T parse(Map body, ThrowingFunction parser) { try { return parser.apply(new JSONObject(body)); - } catch (ParseException e) { - throw new RuntimeException(e); + } + catch (ParseException ex) { + throw new RuntimeException(ex); } } - private interface ThrowingFunction { - T apply(S src) throws E; - } - - private static ClientRegistration.Builder withProviderConfiguration(AuthorizationServerMetadata metadata, String issuer) { + private static ClientRegistration.Builder withProviderConfiguration(AuthorizationServerMetadata metadata, + String issuer) { String metadataIssuer = metadata.getIssuer().getValue(); - if (!issuer.equals(metadataIssuer)) { - throw new IllegalStateException("The Issuer \"" + metadataIssuer + "\" provided in the configuration metadata did " - + "not match the requested issuer \"" + issuer + "\""); - } - + Assert.state(issuer.equals(metadataIssuer), + () -> "The Issuer \"" + metadataIssuer + "\" provided in the configuration metadata did " + + "not match the requested issuer \"" + issuer + "\""); String name = URI.create(issuer).getHost(); - ClientAuthenticationMethod method = getClientAuthenticationMethod(issuer, metadata.getTokenEndpointAuthMethods()); + ClientAuthenticationMethod method = getClientAuthenticationMethod(issuer, + metadata.getTokenEndpointAuthMethods()); List grantTypes = metadata.getGrantTypes(); // If null, the default includes authorization_code if (grantTypes != null && !grantTypes.contains(GrantType.AUTHORIZATION_CODE)) { - throw new IllegalArgumentException("Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + issuer + - "\" returned a configuration of " + grantTypes); + throw new IllegalArgumentException( + "Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + issuer + + "\" returned a configuration of " + grantTypes); } Map configurationMetadata = new LinkedHashMap<>(metadata.toJSONObject()); - + // @formatter:off return ClientRegistration.withRegistrationId(name) .userNameAttributeName(IdTokenClaimNames.SUB) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) @@ -246,11 +261,13 @@ public final class ClientRegistrations { .tokenUri(metadata.getTokenEndpointURI().toASCIIString()) .issuerUri(issuer) .clientName(issuer); + // @formatter:on } private static ClientAuthenticationMethod getClientAuthenticationMethod(String issuer, List metadataAuthMethods) { - if (metadataAuthMethods == null || metadataAuthMethods.contains(com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod.CLIENT_SECRET_BASIC)) { + if (metadataAuthMethods == null || metadataAuthMethods + .contains(com.nimbusds.oauth2.sdk.auth.ClientAuthenticationMethod.CLIENT_SECRET_BASIC)) { // If null, the default includes client_secret_basic return ClientAuthenticationMethod.BASIC; } @@ -261,9 +278,14 @@ public final class ClientRegistrations { return ClientAuthenticationMethod.NONE; } throw new IllegalArgumentException("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and " - + "ClientAuthenticationMethod.NONE are supported. The issuer \"" + issuer + "\" returned a configuration of " + metadataAuthMethods); + + "ClientAuthenticationMethod.NONE are supported. The issuer \"" + issuer + + "\" returned a configuration of " + metadataAuthMethods); } - private ClientRegistrations() {} + private interface ThrowingFunction { + + T apply(S src) throws E; + + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java index 7378d5fe4f..f0092368f3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.registration; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.client.registration; import java.util.Arrays; import java.util.Collections; @@ -24,8 +23,11 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.springframework.util.Assert; + /** - * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory. + * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) + * in-memory. * * @author Joe Grandja * @author Rob Winch @@ -34,12 +36,14 @@ import java.util.concurrent.ConcurrentHashMap; * @see ClientRegistrationRepository * @see ClientRegistration */ -public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository, Iterable { +public final class InMemoryClientRegistrationRepository + implements ClientRegistrationRepository, Iterable { + private final Map registrations; /** - * Constructs an {@code InMemoryClientRegistrationRepository} using the provided parameters. - * + * Constructs an {@code InMemoryClientRegistrationRepository} using the provided + * parameters. * @param registrations the client registration(s) */ public InMemoryClientRegistrationRepository(ClientRegistration... registrations) { @@ -47,8 +51,8 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr } /** - * Constructs an {@code InMemoryClientRegistrationRepository} using the provided parameters. - * + * Constructs an {@code InMemoryClientRegistrationRepository} using the provided + * parameters. * @param registrations the client registration(s) */ public InMemoryClientRegistrationRepository(List registrations) { @@ -63,21 +67,19 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr private static Map toUnmodifiableConcurrentMap(List registrations) { ConcurrentHashMap result = new ConcurrentHashMap<>(); for (ClientRegistration registration : registrations) { - if (result.containsKey(registration.getRegistrationId())) { - throw new IllegalStateException(String.format("Duplicate key %s", - registration.getRegistrationId())); - } + Assert.state(!result.containsKey(registration.getRegistrationId()), + () -> String.format("Duplicate key %s", registration.getRegistrationId())); result.put(registration.getRegistrationId(), registration); } return Collections.unmodifiableMap(result); } /** - * Constructs an {@code InMemoryClientRegistrationRepository} using the provided {@code Map} - * of {@link ClientRegistration#getRegistrationId() registration id} to {@link ClientRegistration}. - * - * @since 5.2 + * Constructs an {@code InMemoryClientRegistrationRepository} using the provided + * {@code Map} of {@link ClientRegistration#getRegistrationId() registration id} to + * {@link ClientRegistration}. * @param registrations the {@code Map} of client registration(s) + * @since 5.2 */ public InMemoryClientRegistrationRepository(Map registrations) { Assert.notNull(registrations, "registrations cannot be null"); @@ -92,11 +94,11 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr /** * Returns an {@code Iterator} of {@link ClientRegistration}. - * * @return an {@code Iterator} */ @Override public Iterator iterator() { return this.registrations.values().iterator(); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java index 98ba597c41..9a49d789a6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.registration; import java.util.Arrays; @@ -27,7 +28,8 @@ import reactor.core.publisher.Mono; import org.springframework.util.Assert; /** - * A Reactive {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory. + * A Reactive {@link ClientRegistrationRepository} that stores + * {@link ClientRegistration}(s) in-memory. * * @author Rob Winch * @author Ebert Toribio @@ -41,8 +43,8 @@ public final class InMemoryReactiveClientRegistrationRepository private final Map clientIdToClientRegistration; /** - * Constructs an {@code InMemoryReactiveClientRegistrationRepository} using the provided parameters. - * + * Constructs an {@code InMemoryReactiveClientRegistrationRepository} using the + * provided parameters. * @param registrations the client registration(s) */ public InMemoryReactiveClientRegistrationRepository(ClientRegistration... registrations) { @@ -55,8 +57,8 @@ public final class InMemoryReactiveClientRegistrationRepository } /** - * Constructs an {@code InMemoryReactiveClientRegistrationRepository} using the provided parameters. - * + * Constructs an {@code InMemoryReactiveClientRegistrationRepository} using the + * provided parameters. * @param registrations the client registration(s) */ public InMemoryReactiveClientRegistrationRepository(List registrations) { @@ -70,7 +72,6 @@ public final class InMemoryReactiveClientRegistrationRepository /** * Returns an {@code Iterator} of {@link ClientRegistration}. - * * @return an {@code Iterator} */ @Override @@ -84,11 +85,11 @@ public final class InMemoryReactiveClientRegistrationRepository for (ClientRegistration registration : registrations) { Assert.notNull(registration, "no registration can be null"); if (result.containsKey(registration.getRegistrationId())) { - throw new IllegalStateException(String.format("Duplicate key %s", - registration.getRegistrationId())); + throw new IllegalStateException(String.format("Duplicate key %s", registration.getRegistrationId())); } result.put(registration.getRegistrationId(), registration); } return Collections.unmodifiableMap(result); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ReactiveClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ReactiveClientRegistrationRepository.java index 731e414141..3d9c82aaf4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ReactiveClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ReactiveClientRegistrationRepository.java @@ -22,11 +22,10 @@ import reactor.core.publisher.Mono; * A reactive repository for OAuth 2.0 / OpenID Connect 1.0 {@link ClientRegistration}(s). * *

        - * NOTE: Client registration information is ultimately stored and owned - * by the associated Authorization Server. - * Therefore, this repository provides the capability to store a sub-set copy - * of the primary client registration information - * externally from the Authorization Server. + * NOTE: Client registration information is ultimately stored and owned by the + * associated Authorization Server. Therefore, this repository provides the capability to + * store a sub-set copy of the primary client registration information externally + * from the Authorization Server. * * @author Rob Winch * @since 5.1 @@ -35,8 +34,8 @@ import reactor.core.publisher.Mono; public interface ReactiveClientRegistrationRepository { /** - * Returns the client registration identified by the provided {@code registrationId}, or {@code null} if not found. - * + * Returns the client registration identified by the provided {@code registrationId}, + * or {@code null} if not found. * @param registrationId the registration identifier * @return the {@link ClientRegistration} if found, otherwise {@code null} */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/package-info.java index 9f133f1165..5e46a72fb4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/package-info.java @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Classes and interfaces that provide support for {@link org.springframework.security.oauth2.client.registration.ClientRegistration}. + * Classes and interfaces that provide support for + * {@link org.springframework.security.oauth2.client.registration.ClientRegistration}. */ package org.springframework.security.oauth2.client.registration; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java index df77620f91..d5b8d0ab85 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; @@ -29,21 +34,20 @@ import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; - /** - * An implementation of an {@link OAuth2UserService} that supports custom {@link OAuth2User} types. + * An implementation of an {@link OAuth2UserService} that supports custom + * {@link OAuth2User} types. *

        - * The custom user type(s) is supplied via the constructor, - * using a {@code Map} of {@link OAuth2User} type(s) keyed by {@code String}, - * which represents the {@link ClientRegistration#getRegistrationId() Registration Id} of the Client. - * - * @deprecated It is recommended to use a delegation-based strategy of an {@link OAuth2UserService} to support custom {@link OAuth2User} types, - * as it provides much greater flexibility compared to this implementation. - * See the reference manual for details on how to implement. + * The custom user type(s) is supplied via the constructor, using a {@code Map} of + * {@link OAuth2User} type(s) keyed by {@code String}, which represents the + * {@link ClientRegistration#getRegistrationId() Registration Id} of the Client. * + * @deprecated It is recommended to use a delegation-based strategy of an + * {@link OAuth2UserService} to support custom {@link OAuth2User} types, as it provides + * much greater flexibility compared to this implementation. See the + * reference + * manual for details on how to implement. * @author Joe Grandja * @since 5.0 * @see OAuth2UserService @@ -53,6 +57,7 @@ import java.util.Map; */ @Deprecated public class CustomUserTypesOAuth2UserService implements OAuth2UserService { + private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; private final Map> customUserTypes; @@ -62,9 +67,10 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService> customUserTypes) { Assert.notEmpty(customUserTypes, "customUserTypes cannot be empty"); @@ -78,33 +84,34 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService customUserType; - if ((customUserType = this.customUserTypes.get(registrationId)) == null) { + Class customUserType = this.customUserTypes.get(registrationId); + if (customUserType == null) { return null; } - RequestEntity request = this.requestEntityConverter.convert(userRequest); + ResponseEntity response = getResponse(customUserType, request); + OAuth2User oauth2User = response.getBody(); + return oauth2User; + } - ResponseEntity response; + private ResponseEntity getResponse(Class customUserType, + RequestEntity request) { try { - response = this.restOperations.exchange(request, customUserType); - } catch (RestClientException ex) { + return this.restOperations.exchange(request, customUserType); + } + catch (RestClientException ex) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, "An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); } - - OAuth2User oauth2User = response.getBody(); - - return oauth2User; } /** - * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} - * to a {@link RequestEntity} representation of the UserInfo Request. - * + * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} to a + * {@link RequestEntity} representation of the UserInfo Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the UserInfo Request * @since 5.1 - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request */ public final void setRequestEntityConverter(Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); @@ -115,16 +122,18 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService - * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: *

          - *
        1. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        2. + *
        3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        4. *
        - * + * @param restOperations the {@link RestOperations} used when requesting the UserInfo + * resource * @since 5.1 - * @param restOperations the {@link RestOperations} used when requesting the UserInfo resource */ public final void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index b3d00a077c..78d797d38f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; import java.util.LinkedHashSet; @@ -43,14 +44,17 @@ import org.springframework.web.client.RestTemplate; import org.springframework.web.client.UnknownContentTypeException; /** - * An implementation of an {@link OAuth2UserService} that supports standard OAuth 2.0 Provider's. + * An implementation of an {@link OAuth2UserService} that supports standard OAuth 2.0 + * Provider's. *

        * For standard OAuth 2.0 Provider's, the attribute name used to access the user's name * from the UserInfo response is required and therefore must be available via - * {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}. + * {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() + * UserInfoEndpoint.getUserNameAttributeName()}. *

        - * NOTE: Attribute names are not standardized between providers and therefore will vary. - * Please consult the provider's API documentation for the set of supported user attribute names. + * NOTE: Attribute names are not standardized between providers and + * therefore will vary. Please consult the provider's API documentation for the set of + * supported user attribute names. * * @author Joe Grandja * @since 5.0 @@ -60,14 +64,15 @@ import org.springframework.web.client.UnknownContentTypeException; * @see DefaultOAuth2User */ public class DefaultOAuth2UserService implements OAuth2UserService { + private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri"; private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute"; private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; - private static final ParameterizedTypeReference> PARAMETERIZED_RESPONSE_TYPE = - new ParameterizedTypeReference>() {}; + private static final ParameterizedTypeReference> PARAMETERIZED_RESPONSE_TYPE = new ParameterizedTypeReference>() { + }; private Converter> requestEntityConverter = new OAuth2UserRequestEntityConverter(); @@ -82,64 +87,25 @@ public class DefaultOAuth2UserService implements OAuth2UserService request = this.requestEntityConverter.convert(userRequest); - - ResponseEntity> response; - try { - response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE); - } catch (OAuth2AuthorizationException ex) { - OAuth2Error oauth2Error = ex.getError(); - StringBuilder errorDetails = new StringBuilder(); - errorDetails.append("Error details: ["); - errorDetails.append("UserInfo Uri: ").append( - userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri()); - errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode()); - if (oauth2Error.getDescription() != null) { - errorDetails.append(", Error Description: ").append(oauth2Error.getDescription()); - } - errorDetails.append("]"); - oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } catch (UnknownContentTypeException ex) { - String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '" + - userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri() + - "': response contains invalid content type '" + ex.getContentType().toString() + "'. " + - "The UserInfo Response should return a JSON object (content type 'application/json') " + - "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + - "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" + - userRequest.getClientRegistration().getRegistrationId() + "' conforms to the UserInfo Endpoint, " + - "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"; - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage, null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } - + ResponseEntity> response = getResponse(userRequest, request); Map userAttributes = response.getBody(); Set authorities = new LinkedHashSet<>(); authorities.add(new OAuth2UserAuthority(userAttributes)); @@ -147,16 +113,54 @@ public class DefaultOAuth2UserService implements OAuth2UserService> getResponse(OAuth2UserRequest userRequest, RequestEntity request) { + try { + return this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE); + } + catch (OAuth2AuthorizationException ex) { + OAuth2Error oauth2Error = ex.getError(); + StringBuilder errorDetails = new StringBuilder(); + errorDetails.append("Error details: ["); + errorDetails.append("UserInfo Uri: ") + .append(userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri()); + errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode()); + if (oauth2Error.getDescription() != null) { + errorDetails.append(", Error Description: ").append(oauth2Error.getDescription()); + } + errorDetails.append("]"); + oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + catch (UnknownContentTypeException ex) { + String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '" + + userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri() + + "': response contains invalid content type '" + ex.getContentType().toString() + "'. " + + "The UserInfo Response should return a JSON object (content type 'application/json') " + + "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + + "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" + + userRequest.getClientRegistration().getRegistrationId() + "' conforms to the UserInfo Endpoint, " + + "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"; + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage, null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + } + /** - * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} - * to a {@link RequestEntity} representation of the UserInfo Request. - * + * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} to a + * {@link RequestEntity} representation of the UserInfo Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the UserInfo Request * @since 5.1 - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request */ public final void setRequestEntityConverter(Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); @@ -167,16 +171,18 @@ public class DefaultOAuth2UserService implements OAuth2UserService - * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: *

          - *
        1. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        2. + *
        3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
        4. *
        - * + * @param restOperations the {@link RestOperations} used when requesting the UserInfo + * resource * @since 5.1 - * @param restOperations the {@link RestOperations} used when requesting the UserInfo resource */ public final void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java index 8e547be7f9..6494135063 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java @@ -16,12 +16,16 @@ package org.springframework.security.oauth2.client.userinfo; - import java.io.IOException; import java.util.HashSet; import java.util.Map; import java.util.Set; +import com.nimbusds.oauth2.sdk.ErrorObject; +import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse; +import net.minidev.json.JSONObject; +import reactor.core.publisher.Mono; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; @@ -42,21 +46,18 @@ import org.springframework.web.reactive.function.UnsupportedMediaTypeException; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; -import com.nimbusds.oauth2.sdk.ErrorObject; -import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse; - -import net.minidev.json.JSONObject; -import reactor.core.publisher.Mono; - /** - * An implementation of an {@link ReactiveOAuth2UserService} that supports standard OAuth 2.0 Provider's. + * An implementation of an {@link ReactiveOAuth2UserService} that supports standard OAuth + * 2.0 Provider's. *

        * For standard OAuth 2.0 Provider's, the attribute name used to access the user's name * from the UserInfo response is required and therefore must be available via - * {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}. + * {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() + * UserInfoEndpoint.getUserNameAttributeName()}. *

        - * NOTE: Attribute names are not standardized between providers and therefore will vary. - * Please consult the provider's API documentation for the set of supported user attribute names. + * NOTE: Attribute names are not standardized between providers and + * therefore will vary. Please consult the provider's API documentation for the set of + * supported user attribute names. * * @author Rob Winch * @since 5.1 @@ -66,71 +67,60 @@ import reactor.core.publisher.Mono; * @see DefaultOAuth2User */ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserService { + private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; + private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri"; + private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute"; + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; + + private static final ParameterizedTypeReference> STRING_STRING_MAP = new ParameterizedTypeReference>() { + }; + private WebClient webClient = WebClient.create(); @Override - public Mono loadUser(OAuth2UserRequest userRequest) - throws OAuth2AuthenticationException { + public Mono loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException { return Mono.defer(() -> { Assert.notNull(userRequest, "userRequest cannot be null"); - - String userInfoUri = userRequest.getClientRegistration().getProviderDetails() - .getUserInfoEndpoint().getUri(); - if (!StringUtils.hasText( - userInfoUri)) { - OAuth2Error oauth2Error = new OAuth2Error( - MISSING_USER_INFO_URI_ERROR_CODE, + String userInfoUri = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() + .getUri(); + if (!StringUtils.hasText(userInfoUri)) { + OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_INFO_URI_ERROR_CODE, "Missing required UserInfo Uri in UserInfoEndpoint for Client Registration: " + userRequest.getClientRegistration().getRegistrationId(), null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() - .getUserNameAttributeName(); + String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails() + .getUserInfoEndpoint().getUserNameAttributeName(); if (!StringUtils.hasText(userNameAttributeName)) { - OAuth2Error oauth2Error = new OAuth2Error( - MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE, + OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE, "Missing required \"user name\" attribute name in UserInfoEndpoint for Client Registration: " + userRequest.getClientRegistration().getRegistrationId(), null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - - ParameterizedTypeReference> typeReference = new ParameterizedTypeReference>() { - }; - AuthenticationMethod authenticationMethod = userRequest.getClientRegistration().getProviderDetails() .getUserInfoEndpoint().getAuthenticationMethod(); - WebClient.RequestHeadersSpec requestHeadersSpec; - if (AuthenticationMethod.FORM.equals(authenticationMethod)) { - requestHeadersSpec = this.webClient.post() - .uri(userInfoUri) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) - .syncBody("access_token=" + userRequest.getAccessToken().getTokenValue()); - } else { - requestHeadersSpec = this.webClient.get() - .uri(userInfoUri) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .headers(headers -> headers.setBearerAuth(userRequest.getAccessToken().getTokenValue())); - } - Mono> userAttributes = requestHeadersSpec - .retrieve() - .onStatus(s -> s != HttpStatus.OK, response -> parse(response).map(userInfoErrorResponse -> { - String description = userInfoErrorResponse.getErrorObject().getDescription(); - OAuth2Error oauth2Error = new OAuth2Error( - INVALID_USER_INFO_RESPONSE_ERROR_CODE, description, - null); - throw new OAuth2AuthenticationException(oauth2Error, - oauth2Error.toString()); - })) - .bodyToMono(typeReference); - - return userAttributes.map(attrs -> { + WebClient.RequestHeadersSpec requestHeadersSpec = getRequestHeaderSpec(userRequest, userInfoUri, + authenticationMethod); + // @formatter:off + Mono> userAttributes = requestHeadersSpec.retrieve() + .onStatus((s) -> s != HttpStatus.OK, (response) -> + parse(response) + .map((userInfoErrorResponse) -> { + String description = userInfoErrorResponse.getErrorObject().getDescription(); + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, description, + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + }) + ) + .bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP); + return userAttributes.map((attrs) -> { GrantedAuthority authority = new OAuth2UserAuthority(attrs); Set authorities = new HashSet<>(); authorities.add(authority); @@ -141,24 +131,53 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi return new DefaultOAuth2User(authorities, attrs, userNameAttributeName); }) - .onErrorMap(IOException.class, e -> new AuthenticationServiceException("Unable to access the userInfoEndpoint " + userInfoUri, e)) - .onErrorMap(UnsupportedMediaTypeException.class, e -> { - String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '" + - userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri() + - "': response contains invalid content type '" + e.getContentType().toString() + "'. " + - "The UserInfo Response should return a JSON object (content type 'application/json') " + - "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + - "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" + - userRequest.getClientRegistration().getRegistrationId() + "' conforms to the UserInfo Endpoint, " + - "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"; - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage, null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), e); + .onErrorMap(IOException.class, + (ex) -> new AuthenticationServiceException("Unable to access the userInfoEndpoint " + userInfoUri, + ex) + ) + .onErrorMap(UnsupportedMediaTypeException.class, (ex) -> { + String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '" + + userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() + .getUri() + + "': response contains invalid content type '" + ex.getContentType().toString() + "'. " + + "The UserInfo Response should return a JSON object (content type 'application/json') " + + "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + + "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" + + userRequest.getClientRegistration().getRegistrationId() + + "' conforms to the UserInfo Endpoint, " + + "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"; + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage, + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); }) - .onErrorMap(t -> !(t instanceof AuthenticationServiceException), t -> { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, "An error occurred reading the UserInfo Success response: " + t.getMessage(), null); + .onErrorMap((t) -> !(t instanceof AuthenticationServiceException), (t) -> { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, + "An error occurred reading the UserInfo Success response: " + t.getMessage(), null); return new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), t); }); }); + // @formatter:on + } + + private WebClient.RequestHeadersSpec getRequestHeaderSpec(OAuth2UserRequest userRequest, String userInfoUri, + AuthenticationMethod authenticationMethod) { + if (AuthenticationMethod.FORM.equals(authenticationMethod)) { + // @formatter:off + return this.webClient.post() + .uri(userInfoUri) + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .bodyValue("access_token=" + userRequest.getAccessToken().getTokenValue()); + // @formatter:on + } + // @formatter:off + return this.webClient.get() + .uri(userInfoUri) + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .headers((headers) -> headers + .setBearerAuth(userRequest.getAccessToken().getTokenValue()) + ); + // @formatter:on } /** @@ -171,19 +190,14 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi } private static Mono parse(ClientResponse httpResponse) { - String wwwAuth = httpResponse.headers().asHttpHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE); - if (!StringUtils.isEmpty(wwwAuth)) { // Bearer token error? return Mono.fromCallable(() -> UserInfoErrorResponse.parse(wwwAuth)); } - - ParameterizedTypeReference> typeReference = - new ParameterizedTypeReference>() {}; // Other error? - return httpResponse - .bodyToMono(typeReference) - .map(body -> new UserInfoErrorResponse(ErrorObject.parse(new JSONObject(body)))); + return httpResponse.bodyToMono(STRING_STRING_MAP) + .map((body) -> new UserInfoErrorResponse(ErrorObject.parse(new JSONObject(body)))); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserService.java index a44058f2c2..32e9ab81de 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserService.java @@ -13,40 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.userinfo; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.user.OAuth2User; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.client.userinfo; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.util.Assert; + /** - * An implementation of an {@link OAuth2UserService} that simply delegates - * to it's internal {@code List} of {@link OAuth2UserService}(s). + * An implementation of an {@link OAuth2UserService} that simply delegates to it's + * internal {@code List} of {@link OAuth2UserService}(s). *

        * Each {@link OAuth2UserService} is given a chance to - * {@link OAuth2UserService#loadUser(OAuth2UserRequest) load} an {@link OAuth2User} - * with the first {@code non-null} {@link OAuth2User} being returned. + * {@link OAuth2UserService#loadUser(OAuth2UserRequest) load} an {@link OAuth2User} with + * the first {@code non-null} {@link OAuth2User} being returned. * + * @param The type of OAuth 2.0 User Request + * @param The type of OAuth 2.0 User * @author Joe Grandja * @since 5.0 * @see OAuth2UserService * @see OAuth2UserRequest * @see OAuth2User - * - * @param The type of OAuth 2.0 User Request - * @param The type of OAuth 2.0 User */ -public class DelegatingOAuth2UserService implements OAuth2UserService { +public class DelegatingOAuth2UserService + implements OAuth2UserService { + private final List> userServices; /** * Constructs a {@code DelegatingOAuth2UserService} using the provided parameters. - * * @param userServices a {@code List} of {@link OAuth2UserService}(s) */ public DelegatingOAuth2UserService(List> userServices) { @@ -57,10 +58,13 @@ public class DelegatingOAuth2UserService userService.loadUser(userRequest)) - .filter(Objects::nonNull) - .findFirst() - .orElse(null); + .map((userService) -> userService.loadUser(userRequest)) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java index 34dc78f7b5..e0582e9732 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java @@ -13,20 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; - /** - * Represents a request the {@link OAuth2UserService} uses - * when initiating a request to the UserInfo Endpoint. + * Represents a request the {@link OAuth2UserService} uses when initiating a request to + * the UserInfo Endpoint. * * @author Joe Grandja * @since 5.0 @@ -35,13 +36,15 @@ import java.util.Map; * @see OAuth2UserService */ public class OAuth2UserRequest { + private final ClientRegistration clientRegistration; + private final OAuth2AccessToken accessToken; + private final Map additionalParameters; /** * Constructs an {@code OAuth2UserRequest} using the provided parameters. - * * @param clientRegistration the client registration * @param accessToken the access token */ @@ -51,26 +54,23 @@ public class OAuth2UserRequest { /** * Constructs an {@code OAuth2UserRequest} using the provided parameters. - * - * @since 5.1 * @param clientRegistration the client registration * @param accessToken the access token * @param additionalParameters the additional parameters, may be empty + * @since 5.1 */ public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, - Map additionalParameters) { + Map additionalParameters) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(accessToken, "accessToken cannot be null"); this.clientRegistration = clientRegistration; this.accessToken = accessToken; - this.additionalParameters = Collections.unmodifiableMap( - CollectionUtils.isEmpty(additionalParameters) ? - Collections.emptyMap() : new LinkedHashMap<>(additionalParameters)); + this.additionalParameters = Collections.unmodifiableMap(CollectionUtils.isEmpty(additionalParameters) + ? Collections.emptyMap() : new LinkedHashMap<>(additionalParameters)); } /** * Returns the {@link ClientRegistration client registration}. - * * @return the {@link ClientRegistration} */ public ClientRegistration getClientRegistration() { @@ -79,7 +79,6 @@ public class OAuth2UserRequest { /** * Returns the {@link OAuth2AccessToken access token}. - * * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { @@ -88,11 +87,11 @@ public class OAuth2UserRequest { /** * Returns the additional parameters that may be used in the request. - * - * @since 5.1 * @return a {@code Map} of the additional parameters, may be empty. + * @since 5.1 */ public Map getAdditionalParameters() { return this.additionalParameters; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java index f373a21012..9a7a3c8dd8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; +import java.net.URI; +import java.util.Collections; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -27,14 +31,9 @@ import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; -import java.util.Collections; - -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; - /** - * A {@link Converter} that converts the provided {@link OAuth2UserRequest} - * to a {@link RequestEntity} representation of a request for the UserInfo Endpoint. + * A {@link Converter} that converts the provided {@link OAuth2UserRequest} to a + * {@link RequestEntity} representation of a request for the UserInfo Endpoint. * * @author Joe Grandja * @since 5.1 @@ -43,27 +42,23 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @see RequestEntity */ public class OAuth2UserRequestEntityConverter implements Converter> { - private static final MediaType DEFAULT_CONTENT_TYPE = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + + private static final MediaType DEFAULT_CONTENT_TYPE = MediaType + .valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); /** * Returns the {@link RequestEntity} used for the UserInfo Request. - * * @param userRequest the user request * @return the {@link RequestEntity} used for the UserInfo Request */ @Override public RequestEntity convert(OAuth2UserRequest userRequest) { ClientRegistration clientRegistration = userRequest.getClientRegistration(); - - HttpMethod httpMethod = HttpMethod.GET; - if (AuthenticationMethod.FORM.equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) { - httpMethod = HttpMethod.POST; - } + HttpMethod httpMethod = getHttpMethod(clientRegistration); HttpHeaders headers = new HttpHeaders(); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); - URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()) - .build() - .toUri(); + URI uri = UriComponentsBuilder + .fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).build().toUri(); RequestEntity request; if (HttpMethod.POST.equals(httpMethod)) { @@ -71,11 +66,21 @@ public class OAuth2UserRequestEntityConverter implements Converter formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.ACCESS_TOKEN, userRequest.getAccessToken().getTokenValue()); request = new RequestEntity<>(formParameters, headers, httpMethod, uri); - } else { + } + else { headers.setBearerAuth(userRequest.getAccessToken().getTokenValue()); request = new RequestEntity<>(headers, httpMethod, uri); } return request; } + + private HttpMethod getHttpMethod(ClientRegistration clientRegistration) { + if (AuthenticationMethod.FORM + .equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) { + return HttpMethod.POST; + } + return HttpMethod.GET; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserService.java index 2aae74a0d9..4f0efc2548 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; import org.springframework.security.core.AuthenticatedPrincipal; @@ -20,18 +21,17 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; /** - * Implementations of this interface are responsible for obtaining the user attributes - * of the End-User (Resource Owner) from the UserInfo Endpoint - * using the {@link OAuth2UserRequest#getAccessToken() Access Token} - * granted to the {@link OAuth2UserRequest#getClientRegistration() Client} - * and returning an {@link AuthenticatedPrincipal} in the form of an {@link OAuth2User}. + * Implementations of this interface are responsible for obtaining the user attributes of + * the End-User (Resource Owner) from the UserInfo Endpoint using the + * {@link OAuth2UserRequest#getAccessToken() Access Token} granted to the + * {@link OAuth2UserRequest#getClientRegistration() Client} and returning an + * {@link AuthenticatedPrincipal} in the form of an {@link OAuth2User}. * * @author Joe Grandja * @since 5.0 * @see OAuth2UserRequest * @see OAuth2User * @see AuthenticatedPrincipal - * * @param The type of OAuth 2.0 User Request * @param The type of OAuth 2.0 User */ @@ -39,11 +39,12 @@ import org.springframework.security.oauth2.core.user.OAuth2User; public interface OAuth2UserService { /** - * Returns an {@link OAuth2User} after obtaining the user attributes of the End-User from the UserInfo Endpoint. - * + * Returns an {@link OAuth2User} after obtaining the user attributes of the End-User + * from the UserInfo Endpoint. * @param userRequest the user request * @return an {@link OAuth2User} - * @throws OAuth2AuthenticationException if an error occurs while attempting to obtain the user attributes from the UserInfo Endpoint + * @throws OAuth2AuthenticationException if an error occurs while attempting to obtain + * the user attributes from the UserInfo Endpoint */ U loadUser(R userRequest) throws OAuth2AuthenticationException; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/ReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/ReactiveOAuth2UserService.java index f166c27326..61e2a99bed 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/ReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/ReactiveOAuth2UserService.java @@ -13,26 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; +import reactor.core.publisher.Mono; + import org.springframework.security.core.AuthenticatedPrincipal; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; -import reactor.core.publisher.Mono; /** - * Implementations of this interface are responsible for obtaining the user attributes - * of the End-User (Resource Owner) from the UserInfo Endpoint - * using the {@link OAuth2UserRequest#getAccessToken() Access Token} - * granted to the {@link OAuth2UserRequest#getClientRegistration() Client} - * and returning an {@link AuthenticatedPrincipal} in the form of an {@link OAuth2User}. + * Implementations of this interface are responsible for obtaining the user attributes of + * the End-User (Resource Owner) from the UserInfo Endpoint using the + * {@link OAuth2UserRequest#getAccessToken() Access Token} granted to the + * {@link OAuth2UserRequest#getClientRegistration() Client} and returning an + * {@link AuthenticatedPrincipal} in the form of an {@link OAuth2User}. * * @author Rob Winch * @since 5.1 * @see OAuth2UserRequest * @see OAuth2User * @see AuthenticatedPrincipal - * * @param The type of OAuth 2.0 User Request * @param The type of OAuth 2.0 User */ @@ -40,11 +41,12 @@ import reactor.core.publisher.Mono; public interface ReactiveOAuth2UserService { /** - * Returns an {@link OAuth2User} after obtaining the user attributes of the End-User from the UserInfo Endpoint. - * + * Returns an {@link OAuth2User} after obtaining the user attributes of the End-User + * from the UserInfo Endpoint. * @param userRequest the user request * @return an {@link OAuth2User} - * @throws OAuth2AuthenticationException if an error occurs while attempting to obtain the user attributes from the UserInfo Endpoint + * @throws OAuth2AuthenticationException if an error occurs while attempting to obtain + * the user attributes from the UserInfo Endpoint */ Mono loadUser(R userRequest) throws OAuth2AuthenticationException; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/package-info.java index 2d698369b9..e016d4b9a9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/package-info.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Classes and interfaces providing support to the client for initiating requests - * to the OAuth 2.0 Authorization Server's UserInfo Endpoint. + * Classes and interfaces providing support to the client for initiating requests to the + * OAuth 2.0 Authorization Server's UserInfo Endpoint. */ package org.springframework.security.oauth2.client.userinfo; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java index fef8ffac71..246f729aa9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.core.Authentication; @@ -22,16 +26,13 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; import org.springframework.util.Assert; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** - * An implementation of an {@link OAuth2AuthorizedClientRepository} that - * delegates to the provided {@link OAuth2AuthorizedClientService} if the current - * {@code Principal} is authenticated, otherwise, - * to the default (or provided) {@link OAuth2AuthorizedClientRepository} - * if the current request is unauthenticated (or anonymous). - * The default {@code OAuth2AuthorizedClientRepository} is {@link HttpSessionOAuth2AuthorizedClientRepository}. + * An implementation of an {@link OAuth2AuthorizedClientRepository} that delegates to the + * provided {@link OAuth2AuthorizedClientService} if the current {@code Principal} is + * authenticated, otherwise, to the default (or provided) + * {@link OAuth2AuthorizedClientRepository} if the current request is unauthenticated (or + * anonymous). The default {@code OAuth2AuthorizedClientRepository} is + * {@link HttpSessionOAuth2AuthorizedClientRepository}. * * @author Joe Grandja * @since 5.1 @@ -41,64 +42,73 @@ import javax.servlet.http.HttpServletResponse; * @see HttpSessionOAuth2AuthorizedClientRepository */ public final class AuthenticatedPrincipalOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { + private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); + private final OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository = new HttpSessionOAuth2AuthorizedClientRepository(); /** - * Constructs a {@code AuthenticatedPrincipalOAuth2AuthorizedClientRepository} using the provided parameters. - * + * Constructs a {@code AuthenticatedPrincipalOAuth2AuthorizedClientRepository} using + * the provided parameters. * @param authorizedClientService the authorized client service */ - public AuthenticatedPrincipalOAuth2AuthorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) { + public AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + OAuth2AuthorizedClientService authorizedClientService) { Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); this.authorizedClientService = authorizedClientService; } /** - * Sets the {@link OAuth2AuthorizedClientRepository} used for requests that are unauthenticated (or anonymous). - * The default is {@link HttpSessionOAuth2AuthorizedClientRepository}. - * - * @param anonymousAuthorizedClientRepository the repository used for requests that are unauthenticated (or anonymous) + * Sets the {@link OAuth2AuthorizedClientRepository} used for requests that are + * unauthenticated (or anonymous). The default is + * {@link HttpSessionOAuth2AuthorizedClientRepository}. + * @param anonymousAuthorizedClientRepository the repository used for requests that + * are unauthenticated (or anonymous) */ - public void setAnonymousAuthorizedClientRepository(OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) { + public void setAnonymousAuthorizedClientRepository( + OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) { Assert.notNull(anonymousAuthorizedClientRepository, "anonymousAuthorizedClientRepository cannot be null"); this.anonymousAuthorizedClientRepository = anonymousAuthorizedClientRepository; } @Override - public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, - HttpServletRequest request) { + public T loadAuthorizedClient(String clientRegistrationId, + Authentication principal, HttpServletRequest request) { if (this.isPrincipalAuthenticated(principal)) { return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()); - } else { - return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, request); } + return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, request); } @Override public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + HttpServletRequest request, HttpServletResponse response) { if (this.isPrincipalAuthenticated(principal)) { this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); - } else { - this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response); + } + else { + this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, + response); } } @Override public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + HttpServletRequest request, HttpServletResponse response) { if (this.isPrincipalAuthenticated(principal)) { this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName()); - } else { - this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + else { + this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, + response); } } private boolean isPrincipalAuthenticated(Authentication authentication) { - return authentication != null && - !this.authenticationTrustResolver.isAnonymous(authentication) && - authentication.isAuthenticated(); + return authentication != null && !this.authenticationTrustResolver.isAnonymous(authentication) + && authentication.isAuthenticated(); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationRequestRepository.java index c2114a0aa8..4ccd7e5f46 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationRequestRepository.java @@ -13,73 +13,73 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.web; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +package org.springframework.security.oauth2.client.web; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; + /** - * Implementations of this interface are responsible for the persistence - * of {@link OAuth2AuthorizationRequest} between requests. + * Implementations of this interface are responsible for the persistence of + * {@link OAuth2AuthorizationRequest} between requests. * *

        - * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the Authorization Request - * before it initiates the authorization code grant flow. - * As well, used by the {@link OAuth2LoginAuthenticationFilter} for resolving - * the associated Authorization Request when handling the callback of the Authorization Response. + * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the + * Authorization Request before it initiates the authorization code grant flow. As well, + * used by the {@link OAuth2LoginAuthenticationFilter} for resolving the associated + * Authorization Request when handling the callback of the Authorization Response. * + * @param The type of OAuth 2.0 Authorization Request * @author Joe Grandja * @since 5.0 * @see OAuth2AuthorizationRequest * @see HttpSessionOAuth2AuthorizationRequestRepository - * - * @param The type of OAuth 2.0 Authorization Request */ public interface AuthorizationRequestRepository { /** - * Returns the {@link OAuth2AuthorizationRequest} associated to the provided {@code HttpServletRequest} - * or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizationRequest} associated to the provided + * {@code HttpServletRequest} or {@code null} if not available. * @param request the {@code HttpServletRequest} * @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available */ T loadAuthorizationRequest(HttpServletRequest request); /** - * Persists the {@link OAuth2AuthorizationRequest} associating it to - * the provided {@code HttpServletRequest} and/or {@code HttpServletResponse}. - * + * Persists the {@link OAuth2AuthorizationRequest} associating it to the provided + * {@code HttpServletRequest} and/or {@code HttpServletResponse}. * @param authorizationRequest the {@link OAuth2AuthorizationRequest} * @param request the {@code HttpServletRequest} * @param response the {@code HttpServletResponse} */ - void saveAuthorizationRequest(T authorizationRequest, HttpServletRequest request, - HttpServletResponse response); + void saveAuthorizationRequest(T authorizationRequest, HttpServletRequest request, HttpServletResponse response); /** * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the * provided {@code HttpServletRequest} or if not available returns {@code null}. - * - * @deprecated Use {@link #removeAuthorizationRequest(HttpServletRequest, HttpServletResponse)} instead + * @deprecated Use + * {@link #removeAuthorizationRequest(HttpServletRequest, HttpServletResponse)} + * instead * @param request the {@code HttpServletRequest} - * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not + * available */ @Deprecated T removeAuthorizationRequest(HttpServletRequest request); /** * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the - * provided {@code HttpServletRequest} and {@code HttpServletResponse} or if not available returns {@code null}. - * - * @since 5.1 + * provided {@code HttpServletRequest} and {@code HttpServletResponse} or if not + * available returns {@code null}. * @param request the {@code HttpServletRequest} * @param response the {@code HttpServletResponse} * @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @since 5.1 */ default T removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) { return removeAuthorizationRequest(request); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java index d3c495dd3b..7e606eae6b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java @@ -13,8 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import javax.servlet.http.HttpServletRequest; + import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -34,23 +45,16 @@ import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.http.HttpServletRequest; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Base64; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Consumer; - /** * An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to - * resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest} - * using the default request {@code URI} pattern {@code /oauth2/authorization/{registrationId}}. + * resolve an {@link OAuth2AuthorizationRequest} from the provided + * {@code HttpServletRequest} using the default request {@code URI} pattern + * {@code /oauth2/authorization/{registrationId}}. * *

        - * NOTE: The default base {@code URI} {@code /oauth2/authorization} may be overridden - * via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}. + * NOTE: The default base {@code URI} {@code /oauth2/authorization} may be + * overridden via it's constructor + * {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}. * * @author Joe Grandja * @author Rob Winch @@ -61,22 +65,32 @@ import java.util.function.Consumer; * @see OAuth2AuthorizationRequestRedirectFilter */ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver { + private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId"; + private static final char PATH_DELIMITER = '/'; + private final ClientRegistrationRepository clientRegistrationRepository; + private final AntPathRequestMatcher authorizationRequestMatcher; + private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); - private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); - private Consumer authorizationRequestCustomizer = customizer -> {}; + + private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 96); + + private Consumer authorizationRequestCustomizer = (customizer) -> { + }; /** - * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters. - * + * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations - * @param authorizationRequestBaseUri the base {@code URI} used for resolving authorization requests + * @param authorizationRequestBaseUri the base {@code URI} used for resolving + * authorization requests */ public DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository clientRegistrationRepository, - String authorizationRequestBaseUri) { + String authorizationRequestBaseUri) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); this.clientRegistrationRepository = clientRegistrationRepository; @@ -104,13 +118,14 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au } /** - * Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder} - * allowing for further customizations. - * + * Sets the {@code Consumer} to be provided the + * {@link OAuth2AuthorizationRequest.Builder} allowing for further customizations. + * @param authorizationRequestCustomizer the {@code Consumer} to be provided the + * {@link OAuth2AuthorizationRequest.Builder} * @since 5.3 - * @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder} */ - public void setAuthorizationRequestCustomizer(Consumer authorizationRequestCustomizer) { + public void setAuthorizationRequestCustomizer( + Consumer authorizationRequestCustomizer) { Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null"); this.authorizationRequestCustomizer = authorizationRequestCustomizer; } @@ -123,67 +138,73 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au return action; } - private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId, String redirectUriAction) { + private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId, + String redirectUriAction) { if (registrationId == null) { return null; } - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); if (clientRegistration == null) { throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId); } - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); - - OAuth2AuthorizationRequest.Builder builder; - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - builder = OAuth2AuthorizationRequest.authorizationCode(); - Map additionalParameters = new HashMap<>(); - if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) && - clientRegistration.getScopes().contains(OidcScopes.OPENID)) { - // Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope - // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. - addNonceParameters(attributes, additionalParameters); - } - if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) { - addPkceParameters(attributes, additionalParameters); - } - builder.additionalParameters(additionalParameters); - } else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { - builder = OAuth2AuthorizationRequest.implicit(); - } else { - throw new IllegalArgumentException("Invalid Authorization Grant Type (" + - clientRegistration.getAuthorizationGrantType().getValue() + - ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); - } + OAuth2AuthorizationRequest.Builder builder = getBuilder(clientRegistration, attributes); String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction); - builder - .clientId(clientRegistration.getClientId()) + // @formatter:off + builder.clientId(clientRegistration.getClientId()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .redirectUri(redirectUriStr) .scopes(clientRegistration.getScopes()) .state(this.stateGenerator.generateKey()) .attributes(attributes); + // @formatter:on this.authorizationRequestCustomizer.accept(builder); return builder.build(); } + private OAuth2AuthorizationRequest.Builder getBuilder(ClientRegistration clientRegistration, + Map attributes) { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + OAuth2AuthorizationRequest.Builder builder = OAuth2AuthorizationRequest.authorizationCode(); + Map additionalParameters = new HashMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) + && clientRegistration.getScopes().contains(OidcScopes.OPENID)) { + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope + // value. + addNonceParameters(attributes, additionalParameters); + } + if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) { + addPkceParameters(attributes, additionalParameters); + } + builder.additionalParameters(additionalParameters); + return builder; + } + if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { + return OAuth2AuthorizationRequest.implicit(); + } + throw new IllegalArgumentException( + "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() + + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); + } + private String resolveRegistrationId(HttpServletRequest request) { if (this.authorizationRequestMatcher.matches(request)) { - return this.authorizationRequestMatcher - .matcher(request).getVariables().get(REGISTRATION_ID_URI_VARIABLE_NAME); + return this.authorizationRequestMatcher.matcher(request).getVariables() + .get(REGISTRATION_ID_URI_VARIABLE_NAME); } return null; } /** - * Expands the {@link ClientRegistration#getRedirectUri()} with following provided variables:
        + * Expands the {@link ClientRegistration#getRedirectUri()} with following provided + * variables:
        * - baseUrl (e.g. https://localhost/app)
        * - baseScheme (e.g. https)
        * - baseHost (e.g. localhost)
        @@ -194,50 +215,52 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au *

        * Null variables are provided as empty strings. *

        - * Default redirectUri is: {@code org.springframework.security.config.oauth2.client.CommonOAuth2Provider#DEFAULT_REDIRECT_URL} - * + * Default redirectUri is: + * {@code org.springframework.security.config.oauth2.client.CommonOAuth2Provider#DEFAULT_REDIRECT_URL} * @return expanded URI */ - private static String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) { + private static String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, + String action) { Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - + // @formatter:off UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) .replacePath(request.getContextPath()) .replaceQuery(null) .fragment(null) .build(); + // @formatter:on String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", scheme == null ? "" : scheme); + uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); String host = uriComponents.getHost(); - uriVariables.put("baseHost", host == null ? "" : host); + uriVariables.put("baseHost", (host != null) ? host : ""); // following logic is based on HierarchicalUriComponents#toUriString() int port = uriComponents.getPort(); - uriVariables.put("basePort", port == -1 ? "" : ":" + port); + uriVariables.put("basePort", (port == -1) ? "" : ":" + port); String path = uriComponents.getPath(); if (StringUtils.hasLength(path)) { if (path.charAt(0) != PATH_DELIMITER) { path = PATH_DELIMITER + path; } } - uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("basePath", (path != null) ? path : ""); uriVariables.put("baseUrl", uriComponents.toUriString()); - - uriVariables.put("action", action == null ? "" : action); - - return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()) - .buildAndExpand(uriVariables) + uriVariables.put("action", (action != null) ? action : ""); + return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()).buildAndExpand(uriVariables) .toUriString(); } /** * Creates nonce and its hash for use in OpenID Connect 1.0 Authentication Requests. - * - * @param attributes where the {@link OidcParameterNames#NONCE} is stored for the authentication request - * @param additionalParameters where the {@link OidcParameterNames#NONCE} hash is added for the authentication request + * @param attributes where the {@link OidcParameterNames#NONCE} is stored for the + * authentication request + * @param additionalParameters where the {@link OidcParameterNames#NONCE} hash is + * added for the authentication request * * @since 5.2 - * @see 3.1.2.1. Authentication Request + * @see 3.1.2.1. + * Authentication Request */ private void addNonceParameters(Map attributes, Map additionalParameters) { try { @@ -245,20 +268,27 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au String nonceHash = createHash(nonce); attributes.put(OidcParameterNames.NONCE, nonce); additionalParameters.put(OidcParameterNames.NONCE, nonceHash); - } catch (NoSuchAlgorithmException e) { } + } + catch (NoSuchAlgorithmException ex) { + } } /** - * Creates and adds additional PKCE parameters for use in the OAuth 2.0 Authorization and Access Token Requests - * - * @param attributes where {@link PkceParameterNames#CODE_VERIFIER} is stored for the token request - * @param additionalParameters where {@link PkceParameterNames#CODE_CHALLENGE} and, usually, - * {@link PkceParameterNames#CODE_CHALLENGE_METHOD} are added to be used in the authorization request. + * Creates and adds additional PKCE parameters for use in the OAuth 2.0 Authorization + * and Access Token Requests + * @param attributes where {@link PkceParameterNames#CODE_VERIFIER} is stored for the + * token request + * @param additionalParameters where {@link PkceParameterNames#CODE_CHALLENGE} and, + * usually, {@link PkceParameterNames#CODE_CHALLENGE_METHOD} are added to be used in + * the authorization request. * * @since 5.2 - * @see 1.1. Protocol Flow - * @see 4.1. Client Creates a Code Verifier - * @see 4.2. Client Creates the Code Challenge + * @see 1.1. + * Protocol Flow + * @see 4.1. + * Client Creates a Code Verifier + * @see 4.2. + * Client Creates the Code Challenge */ private void addPkceParameters(Map attributes, Map additionalParameters) { String codeVerifier = this.secureKeyGenerator.generateKey(); @@ -267,7 +297,8 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au String codeChallenge = createHash(codeVerifier); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, codeChallenge); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - } catch (NoSuchAlgorithmException e) { + } + catch (NoSuchAlgorithmException ex) { additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, codeVerifier); } } @@ -277,4 +308,5 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII)); return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 708db52020..02e41ab33f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager; @@ -39,42 +48,35 @@ import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - /** - * The default implementation of an {@link OAuth2AuthorizedClientManager} - * for use within the context of a {@code HttpServletRequest}. + * The default implementation of an {@link OAuth2AuthorizedClientManager} for use within + * the context of a {@code HttpServletRequest}. * *

        - * (When operating outside of the context of a {@code HttpServletRequest}, - * use {@link AuthorizedClientServiceOAuth2AuthorizedClientManager} instead.) + * (When operating outside of the context of a {@code HttpServletRequest}, use + * {@link AuthorizedClientServiceOAuth2AuthorizedClientManager} instead.) * *

        Authorized Client Persistence

        * *

        - * This manager utilizes an {@link OAuth2AuthorizedClientRepository} - * to persist {@link OAuth2AuthorizedClient}s. + * This manager utilizes an {@link OAuth2AuthorizedClientRepository} to persist + * {@link OAuth2AuthorizedClient}s. * *

        * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} - * will be saved in the {@link OAuth2AuthorizedClientRepository}. - * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler} - * via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}. + * will be saved in the {@link OAuth2AuthorizedClientRepository}. This functionality can + * be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler} via + * {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}. * *

        * By default, when an authorization attempt fails due to an - * {@value OAuth2ErrorCodes#INVALID_GRANT} error, - * the previously saved {@link OAuth2AuthorizedClient} - * will be removed from the {@link OAuth2AuthorizedClientRepository}. - * (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur - * when a refresh token that is no longer valid is used to retrieve a new access token.) - * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler} - * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. + * {@value OAuth2ErrorCodes#INVALID_GRANT} error, the previously saved + * {@link OAuth2AuthorizedClient} will be removed from the + * {@link OAuth2AuthorizedClientRepository}. (The {@value OAuth2ErrorCodes#INVALID_GRANT} + * error can occur when a refresh token that is no longer valid is used to retrieve a new + * access token.) This functionality can be changed by configuring a custom + * {@link OAuth2AuthorizationFailureHandler} via + * {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. * * @author Joe Grandja * @since 5.2 @@ -84,106 +86,118 @@ import java.util.function.Function; * @see OAuth2AuthorizationFailureHandler */ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { - private static final OAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build(); + + // @formatter:off + private static final OAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .password() + .build(); + // @formatter:on + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private Function> contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; /** - * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters. - * + * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ public DefaultOAuth2AuthorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { + OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; this.authorizedClientProvider = DEFAULT_AUTHORIZED_CLIENT_PROVIDER; this.contextAttributesMapper = new DefaultContextAttributesMapper(); - this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> - authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, + this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> authorizedClientRepository + .saveAuthorizedClient(authorizedClient, principal, (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), (HttpServletResponse) attributes.get(HttpServletResponse.class.getName())); this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, - (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), - (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()))); + (clientRegistrationId, principal, attributes) -> authorizedClientRepository.removeAuthorizedClient( + clientRegistrationId, principal, + (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), + (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()))); } @Nullable @Override public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); - HttpServletRequest servletRequest = getHttpServletRequestOrDefault(authorizeRequest.getAttributes()); Assert.notNull(servletRequest, "servletRequest cannot be null"); HttpServletResponse servletResponse = getHttpServletResponseOrDefault(authorizeRequest.getAttributes()); Assert.notNull(servletResponse, "servletResponse cannot be null"); - OAuth2AuthorizationContext.Builder contextBuilder; if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); - } else { - authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, principal, servletRequest); + } + else { + authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, + servletRequest); if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); - } else { - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + } + else { + ClientRegistration clientRegistration = this.clientRegistrationRepository + .findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); } } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .principal(principal) - .attributes(attributes -> { + // @formatter:off + OAuth2AuthorizationContext authorizationContext = contextBuilder.principal(principal) + .attributes((attributes) -> { Map contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); if (!CollectionUtils.isEmpty(contextAttributes)) { attributes.putAll(contextAttributes); } }) .build(); - + // @formatter:on try { authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - } catch (OAuth2AuthorizationException ex) { - this.authorizationFailureHandler.onAuthorizationFailure( - ex, principal, createAttributes(servletRequest, servletResponse)); + } + catch (OAuth2AuthorizationException ex) { + this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, + createAttributes(servletRequest, servletResponse)); throw ex; } - if (authorizedClient != null) { - this.authorizationSuccessHandler.onAuthorizationSuccess( - authorizedClient, principal, createAttributes(servletRequest, servletResponse)); - } else { - // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. - // For these cases, return the provided `authorizationContext.authorizedClient`. + this.authorizationSuccessHandler.onAuthorizationSuccess(authorizedClient, principal, + createAttributes(servletRequest, servletResponse)); + } + else { + // In the case of re-authorization, the returned `authorizedClient` may be + // null if re-authorization is not supported. + // For these cases, return the provided + // `authorizationContext.authorizedClient`. if (authorizationContext.getAuthorizedClient() != null) { return authorizationContext.getAuthorizedClient(); } } - return authorizedClient; } - private static Map createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + private static Map createAttributes(HttpServletRequest servletRequest, + HttpServletResponse servletResponse) { Map attributes = new HashMap<>(); attributes.put(HttpServletRequest.class.getName(), servletRequest); attributes.put(HttpServletResponse.class.getName(), servletResponse); @@ -206,16 +220,17 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori if (servletResponse == null) { RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); if (requestAttributes instanceof ServletRequestAttributes) { - servletResponse = ((ServletRequestAttributes) requestAttributes).getResponse(); + servletResponse = ((ServletRequestAttributes) requestAttributes).getResponse(); } } return servletResponse; } /** - * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. - * - * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or + * re-authorizing) an OAuth 2.0 Client. + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for + * authorizing (or re-authorizing) an OAuth 2.0 Client */ public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); @@ -223,24 +238,28 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori } /** - * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes - * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. - * - * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes - * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + * Sets the {@code Function} used for mapping attribute(s) from the + * {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes to be associated to + * the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * @param contextAttributesMapper the {@code Function} used for supplying the + * {@code Map} of attributes to the {@link OAuth2AuthorizationContext#getAttributes() + * authorization context} */ - public void setContextAttributesMapper(Function> contextAttributesMapper) { + public void setContextAttributesMapper( + Function> contextAttributesMapper) { Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); this.contextAttributesMapper = contextAttributesMapper; } /** - * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations. + * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful + * authorizations. * *

        - * The default saves {@link OAuth2AuthorizedClient}s in the {@link OAuth2AuthorizedClientRepository}. - * - * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations + * The default saves {@link OAuth2AuthorizedClient}s in the + * {@link OAuth2AuthorizedClientRepository}. + * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} + * that handles successful authorizations * @since 5.3 */ public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { @@ -249,14 +268,16 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori } /** - * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures. + * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization + * failures. * *

        - * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default. - * - * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures - * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler + * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by + * default. + * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} + * that handles authorization failures * @since 5.3 + * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler */ public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); @@ -264,9 +285,11 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori } /** - * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + * The default implementation of the {@link #setContextAttributesMapper(Function) + * contextAttributesMapper}. */ - public static class DefaultContextAttributesMapper implements Function> { + public static class DefaultContextAttributesMapper + implements Function> { @Override public Map apply(OAuth2AuthorizeRequest authorizeRequest) { @@ -280,5 +303,7 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori } return contextAttributes; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index fbeef271d4..82980c6544 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -13,8 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; @@ -34,41 +42,46 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; /** - * The default implementation of a {@link ReactiveOAuth2AuthorizedClientManager} - * for use within the context of a {@link ServerWebExchange}. + * The default implementation of a {@link ReactiveOAuth2AuthorizedClientManager} for use + * within the context of a {@link ServerWebExchange}. * - *

        (When operating outside of the context of a {@link ServerWebExchange}, - * use {@link org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} instead.)

        + *

        + * (When operating outside of the context of a {@link ServerWebExchange}, use + * {@link org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager + * AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} instead.) + *

        * - *

        This is a reactive equivalent of {@link DefaultOAuth2AuthorizedClientManager}.

        + *

        + * This is a reactive equivalent of {@link DefaultOAuth2AuthorizedClientManager}. + *

        * *

        Authorized Client Persistence

        * - *

        This client manager utilizes a {@link ServerOAuth2AuthorizedClientRepository} - * to persist {@link OAuth2AuthorizedClient}s.

        + *

        + * This client manager utilizes a {@link ServerOAuth2AuthorizedClientRepository} to + * persist {@link OAuth2AuthorizedClient}s. + *

        * - *

        By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} - * will be saved in the authorized client repository. - * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler} - * via {@link #setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler)}.

        + *

        + * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} + * will be saved in the authorized client repository. This functionality can be changed by + * configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler} via + * {@link #setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler)}. + *

        * - *

        By default, when an authorization attempt fails due to an + *

        + * By default, when an authorization attempt fails due to an * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error, - * the previously saved {@link OAuth2AuthorizedClient} - * will be removed from the authorized client repository. - * (The {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} - * error generally occurs when a refresh token that is no longer valid - * is used to retrieve a new access token.) - * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationFailureHandler} - * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

        + * the previously saved {@link OAuth2AuthorizedClient} will be removed from the authorized + * client repository. (The + * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error + * generally occurs when a refresh token that is no longer valid is used to retrieve a new + * access token.) This functionality can be changed by configuring a custom + * {@link ReactiveOAuth2AuthorizationFailureHandler} via + * {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}. + *

        * * @author Joe Grandja * @author Phil Clay @@ -79,111 +92,124 @@ import java.util.function.Function; * @see ReactiveOAuth2AuthorizationFailureHandler */ public final class DefaultReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { - private static final ReactiveOAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build(); + // @formatter:off + private static final ReactiveOAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .password() + .build(); + // @formatter:on + + // @formatter:off private static final Mono currentServerWebExchangeMono = Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); + .filter((c) -> c.hasKey(ServerWebExchange.class)) + .map((c) -> c.get(ServerWebExchange.class)); + // @formatter:on private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = DEFAULT_AUTHORIZED_CLIENT_PROVIDER; + private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; /** - * Constructs a {@code DefaultReactiveOAuth2AuthorizedClientManager} using the provided parameters. - * + * Constructs a {@code DefaultReactiveOAuth2AuthorizedClientManager} using the + * provided parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ - public DefaultReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository, - ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + public DefaultReactiveOAuth2AuthorizedClientManager( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; - this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> - authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, + this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> authorizedClientRepository + .saveAuthorizedClient(authorizedClient, principal, (ServerWebExchange) attributes.get(ServerWebExchange.class.getName())); this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, - (ServerWebExchange) attributes.get(ServerWebExchange.class.getName()))); + (clientRegistrationId, principal, attributes) -> authorizedClientRepository.removeAuthorizedClient( + clientRegistrationId, principal, + (ServerWebExchange) attributes.get(ServerWebExchange.class.getName()))); } @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); - + // @formatter:off return Mono.justOrEmpty(authorizeRequest.getAttribute(ServerWebExchange.class.getName())) .switchIfEmpty(currentServerWebExchangeMono) .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null"))) - .flatMap(serverWebExchange -> Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) + .flatMap((serverWebExchange) -> Mono + .justOrEmpty(authorizeRequest.getAuthorizedClient()) .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) - .flatMap(authorizedClient -> { - // Re-authorize - return authorizationContext(authorizeRequest, authorizedClient) - .flatMap(authorizationContext -> authorize(authorizationContext, principal, serverWebExchange)) - // Default to the existing authorizedClient if the client was not re-authorized - .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? - authorizeRequest.getAuthorizedClient() : authorizedClient); - }) + .flatMap((authorizedClient) -> // Re-authorize + authorizationContext(authorizeRequest, authorizedClient) + .flatMap((authorizationContext) -> authorize(authorizationContext, principal, serverWebExchange)) + // Default to the existing authorizedClient if the + // client was not re-authorized + .defaultIfEmpty((authorizeRequest.getAuthorizedClient() != null) + ? authorizeRequest.getAuthorizedClient() : authorizedClient) + ) .switchIfEmpty(Mono.defer(() -> - // Authorize - this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( - "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) - .flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration)) - .flatMap(authorizationContext -> authorize(authorizationContext, principal, serverWebExchange)) - ) - )); + // Authorize + this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) + .flatMap((clientRegistration) -> authorizationContext(authorizeRequest, + clientRegistration)) + .flatMap((authorizationContext) -> authorize(authorizationContext, principal, + serverWebExchange)))) + ); + // @formatter:on } - private Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) { + private Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange serverWebExchange) { return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange); } /** - * Performs authorization and then delegates to either the {@link #authorizationSuccessHandler} - * or {@link #authorizationFailureHandler}, depending on the authorization result. - * + * Performs authorization and then delegates to either the + * {@link #authorizationSuccessHandler} or {@link #authorizationFailureHandler}, + * depending on the authorization result. * @param authorizationContext the context to authorize * @param principal the principle to authorize * @param serverWebExchange the currently active exchange - * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds - * and the {@link #authorizationSuccessHandler} has completed, - * or completes with an exception after the authorization attempt fails - * and the {@link #authorizationFailureHandler} has completed + * @return a {@link Mono} that emits the authorized client after the authorization + * attempt succeeds and the {@link #authorizationSuccessHandler} has completed, or + * completes with an exception after the authorization attempt fails and the + * {@link #authorizationFailureHandler} has completed */ - private Mono authorize( - OAuth2AuthorizationContext authorizationContext, - Authentication principal, - ServerWebExchange serverWebExchange) { - + private Mono authorize(OAuth2AuthorizationContext authorizationContext, + Authentication principal, ServerWebExchange serverWebExchange) { + // @formatter:off return this.authorizedClientProvider.authorize(authorizationContext) - // Delegate to the authorizationSuccessHandler of the successful authorization - .flatMap(authorizedClient -> this.authorizationSuccessHandler.onAuthorizationSuccess( - authorizedClient, - principal, - createAttributes(serverWebExchange)) - .thenReturn(authorizedClient)) + // Delegate to the authorizationSuccessHandler of the successful + // authorization + .flatMap((authorizedClient) -> + this.authorizationSuccessHandler + .onAuthorizationSuccess(authorizedClient, principal, createAttributes(serverWebExchange)) + .thenReturn(authorizedClient) + ) // Delegate to the authorizationFailureHandler of the failed authorization - .onErrorResume(OAuth2AuthorizationException.class, authorizationException -> this.authorizationFailureHandler.onAuthorizationFailure( - authorizationException, - principal, - createAttributes(serverWebExchange)) - .then(Mono.error(authorizationException))); + .onErrorResume(OAuth2AuthorizationException.class, (authorizationException) -> + this.authorizationFailureHandler + .onAuthorizationFailure(authorizationException, principal, createAttributes(serverWebExchange)) + .then(Mono.error(authorizationException)) + ); + // @formatter:on } private Map createAttributes(ServerWebExchange serverWebExchange) { @@ -191,37 +217,44 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React } private Mono authorizationContext(OAuth2AuthorizeRequest authorizeRequest, - OAuth2AuthorizedClient authorizedClient) { + OAuth2AuthorizedClient authorizedClient) { + // @formatter:off return Mono.just(authorizeRequest) .flatMap(this.contextAttributesMapper) - .map(attrs -> OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) + .map((attrs) -> OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) .principal(authorizeRequest.getPrincipal()) - .attributes(attributes -> { + .attributes((attributes) -> { if (!CollectionUtils.isEmpty(attrs)) { attributes.putAll(attrs); } }) .build()); + // @formatter:on } private Mono authorizationContext(OAuth2AuthorizeRequest authorizeRequest, - ClientRegistration clientRegistration) { + ClientRegistration clientRegistration) { + // @formatter:off return Mono.just(authorizeRequest) .flatMap(this.contextAttributesMapper) - .map(attrs -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration) + .map((attrs) -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(authorizeRequest.getPrincipal()) - .attributes(attributes -> { + .attributes((attributes) -> { if (!CollectionUtils.isEmpty(attrs)) { attributes.putAll(attrs); } }) - .build()); + .build() + ); + // @formatter:on } /** - * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. - * - * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or + * re-authorizing) an OAuth 2.0 Client. + * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} + * used for authorizing (or re-authorizing) an OAuth 2.0 Client */ public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); @@ -229,13 +262,15 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React } /** - * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes - * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. - * - * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes - * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + * Sets the {@code Function} used for mapping attribute(s) from the + * {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes to be associated to + * the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * @param contextAttributesMapper the {@code Function} used for supplying the + * {@code Map} of attributes to the {@link OAuth2AuthorizationContext#getAttributes() + * authorization context} */ - public void setContextAttributesMapper(Function>> contextAttributesMapper) { + public void setContextAttributesMapper( + Function>> contextAttributesMapper) { Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); this.contextAttributesMapper = contextAttributesMapper; } @@ -243,9 +278,10 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React /** * Sets the handler that handles successful authorizations. * - * The default saves {@link OAuth2AuthorizedClient}s in the {@link ServerOAuth2AuthorizedClientRepository}. - * - * @param authorizationSuccessHandler the handler that handles successful authorizations. + * The default saves {@link OAuth2AuthorizedClient}s in the + * {@link ServerOAuth2AuthorizedClientRepository}. + * @param authorizationSuccessHandler the handler that handles successful + * authorizations. * @since 5.3 */ public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { @@ -256,12 +292,13 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React /** * Sets the handler that handles authorization failures. * - *

        A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} - * is used by default.

        - * + *

        + * A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} is used + * by default. + *

        * @param authorizationFailureHandler the handler that handles authorization failures. - * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler * @since 5.3 + * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler */ public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); @@ -269,16 +306,19 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React } /** - * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + * The default implementation of the {@link #setContextAttributesMapper(Function) + * contextAttributesMapper}. */ - public static class DefaultContextAttributesMapper implements Function>> { + public static class DefaultContextAttributesMapper + implements Function>> { @Override public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) { ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); + // @formatter:off return Mono.justOrEmpty(serverWebExchange) .switchIfEmpty(currentServerWebExchangeMono) - .flatMap(exchange -> { + .flatMap((exchange) -> { Map contextAttributes = Collections.emptyMap(); String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope)) { @@ -289,6 +329,9 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React return Mono.just(contextAttributes); }) .defaultIfEmpty(Collections.emptyMap()); + // @formatter:on } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index 58ea54f53e..df26460b4c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; +import java.util.HashMap; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; -import java.util.HashMap; -import java.util.Map; + +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; /** * An implementation of an {@link AuthorizationRequestRepository} that stores @@ -35,9 +37,11 @@ import java.util.Map; * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest */ -public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository { - private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = - HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST"; +public final class HttpSessionOAuth2AuthorizationRequestRepository + implements AuthorizationRequestRepository { + + private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = HttpSessionOAuth2AuthorizationRequestRepository.class + .getName() + ".AUTHORIZATION_REQUEST"; private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; @@ -54,7 +58,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au @Override public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, HttpServletRequest request, - HttpServletResponse response) { + HttpServletResponse response) { Assert.notNull(request, "request cannot be null"); Assert.notNull(response, "response cannot be null"); if (authorizationRequest == null) { @@ -79,14 +83,16 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter); if (!authorizationRequests.isEmpty()) { request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); - } else { + } + else { request.getSession().removeAttribute(this.sessionAttributeName); } return originalRequest; } @Override - public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) { + public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request, + HttpServletResponse response) { Assert.notNull(response, "response cannot be null"); return this.removeAuthorizationRequest(request); } @@ -101,17 +107,20 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au } /** - * Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest} + * Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to + * an {@link OAuth2AuthorizationRequest} * @param request - * @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}. + * @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} + * to an {@link OAuth2AuthorizationRequest}. */ private Map getAuthorizationRequests(HttpServletRequest request) { HttpSession session = request.getSession(false); - Map authorizationRequests = session == null ? null : - (Map) session.getAttribute(this.sessionAttributeName); + Map authorizationRequests = (session != null) + ? (Map) session.getAttribute(this.sessionAttributeName) : null; if (authorizationRequests == null) { return new HashMap<>(); } return authorizationRequests; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java index 6d608cf066..e0b65a6398 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.util.Assert; +import java.util.HashMap; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; -import java.util.HashMap; -import java.util.Map; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.util.Assert; /** * An implementation of an {@link OAuth2AuthorizedClientRepository} that stores @@ -35,14 +37,16 @@ import java.util.Map; * @see OAuth2AuthorizedClient */ public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { - private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME = - HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"; + + private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME = HttpSessionOAuth2AuthorizedClientRepository.class + .getName() + ".AUTHORIZED_CLIENTS"; + private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME; @SuppressWarnings("unchecked") @Override - public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, - HttpServletRequest request) { + public T loadAuthorizedClient(String clientRegistrationId, + Authentication principal, HttpServletRequest request) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.notNull(request, "request cannot be null"); return (T) this.getAuthorizedClients(request).get(clientRegistrationId); @@ -50,7 +54,7 @@ public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2 @Override public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + HttpServletRequest request, HttpServletResponse response) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(request, "request cannot be null"); Assert.notNull(response, "response cannot be null"); @@ -61,7 +65,7 @@ public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2 @Override public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + HttpServletRequest request, HttpServletResponse response) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.notNull(request, "request cannot be null"); Map authorizedClients = this.getAuthorizedClients(request); @@ -69,7 +73,8 @@ public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2 if (authorizedClients.remove(clientRegistrationId) != null) { if (!authorizedClients.isEmpty()) { request.getSession().setAttribute(this.sessionAttributeName, authorizedClients); - } else { + } + else { request.getSession().removeAttribute(this.sessionAttributeName); } } @@ -79,11 +84,12 @@ public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2 @SuppressWarnings("unchecked") private Map getAuthorizedClients(HttpServletRequest request) { HttpSession session = request.getSession(false); - Map authorizedClients = session == null ? null : - (Map) session.getAttribute(this.sessionAttributeName); + Map authorizedClients = (session != null) + ? (Map) session.getAttribute(this.sessionAttributeName) : null; if (authorizedClients == null) { authorizedClients = new HashMap<>(); } return authorizedClients; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index 8d2f157c45..f8bc3b13b2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -13,8 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.io.IOException; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; @@ -44,42 +57,30 @@ import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; - /** - * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, - * which handles the processing of the OAuth 2.0 Authorization Response. + * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, which handles the + * processing of the OAuth 2.0 Authorization Response. * *

        * The OAuth 2.0 Authorization Response is processed as follows: * *

          - *
        • - * Assuming the End-User (Resource Owner) has granted access to the Client, the Authorization Server will append the - * {@link OAuth2ParameterNames#CODE code} and {@link OAuth2ParameterNames#STATE state} parameters - * to the {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization Request) - * and redirect the End-User's user-agent back to this {@code Filter} (the Client). - *
        • - *
        • - * This {@code Filter} will then create an {@link OAuth2AuthorizationCodeAuthenticationToken} with - * the {@link OAuth2ParameterNames#CODE code} received and - * delegate it to the {@link AuthenticationManager} to authenticate. - *
        • - *
        • - * Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the - * {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the - * {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal} - * and saving it via the {@link OAuth2AuthorizedClientRepository}. - *
        • + *
        • Assuming the End-User (Resource Owner) has granted access to the Client, the + * Authorization Server will append the {@link OAuth2ParameterNames#CODE code} and + * {@link OAuth2ParameterNames#STATE state} parameters to the + * {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization + * Request) and redirect the End-User's user-agent back to this {@code Filter} (the + * Client).
        • + *
        • This {@code Filter} will then create an + * {@link OAuth2AuthorizationCodeAuthenticationToken} with the + * {@link OAuth2ParameterNames#CODE code} received and delegate it to the + * {@link AuthenticationManager} to authenticate.
        • + *
        • Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized + * Client} is created by associating the + * {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to + * the {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} + * and current {@code Principal} and saving it via the + * {@link OAuth2AuthorizedClientRepository}.
        • *
        * * @author Joe Grandja @@ -94,29 +95,37 @@ import java.util.Set; * @see ClientRegistrationRepository * @see OAuth2AuthorizedClient * @see OAuth2AuthorizedClientRepository - * @see Section 4.1 Authorization Code Grant - * @see Section 4.1.2 Authorization Response + * @see Section + * 4.1 Authorization Code Grant + * @see Section 4.1.2 Authorization + * Response */ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private final AuthenticationManager authenticationManager; - private AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); + + private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + private final AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); + private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + private RequestCache requestCache = new HttpSessionRequestCache(); /** - * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the authorized client repository * @param authenticationManager the authentication manager */ public OAuth2AuthorizationCodeGrantFilter(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository, - AuthenticationManager authenticationManager) { + OAuth2AuthorizedClientRepository authorizedClientRepository, AuthenticationManager authenticationManager) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); Assert.notNull(authenticationManager, "authenticationManager cannot be null"); @@ -127,20 +136,22 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { /** * Sets the repository for stored {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestRepository the repository for stored {@link OAuth2AuthorizationRequest}'s + * @param authorizationRequestRepository the repository for stored + * {@link OAuth2AuthorizationRequest}'s */ - public final void setAuthorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) { + public final void setAuthorizationRequestRepository( + AuthorizationRequestRepository authorizationRequestRepository) { Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); this.authorizationRequestRepository = authorizationRequestRepository; } /** - * Sets the {@link RequestCache} used for loading a previously saved request (if available) - * and replaying it after completing the processing of the OAuth 2.0 Authorization Response. - * + * Sets the {@link RequestCache} used for loading a previously saved request (if + * available) and replaying it after completing the processing of the OAuth 2.0 + * Authorization Response. + * @param requestCache the cache used for loading a previously saved request (if + * available) * @since 5.4 - * @param requestCache the cache used for loading a previously saved request (if available) */ public final void setRequestCache(RequestCache requestCache) { Assert.notNull(requestCache, "requestCache cannot be null"); @@ -149,13 +160,11 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { - + throws ServletException, IOException { if (matchesAuthorizationResponse(request)) { processAuthorizationResponse(request, response); return; } - filterChain.doFilter(request, response); } @@ -164,58 +173,56 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) { return false; } - OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request); + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); if (authorizationRequest == null) { return false; } - // Compare redirect_uri UriComponents requestUri = UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request)).build(); UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()).build(); - Set>> requestUriParameters = new LinkedHashSet<>(requestUri.getQueryParams().entrySet()); - Set>> redirectUriParameters = new LinkedHashSet<>(redirectUri.getQueryParams().entrySet()); - // Remove the additional request parameters (if any) from the authorization response (request) - // before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any) + Set>> requestUriParameters = new LinkedHashSet<>( + requestUri.getQueryParams().entrySet()); + Set>> redirectUriParameters = new LinkedHashSet<>( + redirectUri.getQueryParams().entrySet()); + // Remove the additional request parameters (if any) from the authorization + // response (request) + // before doing an exact comparison with the authorizationRequest.getRedirectUri() + // parameters (if any) requestUriParameters.retainAll(redirectUriParameters); - - if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) && - Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) && - Objects.equals(requestUri.getHost(), redirectUri.getHost()) && - Objects.equals(requestUri.getPort(), redirectUri.getPort()) && - Objects.equals(requestUri.getPath(), redirectUri.getPath()) && - Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) { + if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) + && Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) + && Objects.equals(requestUri.getHost(), redirectUri.getHost()) + && Objects.equals(requestUri.getPort(), redirectUri.getPort()) + && Objects.equals(requestUri.getPath(), redirectUri.getPath()) + && Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) { return true; } return false; } private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response) - throws IOException { - - OAuth2AuthorizationRequest authorizationRequest = - this.authorizationRequestRepository.removeAuthorizationRequest(request, response); - + throws IOException { + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository + .removeAuthorizationRequest(request, response); String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); - MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); String redirectUri = UrlUtils.buildFullRequestUrl(request); - OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri); - + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, + redirectUri); OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken( - clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); + clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); - OAuth2AuthorizationCodeAuthenticationToken authenticationResult; - try { - authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) - this.authenticationManager.authenticate(authenticationRequest); - } catch (OAuth2AuthorizationException ex) { + authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationManager + .authenticate(authenticationRequest); + } + catch (OAuth2AuthorizationException ex) { OAuth2Error error = ex.getError(); - UriComponentsBuilder uriBuilder = UriComponentsBuilder - .fromUriString(authorizationRequest.getRedirectUri()) - .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode()); + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()) + .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode()); if (!StringUtils.isEmpty(error.getDescription())) { uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription()); } @@ -225,25 +232,20 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString()); return; } - Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); - String principalName = currentAuthentication != null ? currentAuthentication.getName() : "anonymousUser"; - + String principalName = (currentAuthentication != null) ? currentAuthentication.getName() : "anonymousUser"; OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - authenticationResult.getClientRegistration(), - principalName, - authenticationResult.getAccessToken(), - authenticationResult.getRefreshToken()); - - this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response); - + authenticationResult.getClientRegistration(), principalName, authenticationResult.getAccessToken(), + authenticationResult.getRefreshToken()); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, + response); String redirectUrl = authorizationRequest.getRedirectUri(); SavedRequest savedRequest = this.requestCache.getRequest(request, response); if (savedRequest != null) { redirectUrl = savedRequest.getRedirectUrl(); this.requestCache.removeRequest(request, response); } - this.redirectStrategy.sendRedirect(request, response, redirectUrl); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 0953200255..2e9c62c993 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -13,8 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -29,38 +38,35 @@ import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; - /** - * This {@code Filter} initiates the authorization code grant or implicit grant flow - * by redirecting the End-User's user-agent to the Authorization Server's Authorization Endpoint. + * This {@code Filter} initiates the authorization code grant or implicit grant flow by + * redirecting the End-User's user-agent to the Authorization Server's Authorization + * Endpoint. * *

        - * It builds the OAuth 2.0 Authorization Request, - * which is used as the redirect {@code URI} to the Authorization Endpoint. - * The redirect {@code URI} will include the client identifier, requested scope(s), state, - * response type, and a redirection URI which the authorization server will send the user-agent back to - * once access is granted (or denied) by the End-User (Resource Owner). + * It builds the OAuth 2.0 Authorization Request, which is used as the redirect + * {@code URI} to the Authorization Endpoint. The redirect {@code URI} will include the + * client identifier, requested scope(s), state, response type, and a redirection URI + * which the authorization server will send the user-agent back to once access is granted + * (or denied) by the End-User (Resource Owner). * *

        - * By default, this {@code Filter} responds to authorization requests - * at the {@code URI} {@code /oauth2/authorization/{registrationId}} - * using the default {@link OAuth2AuthorizationRequestResolver}. - * The {@code URI} template variable {@code {registrationId}} represents the - * {@link ClientRegistration#getRegistrationId() registration identifier} of the client - * that is used for initiating the OAuth 2.0 Authorization Request. + * By default, this {@code Filter} responds to authorization requests at the {@code URI} + * {@code /oauth2/authorization/{registrationId}} using the default + * {@link OAuth2AuthorizationRequestResolver}. The {@code URI} template variable + * {@code {registrationId}} represents the {@link ClientRegistration#getRegistrationId() + * registration identifier} of the client that is used for initiating the OAuth 2.0 + * Authorization Request. * *

        - * The default base {@code URI} {@code /oauth2/authorization} may be overridden - * via the constructor {@link #OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository, String)}, - * or alternatively, an {@code OAuth2AuthorizationRequestResolver} may be provided to the constructor + * The default base {@code URI} {@code /oauth2/authorization} may be overridden via the + * constructor + * {@link #OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository, String)}, + * or alternatively, an {@code OAuth2AuthorizationRequestResolver} may be provided to the + * constructor * {@link #OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolver)} * to override the resolving of authorization requests. - + * * @author Joe Grandja * @author Rob Winch * @since 5.0 @@ -69,26 +75,37 @@ import java.io.IOException; * @see AuthorizationRequestRepository * @see ClientRegistration * @see ClientRegistrationRepository - * @see Section 4.1 Authorization Code Grant - * @see Section 4.1.1 Authorization Request (Authorization Code) - * @see Section 4.2 Implicit Grant - * @see Section 4.2.1 Authorization Request (Implicit) + * @see Section + * 4.1 Authorization Code Grant + * @see Section 4.1.1 Authorization Request + * (Authorization Code) + * @see Section + * 4.2 Implicit Grant + * @see Section 4.2.1 Authorization Request + * (Implicit) */ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilter { + /** * The default base {@code URI} used for authorization requests. */ public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization"; + private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); + private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); + private OAuth2AuthorizationRequestResolver authorizationRequestResolver; - private AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); + + private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + private RequestCache requestCache = new HttpSessionRequestCache(); /** - * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations */ public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository clientRegistrationRepository) { @@ -96,24 +113,26 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt } /** - * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations - * @param authorizationRequestBaseUri the base {@code URI} used for authorization requests + * @param authorizationRequestBaseUri the base {@code URI} used for authorization + * requests */ public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository clientRegistrationRepository, - String authorizationRequestBaseUri) { + String authorizationRequestBaseUri) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); - this.authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( - clientRegistrationRepository, authorizationRequestBaseUri); + this.authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, + authorizationRequestBaseUri); } /** - * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided + * parameters. + * @param authorizationRequestResolver the resolver used for resolving authorization + * requests * @since 5.1 - * @param authorizationRequestResolver the resolver used for resolving authorization requests */ public OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolver authorizationRequestResolver) { Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null"); @@ -122,18 +141,18 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s + * @param authorizationRequestRepository the repository used for storing + * {@link OAuth2AuthorizationRequest}'s */ - public final void setAuthorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) { + public final void setAuthorizationRequestRepository( + AuthorizationRequestRepository authorizationRequestRepository) { Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); this.authorizationRequestRepository = authorizationRequestRepository; } /** - * Sets the {@link RequestCache} used for storing the current request - * before redirecting the OAuth 2.0 Authorization Request. - * + * Sets the {@link RequestCache} used for storing the current request before + * redirecting the OAuth 2.0 Authorization Request. * @param requestCache the cache used for storing the current request */ public final void setRequestCache(RequestCache requestCache) { @@ -144,76 +163,80 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - try { OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request); if (authorizationRequest != null) { this.sendRedirectForAuthorization(request, response, authorizationRequest); return; } - } catch (Exception failed) { - this.unsuccessfulRedirectForAuthorization(request, response, failed); + } + catch (Exception ex) { + this.unsuccessfulRedirectForAuthorization(request, response, ex); return; } - try { filterChain.doFilter(request, response); - } catch (IOException ex) { + } + catch (IOException ex) { throw ex; - } catch (Exception ex) { + } + catch (Exception ex) { // Check to see if we need to handle ClientAuthorizationRequiredException Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex); ClientAuthorizationRequiredException authzEx = (ClientAuthorizationRequiredException) this.throwableAnalyzer - .getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain); + .getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain); if (authzEx != null) { try { - OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request, authzEx.getClientRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request, + authzEx.getClientRegistrationId()); if (authorizationRequest == null) { throw authzEx; } this.sendRedirectForAuthorization(request, response, authorizationRequest); this.requestCache.saveRequest(request, response); - } catch (Exception failed) { + } + catch (Exception failed) { this.unsuccessfulRedirectForAuthorization(request, response, failed); } return; } - if (ex instanceof ServletException) { throw (ServletException) ex; - } else if (ex instanceof RuntimeException) { - throw (RuntimeException) ex; - } else { - throw new RuntimeException(ex); } + if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } + throw new RuntimeException(ex); } } private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response, - OAuth2AuthorizationRequest authorizationRequest) throws IOException { - + OAuth2AuthorizationRequest authorizationRequest) throws IOException { if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationRequest.getGrantType())) { this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); } - this.authorizationRedirectStrategy.sendRedirect(request, response, authorizationRequest.getAuthorizationRequestUri()); + this.authorizationRedirectStrategy.sendRedirect(request, response, + authorizationRequest.getAuthorizationRequestUri()); } private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response, - Exception failed) throws IOException { - - if (logger.isErrorEnabled()) { - logger.error("Authorization Request failed: " + failed.toString(), failed); - } - response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); + Exception ex) throws IOException { + this.logger.error(LogMessage.format("Authorization Request failed: %s", ex, ex)); + response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), + HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); } private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer { + + @Override protected void initExtractorMap() { super.initExtractorMap(); - registerExtractor(ServletException.class, throwable -> { + registerExtractor(ServletException.class, (throwable) -> { ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class); return ((ServletException) throwable).getRootCause(); }); } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java index 10940c2373..d4c6fda21f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.web; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +package org.springframework.security.oauth2.client.web; import javax.servlet.http.HttpServletRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; + /** - * Implementations of this interface are capable of resolving - * an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}. - * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests. + * Implementations of this interface are capable of resolving an + * {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}. Used + * by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization + * Requests. * * @author Joe Grandja * @author Rob Winch @@ -33,21 +35,21 @@ import javax.servlet.http.HttpServletRequest; public interface OAuth2AuthorizationRequestResolver { /** - * Returns the {@link OAuth2AuthorizationRequest} resolved from - * the provided {@code HttpServletRequest} or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizationRequest} resolved from the provided + * {@code HttpServletRequest} or {@code null} if not available. * @param request the {@code HttpServletRequest} - * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not + * available */ OAuth2AuthorizationRequest resolve(HttpServletRequest request); /** - * Returns the {@link OAuth2AuthorizationRequest} resolved from - * the provided {@code HttpServletRequest} or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizationRequest} resolved from the provided + * {@code HttpServletRequest} or {@code null} if not available. * @param request the {@code HttpServletRequest} * @param clientRegistrationId the clientRegistrationId to use - * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not + * available */ OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java index 5443c895fb..6a0eec1321 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; import java.util.Map; @@ -52,34 +53,32 @@ final class OAuth2AuthorizationResponseUtils { } static boolean isAuthorizationResponseSuccess(MultiValueMap request) { - return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.CODE)) && - StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); + return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.CODE)) + && StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); } static boolean isAuthorizationResponseError(MultiValueMap request) { - return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.ERROR)) && - StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); + return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.ERROR)) + && StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); } static OAuth2AuthorizationResponse convert(MultiValueMap request, String redirectUri) { String code = request.getFirst(OAuth2ParameterNames.CODE); String errorCode = request.getFirst(OAuth2ParameterNames.ERROR); String state = request.getFirst(OAuth2ParameterNames.STATE); - if (StringUtils.hasText(code)) { - return OAuth2AuthorizationResponse.success(code) - .redirectUri(redirectUri) - .state(state) - .build(); - } else { - String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); - String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); - return OAuth2AuthorizationResponse.error(errorCode) + return OAuth2AuthorizationResponse.success(code).redirectUri(redirectUri).state(state).build(); + } + String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); + String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); + // @formatter:off + return OAuth2AuthorizationResponse.error(errorCode) .redirectUri(redirectUri) .errorDescription(errorDescription) .errorUri(errorUri) .state(state) .build(); - } + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java index b509f614dc..f52993a3d5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java @@ -13,26 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** - * Implementations of this interface are responsible for the persistence - * of {@link OAuth2AuthorizedClient Authorized Client(s)} between requests. + * Implementations of this interface are responsible for the persistence of + * {@link OAuth2AuthorizedClient Authorized Client(s)} between requests. * *

        - * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client} - * is to associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential - * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, - * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} - * that originally granted the authorization. + * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client} is to + * associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential to + * a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, who + * is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} that originally + * granted the authorization. * * @author Joe Grandja * @since 5.1 @@ -44,10 +45,9 @@ import javax.servlet.http.HttpServletResponse; public interface OAuth2AuthorizedClientRepository { /** - * Returns the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User {@link Authentication} (Resource Owner) - * or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User {@link Authentication} (Resource Owner) or + * {@code null} if not available. * @param clientRegistrationId the identifier for the client's registration * @param principal the End-User {@link Authentication} (Resource Owner) * @param request the {@code HttpServletRequest} @@ -55,30 +55,28 @@ public interface OAuth2AuthorizedClientRepository { * @return the {@link OAuth2AuthorizedClient} or {@code null} if not available */ T loadAuthorizedClient(String clientRegistrationId, Authentication principal, - HttpServletRequest request); + HttpServletRequest request); /** - * Saves the {@link OAuth2AuthorizedClient} associating it to - * the provided End-User {@link Authentication} (Resource Owner). - * + * Saves the {@link OAuth2AuthorizedClient} associating it to the provided End-User + * {@link Authentication} (Resource Owner). * @param authorizedClient the authorized client * @param principal the End-User {@link Authentication} (Resource Owner) * @param request the {@code HttpServletRequest} * @param response the {@code HttpServletResponse} */ void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest request, HttpServletResponse response); + HttpServletRequest request, HttpServletResponse response); /** - * Removes the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User {@link Authentication} (Resource Owner). - * + * Removes the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User {@link Authentication} (Resource Owner). * @param clientRegistrationId the identifier for the client's registration * @param principal the End-User {@link Authentication} (Resource Owner) * @param request the {@code HttpServletRequest} * @param response the {@code HttpServletResponse} */ - void removeAuthorizedClient(String clientRegistrationId, Authentication principal, - HttpServletRequest request, HttpServletResponse response); + void removeAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request, + HttpServletResponse response); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index d98922fe3a..6943215ded 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -39,40 +43,35 @@ import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** - * An implementation of an {@link AbstractAuthenticationProcessingFilter} for OAuth 2.0 Login. + * An implementation of an {@link AbstractAuthenticationProcessingFilter} for OAuth 2.0 + * Login. * *

        - * This authentication {@code Filter} handles the processing of an OAuth 2.0 Authorization Response - * for the authorization code grant flow and delegates an {@link OAuth2LoginAuthenticationToken} - * to the {@link AuthenticationManager} to log in the End-User. + * This authentication {@code Filter} handles the processing of an OAuth 2.0 Authorization + * Response for the authorization code grant flow and delegates an + * {@link OAuth2LoginAuthenticationToken} to the {@link AuthenticationManager} to log in + * the End-User. * *

        * The OAuth 2.0 Authorization Response is processed as follows: * *

          - *
        • - * Assuming the End-User (Resource Owner) has granted access to the Client, the Authorization Server will append the - * {@link OAuth2ParameterNames#CODE code} and {@link OAuth2ParameterNames#STATE state} parameters - * to the {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization Request) - * and redirect the End-User's user-agent back to this {@code Filter} (the Client). - *
        • - *
        • - * This {@code Filter} will then create an {@link OAuth2LoginAuthenticationToken} with - * the {@link OAuth2ParameterNames#CODE code} received and - * delegate it to the {@link AuthenticationManager} to authenticate. - *
        • - *
        • - * Upon a successful authentication, an {@link OAuth2AuthenticationToken} is created (representing the End-User {@code Principal}) - * and associated to the {@link OAuth2AuthorizedClient Authorized Client} using the {@link OAuth2AuthorizedClientRepository}. - *
        • - *
        • - * Finally, the {@link OAuth2AuthenticationToken} is returned and ultimately stored - * in the {@link SecurityContextRepository} to complete the authentication processing. - *
        • + *
        • Assuming the End-User (Resource Owner) has granted access to the Client, the + * Authorization Server will append the {@link OAuth2ParameterNames#CODE code} and + * {@link OAuth2ParameterNames#STATE state} parameters to the + * {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization + * Request) and redirect the End-User's user-agent back to this {@code Filter} (the + * Client).
        • + *
        • This {@code Filter} will then create an {@link OAuth2LoginAuthenticationToken} with + * the {@link OAuth2ParameterNames#CODE code} received and delegate it to the + * {@link AuthenticationManager} to authenticate.
        • + *
        • Upon a successful authentication, an {@link OAuth2AuthenticationToken} is created + * (representing the End-User {@code Principal}) and associated to the + * {@link OAuth2AuthorizedClient Authorized Client} using the + * {@link OAuth2AuthorizedClientRepository}.
        • + *
        • Finally, the {@link OAuth2AuthenticationToken} is returned and ultimately stored in + * the {@link SecurityContextRepository} to complete the authentication processing.
        • *
        * * @author Joe Grandja @@ -88,57 +87,67 @@ import javax.servlet.http.HttpServletResponse; * @see ClientRegistrationRepository * @see OAuth2AuthorizedClient * @see OAuth2AuthorizedClientRepository - * @see Section 4.1 Authorization Code Grant - * @see Section 4.1.2 Authorization Response + * @see Section + * 4.1 Authorization Code Grant + * @see Section 4.1.2 Authorization + * Response */ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProcessingFilter { - /** - * The default {@code URI} where this {@code Filter} processes authentication requests. - */ - public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/oauth2/code/*"; - private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; - private static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found"; - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; - private AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); /** - * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided parameters. - * + * The default {@code URI} where this {@code Filter} processes authentication + * requests. + */ + public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/oauth2/code/*"; + + private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; + + private static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found"; + + private ClientRegistrationRepository clientRegistrationRepository; + + private OAuth2AuthorizedClientRepository authorizedClientRepository; + + private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + + /** + * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientService the authorized client service */ public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientService authorizedClientService) { + OAuth2AuthorizedClientService authorizedClientService) { this(clientRegistrationRepository, authorizedClientService, DEFAULT_FILTER_PROCESSES_URI); } /** - * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided parameters. - * + * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientService the authorized client service - * @param filterProcessesUrl the {@code URI} where this {@code Filter} will process the authentication requests + * @param filterProcessesUrl the {@code URI} where this {@code Filter} will process + * the authentication requests */ public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientService authorizedClientService, - String filterProcessesUrl) { + OAuth2AuthorizedClientService authorizedClientService, String filterProcessesUrl) { this(clientRegistrationRepository, - new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService), filterProcessesUrl); + new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService), + filterProcessesUrl); } /** - * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided parameters. - * - * @since 5.1 + * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the authorized client repository - * @param filterProcessesUrl the {@code URI} where this {@code Filter} will process the authentication requests + * @param filterProcessesUrl the {@code URI} where this {@code Filter} will process + * the authentication requests + * @since 5.1 */ public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository, - String filterProcessesUrl) { + OAuth2AuthorizedClientRepository authorizedClientRepository, String filterProcessesUrl) { super(filterProcessesUrl); Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); @@ -149,20 +158,17 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce @Override public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException { - MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) { OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - - OAuth2AuthorizationRequest authorizationRequest = - this.authorizationRequestRepository.removeAuthorizationRequest(request, response); + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository + .removeAuthorizationRequest(request, response); if (authorizationRequest == null) { OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); if (clientRegistration == null) { @@ -170,44 +176,41 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce "Client Registration not found with Id: " + registrationId, null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } + // @formatter:off String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) .replaceQuery(null) .build() .toUriString(); - OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri); - + // @formatter:on + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, + redirectUri); Object authenticationDetails = this.authenticationDetailsSource.buildDetails(request); - OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( - clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); + OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken(clientRegistration, + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); authenticationRequest.setDetails(authenticationDetails); - - OAuth2LoginAuthenticationToken authenticationResult = - (OAuth2LoginAuthenticationToken) this.getAuthenticationManager().authenticate(authenticationRequest); - + OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this + .getAuthenticationManager().authenticate(authenticationRequest); OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken( - authenticationResult.getPrincipal(), - authenticationResult.getAuthorities(), - authenticationResult.getClientRegistration().getRegistrationId()); + authenticationResult.getPrincipal(), authenticationResult.getAuthorities(), + authenticationResult.getClientRegistration().getRegistrationId()); oauth2Authentication.setDetails(authenticationDetails); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - authenticationResult.getClientRegistration(), - oauth2Authentication.getName(), - authenticationResult.getAccessToken(), - authenticationResult.getRefreshToken()); + authenticationResult.getClientRegistration(), oauth2Authentication.getName(), + authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, oauth2Authentication, request, response); - return oauth2Authentication; } /** * Sets the repository for stored {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestRepository the repository for stored {@link OAuth2AuthorizationRequest}'s + * @param authorizationRequestRepository the repository for stored + * {@link OAuth2AuthorizationRequest}'s */ - public final void setAuthorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) { + public final void setAuthorizationRequestRepository( + AuthorizationRequestRepository authorizationRequestRepository) { Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); this.authorizationRequestRepository = authorizationRequestRepository; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 8241f1bce1..a50632a629 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.method.annotation; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.lang.NonNull; @@ -24,7 +28,9 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; @@ -33,8 +39,6 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResp import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -43,16 +47,13 @@ import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.method.support.ModelAndViewContainer; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** - * An implementation of a {@link HandlerMethodArgumentResolver} that is capable - * of resolving a method parameter to an argument value of type {@link OAuth2AuthorizedClient}. + * An implementation of a {@link HandlerMethodArgumentResolver} that is capable of + * resolving a method parameter to an argument value of type + * {@link OAuth2AuthorizedClient}. * *

        - * For example: - *

        + * For example: 
          * @Controller
          * public class MyController {
          *     @GetMapping("/authorized-client")
        @@ -67,16 +68,20 @@ import javax.servlet.http.HttpServletResponse;
          * @see RegisteredOAuth2AuthorizedClient
          */
         public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
        -	private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken(
        -			"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
        +
        +	private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
        +			"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
        +
         	private OAuth2AuthorizedClientManager authorizedClientManager;
        +
         	private boolean defaultAuthorizedClientManager;
         
         	/**
        -	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
        -	 *
        +	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided
        +	 * parameters.
        +	 * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
        +	 * manages the authorized client(s)
         	 * @since 5.2
        -	 * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s)
         	 */
         	public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager authorizedClientManager) {
         		Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
        @@ -84,105 +89,110 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
         	}
         
         	/**
        -	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
        -	 *
        +	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided
        +	 * parameters.
         	 * @param clientRegistrationRepository the repository of client registrations
         	 * @param authorizedClientRepository the repository of authorized clients
         	 */
         	public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clientRegistrationRepository,
        -													OAuth2AuthorizedClientRepository authorizedClientRepository) {
        +			OAuth2AuthorizedClientRepository authorizedClientRepository) {
         		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
         		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
        -		this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
        -				clientRegistrationRepository, authorizedClientRepository);
        +		this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(clientRegistrationRepository,
        +				authorizedClientRepository);
         		this.defaultAuthorizedClientManager = true;
         	}
         
         	@Override
         	public boolean supportsParameter(MethodParameter parameter) {
         		Class parameterType = parameter.getParameterType();
        -		return (OAuth2AuthorizedClient.class.isAssignableFrom(parameterType) &&
        -				(AnnotatedElementUtils.findMergedAnnotation(
        -						parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class) != null));
        +		return (OAuth2AuthorizedClient.class.isAssignableFrom(parameterType) && (AnnotatedElementUtils
        +				.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class) != null));
         	}
         
         	@NonNull
         	@Override
        -	public Object resolveArgument(MethodParameter parameter,
        -									@Nullable ModelAndViewContainer mavContainer,
        -									NativeWebRequest webRequest,
        -									@Nullable WebDataBinderFactory binderFactory) {
        -
        +	public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
        +			NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) {
         		String clientRegistrationId = this.resolveClientRegistrationId(parameter);
         		if (StringUtils.isEmpty(clientRegistrationId)) {
        -			throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " +
        -					"It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or " +
        -					"@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
        +			throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. "
        +					+ "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or "
        +					+ "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
         		}
        -
         		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
         		if (principal == null) {
         			principal = ANONYMOUS_AUTHENTICATION;
         		}
         		HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
         		HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
        -
        -		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId)
        +		// @formatter:off
        +		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
        +				.withClientRegistrationId(clientRegistrationId)
         				.principal(principal)
         				.attribute(HttpServletRequest.class.getName(), servletRequest)
         				.attribute(HttpServletResponse.class.getName(), servletResponse)
         				.build();
        -
        +		// @formatter:on
         		return this.authorizedClientManager.authorize(authorizeRequest);
         	}
         
         	private String resolveClientRegistrationId(MethodParameter parameter) {
        -		RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils.findMergedAnnotation(
        -				parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
        -
        +		RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
        +				.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
         		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
        -
        -		String clientRegistrationId = null;
         		if (!StringUtils.isEmpty(authorizedClientAnnotation.registrationId())) {
        -			clientRegistrationId = authorizedClientAnnotation.registrationId();
        -		} else if (!StringUtils.isEmpty(authorizedClientAnnotation.value())) {
        -			clientRegistrationId = authorizedClientAnnotation.value();
        -		} else if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) {
        -			clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId();
        +			return authorizedClientAnnotation.registrationId();
         		}
        -
        -		return clientRegistrationId;
        +		if (!StringUtils.isEmpty(authorizedClientAnnotation.value())) {
        +			return authorizedClientAnnotation.value();
        +		}
        +		if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) {
        +			return ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId();
        +		}
        +		return null;
         	}
         
         	/**
        -	 * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant.
        -	 *
        -	 * @deprecated Use {@link #OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)} instead.
        -	 * 				Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a
        -	 * 				{@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient}
        -	 * 				(or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}.
        -	 *
        -	 * @param clientCredentialsTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant
        +	 * Sets the client used when requesting an access token credential at the Token
        +	 * Endpoint for the {@code client_credentials} grant.
        +	 * @deprecated Use
        +	 * {@link #OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)}
        +	 * instead. Create an instance of
        +	 * {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a
        +	 * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient)
        +	 * DefaultClientCredentialsTokenResponseClient} (or a custom one) and than supply it
        +	 * to
        +	 * {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider)
        +	 * DefaultOAuth2AuthorizedClientManager}.
        +	 * @param clientCredentialsTokenResponseClient the client used when requesting an
        +	 * access token credential at the Token Endpoint for the {@code client_credentials}
        +	 * grant
         	 */
         	@Deprecated
         	public void setClientCredentialsTokenResponseClient(
         			OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) {
         		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
        -		Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " +
        -				"Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
        +		Assert.state(this.defaultAuthorizedClientManager,
        +				"The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". "
        +						+ "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
         		updateDefaultAuthorizedClientManager(clientCredentialsTokenResponseClient);
         	}
         
         	private void updateDefaultAuthorizedClientManager(
         			OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) {
        -
        -		OAuth2AuthorizedClientProvider authorizedClientProvider =
        -				OAuth2AuthorizedClientProviderBuilder.builder()
        -						.authorizationCode()
        -						.refreshToken()
        -						.clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient))
        -						.password()
        -						.build();
        -		((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
        +		// @formatter:off
        +		OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder()
        +				.authorizationCode()
        +				.refreshToken()
        +				.clientCredentials((configurer) ->
        +						configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)
        +				)
        +				.password()
        +				.build();
        +		// @formatter:on
        +		((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager)
        +				.setAuthorizedClientProvider(authorizedClientProvider);
         	}
        +
         }
        diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/package-info.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/package-info.java
        index cf4e21261f..f464d0e204 100644
        --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/package-info.java
        +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/package-info.java
        @@ -13,6 +13,7 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         /**
          * OAuth 2.0 Client {@code Filter}'s and supporting classes and interfaces.
          */
        diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java
        index 8291e26704..9b07d7e6a3 100644
        --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java
        +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java
        @@ -16,6 +16,17 @@
         
         package org.springframework.security.oauth2.client.web.reactive.function.client;
         
        +import java.time.Duration;
        +import java.util.Collections;
        +import java.util.HashMap;
        +import java.util.Map;
        +import java.util.Optional;
        +import java.util.function.Consumer;
        +import java.util.stream.Collectors;
        +import java.util.stream.Stream;
        +
        +import reactor.core.publisher.Mono;
        +
         import org.springframework.http.HttpHeaders;
         import org.springframework.http.HttpStatus;
         import org.springframework.security.authentication.AnonymousAuthenticationToken;
        @@ -55,38 +66,39 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
         import org.springframework.web.reactive.function.client.ExchangeFunction;
         import org.springframework.web.reactive.function.client.WebClientResponseException;
         import org.springframework.web.server.ServerWebExchange;
        -import reactor.core.publisher.Mono;
        -
        -import java.time.Duration;
        -import java.util.Collections;
        -import java.util.HashMap;
        -import java.util.Map;
        -import java.util.Optional;
        -import java.util.function.Consumer;
        -import java.util.stream.Collectors;
        -import java.util.stream.Stream;
         
         /**
        - * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
        - * token as a Bearer Token.
        + * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2
        + * requests by including the token as a Bearer Token.
          *
          * 

        Authentication and Authorization Failures

        * - *

        Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized) - * and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server to a - * {@link ReactiveOAuth2AuthorizationFailureHandler}. - * A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} can be used - * to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result - * in a new token being retrieved from an Authorization Server, and sent to the Resource Server.

        + *

        + * Since 5.3, this filter function has the ability to forward authentication (HTTP 401 + * Unauthorized) and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 + * Resource Server to a {@link ReactiveOAuth2AuthorizationFailureHandler}. A + * {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} can be used to + * remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result + * in a new token being retrieved from an Authorization Server, and sent to the Resource + * Server. + *

        * - *

        If the {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository, ServerOAuth2AuthorizedClientRepository)} - * constructor is used, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} - * will be configured automatically.

        + *

        + * If the + * {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository, ServerOAuth2AuthorizedClientRepository)} + * constructor is used, a + * {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} will be + * configured automatically. + *

        * - *

        If the {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)} - * constructor is used, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} - * will NOT be configured automatically. - * It is recommended that you configure one via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

        + *

        + * If the + * {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)} + * constructor is used, a + * {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} will + * NOT be configured automatically. It is recommended that you configure one via + * {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}. + *

        * * @author Rob Winch * @author Joe Grandja @@ -94,36 +106,43 @@ import java.util.stream.Stream; * @since 5.1 */ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { + /** * The request attribute name used to locate the {@link OAuth2AuthorizedClient}. */ private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); /** - * The client request attribute name used to locate the {@link ClientRegistration#getRegistrationId()} + * The client request attribute name used to locate the + * {@link ClientRegistration#getRegistrationId()} */ - private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID"); + private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName() + .concat(".CLIENT_REGISTRATION_ID"); /** - * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}. + * The request attribute name used to locate the + * {@link org.springframework.web.server.ServerWebExchange}. */ private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName(); - private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", - AuthorityUtils.createAuthorityList("ROLE_USER")); + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER")); private final Mono currentAuthenticationMono = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + .map(SecurityContext::getAuthentication).defaultIfEmpty(ANONYMOUS_USER_TOKEN); - private final Mono clientRegistrationIdMono = currentAuthenticationMono - .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + // @formatter:off + private final Mono clientRegistrationIdMono = this.currentAuthenticationMono + .filter((t) -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) .cast(OAuth2AuthenticationToken.class) .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + // @formatter:on + // @formatter:off private final Mono currentServerWebExchangeMono = Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); + .filter((c) -> c.hasKey(ServerWebExchange.class)) + .map((c) -> c.get(ServerWebExchange.class)); + // @formatter:on private final ReactiveOAuth2AuthorizedClientManager authorizedClientManager; @@ -141,64 +160,64 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private ClientResponseHandler clientResponseHandler; - @FunctionalInterface - private interface ClientResponseHandler { - Mono handleResponse(ClientRequest request, Mono response); - } - /** - * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the + * provided parameters. * - *

        When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403) - * failures returned from a OAuth 2.0 Resource Server will NOT be forwarded to a - * {@link ReactiveOAuth2AuthorizationFailureHandler}. - * Therefore, future requests to the Resource Server will most likely use the same (most likely invalid) token, - * resulting in the same errors returned from the Resource Server. - * It is recommended to configure a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} - * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)} + *

        + * When this constructor is used, authentication (HTTP 401) and authorization (HTTP + * 403) failures returned from a OAuth 2.0 Resource Server will NOT be + * forwarded to a {@link ReactiveOAuth2AuthorizationFailureHandler}. Therefore, future + * requests to the Resource Server will most likely use the same (most likely invalid) + * token, resulting in the same errors returned from the Resource Server. It is + * recommended to configure a + * {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} via + * {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)} * so that authentication and authorization failures returned from a Resource Server - * will result in removing the authorized client, so that a new token is retrieved for future requests.

        - * + * will result in removing the authorized client, so that a new token is retrieved for + * future requests. + *

        + * @param authorizedClientManager the {@link ReactiveOAuth2AuthorizedClientManager} + * which manages the authorized client(s) * @since 5.2 - * @param authorizedClientManager the {@link ReactiveOAuth2AuthorizedClientManager} which manages the authorized client(s) */ - public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager authorizedClientManager) { + public ServerOAuth2AuthorizedClientExchangeFilterFunction( + ReactiveOAuth2AuthorizedClientManager authorizedClientManager) { Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); this.authorizedClientManager = authorizedClientManager; - this.clientResponseHandler = (request, responseMono) -> responseMono; + this.clientResponseHandler = (request, responseMono) -> responseMono; } /** - * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. - * - *

        Since 5.3, when this constructor is used, authentication (HTTP 401) - * and authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server - * will be forwarded to a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}, - * which will potentially remove the {@link OAuth2AuthorizedClient} from the given - * {@link ServerOAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code returned. - * Authentication failures returned from an OAuth 2.0 Resource Server typically indicate - * that the token is invalid, and should not be used in future requests. - * Removing the authorized client from the repository will ensure that the existing - * token will not be sent for future requests to the Resource Server, - * and a new token is retrieved from Authorization Server and used for - * future requests to the Resource Server.

        + * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the + * provided parameters. * + *

        + * Since 5.3, when this constructor is used, authentication (HTTP 401) and + * authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server will + * be forwarded to a + * {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}, which will + * potentially remove the {@link OAuth2AuthorizedClient} from the given + * {@link ServerOAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error + * code returned. Authentication failures returned from an OAuth 2.0 Resource Server + * typically indicate that the token is invalid, and should not be used in future + * requests. Removing the authorized client from the repository will ensure that the + * existing token will not be sent for future requests to the Resource Server, and a + * new token is retrieved from Authorization Server and used for future requests to + * the Resource Server. + *

        * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ - public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, - ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - - ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler = - new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, - (ServerWebExchange) attributes.get(ServerWebExchange.class.getName()))); - - this.authorizedClientManager = createDefaultAuthorizedClientManager( - clientRegistrationRepository, - authorizedClientRepository, - authorizationFailureHandler); + public ServerOAuth2AuthorizedClientExchangeFilterFunction( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( + (clientRegistrationId, principal, attributes) -> authorizedClientRepository.removeAuthorizedClient( + clientRegistrationId, principal, + (ServerWebExchange) attributes.get(ServerWebExchange.class.getName()))); + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, + authorizedClientRepository, authorizationFailureHandler); this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); this.defaultAuthorizedClientManager = true; } @@ -207,41 +226,33 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository, ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { - // gh-7544 if (authorizedClientRepository instanceof UnAuthenticatedServerOAuth2AuthorizedClientRepository) { - UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager = - new UnAuthenticatedReactiveOAuth2AuthorizedClientManager( - clientRegistrationRepository, - (UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository, - authorizationFailureHandler); - unauthenticatedAuthorizedClientManager.setAuthorizedClientProvider( - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build()); + UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager = new UnAuthenticatedReactiveOAuth2AuthorizedClientManager( + clientRegistrationRepository, + (UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository, + authorizationFailureHandler); + unauthenticatedAuthorizedClientManager + .setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode().refreshToken().clientCredentials().password().build()); return unauthenticatedAuthorizedClientManager; } - - DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = - new DefaultReactiveOAuth2AuthorizedClientManager( - clientRegistrationRepository, authorizedClientRepository); + DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler); - return authorizedClientManager; } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for - * providing the Bearer Token. Example usage: + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link OAuth2AuthorizedClient} to be used for providing the Bearer Token. Example + * usage: * *
         	 * WebClient webClient = WebClient.builder()
         	 *    .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager))
         	 *    .build();
        -	 * Mono response = webClient
        +	 * Mono<String> response = webClient
         	 *    .get()
         	 *    .uri(uri)
         	 *    .attributes(oauth2AuthorizedClient(authorizedClient))
        @@ -257,16 +268,15 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
         	 * 
      4. A refresh token is present on the OAuth2AuthorizedClient
      5. *
      6. The access token will be expired in * {@link #setAccessTokenExpiresSkew(Duration)}
      7. - *
      8. The {@link ReactiveSecurityContextHolder} will be used to attempt to save - * the token. If it is empty, then the principal name on the OAuth2AuthorizedClient - * will be used to create an Authentication for saving.
      9. + *
      10. The {@link ReactiveSecurityContextHolder} will be used to attempt to save the + * token. If it is empty, then the principal name on the OAuth2AuthorizedClient will + * be used to create an Authentication for saving.
      11. * - * * @param authorizedClient the {@link OAuth2AuthorizedClient} to use. * @return the {@link Consumer} to populate the */ public static Consumer> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) { - return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); + return (attributes) -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); } private static OAuth2AuthorizedClient oauth2AuthorizedClient(ClientRequest request) { @@ -274,14 +284,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link ServerWebExchange} to be used for - * providing the Bearer Token. Example usage: + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link ServerWebExchange} to be used for providing the Bearer Token. Example usage: * *
         	 * WebClient webClient = WebClient.builder()
         	 *    .filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager))
         	 *    .build();
        -	 * Mono response = webClient
        +	 * Mono<String> response = webClient
         	 *    .get()
         	 *    .uri(uri)
         	 *    .attributes(serverWebExchange(serverWebExchange))
        @@ -293,7 +303,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
         	 * @return the {@link Consumer} to populate the client request attributes
         	 */
         	public static Consumer> serverWebExchange(ServerWebExchange serverWebExchange) {
        -		return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
        +		return (attributes) -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
         	}
         
         	private static ServerWebExchange serverWebExchange(ClientRequest request) {
        @@ -301,15 +311,15 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
         	}
         
         	/**
        -	 * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
        -	 * be used to look up the {@link OAuth2AuthorizedClient}.
        -	 *
        +	 * Modifies the {@link ClientRequest#attributes()} to include the
        +	 * {@link ClientRegistration#getRegistrationId()} to be used to look up the
        +	 * {@link OAuth2AuthorizedClient}.
         	 * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
         	 * be used to look up the {@link OAuth2AuthorizedClient}.
         	 * @return the {@link Consumer} to populate the attributes
         	 */
         	public static Consumer> clientRegistrationId(String clientRegistrationId) {
        -		return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
        +		return (attributes) -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
         	}
         
         	private static String clientRegistrationId(ClientRequest request) {
        @@ -321,19 +331,21 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
         	}
         
         	/**
        -	 * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
        -	 * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
        -	 * resolved from the current Authentication.
        -	 * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
        -	 *                                      Default is false.
        +	 * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the
        +	 * current Authentication. It is recommended to be cautious with this feature since
        +	 * all HTTP requests will receive the access token if it can be resolved from the
        +	 * current Authentication.
        +	 * @param defaultOAuth2AuthorizedClient true if a default
        +	 * {@link OAuth2AuthorizedClient} should be used, else false. Default is false.
         	 */
         	public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
         		this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
         	}
         
         	/**
        -	 * If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is
        -	 * recommended to be cautious with this feature since all HTTP requests will receive the access token.
        +	 * If set, will be used as the default {@link ClientRegistration#getRegistrationId()}.
        +	 * It is recommended to be cautious with this feature since all HTTP requests will
        +	 * receive the access token.
         	 * @param clientRegistrationId the id to use
         	 */
         	public void setDefaultClientRegistrationId(String clientRegistrationId) {
        @@ -341,41 +353,51 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
         	}
         
         	/**
        -	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant.
        -	 *
        -	 * @deprecated Use {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)} instead.
        -	 * 				Create an instance of {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider} configured with a
        -	 * 				{@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient) WebClientReactiveClientCredentialsTokenResponseClient}
        -	 * 				(or a custom one) and than supply it to {@link DefaultReactiveOAuth2AuthorizedClientManager#setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider) DefaultReactiveOAuth2AuthorizedClientManager}.
        -	 *
        +	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} used for getting an
        +	 * {@link OAuth2AuthorizedClient} for the client_credentials grant.
        +	 * @deprecated Use
        +	 * {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)}
        +	 * instead. Create an instance of
        +	 * {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider} configured with a
        +	 * {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient)
        +	 * WebClientReactiveClientCredentialsTokenResponseClient} (or a custom one) and than
        +	 * supply it to
        +	 * {@link DefaultReactiveOAuth2AuthorizedClientManager#setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider)
        +	 * DefaultReactiveOAuth2AuthorizedClientManager}.
         	 * @param clientCredentialsTokenResponseClient the client to use
         	 */
         	@Deprecated
         	public void setClientCredentialsTokenResponseClient(
         			ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) {
         		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
        -		Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". " +
        -				"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
        +		Assert.state(this.defaultAuthorizedClientManager,
        +				"The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". "
        +						+ "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
         		this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
         		updateDefaultAuthorizedClientManager();
         	}
         
         	private void updateDefaultAuthorizedClientManager() {
        -		ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
        -				ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
        -						.authorizationCode()
        -						.refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
        -						.clientCredentials(this::updateClientCredentialsProvider)
        -						.password(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
        -						.build();
        +		// @formatter:off
        +		ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
        +				.authorizationCode()
        +				.refreshToken((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew))
        +				.clientCredentials(this::updateClientCredentialsProvider)
        +				.password((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew))
        +				.build();
        +		// @formatter:on
         		if (this.authorizedClientManager instanceof UnAuthenticatedReactiveOAuth2AuthorizedClientManager) {
        -			((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
        -		} else {
        -			((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
        +			((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager)
        +					.setAuthorizedClientProvider(authorizedClientProvider);
        +		}
        +		else {
        +			((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager)
        +					.setAuthorizedClientProvider(authorizedClientProvider);
         		}
         	}
         
        -	private void updateClientCredentialsProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
        +	private void updateClientCredentialsProvider(
        +			ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
         		if (this.clientCredentialsTokenResponseClient != null) {
         			builder.accessTokenResponseClient(this.clientCredentialsTokenResponseClient);
         		}
        @@ -385,119 +407,143 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
         	/**
         	 * An access token will be considered expired by comparing its expiration to now +
         	 * this skewed Duration. The default is 1 minute.
        -	 *
        -	 * @deprecated The {@code accessTokenExpiresSkew} should be configured with the specific {@link ReactiveOAuth2AuthorizedClientProvider} implementation,
        -	 * 				e.g. {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration) ClientCredentialsReactiveOAuth2AuthorizedClientProvider} or
        -	 * 				{@link RefreshTokenReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration) RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
        -	 *
        +	 * @deprecated The {@code accessTokenExpiresSkew} should be configured with the
        +	 * specific {@link ReactiveOAuth2AuthorizedClientProvider} implementation, e.g.
        +	 * {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration)
        +	 * ClientCredentialsReactiveOAuth2AuthorizedClientProvider} or
        +	 * {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider#setClockSkew(Duration)
        +	 * RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
         	 * @param accessTokenExpiresSkew the Duration to use.
         	 */
         	@Deprecated
         	public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
         		Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null");
        -		Assert.state(this.defaultAuthorizedClientManager, "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". " +
        -				"Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
        +		Assert.state(this.defaultAuthorizedClientManager,
        +				"The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". "
        +						+ "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\".");
         		this.accessTokenExpiresSkew = accessTokenExpiresSkew;
         		updateDefaultAuthorizedClientManager();
         	}
         
         	@Override
         	public Mono filter(ClientRequest request, ExchangeFunction next) {
        +		// @formatter:off
         		return authorizedClient(request)
        -				.map(authorizedClient -> bearer(request, authorizedClient))
        -				.flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next))
        +				.map((authorizedClient) -> bearer(request, authorizedClient))
        +				.flatMap((requestWithBearer) -> exchangeAndHandleResponse(requestWithBearer, next))
         				.switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next)));
        +		// @formatter:on
         	}
         
         	private Mono exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
         		return next.exchange(request)
        -				.transform(responseMono -> this.clientResponseHandler.handleResponse(request, responseMono));
        +				.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono));
         	}
         
         	private Mono authorizedClient(ClientRequest request) {
         		OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
        +		// @formatter:off
         		return Mono.justOrEmpty(authorizedClientFromAttrs)
        -				.switchIfEmpty(Mono.defer(() ->
        -						authorizeRequest(request).flatMap(this.authorizedClientManager::authorize)))
        -				.flatMap(authorizedClient ->
        -						reauthorizeRequest(request, authorizedClient).flatMap(this.authorizedClientManager::authorize));
        +				.switchIfEmpty(Mono.defer(() -> authorizeRequest(request)
        +						.flatMap(this.authorizedClientManager::authorize))
        +				)
        +				.flatMap((authorizedClient) -> reauthorizeRequest(request, authorizedClient)
        +						.flatMap(this.authorizedClientManager::authorize)
        +				);
        +		// @formatter:on
         	}
         
         	private Mono authorizeRequest(ClientRequest request) {
         		Mono clientRegistrationId = effectiveClientRegistrationId(request);
        -
         		Mono> serverWebExchange = effectiveServerWebExchange(request);
        -
        +		// @formatter:off
         		return Mono.zip(clientRegistrationId, this.currentAuthenticationMono, serverWebExchange)
        -				.map(t3 -> {
        -					OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()).principal(t3.getT2());
        -					t3.getT3().ifPresent(exchange -> builder.attribute(ServerWebExchange.class.getName(), exchange));
        +				.map((t3) -> {
        +					OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest
        +							.withClientRegistrationId(t3.getT1())
        +							.principal(t3.getT2());
        +					t3.getT3().ifPresent((exchange) -> builder.attribute(ServerWebExchange.class.getName(), exchange));
         					return builder.build();
         				});
        +		// @formatter:on
         	}
         
         	/**
        -	 * Returns a {@link Mono} the emits the {@code clientRegistrationId}
        -	 * that is active for the given request.
        -	 *
        +	 * Returns a {@link Mono} the emits the {@code clientRegistrationId} that is active
        +	 * for the given request.
         	 * @param request the request for which to retrieve the {@code clientRegistrationId}
        -	 * @return a mono that emits the {@code clientRegistrationId}
        -	 * 	       that is active for the given request.
        +	 * @return a mono that emits the {@code clientRegistrationId} that is active for the
        +	 * given request.
         	 */
         	private Mono effectiveClientRegistrationId(ClientRequest request) {
        +		// @formatter:off
         		return Mono.justOrEmpty(clientRegistrationId(request))
         				.switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId))
        -				.switchIfEmpty(clientRegistrationIdMono);
        +				.switchIfEmpty(this.clientRegistrationIdMono);
        +		// @formatter:on
         	}
         
         	/**
        -	 * Returns a {@link Mono} that emits an {@link Optional} for the {@link ServerWebExchange}
        -	 * that is active for the given request.
        -	 *
        -	 * 

        The returned {@link Mono} will never complete empty. - * Instead, it will emit an empty {@link Optional} if no exchange is active.

        + * Returns a {@link Mono} that emits an {@link Optional} for the + * {@link ServerWebExchange} that is active for the given request. * + *

        + * The returned {@link Mono} will never complete empty. Instead, it will emit an empty + * {@link Optional} if no exchange is active. + *

        * @param request the request for which to retrieve the exchange - * @return a {@link Mono} that emits an {@link Optional} for the {@link ServerWebExchange} - * that is active for the given request. + * @return a {@link Mono} that emits an {@link Optional} for the + * {@link ServerWebExchange} that is active for the given request. */ private Mono> effectiveServerWebExchange(ClientRequest request) { + // @formatter:off return Mono.justOrEmpty(serverWebExchange(request)) - .switchIfEmpty(currentServerWebExchangeMono) + .switchIfEmpty(this.currentServerWebExchangeMono) .map(Optional::of) .defaultIfEmpty(Optional.empty()); + // @formatter:on } - private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { + private Mono reauthorizeRequest(ClientRequest request, + OAuth2AuthorizedClient authorizedClient) { Mono> serverWebExchange = effectiveServerWebExchange(request); - + // @formatter:off return Mono.zip(this.currentAuthenticationMono, serverWebExchange) - .map(t2 -> { - OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient).principal(t2.getT1()); - t2.getT2().ifPresent(exchange -> builder.attribute(ServerWebExchange.class.getName(), exchange)); + .map((t2) -> { + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient) + .principal(t2.getT1()); + t2.getT2().ifPresent((exchange) -> builder.attribute(ServerWebExchange.class.getName(), exchange)); return builder.build(); }); + // @formatter:on } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { + // @formatter:off return ClientRequest.from(request) - .headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) - .build(); + .headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) + .build(); + // @formatter:on } /** - * Sets the handler that handles authentication and authorization failures when communicating - * to the OAuth 2.0 Resource Server. + * Sets the handler that handles authentication and authorization failures when + * communicating to the OAuth 2.0 Resource Server. * - *

        For example, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} - * is typically used to remove the cached {@link OAuth2AuthorizedClient}, - * so that the same token is no longer used in future requests to the Resource Server.

        + *

        + * For example, a + * {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} is + * typically used to remove the cached {@link OAuth2AuthorizedClient}, so that the + * same token is no longer used in future requests to the Resource Server. + *

        * - *

        The failure handler used by default depends on which constructor was used - * to construct this {@link ServerOAuth2AuthorizedClientExchangeFilterFunction}. - * See the constructors for more details.

        - * - * @param authorizationFailureHandler the handler that handles authentication and authorization failures. + *

        + * The failure handler used by default depends on which constructor was used to + * construct this {@link ServerOAuth2AuthorizedClientExchangeFilterFunction}. See the + * constructors for more details. + *

        + * @param authorizationFailureHandler the handler that handles authentication and + * authorization failures. * @since 5.3 */ public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { @@ -505,11 +551,24 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); } - private static class UnAuthenticatedReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { + @FunctionalInterface + private interface ClientResponseHandler { + + Mono handleResponse(ClientRequest request, Mono response); + + } + + private static final class UnAuthenticatedReactiveOAuth2AuthorizedClientManager + implements ReactiveOAuth2AuthorizedClientManager { + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private final ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private final ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; private UnAuthenticatedReactiveOAuth2AuthorizedClientManager( @@ -518,72 +577,93 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; - this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> - authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null); + this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> authorizedClientRepository + .saveAuthorizedClient(authorizedClient, principal, null); this.authorizationFailureHandler = authorizationFailureHandler; } @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); - + // @formatter:off return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) - .switchIfEmpty(Mono.defer(() -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null))) - .flatMap(authorizedClient -> { - // Re-authorize - return Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal).build()) - .flatMap(authorizationContext -> authorize(authorizationContext, principal)) - // Default to the existing authorizedClient if the client was not re-authorized - .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? - authorizeRequest.getAuthorizedClient() : authorizedClient); - }) - .switchIfEmpty(Mono.defer(() -> - // Authorize - this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( - "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) - .flatMap(clientRegistration -> Mono.just(OAuth2AuthorizationContext.withClientRegistration(clientRegistration).principal(principal).build())) - .flatMap(authorizationContext -> authorize(authorizationContext, principal)) - )); + .switchIfEmpty(loadAuthorizedClient(clientRegistrationId, principal)) + .flatMap((authorizedClient) -> reauthorize(authorizedClient, authorizeRequest, principal)) + .switchIfEmpty(findAndAuthorize(clientRegistrationId, principal)); + // @formatter:on + } + + private Mono loadAuthorizedClient(String clientRegistrationId, + Authentication principal) { + return Mono.defer( + () -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null)); + } + + private Mono reauthorize(OAuth2AuthorizedClient authorizedClient, + OAuth2AuthorizeRequest authorizeRequest, Authentication principal) { + return Mono + .just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal) + .build()) + .flatMap((authorizationContext) -> authorize(authorizationContext, principal)) + // Default to the existing authorizedClient if the client was not + // re-authorized + .defaultIfEmpty((authorizeRequest.getAuthorizedClient() != null) + ? authorizeRequest.getAuthorizedClient() : authorizedClient); + } + + private Mono findAndAuthorize(String clientRegistrationId, Authentication principal) { + // @formatter:off + return Mono.defer(() -> + this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> + new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'")) + ) + .flatMap((clientRegistration) -> Mono.just(OAuth2AuthorizationContext + .withClientRegistration(clientRegistration).principal(principal).build()) + ) + .flatMap((authorizationContext) -> authorize(authorizationContext, principal)) + ); + // @formatter:on } /** - * Performs authorization and then delegates to either the {@link #authorizationSuccessHandler} - * or {@link #authorizationFailureHandler}, depending on the authorization result. - * + * Performs authorization and then delegates to either the + * {@link #authorizationSuccessHandler} or {@link #authorizationFailureHandler}, + * depending on the authorization result. * @param authorizationContext the context to authorize * @param principal the principle to authorize - * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds - * and the {@link #authorizationSuccessHandler} has completed, - * or completes with an exception after the authorization attempt fails - * and the {@link #authorizationFailureHandler} has completed + * @return a {@link Mono} that emits the authorized client after the authorization + * attempt succeeds and the {@link #authorizationSuccessHandler} has completed, or + * completes with an exception after the authorization attempt fails and the + * {@link #authorizationFailureHandler} has completed */ - private Mono authorize( - OAuth2AuthorizationContext authorizationContext, + private Mono authorize(OAuth2AuthorizationContext authorizationContext, Authentication principal) { - + // @formatter:off return this.authorizedClientProvider.authorize(authorizationContext) - // Delegates to the authorizationSuccessHandler of the successful authorization - .flatMap(authorizedClient -> this.authorizationSuccessHandler.onAuthorizationSuccess( - authorizedClient, - principal, - Collections.emptyMap()) - .thenReturn(authorizedClient)) - // Delegates to the authorizationFailureHandler of the failed authorization - .onErrorResume(OAuth2AuthorizationException.class, authorizationException -> this.authorizationFailureHandler.onAuthorizationFailure( - authorizationException, - principal, - Collections.emptyMap()) - .then(Mono.error(authorizationException))); + // Delegates to the authorizationSuccessHandler of the successful + // authorization + .flatMap((authorizedClient) -> this.authorizationSuccessHandler + .onAuthorizationSuccess(authorizedClient, principal, Collections.emptyMap()) + .thenReturn(authorizedClient) + ) + // Delegates to the authorizationFailureHandler of the failed + // authorization + .onErrorResume(OAuth2AuthorizationException.class, (authorizationException) -> + this.authorizationFailureHandler + .onAuthorizationFailure(authorizationException, principal, Collections.emptyMap()) + .then(Mono.error(authorizationException)) + ); + // @formatter:on } private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); this.authorizedClientProvider = authorizedClientProvider; } + } /** @@ -592,25 +672,23 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * * @since 5.3 */ - private class AuthorizationFailureForwarder implements ClientResponseHandler { + private final class AuthorizationFailureForwarder implements ClientResponseHandler { /** - * A map of HTTP Status Code to OAuth 2.0 Error codes for - * HTTP status codes that should be interpreted as - * authentication or authorization failures. + * A map of HTTP Status Code to OAuth 2.0 Error codes for HTTP status codes that + * should be interpreted as authentication or authorization failures. */ private final Map httpStatusToOAuth2ErrorCodeMap; /** - * The {@link ReactiveOAuth2AuthorizationFailureHandler} to notify - * when an authentication/authorization failure occurs. + * The {@link ReactiveOAuth2AuthorizationFailureHandler} to notify when an + * authentication/authorization failure occurs. */ private final ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; private AuthorizationFailureForwarder(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); this.authorizationFailureHandler = authorizationFailureHandler; - Map httpStatusToOAuth2Error = new HashMap<>(); httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN); httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE); @@ -619,30 +697,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Override public Mono handleResponse(ClientRequest request, Mono responseMono) { + // @formatter:off return responseMono - .flatMap(response -> handleResponse(request, response) - .thenReturn(response)) - .onErrorResume(WebClientResponseException.class, e -> handleWebClientResponseException(request, e) - .then(Mono.error(e))) - .onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e) - .then(Mono.error(e))); + .flatMap((response) -> handleResponse(request, response).thenReturn(response)) + .onErrorResume(WebClientResponseException.class, + (e) -> handleWebClientResponseException(request, e).then(Mono.error(e)) + ) + .onErrorResume(OAuth2AuthorizationException.class, + (e) -> handleAuthorizationException(request, e).then(Mono.error(e))); + // @formatter:on } private Mono handleResponse(ClientRequest request, ClientResponse response) { + // @formatter:off return Mono.justOrEmpty(resolveErrorIfPossible(response)) - .flatMap(oauth2Error -> { + .flatMap((oauth2Error) -> { Mono> serverWebExchange = effectiveServerWebExchange(request); - Mono clientRegistrationId = effectiveClientRegistrationId(request); - - return Mono.zip(currentAuthenticationMono, serverWebExchange, clientRegistrationId) - .flatMap(tuple3 -> handleAuthorizationFailure( - tuple3.getT1(), // Authentication principal - tuple3.getT2().orElse(null), // ServerWebExchange exchange - new ClientAuthorizationException( - oauth2Error, - tuple3.getT3()))); // String clientRegistrationId + return Mono + .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, + serverWebExchange, clientRegistrationId) + .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), + new ClientAuthorizationException(oauth2Error, zipped.getT3()))); }); + // @formatter:on } private OAuth2Error resolveErrorIfPossible(ClientResponse response) { @@ -651,8 +729,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements String wwwAuthenticateHeader = response.headers().header(HttpHeaders.WWW_AUTHENTICATE).get(0); Map authParameters = parseAuthParameters(wwwAuthenticateHeader); if (authParameters.containsKey(OAuth2ParameterNames.ERROR)) { - return new OAuth2Error( - authParameters.get(OAuth2ParameterNames.ERROR), + return new OAuth2Error(authParameters.get(OAuth2ParameterNames.ERROR), authParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION), authParameters.get(OAuth2ParameterNames.ERROR_URI)); } @@ -662,90 +739,78 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private OAuth2Error resolveErrorIfPossible(int statusCode) { if (this.httpStatusToOAuth2ErrorCodeMap.containsKey(statusCode)) { - return new OAuth2Error( - this.httpStatusToOAuth2ErrorCodeMap.get(statusCode), - null, + return new OAuth2Error(this.httpStatusToOAuth2ErrorCodeMap.get(statusCode), null, "https://tools.ietf.org/html/rfc6750#section-3.1"); } return null; } private Map parseAuthParameters(String wwwAuthenticateHeader) { + // @formatter:off return Stream.of(wwwAuthenticateHeader) - .filter(header -> !StringUtils.isEmpty(header)) - .filter(header -> header.toLowerCase().startsWith("bearer")) - .map(header -> header.substring("bearer".length())) - .map(header -> header.split(",")) + .filter((header) -> !StringUtils.isEmpty(header)) + .filter((header) -> header.toLowerCase().startsWith("bearer")) + .map((header) -> header.substring("bearer".length())) + .map((header) -> header.split(",")) .flatMap(Stream::of) - .map(parameter -> parameter.split("=")) - .filter(parameter -> parameter.length > 1) - .collect(Collectors.toMap( - parameters -> parameters[0].trim(), - parameters -> parameters[1].trim().replace("\"", ""))); + .map((parameter) -> parameter.split("=")) + .filter((parameter) -> parameter.length > 1) + .collect(Collectors.toMap((parameters) -> parameters[0].trim(), + (parameters) -> parameters[1].trim().replace("\"", "")) + ); + // @formatter:on } /** - * Handles the given http status code returned from a resource server - * by notifying the authorization failure handler if the http status - * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}. - * + * Handles the given http status code returned from a resource server by notifying + * the authorization failure handler if the http status code is in the + * {@link #httpStatusToOAuth2ErrorCodeMap}. * @param request the request being processed * @param exception The root cause exception for the failure - * @return a {@link Mono} that completes empty after the authorization failure handler completes. + * @return a {@link Mono} that completes empty after the authorization failure + * handler completes. */ - private Mono handleWebClientResponseException(ClientRequest request, WebClientResponseException exception) { - return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())) - .flatMap(oauth2Error -> { - Mono> serverWebExchange = effectiveServerWebExchange(request); - - Mono clientRegistrationId = effectiveClientRegistrationId(request); - - return Mono.zip(currentAuthenticationMono, serverWebExchange, clientRegistrationId) - .flatMap(tuple3 -> handleAuthorizationFailure( - tuple3.getT1(), // Authentication principal - tuple3.getT2().orElse(null), // ServerWebExchange exchange - new ClientAuthorizationException( - oauth2Error, - tuple3.getT3(), // String clientRegistrationId - exception))); - }); + private Mono handleWebClientResponseException(ClientRequest request, + WebClientResponseException exception) { + return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())).flatMap((oauth2Error) -> { + Mono> serverWebExchange = effectiveServerWebExchange(request); + Mono clientRegistrationId = effectiveClientRegistrationId(request); + return Mono + .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, + serverWebExchange, clientRegistrationId) + .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), + new ClientAuthorizationException(oauth2Error, zipped.getT3(), exception))); + }); } /** - * Handles the given OAuth2AuthorizationException that occurred downstream - * by notifying the authorization failure handler. - * + * Handles the given OAuth2AuthorizationException that occurred downstream by + * notifying the authorization failure handler. * @param request the request being processed * @param exception the authorization exception to include in the failure event. - * @return a {@link Mono} that completes empty after the authorization failure handler completes. + * @return a {@link Mono} that completes empty after the authorization failure + * handler completes. */ private Mono handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException exception) { Mono> serverWebExchange = effectiveServerWebExchange(request); - - return Mono.zip(currentAuthenticationMono, serverWebExchange) - .flatMap(tuple2 -> handleAuthorizationFailure( - tuple2.getT1(), // Authentication principal - tuple2.getT2().orElse(null), // ServerWebExchange exchange - exception)); + return Mono + .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, + serverWebExchange) + .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), exception)); } /** * Delegates to the authorization failure handler of the failed authorization. - * * @param principal the principal associated with the failed authorization attempt * @param exchange the currently active exchange * @param exception the authorization exception to include in the failure event. - * @return a {@link Mono} that completes empty after the authorization failure handler completes. + * @return a {@link Mono} that completes empty after the authorization failure + * handler completes. */ - private Mono handleAuthorizationFailure( - Authentication principal, - ServerWebExchange exchange, + private Mono handleAuthorizationFailure(Authentication principal, Optional exchange, OAuth2AuthorizationException exception) { - - return this.authorizationFailureHandler.onAuthorizationFailure( - exception, - principal, - createAttributes(exchange)); + return this.authorizationFailureHandler.onAuthorizationFailure(exception, principal, + createAttributes(exchange.orElse(null))); } private Map createAttributes(ServerWebExchange exchange) { @@ -754,5 +819,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements } return Collections.singletonMap(ServerWebExchange.class.getName(), exchange); } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 8694121c8e..e97728e4ea 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -13,8 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.reactive.function.client; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.context.Context; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.security.authentication.AbstractAuthenticationToken; @@ -54,23 +70,11 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; -import reactor.util.context.Context; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.time.Duration; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import java.util.stream.Stream; /** - * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth 2.0 requests - * by including the {@link OAuth2AuthorizedClient#getAccessToken() access token} as a bearer token. + * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth + * 2.0 requests by including the {@link OAuth2AuthorizedClient#getAccessToken() access + * token} as a bearer token. * *

        * NOTE:This class is intended to be used in a {@code Servlet} environment. @@ -83,7 +87,7 @@ import java.util.stream.Stream; * WebClient webClient = WebClient.builder() * .apply(oauth2.oauth2Configuration()) * .build(); - * Mono response = webClient + * Mono<String> response = webClient * .get() * .uri(uri) * .attributes(oauth2AuthorizedClient(authorizedClient)) @@ -95,23 +99,25 @@ import java.util.stream.Stream; *

        Authentication and Authorization Failures

        * *

        - * Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized) - * and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server - * to a {@link OAuth2AuthorizationFailureHandler}. - * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} can be used - * to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result - * in a new token being retrieved from an Authorization Server, and sent to the Resource Server. + * Since 5.3, this filter function has the ability to forward authentication (HTTP 401 + * Unauthorized) and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 + * Resource Server to a {@link OAuth2AuthorizationFailureHandler}. A + * {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} can be used to remove + * the cached {@link OAuth2AuthorizedClient}, so that future requests will result in a new + * token being retrieved from an Authorization Server, and sent to the Resource Server. * *

        - * If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)} + * If the + * {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)} * constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} * will be configured automatically. * *

        - * If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} + * If the + * {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} * constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} - * will NOT be configured automatically. - * It is recommended that you configure one via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. + * will NOT be configured automatically. It is recommended that you configure one + * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. * * @author Rob Winch * @author Joe Grandja @@ -124,7 +130,8 @@ import java.util.stream.Stream; */ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { - // Same key as in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES + // Same key as in + // SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES"; /** @@ -132,13 +139,17 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement */ private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); - private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID"); + private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName() + .concat(".CLIENT_REGISTRATION_ID"); + private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); + private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); - private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( - "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous", + "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); @Deprecated private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); @@ -156,107 +167,115 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private ClientResponseHandler clientResponseHandler; - @FunctionalInterface - private interface ClientResponseHandler { - Mono handleResponse(ClientRequest request, Mono response); - } - public ServletOAuth2AuthorizedClientExchangeFilterFunction() { } /** - * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the + * provided parameters. * *

        - * When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403) - * failures returned from an OAuth 2.0 Resource Server will NOT be forwarded to an - * {@link OAuth2AuthorizationFailureHandler}. - * Therefore, future requests to the Resource Server will most likely use the same (likely invalid) token, - * resulting in the same errors returned from the Resource Server. - * It is recommended to configure a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} - * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)} - * so that authentication and authorization failures returned from a Resource Server - * will result in removing the authorized client, so that a new token is retrieved for future requests. - * + * When this constructor is used, authentication (HTTP 401) and authorization (HTTP + * 403) failures returned from an OAuth 2.0 Resource Server will NOT be + * forwarded to an {@link OAuth2AuthorizationFailureHandler}. Therefore, future + * requests to the Resource Server will most likely use the same (likely invalid) + * token, resulting in the same errors returned from the Resource Server. It is + * recommended to configure a + * {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} via + * {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)} so that + * authentication and authorization failures returned from a Resource Server will + * result in removing the authorized client, so that a new token is retrieved for + * future requests. + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which + * manages the authorized client(s) * @since 5.2 - * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) */ public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager authorizedClientManager) { Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); this.authorizedClientManager = authorizedClientManager; - this.clientResponseHandler = (request, responseMono) -> responseMono; + this.clientResponseHandler = (request, responseMono) -> responseMono; } /** - * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the + * provided parameters. * *

        - * Since 5.3, when this constructor is used, authentication (HTTP 401) - * and authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server - * will be forwarded to a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}, + * Since 5.3, when this constructor is used, authentication (HTTP 401) and + * authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server will + * be forwarded to a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}, * which will potentially remove the {@link OAuth2AuthorizedClient} from the given - * {@link OAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code returned. - * Authentication failures returned from an OAuth 2.0 Resource Server typically indicate - * that the token is invalid, and should not be used in future requests. - * Removing the authorized client from the repository will ensure that the existing - * token will not be sent for future requests to the Resource Server, - * and a new token is retrieved from the Authorization Server and used for - * future requests to the Resource Server. - * + * {@link OAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code + * returned. Authentication failures returned from an OAuth 2.0 Resource Server + * typically indicate that the token is invalid, and should not be used in future + * requests. Removing the authorized client from the repository will ensure that the + * existing token will not be sent for future requests to the Resource Server, and a + * new token is retrieved from the Authorization Server and used for future requests + * to the Resource Server. * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - - OAuth2AuthorizationFailureHandler authorizationFailureHandler = - new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, - (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), - (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()))); - DefaultOAuth2AuthorizedClientManager defaultAuthorizedClientManager = - new DefaultOAuth2AuthorizedClientManager( - clientRegistrationRepository, authorizedClientRepository); + OAuth2AuthorizationFailureHandler authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( + (clientRegistrationId, principal, attributes) -> removeAuthorizedClient(authorizedClientRepository, + clientRegistrationId, principal, attributes)); + DefaultOAuth2AuthorizedClientManager defaultAuthorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); defaultAuthorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler); this.authorizedClientManager = defaultAuthorizedClientManager; this.defaultAuthorizedClientManager = true; this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); } + private void removeAuthorizedClient(OAuth2AuthorizedClientRepository authorizedClientRepository, + String clientRegistrationId, Authentication principal, Map attributes) { + HttpServletRequest request = getRequest(attributes); + HttpServletResponse response = getResponse(attributes); + authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + /** - * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant. - * - * @deprecated Use {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} instead. - * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a - * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} - * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. - * + * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an + * {@link OAuth2AuthorizedClient} for the client_credentials grant. + * @deprecated Use + * {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} + * instead. Create an instance of + * {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a + * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) + * DefaultClientCredentialsTokenResponseClient} (or a custom one) and than supply it + * to + * {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) + * DefaultOAuth2AuthorizedClientManager}. * @param clientCredentialsTokenResponseClient the client to use */ @Deprecated public void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + Assert.state(this.defaultAuthorizedClientManager, + "The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; updateDefaultAuthorizedClientManager(); } private void updateDefaultAuthorizedClientManager() { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) - .clientCredentials(this::updateClientCredentialsProvider) - .password(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) - .build(); - ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); + // @formatter:off + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew)) + .clientCredentials(this::updateClientCredentialsProvider) + .password((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew)) + .build(); + // @formatter:on + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager) + .setAuthorizedClientProvider(authorizedClientProvider); } - private void updateClientCredentialsProvider(OAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) { + private void updateClientCredentialsProvider( + OAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) { if (this.clientCredentialsTokenResponseClient != null) { builder.accessTokenResponseClient(this.clientCredentialsTokenResponseClient); } @@ -264,19 +283,21 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } /** - * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is - * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be - * resolved from the current Authentication. - * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false. - * Default is false. + * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the + * current Authentication. It is recommended to be cautious with this feature since + * all HTTP requests will receive the access token if it can be resolved from the + * current Authentication. + * @param defaultOAuth2AuthorizedClient true if a default + * {@link OAuth2AuthorizedClient} should be used, else false. Default is false. */ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; } /** - * If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is - * recommended to be cautious with this feature since all HTTP requests will receive the access token. + * If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. + * It is recommended to be cautious with this feature since all HTTP requests will + * receive the access token. * @param clientRegistrationId the id to use */ public void setDefaultClientRegistrationId(String clientRegistrationId) { @@ -284,127 +305,131 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } /** - * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction} + * Configures the builder with {@link #defaultRequest()} and adds this as a + * {@link ExchangeFilterFunction} * @return the {@link Consumer} to configure the builder */ public Consumer oauth2Configuration() { - return builder -> builder.defaultRequest(defaultRequest()).filter(this); + return (builder) -> builder.defaultRequest(defaultRequest()).filter(this); } /** - * Provides defaults for the {@link HttpServletRequest} and the {@link HttpServletResponse} using - * {@link RequestContextHolder}. It also provides defaults for the {@link Authentication} using - * {@link SecurityContextHolder}. It also can default the {@link OAuth2AuthorizedClient} using the - * {@link #clientRegistrationId(String)} or the {@link #authentication(Authentication)}. + * Provides defaults for the {@link HttpServletRequest} and the + * {@link HttpServletResponse} using {@link RequestContextHolder}. It also provides + * defaults for the {@link Authentication} using {@link SecurityContextHolder}. It + * also can default the {@link OAuth2AuthorizedClient} using the + * {@link #clientRegistrationId(String)} or the + * {@link #authentication(Authentication)}. * @return the {@link Consumer} to populate the attributes */ public Consumer> defaultRequest() { - return spec -> spec.attributes(attrs -> { + return (spec) -> spec.attributes((attrs) -> { populateDefaultRequestResponse(attrs); populateDefaultAuthentication(attrs); }); } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for - * providing the Bearer Token. - * + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link OAuth2AuthorizedClient} to be used for providing the Bearer Token. * @param authorizedClient the {@link OAuth2AuthorizedClient} to use. * @return the {@link Consumer} to populate the attributes */ public static Consumer> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) { - return attributes -> { + return (attributes) -> { if (authorizedClient == null) { attributes.remove(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); - } else { + } + else { attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); } }; } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to - * be used to look up the {@link OAuth2AuthorizedClient}. - * + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link ClientRegistration#getRegistrationId()} to be used to look up the + * {@link OAuth2AuthorizedClient}. * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to * be used to look up the {@link OAuth2AuthorizedClient}. * @return the {@link Consumer} to populate the attributes */ public static Consumer> clientRegistrationId(String clientRegistrationId) { - return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); + return (attributes) -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to - * look up and save the {@link OAuth2AuthorizedClient}. The value is defaulted in + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link Authentication} used to look up and save the {@link OAuth2AuthorizedClient}. + * The value is defaulted in * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()} - * * @param authentication the {@link Authentication} to use. * @return the {@link Consumer} to populate the attributes */ public static Consumer> authentication(Authentication authentication) { - return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); + return (attributes) -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link HttpServletRequest} used to - * look up and save the {@link OAuth2AuthorizedClient}. The value is defaulted in + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link HttpServletRequest} used to look up and save the + * {@link OAuth2AuthorizedClient}. The value is defaulted in * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()} - * * @param request the {@link HttpServletRequest} to use. * @return the {@link Consumer} to populate the attributes */ public static Consumer> httpServletRequest(HttpServletRequest request) { - return attributes -> attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, request); + return (attributes) -> attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, request); } /** - * Modifies the {@link ClientRequest#attributes()} to include the {@link HttpServletResponse} used to - * save the {@link OAuth2AuthorizedClient}. The value is defaulted in + * Modifies the {@link ClientRequest#attributes()} to include the + * {@link HttpServletResponse} used to save the {@link OAuth2AuthorizedClient}. The + * value is defaulted in * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()} - * * @param response the {@link HttpServletResponse} to use. * @return the {@link Consumer} to populate the attributes */ public static Consumer> httpServletResponse(HttpServletResponse response) { - return attributes -> attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); + return (attributes) -> attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); } /** * An access token will be considered expired by comparing its expiration to now + * this skewed Duration. The default is 1 minute. - * - * @deprecated The {@code accessTokenExpiresSkew} should be configured with the specific {@link OAuth2AuthorizedClientProvider} implementation, - * e.g. {@link ClientCredentialsOAuth2AuthorizedClientProvider#setClockSkew(Duration) ClientCredentialsOAuth2AuthorizedClientProvider} or - * {@link RefreshTokenOAuth2AuthorizedClientProvider#setClockSkew(Duration) RefreshTokenOAuth2AuthorizedClientProvider}. - * + * @deprecated The {@code accessTokenExpiresSkew} should be configured with the + * specific {@link OAuth2AuthorizedClientProvider} implementation, e.g. + * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setClockSkew(Duration) + * ClientCredentialsOAuth2AuthorizedClientProvider} or + * {@link RefreshTokenOAuth2AuthorizedClientProvider#setClockSkew(Duration) + * RefreshTokenOAuth2AuthorizedClientProvider}. * @param accessTokenExpiresSkew the Duration to use. */ @Deprecated public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); - Assert.state(this.defaultAuthorizedClientManager, "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + Assert.state(this.defaultAuthorizedClientManager, + "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); this.accessTokenExpiresSkew = accessTokenExpiresSkew; updateDefaultAuthorizedClientManager(); } /** - * Sets the {@link OAuth2AuthorizationFailureHandler} that handles - * authentication and authorization failures when communicating - * to the OAuth 2.0 Resource Server. + * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authentication and + * authorization failures when communicating to the OAuth 2.0 Resource Server. * *

        - * For example, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} - * is typically used to remove the cached {@link OAuth2AuthorizedClient}, - * so that the same token is no longer used in future requests to the Resource Server. + * For example, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is + * typically used to remove the cached {@link OAuth2AuthorizedClient}, so that the + * same token is no longer used in future requests to the Resource Server. * *

        - * The failure handler used by default depends on which constructor was used - * to construct this {@link ServletOAuth2AuthorizedClientExchangeFilterFunction}. - * See the constructors for more details. - * - * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authentication and authorization failures + * The failure handler used by default depends on which constructor was used to + * construct this {@link ServletOAuth2AuthorizedClientExchangeFilterFunction}. See the + * constructors for more details. + * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} + * that handles authentication and authorization failures * @since 5.3 */ public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { @@ -414,43 +439,47 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @Override public Mono filter(ClientRequest request, ExchangeFunction next) { + // @formatter:off return mergeRequestAttributesIfNecessary(request) - .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .flatMap(req -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req)) - .switchIfEmpty(Mono.defer(() -> - mergeRequestAttributesIfNecessary(request) - .filter(req -> resolveClientRegistrationId(req) != null) - .flatMap(req -> authorizeClient(resolveClientRegistrationId(req), req)) - )) - .map(authorizedClient -> bearer(request, authorizedClient)) - .flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next)) + .filter((req) -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) + .flatMap((req) -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req)) + .switchIfEmpty( + Mono.defer(() -> + mergeRequestAttributesIfNecessary(request) + .filter((req) -> resolveClientRegistrationId(req) != null) + .flatMap((req) -> authorizeClient(resolveClientRegistrationId(req), req)) + ) + ) + .map((authorizedClient) -> bearer(request, authorizedClient)) + .flatMap((requestWithBearer) -> exchangeAndHandleResponse(requestWithBearer, next)) .switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next))); + // @formatter:on } private Mono exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { return next.exchange(request) - .transform(responseMono -> this.clientResponseHandler.handleResponse(request, responseMono)); + .transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono)); } private Mono mergeRequestAttributesIfNecessary(ClientRequest request) { - if (!request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() || - !request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() || - !request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { + if (!request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() + || !request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() + || !request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { return mergeRequestAttributesFromContext(request); - } else { - return Mono.just(request); } + return Mono.just(request); } private Mono mergeRequestAttributesFromContext(ClientRequest request) { ClientRequest.Builder builder = ClientRequest.from(request); return Mono.subscriberContext() - .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))) + .map((ctx) -> builder.attributes((attrs) -> populateRequestAttributes(attrs, ctx))) .map(ClientRequest.Builder::build); } private void populateRequestAttributes(Map attrs, Context ctx) { - // NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds this key + // NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds + // this key if (!ctx.hasKey(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY)) { return; } @@ -470,13 +499,12 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } private void populateDefaultRequestResponse(Map attrs) { - if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && - attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { + if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { return; } RequestAttributes context = RequestContextHolder.getRequestAttributes(); if (context instanceof ServletRequestAttributes) { - attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ((ServletRequestAttributes) context).getRequest()); + attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ((ServletRequestAttributes) context).getRequest()); attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ((ServletRequestAttributes) context).getResponse()); } } @@ -496,8 +524,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement clientRegistrationId = this.defaultClientRegistrationId; } Authentication authentication = getAuthentication(attrs); - if (clientRegistrationId == null - && this.defaultOAuth2AuthorizedClient + if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && authentication instanceof OAuth2AuthenticationToken) { clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); } @@ -515,26 +542,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - - OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId).principal(authentication); - builder.attributes(attributes -> { - if (servletRequest != null) { - attributes.put(HttpServletRequest.class.getName(), servletRequest); - } - if (servletResponse != null) { - attributes.put(HttpServletResponse.class.getName(), servletResponse); - } - }); + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) + .principal(authentication); + builder.attributes((attributes) -> addToAttributes(attributes, servletRequest, servletResponse)); OAuth2AuthorizeRequest authorizeRequest = builder.build(); - - // NOTE: - // 'authorizedClientManager.authorize()' needs to be executed - // on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) - // since it performs a blocking I/O operation using RestTemplate internally - return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)).subscribeOn(Schedulers.boundedElastic()); + // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated + // thread via subscribeOn(Schedulers.boundedElastic()) since it performs a + // blocking I/O operation using RestTemplate internally + return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .subscribeOn(Schedulers.boundedElastic()); } - private Mono reauthorizeClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) { + private Mono reauthorizeClient(OAuth2AuthorizedClient authorizedClient, + ClientRequest request) { if (this.authorizedClientManager == null) { return Mono.just(authorizedClient); } @@ -545,30 +565,34 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - - OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient).principal(authentication); - builder.attributes(attributes -> { - if (servletRequest != null) { - attributes.put(HttpServletRequest.class.getName(), servletRequest); - } - if (servletResponse != null) { - attributes.put(HttpServletResponse.class.getName(), servletResponse); - } - }); + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient) + .principal(authentication); + builder.attributes((attributes) -> addToAttributes(attributes, servletRequest, servletResponse)); OAuth2AuthorizeRequest reauthorizeRequest = builder.build(); + // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated + // thread via subscribeOn(Schedulers.boundedElastic()) since it performs a + // blocking I/O operation using RestTemplate internally + return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + .subscribeOn(Schedulers.boundedElastic()); + } - // NOTE: - // 'authorizedClientManager.authorize()' needs to be executed - // on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) - // since it performs a blocking I/O operation using RestTemplate internally - return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)).subscribeOn(Schedulers.boundedElastic()); + private void addToAttributes(Map attributes, HttpServletRequest servletRequest, + HttpServletResponse servletResponse) { + if (servletRequest != null) { + attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, servletRequest); + } + if (servletResponse != null) { + attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, servletResponse); + } } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { + // @formatter:off return ClientRequest.from(request) - .headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .build(); + .headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + // @formatter:on } static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map attrs) { @@ -593,8 +617,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private static Authentication createAuthentication(final String principalName) { Assert.hasText(principalName, "principalName cannot be empty"); - return new AbstractAuthenticationToken(null) { + @Override public Object getCredentials() { return ""; @@ -607,31 +631,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement }; } + @FunctionalInterface + private interface ClientResponseHandler { + + Mono handleResponse(ClientRequest request, Mono response); + + } + /** * Forwards authentication and authorization failures to an * {@link OAuth2AuthorizationFailureHandler}. * * @since 5.3 */ - private static class AuthorizationFailureForwarder implements ClientResponseHandler { + private static final class AuthorizationFailureForwarder implements ClientResponseHandler { /** - * A map of HTTP status code to OAuth 2.0 error code for - * HTTP status codes that should be interpreted as - * authentication or authorization failures. + * A map of HTTP status code to OAuth 2.0 error code for HTTP status codes that + * should be interpreted as authentication or authorization failures. */ private final Map httpStatusToOAuth2ErrorCodeMap; /** - * The {@link OAuth2AuthorizationFailureHandler} to notify - * when an authentication/authorization failure occurs. + * The {@link OAuth2AuthorizationFailureHandler} to notify when an + * authentication/authorization failure occurs. */ private final OAuth2AuthorizationFailureHandler authorizationFailureHandler; private AuthorizationFailureForwarder(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); this.authorizationFailureHandler = authorizationFailureHandler; - Map httpStatusToOAuth2Error = new HashMap<>(); httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN); httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE); @@ -640,33 +669,30 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @Override public Mono handleResponse(ClientRequest request, Mono responseMono) { - return responseMono - .flatMap(response -> handleResponse(request, response) - .thenReturn(response)) - .onErrorResume(WebClientResponseException.class, e -> handleWebClientResponseException(request, e) - .then(Mono.error(e))) - .onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e) - .then(Mono.error(e))); + return responseMono.flatMap((response) -> handleResponse(request, response).thenReturn(response)) + .onErrorResume(WebClientResponseException.class, + (e) -> handleWebClientResponseException(request, e).then(Mono.error(e))) + .onErrorResume(OAuth2AuthorizationException.class, + (e) -> handleAuthorizationException(request, e).then(Mono.error(e))); } private Mono handleResponse(ClientRequest request, ClientResponse response) { + // @formatter:off return Mono.justOrEmpty(resolveErrorIfPossible(response)) - .flatMap(oauth2Error -> { + .flatMap((oauth2Error) -> { Map attrs = request.attributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); if (authorizedClient == null) { return Mono.empty(); } - - ClientAuthorizationException authorizationException = new ClientAuthorizationException( - oauth2Error, authorizedClient.getClientRegistration().getRegistrationId()); - + ClientAuthorizationException authorizationException = new ClientAuthorizationException(oauth2Error, + authorizedClient.getClientRegistration().getRegistrationId()); Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); }); + // @formatter:on } private OAuth2Error resolveErrorIfPossible(ClientResponse response) { @@ -675,8 +701,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement String wwwAuthenticateHeader = response.headers().header(HttpHeaders.WWW_AUTHENTICATE).get(0); Map authParameters = parseAuthParameters(wwwAuthenticateHeader); if (authParameters.containsKey(OAuth2ParameterNames.ERROR)) { - return new OAuth2Error( - authParameters.get(OAuth2ParameterNames.ERROR), + return new OAuth2Error(authParameters.get(OAuth2ParameterNames.ERROR), authParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION), authParameters.get(OAuth2ParameterNames.ERROR_URI)); } @@ -686,103 +711,107 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private OAuth2Error resolveErrorIfPossible(int statusCode) { if (this.httpStatusToOAuth2ErrorCodeMap.containsKey(statusCode)) { - return new OAuth2Error( - this.httpStatusToOAuth2ErrorCodeMap.get(statusCode), - null, + return new OAuth2Error(this.httpStatusToOAuth2ErrorCodeMap.get(statusCode), null, "https://tools.ietf.org/html/rfc6750#section-3.1"); } return null; } private Map parseAuthParameters(String wwwAuthenticateHeader) { - return Stream.of(wwwAuthenticateHeader) - .filter(header -> !StringUtils.isEmpty(header)) - .filter(header -> header.toLowerCase().startsWith("bearer")) - .map(header -> header.substring("bearer".length())) - .map(header -> header.split(",")) + // @formatter:off + return Stream.of(wwwAuthenticateHeader).filter((header) -> !StringUtils.isEmpty(header)) + .filter((header) -> header.toLowerCase().startsWith("bearer")) + .map((header) -> header.substring("bearer".length())) + .map((header) -> header.split(",")) .flatMap(Stream::of) - .map(parameter -> parameter.split("=")) - .filter(parameter -> parameter.length > 1) - .collect(Collectors.toMap( - parameters -> parameters[0].trim(), - parameters -> parameters[1].trim().replace("\"", ""))); + .map((parameter) -> parameter.split("=")) + .filter((parameter) -> parameter.length > 1) + .collect(Collectors.toMap((parameters) -> parameters[0].trim(), + (parameters) -> parameters[1].trim().replace("\"", "")) + ); + // @formatter:on } /** - * Handles the given http status code returned from a resource server - * by notifying the authorization failure handler if the http status - * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}. - * + * Handles the given http status code returned from a resource server by notifying + * the authorization failure handler if the http status code is in the + * {@link #httpStatusToOAuth2ErrorCodeMap}. * @param request the request being processed * @param exception The root cause exception for the failure - * @return a {@link Mono} that completes empty after the authorization failure handler completes + * @return a {@link Mono} that completes empty after the authorization failure + * handler completes */ - private Mono handleWebClientResponseException(ClientRequest request, WebClientResponseException exception) { - return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())) - .flatMap(oauth2Error -> { - Map attrs = request.attributes(); - OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); - if (authorizedClient == null) { - return Mono.empty(); - } - - ClientAuthorizationException authorizationException = new ClientAuthorizationException( - oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception); - - Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); - HttpServletRequest servletRequest = getRequest(attrs); - HttpServletResponse servletResponse = getResponse(attrs); - - return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); - }); + private Mono handleWebClientResponseException(ClientRequest request, + WebClientResponseException exception) { + return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())).flatMap((oauth2Error) -> { + Map attrs = request.attributes(); + OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); + if (authorizedClient == null) { + return Mono.empty(); + } + ClientAuthorizationException authorizationException = new ClientAuthorizationException(oauth2Error, + authorizedClient.getClientRegistration().getRegistrationId(), exception); + Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); + }); } /** * Handles the given {@link OAuth2AuthorizationException} that occurred downstream * by notifying the authorization failure handler. - * * @param request the request being processed - * @param authorizationException the authorization exception to include in the failure event - * @return a {@link Mono} that completes empty after the authorization failure handler completes + * @param authorizationException the authorization exception to include in the + * failure event + * @return a {@link Mono} that completes empty after the authorization failure + * handler completes */ - private Mono handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException authorizationException) { - return Mono.justOrEmpty(request) - .flatMap(req -> { - Map attrs = req.attributes(); - OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); - if (authorizedClient == null) { - return Mono.empty(); - } - - Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); - HttpServletRequest servletRequest = getRequest(attrs); - HttpServletResponse servletResponse = getResponse(attrs); - - return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); - }); + private Mono handleAuthorizationException(ClientRequest request, + OAuth2AuthorizationException authorizationException) { + return Mono.justOrEmpty(request).flatMap((req) -> { + Map attrs = req.attributes(); + OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); + if (authorizedClient == null) { + return Mono.empty(); + } + Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); + }); } /** - * Delegates the failed authorization to the {@link OAuth2AuthorizationFailureHandler}. - * - * @param exception the {@link OAuth2AuthorizationException} to include in the failure event + * Delegates the failed authorization to the + * {@link OAuth2AuthorizationFailureHandler}. + * @param exception the {@link OAuth2AuthorizationException} to include in the + * failure event * @param principal the principal associated with the failed authorization attempt * @param servletRequest the currently active {@code HttpServletRequest} * @param servletResponse the currently active {@code HttpServletResponse} - * @return a {@link Mono} that completes empty after the {@link OAuth2AuthorizationFailureHandler} completes + * @return a {@link Mono} that completes empty after the + * {@link OAuth2AuthorizationFailureHandler} completes */ - private Mono handleAuthorizationFailure(OAuth2AuthorizationException exception, - Authentication principal, HttpServletRequest servletRequest, HttpServletResponse servletResponse) { - Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure( - exception, principal, createAttributes(servletRequest, servletResponse)); - return Mono.fromRunnable(runnable).subscribeOn(Schedulers.boundedElastic()).then(); + private Mono handleAuthorizationFailure(OAuth2AuthorizationException exception, Authentication principal, + HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure(exception, principal, + createAttributes(servletRequest, servletResponse)); + // @formatter:off + return Mono.fromRunnable(runnable) + .subscribeOn(Schedulers.boundedElastic()) + .then(); + // @formatter:on } - private static Map createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + private static Map createAttributes(HttpServletRequest servletRequest, + HttpServletResponse servletResponse) { Map attributes = new HashMap<>(); attributes.put(HttpServletRequest.class.getName(), servletRequest); attributes.put(HttpServletResponse.class.getName(), servletResponse); return attributes; } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 798fafc30c..f828c82360 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -16,6 +16,8 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.annotation; +import reactor.core.publisher.Mono; + import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -36,15 +38,14 @@ import org.springframework.util.StringUtils; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; /** - * An implementation of a {@link HandlerMethodArgumentResolver} that is capable - * of resolving a method parameter to an argument value of type {@link OAuth2AuthorizedClient}. + * An implementation of a {@link HandlerMethodArgumentResolver} that is capable of + * resolving a method parameter to an argument value of type + * {@link OAuth2AuthorizedClient}. * *

        - * For example: - *

        + * For example: 
          * @Controller
          * public class MyController {
          *     @GetMapping("/authorized-client")
        @@ -60,15 +61,18 @@ import reactor.core.publisher.Mono;
          * @see RegisteredOAuth2AuthorizedClient
          */
         public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
        +
         	private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken(
         			"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER"));
        +
         	private ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
         
         	/**
        -	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
        -	 *
        +	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided
        +	 * parameters.
        +	 * @param authorizedClientManager the {@link ReactiveOAuth2AuthorizedClientManager}
        +	 * which manages the authorized client(s)
         	 * @since 5.2
        -	 * @param authorizedClientManager the {@link ReactiveOAuth2AuthorizedClientManager} which manages the authorized client(s)
         	 */
         	public OAuth2AuthorizedClientArgumentResolver(ReactiveOAuth2AuthorizedClientManager authorizedClientManager) {
         		Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
        @@ -76,72 +80,71 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
         	}
         
         	/**
        -	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
        -	 *
        +	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided
        +	 * parameters.
         	 * @param clientRegistrationRepository the repository of client registrations
         	 * @param authorizedClientRepository the repository of authorized clients
         	 */
         	public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository,
        -													ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
        +			ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
         		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
         		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
        -		this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
        -				clientRegistrationRepository, authorizedClientRepository);
        +		this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(clientRegistrationRepository,
        +				authorizedClientRepository);
         	}
         
         	@Override
         	public boolean supportsParameter(MethodParameter parameter) {
        -		return AnnotatedElementUtils.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class) != null;
        +		return AnnotatedElementUtils.findMergedAnnotation(parameter.getParameter(),
        +				RegisteredOAuth2AuthorizedClient.class) != null;
         	}
         
         	@Override
        -	public Mono resolveArgument(MethodParameter parameter, BindingContext bindingContext, ServerWebExchange exchange) {
        +	public Mono resolveArgument(MethodParameter parameter, BindingContext bindingContext,
        +			ServerWebExchange exchange) {
         		return Mono.defer(() -> {
         			RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
         					.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
        -
        -			String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ?
        -					authorizedClientAnnotation.registrationId() : null;
        -
        -			return authorizeRequest(clientRegistrationId, exchange)
        -					.flatMap(this.authorizedClientManager::authorize);
        +			String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId())
        +					? authorizedClientAnnotation.registrationId() : null;
        +			return authorizeRequest(clientRegistrationId, exchange).flatMap(this.authorizedClientManager::authorize);
         		});
         	}
         
         	private Mono authorizeRequest(String registrationId, ServerWebExchange exchange) {
         		Mono defaultedAuthentication = currentAuthentication();
        -
         		Mono defaultedRegistrationId = Mono.justOrEmpty(registrationId)
         				.switchIfEmpty(clientRegistrationId(defaultedAuthentication))
        -				.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one")));
        -
        +				.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
        +						"The clientRegistrationId could not be resolved. Please provide one")));
         		Mono defaultedExchange = Mono.justOrEmpty(exchange)
         				.switchIfEmpty(currentServerWebExchange());
        -
         		return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange)
        -				.map(t3 -> OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1())
        -						.principal(t3.getT2())
        -						.attribute(ServerWebExchange.class.getName(), t3.getT3())
        -						.build()
        -				);
        +				.map((zipped) -> OAuth2AuthorizeRequest.withClientRegistrationId(zipped.getT1())
        +						.principal(zipped.getT2()).attribute(ServerWebExchange.class.getName(), zipped.getT3())
        +						.build());
         	}
         
         	private Mono currentAuthentication() {
        +		// @formatter:off
         		return ReactiveSecurityContextHolder.getContext()
         				.map(SecurityContext::getAuthentication)
         				.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
        +		// @formatter:on
         	}
         
         	private Mono clientRegistrationId(Mono authentication) {
        -		return authentication
        -				.filter(t -> t instanceof OAuth2AuthenticationToken)
        +		return authentication.filter((t) -> t instanceof OAuth2AuthenticationToken)
         				.cast(OAuth2AuthenticationToken.class)
         				.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
         	}
         
         	private Mono currentServerWebExchange() {
        +		// @formatter:off
         		return Mono.subscriberContext()
        -				.filter(c -> c.hasKey(ServerWebExchange.class))
        -				.map(c -> c.get(ServerWebExchange.class));
        +				.filter((c) -> c.hasKey(ServerWebExchange.class))
        +				.map((c) -> c.get(ServerWebExchange.class));
        +		// @formatter:on
         	}
        +
         }
        diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java
        index 6357dc38c9..56c180e0f0 100644
        --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java
        +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java
        @@ -13,8 +13,11 @@
          * See the License for the specific language governing permissions and
          * limitations under the License.
          */
        +
         package org.springframework.security.oauth2.client.web.server;
         
        +import reactor.core.publisher.Mono;
        +
         import org.springframework.security.authentication.AuthenticationTrustResolver;
         import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
         import org.springframework.security.core.Authentication;
        @@ -25,15 +28,14 @@ import org.springframework.security.oauth2.client.web.HttpSessionOAuth2Authorize
         import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
         import org.springframework.util.Assert;
         import org.springframework.web.server.ServerWebExchange;
        -import reactor.core.publisher.Mono;
         
         /**
        - * An implementation of an {@link ServerOAuth2AuthorizedClientRepository} that
        - * delegates to the provided {@link ServerOAuth2AuthorizedClientRepository} if the current
        - * {@code Principal} is authenticated, otherwise,
        - * to the default (or provided) {@link ServerOAuth2AuthorizedClientRepository}
        - * if the current request is unauthenticated (or anonymous).
        - * The default {@code ReactiveOAuth2AuthorizedClientRepository} is
        + * An implementation of an {@link ServerOAuth2AuthorizedClientRepository} that delegates
        + * to the provided {@link ServerOAuth2AuthorizedClientRepository} if the current
        + * {@code Principal} is authenticated, otherwise, to the default (or provided)
        + * {@link ServerOAuth2AuthorizedClientRepository} if the current request is
        + * unauthenticated (or anonymous). The default
        + * {@code ReactiveOAuth2AuthorizedClientRepository} is
          * {@link WebSessionServerOAuth2AuthorizedClientRepository}.
          *
          * @author Rob Winch
        @@ -45,25 +47,29 @@ import reactor.core.publisher.Mono;
          */
         public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository
         		implements ServerOAuth2AuthorizedClientRepository {
        +
         	private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
        +
         	private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
        +
         	private ServerOAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository();
         
         	/**
         	 * Creates an instance
        -	 *
         	 * @param authorizedClientService the authorized client service
         	 */
        -	public AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
        +	public AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(
        +			ReactiveOAuth2AuthorizedClientService authorizedClientService) {
         		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
         		this.authorizedClientService = authorizedClientService;
         	}
         
         	/**
        -	 * Sets the {@link ServerOAuth2AuthorizedClientRepository} used for requests that are unauthenticated (or anonymous).
        -	 * The default is {@link WebSessionServerOAuth2AuthorizedClientRepository}.
        -	 *
        -	 * @param anonymousAuthorizedClientRepository the repository used for requests that are unauthenticated (or anonymous)
        +	 * Sets the {@link ServerOAuth2AuthorizedClientRepository} used for requests that are
        +	 * unauthenticated (or anonymous). The default is
        +	 * {@link WebSessionServerOAuth2AuthorizedClientRepository}.
        +	 * @param anonymousAuthorizedClientRepository the repository used for requests that
        +	 * are unauthenticated (or anonymous)
         	 */
         	public void setAnonymousAuthorizedClientRepository(
         			ServerOAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) {
        @@ -72,13 +78,12 @@ public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository
         	}
         
         	@Override
        -	public  Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal,
        -																		ServerWebExchange exchange) {
        +	public  Mono loadAuthorizedClient(String clientRegistrationId,
        +			Authentication principal, ServerWebExchange exchange) {
         		if (this.isPrincipalAuthenticated(principal)) {
         			return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName());
        -		} else {
        -			return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange);
         		}
        +		return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange);
         	}
         
         	@Override
        @@ -86,9 +91,8 @@ public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository
         			ServerWebExchange exchange) {
         		if (this.isPrincipalAuthenticated(principal)) {
         			return this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
        -		} else {
        -			return this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
         		}
        +		return this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
         	}
         
         	@Override
        @@ -96,14 +100,14 @@ public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository
         			ServerWebExchange exchange) {
         		if (this.isPrincipalAuthenticated(principal)) {
         			return this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName());
        -		} else {
        -			return this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, exchange);
         		}
        +		return this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal,
        +				exchange);
         	}
         
         	private boolean isPrincipalAuthenticated(Authentication authentication) {
        -		return authentication != null &&
        -				!this.authenticationTrustResolver.isAnonymous(authentication) &&
        -				authentication.isAuthenticated();
        +		return authentication != null && !this.authenticationTrustResolver.isAnonymous(authentication)
        +				&& authentication.isAuthenticated();
         	}
        +
         }
        diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java
        index 3aee38b289..a3a0169189 100644
        --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java
        +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java
        @@ -16,6 +16,16 @@
         
         package org.springframework.security.oauth2.client.web.server;
         
        +import java.nio.charset.StandardCharsets;
        +import java.security.MessageDigest;
        +import java.security.NoSuchAlgorithmException;
        +import java.util.Base64;
        +import java.util.HashMap;
        +import java.util.Map;
        +import java.util.function.Consumer;
        +
        +import reactor.core.publisher.Mono;
        +
         import org.springframework.http.HttpStatus;
         import org.springframework.http.server.reactive.ServerHttpRequest;
         import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
        @@ -38,39 +48,33 @@ import org.springframework.web.server.ResponseStatusException;
         import org.springframework.web.server.ServerWebExchange;
         import org.springframework.web.util.UriComponents;
         import org.springframework.web.util.UriComponentsBuilder;
        -import reactor.core.publisher.Mono;
        -
        -import java.nio.charset.StandardCharsets;
        -import java.security.MessageDigest;
        -import java.security.NoSuchAlgorithmException;
        -import java.util.Base64;
        -import java.util.HashMap;
        -import java.util.Map;
        -import java.util.function.Consumer;
         
         /**
          * The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}.
          *
        - * The {@link ClientRegistration#getRegistrationId()} is extracted from the request using the
        - * {@link #DEFAULT_AUTHORIZATION_REQUEST_PATTERN}. The injected {@link ReactiveClientRegistrationRepository} is then
        - * used to resolve the {@link ClientRegistration} and create the {@link OAuth2AuthorizationRequest}.
        + * The {@link ClientRegistration#getRegistrationId()} is extracted from the request using
        + * the {@link #DEFAULT_AUTHORIZATION_REQUEST_PATTERN}. The injected
        + * {@link ReactiveClientRegistrationRepository} is then used to resolve the
        + * {@link ClientRegistration} and create the {@link OAuth2AuthorizationRequest}.
          *
          * @author Rob Winch
          * @author Mark Heckler
          * @since 5.1
          */
        -public class DefaultServerOAuth2AuthorizationRequestResolver
        -		implements ServerOAuth2AuthorizationRequestResolver {
        +public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOAuth2AuthorizationRequestResolver {
         
         	/**
        -	 * The name of the path variable that contains the {@link ClientRegistration#getRegistrationId()}
        +	 * The name of the path variable that contains the
        +	 * {@link ClientRegistration#getRegistrationId()}
         	 */
         	public static final String DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
         
         	/**
        -	 * The default pattern used to resolve the {@link ClientRegistration#getRegistrationId()}
        +	 * The default pattern used to resolve the
        +	 * {@link ClientRegistration#getRegistrationId()}
         	 */
        -	public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME + "}";
        +	public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{"
        +			+ DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME + "}";
         
         	private static final char PATH_DELIMITER = '/';
         
        @@ -80,26 +84,33 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
         
         	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
         
        -	private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
        +	private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(
        +			Base64.getUrlEncoder().withoutPadding(), 96);
         
        -	private Consumer authorizationRequestCustomizer = customizer -> {};
        +	private Consumer authorizationRequestCustomizer = (customizer) -> {
        +	};
         
         	/**
         	 * Creates a new instance
        -	 * @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
        +	 * @param clientRegistrationRepository the repository to resolve the
        +	 * {@link ClientRegistration}
         	 */
        -	public DefaultServerOAuth2AuthorizationRequestResolver(ReactiveClientRegistrationRepository clientRegistrationRepository) {
        -		this(clientRegistrationRepository, new PathPatternParserServerWebExchangeMatcher(
        -				DEFAULT_AUTHORIZATION_REQUEST_PATTERN));
        +	public DefaultServerOAuth2AuthorizationRequestResolver(
        +			ReactiveClientRegistrationRepository clientRegistrationRepository) {
        +		this(clientRegistrationRepository,
        +				new PathPatternParserServerWebExchangeMatcher(DEFAULT_AUTHORIZATION_REQUEST_PATTERN));
         	}
         
         	/**
         	 * Creates a new instance
        -	 * @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
        -	 * @param authorizationRequestMatcher the matcher that determines if the request is a match and extracts the
        -	 * {@link #DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME} from the path variables.
        +	 * @param clientRegistrationRepository the repository to resolve the
        +	 * {@link ClientRegistration}
        +	 * @param authorizationRequestMatcher the matcher that determines if the request is a
        +	 * match and extracts the {@link #DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME} from the
        +	 * path variables.
         	 */
        -	public DefaultServerOAuth2AuthorizationRequestResolver(ReactiveClientRegistrationRepository clientRegistrationRepository,
        +	public DefaultServerOAuth2AuthorizationRequestResolver(
        +			ReactiveClientRegistrationRepository clientRegistrationRepository,
         			ServerWebExchangeMatcher authorizationRequestMatcher) {
         		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
         		Assert.notNull(authorizationRequestMatcher, "authorizationRequestMatcher cannot be null");
        @@ -109,82 +120,94 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
         
         	@Override
         	public Mono resolve(ServerWebExchange exchange) {
        -		return this.authorizationRequestMatcher.matches(exchange)
        -				.filter(matchResult -> matchResult.isMatch())
        +		// @formatter:off
        +		return this.authorizationRequestMatcher
        +				.matches(exchange)
        +				.filter((matchResult) -> matchResult.isMatch())
         				.map(ServerWebExchangeMatcher.MatchResult::getVariables)
        -				.map(variables -> variables.get(DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME))
        +				.map((variables) -> variables.get(DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME))
         				.cast(String.class)
        -				.flatMap(clientRegistrationId -> resolve(exchange, clientRegistrationId));
        +				.flatMap((clientRegistrationId) -> resolve(exchange, clientRegistrationId));
        +		// @formatter:on
         	}
         
         	@Override
        -	public Mono resolve(ServerWebExchange exchange,
        -			String clientRegistrationId) {
        +	public Mono resolve(ServerWebExchange exchange, String clientRegistrationId) {
         		return this.findByRegistrationId(exchange, clientRegistrationId)
        -			.map(clientRegistration -> authorizationRequest(exchange, clientRegistration));
        +				.map((clientRegistration) -> authorizationRequest(exchange, clientRegistration));
         	}
         
         	/**
        -	 * Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
        -	 * allowing for further customizations.
        -	 *
        +	 * Sets the {@code Consumer} to be provided the
        +	 * {@link OAuth2AuthorizationRequest.Builder} allowing for further customizations.
        +	 * @param authorizationRequestCustomizer the {@code Consumer} to be provided the
        +	 * {@link OAuth2AuthorizationRequest.Builder}
         	 * @since 5.3
        -	 * @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
         	 */
        -	public final void setAuthorizationRequestCustomizer(Consumer authorizationRequestCustomizer) {
        +	public final void setAuthorizationRequestCustomizer(
        +			Consumer authorizationRequestCustomizer) {
         		Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
         		this.authorizationRequestCustomizer = authorizationRequestCustomizer;
         	}
         
         	private Mono findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
        +		// @formatter:off
         		return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
         				.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
        +		// @formatter:on
         	}
         
         	private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange,
         			ClientRegistration clientRegistration) {
         		String redirectUriStr = expandRedirectUri(exchange.getRequest(), clientRegistration);
        -
         		Map attributes = new HashMap<>();
         		attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
        -
        -		OAuth2AuthorizationRequest.Builder builder;
        -		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
        -			builder = OAuth2AuthorizationRequest.authorizationCode();
        -			Map additionalParameters = new HashMap<>();
        -			if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) &&
        -					clientRegistration.getScopes().contains(OidcScopes.OPENID)) {
        -				// Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
        -				// scope
        -				// 		REQUIRED. OpenID Connect requests MUST contain the "openid" scope value.
        -				addNonceParameters(attributes, additionalParameters);
        -			}
        -			if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) {
        -				addPkceParameters(attributes, additionalParameters);
        -			}
        -			builder.additionalParameters(additionalParameters);
        -		} else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
        -			builder = OAuth2AuthorizationRequest.implicit();
        -		} else {
        -			throw new IllegalArgumentException(
        -					"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
        -							+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
        -		}
        -		builder
        -				.clientId(clientRegistration.getClientId())
        +		OAuth2AuthorizationRequest.Builder builder = getBuilder(clientRegistration, attributes);
        +		// @formatter:off
        +		builder.clientId(clientRegistration.getClientId())
         				.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
         				.redirectUri(redirectUriStr)
         				.scopes(clientRegistration.getScopes())
         				.state(this.stateGenerator.generateKey())
         				.attributes(attributes);
        +		// @formatter:on
         
         		this.authorizationRequestCustomizer.accept(builder);
         
         		return builder.build();
         	}
         
        +	private OAuth2AuthorizationRequest.Builder getBuilder(ClientRegistration clientRegistration,
        +			Map attributes) {
        +		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
        +			OAuth2AuthorizationRequest.Builder builder = OAuth2AuthorizationRequest.authorizationCode();
        +			Map additionalParameters = new HashMap<>();
        +			if (!CollectionUtils.isEmpty(clientRegistration.getScopes())
        +					&& clientRegistration.getScopes().contains(OidcScopes.OPENID)) {
        +				// Section 3.1.2.1 Authentication Request -
        +				// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
        +				// scope
        +				// REQUIRED. OpenID Connect requests MUST contain the "openid" scope
        +				// value.
        +				addNonceParameters(attributes, additionalParameters);
        +			}
        +			if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) {
        +				addPkceParameters(attributes, additionalParameters);
        +			}
        +			builder.additionalParameters(additionalParameters);
        +			return builder;
        +		}
        +		if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
        +			return OAuth2AuthorizationRequest.implicit();
        +		}
        +		throw new IllegalArgumentException(
        +				"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
        +						+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
        +	}
        +
         	/**
        -	 * Expands the {@link ClientRegistration#getRedirectUri()} with following provided variables:
        + * Expands the {@link ClientRegistration#getRedirectUri()} with following provided + * variables:
        * - baseUrl (e.g. https://localhost/app)
        * - baseScheme (e.g. https)
        * - baseHost (e.g. localhost)
        @@ -195,54 +218,58 @@ public class DefaultServerOAuth2AuthorizationRequestResolver *

        * Null variables are provided as empty strings. *

        - * Default redirectUri is: {@code org.springframework.security.config.oauth2.client.CommonOAuth2Provider#DEFAULT_REDIRECT_URL} - * + * Default redirectUri is: + * {@code org.springframework.security.config.oauth2.client.CommonOAuth2Provider#DEFAULT_REDIRECT_URL} * @return expanded URI */ private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - + // @formatter:off UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) .replacePath(request.getPath().contextPath().value()) .replaceQuery(null) .fragment(null) .build(); + // @formatter:on String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", scheme == null ? "" : scheme); + uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); String host = uriComponents.getHost(); - uriVariables.put("baseHost", host == null ? "" : host); + uriVariables.put("baseHost", (host != null) ? host : ""); // following logic is based on HierarchicalUriComponents#toUriString() int port = uriComponents.getPort(); - uriVariables.put("basePort", port == -1 ? "" : ":" + port); + uriVariables.put("basePort", (port == -1) ? "" : ":" + port); String path = uriComponents.getPath(); if (StringUtils.hasLength(path)) { if (path.charAt(0) != PATH_DELIMITER) { path = PATH_DELIMITER + path; } } - uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("basePath", (path != null) ? path : ""); uriVariables.put("baseUrl", uriComponents.toUriString()); - String action = ""; if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { action = "login"; } uriVariables.put("action", action); - + // @formatter:off return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()) .buildAndExpand(uriVariables) .toUriString(); + // @formatter:on } /** * Creates nonce and its hash for use in OpenID Connect 1.0 Authentication Requests. - * - * @param attributes where the {@link OidcParameterNames#NONCE} is stored for the authentication request - * @param additionalParameters where the {@link OidcParameterNames#NONCE} hash is added for the authentication request + * @param attributes where the {@link OidcParameterNames#NONCE} is stored for the + * authentication request + * @param additionalParameters where the {@link OidcParameterNames#NONCE} hash is + * added for the authentication request * * @since 5.2 - * @see 3.1.2.1. Authentication Request + * @see 3.1.2.1. + * Authentication Request */ private void addNonceParameters(Map attributes, Map additionalParameters) { try { @@ -250,20 +277,27 @@ public class DefaultServerOAuth2AuthorizationRequestResolver String nonceHash = createHash(nonce); attributes.put(OidcParameterNames.NONCE, nonce); additionalParameters.put(OidcParameterNames.NONCE, nonceHash); - } catch (NoSuchAlgorithmException e) { } + } + catch (NoSuchAlgorithmException ex) { + } } /** - * Creates and adds additional PKCE parameters for use in the OAuth 2.0 Authorization and Access Token Requests - * - * @param attributes where {@link PkceParameterNames#CODE_VERIFIER} is stored for the token request - * @param additionalParameters where {@link PkceParameterNames#CODE_CHALLENGE} and, usually, - * {@link PkceParameterNames#CODE_CHALLENGE_METHOD} are added to be used in the authorization request. + * Creates and adds additional PKCE parameters for use in the OAuth 2.0 Authorization + * and Access Token Requests + * @param attributes where {@link PkceParameterNames#CODE_VERIFIER} is stored for the + * token request + * @param additionalParameters where {@link PkceParameterNames#CODE_CHALLENGE} and, + * usually, {@link PkceParameterNames#CODE_CHALLENGE_METHOD} are added to be used in + * the authorization request. * * @since 5.2 - * @see 1.1. Protocol Flow - * @see 4.1. Client Creates a Code Verifier - * @see 4.2. Client Creates the Code Challenge + * @see 1.1. + * Protocol Flow + * @see 4.1. + * Client Creates a Code Verifier + * @see 4.2. + * Client Creates the Code Challenge */ private void addPkceParameters(Map attributes, Map additionalParameters) { String codeVerifier = this.secureKeyGenerator.generateKey(); @@ -272,7 +306,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver String codeChallenge = createHash(codeVerifier); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, codeChallenge); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - } catch (NoSuchAlgorithmException e) { + } + catch (NoSuchAlgorithmException ex) { additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, codeVerifier); } } @@ -282,4 +317,5 @@ public class DefaultServerOAuth2AuthorizationRequestResolver byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII)); return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index a8d10bff0e..a9659b1bde 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -16,6 +16,15 @@ package org.springframework.security.oauth2.client.web.server; +import java.net.URI; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; @@ -46,40 +55,31 @@ import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; -import reactor.core.publisher.Mono; - -import java.net.URI; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; /** - * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, - * which handles the processing of the OAuth 2.0 Authorization Response. + * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, which handles the + * processing of the OAuth 2.0 Authorization Response. * *

        * The OAuth 2.0 Authorization Response is processed as follows: * *

          - *
        • - * Assuming the End-User (Resource Owner) has granted access to the Client, the Authorization Server will append the - * {@link OAuth2ParameterNames#CODE code} and {@link OAuth2ParameterNames#STATE state} parameters - * to the {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization Request) - * and redirect the End-User's user-agent back to this {@code Filter} (the Client). - *
        • - *
        • - * This {@code Filter} will then create an {@link OAuth2AuthorizationCodeAuthenticationToken} with - * the {@link OAuth2ParameterNames#CODE code} received and - * delegate it to the {@link ReactiveAuthenticationManager} to authenticate. - *
        • - *
        • - * Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the - * {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the - * {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal} - * and saving it via the {@link ServerOAuth2AuthorizedClientRepository}. - *
        • + *
        • Assuming the End-User (Resource Owner) has granted access to the Client, the + * Authorization Server will append the {@link OAuth2ParameterNames#CODE code} and + * {@link OAuth2ParameterNames#STATE state} parameters to the + * {@link OAuth2ParameterNames#REDIRECT_URI redirect_uri} (provided in the Authorization + * Request) and redirect the End-User's user-agent back to this {@code Filter} (the + * Client).
        • + *
        • This {@code Filter} will then create an + * {@link OAuth2AuthorizationCodeAuthenticationToken} with the + * {@link OAuth2ParameterNames#CODE code} received and delegate it to the + * {@link ReactiveAuthenticationManager} to authenticate.
        • + *
        • Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized + * Client} is created by associating the + * {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to + * the {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} + * and current {@code Principal} and saving it via the + * {@link ServerOAuth2AuthorizedClientRepository}.
        • *
        * * @author Rob Winch @@ -95,16 +95,19 @@ import java.util.Set; * @see ReactiveClientRegistrationRepository * @see OAuth2AuthorizedClient * @see ServerOAuth2AuthorizedClientRepository - * @see Section 4.1 Authorization Code Grant - * @see Section 4.1.2 Authorization Response + * @see Section + * 4.1 Authorization Code Grant + * @see Section 4.1.2 Authorization + * Response */ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { + private final ReactiveAuthenticationManager authenticationManager; private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - private ServerAuthorizationRequestRepository authorizationRequestRepository = - new WebSessionOAuth2ServerAuthorizationRequestRepository(); + private ServerAuthorizationRequestRepository authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); private ServerAuthenticationSuccessHandler authenticationSuccessHandler; @@ -119,10 +122,9 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { private ServerRequestCache requestCache = new WebSessionServerRequestCache(); private AnonymousAuthenticationToken anonymousToken = new AnonymousAuthenticationToken("key", "anonymous", - AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); - public OAuth2AuthorizationCodeGrantWebFilter( - ReactiveAuthenticationManager authenticationManager, + public OAuth2AuthorizationCodeGrantWebFilter(ReactiveAuthenticationManager authenticationManager, ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); @@ -131,20 +133,18 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { this.authenticationManager = authenticationManager; this.authorizedClientRepository = authorizedClientRepository; this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; - ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = - new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter( + clientRegistrationRepository); authenticationConverter.setAuthorizationRequestRepository(this.authorizationRequestRepository); this.authenticationConverter = authenticationConverter; this.defaultAuthenticationConverter = true; - RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = - new RedirectServerAuthenticationSuccessHandler(); + RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); authenticationSuccessHandler.setRequestCache(this.requestCache); this.authenticationSuccessHandler = authenticationSuccessHandler; this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); } - public OAuth2AuthorizationCodeGrantWebFilter( - ReactiveAuthenticationManager authenticationManager, + public OAuth2AuthorizationCodeGrantWebFilter(ReactiveAuthenticationManager authenticationManager, ServerAuthenticationConverter authenticationConverter, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); @@ -154,19 +154,18 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { this.authorizedClientRepository = authorizedClientRepository; this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; this.authenticationConverter = authenticationConverter; - RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = - new RedirectServerAuthenticationSuccessHandler(); + RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); authenticationSuccessHandler.setRequestCache(this.requestCache); this.authenticationSuccessHandler = authenticationSuccessHandler; this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); } /** - * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. - * The default is {@link WebSessionOAuth2ServerAuthorizationRequestRepository}. - * + * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. The + * default is {@link WebSessionOAuth2ServerAuthorizationRequestRepository}. + * @param authorizationRequestRepository the repository used for storing + * {@link OAuth2AuthorizationRequest}'s * @since 5.2 - * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s */ public final void setAuthorizationRequestRepository( ServerAuthorizationRequestRepository authorizationRequestRepository) { @@ -183,11 +182,12 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { } /** - * Sets the {@link ServerRequestCache} used for loading a previously saved request (if available) - * and replaying it after completing the processing of the OAuth 2.0 Authorization Response. - * + * Sets the {@link ServerRequestCache} used for loading a previously saved request (if + * available) and replaying it after completing the processing of the OAuth 2.0 + * Authorization Response. + * @param requestCache the cache used for loading a previously saved request (if + * available) * @since 5.4 - * @param requestCache the cache used for loading a previously saved request (if available) */ public final void setRequestCache(ServerRequestCache requestCache) { Assert.notNull(requestCache, "requestCache cannot be null"); @@ -196,79 +196,93 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { } private void updateDefaultAuthenticationSuccessHandler() { - ((RedirectServerAuthenticationSuccessHandler) this.authenticationSuccessHandler).setRequestCache(this.requestCache); + ((RedirectServerAuthenticationSuccessHandler) this.authenticationSuccessHandler) + .setRequestCache(this.requestCache); } @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + // @formatter:off return this.requiresAuthenticationMatcher.matches(exchange) .filter(ServerWebExchangeMatcher.MatchResult::isMatch) - .flatMap(matchResult -> - this.authenticationConverter.convert(exchange) - .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException( - e.getError(), e.getError().toString()))) + .flatMap((matchResult) -> this.authenticationConverter.convert(exchange) + .onErrorMap(OAuth2AuthorizationException.class, + (ex) -> new OAuth2AuthenticationException(ex.getError(), ex.getError().toString()) + ) + ) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) - .flatMap(token -> authenticate(exchange, chain, token)) - .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler - .onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)); + .flatMap((token) -> authenticate(exchange, chain, token)) + .onErrorResume(AuthenticationException.class, (e) -> + this.authenticationFailureHandler.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e) + ); + // @formatter:on } private Mono authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) { WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); return this.authenticationManager.authenticate(token) - .onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException( - e.getError(), e.getError().toString())) - .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) - .flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange)) - .onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler - .onAuthenticationFailure(webFilterExchange, e)); + .onErrorMap(OAuth2AuthorizationException.class, + (ex) -> new OAuth2AuthenticationException(ex.getError(), ex.getError().toString())) + .switchIfEmpty(Mono.defer( + () -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) + .flatMap((authentication) -> onAuthenticationSuccess(authentication, webFilterExchange)) + .onErrorResume(AuthenticationException.class, + (e) -> this.authenticationFailureHandler.onAuthenticationFailure(webFilterExchange, e)); } private Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { OAuth2AuthorizationCodeAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) authentication; OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - authenticationResult.getClientRegistration(), - authenticationResult.getName(), - authenticationResult.getAccessToken(), - authenticationResult.getRefreshToken()); - return this.authenticationSuccessHandler - .onAuthenticationSuccess(webFilterExchange, authentication) - .then(ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(this.anonymousToken) - .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange())) - ); + authenticationResult.getClientRegistration(), authenticationResult.getName(), + authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); + // @formatter:off + return this.authenticationSuccessHandler.onAuthenticationSuccess(webFilterExchange, authentication) + .then(ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(this.anonymousToken) + .flatMap((principal) -> this.authorizedClientRepository + .saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange()) + ) + ); + // @formatter:on } private Mono matchesAuthorizationResponse(ServerWebExchange exchange) { + // @formatter:off return Mono.just(exchange) - .filter(exch -> OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams())) - .flatMap(exch -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange) - .flatMap(authorizationRequest -> - matchesRedirectUri(exch.getRequest().getURI(), authorizationRequest.getRedirectUri()))) + .filter((exch) -> + OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams()) + ) + .flatMap((exch) -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange) + .flatMap((authorizationRequest) -> matchesRedirectUri(exch.getRequest().getURI(), + authorizationRequest.getRedirectUri())) + ) .switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch()); + // @formatter:on } - private static Mono matchesRedirectUri( - URI authorizationResponseUri, String authorizationRequestRedirectUri) { + private static Mono matchesRedirectUri(URI authorizationResponseUri, + String authorizationRequestRedirectUri) { UriComponents requestUri = UriComponentsBuilder.fromUri(authorizationResponseUri).build(); UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequestRedirectUri).build(); - Set>> requestUriParameters = - new LinkedHashSet<>(requestUri.getQueryParams().entrySet()); - Set>> redirectUriParameters = - new LinkedHashSet<>(redirectUri.getQueryParams().entrySet()); - // Remove the additional request parameters (if any) from the authorization response (request) - // before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any) + Set>> requestUriParameters = new LinkedHashSet<>( + requestUri.getQueryParams().entrySet()); + Set>> redirectUriParameters = new LinkedHashSet<>( + redirectUri.getQueryParams().entrySet()); + // Remove the additional request parameters (if any) from the authorization + // response (request) + // before doing an exact comparison with the authorizationRequest.getRedirectUri() + // parameters (if any) requestUriParameters.retainAll(redirectUriParameters); - - if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) && - Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) && - Objects.equals(requestUri.getHost(), redirectUri.getHost()) && - Objects.equals(requestUri.getPort(), redirectUri.getPort()) && - Objects.equals(requestUri.getPath(), redirectUri.getPath()) && - Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) { + if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) + && Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) + && Objects.equals(requestUri.getHost(), redirectUri.getHost()) + && Objects.equals(requestUri.getPort(), redirectUri.getPort()) + && Objects.equals(requestUri.getPath(), redirectUri.getPath()) + && Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) { return ServerWebExchangeMatcher.MatchResult.match(); } return ServerWebExchangeMatcher.MatchResult.notMatch(); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java index 907ebdfffb..1b20821a56 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; +import java.net.URI; + +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -31,27 +36,25 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import org.springframework.web.util.UriComponentsBuilder; -import reactor.core.publisher.Mono; - -import java.net.URI; /** - * This {@code WebFilter} initiates the authorization code grant or implicit grant flow - * by redirecting the End-User's user-agent to the Authorization Server's Authorization Endpoint. + * This {@code WebFilter} initiates the authorization code grant or implicit grant flow by + * redirecting the End-User's user-agent to the Authorization Server's Authorization + * Endpoint. * *

        - * It builds the OAuth 2.0 Authorization Request, - * which is used as the redirect {@code URI} to the Authorization Endpoint. - * The redirect {@code URI} will include the client identifier, requested scope(s), state, - * response type, and a redirection URI which the authorization server will send the user-agent back to - * once access is granted (or denied) by the End-User (Resource Owner). + * It builds the OAuth 2.0 Authorization Request, which is used as the redirect + * {@code URI} to the Authorization Endpoint. The redirect {@code URI} will include the + * client identifier, requested scope(s), state, response type, and a redirection URI + * which the authorization server will send the user-agent back to once access is granted + * (or denied) by the End-User (Resource Owner). * *

        - * By default, this {@code Filter} responds to authorization requests - * at the {@code URI} {@code /oauth2/authorization/{registrationId}}. - * The {@code URI} template variable {@code {registrationId}} represents the - * {@link ClientRegistration#getRegistrationId() registration identifier} of the client - * that is used for initiating the OAuth 2.0 Authorization Request. + * By default, this {@code Filter} responds to authorization requests at the {@code URI} + * {@code /oauth2/authorization/{registrationId}}. The {@code URI} template variable + * {@code {registrationId}} represents the {@link ClientRegistration#getRegistrationId() + * registration identifier} of the client that is used for initiating the OAuth 2.0 + * Authorization Request. * * @author Rob Winch * @since 5.1 @@ -59,41 +62,53 @@ import java.net.URI; * @see AuthorizationRequestRepository * @see ClientRegistration * @see ClientRegistrationRepository - * @see Section 4.1 Authorization Code Grant - * @see Section 4.1.1 Authorization Request (Authorization Code) - * @see Section 4.2 Implicit Grant - * @see Section 4.2.1 Authorization Request (Implicit) + * @see Section + * 4.1 Authorization Code Grant + * @see Section 4.1.1 Authorization Request + * (Authorization Code) + * @see Section + * 4.2 Implicit Grant + * @see Section 4.2.1 Authorization Request + * (Implicit) */ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { + private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + private final ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; - private ServerAuthorizationRequestRepository authorizationRequestRepository = - new WebSessionOAuth2ServerAuthorizationRequestRepository(); + + private ServerAuthorizationRequestRepository authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); /** - * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided + * parameters. * @param clientRegistrationRepository the repository of client registrations */ - public OAuth2AuthorizationRequestRedirectWebFilter(ReactiveClientRegistrationRepository clientRegistrationRepository) { - this.authorizationRequestResolver = new DefaultServerOAuth2AuthorizationRequestResolver(clientRegistrationRepository); + public OAuth2AuthorizationRequestRedirectWebFilter( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + this.authorizationRequestResolver = new DefaultServerOAuth2AuthorizationRequestResolver( + clientRegistrationRepository); } /** - * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. - * + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided + * parameters. * @param authorizationRequestResolver the resolver to use */ - public OAuth2AuthorizationRequestRedirectWebFilter(ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver) { + public OAuth2AuthorizationRequestRedirectWebFilter( + ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver) { Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null"); this.authorizationRequestResolver = authorizationRequestResolver; } /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. - * - * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s + * @param authorizationRequestRepository the repository used for storing + * {@link OAuth2AuthorizationRequest}'s */ public final void setAuthorizationRequestRepository( ServerAuthorizationRequestRepository authorizationRequestRepository) { @@ -112,11 +127,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + // @formatter:off return this.authorizationRequestResolver.resolve(exchange) - .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) - .onErrorResume(ClientAuthorizationRequiredException.class, e -> this.requestCache.saveRequest(exchange) - .then(this.authorizationRequestResolver.resolve(exchange, e.getClientRegistrationId()))) - .flatMap(clientRegistration -> sendRedirectForAuthorization(exchange, clientRegistration)); + .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) + .onErrorResume(ClientAuthorizationRequiredException.class, + (ex) -> this.requestCache.saveRequest(exchange).then( + this.authorizationRequestResolver.resolve(exchange, ex.getClientRegistrationId())) + ) + .flatMap((clientRegistration) -> sendRedirectForAuthorization(exchange, clientRegistration)); + // @formatter:on } private Mono sendRedirectForAuthorization(ServerWebExchange exchange, @@ -127,12 +146,14 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { saveAuthorizationRequest = this.authorizationRequestRepository .saveAuthorizationRequest(authorizationRequest, exchange); } - - URI redirectUri = UriComponentsBuilder - .fromUriString(authorizationRequest.getAuthorizationRequestUri()) - .build(true).toUri(); + // @formatter:off + URI redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getAuthorizationRequestUri()) + .build(true) + .toUri(); + // @formatter:on return saveAuthorizationRequest .then(this.authorizationRedirectStrategy.sendRedirect(exchange, redirectUri)); }); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java index 423c785302..d47567026c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; +import java.util.Map; + import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import java.util.Map; - /** * Utility methods for an OAuth 2.0 Authorization Response. * @@ -52,34 +53,32 @@ final class OAuth2AuthorizationResponseUtils { } static boolean isAuthorizationResponseSuccess(MultiValueMap request) { - return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.CODE)) && - StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); + return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.CODE)) + && StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); } static boolean isAuthorizationResponseError(MultiValueMap request) { - return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.ERROR)) && - StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); + return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.ERROR)) + && StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE)); } static OAuth2AuthorizationResponse convert(MultiValueMap request, String redirectUri) { String code = request.getFirst(OAuth2ParameterNames.CODE); String errorCode = request.getFirst(OAuth2ParameterNames.ERROR); String state = request.getFirst(OAuth2ParameterNames.STATE); - if (StringUtils.hasText(code)) { - return OAuth2AuthorizationResponse.success(code) - .redirectUri(redirectUri) - .state(state) - .build(); - } else { - String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); - String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); - return OAuth2AuthorizationResponse.error(errorCode) + return OAuth2AuthorizationResponse.success(code).redirectUri(redirectUri).state(state).build(); + } + String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); + String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); + // @formatter:off + return OAuth2AuthorizationResponse.error(errorCode) .redirectUri(redirectUri) .errorDescription(errorDescription) .errorUri(errorUri) .state(state) .build(); - } + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerAuthorizationRequestRepository.java index ee180a8faf..6bc5404463 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerAuthorizationRequestRepository.java @@ -16,57 +16,55 @@ package org.springframework.security.oauth2.client.web.server; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter; import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - /** - * Implementations of this interface are responsible for the persistence - * of {@link OAuth2AuthorizationRequest} between requests. + * Implementations of this interface are responsible for the persistence of + * {@link OAuth2AuthorizationRequest} between requests. * *

        - * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the Authorization Request - * before it initiates the authorization code grant flow. - * As well, used by the {@link OAuth2LoginAuthenticationFilter} for resolving - * the associated Authorization Request when handling the callback of the Authorization Response. + * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the + * Authorization Request before it initiates the authorization code grant flow. As well, + * used by the {@link OAuth2LoginAuthenticationFilter} for resolving the associated + * Authorization Request when handling the callback of the Authorization Response. * + * @param The type of OAuth 2.0 Authorization Request * @author Rob Winch * @since 5.1 * @see OAuth2AuthorizationRequest * @see HttpSessionOAuth2AuthorizationRequestRepository - * - * @param The type of OAuth 2.0 Authorization Request */ public interface ServerAuthorizationRequestRepository { /** - * Returns the {@link OAuth2AuthorizationRequest} associated to the provided {@code HttpServletRequest} - * or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizationRequest} associated to the provided + * {@code HttpServletRequest} or {@code null} if not available. * @param exchange the {@code ServerWebExchange} * @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available */ Mono loadAuthorizationRequest(ServerWebExchange exchange); /** - * Persists the {@link OAuth2AuthorizationRequest} associating it to - * the provided {@code HttpServletRequest} and/or {@code HttpServletResponse}. - * + * Persists the {@link OAuth2AuthorizationRequest} associating it to the provided + * {@code HttpServletRequest} and/or {@code HttpServletResponse}. * @param authorizationRequest the {@link OAuth2AuthorizationRequest} - * @param exchange the {@code ServerWebExchange} + * @param exchange the {@code ServerWebExchange} */ Mono saveAuthorizationRequest(T authorizationRequest, ServerWebExchange exchange); /** * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the * provided {@code HttpServletRequest} or if not available returns {@code null}. - * * @param exchange the {@code ServerWebExchange} - * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not + * available */ Mono removeAuthorizationRequest(ServerWebExchange exchange); + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java index 6ff9d34fd1..b36ef33941 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java @@ -16,6 +16,8 @@ package org.springframework.security.oauth2.client.web.server; +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; @@ -29,24 +31,23 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponentsBuilder; -import reactor.core.publisher.Mono; /** - * Converts from a {@link ServerWebExchange} to an {@link OAuth2AuthorizationCodeAuthenticationToken} that can be authenticated. The + * Converts from a {@link ServerWebExchange} to an + * {@link OAuth2AuthorizationCodeAuthenticationToken} that can be authenticated. The * converter does not validate any errors it only performs a conversion. + * * @author Rob Winch * @since 5.1 * @see org.springframework.security.web.server.authentication.AuthenticationWebFilter#setServerAuthenticationConverter(ServerAuthenticationConverter) */ -public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter - implements ServerAuthenticationConverter { +public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter implements ServerAuthenticationConverter { static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found"; - private ServerAuthorizationRequestRepository authorizationRequestRepository = - new WebSessionOAuth2ServerAuthorizationRequestRepository(); + private ServerAuthorizationRequestRepository authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); private final ReactiveClientRegistrationRepository clientRegistrationRepository; @@ -69,9 +70,11 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter @Override public Mono convert(ServerWebExchange serverWebExchange) { + // @formatter:off return this.authorizationRequestRepository.removeAuthorizationRequest(serverWebExchange) - .switchIfEmpty(oauth2AuthorizationException(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE)) - .flatMap(authorizationRequest -> authenticationRequest(serverWebExchange, authorizationRequest)); + .switchIfEmpty(oauth2AuthorizationException(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE)) + .flatMap((authorizationRequest) -> authenticationRequest(serverWebExchange, authorizationRequest)); + // @formatter:on } private Mono oauth2AuthorizationException(String errorCode) { @@ -81,10 +84,11 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter }); } - private Mono authenticationRequest(ServerWebExchange exchange, OAuth2AuthorizationRequest authorizationRequest) { + private Mono authenticationRequest(ServerWebExchange exchange, + OAuth2AuthorizationRequest authorizationRequest) { + // @formatter:off return Mono.just(authorizationRequest) - .map(OAuth2AuthorizationRequest::getAttributes) - .flatMap(attributes -> { + .map(OAuth2AuthorizationRequest::getAttributes).flatMap((attributes) -> { String id = (String) attributes.get(OAuth2ParameterNames.REGISTRATION_ID); if (id == null) { return oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); @@ -92,19 +96,19 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter return this.clientRegistrationRepository.findByRegistrationId(id); }) .switchIfEmpty(oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE)) - .map(clientRegistration -> { + .map((clientRegistration) -> { OAuth2AuthorizationResponse authorizationResponse = convertResponse(exchange); OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken( - clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); + clientRegistration, + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); return authenticationRequest; }); + // @formatter:on } private static OAuth2AuthorizationResponse convertResponse(ServerWebExchange exchange) { - String redirectUri = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()) - .build() - .toUriString(); - return OAuth2AuthorizationResponseUtils - .convert(exchange.getRequest().getQueryParams(), redirectUri); + String redirectUri = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()).build().toUriString(); + return OAuth2AuthorizationResponseUtils.convert(exchange.getRequest().getQueryParams(), redirectUri); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationRequestResolver.java index e33d419b6e..8037996269 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationRequestResolver.java @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; /** - * Implementations of this interface are capable of resolving - * an {@link OAuth2AuthorizationRequest} from the provided {@code ServerWebExchange}. - * Used by the {@link OAuth2AuthorizationRequestRedirectWebFilter} for resolving Authorization Requests. + * Implementations of this interface are capable of resolving an + * {@link OAuth2AuthorizationRequest} from the provided {@code ServerWebExchange}. Used by + * the {@link OAuth2AuthorizationRequestRedirectWebFilter} for resolving Authorization + * Requests. * * @author Rob Winch * @since 5.1 @@ -32,21 +35,21 @@ import reactor.core.publisher.Mono; public interface ServerOAuth2AuthorizationRequestResolver { /** - * Returns the {@link OAuth2AuthorizationRequest} resolved from - * the provided {@code HttpServletRequest} or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizationRequest} resolved from the provided + * {@code HttpServletRequest} or {@code null} if not available. * @param exchange the {@code ServerWebExchange} - * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not + * available */ Mono resolve(ServerWebExchange exchange); /** - * Returns the {@link OAuth2AuthorizationRequest} resolved from - * the provided {@code HttpServletRequest} or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizationRequest} resolved from the provided + * {@code HttpServletRequest} or {@code null} if not available. * @param exchange the {@code ServerWebExchange} * @param clientRegistrationId the client registration id - * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available + * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not + * available */ Mono resolve(ServerWebExchange exchange, String clientRegistrationId); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java index 9491f71361..d7db2e45cc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java @@ -13,25 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; /** - * Implementations of this interface are responsible for the persistence - * of {@link OAuth2AuthorizedClient Authorized Client(s)} between requests. + * Implementations of this interface are responsible for the persistence of + * {@link OAuth2AuthorizedClient Authorized Client(s)} between requests. * *

        - * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client} - * is to associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential - * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, - * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} - * that originally granted the authorization. + * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client} is to + * associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential to + * a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, who + * is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} that originally + * granted the authorization. * * @author Rob Winch * @since 5.1 @@ -43,10 +45,9 @@ import reactor.core.publisher.Mono; public interface ServerOAuth2AuthorizedClientRepository { /** - * Returns the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User {@link Authentication} (Resource Owner) - * or {@code null} if not available. - * + * Returns the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User {@link Authentication} (Resource Owner) or + * {@code null} if not available. * @param clientRegistrationId the identifier for the client's registration * @param principal the End-User {@link Authentication} (Resource Owner) * @param exchange the {@code ServerWebExchange} @@ -57,20 +58,18 @@ public interface ServerOAuth2AuthorizedClientRepository { Authentication principal, ServerWebExchange exchange); /** - * Saves the {@link OAuth2AuthorizedClient} associating it to - * the provided End-User {@link Authentication} (Resource Owner). - * + * Saves the {@link OAuth2AuthorizedClient} associating it to the provided End-User + * {@link Authentication} (Resource Owner). * @param authorizedClient the authorized client * @param principal the End-User {@link Authentication} (Resource Owner) * @param exchange the {@code ServerWebExchange} */ - Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, - Authentication principal, ServerWebExchange exchange); + Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + ServerWebExchange exchange); /** - * Removes the {@link OAuth2AuthorizedClient} associated to the - * provided client registration identifier and End-User {@link Authentication} (Resource Owner). - * + * Removes the {@link OAuth2AuthorizedClient} associated to the provided client + * registration identifier and End-User {@link Authentication} (Resource Owner). * @param clientRegistrationId the identifier for the client's registration * @param principal the End-User {@link Authentication} (Resource Owner) * @param exchange the {@code ServerWebExchange} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java index 05356df521..1073e32fa7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java @@ -16,6 +16,11 @@ package org.springframework.security.oauth2.client.web.server; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.core.Authentication; @@ -23,23 +28,21 @@ import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiv import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; /** - * Provides support for an unauthenticated user. This is useful when running as a process with no - * user associated to it. The implementation ensures that {@link ServerWebExchange} is null and that the - * {@link Authentication} is either null or anonymous to prevent using it incorrectly. - * - * @deprecated Use {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} instead + * Provides support for an unauthenticated user. This is useful when running as a process + * with no user associated to it. The implementation ensures that + * {@link ServerWebExchange} is null and that the {@link Authentication} is either null or + * anonymous to prevent using it incorrectly. * + * @deprecated Use {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} + * instead * @author Rob Winch * @since 5.1 */ @Deprecated public class UnAuthenticatedServerOAuth2AuthorizedClientRepository implements ServerOAuth2AuthorizedClientRepository { + private final AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); private final Map clientRegistrationIdToAuthorizedClient = new ConcurrentHashMap<>(); @@ -50,14 +53,12 @@ public class UnAuthenticatedServerOAuth2AuthorizedClientRepository implements Se Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null"); Assert.isNull(serverWebExchange, "serverWebExchange must be null"); Assert.isTrue(isUnauthenticated(authentication), "The user " + authentication + " should not be authenticated"); - return Mono.fromSupplier(() -> (T) this.clientRegistrationIdToAuthorizedClient.get(clientRegistrationId)); } @Override - public Mono saveAuthorizedClient( - OAuth2AuthorizedClient authorizedClient, - Authentication authentication, ServerWebExchange serverWebExchange) { + public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication authentication, + ServerWebExchange serverWebExchange) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.isNull(serverWebExchange, "serverWebExchange must be null"); Assert.isTrue(isUnauthenticated(authentication), "The user " + authentication + " should not be authenticated"); @@ -79,4 +80,5 @@ public class UnAuthenticatedServerOAuth2AuthorizedClientRepository implements Se private boolean isUnauthenticated(Authentication authentication) { return authentication == null || this.trustResolver.isAnonymous(authentication); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java index 2ad3026ca4..f56953da55 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java @@ -19,6 +19,8 @@ package org.springframework.security.oauth2.client.web.server; import java.util.HashMap; import java.util.Map; +import reactor.core.publisher.Mono; + import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -27,8 +29,6 @@ import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; -import reactor.core.publisher.Mono; - /** * An implementation of an {@link ServerAuthorizationRequestRepository} that stores * {@link OAuth2AuthorizationRequest} in the {@code WebSession}. @@ -41,60 +41,69 @@ import reactor.core.publisher.Mono; public final class WebSessionOAuth2ServerAuthorizationRequestRepository implements ServerAuthorizationRequestRepository { - private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = - WebSessionOAuth2ServerAuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST"; + private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = WebSessionOAuth2ServerAuthorizationRequestRepository.class + .getName() + ".AUTHORIZATION_REQUEST"; private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; @Override - public Mono loadAuthorizationRequest( - ServerWebExchange exchange) { + public Mono loadAuthorizationRequest(ServerWebExchange exchange) { String state = getStateParameter(exchange); if (state == null) { return Mono.empty(); } + // @formatter:off return getStateToAuthorizationRequest(exchange) - .filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state)) - .map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state)); + .filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state)) + .map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state)); + // @formatter:on } @Override - public Mono saveAuthorizationRequest( - OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { - Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); - return saveStateToAuthorizationRequest(exchange) - .doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)) - .then(); - } - - @Override - public Mono removeAuthorizationRequest( + public Mono saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { + Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); + // @formatter:off + return saveStateToAuthorizationRequest(exchange) + .doOnNext((stateToAuthorizationRequest) -> stateToAuthorizationRequest + .put(authorizationRequest.getState(), authorizationRequest)) + .then(); + // @formatter:on + } + + @Override + public Mono removeAuthorizationRequest(ServerWebExchange exchange) { String state = getStateParameter(exchange); if (state == null) { return Mono.empty(); } + // @formatter:off return exchange.getSession() - .map(WebSession::getAttributes) - .handle((sessionAttrs, sink) -> { - Map stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs); - if (stateToAuthzRequest == null) { - sink.complete(); - return; - } - OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); - if (stateToAuthzRequest.isEmpty()) { - sessionAttrs.remove(this.sessionAttributeName); - } else if (removedValue != null) { - // gh-7327 Overwrite the existing Map to ensure the state is saved for distributed sessions - sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); - } - if (removedValue == null) { - sink.complete(); - } else { - sink.next(removedValue); - } - }); + .map(WebSession::getAttributes) + .handle((sessionAttrs, sink) -> { + Map stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest( + sessionAttrs); + if (stateToAuthzRequest == null) { + sink.complete(); + return; + } + OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); + if (stateToAuthzRequest.isEmpty()) { + sessionAttrs.remove(this.sessionAttributeName); + } + else if (removedValue != null) { + // gh-7327 Overwrite the existing Map to ensure the state is saved for + // distributed sessions + sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); + } + if (removedValue == null) { + sink.complete(); + } + else { + sink.next(removedValue); + } + }); + // @formatter:on } /** @@ -114,28 +123,33 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository private Mono> getStateToAuthorizationRequest(ServerWebExchange exchange) { Assert.notNull(exchange, "exchange cannot be null"); + // @formatter:off return getSessionAttributes(exchange) - .flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + .flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + // @formatter:on } private Mono> saveStateToAuthorizationRequest(ServerWebExchange exchange) { Assert.notNull(exchange, "exchange cannot be null"); - + // @formatter:off return getSessionAttributes(exchange) - .doOnNext(sessionAttrs -> { - Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); - - if (stateToAuthzRequest == null) { - stateToAuthzRequest = new HashMap(); - } - - // No matter stateToAuthzRequest was in session or not, we should always put it into session again - // in case of redis or hazelcast session. #6215 - sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); - }).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + .doOnNext((sessionAttrs) -> { + Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); + if (stateToAuthzRequest == null) { + stateToAuthzRequest = new HashMap(); + } + // No matter stateToAuthzRequest was in session or not, we should always put + // it into session again + // in case of redis or hazelcast session. #6215 + sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); + }) + .flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + // @formatter:on } - private Map sessionAttrsMapStateToAuthorizationRequest(Map sessionAttrs) { + private Map sessionAttrsMapStateToAuthorizationRequest( + Map sessionAttrs) { return (Map) sessionAttrs.get(this.sessionAttributeName); } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java index 2c3538e852..e8eb93f04c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; +import java.util.HashMap; +import java.util.Map; + +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; -import reactor.core.publisher.Mono; - -import java.util.HashMap; -import java.util.Map; /** * An implementation of an {@link OAuth2AuthorizedClientRepository} that stores @@ -35,21 +37,24 @@ import java.util.Map; * @see OAuth2AuthorizedClientRepository * @see OAuth2AuthorizedClient */ -public final class WebSessionServerOAuth2AuthorizedClientRepository - implements ServerOAuth2AuthorizedClientRepository { - private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME = - WebSessionServerOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"; +public final class WebSessionServerOAuth2AuthorizedClientRepository implements ServerOAuth2AuthorizedClientRepository { + + private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME = WebSessionServerOAuth2AuthorizedClientRepository.class + .getName() + ".AUTHORIZED_CLIENTS"; + private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME; @Override @SuppressWarnings("unchecked") - public Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, - ServerWebExchange exchange) { + public Mono loadAuthorizedClient(String clientRegistrationId, + Authentication principal, ServerWebExchange exchange) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.notNull(exchange, "exchange cannot be null"); + // @formatter:off return exchange.getSession() - .map(this::getAuthorizedClients) - .flatMap(clients -> Mono.justOrEmpty((T) clients.get(clientRegistrationId))); + .map(this::getAuthorizedClients) + .flatMap((clients) -> Mono.justOrEmpty((T) clients.get(clientRegistrationId))); + // @formatter:on } @Override @@ -57,13 +62,15 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository ServerWebExchange exchange) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(exchange, "exchange cannot be null"); + // @formatter:off return exchange.getSession() - .doOnSuccess(session -> { - Map authorizedClients = getAuthorizedClients(session); - authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient); - session.getAttributes().put(this.sessionAttributeName, authorizedClients); - }) - .then(Mono.empty()); + .doOnSuccess((session) -> { + Map authorizedClients = getAuthorizedClients(session); + authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient); + session.getAttributes().put(this.sessionAttributeName, authorizedClients); + }) + .then(Mono.empty()); + // @formatter:on } @Override @@ -71,26 +78,30 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository ServerWebExchange exchange) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.notNull(exchange, "exchange cannot be null"); + // @formatter:off return exchange.getSession() - .doOnSuccess(session -> { - Map authorizedClients = getAuthorizedClients(session); - authorizedClients.remove(clientRegistrationId); - if (authorizedClients.isEmpty()) { - session.getAttributes().remove(this.sessionAttributeName); - } else { - session.getAttributes().put(this.sessionAttributeName, authorizedClients); - } - }) - .then(Mono.empty()); + .doOnSuccess((session) -> { + Map authorizedClients = getAuthorizedClients(session); + authorizedClients.remove(clientRegistrationId); + if (authorizedClients.isEmpty()) { + session.getAttributes().remove(this.sessionAttributeName); + } + else { + session.getAttributes().put(this.sessionAttributeName, authorizedClients); + } + }) + .then(Mono.empty()); + // @formatter:on } @SuppressWarnings("unchecked") private Map getAuthorizedClients(WebSession session) { - Map authorizedClients = session == null ? null : - (Map) session.getAttribute(this.sessionAttributeName); + Map authorizedClients = (session != null) + ? (Map) session.getAttribute(this.sessionAttributeName) : null; if (authorizedClients == null) { authorizedClients = new HashMap<>(); } return authorizedClients; } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java index 9c3a2ba8f5..19d566bcf0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server.authentication; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; @@ -24,11 +27,11 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori import org.springframework.security.web.server.WebFilterExchange; import org.springframework.security.web.server.authentication.AuthenticationWebFilter; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; /** - * A specialized {@link AuthenticationWebFilter} that converts from an {@link OAuth2LoginAuthenticationToken} to an - * {@link OAuth2AuthenticationToken} and saves the {@link OAuth2AuthorizedClient} + * A specialized {@link AuthenticationWebFilter} that converts from an + * {@link OAuth2LoginAuthenticationToken} to an {@link OAuth2AuthenticationToken} and + * saves the {@link OAuth2AuthorizedClient} * * @author Rob Winch * @since 5.1 @@ -39,12 +42,10 @@ public class OAuth2LoginAuthenticationWebFilter extends AuthenticationWebFilter /** * Creates an instance - * * @param authenticationManager the authentication manager to use * @param authorizedClientRepository */ - public OAuth2LoginAuthenticationWebFilter( - ReactiveAuthenticationManager authenticationManager, + public OAuth2LoginAuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { super(authenticationManager); Assert.notNull(authorizedClientRepository, "authorizedClientService cannot be null"); @@ -52,19 +53,19 @@ public class OAuth2LoginAuthenticationWebFilter extends AuthenticationWebFilter } @Override - protected Mono onAuthenticationSuccess(Authentication authentication, - WebFilterExchange webFilterExchange) { + protected Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) authentication; OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - authenticationResult.getClientRegistration(), - authenticationResult.getName(), - authenticationResult.getAccessToken(), - authenticationResult.getRefreshToken()); - OAuth2AuthenticationToken result = new OAuth2AuthenticationToken( - authenticationResult.getPrincipal(), + authenticationResult.getClientRegistration(), authenticationResult.getName(), + authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); + OAuth2AuthenticationToken result = new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(), authenticationResult.getClientRegistration().getRegistrationId()); - return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authenticationResult, webFilterExchange.getExchange()) + // @formatter:off + return this.authorizedClientRepository + .saveAuthorizedClient(authorizedClient, authenticationResult, webFilterExchange.getExchange()) .then(super.onAuthenticationSuccess(result, webFilterExchange)); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java index c393b9f323..e22e3d3491 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -24,7 +26,8 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. @@ -32,53 +35,59 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class AuthorizationCodeOAuth2AuthorizedClientProviderTests { + private AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; @Before public void setup() { this.authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider(); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); - this.authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, "principal", + TestOAuth2AccessTokens.scopes("read", "write")); this.principal = new TestingAuthenticationToken("principal", "password"); } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null)); } @Test public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientCredentialsClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientCredentialsClient).principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient).principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(ClientAuthorizationRequiredException.class); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration).principal(this.principal) + .build(); + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java index 97bf724011..15b7d5740c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -24,7 +26,8 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link AuthorizationCodeReactiveOAuth2AuthorizedClientProvider}. @@ -32,53 +35,62 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests { + private AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; @Before public void setup() { this.authorizedClientProvider = new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider(); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); - this.authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, "principal", + TestOAuth2AccessTokens.scopes("read", "write")); this.principal = new TestingAuthenticationToken("principal", "password"); } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()); } @Test public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientCredentialsClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientCredentialsClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) - .isInstanceOf(ClientAuthorizationRequiredException.class); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java index 9ca0f88929..e99ba59846 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -29,20 +34,17 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import java.util.Map; -import java.util.function.Function; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; /** * Tests for {@link AuthorizedClientServiceOAuth2AuthorizedClientManager}. @@ -50,16 +52,27 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; + private AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private ArgumentCaptor authorizationContextCaptor; @SuppressWarnings("unchecked") @@ -71,13 +84,15 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { this.contextAttributesMapper = mock(Function.class); this.authorizationSuccessHandler = spy(new OAuth2AuthorizationSuccessHandler() { @Override - public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, Map attributes) { - authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, + Map attributes) { + AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.this.authorizedClientService + .saveAuthorizedClient(authorizedClient, principal); } }); this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName()))); + (clientRegistrationId, principal, attributes) -> this.authorizedClientService + .removeAuthorizedClient(clientRegistrationId, principal.getName()))); this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientService); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); @@ -93,82 +108,95 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(null, this.authorizedClientService)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(null, this.authorizedClientService)) + .withMessage("clientRegistrationRepository cannot be null"); + // @formatter:on } @Test public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientService cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .withMessage("authorizedClientService cannot be null"); + // @formatter:on } @Test public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientProvider cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .withMessage("authorizedClientProvider cannot be null"); + // @formatter:on } @Test public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("contextAttributesMapper cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .withMessage("contextAttributesMapper cannot be null"); + // @formatter:on } @Test public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationSuccessHandler cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .withMessage("authorizationSuccessHandler cannot be null"); + // @formatter:on } @Test public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationFailureHandler cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .withMessage("authorizationFailureHandler cannot be null"); + // @formatter:on } @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizeRequest cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.authorize(null)) + .withMessage("authorizeRequest cannot be null"); + // @formatter:on } @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("invalid-registration-id") + + .principal(this.principal).build(); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .withMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + // @formatter:on } @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isNull(); verifyNoInteractions(this.authorizationSuccessHandler); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); @@ -177,81 +205,71 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(this.authorizedClient); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(this.authorizedClient), eq(this.principal), any()); - verify(this.authorizedClientService).saveAuthorizedClient( - eq(this.authorizedClient), eq(this.principal)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(this.authorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientService).saveAuthorizedClient(eq(this.authorizedClient), eq(this.principal)); } @SuppressWarnings("unchecked") @Test public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - when(this.authorizedClientService.loadAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(this.authorizedClient); - - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + given(this.authorizedClientService.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal.getName()))).willReturn(this.authorizedClient); + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(reauthorizedClient); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(reauthorizedClient), eq(this.principal), any()); - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(reauthorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); } @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + .principal(this.principal).build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); verifyNoInteractions(this.authorizationSuccessHandler); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); @@ -260,66 +278,52 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenSupportedProviderThenReauthorized() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(reauthorizedClient); OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + .principal(this.principal).build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(reauthorizedClient), eq(this.principal), any()); - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(reauthorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); } @SuppressWarnings("unchecked") @Test public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(reauthorizedClient); // Override the mock with the default this.authorizedClientManager.setContextAttributesMapper( new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); - OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attribute(OAuth2ParameterNames.SCOPE, "read write") - .build(); + .principal(this.principal).attribute(OAuth2ParameterNames.SCOPE, "read write").build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); - String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(authorizationContext.getAttributes()) + .containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext + .getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(reauthorizedClient), eq(this.principal), any()); - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(reauthorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); } @Test @@ -327,41 +331,37 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { ClientAuthorizationException authorizationException = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) - .thenThrow(authorizationException); - + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willThrow(authorizationException); OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); - - assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + .principal(this.principal).build(); + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isEqualTo(authorizationException); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - eq(authorizationException), eq(this.principal), any()); - verify(this.authorizedClientService).removeAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName())); + verify(this.authorizationFailureHandler).onAuthorizationFailure(eq(authorizationException), eq(this.principal), + any()); + verify(this.authorizedClientService).removeAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal.getName())); } @Test public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() { ClientAuthorizationException authorizationException = new ClientAuthorizationException( - new OAuth2Error("non-matching-error-code", null, null), - this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) - .thenThrow(authorizationException); - - OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + new OAuth2Error("non-matching-error-code", null, null), this.clientRegistration.getRegistrationId()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willThrow(authorizationException); + // @formatter:off + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest + .withAuthorizedClient(this.authorizedClient) .principal(this.principal) .build(); - - assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isEqualTo(authorizationException); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - eq(authorizationException), eq(this.principal), any()); + // @formatter:on + verify(this.authorizationFailureHandler).onAuthorizationFailure(eq(authorizationException), eq(this.principal), + any()); verifyNoInteractions(this.authorizedClientService); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java index 01d38442d2..0ce1f0c675 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -29,22 +37,16 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import reactor.test.publisher.PublisherProbe; - -import java.util.Map; -import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * Tests for {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}. @@ -53,16 +55,27 @@ import static org.mockito.Mockito.when; * @author Phil Clay */ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private Function>> contextAttributesMapper; + private AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private ArgumentCaptor authorizationContextCaptor; + private PublisherProbe saveAuthorizedClientProbe; + private PublisherProbe removeAuthorizedClientProbe; @SuppressWarnings("unchecked") @@ -71,12 +84,14 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); this.saveAuthorizedClientProbe = PublisherProbe.empty(); - when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono()); + given(this.authorizedClientService.saveAuthorizedClient(any(), any())) + .willReturn(this.saveAuthorizedClientProbe.mono()); this.removeAuthorizedClientProbe = PublisherProbe.empty(); - when(this.authorizedClientService.removeAuthorizedClient(any(), any())).thenReturn(this.removeAuthorizedClientProbe.mono()); + given(this.authorizedClientService.removeAuthorizedClient(any(), any())) + .willReturn(this.removeAuthorizedClientProbe.mono()); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); this.contextAttributesMapper = mock(Function.class); - when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty()); + given(this.contextAttributesMapper.apply(any())).willReturn(Mono.empty()); this.authorizedClientManager = new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientService); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); @@ -90,123 +105,109 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, + this.authorizedClientService)) + .withMessage("clientRegistrationRepository cannot be null"); } @Test public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientService cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, null)) + .withMessage("authorizedClientService cannot be null"); } @Test public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientProvider cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .withMessage("authorizedClientProvider cannot be null"); } @Test public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("contextAttributesMapper cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .withMessage("contextAttributesMapper cannot be null"); } @Test public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationSuccessHandler cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .withMessage("authorizationSuccessHandler cannot be null"); } @Test public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationFailureHandler cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .withMessage("authorizationFailureHandler cannot be null"); } @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizeRequest cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(null)) + .withMessage("authorizeRequest cannot be null"); } @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { String clientRegistrationId = "invalid-registration-id"; OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) - .principal(this.principal) - .build(); - when(this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(Mono.empty()); + .principal(this.principal).build(); + given(this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).willReturn(Mono.empty()); StepVerifier.create(this.authorizedClientManager.authorize(authorizeRequest)) .verifyError(IllegalArgumentException.class); - } @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - when(authorizedClientProvider.authorize(any())).thenReturn(Mono.empty()); - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + given(this.authorizedClientProvider.authorize(any())).willReturn(Mono.empty()); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); + // @formatter:on Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - StepVerifier.create(authorizedClient).verifyComplete(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal)); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(OAuth2AuthorizedClient.class), + eq(this.principal)); } @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.authorizedClient)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - - StepVerifier.create(authorizedClient) - .expectNext(this.authorizedClient) - .verifyComplete(); - + StepVerifier.create(authorizedClient).expectNext(this.authorizedClient).verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService).saveAuthorizedClient( - eq(this.authorizedClient), eq(this.principal)); + verify(this.authorizedClientService).saveAuthorizedClient(eq(this.authorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); } @@ -214,34 +215,28 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndSupportedProviderAndCustomSuccessHandlerThenInvokeCustomSuccessHandler() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.authorizedClient)); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); + // @formatter:on PublisherProbe authorizationSuccessHandlerProbe = PublisherProbe.empty(); - this.authorizedClientManager.setAuthorizationSuccessHandler((client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); - + this.authorizedClientManager.setAuthorizationSuccessHandler( + (client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - - StepVerifier.create(authorizedClient) - .expectNext(this.authorizedClient) - .verifyComplete(); - + StepVerifier.create(authorizedClient).expectNext(this.authorizedClient).verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - authorizationSuccessHandlerProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); @@ -249,170 +244,150 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { @Test public void authorizeWhenInvalidTokenThenRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); - + // @formatter:on ClientAuthorizationException exception = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService).removeAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName())); + verify(this.authorizedClientService).removeAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal.getName())); this.removeAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @Test public void authorizeWhenInvalidGrantThenRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); - + // @formatter:on ClientAuthorizationException exception = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService).removeAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName())); + verify(this.authorizedClientService).removeAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal.getName())); this.removeAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @Test public void authorizeWhenServerErrorThenDoNotRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); - + // @formatter:on ClientAuthorizationException exception = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @Test public void authorizeWhenOAuth2AuthorizationExceptionThenDoNotRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); - + // @formatter:on OAuth2AuthorizationException exception = new OAuth2AuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @Test public void authorizeWhenOAuth2AuthorizationExceptionAndCustomFailureHandlerThenInvokeCustomFailureHandler() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientService.loadAuthorizedClient( - any(), any())).thenReturn(Mono.empty()); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) .build(); - + // @formatter:on OAuth2AuthorizationException exception = new OAuth2AuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); PublisherProbe authorizationFailureHandlerProbe = PublisherProbe.empty(); - this.authorizedClientManager.setAuthorizationFailureHandler((client, principal, attributes) -> authorizationFailureHandlerProbe.mono()); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + this.authorizedClientManager.setAuthorizationFailureHandler( + (client, principal, attributes) -> authorizationFailureHandlerProbe.mono()); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - authorizationFailureHandlerProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); @@ -421,35 +396,30 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - when(this.authorizedClientService.loadAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(Mono.just(this.authorizedClient)); - - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientService.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal.getName()))).willReturn(Mono.just(this.authorizedClient)); + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(reauthorizedClient)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - + // @formatter:off StepVerifier.create(authorizedClient) .expectNext(reauthorizedClient) .verifyComplete(); + // @formatter:on verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); + verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); } @@ -457,55 +427,45 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).willReturn(Mono.empty()); + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .build(); + // @formatter:on Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - - StepVerifier.create(authorizedClient) - .expectNext(this.authorizedClient) - .verifyComplete(); + StepVerifier.create(authorizedClient).expectNext(this.authorizedClient).verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal)); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(OAuth2AuthorizedClient.class), + eq(this.principal)); } @SuppressWarnings("unchecked") @Test public void reauthorizeWhenSupportedProviderThenReauthorized() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(reauthorizedClient)); + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .build(); + // @formatter:on Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - - StepVerifier.create(authorizedClient) - .expectNext(reauthorizedClient) - .verifyComplete(); - + StepVerifier.create(authorizedClient).expectNext(reauthorizedClient).verifyComplete(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); + verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); } @@ -513,37 +473,33 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(reauthorizedClient)); OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attribute(OAuth2ParameterNames.SCOPE, "read write") - .build(); - - this.authorizedClientManager.setContextAttributesMapper(new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); + .principal(this.principal).attribute(OAuth2ParameterNames.SCOPE, "read write").build(); + this.authorizedClientManager.setContextAttributesMapper( + new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - + // @formatter:off StepVerifier.create(authorizedClient) .expectNext(reauthorizedClient) .verifyComplete(); - verify(this.authorizedClientService).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal)); + // @formatter:on + verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); - String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(authorizationContext.getAttributes()) + .containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext + .getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); - } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index c2b85c7a30..c69a285ad8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -28,14 +33,11 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import java.time.Duration; -import java.time.Instant; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link ClientCredentialsOAuth2AuthorizedClientProvider}. @@ -43,9 +45,13 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { + private ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; @Before @@ -59,61 +65,72 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessTokenResponseClient cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class).withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew must be >= 0"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClock(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clock cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .withMessage("context cannot be null"); + // @formatter:on } @Test public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -123,20 +140,19 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), accessToken); - + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -144,13 +160,14 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @@ -160,26 +177,25 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { Instant now = Instant.now(); Instant issuedAt = now.minus(Duration.ofMinutes(60)); Instant expiresAt = now.minus(Duration.ofMinutes(1)); - OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), expiresInOneMinAccessToken); - - // Shorten the lifespan of the access token by 90 seconds, which will ultimately force it to expire on the client + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java index af33a645ac..2c11bb375f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests.java @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; @@ -27,16 +33,12 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link ClientCredentialsReactiveOAuth2AuthorizedClientProvider}. @@ -44,9 +46,13 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests { + private ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; @Before @@ -60,61 +66,72 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests { @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessTokenResponseClient cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew must be >= 0"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClock(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clock cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .withMessage("context cannot be null"); + // @formatter:on } @Test public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -124,20 +141,19 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests { public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), accessToken); - + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -145,13 +161,14 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @@ -161,26 +178,26 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests { Instant now = Instant.now(); Instant issuedAt = now.minus(Duration.ofMinutes(60)); Instant expiresAt = now.minus(Duration.ofMinutes(1)); - OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), expiresInOneMinAccessToken); - - // Shorten the lifespan of the access token by 90 seconds, which will ultimately force it to expire on the client + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); - - OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java index f930233aa8..f595fa571c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -13,22 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; + import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; -import java.util.Collections; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link DelegatingOAuth2AuthorizedClientProvider}. @@ -39,36 +41,33 @@ public class DelegatingOAuth2AuthorizedClientProviderTests { @Test public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(new OAuth2AuthorizedClientProvider[0])) - .isInstanceOf(IllegalArgumentException.class); - assertThatThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(Collections.emptyList())) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(new OAuth2AuthorizedClientProvider[0])); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(Collections.emptyList())); } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( mock(OAuth2AuthorizedClientProvider.class)); - assertThatThrownBy(() -> delegate.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> delegate.authorize(null)) + .withMessage("context cannot be null"); } @Test public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { Authentication principal = new TestingAuthenticationToken("principal", "password"); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, principal.getName(), TestOAuth2AccessTokens.noScopes()); - + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + TestOAuth2AccessTokens.noScopes()); OAuth2AuthorizedClientProvider authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); - when(authorizedClientProvider.authorize(any())).thenReturn(authorizedClient); - + given(authorizedClientProvider.authorize(any())).willReturn(authorizedClient); DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); + mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), + authorizedClientProvider); OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(principal) - .build(); + .principal(principal).build(); OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); assertThat(reauthorizedClient).isSameAs(authorizedClient); } @@ -77,11 +76,10 @@ public class DelegatingOAuth2AuthorizedClientProviderTests { public void authorizeWhenProviderCantAuthorizeThenReturnNull() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(new TestingAuthenticationToken("principal", "password")) - .build(); - + .principal(new TestingAuthenticationToken("principal", "password")).build(); DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); assertThat(delegate.authorize(context)).isNull(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java index 45ddf6e528..5619fd4be2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingReactiveOAuth2AuthorizedClientProviderTests.java @@ -13,23 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; + import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; -import reactor.core.publisher.Mono; - -import java.util.Collections; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link DelegatingReactiveOAuth2AuthorizedClientProvider}. @@ -40,40 +42,39 @@ public class DelegatingReactiveOAuth2AuthorizedClientProviderTests { @Test public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(new ReactiveOAuth2AuthorizedClientProvider[0])) - .isInstanceOf(IllegalArgumentException.class); - assertThatThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(Collections.emptyList())) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider( + new ReactiveOAuth2AuthorizedClientProvider[0])); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DelegatingReactiveOAuth2AuthorizedClientProvider(Collections.emptyList())); } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider( mock(ReactiveOAuth2AuthorizedClientProvider.class)); - assertThatThrownBy(() -> delegate.authorize(null).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> delegate.authorize(null).block()) + .withMessage("context cannot be null"); } @Test public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { Authentication principal = new TestingAuthenticationToken("principal", "password"); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, principal.getName(), TestOAuth2AccessTokens.noScopes()); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(authorizedClientProvider1.authorize(any())).thenReturn(Mono.empty()); - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(authorizedClientProvider2.authorize(any())).thenReturn(Mono.empty()); - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider3 = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(authorizedClientProvider3.authorize(any())).thenReturn(Mono.just(authorizedClient)); - + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(), + TestOAuth2AccessTokens.noScopes()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock( + ReactiveOAuth2AuthorizedClientProvider.class); + given(authorizedClientProvider1.authorize(any())).willReturn(Mono.empty()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock( + ReactiveOAuth2AuthorizedClientProvider.class); + given(authorizedClientProvider2.authorize(any())).willReturn(Mono.empty()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider3 = mock( + ReactiveOAuth2AuthorizedClientProvider.class); + given(authorizedClientProvider3.authorize(any())).willReturn(Mono.just(authorizedClient)); DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider( authorizedClientProvider1, authorizedClientProvider2, authorizedClientProvider3); OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(principal) - .build(); + .principal(principal).build(); OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context).block(); assertThat(reauthorizedClient).isSameAs(authorizedClient); } @@ -82,16 +83,16 @@ public class DelegatingReactiveOAuth2AuthorizedClientProviderTests { public void authorizeWhenProviderCantAuthorizeThenReturnNull() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(new TestingAuthenticationToken("principal", "password")) - .build(); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(authorizedClientProvider1.authorize(any())).thenReturn(Mono.empty()); - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(authorizedClientProvider2.authorize(any())).thenReturn(Mono.empty()); - + .principal(new TestingAuthenticationToken("principal", "password")).build(); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider1 = mock( + ReactiveOAuth2AuthorizedClientProvider.class); + given(authorizedClientProvider1.authorize(any())).willReturn(Mono.empty()); + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider2 = mock( + ReactiveOAuth2AuthorizedClientProvider.class); + given(authorizedClientProvider2.authorize(any())).willReturn(Mono.empty()); DelegatingReactiveOAuth2AuthorizedClientProvider delegate = new DelegatingReactiveOAuth2AuthorizedClientProvider( authorizedClientProvider1, authorizedClientProvider2); assertThat(delegate.authorize(context).block()).isNull(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java index dbb00dbfcc..6dbeffd490 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java @@ -13,9 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.util.Collections; +import java.util.Map; + import org.junit.Test; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -23,14 +28,12 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import java.util.Collections; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatObject; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link InMemoryOAuth2AuthorizedClientService}. @@ -39,24 +42,23 @@ import static org.mockito.Mockito.when; * @author Vedran Pavic */ public class InMemoryOAuth2AuthorizedClientServiceTests { + private String principalName1 = "principal-1"; + private String principalName2 = "principal-2"; private ClientRegistration registration1 = TestClientRegistrations.clientRegistration().build(); private ClientRegistration registration2 = TestClientRegistrations.clientRegistration2().build(); - private ClientRegistration registration3 = TestClientRegistrations.clientRegistration() - .clientId("client-3") - .registrationId("registration-3") - .build(); + private ClientRegistration registration3 = TestClientRegistrations.clientRegistration().clientId("client-3") + .registrationId("registration-3").build(); - private ClientRegistrationRepository clientRegistrationRepository = - new InMemoryClientRegistrationRepository(this.registration1, this.registration2, this.registration3); - - private InMemoryOAuth2AuthorizedClientService authorizedClientService = - new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); + private ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + this.registration1, this.registration2, this.registration3); + private InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( + this.clientRegistrationRepository); @Test(expected = IllegalArgumentException.class) public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { @@ -65,25 +67,24 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { @Test public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClients cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository, null)) + .withMessage("authorizedClients cannot be empty"); + // @formatter:on } @Test public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() { String registrationId = this.registration3.getRegistrationId(); - Map authorizedClients = Collections.singletonMap( new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1), mock(OAuth2AuthorizedClient.class)); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); - when(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).thenReturn(this.registration3); - + given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3); InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( clientRegistrationRepository, authorizedClients); - assertThat((OAuth2AuthorizedClient) authorizedClientService.loadAuthorizedClient( - registrationId, this.principalName1)).isNotNull(); + assertThatObject(authorizedClientService.loadAuthorizedClient(registrationId, this.principalName1)).isNotNull(); } @Test(expected = IllegalArgumentException.class) @@ -98,29 +99,27 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { @Test public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() { - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - "registration-not-found", this.principalName1); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient("registration-not-found", this.principalName1); assertThat(authorizedClient).isNull(); } @Test public void loadAuthorizedClientWhenClientRegistrationFoundButNotAssociatedToPrincipalThenReturnNull() { - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.registration1.getRegistrationId(), "principal-not-found"); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration1.getRegistrationId(), "principal-not-found"); assertThat(authorizedClient).isNull(); } @Test public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrincipalThenReturnAuthorizedClient() { Authentication authentication = mock(Authentication.class); - when(authentication.getName()).thenReturn(this.principalName1); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + given(authentication.getName()).willReturn(this.principalName1); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); - - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.registration1.getRegistrationId(), this.principalName1); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); } @@ -137,14 +136,12 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { @Test public void saveAuthorizedClientWhenSavedThenCanLoad() { Authentication authentication = mock(Authentication.class); - when(authentication.getName()).thenReturn(this.principalName2); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration3, this.principalName2, mock(OAuth2AccessToken.class)); + given(authentication.getName()).willReturn(this.principalName2); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName2, + mock(OAuth2AccessToken.class)); this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); - - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.registration3.getRegistrationId(), this.principalName2); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2); assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); } @@ -161,20 +158,18 @@ public class InMemoryOAuth2AuthorizedClientServiceTests { @Test public void removeAuthorizedClientWhenSavedThenRemoved() { Authentication authentication = mock(Authentication.class); - when(authentication.getName()).thenReturn(this.principalName2); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName2, mock(OAuth2AccessToken.class)); + given(authentication.getName()).willReturn(this.principalName2); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName2, + mock(OAuth2AccessToken.class)); this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); - - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.registration2.getRegistrationId(), this.principalName2); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2); assertThat(loadedAuthorizedClient).isNotNull(); - - this.authorizedClientService.removeAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2); - - loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.registration2.getRegistrationId(), this.principalName2); + this.authorizedClientService.removeAuthorizedClient(this.registration2.getRegistrationId(), + this.principalName2); + loadedAuthorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2); assertThat(loadedAuthorizedClient).isNull(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java index 2d05534822..314aaa4fe3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java @@ -16,11 +16,17 @@ package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -28,14 +34,9 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import java.time.Duration; -import java.time.Instant; - -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.when; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -43,6 +44,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { + @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; @@ -54,11 +56,10 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { private Authentication principal = new TestingAuthenticationToken(this.principalName, "notused"); - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", - Instant.now(), + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), Instant.now().plus(Duration.ofDays(1))); + // @formatter:off private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) @@ -72,6 +73,7 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { .clientId("clientId") .clientSecret("clientSecret") .build(); + // @formatter:on @Before public void setup() { @@ -82,130 +84,166 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { @Test public void constructorNullClientRegistrationRepositoryThenThrowsIllegalArgumentException() { this.clientRegistrationRepository = null; - assertThatThrownBy(() -> new InMemoryReactiveOAuth2AuthorizedClientService(this.clientRegistrationRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new InMemoryReactiveOAuth2AuthorizedClientService(this.clientRegistrationRepository)); } @Test public void loadAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() { this.clientRegistrationId = null; - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void loadAuthorizedClientWhenClientRegistrationIdEmptyThenIllegalArgumentException() { this.clientRegistrationId = ""; - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void loadAuthorizedClientWhenPrincipalNameNullThenIllegalArgumentException() { this.principalName = null; - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void loadAuthorizedClientWhenPrincipalNameEmptyThenIllegalArgumentException() { this.principalName = ""; - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void loadAuthorizedClientWhenClientRegistrationIdNotFoundThenEmpty() { - when(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) - .thenReturn(Mono.empty()); - StepVerifier - .create(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .verifyComplete(); + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.empty()); + StepVerifier.create( + this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) + .verifyComplete(); } @Test public void loadAuthorizedClientWhenClientRegistrationFoundAndNotAuthorizedClientThenEmpty() { - when(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)).thenReturn(Mono.just(this.clientRegistration)); - StepVerifier - .create(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.just(this.clientRegistration)); + StepVerifier.create( + this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) .verifyComplete(); } @Test public void loadAuthorizedClientWhenClientRegistrationFoundThenFound() { - when(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)).thenReturn(Mono.just(this.clientRegistration)); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); - Mono saveAndLoad = this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken); + // @formatter:off + Mono saveAndLoad = this.authorizedClientService + .saveAuthorizedClient(authorizedClient, this.principal) .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); - StepVerifier.create(saveAndLoad) - .expectNext(authorizedClient) - .verifyComplete(); + .expectNext(authorizedClient) + .verifyComplete(); + // @formatter:on } @Test public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() { OAuth2AuthorizedClient authorizedClient = null; - assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)); + // @formatter:on } @Test public void saveAuthorizedClientWhenPrincipalNullThenIllegalArgumentException() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken); this.principal = null; - assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)); + // @formatter:on } @Test public void removeAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() { this.clientRegistrationId = null; - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void removeAuthorizedClientWhenClientRegistrationIdEmptyThenIllegalArgumentException() { this.clientRegistrationId = ""; - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void removeAuthorizedClientWhenPrincipalNameNullThenIllegalArgumentException() { this.principalName = null; - assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void removeAuthorizedClientWhenPrincipalNameEmptyThenIllegalArgumentException() { this.principalName = ""; - assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName)); + // @formatter:on } @Test public void removeAuthorizedClientWhenClientIdThenNoException() { - when(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)).thenReturn(Mono.empty()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.empty()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken); + // @formatter:off Mono saveAndDeleteAndLoad = this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal) - .then(this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName)); - + .then(this.authorizedClientService + .removeAuthorizedClient(this.clientRegistrationId, this.principalName) + ); StepVerifier.create(saveAndDeleteAndLoad) .verifyComplete(); + // @formatter:on } @Test public void removeAuthorizedClientWhenClientRegistrationFoundRemovedThenNotFound() { - when(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)).thenReturn(Mono.just(this.clientRegistration)); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken); + // @formatter:off Mono saveAndDeleteAndLoad = this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal) - .then(this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName)) + .then(this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, + this.principalName)) .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); - StepVerifier.create(saveAndDeleteAndLoad) .verifyComplete(); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java index 78fe9c30fb..cea664e166 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java @@ -13,11 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.nio.charset.StandardCharsets; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.List; +import java.util.Set; + import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.dao.DataRetrievalFailureException; import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; import org.springframework.jdbc.core.JdbcOperations; @@ -40,26 +53,16 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.nio.charset.StandardCharsets; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; -import java.sql.Types; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Collections; -import java.util.List; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.within; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * Tests for {@link JdbcOAuth2AuthorizedClientService}. @@ -68,23 +71,30 @@ import static org.mockito.Mockito.when; * @author Stav Shamir */ public class JdbcOAuth2AuthorizedClientServiceTests { + private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql"; + private static int principalId = 1000; + private ClientRegistration clientRegistration; + private ClientRegistrationRepository clientRegistrationRepository; + private EmbeddedDatabase db; + private JdbcOperations jdbcOperations; + private JdbcOAuth2AuthorizedClientService authorizedClientService; @Before public void setUp() { this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.clientRegistration); + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(this.clientRegistration); this.db = createDb(); this.jdbcOperations = new JdbcTemplate(this.db); - this.authorizedClientService = new JdbcOAuth2AuthorizedClientService( - this.jdbcOperations, this.clientRegistrationRepository); + this.authorizedClientService = new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, + this.clientRegistrationRepository); } @After @@ -94,50 +104,61 @@ public class JdbcOAuth2AuthorizedClientServiceTests { @Test public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jdbcOperations cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) + .withMessage("jdbcOperations cannot be null"); + // @formatter:on } @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, null)) + .withMessage("clientRegistrationRepository cannot be null"); } @Test public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRowMapper cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) + .withMessage("authorizedClientRowMapper cannot be null"); + // @formatter:on } @Test public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientParametersMapper cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) + .withMessage("authorizedClientParametersMapper cannot be null"); + // @formatter:on } @Test public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) + .withMessage("clientRegistrationId cannot be empty"); + // @formatter:on } @Test public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principalName cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) + .withMessage("principalName cannot be empty"); + // @formatter:on } @Test public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() { - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - "registration-not-found", "principalName"); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient("registration-not-found", "principalName"); assertThat(authorizedClient).isNull(); } @@ -145,94 +166,97 @@ public class JdbcOAuth2AuthorizedClientServiceTests { public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() { Authentication principal = createPrincipal(); OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); - this.authorizedClientService.saveAuthorizedClient(expected, principal); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); - assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); - assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); - assertThat(authorizedClient.getAccessToken().getIssuedAt()).isCloseTo(expected.getAccessToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); - assertThat(authorizedClient.getAccessToken().getExpiresAt()).isCloseTo(expected.getAccessToken().getExpiresAt(), within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getAccessToken().getTokenType()) + .isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()).isCloseTo(expected.getAccessToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getAccessToken().getExpiresAt()).isCloseTo(expected.getAccessToken().getExpiresAt(), + within(1, ChronoUnit.MILLIS)); assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); - assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue()); - assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isCloseTo(expected.getRefreshToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getRefreshToken().getTokenValue()) + .isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isCloseTo(expected.getRefreshToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); } @Test public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(null); + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(null); Authentication principal = createPrincipal(); OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); - this.authorizedClientService.saveAuthorizedClient(expected, principal); - - assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())) - .isInstanceOf(DataRetrievalFailureException.class) - .hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + - "' exists in the data source, however, it was not found in the ClientRegistrationRepository."); + assertThatExceptionOfType(DataRetrievalFailureException.class) + .isThrownBy(() -> this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())) + .withMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + + "' exists in the data source, however, it was not found in the ClientRegistrationRepository."); } @Test public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { Authentication principal = createPrincipal(); - - assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal)) + .withMessage("authorizedClient cannot be null"); } @Test public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { Authentication principal = createPrincipal(); OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); - - assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null)) + .withMessage("principal cannot be null"); } @Test public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() { Authentication principal = createPrincipal(); OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); - this.authorizedClientService.saveAuthorizedClient(expected, principal); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); - assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); - assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); - assertThat(authorizedClient.getAccessToken().getIssuedAt()).isCloseTo(expected.getAccessToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); - assertThat(authorizedClient.getAccessToken().getExpiresAt()).isCloseTo(expected.getAccessToken().getExpiresAt(), within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getAccessToken().getTokenType()) + .isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()).isCloseTo(expected.getAccessToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getAccessToken().getExpiresAt()).isCloseTo(expected.getAccessToken().getExpiresAt(), + within(1, ChronoUnit.MILLIS)); assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); - assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue()); - assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isCloseTo(expected.getRefreshToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); - + assertThat(authorizedClient.getRefreshToken().getTokenValue()) + .isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isCloseTo(expected.getRefreshToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); // Test save/load of NOT NULL attributes only principal = createPrincipal(); expected = createAuthorizedClient(principal, this.clientRegistration, true); - this.authorizedClientService.saveAuthorizedClient(expected, principal); - - authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - + authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); - assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); - assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); - assertThat(authorizedClient.getAccessToken().getIssuedAt()).isCloseTo(expected.getAccessToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); - assertThat(authorizedClient.getAccessToken().getExpiresAt()).isCloseTo(expected.getAccessToken().getExpiresAt(), within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getAccessToken().getTokenType()) + .isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()).isCloseTo(expected.getAccessToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); + assertThat(authorizedClient.getAccessToken().getExpiresAt()).isCloseTo(expected.getAccessToken().getExpiresAt(), + within(1, ChronoUnit.MILLIS)); assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty(); assertThat(authorizedClient.getRefreshToken()).isNull(); } @@ -243,101 +267,92 @@ public class JdbcOAuth2AuthorizedClientServiceTests { Authentication principal = createPrincipal(); OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); - // When a client with the same principal and registration id is saved OAuth2AuthorizedClient updatedClient = createAuthorizedClient(principal, this.clientRegistration); this.authorizedClientService.saveAuthorizedClient(updatedClient, principal); - // Then the saved client is updated - OAuth2AuthorizedClient savedClient = this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - + OAuth2AuthorizedClient savedClient = this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(savedClient).isNotNull(); assertThat(savedClient.getClientRegistration()).isEqualTo(updatedClient.getClientRegistration()); assertThat(savedClient.getPrincipalName()).isEqualTo(updatedClient.getPrincipalName()); - assertThat(savedClient.getAccessToken().getTokenType()).isEqualTo(updatedClient.getAccessToken().getTokenType()); - assertThat(savedClient.getAccessToken().getTokenValue()).isEqualTo(updatedClient.getAccessToken().getTokenValue()); - assertThat(savedClient.getAccessToken().getIssuedAt()).isCloseTo(updatedClient.getAccessToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); - assertThat(savedClient.getAccessToken().getExpiresAt()).isCloseTo(updatedClient.getAccessToken().getExpiresAt(), within(1, ChronoUnit.MILLIS)); + assertThat(savedClient.getAccessToken().getTokenType()) + .isEqualTo(updatedClient.getAccessToken().getTokenType()); + assertThat(savedClient.getAccessToken().getTokenValue()) + .isEqualTo(updatedClient.getAccessToken().getTokenValue()); + assertThat(savedClient.getAccessToken().getIssuedAt()).isCloseTo(updatedClient.getAccessToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); + assertThat(savedClient.getAccessToken().getExpiresAt()).isCloseTo(updatedClient.getAccessToken().getExpiresAt(), + within(1, ChronoUnit.MILLIS)); assertThat(savedClient.getAccessToken().getScopes()).isEqualTo(updatedClient.getAccessToken().getScopes()); - assertThat(savedClient.getRefreshToken().getTokenValue()).isEqualTo(updatedClient.getRefreshToken().getTokenValue()); - assertThat(savedClient.getRefreshToken().getIssuedAt()).isCloseTo(updatedClient.getRefreshToken().getIssuedAt(), within(1, ChronoUnit.MILLIS)); + assertThat(savedClient.getRefreshToken().getTokenValue()) + .isEqualTo(updatedClient.getRefreshToken().getTokenValue()); + assertThat(savedClient.getRefreshToken().getIssuedAt()).isCloseTo(updatedClient.getRefreshToken().getIssuedAt(), + within(1, ChronoUnit.MILLIS)); } @Test public void saveLoadAuthorizedClientWhenCustomStrategiesSetThenCalled() throws Exception { - JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper authorizedClientRowMapper = - spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper(this.clientRegistrationRepository)); + JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper authorizedClientRowMapper = spy( + new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper( + this.clientRegistrationRepository)); this.authorizedClientService.setAuthorizedClientRowMapper(authorizedClientRowMapper); - JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper authorizedClientParametersMapper = - spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper()); + JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper authorizedClientParametersMapper = spy( + new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper()); this.authorizedClientService.setAuthorizedClientParametersMapper(authorizedClientParametersMapper); - Authentication principal = createPrincipal(); OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); - this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); - this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - + this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), + principal.getName()); verify(authorizedClientRowMapper).mapRow(any(), anyInt()); verify(authorizedClientParametersMapper).apply(any()); } @Test public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName")) + .withMessage("clientRegistrationId cannot be empty"); } @Test public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principalName cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientService + .removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) + .withMessage("principalName cannot be empty"); } @Test public void removeAuthorizedClientWhenExistsThenRemoved() { Authentication principal = createPrincipal(); OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); - this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); - - authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); + authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNotNull(); - - this.authorizedClientService.removeAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - - authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); + this.authorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), + principal.getName()); + authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNull(); } @Test public void tableDefinitionWhenCustomThenAbleToOverride() { - CustomTableDefinitionJdbcOAuth2AuthorizedClientService customAuthorizedClientService = - new CustomTableDefinitionJdbcOAuth2AuthorizedClientService( - new JdbcTemplate(createDb("custom-oauth2-client-schema.sql")), - this.clientRegistrationRepository); - + CustomTableDefinitionJdbcOAuth2AuthorizedClientService customAuthorizedClientService = new CustomTableDefinitionJdbcOAuth2AuthorizedClientService( + new JdbcTemplate(createDb("custom-oauth2-client-schema.sql")), this.clientRegistrationRepository); Authentication principal = createPrincipal(); OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); - customAuthorizedClientService.saveAuthorizedClient(authorizedClient, principal); - - authorizedClient = customAuthorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); + authorizedClient = customAuthorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNotNull(); - - customAuthorizedClientService.removeAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); - - authorizedClient = customAuthorizedClientService.loadAuthorizedClient( - this.clientRegistration.getRegistrationId(), principal.getName()); + customAuthorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), + principal.getName()); + authorizedClient = customAuthorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()); assertThat(authorizedClient).isNull(); } @@ -346,19 +361,22 @@ public class JdbcOAuth2AuthorizedClientServiceTests { } private static EmbeddedDatabase createDb(String schema) { + // @formatter:off return new EmbeddedDatabaseBuilder() .generateUniqueName(true) .setType(EmbeddedDatabaseType.HSQL) .setScriptEncoding("UTF-8") .addScript(schema) .build(); + // @formatter:on } private static Authentication createPrincipal() { return new TestingAuthenticationToken("principal-" + principalId++, "password"); } - private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, ClientRegistration clientRegistration) { + private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, + ClientRegistration clientRegistration) { return createAuthorizedClient(principal, clientRegistration, false); } @@ -367,60 +385,59 @@ public class JdbcOAuth2AuthorizedClientServiceTests { OAuth2AccessToken accessToken; if (!requiredAttributesOnly) { accessToken = TestOAuth2AccessTokens.scopes("read", "write"); - } else { + } + else { accessToken = TestOAuth2AccessTokens.noScopes(); } OAuth2RefreshToken refreshToken = null; if (!requiredAttributesOnly) { refreshToken = TestOAuth2RefreshTokens.refreshToken(); } - return new OAuth2AuthorizedClient( - clientRegistration, principal.getName(), accessToken, refreshToken); + return new OAuth2AuthorizedClient(clientRegistration, principal.getName(), accessToken, refreshToken); } - private static class CustomTableDefinitionJdbcOAuth2AuthorizedClientService extends JdbcOAuth2AuthorizedClientService { - private static final String COLUMN_NAMES = - "clientRegistrationId, " + - "principalName, " + - "accessTokenType, " + - "accessTokenValue, " + - "accessTokenIssuedAt, " + - "accessTokenExpiresAt, " + - "accessTokenScopes, " + - "refreshTokenValue, " + - "refreshTokenIssuedAt"; - private static final String TABLE_NAME = "oauth2AuthorizedClient"; - private static final String PK_FILTER = "clientRegistrationId = ? AND principalName = ?"; - private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + - " FROM " + TABLE_NAME + " WHERE " + PK_FILTER; - private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + - " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; - private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + - " WHERE " + PK_FILTER; + private static final class CustomTableDefinitionJdbcOAuth2AuthorizedClientService + extends JdbcOAuth2AuthorizedClientService { - private CustomTableDefinitionJdbcOAuth2AuthorizedClientService( - JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { + private static final String COLUMN_NAMES = "clientRegistrationId, " + "principalName, " + "accessTokenType, " + + "accessTokenValue, " + "accessTokenIssuedAt, " + "accessTokenExpiresAt, " + "accessTokenScopes, " + + "refreshTokenValue, " + "refreshTokenIssuedAt"; + + private static final String TABLE_NAME = "oauth2AuthorizedClient"; + + private static final String PK_FILTER = "clientRegistrationId = ? AND principalName = ?"; + + private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + + private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + " (" + COLUMN_NAMES + + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + + private CustomTableDefinitionJdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations, + ClientRegistrationRepository clientRegistrationRepository) { super(jdbcOperations, clientRegistrationRepository); setAuthorizedClientRowMapper(new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository)); } @Override @SuppressWarnings("unchecked") - public T loadAuthorizedClient(String clientRegistrationId, String principalName) { + public T loadAuthorizedClient(String clientRegistrationId, + String principalName) { SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), - new SqlParameterValue(Types.VARCHAR, principalName) - }; + new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); - List result = this.jdbcOperations.query( - LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); + List result = this.jdbcOperations.query(LOAD_AUTHORIZED_CLIENT_SQL, pss, + this.authorizedClientRowMapper); return !result.isEmpty() ? (T) result.get(0) : null; } @Override public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { - List parameters = this.authorizedClientParametersMapper.apply( - new OAuth2AuthorizedClientHolder(authorizedClient, principal)); + List parameters = this.authorizedClientParametersMapper + .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)); PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); } @@ -429,13 +446,13 @@ public class JdbcOAuth2AuthorizedClientServiceTests { public void removeAuthorizedClient(String clientRegistrationId, String principalName) { SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), - new SqlParameterValue(Types.VARCHAR, principalName) - }; + new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); } - private static class OAuth2AuthorizedClientRowMapper implements RowMapper { + private static final class OAuth2AuthorizedClientRowMapper implements RowMapper { + private final ClientRegistrationRepository clientRegistrationRepository; private OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) { @@ -446,17 +463,15 @@ public class JdbcOAuth2AuthorizedClientServiceTests { @Override public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException { String clientRegistrationId = rs.getString("clientRegistrationId"); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId( - clientRegistrationId); + ClientRegistration clientRegistration = this.clientRegistrationRepository + .findByRegistrationId(clientRegistrationId); if (clientRegistration == null) { - throw new DataRetrievalFailureException("The ClientRegistration with id '" + - clientRegistrationId + "' exists in the data source, " + - "however, it was not found in the ClientRegistrationRepository."); + throw new DataRetrievalFailureException( + "The ClientRegistration with id '" + clientRegistrationId + "' exists in the data source, " + + "however, it was not found in the ClientRegistrationRepository."); } - OAuth2AccessToken.TokenType tokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( - rs.getString("accessTokenType"))) { + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("accessTokenType"))) { tokenType = OAuth2AccessToken.TokenType.BEARER; } String tokenValue = new String(rs.getBytes("accessTokenValue"), StandardCharsets.UTF_8); @@ -467,9 +482,8 @@ public class JdbcOAuth2AuthorizedClientServiceTests { if (accessTokenScopes != null) { scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); } - OAuth2AccessToken accessToken = new OAuth2AccessToken( - tokenType, tokenValue, issuedAt, expiresAt, scopes); - + OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, + scopes); OAuth2RefreshToken refreshToken = null; byte[] refreshTokenValue = rs.getBytes("refreshTokenValue"); if (refreshTokenValue != null) { @@ -481,12 +495,12 @@ public class JdbcOAuth2AuthorizedClientServiceTests { } refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); } - String principalName = rs.getString("principalName"); - - return new OAuth2AuthorizedClient( - clientRegistration, principalName, accessToken, refreshToken); + return new OAuth2AuthorizedClient(clientRegistration, principalName, accessToken, refreshToken); } + } + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index 9749b6ded7..0a385dd806 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -13,17 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.entry; /** * Tests for {@link OAuth2AuthorizationContext}. @@ -31,52 +35,59 @@ import static org.assertj.core.api.Assertions.*; * @author Joe Grandja */ public class OAuth2AuthorizationContextTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; @Before public void setup() { this.clientRegistration = TestClientRegistrations.clientRegistration().build(); - this.authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, "principal", + TestOAuth2AccessTokens.scopes("read", "write")); this.principal = new TestingAuthenticationToken("principal", "password"); } @Test public void withClientRegistrationWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistration cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(null).build()) + .withMessage("clientRegistration cannot be null"); } @Test public void withAuthorizedClientWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient(null).build()) + .withMessage("authorizedClient cannot be null"); } @Test public void withClientRegistrationWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build()) + .withMessage("principal cannot be null"); } @Test public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attributes(attributes -> { + .attributes((attributes) -> { attributes.put("attribute1", "value1"); attributes.put("attribute2", "value2"); }) .build(); + // @formatter:on assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); - assertThat(authorizationContext.getAttributes()).contains( - entry("attribute1", "value1"), entry("attribute2", "value2")); + assertThat(authorizationContext.getAttributes()).contains(entry("attribute1", "value1"), + entry("attribute2", "value2")); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java index 37c7e89eef..588f6283f5 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -24,7 +26,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.entry; /** @@ -33,50 +35,50 @@ import static org.assertj.core.api.Assertions.entry; * @author Joe Grandja */ public class OAuth2AuthorizeRequestTests { + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + private Authentication principal = new TestingAuthenticationToken("principal", "password"); - private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + + private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write"), + TestOAuth2RefreshTokens.refreshToken()); @Test public void withClientRegistrationIdWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); + assertThatIllegalArgumentException().isThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(null)) + .withMessage("clientRegistrationId cannot be empty"); } @Test public void withAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizeRequest.withAuthorizedClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> OAuth2AuthorizeRequest.withAuthorizedClient(null)) + .withMessage("authorizedClient cannot be null"); } @Test public void withClientRegistrationIdWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).build()) + .withMessage("principal cannot be null"); } @Test public void withClientRegistrationIdWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal((String) null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principalName cannot be empty"); + assertThatIllegalArgumentException().isThrownBy(() -> OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal((String) null).build()) + .withMessage("principalName cannot be empty"); } @Test public void withClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attributes(attrs -> { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) + .attributes((attrs) -> { attrs.put("name1", "value1"); attrs.put("name2", "value2"); - }) - .build(); - + }).build(); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); assertThat(authorizeRequest.getAuthorizedClient()).isNull(); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); @@ -86,14 +88,12 @@ public class OAuth2AuthorizeRequestTests { @Test public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attributes(attrs -> { + .principal(this.principal).attributes((attrs) -> { attrs.put("name1", "value1"); attrs.put("name2", "value2"); - }) - .build(); - - assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); + }).build(); + assertThat(authorizeRequest.getClientRegistrationId()) + .isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); assertThat(authorizeRequest.getAttributes()).contains(entry("name1", "value1"), entry("name2", "value2")); @@ -101,12 +101,12 @@ public class OAuth2AuthorizeRequestTests { @Test public void withClientRegistrationIdWhenPrincipalNameProvidedThenPrincipalCreated() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal("principalName") + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal("principalName") .build(); - assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); assertThat(authorizeRequest.getAuthorizedClient()).isNull(); assertThat(authorizeRequest.getPrincipal().getName()).isEqualTo("principalName"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java index 338f929eaf..f8a9cdf4de 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2AuthorizedClientId}. @@ -29,16 +30,14 @@ public class OAuth2AuthorizedClientIdTests { @Test public void constructorWhenRegistrationIdNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientId(null, "test-principal")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizedClientId(null, "test-principal")) + .withMessage("clientRegistrationId cannot be empty"); } @Test public void constructorWhenPrincipalNameNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientId("test-client", null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principalName cannot be empty"); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizedClientId("test-client", null)) + .withMessage("principalName cannot be empty"); } @Test @@ -82,4 +81,5 @@ public class OAuth2AuthorizedClientIdTests { OAuth2AuthorizedClientId id2 = new OAuth2AuthorizedClientId("test-client", "test-principal2"); assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java index d2ea51f1b9..3275de9195 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpStatus; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; @@ -33,12 +38,15 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.web.client.RestOperations; -import java.time.Duration; -import java.time.Instant; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2AuthorizedClientProviderBuilder}. @@ -46,10 +54,15 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class OAuth2AuthorizedClientProviderBuilderTests { + private RestOperations accessTokenClient; + private DefaultClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient; + private DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient; + private DefaultPasswordTokenResponseClient passwordTokenResponseClient; + private Authentication principal; @SuppressWarnings("unchecked") @@ -57,8 +70,8 @@ public class OAuth2AuthorizedClientProviderBuilderTests { public void setup() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); this.accessTokenClient = mock(RestOperations.class); - when(this.accessTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) - .thenReturn(new ResponseEntity(accessTokenResponse, HttpStatus.OK)); + given(this.accessTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .willReturn(new ResponseEntity(accessTokenResponse, HttpStatus.OK)); this.refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); this.refreshTokenTokenResponseClient.setRestOperations(this.accessTokenClient); this.clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); @@ -70,161 +83,148 @@ public class OAuth2AuthorizedClientProviderBuilderTests { @Test public void providerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizedClientProviderBuilder.builder().provider(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizedClientProviderBuilder.builder().provider(null)); } @Test public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientRegistration().build()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(ClientAuthorizationRequiredException.class); + // @formatter:off + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .build(); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(TestClientRegistrations.clientRegistration().build()) + .principal(this.principal) + .build(); + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> authorizedClientProvider.authorize(authorizationContext)); } @Test public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) - .build(); - + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .refreshToken( + (configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .build(); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - TestClientRegistrations.clientRegistration().build(), - this.principal.getName(), - expiredAccessToken(), + TestClientRegistrations.clientRegistration().build(), this.principal.getName(), expiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext); - assertThat(reauthorizedClient).isNotNull(); verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); } @Test public void buildWhenClientCredentialsProviderThenProviderAuthorizes() { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientCredentials().build()) - .principal(this.principal) - .build(); + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .clientCredentials( + (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(TestClientRegistrations.clientCredentials().build()) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient).isNotNull(); verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); } @Test public void buildWhenPasswordProviderThenProviderAuthorizes() { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .password(configurer -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.password().build()) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + // @formatter:off + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) + .build(); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(TestClientRegistrations.password().build()) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient).isNotNull(); verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); } @Test public void buildWhenAllProvidersThenProvidersAuthorize() { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) - .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) - .password(configurer -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) - .build(); - + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken( + (configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials( + (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) + .build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - - // authorization_code - OAuth2AuthorizationContext authorizationCodeContext = - OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext)) - .isInstanceOf(ClientAuthorizationRequiredException.class); - - + // @formatter:off + OAuth2AuthorizationContext authorizationCodeContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext)); // refresh_token - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, - this.principal.getName(), - expiredAccessToken(), - TestOAuth2RefreshTokens.refreshToken()); - - OAuth2AuthorizationContext refreshTokenContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + this.principal.getName(), expiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); + OAuth2AuthorizationContext refreshTokenContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient).principal(this.principal).build(); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext); - assertThat(reauthorizedClient).isNotNull(); - verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); - - + verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class), + eq(OAuth2AccessTokenResponse.class)); // client_credentials - OAuth2AuthorizationContext clientCredentialsContext = - OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientCredentials().build()) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext + .withClientRegistration(TestClientRegistrations.clientCredentials().build()) + .principal(this.principal) + .build(); + // @formatter:on authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext); - assertThat(authorizedClient).isNotNull(); - verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); - + verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class), + eq(OAuth2AccessTokenResponse.class)); // password - OAuth2AuthorizationContext passwordContext = - OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.password().build()) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + // @formatter:off + OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext + .withClientRegistration(TestClientRegistrations.password().build()) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on authorizedClient = authorizedClientProvider.authorize(passwordContext); - assertThat(authorizedClient).isNotNull(); - verify(this.accessTokenClient, times(3)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + verify(this.accessTokenClient, times(3)).exchange(any(RequestEntity.class), + eq(OAuth2AccessTokenResponse.class)); } @Test public void buildWhenCustomProviderThenProviderCalled() { OAuth2AuthorizedClientProvider customProvider = mock(OAuth2AuthorizedClientProvider.class); - - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .provider(customProvider) - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientRegistration().build()) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .provider(customProvider) + .build(); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(TestClientRegistrations.clientRegistration().build()) + .principal(this.principal) + .build(); + // @formatter:on authorizedClientProvider.authorize(authorizationContext); - verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); } @@ -233,4 +233,5 @@ public class OAuth2AuthorizedClientProviderBuilderTests { Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java index 64925ee25d..a91d541770 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; /** * Tests for {@link OAuth2AuthorizedClient}. @@ -31,15 +32,18 @@ import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.no * @author Joe Grandja */ public class OAuth2AuthorizedClientTests { + private ClientRegistration clientRegistration; + private String principalName; + private OAuth2AccessToken accessToken; @Before public void setUp() { - this.clientRegistration = clientRegistration().build(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.principalName = "principal"; - this.accessToken = noScopes(); + this.accessToken = TestOAuth2AccessTokens.noScopes(); } @Test(expected = IllegalArgumentException.class) @@ -59,11 +63,11 @@ public class OAuth2AuthorizedClientTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principalName, this.accessToken); - + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName); assertThat(authorizedClient.getAccessToken()).isEqualTo(this.accessToken); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProviderTests.java index 8e2dcb1182..10d375e4b6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProviderTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -28,14 +33,11 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import java.time.Duration; -import java.time.Instant; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link PasswordOAuth2AuthorizedClientProvider}. @@ -43,9 +45,13 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class PasswordOAuth2AuthorizedClientProviderTests { + private PasswordOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; @Before @@ -59,85 +65,98 @@ public class PasswordOAuth2AuthorizedClientProviderTests { @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessTokenResponseClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); } @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew must be >= 0"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClock(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clock cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .withMessage("context cannot be null"); + // @formatter:on } @Test public void authorizeWhenNotPasswordThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenPasswordAndNotAuthorizedAndEmptyUsernameThenUnableToAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null) - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null) + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenPasswordAndNotAuthorizedAndEmptyPasswordThenUnableToAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenPasswordAndNotAuthorizedThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -147,44 +166,44 @@ public class PasswordOAuth2AuthorizedClientProviderTests { public void authorizeWhenPasswordAndAuthorizedWithoutRefreshTokenAndTokenExpiredThenReauthorize() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-expired", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), accessToken); // without refresh token - + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-expired", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); // without refresh token OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .principal(this.principal) - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .principal(this.principal) + .build(); + // @formatter:on authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - } @Test public void authorizeWhenPasswordAndAuthorizedWithRefreshTokenAndTokenExpiredThenNotReauthorize() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-expired", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - accessToken, TestOAuth2RefreshTokens.refreshToken()); // with refresh token - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .principal(this.principal) - .build(); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-expired", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken, TestOAuth2RefreshTokens.refreshToken()); // with + // refresh + // token + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @@ -194,28 +213,28 @@ public class PasswordOAuth2AuthorizedClientProviderTests { Instant now = Instant.now(); Instant issuedAt = now.minus(Duration.ofMinutes(60)); Instant expiresAt = now.minus(Duration.ofMinutes(1)); - OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), expiresInOneMinAccessToken); // without refresh token - - // Shorten the lifespan of the access token by 90 seconds, which will ultimately force it to expire on the client + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); // without refresh + // token + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .principal(this.principal) - .build(); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProviderTests.java index 9fefec4b1f..184ce64851 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProviderTests.java @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; @@ -27,16 +33,12 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link PasswordReactiveOAuth2AuthorizedClientProvider}. @@ -44,9 +46,13 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class PasswordReactiveOAuth2AuthorizedClientProviderTests { + private PasswordReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; @Before @@ -60,85 +66,98 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests { @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessTokenResponseClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); } @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew must be >= 0"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClock(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clock cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .withMessage("context cannot be null"); + // @formatter:on } @Test public void authorizeWhenNotPasswordThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal). + build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenPasswordAndNotAuthorizedAndEmptyUsernameThenUnableToAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null) - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null) + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenPasswordAndNotAuthorizedAndEmptyPasswordThenUnableToAuthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenPasswordAndNotAuthorizedThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -148,44 +167,44 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests { public void authorizeWhenPasswordAndAuthorizedWithoutRefreshTokenAndTokenExpiredThenReauthorize() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-expired", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), accessToken); // without refresh token - + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-expired", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); // without refresh token OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .principal(this.principal) - .build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .principal(this.principal) + .build(); + // @formatter:on authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - } @Test public void authorizeWhenPasswordAndAuthorizedWithRefreshTokenAndTokenExpiredThenNotReauthorize() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-expired", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - accessToken, TestOAuth2RefreshTokens.refreshToken()); // with refresh token - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .principal(this.principal) - .build(); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-expired", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken, TestOAuth2RefreshTokens.refreshToken()); // with + // refresh + // token + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @@ -195,28 +214,29 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests { Instant now = Instant.now(); Instant issuedAt = now.minus(Duration.ofMinutes(60)); Instant expiresAt = now.minus(Duration.ofMinutes(1)); - OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), expiresInOneMinAccessToken); // without refresh token - - // Shorten the lifespan of the access token by 90 seconds, which will ultimately force it to expire on the client + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); // without refresh + // token + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .principal(this.principal) - .build(); - - OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java index be877f2ce0..71280097c5 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilderTests.java @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -30,14 +36,14 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link ReactiveOAuth2AuthorizedClientProviderBuilder}. @@ -45,8 +51,11 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { + private ClientRegistration.Builder clientRegistrationBuilder; + private Authentication principal; + private MockWebServer server; @Before @@ -65,55 +74,48 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { @Test public void providerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> ReactiveOAuth2AuthorizedClientProviderBuilder.builder().provider(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> ReactiveOAuth2AuthorizedClientProviderBuilder.builder().provider(null)); } @Test public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext).block()) - .isInstanceOf(ClientAuthorizationRequiredException.class); + // @formatter:off + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder() + .authorizationCode() + .build(); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistrationBuilder.build()) + .principal(this.principal) + .build(); + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> authorizedClientProvider.authorize(authorizationContext).block()); } @Test public void buildWhenRefreshTokenProviderThenProviderReauthorizes() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .refreshToken() - .build(); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistrationBuilder.build(), - this.principal.getName(), - expiredAccessToken(), - TestOAuth2RefreshTokens.refreshToken()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + // @formatter:off + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder() + .refreshToken() + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), + this.principal.getName(), expiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); - assertThat(reauthorizedClient).isNotNull(); - assertThat(this.server.getRequestCount()).isEqualTo(1); - RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=refresh_token"); @@ -121,28 +123,23 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { @Test public void buildWhenClientCredentialsProviderThenProviderAuthorizes() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .clientCredentials() - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) - .principal(this.principal) - .build(); + // @formatter:off + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder() + .clientCredentials() + .build(); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistrationBuilder + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); - assertThat(authorizedClient).isNotNull(); - assertThat(this.server.getRequestCount()).isEqualTo(1); - RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=client_credentials"); @@ -150,30 +147,24 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { @Test public void buildWhenPasswordProviderThenProviderAuthorizes() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .password() - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build()) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); - OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); - + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder().password().build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration( + this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build()) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext) + .block(); + // @formatter:on assertThat(authorizedClient).isNotNull(); - assertThat(this.server.getRequestCount()).isEqualTo(1); - RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=password"); @@ -181,82 +172,60 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { @Test public void buildWhenAllProvidersThenProvidersAuthorize() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .password() - .build(); - + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder().authorizationCode().refreshToken().clientCredentials().password().build(); // authorization_code - OAuth2AuthorizationContext authorizationCodeContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext).block()) - .isInstanceOf(ClientAuthorizationRequiredException.class); - - + // @formatter:off + OAuth2AuthorizationContext authorizationCodeContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistrationBuilder.build()) + .principal(this.principal) + .build(); + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext).block()); // refresh_token - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistrationBuilder.build(), - this.principal.getName(), - expiredAccessToken(), - TestOAuth2RefreshTokens.refreshToken()); - - OAuth2AuthorizationContext refreshTokenContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), + this.principal.getName(), expiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); + OAuth2AuthorizationContext refreshTokenContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient).principal(this.principal).build(); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext).block(); - assertThat(reauthorizedClient).isNotNull(); - assertThat(this.server.getRequestCount()).isEqualTo(1); - RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=refresh_token"); - - // client_credentials - OAuth2AuthorizationContext clientCredentialsContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistrationBuilder + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) + .principal(this.principal) + .build(); + // @formatter:on authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext).block(); - assertThat(authorizedClient).isNotNull(); - assertThat(this.server.getRequestCount()).isEqualTo(2); - recordedRequest = this.server.takeRequest(); formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=client_credentials"); - // password - OAuth2AuthorizationContext passwordContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build()) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") - .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") - .build(); + // @formatter:off + OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext + .withClientRegistration( + this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build()) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") + .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password") + .build(); + // @formatter:on authorizedClient = authorizedClientProvider.authorize(passwordContext).block(); - assertThat(authorizedClient).isNotNull(); - assertThat(this.server.getRequestCount()).isEqualTo(3); - recordedRequest = this.server.takeRequest(); formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=password"); @@ -265,19 +234,18 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { @Test public void buildWhenCustomProviderThenProviderCalled() { ReactiveOAuth2AuthorizedClientProvider customProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(customProvider.authorize(any())).thenReturn(Mono.empty()); - - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .provider(customProvider) - .build(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistrationBuilder.build()) - .principal(this.principal) - .build(); + given(customProvider.authorize(any())).willReturn(Mono.empty()); + // @formatter:off + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder() + .provider(customProvider) + .build(); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistrationBuilder.build()) + .principal(this.principal) + .build(); + // @formatter:on authorizedClientProvider.authorize(authorizationContext).block(); - verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); } @@ -288,8 +256,7 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests { } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 3b373000b3..e1a63fc2d2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -13,11 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -30,14 +37,12 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashSet; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link RefreshTokenOAuth2AuthorizedClientProvider}. @@ -45,10 +50,15 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class RefreshTokenOAuth2AuthorizedClientProviderTests { + private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; @Before @@ -60,65 +70,78 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { this.principal = new TestingAuthenticationToken("principal", "password"); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), expiredAccessToken, TestOAuth2RefreshTokens.refreshToken()); } @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessTokenResponseClient cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew must be >= 0"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClock(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clock cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .withMessage("context cannot be null"); + // @formatter:on } @Test public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), this.authorizedClient.getAccessToken()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @@ -126,11 +149,12 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @@ -138,28 +162,25 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenAuthorizedAndAccessTokenNotExpiredButClockSkewForcesExpiryThenReauthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() - .refreshToken("new-refresh-token") - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + .refreshToken("new-refresh-token").build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); Instant now = Instant.now(); Instant issuedAt = now.minus(Duration.ofMinutes(60)); Instant expiresAt = now.minus(Duration.ofMinutes(1)); - OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), - expiresInOneMinAccessToken, this.authorizedClient.getRefreshToken()); - - // Shorten the lifespan of the access token by 90 seconds, which will ultimately force it to expire on the client + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken, this.authorizedClient.getRefreshToken()); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); - + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -168,18 +189,20 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse() .refreshToken("new-refresh-token") .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); - + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -188,38 +211,43 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse() .refreshToken("new-refresh-token") .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); String[] requestScope = new String[] { "read", "write" }; - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) - .build(); - + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) + .build(); + // @formatter:on this.authorizedClientProvider.authorize(authorizationContext); - - ArgumentCaptor refreshTokenGrantRequestArgCaptor = - ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); + ArgumentCaptor refreshTokenGrantRequestArgCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); - assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(new HashSet<>(Arrays.asList(requestScope))); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()) + .isEqualTo(new HashSet<>(Arrays.asList(requestScope))); } @Test public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { String invalidRequestScope = "read write"; - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) - .build(); - - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageStartingWith("The context attribute must be of type String[] '" + - OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) + .build(); + // @formatter:on + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .withMessageStartingWith("The context attribute must be of type String[] '" + + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java index 6ba450f15c..504c9eac2e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; @@ -29,16 +37,13 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashSet; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}. @@ -46,10 +51,15 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { + private RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; @Before @@ -61,65 +71,76 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { this.principal = new TestingAuthenticationToken("principal", "password"); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); - OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), expiredAccessToken, TestOAuth2RefreshTokens.refreshToken()); } @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessTokenResponseClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); } @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clockSkew must be >= 0"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.setClock(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clock cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on } @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .withMessage("context cannot be null"); + // @formatter:on } @Test public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @Test public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), this.authorizedClient.getAccessToken()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @@ -127,11 +148,12 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); } @@ -139,28 +161,22 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenAuthorizedAndAccessTokenNotExpiredButClockSkewForcesExpiryThenReauthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() - .refreshToken("new-refresh-token") - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - + .refreshToken("new-refresh-token").build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); Instant now = Instant.now(); Instant issuedAt = now.minus(Duration.ofMinutes(60)); Instant expiresAt = now.minus(Duration.ofMinutes(1)); - OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), - expiresInOneMinAccessToken, this.authorizedClient.getRefreshToken()); - - // Shorten the lifespan of the access token by 90 seconds, which will ultimately force it to expire on the client + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken, this.authorizedClient.getRefreshToken()); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(this.principal) - .build(); - - OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient).principal(this.principal).build(); + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -170,17 +186,12 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() - .refreshToken("new-refresh-token") - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); - - OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); - + .refreshToken("new-refresh-token").build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient).principal(this.principal).build(); + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); @@ -190,37 +201,38 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { @Test public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() - .refreshToken("new-refresh-token") - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - + .refreshToken("new-refresh-token").build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); String[] requestScope = new String[] { "read", "write" }; - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) - .build(); - + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) + .build(); + // @formatter:on this.authorizedClientProvider.authorize(authorizationContext).block(); - - ArgumentCaptor refreshTokenGrantRequestArgCaptor = - ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); + ArgumentCaptor refreshTokenGrantRequestArgCaptor = ArgumentCaptor + .forClass(OAuth2RefreshTokenGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); - assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(new HashSet<>(Arrays.asList(requestScope))); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()) + .isEqualTo(new HashSet<>(Arrays.asList(requestScope))); } @Test public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { String invalidRequestScope = "read write"; - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) - .build(); - - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageStartingWith("The context attribute must be of type String[] '" + - OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) + .build(); + // @formatter:on + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) + .withMessageStartingWith("The context attribute must be of type String[] '" + + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationTokenTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationTokenTests.java index 8e0823fe76..c250c84f06 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationTokenTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationTokenTests.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.authentication; -import org.junit.Before; -import org.junit.Test; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.oauth2.core.user.OAuth2User; +package org.springframework.security.oauth2.client.authentication; import java.util.Collection; import java.util.Collections; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.oauth2.core.user.OAuth2User; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; @@ -32,8 +34,11 @@ import static org.mockito.Mockito.mock; * @author Joe Grandja */ public class OAuth2AuthenticationTokenTests { + private OAuth2User principal; + private Collection authorities; + private String authorizedClientRegistrationId; @Before @@ -65,13 +70,13 @@ public class OAuth2AuthenticationTokenTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( - this.principal, this.authorities, this.authorizedClientRegistrationId); - + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(this.principal, this.authorities, + this.authorizedClientRegistrationId); assertThat(authentication.getPrincipal()).isEqualTo(this.principal); assertThat(authentication.getCredentials()).isEqualTo(""); assertThat(authentication.getAuthorities()).isEqualTo(this.authorities); assertThat(authentication.getAuthorizedClientRegistrationId()).isEqualTo(this.authorizedClientRegistrationId); assertThat(authentication.isAuthenticated()).isEqualTo(true); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 41ebe4a1e6..2ff9683ef4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; import java.util.Collections; @@ -25,23 +26,23 @@ import org.junit.Test; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}. @@ -49,24 +50,27 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class OAuth2AuthorizationCodeAuthenticationProviderTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationRequest authorizationRequest; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; @Before @SuppressWarnings("unchecked") public void setUp() { - this.clientRegistration = clientRegistration().build(); - this.authorizationRequest = request().build(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizationRequest = TestOAuth2AuthorizationRequests.request().build(); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient); } @Test public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null)); } @Test @@ -76,41 +80,38 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthorizationException() { - OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( - this.authorizationRequest, authorizationResponse); - - assertThatThrownBy(() -> { - this.authenticationProvider.authenticate( - new OAuth2AuthorizationCodeAuthenticationToken( - this.clientRegistration, authorizationExchange)); - }).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() + .errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + authorizationResponse); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange))) + .withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); } @Test public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthorizationException() { - OAuth2AuthorizationResponse authorizationResponse = success().state("67890").build(); - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( - this.authorizationRequest, authorizationResponse); - - assertThatThrownBy(() -> { - this.authenticationProvider.authenticate( - new OAuth2AuthorizationCodeAuthenticationToken( - this.clientRegistration, authorizationExchange)); - }).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining("invalid_state_parameter"); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890") + .build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + authorizationResponse); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange))) + .withMessageContaining("invalid_state_parameter"); } @Test public void authenticateWhenAuthorizationSuccessResponseThenExchangedForAccessToken() { - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().refreshToken("refresh").build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( - this.authorizationRequest, success().build()); - OAuth2AuthorizationCodeAuthenticationToken authenticationResult = - (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange)); - + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("refresh").build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + TestOAuth2AuthorizationResponses.success().build()); + OAuth2AuthorizationCodeAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider + .authenticate( + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange)); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); assertThat(authenticationResult.getCredentials()) @@ -128,19 +129,16 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - - OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().additionalParameters(additionalParameters) - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .additionalParameters(additionalParameters).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, - success().build()); - + TestOAuth2AuthorizationResponses.success().build()); OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider .authenticate( new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange)); - assertThat(authentication.getAdditionalParameters()) .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java index e153e63e56..b96992ea85 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; import java.util.Collections; @@ -21,15 +22,15 @@ import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}. @@ -37,37 +38,40 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class OAuth2AuthorizationCodeAuthenticationTokenTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessToken accessToken; @Before public void setUp() { - this.clientRegistration = clientRegistration().build(); - this.authorizationExchange = new OAuth2AuthorizationExchange(request().build(), - success().code("code").build()); - this.accessToken = noScopes(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizationExchange = new OAuth2AuthorizationExchange(TestOAuth2AuthorizationRequests.request().build(), + TestOAuth2AuthorizationResponses.success().code("code").build()); + this.accessToken = TestOAuth2AccessTokens.noScopes(); } @Test public void constructorAuthorizationRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.authorizationExchange)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.authorizationExchange)); } @Test public void constructorAuthorizationRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, null)); } @Test public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() { - OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange); - + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + this.clientRegistration, this.authorizationExchange); assertThat(authentication.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); - assertThat(authentication.getCredentials()).isEqualTo(this.authorizationExchange.getAuthorizationResponse().getCode()); + assertThat(authentication.getCredentials()) + .isEqualTo(this.authorizationExchange.getAuthorizationResponse().getCode()); assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList()); assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); @@ -77,27 +81,27 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { @Test public void constructorTokenRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.authorizationExchange, this.accessToken)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, + this.authorizationExchange, this.accessToken)); } @Test public void constructorTokenRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, null, this.accessToken)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, null, this.accessToken)); } @Test public void constructorTokenRequestResponseWhenAccessTokenIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, + this.authorizationExchange, null)); } @Test public void constructorTokenRequestResponseWhenAllParametersProvidedAndValidThenCreated() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( - this.clientRegistration, this.authorizationExchange, this.accessToken); - + this.clientRegistration, this.authorizationExchange, this.accessToken); assertThat(authentication.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); assertThat(authentication.getCredentials()).isEqualTo(this.accessToken.getTokenValue()); assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList()); @@ -106,4 +110,5 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken); assertThat(authentication.isAuthenticated()).isEqualTo(true); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManagerTests.java index f375342212..91731b0642 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManagerTests.java @@ -38,9 +38,9 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -48,6 +48,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizationCodeReactiveAuthenticationManagerTests { + @Mock private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; @@ -59,8 +60,7 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManagerTests { private OAuth2AuthorizationResponse.Builder authorizationResponse = TestOAuth2AuthorizationResponses.success(); - private OAuth2AccessTokenResponse.Builder tokenResponse = TestOAuth2AccessTokenResponses - .accessTokenResponse(); + private OAuth2AccessTokenResponse.Builder tokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse(); @Before public void setup() { @@ -70,48 +70,42 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManagerTests { @Test public void authenticateWhenErrorThenOAuth2AuthorizationException() { this.authorizationResponse = TestOAuth2AuthorizationResponses.error(); - assertThatCode(() -> authenticate()) - .isInstanceOf(OAuth2AuthorizationException.class); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> authenticate()); } @Test public void authenticateWhenStateNotEqualThenOAuth2AuthorizationException() { this.authorizationRequest.state("notequal"); - assertThatCode(() -> authenticate()) - .isInstanceOf(OAuth2AuthorizationException.class); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> authenticate()); } @Test public void authenticateWhenValidThenSuccess() { - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(this.tokenResponse.build())); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(this.tokenResponse.build())); OAuth2AuthorizationCodeAuthenticationToken result = authenticate(); - assertThat(result).isNotNull(); } @Test public void authenticateWhenEmptyThenEmpty() { - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.empty()); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.empty()); OAuth2AuthorizationCodeAuthenticationToken result = authenticate(); - assertThat(result).isNull(); } @Test public void authenticateWhenOAuth2AuthorizationExceptionThenOAuth2AuthorizationException() { - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.error(() -> new OAuth2AuthorizationException(new OAuth2Error("error")))); - - assertThatCode(() -> authenticate()) - .isInstanceOf(OAuth2AuthorizationException.class); + given(this.accessTokenResponseClient.getTokenResponse(any())) + .willReturn(Mono.error(() -> new OAuth2AuthorizationException(new OAuth2Error("error")))); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> authenticate()); } private OAuth2AuthorizationCodeAuthenticationToken authenticate() { - OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange( - this.authorizationRequest.build(), this.authorizationResponse.build()); + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(this.authorizationRequest.build(), + this.authorizationResponse.build()); OAuth2AuthorizationCodeAuthenticationToken token = new OAuth2AuthorizationCodeAuthenticationToken( this.registration.build(), exchange); return (OAuth2AuthorizationCodeAuthenticationToken) this.manager.authenticate(token).block(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java index 8a2aaa10d1..cbe7c6b05c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; import java.time.Instant; @@ -36,6 +37,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -45,18 +47,16 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.oauth2.core.user.OAuth2User; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2LoginAuthenticationProvider}. @@ -64,12 +64,19 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class OAuth2LoginAuthenticationProviderTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationRequest authorizationRequest; + private OAuth2AuthorizationResponse authorizationResponse; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private OAuth2UserService userService; + private OAuth2LoginAuthenticationProvider authenticationProvider; @Rule @@ -78,13 +85,15 @@ public class OAuth2LoginAuthenticationProviderTests { @Before @SuppressWarnings("unchecked") public void setUp() { - this.clientRegistration = clientRegistration().build(); - this.authorizationRequest = request().scope("scope1", "scope2").build(); - this.authorizationResponse = success().build(); - this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizationRequest = TestOAuth2AuthorizationRequests.request().scope("scope1", "scope2").build(); + this.authorizationResponse = TestOAuth2AuthorizationResponses.success().build(); + this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + this.authorizationResponse); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.userService = mock(OAuth2UserService.class); - this.authenticationProvider = new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, this.userService); + this.authenticationProvider = new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, + this.userService); } @Test @@ -112,14 +121,12 @@ public class OAuth2LoginAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() { - OAuth2AuthorizationRequest authorizationRequest = request().scope("openid").build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse); - - OAuth2LoginAuthenticationToken authentication = - (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().scope("openid") + .build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + this.authorizationResponse); + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); assertThat(authentication).isNull(); } @@ -127,45 +134,36 @@ public class OAuth2LoginAuthenticationProviderTests { public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST)); - - OAuth2AuthorizationResponse authorizationResponse = - error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() + .errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + authorizationResponse); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_state_parameter")); - - OAuth2AuthorizationResponse authorizationResponse = - success().state("67890").build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890") + .build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + authorizationResponse); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test public void authenticateWhenLoginSuccessThenReturnAuthentication() { OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User principal = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(principal.getAuthorities()).thenAnswer( - (Answer>) invocation -> authorities); - when(this.userService.loadUser(any())).thenReturn(principal); - - OAuth2LoginAuthenticationToken authentication = - (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - + given(principal.getAuthorities()).willAnswer((Answer>) (invocation) -> authorities); + given(this.userService.loadUser(any())).willReturn(principal); + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); assertThat(authentication.isAuthenticated()).isTrue(); assertThat(authentication.getPrincipal()).isEqualTo(principal); assertThat(authentication.getCredentials()).isEqualTo(""); @@ -179,24 +177,18 @@ public class OAuth2LoginAuthenticationProviderTests { @Test public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User principal = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(principal.getAuthorities()).thenAnswer( - (Answer>) invocation -> authorities); - when(this.userService.loadUser(any())).thenReturn(principal); - + given(principal.getAuthorities()).willAnswer((Answer>) (invocation) -> authorities); + given(this.userService.loadUser(any())).willReturn(principal); List mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OAUTH2_USER"); GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class); - when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer( - (Answer>) invocation -> mappedAuthorities); + given(authoritiesMapper.mapAuthorities(anyCollection())) + .willAnswer((Answer>) (invocation) -> mappedAuthorities); this.authenticationProvider.setAuthoritiesMapper(authoritiesMapper); - - OAuth2LoginAuthenticationToken authentication = - (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities); } @@ -204,20 +196,16 @@ public class OAuth2LoginAuthenticationProviderTests { @Test public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() { OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); OAuth2User principal = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(principal.getAuthorities()).thenAnswer( - (Answer>) invocation -> authorities); + given(principal.getAuthorities()).willAnswer((Answer>) (invocation) -> authorities); ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class); - when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - - assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf( - accessTokenResponse.getAdditionalParameters()); + given(this.userService.loadUser(userRequestArgCaptor.capture())).willReturn(principal); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()) + .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); } private OAuth2AccessTokenResponse accessTokenSuccessResponse() { @@ -226,15 +214,15 @@ public class OAuth2LoginAuthenticationProviderTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - - return OAuth2AccessTokenResponse - .withToken("access-token-1234") + // @formatter:off + return OAuth2AccessTokenResponse.withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(expiresAt.getEpochSecond()) .scopes(scopes) .refreshToken("refresh-token-1234") .additionalParameters(additionalParameters) .build(); - + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java index 66feaa0838..cb83eef68b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.authentication; import java.util.Collection; @@ -23,16 +24,16 @@ import org.junit.Test; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.oauth2.core.user.OAuth2User; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2LoginAuthenticationToken}. @@ -40,20 +41,25 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class OAuth2LoginAuthenticationTokenTests { + private OAuth2User principal; + private Collection authorities; + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessToken accessToken; @Before public void setUp() { this.principal = mock(OAuth2User.class); this.authorities = Collections.emptyList(); - this.clientRegistration = clientRegistration().build(); - this.authorizationExchange = new OAuth2AuthorizationExchange( - request().build(), success().code("code").build()); - this.accessToken = noScopes(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizationExchange = new OAuth2AuthorizationExchange(TestOAuth2AuthorizationRequests.request().build(), + TestOAuth2AuthorizationResponses.success().code("code").build()); + this.accessToken = TestOAuth2AccessTokens.noScopes(); } @Test(expected = IllegalArgumentException.class) @@ -68,9 +74,8 @@ public class OAuth2LoginAuthenticationTokenTests { @Test public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() { - OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken( - this.clientRegistration, this.authorizationExchange); - + OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken(this.clientRegistration, + this.authorizationExchange); assertThat(authentication.getPrincipal()).isNull(); assertThat(authentication.getCredentials()).isEqualTo(""); assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList()); @@ -82,45 +87,44 @@ public class OAuth2LoginAuthenticationTokenTests { @Test(expected = IllegalArgumentException.class) public void constructorTokenRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - new OAuth2LoginAuthenticationToken(null, this.authorizationExchange, this.principal, - this.authorities, this.accessToken); + new OAuth2LoginAuthenticationToken(null, this.authorizationExchange, this.principal, this.authorities, + this.accessToken); } @Test(expected = IllegalArgumentException.class) public void constructorTokenRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() { - new OAuth2LoginAuthenticationToken(this.clientRegistration, null, this.principal, - this.authorities, this.accessToken); + new OAuth2LoginAuthenticationToken(this.clientRegistration, null, this.principal, this.authorities, + this.accessToken); } @Test(expected = IllegalArgumentException.class) public void constructorTokenRequestResponseWhenPrincipalIsNullThenThrowIllegalArgumentException() { - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, null, - this.authorities, this.accessToken); + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, null, this.authorities, + this.accessToken); } @Test public void constructorTokenRequestResponseWhenAuthoritiesIsNullThenCreated() { - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, - this.principal, null, this.accessToken); + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, this.principal, null, + this.accessToken); } @Test public void constructorTokenRequestResponseWhenAuthoritiesIsEmptyThenCreated() { - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, - this.principal, Collections.emptyList(), this.accessToken); + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, this.principal, + Collections.emptyList(), this.accessToken); } @Test(expected = IllegalArgumentException.class) public void constructorTokenRequestResponseWhenAccessTokenIsNullThenThrowIllegalArgumentException() { new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, this.principal, - this.authorities, null); + this.authorities, null); } @Test public void constructorTokenRequestResponseWhenAllParametersProvidedAndValidThenCreated() { - OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken( - this.clientRegistration, this.authorizationExchange, this.principal, this.authorities, this.accessToken); - + OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken(this.clientRegistration, + this.authorizationExchange, this.principal, this.authorities, this.accessToken); assertThat(authentication.getPrincipal()).isEqualTo(this.principal); assertThat(authentication.getCredentials()).isEqualTo(""); assertThat(authentication.getAuthorities()).isEqualTo(this.authorities); @@ -129,4 +133,5 @@ public class OAuth2LoginAuthenticationTokenTests { assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken); assertThat(authentication.isAuthenticated()).isEqualTo(true); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java index dd27077582..06eb250723 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java @@ -16,14 +16,6 @@ package org.springframework.security.oauth2.client.authentication; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyCollection; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -37,6 +29,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; @@ -57,8 +51,13 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; -import reactor.core.publisher.Mono; - +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; /** * @author Rob Winch @@ -66,6 +65,7 @@ import reactor.core.publisher.Mono; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2LoginReactiveAuthenticationManagerTests { + @Mock private ReactiveOAuth2UserService userService; @@ -77,8 +77,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); - OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse - .success("code") + OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse.success("code") .state("state"); private OAuth2LoginReactiveAuthenticationManager manager; @@ -91,33 +90,29 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { @Test public void constructorWhenNullAccessTokenResponseClientThenIllegalArgumentException() { this.accessTokenResponseClient = null; - assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)); } @Test public void constructorWhenNullUserServiceThenIllegalArgumentException() { this.userService = null; - assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)); } @Test public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.manager.setAuthoritiesMapper(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.manager.setAuthoritiesMapper(null)); } @Test public void authenticateWhenNoSubscriptionThenDoesNothing() { - // we didn't do anything because it should cause a ClassCastException (as verified below) + // we didn't do anything because it should cause a ClassCastException (as verified + // below) TestingAuthenticationToken token = new TestingAuthenticationToken("a", "b"); - - assertThatCode(()-> this.manager.authenticate(token)) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.manager.authenticate(token).block()) - .isInstanceOf(Throwable.class); + this.manager.authenticate(token); + assertThatExceptionOfType(Throwable.class).isThrownBy(() -> this.manager.authenticate(token).block()); } @Test @@ -129,41 +124,41 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { @Test public void authenticationWhenErrorThenOAuth2AuthenticationException() { + // @formatter:off this.authorizationResponseBldr = OAuth2AuthorizationResponse .error("error") .state("state"); - assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + // @formatter:on + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(loginToken()).block()); } @Test public void authenticationWhenStateDoesNotMatchThenOAuth2AuthenticationException() { this.authorizationResponseBldr.state("notmatch"); - assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(loginToken()).block()); } @Test public void authenticationWhenOAuth2UserNotFoundThenEmpty() { OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - when(this.userService.loadUser(any())).thenReturn(Mono.empty()); + .tokenType(OAuth2AccessToken.TokenType.BEARER).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + given(this.userService.loadUser(any())).willReturn(Mono.empty()); assertThat(this.manager.authenticate(loginToken()).block()).isNull(); } @Test public void authenticationWhenOAuth2UserFoundThenSuccess() { OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user"); - when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); - - OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block(); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); + given(this.userService.loadUser(any())).willReturn(Mono.just(user)); + OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()) + .block(); assertThat(result.getPrincipal()).isEqualTo(user); assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities()); assertThat(result.isAuthenticated()).isTrue(); @@ -176,16 +171,13 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(additionalParameters) - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user"); + .tokenType(OAuth2AccessToken.TokenType.BEARER).additionalParameters(additionalParameters).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class); - when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user)); - + given(this.userService.loadUser(userRequestArgCaptor.capture())).willReturn(Mono.just(user)); this.manager.authenticate(loginToken()).block(); - assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()) .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); } @@ -193,36 +185,32 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { @Test public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user"); - when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); + .tokenType(OAuth2AccessToken.TokenType.BEARER).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); + given(this.userService.loadUser(any())).willReturn(Mono.just(user)); List mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OAUTH_USER"); GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class); - when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer((Answer>) invocation -> mappedAuthorities); - manager.setAuthoritiesMapper(authoritiesMapper); - - OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block(); - + given(authoritiesMapper.mapAuthorities(anyCollection())) + .willAnswer((Answer>) (invocation) -> mappedAuthorities); + this.manager.setAuthoritiesMapper(authoritiesMapper); + OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()) + .block(); assertThat(result.getAuthorities()).isEqualTo(mappedAuthorities); } private OAuth2AuthorizationCodeAuthenticationToken loginToken() { ClientRegistration clientRegistration = this.registration.build(); - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest - .authorizationCode() - .state("state") + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode().state("state") .clientId(clientRegistration.getClientId()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) - .redirectUri(clientRegistration.getRedirectUri()) - .scopes(clientRegistration.getScopes()) - .build(); + .redirectUri(clientRegistration.getRedirectUri()).scopes(clientRegistration.getScopes()).build(); OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr - .redirectUri(clientRegistration.getRedirectUri()) - .build(); + .redirectUri(clientRegistration.getRedirectUri()).build(); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); return new OAuth2AuthorizationCodeAuthenticationToken(clientRegistration, authorizationExchange); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthenticationTokens.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthenticationTokens.java index e3b3815287..846ad1abec 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthenticationTokens.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthenticationTokens.java @@ -25,7 +25,10 @@ import org.springframework.security.oauth2.core.user.TestOAuth2Users; * @author Josh Cummings * @since 5.2 */ -public class TestOAuth2AuthenticationTokens { +public final class TestOAuth2AuthenticationTokens { + + private TestOAuth2AuthenticationTokens() { + } public static OAuth2AuthenticationToken authenticated() { DefaultOAuth2User principal = TestOAuth2Users.create(); @@ -38,4 +41,5 @@ public class TestOAuth2AuthenticationTokens { String registrationId = "registration-id"; return new OAuth2AuthenticationToken(principal, principal.getAuthorities(), registrationId); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthorizationCodeAuthenticationTokens.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthorizationCodeAuthenticationTokens.java index a17ec5008f..bde27fc388 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthorizationCodeAuthenticationTokens.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/TestOAuth2AuthorizationCodeAuthenticationTokens.java @@ -29,7 +29,10 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization * @author Rob Winch * @since 5.1 */ -public class TestOAuth2AuthorizationCodeAuthenticationTokens { +public final class TestOAuth2AuthorizationCodeAuthenticationTokens { + + private TestOAuth2AuthorizationCodeAuthenticationTokens() { + } public static OAuth2AuthorizationCodeAuthenticationToken unauthenticated() { ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); @@ -44,4 +47,5 @@ public class TestOAuth2AuthorizationCodeAuthenticationTokens { OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); return new OAuth2AuthorizationCodeAuthenticationToken(registration, exchange, accessToken, refreshToken); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java index 2ac1750edb..7bd0f9ba04 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClientTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import java.time.Instant; @@ -38,7 +39,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultAuthorizationCodeTokenResponseClient}. @@ -46,9 +48,11 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class DefaultAuthorizationCodeTokenResponseClientTests { - private DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = - new DefaultAuthorizationCodeTokenResponseClient(); + + private DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient(); + private ClientRegistration clientRegistration; + private MockWebServer server; @Before @@ -56,7 +60,9 @@ public class DefaultAuthorizationCodeTokenResponseClientTests { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") + // @formatter:off + this.clientRegistration = ClientRegistration + .withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) @@ -69,6 +75,7 @@ public class DefaultAuthorizationCodeTokenResponseClientTests { .userNameAttributeName("id") .clientName("client-1") .build(); + // @formatter:on } @After @@ -78,52 +85,46 @@ public class DefaultAuthorizationCodeTokenResponseClientTests { @Test public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)); } @Test public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)); } @Test public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - - OAuth2AccessTokenResponse accessTokenResponse = - this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(this.authorizationCodeGrantRequest()); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); - + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=authorization_code"); assertThat(formParameters).contains("code=code-1234"); assertThat(formParameters).contains("redirect_uri=https%3A%2F%2Fclient.com%2Fcallback%2Fclient-1"); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); @@ -136,37 +137,34 @@ public class DefaultAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); } @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.from(this.clientRegistration) - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); - + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration)); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("client_id=client-1"); assertThat(formParameters).contains("client_secret=secret"); @@ -174,147 +172,139 @@ public class DefaultAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("tokenType cannot be null"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .withMessageContaining("tokenType cannot be null"); } @Test public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("tokenType cannot be null"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .withMessageContaining("tokenType cannot be null"); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"scope\": \"read\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2AccessTokenResponse accessTokenResponse = - this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(this.authorizationCodeGrantRequest()); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"refresh_token\": \"refresh-token-1234\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"refresh_token\": \"refresh-token-1234\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2AccessTokenResponse accessTokenResponse = - this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(this.authorizationCodeGrantRequest()); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); } @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { String invalidTokenUri = "https://invalid-provider.com/oauth2/token"; - ClientRegistration clientRegistration = this.from(this.clientRegistration) - .tokenUri(invalidTokenUri) - .build(); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration))) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + ClientRegistration clientRegistration = this.from(this.clientRegistration).tokenUri(invalidTokenUri).build(); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy( + () -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration))) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @Test public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n"; -// "}\n"; // Make the JSON invalid/malformed + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; + String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[unauthorized_client]"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .withMessageContaining("[unauthorized_client]"); } @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve " + + "the OAuth 2.0 Access Token Response"); } private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() { return this.authorizationCodeGrantRequest(this.clientRegistration); } - private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest( - ClientRegistration clientRegistration) { - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest - .authorizationCode() - .clientId(clientRegistration.getClientId()) - .state("state-1234") + private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest(ClientRegistration clientRegistration) { + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .clientId(clientRegistration.getClientId()).state("state-1234") .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) - .redirectUri(clientRegistration.getRedirectUri()) - .scopes(clientRegistration.getScopes()) - .build(); - OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse - .success("code-1234") - .state("state-1234") - .redirectUri(clientRegistration.getRedirectUri()) - .build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + .redirectUri(clientRegistration.getRedirectUri()).scopes(clientRegistration.getScopes()).build(); + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success("code-1234") + .state("state-1234").redirectUri(clientRegistration.getRedirectUri()).build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); return new OAuth2AuthorizationCodeGrantRequest(clientRegistration, authorizationExchange); } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } private ClientRegistration.Builder from(ClientRegistration registration) { + // @formatter:off return ClientRegistration.withRegistrationId(registration.getRegistrationId()) .clientId(registration.getClientId()) .clientSecret(registration.getClientSecret()) @@ -325,7 +315,10 @@ public class DefaultAuthorizationCodeTokenResponseClientTests { .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) .tokenUri(registration.getProviderDetails().getTokenUri()) .userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri()) - .userNameAttributeName(registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) + .userNameAttributeName( + registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) .clientName(registration.getClientName()); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java index b24de2720b..6c7dee1f87 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import java.time.Instant; @@ -35,7 +36,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultClientCredentialsTokenResponseClient}. @@ -43,8 +45,11 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class DefaultClientCredentialsTokenResponseClientTests { + private DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + private ClientRegistration clientRegistration; + private MockWebServer server; @Before @@ -52,6 +57,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") @@ -60,6 +66,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { .scope("read", "write") .tokenUri(tokenUri) .build(); + // @formatter:on } @After @@ -69,52 +76,52 @@ public class DefaultClientCredentialsTokenResponseClientTests { @Test public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)); + // @formatter:on } @Test public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)); + // @formatter:on } @Test public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(clientCredentialsGrantRequest); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); - + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=client_credentials"); assertThat(formParameters).contains("scope=read+write"); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); @@ -127,43 +134,38 @@ public class DefaultClientCredentialsTokenResponseClientTests { @Test public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); } @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.from(this.clientRegistration) - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("client_id=client-1"); assertThat(formParameters).contains("client_secret=secret"); @@ -171,142 +173,136 @@ public class DefaultClientCredentialsTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("tokenType cannot be null"); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .withMessageContaining("tokenType cannot be null"); } @Test public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\"\n" + - "}\n"; + String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("tokenType cannot be null"); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .withMessageContaining("tokenType cannot be null"); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(clientCredentialsGrantRequest); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(clientCredentialsGrantRequest); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); } @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { String invalidTokenUri = "https://invalid-provider.com/oauth2/token"; - ClientRegistration clientRegistration = this.from(this.clientRegistration) - .tokenUri(invalidTokenUri) - .build(); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + ClientRegistration clientRegistration = this.from(this.clientRegistration).tokenUri(invalidTokenUri).build(); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @Test public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n"; -// "}\n"; // Make the JSON invalid/malformed + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; + // @formatter:off + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[unauthorized_client]"); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .withMessageContaining("[unauthorized_client]"); } @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest)) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } private ClientRegistration.Builder from(ClientRegistration registration) { + // @formatter:off return ClientRegistration.withRegistrationId(registration.getRegistrationId()) .clientId(registration.getClientId()) .clientSecret(registration.getClientSecret()) @@ -314,5 +310,7 @@ public class DefaultClientCredentialsTokenResponseClientTests { .authorizationGrantType(registration.getAuthorizationGrantType()) .scope(registration.getScopes()) .tokenUri(registration.getProviderDetails().getTokenUri()); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java index 6390ae3959..bbe593a627 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClientTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import java.time.Instant; @@ -36,7 +37,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultPasswordTokenResponseClient}. @@ -44,10 +46,15 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class DefaultPasswordTokenResponseClientTests { + private DefaultPasswordTokenResponseClient tokenResponseClient = new DefaultPasswordTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private String username = "user1"; + private String password = "password"; + private MockWebServer server; @Before @@ -56,9 +63,7 @@ public class DefaultPasswordTokenResponseClientTests { this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .scope("read", "write") - .tokenUri(tokenUri); + .authorizationGrantType(AuthorizationGrantType.PASSWORD).scope("read", "write").tokenUri(tokenUri); } @After @@ -68,79 +73,70 @@ public class DefaultPasswordTokenResponseClientTests { @Test public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)); } @Test public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)); } @Test public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - clientRegistration, this.username, this.password); - + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest); - Instant expiresAtAfter = Instant.now().plusSeconds(3600); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); - + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=password"); assertThat(formParameters).contains("username=user1"); assertThat(formParameters).contains("password=password"); assertThat(formParameters).contains("scope=read+write"); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(clientRegistration.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getAccessToken().getScopes()) + .containsExactly(clientRegistration.getScopes().toArray(new String[0])); assertThat(accessTokenResponse.getRefreshToken()).isNull(); } @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - clientRegistration, this.username, this.password); - + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); this.tokenResponseClient.getTokenResponse(passwordGrantRequest); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("client_id=client-id"); assertThat(formParameters).contains("client_secret=client-secret"); @@ -148,74 +144,67 @@ public class DefaultPasswordTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("tokenType cannot be null"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) + .withMessageContaining( + "[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .withMessageContaining("tokenType cannot be null"); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest); - RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("scope=read"); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; + String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[unauthorized_client]"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) + .withMessageContaining("[unauthorized_client]"); } @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest)) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to " + + "retrieve the OAuth 2.0 Access Token Response"); } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java index 44c455a944..4909fffe49 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import java.time.Instant; @@ -39,7 +40,8 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultRefreshTokenTokenResponseClient}. @@ -47,10 +49,15 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class DefaultRefreshTokenTokenResponseClientTests { + private DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private MockWebServer server; @Before @@ -70,78 +77,64 @@ public class DefaultRefreshTokenTokenResponseClientTests { @Test public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)); } @Test public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)); } @Test public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(refreshTokenGrantRequest); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=refresh_token"); assertThat(formParameters).contains("refresh_token=refresh-token"); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.accessToken.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getAccessToken().getScopes()) + .containsExactly(this.accessToken.getScopes().toArray(new String[0])); assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue()); } @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); - - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); - + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("client_id=client-id"); assertThat(formParameters).contains("client_secret=client-secret"); @@ -149,74 +142,69 @@ public class DefaultRefreshTokenTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("tokenType cannot be null"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to " + + "retrieve the OAuth 2.0 Access Token Response") + .withMessageContaining("tokenType cannot be null"); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, Collections.singleton("read")); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); - + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, + Collections.singleton("read")); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(refreshTokenGrantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("scope=read"); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; + String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[unauthorized_client]"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .withMessageContaining("[unauthorized_client]"); } @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to " + + "retrieve the OAuth 2.0 Access Token Response"); } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java index e26cb0b18d..17fc44c706 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import java.time.Instant; @@ -27,6 +28,7 @@ import org.junit.rules.ExpectedException; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -34,12 +36,11 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}. @@ -47,10 +48,15 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class NimbusAuthorizationCodeTokenResponseClientTests { + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AuthorizationRequest authorizationRequest; + private OAuth2AuthorizationResponse authorizationResponse; + private OAuth2AuthorizationExchange authorizationExchange; + private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient(); @Rule @@ -58,44 +64,39 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Before public void setUp() { - this.clientRegistrationBuilder = clientRegistration() + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC); - this.authorizationRequest = request().build(); - this.authorizationResponse = success().build(); - this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); + this.authorizationRequest = TestOAuth2AuthorizationRequests.request().build(); + this.authorizationResponse = TestOAuth2AuthorizationResponses.success().build(); + this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + this.authorizationResponse); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { MockWebServer server = new MockWebServer(); - - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"openid profile\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n" + - "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(accessTokenSuccessResponse)); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + // @formatter:on + server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(accessTokenSuccessResponse)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), + this.authorizationExchange)); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - server.shutdown(); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); @@ -109,58 +110,49 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); - String redirectUri = "http:\\example.com"; - OAuth2AuthorizationRequest authorizationRequest = request().redirectUri(redirectUri).build(); - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( - authorizationRequest, this.authorizationResponse); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri(redirectUri).build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + this.authorizationResponse); this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), authorizationExchange)); } @Test public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); - String tokenUri = "http:\\provider.com\\oauth2\\token"; this.clientRegistrationBuilder.tokenUri(tokenUri); - - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } @Test public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() throws Exception { this.exception.expect(OAuth2AuthorizationException.class); this.exception.expectMessage(containsString("invalid_token_response")); - MockWebServer server = new MockWebServer(); - - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"openid profile\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n"; -// "}\n"; // Make the JSON invalid/malformed - - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(accessTokenSuccessResponse)); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on + server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(accessTokenSuccessResponse)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - try { - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); - } finally { + this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); + } + finally { server.shutdown(); } } @@ -168,39 +160,32 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { this.exception.expect(OAuth2AuthorizationException.class); - String tokenUri = "https://invalid-provider.com/oauth2/token"; this.clientRegistrationBuilder.tokenUri(tokenUri); - - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() throws Exception { this.exception.expect(OAuth2AuthorizationException.class); this.exception.expectMessage(containsString("unauthorized_client")); - MockWebServer server = new MockWebServer(); - - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setResponseCode(500) - .setBody(accessTokenErrorResponse)); + // @formatter:off + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + // @formatter:on + server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setResponseCode(500).setBody(accessTokenErrorResponse)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - try { - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); - } finally { + this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); + } + finally { server.shutdown(); } } @@ -210,114 +195,99 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() throws Exception { this.exception.expect(OAuth2AuthorizationException.class); this.exception.expectMessage(containsString("server_error")); - MockWebServer server = new MockWebServer(); - server.enqueue(new MockResponse().setResponseCode(500)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - try { - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); - } finally { + this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); + } + finally { server.shutdown(); } } @Test - public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() throws Exception { + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() + throws Exception { this.exception.expect(OAuth2AuthorizationException.class); this.exception.expectMessage(containsString("invalid_token_response")); - MockWebServer server = new MockWebServer(); - - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; - - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(accessTokenSuccessResponse)); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(accessTokenSuccessResponse)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - try { - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); - } finally { + this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); + } + finally { server.shutdown(); } } @Test - public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() throws Exception { + public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() + throws Exception { MockWebServer server = new MockWebServer(); - - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"openid profile\"\n" + - "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(accessTokenSuccessResponse)); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\"\n" + + "}\n"; + // @formatter:on + server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(accessTokenSuccessResponse)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - - OAuth2AuthorizationRequest authorizationRequest = - request().scope("openid", "profile", "email", "address").build(); - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( - authorizationRequest, this.authorizationResponse); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .scope("openid", "profile", "email", "address").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + this.authorizationResponse); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), authorizationExchange)); - + new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), authorizationExchange)); server.shutdown(); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile"); } @Test - public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() throws Exception { + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() + throws Exception { MockWebServer server = new MockWebServer(); - - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(accessTokenSuccessResponse)); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(accessTokenSuccessResponse)); server.start(); - String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); - - OAuth2AuthorizationRequest authorizationRequest = - request().scope("openid", "profile", "email", "address").build(); - OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( - authorizationRequest, this.authorizationResponse); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .scope("openid", "profile", "email", "address").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + this.authorizationResponse); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), authorizationExchange)); - + new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), authorizationExchange)); server.shutdown(); - - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", "address"); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", + "address"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java index ab3003f98e..53f6d724ad 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverterTests.java @@ -13,9 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -30,13 +37,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.util.MultiValueMap; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; /** * Tests for {@link OAuth2AuthorizationCodeGrantRequestEntityConverter}. @@ -44,109 +45,104 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @author Joe Grandja */ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests { + private OAuth2AuthorizationCodeGrantRequestEntityConverter converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); + + // @formatter:off private ClientRegistration.Builder clientRegistrationBuilder = ClientRegistration - .withRegistrationId("registration-1") - .clientId("client-1") - .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri("https://client.com/callback/client-1") - .scope("read", "write") - .authorizationUri("https://provider.com/oauth2/authorize") - .tokenUri("https://provider.com/oauth2/token") - .userInfoUri("https://provider.com/user") - .userNameAttributeName("id") - .clientName("client-1"); + .withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri("https://client.com/callback/client-1") + .scope("read", "write") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/user") + .userNameAttributeName("id") + .clientName("client-1"); + // @formatter:on + + // @formatter:off private OAuth2AuthorizationRequest.Builder authorizationRequestBuilder = OAuth2AuthorizationRequest - .authorizationCode() - .clientId("client-1") - .state("state-1234") - .authorizationUri("https://provider.com/oauth2/authorize") - .redirectUri("https://client.com/callback/client-1") - .scopes(new HashSet(Arrays.asList("read", "write"))); + .authorizationCode() + .clientId("client-1") + .state("state-1234") + .authorizationUri("https://provider.com/oauth2/authorize") + .redirectUri("https://client.com/callback/client-1") + .scopes(new HashSet(Arrays.asList("read", "write"))); + // @formatter:on + + // @formatter:off private OAuth2AuthorizationResponse.Builder authorizationResponseBuilder = OAuth2AuthorizationResponse - .success("code-1234") - .state("state-1234") - .redirectUri("https://client.com/callback/client-1"); + .success("code-1234") + .state("state-1234") + .redirectUri("https://client.com/callback/client-1"); + // @formatter:on @SuppressWarnings("unchecked") @Test public void convertWhenGrantRequestValidThenConverts() { - ClientRegistration clientRegistration = clientRegistrationBuilder.build(); - OAuth2AuthorizationRequest authorizationRequest = authorizationRequestBuilder.build(); - OAuth2AuthorizationResponse authorizationResponse = authorizationResponseBuilder.build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.build(); + OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBuilder.build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( clientRegistration, authorizationExchange); - RequestEntity requestEntity = this.converter.convert(authorizationCodeGrantRequest); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getTokenUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); - assertThat(headers.getContentType()).isEqualTo( - MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); - MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo("code-1234"); assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isNull(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)).isEqualTo( - clientRegistration.getRedirectUri()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)) + .isEqualTo(clientRegistration.getRedirectUri()); } @SuppressWarnings("unchecked") @Test public void convertWhenPkceGrantRequestValidThenConverts() { - ClientRegistration clientRegistration = clientRegistrationBuilder - .clientAuthenticationMethod(null) - .clientSecret(null) - .build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.clientAuthenticationMethod(null) + .clientSecret(null).build(); Map attributes = new HashMap<>(); attributes.put(PkceParameterNames.CODE_VERIFIER, "code-verifier-1234"); - Map additionalParameters = new HashMap<>(); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, "code-challenge-1234"); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - - OAuth2AuthorizationRequest authorizationRequest = authorizationRequestBuilder - .attributes(attributes) - .additionalParameters(additionalParameters) - .build(); - - OAuth2AuthorizationResponse authorizationResponse = authorizationResponseBuilder.build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.attributes(attributes) + .additionalParameters(additionalParameters).build(); + OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBuilder.build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( clientRegistration, authorizationExchange); - RequestEntity requestEntity = this.converter.convert(authorizationCodeGrantRequest); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getTokenUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); - assertThat(headers.getContentType()).isEqualTo( - MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isNull(); - MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo("code-1234"); - assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)).isEqualTo( - clientRegistration.getRedirectUri()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI)) + .isEqualTo(clientRegistration.getRedirectUri()); assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isEqualTo("client-1"); assertThat(formParameters.getFirst(PkceParameterNames.CODE_VERIFIER)).isEqualTo("code-verifier-1234"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java index 991ec80c0a..28c625f841 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success; /** * Tests for {@link OAuth2AuthorizationCodeGrantRequest}. @@ -32,13 +33,15 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class OAuth2AuthorizationCodeGrantRequestTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationExchange authorizationExchange; @Before public void setUp() { - this.clientRegistration = clientRegistration().build(); - this.authorizationExchange = success(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizationExchange = TestOAuth2AuthorizationExchanges.success(); } @Test(expected = IllegalArgumentException.class) @@ -53,11 +56,11 @@ public class OAuth2AuthorizationCodeGrantRequestTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange); - + OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistration, this.authorizationExchange); assertThat(authorizationCodeGrantRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationCodeGrantRequest.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); assertThat(authorizationCodeGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java index 28233c9a16..9f1a9b11fc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverterTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -28,7 +30,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; /** * Tests for {@link OAuth2ClientCredentialsGrantRequestEntityConverter}. @@ -36,11 +37,14 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @author Joe Grandja */ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests { + private OAuth2ClientCredentialsGrantRequestEntityConverter converter = new OAuth2ClientCredentialsGrantRequestEntityConverter(); + private OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest; @Before public void setup() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") @@ -49,6 +53,7 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests { .scope("read", "write") .tokenUri("https://provider.com/oauth2/token") .build(); + // @formatter:on this.clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); } @@ -56,22 +61,19 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests { @Test public void convertWhenGrantRequestValidThenConverts() { RequestEntity requestEntity = this.converter.convert(this.clientCredentialsGrantRequest); - ClientRegistration clientRegistration = this.clientCredentialsGrantRequest.getClientRegistration(); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getTokenUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); - assertThat(headers.getContentType()).isEqualTo( - MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); - MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java index 965a552404..095ab160d2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.junit.Before; import org.junit.Test; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Java6Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2ClientCredentialsGrantRequest}. @@ -30,10 +32,12 @@ import static org.assertj.core.api.Java6Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2ClientCredentialsGrantRequestTests { + private ClientRegistration clientRegistration; @Before public void setup() { + // @formatter:off this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") @@ -42,16 +46,17 @@ public class OAuth2ClientCredentialsGrantRequestTests { .scope("read", "write") .tokenUri("https://provider.com/oauth2/token") .build(); + // @formatter:on } @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(null)); } @Test public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .authorizationGrantType(AuthorizationGrantType.IMPLICIT) @@ -59,18 +64,18 @@ public class OAuth2ClientCredentialsGrantRequestTests { .authorizationUri("https://provider.com/oauth2/auth") .clientName("Client 1") .build(); - - assertThatThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(clientRegistration)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); + // @formatter:on + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(clientRegistration)).withMessage( + "clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); } @Test public void constructorWhenValidParametersProvidedThenCreated() { - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(this.clientRegistration); - + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration); assertThat(clientCredentialsGrantRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(clientCredentialsGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java index 0fd6eb4a2c..7e85dcc497 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverterTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -28,7 +30,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; /** * Tests for {@link OAuth2PasswordGrantRequestEntityConverter}. @@ -36,15 +37,19 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @author Joe Grandja */ public class OAuth2PasswordGrantRequestEntityConverterTests { + private OAuth2PasswordGrantRequestEntityConverter converter = new OAuth2PasswordGrantRequestEntityConverter(); + private OAuth2PasswordGrantRequest passwordGrantRequest; @Before public void setup() { + // @formatter:off ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() .authorizationGrantType(AuthorizationGrantType.PASSWORD) .scope("read", "write") .build(); + // @formatter:on this.passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", "password"); } @@ -52,22 +57,18 @@ public class OAuth2PasswordGrantRequestEntityConverterTests { @Test public void convertWhenGrantRequestValidThenConverts() { RequestEntity requestEntity = this.converter.convert(this.passwordGrantRequest); - ClientRegistration clientRegistration = this.passwordGrantRequest.getClientRegistration(); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getTokenUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); - assertThat(headers.getContentType()).isEqualTo( - MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); - MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( - AuthorizationGrantType.PASSWORD.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.PASSWORD.getValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.USERNAME)).isEqualTo("user1"); assertThat(formParameters.getFirst(OAuth2ParameterNames.PASSWORD)).isEqualTo("password"); assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write"); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestTests.java index 7c821f9897..dfcfe69841 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; import org.junit.Test; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2PasswordGrantRequest}. @@ -29,53 +31,57 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2PasswordGrantRequestTests { + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() .authorizationGrantType(AuthorizationGrantType.PASSWORD).build(); + private String username = "user1"; + private String password = "password"; @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2PasswordGrantRequest(null, this.username, this.password)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistration cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2PasswordGrantRequest(null, this.username, this.password)) + .withMessage("clientRegistration cannot be null"); } @Test public void constructorWhenUsernameIsEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, null, this.password)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("username cannot be empty"); - assertThatThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, "", this.password)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("username cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, null, this.password)) + .withMessage("username cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, "", this.password)) + .withMessage("username cannot be empty"); } @Test public void constructorWhenPasswordIsEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, this.username, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("password cannot be empty"); - assertThatThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, this.username, "")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("password cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, this.username, null)) + .withMessage("password cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2PasswordGrantRequest(this.clientRegistration, this.username, "")) + .withMessage("password cannot be empty"); } @Test public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() { ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); - assertThatThrownBy(() -> new OAuth2PasswordGrantRequest(registration, this.username, this.password)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistration.authorizationGrantType must be AuthorizationGrantType.PASSWORD"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2PasswordGrantRequest(registration, this.username, this.password)) + .withMessage("clientRegistration.authorizationGrantType must be AuthorizationGrantType.PASSWORD"); } @Test public void constructorWhenValidParametersProvidedThenCreated() { - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - this.clientRegistration, this.username, this.password); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(this.clientRegistration, + this.username, this.password); assertThat(passwordGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.PASSWORD); assertThat(passwordGrantRequest.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(passwordGrantRequest.getUsername()).isEqualTo(this.username); assertThat(passwordGrantRequest.getPassword()).isEqualTo(this.password); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java index 2f73174039..60c53bdeda 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -30,10 +34,7 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.MultiValueMap; -import java.util.Collections; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; /** * Tests for {@link OAuth2RefreshTokenGrantRequestEntityConverter}. @@ -41,41 +42,37 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @author Joe Grandja */ public class OAuth2RefreshTokenGrantRequestEntityConverterTests { + private OAuth2RefreshTokenGrantRequestEntityConverter converter = new OAuth2RefreshTokenGrantRequestEntityConverter(); + private OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest; @Before public void setup() { this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - TestClientRegistrations.clientRegistration().build(), - TestOAuth2AccessTokens.scopes("read", "write"), - TestOAuth2RefreshTokens.refreshToken(), - Collections.singleton("read")); + TestClientRegistrations.clientRegistration().build(), TestOAuth2AccessTokens.scopes("read", "write"), + TestOAuth2RefreshTokens.refreshToken(), Collections.singleton("read")); } @SuppressWarnings("unchecked") @Test public void convertWhenGrantRequestValidThenConverts() { RequestEntity requestEntity = this.converter.convert(this.refreshTokenGrantRequest); - ClientRegistration clientRegistration = this.refreshTokenGrantRequest.getClientRegistration(); OAuth2RefreshToken refreshToken = this.refreshTokenGrantRequest.getRefreshToken(); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getTokenUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); - assertThat(headers.getContentType()).isEqualTo( - MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); - MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( - AuthorizationGrantType.REFRESH_TOKEN.getValue()); - assertThat(formParameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN)).isEqualTo( - refreshToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.REFRESH_TOKEN.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN)).isEqualTo(refreshToken.getTokenValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java index dc90a388a5..8b51a22e04 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + import org.junit.Before; import org.junit.Test; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -24,12 +30,8 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2RefreshTokenGrantRequest}. @@ -37,8 +39,11 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2RefreshTokenGrantRequestTests { + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; @Before @@ -50,23 +55,23 @@ public class OAuth2RefreshTokenGrantRequestTests { @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(null, this.accessToken, this.refreshToken)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistration cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2RefreshTokenGrantRequest(null, this.accessToken, this.refreshToken)) + .withMessage("clientRegistration cannot be null"); } @Test public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, null, this.refreshToken)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("accessToken cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, null, this.refreshToken)) + .withMessage("accessToken cannot be null"); } @Test public void constructorWhenRefreshTokenIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, this.accessToken, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("refreshToken cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, this.accessToken, null)) + .withMessage("refreshToken cannot be null"); } @Test @@ -79,4 +84,5 @@ public class OAuth2RefreshTokenGrantRequestTests { assertThat(refreshTokenGrantRequest.getRefreshToken()).isSameAs(this.refreshToken); assertThat(refreshTokenGrantRequest.getScopes()).isEqualTo(scopes); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java index 64186e10cb..8617afe600 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java @@ -16,11 +16,16 @@ package org.springframework.security.oauth2.client.endpoint; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -35,22 +40,19 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.web.reactive.function.client.WebClient; -import java.time.Instant; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch * @since 5.1 */ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests { + private ClientRegistration.Builder clientRegistration; private WebClientReactiveAuthorizationCodeTokenResponseClient tokenResponseClient = new WebClientReactiveAuthorizationCodeTokenResponseClient(); @@ -61,11 +63,8 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - String tokenUri = this.server.url("/oauth2/token").toString(); - - this.clientRegistration = TestClientRegistrations.clientRegistration() - .tokenUri(tokenUri); + this.clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); } @After @@ -75,30 +74,27 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"openid profile\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest()).block(); String body = this.server.takeRequest().getBody().readUtf8(); - - assertThat(body).isEqualTo("grant_type=authorization_code&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D"); - + assertThat(body).isEqualTo( + "grant_type=authorization_code&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D"); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); - assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo( - OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile"); assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token-1234"); @@ -107,227 +103,223 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests { assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2"); } -// @Test -// public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws Exception { -// this.exception.expect(IllegalArgumentException.class); -// -// String redirectUri = "http:\\example.com"; -// when(this.clientRegistration.getRedirectUri()).thenReturn(redirectUri); -// -// this.tokenResponseClient.getTokenResponse( -// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); -// } -// -// @Test -// public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() throws Exception { -// this.exception.expect(IllegalArgumentException.class); -// -// String tokenUri = "http:\\provider.com\\oauth2\\token"; -// when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); -// -// this.tokenResponseClient.getTokenResponse( -// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); -// } -// -// @Test -// public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() throws Exception { -// this.exception.expect(OAuth2AuthorizationException.class); -// this.exception.expectMessage(containsString("invalid_token_response")); -// -// MockWebServer server = new MockWebServer(); -// -// String accessTokenSuccessResponse = "{\n" + -// " \"access_token\": \"access-token-1234\",\n" + -// " \"token_type\": \"bearer\",\n" + -// " \"expires_in\": \"3600\",\n" + -// " \"scope\": \"openid profile\",\n" + -// " \"custom_parameter_1\": \"custom-value-1\",\n" + -// " \"custom_parameter_2\": \"custom-value-2\"\n"; -// // "}\n"; // Make the JSON invalid/malformed -// -// server.enqueue(new MockResponse() -// .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) -// .setBody(accessTokenSuccessResponse)); -// server.start(); -// -// String tokenUri = server.url("/oauth2/token").toString(); -// when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); -// -// try { -// this.tokenResponseClient.getTokenResponse( -// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); -// } finally { -// server.shutdown(); -// } -// } -// -// @Test -// public void getTokenResponseWhenTokenUriInvalidThenThrowAuthenticationServiceException() throws Exception { -// this.exception.expect(AuthenticationServiceException.class); -// -// String tokenUri = "https://invalid-provider.com/oauth2/token"; -// when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); -// -// this.tokenResponseClient.getTokenResponse( -// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); -// } -// + // @Test + // public void + // getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws + // Exception { + // this.exception.expect(IllegalArgumentException.class); + // + // String redirectUri = "http:\\example.com"; + // when(this.clientRegistration.getRedirectUri()).thenReturn(redirectUri); + // + // this.tokenResponseClient.getTokenResponse( + // new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, + // this.authorizationExchange)); + // } + // + // @Test + // public void + // getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() throws + // Exception { + // this.exception.expect(IllegalArgumentException.class); + // + // String tokenUri = "http:\\provider.com\\oauth2\\token"; + // when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + // + // this.tokenResponseClient.getTokenResponse( + // new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, + // this.authorizationExchange)); + // } + // + // @Test + // public void + // getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() + // throws Exception { + // this.exception.expect(OAuth2AuthorizationException.class); + // this.exception.expectMessage(containsString("invalid_token_response")); + // + // MockWebServer server = new MockWebServer(); + // + // String accessTokenSuccessResponse = "{\n" + + // " \"access_token\": \"access-token-1234\",\n" + + // " \"token_type\": \"bearer\",\n" + + // " \"expires_in\": \"3600\",\n" + + // " \"scope\": \"openid profile\",\n" + + // " \"custom_parameter_1\": \"custom-value-1\",\n" + + // " \"custom_parameter_2\": \"custom-value-2\"\n"; + // // "}\n"; // Make the JSON invalid/malformed + // + // server.enqueue(new MockResponse() + // .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + // .setBody(accessTokenSuccessResponse)); + // server.start(); + // + // String tokenUri = server.url("/oauth2/token").toString(); + // when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + // + // try { + // this.tokenResponseClient.getTokenResponse( + // new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, + // this.authorizationExchange)); + // } finally { + // server.shutdown(); + // } + // } + // + // @Test + // public void + // getTokenResponseWhenTokenUriInvalidThenThrowAuthenticationServiceException() throws + // Exception { + // this.exception.expect(AuthenticationServiceException.class); + // + // String tokenUri = "https://invalid-provider.com/oauth2/token"; + // when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + // + // this.tokenResponseClient.getTokenResponse( + // new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, + // this.authorizationExchange)); + // } + // @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; - - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client")) - .hasMessageContaining("unauthorized_client"); + String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; + this.server.enqueue( + jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("unauthorized_client")) + .withMessageContaining("unauthorized_client"); } // gh-5594 @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { String accessTokenErrorResponse = "{}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("server_error"); + this.server.enqueue( + jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) + .withMessageContaining("server_error"); } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; - + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + "\"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("invalid_token_response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) + .withMessageContaining("invalid_token_response"); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"openid profile\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + "\"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - this.clientRegistration.scope("openid", "profile", "email", "address"); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest()).block(); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile"); } @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - - this.clientRegistration.scope("openid", "profile", "email", "address"); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); - - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", "address"); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(authorizationCodeGrantRequest()).block(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", + "address"); } private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() { ClientRegistration registration = this.clientRegistration.build(); - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest - .authorizationCode() - .clientId(registration.getClientId()) - .state("state") + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .clientId(registration.getClientId()).state("state") .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) - .redirectUri(registration.getRedirectUri()) - .scopes(registration.getScopes()) - .build(); - OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse - .success("code") - .state("state") - .redirectUri(registration.getRedirectUri()) - .build(); + .redirectUri(registration.getRedirectUri()).scopes(registration.getScopes()).build(); + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success("code").state("state") + .redirectUri(registration.getRedirectUri()).build(); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange); } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } - @Test(expected=IllegalArgumentException.class) - public void setWebClientNullThenIllegalArgumentException(){ - tokenResponseClient.setWebClient(null); + @Test(expected = IllegalArgumentException.class) + public void setWebClientNullThenIllegalArgumentException() { + this.tokenResponseClient.setWebClient(null); } @Test public void setCustomWebClientThenCustomWebClientIsUsed() { WebClient customClient = mock(WebClient.class); - when(customClient.post()).thenReturn(WebClient.builder().build().post()); - - tokenResponseClient.setWebClient(customClient); - - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"openid profile\"\n" + - "}\n"; + given(customClient.post()).willReturn(WebClient.builder().build().post()); + this.tokenResponseClient.setWebClient(customClient); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - this.clientRegistration.scope("openid", "profile", "email", "address"); - - OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); - + OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()) + .block(); verify(customClient, atLeastOnce()).post(); } @Test - public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParametersThenTokenRequestBodyShouldContainCodeVerifier() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParametersThenTokenRequestBodyShouldContainCodeVerifier() + throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block(); String body = this.server.takeRequest().getBody().readUtf8(); - - assertThat(body).isEqualTo("grant_type=authorization_code&client_id=client-id&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D&code_verifier=code-verifier-1234"); + assertThat(body).isEqualTo( + "grant_type=authorization_code&client_id=client-id&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D&code_verifier=code-verifier-1234"); } private OAuth2AuthorizationCodeGrantRequest pkceAuthorizationCodeGrantRequest() { - ClientRegistration registration = this.clientRegistration - .clientAuthenticationMethod(null) - .clientSecret(null) + ClientRegistration registration = this.clientRegistration.clientAuthenticationMethod(null).clientSecret(null) .build(); - Map attributes = new HashMap<>(); attributes.put(PkceParameterNames.CODE_VERIFIER, "code-verifier-1234"); - Map additionalParameters = new HashMap<>(); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, "code-challenge-1234"); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest - .authorizationCode() + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .clientId(registration.getClientId()) .state("state") .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) @@ -341,8 +333,10 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests { .state("state") .redirectUri(registration.getRedirectUri()) .build(); + // @formatter:on OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java index 5128f7b779..a07e59979d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -22,6 +22,7 @@ import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -33,12 +34,12 @@ import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.validateMockitoUsage; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -55,9 +56,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - - this.clientRegistration = TestClientRegistrations - .clientCredentials() + this.clientRegistration = TestClientRegistrations.clientCredentials() .tokenUri(this.server.url("/oauth2/token").uri().toASCIIString()); } @@ -69,87 +68,86 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests { @Test public void getTokenResponseWhenHeaderThenSuccess() throws Exception { + // @formatter:off enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" - + " \"scope\":\"create\"\n" - + "}"); - OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(this.clientRegistration - .build()); - + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" + + " \"scope\":\"create\"\n" + + "}"); + // @formatter:on + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration.build()); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); RecordedRequest actualRequest = this.server.takeRequest(); String body = actualRequest.getUtf8Body(); - assertThat(response.getAccessToken()).isNotNull(); - assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser"); } @Test public void getTokenResponseWhenPostThenSuccess() throws Exception { ClientRegistration registration = this.clientRegistration - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); + // @formatter:off enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" - + " \"scope\":\"create\"\n" - + "}"); - + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" + + " \"scope\":\"create\"\n" + + "}"); + // @formatter:on OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); - OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); RecordedRequest actualRequest = this.server.takeRequest(); String body = actualRequest.getUtf8Body(); - assertThat(response.getAccessToken()).isNotNull(); assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - assertThat(body).isEqualTo("grant_type=client_credentials&client_id=client-id&client_secret=client-secret&scope=read%3Auser"); + assertThat(body).isEqualTo( + "grant_type=client_credentials&client_id=client-id&client_secret=client-secret&scope=read%3Auser"); } @Test public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() { ClientRegistration registration = this.clientRegistration.build(); + // @formatter:off enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"); + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"); + // @formatter:on OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); - OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); - assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes()); } - @Test(expected=IllegalArgumentException.class) - public void setWebClientNullThenIllegalArgumentException(){ - client.setWebClient(null); + @Test(expected = IllegalArgumentException.class) + public void setWebClientNullThenIllegalArgumentException() { + this.client.setWebClient(null); } @Test public void setWebClientCustomThenCustomClientIsUsed() { WebClient customClient = mock(WebClient.class); - when(customClient.post()).thenReturn(WebClient.builder().build().post()); - + given(customClient.post()).willReturn(WebClient.builder().build().post()); this.client.setWebClient(customClient); ClientRegistration registration = this.clientRegistration.build(); + // @formatter:off enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"); + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"); + // @formatter:on OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); - OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); - verify(customClient, atLeastOnce()).post(); } @@ -157,27 +155,27 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests { public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException { ClientRegistration registration = this.clientRegistration.build(); enqueueUnexpectedResponse(); - OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); - - assertThatThrownBy(() -> this.client.getTokenResponse(request).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) - .hasMessageContaining("[invalid_token_response]") - .hasMessageContaining("Empty OAuth 2.0 Access Token Response"); - + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.client.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("Empty OAuth 2.0 Access Token Response"); } - private void enqueueUnexpectedResponse(){ + private void enqueueUnexpectedResponse() { + // @formatter:off MockResponse response = new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setResponseCode(301); + // @formatter:on this.server.enqueue(response); } private void enqueueJson(String body) { - MockResponse response = new MockResponse() - .setBody(body) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + MockResponse response = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, + MediaType.APPLICATION_JSON_VALUE); this.server.enqueue(response); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java index 2cef538c1e..3f00b3b43f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java @@ -13,14 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.time.Instant; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -31,10 +35,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import java.time.Instant; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link WebClientReactivePasswordTokenResponseClient}. @@ -42,10 +45,15 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class WebClientReactivePasswordTokenResponseClientTests { + private WebClientReactivePasswordTokenResponseClient tokenResponseClient = new WebClientReactivePasswordTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private String username = "user1"; + private String password = "password"; + private MockWebServer server; @Before @@ -63,73 +71,66 @@ public class WebClientReactivePasswordTokenResponseClientTests { @Test public void setWebClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setWebClient(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setWebClient(null)); } @Test public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null).block()); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - clientRegistration, this.username, this.password); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block(); - + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest) + .block(); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); - + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=password"); assertThat(formParameters).contains("username=user1"); assertThat(formParameters).contains("password=password"); assertThat(formParameters).contains("scope=read+write"); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(clientRegistration.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getAccessToken().getScopes()) + .containsExactly(clientRegistration.getScopes().toArray(new String[0])); assertThat(accessTokenResponse.getRefreshToken()).isNull(); } @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( - clientRegistration, this.username, this.password); - + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block(); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("client_id=client-id"); assertThat(formParameters).contains("client_secret=client-secret"); @@ -137,76 +138,79 @@ public class WebClientReactivePasswordTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) - .hasMessageContaining("[invalid_token_response]") - .hasMessageContaining("An error occurred parsing the Access Token response") - .hasCauseInstanceOf(Throwable.class); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("An error occurred parsing the Access Token response") + .withCauseInstanceOf(Throwable.class); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block(); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest) + .block(); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("scope=read"); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; + // @formatter:off + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client")) - .hasMessageContaining("[unauthorized_client]"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("unauthorized_client")) + .withMessageContaining("[unauthorized_client]"); } @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( this.clientRegistrationBuilder.build(), this.username, this.password); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) - .hasMessageContaining("[invalid_token_response]") - .hasMessageContaining("Empty OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("Empty OAuth 2.0 Access Token Response"); } private MockResponse jsonResponse(String json) { + // @formatter:off return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setBody(json); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java index 272e175b9e..1aee3057fc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.endpoint; +import java.time.Instant; +import java.util.Collections; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -34,11 +39,9 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import java.time.Instant; -import java.util.Collections; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link WebClientReactiveRefreshTokenTokenResponseClient}. @@ -46,10 +49,15 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class WebClientReactiveRefreshTokenTokenResponseClientTests { + private WebClientReactiveRefreshTokenTokenResponseClient tokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private MockWebServer server; @Before @@ -69,72 +77,64 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests { @Test public void setWebClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.setWebClient(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setWebClient(null)); } @Test public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null).block()); } @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - Instant expiresAtBefore = Instant.now().plusSeconds(3600); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); - + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(refreshTokenGrantRequest).block(); Instant expiresAtAfter = Instant.now().plusSeconds(3600); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); - assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("grant_type=refresh_token"); assertThat(formParameters).contains("refresh_token=refresh-token"); - assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.accessToken.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getAccessToken().getScopes()) + .containsExactly(this.accessToken.getScopes().toArray(new String[0])); assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue()); } @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .clientAuthenticationMethod(ClientAuthenticationMethod.POST) - .build(); - - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - clientRegistration, this.accessToken, this.refreshToken); - + .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); - RecordedRequest recordedRequest = this.server.takeRequest(); assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("client_id=client-id"); assertThat(formParameters).contains("client_secret=client-secret"); @@ -142,76 +142,79 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"not-bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response]") - .hasMessageContaining("An error occurred parsing the Access Token response") - .hasCauseInstanceOf(Throwable.class); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("An error occurred parsing the Access Token response") + .withCauseInstanceOf(Throwable.class); } @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { - String accessTokenSuccessResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read\"\n" + - "}\n"; + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( - this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, Collections.singleton("read")); - - OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); - + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, + Collections.singleton("read")); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient + .getTokenResponse(refreshTokenGrantRequest).block(); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); assertThat(formParameters).contains("scope=read"); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\n" + - " \"error\": \"unauthorized_client\"\n" + - "}\n"; + // @formatter:off + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client")) - .hasMessageContaining("[unauthorized_client]"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("unauthorized_client")) + .withMessageContaining("[unauthorized_client]"); } @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - - assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) - .hasMessageContaining("[invalid_token_response]") - .hasMessageContaining("Empty OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("Empty OAuth 2.0 Access Token Response"); } private MockResponse jsonResponse(String json) { + // @formatter:off return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setBody(json); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java index 98e2b72bb2..7f33e8e745 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.http; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link OAuth2ErrorResponseErrorHandler}. @@ -29,33 +31,31 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2ErrorResponseErrorHandlerTests { + private OAuth2ErrorResponseErrorHandler errorHandler = new OAuth2ErrorResponseErrorHandler(); @Test public void handleErrorWhenErrorResponseBodyThenHandled() { - String errorResponse = "{\n" + - " \"error\": \"unauthorized_client\",\n" + - " \"error_description\": \"The client is not authorized\"\n" + - "}\n"; - - MockClientHttpResponse response = new MockClientHttpResponse( - errorResponse.getBytes(), HttpStatus.BAD_REQUEST); - - assertThatThrownBy(() -> this.errorHandler.handleError(response)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessage("[unauthorized_client] The client is not authorized"); + // @formatter:off + String errorResponse = "{\n" + + " \"error\": \"unauthorized_client\",\n" + + " \"error_description\": \"The client is not authorized\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(errorResponse.getBytes(), HttpStatus.BAD_REQUEST); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.errorHandler.handleError(response)) + .withMessage("[unauthorized_client] The client is not authorized"); } @Test public void handleErrorWhenErrorResponseWwwAuthenticateHeaderThenHandled() { String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\""; - - MockClientHttpResponse response = new MockClientHttpResponse( - new byte[0], HttpStatus.BAD_REQUEST); + MockClientHttpResponse response = new MockClientHttpResponse(new byte[0], HttpStatus.BAD_REQUEST); response.getHeaders().add(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader); - - assertThatThrownBy(() -> this.errorHandler.handleError(response)) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessage("[insufficient_scope] The access token expired"); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.errorHandler.handleError(response)) + .withMessage("[insufficient_scope] The access token expired"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixinTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixinTests.java index 537c90026d..5052f98a07 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixinTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationExceptionMixinTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; import com.fasterxml.jackson.core.JsonProcessingException; @@ -20,12 +21,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.Before; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; + import org.springframework.security.jackson2.SecurityJackson2Modules; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link OAuth2AuthenticationExceptionMixin}. @@ -46,12 +48,9 @@ public class OAuth2AuthenticationExceptionMixinTests { @Test public void serializeWhenMixinRegisteredThenSerializes() throws Exception { - OAuth2AuthenticationException exception = new OAuth2AuthenticationException(new OAuth2Error( - "[authorization_request_not_found]", - "Authorization Request Not Found", - "/foo/bar" - ), "Authorization Request Not Found"); - + OAuth2AuthenticationException exception = new OAuth2AuthenticationException( + new OAuth2Error("[authorization_request_not_found]", "Authorization Request Not Found", "/foo/bar"), + "Authorization Request Not Found"); String serializedJson = this.mapper.writeValueAsString(exception); String expected = asJson(exception); JSONAssert.assertEquals(expected, serializedJson, true); @@ -60,9 +59,7 @@ public class OAuth2AuthenticationExceptionMixinTests { @Test public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception { OAuth2AuthenticationException exception = new OAuth2AuthenticationException( - new OAuth2Error("[authorization_request_not_found]") - ); - + new OAuth2Error("[authorization_request_not_found]")); String serializedJson = this.mapper.writeValueAsString(exception); String expected = asJson(exception); JSONAssert.assertEquals(expected, serializedJson, true); @@ -70,26 +67,21 @@ public class OAuth2AuthenticationExceptionMixinTests { @Test public void deserializeWhenMixinNotRegisteredThenThrowJsonProcessingException() { - String json = asJson(new OAuth2AuthenticationException( - new OAuth2Error("[authorization_request_not_found]") - )); - assertThatThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthenticationException.class)) - .isInstanceOf(JsonProcessingException.class); + String json = asJson(new OAuth2AuthenticationException(new OAuth2Error("[authorization_request_not_found]"))); + assertThatExceptionOfType(JsonProcessingException.class) + .isThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthenticationException.class)); } @Test public void deserializeWhenMixinRegisteredThenDeserializes() throws Exception { - OAuth2AuthenticationException expected = new OAuth2AuthenticationException(new OAuth2Error( - "[authorization_request_not_found]", - "Authorization Request Not Found", - "/foo/bar" - ), "Authorization Request Not Found"); - - OAuth2AuthenticationException exception = this.mapper.readValue(asJson(expected), OAuth2AuthenticationException.class); + OAuth2AuthenticationException expected = new OAuth2AuthenticationException( + new OAuth2Error("[authorization_request_not_found]", "Authorization Request Not Found", "/foo/bar"), + "Authorization Request Not Found"); + OAuth2AuthenticationException exception = this.mapper.readValue(asJson(expected), + OAuth2AuthenticationException.class); assertThat(exception).isNotNull(); assertThat(exception.getCause()).isNull(); assertThat(exception.getMessage()).isEqualTo(expected.getMessage()); - OAuth2Error oauth2Error = exception.getError(); assertThat(oauth2Error).isNotNull(); assertThat(oauth2Error.getErrorCode()).isEqualTo(expected.getError().getErrorCode()); @@ -100,14 +92,12 @@ public class OAuth2AuthenticationExceptionMixinTests { @Test public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception { OAuth2AuthenticationException expected = new OAuth2AuthenticationException( - new OAuth2Error("[authorization_request_not_found]") - ); - - OAuth2AuthenticationException exception = this.mapper.readValue(asJson(expected), OAuth2AuthenticationException.class); + new OAuth2Error("[authorization_request_not_found]")); + OAuth2AuthenticationException exception = this.mapper.readValue(asJson(expected), + OAuth2AuthenticationException.class); assertThat(exception).isNotNull(); assertThat(exception.getCause()).isNull(); assertThat(exception.getMessage()).isNull(); - OAuth2Error oauth2Error = exception.getError(); assertThat(oauth2Error).isNotNull(); assertThat(oauth2Error.getErrorCode()).isEqualTo(expected.getError().getErrorCode()); @@ -133,8 +123,7 @@ public class OAuth2AuthenticationExceptionMixinTests { } private String jsonStringOrNull(String input) { - return input != null - ? "\"" + input + "\"" - : "null"; + return (input != null) ? "\"" + input + "\"" : "null"; } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixinTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixinTests.java index 8ac3721ecc..e394254109 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixinTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthenticationTokenMixinTests.java @@ -13,29 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.jackson2; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jsr310.DecimalUtils; -import org.junit.Before; -import org.junit.Test; -import org.skyscreamer.jsonassert.JSONAssert; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; -import org.springframework.security.jackson2.SecurityJackson2Modules; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthenticationTokens; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import org.springframework.security.oauth2.core.oidc.OidcUserInfo; -import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; -import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; -import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; -import org.springframework.security.oauth2.core.user.DefaultOAuth2User; -import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; +package org.springframework.security.oauth2.client.jackson2; import java.time.Instant; import java.util.ArrayList; @@ -44,9 +23,32 @@ import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.DecimalUtils; +import org.junit.Before; +import org.junit.Test; +import org.skyscreamer.jsonassert.JSONAssert; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.jackson2.SecurityJackson2Modules; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthenticationTokens; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.NAME; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link OAuth2AuthenticationTokenMixin}. @@ -54,6 +56,7 @@ import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.N * @author Joe Grandja */ public class OAuth2AuthenticationTokenMixinTests { + private ObjectMapper mapper; @Before @@ -70,7 +73,6 @@ public class OAuth2AuthenticationTokenMixinTests { String expectedJson = asJson(authentication); String json = this.mapper.writeValueAsString(authentication); JSONAssert.assertEquals(expectedJson, json, true); - // OAuth2User authentication = TestOAuth2AuthenticationTokens.authenticated(); expectedJson = asJson(authentication); @@ -82,8 +84,8 @@ public class OAuth2AuthenticationTokenMixinTests { public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception { DefaultOidcUser principal = TestOidcUsers.create(); principal = new DefaultOidcUser(principal.getAuthorities(), principal.getIdToken()); - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( - principal, Collections.emptyList(), "registration-id"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(principal, Collections.emptyList(), + "registration-id"); String expectedJson = asJson(authentication); String json = this.mapper.writeValueAsString(authentication); JSONAssert.assertEquals(expectedJson, json, true); @@ -93,8 +95,8 @@ public class OAuth2AuthenticationTokenMixinTests { public void deserializeWhenMixinNotRegisteredThenThrowJsonProcessingException() { OAuth2AuthenticationToken authentication = TestOAuth2AuthenticationTokens.oidcAuthenticated(); String json = asJson(authentication); - assertThatThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthenticationToken.class)) - .isInstanceOf(JsonProcessingException.class); + assertThatExceptionOfType(JsonProcessingException.class) + .isThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthenticationToken.class)); } @Test @@ -103,95 +105,71 @@ public class OAuth2AuthenticationTokenMixinTests { OAuth2AuthenticationToken expectedAuthentication = TestOAuth2AuthenticationTokens.oidcAuthenticated(); String json = asJson(expectedAuthentication); OAuth2AuthenticationToken authentication = this.mapper.readValue(json, OAuth2AuthenticationToken.class); - assertThat(authentication.getAuthorities()) - .containsExactlyElementsOf(expectedAuthentication.getAuthorities()); - assertThat(authentication.getDetails()) - .isEqualTo(expectedAuthentication.getDetails()); - assertThat(authentication.isAuthenticated()) - .isEqualTo(expectedAuthentication.isAuthenticated()); + assertThat(authentication.getAuthorities()).containsExactlyElementsOf(expectedAuthentication.getAuthorities()); + assertThat(authentication.getDetails()).isEqualTo(expectedAuthentication.getDetails()); + assertThat(authentication.isAuthenticated()).isEqualTo(expectedAuthentication.isAuthenticated()); assertThat(authentication.getAuthorizedClientRegistrationId()) .isEqualTo(expectedAuthentication.getAuthorizedClientRegistrationId()); DefaultOidcUser expectedOidcUser = (DefaultOidcUser) expectedAuthentication.getPrincipal(); DefaultOidcUser oidcUser = (DefaultOidcUser) authentication.getPrincipal(); assertThat(oidcUser.getAuthorities().containsAll(expectedOidcUser.getAuthorities())).isTrue(); - assertThat(oidcUser.getAttributes()) - .containsExactlyEntriesOf(expectedOidcUser.getAttributes()); - assertThat(oidcUser.getName()) - .isEqualTo(expectedOidcUser.getName()); + assertThat(oidcUser.getAttributes()).containsExactlyEntriesOf(expectedOidcUser.getAttributes()); + assertThat(oidcUser.getName()).isEqualTo(expectedOidcUser.getName()); OidcIdToken expectedIdToken = expectedOidcUser.getIdToken(); OidcIdToken idToken = oidcUser.getIdToken(); - assertThat(idToken.getTokenValue()) - .isEqualTo(expectedIdToken.getTokenValue()); - assertThat(idToken.getIssuedAt()) - .isEqualTo(expectedIdToken.getIssuedAt()); - assertThat(idToken.getExpiresAt()) - .isEqualTo(expectedIdToken.getExpiresAt()); - assertThat(idToken.getClaims()) - .containsExactlyEntriesOf(expectedIdToken.getClaims()); + assertThat(idToken.getTokenValue()).isEqualTo(expectedIdToken.getTokenValue()); + assertThat(idToken.getIssuedAt()).isEqualTo(expectedIdToken.getIssuedAt()); + assertThat(idToken.getExpiresAt()).isEqualTo(expectedIdToken.getExpiresAt()); + assertThat(idToken.getClaims()).containsExactlyEntriesOf(expectedIdToken.getClaims()); OidcUserInfo expectedUserInfo = expectedOidcUser.getUserInfo(); OidcUserInfo userInfo = oidcUser.getUserInfo(); - assertThat(userInfo.getClaims()) - .containsExactlyEntriesOf(expectedUserInfo.getClaims()); - + assertThat(userInfo.getClaims()).containsExactlyEntriesOf(expectedUserInfo.getClaims()); // OAuth2User expectedAuthentication = TestOAuth2AuthenticationTokens.authenticated(); json = asJson(expectedAuthentication); authentication = this.mapper.readValue(json, OAuth2AuthenticationToken.class); - assertThat(authentication.getAuthorities()) - .containsExactlyElementsOf(expectedAuthentication.getAuthorities()); - assertThat(authentication.getDetails()) - .isEqualTo(expectedAuthentication.getDetails()); - assertThat(authentication.isAuthenticated()) - .isEqualTo(expectedAuthentication.isAuthenticated()); + assertThat(authentication.getAuthorities()).containsExactlyElementsOf(expectedAuthentication.getAuthorities()); + assertThat(authentication.getDetails()).isEqualTo(expectedAuthentication.getDetails()); + assertThat(authentication.isAuthenticated()).isEqualTo(expectedAuthentication.isAuthenticated()); assertThat(authentication.getAuthorizedClientRegistrationId()) .isEqualTo(expectedAuthentication.getAuthorizedClientRegistrationId()); DefaultOAuth2User expectedOauth2User = (DefaultOAuth2User) expectedAuthentication.getPrincipal(); DefaultOAuth2User oauth2User = (DefaultOAuth2User) authentication.getPrincipal(); assertThat(oauth2User.getAuthorities().containsAll(expectedOauth2User.getAuthorities())).isTrue(); - assertThat(oauth2User.getAttributes()) - .containsExactlyEntriesOf(expectedOauth2User.getAttributes()); - assertThat(oauth2User.getName()) - .isEqualTo(expectedOauth2User.getName()); + assertThat(oauth2User.getAttributes()).containsExactlyEntriesOf(expectedOauth2User.getAttributes()); + assertThat(oauth2User.getName()).isEqualTo(expectedOauth2User.getName()); } @Test public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception { DefaultOidcUser expectedPrincipal = TestOidcUsers.create(); expectedPrincipal = new DefaultOidcUser(expectedPrincipal.getAuthorities(), expectedPrincipal.getIdToken()); - OAuth2AuthenticationToken expectedAuthentication = new OAuth2AuthenticationToken( - expectedPrincipal, Collections.emptyList(), "registration-id"); + OAuth2AuthenticationToken expectedAuthentication = new OAuth2AuthenticationToken(expectedPrincipal, + Collections.emptyList(), "registration-id"); String json = asJson(expectedAuthentication); OAuth2AuthenticationToken authentication = this.mapper.readValue(json, OAuth2AuthenticationToken.class); assertThat(authentication.getAuthorities()).isEmpty(); - assertThat(authentication.getDetails()) - .isEqualTo(expectedAuthentication.getDetails()); - assertThat(authentication.isAuthenticated()) - .isEqualTo(expectedAuthentication.isAuthenticated()); + assertThat(authentication.getDetails()).isEqualTo(expectedAuthentication.getDetails()); + assertThat(authentication.isAuthenticated()).isEqualTo(expectedAuthentication.isAuthenticated()); assertThat(authentication.getAuthorizedClientRegistrationId()) .isEqualTo(expectedAuthentication.getAuthorizedClientRegistrationId()); DefaultOidcUser principal = (DefaultOidcUser) authentication.getPrincipal(); assertThat(principal.getAuthorities().containsAll(expectedPrincipal.getAuthorities())).isTrue(); - assertThat(principal.getAttributes()) - .containsExactlyEntriesOf(expectedPrincipal.getAttributes()); - assertThat(principal.getName()) - .isEqualTo(expectedPrincipal.getName()); + assertThat(principal.getAttributes()).containsExactlyEntriesOf(expectedPrincipal.getAttributes()); + assertThat(principal.getName()).isEqualTo(expectedPrincipal.getName()); OidcIdToken expectedIdToken = expectedPrincipal.getIdToken(); OidcIdToken idToken = principal.getIdToken(); - assertThat(idToken.getTokenValue()) - .isEqualTo(expectedIdToken.getTokenValue()); - assertThat(idToken.getIssuedAt()) - .isEqualTo(expectedIdToken.getIssuedAt()); - assertThat(idToken.getExpiresAt()) - .isEqualTo(expectedIdToken.getExpiresAt()); - assertThat(idToken.getClaims()) - .containsExactlyEntriesOf(expectedIdToken.getClaims()); + assertThat(idToken.getTokenValue()).isEqualTo(expectedIdToken.getTokenValue()); + assertThat(idToken.getIssuedAt()).isEqualTo(expectedIdToken.getIssuedAt()); + assertThat(idToken.getExpiresAt()).isEqualTo(expectedIdToken.getExpiresAt()); + assertThat(idToken.getClaims()).containsExactlyEntriesOf(expectedIdToken.getClaims()); assertThat(principal.getUserInfo()).isNull(); } private static String asJson(OAuth2AuthenticationToken authentication) { - String principalJson = authentication.getPrincipal() instanceof DefaultOidcUser ? - asJson((DefaultOidcUser) authentication.getPrincipal()) : - asJson((DefaultOAuth2User) authentication.getPrincipal()); + String principalJson = (authentication.getPrincipal() instanceof DefaultOidcUser) + ? asJson((DefaultOidcUser) authentication.getPrincipal()) + : asJson((DefaultOAuth2User) authentication.getPrincipal()); // @formatter:off return "{\n" + " \"@class\": \"org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken\",\n" + @@ -236,17 +214,16 @@ public class OAuth2AuthenticationTokenMixinTests { for (GrantedAuthority authority : authorities) { if (authority instanceof OidcUserAuthority) { oidcUserAuthority = (OidcUserAuthority) authority; - } else if (authority instanceof OAuth2UserAuthority) { + } + else if (authority instanceof OAuth2UserAuthority) { oauth2UserAuthority = (OAuth2UserAuthority) authority; - } else if (authority instanceof SimpleGrantedAuthority) { + } + else if (authority instanceof SimpleGrantedAuthority) { simpleAuthorities.add((SimpleGrantedAuthority) authority); } } - String authoritiesJson = oidcUserAuthority != null ? - asJson(oidcUserAuthority) : - oauth2UserAuthority != null ? - asJson(oauth2UserAuthority) : - ""; + String authoritiesJson = (oidcUserAuthority != null) ? asJson(oidcUserAuthority) + : (oauth2UserAuthority != null) ? asJson(oauth2UserAuthority) : ""; if (!simpleAuthorities.isEmpty()) { if (!StringUtils.isEmpty(authoritiesJson)) { authoritiesJson += ","; @@ -288,7 +265,7 @@ public class OAuth2AuthenticationTokenMixinTests { private static String asJson(List simpleAuthorities) { // @formatter:off return simpleAuthorities.stream() - .map(authority -> "{\n" + + .map((authority) -> "{\n" + " \"@class\": \"org.springframework.security.core.authority.SimpleGrantedAuthority\",\n" + " \"authority\": \"" + authority.getAuthority() + "\"\n" + " }") @@ -339,7 +316,7 @@ public class OAuth2AuthenticationTokenMixinTests { " \"claims\": {\n" + " \"@class\": \"java.util.Collections$UnmodifiableMap\",\n" + " \"sub\": \"" + userInfo.getSubject() + "\",\n" + - " \"name\": \"" + userInfo.getClaim(NAME) + "\"\n" + + " \"name\": \"" + userInfo.getClaim(StandardClaimNames.NAME) + "\"\n" + " }\n" + " }"; // @formatter:on @@ -351,4 +328,5 @@ public class OAuth2AuthenticationTokenMixinTests { } return DecimalUtils.toBigDecimal(instant.getEpochSecond(), instant.getNano()).toString(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java index 2630efeabc..66c1d149a9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java @@ -13,26 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; + import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.Before; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; + import org.springframework.security.jackson2.SecurityJackson2Modules; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link OAuth2AuthorizationRequestMixin}. @@ -40,7 +42,9 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2AuthorizationRequestMixinTests { + private ObjectMapper mapper; + private OAuth2AuthorizationRequest.Builder authorizationRequestBuilder; @Before @@ -51,9 +55,11 @@ public class OAuth2AuthorizationRequestMixinTests { Map additionalParameters = new LinkedHashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); + // @formatter:off this.authorizationRequestBuilder = TestOAuth2AuthorizationRequests.request() .scope("read", "write") .additionalParameters(additionalParameters); + // @formatter:on } @Test @@ -66,13 +72,14 @@ public class OAuth2AuthorizationRequestMixinTests { @Test public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception { - OAuth2AuthorizationRequest authorizationRequest = - this.authorizationRequestBuilder - .scopes(null) - .state(null) - .additionalParameters(Map::clear) - .attributes(Map::clear) - .build(); + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder + .scopes(null) + .state(null) + .additionalParameters(Map::clear) + .attributes(Map::clear) + .build(); + // @formatter:on String expectedJson = asJson(authorizationRequest); String json = this.mapper.writeValueAsString(authorizationRequest); JSONAssert.assertEquals(expectedJson, json, true); @@ -81,8 +88,8 @@ public class OAuth2AuthorizationRequestMixinTests { @Test public void deserializeWhenMixinNotRegisteredThenThrowJsonProcessingException() { String json = asJson(this.authorizationRequestBuilder.build()); - assertThatThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthorizationRequest.class)) - .isInstanceOf(JsonProcessingException.class); + assertThatExceptionOfType(JsonProcessingException.class) + .isThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthorizationRequest.class)); } @Test @@ -92,18 +99,12 @@ public class OAuth2AuthorizationRequestMixinTests { OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class); assertThat(authorizationRequest.getAuthorizationUri()) .isEqualTo(expectedAuthorizationRequest.getAuthorizationUri()); - assertThat(authorizationRequest.getGrantType()) - .isEqualTo(expectedAuthorizationRequest.getGrantType()); - assertThat(authorizationRequest.getResponseType()) - .isEqualTo(expectedAuthorizationRequest.getResponseType()); - assertThat(authorizationRequest.getClientId()) - .isEqualTo(expectedAuthorizationRequest.getClientId()); - assertThat(authorizationRequest.getRedirectUri()) - .isEqualTo(expectedAuthorizationRequest.getRedirectUri()); - assertThat(authorizationRequest.getScopes()) - .isEqualTo(expectedAuthorizationRequest.getScopes()); - assertThat(authorizationRequest.getState()) - .isEqualTo(expectedAuthorizationRequest.getState()); + assertThat(authorizationRequest.getGrantType()).isEqualTo(expectedAuthorizationRequest.getGrantType()); + assertThat(authorizationRequest.getResponseType()).isEqualTo(expectedAuthorizationRequest.getResponseType()); + assertThat(authorizationRequest.getClientId()).isEqualTo(expectedAuthorizationRequest.getClientId()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedAuthorizationRequest.getRedirectUri()); + assertThat(authorizationRequest.getScopes()).isEqualTo(expectedAuthorizationRequest.getScopes()); + assertThat(authorizationRequest.getState()).isEqualTo(expectedAuthorizationRequest.getState()); assertThat(authorizationRequest.getAdditionalParameters()) .containsExactlyEntriesOf(expectedAuthorizationRequest.getAdditionalParameters()); assertThat(authorizationRequest.getAuthorizationRequestUri()) @@ -114,25 +115,21 @@ public class OAuth2AuthorizationRequestMixinTests { @Test public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception { - OAuth2AuthorizationRequest expectedAuthorizationRequest = - this.authorizationRequestBuilder - .scopes(null) - .state(null) - .additionalParameters(Map::clear) - .attributes(Map::clear) - .build(); + // @formatter:off + OAuth2AuthorizationRequest expectedAuthorizationRequest = this.authorizationRequestBuilder.scopes(null) + .state(null) + .additionalParameters(Map::clear) + .attributes(Map::clear) + .build(); + // @formatter:on String json = asJson(expectedAuthorizationRequest); OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class); assertThat(authorizationRequest.getAuthorizationUri()) .isEqualTo(expectedAuthorizationRequest.getAuthorizationUri()); - assertThat(authorizationRequest.getGrantType()) - .isEqualTo(expectedAuthorizationRequest.getGrantType()); - assertThat(authorizationRequest.getResponseType()) - .isEqualTo(expectedAuthorizationRequest.getResponseType()); - assertThat(authorizationRequest.getClientId()) - .isEqualTo(expectedAuthorizationRequest.getClientId()); - assertThat(authorizationRequest.getRedirectUri()) - .isEqualTo(expectedAuthorizationRequest.getRedirectUri()); + assertThat(authorizationRequest.getGrantType()).isEqualTo(expectedAuthorizationRequest.getGrantType()); + assertThat(authorizationRequest.getResponseType()).isEqualTo(expectedAuthorizationRequest.getResponseType()); + assertThat(authorizationRequest.getClientId()).isEqualTo(expectedAuthorizationRequest.getClientId()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedAuthorizationRequest.getRedirectUri()); assertThat(authorizationRequest.getScopes()).isEmpty(); assertThat(authorizationRequest.getState()).isNull(); assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); @@ -145,9 +142,9 @@ public class OAuth2AuthorizationRequestMixinTests { public void deserializeWhenInvalidAuthorizationGrantTypeThenThrowJsonParseException() { OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.build(); String json = asJson(authorizationRequest).replace("authorization_code", "client_credentials"); - assertThatThrownBy(() -> this.mapper.readValue(json, OAuth2AuthorizationRequest.class)) - .isInstanceOf(JsonParseException.class) - .hasMessageContaining("Invalid authorizationGrantType"); + assertThatExceptionOfType(JsonParseException.class) + .isThrownBy(() -> this.mapper.readValue(json, OAuth2AuthorizationRequest.class)) + .withMessageContaining("Invalid authorizationGrantType"); } private static String asJson(OAuth2AuthorizationRequest authorizationRequest) { @@ -157,14 +154,14 @@ public class OAuth2AuthorizationRequestMixinTests { } String additionalParameters = "\"@class\": \"java.util.Collections$UnmodifiableMap\""; if (!CollectionUtils.isEmpty(authorizationRequest.getAdditionalParameters())) { - additionalParameters += "," + authorizationRequest.getAdditionalParameters().keySet().stream() - .map(key -> "\"" + key + "\": \"" + authorizationRequest.getAdditionalParameters().get(key) + "\"") + additionalParameters += "," + authorizationRequest.getAdditionalParameters().keySet().stream().map( + (key) -> "\"" + key + "\": \"" + authorizationRequest.getAdditionalParameters().get(key) + "\"") .collect(Collectors.joining(",")); } String attributes = "\"@class\": \"java.util.Collections$UnmodifiableMap\""; if (!CollectionUtils.isEmpty(authorizationRequest.getAttributes())) { attributes += "," + authorizationRequest.getAttributes().keySet().stream() - .map(key -> "\"" + key + "\": \"" + authorizationRequest.getAttributes().get(key) + "\"") + .map((key) -> "\"" + key + "\": \"" + authorizationRequest.getAttributes().get(key) + "\"") .collect(Collectors.joining(",")); } // @formatter:off @@ -183,7 +180,7 @@ public class OAuth2AuthorizationRequestMixinTests { " \"java.util.Collections$UnmodifiableSet\",\n" + " [" + scopes + "]\n" + " ],\n" + - " \"state\": " + (authorizationRequest.getState() != null ? "\"" + authorizationRequest.getState() + "\"" : "null") + ",\n" + + " \"state\": " + ((authorizationRequest.getState() != null) ? "\"" + authorizationRequest.getState() + "\"" : "null") + ",\n" + " \"additionalParameters\": {\n" + " " + additionalParameters + "\n" + " },\n" + @@ -194,4 +191,5 @@ public class OAuth2AuthorizationRequestMixinTests { "}"; // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixinTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixinTests.java index 893e99d5a0..13bfd8ed5b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixinTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizedClientMixinTests.java @@ -13,14 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.jackson2; +import java.time.Instant; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.DecimalUtils; import org.junit.Before; import org.junit.Test; import org.skyscreamer.jsonassert.JSONAssert; + import org.springframework.security.jackson2.SecurityJackson2Modules; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -32,13 +39,8 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.time.Instant; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link OAuth2AuthorizedClientMixin}. @@ -46,10 +48,15 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2AuthorizedClientMixinTests { + private ObjectMapper mapper; + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private String principalName; @Before @@ -60,9 +67,11 @@ public class OAuth2AuthorizedClientMixinTests { Map providerConfigurationMetadata = new LinkedHashMap<>(); providerConfigurationMetadata.put("config1", "value1"); providerConfigurationMetadata.put("config2", "value2"); + // @formatter:off this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() .scope("read", "write") .providerConfigurationMetadata(providerConfigurationMetadata); + // @formatter:on this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); this.principalName = "principal-name"; @@ -70,8 +79,8 @@ public class OAuth2AuthorizedClientMixinTests { @Test public void serializeWhenMixinRegisteredThenSerializes() throws Exception { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistrationBuilder.build(), this.principalName, this.accessToken, this.refreshToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), + this.principalName, this.accessToken, this.refreshToken); String expectedJson = asJson(authorizedClient); String json = this.mapper.writeValueAsString(authorizedClient); JSONAssert.assertEquals(expectedJson, json, true); @@ -79,17 +88,18 @@ public class OAuth2AuthorizedClientMixinTests { @Test public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception { - ClientRegistration clientRegistration = - TestClientRegistrations.clientRegistration() - .clientSecret(null) - .clientName(null) - .userInfoUri(null) - .userNameAttributeName(null) - .jwkSetUri(null) - .issuerUri(null) - .build(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, this.principalName, TestOAuth2AccessTokens.noScopes()); + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .clientSecret(null) + .clientName(null) + .userInfoUri(null) + .userNameAttributeName(null) + .jwkSetUri(null) + .issuerUri(null) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, this.principalName, + TestOAuth2AccessTokens.noScopes()); String expectedJson = asJson(authorizedClient); String json = this.mapper.writeValueAsString(authorizedClient); JSONAssert.assertEquals(expectedJson, json, true); @@ -97,11 +107,11 @@ public class OAuth2AuthorizedClientMixinTests { @Test public void deserializeWhenMixinNotRegisteredThenThrowJsonProcessingException() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistrationBuilder.build(), this.principalName, this.accessToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), + this.principalName, this.accessToken); String json = asJson(authorizedClient); - assertThatThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthorizedClient.class)) - .isInstanceOf(JsonProcessingException.class); + assertThatExceptionOfType(JsonProcessingException.class) + .isThrownBy(() -> new ObjectMapper().readValue(json, OAuth2AuthorizedClient.class)); } @Test @@ -109,120 +119,96 @@ public class OAuth2AuthorizedClientMixinTests { ClientRegistration expectedClientRegistration = this.clientRegistrationBuilder.build(); OAuth2AccessToken expectedAccessToken = this.accessToken; OAuth2RefreshToken expectedRefreshToken = this.refreshToken; - OAuth2AuthorizedClient expectedAuthorizedClient = new OAuth2AuthorizedClient( - expectedClientRegistration, this.principalName, expectedAccessToken, expectedRefreshToken); + OAuth2AuthorizedClient expectedAuthorizedClient = new OAuth2AuthorizedClient(expectedClientRegistration, + this.principalName, expectedAccessToken, expectedRefreshToken); String json = asJson(expectedAuthorizedClient); OAuth2AuthorizedClient authorizedClient = this.mapper.readValue(json, OAuth2AuthorizedClient.class); ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); - assertThat(clientRegistration.getRegistrationId()) - .isEqualTo(expectedClientRegistration.getRegistrationId()); - assertThat(clientRegistration.getClientId()) - .isEqualTo(expectedClientRegistration.getClientId()); - assertThat(clientRegistration.getClientSecret()) - .isEqualTo(expectedClientRegistration.getClientSecret()); + assertThat(clientRegistration.getRegistrationId()).isEqualTo(expectedClientRegistration.getRegistrationId()); + assertThat(clientRegistration.getClientId()).isEqualTo(expectedClientRegistration.getClientId()); + assertThat(clientRegistration.getClientSecret()).isEqualTo(expectedClientRegistration.getClientSecret()); assertThat(clientRegistration.getClientAuthenticationMethod()) .isEqualTo(expectedClientRegistration.getClientAuthenticationMethod()); assertThat(clientRegistration.getAuthorizationGrantType()) .isEqualTo(expectedClientRegistration.getAuthorizationGrantType()); - assertThat(clientRegistration.getRedirectUri()) - .isEqualTo(expectedClientRegistration.getRedirectUri()); - assertThat(clientRegistration.getScopes()) - .isEqualTo(expectedClientRegistration.getScopes()); + assertThat(clientRegistration.getRedirectUri()).isEqualTo(expectedClientRegistration.getRedirectUri()); + assertThat(clientRegistration.getScopes()).isEqualTo(expectedClientRegistration.getScopes()); assertThat(clientRegistration.getProviderDetails().getAuthorizationUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getAuthorizationUri()); assertThat(clientRegistration.getProviderDetails().getTokenUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getTokenUri()); assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); - assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()) - .isEqualTo(expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()); - assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) - .isEqualTo(expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()); + assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()).isEqualTo( + expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()); + assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo( + expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()); assertThat(clientRegistration.getProviderDetails().getJwkSetUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getJwkSetUri()); assertThat(clientRegistration.getProviderDetails().getIssuerUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getIssuerUri()); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()) .containsExactlyEntriesOf(clientRegistration.getProviderDetails().getConfigurationMetadata()); - assertThat(clientRegistration.getClientName()) - .isEqualTo(expectedClientRegistration.getClientName()); - assertThat(authorizedClient.getPrincipalName()) - .isEqualTo(expectedAuthorizedClient.getPrincipalName()); + assertThat(clientRegistration.getClientName()).isEqualTo(expectedClientRegistration.getClientName()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expectedAuthorizedClient.getPrincipalName()); OAuth2AccessToken accessToken = authorizedClient.getAccessToken(); - assertThat(accessToken.getTokenType()) - .isEqualTo(expectedAccessToken.getTokenType()); - assertThat(accessToken.getScopes()) - .isEqualTo(expectedAccessToken.getScopes()); - assertThat(accessToken.getTokenValue()) - .isEqualTo(expectedAccessToken.getTokenValue()); - assertThat(accessToken.getIssuedAt()) - .isEqualTo(expectedAccessToken.getIssuedAt()); - assertThat(accessToken.getExpiresAt()) - .isEqualTo(expectedAccessToken.getExpiresAt()); + assertThat(accessToken.getTokenType()).isEqualTo(expectedAccessToken.getTokenType()); + assertThat(accessToken.getScopes()).isEqualTo(expectedAccessToken.getScopes()); + assertThat(accessToken.getTokenValue()).isEqualTo(expectedAccessToken.getTokenValue()); + assertThat(accessToken.getIssuedAt()).isEqualTo(expectedAccessToken.getIssuedAt()); + assertThat(accessToken.getExpiresAt()).isEqualTo(expectedAccessToken.getExpiresAt()); OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); - assertThat(refreshToken.getTokenValue()) - .isEqualTo(expectedRefreshToken.getTokenValue()); - assertThat(refreshToken.getIssuedAt()) - .isEqualTo(expectedRefreshToken.getIssuedAt()); - assertThat(refreshToken.getExpiresAt()) - .isEqualTo(expectedRefreshToken.getExpiresAt()); + assertThat(refreshToken.getTokenValue()).isEqualTo(expectedRefreshToken.getTokenValue()); + assertThat(refreshToken.getIssuedAt()).isEqualTo(expectedRefreshToken.getIssuedAt()); + assertThat(refreshToken.getExpiresAt()).isEqualTo(expectedRefreshToken.getExpiresAt()); } @Test public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception { - ClientRegistration expectedClientRegistration = - TestClientRegistrations.clientRegistration() - .clientSecret(null) - .clientName(null) - .userInfoUri(null) - .userNameAttributeName(null) - .jwkSetUri(null) - .issuerUri(null) - .build(); + // @formatter:off + ClientRegistration expectedClientRegistration = TestClientRegistrations.clientRegistration() + .clientSecret(null) + .clientName(null) + .userInfoUri(null) + .userNameAttributeName(null) + .jwkSetUri(null) + .issuerUri(null) + .build(); + // @formatter:on OAuth2AccessToken expectedAccessToken = TestOAuth2AccessTokens.noScopes(); - OAuth2AuthorizedClient expectedAuthorizedClient = new OAuth2AuthorizedClient( - expectedClientRegistration, this.principalName, expectedAccessToken); + OAuth2AuthorizedClient expectedAuthorizedClient = new OAuth2AuthorizedClient(expectedClientRegistration, + this.principalName, expectedAccessToken); String json = asJson(expectedAuthorizedClient); OAuth2AuthorizedClient authorizedClient = this.mapper.readValue(json, OAuth2AuthorizedClient.class); ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); - assertThat(clientRegistration.getRegistrationId()) - .isEqualTo(expectedClientRegistration.getRegistrationId()); - assertThat(clientRegistration.getClientId()) - .isEqualTo(expectedClientRegistration.getClientId()); + assertThat(clientRegistration.getRegistrationId()).isEqualTo(expectedClientRegistration.getRegistrationId()); + assertThat(clientRegistration.getClientId()).isEqualTo(expectedClientRegistration.getClientId()); assertThat(clientRegistration.getClientSecret()).isEmpty(); assertThat(clientRegistration.getClientAuthenticationMethod()) .isEqualTo(expectedClientRegistration.getClientAuthenticationMethod()); assertThat(clientRegistration.getAuthorizationGrantType()) .isEqualTo(expectedClientRegistration.getAuthorizationGrantType()); - assertThat(clientRegistration.getRedirectUri()) - .isEqualTo(expectedClientRegistration.getRedirectUri()); - assertThat(clientRegistration.getScopes()) - .isEqualTo(expectedClientRegistration.getScopes()); + assertThat(clientRegistration.getRedirectUri()).isEqualTo(expectedClientRegistration.getRedirectUri()); + assertThat(clientRegistration.getScopes()).isEqualTo(expectedClientRegistration.getScopes()); assertThat(clientRegistration.getProviderDetails().getAuthorizationUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getAuthorizationUri()); assertThat(clientRegistration.getProviderDetails().getTokenUri()) .isEqualTo(expectedClientRegistration.getProviderDetails().getTokenUri()); assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).isNull(); - assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()) - .isEqualTo(expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()); + assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()).isEqualTo( + expectedClientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()); assertThat(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()).isNull(); assertThat(clientRegistration.getProviderDetails().getJwkSetUri()).isNull(); assertThat(clientRegistration.getProviderDetails().getIssuerUri()).isNull(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty(); - assertThat(clientRegistration.getClientName()) - .isEqualTo(clientRegistration.getRegistrationId()); - assertThat(authorizedClient.getPrincipalName()) - .isEqualTo(expectedAuthorizedClient.getPrincipalName()); + assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expectedAuthorizedClient.getPrincipalName()); OAuth2AccessToken accessToken = authorizedClient.getAccessToken(); - assertThat(accessToken.getTokenType()) - .isEqualTo(expectedAccessToken.getTokenType()); + assertThat(accessToken.getTokenType()).isEqualTo(expectedAccessToken.getTokenType()); assertThat(accessToken.getScopes()).isEmpty(); - assertThat(accessToken.getTokenValue()) - .isEqualTo(expectedAccessToken.getTokenValue()); - assertThat(accessToken.getIssuedAt()) - .isEqualTo(expectedAccessToken.getIssuedAt()); - assertThat(accessToken.getExpiresAt()) - .isEqualTo(expectedAccessToken.getExpiresAt()); + assertThat(accessToken.getTokenValue()).isEqualTo(expectedAccessToken.getTokenValue()); + assertThat(accessToken.getIssuedAt()).isEqualTo(expectedAccessToken.getIssuedAt()); + assertThat(accessToken.getExpiresAt()).isEqualTo(expectedAccessToken.getExpiresAt()); assertThat(authorizedClient.getRefreshToken()).isNull(); } @@ -248,7 +234,7 @@ public class OAuth2AuthorizedClientMixinTests { String configurationMetadata = "\"@class\": \"java.util.Collections$UnmodifiableMap\""; if (!CollectionUtils.isEmpty(providerDetails.getConfigurationMetadata())) { configurationMetadata += "," + providerDetails.getConfigurationMetadata().keySet().stream() - .map(key -> "\"" + key + "\": \"" + providerDetails.getConfigurationMetadata().get(key) + "\"") + .map((key) -> "\"" + key + "\": \"" + providerDetails.getConfigurationMetadata().get(key) + "\"") .collect(Collectors.joining(",")); } // @formatter:off @@ -274,14 +260,14 @@ public class OAuth2AuthorizedClientMixinTests { " \"tokenUri\": \"" + providerDetails.getTokenUri() + "\",\n" + " \"userInfoEndpoint\": {\n" + " \"@class\": \"org.springframework.security.oauth2.client.registration.ClientRegistration$ProviderDetails$UserInfoEndpoint\",\n" + - " \"uri\": " + (userInfoEndpoint.getUri() != null ? "\"" + userInfoEndpoint.getUri() + "\"" : null) + ",\n" + + " \"uri\": " + ((userInfoEndpoint.getUri() != null) ? "\"" + userInfoEndpoint.getUri() + "\"" : null) + ",\n" + " \"authenticationMethod\": {\n" + " \"value\": \"" + userInfoEndpoint.getAuthenticationMethod().getValue() + "\"\n" + " },\n" + - " \"userNameAttributeName\": " + (userInfoEndpoint.getUserNameAttributeName() != null ? "\"" + userInfoEndpoint.getUserNameAttributeName() + "\"" : null) + "\n" + + " \"userNameAttributeName\": " + ((userInfoEndpoint.getUserNameAttributeName() != null) ? "\"" + userInfoEndpoint.getUserNameAttributeName() + "\"" : null) + "\n" + " },\n" + - " \"jwkSetUri\": " + (providerDetails.getJwkSetUri() != null ? "\"" + providerDetails.getJwkSetUri() + "\"" : null) + ",\n" + - " \"issuerUri\": " + (providerDetails.getIssuerUri() != null ? "\"" + providerDetails.getIssuerUri() + "\"" : null) + ",\n" + + " \"jwkSetUri\": " + ((providerDetails.getJwkSetUri() != null) ? "\"" + providerDetails.getJwkSetUri() + "\"" : null) + ",\n" + + " \"issuerUri\": " + ((providerDetails.getIssuerUri() != null) ? "\"" + providerDetails.getIssuerUri() + "\"" : null) + ",\n" + " \"configurationMetadata\": {\n" + " " + configurationMetadata + "\n" + " }\n" + @@ -333,4 +319,5 @@ public class OAuth2AuthorizedClientMixinTests { } return DecimalUtils.toBigDecimal(instant.getEpochSecond(), instant.getNano()).toString(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java index 4c9fc9acbe..2b0e03fa6a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java @@ -13,38 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.authentication; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.mockito.ArgumentCaptor; -import org.mockito.stubbing.Answer; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; -import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; -import org.springframework.security.crypto.keygen.StringKeyGenerator; -import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; -import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import org.springframework.security.oauth2.core.oidc.user.OidcUser; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.JwtException; +package org.springframework.security.oauth2.client.oidc.authentication; import java.security.NoSuchAlgorithmException; import java.time.Instant; @@ -57,17 +27,49 @@ import java.util.List; import java.util.Map; import java.util.Set; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.TestJwts; + import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider.createHash; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link OidcAuthorizationCodeAuthenticationProvider}. @@ -76,15 +78,26 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * @author Mark Heckler */ public class OidcAuthorizationCodeAuthenticationProviderTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizationRequest authorizationRequest; + private OAuth2AuthorizationResponse authorizationResponse; + private OAuth2AuthorizationExchange authorizationExchange; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private OAuth2AccessTokenResponse accessTokenResponse; + private OAuth2UserService userService; + private OidcAuthorizationCodeAuthenticationProvider authenticationProvider; - private StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + + private StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 96); + private String nonceHash; @Rule @@ -93,29 +106,34 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Before @SuppressWarnings("unchecked") public void setUp() { - this.clientRegistration = clientRegistration().clientId("client1").build(); + this.clientRegistration = TestClientRegistrations.clientRegistration().clientId("client1").build(); Map attributes = new HashMap<>(); Map additionalParameters = new HashMap<>(); try { String nonce = this.secureKeyGenerator.generateKey(); - this.nonceHash = createHash(nonce); + this.nonceHash = OidcAuthorizationCodeAuthenticationProvider.createHash(nonce); attributes.put(OidcParameterNames.NONCE, nonce); additionalParameters.put(OidcParameterNames.NONCE, this.nonceHash); - } catch (NoSuchAlgorithmException e) { } - this.authorizationRequest = request() + } + catch (NoSuchAlgorithmException ex) { + } + // @formatter:off + this.authorizationRequest = TestOAuth2AuthorizationRequests.request() .scope("openid", "profile", "email") .attributes(attributes) .additionalParameters(additionalParameters) .build(); - this.authorizationResponse = success().build(); - this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); + this.authorizationResponse = TestOAuth2AuthorizationResponses.success() + .build(); + // @formatter:on + this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + this.authorizationResponse); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.accessTokenResponse = this.accessTokenSuccessResponse(); this.userService = mock(OAuth2UserService.class); - this.authenticationProvider = - new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService); - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse); + this.authenticationProvider = new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, + this.userService); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(this.accessTokenResponse); } @Test @@ -149,14 +167,15 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() { - OAuth2AuthorizationRequest authorizationRequest = request().scope("scope1").build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse); - - OAuth2LoginAuthenticationToken authentication = - (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); - + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .scope("scope1") + .build(); + // @formatter:on + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + this.authorizationResponse); + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); assertThat(authentication).isNull(); } @@ -164,72 +183,75 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE)); - - OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_SCOPE).build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + // @formatter:off + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() + .errorCode(OAuth2ErrorCodes.INVALID_SCOPE) + .build(); + // @formatter:on + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + authorizationResponse); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_state_parameter")); - - OAuth2AuthorizationResponse authorizationResponse = success().state("89012").build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + // @formatter:off + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .state("89012") + .build(); + // @formatter:on + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, + authorizationResponse); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_id_token")); - - OAuth2AccessTokenResponse accessTokenResponse = - OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse()) - .additionalParameters(Collections.emptyMap()) - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withResponse(this.accessTokenSuccessResponse()) + .additionalParameters(Collections.emptyMap()) + .build(); + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); } @Test public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("missing_signature_verifier")); - - ClientRegistration clientRegistration = clientRegistration().jwkSetUri(null).build(); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)); + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .jwkSetUri(null) + .build(); + // @formatter:on + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)); } @Test public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("[invalid_id_token] ID Token Validation Error")); - JwtDecoder jwtDecoder = mock(JwtDecoder.class); - when(jwtDecoder.decode(anyString())).thenThrow(new JwtException("ID Token Validation Error")); - this.authenticationProvider.setJwtDecoderFactory(registration -> jwtDecoder); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + given(jwtDecoder.decode(anyString())).willThrow(new JwtException("ID Token Validation Error")); + this.authenticationProvider.setJwtDecoderFactory((registration) -> jwtDecoder); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); } @Test public void authenticateWhenIdTokenInvalidNonceThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("[invalid_nonce]")); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://provider.com"); claims.put(IdTokenClaimNames.SUB, "subject1"); @@ -237,9 +259,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { claims.put(IdTokenClaimNames.AZP, "client1"); claims.put(IdTokenClaimNames.NONCE, "invalid-nonce-hash"); this.setUpIdToken(claims); - - this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); } @Test @@ -251,17 +272,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { claims.put(IdTokenClaimNames.AZP, "client1"); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); this.setUpIdToken(claims); - OidcUser principal = mock(OidcUser.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(principal.getAuthorities()).thenAnswer( - (Answer>) invocation -> authorities); - when(this.userService.loadUser(any())).thenReturn(principal); - - OAuth2LoginAuthenticationToken authentication = - (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - + given(principal.getAuthorities()).willAnswer((Answer>) (invocation) -> authorities); + given(this.userService.loadUser(any())).willReturn(principal); + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); assertThat(authentication.isAuthenticated()).isTrue(); assertThat(authentication.getPrincipal()).isEqualTo(principal); assertThat(authentication.getCredentials()).isEqualTo(""); @@ -281,23 +297,17 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { claims.put(IdTokenClaimNames.AZP, "client1"); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); this.setUpIdToken(claims); - OidcUser principal = mock(OidcUser.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(principal.getAuthorities()).thenAnswer( - (Answer>) invocation -> authorities); - when(this.userService.loadUser(any())).thenReturn(principal); - + given(principal.getAuthorities()).willAnswer((Answer>) (invocation) -> authorities); + given(this.userService.loadUser(any())).willReturn(principal); List mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OIDC_USER"); GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class); - when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer( - (Answer>) invocation -> mappedAuthorities); + given(authoritiesMapper.mapAuthorities(anyCollection())) + .willAnswer((Answer>) (invocation) -> mappedAuthorities); this.authenticationProvider.setAuthoritiesMapper(authoritiesMapper); - - OAuth2LoginAuthenticationToken authentication = - (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); - + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities); } @@ -311,26 +321,22 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { claims.put(IdTokenClaimNames.AZP, "client1"); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); this.setUpIdToken(claims); - OidcUser principal = mock(OidcUser.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - when(principal.getAuthorities()).thenAnswer( - (Answer>) invocation -> authorities); + given(principal.getAuthorities()).willAnswer((Answer>) (invocation) -> authorities); ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); - when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal); - - this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken( - this.clientRegistration, this.authorizationExchange)); - - assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf( - this.accessTokenResponse.getAdditionalParameters()); + given(this.userService.loadUser(userRequestArgCaptor.capture())).willReturn(principal); + this.authenticationProvider + .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()) + .containsAllEntriesOf(this.accessTokenResponse.getAdditionalParameters()); } private void setUpIdToken(Map claims) { - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); JwtDecoder jwtDecoder = mock(JwtDecoder.class); - when(jwtDecoder.decode(anyString())).thenReturn(idToken); - this.authenticationProvider.setJwtDecoderFactory(registration -> jwtDecoder); + given(jwtDecoder.decode(anyString())).willReturn(idToken); + this.authenticationProvider.setJwtDecoderFactory((registration) -> jwtDecoder); } private OAuth2AccessTokenResponse accessTokenSuccessResponse() { @@ -340,15 +346,15 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); - - return OAuth2AccessTokenResponse - .withToken("access-token-1234") + // @formatter:off + return OAuth2AccessTokenResponse.withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(expiresAt.getEpochSecond()) .scopes(scopes) .refreshToken("refresh-token-1234") .additionalParameters(additionalParameters) .build(); - + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java index 4faec435a1..2c1dafdc86 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java @@ -31,13 +31,13 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import reactor.core.publisher.Mono; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; @@ -63,16 +63,15 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager.createHash; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * @author Rob Winch @@ -81,6 +80,7 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; */ @RunWith(MockitoJUnitRunner.class) public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { + @Mock private ReactiveOAuth2UserService userService; @@ -90,62 +90,64 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { @Mock private ReactiveJwtDecoder jwtDecoder; + // @formatter:off private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration() .scope("openid"); + // @formatter:on - private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse - .success("code") + // @formatter:off + private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse.success("code") .state("state"); + // @formatter:on private OidcIdToken idToken = TestOidcIdTokens.idToken().build(); private OidcAuthorizationCodeReactiveAuthenticationManager manager; - private StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); + private StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 96); private String nonceHash; @Before public void setup() { - this.manager = new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService); + this.manager = new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, + this.userService); } @Test public void constructorWhenNullAccessTokenResponseClientThenIllegalArgumentException() { this.accessTokenResponseClient = null; - assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, + this.userService)); } @Test public void constructorWhenNullUserServiceThenIllegalArgumentException() { this.userService = null; - assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, + this.userService)); } @Test public void setJwtDecoderFactoryWhenNullThenIllegalArgumentException() { - assertThatThrownBy(() -> this.manager.setJwtDecoderFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.manager.setJwtDecoderFactory(null)); } @Test public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.manager.setAuthoritiesMapper(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.manager.setAuthoritiesMapper(null)); } @Test public void authenticateWhenNoSubscriptionThenDoesNothing() { - // we didn't do anything because it should cause a ClassCastException (as verified below) + // we didn't do anything because it should cause a ClassCastException (as verified + // below) TestingAuthenticationToken token = new TestingAuthenticationToken("a", "b"); - - assertThatCode(()-> this.manager.authenticate(token)) - .doesNotThrowAnyException(); - - assertThatThrownBy(() -> this.manager.authenticate(token).block()) - .isInstanceOf(Throwable.class); + this.manager.authenticate(token); + assertThatExceptionOfType(Throwable.class).isThrownBy(() -> this.manager.authenticate(token).block()); } @Test @@ -156,108 +158,108 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { @Test public void authenticationWhenErrorThenOAuth2AuthenticationException() { - this.authorizationResponseBldr = OAuth2AuthorizationResponse - .error("error") + // @formatter:off + this.authorizationResponseBldr = OAuth2AuthorizationResponse.error("error") .state("state"); - assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + // @formatter:on + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(loginToken()).block()); } @Test public void authenticationWhenStateDoesNotMatchThenOAuth2AuthenticationException() { this.authorizationResponseBldr.state("notmatch"); - assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(loginToken()).block()); } @Test public void authenticateWhenIdTokenValidationErrorThenOAuth2AuthenticationException() { + // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - when(this.jwtDecoder.decode(any())).thenThrow(new JwtException("ID Token Validation Error")); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); - - assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_id_token] ID Token Validation Error"); + // @formatter:on + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + given(this.jwtDecoder.decode(any())).willThrow(new JwtException("ID Token Validation Error")); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(loginToken()).block()) + .withMessageContaining("[invalid_id_token] ID Token Validation Error"); } @Test public void authenticateWhenIdTokenInvalidNonceThenOAuth2AuthenticationException() { + // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) + .additionalParameters( + Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) .build(); - + // @formatter:on OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.SUB, "sub"); claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id")); claims.put(IdTokenClaimNames.NONCE, "invalid-nonce-hash"); - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); - - assertThatThrownBy(() -> this.manager.authenticate(authorizationCodeAuthentication).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_nonce]"); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + given(this.jwtDecoder.decode(any())).willReturn(Mono.just(idToken)); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(authorizationCodeAuthentication).block()) + .withMessageContaining("[invalid_nonce]"); } @Test public void authenticationWhenOAuth2UserNotFoundThenEmpty() { + // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.")) + .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.")) .build(); - + // @formatter:on OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.SUB, "rob"); claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id")); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - when(this.userService.loadUser(any())).thenReturn(Mono.empty()); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + given(this.userService.loadUser(any())).willReturn(Mono.empty()); + given(this.jwtDecoder.decode(any())).willReturn(Mono.just(idToken)); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); assertThat(this.manager.authenticate(authorizationCodeAuthentication).block()).isNull(); } @Test public void authenticationWhenOAuth2UserFoundThenSuccess() { - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) + .additionalParameters( + Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) .build(); - + // @formatter:on OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.SUB, "rob"); claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id")); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); - when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); - - OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(authorizationCodeAuthentication).block(); - + given(this.userService.loadUser(any())).willReturn(Mono.just(user)); + given(this.jwtDecoder.decode(any())).willReturn(Mono.just(idToken)); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); + OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager + .authenticate(authorizationCodeAuthentication).block(); assertThat(result.getPrincipal()).isEqualTo(user); assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities()); assertThat(result.isAuthenticated()).isTrue(); @@ -265,29 +267,29 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { @Test public void authenticationWhenRefreshTokenThenRefreshTokenInAuthorizedClient() { - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) + .additionalParameters( + Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) .refreshToken("refresh-token") .build(); - + // @formatter:on OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.SUB, "rob"); claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id")); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); - when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); - - OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(authorizationCodeAuthentication).block(); - + given(this.userService.loadUser(any())).willReturn(Mono.just(user)); + given(this.jwtDecoder.decode(any())).willReturn(Mono.just(idToken)); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); + OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager + .authenticate(authorizationCodeAuthentication).block(); assertThat(result.getPrincipal()).isEqualTo(user); assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities()); assertThat(result.isAuthenticated()).isTrue(); @@ -302,29 +304,27 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) .additionalParameters(additionalParameters) .build(); - + // @formatter:on OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.SUB, "rob"); claims.put(IdTokenClaimNames.AUD, Arrays.asList(clientRegistration.getClientId())); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); - when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user)); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); - + given(this.userService.loadUser(userRequestArgCaptor.capture())).willReturn(Mono.just(user)); + given(this.jwtDecoder.decode(any())).willReturn(Mono.just(idToken)); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); this.manager.authenticate(authorizationCodeAuthentication).block(); - assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()) .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); } @@ -332,36 +332,32 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { @Test public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { ClientRegistration clientRegistration = this.registration.build(); + // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) - .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) + .additionalParameters( + Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) .build(); - + // @formatter:on OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); - Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.SUB, "rob"); claims.put(IdTokenClaimNames.AUD, Collections.singletonList(clientRegistration.getClientId())); claims.put(IdTokenClaimNames.NONCE, this.nonceHash); - Jwt idToken = jwt().claims(c -> c.putAll(claims)).build(); - - - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + Jwt idToken = TestJwts.jwt().claims((c) -> c.putAll(claims)).build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); - when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user)); - + given(this.userService.loadUser(userRequestArgCaptor.capture())).willReturn(Mono.just(user)); List mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OIDC_USER"); GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class); - when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer( - (Answer>) invocation -> mappedAuthorities); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); - this.manager.setJwtDecoderFactory(c -> this.jwtDecoder); + given(authoritiesMapper.mapAuthorities(anyCollection())) + .willAnswer((Answer>) (invocation) -> mappedAuthorities); + given(this.jwtDecoder.decode(any())).willReturn(Mono.just(idToken)); + this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); this.manager.setAuthoritiesMapper(authoritiesMapper); - Authentication result = this.manager.authenticate(authorizationCodeAuthentication).block(); - assertThat(result.getAuthorities()).isEqualTo(mappedAuthorities); } @@ -371,25 +367,28 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { Map additionalParameters = new HashMap<>(); try { String nonce = this.secureKeyGenerator.generateKey(); - this.nonceHash = createHash(nonce); + this.nonceHash = OidcAuthorizationCodeReactiveAuthenticationManager.createHash(nonce); attributes.put(OidcParameterNames.NONCE, nonce); additionalParameters.put(OidcParameterNames.NONCE, this.nonceHash); - } catch (NoSuchAlgorithmException e) { } - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest - .authorizationCode() + } + catch (NoSuchAlgorithmException ex) { + } + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .state("state") .clientId(clientRegistration.getClientId()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .redirectUri(clientRegistration.getRedirectUri()) .scopes(clientRegistration.getScopes()) .additionalParameters(additionalParameters) - .attributes(attributes) - .build(); + .attributes(attributes).build(); OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr .redirectUri(clientRegistration.getRedirectUri()) .build(); + // @formatter:on OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); return new OAuth2AuthorizationCodeAuthenticationToken(clientRegistration, authorizationExchange); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java index 0b26a0d03f..ef3e1df791 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; + import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -30,12 +35,13 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; -import java.util.Map; -import java.util.function.Function; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * @author Joe Grandja @@ -44,8 +50,11 @@ import static org.mockito.Mockito.*; */ public class OidcIdTokenDecoderFactoryTests { - private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration() + // @formatter:off + private ClientRegistration.Builder registration = TestClientRegistrations + .clientRegistration() .scope("openid"); + // @formatter:on private OidcIdTokenDecoderFactory idTokenDecoderFactory; @@ -56,7 +65,8 @@ public class OidcIdTokenDecoderFactoryTests { @Test public void createDefaultClaimTypeConvertersWhenCalledThenDefaultsAreCorrect() { - Map> claimTypeConverters = OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters(); + Map> claimTypeConverters = OidcIdTokenDecoderFactory + .createDefaultClaimTypeConverters(); assertThat(claimTypeConverters).containsKey(IdTokenClaimNames.ISS); assertThat(claimTypeConverters).containsKey(IdTokenClaimNames.AUD); assertThat(claimTypeConverters).containsKey(IdTokenClaimNames.NONCE); @@ -71,85 +81,78 @@ public class OidcIdTokenDecoderFactoryTests { @Test public void setJwtValidatorFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.setJwtValidatorFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.setJwtValidatorFactory(null)); } @Test public void setJwsAlgorithmResolverWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.setJwsAlgorithmResolver(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.setJwsAlgorithmResolver(null)); } @Test public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)); } @Test public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)); } @Test public void createDecoderWhenJwsAlgorithmDefaultAndJwkSetUriEmptyThenThrowOAuth2AuthenticationException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured the JwkSet URI."); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured the JwkSet URI."); } @Test public void createDecoderWhenJwsAlgorithmEcAndJwkSetUriEmptyThenThrowOAuth2AuthenticationException() { - this.idTokenDecoderFactory.setJwsAlgorithmResolver(clientRegistration -> SignatureAlgorithm.ES256); - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured the JwkSet URI."); + this.idTokenDecoderFactory.setJwsAlgorithmResolver((clientRegistration) -> SignatureAlgorithm.ES256); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured the JwkSet URI."); } @Test public void createDecoderWhenJwsAlgorithmHmacAndClientSecretNullThenThrowOAuth2AuthenticationException() { - this.idTokenDecoderFactory.setJwsAlgorithmResolver(clientRegistration -> MacAlgorithm.HS256); - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.clientSecret(null).build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured the client secret."); + this.idTokenDecoderFactory.setJwsAlgorithmResolver((clientRegistration) -> MacAlgorithm.HS256); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.idTokenDecoderFactory.createDecoder(this.registration.clientSecret(null).build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured the client secret."); } @Test public void createDecoderWhenJwsAlgorithmNullThenThrowOAuth2AuthenticationException() { - this.idTokenDecoderFactory.setJwsAlgorithmResolver(clientRegistration -> null); - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured a valid JWS Algorithm: 'null'"); + this.idTokenDecoderFactory.setJwsAlgorithmResolver((clientRegistration) -> null); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured a valid JWS Algorithm: 'null'"); } @Test public void createDecoderWhenClientRegistrationValidThenReturnDecoder() { - assertThat(this.idTokenDecoderFactory.createDecoder(this.registration.build())) - .isNotNull(); + assertThat(this.idTokenDecoderFactory.createDecoder(this.registration.build())).isNotNull(); } @Test public void createDecoderWhenCustomJwtValidatorFactorySetThenApplied() { Function> customJwtValidatorFactory = mock(Function.class); this.idTokenDecoderFactory.setJwtValidatorFactory(customJwtValidatorFactory); - ClientRegistration clientRegistration = this.registration.build(); - - when(customJwtValidatorFactory.apply(same(clientRegistration))) - .thenReturn(new OidcIdTokenValidator(clientRegistration)); - + given(customJwtValidatorFactory.apply(same(clientRegistration))) + .willReturn(new OidcIdTokenValidator(clientRegistration)); this.idTokenDecoderFactory.createDecoder(clientRegistration); - verify(customJwtValidatorFactory).apply(same(clientRegistration)); } @@ -157,29 +160,22 @@ public class OidcIdTokenDecoderFactoryTests { public void createDecoderWhenCustomJwsAlgorithmResolverSetThenApplied() { Function customJwsAlgorithmResolver = mock(Function.class); this.idTokenDecoderFactory.setJwsAlgorithmResolver(customJwsAlgorithmResolver); - ClientRegistration clientRegistration = this.registration.build(); - - when(customJwsAlgorithmResolver.apply(same(clientRegistration))) - .thenReturn(MacAlgorithm.HS256); - + given(customJwsAlgorithmResolver.apply(same(clientRegistration))).willReturn(MacAlgorithm.HS256); this.idTokenDecoderFactory.createDecoder(clientRegistration); - verify(customJwsAlgorithmResolver).apply(same(clientRegistration)); } @Test public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { - Function, Map>> customClaimTypeConverterFactory = mock(Function.class); + Function, Map>> customClaimTypeConverterFactory = mock( + Function.class); this.idTokenDecoderFactory.setClaimTypeConverterFactory(customClaimTypeConverterFactory); - ClientRegistration clientRegistration = this.registration.build(); - - when(customClaimTypeConverterFactory.apply(same(clientRegistration))) - .thenReturn(new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters())); - + given(customClaimTypeConverterFactory.apply(same(clientRegistration))) + .willReturn(new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters())); this.idTokenDecoderFactory.createDecoder(clientRegistration); - verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java index 098fe8d1f2..99b32caac6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java @@ -13,16 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.authentication; -import org.junit.Before; -import org.junit.Test; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; -import org.springframework.security.oauth2.jwt.Jwt; +package org.springframework.security.oauth2.client.oidc.authentication; import java.time.Duration; import java.time.Instant; @@ -32,8 +24,18 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; +import org.springframework.security.oauth2.jwt.Jwt; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * @author Rob Winch @@ -41,11 +43,17 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @since 5.1 */ public class OidcIdTokenValidatorTests { + private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); + private Map headers = new HashMap<>(); + private Map claims = new HashMap<>(); + private Instant issuedAt = Instant.now(); + private Instant expiresAt = this.issuedAt.plusSeconds(3600); + private Duration clockSkew = Duration.ofSeconds(60); @Before @@ -61,114 +69,129 @@ public class OidcIdTokenValidatorTests { assertThat(this.validateIdToken()).isEmpty(); } - @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build()); - assertThatThrownBy(() -> idTokenValidator.setClockSkew(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> idTokenValidator.setClockSkew(null)); + // @formatter:on } @Test public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build()); - assertThatThrownBy(() -> idTokenValidator.setClockSkew(Duration.ofSeconds(-1))) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> idTokenValidator.setClockSkew(Duration.ofSeconds(-1))); + // @formatter:on } @Test public void setClockWhenNullThenThrowIllegalArgumentException() { OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build()); - assertThatThrownBy(() -> idTokenValidator.setClock(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> idTokenValidator.setClock(null)); + // @formatter:on } @Test public void validateWhenIssuerNullThenHasErrors() { this.claims.remove(IdTokenClaimNames.ISS); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.ISS)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.ISS)); + // @formatter:on } @Test public void validateWhenMetadataIssuerMismatchThenHasErrors() { /* - * When the issuer is set in the provider metadata, and it does not match the issuer in the ID Token, - * the validation must fail + * When the issuer is set in the provider metadata, and it does not match the + * issuer in the ID Token, the validation must fail */ this.registration = this.registration.issuerUri("https://somethingelse.com"); - + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.ISS)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.ISS)); + // @formatter:on } @Test public void validateWhenMetadataIssuerMatchThenNoErrors() { /* - * When the issuer is set in the provider metadata, and it does match the issuer in the ID Token, - * the validation must succeed + * When the issuer is set in the provider metadata, and it does match the issuer + * in the ID Token, the validation must succeed */ this.registration = this.registration.issuerUri("https://example.com"); - assertThat(this.validateIdToken()).isEmpty(); } @Test public void validateWhenSubNullThenHasErrors() { this.claims.remove(IdTokenClaimNames.SUB); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.SUB)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.SUB)); + // @formatter:on } @Test public void validateWhenAudNullThenHasErrors() { this.claims.remove(IdTokenClaimNames.AUD); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD)); + // @formatter:on } @Test public void validateWhenIssuedAtNullThenHasErrors() { this.issuedAt = null; + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT)); + // @formatter:on } @Test public void validateWhenExpiresAtNullThenHasErrors() { this.expiresAt = null; - assertThat(this.validateIdToken()) - .hasSize(1) - .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); + assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) + .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); } @Test public void validateWhenAudMultipleAndAzpNullThenHasErrors() { this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP)); + // @formatter:on } @Test public void validateWhenAzpNotClientIdThenHasErrors() { this.claims.put(IdTokenClaimNames.AZP, "other"); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP)); + // @formatter:on } @Test @@ -182,19 +205,23 @@ public class OidcIdTokenValidatorTests { public void validateWhenMultipleAudAzpNotClientIdThenHasErrors() { this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id-1", "client-id-2")); this.claims.put(IdTokenClaimNames.AZP, "other-client"); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP)); + // @formatter:on } @Test public void validateWhenAudNotClientIdThenHasErrors() { this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client")); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD)); + // @formatter:on } @Test @@ -210,10 +237,12 @@ public class OidcIdTokenValidatorTests { this.issuedAt = Instant.now().minus(Duration.ofSeconds(60)); this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(30)); this.clockSkew = Duration.ofSeconds(0); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); + // @formatter:on } @Test @@ -229,10 +258,12 @@ public class OidcIdTokenValidatorTests { this.issuedAt = Instant.now().plus(Duration.ofMinutes(1)); this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(60)); this.clockSkew = Duration.ofMinutes(0); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT)); + // @formatter:on } @Test @@ -240,10 +271,12 @@ public class OidcIdTokenValidatorTests { this.issuedAt = Instant.now().minus(Duration.ofSeconds(10)); this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(5)); this.clockSkew = Duration.ofSeconds(0); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); + // @formatter:on } @Test @@ -252,34 +285,41 @@ public class OidcIdTokenValidatorTests { this.claims.remove(IdTokenClaimNames.AUD); this.issuedAt = null; this.expiresAt = null; + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.contains(IdTokenClaimNames.SUB)) - .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD)) - .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT)) - .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); + .allMatch((msg) -> msg.contains(IdTokenClaimNames.SUB)) + .allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD)) + .allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT)) + .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); + // @formatter:on } @Test public void validateFormatError() { this.claims.remove(IdTokenClaimNames.SUB); this.claims.remove(IdTokenClaimNames.AUD); + // @formatter:off assertThat(this.validateIdToken()) .hasSize(1) .extracting(OAuth2Error::getDescription) - .allMatch(msg -> msg.equals("The ID Token contains invalid claims: {sub=null, aud=null}")); + .allMatch((msg) -> msg.equals("The ID Token contains invalid claims: {sub=null, aud=null}")); + // @formatter:on } private Collection validateIdToken() { + // @formatter:off Jwt idToken = Jwt.withTokenValue("token") - .issuedAt(this.issuedAt) - .expiresAt(this.expiresAt) - .headers(h -> h.putAll(this.headers)) - .claims(c -> c.putAll(this.claims)) - .build(); + .issuedAt(this.issuedAt) + .expiresAt(this.expiresAt) + .headers((h) -> h.putAll(this.headers)) + .claims((c) -> c.putAll(this.claims)) + .build(); + // @formatter:on OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build()); validator.setClockSkew(this.clockSkew); return validator.validate(idToken).getErrors(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java index bb2122a5cc..f5d694b549 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java @@ -13,10 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.authentication; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; + import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -30,12 +35,13 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; -import java.util.Map; -import java.util.function.Function; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * @author Joe Grandja @@ -44,8 +50,10 @@ import static org.mockito.Mockito.*; */ public class ReactiveOidcIdTokenDecoderFactoryTests { + // @formatter:off private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration() .scope("openid"); + // @formatter:on private ReactiveOidcIdTokenDecoderFactory idTokenDecoderFactory; @@ -56,7 +64,8 @@ public class ReactiveOidcIdTokenDecoderFactoryTests { @Test public void createDefaultClaimTypeConvertersWhenCalledThenDefaultsAreCorrect() { - Map> claimTypeConverters = ReactiveOidcIdTokenDecoderFactory.createDefaultClaimTypeConverters(); + Map> claimTypeConverters = ReactiveOidcIdTokenDecoderFactory + .createDefaultClaimTypeConverters(); assertThat(claimTypeConverters).containsKey(IdTokenClaimNames.ISS); assertThat(claimTypeConverters).containsKey(IdTokenClaimNames.AUD); assertThat(claimTypeConverters).containsKey(IdTokenClaimNames.NONCE); @@ -71,85 +80,78 @@ public class ReactiveOidcIdTokenDecoderFactoryTests { @Test public void setJwtValidatorFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.setJwtValidatorFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.setJwtValidatorFactory(null)); } @Test public void setJwsAlgorithmResolverWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.setJwsAlgorithmResolver(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.setJwsAlgorithmResolver(null)); } @Test public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)); } @Test public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)); } @Test public void createDecoderWhenJwsAlgorithmDefaultAndJwkSetUriEmptyThenThrowOAuth2AuthenticationException() { - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured the JwkSet URI."); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured the JwkSet URI."); } @Test public void createDecoderWhenJwsAlgorithmEcAndJwkSetUriEmptyThenThrowOAuth2AuthenticationException() { - this.idTokenDecoderFactory.setJwsAlgorithmResolver(clientRegistration -> SignatureAlgorithm.ES256); - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured the JwkSet URI."); + this.idTokenDecoderFactory.setJwsAlgorithmResolver((clientRegistration) -> SignatureAlgorithm.ES256); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.jwkSetUri(null).build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured the JwkSet URI."); } @Test public void createDecoderWhenJwsAlgorithmHmacAndClientSecretNullThenThrowOAuth2AuthenticationException() { - this.idTokenDecoderFactory.setJwsAlgorithmResolver(clientRegistration -> MacAlgorithm.HS256); - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.clientSecret(null).build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured the client secret."); + this.idTokenDecoderFactory.setJwsAlgorithmResolver((clientRegistration) -> MacAlgorithm.HS256); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.idTokenDecoderFactory.createDecoder(this.registration.clientSecret(null).build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured the client secret."); } @Test public void createDecoderWhenJwsAlgorithmNullThenThrowOAuth2AuthenticationException() { - this.idTokenDecoderFactory.setJwsAlgorithmResolver(clientRegistration -> null); - assertThatThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.build())) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + - "for Client Registration: 'registration-id'. " + - "Check to ensure you have configured a valid JWS Algorithm: 'null'"); + this.idTokenDecoderFactory.setJwsAlgorithmResolver((clientRegistration) -> null); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(this.registration.build())) + .withMessage("[missing_signature_verifier] Failed to find a Signature Verifier " + + "for Client Registration: 'registration-id'. " + + "Check to ensure you have configured a valid JWS Algorithm: 'null'"); } @Test public void createDecoderWhenClientRegistrationValidThenReturnDecoder() { - assertThat(this.idTokenDecoderFactory.createDecoder(this.registration.build())) - .isNotNull(); + assertThat(this.idTokenDecoderFactory.createDecoder(this.registration.build())).isNotNull(); } @Test public void createDecoderWhenCustomJwtValidatorFactorySetThenApplied() { Function> customJwtValidatorFactory = mock(Function.class); this.idTokenDecoderFactory.setJwtValidatorFactory(customJwtValidatorFactory); - ClientRegistration clientRegistration = this.registration.build(); - - when(customJwtValidatorFactory.apply(same(clientRegistration))) - .thenReturn(new OidcIdTokenValidator(clientRegistration)); - + given(customJwtValidatorFactory.apply(same(clientRegistration))) + .willReturn(new OidcIdTokenValidator(clientRegistration)); this.idTokenDecoderFactory.createDecoder(clientRegistration); - verify(customJwtValidatorFactory).apply(same(clientRegistration)); } @@ -157,29 +159,22 @@ public class ReactiveOidcIdTokenDecoderFactoryTests { public void createDecoderWhenCustomJwsAlgorithmResolverSetThenApplied() { Function customJwsAlgorithmResolver = mock(Function.class); this.idTokenDecoderFactory.setJwsAlgorithmResolver(customJwsAlgorithmResolver); - ClientRegistration clientRegistration = this.registration.build(); - - when(customJwsAlgorithmResolver.apply(same(clientRegistration))) - .thenReturn(MacAlgorithm.HS256); - + given(customJwsAlgorithmResolver.apply(same(clientRegistration))).willReturn(MacAlgorithm.HS256); this.idTokenDecoderFactory.createDecoder(clientRegistration); - verify(customJwsAlgorithmResolver).apply(same(clientRegistration)); } @Test public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { - Function, Map>> customClaimTypeConverterFactory = mock(Function.class); + Function, Map>> customClaimTypeConverterFactory = mock( + Function.class); this.idTokenDecoderFactory.setClaimTypeConverterFactory(customClaimTypeConverterFactory); - ClientRegistration clientRegistration = this.registration.build(); - - when(customClaimTypeConverterFactory.apply(same(clientRegistration))) - .thenReturn(new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters())); - + given(customClaimTypeConverterFactory.apply(same(clientRegistration))) + .willReturn(new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters())); this.idTokenDecoderFactory.createDecoder(clientRegistration); - verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index e84e9f38dd..e41eafe700 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -41,27 +41,25 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.any; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.same; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; -import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * @author Rob Winch @@ -69,19 +67,17 @@ import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idT */ @RunWith(MockitoJUnitRunner.class) public class OidcReactiveOAuth2UserServiceTests { + @Mock private ReactiveOAuth2UserService oauth2UserService; private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration() .userNameAttributeName(IdTokenClaimNames.SUB); - private OidcIdToken idToken = idToken().build(); + private OidcIdToken idToken = TestOidcIdTokens.idToken().build(); - private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", - Instant.now(), - Instant.now().plus(Duration.ofDays(1)), - Collections.singleton("read:user")); + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofDays(1)), Collections.singleton("read:user")); private OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); @@ -92,7 +88,8 @@ public class OidcReactiveOAuth2UserServiceTests { @Test public void createDefaultClaimTypeConvertersWhenCalledThenDefaultsAreCorrect() { - Map> claimTypeConverters = OidcReactiveOAuth2UserService.createDefaultClaimTypeConverters(); + Map> claimTypeConverters = OidcReactiveOAuth2UserService + .createDefaultClaimTypeConverters(); assertThat(claimTypeConverters).containsKey(StandardClaimNames.EMAIL_VERIFIED); assertThat(claimTypeConverters).containsKey(StandardClaimNames.PHONE_NUMBER_VERIFIED); assertThat(claimTypeConverters).containsKey(StandardClaimNames.UPDATED_AT); @@ -100,35 +97,30 @@ public class OidcReactiveOAuth2UserServiceTests { @Test public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.userService.setClaimTypeConverterFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null)); } @Test public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() { this.registration.userInfoUri(null); - OidcUser user = this.userService.loadUser(userRequest()).block(); - assertThat(user.getUserInfo()).isNull(); } @Test public void loadUserWhenOAuth2UserEmptyThenNullUserInfo() { - when(this.oauth2UserService.loadUser(any())).thenReturn(Mono.empty()); - + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.empty()); OidcUser user = this.userService.loadUser(userRequest()).block(); - assertThat(user.getUserInfo()).isNull(); } @Test public void loadUserWhenOAuth2UserSubjectNullThenOAuth2AuthenticationException() { - OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user"); - when(this.oauth2UserService.loadUser(any())).thenReturn(Mono.just(oauth2User)); - - assertThatCode(() -> this.userService.loadUser(userRequest()).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService.loadUser(userRequest()).block()); } @Test @@ -136,12 +128,11 @@ public class OidcReactiveOAuth2UserServiceTests { Map attributes = new HashMap<>(); attributes.put(StandardClaimNames.SUB, "not-equal"); attributes.put("user", "rob"); - OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), - attributes, "user"); - when(this.oauth2UserService.loadUser(any())).thenReturn(Mono.just(oauth2User)); - - assertThatCode(() -> this.userService.loadUser(userRequest()).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes, + "user"); + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService.loadUser(userRequest()).block()); } @Test @@ -149,10 +140,9 @@ public class OidcReactiveOAuth2UserServiceTests { Map attributes = new HashMap<>(); attributes.put(StandardClaimNames.SUB, "subject"); attributes.put("user", "rob"); - OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), - attributes, "user"); - when(this.oauth2UserService.loadUser(any())).thenReturn(Mono.just(oauth2User)); - + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes, + "user"); + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User)); assertThat(this.userService.loadUser(userRequest()).block().getUserInfo()).isNotNull(); } @@ -162,10 +152,9 @@ public class OidcReactiveOAuth2UserServiceTests { Map attributes = new HashMap<>(); attributes.put(StandardClaimNames.SUB, "subject"); attributes.put("user", "rob"); - OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), - attributes, "user"); - when(this.oauth2UserService.loadUser(any())).thenReturn(Mono.just(oauth2User)); - + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes, + "user"); + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User)); assertThat(this.userService.loadUser(userRequest()).block().getName()).isEqualTo("rob"); } @@ -174,30 +163,25 @@ public class OidcReactiveOAuth2UserServiceTests { Map attributes = new HashMap<>(); attributes.put(StandardClaimNames.SUB, "subject"); attributes.put("user", "rob"); - OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), - attributes, "user"); - when(this.oauth2UserService.loadUser(any())).thenReturn(Mono.just(oauth2User)); - + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes, + "user"); + given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User)); OidcUserRequest userRequest = userRequest(); - - Function, Map>> customClaimTypeConverterFactory = mock(Function.class); + Function, Map>> customClaimTypeConverterFactory = mock( + Function.class); this.userService.setClaimTypeConverterFactory(customClaimTypeConverterFactory); - - when(customClaimTypeConverterFactory.apply(same(userRequest.getClientRegistration()))) - .thenReturn(new ClaimTypeConverter(OidcReactiveOAuth2UserService.createDefaultClaimTypeConverters())); - + given(customClaimTypeConverterFactory.apply(same(userRequest.getClientRegistration()))) + .willReturn(new ClaimTypeConverter(OidcReactiveOAuth2UserService.createDefaultClaimTypeConverters())); this.userService.loadUser(userRequest).block().getUserInfo(); - verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration())); } @Test public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); - OidcUserRequest request = new OidcUserRequest( - clientRegistration().build(), scopes("message:read", "message:write"), idToken().build()); + OidcUserRequest request = new OidcUserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.scopes("message:read", "message:write"), TestOidcIdTokens.idToken().build()); OidcUser user = userService.loadUser(request).block(); - assertThat(user.getAuthorities()).hasSize(3); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); @@ -208,10 +192,9 @@ public class OidcReactiveOAuth2UserServiceTests { @Test public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); - OidcUserRequest request = new OidcUserRequest( - clientRegistration().build(), noScopes(), idToken().build()); + OidcUserRequest request = new OidcUserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.noScopes(), TestOidcIdTokens.idToken().build()); OidcUser user = userService.loadUser(request).block(); - assertThat(user.getAuthorities()).hasSize(1); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); @@ -220,4 +203,5 @@ public class OidcReactiveOAuth2UserServiceTests { private OidcUserRequest userRequest() { return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java index b1e2b5481b..af20f2c911 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; @@ -25,13 +26,13 @@ import org.junit.Before; import org.junit.Test; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OidcUserRequest}. @@ -39,18 +40,21 @@ import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idT * @author Joe Grandja */ public class OidcUserRequestTests { + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; + private OidcIdToken idToken; + private Map additionalParameters; @Before public void setUp() { - this.clientRegistration = clientRegistration().build(); - this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "access-token-1234", Instant.now(), Instant.now().plusSeconds(60), - new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); - this.idToken = idToken().authorizedParty(this.clientRegistration.getClientId()).build(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), + Instant.now().plusSeconds(60), new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); + this.idToken = TestOidcIdTokens.idToken().authorizedParty(this.clientRegistration.getClientId()).build(); this.additionalParameters = new HashMap<>(); this.additionalParameters.put("param1", "value1"); this.additionalParameters.put("param2", "value2"); @@ -58,30 +62,30 @@ public class OidcUserRequestTests { @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken)); } @Test public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken)); } @Test public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null)); } @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OidcUserRequest userRequest = new OidcUserRequest( - this.clientRegistration, this.accessToken, this.idToken, this.additionalParameters); - + OidcUserRequest userRequest = new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken, + this.additionalParameters); assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken); assertThat(userRequest.getIdToken()).isEqualTo(this.idToken); assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtilsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtilsTests.java index f5270813f7..045ac4a3f5 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtilsTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtilsTests.java @@ -36,15 +36,13 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 5.1 */ public class OidcUserRequestUtilsTests { + private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); OidcIdToken idToken = TestOidcIdTokens.idToken().build(); - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", - Instant.now(), - Instant.now().plus(Duration.ofDays(1)), - Collections.singleton("read:user")); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), + Instant.now().plus(Duration.ofDays(1)), Collections.singleton("read:user")); @Test public void shouldRetrieveUserInfoWhenEndpointDefinedAndScopesOverlapThenTrue() { @@ -54,25 +52,23 @@ public class OidcUserRequestUtilsTests { @Test public void shouldRetrieveUserInfoWhenNoUserInfoUriThenFalse() { this.registration.userInfoUri(null); - assertThat(OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest())).isFalse(); } @Test public void shouldRetrieveUserInfoWhenDifferentScopesThenFalse() { this.registration.scope("notintoken"); - assertThat(OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest())).isFalse(); } @Test public void shouldRetrieveUserInfoWhenNotAuthorizationCodeThenFalse() { this.registration.authorizationGrantType(AuthorizationGrantType.IMPLICIT); - assertThat(OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest())).isFalse(); } private OidcUserRequest userRequest() { return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 8a5a495d7a..3693d635be 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; @@ -39,29 +40,28 @@ import org.springframework.http.MediaType; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.same; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; -import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * Tests for {@link OidcUserService}. @@ -69,10 +69,15 @@ import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idT * @author Joe Grandja */ public class OidcUserServiceTests { + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private OidcIdToken idToken; + private OidcUserService userService = new OidcUserService(); + private MockWebServer server; @Rule @@ -82,18 +87,14 @@ public class OidcUserServiceTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.clientRegistrationBuilder = clientRegistration() - .userInfoUri(null) + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().userInfoUri(null) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) .userNameAttributeName(StandardClaimNames.SUB); - - this.accessToken = scopes(OidcScopes.OPENID, OidcScopes.PROFILE); - + this.accessToken = TestOAuth2AccessTokens.scopes(OidcScopes.OPENID, OidcScopes.PROFILE); Map idTokenClaims = new HashMap<>(); idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com"); idTokenClaims.put(IdTokenClaimNames.SUB, "subject1"); this.idToken = new OidcIdToken("access-token", Instant.MIN, Instant.MAX, idTokenClaims); - this.userService.setOauth2UserService(new DefaultOAuth2UserService()); } @@ -112,20 +113,17 @@ public class OidcUserServiceTests { @Test public void setOauth2UserServiceWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.userService.setOauth2UserService(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setOauth2UserService(null)); } @Test public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.userService.setClaimTypeConverterFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null)); } @Test public void setAccessibleScopesWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.userService.setAccessibleScopes(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setAccessibleScopes(null)); } @Test @@ -141,117 +139,105 @@ public class OidcUserServiceTests { @Test public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() { - OidcUser user = this.userService.loadUser( - new OidcUserRequest(this.clientRegistrationBuilder.build(), this.accessToken, this.idToken)); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(this.clientRegistrationBuilder.build(), this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNull(); } @Test public void loadUserWhenNonStandardScopesAuthorizedThenUserInfoEndpointNotRequested() { - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri("https://provider.com/user").build(); - this.accessToken = scopes("scope1", "scope2"); - - OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri("https://provider.com/user") + .build(); + this.accessToken = TestOAuth2AccessTokens.scopes("scope1", "scope2"); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNull(); } // gh-6886 @Test public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesMatchThenUserInfoEndpointRequested() { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - - this.accessToken = scopes("scope1", "scope2"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + this.accessToken = TestOAuth2AccessTokens.scopes("scope1", "scope2"); this.userService.setAccessibleScopes(Collections.singleton("scope2")); - - OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNotNull(); } // gh-6886 @Test public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesEmptyThenUserInfoEndpointRequested() { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - - this.accessToken = scopes("scope1", "scope2"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + this.accessToken = TestOAuth2AccessTokens.scopes("scope1", "scope2"); this.userService.setAccessibleScopes(Collections.emptySet()); - - OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNotNull(); } // gh-6886 @Test public void loadUserWhenStandardScopesAuthorizedThenUserInfoEndpointRequested() { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - - OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNotNull(); } @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - - OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getIdToken()).isNotNull(); assertThat(user.getUserInfo()).isNotNull(); assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6); @@ -263,7 +249,6 @@ public class OidcUserServiceTests { assertThat(user.getUserInfo().getFamilyName()).isEqualTo("last"); assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1"); assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(3); assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class); OidcUserAuthority userAuthority = (OidcUserAuthority) user.getAuthorities().iterator().next(); @@ -277,19 +262,16 @@ public class OidcUserServiceTests { public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_user_info_response")); - - String userInfoResponse = "{\n" + - " \"email\": \"full_name@provider.com\",\n" + - " \"name\": \"full name\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"email\": \"full_name@provider.com\",\n" + + " \"name\": \"full name\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userNameAttributeName(StandardClaimNames.EMAIL).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @@ -297,113 +279,92 @@ public class OidcUserServiceTests { public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_user_info_response")); - - String userInfoResponse = "{\n" + - " \"sub\": \"other-subject\"\n" + - "}\n"; + String userInfoResponse = "{\n" + " \"sub\": \"other-subject\"\n" + "}\n"; this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n"; -// "}\n"; // Make the JSON invalid/malformed + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); this.server.enqueue(new MockResponse().setResponseCode(500)); - - String userInfoUri = server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + String userInfoUri = this.server.url("/user").toString(); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "https://invalid-provider.com/user"; - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userNameAttributeName(StandardClaimNames.EMAIL).build(); - - OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); - + OidcUser user = this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getName()).isEqualTo("user1@example.com"); } // gh-5294 @Test public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) .isEqualTo(MediaType.APPLICATION_JSON_VALUE); @@ -412,47 +373,44 @@ public class OidcUserServiceTests { // gh-5500 @Test public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.FORM).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); @@ -463,40 +421,34 @@ public class OidcUserServiceTests { @Test public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { - String userInfoResponse = "{\n" + - " \"sub\": \"subject1\",\n" + - " \"name\": \"first last\",\n" + - " \"given_name\": \"first\",\n" + - " \"family_name\": \"last\",\n" + - " \"preferred_username\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .build(); - - Function, Map>> customClaimTypeConverterFactory = mock(Function.class); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); + Function, Map>> customClaimTypeConverterFactory = mock( + Function.class); this.userService.setClaimTypeConverterFactory(customClaimTypeConverterFactory); - - when(customClaimTypeConverterFactory.apply(same(clientRegistration))) - .thenReturn(new ClaimTypeConverter(OidcUserService.createDefaultClaimTypeConverters())); - + given(customClaimTypeConverterFactory.apply(same(clientRegistration))) + .willReturn(new ClaimTypeConverter(OidcUserService.createDefaultClaimTypeConverters())); this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); - verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } @Test public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { OidcUserService userService = new OidcUserService(); - OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), - scopes("message:read", "message:write"), idToken().build()); + OidcUserRequest request = new OidcUserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.scopes("message:read", "message:write"), TestOidcIdTokens.idToken().build()); OidcUser user = userService.loadUser(request); - assertThat(user.getAuthorities()).hasSize(3); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); @@ -507,18 +459,20 @@ public class OidcUserServiceTests { @Test public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { OidcUserService userService = new OidcUserService(); - OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), - noScopes(), idToken().build()); + OidcUserRequest request = new OidcUserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.noScopes(), TestOidcIdTokens.idToken().build()); OidcUser user = userService.loadUser(request); - assertThat(user.getAuthorities()).hasSize(1); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); } private MockResponse jsonResponse(String json) { + // @formatter:off return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setBody(json); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandlerTests.java index 59a7c08150..757ebd03b2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandlerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcClientInitiatedLogoutSuccessHandlerTests.java @@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.oidc.web.logout; import java.io.IOException; import java.net.URI; import java.util.Collections; + import javax.servlet.ServletException; import org.junit.Before; @@ -39,7 +40,7 @@ import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; import org.springframework.security.oauth2.core.user.TestOAuth2Users; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; /** @@ -47,14 +48,18 @@ import static org.mockito.Mockito.mock; */ @RunWith(MockitoJUnitRunner.class) public class OidcClientInitiatedLogoutSuccessHandlerTests { + + // @formatter:off ClientRegistration registration = TestClientRegistrations .clientRegistration() - .providerConfigurationMetadata( - Collections.singletonMap("end_session_endpoint", "https://endpoint")) + .providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://endpoint")) .build(); - ClientRegistrationRepository repository = new InMemoryClientRegistrationRepository(registration); + // @formatter:on + + ClientRegistrationRepository repository = new InMemoryClientRegistrationRepository(this.registration); MockHttpServletRequest request; + MockHttpServletResponse response; OidcClientInitiatedLogoutSuccessHandler handler; @@ -67,113 +72,80 @@ public class OidcClientInitiatedLogoutSuccessHandlerTests { } @Test - public void logoutWhenOidcRedirectUrlConfiguredThenRedirects() - throws IOException, ServletException { - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - + public void logoutWhenOidcRedirectUrlConfiguredThenRedirects() throws IOException, ServletException { + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); this.request.setUserPrincipal(token); this.handler.onLogoutSuccess(this.request, this.response, token); - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://endpoint?id_token_hint=id-token"); } @Test - public void logoutWhenNotOAuth2AuthenticationThenDefaults() - throws IOException, ServletException { + public void logoutWhenNotOAuth2AuthenticationThenDefaults() throws IOException, ServletException { Authentication token = mock(Authentication.class); - this.request.setUserPrincipal(token); this.handler.setDefaultTargetUrl("https://default"); this.handler.onLogoutSuccess(this.request, this.response, token); - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://default"); } @Test - public void logoutWhenNotOidcUserThenDefaults() - throws IOException, ServletException { - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOAuth2Users.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - + public void logoutWhenNotOidcUserThenDefaults() throws IOException, ServletException { + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOAuth2Users.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); this.request.setUserPrincipal(token); this.handler.setDefaultTargetUrl("https://default"); this.handler.onLogoutSuccess(this.request, this.response, token); - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://default"); } @Test - public void logoutWhenClientRegistrationHasNoEndSessionEndpointThenDefaults() - throws Exception { - + public void logoutWhenClientRegistrationHasNoEndSessionEndpointThenDefaults() throws Exception { ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); ClientRegistrationRepository repository = new InMemoryClientRegistrationRepository(registration); OidcClientInitiatedLogoutSuccessHandler handler = new OidcClientInitiatedLogoutSuccessHandler(repository); - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - registration.getRegistrationId()); - + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, registration.getRegistrationId()); this.request.setUserPrincipal(token); handler.setDefaultTargetUrl("https://default"); handler.onLogoutSuccess(this.request, this.response, token); - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://default"); } @Test - public void logoutWhenUsingPostLogoutRedirectUriThenIncludesItInRedirect() - throws IOException, ServletException { - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - + public void logoutWhenUsingPostLogoutRedirectUriThenIncludesItInRedirect() throws IOException, ServletException { + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); this.handler.setPostLogoutRedirectUri(URI.create("https://postlogout?encodedparam=value")); this.request.setUserPrincipal(token); this.handler.onLogoutSuccess(this.request, this.response, token); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://endpoint?" + - "id_token_hint=id-token&" + - "post_logout_redirect_uri=https://postlogout?encodedparam%3Dvalue"); + assertThat(this.response.getRedirectedUrl()).isEqualTo("https://endpoint?" + "id_token_hint=id-token&" + + "post_logout_redirect_uri=https://postlogout?encodedparam%3Dvalue"); } @Test public void logoutWhenUsingPostLogoutRedirectUriTemplateThenBuildsItForRedirect() throws IOException, ServletException { - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); this.handler.setPostLogoutRedirectUri("{baseUrl}"); this.request.setScheme("https"); this.request.setServerPort(443); this.request.setServerName("rp.example.org"); this.request.setUserPrincipal(token); this.handler.onLogoutSuccess(this.request, this.response, token); - - assertThat(this.response.getRedirectedUrl()).isEqualTo("https://endpoint?" + - "id_token_hint=id-token&" + - "post_logout_redirect_uri=https://rp.example.org"); + assertThat(this.response.getRedirectedUrl()).isEqualTo( + "https://endpoint?" + "id_token_hint=id-token&" + "post_logout_redirect_uri=https://rp.example.org"); } @Test public void setPostLogoutRedirectUriWhenGivenNullThenThrowsException() { - assertThatThrownBy(() -> this.handler.setPostLogoutRedirectUri((URI) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setPostLogoutRedirectUri((URI) null)); } @Test public void setPostLogoutRedirectUriTemplateWhenGivenNullThenThrowsException() { - assertThatThrownBy(() -> this.handler.setPostLogoutRedirectUri((String) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setPostLogoutRedirectUri((String) null)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java index 0dd7ad7209..5714c5f4c1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java @@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.oidc.web.server.logout; import java.io.IOException; import java.net.URI; import java.util.Collections; + import javax.servlet.ServletException; import org.junit.Before; @@ -41,22 +42,27 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link OidcClientInitiatedServerLogoutSuccessHandler} */ public class OidcClientInitiatedServerLogoutSuccessHandlerTests { + + // @formatter:off ClientRegistration registration = TestClientRegistrations .clientRegistration() - .providerConfigurationMetadata( - Collections.singletonMap("end_session_endpoint", "https://endpoint")) + .providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://endpoint")) .build(); - ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository(registration); + // @formatter:on + + ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository( + this.registration); ServerWebExchange exchange; + WebFilterChain chain; OidcClientInitiatedServerLogoutSuccessHandler handler; @@ -64,133 +70,98 @@ public class OidcClientInitiatedServerLogoutSuccessHandlerTests { @Before public void setup() { this.exchange = mock(ServerWebExchange.class); - when(this.exchange.getResponse()).thenReturn(new MockServerHttpResponse()); - when(this.exchange.getRequest()).thenReturn(MockServerHttpRequest.get("/").build()); + given(this.exchange.getResponse()).willReturn(new MockServerHttpResponse()); + given(this.exchange.getRequest()).willReturn(MockServerHttpRequest.get("/").build()); this.chain = mock(WebFilterChain.class); this.handler = new OidcClientInitiatedServerLogoutSuccessHandler(this.repository); } @Test public void logoutWhenOidcRedirectUrlConfiguredThenRedirects() { - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - - when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); - WebFilterExchange f = new WebFilterExchange(exchange, this.chain); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); + given(this.exchange.getPrincipal()).willReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(this.exchange, this.chain); this.handler.onLogoutSuccess(f, token).block(); - assertThat(redirectedUrl(this.exchange)).isEqualTo("https://endpoint?id_token_hint=id-token"); } @Test public void logoutWhenNotOAuth2AuthenticationThenDefaults() { Authentication token = mock(Authentication.class); - - when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); - WebFilterExchange f = new WebFilterExchange(exchange, this.chain); - + given(this.exchange.getPrincipal()).willReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(this.exchange, this.chain); this.handler.setLogoutSuccessUrl(URI.create("https://default")); this.handler.onLogoutSuccess(f, token).block(); - assertThat(redirectedUrl(this.exchange)).isEqualTo("https://default"); } @Test public void logoutWhenNotOidcUserThenDefaults() { - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOAuth2Users.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - - when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); - WebFilterExchange f = new WebFilterExchange(exchange, this.chain); - + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOAuth2Users.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); + given(this.exchange.getPrincipal()).willReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(this.exchange, this.chain); this.handler.setLogoutSuccessUrl(URI.create("https://default")); this.handler.onLogoutSuccess(f, token).block(); - assertThat(redirectedUrl(this.exchange)).isEqualTo("https://default"); } @Test public void logoutWhenClientRegistrationHasNoEndSessionEndpointThenDefaults() { - ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - ReactiveClientRegistrationRepository repository = - new InMemoryReactiveClientRegistrationRepository(registration); - OidcClientInitiatedServerLogoutSuccessHandler handler = - new OidcClientInitiatedServerLogoutSuccessHandler(repository); - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - registration.getRegistrationId()); - - when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); - WebFilterExchange f = new WebFilterExchange(exchange, this.chain); - + ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository( + registration); + OidcClientInitiatedServerLogoutSuccessHandler handler = new OidcClientInitiatedServerLogoutSuccessHandler( + repository); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, registration.getRegistrationId()); + given(this.exchange.getPrincipal()).willReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(this.exchange, this.chain); handler.setLogoutSuccessUrl(URI.create("https://default")); handler.onLogoutSuccess(f, token).block(); - assertThat(redirectedUrl(this.exchange)).isEqualTo("https://default"); } @Test public void logoutWhenUsingPostLogoutRedirectUriThenIncludesItInRedirect() { - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - - when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); - WebFilterExchange f = new WebFilterExchange(exchange, this.chain); - + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); + given(this.exchange.getPrincipal()).willReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(this.exchange, this.chain); this.handler.setPostLogoutRedirectUri(URI.create("https://postlogout?encodedparam=value")); this.handler.onLogoutSuccess(f, token).block(); - - assertThat(redirectedUrl(this.exchange)) - .isEqualTo("https://endpoint?" + - "id_token_hint=id-token&" + - "post_logout_redirect_uri=https://postlogout?encodedparam%3Dvalue"); + assertThat(redirectedUrl(this.exchange)).isEqualTo("https://endpoint?" + "id_token_hint=id-token&" + + "post_logout_redirect_uri=https://postlogout?encodedparam%3Dvalue"); } @Test public void logoutWhenUsingPostLogoutRedirectUriTemplateThenBuildsItForRedirect() throws IOException, ServletException { - - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( - TestOidcUsers.create(), - AuthorityUtils.NO_AUTHORITIES, - this.registration.getRegistrationId()); - when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); + given(this.exchange.getPrincipal()).willReturn(Mono.just(token)); MockServerHttpRequest request = MockServerHttpRequest.get("https://rp.example.org/").build(); - when(this.exchange.getRequest()).thenReturn(request); - WebFilterExchange f = new WebFilterExchange(exchange, this.chain); - + given(this.exchange.getRequest()).willReturn(request); + WebFilterExchange f = new WebFilterExchange(this.exchange, this.chain); this.handler.setPostLogoutRedirectUri("{baseUrl}"); this.handler.onLogoutSuccess(f, token).block(); - - assertThat(redirectedUrl(this.exchange)) - .isEqualTo("https://endpoint?" + - "id_token_hint=id-token&" + - "post_logout_redirect_uri=https://rp.example.org"); + assertThat(redirectedUrl(this.exchange)).isEqualTo( + "https://endpoint?" + "id_token_hint=id-token&" + "post_logout_redirect_uri=https://rp.example.org"); } @Test public void setPostLogoutRedirectUriWhenGivenNullThenThrowsException() { - assertThatThrownBy(() -> this.handler.setPostLogoutRedirectUri((URI) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setPostLogoutRedirectUri((URI) null)); } @Test public void setPostLogoutRedirectUriTemplateWhenGivenNullThenThrowsException() { - assertThatThrownBy(() -> this.handler.setPostLogoutRedirectUri((String) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setPostLogoutRedirectUri((String) null)); } private String redirectedUrl(ServerWebExchange exchange) { return exchange.getResponse().getHeaders().getFirst("Location"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java index f934e1fb34..fba7dfd338 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.registration; import java.util.Collections; @@ -29,9 +30,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.springframework.security.oauth2.client.registration.ClientRegistration.withClientRegistration; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link ClientRegistration}. @@ -39,19 +38,30 @@ import static org.springframework.security.oauth2.client.registration.TestClient * @author Joe Grandja */ public class ClientRegistrationTests { + private static final String REGISTRATION_ID = "registration-1"; + private static final String CLIENT_ID = "client-1"; + private static final String CLIENT_SECRET = "secret"; + private static final String REDIRECT_URI = "https://example.com"; - private static final Set SCOPES = Collections.unmodifiableSet( - Stream.of("openid", "profile", "email").collect(Collectors.toSet())); + + private static final Set SCOPES = Collections + .unmodifiableSet(Stream.of("openid", "profile", "email").collect(Collectors.toSet())); + private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorization"; + private static final String TOKEN_URI = "https://provider.com/oauth2/token"; + private static final String JWK_SET_URI = "https://provider.com/oauth2/keys"; + private static final String ISSUER_URI = "https://provider.com"; + private static final String CLIENT_NAME = "Client 1"; - private static final Map PROVIDER_CONFIGURATION_METADATA = - Collections.unmodifiableMap(createProviderConfigurationMetadata()); + + private static final Map PROVIDER_CONFIGURATION_METADATA = Collections + .unmodifiableMap(createProviderConfigurationMetadata()); private static Map createProviderConfigurationMetadata() { Map configurationMetadata = new LinkedHashMap<>(); @@ -62,39 +72,42 @@ public class ClientRegistrationTests { @Test(expected = IllegalArgumentException.class) public void buildWhenAuthorizationGrantTypeIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(null) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(null) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test public void buildWhenAuthorizationCodeGrantAllAttributesProvidedThenAllAttributesAreSet() { + // @formatter:off ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .issuerUri(ISSUER_URI) - .providerConfigurationMetadata(PROVIDER_CONFIGURATION_METADATA) - .clientName(CLIENT_NAME) - .build(); - + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .issuerUri(ISSUER_URI) + .providerConfigurationMetadata(PROVIDER_CONFIGURATION_METADATA) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); @@ -104,49 +117,56 @@ public class ClientRegistrationTests { assertThat(registration.getScopes()).isEqualTo(SCOPES); assertThat(registration.getProviderDetails().getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI); assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI); - assertThat(registration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()).isEqualTo(AuthenticationMethod.FORM); + assertThat(registration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()) + .isEqualTo(AuthenticationMethod.FORM); assertThat(registration.getProviderDetails().getJwkSetUri()).isEqualTo(JWK_SET_URI); assertThat(registration.getProviderDetails().getIssuerUri()).isEqualTo(ISSUER_URI); - assertThat(registration.getProviderDetails().getConfigurationMetadata()).isEqualTo(PROVIDER_CONFIGURATION_METADATA); + assertThat(registration.getProviderDetails().getConfigurationMetadata()) + .isEqualTo(PROVIDER_CONFIGURATION_METADATA); assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME); } @Test(expected = IllegalArgumentException.class) public void buildWhenAuthorizationCodeGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(null) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenAuthorizationCodeGrantClientIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(null) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(null) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test public void buildWhenAuthorizationCodeGrantClientSecretIsNullThenDefaultToEmpty() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(null) @@ -160,11 +180,13 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on assertThat(clientRegistration.getClientSecret()).isEqualTo(""); } @Test public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -177,11 +199,13 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); } @Test public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedAndClientSecretNullThenDefaultToNone() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(null) @@ -194,11 +218,13 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE); } @Test public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedAndClientSecretBlankThenDefaultToNone() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(" ") @@ -211,81 +237,91 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE); assertThat(clientRegistration.getClientSecret()).isEqualTo(""); } @Test(expected = IllegalArgumentException.class) public void buildWhenAuthorizationCodeGrantRedirectUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(null) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(null) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } // gh-5494 @Test public void buildWhenAuthorizationCodeGrantScopeIsNullThenScopeNotRequired() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope((String[]) null) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope((String[]) null) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenAuthorizationCodeGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(null) - .tokenUri(TOKEN_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(null) + .tokenUri(TOKEN_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenAuthorizationCodeGrantTokenUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .tokenUri(null) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .jwkSetUri(JWK_SET_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .tokenUri(null) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .jwkSetUri(JWK_SET_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test public void buildWhenAuthorizationCodeGrantClientNameNotProvidedThenDefaultToRegistrationId() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -298,28 +334,32 @@ public class ClientRegistrationTests { .userInfoAuthenticationMethod(AuthenticationMethod.FORM) .jwkSetUri(JWK_SET_URI) .build(); + // @formatter:on assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId()); } @Test public void buildWhenAuthorizationCodeGrantScopeDoesNotContainOpenidThenJwkSetUriNotRequired() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUri(REDIRECT_URI) - .scope("scope1") - .authorizationUri(AUTHORIZATION_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .tokenUri(TOKEN_URI) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri(REDIRECT_URI) + .scope("scope1") + .authorizationUri(AUTHORIZATION_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .tokenUri(TOKEN_URI) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } // gh-5494 @Test public void buildWhenAuthorizationCodeGrantScopeIsNullThenJwkSetUriNotRequired() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -330,10 +370,12 @@ public class ClientRegistrationTests { .tokenUri(TOKEN_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on } @Test public void buildWhenAuthorizationCodeGrantProviderConfigurationMetadataIsNullThenDefaultToEmpty() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -348,12 +390,14 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isNotNull(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty(); } @Test public void buildWhenAuthorizationCodeGrantProviderConfigurationMetadataEmptyThenIsEmpty() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -368,100 +412,114 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); + // @formatter:on assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isNotNull(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty(); } @Test public void buildWhenImplicitGrantAllAttributesProvidedThenAllAttributesAreSet() { + // @formatter:off ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .clientName(CLIENT_NAME) - .build(); - + .clientId(CLIENT_ID) + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT); assertThat(registration.getRedirectUri()).isEqualTo(REDIRECT_URI); assertThat(registration.getScopes()).isEqualTo(SCOPES); assertThat(registration.getProviderDetails().getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI); - assertThat(registration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()).isEqualTo(AuthenticationMethod.FORM); + assertThat(registration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod()) + .isEqualTo(AuthenticationMethod.FORM); assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME); } @Test(expected = IllegalArgumentException.class) public void buildWhenImplicitGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(null) - .clientId(CLIENT_ID) - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenImplicitGrantClientIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(null) - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .clientName(CLIENT_NAME) - .build(); + .clientId(null) + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenImplicitGrantRedirectUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri(null) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(AUTHORIZATION_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri(null) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(AUTHORIZATION_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } // gh-5494 @Test public void buildWhenImplicitGrantScopeIsNullThenScopeNotRequired() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri(REDIRECT_URI) - .scope((String[]) null) - .authorizationUri(AUTHORIZATION_URI) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri(REDIRECT_URI) + .scope((String[]) null) + .authorizationUri(AUTHORIZATION_URI) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenImplicitGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri(REDIRECT_URI) - .scope(SCOPES.toArray(new String[0])) - .authorizationUri(null) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .clientName(CLIENT_NAME) - .build(); + .clientId(CLIENT_ID) + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri(REDIRECT_URI) + .scope(SCOPES.toArray(new String[0])) + .authorizationUri(null) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .clientName(CLIENT_NAME) + .build(); + // @formatter:on } @Test public void buildWhenImplicitGrantClientNameNotProvidedThenDefaultToRegistrationId() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .authorizationGrantType(AuthorizationGrantType.IMPLICIT) @@ -470,12 +528,14 @@ public class ClientRegistrationTests { .authorizationUri(AUTHORIZATION_URI) .userInfoAuthenticationMethod(AuthenticationMethod.FORM) .build(); + // @formatter:on assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId()); } @Test public void buildWhenOverrideRegistrationIdThenOverridden() { String overriddenId = "override"; + // @formatter:off ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .registrationId(overriddenId) .clientId(CLIENT_ID) @@ -489,12 +549,13 @@ public class ClientRegistrationTests { .jwkSetUri(JWK_SET_URI) .clientName(CLIENT_NAME) .build(); - + // @formatter:on assertThat(registration.getRegistrationId()).isEqualTo(overriddenId); } @Test public void buildWhenClientCredentialsGrantAllAttributesProvidedThenAllAttributesAreSet() { + // @formatter:off ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -504,7 +565,7 @@ public class ClientRegistrationTests { .tokenUri(TOKEN_URI) .clientName(CLIENT_NAME) .build(); - + // @formatter:on assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); @@ -517,32 +578,22 @@ public class ClientRegistrationTests { @Test public void buildWhenClientCredentialsGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - ClientRegistration.withRegistrationId(null) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .tokenUri(TOKEN_URI) - .build() - ).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistration.withRegistrationId(null).clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).tokenUri(TOKEN_URI).build()); } @Test public void buildWhenClientCredentialsGrantClientIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(null) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .tokenUri(TOKEN_URI) - .build() - ).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> ClientRegistration.withRegistrationId(REGISTRATION_ID) + .clientId(null).clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).tokenUri(TOKEN_URI).build()); } @Test public void buildWhenClientCredentialsGrantClientSecretIsNullThenDefaultToEmpty() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(null) @@ -550,54 +601,47 @@ public class ClientRegistrationTests { .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .tokenUri(TOKEN_URI) .build(); + // @formatter:on assertThat(clientRegistration.getClientSecret()).isEqualTo(""); } @Test public void buildWhenClientCredentialsGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .tokenUri(TOKEN_URI) .build(); + // @formatter:on assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); } @Test public void buildWhenClientCredentialsGrantTokenUriIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(CLIENT_ID) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .tokenUri(null) - .build() - ).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).tokenUri(null).build()); } // gh-6256 @Test public void buildWhenScopesContainASpaceThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - TestClientRegistrations.clientCredentials() - .scope("openid profile email") - .build() - ).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> TestClientRegistrations.clientCredentials().scope("openid profile email").build()); } @Test public void buildWhenScopesContainAnInvalidCharacterThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - TestClientRegistrations.clientCredentials() - .scope("an\"invalid\"scope") - .build() - ).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> TestClientRegistrations.clientCredentials().scope("an\"invalid\"scope").build()); } @Test public void buildWhenPasswordGrantAllAttributesProvidedThenAllAttributesAreSet() { + // @formatter:off ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) @@ -607,7 +651,7 @@ public class ClientRegistrationTests { .tokenUri(TOKEN_URI) .clientName(CLIENT_NAME) .build(); - + // @formatter:on assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); @@ -620,32 +664,37 @@ public class ClientRegistrationTests { @Test public void buildWhenPasswordGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - ClientRegistration.withRegistrationId(null) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistration.withRegistrationId(null) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .authorizationGrantType(AuthorizationGrantType.PASSWORD) .tokenUri(TOKEN_URI) .build() - ).isInstanceOf(IllegalArgumentException.class); + ); + // @formatter:on } @Test public void buildWhenPasswordGrantClientIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - ClientRegistration.withRegistrationId(REGISTRATION_ID) - .clientId(null) - .clientSecret(CLIENT_SECRET) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .tokenUri(TOKEN_URI) - .build() - ).isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException().isThrownBy(() -> ClientRegistration + .withRegistrationId(REGISTRATION_ID) + .clientId(null) + .clientSecret(CLIENT_SECRET) + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.PASSWORD) + .tokenUri(TOKEN_URI) + .build() + ); + // @formatter:on } @Test public void buildWhenPasswordGrantClientSecretIsNullThenDefaultToEmpty() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(null) @@ -653,37 +702,44 @@ public class ClientRegistrationTests { .authorizationGrantType(AuthorizationGrantType.PASSWORD) .tokenUri(TOKEN_URI) .build(); + // @formatter:on assertThat(clientRegistration.getClientSecret()).isEqualTo(""); } @Test public void buildWhenPasswordGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) .authorizationGrantType(AuthorizationGrantType.PASSWORD) .tokenUri(TOKEN_URI) .build(); + // @formatter:on assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); } @Test public void buildWhenPasswordGrantTokenUriIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - ClientRegistration.withRegistrationId(REGISTRATION_ID) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .authorizationGrantType(AuthorizationGrantType.PASSWORD) .tokenUri(null) .build() - ).isInstanceOf(IllegalArgumentException.class); + ); + // @formatter:on } @Test public void buildWhenCustomGrantAllAttributesProvidedThenAllAttributesAreSet() { AuthorizationGrantType customGrantType = new AuthorizationGrantType("CUSTOM"); - ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) + // @formatter:off + ClientRegistration registration = ClientRegistration + .withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .clientSecret(CLIENT_SECRET) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) @@ -692,7 +748,7 @@ public class ClientRegistrationTests { .tokenUri(TOKEN_URI) .clientName(CLIENT_NAME) .build(); - + // @formatter:on assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); @@ -705,8 +761,8 @@ public class ClientRegistrationTests { @Test public void buildWhenClientRegistrationProvidedThenMakesACopy() { - ClientRegistration clientRegistration = clientRegistration().build(); - ClientRegistration updated = withClientRegistration(clientRegistration).build(); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + ClientRegistration updated = ClientRegistration.withClientRegistration(clientRegistration).build(); assertThat(clientRegistration.getScopes()).isEqualTo(updated.getScopes()); assertThat(clientRegistration.getScopes()).isNotSameAs(updated.getScopes()); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()) @@ -717,65 +773,60 @@ public class ClientRegistrationTests { @Test public void buildWhenClientRegistrationProvidedThenEachPropertyMatches() { - ClientRegistration clientRegistration = clientRegistration().build(); - ClientRegistration updated = withClientRegistration(clientRegistration).build(); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + ClientRegistration updated = ClientRegistration.withClientRegistration(clientRegistration).build(); assertThat(clientRegistration.getRegistrationId()).isEqualTo(updated.getRegistrationId()); assertThat(clientRegistration.getClientId()).isEqualTo(updated.getClientId()); assertThat(clientRegistration.getClientSecret()).isEqualTo(updated.getClientSecret()); assertThat(clientRegistration.getClientAuthenticationMethod()) .isEqualTo(updated.getClientAuthenticationMethod()); - assertThat(clientRegistration.getAuthorizationGrantType()) - .isEqualTo(updated.getAuthorizationGrantType()); - assertThat(clientRegistration.getRedirectUri()) - .isEqualTo(updated.getRedirectUri()); + assertThat(clientRegistration.getAuthorizationGrantType()).isEqualTo(updated.getAuthorizationGrantType()); + assertThat(clientRegistration.getRedirectUri()).isEqualTo(updated.getRedirectUri()); assertThat(clientRegistration.getScopes()).isEqualTo(updated.getScopes()); - ClientRegistration.ProviderDetails providerDetails = clientRegistration.getProviderDetails(); ClientRegistration.ProviderDetails updatedProviderDetails = updated.getProviderDetails(); - assertThat(providerDetails.getAuthorizationUri()) - .isEqualTo(updatedProviderDetails.getAuthorizationUri()); - assertThat(providerDetails.getTokenUri()) - .isEqualTo(updatedProviderDetails.getTokenUri()); - + assertThat(providerDetails.getAuthorizationUri()).isEqualTo(updatedProviderDetails.getAuthorizationUri()); + assertThat(providerDetails.getTokenUri()).isEqualTo(updatedProviderDetails.getTokenUri()); ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint = providerDetails.getUserInfoEndpoint(); - ClientRegistration.ProviderDetails.UserInfoEndpoint updatedUserInfoEndpoint = updatedProviderDetails.getUserInfoEndpoint(); + ClientRegistration.ProviderDetails.UserInfoEndpoint updatedUserInfoEndpoint = updatedProviderDetails + .getUserInfoEndpoint(); assertThat(userInfoEndpoint.getUri()).isEqualTo(updatedUserInfoEndpoint.getUri()); assertThat(userInfoEndpoint.getAuthenticationMethod()) .isEqualTo(updatedUserInfoEndpoint.getAuthenticationMethod()); assertThat(userInfoEndpoint.getUserNameAttributeName()) .isEqualTo(updatedUserInfoEndpoint.getUserNameAttributeName()); - assertThat(providerDetails.getJwkSetUri()).isEqualTo(updatedProviderDetails.getJwkSetUri()); assertThat(providerDetails.getIssuerUri()).isEqualTo(updatedProviderDetails.getIssuerUri()); assertThat(providerDetails.getConfigurationMetadata()) .isEqualTo(updatedProviderDetails.getConfigurationMetadata()); - assertThat(clientRegistration.getClientName()).isEqualTo(updated.getClientName()); } @Test public void buildWhenClientRegistrationValuesOverriddenThenPropagated() { - ClientRegistration clientRegistration = clientRegistration().build(); - ClientRegistration updated = withClientRegistration(clientRegistration) + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + // @formatter:off + ClientRegistration updated = ClientRegistration.withClientRegistration(clientRegistration) .clientSecret("a-new-secret") .scope("a-new-scope") .providerConfigurationMetadata(Collections.singletonMap("a-new-config", "a-new-value")) .build(); - + // @formatter:on assertThat(clientRegistration.getClientSecret()).isNotEqualTo(updated.getClientSecret()); assertThat(updated.getClientSecret()).isEqualTo("a-new-secret"); assertThat(clientRegistration.getScopes()).doesNotContain("a-new-scope"); assertThat(updated.getScopes()).containsExactly("a-new-scope"); - assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()) - .doesNotContainKey("a-new-config").doesNotContainValue("a-new-value"); - assertThat(updated.getProviderDetails().getConfigurationMetadata()) - .containsOnlyKeys("a-new-config").containsValue("a-new-value"); + assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).doesNotContainKey("a-new-config") + .doesNotContainValue("a-new-value"); + assertThat(updated.getProviderDetails().getConfigurationMetadata()).containsOnlyKeys("a-new-config") + .containsValue("a-new-value"); } // gh-8903 @Test public void buildWhenCustomClientAuthenticationMethodProvidedThenSet() { ClientAuthenticationMethod clientAuthenticationMethod = new ClientAuthenticationMethod("tls_client_auth"); + // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) .clientId(CLIENT_ID) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) @@ -784,6 +835,8 @@ public class ClientRegistrationTests { .authorizationUri(AUTHORIZATION_URI) .tokenUri(TOKEN_URI) .build(); + // @formatter:on assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(clientAuthenticationMethod); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTests.java similarity index 73% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTests.java index 9e52579c8d..58227755fb 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTest.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationsTests.java @@ -35,20 +35,21 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * @author Rob Winch * @author Rafiullah Hamedy * @since 5.1 */ -public class ClientRegistrationsTest { +public class ClientRegistrationsTests { /** * Contains all optional parameters that are found in ClientRegistration */ - private static final String DEFAULT_RESPONSE = - "{\n" + // @formatter:off + private static final String DEFAULT_RESPONSE = "{\n" + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" + " \"claims_supported\": [\n" + " \"aud\", \n" @@ -101,6 +102,7 @@ public class ClientRegistrationsTest { + " ], \n" + " \"userinfo_endpoint\": \"https://example.com/oauth2/v3/userinfo\"\n" + "}"; + // @formatter:on private MockWebServer server; @@ -114,7 +116,8 @@ public class ClientRegistrationsTest { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.response = this.mapper.readValue(DEFAULT_RESPONSE, new TypeReference>(){}); + this.response = this.mapper.readValue(DEFAULT_RESPONSE, new TypeReference>() { + }); } @After @@ -132,10 +135,11 @@ public class ClientRegistrationsTest { /** * - * Test compatibility with OpenID v1 discovery endpoint by making a - * OpenID Provider - * Configuration Request as highlighted - * Compatibility Notes of RFC 8414 specification. + * Test compatibility with OpenID v1 discovery endpoint by making a OpenID + * Provider Configuration Request as highlighted + * Compatibility Notes of + * RFC 8414 specification. */ @Test public void issuerWhenOidcFallbackAllInformationThenSuccess() throws Exception { @@ -152,8 +156,7 @@ public class ClientRegistrationsTest { assertIssuerMetadata(registration, provider); } - private void assertIssuerMetadata(ClientRegistration registration, - ClientRegistration.ProviderDetails provider) { + private void assertIssuerMetadata(ClientRegistration registration, ClientRegistration.ProviderDetails provider) { assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(registration.getRegistrationId()).isEqualTo(this.server.getHostName()); @@ -166,25 +169,24 @@ public class ClientRegistrationsTest { assertThat(provider.getConfigurationMetadata()).containsKeys("authorization_endpoint", "claims_supported", "code_challenge_methods_supported", "id_token_signing_alg_values_supported", "issuer", "jwks_uri", "response_types_supported", "revocation_endpoint", "scopes_supported", "subject_types_supported", - "grant_types_supported", "token_endpoint", "token_endpoint_auth_methods_supported", "userinfo_endpoint"); + "grant_types_supported", "token_endpoint", "token_endpoint_auth_methods_supported", + "userinfo_endpoint"); } // gh-7512 @Test public void issuerWhenResponseMissingJwksUriThenThrowsIllegalArgumentException() throws Exception { this.response.remove("jwks_uri"); - assertThatThrownBy(() -> registration("").build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The public JWK set URI must not be null"); + assertThatIllegalArgumentException().isThrownBy(() -> registration("").build()) + .withMessageContaining("The public JWK set URI must not be null"); } // gh-7512 @Test public void issuerWhenOidcFallbackResponseMissingJwksUriThenThrowsIllegalArgumentException() throws Exception { this.response.remove("jwks_uri"); - assertThatThrownBy(() -> registrationOidcFallback("issuer1", null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The public JWK set URI must not be null"); + assertThatIllegalArgumentException().isThrownBy(() -> registrationOidcFallback("issuer1", null).build()) + .withMessageContaining("The public JWK set URI must not be null"); } // gh-7512 @@ -225,232 +227,233 @@ public class ClientRegistrationsTest { @Test public void issuerWhenGrantTypesSupportedNullThenDefaulted() throws Exception { this.response.remove("grant_types_supported"); - ClientRegistration registration = registration("").build(); - assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); } @Test public void issuerWhenOAuth2GrantTypesSupportedNullThenDefaulted() throws Exception { this.response.remove("grant_types_supported"); - ClientRegistration registration = registrationOAuth2("", null).build(); - assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); } /** - * We currently only support authorization_code, so verify we have a meaningful error until we add support. + * We currently only support authorization_code, so verify we have a meaningful error + * until we add support. */ @Test public void issuerWhenGrantTypesSupportedInvalidThenException() { this.response.put("grant_types_supported", Arrays.asList("implicit")); - - assertThatThrownBy(() -> registration("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + this.issuer + "\" returned a configuration of [implicit]"); + assertThatIllegalArgumentException().isThrownBy(() -> registration("")) + .withMessageContaining("Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + + this.issuer + "\" returned a configuration of [implicit]"); } @Test public void issuerWhenOAuth2GrantTypesSupportedInvalidThenException() { this.response.put("grant_types_supported", Arrays.asList("implicit")); - - assertThatThrownBy(() -> registrationOAuth2("", null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + this.issuer + "\" returned a configuration of [implicit]"); + assertThatIllegalArgumentException().isThrownBy(() -> registrationOAuth2("", null)) + .withMessageContaining("Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + + this.issuer + "\" returned a configuration of [implicit]"); } @Test public void issuerWhenTokenEndpointAuthMethodsNullThenDefaulted() throws Exception { this.response.remove("token_endpoint_auth_methods_supported"); - ClientRegistration registration = registration("").build(); - assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); } @Test public void issuerWhenOAuth2TokenEndpointAuthMethodsNullThenDefaulted() throws Exception { this.response.remove("token_endpoint_auth_methods_supported"); - ClientRegistration registration = registrationOAuth2("", null).build(); - assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); } @Test public void issuerWhenTokenEndpointAuthMethodsPostThenMethodIsPost() throws Exception { this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("client_secret_post")); - ClientRegistration registration = registration("").build(); - assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.POST); } @Test public void issuerWhenOAuth2TokenEndpointAuthMethodsPostThenMethodIsPost() throws Exception { this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("client_secret_post")); - ClientRegistration registration = registrationOAuth2("", null).build(); - assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.POST); } @Test public void issuerWhenTokenEndpointAuthMethodsNoneThenMethodIsNone() throws Exception { this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("none")); - ClientRegistration registration = registration("").build(); - assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE); } @Test public void issuerWhenOAuth2TokenEndpointAuthMethodsNoneThenMethodIsNone() throws Exception { this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("none")); - ClientRegistration registration = registrationOAuth2("", null).build(); - assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE); } /** - * We currently only support client_secret_basic, so verify we have a meaningful error until we add support. + * We currently only support client_secret_basic, so verify we have a meaningful error + * until we add support. */ @Test public void issuerWhenTokenEndpointAuthMethodsInvalidThenException() { this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("tls_client_auth")); - - assertThatThrownBy(() -> registration("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer + "\" returned a configuration of [tls_client_auth]"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> registration("")) + .withMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and " + + "ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer + + "\" returned a configuration of [tls_client_auth]"); + // @formatter:on } @Test public void issuerWhenOAuth2TokenEndpointAuthMethodsInvalidThenException() { this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("tls_client_auth")); - - assertThatThrownBy(() -> registrationOAuth2("", null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer + "\" returned a configuration of [tls_client_auth]"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> registrationOAuth2("", null)) + .withMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and " + + "ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer + + "\" returned a configuration of [tls_client_auth]"); + // @formatter:on } @Test public void issuerWhenOAuth2EmptyStringThenMeaningfulErrorMessage() { - assertThatThrownBy(() -> ClientRegistrations.fromIssuerLocation("")) - .hasMessageContaining("issuer cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistrations.fromIssuerLocation("")) + .withMessageContaining("issuer cannot be empty"); + // @formatter:on } @Test public void issuerWhenEmptyStringThenMeaningfulErrorMessage() { - assertThatThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation("")) - .hasMessageContaining("issuer cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation("")) + .withMessageContaining("issuer cannot be empty"); + // @formatter:on } @Test - public void issuerWhenOpenIdConfigurationDoesNotMatchThenMeaningfulErrorMessage() throws Exception { + public void issuerWhenOpenIdConfigurationDoesNotMatchThenMeaningfulErrorMessage() throws Exception { this.issuer = createIssuerFromServer(""); String body = this.mapper.writeValueAsString(this.response); - MockResponse mockResponse = new MockResponse() - .setBody(body) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, + MediaType.APPLICATION_JSON_VALUE); this.server.enqueue(mockResponse); - assertThatThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation(this.issuer)) - .hasMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata did not match the requested issuer \"" + this.issuer + "\""); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation(this.issuer)) + .withMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata did " + + "not match the requested issuer \"" + this.issuer + "\""); + // @formatter:on } @Test - public void issuerWhenOAuth2ConfigurationDoesNotMatchThenMeaningfulErrorMessage() throws Exception { + public void issuerWhenOAuth2ConfigurationDoesNotMatchThenMeaningfulErrorMessage() throws Exception { this.issuer = createIssuerFromServer(""); String body = this.mapper.writeValueAsString(this.response); - MockResponse mockResponse = new MockResponse() - .setBody(body) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, + MediaType.APPLICATION_JSON_VALUE); this.server.enqueue(mockResponse); - assertThatThrownBy(() -> ClientRegistrations.fromIssuerLocation(this.issuer)) - .hasMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata did not match the requested issuer \"" + this.issuer + "\""); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> ClientRegistrations.fromIssuerLocation(this.issuer)) + .withMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata " + + "did not match the requested issuer \"" + this.issuer + "\""); + // @formatter:on } private ClientRegistration.Builder registration(String path) throws Exception { this.issuer = createIssuerFromServer(path); this.response.put("issuer", this.issuer); String body = this.mapper.writeValueAsString(this.response); + // @formatter:off MockResponse mockResponse = new MockResponse() .setBody(body) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + .setHeader(HttpHeaders.CONTENT_TYPE, + MediaType.APPLICATION_JSON_VALUE); this.server.enqueue(mockResponse); - return ClientRegistrations.fromOidcIssuerLocation(this.issuer) - .clientId("client-id") - .clientSecret("client-secret"); + .clientId("client-id") + .clientSecret("client-secret"); + // @formatter:on } private ClientRegistration.Builder registrationOAuth2(String path, String body) throws Exception { this.issuer = createIssuerFromServer(path); this.response.put("issuer", this.issuer); this.issuer = this.server.url(path).toString(); - final String responseBody = body != null ? body : this.mapper.writeValueAsString(this.response); - + final String responseBody = (body != null) ? body : this.mapper.writeValueAsString(this.response); final Dispatcher dispatcher = new Dispatcher() { @Override public MockResponse dispatch(RecordedRequest request) { - switch(request.getPath()) { - case "/.well-known/oauth-authorization-server/issuer1": - case "/.well-known/oauth-authorization-server/": - return buildSuccessMockResponse(responseBody); + switch (request.getPath()) { + case "/.well-known/oauth-authorization-server/issuer1": + case "/.well-known/oauth-authorization-server/": + return buildSuccessMockResponse(responseBody); } return new MockResponse().setResponseCode(404); } }; - this.server.setDispatcher(dispatcher); - + // @formatter:off return ClientRegistrations.fromIssuerLocation(this.issuer) .clientId("client-id") .clientSecret("client-secret"); + // @formatter:on } - private String createIssuerFromServer(String path) { return this.server.url(path).toString(); } /** - * Simulates a situation when the ClientRegistration is used with a legacy application where the OIDC - * Discovery Endpoint is "/issuer1/.well-known/openid-configuration" instead of - * "/.well-known/openid-configuration/issuer1" in which case the first attempt results in HTTP 404 and - * the subsequent call results in 200 OK. + * Simulates a situation when the ClientRegistration is used with a legacy application + * where the OIDC Discovery Endpoint is "/issuer1/.well-known/openid-configuration" + * instead of "/.well-known/openid-configuration/issuer1" in which case the first + * attempt results in HTTP 404 and the subsequent call results in 200 OK. * - * @see Section 5 for more details. + * @see Section 5 for more + * details. */ private ClientRegistration.Builder registrationOidcFallback(String path, String body) throws Exception { this.issuer = createIssuerFromServer(path); this.response.put("issuer", this.issuer); - - String responseBody = body != null ? body : this.mapper.writeValueAsString(this.response); - + String responseBody = (body != null) ? body : this.mapper.writeValueAsString(this.response); final Dispatcher dispatcher = new Dispatcher() { @Override public MockResponse dispatch(RecordedRequest request) { - switch(request.getPath()) { - case "/issuer1/.well-known/openid-configuration": - case "/.well-known/openid-configuration/": - return buildSuccessMockResponse(responseBody); + switch (request.getPath()) { + case "/issuer1/.well-known/openid-configuration": + case "/.well-known/openid-configuration/": + return buildSuccessMockResponse(responseBody); } return new MockResponse().setResponseCode(404); } }; this.server.setDispatcher(dispatcher); - - return ClientRegistrations.fromIssuerLocation(this.issuer) - .clientId("client-id") - .clientSecret("client-secret"); + return ClientRegistrations.fromIssuerLocation(this.issuer).clientId("client-id").clientSecret("client-secret"); } private MockResponse buildSuccessMockResponse(String body) { + // @formatter:off return new MockResponse().setResponseCode(200) .setBody(body) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java index 0778afafec..e222dd63fd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java @@ -16,14 +16,14 @@ package org.springframework.security.oauth2.client.registration; -import org.junit.Test; - import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -34,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 5.0 */ public class InMemoryClientRegistrationRepositoryTests { + private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); private InMemoryClientRegistrationRepository clients = new InMemoryClientRegistrationRepository(this.registration); @@ -94,4 +95,5 @@ public class InMemoryClientRegistrationRepositoryTests { public void iteratorWhenGetThenContainsAll() { assertThat(this.clients).containsOnly(this.registration); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java index 5ca19c509b..43dbefe782 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java @@ -24,7 +24,7 @@ import org.junit.Test; import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * @author Rob Winch @@ -43,22 +43,21 @@ public class InMemoryReactiveClientRegistrationRepositoryTests { @Test public void constructorWhenZeroVarArgsThenIllegalArgumentException() { - assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new InMemoryReactiveClientRegistrationRepository()); } @Test public void constructorWhenClientRegistrationArrayThenIllegalArgumentException() { ClientRegistration[] registrations = null; - assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registrations)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registrations)); } @Test public void constructorWhenClientRegistrationListThenIllegalArgumentException() { List registrations = null; - assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registrations)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registrations)); } @Test(expected = IllegalStateException.class) @@ -70,15 +69,17 @@ public class InMemoryReactiveClientRegistrationRepositoryTests { @Test public void constructorWhenClientRegistrationIsNullThenIllegalArgumentException() { ClientRegistration registration = null; - assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registration)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registration)); } @Test public void findByRegistrationIdWhenValidIdThenFound() { + // @formatter:off StepVerifier.create(this.repository.findByRegistrationId(this.registration.getRegistrationId())) .expectNext(this.registration) .verifyComplete(); + // @formatter:on } @Test @@ -91,4 +92,5 @@ public class InMemoryReactiveClientRegistrationRepositoryTests { public void iteratorWhenContainsGithubThenContains() { assertThat(this.repository).containsOnly(this.registration); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java index 37f65b6487..a219b779cb 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java @@ -23,25 +23,32 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; * @author Rob Winch * @since 5.1 */ -public class TestClientRegistrations { +public final class TestClientRegistrations { + + private TestClientRegistrations() { + } + public static ClientRegistration.Builder clientRegistration() { + // @formatter:off return ClientRegistration.withRegistrationId("registration-id") - .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .scope("read:user") - .authorizationUri("https://example.com/login/oauth/authorize") - .tokenUri("https://example.com/login/oauth/access_token") - .jwkSetUri("https://example.com/oauth2/jwk") - .issuerUri("https://example.com") - .userInfoUri("https://api.example.com/user") - .userNameAttributeName("id") - .clientName("Client Name") - .clientId("client-id") - .clientSecret("client-secret"); + .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://example.com/login/oauth/authorize") + .tokenUri("https://example.com/login/oauth/access_token") + .jwkSetUri("https://example.com/oauth2/jwk") + .issuerUri("https://example.com") + .userInfoUri("https://api.example.com/user") + .userNameAttributeName("id") + .clientName("Client Name") + .clientId("client-id") + .clientSecret("client-secret"); + // @formatter:on } public static ClientRegistration.Builder clientRegistration2() { + // @formatter:off return ClientRegistration.withRegistrationId("registration-id-2") .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) @@ -54,16 +61,20 @@ public class TestClientRegistrations { .clientName("Client Name") .clientId("client-id-2") .clientSecret("client-secret"); + // @formatter:on } public static ClientRegistration.Builder clientCredentials() { + // @formatter:off return clientRegistration() .registrationId("client-credentials") .clientId("client-id") .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS); + // @formatter:on } public static ClientRegistration.Builder password() { + // @formatter:off return ClientRegistration.withRegistrationId("password") .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .authorizationGrantType(AuthorizationGrantType.PASSWORD) @@ -72,5 +83,7 @@ public class TestClientRegistrations { .clientName("Client Name") .clientId("client-id") .clientSecret("client-secret"); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java index 3305a1ebd9..c1279db529 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; import java.util.Collection; @@ -34,15 +35,15 @@ import org.springframework.http.MediaType; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.user.OAuth2User; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; /** * Tests for {@link CustomUserTypesOAuth2UserService}. @@ -51,9 +52,13 @@ import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.no * @author Eddú Meléndez */ public class CustomUserTypesOAuth2UserServiceTests { + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private CustomUserTypesOAuth2UserService userService; + private MockWebServer server; @Rule @@ -64,9 +69,11 @@ public class CustomUserTypesOAuth2UserServiceTests { this.server = new MockWebServer(); this.server.start(); String registrationId = "client-registration-id-1"; - this.clientRegistrationBuilder = clientRegistration().registrationId(registrationId); - this.accessToken = noScopes(); - + // @formatter:off + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() + .registrationId(registrationId); + // @formatter:on + this.accessToken = TestOAuth2AccessTokens.noScopes(); Map> customUserTypes = new HashMap<>(); customUserTypes.put(registrationId, CustomOAuth2User.class); this.userService = new CustomUserTypesOAuth2UserService(customUserTypes); @@ -109,37 +116,35 @@ public class CustomUserTypesOAuth2UserServiceTests { @Test public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() { - ClientRegistration clientRegistration = - clientRegistration().registrationId("other-client-registration-id-1").build(); - + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .registrationId("other-client-registration-id-1") + .build(); + // @formatter:on OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThat(user).isNull(); } @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { - String userInfoResponse = "{\n" + - " \"id\": \"12345\",\n" + - " \"name\": \"first last\",\n" + - " \"login\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"id\": \"12345\",\n" + + " \"name\": \"first last\",\n" + + " \"login\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); - assertThat(user.getName()).isEqualTo("first last"); assertThat(user.getAttributes().size()).isEqualTo(4); assertThat((String) user.getAttribute("id")).isEqualTo("12345"); assertThat((String) user.getAttribute("name")).isEqualTo("first last"); assertThat((String) user.getAttribute("login")).isEqualTo("user1"); assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(1); assertThat(user.getAuthorities().iterator().next().getAuthority()).isEqualTo("ROLE_USER"); } @@ -147,71 +152,67 @@ public class CustomUserTypesOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + // @formatter:off + String userInfoResponse = "{\n" + + " \"id\": \"12345\",\n" + + " \"name\": \"first last\",\n" - String userInfoResponse = "{\n" + - " \"id\": \"12345\",\n" + - " \"name\": \"first last\",\n" + - " \"login\": \"user1\",\n" + - " \"email\": \"user1@example.com\"\n"; -// "}\n"; // Make the JSON invalid/malformed + + " \"login\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); this.server.enqueue(new MockResponse().setResponseCode(500)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "https://invalid-provider.com/user"; - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri).build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } private ClientRegistration.Builder withRegistrationId(String registrationId) { - return ClientRegistration - .withRegistrationId(registrationId) + // @formatter:off + return ClientRegistration.withRegistrationId(registrationId) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .clientId("client") .tokenUri("/token"); + // @formatter:on } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } public static class CustomOAuth2User implements OAuth2User { + private List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + private String id; + private String name; + private String login; + private String email; public CustomOAuth2User() { @@ -264,5 +265,7 @@ public class CustomUserTypesOAuth2UserServiceTests { public void setEmail(String email) { this.email = email; } + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 6324322bea..cb6da9fa93 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; import java.util.HashMap; @@ -40,9 +41,11 @@ import org.springframework.http.ResponseEntity; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.web.client.RestOperations; @@ -51,11 +54,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * Tests for {@link DefaultOAuth2UserService}. @@ -64,9 +64,13 @@ import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.sc * @author Eddú Meléndez */ public class DefaultOAuth2UserServiceTests { + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AccessToken accessToken; + private DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + private MockWebServer server; @Rule @@ -76,10 +80,12 @@ public class DefaultOAuth2UserServiceTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.clientRegistrationBuilder = clientRegistration() + // @formatter:off + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() .userInfoUri(null) .userNameAttributeName(null); - this.accessToken = noScopes(); + // @formatter:on + this.accessToken = TestOAuth2AccessTokens.noScopes(); } @After @@ -109,7 +115,6 @@ public class DefaultOAuth2UserServiceTests { public void loadUserWhenUserInfoUriIsNullThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("missing_user_info_uri")); - ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @@ -118,33 +123,31 @@ public class DefaultOAuth2UserServiceTests { public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("missing_user_name_attribute")); - + // @formatter:off ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri("https://provider.com/user").build(); + .userInfoUri("https://provider.com/user") + .build(); + // @formatter:on this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { - String userInfoResponse = "{\n" + - " \"user-name\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"user-name\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); - assertThat(user.getName()).isEqualTo("user1"); assertThat(user.getAttributes().size()).isEqualTo(6); assertThat((String) user.getAttribute("user-name")).isEqualTo("user1"); @@ -153,7 +156,6 @@ public class DefaultOAuth2UserServiceTests { assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle"); assertThat((String) user.getAttribute("address")).isEqualTo("address"); assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(1); assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class); OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next(); @@ -164,124 +166,101 @@ public class DefaultOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - - String userInfoResponse = "{\n" + - " \"user-name\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n"; -// "}\n"; // Make the JSON invalid/malformed + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + // @formatter:off + String userInfoResponse = "{\n" + + " \"user-name\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - this.exception.expectMessage(containsString("Error Code: insufficient_scope, Error Description: The access token expired")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + this.exception.expectMessage( + containsString("Error Code: insufficient_scope, Error Description: The access token expired")); String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\""; - MockResponse response = new MockResponse(); response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader); response.setResponseCode(400); this.server.enqueue(response); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); this.exception.expectMessage(containsString("Error Code: invalid_token")); - - String userInfoErrorResponse = "{\n" + - " \"error\": \"invalid_token\"\n" + - "}\n"; + // @formatter:off + String userInfoErrorResponse = "{\n" + + " \"error\": \"invalid_token\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); this.server.enqueue(new MockResponse().setResponseCode(500)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - + this.exception.expectMessage(containsString( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "https://invalid-provider.com/user"; - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } // gh-5294 @Test public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception { - String userInfoResponse = "{\n" + - " \"user-name\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"user-name\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) .isEqualTo(MediaType.APPLICATION_JSON_VALUE); @@ -290,50 +269,45 @@ public class DefaultOAuth2UserServiceTests { // gh-5500 @Test public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { - String userInfoResponse = "{\n" + - " \"user-name\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"user-name\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { - String userInfoResponse = "{\n" + - " \"user-name\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"user-name\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = this.server.url("/user").toString(); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); @@ -347,10 +321,9 @@ public class DefaultOAuth2UserServiceTests { Map body = new HashMap<>(); body.put("id", "id"); DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest( - clientRegistration().build(), scopes("message:read", "message:write")); + OAuth2UserRequest request = new OAuth2UserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.scopes("message:read", "message:write")); OAuth2User user = userService.loadUser(request); - assertThat(user.getAuthorities()).hasSize(3); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); @@ -363,10 +336,9 @@ public class DefaultOAuth2UserServiceTests { Map body = new HashMap<>(); body.put("id", "id"); DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest( - clientRegistration().build(), noScopes()); + OAuth2UserRequest request = new OAuth2UserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.noScopes()); OAuth2User user = userService.loadUser(request); - assertThat(user.getAuthorities()).hasSize(1); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); @@ -376,22 +348,16 @@ public class DefaultOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2AuthenticationException() { String userInfoUri = this.server.url("/user").toString(); - this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource " + - "from '" + userInfoUri + "': response contains invalid content type 'text/plain'.")); - + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource " + + "from '" + userInfoUri + "': response contains invalid content type 'text/plain'.")); MockResponse response = new MockResponse(); response.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE); response.setBody("invalid content type"); this.server.enqueue(response); - - ClientRegistration clientRegistration = this.clientRegistrationBuilder - .userInfoUri(userInfoUri) - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("user-name").build(); - + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @@ -399,8 +365,8 @@ public class DefaultOAuth2UserServiceTests { ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); Converter> requestEntityConverter = mock(Converter.class); RestOperations rest = mock(RestOperations.class); - when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) - .thenReturn(responseEntity); + given(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) + .willReturn(responseEntity); DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); userService.setRequestEntityConverter(requestEntityConverter); userService.setRestOperations(rest); @@ -408,8 +374,7 @@ public class DefaultOAuth2UserServiceTests { } private MockResponse jsonResponse(String json) { - return new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java index 476033286a..56822e60b9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java @@ -45,19 +45,17 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * @author Rob Winch @@ -65,12 +63,13 @@ import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.sc * @since 5.1 */ public class DefaultReactiveOAuth2UserServiceTests { + private ClientRegistration.Builder clientRegistration; private DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); - private OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plus(Duration.ofDays(1))); + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", + Instant.now(), Instant.now().plus(Duration.ofDays(1))); private MockWebServer server; @@ -78,11 +77,11 @@ public class DefaultReactiveOAuth2UserServiceTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - String userInfoUri = this.server.url("/user").toString(); - + // @formatter:off this.clientRegistration = TestClientRegistrations.clientRegistration() .userInfoUri(userInfoUri); + // @formatter:on } @After @@ -93,49 +92,44 @@ public class DefaultReactiveOAuth2UserServiceTests { @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { OAuth2UserRequest request = null; - StepVerifier.create(this.userService.loadUser(request)) - .expectError(IllegalArgumentException.class) - .verify(); + StepVerifier.create(this.userService.loadUser(request)).expectError(IllegalArgumentException.class).verify(); } @Test public void loadUserWhenUserInfoUriIsNullThenThrowOAuth2AuthenticationException() { this.clientRegistration.userInfoUri(null); - - StepVerifier.create(this.userService.loadUser(oauth2UserRequest())) - .expectErrorSatisfies(t -> assertThat(t) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("missing_user_info_uri") - ) + StepVerifier.create(this.userService.loadUser(oauth2UserRequest())).expectErrorSatisfies((ex) -> assertThat(ex) + .isInstanceOf(OAuth2AuthenticationException.class).hasMessageContaining("missing_user_info_uri")) .verify(); } @Test public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() { this.clientRegistration.userNameAttributeName(null); - + // @formatter:off StepVerifier.create(this.userService.loadUser(oauth2UserRequest())) - .expectErrorSatisfies(t -> assertThat(t) + .expectErrorSatisfies((ex) -> assertThat(ex) .isInstanceOf(OAuth2AuthenticationException.class) .hasMessageContaining("missing_user_name_attribute") ) .verify(); + // @formatter:on } @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { - String userInfoResponse = "{\n" + - " \"id\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"id\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on enqueueApplicationJsonBody(userInfoResponse); - OAuth2User user = this.userService.loadUser(oauth2UserRequest()).block(); - assertThat(user.getName()).isEqualTo("user1"); assertThat(user.getAttributes().size()).isEqualTo(6); assertThat((String) user.getAttribute("id")).isEqualTo("user1"); @@ -144,7 +138,6 @@ public class DefaultReactiveOAuth2UserServiceTests { assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle"); assertThat((String) user.getAttribute("address")).isEqualTo("address"); assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(1); assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class); OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next(); @@ -156,40 +149,41 @@ public class DefaultReactiveOAuth2UserServiceTests { @Test public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { this.clientRegistration.userInfoAuthenticationMethod(AuthenticationMethod.HEADER); - String userInfoResponse = "{\n" + - " \"id\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + // @formatter:off + String userInfoResponse = "{\n" + + " \"id\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on enqueueApplicationJsonBody(userInfoResponse); - this.userService.loadUser(oauth2UserRequest()).block(); - RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); - assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { - this.clientRegistration.userInfoAuthenticationMethod( AuthenticationMethod.FORM); - String userInfoResponse = "{\n" + - " \"id\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n" + - "}\n"; + this.clientRegistration.userInfoAuthenticationMethod(AuthenticationMethod.FORM); + // @formatter:off + String userInfoResponse = "{\n" + + " \"id\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on enqueueApplicationJsonBody(userInfoResponse); - this.userService.loadUser(oauth2UserRequest()).block(); - RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); @@ -199,35 +193,36 @@ public class DefaultReactiveOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { - String userInfoResponse = "{\n" + - " \"id\": \"user1\",\n" + - " \"first-name\": \"first\",\n" + - " \"last-name\": \"last\",\n" + - " \"middle-name\": \"middle\",\n" + - " \"address\": \"address\",\n" + - " \"email\": \"user1@example.com\"\n"; - // "}\n"; // Make the JSON invalid/malformed + // @formatter:off + String userInfoResponse = "{\n" + + " \"id\": \"user1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n"; + // "}\n"; // Make the JSON invalid/malformed + // @formatter:on enqueueApplicationJsonBody(userInfoResponse); - - assertThatThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("invalid_user_info_response"); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()) + .withMessageContaining("invalid_user_info_response"); } @Test public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() { - this.server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setResponseCode(500).setBody("{}")); - - assertThatThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("invalid_user_info_response"); + this.server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setResponseCode(500).setBody("{}")); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()) + .withMessageContaining("invalid_user_info_response"); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() { this.clientRegistration.userInfoUri("https://invalid-provider.com/user"); - assertThatThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()) - .isInstanceOf(AuthenticationServiceException.class); + assertThatExceptionOfType(AuthenticationServiceException.class) + .isThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()); } @Test @@ -235,10 +230,9 @@ public class DefaultReactiveOAuth2UserServiceTests { Map body = new HashMap<>(); body.put("id", "id"); DefaultReactiveOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest( - clientRegistration().build(), scopes("message:read", "message:write")); + OAuth2UserRequest request = new OAuth2UserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.scopes("message:read", "message:write")); OAuth2User user = userService.loadUser(request).block(); - assertThat(user.getAuthorities()).hasSize(3); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); @@ -251,10 +245,9 @@ public class DefaultReactiveOAuth2UserServiceTests { Map body = new HashMap<>(); body.put("id", "id"); DefaultReactiveOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest( - clientRegistration().build(), noScopes()); + OAuth2UserRequest request = new OAuth2UserRequest(TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.noScopes()); OAuth2User user = userService.loadUser(request).block(); - assertThat(user.getAuthorities()).hasSize(1); Iterator authorities = user.getAuthorities().iterator(); assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); @@ -267,14 +260,13 @@ public class DefaultReactiveOAuth2UserServiceTests { response.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE); response.setBody("invalid content type"); this.server.enqueue(response); - OAuth2UserRequest userRequest = oauth2UserRequest(); - - assertThatThrownBy(() -> this.userService.loadUser(userRequest).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource from '" + - userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri() + "': " + - "response contains invalid content type 'text/plain'"); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService.loadUser(userRequest).block()).withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to " + + "retrieve the UserInfo Resource from '" + userRequest.getClientRegistration() + .getProviderDetails().getUserInfoEndpoint().getUri() + + "': " + "response contains invalid content type 'text/plain'"); } private DefaultReactiveOAuth2UserService withMockResponse(Map body) { @@ -282,13 +274,10 @@ public class DefaultReactiveOAuth2UserServiceTests { WebClient.RequestHeadersUriSpec spec = spy(real.post()); WebClient rest = spy(WebClient.class); WebClient.ResponseSpec clientResponse = mock(WebClient.ResponseSpec.class); - when(rest.get()).thenReturn(spec); - when(spec.retrieve()).thenReturn(clientResponse); - when(clientResponse.onStatus(any(Predicate.class), any(Function.class))) - .thenReturn(clientResponse); - when(clientResponse.bodyToMono(any(ParameterizedTypeReference.class))) - .thenReturn(Mono.just(body)); - + given(rest.get()).willReturn(spec); + given(spec.retrieve()).willReturn(clientResponse); + given(clientResponse.onStatus(any(Predicate.class), any(Function.class))).willReturn(clientResponse); + given(clientResponse.bodyToMono(any(ParameterizedTypeReference.class))).willReturn(Mono.just(body)); DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); userService.setWebClient(rest); return userService; @@ -299,9 +288,8 @@ public class DefaultReactiveOAuth2UserServiceTests { } private void enqueueApplicationJsonBody(String json) { - - this.server.enqueue(new MockResponse() - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .setBody(json)); + this.server.enqueue( + new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserServiceTests.java index f43b18a0ae..f259c60ddf 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserServiceTests.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.userinfo; -import org.junit.Test; -import org.springframework.security.oauth2.core.user.OAuth2User; +package org.springframework.security.oauth2.client.userinfo; import java.util.Arrays; import java.util.Collections; +import org.junit.Test; + +import org.springframework.security.oauth2.core.user.OAuth2User; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link DelegatingOAuth2UserService}. @@ -46,9 +48,9 @@ public class DelegatingOAuth2UserServiceTests { @Test(expected = IllegalArgumentException.class) @SuppressWarnings("unchecked") public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { - DelegatingOAuth2UserService delegatingUserService = - new DelegatingOAuth2UserService<>( - Arrays.asList(mock(OAuth2UserService.class), mock(OAuth2UserService.class))); + OAuth2UserService userService = mock(OAuth2UserService.class); + DelegatingOAuth2UserService delegatingUserService = new DelegatingOAuth2UserService<>( + Arrays.asList(userService, userService)); delegatingUserService.loadUser(null); } @@ -59,11 +61,9 @@ public class DelegatingOAuth2UserServiceTests { OAuth2UserService userService2 = mock(OAuth2UserService.class); OAuth2UserService userService3 = mock(OAuth2UserService.class); OAuth2User mockUser = mock(OAuth2User.class); - when(userService3.loadUser(any(OAuth2UserRequest.class))).thenReturn(mockUser); - - DelegatingOAuth2UserService delegatingUserService = - new DelegatingOAuth2UserService<>(Arrays.asList(userService1, userService2, userService3)); - + given(userService3.loadUser(any(OAuth2UserRequest.class))).willReturn(mockUser); + DelegatingOAuth2UserService delegatingUserService = new DelegatingOAuth2UserService<>( + Arrays.asList(userService1, userService2, userService3)); OAuth2User loadedUser = delegatingUserService.loadUser(mock(OAuth2UserRequest.class)); assertThat(loadedUser).isEqualTo(mockUser); } @@ -74,11 +74,10 @@ public class DelegatingOAuth2UserServiceTests { OAuth2UserService userService1 = mock(OAuth2UserService.class); OAuth2UserService userService2 = mock(OAuth2UserService.class); OAuth2UserService userService3 = mock(OAuth2UserService.class); - - DelegatingOAuth2UserService delegatingUserService = - new DelegatingOAuth2UserService<>(Arrays.asList(userService1, userService2, userService3)); - + DelegatingOAuth2UserService delegatingUserService = new DelegatingOAuth2UserService<>( + Arrays.asList(userService1, userService2, userService3)); OAuth2User loadedUser = delegatingUserService.loadUser(mock(OAuth2UserRequest.class)); assertThat(loadedUser).isNull(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java index 4ad9d4e4aa..a4f975736f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java @@ -13,9 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.userinfo; +import java.time.Instant; +import java.util.Arrays; +import java.util.LinkedHashSet; + import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -27,12 +33,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.MultiValueMap; -import java.time.Instant; -import java.util.Arrays; -import java.util.LinkedHashSet; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; /** * Tests for {@link OAuth2UserRequestEntityConverter}. @@ -40,56 +41,47 @@ import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VAL * @author Joe Grandja */ public class OAuth2UserRequestEntityConverterTests { + private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter(); @SuppressWarnings("unchecked") @Test public void convertWhenAuthenticationMethodHeaderThenGetRequest() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2UserRequest userRequest = new OAuth2UserRequest( - clientRegistration, this.createAccessToken()); - + OAuth2UserRequest userRequest = new OAuth2UserRequest(clientRegistration, this.createAccessToken()); RequestEntity requestEntity = this.converter.convert(userRequest); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.GET); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); - assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo( - "Bearer " + userRequest.getAccessToken().getTokenValue()); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + userRequest.getAccessToken().getTokenValue()); } @SuppressWarnings("unchecked") @Test public void convertWhenAuthenticationMethodFormThenPostRequest() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() - .userInfoAuthenticationMethod(AuthenticationMethod.FORM) - .build(); - OAuth2UserRequest userRequest = new OAuth2UserRequest( - clientRegistration, this.createAccessToken()); - + .userInfoAuthenticationMethod(AuthenticationMethod.FORM).build(); + OAuth2UserRequest userRequest = new OAuth2UserRequest(clientRegistration, this.createAccessToken()); RequestEntity requestEntity = this.converter.convert(userRequest); - assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); - assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); - + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON); - assertThat(headers.getContentType()).isEqualTo( - MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); - + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); - assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)).isEqualTo( - userRequest.getAccessToken().getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)) + .isEqualTo(userRequest.getAccessToken().getTokenValue()); } private OAuth2AccessToken createAccessToken() { - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), - Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write"))); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + Instant.now(), Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write"))); return accessToken; } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java index 48dcd5eea4..c341e72a2b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java @@ -13,14 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.userinfo; -import org.junit.Before; -import org.junit.Test; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AccessToken; +package org.springframework.security.oauth2.client.userinfo; import java.time.Instant; import java.util.Arrays; @@ -28,8 +22,16 @@ import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2UserRequest}. @@ -37,12 +39,16 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2UserRequestTests { + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; + private Map additionalParameters; @Before public void setUp() { + // @formatter:off this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") .clientId("client-1") .clientSecret("secret") @@ -54,9 +60,9 @@ public class OAuth2UserRequestTests { .tokenUri("https://provider.com/oauth2/token") .clientName("Client 1") .build(); - this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "access-token-1234", Instant.now(), Instant.now().plusSeconds(60), - new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); + // @formatter:on + this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), + Instant.now().plusSeconds(60), new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); this.additionalParameters = new HashMap<>(); this.additionalParameters.put("param1", "value1"); this.additionalParameters.put("param2", "value2"); @@ -64,23 +70,21 @@ public class OAuth2UserRequestTests { @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2UserRequest(null, this.accessToken)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2UserRequest(null, this.accessToken)); } @Test public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null)); } @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2UserRequest userRequest = new OAuth2UserRequest( - this.clientRegistration, this.accessToken, this.additionalParameters); - + OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken, + this.additionalParameters); assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken); assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java index e625e362b7..88e7c57ade 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -26,7 +28,7 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -36,34 +38,43 @@ import static org.mockito.Mockito.verify; * @author Joe Grandja */ public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests { + private String registrationId = "registrationId"; + private String principalName = "principalName"; + private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository; + private AuthenticatedPrincipalOAuth2AuthorizedClientRepository authorizedClientRepository; + private MockHttpServletRequest request; + private MockHttpServletResponse response; @Before public void setup() { this.authorizedClientService = mock(OAuth2AuthorizedClientService.class); this.anonymousAuthorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService); - this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository); + this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + this.authorizedClientService); + this.authorizedClientRepository + .setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository); this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); } @Test public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null)); } @Test public void setAuthorizedClientRepositoryWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null)); } @Test @@ -77,14 +88,16 @@ public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests { public void loadAuthorizedClientWhenAnonymousPrincipalThenLoadFromAnonymousRepository() { Authentication authentication = this.createAnonymousPrincipal(); this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request); - verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, this.request); + verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, + this.request); } @Test public void saveAuthorizedClientWhenAuthenticatedPrincipalThenSaveToService() { Authentication authentication = this.createAuthenticatedPrincipal(); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); - this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, + this.response); verify(this.authorizedClientService).saveAuthorizedClient(authorizedClient, authentication); } @@ -92,22 +105,27 @@ public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests { public void saveAuthorizedClientWhenAnonymousPrincipalThenSaveToAnonymousRepository() { Authentication authentication = this.createAnonymousPrincipal(); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); - this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response); - verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, this.request, this.response); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, + this.response); + verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, + this.request, this.response); } @Test public void removeAuthorizedClientWhenAuthenticatedPrincipalThenRemoveFromService() { Authentication authentication = this.createAuthenticatedPrincipal(); - this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, + this.response); verify(this.authorizedClientService).removeAuthorizedClient(this.registrationId, this.principalName); } @Test public void removeAuthorizedClientWhenAnonymousPrincipalThenRemoveFromAnonymousRepository() { Authentication authentication = this.createAnonymousPrincipal(); - this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response); - verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, this.request, this.response); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, + this.response); + verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, + this.request, this.response); } private Authentication createAuthenticatedPrincipal() { @@ -117,6 +135,8 @@ public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests { } private Authentication createAnonymousPrincipal() { - return new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + return new AnonymousAuthenticationToken("key-1234", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index 2a861520a5..2c875115fd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; import java.io.IOException; import java.nio.charset.StandardCharsets; + import javax.servlet.http.HttpServletRequest; + import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentMatchers; import org.mockito.Mockito; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -36,7 +41,7 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.entry; /** @@ -45,13 +50,21 @@ import static org.assertj.core.api.Assertions.entry; * @author Joe Grandja */ public class DefaultOAuth2AuthorizationRequestResolverTests { + private ClientRegistration registration1; + private ClientRegistration registration2; + private ClientRegistration fineRedirectUriTemplateRegistration; + private ClientRegistration pkceRegistration; + private ClientRegistration oidcRegistration; + private ClientRegistrationRepository clientRegistrationRepository; + private final String authorizationRequestBaseUri = "/oauth2/authorization"; + private DefaultOAuth2AuthorizationRequestResolver resolver; @Before @@ -59,6 +72,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build(); this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build(); + // @formatter:off this.pkceRegistration = TestClientRegistrations.clientRegistration() .registrationId("pkce-client-registration-id") .clientId("pkce-client-id") @@ -67,30 +81,31 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { .build(); this.oidcRegistration = TestClientRegistrations.clientRegistration() .registrationId("oidc-registration-id") - .scope(OidcScopes.OPENID).build(); - this.clientRegistrationRepository = new InMemoryClientRegistrationRepository( - this.registration1, this.registration2, this.fineRedirectUriTemplateRegistration, - this.pkceRegistration, this.oidcRegistration); - this.resolver = new DefaultOAuth2AuthorizationRequestResolver( - this.clientRegistrationRepository, this.authorizationRequestBaseUri); + .scope(OidcScopes.OPENID) + .build(); + // @formatter:on + this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, + this.registration2, this.fineRedirectUriTemplateRegistration, this.pkceRegistration, + this.oidcRegistration); + this.resolver = new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, + this.authorizationRequestBaseUri); } @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2AuthorizationRequestResolver(null, this.authorizationRequestBaseUri)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new DefaultOAuth2AuthorizationRequestResolver(null, this.authorizationRequestBaseUri)); } @Test public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, null)); } @Test public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null)); } @Test @@ -98,7 +113,6 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNull(); } @@ -111,27 +125,27 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setContent("foo".getBytes(StandardCharsets.UTF_8)); request.setCharacterEncoding(StandardCharsets.UTF_8.name()); HttpServletRequest spyRequest = Mockito.spy(request); - this.resolver.resolve(spyRequest); - Mockito.verify(spyRequest, Mockito.never()).getReader(); Mockito.verify(spyRequest, Mockito.never()).getInputStream(); - Mockito.verify(spyRequest, Mockito.never()).getParameter(Mockito.anyString()); + Mockito.verify(spyRequest, Mockito.never()).getParameter(ArgumentMatchers.anyString()); Mockito.verify(spyRequest, Mockito.never()).getParameterMap(); Mockito.verify(spyRequest, Mockito.never()).getParameterNames(); - Mockito.verify(spyRequest, Mockito.never()).getParameterValues(Mockito.anyString()); + Mockito.verify(spyRequest, Mockito.never()).getParameterValues(ArgumentMatchers.anyString()); } @Test public void resolveWhenAuthorizationRequestWithInvalidClientThenThrowIllegalArgumentException() { ClientRegistration clientRegistration = this.registration1; - String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() + "-invalid"; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() + + "-invalid"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - - assertThatThrownBy(() -> this.resolver.resolve(request)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Invalid Client Registration with Id: " + clientRegistration.getRegistrationId() + "-invalid"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.resolver.resolve(request)) + .withMessage("Invalid Client Registration with Id: " + clientRegistration.getRegistrationId() + "-invalid"); + // @formatter:on } @Test @@ -140,11 +154,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); - assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo( - clientRegistration.getProviderDetails().getAuthorizationUri()); + assertThat(authorizationRequest.getAuthorizationUri()) + .isEqualTo(clientRegistration.getProviderDetails().getAuthorizationUri()); assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); assertThat(authorizationRequest.getClientId()).isEqualTo(clientRegistration.getClientId()); @@ -152,14 +165,14 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { .isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes()); assertThat(authorizationRequest.getState()).isNotNull(); - assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID); + assertThat(authorizationRequest.getAdditionalParameters()) + .doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAttributes()) .containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId())); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); } @Test @@ -168,8 +181,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, + clientRegistration.getRegistrationId()); assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest.getAttributes()) .containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId())); @@ -181,12 +194,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); - assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo( - clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); } @Test @@ -196,11 +207,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServerPort(8080); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "http://localhost:8080/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("http://localhost:8080/login/oauth2/code/" + clientRegistration.getRegistrationId()); } @Test @@ -211,11 +221,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setScheme("https"); request.setServerPort(8081); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "https://localhost:8081/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("https://localhost:8081/login/oauth2/code/" + clientRegistration.getRegistrationId()); } @Test @@ -226,11 +235,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setScheme("http"); request.setServerPort(80); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); } @Test @@ -241,11 +249,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setScheme("https"); request.setServerPort(443); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); } @Test @@ -256,11 +263,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setScheme("https"); request.setServerPort(-1); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); } // gh-5520 @@ -271,12 +277,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setQueryString("foo=bar"); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); - assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo( - clientRegistration.getRedirectUri()); - assertThat(authorizationRequest.getRedirectUri()).isEqualTo( - "http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUri()); + assertThat(authorizationRequest.getRedirectUri()) + .isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); } @Test @@ -288,13 +292,11 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setServerName("localhost"); request.setServerPort(80); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); } @Test @@ -306,13 +308,11 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { request.setServerName("example.com"); request.setServerPort(443); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=https://example.com/login/oauth2/code/registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=https://example.com/login/oauth2/code/registration-id"); } @Test @@ -321,13 +321,12 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, + clientRegistration.getRegistrationId()); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/authorize/oauth2/code/registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/authorize/oauth2/code/registration-id"); } @Test @@ -336,13 +335,11 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id-2&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id-2"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id-2"); } @Test @@ -352,13 +349,11 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.addParameter("action", "authorize"); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/authorize/oauth2/code/registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/authorize/oauth2/code/registration-id"); } @Test @@ -368,13 +363,11 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.addParameter("action", "login"); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id-2&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id-2"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id-2&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id-2"); } @Test @@ -383,11 +376,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); - assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo( - clientRegistration.getProviderDetails().getAuthorizationUri()); + assertThat(authorizationRequest.getAuthorizationUri()) + .isEqualTo(clientRegistration.getProviderDetails().getAuthorizationUri()); assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); assertThat(authorizationRequest.getClientId()).isEqualTo(clientRegistration.getClientId()); @@ -395,22 +387,21 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { .isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes()); assertThat(authorizationRequest.getState()).isNotNull(); - assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID); + assertThat(authorizationRequest.getAdditionalParameters()) + .doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAdditionalParameters()).containsKey(PkceParameterNames.CODE_CHALLENGE); assertThat(authorizationRequest.getAdditionalParameters()) .contains(entry(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")); assertThat(authorizationRequest.getAttributes()) .contains(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId())); - assertThat(authorizationRequest.getAttributes()) - .containsKey(PkceParameterNames.CODE_VERIFIER); - assertThat((String) authorizationRequest.getAttribute(PkceParameterNames.CODE_VERIFIER)).matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); + assertThat(authorizationRequest.getAttributes()).containsKey(PkceParameterNames.CODE_VERIFIER); + assertThat((String) authorizationRequest.getAttribute(PkceParameterNames.CODE_VERIFIER)) + .matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=pkce-client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/pkce-client-registration-id&" + - "code_challenge_method=S256&" + - "code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=pkce-client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/pkce-client-registration-id&" + + "code_challenge_method=S256&" + "code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } @Test @@ -419,11 +410,10 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest).isNotNull(); - assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo( - clientRegistration.getProviderDetails().getAuthorizationUri()); + assertThat(authorizationRequest.getAuthorizationUri()) + .isEqualTo(clientRegistration.getProviderDetails().getAuthorizationUri()); assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); assertThat(authorizationRequest.getClientId()).isEqualTo(clientRegistration.getClientId()); @@ -431,18 +421,19 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { .isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes()); assertThat(authorizationRequest.getState()).isNotNull(); - assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID); + assertThat(authorizationRequest.getAdditionalParameters()) + .doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAdditionalParameters()).containsKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()) .contains(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId())); assertThat(authorizationRequest.getAttributes()).containsKey(OidcParameterNames.NONCE); - assertThat((String) authorizationRequest.getAttribute(OidcParameterNames.NONCE)).matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); + assertThat((String) authorizationRequest.getAttribute(OidcParameterNames.NONCE)) + .matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + - "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } // gh-7696 @@ -452,20 +443,17 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - - this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer - .additionalParameters(params -> params.remove(OidcParameterNames.NONCE)) - .attributes(attrs -> attrs.remove(OidcParameterNames.NONCE))); - + this.resolver.setAuthorizationRequestCustomizer( + (customizer) -> customizer.additionalParameters((params) -> params.remove(OidcParameterNames.NONCE)) + .attributes((attrs) -> attrs.remove(OidcParameterNames.NONCE))); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id"); } @Test @@ -474,22 +462,17 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - - this.resolver.setAuthorizationRequestCustomizer(customizer -> - customizer.authorizationRequestUri(uriBuilder -> { + this.resolver + .setAuthorizationRequestCustomizer((customizer) -> customizer.authorizationRequestUri((uriBuilder) -> { uriBuilder.queryParam("param1", "value1"); return uriBuilder.build(); - }) - ); - + })); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + - "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + - "param1=value1"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + "param1=value1"); } @Test @@ -498,25 +481,19 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - - this.resolver.setAuthorizationRequestCustomizer(customizer -> - customizer.parameters(params -> { - params.put("appid", params.get("client_id")); - params.remove("client_id"); - }) - ); - + this.resolver.setAuthorizationRequestCustomizer((customizer) -> customizer.parameters((params) -> { + params.put("appid", params.get("client_id")); + params.remove("client_id"); + })); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + - "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + - "appid=client-id"); + assertThat(authorizationRequest.getAuthorizationRequestUri()).matches( + "https://example.com/login/oauth/authorize\\?" + "response_type=code&" + "scope=openid&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + "appid=client-id"); } private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() { + // @formatter:off return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration") .redirectUri("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}") .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) @@ -529,5 +506,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { .clientName("Fine Redirect Uri Template Client") .clientId("fine-redirect-uri-template-client") .clientSecret("client-secret"); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java index aa13f0c879..74ec72f23b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -13,11 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -40,23 +49,17 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.StringUtils; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; /** * Tests for {@link DefaultOAuth2AuthorizedClientManager}. @@ -64,18 +67,31 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class DefaultOAuth2AuthorizedClientManagerTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; + private DefaultOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private ArgumentCaptor authorizationContextCaptor; @SuppressWarnings("unchecked") @@ -87,19 +103,21 @@ public class DefaultOAuth2AuthorizedClientManagerTests { this.contextAttributesMapper = mock(Function.class); this.authorizationSuccessHandler = spy(new OAuth2AuthorizationSuccessHandler() { @Override - public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, Map attributes) { - authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, + public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal, + Map attributes) { + DefaultOAuth2AuthorizedClientManagerTests.this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, principal, (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), (HttpServletResponse) attributes.get(HttpServletResponse.class.getName())); } }); - this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> - authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, - (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), + this.authorizationFailureHandler = spy( + new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler((clientRegistrationId, principal, + attributes) -> this.authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, + principal, (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), (HttpServletResponse) attributes.get(HttpServletResponse.class.getName())))); - this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(this.clientRegistrationRepository, + this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler); @@ -115,111 +133,108 @@ public class DefaultOAuth2AuthorizedClientManagerTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) + .withMessage("clientRegistrationRepository cannot be null"); } @Test public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .withMessage("authorizedClientRepository cannot be null"); } @Test public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientProvider cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .withMessage("authorizedClientProvider cannot be null"); } @Test public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("contextAttributesMapper cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .withMessage("contextAttributesMapper cannot be null"); } @Test public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationSuccessHandler cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .withMessage("authorizationSuccessHandler cannot be null"); } @Test public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationFailureHandler cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .withMessage("authorizationFailureHandler cannot be null"); } @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizeRequest cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(null)) + .withMessage("authorizeRequest cannot be null"); } @Test public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("servletRequest cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .withMessage("servletRequest cannot be null"); } @Test public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), this.request) - .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("servletResponse cannot be null"); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request).build(); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .withMessage("servletResponse cannot be null"); } @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("invalid-registration-id") .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + // @formatter:on + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .withMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); } @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isNull(); verifyNoInteractions(this.authorizationSuccessHandler); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any(), any()); @@ -228,82 +243,77 @@ public class DefaultOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(this.authorizedClient); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(this.authorizedClient), eq(this.principal), any()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(this.authorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(this.authorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(this.authorizedClient), eq(this.principal), + eq(this.request), eq(this.response)); } @SuppressWarnings("unchecked") @Test public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request))).thenReturn(this.authorizedClient); - - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + given(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), eq(this.request))).willReturn(this.authorizedClient); + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(reauthorizedClient); + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(any()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(reauthorizedClient), eq(this.principal), any()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(reauthorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal), + eq(this.request), eq(this.response)); } @Test public void authorizeWhenRequestParameterUsernamePasswordThenMappedToContext() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); - + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(this.clientRegistration); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(this.authorizedClient); // Set custom contextAttributesMapper - this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> { + this.authorizedClientManager.setContextAttributesMapper((authorizeRequest) -> { Map contextAttributes = new HashMap<>(); HttpServletRequest servletRequest = authorizeRequest.getAttribute(HttpServletRequest.class.getName()); String username = servletRequest.getParameter(OAuth2ParameterNames.USERNAME); @@ -314,21 +324,20 @@ public class DefaultOAuth2AuthorizedClientManagerTests { } return contextAttributes; }); - this.request.addParameter(OAuth2ParameterNames.USERNAME, "username"); this.request.addParameter(OAuth2ParameterNames.PASSWORD, "password"); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + // @formatter:off + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on this.authorizedClientManager.authorize(authorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); String username = authorizationContext.getAttribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME); assertThat(username).isEqualTo("username"); @@ -339,89 +348,82 @@ public class DefaultOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); verifyNoInteractions(this.authorizationSuccessHandler); - verify(this.authorizedClientRepository, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(OAuth2AuthorizedClient.class), + eq(this.principal), eq(this.request), eq(this.response)); } @SuppressWarnings("unchecked") @Test public void reauthorizeWhenSupportedProviderThenReauthorized() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(reauthorizedClient); + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizationSuccessHandler).onAuthorizationSuccess( - eq(reauthorizedClient), eq(this.principal), any()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess(eq(reauthorizedClient), eq(this.principal), + any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal), + eq(this.request), eq(this.response)); } @Test public void reauthorizeWhenRequestParameterScopeThenMappedToContext() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(reauthorizedClient); // Override the mock with the default - this.authorizedClientManager.setContextAttributesMapper( - new DefaultOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); - + this.authorizedClientManager + .setContextAttributesMapper(new DefaultOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write"); - + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); + // @formatter:on this.authorizedClientManager.authorize(reauthorizeRequest); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); - String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext + .getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); } @@ -430,49 +432,47 @@ public class DefaultOAuth2AuthorizedClientManagerTests { ClientAuthorizationException authorizationException = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) - .thenThrow(authorizationException); - + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willThrow(authorizationException); + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); - - assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isEqualTo(authorizationException); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - eq(authorizationException), eq(this.principal), any()); - verify(this.authorizedClientRepository).removeAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request), eq(this.response)); + verify(this.authorizationFailureHandler).onAuthorizationFailure(eq(authorizationException), eq(this.principal), + any()); + verify(this.authorizedClientRepository).removeAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), eq(this.request), eq(this.response)); } @Test public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() { ClientAuthorizationException authorizationException = new ClientAuthorizationException( - new OAuth2Error("non-matching-error-code", null, null), - this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) - .thenThrow(authorizationException); - + new OAuth2Error("non-matching-error-code", null, null), this.clientRegistration.getRegistrationId()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willThrow(authorizationException); + // @formatter:off OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attributes(attrs -> { + .attributes((attrs) -> { attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletResponse.class.getName(), this.response); }) .build(); - - assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + // @formatter:on + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isEqualTo(authorizationException); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - eq(authorizationException), eq(this.principal), any()); + verify(this.authorizationFailureHandler).onAuthorizationFailure(eq(authorizationException), eq(this.principal), + any()); verifyNoInteractions(this.authorizedClientRepository); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java index de74b188fc..b41e196ba8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java @@ -13,11 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.publisher.PublisherProbe; +import reactor.util.context.Context; + import org.springframework.http.MediaType; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; @@ -39,19 +49,17 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; -import reactor.test.publisher.PublisherProbe; -import reactor.util.context.Context; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; /** * Tests for {@link DefaultReactiveOAuth2AuthorizedClientManager}. @@ -59,41 +67,55 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private MockServerWebExchange serverWebExchange; + private Context context; + private ArgumentCaptor authorizationContextCaptor; + private PublisherProbe loadAuthorizedClientProbe; + private PublisherProbe saveAuthorizedClientProbe; + private PublisherProbe removeAuthorizedClientProbe; @SuppressWarnings("unchecked") @Before public void setup() { this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); - when(this.clientRegistrationRepository.findByRegistrationId( - anyString())).thenReturn(Mono.empty()); + given(this.clientRegistrationRepository.findByRegistrationId(anyString())).willReturn(Mono.empty()); this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class); this.loadAuthorizedClientProbe = PublisherProbe.empty(); - when(this.authorizedClientRepository.loadAuthorizedClient( - anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.loadAuthorizedClientProbe.mono()); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), + any(ServerWebExchange.class))).willReturn(this.loadAuthorizedClientProbe.mono()); this.saveAuthorizedClientProbe = PublisherProbe.empty(); - when(this.authorizedClientRepository.saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.saveAuthorizedClientProbe.mono()); + given(this.authorizedClientRepository.saveAuthorizedClient(any(OAuth2AuthorizedClient.class), + any(Authentication.class), any(ServerWebExchange.class))) + .willReturn(this.saveAuthorizedClientProbe.mono()); this.removeAuthorizedClientProbe = PublisherProbe.empty(); - when(this.authorizedClientRepository.removeAuthorizedClient( - any(String.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.removeAuthorizedClientProbe.mono()); + given(this.authorizedClientRepository.removeAuthorizedClient(any(String.class), any(Authentication.class), + any(ServerWebExchange.class))).willReturn(this.removeAuthorizedClientProbe.mono()); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).willReturn(Mono.empty()); this.contextAttributesMapper = mock(Function.class); - when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.just(Collections.emptyMap())); + given(this.contextAttributesMapper.apply(any())).willReturn(Mono.just(Collections.emptyMap())); this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); @@ -109,93 +131,89 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy( + () -> new DefaultReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) + .withMessage("clientRegistrationRepository cannot be null"); } @Test public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy( + () -> new DefaultReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .withMessage("authorizedClientRepository cannot be null"); } @Test public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientProvider cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .withMessage("authorizedClientProvider cannot be null"); } @Test public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationSuccessHandler cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .withMessage("authorizationSuccessHandler cannot be null"); } @Test public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationFailureHandler cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .withMessage("authorizationFailureHandler cannot be null"); } @Test public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("contextAttributesMapper cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .withMessage("contextAttributesMapper cannot be null"); } @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizeRequest cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(null).block()) + .withMessage("authorizeRequest cannot be null"); } @Test public void authorizeWhenExchangeIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("serverWebExchange cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .withMessage("serverWebExchange cannot be null"); } @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId("invalid-registration-id").principal(this.principal).build(); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) + .withMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); } @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) .subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isNull(); this.loadAuthorizedClientProbe.assertWasSubscribed(); this.saveAuthorizedClientProbe.assertWasNotSubscribed(); @@ -204,28 +222,24 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.authorizedClient)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) .subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(this.authorizedClient), eq(this.principal), + eq(this.serverWebExchange)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); } @@ -233,29 +247,24 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenNotAuthorizedAndSupportedProviderAndCustomSuccessHandlerThenInvokeCustomSuccessHandler() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.authorizedClient)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - PublisherProbe authorizationSuccessHandlerProbe = PublisherProbe.empty(); - this.authorizedClientManager.setAuthorizationSuccessHandler((client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); - + this.authorizedClientManager.setAuthorizationSuccessHandler( + (client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) .subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); authorizationSuccessHandlerProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); @@ -265,34 +274,27 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenInvalidTokenThenRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - ClientAuthorizationException exception = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) - .subscriberContext(this.context).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(ClientAuthorizationException.class).isThrownBy( + () -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientRepository).removeAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository).removeAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), eq(this.serverWebExchange)); this.removeAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } @@ -300,34 +302,27 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenInvalidGrantThenRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - ClientAuthorizationException exception = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) - .subscriberContext(this.context).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(ClientAuthorizationException.class).isThrownBy( + () -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - - verify(this.authorizedClientRepository).removeAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository).removeAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), eq(this.serverWebExchange)); this.removeAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } @@ -335,32 +330,25 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenServerErrorThenDoNotRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - ClientAuthorizationException exception = new ClientAuthorizationException( new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null), this.clientRegistration.getRegistrationId()); - - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) - .subscriberContext(this.context).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(ClientAuthorizationException.class).isThrownBy( + () -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } @@ -368,31 +356,24 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenOAuth2AuthorizationExceptionThenDoNotRemoveAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - OAuth2AuthorizationException exception = new OAuth2AuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); - - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) - .subscriberContext(this.context).block()) + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy( + () -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } @@ -400,110 +381,87 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void authorizeWhenOAuth2AuthorizationExceptionAndCustomFailureHandlerThenInvokeCustomFailureHandler() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); - OAuth2AuthorizationException exception = new OAuth2AuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); - - when(this.authorizedClientProvider.authorize( - any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); - + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.error(exception)); PublisherProbe authorizationFailureHandlerProbe = PublisherProbe.empty(); - this.authorizedClientManager.setAuthorizationFailureHandler((client, principal, attributes) -> authorizationFailureHandlerProbe.mono()); - - assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) - .subscriberContext(this.context).block()) + this.authorizedClientManager.setAuthorizationFailureHandler( + (client, principal, attributes) -> authorizationFailureHandlerProbe.mono()); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy( + () -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) .isEqualTo(exception); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - authorizationFailureHandlerProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } + @SuppressWarnings("unchecked") @Test public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); this.loadAuthorizedClientProbe = PublisherProbe.of(Mono.just(this.authorizedClient)); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange))).thenReturn(this.loadAuthorizedClientProbe.mono()); - - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + given(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), eq(this.serverWebExchange))).willReturn(this.loadAuthorizedClientProbe.mono()); + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(reauthorizedClient)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) .subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(any()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal), + eq(this.serverWebExchange)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); } @Test public void authorizeWhenRequestFormParameterUsernamePasswordThenMappedToContext() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); - + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.authorizedClient)); // Set custom contextAttributesMapper capable of mapping the form parameters - this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> - currentServerWebExchange() - .flatMap(ServerWebExchange::getFormData) - .map(formData -> { - Map contextAttributes = new HashMap<>(); - String username = formData.getFirst(OAuth2ParameterNames.USERNAME); - contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); - String password = formData.getFirst(OAuth2ParameterNames.PASSWORD); - contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); - return contextAttributes; - }) - ); - - this.serverWebExchange = MockServerWebExchange.builder( - MockServerHttpRequest - .post("/") - .contentType(MediaType.APPLICATION_FORM_URLENCODED) - .body("username=username&password=password")) + this.authorizedClientManager.setContextAttributesMapper((authorizeRequest) -> currentServerWebExchange() + .flatMap(ServerWebExchange::getFormData).map((formData) -> { + Map contextAttributes = new HashMap<>(); + String username = formData.getFirst(OAuth2ParameterNames.USERNAME); + contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); + String password = formData.getFirst(OAuth2ParameterNames.PASSWORD); + contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); + return contextAttributes; + })); + this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/") + .contentType(MediaType.APPLICATION_FORM_URLENCODED).body("username=username&password=password")) .build(); this.context = Context.of(ServerWebExchange.class, this.serverWebExchange); - - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .build(); this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); String username = authorizationContext.getAttribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME); assertThat(username).isEqualTo("username"); @@ -515,19 +473,15 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + .principal(this.principal).build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest) .subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(this.authorizedClient); this.saveAuthorizedClientProbe.assertWasNotSubscribed(); } @@ -535,67 +489,52 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenSupportedProviderThenReauthorized() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(reauthorizedClient)); OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + .principal(this.principal).build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest) .subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizedClient).isSameAs(reauthorizedClient); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal), + eq(this.serverWebExchange)); this.saveAuthorizedClientProbe.assertWasSubscribed(); verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); } @Test public void reauthorizeWhenRequestParameterScopeThenMappedToContext() { - OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); - - when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(reauthorizedClient)); // Override the mock with the default this.authorizedClientManager.setContextAttributesMapper( new DefaultReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); - - this.serverWebExchange = MockServerWebExchange.builder( - MockServerHttpRequest - .get("/") - .queryParam(OAuth2ParameterNames.SCOPE, "read write")) - .build(); + this.serverWebExchange = MockServerWebExchange + .builder(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.SCOPE, "read write")).build(); this.context = Context.of(ServerWebExchange.class, this.serverWebExchange); - OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) - .principal(this.principal) - .build(); + .principal(this.principal).build(); this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block(); - verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); - String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext + .getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); } private Mono currentServerWebExchange() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); + return Mono.subscriberContext().filter((c) -> c.hasKey(ServerWebExchange.class)) + .map((c) -> c.get(ServerWebExchange.class)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index fe8fc5514a..ba329f42e9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -13,22 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.HashMap; +import java.util.Map; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}. @@ -37,8 +39,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; */ @RunWith(MockitoJUnitRunner.class) public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { - private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); + + private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); @Test(expected = IllegalArgumentException.class) public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { @@ -49,9 +51,8 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void loadAuthorizationRequestWhenNotSavedThenReturnNull() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addParameter(OAuth2ParameterNames.STATE, "state-1234"); - OAuth2AuthorizationRequest authorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(authorizationRequest).isNull(); } @@ -59,14 +60,11 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - OAuth2AuthorizationRequest loadedAuthorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + OAuth2AuthorizationRequest loadedAuthorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); } @@ -75,88 +73,71 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - String state1 = "state-1122"; OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); - String state2 = "state-3344"; OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); - String state3 = "state-5566"; OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); - request.addParameter(OAuth2ParameterNames.STATE, state1); - OAuth2AuthorizationRequest loadedAuthorizationRequest1 = - this.authorizationRequestRepository.loadAuthorizationRequest(request); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1); - request.removeParameter(OAuth2ParameterNames.STATE); request.addParameter(OAuth2ParameterNames.STATE, state2); - OAuth2AuthorizationRequest loadedAuthorizationRequest2 = - this.authorizationRequestRepository.loadAuthorizationRequest(request); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); - request.removeParameter(OAuth2ParameterNames.STATE); request.addParameter(OAuth2ParameterNames.STATE, state3); - OAuth2AuthorizationRequest loadedAuthorizationRequest3 = - this.authorizationRequestRepository.loadAuthorizationRequest(request); + OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); } @Test public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenReturnNull() { MockHttpServletRequest request = new MockHttpServletRequest(); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, request, new MockHttpServletResponse()); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, + new MockHttpServletResponse()); assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull(); } @Test public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - - assertThatThrownBy(() -> this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, null, new MockHttpServletResponse())) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizationRequestRepository + .saveAuthorizationRequest(authorizationRequest, null, new MockHttpServletResponse())); } @Test public void saveAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - - assertThatThrownBy(() -> this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, new MockHttpServletRequest(), null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizationRequestRepository + .saveAuthorizationRequest(authorizationRequest, new MockHttpServletRequest(), null)); } @Test public void saveAuthorizationRequestWhenStateNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest() - .state(null) - .build(); - assertThatThrownBy(() -> this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, new MockHttpServletRequest(), new MockHttpServletResponse())) - .isInstanceOf(IllegalArgumentException.class); + OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().state(null).build(); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, + new MockHttpServletRequest(), new MockHttpServletResponse())); } @Test public void saveAuthorizationRequestWhenNotNullThenSaved() { MockHttpServletRequest request = new MockHttpServletRequest(); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, request, new MockHttpServletResponse()); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, + new MockHttpServletResponse()); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - OAuth2AuthorizationRequest loadedAuthorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + OAuth2AuthorizationRequest loadedAuthorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); } @@ -164,15 +145,12 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void saveAuthorizationRequestWhenNoExistingSessionAndDistributedSessionThenSaved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.setSession(new MockDistributedHttpSession()); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, request, new MockHttpServletResponse()); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, + new MockHttpServletResponse()); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - OAuth2AuthorizationRequest loadedAuthorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + OAuth2AuthorizationRequest loadedAuthorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); } @@ -180,19 +158,15 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void saveAuthorizationRequestWhenExistingSessionAndDistributedSessionThenSaved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.setSession(new MockDistributedHttpSession()); - OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().build(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest1, request, new MockHttpServletResponse()); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, + new MockHttpServletResponse()); OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().build(); - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest2, request, new MockHttpServletResponse()); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, + new MockHttpServletResponse()); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest2.getState()); - OAuth2AuthorizationRequest loadedAuthorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + OAuth2AuthorizationRequest loadedAuthorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest2); } @@ -200,51 +174,38 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void saveAuthorizationRequestWhenNullThenRemoved() { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - - this.authorizationRequestRepository.saveAuthorizationRequest( // Save - authorizationRequest, request, response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - this.authorizationRequestRepository.saveAuthorizationRequest( // Null value removes - null, request, response); - - OAuth2AuthorizationRequest loadedAuthorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + this.authorizationRequestRepository.saveAuthorizationRequest(null, request, response); + OAuth2AuthorizationRequest loadedAuthorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(loadedAuthorizationRequest).isNull(); } @Test public void removeAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizationRequestRepository.removeAuthorizationRequest( - null, new MockHttpServletResponse())).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizationRequestRepository + .removeAuthorizationRequest(null, new MockHttpServletResponse())); } @Test public void removeAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizationRequestRepository.removeAuthorizationRequest( - new MockHttpServletRequest(), null)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizationRequestRepository + .removeAuthorizationRequest(new MockHttpServletRequest(), null)); } @Test public void removeAuthorizationRequestWhenSavedThenRemoved() { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, request, response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - OAuth2AuthorizationRequest removedAuthorizationRequest = - this.authorizationRequestRepository.removeAuthorizationRequest(request, response); - OAuth2AuthorizationRequest loadedAuthorizationRequest = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - + OAuth2AuthorizationRequest removedAuthorizationRequest = this.authorizationRequestRepository + .removeAuthorizationRequest(request, response); + OAuth2AuthorizationRequest loadedAuthorizationRequest = this.authorizationRequestRepository + .loadAuthorizationRequest(request); assertThat(removedAuthorizationRequest).isNotNull(); assertThat(loadedAuthorizationRequest).isNull(); } @@ -254,19 +215,13 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void removeAuthorizationRequestWhenSavedThenRemovedFromSession() { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); - - this.authorizationRequestRepository.saveAuthorizationRequest( - authorizationRequest, request, response); - + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); - OAuth2AuthorizationRequest removedAuthorizationRequest = - this.authorizationRequestRepository.removeAuthorizationRequest(request, response); - - String sessionAttributeName = HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + - ".AUTHORIZATION_REQUEST"; - + OAuth2AuthorizationRequest removedAuthorizationRequest = this.authorizationRequestRepository + .removeAuthorizationRequest(request, response); + String sessionAttributeName = HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + + ".AUTHORIZATION_REQUEST"; assertThat(removedAuthorizationRequest).isNotNull(); assertThat(request.getSession().getAttribute(sessionAttributeName)).isNull(); } @@ -275,23 +230,19 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addParameter(OAuth2ParameterNames.STATE, "state-1234"); - MockHttpServletResponse response = new MockHttpServletResponse(); - - OAuth2AuthorizationRequest removedAuthorizationRequest = - this.authorizationRequestRepository.removeAuthorizationRequest(request, response); - + OAuth2AuthorizationRequest removedAuthorizationRequest = this.authorizationRequestRepository + .removeAuthorizationRequest(request, response); assertThat(removedAuthorizationRequest).isNull(); } private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() { - return OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id-1234") - .state("state-1234"); + return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id-1234").state("state-1234"); } static class MockDistributedHttpSession extends MockHttpSession { + @Override public Object getAttribute(String name) { return wrap(super.getAttribute(name)); @@ -308,5 +259,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { } return object; } + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java index 9646ce8c1e..42db859732 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.Map; + +import javax.servlet.http.HttpSession; + import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; @@ -24,11 +30,8 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import javax.servlet.http.HttpSession; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; /** @@ -37,6 +40,7 @@ import static org.mockito.Mockito.mock; * @author Joe Grandja */ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { + private String principalName1 = "principalName-1"; private ClientRegistration registration1 = TestClientRegistrations.clientRegistration().build(); @@ -47,8 +51,7 @@ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { private String registrationId2 = this.registration2.getRegistrationId(); - private HttpSessionOAuth2AuthorizedClientRepository authorizedClientRepository = - new HttpSessionOAuth2AuthorizedClientRepository(); + private HttpSessionOAuth2AuthorizedClientRepository authorizedClientRepository = new HttpSessionOAuth2AuthorizedClientRepository(); private MockHttpServletRequest request; @@ -62,8 +65,8 @@ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { @Test public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.request)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.request)); } @Test @@ -73,69 +76,66 @@ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { @Test public void loadAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null)); } @Test public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() { - OAuth2AuthorizedClient authorizedClient = - this.authorizedClientRepository.loadAuthorizedClient("registration-not-found", null, this.request); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository + .loadAuthorizedClient("registration-not-found", null, this.request); assertThat(authorizedClient).isNull(); } @Test public void loadAuthorizedClientWhenSavedThenReturnAuthorizedClient() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); - - OAuth2AuthorizedClient loadedAuthorizedClient = - this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId1, null, this.request); assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); } @Test public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.request, this.response)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.request, this.response)); } @Test public void saveAuthorizedClientWhenAuthenticationIsNullThenExceptionNotThrown() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); } @Test public void saveAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); - assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null, this.response)) - .isInstanceOf(IllegalArgumentException.class); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientRepository + .saveAuthorizedClient(authorizedClient, null, null, this.response)); } @Test public void saveAuthorizedClientWhenResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); - assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, null)) - .isInstanceOf(IllegalArgumentException.class); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, null)); } @Test public void saveAuthorizedClientWhenSavedThenSavedToSession() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); - HttpSession session = this.request.getSession(false); assertThat(session).isNotNull(); - @SuppressWarnings("unchecked") - Map authorizedClients = (Map) - session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); + Map authorizedClients = (Map) session + .getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); assertThat(authorizedClients).isNotEmpty(); assertThat(authorizedClients).hasSize(1); assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient); @@ -143,8 +143,8 @@ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { @Test public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( - null, null, this.request, this.response)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.removeAuthorizedClient(null, null, this.request, this.response)); } @Test @@ -154,8 +154,8 @@ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { @Test public void removeAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId1, null, null, this.response)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientRepository + .removeAuthorizedClient(this.registrationId1, null, null, this.response)); } @Test @@ -165,74 +165,66 @@ public class HttpSessionOAuth2AuthorizedClientRepositoryTests { @Test public void removeAuthorizedClientWhenNotSavedThenSessionNotCreated() { - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId2, null, this.request, this.response); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId2, null, this.request, this.response); assertThat(this.request.getSession(false)).isNull(); } @Test public void removeAuthorizedClientWhenClient1SavedAndClient2RemovedThenClient1NotRemoved() { - OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response); - // Remove registrationId2 (never added so is not removed either) - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId2, null, this.request, this.response); - - OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId1, null, this.request); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId2, null, this.request, this.response); + OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId1, null, this.request); assertThat(loadedAuthorizedClient1).isNotNull(); assertThat(loadedAuthorizedClient1).isSameAs(authorizedClient1); } @Test public void removeAuthorizedClientWhenSavedThenRemoved() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId2, null, this.request); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId2, null, this.request); assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId2, null, this.request, this.response); - loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId2, null, this.request); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId2, null, this.request, this.response); + loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(this.registrationId2, null, + this.request); assertThat(loadedAuthorizedClient).isNull(); } @Test public void removeAuthorizedClientWhenSavedThenRemovedFromSession() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response); - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId1, null, this.request); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId1, null, this.request); assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId1, null, this.request, this.response); - + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, this.response); HttpSession session = this.request.getSession(false); assertThat(session).isNotNull(); - assertThat(session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS")).isNull(); + assertThat(session + .getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS")) + .isNull(); } @Test public void removeAuthorizedClientWhenClient1Client2SavedAndClient1RemovedThenClient2NotRemoved() { - OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response); - - OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient2, null, this.request, this.response); - - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId1, null, this.request, this.response); - - OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId2, null, this.request); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, this.response); + OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId2, null, this.request); assertThat(loadedAuthorizedClient2).isNotNull(); assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index d1dba0d899..20666fa15e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -13,11 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -38,35 +50,26 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.util.UrlUtils; import org.springframework.util.CollectionUtils; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.any; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; /** * Tests for {@link OAuth2AuthorizationCodeGrantFilter}. @@ -75,13 +78,21 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Parikshit Dutta */ public class OAuth2AuthorizationCodeGrantFilterTests { + private ClientRegistration registration1; + private String principalName1 = "principal-1"; + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private AuthenticationManager authenticationManager; + private AuthorizationRequestRepository authorizationRequestRepository; + private OAuth2AuthorizationCodeGrantFilter filter; @Before @@ -89,11 +100,12 @@ public class OAuth2AuthorizationCodeGrantFilterTests { this.registration1 = TestClientRegistrations.clientRegistration().build(); this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1); this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); - this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService); + this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + this.authorizedClientService); this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); this.authenticationManager = mock(AuthenticationManager.class); - this.filter = spy(new OAuth2AuthorizationCodeGrantFilter( - this.clientRegistrationRepository, this.authorizedClientRepository, this.authenticationManager)); + this.filter = spy(new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, + this.authorizedClientRepository, this.authenticationManager)); this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName1, "password"); authentication.setAuthenticated(true); @@ -109,32 +121,32 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientRepository, this.authenticationManager)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, + this.authorizedClientRepository, this.authenticationManager)); } @Test public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, + this.authenticationManager)); } @Test public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, + this.authorizedClientRepository, null)); } @Test public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)); } @Test public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.filter.setRequestCache(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); } @Test @@ -142,12 +154,11 @@ public class OAuth2AuthorizationCodeGrantFilterTests { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); - // NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter. + // NOTE: A valid Authorization Response contains either a 'code' or 'error' + // parameter. MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @@ -157,9 +168,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(authorizationResponse, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @@ -172,9 +181,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); authorizationResponse.setRequestURI(requestUri + "-no-match"); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(authorizationResponse, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @@ -194,8 +201,8 @@ public class OAuth2AuthorizationCodeGrantFilterTests { MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); this.filter.doFilter(authorizationResponse, response, filterChain); verifyNoInteractions(filterChain); - - // 2) redirect_uri with query parameters AND authorization response additional parameters + // 2) redirect_uri with query parameters AND authorization response additional + // parameters Map additionalParameters = new LinkedHashMap<>(); additionalParameters.put("auth-param1", "value1"); additionalParameters.put("auth-param2", "value2"); @@ -218,7 +225,6 @@ public class OAuth2AuthorizationCodeGrantFilterTests { this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); FilterChain filterChain = mock(FilterChain.class); - // 1) Parameter value Map parametersNotMatch = new LinkedHashMap<>(parameters); parametersNotMatch.put("param2", "value8"); @@ -227,22 +233,18 @@ public class OAuth2AuthorizationCodeGrantFilterTests { authorizationResponse.setSession(authorizationRequest.getSession()); this.filter.doFilter(authorizationResponse, response, filterChain); verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - // 2) Parameter order parametersNotMatch = new LinkedHashMap<>(); parametersNotMatch.put("param2", "value2"); parametersNotMatch.put("param1", "value1"); - authorizationResponse = createAuthorizationResponse( - createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse = createAuthorizationResponse(createAuthorizationRequest(requestUri, parametersNotMatch)); authorizationResponse.setSession(authorizationRequest.getSession()); this.filter.doFilter(authorizationResponse, response, filterChain); verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - // 3) Parameter missing parametersNotMatch = new LinkedHashMap<>(parameters); parametersNotMatch.remove("param2"); - authorizationResponse = createAuthorizationResponse( - createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse = createAuthorizationResponse(createAuthorizationRequest(requestUri, parametersNotMatch)); authorizationResponse.setSession(authorizationRequest.getSession()); this.filter.doFilter(authorizationResponse, response, filterChain); verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); @@ -256,9 +258,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(authorizationResponse, response, filterChain); - assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull(); } @@ -269,13 +269,10 @@ public class OAuth2AuthorizationCodeGrantFilterTests { MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); - OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT); - when(this.authenticationManager.authenticate(any(Authentication.class))) - .thenThrow(new OAuth2AuthorizationException(error)); - + given(this.authenticationManager.authenticate(any(Authentication.class))) + .willThrow(new OAuth2AuthorizationException(error)); this.filter.doFilter(authorizationResponse, response, filterChain); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant"); } @@ -287,11 +284,9 @@ public class OAuth2AuthorizationCodeGrantFilterTests { FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(authorizationResponse, response, filterChain); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( - this.registration1.getRegistrationId(), this.principalName1); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientService + .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1); @@ -307,9 +302,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(authorizationResponse, response, filterChain); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1"); } @@ -327,9 +320,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(request, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); } @@ -338,85 +329,72 @@ public class OAuth2AuthorizationCodeGrantFilterTests { MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - RequestCache requestCache = spy(HttpSessionRequestCache.class); this.filter.setRequestCache(requestCache); - authorizationRequest.setRequestURI("/saved-request"); requestCache.saveRequest(authorizationRequest, response); - this.filter.doFilter(authorizationResponse, response, filterChain); - verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); } @Test - public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception { - AnonymousAuthenticationToken anonymousPrincipal = - new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() + throws Exception { + AnonymousAuthenticationToken anonymousPrincipal = new AnonymousAuthenticationToken("key-1234", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication(anonymousPrincipal); SecurityContextHolder.setContext(securityContext); - MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(authorizationResponse, response, filterChain); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName()); assertThat(authorizedClient.getAccessToken()).isNotNull(); - HttpSession session = authorizationResponse.getSession(false); assertThat(session).isNotNull(); - @SuppressWarnings("unchecked") - Map authorizedClients = (Map) - session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); + Map authorizedClients = (Map) session + .getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); assertThat(authorizedClients).isNotEmpty(); assertThat(authorizedClients).hasSize(1); assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient); } @Test - public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception { + public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() + throws Exception { SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - SecurityContextHolder.setContext(securityContext); // null Authentication - + SecurityContextHolder.setContext(securityContext); // null Authentication MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(authorizationResponse, response, filterChain); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registration1.getRegistrationId(), null, authorizationResponse); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registration1.getRegistrationId(), null, authorizationResponse); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); assertThat(authorizedClient.getAccessToken()).isNotNull(); - HttpSession session = authorizationResponse.getSession(false); assertThat(session).isNotNull(); - @SuppressWarnings("unchecked") - Map authorizedClients = (Map) - session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); + Map authorizedClients = (Map) session + .getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"); assertThat(authorizedClients).isNotEmpty(); assertThat(authorizedClients).hasSize(1); assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient); @@ -426,15 +404,14 @@ public class OAuth2AuthorizationCodeGrantFilterTests { return createAuthorizationRequest(requestUri, new LinkedHashMap<>()); } - private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map parameters) { + private static MockHttpServletRequest createAuthorizationRequest(String requestUri, + Map parameters) { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); if (!CollectionUtils.isEmpty(parameters)) { parameters.forEach(request::addParameter); - request.setQueryString( - parameters.entrySet().stream() - .map(e -> e.getKey() + "=" + e.getValue()) - .collect(Collectors.joining("&"))); + request.setQueryString(parameters.entrySet().stream().map((e) -> e.getKey() + "=" + e.getValue()) + .collect(Collectors.joining("&"))); } return request; } @@ -443,36 +420,35 @@ public class OAuth2AuthorizationCodeGrantFilterTests { return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>()); } - private static MockHttpServletRequest createAuthorizationResponse( - MockHttpServletRequest authorizationRequest, Map additionalParameters) { - MockHttpServletRequest authorizationResponse = new MockHttpServletRequest( - authorizationRequest.getMethod(), authorizationRequest.getRequestURI()); + private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest, + Map additionalParameters) { + MockHttpServletRequest authorizationResponse = new MockHttpServletRequest(authorizationRequest.getMethod(), + authorizationRequest.getRequestURI()); authorizationResponse.setServletPath(authorizationRequest.getRequestURI()); authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter); authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code"); authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state"); additionalParameters.forEach(authorizationResponse::addParameter); - authorizationResponse.setQueryString( - authorizationResponse.getParameterMap().entrySet().stream() - .map(e -> e.getKey() + "=" + e.getValue()[0]) - .collect(Collectors.joining("&"))); + authorizationResponse.setQueryString(authorizationResponse.getParameterMap().entrySet().stream() + .map((e) -> e.getKey() + "=" + e.getValue()[0]).collect(Collectors.joining("&"))); authorizationResponse.setSession(authorizationRequest.getSession()); return authorizationResponse; } private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, - ClientRegistration registration) { + ClientRegistration registration) { Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); - OAuth2AuthorizationRequest authorizationRequest = request() - .attributes(attributes) - .redirectUri(UrlUtils.buildFullRequestUrl(request)).build(); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .attributes(attributes).redirectUri(UrlUtils.buildFullRequestUrl(request)).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); } private void setUpAuthenticationResult(ClientRegistration registration) { - OAuth2AuthorizationCodeAuthenticationToken authentication = - new OAuth2AuthorizationCodeAuthenticationToken(registration, success(), noScopes(), refreshToken()); - when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication); + OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken( + registration, TestOAuth2AuthorizationExchanges.success(), TestOAuth2AccessTokens.noScopes(), + TestOAuth2RefreshTokens.refreshToken()); + given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(authentication); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 7cc320e770..4d3dedacb4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -13,10 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; +import java.lang.reflect.Constructor; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.FilterChain; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpStatus; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -31,19 +44,15 @@ import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.util.ClassUtils; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.FilterChain; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.lang.reflect.Constructor; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}. @@ -51,24 +60,32 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class OAuth2AuthorizationRequestRedirectFilterTests { + private ClientRegistration registration1; + private ClientRegistration registration2; + private ClientRegistration registration3; + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizationRequestRedirectFilter filter; + private RequestCache requestCache; @Before public void setUp() { this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build(); + // @formatter:off this.registration3 = TestClientRegistrations.clientRegistration() - .registrationId("registration-3") - .authorizationGrantType(AuthorizationGrantType.IMPLICIT) - .redirectUri("{baseUrl}/authorize/oauth2/implicit/{registrationId}") - .build(); - this.clientRegistrationRepository = new InMemoryClientRegistrationRepository( - this.registration1, this.registration2, this.registration3); + .registrationId("registration-3") + .authorizationGrantType(AuthorizationGrantType.IMPLICIT) + .redirectUri("{baseUrl}/authorize/oauth2/implicit/{registrationId}") + .build(); + // @formatter:on + this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, + this.registration2, this.registration3); this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository); this.requestCache = mock(RequestCache.class); this.filter.setRequestCache(this.requestCache); @@ -78,34 +95,30 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { Constructor constructor = ClassUtils.getConstructorIfAvailable( OAuth2AuthorizationRequestRedirectFilter.class, ClientRegistrationRepository.class); - assertThatThrownBy(() -> constructor.newInstance(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> constructor.newInstance(null)); } @Test public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, null)); } @Test public void constructorWhenAuthorizationRequestResolverIsNullThenThrowIllegalArgumentException() { Constructor constructor = ClassUtils.getConstructorIfAvailable( OAuth2AuthorizationRequestRedirectFilter.class, OAuth2AuthorizationRequestResolver.class); - assertThatThrownBy(() -> constructor.newInstance(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> constructor.newInstance(null)); } @Test public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)); } @Test public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.filter.setRequestCache(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); } @Test @@ -115,250 +128,209 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalServerError() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration1.getRegistrationId() + "-invalid"; + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId() + "-invalid"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); } @Test public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration1.getRegistrationId(); + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - - assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id"); + assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); } @Test public void doFilterWhenAuthorizationRequestOAuth2LoginThenAuthorizationRequestSaved() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration2.getRegistrationId(); + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration2.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - AuthorizationRequestRepository authorizationRequestRepository = - mock(AuthorizationRequestRepository.class); + AuthorizationRequestRepository authorizationRequestRepository = mock( + AuthorizationRequestRepository.class); this.filter.setAuthorizationRequestRepository(authorizationRequestRepository); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - verify(authorizationRequestRepository).saveAuthorizationRequest( - any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(authorizationRequestRepository).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void doFilterWhenAuthorizationRequestImplicitGrantThenRedirectForAuthorization() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration3.getRegistrationId(); + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration3.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - - assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=token&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/authorize/oauth2/implicit/registration-3"); + assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=token&client_id=client-id&" + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/authorize/oauth2/implicit/registration-3"); } @Test public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSaved() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration3.getRegistrationId(); + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration3.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - AuthorizationRequestRepository authorizationRequestRepository = - mock(AuthorizationRequestRepository.class); + AuthorizationRequestRepository authorizationRequestRepository = mock( + AuthorizationRequestRepository.class); this.filter.setAuthorizationRequestRepository(authorizationRequestRepository); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - verify(authorizationRequestRepository, times(0)).saveAuthorizationRequest( - any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(authorizationRequestRepository, times(0)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void doFilterWhenCustomAuthorizationRequestBaseUriThenRedirectForAuthorization() throws Exception { String authorizationRequestBaseUri = "/custom/authorization"; - this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, authorizationRequestBaseUri); - + this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository, + authorizationRequestBaseUri); String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - - assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id"); + assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); } @Test - public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectForAuthorization() throws Exception { + public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectForAuthorization() + throws Exception { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())) - .when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); - + willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) + .doFilter(any(ServletRequest.class), any(ServletResponse.class)); this.filter.doFilter(request, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - - assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/authorize/oauth2/code/registration-id"); + assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/authorize/oauth2/code/registration-id"); verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() throws Exception { + public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() + throws Exception { String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())) - .when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); - + willThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())).given(filterChain) + .doFilter(any(ServletRequest.class), any(ServletResponse.class)); OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); - filter.doFilter(request, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - verifyZeroInteractions(filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); } // gh-4911 @Test - public void doFilterWhenAuthorizationRequestAndAdditionalParametersProvidedThenAuthorizationRequestIncludesAdditionalParameters() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration1.getRegistrationId(); + public void doFilterWhenAuthorizationRequestAndAdditionalParametersProvidedThenAuthorizationRequestIncludesAdditionalParameters() + throws Exception { + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.addParameter("idp", "https://other.provider.com"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( - this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); - + this.clientRegistrationRepository, + OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class); OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest .from(defaultAuthorizationRequestResolver.resolve(request)) - .additionalParameters( - Collections.singletonMap("idp", request.getParameter("idp"))) - .build(); - when(resolver.resolve(any())).thenReturn(result); + .additionalParameters(Collections.singletonMap("idp", request.getParameter("idp"))).build(); + given(resolver.resolve(any())).willReturn(result); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); - filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - - assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id&" + - "idp=https://other.provider.com"); + assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id&" + + "idp=https://other.provider.com"); } // gh-4911, gh-5244 @Test - public void doFilterWhenAuthorizationRequestAndCustomAuthorizationRequestUriSetThenCustomAuthorizationRequestUriUsed() throws Exception { - String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + - "/" + this.registration1.getRegistrationId(); + public void doFilterWhenAuthorizationRequestAndCustomAuthorizationRequestUriSetThenCustomAuthorizationRequestUriUsed() + throws Exception { + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); String loginHintParamName = "login_hint"; request.addParameter(loginHintParamName, "user@provider.com"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( - this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); - + this.clientRegistrationRepository, + OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class); - OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request); Map additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters()); additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName)); + // @formatter:off String customAuthorizationRequestUri = UriComponentsBuilder .fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri()) .queryParam(loginHintParamName, additionalParameters.get(loginHintParamName)) - .build(true).toUriString(); + .build(true) + .toUriString(); OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest .from(defaultAuthorizationRequestResolver.resolve(request)) - .additionalParameters( - Collections.singletonMap("idp", request.getParameter("idp"))) + .additionalParameters(Collections.singletonMap("idp", request.getParameter("idp"))) .authorizationRequestUri(customAuthorizationRequestUri) .build(); - when(resolver.resolve(any())).thenReturn(result); - + // @formatter:on + given(resolver.resolve(any())).willReturn(result); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); - filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); - - assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.{15,}&" + - "redirect_uri=http://localhost/login/oauth2/code/registration-id&" + - "login_hint=user@provider\\.com"); + assertThat(response.getRedirectedUrl()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id&" + + "login_hint=user@provider\\.com"); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 07f8ebed92..7d50b6dda8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web; import java.util.HashMap; import java.util.Map; + import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -49,6 +51,7 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.WebAuthenticationDetails; @@ -56,15 +59,14 @@ import org.springframework.security.web.util.UrlUtils; import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.any; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success; /** * Tests for {@link OAuth2LoginAuthenticationFilter}. @@ -72,27 +74,40 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author * @author Joe Grandja */ public class OAuth2LoginAuthenticationFilterTests { + private ClientRegistration registration1; + private ClientRegistration registration2; + private String principalName1 = "principal-1"; + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientService authorizedClientService; + private AuthorizationRequestRepository authorizationRequestRepository; + private AuthenticationFailureHandler failureHandler; + private AuthenticationManager authenticationManager; + private AuthenticationDetailsSource authenticationDetailsSource; + private OAuth2LoginAuthenticationToken loginAuthentication; + private OAuth2LoginAuthenticationFilter filter; @Before public void setUp() { this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build(); - this.clientRegistrationRepository = new InMemoryClientRegistrationRepository( - this.registration1, this.registration2); + this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, + this.registration2); this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository); - this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService); + this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + this.authorizedClientService); this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); this.failureHandler = mock(AuthenticationFailureHandler.class); this.authenticationManager = mock(AuthenticationManager.class); @@ -107,33 +122,34 @@ public class OAuth2LoginAuthenticationFilterTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(null, this.authorizedClientService)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2LoginAuthenticationFilter(null, this.authorizedClientService)); } @Test public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, null)); } @Test public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, - (OAuth2AuthorizedClientRepository) null, OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, + (OAuth2AuthorizedClientRepository) null, + OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI)); } @Test public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, + this.authorizedClientRepository, null)); } @Test public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)); } @Test @@ -143,11 +159,10 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - verify(this.filter, never()).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(this.filter, never()).attemptAuthentication(any(HttpServletRequest.class), + any(HttpServletResponse.class)); } @Test @@ -158,56 +173,53 @@ public class OAuth2LoginAuthenticationFilterTests { // NOTE: // A valid Authorization Response contains either a 'code' or 'error' parameter. // Don't set it to force an invalid Authorization Response. - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - - ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class); - verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class), - authenticationExceptionArgCaptor.capture()); - + ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor + .forClass(AuthenticationException.class); + verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), + any(HttpServletResponse.class), authenticationExceptionArgCaptor.capture()); assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class); - OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue(); + OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor + .getValue(); assertThat(authenticationException.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); } @Test - public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() throws Exception { + public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() + throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, "state"); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.filter.doFilter(request, response, filterChain); - - ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class); - verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class), - authenticationExceptionArgCaptor.capture()); - + ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor + .forClass(AuthenticationException.class); + verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), + any(HttpServletResponse.class), authenticationExceptionArgCaptor.capture()); assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class); - OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue(); + OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor + .getValue(); assertThat(authenticationException.getError().getErrorCode()).isEqualTo("authorization_request_not_found"); } // gh-5251 @Test - public void doFilterWhenAuthorizationResponseClientRegistrationNotFoundThenClientRegistrationNotFoundError() throws Exception { + public void doFilterWhenAuthorizationResponseClientRegistrationNotFoundThenClientRegistrationNotFoundError() + throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, "state"); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - + // @formatter:off ClientRegistration registrationNotFound = ClientRegistration.withRegistrationId("registration-not-found") .clientId("client-1") .clientSecret("secret") @@ -221,16 +233,16 @@ public class OAuth2LoginAuthenticationFilterTests { .userNameAttributeName("id") .clientName("client-1") .build(); + // @formatter:on this.setUpAuthorizationRequest(request, response, registrationNotFound, state); - this.filter.doFilter(request, response, filterChain); - - ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class); - verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class), - authenticationExceptionArgCaptor.capture()); - + ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor + .forClass(AuthenticationException.class); + verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), + any(HttpServletResponse.class), authenticationExceptionArgCaptor.capture()); assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class); - OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue(); + OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor + .getValue(); assertThat(authenticationException.getError().getErrorCode()).isEqualTo("client_registration_not_found"); } @@ -242,15 +254,11 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, state); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthenticationResult(this.registration2); - this.filter.doFilter(request, response, filterChain); - assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull(); } @@ -262,17 +270,13 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, state); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration1, state); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registration1.getRegistrationId(), this.loginAuthentication, request); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registration1.getRegistrationId(), this.loginAuthentication, request); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1); @@ -283,32 +287,28 @@ public class OAuth2LoginAuthenticationFilterTests { @Test public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception { String filterProcessesUrl = "/login/oauth2/custom/*"; - this.filter = spy(new OAuth2LoginAuthenticationFilter( - this.clientRegistrationRepository, this.authorizedClientRepository, filterProcessesUrl)); + this.filter = spy(new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, + this.authorizedClientRepository, filterProcessesUrl)); this.filter.setAuthenticationManager(this.authenticationManager); - String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId(); String state = "state"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, state); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthenticationResult(this.registration2); - this.filter.doFilter(request, response, filterChain); - verifyZeroInteractions(filterChain); verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class)); } // gh-5890 @Test - public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort() throws Exception { + public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort() + throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); @@ -318,22 +318,19 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, "state"); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthenticationResult(this.registration2); - this.filter.doFilter(request, response, filterChain); - ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture()); - - OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue(); - OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest(); - OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse(); - + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor + .getValue(); + OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange() + .getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange() + .getAuthorizationResponse(); String expectedRedirectUri = "http://localhost/login/oauth2/code/registration-id-2"; assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri); assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); @@ -341,7 +338,8 @@ public class OAuth2LoginAuthenticationFilterTests { // gh-5890 @Test - public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort() throws Exception { + public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort() + throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); @@ -351,22 +349,19 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, "state"); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthenticationResult(this.registration2); - this.filter.doFilter(request, response, filterChain); - ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture()); - - OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue(); - OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest(); - OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse(); - + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor + .getValue(); + OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange() + .getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange() + .getAuthorizationResponse(); String expectedRedirectUri = "https://example.com/login/oauth2/code/registration-id-2"; assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri); assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); @@ -374,7 +369,8 @@ public class OAuth2LoginAuthenticationFilterTests { // gh-5890 @Test - public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort() throws Exception { + public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort() + throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); String state = "state"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); @@ -384,22 +380,19 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, "state"); - MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthenticationResult(this.registration2); - this.filter.doFilter(request, response, filterChain); - ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture()); - - OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue(); - OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest(); - OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse(); - + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor + .getValue(); + OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange() + .getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange() + .getAuthorizationResponse(); String expectedRedirectUri = "https://example.com:9090/login/oauth2/code/registration-id-2"; assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri); assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); @@ -414,64 +407,51 @@ public class OAuth2LoginAuthenticationFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, state); - WebAuthenticationDetails webAuthenticationDetails = mock(WebAuthenticationDetails.class); - when(authenticationDetailsSource.buildDetails(any())).thenReturn(webAuthenticationDetails); - + given(this.authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails); MockHttpServletResponse response = new MockHttpServletResponse(); - this.setUpAuthorizationRequest(request, response, this.registration2, state); this.setUpAuthenticationResult(this.registration2); - Authentication result = this.filter.attemptAuthentication(request, response); - assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails); } private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, - ClientRegistration registration, String state) { + ClientRegistration registration, String state) { Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) - .clientId(registration.getClientId()) - .redirectUri(expandRedirectUri(request, registration)) - .scopes(registration.getScopes()) - .state(state) - .attributes(attributes) - .build(); + .clientId(registration.getClientId()).redirectUri(expandRedirectUri(request, registration)) + .scopes(registration.getScopes()).state(state).attributes(attributes).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); } private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) { - String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) - .replaceQuery(null) - .replacePath(request.getContextPath()) - .build() - .toUriString(); - + String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)).replaceQuery(null) + .replacePath(request.getContextPath()).build().toUriString(); Map uriVariables = new HashMap<>(); uriVariables.put("baseUrl", baseUrl); uriVariables.put("action", "login"); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - - return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()) - .buildAndExpand(uriVariables) + return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()).buildAndExpand(uriVariables) .toUriString(); } private void setUpAuthenticationResult(ClientRegistration registration) { OAuth2User user = mock(OAuth2User.class); - when(user.getName()).thenReturn(this.principalName1); + given(user.getName()).willReturn(this.principalName1); this.loginAuthentication = mock(OAuth2LoginAuthenticationToken.class); - when(this.loginAuthentication.getPrincipal()).thenReturn(user); - when(this.loginAuthentication.getName()).thenReturn(this.principalName1); - when(this.loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER")); - when(this.loginAuthentication.getClientRegistration()).thenReturn(registration); - when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(success()); - when(this.loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class)); - when(this.loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class)); - when(this.loginAuthentication.isAuthenticated()).thenReturn(true); - when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(this.loginAuthentication); + given(this.loginAuthentication.getPrincipal()).willReturn(user); + given(this.loginAuthentication.getName()).willReturn(this.principalName1); + given(this.loginAuthentication.getAuthorities()).willReturn(AuthorityUtils.createAuthorityList("ROLE_USER")); + given(this.loginAuthentication.getClientRegistration()).willReturn(registration); + given(this.loginAuthentication.getAuthorizationExchange()) + .willReturn(TestOAuth2AuthorizationExchanges.success()); + given(this.loginAuthentication.getAccessToken()).willReturn(mock(OAuth2AccessToken.class)); + given(this.loginAuthentication.getRefreshToken()).willReturn(mock(OAuth2RefreshToken.class)); + given(this.loginAuthentication.isAuthenticated()).willReturn(true); + given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index ba563f4399..6cf546a2b6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -13,11 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.method.annotation; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.core.MethodParameter; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -53,15 +62,16 @@ import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.context.request.ServletWebRequest; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.lang.reflect.Method; -import java.util.HashMap; -import java.util.Map; - -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; -import static org.mockito.Mockito.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link OAuth2AuthorizedClientArgumentResolver}. @@ -69,17 +79,29 @@ import static org.mockito.Mockito.*; * @author Joe Grandja */ public class OAuth2AuthorizedClientArgumentResolverTests { + private TestingAuthenticationToken authentication; + private String principalName = "principal-1"; + private ClientRegistration registration1; + private ClientRegistration registration2; + private ClientRegistration registration3; + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClient authorizedClient1; + private OAuth2AuthorizedClient authorizedClient2; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientArgumentResolver argumentResolver; + private MockHttpServletRequest request; + private MockHttpServletResponse response; @Before @@ -88,7 +110,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests { SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication(this.authentication); SecurityContextHolder.setContext(securityContext); - + // @formatter:off this.registration1 = ClientRegistration.withRegistrationId("client1") .clientId("client-1") .clientSecret("secret") @@ -110,28 +132,27 @@ public class OAuth2AuthorizedClientArgumentResolverTests { .scope("read", "write") .tokenUri("https://provider.com/oauth2/token") .build(); - this.registration3 = TestClientRegistrations.password().registrationId("client3").build(); - this.clientRegistrationRepository = new InMemoryClientRegistrationRepository( - this.registration1, this.registration2, this.registration3); + this.registration3 = TestClientRegistrations.password() + .registrationId("client3") + .build(); + // @formatter:on + this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, + this.registration2, this.registration3); this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .build(); + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode().refreshToken().clientCredentials().build(); DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); - this.authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName, mock(OAuth2AccessToken.class)); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) - .thenReturn(this.authorizedClient1); - this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class)); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) - .thenReturn(this.authorizedClient2); + this.authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName, + mock(OAuth2AccessToken.class)); + given(this.authorizedClientRepository.loadAuthorizedClient(eq(this.registration1.getRegistrationId()), + any(Authentication.class), any(HttpServletRequest.class))).willReturn(this.authorizedClient1); + this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, + mock(OAuth2AccessToken.class)); + given(this.authorizedClientRepository.loadAuthorizedClient(eq(this.registration2.getRegistrationId()), + any(Authentication.class), any(HttpServletRequest.class))).willReturn(this.authorizedClient2); this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); } @@ -143,46 +164,50 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository)); } @Test public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)); } @Test public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)); } @Test public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null)) + .withMessage("clientCredentialsTokenResponseClient cannot be null"); } @Test public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { - assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) - .isInstanceOf(IllegalStateException.class) - .hasMessage("The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + assertThatIllegalStateException() + .isThrownBy(() -> this.argumentResolver + .setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) + .withMessage("The client cannot be set when the constructor used is " + + "\"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, " + + "OAuth2AuthorizedClientRepository)\"."); } @Test public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue(); } @Test public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientWithoutAnnotationThenFalse() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", OAuth2AuthorizedClient.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", + OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); } @@ -194,104 +219,103 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void supportsParameterWhenParameterTypeUnsupportedWithoutAnnotationThenFalse() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", String.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", + String.class); assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); } @Test public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() { MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null)) + .withMessage("Unable to resolve the Client Registration Identifier. It must be provided via " + + "@RegisteredOAuth2AuthorizedClient(\"client1\") or " + + "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); } @Test public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() throws Exception { OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class); - when(authentication.getAuthorizedClientRegistrationId()).thenReturn("client1"); + given(authentication.getAuthorizedClientRegistrationId()).willReturn("client1"); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication(authentication); SecurityContextHolder.setContext(securityContext); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); - assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); + assertThat(this.argumentResolver.resolveArgument(methodParameter, null, + new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } @Test public void resolveArgumentWhenAuthorizedClientFoundThenResolves() throws Exception { - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); + assertThat(this.argumentResolver.resolveArgument(methodParameter, null, + new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } @Test public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() { - MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id 'invalid'"); + MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", + OAuth2AuthorizedClient.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, + new ServletWebRequest(this.request, this.response), null)) + .withMessage("Could not find ClientRegistration with id 'invalid'"); } @Test public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClientThenThrowClientAuthorizationRequiredException() { - when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) - .thenReturn(null); - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)) - .isInstanceOf(ClientAuthorizationRequiredException.class); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) + .willReturn(null); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); + assertThatExceptionOfType(ClientAuthorizationRequiredException.class).isThrownBy(() -> this.argumentResolver + .resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)); } @SuppressWarnings("unchecked") @Test - public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClientThenResolvesFromTokenResponseClient() throws Exception { - OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - mock(OAuth2AccessTokenResponseClient.class); - ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(); + public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClientThenResolvesFromTokenResponseClient() + throws Exception { + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = mock( + OAuth2AccessTokenResponseClient.class); + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken("access-token-1234") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .build(); - when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) - .thenReturn(null); - MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class); - - OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request, this.response), null); - + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).build(); + given(clientCredentialsTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) + .willReturn(null); + MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", + OAuth2AuthorizedClient.class); + OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver + .resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName); assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken()); - - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(authorizedClient), eq(this.authentication), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @SuppressWarnings("unchecked") @Test - public void resolveArgumentWhenAuthorizedClientNotFoundForPasswordClientThenResolvesFromTokenResponseClient() throws Exception { - OAuth2AccessTokenResponseClient passwordTokenResponseClient = - mock(OAuth2AccessTokenResponseClient.class); - PasswordOAuth2AuthorizedClientProvider passwordAuthorizedClientProvider = - new PasswordOAuth2AuthorizedClientProvider(); + public void resolveArgumentWhenAuthorizedClientNotFoundForPasswordClientThenResolvesFromTokenResponseClient() + throws Exception { + OAuth2AccessTokenResponseClient passwordTokenResponseClient = mock( + OAuth2AccessTokenResponseClient.class); + PasswordOAuth2AuthorizedClientProvider passwordAuthorizedClientProvider = new PasswordOAuth2AuthorizedClientProvider(); passwordAuthorizedClientProvider.setAccessTokenResponseClient(passwordTokenResponseClient); DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(passwordAuthorizedClientProvider); - // Set custom contextAttributesMapper - authorizedClientManager.setContextAttributesMapper(authorizeRequest -> { + authorizedClientManager.setContextAttributesMapper((authorizeRequest) -> { Map contextAttributes = new HashMap<>(); HttpServletRequest servletRequest = authorizeRequest.getAttribute(HttpServletRequest.class.getName()); String username = servletRequest.getParameter(OAuth2ParameterNames.USERNAME); @@ -302,33 +326,23 @@ public class OAuth2AuthorizedClientArgumentResolverTests { } return contextAttributes; }); - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken("access-token-1234") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .build(); - when(passwordTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) - .thenReturn(null); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).build(); + given(passwordTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) + .willReturn(null); MethodParameter methodParameter = this.getMethodParameter("passwordClient", OAuth2AuthorizedClient.class); - this.request.setParameter(OAuth2ParameterNames.USERNAME, "username"); this.request.setParameter(OAuth2ParameterNames.PASSWORD, "password"); - - OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request, this.response), null); - + OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver + .resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration3); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName); assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken()); - - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), any(HttpServletResponse.class)); + verify(this.authorizedClientRepository).saveAuthorizedClient(eq(authorizedClient), eq(this.authentication), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { @@ -337,7 +351,9 @@ public class OAuth2AuthorizedClientArgumentResolverTests { } static class TestController { - void paramTypeAuthorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { + + void paramTypeAuthorizedClient( + @RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { } void paramTypeAuthorizedClientWithoutAnnotation(OAuth2AuthorizedClient authorizedClient) { @@ -352,13 +368,17 @@ public class OAuth2AuthorizedClientArgumentResolverTests { void registrationIdEmpty(@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { } - void registrationIdInvalid(@RegisteredOAuth2AuthorizedClient("invalid") OAuth2AuthorizedClient authorizedClient) { + void registrationIdInvalid( + @RegisteredOAuth2AuthorizedClient("invalid") OAuth2AuthorizedClient authorizedClient) { } - void clientCredentialsClient(@RegisteredOAuth2AuthorizedClient("client2") OAuth2AuthorizedClient authorizedClient) { + void clientCredentialsClient( + @RegisteredOAuth2AuthorizedClient("client2") OAuth2AuthorizedClient authorizedClient) { } void passwordClient(@RegisteredOAuth2AuthorizedClient("client3") OAuth2AuthorizedClient authorizedClient) { } + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java index ac9f31b6ff..46e6f20d0f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java @@ -16,22 +16,23 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; -import static org.mockito.Mockito.mock; +import java.util.ArrayList; +import java.util.List; + +import reactor.core.publisher.Mono; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFunction; -import reactor.core.publisher.Mono; - -import java.util.ArrayList; -import java.util.List; +import static org.mockito.Mockito.mock; /** * @author Rob Winch * @since 5.1 */ public class MockExchangeFunction implements ExchangeFunction { + private List requests = new ArrayList<>(); private ClientResponse response = mock(ClientResponse.class); @@ -55,4 +56,5 @@ public class MockExchangeFunction implements ExchangeFunction { return Mono.just(this.response); }); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java index f656d981a1..9b501a6e63 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -13,14 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.reactive.function.client; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -43,26 +52,18 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; -import reactor.util.context.Context; - -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashSet; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; /** * @author Phil Clay @@ -70,12 +71,19 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ServerOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter; + private MockWebServer server; + private String serverUrl; + private WebClient webClient; + private Authentication authentication; + private MockServerWebExchange exchange; @Before @@ -84,37 +92,34 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { final ServerOAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( new InMemoryReactiveOAuth2AuthorizedClientService(this.clientRegistrationRepository)); this.authorizedClientRepository = spy(new ServerOAuth2AuthorizedClientRepository() { - @Override - public Mono loadAuthorizedClient( - String clientRegistrationId, + public Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange exchange) { return delegate.loadAuthorizedClient(clientRegistrationId, principal, exchange); } @Override - public Mono saveAuthorizedClient( - OAuth2AuthorizedClient authorizedClient, - Authentication principal, ServerWebExchange exchange) { + public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + ServerWebExchange exchange) { return delegate.saveAuthorizedClient(authorizedClient, principal, exchange); } @Override - public Mono removeAuthorizedClient( - String clientRegistrationId, - Authentication principal, ServerWebExchange exchange) { + public Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange) { return delegate.removeAuthorizedClient(clientRegistrationId, principal, exchange); } - }); this.authorizedClientFilter = new ServerOAuth2AuthorizedClientExchangeFilterFunction( this.clientRegistrationRepository, this.authorizedClientRepository); this.server = new MockWebServer(); this.server.start(); this.serverUrl = this.server.url("/").toString(); + // @formatter:off this.webClient = WebClient.builder() .filter(this.authorizedClientFilter) .build(); + // @formatter:on this.authentication = new TestingAuthenticationToken("principal", "password"); this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/").build()).build(); } @@ -126,83 +131,82 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { @Test public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration)); - - this.webClient - .get() + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl) + .build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(clientRegistration)); + // @formatter:off + this.webClient.get() .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration.getRegistrationId())) .retrieve() .bodyToMono(String.class) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) .block(); - + // @formatter:on assertThat(this.server.getRequestCount()).isEqualTo(2); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.exchange)); assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); } @Test public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"refreshed-access-token\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; - + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"refreshed-access-token\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration)); - + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl) + .build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(clientRegistration)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofHours(1)); OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write"))); OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, this.authentication.getName(), accessToken, refreshToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + this.authentication.getName(), accessToken, refreshToken); doReturn(Mono.just(authorizedClient)).when(this.authorizedClientRepository).loadAuthorizedClient( eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange)); - - this.webClient - .get() - .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) - .retrieve() - .bodyToMono(String.class) + this.webClient.get().uri(this.serverUrl) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve().bodyToMono(String.class) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) - .block(); - + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)).block(); assertThat(this.server.getRequestCount()).isEqualTo(2); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.exchange)); OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue(); assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration); assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); @@ -210,124 +214,127 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { @Test public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; - + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on // Client 1 this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials() - .registrationId("client-1").tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(Mono.just(clientRegistration1)); - + ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials().registrationId("client-1") + .tokenUri(this.serverUrl).build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))) + .willReturn(Mono.just(clientRegistration1)); // Client 2 this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials() - .registrationId("client-2").tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(Mono.just(clientRegistration2)); - - this.webClient - .get() + ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials().registrationId("client-2") + .tokenUri(this.serverUrl).build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))) + .willReturn(Mono.just(clientRegistration2)); + // @formatter:off + this.webClient.get() .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration1.getRegistrationId())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration1.getRegistrationId())) .retrieve() .bodyToMono(String.class) - .flatMap(response -> this.webClient - .get() + .flatMap((response) -> this.webClient.get() .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration2.getRegistrationId())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration2.getRegistrationId())) .retrieve() - .bodyToMono(String.class)) + .bodyToMono(String.class) + ) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) .block(); - + // @formatter:on assertThat(this.server.getRequestCount()).isEqualTo(4); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.exchange)); assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1); assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2); } /** - * When a non-expired {@link OAuth2AuthorizedClient} exists - * but the resource server returns 401, - * then remove the {@link OAuth2AuthorizedClient} from the repository. + * When a non-expired {@link OAuth2AuthorizedClient} exists but the resource server + * returns 401, then remove the {@link OAuth2AuthorizedClient} from the repository. */ @Test public void requestWhenUnauthorizedThenReAuthorize() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(new MockResponse().setResponseCode(HttpStatus.UNAUTHORIZED.value())); this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration)); - + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl) + .build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(clientRegistration)); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, this.authentication.getName(), accessToken, refreshToken); - doReturn(Mono.just(authorizedClient)) - .doReturn(Mono.empty()) - .when(this.authorizedClientRepository).loadAuthorizedClient( - eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange)); - - Mono requestMono = this.webClient - .get() + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + this.authentication.getName(), accessToken, refreshToken); + doReturn(Mono.just(authorizedClient)).doReturn(Mono.empty()).when(this.authorizedClientRepository) + .loadAuthorizedClient(eq(clientRegistration.getRegistrationId()), eq(this.authentication), + eq(this.exchange)); + // @formatter:off + Mono requestMono = this.webClient.get() .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration.getRegistrationId())) .retrieve() .bodyToMono(String.class) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)); - + // @formatter:on // first try should fail, and remove the cached authorized client - assertThatCode(requestMono::block) - .isInstanceOfSatisfying(WebClientResponseException.class, e -> assertThat(e.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED)); - + // @formatter:off + assertThatExceptionOfType(WebClientResponseException.class) + .isThrownBy(requestMono::block) + .satisfies((ex) -> assertThat(ex.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED)); + // @formatter:on assertThat(this.server.getRequestCount()).isEqualTo(1); - verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); - verify(this.authorizedClientRepository).removeAuthorizedClient( - eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange)); - + verify(this.authorizedClientRepository).removeAuthorizedClient(eq(clientRegistration.getRegistrationId()), + eq(this.authentication), eq(this.exchange)); // second try should retrieve the authorized client and succeed requestMono.block(); - assertThat(this.server.getRequestCount()).isEqualTo(3); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.exchange)); assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); } private MockResponse jsonResponse(String json) { + // @formatter:off return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setBody(json); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 15a603953a..7bc1877cf3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -13,8 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.reactive.function.client; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -22,6 +34,10 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.test.publisher.PublisherProbe; +import reactor.util.context.Context; + import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; @@ -78,36 +94,20 @@ import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; -import reactor.test.publisher.PublisherProbe; -import reactor.util.context.Context; - -import java.net.URI; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.entry; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.http.HttpMethod.GET; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; /** * @author Rob Winch @@ -115,6 +115,7 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c */ @RunWith(MockitoJUnitRunner.class) public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { + @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; @@ -151,136 +152,132 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { private MockExchangeFunction exchange = new MockExchangeFunction(); - private ClientRegistration registration = TestClientRegistrations.clientRegistration() - .build(); + private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token-0", - Instant.now(), - Instant.now().plus(Duration.ofDays(1))); + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token-0", + Instant.now(), Instant.now().plus(Duration.ofDays(1))); private DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager; @Before public void setup() { - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) - .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) - .password(configurer -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) - .build(); + // @formatter:off + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder() + .authorizationCode() + .refreshToken( + (configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials( + (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) + .build(); + // @formatter:on this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); - when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class)); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager); + given(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + given(this.exchange.getResponse().headers()).willReturn(mock(ClientResponse.Headers.class)); } @Test public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null)); } @Test public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) + .withMessage("clientCredentialsTokenResponseClient cannot be null"); } @Test public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { - assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(new WebClientReactiveClientCredentialsTokenResponseClient())) - .isInstanceOf(IllegalStateException.class) - .hasMessage("The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + assertThatIllegalStateException() + .isThrownBy(() -> this.function.setClientCredentialsTokenResponseClient( + new WebClientReactiveClientCredentialsTokenResponseClient())) + .withMessage( + "The client cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); } @Test public void setAccessTokenExpiresSkewWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { - assertThatThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))) - .isInstanceOf(IllegalStateException.class) - .hasMessage("The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + assertThatIllegalStateException() + .isThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))).withMessage( + "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServerOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); } @Test public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); this.function.filter(request, this.exchange).block(); - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); } @Test public void filterWhenAuthorizedClientThenAuthorizationHeader() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) + // @formatter:off + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()) .block(); - - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); + // @formatter:on + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } @Test public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") - .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - + // @formatter:on + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } @Test public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse + .withToken("new-token") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(360) .build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - + // @formatter:on + given(this.clientCredentialsTokenResponseClient.getTokenResponse(any())) + .willReturn(Mono.just(accessTokenResponse)); ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, - "principalName", accessToken, null); - + this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", accessToken, + null); TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(serverWebExchange()) .block(); - + // @formatter:on verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); ClientRequest request1 = requests.get(0); @@ -294,20 +291,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, - "principalName", this.accessToken, null); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", + this.accessToken, null); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(serverWebExchange()) .block(); - + // @formatter:on verify(this.clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); ClientRequest request1 = requests.get(0); @@ -320,42 +315,35 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefresh() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .refreshToken("refresh-1") - .build(); - when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(response)); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).refreshToken("refresh-1").build(); + given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), + issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - + // @formatter:on TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + // @formatter:off this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(serverWebExchange()) .block(); - + // @formatter:on verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); - verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any()); - - OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); + verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), + eq(authentication), any()); + OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -366,36 +354,27 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .refreshToken("refresh-1") - .build(); - when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(response)); + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).refreshToken("refresh-1").build(); + given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response)); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); - Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); - + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), + issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange) .subscriberContext(serverWebExchange()) .block(); - + // @formatter:on verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -405,19 +384,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange) .subscriberContext(serverWebExchange()) .block(); - + // @formatter:on List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -428,19 +406,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - this.function.filter(request, this.exchange) .subscriberContext(serverWebExchange()) .block(); - + // @formatter:on List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -450,336 +427,264 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenUnauthorizedThenInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); - when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); - + given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) + .willReturn(publisherProbe.mono()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.UNAUTHORIZED.value()); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - + // @formatter:on + given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.UNAUTHORIZED.value()); + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); assertThat(publisherProbe.wasSubscribed()).isTrue(); - - verify(authorizationFailureHandler).onAuthorizationFailure( - authorizationExceptionCaptor.capture(), - authenticationCaptor.capture(), - attributesCaptor.capture()); - - assertThat(authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token"); - assertThat(e).hasNoCause(); - assertThat(e).hasMessageContaining("[invalid_token]"); + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token"); + assertThat(ex).hasNoCause(); + assertThat(ex).hasMessageContaining("[invalid_token]"); }); - assertThat(authenticationCaptor.getValue()) - .isInstanceOf(AnonymousAuthenticationToken.class); - assertThat(attributesCaptor.getValue()) + assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(this.attributesCaptor.getValue()) .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); } @Test public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); - when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); - + given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) + .willReturn(publisherProbe.mono()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - WebClientResponseException exception = WebClientResponseException.create( - HttpStatus.UNAUTHORIZED.value(), - HttpStatus.UNAUTHORIZED.getReasonPhrase(), - HttpHeaders.EMPTY, - new byte[0], - StandardCharsets.UTF_8); - - ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); - - assertThatCode(() -> this.function.filter(request, throwingExchangeFunction) - .subscriberContext(serverWebExchange()) - .block()) + // @formatter:on + WebClientResponseException exception = WebClientResponseException.create(HttpStatus.UNAUTHORIZED.value(), + HttpStatus.UNAUTHORIZED.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8); + ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception); + // @formatter:off + assertThatExceptionOfType(WebClientResponseException.class) + .isThrownBy(() -> this.function + .filter(request, throwingExchangeFunction) + .subscriberContext(serverWebExchange()) + .block() + ) .isEqualTo(exception); - + // @formatter:on assertThat(publisherProbe.wasSubscribed()).isTrue(); - - verify(authorizationFailureHandler).onAuthorizationFailure( - authorizationExceptionCaptor.capture(), - authenticationCaptor.capture(), - attributesCaptor.capture()); - - assertThat(authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token"); - assertThat(e).hasCause(exception); - assertThat(e).hasMessageContaining("[invalid_token]"); + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); + // @formatter:off + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token"); + assertThat(ex).hasCause(exception); + assertThat(ex).hasMessageContaining("[invalid_token]"); }); - assertThat(authenticationCaptor.getValue()) - .isInstanceOf(AnonymousAuthenticationToken.class); - assertThat(attributesCaptor.getValue()) + // @formatter:on + assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(this.attributesCaptor.getValue()) .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); } @Test public void filterWhenForbiddenThenInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); - when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); - + given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) + .willReturn(publisherProbe.mono()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.FORBIDDEN.value()); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - + // @formatter:on + given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.FORBIDDEN.value()); + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); assertThat(publisherProbe.wasSubscribed()).isTrue(); - - verify(authorizationFailureHandler).onAuthorizationFailure( - authorizationExceptionCaptor.capture(), - authenticationCaptor.capture(), - attributesCaptor.capture()); - - assertThat(authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo("insufficient_scope"); - assertThat(e).hasNoCause(); - assertThat(e).hasMessageContaining("[insufficient_scope]"); + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo("insufficient_scope"); + assertThat(ex).hasNoCause(); + assertThat(ex).hasMessageContaining("[insufficient_scope]"); }); - assertThat(authenticationCaptor.getValue()) - .isInstanceOf(AnonymousAuthenticationToken.class); - assertThat(attributesCaptor.getValue()) + assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(this.attributesCaptor.getValue()) .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); } @Test public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); - when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); - + given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) + .willReturn(publisherProbe.mono()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - WebClientResponseException exception = WebClientResponseException.create( - HttpStatus.FORBIDDEN.value(), - HttpStatus.FORBIDDEN.getReasonPhrase(), - HttpHeaders.EMPTY, - new byte[0], - StandardCharsets.UTF_8); - - ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); - - assertThatCode(() -> this.function.filter(request, throwingExchangeFunction) - .subscriberContext(serverWebExchange()) - .block()) + WebClientResponseException exception = WebClientResponseException.create(HttpStatus.FORBIDDEN.value(), + HttpStatus.FORBIDDEN.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8); + ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception); + // @formatter:off + assertThatExceptionOfType(WebClientResponseException.class) + .isThrownBy(() -> this.function + .filter(request, throwingExchangeFunction) + .subscriberContext(serverWebExchange()) + .block() + ) .isEqualTo(exception); - + // @formatter:on assertThat(publisherProbe.wasSubscribed()).isTrue(); - - verify(authorizationFailureHandler).onAuthorizationFailure( - authorizationExceptionCaptor.capture(), - authenticationCaptor.capture(), - attributesCaptor.capture()); - - assertThat(authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo("insufficient_scope"); - assertThat(e).hasCause(exception); - assertThat(e).hasMessageContaining("[insufficient_scope]"); + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo("insufficient_scope"); + assertThat(ex).hasCause(exception); + assertThat(ex).hasMessageContaining("[insufficient_scope]"); }); - assertThat(authenticationCaptor.getValue()) - .isInstanceOf(AnonymousAuthenticationToken.class); - assertThat(attributesCaptor.getValue()) + assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(this.attributesCaptor.getValue()) .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); } @Test public void filterWhenWWWAuthenticateHeaderIncludesErrorThenInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); - when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); - + given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) + .willReturn(publisherProbe.mono()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " + - "error_description=\"The request requires higher privileges than provided by the access token.\", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""; + String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " + + "error_description=\"The request requires higher privileges than provided by the access token.\", " + + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""; ClientResponse.Headers headers = mock(ClientResponse.Headers.class); - when(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE))) - .thenReturn(Collections.singletonList(wwwAuthenticateHeader)); - when(this.exchange.getResponse().headers()).thenReturn(headers); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - + given(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE))) + .willReturn(Collections.singletonList(wwwAuthenticateHeader)); + given(this.exchange.getResponse().headers()).willReturn(headers); + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); assertThat(publisherProbe.wasSubscribed()).isTrue(); - - verify(authorizationFailureHandler).onAuthorizationFailure( - authorizationExceptionCaptor.capture(), - authenticationCaptor.capture(), - attributesCaptor.capture()); - - assertThat(authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); - assertThat(e.getError().getDescription()).isEqualTo("The request requires higher privileges than provided by the access token."); - assertThat(e.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); - assertThat(e).hasNoCause(); - assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + assertThat(ex.getError().getDescription()) + .isEqualTo("The request requires higher privileges than provided by the access token."); + assertThat(ex.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); + assertThat(ex).hasNoCause(); + assertThat(ex).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); }); - assertThat(authenticationCaptor.getValue()) - .isInstanceOf(AnonymousAuthenticationToken.class); - assertThat(attributesCaptor.getValue()) + assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(this.attributesCaptor.getValue()) .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); } @Test public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); PublisherProbe publisherProbe = PublisherProbe.empty(); - when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); - + given(this.authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())) + .willReturn(publisherProbe.mono()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - OAuth2AuthorizationException exception = new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null)); - - ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); - - assertThatCode(() -> this.function.filter(request, throwingExchangeFunction) - .subscriberContext(serverWebExchange()) - .block()) + OAuth2AuthorizationException exception = new OAuth2AuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null)); + ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> this.function + .filter(request, throwingExchangeFunction).subscriberContext(serverWebExchange()).block()) .isEqualTo(exception); - assertThat(publisherProbe.wasSubscribed()).isTrue(); - - verify(authorizationFailureHandler).onAuthorizationFailure( - authorizationExceptionCaptor.capture(), - authenticationCaptor.capture(), - attributesCaptor.capture()); - - assertThat(authorizationExceptionCaptor.getValue()) - .isSameAs(exception); - assertThat(authenticationCaptor.getValue()) - .isInstanceOf(AnonymousAuthenticationToken.class); - assertThat(attributesCaptor.getValue()) + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); + assertThat(this.authorizationExceptionCaptor.getValue()).isSameAs(exception); + assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(this.attributesCaptor.getValue()) .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); } @Test public void filterWhenOtherHttpStatusShouldNotInvokeFailureHandler() { - function.setAuthorizationFailureHandler(authorizationFailureHandler); - + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .build(); - - when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value()); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - - verify(authorizationFailureHandler, never()).onAuthorizationFailure(any(), any(), any()); + given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.BAD_REQUEST.value()); + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); + verify(this.authorizationFailureHandler, never()).onAuthorizationFailure(any(), any(), any()); } @Test public void filterWhenPasswordClientNotAuthorizedThenGetNewToken() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); ClientRegistration registration = TestClientRegistrations.password().build(); - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(360) - .build(); - when(this.passwordTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - - when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(Mono.just(registration)); - when(this.authorizedClientRepository.loadAuthorizedClient(eq(registration.getRegistrationId()), eq(authentication), any())).thenReturn(Mono.empty()); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(360).build(); + given(this.passwordTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))) + .willReturn(Mono.just(registration)); + given(this.authorizedClientRepository.loadAuthorizedClient(eq(registration.getRegistrationId()), + eq(authentication), any())).willReturn(Mono.empty()); // Set custom contextAttributesMapper capable of mapping the form parameters - this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> { + this.authorizedClientManager.setContextAttributesMapper((authorizeRequest) -> { ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); - return Mono.just(serverWebExchange) - .flatMap(ServerWebExchange::getFormData) - .map(formData -> { - Map contextAttributes = new HashMap<>(); - String username = formData.getFirst(OAuth2ParameterNames.USERNAME); - String password = formData.getFirst(OAuth2ParameterNames.PASSWORD); - if (StringUtils.hasText(username) && StringUtils.hasText(password)) { - contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); - contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); - } - return contextAttributes; - }); + return Mono.just(serverWebExchange).flatMap(ServerWebExchange::getFormData).map((formData) -> { + Map contextAttributes = new HashMap<>(); + String username = formData.getFirst(OAuth2ParameterNames.USERNAME); + String password = formData.getFirst(OAuth2ParameterNames.PASSWORD); + if (StringUtils.hasText(username) && StringUtils.hasText(password)) { + contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); + contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); + } + return contextAttributes; + }); }); - - this.serverWebExchange = MockServerWebExchange.builder( - MockServerHttpRequest - .post("/") - .contentType(MediaType.APPLICATION_FORM_URLENCODED) - .body("username=username&password=password")) + this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.post("/") + .contentType(MediaType.APPLICATION_FORM_URLENCODED).body("username=username&password=password")) .build(); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(clientRegistrationId(registration.getRegistrationId())) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(registration.getRegistrationId())) .build(); - this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .subscriberContext(serverWebExchange()) - .block(); - + .subscriberContext(serverWebExchange()).block(); verify(this.passwordTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); ClientRequest request1 = requests.get(0); @@ -792,20 +697,17 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(clientRegistrationId(this.registration.getRegistrationId())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .willReturn(Mono.just(authorizedClient)); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(this.registration.getRegistrationId())) .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -817,19 +719,14 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { public void filterWhenDefaultClientRegistrationIdThenAuthorizedClientResolved() { this.function.setDefaultClientRegistrationId(this.registration.getRegistrationId()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(serverWebExchange()) - .block(); - + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .willReturn(Mono.just(authorizedClient)); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); + this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -840,26 +737,21 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() { this.function.setDefaultOAuth2AuthorizedClient(true); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - - OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections - .singletonMap("user", "rob"), "user"); - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id"); - this.function - .filter(request, this.exchange) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .willReturn(Mono.just(authorizedClient)); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), + "client-id"); + this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .subscriberContext(serverWebExchange()) - .block(); - + .subscriberContext(serverWebExchange()).block(); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -869,69 +761,66 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - - OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections - .singletonMap("user", "rob"), "user"); - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id"); - - this.function - .filter(request, this.exchange) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), + "client-id"); + // @formatter:off + this.function.filter(request, this.exchange) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .block(); - + // @formatter:on List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository); } @Test public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(clientRegistrationId(this.registration.getRegistrationId())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .willReturn(Mono.just(authorizedClient)); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId(this.registration.getRegistrationId())) .build(); - this.function.filter(request, this.exchange) .subscriberContext(serverWebExchange()) .block(); - - verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange)); + // @formatter:on + verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), + eq(this.serverWebExchange)); } // gh-7544 @Test public void filterWhenClientCredentialsClientNotAuthorizedAndOutsideRequestContextThenGetNewToken() { - // Use UnAuthenticatedServerOAuth2AuthorizedClientRepository when operating outside of a request context - ServerOAuth2AuthorizedClientRepository unauthenticatedAuthorizedClientRepository = spy(new UnAuthenticatedServerOAuth2AuthorizedClientRepository()); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, unauthenticatedAuthorizedClientRepository); + // Use UnAuthenticatedServerOAuth2AuthorizedClientRepository when operating + // outside of a request context + ServerOAuth2AuthorizedClientRepository unauthenticatedAuthorizedClientRepository = spy( + new UnAuthenticatedServerOAuth2AuthorizedClientRepository()); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + unauthenticatedAuthorizedClientRepository); this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(360) - .build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(360).build(); + given(this.clientCredentialsTokenResponseClient.getTokenResponse(any())) + .willReturn(Mono.just(accessTokenResponse)); ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(Mono.just(registration)); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(clientRegistrationId(registration.getRegistrationId())) + given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))) + .willReturn(Mono.just(registration)); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId(registration.getRegistrationId())) .build(); - + // @formatter:on this.function.filter(request, this.exchange).block(); - verify(unauthenticatedAuthorizedClientRepository).loadAuthorizedClient(any(), any(), any()); verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); verify(unauthenticatedAuthorizedClientRepository).saveAuthorizedClient(any(), any(), any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); ClientRequest request1 = requests.get(0); @@ -956,7 +845,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { messageWriters.add(new FormHttpMessageWriter()); messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes())); messageWriters.add(new MultipartHttpMessageWriter(messageWriters)); - BodyInserter.Context context = new BodyInserter.Context() { @Override public List> messageWriters() { @@ -973,9 +861,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { return new HashMap<>(); } }; - MockClientHttpRequest body = new MockClientHttpRequest(HttpMethod.GET, "/"); request.body().insert(body, context).block(); return body.getBodyAsString().block(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java index 4fc245d31b..c2cb2070e8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -13,8 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.reactive.function.client; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; @@ -22,6 +33,9 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; +import reactor.blockhound.BlockHound; +import reactor.util.context.Context; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; @@ -42,47 +56,54 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.client.WebClient; -import reactor.blockhound.BlockHound; -import reactor.util.context.Context; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * @author Joe Grandja */ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private ServletOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter; + private MockWebServer server; + private String serverUrl; + private WebClient webClient; + private Authentication authentication; + private MockHttpServletRequest request; + private MockHttpServletResponse response; @BeforeClass public static void setUpBlockingChecks() { // IMPORTANT: - // Before enabling BlockHound, we need to white-list `java.lang.Class.getPackage()`. + // Before enabling BlockHound, we need to white-list + // `java.lang.Class.getPackage()`. // When the JVM loads `java.lang.Package.getSystemPackage()`, it attempts to - // `java.lang.Package.loadManifest()` which is blocking I/O and triggers BlockHound to error. - // NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine w/o this white-list. + // `java.lang.Package.loadManifest()` which is blocking I/O and triggers + // BlockHound to error. + // NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine + // w/o this white-list. + // @formatter:off BlockHound.builder() .allowBlockingCallsInside(Class.class.getName(), "getPackage") .install(); + // @formatter:on } @Before @@ -92,17 +113,20 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository)); this.authorizedClientRepository = spy(new OAuth2AuthorizedClientRepository() { @Override - public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request) { + public T loadAuthorizedClient(String clientRegistrationId, + Authentication principal, HttpServletRequest request) { return delegate.loadAuthorizedClient(clientRegistrationId, principal, request); } @Override - public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, HttpServletRequest request, HttpServletResponse response) { + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { delegate.saveAuthorizedClient(authorizedClient, principal, request, response); } @Override - public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request, HttpServletResponse response) { + public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { delegate.removeAuthorizedClient(clientRegistrationId, principal, request, response); } }); @@ -111,9 +135,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { this.server = new MockWebServer(); this.server.start(); this.serverUrl = this.server.url("/").toString(); - this.webClient = WebClient.builder() - .apply(this.authorizedClientFilter.oauth2Configuration()) - .build(); + this.webClient = WebClient.builder().apply(this.authorizedClientFilter.oauth2Configuration()).build(); this.authentication = new TestingAuthenticationToken("principal", "password"); SecurityContextHolder.getContext().setAuthentication(this.authentication); this.request = new MockHttpServletRequest(); @@ -130,80 +152,73 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { @Test public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; - + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); - - this.webClient - .get() - .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) - .retrieve() - .bodyToMono(String.class) - .block(); - + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl) + .build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) + .willReturn(clientRegistration); + this.webClient.get().uri(this.serverUrl) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve().bodyToMono(String.class).block(); assertThat(this.server.getRequestCount()).isEqualTo(2); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.request), eq(this.response)); assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); } @Test public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"refreshed-access-token\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; - + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"refreshed-access-token\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); - + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl) + .build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) + .willReturn(clientRegistration); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofHours(1)); OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write"))); OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, this.authentication.getName(), accessToken, refreshToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + this.authentication.getName(), accessToken, refreshToken); doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient( eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request)); - - this.webClient - .get() - .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) - .retrieve() - .bodyToMono(String.class) - .block(); - + this.webClient.get().uri(this.serverUrl) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve().bodyToMono(String.class).block(); assertThat(this.server.getRequestCount()).isEqualTo(2); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.request), eq(this.response)); OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue(); assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration); assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); @@ -211,53 +226,53 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { @Test public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { - String accessTokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\"\n" + - "}\n"; - String clientResponse = "{\n" + - " \"attribute1\": \"value1\",\n" + - " \"attribute2\": \"value2\"\n" + - "}\n"; - + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + // @formatter:on // Client 1 this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials() - .registrationId("client-1").tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(clientRegistration1); - + ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials().registrationId("client-1") + .tokenUri(this.serverUrl).build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))) + .willReturn(clientRegistration1); // Client 2 this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(clientResponse)); - - ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials() - .registrationId("client-2").tokenUri(this.serverUrl).build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(clientRegistration2); - - this.webClient - .get() + ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials().registrationId("client-2") + .tokenUri(this.serverUrl).build(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))) + .willReturn(clientRegistration2); + // @formatter:off + this.webClient.get() .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration1.getRegistrationId())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId(clientRegistration1.getRegistrationId())) .retrieve() .bodyToMono(String.class) - .flatMap(response -> this.webClient - .get() - .uri(this.serverUrl) - .attributes(clientRegistrationId(clientRegistration2.getRegistrationId())) - .retrieve() - .bodyToMono(String.class)) + .flatMap((response) -> this.webClient + .get() + .uri(this.serverUrl) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId(clientRegistration2.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + ) .subscriberContext(context()) .block(); - + // @formatter:on assertThat(this.server.getRequestCount()).isEqualTo(4); - - ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); - verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient( - authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor + .forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient(authorizedClientCaptor.capture(), + eq(this.authentication), eq(this.request), eq(this.response)); assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1); assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2); } @@ -267,12 +282,16 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { contextAttributes.put(HttpServletRequest.class, this.request); contextAttributes.put(HttpServletResponse.class, this.response); contextAttributes.put(Authentication.class, this.authentication); - return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes); + return Context.of(ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, + contextAttributes); } private MockResponse jsonResponse(String json) { + // @formatter:off return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setBody(json); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 2a0c9cf841..2e3a793462 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -13,8 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.reactive.function.client; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -23,6 +39,9 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; @@ -84,44 +103,19 @@ import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; -import reactor.core.publisher.Mono; -import reactor.util.context.Context; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.net.URI; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.entry; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.http.HttpMethod.GET; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; /** * @author Rob Winch @@ -129,28 +123,40 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c */ @RunWith(MockitoJUnitRunner.class) public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { + @Mock private OAuth2AuthorizedClientRepository authorizedClientRepository; + @Mock private ClientRegistrationRepository clientRegistrationRepository; + @Mock private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; + @Mock private OAuth2AccessTokenResponseClient refreshTokenTokenResponseClient; + @Mock private OAuth2AccessTokenResponseClient passwordTokenResponseClient; + @Mock private OAuth2AuthorizationFailureHandler authorizationFailureHandler; + @Captor private ArgumentCaptor authorizationExceptionCaptor; + @Captor private ArgumentCaptor authenticationCaptor; + @Captor private ArgumentCaptor> attributesCaptor; + @Mock private WebClient.RequestHeadersSpec spec; + @Captor private ArgumentCaptor>> attrs; + @Captor private ArgumentCaptor authorizedClientCaptor; @@ -169,23 +175,22 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token-0", - Instant.now(), - Instant.now().plus(Duration.ofDays(1))); + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token-0", + Instant.now(), Instant.now().plus(Duration.ofDays(1))); @Before public void setup() { this.authentication = new TestingAuthenticationToken("test", "this"); - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) - .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) - .password(configurer -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) - .build(); - this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() + .authorizationCode() + .refreshToken( + (configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials( + (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) + .build(); + this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(this.clientRegistrationRepository, + this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager); } @@ -198,38 +203,44 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null)); } @Test public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) + .withMessage("clientCredentialsTokenResponseClient cannot be null"); } @Test public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { - assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) - .isInstanceOf(IllegalStateException.class) - .hasMessage("The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + assertThatIllegalStateException() + .isThrownBy(() -> this.function + .setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) + .withMessage("The client cannot be set when the constructor used is " + + "\"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, " + + "OAuth2AuthorizedClientRepository)\"."); } @Test public void setAccessTokenExpiresSkewWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { - assertThatThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))) + assertThatIllegalStateException() + .isThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))) .isInstanceOf(IllegalStateException.class) - .hasMessage("The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + - "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + .withMessage("The accessTokenExpiresSkew cannot be set when the constructor used is " + + "\"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, " + + "OAuth2AuthorizedClientRepository)\"."); } @Test public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() { Map attrs = getDefaultRequestAttributes(); - assertThat(getRequest(attrs)).isNull(); - assertThat(getResponse(attrs)).isNull(); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest(attrs)).isNull(); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse(attrs)).isNull(); } @Test @@ -238,73 +249,70 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { MockHttpServletResponse response = new MockHttpServletResponse(); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); Map attrs = getDefaultRequestAttributes(); - assertThat(getRequest(attrs)).isEqualTo(request); - assertThat(getResponse(attrs)).isEqualTo(response); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest(attrs)).isEqualTo(request); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse(attrs)).isEqualTo(response); } @Test public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticationNull() { Map attrs = getDefaultRequestAttributes(); - assertThat(getAuthentication(attrs)).isNull(); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs)).isNull(); } @Test public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() { SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); - assertThat(getAuthentication(attrs)).isEqualTo(this.authentication); + assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs)) + .isEqualTo(this.authentication); verifyNoInteractions(this.authorizedClientRepository); } private Map getDefaultRequestAttributes() { this.function.defaultRequest().accept(this.spec); verify(this.spec).attributes(this.attrs.capture()); - this.attrs.getValue().accept(this.result); - return this.result; } @Test public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); this.function.filter(request, this.exchange).block(); - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); } @Test public void filterWhenAuthorizedClientThenAuthorizationHeader() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } @Test public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } @@ -312,42 +320,33 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefresh() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .refreshToken("refresh-1") - .build(); - when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).refreshToken("refresh-1").build(); + given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(response); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), + issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication(this.authentication)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); - - OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); + verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), + eq(this.authentication), any(), any()); + OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -358,53 +357,45 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefreshToken() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) -// .refreshToken(xxx) // No refreshToken in response + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600) + // .refreshToken(xxx) // No refreshToken in response .build(); - RestOperations refreshTokenClient = mock(RestOperations.class); - when(refreshTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) - .thenReturn(new ResponseEntity(response, HttpStatus.OK)); + given(refreshTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .willReturn(new ResponseEntity(response, HttpStatus.OK)); DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); - RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); - Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), + issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication(this.authentication)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - verify(refreshTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); - verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); - - OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); + verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), + eq(this.authentication), any(), any()); + OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); assertThat(newAuthorizedClient.getRefreshToken().getTokenValue()).isEqualTo(refreshToken.getTokenValue()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -415,25 +406,23 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { this.registration = TestClientRegistrations.clientCredentials().build(); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, null); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, null); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication(this.authentication)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - - verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); - - verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); - + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), + any()); + verify(this.clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request1 = requests.get(0); assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); @@ -444,38 +433,28 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { this.registration = TestClientRegistrations.clientCredentials().build(); - - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses - .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); - + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, null); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), + issuedAt, accessTokenExpiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, null); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication(this.authentication)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); - verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request1 = requests.get(0); assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token"); assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); @@ -486,16 +465,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenPasswordClientNotAuthorizedThenGetNewToken() { OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(360) - .build(); - when(this.passwordTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(360).build(); + given(this.passwordTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); ClientRegistration registration = TestClientRegistrations.password().build(); - when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(registration); - + given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))) + .willReturn(registration); // Set custom contextAttributesMapper - this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> { + this.authorizedClientManager.setContextAttributesMapper((authorizeRequest) -> { Map contextAttributes = new HashMap<>(); HttpServletRequest servletRequest = authorizeRequest.getAttribute(HttpServletRequest.class.getName()); String username = servletRequest.getParameter(OAuth2ParameterNames.USERNAME); @@ -506,24 +482,20 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { } return contextAttributes; }); - MockHttpServletRequest servletRequest = new MockHttpServletRequest(); servletRequest.setParameter(OAuth2ParameterNames.USERNAME, "username"); servletRequest.setParameter(OAuth2ParameterNames.PASSWORD, "password"); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(clientRegistrationId(registration.getRegistrationId())) - .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(servletRequest)) - .attributes(httpServletResponse(servletResponse)) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction + .clientRegistrationId(registration.getRegistrationId())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication(this.authentication)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest(servletRequest)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse(servletResponse)) .build(); - this.function.filter(request, this.exchange).block(); - verify(this.passwordTokenResponseClient).getTokenResponse(any()); - verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any(), any()); - + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); ClientRequest request1 = requests.get(0); @@ -536,34 +508,28 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(3600) - .refreshToken("refresh-1") - .build(); - when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); - + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(3600).refreshToken("refresh-1").build(); + given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(response); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), + issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any()); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -573,20 +539,19 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -597,20 +562,19 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken, refreshToken); - - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletRequest(new MockHttpServletRequest())) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction + .httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange).block(); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); - ClientRequest request0 = requests.get(0); assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); @@ -622,44 +586,33 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { this.function.setDefaultOAuth2AuthorizedClient(true); - MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( - user, authorities, this.registration.getRegistrationId()); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration, "principalName", this.accessToken); - - when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()), - eq(authentication), eq(servletRequest))).thenReturn(authorizedClient); - + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities, + this.registration.getRegistrationId()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + given(this.authorizedClientRepository.loadAuthorizedClient( + eq(authentication.getAuthorizedClientRegistrationId()), eq(authentication), eq(servletRequest))) + .willReturn(authorizedClient); // Default request attributes set - final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com")) - .attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build(); - + final ClientRequest request1 = ClientRequest.create(HttpMethod.GET, URI.create("https://example1.com")) + .attributes((attrs) -> attrs.putAll(getDefaultRequestAttributes())).build(); // Default request attributes NOT set - final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build(); - + final ClientRequest request2 = ClientRequest.create(HttpMethod.GET, URI.create("https://example2.com")).build(); Context context = context(servletRequest, servletResponse, authentication); - this.function.filter(request1, this.exchange) - .flatMap(response -> this.function.filter(request2, this.exchange)) - .subscriberContext(context) + .flatMap((response) -> this.function.filter(request2, this.exchange)).subscriberContext(context) .block(); - List requests = this.exchange.getRequests(); assertThat(requests).hasSize(2); - ClientRequest request = requests.get(0); assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com"); assertThat(request.method()).isEqualTo(HttpMethod.GET); assertThat(getBody(request)).isEmpty(); - request = requests.get(1); assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com"); @@ -678,210 +631,180 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { } private void assertHttpStatusInvokesFailureHandler(HttpStatus httpStatus, String expectedErrorCode) { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration, "principalName", this.accessToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(servletRequest)) - .attributes(httpServletResponse(servletResponse)) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest(servletRequest)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse(servletResponse)) .build(); - - when(this.exchange.getResponse().rawStatusCode()).thenReturn(httpStatus.value()); - when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class)); + given(this.exchange.getResponse().rawStatusCode()).willReturn(httpStatus.value()); + given(this.exchange.getResponse().headers()).willReturn(mock(ClientResponse.Headers.class)); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); - this.function.filter(request, this.exchange).block(); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - this.authorizationExceptionCaptor.capture(), - this.authenticationCaptor.capture(), - this.attributesCaptor.capture()); - + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); assertThat(this.authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode); - assertThat(e).hasNoCause(); - assertThat(e).hasMessageContaining(expectedErrorCode); + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo(expectedErrorCode); + assertThat(ex).hasNoCause(); + assertThat(ex).hasMessageContaining(expectedErrorCode); }); - assertThat(this.authenticationCaptor.getValue().getName()) - .isEqualTo(authorizedClient.getPrincipalName()); - assertThat(this.attributesCaptor.getValue()) - .containsExactly( - entry(HttpServletRequest.class.getName(), servletRequest), - entry(HttpServletResponse.class.getName(), servletResponse)); + assertThat(this.authenticationCaptor.getValue().getName()).isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()).containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); } @Test public void filterWhenWWWAuthenticateHeaderIncludesErrorThenInvokeFailureHandler() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration, "principalName", this.accessToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(servletRequest)) - .attributes(httpServletResponse(servletResponse)) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest(servletRequest)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse(servletResponse)) .build(); - - String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " + - "error_description=\"The request requires higher privileges than provided by the access token.\", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""; + String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " + + "error_description=\"The request requires higher privileges than provided by the access token.\", " + + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""; ClientResponse.Headers headers = mock(ClientResponse.Headers.class); - when(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE))) - .thenReturn(Collections.singletonList(wwwAuthenticateHeader)); - when(this.exchange.getResponse().headers()).thenReturn(headers); + given(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE))) + .willReturn(Collections.singletonList(wwwAuthenticateHeader)); + given(this.exchange.getResponse().headers()).willReturn(headers); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); - this.function.filter(request, this.exchange).block(); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - this.authorizationExceptionCaptor.capture(), - this.authenticationCaptor.capture(), - this.attributesCaptor.capture()); - + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); assertThat(this.authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); - assertThat(e.getError().getDescription()).isEqualTo("The request requires higher privileges than provided by the access token."); - assertThat(e.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); - assertThat(e).hasNoCause(); - assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + assertThat(ex.getError().getDescription()) + .isEqualTo("The request requires higher privileges than provided by the access token."); + assertThat(ex.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); + assertThat(ex).hasNoCause(); + assertThat(ex).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); }); - assertThat(this.authenticationCaptor.getValue().getName()) - .isEqualTo(authorizedClient.getPrincipalName()); - assertThat(this.attributesCaptor.getValue()) - .containsExactly( - entry(HttpServletRequest.class.getName(), servletRequest), - entry(HttpServletResponse.class.getName(), servletResponse)); + assertThat(this.authenticationCaptor.getValue().getName()).isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()).containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); } @Test public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() { - assertHttpStatusWithWebClientExceptionInvokesFailureHandler( - HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN); + assertHttpStatusWithWebClientExceptionInvokesFailureHandler(HttpStatus.UNAUTHORIZED, + OAuth2ErrorCodes.INVALID_TOKEN); } @Test public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() { - assertHttpStatusWithWebClientExceptionInvokesFailureHandler( - HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + assertHttpStatusWithWebClientExceptionInvokesFailureHandler(HttpStatus.FORBIDDEN, + OAuth2ErrorCodes.INSUFFICIENT_SCOPE); } - private void assertHttpStatusWithWebClientExceptionInvokesFailureHandler( - HttpStatus httpStatus, String expectedErrorCode) { - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration, "principalName", this.accessToken); + private void assertHttpStatusWithWebClientExceptionInvokesFailureHandler(HttpStatus httpStatus, + String expectedErrorCode) { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(servletRequest)) - .attributes(httpServletResponse(servletResponse)) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest(servletRequest)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse(servletResponse)) .build(); - - WebClientResponseException exception = WebClientResponseException.create( - httpStatus.value(), - httpStatus.getReasonPhrase(), - HttpHeaders.EMPTY, - new byte[0], - StandardCharsets.UTF_8); - ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); + WebClientResponseException exception = WebClientResponseException.create(httpStatus.value(), + httpStatus.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8); + ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); - - assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block()) - .isEqualTo(exception); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - this.authorizationExceptionCaptor.capture(), - this.authenticationCaptor.capture(), - this.attributesCaptor.capture()); - + assertThatExceptionOfType(WebClientResponseException.class) + .isThrownBy(() -> this.function.filter(request, throwingExchangeFunction).block()).isEqualTo(exception); + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); assertThat(this.authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { - assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); - assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode); - assertThat(e).hasCause(exception); - assertThat(e).hasMessageContaining(expectedErrorCode); + .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { + assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(ex.getError().getErrorCode()).isEqualTo(expectedErrorCode); + assertThat(ex).hasCause(exception); + assertThat(ex).hasMessageContaining(expectedErrorCode); }); - assertThat(this.authenticationCaptor.getValue().getName()) - .isEqualTo(authorizedClient.getPrincipalName()); - assertThat(this.attributesCaptor.getValue()) - .containsExactly( - entry(HttpServletRequest.class.getName(), servletRequest), - entry(HttpServletResponse.class.getName(), servletResponse)); + assertThat(this.authenticationCaptor.getValue().getName()).isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()).containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); } @Test public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration, "principalName", this.accessToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(servletRequest)) - .attributes(httpServletResponse(servletResponse)) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest(servletRequest)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse(servletResponse)) .build(); - OAuth2AuthorizationException authorizationException = new OAuth2AuthorizationException( new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); - ExchangeFunction throwingExchangeFunction = r -> Mono.error(authorizationException); + ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(authorizationException); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); - - assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block()) + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.function.filter(request, throwingExchangeFunction).block()) .isEqualTo(authorizationException); - - verify(this.authorizationFailureHandler).onAuthorizationFailure( - this.authorizationExceptionCaptor.capture(), - this.authenticationCaptor.capture(), - this.attributesCaptor.capture()); - + verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), this.attributesCaptor.capture()); assertThat(this.authorizationExceptionCaptor.getValue()) - .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> { - assertThat(e.getError().getErrorCode()).isEqualTo(authorizationException.getError().getErrorCode()); - assertThat(e).hasNoCause(); - assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INVALID_TOKEN); + .isInstanceOfSatisfying(OAuth2AuthorizationException.class, (ex) -> { + assertThat(ex.getError().getErrorCode()) + .isEqualTo(authorizationException.getError().getErrorCode()); + assertThat(ex).hasNoCause(); + assertThat(ex).hasMessageContaining(OAuth2ErrorCodes.INVALID_TOKEN); }); - assertThat(this.authenticationCaptor.getValue().getName()) - .isEqualTo(authorizedClient.getPrincipalName()); - assertThat(this.attributesCaptor.getValue()) - .containsExactly( - entry(HttpServletRequest.class.getName(), servletRequest), - entry(HttpServletResponse.class.getName(), servletResponse)); + assertThat(this.authenticationCaptor.getValue().getName()).isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()).containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); } @Test public void filterWhenOtherHttpStatusThenDoesNotInvokeFailureHandler() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration, "principalName", this.accessToken); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(servletRequest)) - .attributes(httpServletResponse(servletResponse)) + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes( + ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest(servletRequest)) + .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse(servletResponse)) .build(); - - when(this.exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value()); - when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class)); + given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.BAD_REQUEST.value()); + given(this.exchange.getResponse().headers()).willReturn(mock(ClientResponse.Headers.class)); this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); - this.function.filter(request, this.exchange).block(); - verifyNoInteractions(this.authorizationFailureHandler); } - private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) { + private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, + Authentication authentication) { Map contextAttributes = new HashMap<>(); contextAttributes.put(HttpServletRequest.class, servletRequest); contextAttributes.put(HttpServletResponse.class, servletResponse); contextAttributes.put(Authentication.class, authentication); - return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes); + return Context.of(ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, + contextAttributes); } private static String getBody(ClientRequest request) { @@ -895,7 +818,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { messageWriters.add(new FormHttpMessageWriter()); messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes())); messageWriters.add(new MultipartHttpMessageWriter(messageWriters)); - BodyInserter.Context context = new BodyInserter.Context() { @Override public List> messageWriters() { @@ -912,9 +834,9 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { return new HashMap<>(); } }; - MockClientHttpRequest body = new MockClientHttpRequest(HttpMethod.GET, "/"); request.body().insert(body, context).block(); return body.getBodyAsString().block(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 0c66fb9f44..cae90564be 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -16,11 +16,16 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.annotation; +import java.lang.reflect.Method; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + import org.springframework.core.MethodParameter; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; @@ -41,17 +46,14 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.util.ReflectionUtils; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; -import reactor.util.context.Context; - -import java.lang.reflect.Method; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -59,84 +61,94 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizedClientArgumentResolverTests { + @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; + @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); private OAuth2AuthorizedClientArgumentResolver argumentResolver; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication authentication = new TestingAuthenticationToken("test", "this"); @Before public void setUp() { - ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = - ReactiveOAuth2AuthorizedClientProviderBuilder.builder() - .authorizationCode() - .refreshToken() - .clientCredentials() - .build(); + // @formatter:off + ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder + .builder() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + // @formatter:on DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); - this.authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.authentication.getName(), TestOAuth2AccessTokens.noScopes()); - when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient)); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.authentication.getName(), + TestOAuth2AccessTokens.noScopes()); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())) + .willReturn(Mono.just(this.authorizedClient)); } @Test public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository)); } @Test public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)); } @Test public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)); } @Test public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientThenTrue() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.supportsParameter(methodParameter)).isTrue(); } @Test public void supportsParameterWhenParameterTypeOAuth2AuthorizedClientWithoutAnnotationThenFalse() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", OAuth2AuthorizedClient.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClientWithoutAnnotation", + OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); } @Test public void supportsParameterWhenParameterTypeUnsupportedWithoutAnnotationThenFalse() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", String.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeUnsupportedWithoutAnnotation", + String.class); assertThat(this.argumentResolver.supportsParameter(methodParameter)).isFalse(); } @Test public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() { MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> resolveArgument(methodParameter)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The clientRegistrationId could not be resolved. Please provide one"); + assertThatIllegalArgumentException().isThrownBy(() -> resolveArgument(methodParameter)) + .withMessage("The clientRegistrationId could not be resolved. Please provide one"); } @Test public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { this.authentication = mock(OAuth2AuthenticationToken.class); - when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1"); + given(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()) + .willReturn("client1"); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); resolveArgument(methodParameter); } @@ -144,30 +156,34 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() { this.authentication = null; - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); - when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty()); - MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> resolveArgument(methodParameter)) - .isInstanceOf(ClientAuthorizationRequiredException.class); + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).willReturn(Mono.empty()); + MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", + OAuth2AuthorizedClient.class); + assertThatExceptionOfType(ClientAuthorizationRequiredException.class) + .isThrownBy(() -> resolveArgument(methodParameter)); } private Object resolveArgument(MethodParameter methodParameter) { return this.argumentResolver.resolveArgument(methodParameter, null, null) - .subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication)) - .subscriberContext(serverWebExchange()) - .block(); + .subscriberContext((this.authentication != null) + ? ReactiveSecurityContextHolder.withAuthentication(this.authentication) : Context.empty()) + .subscriberContext(serverWebExchange()).block(); } private Context serverWebExchange() { @@ -175,13 +191,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests { } private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { - Method method = ReflectionUtils.findMethod( - TestController.class, methodName, paramTypes); + Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes); return new MethodParameter(method, 0); } static class TestController { - void paramTypeAuthorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { + + void paramTypeAuthorizedClient( + @RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) { } void paramTypeAuthorizedClientWithoutAnnotation(OAuth2AuthorizedClient authorizedClient) { @@ -195,5 +212,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests { void registrationIdEmpty(@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { } + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java index 0021ebd6bf..9693d71749 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -26,25 +29,26 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; -import reactor.core.publisher.Mono; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** - * * @author Rob Winch */ public class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests { + private String registrationId = "registrationId"; + private String principalName = "principalName"; + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ServerOAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository; + private AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository authorizedClientRepository; private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); @@ -52,43 +56,48 @@ public class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests { @Before public void setup() { this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); - this.anonymousAuthorizedClientRepository = mock( - ServerOAuth2AuthorizedClientRepository.class); - this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(this.authorizedClientService); - this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository); + this.anonymousAuthorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class); + this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( + this.authorizedClientService); + this.authorizedClientRepository + .setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository); } @Test public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null)); } @Test public void setAuthorizedClientRepositoryWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null)); } @Test public void loadAuthorizedClientWhenAuthenticatedPrincipalThenLoadFromService() { - when(this.authorizedClientService.loadAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); Authentication authentication = this.createAuthenticatedPrincipal(); - this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.exchange).block(); + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.exchange) + .block(); verify(this.authorizedClientService).loadAuthorizedClient(this.registrationId, this.principalName); } @Test public void loadAuthorizedClientWhenAnonymousPrincipalThenLoadFromAnonymousRepository() { - when(this.anonymousAuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + given(this.anonymousAuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())) + .willReturn(Mono.empty()); Authentication authentication = this.createAnonymousPrincipal(); - this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.exchange).block(); - verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, this.exchange); + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.exchange) + .block(); + verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, + this.exchange); } @Test public void saveAuthorizedClientWhenAuthenticatedPrincipalThenSaveToService() { - when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + given(this.authorizedClientService.saveAuthorizedClient(any(), any())).willReturn(Mono.empty()); Authentication authentication = this.createAuthenticatedPrincipal(); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.exchange).block(); @@ -97,27 +106,33 @@ public class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests { @Test public void saveAuthorizedClientWhenAnonymousPrincipalThenSaveToAnonymousRepository() { - when(this.anonymousAuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + given(this.anonymousAuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())) + .willReturn(Mono.empty()); Authentication authentication = this.createAnonymousPrincipal(); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.exchange).block(); - verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, this.exchange); + verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, + this.exchange); } @Test public void removeAuthorizedClientWhenAuthenticatedPrincipalThenRemoveFromService() { - when(this.authorizedClientService.removeAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + given(this.authorizedClientService.removeAuthorizedClient(any(), any())).willReturn(Mono.empty()); Authentication authentication = this.createAuthenticatedPrincipal(); - this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.exchange).block(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.exchange) + .block(); verify(this.authorizedClientService).removeAuthorizedClient(this.registrationId, this.principalName); } @Test public void removeAuthorizedClientWhenAnonymousPrincipalThenRemoveFromAnonymousRepository() { - when(this.anonymousAuthorizedClientRepository.removeAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + given(this.anonymousAuthorizedClientRepository.removeAuthorizedClient(any(), any(), any())) + .willReturn(Mono.empty()); Authentication authentication = this.createAnonymousPrincipal(); - this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.exchange).block(); - verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, this.exchange); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.exchange) + .block(); + verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, + this.exchange); } private Authentication createAuthenticatedPrincipal() { @@ -127,6 +142,8 @@ public class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests { } private Authentication createAnonymousPrincipal() { - return new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + return new AnonymousAuthenticationToken("key-1234", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java index 958799b014..ecccefbb43 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java @@ -21,6 +21,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpStatus; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; @@ -35,13 +37,12 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.catchThrowableOfType; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -49,6 +50,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class DefaultServerOAuth2AuthorizationRequestResolverTests { + @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; @@ -63,8 +65,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests { @Test public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null)); } @Test @@ -74,155 +75,112 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests { @Test public void resolveWhenClientRegistrationNotFoundMatchThenBadRequest() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.empty()); - - ResponseStatusException expected = catchThrowableOfType(() -> resolve("/oauth2/authorization/not-found-id"), ResponseStatusException.class); - - assertThat(expected.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST); + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.empty()); + assertThatExceptionOfType(ResponseStatusException.class) + .isThrownBy(() -> resolve("/oauth2/authorization/not-found-id")) + .satisfies((ex) -> assertThat(ex.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST)); } @Test public void resolveWhenClientRegistrationFoundThenWorks() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(this.registration)); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration)); OAuth2AuthorizationRequest request = resolve("/oauth2/authorization/not-found-id"); - - assertThat(request.getAuthorizationRequestUri()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.*?&" + - "redirect_uri=/login/oauth2/code/registration-id"); + assertThat(request.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.*?&" + "redirect_uri=/login/oauth2/code/registration-id"); } @Test public void resolveWhenForwardedHeadersClientRegistrationFoundThenWorks() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(this.registration)); - ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/oauth2/authorization/id").header("X-Forwarded-Host", "evil.com")); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration)); + // @formatter:off + MockServerHttpRequest.BaseBuilder httpRequest = MockServerHttpRequest + .get("/oauth2/authorization/id") + .header("X-Forwarded-Host", "evil.com"); + // @formatter:on + ServerWebExchange exchange = MockServerWebExchange.from(httpRequest); OAuth2AuthorizationRequest request = this.resolver.resolve(exchange).block(); - - assertThat(request.getAuthorizationRequestUri()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.*?&" + - "redirect_uri=/login/oauth2/code/registration-id"); + assertThat(request.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.*?&" + "redirect_uri=/login/oauth2/code/registration-id"); } @Test public void resolveWhenAuthorizationRequestWithValidPkceClientThenResolves() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(TestClientRegistrations.clientRegistration() - .clientAuthenticationMethod(ClientAuthenticationMethod.NONE) - .clientSecret(null) - .build())); - + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration() + .clientAuthenticationMethod(ClientAuthenticationMethod.NONE).clientSecret(null).build())); OAuth2AuthorizationRequest request = resolve("/oauth2/authorization/registration-id"); - - assertThat((String) request.getAttribute(PkceParameterNames.CODE_VERIFIER)).matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); - - assertThat(request.getAuthorizationRequestUri()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=read:user&state=.*?&" + - "redirect_uri=/login/oauth2/code/registration-id&" + - "code_challenge_method=S256&" + - "code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); + assertThat((String) request.getAttribute(PkceParameterNames.CODE_VERIFIER)) + .matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); + assertThat(request.getAuthorizationRequestUri()) + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.*?&" + "redirect_uri=/login/oauth2/code/registration-id&" + + "code_challenge_method=S256&" + "code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } @Test public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(TestClientRegistrations.clientRegistration() - .scope(OidcScopes.OPENID) - .build())); - + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().scope(OidcScopes.OPENID).build())); OAuth2AuthorizationRequest request = resolve("/oauth2/authorization/registration-id"); - assertThat((String) request.getAttribute(OidcParameterNames.NONCE)).matches("^([a-zA-Z0-9\\-\\.\\_\\~]){128}$"); - - assertThat(request.getAuthorizationRequestUri()).matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=openid&state=.*?&" + - "redirect_uri=/login/oauth2/code/registration-id&" + - "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); + assertThat(request.getAuthorizationRequestUri()).matches("https://example.com/login/oauth/authorize\\?" + + "response_type=code&client_id=client-id&" + "scope=openid&state=.*?&" + + "redirect_uri=/login/oauth2/code/registration-id&" + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } // gh-7696 @Test public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(TestClientRegistrations.clientRegistration() - .scope(OidcScopes.OPENID) - .build())); - - this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer - .additionalParameters(params -> params.remove(OidcParameterNames.NONCE)) - .attributes(attrs -> attrs.remove(OidcParameterNames.NONCE))); - + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().scope(OidcScopes.OPENID).build())); + this.resolver.setAuthorizationRequestCustomizer( + (customizer) -> customizer.additionalParameters((params) -> params.remove(OidcParameterNames.NONCE)) + .attributes((attrs) -> attrs.remove(OidcParameterNames.NONCE))); OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id"); - assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE); assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID); assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=/login/oauth2/code/registration-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + "redirect_uri=/login/oauth2/code/registration-id"); } @Test public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(TestClientRegistrations.clientRegistration() - .scope(OidcScopes.OPENID) - .build())); - - this.resolver.setAuthorizationRequestCustomizer(customizer -> - customizer.authorizationRequestUri(uriBuilder -> { + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().scope(OidcScopes.OPENID).build())); + this.resolver + .setAuthorizationRequestCustomizer((customizer) -> customizer.authorizationRequestUri((uriBuilder) -> { uriBuilder.queryParam("param1", "value1"); return uriBuilder.build(); - }) - ); - + })); OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id"); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&client_id=client-id&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=/login/oauth2/code/registration-id&" + - "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + - "param1=value1"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=openid&state=.{15,}&" + "redirect_uri=/login/oauth2/code/registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + "param1=value1"); } @Test public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn( - Mono.just(TestClientRegistrations.clientRegistration() - .scope(OidcScopes.OPENID) - .build())); - - this.resolver.setAuthorizationRequestCustomizer(customizer -> - customizer.parameters(params -> { - params.put("appid", params.get("client_id")); - params.remove("client_id"); - }) - ); - + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().scope(OidcScopes.OPENID).build())); + this.resolver.setAuthorizationRequestCustomizer((customizer) -> customizer.parameters((params) -> { + params.put("appid", params.get("client_id")); + params.remove("client_id"); + })); OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id"); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .matches("https://example.com/login/oauth/authorize\\?" + - "response_type=code&" + - "scope=openid&state=.{15,}&" + - "redirect_uri=/login/oauth2/code/registration-id&" + - "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + - "appid=client-id"); + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&" + + "scope=openid&state=.{15,}&" + "redirect_uri=/login/oauth2/code/registration-id&" + + "nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" + "appid=client-id"); } private OAuth2AuthorizationRequest resolve(String path) { ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path)); return this.resolver.resolve(exchange).block(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java index 1efd84389e..b474ce8a9f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java @@ -16,11 +16,19 @@ package org.springframework.security.oauth2.client.web.server; +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -34,28 +42,21 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.util.CollectionUtils; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.handler.DefaultWebFilterChain; -import reactor.core.publisher.Mono; - -import java.net.URI; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; /** * @author Rob Winch @@ -64,138 +65,121 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author */ @RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizationCodeGrantWebFilterTests { + private OAuth2AuthorizationCodeGrantWebFilter filter; + @Mock private ReactiveAuthenticationManager authenticationManager; + @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; + @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + @Mock private ServerAuthorizationRequestRepository authorizationRequestRepository; @Before public void setup() { - this.filter = new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, this.clientRegistrationRepository, - this.authorizedClientRepository); + this.filter = new OAuth2AuthorizationCodeGrantWebFilter(this.authenticationManager, + this.clientRegistrationRepository, this.authorizedClientRepository); this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } @Test public void constructorWhenAuthenticationManagerNullThenIllegalArgumentException() { this.authenticationManager = null; - assertThatCode(() -> new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, this.clientRegistrationRepository, - this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeGrantWebFilter(this.authenticationManager, + this.clientRegistrationRepository, this.authorizedClientRepository)); } @Test public void constructorWhenClientRegistrationRepositoryNullThenIllegalArgumentException() { this.clientRegistrationRepository = null; - assertThatCode(() -> new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, this.clientRegistrationRepository, - this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeGrantWebFilter(this.authenticationManager, + this.clientRegistrationRepository, this.authorizedClientRepository)); } @Test public void constructorWhenAuthorizedClientRepositoryNullThenIllegalArgumentException() { this.authorizedClientRepository = null; - assertThatCode(() -> new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, this.clientRegistrationRepository, - this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationCodeGrantWebFilter(this.authenticationManager, + this.clientRegistrationRepository, this.authorizedClientRepository)); } @Test public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { - assertThatCode(() -> this.filter.setRequestCache(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); } @Test public void filterWhenNotMatchThenAuthenticationManagerNotCalled() { - MockServerWebExchange exchange = MockServerWebExchange - .from(MockServerHttpRequest.get("/")); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); this.filter.filter(exchange, chain).block(); - verifyNoInteractions(this.authenticationManager); } @Test public void filterWhenMatchThenAuthorizedClientSaved() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())) - .thenReturn(Mono.just(clientRegistration)); - - MockServerHttpRequest authorizationRequest = - createAuthorizationRequest("/authorization/callback"); - OAuth2AuthorizationRequest oauth2AuthorizationRequest = - createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) - .thenReturn(Mono.empty()); - when(this.authenticationManager.authenticate(any())) - .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(clientRegistration)); + MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = createOAuth2AuthorizationRequest(authorizationRequest, + clientRegistration); + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); this.filter.filter(exchange, chain).block(); - - verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), + any()); } // gh-7966 @Test public void filterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())) - .thenReturn(Mono.just(clientRegistration)); - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) - .thenReturn(Mono.empty()); - when(this.authenticationManager.authenticate(any())) - .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(clientRegistration)); + given(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); // 1) redirect_uri with query parameters Map parameters = new LinkedHashMap<>(); parameters.put("param1", "value1"); parameters.put("param2", "value2"); - MockServerHttpRequest authorizationRequest = - createAuthorizationRequest("/authorization/callback", parameters); - OAuth2AuthorizationRequest oauth2AuthorizationRequest = - createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - + MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback", parameters); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = createOAuth2AuthorizationRequest(authorizationRequest, + clientRegistration); + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); this.filter.filter(exchange, chain).block(); verify(this.authenticationManager, times(1)).authenticate(any()); - - // 2) redirect_uri with query parameters AND authorization response additional parameters + // 2) redirect_uri with query parameters AND authorization response additional + // parameters Map additionalParameters = new LinkedHashMap<>(); additionalParameters.put("auth-param1", "value1"); additionalParameters.put("auth-param2", "value2"); authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters); exchange = MockServerWebExchange.from(authorizationResponse); - this.filter.filter(exchange, chain).block(); verify(this.authenticationManager, times(2)).authenticate(any()); } @@ -207,44 +191,35 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { Map parameters = new LinkedHashMap<>(); parameters.put("param1", "value1"); parameters.put("param2", "value2"); - MockServerHttpRequest authorizationRequest = - createAuthorizationRequest(requestUri, parameters); + MockServerHttpRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizationRequest oauth2AuthorizationRequest = - createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - + OAuth2AuthorizationRequest oauth2AuthorizationRequest = createOAuth2AuthorizationRequest(authorizationRequest, + clientRegistration); + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); // 1) Parameter value Map parametersNotMatch = new LinkedHashMap<>(parameters); parametersNotMatch.put("param2", "value8"); MockServerHttpRequest authorizationResponse = createAuthorizationResponse( createAuthorizationRequest(requestUri, parametersNotMatch)); MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); this.filter.filter(exchange, chain).block(); verifyNoInteractions(this.authenticationManager); - // 2) Parameter order parametersNotMatch = new LinkedHashMap<>(); parametersNotMatch.put("param2", "value2"); parametersNotMatch.put("param1", "value1"); - authorizationResponse = createAuthorizationResponse( - createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse = createAuthorizationResponse(createAuthorizationRequest(requestUri, parametersNotMatch)); exchange = MockServerWebExchange.from(authorizationResponse); - this.filter.filter(exchange, chain).block(); verifyNoInteractions(this.authenticationManager); - // 3) Parameter missing parametersNotMatch = new LinkedHashMap<>(parameters); parametersNotMatch.remove("param2"); - authorizationResponse = createAuthorizationResponse( - createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse = createAuthorizationResponse(createAuthorizationRequest(requestUri, parametersNotMatch)); exchange = MockServerWebExchange.from(authorizationResponse); - this.filter.filter(exchange, chain).block(); verifyNoInteractions(this.authenticationManager); } @@ -252,33 +227,26 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Test public void filterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestCacheUsed() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())) - .thenReturn(Mono.just(clientRegistration)); - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) - .thenReturn(Mono.empty()); - when(this.authenticationManager.authenticate(any())) - .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(clientRegistration)); + given(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback"); - OAuth2AuthorizationRequest oauth2AuthorizationRequest = - createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - + OAuth2AuthorizationRequest oauth2AuthorizationRequest = createOAuth2AuthorizationRequest(authorizationRequest, + clientRegistration); + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); ServerRequestCache requestCache = mock(ServerRequestCache.class); - when(requestCache.getRedirectUri(any(ServerWebExchange.class))).thenReturn(Mono.just(URI.create("/saved-request"))); - + given(requestCache.getRedirectUri(any(ServerWebExchange.class))) + .willReturn(Mono.just(URI.create("/saved-request"))); this.filter.setRequestCache(requestCache); - this.filter.filter(exchange, chain).block(); - verify(requestCache).getRedirectUri(exchange); assertThat(exchange.getResponse().getHeaders().getLocation().toString()).isEqualTo("/saved-request"); } @@ -287,27 +255,22 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Test public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); - - MockServerHttpRequest authorizationRequest = - createAuthorizationRequest("/authorization/callback"); - OAuth2AuthorizationRequest oauth2AuthorizationRequest = - createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.empty()); + MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = createOAuth2AuthorizationRequest(authorizationRequest, + clientRegistration); + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - - assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo("client_registration_not_found"); + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.filter.filter(exchange, chain).block()) + .satisfies((ex) -> assertThat(ex.getError()).extracting("errorCode") + .isEqualTo("client_registration_not_found")); verifyNoInteractions(this.authenticationManager); } @@ -315,41 +278,35 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Test public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())) - .thenReturn(Mono.just(clientRegistration)); - - MockServerHttpRequest authorizationRequest = - createAuthorizationRequest("/authorization/callback"); - OAuth2AuthorizationRequest oauth2AuthorizationRequest = - createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); - when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) - .thenReturn(Mono.just(oauth2AuthorizationRequest)); - - when(this.authenticationManager.authenticate(any())) - .thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error")))); - + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(clientRegistration)); + MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = createOAuth2AuthorizationRequest(authorizationRequest, + clientRegistration); + given(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(oauth2AuthorizationRequest)); + given(this.authenticationManager.authenticate(any())) + .willReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error")))); MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete(), Collections.emptyList()); - - assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo("authorization_error"); + DefaultWebFilterChain chain = new DefaultWebFilterChain((e) -> e.getResponse().setComplete(), + Collections.emptyList()); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.filter.filter(exchange, chain).block()) + .satisfies((ex) -> assertThat(ex.getError()).extracting("errorCode").isEqualTo("authorization_error")); } private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest( MockServerHttpRequest authorizationRequest, ClientRegistration registration) { Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); - return request() + // @formatter:off + return TestOAuth2AuthorizationRequests.request() .attributes(attributes) .redirectUri(authorizationRequest.getURI().toString()) .build(); + // @formatter:on } private static MockServerHttpRequest createAuthorizationRequest(String requestUri) { @@ -357,8 +314,7 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { } private static MockServerHttpRequest createAuthorizationRequest(String requestUri, Map parameters) { - MockServerHttpRequest.BaseBuilder builder = MockServerHttpRequest - .get(requestUri); + MockServerHttpRequest.BaseBuilder builder = MockServerHttpRequest.get(requestUri); if (!CollectionUtils.isEmpty(parameters)) { parameters.forEach(builder::queryParam); } @@ -369,8 +325,8 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>()); } - private static MockServerHttpRequest createAuthorizationResponse( - MockServerHttpRequest authorizationRequest, Map additionalParameters) { + private static MockServerHttpRequest createAuthorizationResponse(MockServerHttpRequest authorizationRequest, + Map additionalParameters) { MockServerHttpRequest.BaseBuilder builder = MockServerHttpRequest .get(authorizationRequest.getURI().toString()); builder.queryParam(OAuth2ParameterNames.CODE, "code"); @@ -379,4 +335,5 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { builder.cookies(authorizationRequest.getCookies()); return builder.build(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java index dd5d8443db..448425169e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java @@ -16,11 +16,16 @@ package org.springframework.security.oauth2.client.web.server; +import java.net.URI; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; @@ -30,17 +35,13 @@ import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.server.handler.FilteringWebHandler; -import reactor.core.publisher.Mono; - -import java.net.URI; -import java.util.Arrays; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verifyNoInteractions; /** * @author Rob Winch @@ -48,6 +49,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizationRequestRedirectWebFilterTests { + @Mock private ReactiveClientRegistrationRepository clientRepository; @@ -67,46 +69,45 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { public void setup() { this.filter = new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository); this.filter.setAuthorizationRequestRepository(this.authzRequestRepository); - FilteringWebHandler webHandler = new FilteringWebHandler(e -> e.getResponse().setComplete(), Arrays.asList(this.filter)); - + FilteringWebHandler webHandler = new FilteringWebHandler((e) -> e.getResponse().setComplete(), + Arrays.asList(this.filter)); this.client = WebTestClient.bindToWebHandler(webHandler).build(); - when(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())).thenReturn( - Mono.just(this.registration)); - when(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).thenReturn( - Mono.empty()); + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); } @Test public void constructorWhenClientRegistrationRepositoryNullThenIllegalArgumentException() { this.clientRepository = null; - assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository)); } @Test public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() { + // @formatter:off this.client.get() .exchange() .expectStatus().isOk(); - - verifyZeroInteractions(this.clientRepository, this.authzRequestRepository); + // @formatter:on + verifyNoInteractions(this.clientRepository, this.authzRequestRepository); } @Test public void filterWhenDoesMatchThenClientRegistrationRepositoryNotSubscribed() { + // @formatter:off FluxExchangeResult result = this.client.get() - .uri("https://example.com/oauth2/authorization/registration-id").exchange() - .expectStatus().is3xxRedirection().returnResult(String.class); + .uri("https://example.com/oauth2/authorization/registration-id") + .exchange() + .expectStatus().is3xxRedirection() + .returnResult(String.class); + // @formatter:on result.assertWithDiagnostics(() -> { URI location = result.getResponseHeaders().getLocation(); - assertThat(location) - .hasScheme("https") - .hasHost("example.com") - .hasPath("/login/oauth/authorize") - .hasParameter("response_type", "code") - .hasParameter("client_id", "client-id") - .hasParameter("scope", "read:user") - .hasParameter("state") + assertThat(location).hasScheme("https").hasHost("example.com").hasPath("/login/oauth/authorize") + .hasParameter("response_type", "code").hasParameter("client_id", "client-id") + .hasParameter("scope", "read:user").hasParameter("state") .hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id"); }); verify(this.authzRequestRepository).saveAuthorizationRequest(any(), any()); @@ -115,9 +116,10 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { // gh-5520 @Test public void filterWhenDoesMatchThenResolveRedirectUriExpandedExcludesQueryString() { + // @formatter:off FluxExchangeResult result = this.client.get() - .uri("https://example.com/oauth2/authorization/registration-id?foo=bar").exchange() - .expectStatus().is3xxRedirection().returnResult(String.class); + .uri("https://example.com/oauth2/authorization/registration-id?foo=bar").exchange().expectStatus() + .is3xxRedirection().returnResult(String.class); result.assertWithDiagnostics(() -> { URI location = result.getResponseHeaders().getLocation(); assertThat(location) @@ -130,46 +132,55 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { .hasParameter("state") .hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id"); }); + // @formatter:on } @Test public void filterWhenExceptionThenRedirected() { - FilteringWebHandler webHandler = new FilteringWebHandler(e -> Mono.error(new ClientAuthorizationRequiredException(this.registration - .getRegistrationId())), Arrays.asList(this.filter)); - this.client = WebTestClient.bindToWebHandler(webHandler).build(); + FilteringWebHandler webHandler = new FilteringWebHandler( + (e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())), + Arrays.asList(this.filter)); + // @formatter:off + this.client = WebTestClient.bindToWebHandler(webHandler) + .build(); FluxExchangeResult result = this.client.get() - .uri("https://example.com/foo").exchange() - .expectStatus() - .is3xxRedirection() + .uri("https://example.com/foo") + .exchange() + .expectStatus().is3xxRedirection() .returnResult(String.class); + // @formatter:on } @Test public void filterWhenExceptionThenSaveRequestSessionAttribute() { this.filter.setRequestCache(this.requestCache); - when(this.requestCache.saveRequest(any())).thenReturn(Mono.empty()); + given(this.requestCache.saveRequest(any())).willReturn(Mono.empty()); FilteringWebHandler webHandler = new FilteringWebHandler( - e -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())), + (e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())), Arrays.asList(this.filter)); - this.client = WebTestClient.bindToWebHandler(webHandler).build(); + // @formatter:off + this.client = WebTestClient.bindToWebHandler(webHandler) + .build(); this.client.get() .uri("https://example.com/foo") .exchange() - .expectStatus() - .is3xxRedirection() + .expectStatus().is3xxRedirection() .returnResult(String.class); + // @formatter:on verify(this.requestCache).saveRequest(any()); } @Test public void filterWhenPathMatchesThenRequestSessionAttributeNotSaved() { this.filter.setRequestCache(this.requestCache); + // @formatter:off this.client.get() .uri("https://example.com/oauth2/authorization/registration-id") .exchange() - .expectStatus() - .is3xxRedirection() + .expectStatus().is3xxRedirection() .returnResult(String.class); - verifyZeroInteractions(this.requestCache); + // @formatter:on + verifyNoInteractions(this.requestCache); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTests.java similarity index 71% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTests.java index 9f921392d8..fad1584328 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTests.java @@ -16,11 +16,16 @@ package org.springframework.security.oauth2.client.web.server; +import java.util.Collections; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; @@ -32,22 +37,19 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch * @since 5.1 */ @RunWith(MockitoJUnitRunner.class) -public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { +public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTests { + @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; @@ -56,6 +58,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { private String clientRegistrationId = "github"; + // @formatter:off private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) @@ -69,13 +72,16 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { .clientId("clientId") .clientSecret("clientSecret") .build(); + // @formatter:on + // @formatter:off private OAuth2AuthorizationRequest.Builder authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") .state("state") .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId)); + // @formatter:on private final MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/"); @@ -83,44 +89,44 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { @Before public void setup() { - this.converter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(this.clientRegistrationRepository); + this.converter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter( + this.clientRegistrationRepository); this.converter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } @Test public void applyWhenAuthorizationRequestEmptyThenOAuth2AuthorizationException() { - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.empty()); - - assertThatThrownBy(() -> applyConverter()) - .isInstanceOf(OAuth2AuthorizationException.class); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())).willReturn(Mono.empty()); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> applyConverter()); } @Test public void applyWhenAttributesMissingThenOAuth2AuthorizationException() { this.authorizationRequest.attributes(Map::clear); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); - - assertThatThrownBy(() -> applyConverter()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining(ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(this.authorizationRequest.build())); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> applyConverter()) + .withMessageContaining( + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); } @Test public void applyWhenClientRegistrationMissingThenOAuth2AuthorizationException() { - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); - - assertThatThrownBy(() -> applyConverter()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining(ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(this.authorizationRequest.build())); + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.empty()); + assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(() -> applyConverter()) + .withMessageContaining( + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); } @Test public void applyWhenCodeParameterNotFoundThenErrorCode() { this.request.queryParam(OAuth2ParameterNames.ERROR, "error"); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(this.authorizationRequest.build())); + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(this.clientRegistration)); assertThat(applyConverter().getAuthorizationExchange().getAuthorizationResponse().getError().getErrorCode()) .isEqualTo("error"); } @@ -128,13 +134,12 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { @Test public void applyWhenCodeParameterFoundThenCode() { this.request.queryParam(OAuth2ParameterNames.CODE, "code"); - when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); - + given(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .willReturn(Mono.just(this.authorizationRequest.build())); + given(this.clientRegistrationRepository.findByRegistrationId(any())) + .willReturn(Mono.just(this.clientRegistration)); OAuth2AuthorizationCodeAuthenticationToken result = applyConverter(); - - OAuth2AuthorizationResponse exchange = result - .getAuthorizationExchange().getAuthorizationResponse(); + OAuth2AuthorizationResponse exchange = result.getAuthorizationExchange().getAuthorizationResponse(); assertThat(exchange.getError()).isNull(); assertThat(exchange.getCode()).isEqualTo("code"); } @@ -143,4 +148,5 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest { MockServerWebExchange exchange = MockServerWebExchange.from(this.request); return (OAuth2AuthorizationCodeAuthenticationToken) this.converter.convert(exchange).block(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests.java index 67486bebf1..308c5f03db 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.web.server; import org.junit.Before; import org.junit.Test; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -31,14 +32,15 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.web.server.ServerWebExchange; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * @author Rob Winch */ public class UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests { - private UnAuthenticatedServerOAuth2AuthorizedClientRepository repository = - new UnAuthenticatedServerOAuth2AuthorizedClientRepository(); + + private UnAuthenticatedServerOAuth2AuthorizedClientRepository repository = new UnAuthenticatedServerOAuth2AuthorizedClientRepository(); private ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); @@ -46,7 +48,8 @@ public class UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests { private ServerWebExchange exchange; - private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); private Authentication authentication; @@ -59,112 +62,110 @@ public class UnAuthenticatedServerOAuth2AuthorizedClientRepositoryTests { } // loadAuthorizedClient - @Test public void loadAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() { this.clientRegistrationId = null; - assertThatThrownBy(() -> this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()); } @Test public void loadAuthorizedClientWhenAuthenticationNotNullThenIllegalArgumentException() { this.authentication = new TestingAuthenticationToken("a", "b", "ROLE_USER"); - assertThatThrownBy(() -> this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()); } @Test public void loadAuthorizedClientWhenServerWebExchangeNotNullThenIllegalArgumentException() { this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); - assertThatThrownBy(() -> this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()); } @Test public void loadAuthorizedClientWhenNotFoundThenEmpty() { - assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()).isNull(); + assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange) + .block()).isNull(); } @Test public void loadAuthorizedClientWhenFoundThenFound() { this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block(); - - assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()).isEqualTo(this.authorizedClient); + assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange) + .block()).isEqualTo(this.authorizedClient); } @Test public void loadAuthorizedClientWhenMultipleThenFound() { ClientRegistration otherClientRegistration = TestClientRegistrations.clientRegistration() - .registrationId("other-client-registration") - .build(); - OAuth2AuthorizedClient otherAuthorizedClient = new OAuth2AuthorizedClient(otherClientRegistration, "anonymousUser", this.authorizedClient.getAccessToken()); - + .registrationId("other-client-registration").build(); + OAuth2AuthorizedClient otherAuthorizedClient = new OAuth2AuthorizedClient(otherClientRegistration, + "anonymousUser", this.authorizedClient.getAccessToken()); this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block(); this.repository.saveAuthorizedClient(otherAuthorizedClient, this.authentication, this.exchange).block(); - - assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()).isEqualTo(this.authorizedClient); + assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange) + .block()).isEqualTo(this.authorizedClient); } @Test public void loadAuthorizedClientWhenAnonymousThenFound() { this.authentication = this.anonymous; this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block(); - - assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()).isEqualTo(this.authorizedClient); + assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange) + .block()).isEqualTo(this.authorizedClient); } // saveAuthorizedClient - @Test public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() { this.authorizedClient = null; - assertThatThrownBy(() -> this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block()); } @Test public void saveAuthorizedClientWhenAuthenticationNotNullThenIllegalArgumentException() { this.authentication = new TestingAuthenticationToken("a", "b", "ROLE_USER"); - assertThatThrownBy(() -> this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block()); } @Test public void saveAuthorizedClientWhenServerWebExchangeNotNullThenIllegalArgumentException() { this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); - assertThatThrownBy(() -> this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block()); } // removeAuthorizedClient - @Test public void removeAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() { this.clientRegistrationId = null; - assertThatThrownBy(() -> this.repository.removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()); } @Test public void removeAuthorizedClientWhenAuthenticationNotNullThenIllegalArgumentException() { this.authentication = new TestingAuthenticationToken("a", "b", "ROLE_USER"); - assertThatThrownBy(() -> this.repository.removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()); } @Test public void removeAuthorizedClientWhenServerWebExchangeNotNullThenIllegalArgumentException() { this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build()); - assertThatThrownBy(() -> this.repository.removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository + .removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()); } @Test public void removeAuthorizedClientWhenFoundThenFound() { this.repository.saveAuthorizedClient(this.authorizedClient, this.authentication, this.exchange).block(); this.repository.removeAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block(); - - assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange).block()).isNull(); + assertThat(this.repository.loadAuthorizedClient(this.clientRegistrationId, this.authentication, this.exchange) + .block()).isNull(); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java index b4e11c05be..adae7ac311 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java @@ -16,18 +16,13 @@ package org.springframework.security.oauth2.client.web.server; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import java.util.HashMap; import java.util.Map; import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpResponse; @@ -40,8 +35,13 @@ import org.springframework.web.server.adapter.DefaultServerWebExchange; import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; import org.springframework.web.server.session.WebSessionManager; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * @author Rob Winch @@ -49,146 +49,154 @@ import reactor.test.StepVerifier; */ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { - private WebSessionOAuth2ServerAuthorizationRequestRepository repository = - new WebSessionOAuth2ServerAuthorizationRequestRepository(); + private WebSessionOAuth2ServerAuthorizationRequestRepository repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + // @formatter:off private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") .state("state") .build(); + // @formatter:on - private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, "state")); + private ServerWebExchange exchange = MockServerWebExchange + .from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state")); @Test public void loadAuthorizationRequestWhenNullExchangeThenIllegalArgumentException() { this.exchange = null; - assertThatThrownBy(() -> this.repository.loadAuthorizationRequest(this.exchange)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.repository.loadAuthorizationRequest(this.exchange)); } @Test public void loadAuthorizationRequestWhenNoSessionThenEmpty() { + // @formatter:off StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) .verifyComplete(); - + // @formatter:on assertSessionStartedIs(false); } @Test public void loadAuthorizationRequestWhenSessionAndNoRequestThenEmpty() { + // @formatter:off Mono setAttrThenLoad = this.exchange.getSession() - .map(WebSession::getAttributes).doOnNext(attrs -> attrs.put("foo", "bar")) + .map(WebSession::getAttributes) + .doOnNext((attrs) -> attrs.put("foo", "bar")) .then(this.repository.loadAuthorizationRequest(this.exchange)); - StepVerifier.create(setAttrThenLoad) .verifyComplete(); + // @formatter:on } @Test public void loadAuthorizationRequestWhenNoStateParamThenEmpty() { this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); - Mono saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange) + // @formatter:off + Mono saveAndLoad = this.repository + .saveAuthorizationRequest(this.authorizationRequest, this.exchange) .then(this.repository.loadAuthorizationRequest(this.exchange)); - StepVerifier.create(saveAndLoad) .verifyComplete(); + // @formatter:on } @Test public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() { - Mono saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange) + // @formatter:off + Mono saveAndLoad = this.repository + .saveAuthorizationRequest(this.authorizationRequest, this.exchange) .then(this.repository.loadAuthorizationRequest(this.exchange)); StepVerifier.create(saveAndLoad) .expectNext(this.authorizationRequest) .verifyComplete(); + // @formatter:on } @Test public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { String oldState = "state0"; + // @formatter:off MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") .state(oldState) .build(); - - WebSessionManager sessionManager = e -> this.exchange.getSession(); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + // @formatter:on + WebSessionManager sessionManager = (e) -> this.exchange.getSession(); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndLoad = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) .then(this.repository.loadAuthorizationRequest(oldExchange)); - StepVerifier.create(saveAndSaveAndLoad) .expectNext(oldAuthorizationRequest) .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) .expectNext(this.authorizationRequest) .verifyComplete(); + // @formatter:on } @Test public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() { this.authorizationRequest = null; - assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)); assertSessionStartedIs(false); - } @Test public void saveAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() { this.exchange = null; - assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .isInstanceOf(IllegalArgumentException.class); - + assertThatIllegalArgumentException() + .isThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)); } @Test public void removeAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() { this.exchange = null; - assertThatThrownBy(() -> this.repository.removeAuthorizationRequest(this.exchange)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.repository.removeAuthorizationRequest(this.exchange)); } @Test public void removeAuthorizationRequestWhenNotPresentThenThrowsIllegalArgumentException() { - StepVerifier.create(this.repository.removeAuthorizationRequest(this.exchange)) - .verifyComplete(); + StepVerifier.create(this.repository.removeAuthorizationRequest(this.exchange)).verifyComplete(); assertSessionStartedIs(false); } @Test public void removeAuthorizationRequestWhenPresentThenFoundAndRemoved() { + // @formatter:off Mono saveAndRemove = this.repository .saveAuthorizationRequest(this.authorizationRequest, this.exchange) .then(this.repository.removeAuthorizationRequest(this.exchange)); - - StepVerifier.create(saveAndRemove).expectNext(this.authorizationRequest) - .verifyComplete(); - - StepVerifier.create(this.exchange.getSession() - .map(WebSession::getAttributes) - .map(Map::isEmpty)) - .expectNext(true) + StepVerifier.create(saveAndRemove) + .expectNext(this.authorizationRequest) .verifyComplete(); + StepVerifier.create(this.exchange + .getSession() + .map(WebSession::getAttributes) + .map(Map::isEmpty) + ) + .expectNext(true).verifyComplete(); + // @formatter:on } // gh-5599 @Test public void removeAuthorizationRequestWhenStateMissingThenNoErrors() { + // @formatter:off MockServerHttpRequest otherState = MockServerHttpRequest.get("/") .queryParam(OAuth2ParameterNames.STATE, "other") .build(); @@ -198,89 +206,89 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { Mono saveAndRemove = this.repository .saveAuthorizationRequest(this.authorizationRequest, this.exchange) .then(this.repository.removeAuthorizationRequest(otherStateExchange)); - StepVerifier.create(saveAndRemove) .verifyComplete(); + // @formatter:on } @Test public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { String oldState = "state0"; + // @formatter:off MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") .state(oldState) .build(); - - WebSessionManager sessionManager = e -> this.exchange.getSession(); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + // @formatter:on + WebSessionManager sessionManager = (e) -> this.exchange.getSession(); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) .then(this.repository.removeAuthorizationRequest(this.exchange)); - - StepVerifier.create(saveAndSaveAndRemove) - .expectNext(this.authorizationRequest) + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)) .expectNext(oldAuthorizationRequest) .verifyComplete(); + // @formatter:on } // gh-7327 @Test public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { String oldState = "state0"; + // @formatter:off MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") .state(oldState) .build(); - + // @formatter:on Map sessionAttrs = spy(new HashMap<>()); WebSession session = mock(WebSession.class); - when(session.getAttributes()).thenReturn(sessionAttrs); - WebSessionManager sessionManager = e -> Mono.just(session); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + given(session.getAttributes()).willReturn(sessionAttrs); + WebSessionManager sessionManager = (e) -> Mono.just(session); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) .then(this.repository.removeAuthorizationRequest(this.exchange)); - - StepVerifier.create(saveAndSaveAndRemove) - .expectNext(this.authorizationRequest) + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) .verifyComplete(); - + // @formatter:on verify(sessionAttrs, times(3)).put(any(), any()); } private void assertSessionStartedIs(boolean expected) { - Mono isStarted = this.exchange.getSession().map(WebSession::isStarted); + // @formatter:off + Mono isStarted = this.exchange.getSession() + .map(WebSession::isStarted); StepVerifier.create(isStarted) - .expectNext(expected) - .verifyComplete(); + .expectNext(expected) + .verifyComplete(); + // @formatter:on } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java index bc72b2a30a..d0bd471407 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.client.web.server; import org.junit.Test; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.web.server.WebSession; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; /** @@ -33,8 +35,8 @@ import static org.mockito.Mockito.mock; * @since 5.1 */ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { - private WebSessionServerOAuth2AuthorizedClientRepository authorizedClientRepository = - new WebSessionServerOAuth2AuthorizedClientRepository(); + + private WebSessionServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); @@ -43,14 +45,15 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { private ClientRegistration registration2 = TestClientRegistrations.clientRegistration2().build(); private String registrationId1 = this.registration1.getRegistrationId(); - private String registrationId2 = this.registration2.getRegistrationId(); - private String principalName1 = "principalName-1"; + private String registrationId2 = this.registration2.getRegistrationId(); + + private String principalName1 = "principalName-1"; @Test public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.exchange).block()); } @Test @@ -60,65 +63,62 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { @Test public void loadAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null).block()); } @Test public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() { - OAuth2AuthorizedClient authorizedClient = - this.authorizedClientRepository.loadAuthorizedClient("registration-not-found", null, this.exchange).block(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository + .loadAuthorizedClient("registration-not-found", null, this.exchange).block(); assertThat(authorizedClient).isNull(); } @Test public void loadAuthorizedClientWhenSavedThenReturnAuthorizedClient() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); - - OAuth2AuthorizedClient loadedAuthorizedClient = - this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.exchange).block(); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId1, null, this.exchange).block(); assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); } @Test public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.exchange).block()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.exchange).block()); } @Test public void saveAuthorizedClientWhenAuthenticationIsNullThenExceptionNotThrown() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); } @Test public void saveAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); - assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null).block()) - .isInstanceOf(IllegalArgumentException.class); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null).block()); } @Test public void saveAuthorizedClientWhenSavedThenSavedToSession() { - OAuth2AuthorizedClient expected = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient expected = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(expected, null, this.exchange).block(); - OAuth2AuthorizedClient result = this.authorizedClientRepository .loadAuthorizedClient(this.registrationId2, null, this.exchange).block(); - assertThat(result).isEqualTo(expected); } @Test public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( - null, null, this.exchange)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(null, null, this.exchange)); } @Test @@ -128,60 +128,52 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { @Test public void removeAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId1, null, null)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, null)); } - @Test public void removeAuthorizedClientWhenNotSavedThenSessionNotCreated() { - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId2, null, this.exchange); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId2, null, this.exchange); assertThat(this.exchange.getSession().block().isStarted()).isFalse(); } @Test public void removeAuthorizedClientWhenClient1SavedAndClient2RemovedThenClient1NotRemoved() { - OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.exchange).block(); - // Remove registrationId2 (never added so is not removed either) - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId2, null, this.exchange); - - OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId1, null, this.exchange).block(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId2, null, this.exchange); + OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId1, null, this.exchange).block(); assertThat(loadedAuthorizedClient1).isNotNull(); assertThat(loadedAuthorizedClient1).isSameAs(authorizedClient1); } @Test public void removeAuthorizedClientWhenSavedThenRemoved() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId2, null, this.exchange).block(); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId2, null, this.exchange).block(); assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId2, null, this.exchange).block(); - loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId2, null, this.exchange).block(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId2, null, this.exchange).block(); + loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId2, null, this.exchange).block(); assertThat(loadedAuthorizedClient).isNull(); } @Test public void removeAuthorizedClientWhenSavedThenRemovedFromSession() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); - OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId1, null, this.exchange).block(); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId1, null, this.exchange).block(); assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId1, null, this.exchange).block(); - + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.exchange).block(); WebSession session = this.exchange.getSession().block(); assertThat(session).isNotNull(); assertThat(session.getAttributes()).isEmpty(); @@ -189,20 +181,17 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { @Test public void removeAuthorizedClientWhenClient1Client2SavedAndClient1RemovedThenClient2NotRemoved() { - OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( - this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.exchange).block(); - - OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient( - this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName1, + mock(OAuth2AccessToken.class)); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient2, null, this.exchange).block(); - - this.authorizedClientRepository.removeAuthorizedClient( - this.registrationId1, null, this.exchange).block(); - - OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository.loadAuthorizedClient( - this.registrationId2, null, this.exchange).block(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.exchange).block(); + OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId2, null, this.exchange).block(); assertThat(loadedAuthorizedClient2).isNotNull(); assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java index ea3c8dd359..6c89588554 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java @@ -16,11 +16,17 @@ package org.springframework.security.oauth2.client.web.server.authentication; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.ReactiveAuthenticationManager; @@ -36,15 +42,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.web.server.WebFilterExchange; import org.springframework.web.server.handler.DefaultWebFilterChain; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; -import java.util.Collections; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -52,59 +53,60 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2LoginAuthenticationWebFilterTests { + @Mock private ReactiveAuthenticationManager authenticationManager; + @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2LoginAuthenticationWebFilter filter; - private WebFilterExchange webFilterExchange; + private WebFilterExchange webFilterExchange; private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); - - private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse - .success("code") + private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse.success("code") .state("state"); @Before public void setup() { - this.filter = new OAuth2LoginAuthenticationWebFilter(this.authenticationManager, this.authorizedClientRepository); - this.webFilterExchange = new WebFilterExchange(MockServerWebExchange.from(MockServerHttpRequest.get("/")), new DefaultWebFilterChain(exchange -> exchange.getResponse().setComplete())); - when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) - .thenReturn(Mono.empty()); + this.filter = new OAuth2LoginAuthenticationWebFilter(this.authenticationManager, + this.authorizedClientRepository); + this.webFilterExchange = new WebFilterExchange(MockServerWebExchange.from(MockServerHttpRequest.get("/")), + new DefaultWebFilterChain((exchange) -> exchange.getResponse().setComplete())); + given(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).willReturn(Mono.empty()); } @Test public void onAuthenticationSuccessWhenOAuth2LoginAuthenticationTokenThenSavesAuthorizedClient() { this.filter.onAuthenticationSuccess(loginToken(), this.webFilterExchange).block(); - verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); } private OAuth2LoginAuthenticationToken loginToken() { - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", - Instant.now(), - Instant.now().plus(Duration.ofDays(1)), - Collections.singleton("user")); - DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections - .singletonMap("user", "rob"), "user"); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofDays(1)), Collections.singleton("user")); + DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); ClientRegistration clientRegistration = this.registration.build(); - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest - .authorizationCode() + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .state("state") .clientId(clientRegistration.getClientId()) - .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .authorizationUri(clientRegistration.getProviderDetails() + .getAuthorizationUri()) .redirectUri(clientRegistration.getRedirectUri()) .scopes(clientRegistration.getScopes()) .build(); OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr .redirectUri(clientRegistration.getRedirectUri()) .build(); + // @formatter:on OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); - return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange, user, user.getAuthorities(), accessToken); + return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange, user, + user.getAuthorities(), accessToken); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AbstractOAuth2Token.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AbstractOAuth2Token.java index e1dc4d5903..036347bd8e 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AbstractOAuth2Token.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AbstractOAuth2Token.java @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; +import java.io.Serializable; +import java.time.Instant; + import org.springframework.lang.Nullable; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.io.Serializable; -import java.time.Instant; - /** * Base class for OAuth 2.0 Token implementations. * @@ -30,14 +31,17 @@ import java.time.Instant; * @see OAuth2AccessToken */ public abstract class AbstractOAuth2Token implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final String tokenValue; + private final Instant issuedAt; + private final Instant expiresAt; /** * Sub-class constructor. - * * @param tokenValue the token value */ protected AbstractOAuth2Token(String tokenValue) { @@ -46,10 +50,10 @@ public abstract class AbstractOAuth2Token implements Serializable { /** * Sub-class constructor. - * * @param tokenValue the token value * @param issuedAt the time at which the token was issued, may be null - * @param expiresAt the expiration time on or after which the token MUST NOT be accepted, may be null + * @param expiresAt the expiration time on or after which the token MUST NOT be + * accepted, may be null */ protected AbstractOAuth2Token(String tokenValue, @Nullable Instant issuedAt, @Nullable Instant expiresAt) { Assert.hasText(tokenValue, "tokenValue cannot be empty"); @@ -63,7 +67,6 @@ public abstract class AbstractOAuth2Token implements Serializable { /** * Returns the token value. - * * @return the token value */ public String getTokenValue() { @@ -72,7 +75,6 @@ public abstract class AbstractOAuth2Token implements Serializable { /** * Returns the time at which the token was issued. - * * @return the time the token was issued or null */ public @Nullable Instant getIssuedAt() { @@ -81,7 +83,6 @@ public abstract class AbstractOAuth2Token implements Serializable { /** * Returns the expiration time on or after which the token MUST NOT be accepted. - * * @return the expiration time of the token or null */ public @Nullable Instant getExpiresAt() { @@ -96,23 +97,24 @@ public abstract class AbstractOAuth2Token implements Serializable { if (obj == null || this.getClass() != obj.getClass()) { return false; } - - AbstractOAuth2Token that = (AbstractOAuth2Token) obj; - - if (!this.getTokenValue().equals(that.getTokenValue())) { + AbstractOAuth2Token other = (AbstractOAuth2Token) obj; + if (!this.getTokenValue().equals(other.getTokenValue())) { return false; } - if (this.getIssuedAt() != null ? !this.getIssuedAt().equals(that.getIssuedAt()) : that.getIssuedAt() != null) { + if ((this.getIssuedAt() != null) ? !this.getIssuedAt().equals(other.getIssuedAt()) + : other.getIssuedAt() != null) { return false; } - return this.getExpiresAt() != null ? this.getExpiresAt().equals(that.getExpiresAt()) : that.getExpiresAt() == null; + return (this.getExpiresAt() != null) ? this.getExpiresAt().equals(other.getExpiresAt()) + : other.getExpiresAt() == null; } @Override public int hashCode() { int result = this.getTokenValue().hashCode(); - result = 31 * result + (this.getIssuedAt() != null ? this.getIssuedAt().hashCode() : 0); - result = 31 * result + (this.getExpiresAt() != null ? this.getExpiresAt().hashCode() : 0); + result = 31 * result + ((this.getIssuedAt() != null) ? this.getIssuedAt().hashCode() : 0); + result = 31 * result + ((this.getExpiresAt() != null) ? this.getExpiresAt().hashCode() : 0); return result; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthenticationMethod.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthenticationMethod.java index 16f7a96c34..f01c4f4283 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthenticationMethod.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthenticationMethod.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import java.io.Serializable; @@ -21,22 +22,28 @@ import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; /** - * The authentication method used when sending bearer access tokens in resource requests to resource servers. + * The authentication method used when sending bearer access tokens in resource requests + * to resource servers. * * @author MyeongHyeon Lee * @since 5.1 - * @see Section 2 Authenticated Requests + * @see Section 2 + * Authenticated Requests */ public final class AuthenticationMethod implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + public static final AuthenticationMethod HEADER = new AuthenticationMethod("header"); + public static final AuthenticationMethod FORM = new AuthenticationMethod("form"); + public static final AuthenticationMethod QUERY = new AuthenticationMethod("query"); + private final String value; /** * Constructs an {@code AuthenticationMethod} using the provided value. - * * @param value the value of the authentication method type */ public AuthenticationMethod(String value) { @@ -46,7 +53,6 @@ public final class AuthenticationMethod implements Serializable { /** * Returns the value of the authentication method type. - * * @return the value of the authentication method type */ public String getValue() { @@ -69,4 +75,5 @@ public final class AuthenticationMethod implements Serializable { public int hashCode() { return this.getValue().hashCode(); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java index db0b8eb014..8285358b0d 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java @@ -13,48 +13,56 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; +import java.io.Serializable; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.io.Serializable; - /** * An authorization grant is a credential representing the resource owner's authorization - * (to access it's protected resources) to the client and used by the client to obtain an access token. + * (to access it's protected resources) to the client and used by the client to obtain an + * access token. * *

        - * The OAuth 2.0 Authorization Framework defines four standard grant types: - * authorization code, implicit, resource owner password credentials, and client credentials. - * It also provides an extensibility mechanism for defining additional grant types. + * The OAuth 2.0 Authorization Framework defines four standard grant types: authorization + * code, implicit, resource owner password credentials, and client credentials. It also + * provides an extensibility mechanism for defining additional grant types. * * @author Joe Grandja * @since 5.0 - * @see Section 1.3 Authorization Grant + * @see Section + * 1.3 Authorization Grant */ public final class AuthorizationGrantType implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + public static final AuthorizationGrantType AUTHORIZATION_CODE = new AuthorizationGrantType("authorization_code"); /** - * It is not recommended to use the implicit flow - * due to the inherent risks of returning access tokens in an HTTP redirect - * without any confirmation that it has been received by the client. + * It is not recommended to use the implicit flow due to the inherent risks of + * returning access tokens in an HTTP redirect without any confirmation that it has + * been received by the client. * - * @see OAuth 2.0 Implicit Grant + * @see OAuth 2.0 + * Implicit Grant */ @Deprecated public static final AuthorizationGrantType IMPLICIT = new AuthorizationGrantType("implicit"); public static final AuthorizationGrantType REFRESH_TOKEN = new AuthorizationGrantType("refresh_token"); + public static final AuthorizationGrantType CLIENT_CREDENTIALS = new AuthorizationGrantType("client_credentials"); + public static final AuthorizationGrantType PASSWORD = new AuthorizationGrantType("password"); + private final String value; /** * Constructs an {@code AuthorizationGrantType} using the provided value. - * * @param value the value of the authorization grant type */ public AuthorizationGrantType(String value) { @@ -64,7 +72,6 @@ public final class AuthorizationGrantType implements Serializable { /** * Returns the value of the authorization grant type. - * * @return the value of the authorization grant type */ public String getValue() { @@ -87,4 +94,5 @@ public final class AuthorizationGrantType implements Serializable { public int hashCode() { return this.getValue().hashCode(); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClaimAccessor.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClaimAccessor.java index 33c8e29ced..70998b06e4 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClaimAccessor.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClaimAccessor.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.security.oauth2.core.converter.ClaimConversionService; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.core; import java.net.URL; import java.time.Instant; import java.util.List; import java.util.Map; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.security.oauth2.core.converter.ClaimConversionService; +import org.springframework.util.Assert; + /** * An "accessor" for a set of claims that may be used for assertions. * @@ -34,14 +35,13 @@ public interface ClaimAccessor { /** * Returns a set of claims that may be used for assertions. - * * @return a {@code Map} of claims */ Map getClaims(); /** - * Returns the claim value as a {@code T} type. - * The claim value is expected to be of type {@code T}. + * Returns the claim value as a {@code T} type. The claim value is expected to be of + * type {@code T}. * * @since 5.2 * @param claim the name of the claim @@ -54,8 +54,8 @@ public interface ClaimAccessor { } /** - * Returns {@code true} if the claim exists in {@link #getClaims()}, otherwise {@code false}. - * + * Returns {@code true} if the claim exists in {@link #getClaims()}, otherwise + * {@code false}. * @param claim the name of the claim * @return {@code true} if the claim exists, otherwise {@code false} */ @@ -65,30 +65,29 @@ public interface ClaimAccessor { } /** - * Returns the claim value as a {@code String} or {@code null} if it does not exist or is equal to {@code null}. - * + * Returns the claim value as a {@code String} or {@code null} if it does not exist or + * is equal to {@code null}. * @param claim the name of the claim - * @return the claim value or {@code null} if it does not exist or is equal to {@code null} + * @return the claim value or {@code null} if it does not exist or is equal to + * {@code null} */ default String getClaimAsString(String claim) { - return !containsClaim(claim) ? null : - ClaimConversionService.getSharedInstance().convert(getClaims().get(claim), String.class); + return !containsClaim(claim) ? null + : ClaimConversionService.getSharedInstance().convert(getClaims().get(claim), String.class); } /** * Returns the claim value as a {@code Boolean} or {@code null} if it does not exist. - * * @param claim the name of the claim * @return the claim value or {@code null} if it does not exist */ default Boolean getClaimAsBoolean(String claim) { - return !containsClaim(claim) ? null : - ClaimConversionService.getSharedInstance().convert(getClaims().get(claim), Boolean.class); + return !containsClaim(claim) ? null + : ClaimConversionService.getSharedInstance().convert(getClaims().get(claim), Boolean.class); } /** * Returns the claim value as an {@code Instant} or {@code null} if it does not exist. - * * @param claim the name of the claim * @return the claim value or {@code null} if it does not exist */ @@ -98,16 +97,13 @@ public interface ClaimAccessor { } Object claimValue = getClaims().get(claim); Instant convertedValue = ClaimConversionService.getSharedInstance().convert(claimValue, Instant.class); - if (convertedValue == null) { - throw new IllegalArgumentException("Unable to convert claim '" + claim + - "' of type '" + claimValue.getClass() + "' to Instant."); - } + Assert.isTrue(convertedValue != null, + () -> "Unable to convert claim '" + claim + "' of type '" + claimValue.getClass() + "' to Instant."); return convertedValue; } /** * Returns the claim value as an {@code URL} or {@code null} if it does not exist. - * * @param claim the name of the claim * @return the claim value or {@code null} if it does not exist */ @@ -117,19 +113,17 @@ public interface ClaimAccessor { } Object claimValue = getClaims().get(claim); URL convertedValue = ClaimConversionService.getSharedInstance().convert(claimValue, URL.class); - if (convertedValue == null) { - throw new IllegalArgumentException("Unable to convert claim '" + claim + - "' of type '" + claimValue.getClass() + "' to URL."); - } + Assert.isTrue(convertedValue != null, + () -> "Unable to convert claim '" + claim + "' of type '" + claimValue.getClass() + "' to URL."); return convertedValue; } /** - * Returns the claim value as a {@code Map} - * or {@code null} if it does not exist or cannot be assigned to a {@code Map}. - * + * Returns the claim value as a {@code Map} or {@code null} if it does + * not exist or cannot be assigned to a {@code Map}. * @param claim the name of the claim - * @return the claim value or {@code null} if it does not exist or cannot be assigned to a {@code Map} + * @return the claim value or {@code null} if it does not exist or cannot be assigned + * to a {@code Map} */ @SuppressWarnings("unchecked") default Map getClaimAsMap(String claim) { @@ -137,24 +131,22 @@ public interface ClaimAccessor { return null; } final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - final TypeDescriptor targetDescriptor = TypeDescriptor.map( - Map.class, TypeDescriptor.valueOf(String.class), TypeDescriptor.valueOf(Object.class)); + final TypeDescriptor targetDescriptor = TypeDescriptor.map(Map.class, TypeDescriptor.valueOf(String.class), + TypeDescriptor.valueOf(Object.class)); Object claimValue = getClaims().get(claim); - Map convertedValue = (Map) ClaimConversionService.getSharedInstance().convert( - claimValue, sourceDescriptor, targetDescriptor); - if (convertedValue == null) { - throw new IllegalArgumentException("Unable to convert claim '" + claim + - "' of type '" + claimValue.getClass() + "' to Map."); - } + Map convertedValue = (Map) ClaimConversionService.getSharedInstance() + .convert(claimValue, sourceDescriptor, targetDescriptor); + Assert.isTrue(convertedValue != null, + () -> "Unable to convert claim '" + claim + "' of type '" + claimValue.getClass() + "' to Map."); return convertedValue; } /** - * Returns the claim value as a {@code List} - * or {@code null} if it does not exist or cannot be assigned to a {@code List}. - * + * Returns the claim value as a {@code List} or {@code null} if it does not + * exist or cannot be assigned to a {@code List}. * @param claim the name of the claim - * @return the claim value or {@code null} if it does not exist or cannot be assigned to a {@code List} + * @return the claim value or {@code null} if it does not exist or cannot be assigned + * to a {@code List} */ @SuppressWarnings("unchecked") default List getClaimAsStringList(String claim) { @@ -162,15 +154,14 @@ public interface ClaimAccessor { return null; } final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - final TypeDescriptor targetDescriptor = TypeDescriptor.collection( - List.class, TypeDescriptor.valueOf(String.class)); + final TypeDescriptor targetDescriptor = TypeDescriptor.collection(List.class, + TypeDescriptor.valueOf(String.class)); Object claimValue = getClaims().get(claim); - List convertedValue = (List) ClaimConversionService.getSharedInstance().convert( - claimValue, sourceDescriptor, targetDescriptor); - if (convertedValue == null) { - throw new IllegalArgumentException("Unable to convert claim '" + claim + - "' of type '" + claimValue.getClass() + "' to List."); - } + List convertedValue = (List) ClaimConversionService.getSharedInstance().convert(claimValue, + sourceDescriptor, targetDescriptor); + Assert.isTrue(convertedValue != null, + () -> "Unable to convert claim '" + claim + "' of type '" + claimValue.getClass() + "' to List."); return convertedValue; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java index 48ecaaa0ca..c1e3510513 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/ClientAuthenticationMethod.java @@ -13,23 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; +import java.io.Serializable; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.io.Serializable; - /** - * The authentication method used when authenticating the client with the authorization server. + * The authentication method used when authenticating the client with the authorization + * server. * * @author Joe Grandja * @since 5.0 - * @see Section 2.3 Client Authentication + * @see Section + * 2.3 Client Authentication */ public final class ClientAuthenticationMethod implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + public static final ClientAuthenticationMethod BASIC = new ClientAuthenticationMethod("basic"); + public static final ClientAuthenticationMethod POST = new ClientAuthenticationMethod("post"); /** @@ -41,7 +47,6 @@ public final class ClientAuthenticationMethod implements Serializable { /** * Constructs a {@code ClientAuthenticationMethod} using the provided value. - * * @param value the value of the client authentication method */ public ClientAuthenticationMethod(String value) { @@ -51,7 +56,6 @@ public final class ClientAuthenticationMethod implements Serializable { /** * Returns the value of the client authentication method. - * * @return the value of the client authentication method */ public String getValue() { @@ -74,4 +78,5 @@ public final class ClientAuthenticationMethod implements Serializable { public int hashCode() { return this.getValue().hashCode(); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipal.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipal.java index 8b8dd8c677..aaacad14a6 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipal.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipal.java @@ -22,10 +22,9 @@ import java.util.Collections; import java.util.Map; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.util.Assert; -import static org.springframework.security.core.authority.AuthorityUtils.NO_AUTHORITIES; - /** * A domain object that wraps the attributes of an OAuth 2.0 token. * @@ -34,61 +33,57 @@ import static org.springframework.security.core.authority.AuthorityUtils.NO_AUTH * @since 5.2 */ public final class DefaultOAuth2AuthenticatedPrincipal implements OAuth2AuthenticatedPrincipal, Serializable { + private final Map attributes; + private final Collection authorities; + private final String name; /** - * Constructs an {@code DefaultOAuth2AuthenticatedPrincipal} using the provided parameters. - * + * Constructs an {@code DefaultOAuth2AuthenticatedPrincipal} using the provided + * parameters. * @param attributes the attributes of the OAuth 2.0 token * @param authorities the authorities of the OAuth 2.0 token */ public DefaultOAuth2AuthenticatedPrincipal(Map attributes, Collection authorities) { - this(null, attributes, authorities); } /** - * Constructs an {@code DefaultOAuth2AuthenticatedPrincipal} using the provided parameters. - * + * Constructs an {@code DefaultOAuth2AuthenticatedPrincipal} using the provided + * parameters. * @param name the name attached to the OAuth 2.0 token * @param attributes the attributes of the OAuth 2.0 token * @param authorities the authorities of the OAuth 2.0 token */ public DefaultOAuth2AuthenticatedPrincipal(String name, Map attributes, Collection authorities) { - Assert.notEmpty(attributes, "attributes cannot be empty"); this.attributes = Collections.unmodifiableMap(attributes); - this.authorities = authorities == null ? - NO_AUTHORITIES : Collections.unmodifiableCollection(authorities); - this.name = name == null ? (String) this.attributes.get("sub") : name; + this.authorities = (authorities != null) ? Collections.unmodifiableCollection(authorities) + : AuthorityUtils.NO_AUTHORITIES; + this.name = (name != null) ? name : (String) this.attributes.get("sub"); } /** * Gets the attributes of the OAuth 2.0 token in map form. - * * @return a {@link Map} of the attribute's objects keyed by the attribute's names */ + @Override public Map getAttributes() { return this.attributes; } - /** - * {@inheritDoc} - */ @Override public Collection getAuthorities() { return this.authorities; } - /** - * {@inheritDoc} - */ @Override public String getName() { return this.name; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidator.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidator.java index 785d3cf248..e16cf7fef3 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidator.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidator.java @@ -26,29 +26,25 @@ import org.springframework.util.Assert; * A composite validator * * @param the type of {@link AbstractOAuth2Token} this validator validates - * * @author Josh Cummings * @since 5.1 */ -public final class DelegatingOAuth2TokenValidator - implements OAuth2TokenValidator { +public final class DelegatingOAuth2TokenValidator implements OAuth2TokenValidator { private final Collection> tokenValidators; /** * Constructs a {@code DelegatingOAuth2TokenValidator} using the provided validators. - * - * @param tokenValidators the {@link Collection} of {@link OAuth2TokenValidator}s to use + * @param tokenValidators the {@link Collection} of {@link OAuth2TokenValidator}s to + * use */ public DelegatingOAuth2TokenValidator(Collection> tokenValidators) { Assert.notNull(tokenValidators, "tokenValidators cannot be null"); - this.tokenValidators = new ArrayList<>(tokenValidators); } /** * Constructs a {@code DelegatingOAuth2TokenValidator} using the provided validators. - * * @param tokenValidators the collection of {@link OAuth2TokenValidator}s to use */ @SafeVarargs @@ -56,17 +52,13 @@ public final class DelegatingOAuth2TokenValidator this(Arrays.asList(tokenValidators)); } - /** - * {@inheritDoc} - */ @Override public OAuth2TokenValidatorResult validate(T token) { Collection errors = new ArrayList<>(); - - for ( OAuth2TokenValidator validator : this.tokenValidators) { + for (OAuth2TokenValidator validator : this.tokenValidators) { errors.addAll(validator.validate(token).getErrors()); } - return OAuth2TokenValidatorResult.failure(errors); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AccessToken.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AccessToken.java index b2ff587f0d..8ec26b6023 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AccessToken.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AccessToken.java @@ -13,40 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core; -import org.springframework.security.core.SpringSecurityCoreVersion; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.core; import java.io.Serializable; import java.time.Instant; import java.util.Collections; import java.util.Set; +import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.util.Assert; + /** - * An implementation of an {@link AbstractOAuth2Token} representing an OAuth 2.0 Access Token. + * An implementation of an {@link AbstractOAuth2Token} representing an OAuth 2.0 Access + * Token. * *

        - * An access token is a credential that represents an authorization - * granted by the resource owner to the client. - * It is primarily used by the client to access protected resources on either a - * resource server or the authorization server that originally issued the access token. + * An access token is a credential that represents an authorization granted by the + * resource owner to the client. It is primarily used by the client to access protected + * resources on either a resource server or the authorization server that originally + * issued the access token. * * @author Joe Grandja * @since 5.0 - * @see Section 1.4 Access Token + * @see Section + * 1.4 Access Token */ public class OAuth2AccessToken extends AbstractOAuth2Token { + private final TokenType tokenType; + private final Set scopes; /** * Constructs an {@code OAuth2AccessToken} using the provided parameters. - * * @param tokenType the token type * @param tokenValue the token value * @param issuedAt the time at which the token was issued - * @param expiresAt the expiration time on or after which the token MUST NOT be accepted + * @param expiresAt the expiration time on or after which the token MUST NOT be + * accepted */ public OAuth2AccessToken(TokenType tokenType, String tokenValue, Instant issuedAt, Instant expiresAt) { this(tokenType, tokenValue, issuedAt, expiresAt, Collections.emptySet()); @@ -54,24 +59,23 @@ public class OAuth2AccessToken extends AbstractOAuth2Token { /** * Constructs an {@code OAuth2AccessToken} using the provided parameters. - * * @param tokenType the token type * @param tokenValue the token value * @param issuedAt the time at which the token was issued - * @param expiresAt the expiration time on or after which the token MUST NOT be accepted + * @param expiresAt the expiration time on or after which the token MUST NOT be + * accepted * @param scopes the scope(s) associated to the token */ - public OAuth2AccessToken(TokenType tokenType, String tokenValue, Instant issuedAt, Instant expiresAt, Set scopes) { + public OAuth2AccessToken(TokenType tokenType, String tokenValue, Instant issuedAt, Instant expiresAt, + Set scopes) { super(tokenValue, issuedAt, expiresAt); Assert.notNull(tokenType, "tokenType cannot be null"); this.tokenType = tokenType; - this.scopes = Collections.unmodifiableSet( - scopes != null ? scopes : Collections.emptySet()); + this.scopes = Collections.unmodifiableSet((scopes != null) ? scopes : Collections.emptySet()); } /** * Returns the {@link TokenType token type}. - * * @return the {@link TokenType} */ public TokenType getTokenType() { @@ -80,7 +84,6 @@ public class OAuth2AccessToken extends AbstractOAuth2Token { /** * Returns the scope(s) associated to the token. - * * @return the scope(s) associated to the token */ public Set getScopes() { @@ -90,11 +93,16 @@ public class OAuth2AccessToken extends AbstractOAuth2Token { /** * Access Token Types. * - * @see Section 7.1 Access Token Types + * @see Section 7.1 Access Token + * Types */ public static final class TokenType implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + public static final TokenType BEARER = new TokenType("Bearer"); + private final String value; private TokenType(String value) { @@ -104,7 +112,6 @@ public class OAuth2AccessToken extends AbstractOAuth2Token { /** * Returns the value of the token type. - * * @return the value of the token type */ public String getValue() { @@ -127,5 +134,7 @@ public class OAuth2AccessToken extends AbstractOAuth2Token { public int hashCode() { return this.getValue().hashCode(); } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticatedPrincipal.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticatedPrincipal.java index b3329d2f77..f056aba43d 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticatedPrincipal.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticatedPrincipal.java @@ -24,16 +24,16 @@ import org.springframework.security.core.AuthenticatedPrincipal; import org.springframework.security.core.GrantedAuthority; /** - * An {@link AuthenticatedPrincipal} that represents the principal - * associated with an OAuth 2.0 token. + * An {@link AuthenticatedPrincipal} that represents the principal associated with an + * OAuth 2.0 token. * * @author Josh Cummings * @since 5.2 */ public interface OAuth2AuthenticatedPrincipal extends AuthenticatedPrincipal { + /** * Get the OAuth 2.0 token attribute by name - * * @param name the name of the attribute * @param the type of the attribute * @return the attribute or {@code null} otherwise @@ -45,15 +45,13 @@ public interface OAuth2AuthenticatedPrincipal extends AuthenticatedPrincipal { /** * Get the OAuth 2.0 token attributes - * * @return the OAuth 2.0 token attributes */ Map getAttributes(); /** - * Get the {@link Collection} of {@link GrantedAuthority}s associated - * with this OAuth 2.0 token - * + * Get the {@link Collection} of {@link GrantedAuthority}s associated with this OAuth + * 2.0 token * @return the OAuth 2.0 token authorities */ Collection getAuthorities(); diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticationException.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticationException.java index 0f9383181b..01cedf36e1 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticationException.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthenticationException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import org.springframework.security.core.Authentication; @@ -25,24 +26,25 @@ import org.springframework.util.Assert; *

        * There are a number of scenarios where an error may occur, for example: *

        * * @author Joe Grandja * @since 5.0 */ public class OAuth2AuthenticationException extends AuthenticationException { - private OAuth2Error error; + + private final OAuth2Error error; /** * Constructs an {@code OAuth2AuthenticationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} */ public OAuth2AuthenticationException(OAuth2Error error) { @@ -51,7 +53,6 @@ public class OAuth2AuthenticationException extends AuthenticationException { /** * Constructs an {@code OAuth2AuthenticationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param cause the root cause */ @@ -61,38 +62,31 @@ public class OAuth2AuthenticationException extends AuthenticationException { /** * Constructs an {@code OAuth2AuthenticationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param message the detail message */ public OAuth2AuthenticationException(OAuth2Error error, String message) { - super(message); - this.setError(error); + this(error, message, null); } /** * Constructs an {@code OAuth2AuthenticationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param message the detail message * @param cause the root cause */ public OAuth2AuthenticationException(OAuth2Error error, String message, Throwable cause) { super(message, cause); - this.setError(error); + Assert.notNull(error, "error cannot be null"); + this.error = error; } /** * Returns the {@link OAuth2Error OAuth 2.0 Error}. - * * @return the {@link OAuth2Error} */ public OAuth2Error getError() { return this.error; } - private void setError(OAuth2Error error) { - Assert.notNull(error, "error cannot be null"); - this.error = error; - } } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java index a894c6d6eb..dbfdf98e5f 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import org.springframework.util.Assert; @@ -24,11 +25,11 @@ import org.springframework.util.Assert; * @since 5.1 */ public class OAuth2AuthorizationException extends RuntimeException { - private OAuth2Error error; + + private final OAuth2Error error; /** * Constructs an {@code OAuth2AuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} */ public OAuth2AuthorizationException(OAuth2Error error) { @@ -37,7 +38,6 @@ public class OAuth2AuthorizationException extends RuntimeException { /** * Constructs an {@code OAuth2AuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param message the exception message * @since 5.3 @@ -50,7 +50,6 @@ public class OAuth2AuthorizationException extends RuntimeException { /** * Constructs an {@code OAuth2AuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param cause the root cause */ @@ -60,7 +59,6 @@ public class OAuth2AuthorizationException extends RuntimeException { /** * Constructs an {@code OAuth2AuthorizationException} using the provided parameters. - * * @param error the {@link OAuth2Error OAuth 2.0 Error} * @param message the exception message * @param cause the root cause @@ -74,10 +72,10 @@ public class OAuth2AuthorizationException extends RuntimeException { /** * Returns the {@link OAuth2Error OAuth 2.0 Error}. - * * @return the {@link OAuth2Error} */ public OAuth2Error getError() { return this.error; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2Error.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2Error.java index 22d306a01b..f1816b723d 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2Error.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2Error.java @@ -13,36 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; +import java.io.Serializable; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.io.Serializable; - /** * A representation of an OAuth 2.0 Error. * *

        - * At a minimum, an error response will contain an error code. - * The error code may be one of the standard codes defined by the specification, - * or a new code defined in the OAuth Extensions Error Registry, - * for cases where protocol extensions require additional error code(s) above the standard codes. + * At a minimum, an error response will contain an error code. The error code may be one + * of the standard codes defined by the specification, or a new code defined in the OAuth + * Extensions Error Registry, for cases where protocol extensions require additional error + * code(s) above the standard codes. * * @author Joe Grandja * @since 5.0 * @see OAuth2ErrorCodes - * @see Section 11.4 OAuth Extensions Error Registry + * @see Section + * 11.4 OAuth Extensions Error Registry */ public class OAuth2Error implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final String errorCode; + private final String description; + private final String uri; /** * Constructs an {@code OAuth2Error} using the provided parameters. - * * @param errorCode the error code */ public OAuth2Error(String errorCode) { @@ -51,7 +56,6 @@ public class OAuth2Error implements Serializable { /** * Constructs an {@code OAuth2Error} using the provided parameters. - * * @param errorCode the error code * @param description the error description * @param uri the error uri @@ -65,7 +69,6 @@ public class OAuth2Error implements Serializable { /** * Returns the error code. - * * @return the error code */ public final String getErrorCode() { @@ -74,7 +77,6 @@ public class OAuth2Error implements Serializable { /** * Returns the error description. - * * @return the error description */ public final String getDescription() { @@ -83,7 +85,6 @@ public class OAuth2Error implements Serializable { /** * Returns the error uri. - * * @return the error uri */ public final String getUri() { @@ -92,7 +93,7 @@ public class OAuth2Error implements Serializable { @Override public String toString() { - return "[" + this.getErrorCode() + "] " + - (this.getDescription() != null ? this.getDescription() : ""); + return "[" + this.getErrorCode() + "] " + ((this.getDescription() != null) ? this.getDescription() : ""); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java index 7a8f7a1677..edcc5f2d6b 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; /** @@ -24,20 +25,21 @@ package org.springframework.security.oauth2.core; public interface OAuth2ErrorCodes { /** - * {@code invalid_request} - The request is missing a required parameter, - * includes an invalid parameter value, - * includes a parameter more than once, or is otherwise malformed. + * {@code invalid_request} - The request is missing a required parameter, includes an + * invalid parameter value, includes a parameter more than once, or is otherwise + * malformed. */ String INVALID_REQUEST = "invalid_request"; /** - * {@code unauthorized_client} - The client is not authorized to request - * an authorization code or access token using this method. + * {@code unauthorized_client} - The client is not authorized to request an + * authorization code or access token using this method. */ String UNAUTHORIZED_CLIENT = "unauthorized_client"; /** - * {@code access_denied} - The resource owner or authorization server denied the request. + * {@code access_denied} - The resource owner or authorization server denied the + * request. */ String ACCESS_DENIED = "access_denied"; @@ -54,62 +56,66 @@ public interface OAuth2ErrorCodes { String INVALID_SCOPE = "invalid_scope"; /** - * {@code insufficient_scope} - The request requires higher privileges than - * provided by the access token. - * The resource server SHOULD respond with the HTTP 403 (Forbidden) - * status code and MAY include the "scope" attribute with the scope - * necessary to access the protected resource. + * {@code insufficient_scope} - The request requires higher privileges than provided + * by the access token. The resource server SHOULD respond with the HTTP 403 + * (Forbidden) status code and MAY include the "scope" attribute with the scope + * necessary to access the protected resource. * - * @see RFC-6750 - Section 3.1 - Error Codes + * @see RFC-6750 - Section + * 3.1 - Error Codes */ String INSUFFICIENT_SCOPE = "insufficient_scope"; /** - * {@code invalid_token} - The access token provided is expired, revoked, - * malformed, or invalid for other reasons. - * The resource SHOULD respond with the HTTP 401 (Unauthorized) status code. - * The client MAY request a new access token and retry the protected resource request. + * {@code invalid_token} - The access token provided is expired, revoked, malformed, + * or invalid for other reasons. The resource SHOULD respond with the HTTP 401 + * (Unauthorized) status code. The client MAY request a new access token and retry the + * protected resource request. * - * @see RFC-6750 - Section 3.1 - Error Codes + * @see RFC-6750 - Section + * 3.1 - Error Codes */ String INVALID_TOKEN = "invalid_token"; /** - * {@code server_error} - The authorization server encountered an - * unexpected condition that prevented it from fulfilling the request. - * (This error code is needed because a 500 Internal Server Error HTTP status code - * cannot be returned to the client via a HTTP redirect.) + * {@code server_error} - The authorization server encountered an unexpected condition + * that prevented it from fulfilling the request. (This error code is needed because a + * 500 Internal Server Error HTTP status code cannot be returned to the client via a + * HTTP redirect.) */ String SERVER_ERROR = "server_error"; /** - * {@code temporarily_unavailable} - The authorization server is currently unable - * to handle the request due to a temporary overloading or maintenance of the server. + * {@code temporarily_unavailable} - The authorization server is currently unable to + * handle the request due to a temporary overloading or maintenance of the server. * (This error code is needed because a 503 Service Unavailable HTTP status code * cannot be returned to the client via an HTTP redirect.) */ String TEMPORARILY_UNAVAILABLE = "temporarily_unavailable"; /** - * {@code invalid_client} - Client authentication failed (e.g., unknown client, - * no client authentication included, or unsupported authentication method). - * The authorization server MAY return a HTTP 401 (Unauthorized) status code - * to indicate which HTTP authentication schemes are supported. - * If the client attempted to authenticate via the "Authorization" request header field, - * the authorization server MUST respond with a HTTP 401 (Unauthorized) status code and - * include the "WWW-Authenticate" response header field matching the authentication scheme used by the client. + * {@code invalid_client} - Client authentication failed (e.g., unknown client, no + * client authentication included, or unsupported authentication method). The + * authorization server MAY return a HTTP 401 (Unauthorized) status code to indicate + * which HTTP authentication schemes are supported. If the client attempted to + * authenticate via the "Authorization" request header field, the + * authorization server MUST respond with a HTTP 401 (Unauthorized) status code and + * include the "WWW-Authenticate" response header field matching the + * authentication scheme used by the client. */ String INVALID_CLIENT = "invalid_client"; /** - * {@code invalid_grant} - The provided authorization grant - * (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, - * does not match the redirection URI used in the authorization request, or was issued to another client. + * {@code invalid_grant} - The provided authorization grant (e.g., authorization code, + * resource owner credentials) or refresh token is invalid, expired, revoked, does not + * match the redirection URI used in the authorization request, or was issued to + * another client. */ String INVALID_GRANT = "invalid_grant"; /** - * {@code unsupported_grant_type} - The authorization grant type is not supported by the authorization server. + * {@code unsupported_grant_type} - The authorization grant type is not supported by + * the authorization server. */ String UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type"; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2RefreshToken.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2RefreshToken.java index e52f364399..26814bc31f 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2RefreshToken.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2RefreshToken.java @@ -13,33 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import java.time.Instant; /** - * An implementation of an {@link AbstractOAuth2Token} representing an OAuth 2.0 Refresh Token. + * An implementation of an {@link AbstractOAuth2Token} representing an OAuth 2.0 Refresh + * Token. * *

        - * A refresh token is a credential that represents an authorization - * granted by the resource owner to the client. - * It is used by the client to obtain a new access token when the current access token - * becomes invalid or expires, or to obtain additional access tokens with identical or narrower scope. + * A refresh token is a credential that represents an authorization granted by the + * resource owner to the client. It is used by the client to obtain a new access token + * when the current access token becomes invalid or expires, or to obtain additional + * access tokens with identical or narrower scope. * * @author Joe Grandja * @since 5.1 * @see OAuth2AccessToken - * @see Section 1.5 Refresh Token + * @see Section + * 1.5 Refresh Token */ public class OAuth2RefreshToken extends AbstractOAuth2Token { /** * Constructs an {@code OAuth2RefreshToken} using the provided parameters. - * * @param tokenValue the token value * @param issuedAt the time at which the token was issued */ public OAuth2RefreshToken(String tokenValue, Instant issuedAt) { super(tokenValue, issuedAt, null); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidator.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidator.java index 95cd5153df..25cb8f78a3 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidator.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidator.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; /** - * Implementations of this interface are responsible for "verifying" - * the validity and/or constraints of the attributes contained in an OAuth 2.0 Token. + * Implementations of this interface are responsible for "verifying" the + * validity and/or constraints of the attributes contained in an OAuth 2.0 Token. * * @author Joe Grandja * @author Josh Cummings @@ -28,9 +29,9 @@ public interface OAuth2TokenValidator { /** * Verify the validity and/or constraints of the provided OAuth 2.0 Token. - * * @param token an OAuth 2.0 token * @return OAuth2TokenValidationResult the success or failure detail of the validation */ OAuth2TokenValidatorResult validate(T token); + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResult.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResult.java index 0922360aef..a8baab9c4f 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResult.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResult.java @@ -30,6 +30,7 @@ import org.springframework.util.Assert; * @since 5.1 */ public final class OAuth2TokenValidatorResult { + static final OAuth2TokenValidatorResult NO_ERRORS = new OAuth2TokenValidatorResult(Collections.emptyList()); private final Collection errors; @@ -41,7 +42,6 @@ public final class OAuth2TokenValidatorResult { /** * Say whether this result indicates success - * * @return whether this result has errors */ public boolean hasErrors() { @@ -50,8 +50,8 @@ public final class OAuth2TokenValidatorResult { /** * Return error details regarding the validation attempt - * - * @return the collection of results in this result, if any; returns an empty list otherwise + * @return the collection of results in this result, if any; returns an empty list + * otherwise */ public Collection getErrors() { return this.errors; @@ -59,7 +59,6 @@ public final class OAuth2TokenValidatorResult { /** * Construct a successful {@link OAuth2TokenValidatorResult} - * * @return an {@link OAuth2TokenValidatorResult} with no errors */ public static OAuth2TokenValidatorResult success() { @@ -68,7 +67,6 @@ public final class OAuth2TokenValidatorResult { /** * Construct a failure {@link OAuth2TokenValidatorResult} with the provided detail - * * @param errors the list of errors * @return an {@link OAuth2TokenValidatorResult} with the errors specified */ @@ -78,15 +76,11 @@ public final class OAuth2TokenValidatorResult { /** * Construct a failure {@link OAuth2TokenValidatorResult} with the provided detail - * * @param errors the list of errors * @return an {@link OAuth2TokenValidatorResult} with the errors specified */ public static OAuth2TokenValidatorResult failure(Collection errors) { - if (errors.isEmpty()) { - return NO_ERRORS; - } - - return new OAuth2TokenValidatorResult(errors); + return (errors.isEmpty()) ? NO_ERRORS : new OAuth2TokenValidatorResult(errors); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimConversionService.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimConversionService.java index a871537534..9ff17edad5 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimConversionService.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimConversionService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.converter; import org.springframework.core.convert.ConversionService; @@ -21,8 +22,8 @@ import org.springframework.core.convert.support.GenericConversionService; import org.springframework.security.oauth2.core.ClaimAccessor; /** - * A {@link ConversionService} configured with converters - * that provide type conversion for claim values. + * A {@link ConversionService} configured with converters that provide type conversion for + * claim values. * * @author Joe Grandja * @since 5.2 @@ -30,6 +31,7 @@ import org.springframework.security.oauth2.core.ClaimAccessor; * @see ClaimAccessor */ public final class ClaimConversionService extends GenericConversionService { + private static volatile ClaimConversionService sharedInstance; private ClaimConversionService() { @@ -38,7 +40,6 @@ public final class ClaimConversionService extends GenericConversionService { /** * Returns a shared instance of {@code ClaimConversionService}. - * * @return a shared instance of {@code ClaimConversionService} */ public static ClaimConversionService getSharedInstance() { @@ -56,9 +57,8 @@ public final class ClaimConversionService extends GenericConversionService { } /** - * Adds the converters that provide type conversion for claim values - * to the provided {@link ConverterRegistry}. - * + * Adds the converters that provide type conversion for claim values to the provided + * {@link ConverterRegistry}. * @param converterRegistry the registry of converters to add to */ public static void addConverters(ConverterRegistry converterRegistry) { @@ -69,4 +69,5 @@ public final class ClaimConversionService extends GenericConversionService { converterRegistry.addConverter(new ObjectToListStringConverter()); converterRegistry.addConverter(new ObjectToMapStringObjectConverter()); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverter.java index 098cb86a01..1eb661f2a4 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverter.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.converter.Converter; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; +package org.springframework.security.oauth2.core.converter; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; +import org.springframework.core.convert.converter.Converter; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + /** * A {@link Converter} that provides type conversion for claim values. * @@ -32,12 +33,13 @@ import java.util.Map; * @see Converter */ public final class ClaimTypeConverter implements Converter, Map> { + private final Map> claimTypeConverters; /** * Constructs a {@code ClaimTypeConverter} using the provided parameters. - * - * @param claimTypeConverters a {@link Map} of {@link Converter}(s) keyed by claim name + * @param claimTypeConverters a {@link Map} of {@link Converter}(s) keyed by claim + * name */ public ClaimTypeConverter(Map> claimTypeConverters) { Assert.notEmpty(claimTypeConverters, "claimTypeConverters cannot be empty"); @@ -50,7 +52,6 @@ public final class ClaimTypeConverter implements Converter, if (CollectionUtils.isEmpty(claims)) { return claims; } - Map result = new HashMap<>(claims); this.claimTypeConverters.forEach((claimName, typeConverter) -> { if (claims.containsKey(claimName)) { @@ -61,7 +62,7 @@ public final class ClaimTypeConverter implements Converter, } } }); - return result; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToBooleanConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToBooleanConverter.java index 82c2243308..8dd12c1e21 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToBooleanConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToBooleanConverter.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.GenericConverter; +package org.springframework.security.oauth2.core.converter; import java.util.Collections; import java.util.Set; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.GenericConverter; + /** * @author Joe Grandja * @since 5.2 @@ -42,4 +43,5 @@ final class ObjectToBooleanConverter implements GenericConverter { } return Boolean.valueOf(source.toString()); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToInstantConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToInstantConverter.java index 65ddaae9d5..bf0909ee82 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToInstantConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToInstantConverter.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.GenericConverter; +package org.springframework.security.oauth2.core.converter; import java.time.Instant; import java.util.Collections; import java.util.Date; import java.util.Set; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.GenericConverter; + /** * @author Joe Grandja * @since 5.2 @@ -50,14 +51,17 @@ final class ObjectToInstantConverter implements GenericConverter { } try { return Instant.ofEpochSecond(Long.parseLong(source.toString())); - } catch (Exception ex) { + } + catch (Exception ex) { // Ignore } try { return Instant.parse(source.toString()); - } catch (Exception ex) { + } + catch (Exception ex) { // Ignore } return null; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToListStringConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToListStringConverter.java index daba913f25..0597f561fe 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToListStringConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToListStringConverter.java @@ -13,18 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.ConditionalGenericConverter; -import org.springframework.util.ClassUtils; - +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; -import java.util.ArrayList; + +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.ConditionalGenericConverter; +import org.springframework.util.ClassUtils; /** * @author Joe Grandja @@ -42,10 +43,9 @@ final class ObjectToListStringConverter implements ConditionalGenericConverter { @Override public boolean matches(TypeDescriptor sourceType, TypeDescriptor targetType) { - if (targetType.getElementTypeDescriptor() == null || - targetType.getElementTypeDescriptor().getType().equals(String.class) || - sourceType == null || - ClassUtils.isAssignable(sourceType.getType(), targetType.getElementTypeDescriptor().getType())) { + if (targetType.getElementTypeDescriptor() == null + || targetType.getElementTypeDescriptor().getType().equals(String.class) || sourceType == null + || ClassUtils.isAssignable(sourceType.getType(), targetType.getElementTypeDescriptor().getType())) { return true; } return false; @@ -73,4 +73,5 @@ final class ObjectToListStringConverter implements ConditionalGenericConverter { } return Collections.singletonList(source.toString()); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToMapStringObjectConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToMapStringObjectConverter.java index 6db09f9cf1..08d4f6ef97 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToMapStringObjectConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToMapStringObjectConverter.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.ConditionalGenericConverter; +package org.springframework.security.oauth2.core.converter; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Set; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.ConditionalGenericConverter; + /** * @author Joe Grandja * @since 5.2 @@ -36,8 +37,8 @@ final class ObjectToMapStringObjectConverter implements ConditionalGenericConver @Override public boolean matches(TypeDescriptor sourceType, TypeDescriptor targetType) { - if (targetType.getElementTypeDescriptor() == null || - targetType.getMapKeyTypeDescriptor().getType().equals(String.class)) { + if (targetType.getElementTypeDescriptor() == null + || targetType.getMapKeyTypeDescriptor().getType().equals(String.class)) { return true; } return false; @@ -59,4 +60,5 @@ final class ObjectToMapStringObjectConverter implements ConditionalGenericConver sourceMap.forEach((k, v) -> result.put(k.toString(), v)); return result; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToStringConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToStringConverter.java index 8a73b71e74..a82e0341b3 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToStringConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToStringConverter.java @@ -13,14 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.GenericConverter; +package org.springframework.security.oauth2.core.converter; import java.util.Collections; import java.util.Set; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.GenericConverter; + /** * @author Joe Grandja * @since 5.2 @@ -34,6 +35,7 @@ final class ObjectToStringConverter implements GenericConverter { @Override public Object convert(Object source, TypeDescriptor sourceType, TypeDescriptor targetType) { - return source == null ? null : source.toString(); + return (source != null) ? source.toString() : null; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToURLConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToURLConverter.java index f24020a378..209afb3ae6 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToURLConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/converter/ObjectToURLConverter.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.GenericConverter; +package org.springframework.security.oauth2.core.converter; import java.net.URI; import java.net.URL; import java.util.Collections; import java.util.Set; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.GenericConverter; + /** * @author Joe Grandja * @since 5.2 @@ -44,9 +45,11 @@ final class ObjectToURLConverter implements GenericConverter { } try { return new URI(source.toString()).toURL(); - } catch (Exception ex) { + } + catch (Exception ex) { // Ignore } return null; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverter.java index 3cc3aa0ede..ef5e138232 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import java.util.Arrays; @@ -27,55 +28,34 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.util.StringUtils; /** - * A {@link Converter} that converts the provided - * OAuth 2.0 Access Token Response parameters to an {@link OAuth2AccessTokenResponse}. + * A {@link Converter} that converts the provided OAuth 2.0 Access Token Response + * parameters to an {@link OAuth2AccessTokenResponse}. * * @author Joe Grandja * @author Nikita Konev * @since 5.3 */ -public final class MapOAuth2AccessTokenResponseConverter implements Converter, OAuth2AccessTokenResponse> { - private static final Set TOKEN_RESPONSE_PARAMETER_NAMES = new HashSet<>(Arrays.asList( - OAuth2ParameterNames.ACCESS_TOKEN, - OAuth2ParameterNames.EXPIRES_IN, - OAuth2ParameterNames.REFRESH_TOKEN, - OAuth2ParameterNames.SCOPE, - OAuth2ParameterNames.TOKEN_TYPE - )); +public final class MapOAuth2AccessTokenResponseConverter + implements Converter, OAuth2AccessTokenResponse> { + + private static final Set TOKEN_RESPONSE_PARAMETER_NAMES = new HashSet<>( + Arrays.asList(OAuth2ParameterNames.ACCESS_TOKEN, OAuth2ParameterNames.EXPIRES_IN, + OAuth2ParameterNames.REFRESH_TOKEN, OAuth2ParameterNames.SCOPE, OAuth2ParameterNames.TOKEN_TYPE)); @Override public OAuth2AccessTokenResponse convert(Map tokenResponseParameters) { String accessToken = tokenResponseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN); - - OAuth2AccessToken.TokenType accessTokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( - tokenResponseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) { - accessTokenType = OAuth2AccessToken.TokenType.BEARER; - } - - long expiresIn = 0; - if (tokenResponseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) { - try { - expiresIn = Long.parseLong(tokenResponseParameters.get(OAuth2ParameterNames.EXPIRES_IN)); - } catch (NumberFormatException ex) { - } - } - - Set scopes = Collections.emptySet(); - if (tokenResponseParameters.containsKey(OAuth2ParameterNames.SCOPE)) { - String scope = tokenResponseParameters.get(OAuth2ParameterNames.SCOPE); - scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); - } - + OAuth2AccessToken.TokenType accessTokenType = getAccessTokenType(tokenResponseParameters); + long expiresIn = getExpiresIn(tokenResponseParameters); + Set scopes = getScopes(tokenResponseParameters); String refreshToken = tokenResponseParameters.get(OAuth2ParameterNames.REFRESH_TOKEN); - Map additionalParameters = new LinkedHashMap<>(); for (Map.Entry entry : tokenResponseParameters.entrySet()) { if (!TOKEN_RESPONSE_PARAMETER_NAMES.contains(entry.getKey())) { additionalParameters.put(entry.getKey(), entry.getValue()); } } - + // @formatter:off return OAuth2AccessTokenResponse.withToken(accessToken) .tokenType(accessTokenType) .expiresIn(expiresIn) @@ -83,5 +63,34 @@ public final class MapOAuth2AccessTokenResponseConverter implements Converter tokenResponseParameters) { + if (OAuth2AccessToken.TokenType.BEARER.getValue() + .equalsIgnoreCase(tokenResponseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) { + return OAuth2AccessToken.TokenType.BEARER; + } + return null; + } + + private long getExpiresIn(Map tokenResponseParameters) { + if (tokenResponseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) { + try { + return Long.parseLong(tokenResponseParameters.get(OAuth2ParameterNames.EXPIRES_IN)); + } + catch (NumberFormatException ex) { + } + } + return 0; + } + + private Set getScopes(Map tokenResponseParameters) { + if (tokenResponseParameters.containsKey(OAuth2ParameterNames.SCOPE)) { + String scope = tokenResponseParameters.get(OAuth2ParameterNames.SCOPE); + return new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " "))); + } + return Collections.emptySet(); + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponse.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponse.java index 5bd0df5696..c09f36e909 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponse.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponse.java @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + import org.springframework.lang.Nullable; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.time.Instant; -import java.util.Collections; -import java.util.Map; -import java.util.Set; - /** * A representation of an OAuth 2.0 Access Token Response. * @@ -33,11 +34,15 @@ import java.util.Set; * @since 5.0 * @see OAuth2AccessToken * @see OAuth2RefreshToken - * @see Section 5.1 Access Token Response + * @see Section + * 5.1 Access Token Response */ public final class OAuth2AccessTokenResponse { + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; + private Map additionalParameters; private OAuth2AccessTokenResponse() { @@ -45,7 +50,6 @@ public final class OAuth2AccessTokenResponse { /** * Returns the {@link OAuth2AccessToken Access Token}. - * * @return the {@link OAuth2AccessToken} */ public OAuth2AccessToken getAccessToken() { @@ -54,9 +58,8 @@ public final class OAuth2AccessTokenResponse { /** * Returns the {@link OAuth2RefreshToken Refresh Token}. - * - * @since 5.1 * @return the {@link OAuth2RefreshToken} + * @since 5.1 */ public @Nullable OAuth2RefreshToken getRefreshToken() { return this.refreshToken; @@ -64,8 +67,8 @@ public final class OAuth2AccessTokenResponse { /** * Returns the additional parameters returned in the response. - * - * @return a {@code Map} of the additional parameters returned in the response, may be empty. + * @return a {@code Map} of the additional parameters returned in the response, may be + * empty. */ public Map getAdditionalParameters() { return this.additionalParameters; @@ -73,7 +76,6 @@ public final class OAuth2AccessTokenResponse { /** * Returns a new {@link Builder}, initialized with the provided access token value. - * * @param tokenValue the value of the access token * @return the {@link Builder} */ @@ -83,7 +85,6 @@ public final class OAuth2AccessTokenResponse { /** * Returns a new {@link Builder}, initialized with the provided response. - * * @param response the response to initialize the builder with * @return the {@link Builder} */ @@ -94,14 +95,22 @@ public final class OAuth2AccessTokenResponse { /** * A builder for {@link OAuth2AccessTokenResponse}. */ - public static class Builder { + public static final class Builder { + private String tokenValue; + private OAuth2AccessToken.TokenType tokenType; + private Instant issuedAt; + private Instant expiresAt; + private long expiresIn; + private Set scopes; + private String refreshToken; + private Map additionalParameters; private Builder(OAuth2AccessTokenResponse response) { @@ -111,8 +120,8 @@ public final class OAuth2AccessTokenResponse { this.issuedAt = accessToken.getIssuedAt(); this.expiresAt = accessToken.getExpiresAt(); this.scopes = accessToken.getScopes(); - this.refreshToken = response.getRefreshToken() == null ? - null : response.getRefreshToken().getTokenValue(); + this.refreshToken = (response.getRefreshToken() != null) ? response.getRefreshToken().getTokenValue() + : null; this.additionalParameters = response.getAdditionalParameters(); } @@ -122,7 +131,6 @@ public final class OAuth2AccessTokenResponse { /** * Sets the {@link OAuth2AccessToken.TokenType token type}. - * * @param tokenType the type of token issued * @return the {@link Builder} */ @@ -133,7 +141,6 @@ public final class OAuth2AccessTokenResponse { /** * Sets the lifetime (in seconds) of the access token. - * * @param expiresIn the lifetime of the access token, in seconds. * @return the {@link Builder} */ @@ -145,7 +152,6 @@ public final class OAuth2AccessTokenResponse { /** * Sets the scope(s) associated to the access token. - * * @param scopes the scope(s) associated to the access token. * @return the {@link Builder} */ @@ -156,7 +162,6 @@ public final class OAuth2AccessTokenResponse { /** * Sets the refresh token associated to the access token. - * * @param refreshToken the refresh token associated to the access token. * @return the {@link Builder} */ @@ -167,7 +172,6 @@ public final class OAuth2AccessTokenResponse { /** * Sets the additional parameters returned in the response. - * * @param additionalParameters the additional parameters returned in the response * @return the {@link Builder} */ @@ -178,21 +182,20 @@ public final class OAuth2AccessTokenResponse { /** * Builds a new {@link OAuth2AccessTokenResponse}. - * * @return a {@link OAuth2AccessTokenResponse} */ public OAuth2AccessTokenResponse build() { Instant issuedAt = getIssuedAt(); Instant expiresAt = getExpiresAt(); - OAuth2AccessTokenResponse accessTokenResponse = new OAuth2AccessTokenResponse(); - accessTokenResponse.accessToken = new OAuth2AccessToken( - this.tokenType, this.tokenValue, issuedAt, expiresAt, this.scopes); + accessTokenResponse.accessToken = new OAuth2AccessToken(this.tokenType, this.tokenValue, issuedAt, + expiresAt, this.scopes); if (StringUtils.hasText(this.refreshToken)) { accessTokenResponse.refreshToken = new OAuth2RefreshToken(this.refreshToken, issuedAt); } - accessTokenResponse.additionalParameters = Collections.unmodifiableMap( - CollectionUtils.isEmpty(this.additionalParameters) ? Collections.emptyMap() : this.additionalParameters); + accessTokenResponse.additionalParameters = Collections + .unmodifiableMap(CollectionUtils.isEmpty(this.additionalParameters) ? Collections.emptyMap() + : this.additionalParameters); return accessTokenResponse; } @@ -204,19 +207,21 @@ public final class OAuth2AccessTokenResponse { } /** - * expires_in is RECOMMENDED, as per spec https://tools.ietf.org/html/rfc6749#section-5.1 - * Therefore, expires_in may not be returned in the Access Token response which would result in the default value of 0. - * For these instances, default the expiresAt to +1 second from issuedAt time. + * expires_in is RECOMMENDED, as per spec + * https://tools.ietf.org/html/rfc6749#section-5.1 Therefore, expires_in may not + * be returned in the Access Token response which would result in the default + * value of 0. For these instances, default the expiresAt to +1 second from + * issuedAt time. * @return */ private Instant getExpiresAt() { if (this.expiresAt == null) { Instant issuedAt = getIssuedAt(); - this.expiresAt = this.expiresIn > 0 ? - issuedAt.plusSeconds(this.expiresIn) : - issuedAt.plusSeconds(1); + this.expiresAt = (this.expiresIn > 0) ? issuedAt.plusSeconds(this.expiresIn) : issuedAt.plusSeconds(1); } return this.expiresAt; } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverter.java index cc023ee4ce..443f03ccee 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import java.time.Instant; @@ -25,27 +26,22 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; /** - * A {@link Converter} that converts the provided {@link OAuth2AccessTokenResponse} - * to a {@code Map} representation of the OAuth 2.0 Access Token Response parameters. + * A {@link Converter} that converts the provided {@link OAuth2AccessTokenResponse} to a + * {@code Map} representation of the OAuth 2.0 Access Token Response parameters. * * @author Joe Grandja * @author Nikita Konev * @since 5.3 */ -public final class OAuth2AccessTokenResponseMapConverter implements Converter> { +public final class OAuth2AccessTokenResponseMapConverter + implements Converter> { @Override public Map convert(OAuth2AccessTokenResponse tokenResponse) { Map parameters = new HashMap<>(); - - long expiresIn = -1; - if (tokenResponse.getAccessToken().getExpiresAt() != null) { - expiresIn = ChronoUnit.SECONDS.between(Instant.now(), tokenResponse.getAccessToken().getExpiresAt()); - } - parameters.put(OAuth2ParameterNames.ACCESS_TOKEN, tokenResponse.getAccessToken().getTokenValue()); parameters.put(OAuth2ParameterNames.TOKEN_TYPE, tokenResponse.getAccessToken().getTokenType().getValue()); - parameters.put(OAuth2ParameterNames.EXPIRES_IN, String.valueOf(expiresIn)); + parameters.put(OAuth2ParameterNames.EXPIRES_IN, String.valueOf(getExpiresIn(tokenResponse))); if (!CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { parameters.put(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(tokenResponse.getAccessToken().getScopes(), " ")); @@ -58,7 +54,14 @@ public final class OAuth2AccessTokenResponseMapConverter implements ConverterSection 4.1.1 Authorization Code Grant Request - * @see Section 4.2.1 Implicit Grant Request + * @see Section 4.1.1 Authorization Code + * Grant Request + * @see Section 4.2.1 Implicit Grant + * Request */ public final class OAuth2AuthorizationRequest implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private String authorizationUri; + private AuthorizationGrantType authorizationGrantType; + private OAuth2AuthorizationResponseType responseType; + private String clientId; + private String redirectUri; + private Set scopes; + private String state; + private Map additionalParameters; + private String authorizationRequestUri; + private Map attributes; private OAuth2AuthorizationRequest() { @@ -67,7 +83,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the uri for the authorization endpoint. - * * @return the uri for the authorization endpoint */ public String getAuthorizationUri() { @@ -76,7 +91,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the {@link AuthorizationGrantType grant type}. - * * @return the {@link AuthorizationGrantType} */ public AuthorizationGrantType getGrantType() { @@ -85,7 +99,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the {@link OAuth2AuthorizationResponseType response type}. - * * @return the {@link OAuth2AuthorizationResponseType} */ public OAuth2AuthorizationResponseType getResponseType() { @@ -94,7 +107,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the client identifier. - * * @return the client identifier */ public String getClientId() { @@ -103,7 +115,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the uri for the redirection endpoint. - * * @return the uri for the redirection endpoint */ public String getRedirectUri() { @@ -112,7 +123,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the scope(s). - * * @return the scope(s), or an empty {@code Set} if not available */ public Set getScopes() { @@ -121,7 +131,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the state. - * * @return the state */ public String getState() { @@ -130,8 +139,8 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the additional parameter(s) used in the request. - * - * @return a {@code Map} of the additional parameter(s), or an empty {@code Map} if not available + * @return a {@code Map} of the additional parameter(s), or an empty {@code Map} if + * not available */ public Map getAdditionalParameters() { return this.additionalParameters; @@ -139,9 +148,8 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the attribute(s) associated to the request. - * - * @since 5.2 * @return a {@code Map} of the attribute(s), or an empty {@code Map} if not available + * @since 5.2 */ public Map getAttributes() { return this.attributes; @@ -149,11 +157,11 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns the value of an attribute associated to the request. - * - * @since 5.2 - * @param name the name of the attribute * @param the type of the attribute - * @return the value of the attribute associated to the request, or {@code null} if not available + * @param name the name of the attribute + * @return the value of the attribute associated to the request, or {@code null} if + * not available + * @since 5.2 */ @SuppressWarnings("unchecked") public T getAttribute(String name) { @@ -161,14 +169,15 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * Returns the {@code URI} string representation of the OAuth 2.0 Authorization Request. + * Returns the {@code URI} string representation of the OAuth 2.0 Authorization + * Request. * *

        * NOTE: The {@code URI} string is encoded in the * {@code application/x-www-form-urlencoded} MIME format. - * + * @return the {@code URI} string representation of the OAuth 2.0 Authorization + * Request * @since 5.1 - * @return the {@code URI} string representation of the OAuth 2.0 Authorization Request */ public String getAuthorizationRequestUri() { return this.authorizationRequestUri; @@ -176,7 +185,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns a new {@link Builder}, initialized with the authorization code grant type. - * * @return the {@link Builder} */ public static Builder authorizationCode() { @@ -185,12 +193,12 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Returns a new {@link Builder}, initialized with the implicit grant type. - * - * @deprecated It is not recommended to use the implicit flow - * due to the inherent risks of returning access tokens in an HTTP redirect - * without any confirmation that it has been received by the client. - * @see OAuth 2.0 Implicit Grant * @return the {@link Builder} + * @deprecated It is not recommended to use the implicit flow due to the inherent + * risks of returning access tokens in an HTTP redirect without any confirmation that + * it has been received by the client. + * @see OAuth 2.0 + * Implicit Grant */ @Deprecated public static Builder implicit() { @@ -198,16 +206,16 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * Returns a new {@link Builder}, initialized with the values - * from the provided {@code authorizationRequest}. - * - * @since 5.1 - * @param authorizationRequest the authorization request used for initializing the {@link Builder} + * Returns a new {@link Builder}, initialized with the values from the provided + * {@code authorizationRequest}. + * @param authorizationRequest the authorization request used for initializing the + * {@link Builder} * @return the {@link Builder} + * @since 5.1 */ public static Builder from(OAuth2AuthorizationRequest authorizationRequest) { Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); - + // @formatter:off return new Builder(authorizationRequest.getGrantType()) .authorizationUri(authorizationRequest.getAuthorizationUri()) .clientId(authorizationRequest.getClientId()) @@ -216,24 +224,39 @@ public final class OAuth2AuthorizationRequest implements Serializable { .state(authorizationRequest.getState()) .additionalParameters(authorizationRequest.getAdditionalParameters()) .attributes(authorizationRequest.getAttributes()); + // @formatter:on } /** * A builder for {@link OAuth2AuthorizationRequest}. */ - public static class Builder { + public static final class Builder { + private String authorizationUri; + private AuthorizationGrantType authorizationGrantType; + private OAuth2AuthorizationResponseType responseType; + private String clientId; + private String redirectUri; + private Set scopes; + private String state; + private Map additionalParameters = new LinkedHashMap<>(); - private Consumer> parametersConsumer = params -> {}; + + private Consumer> parametersConsumer = (params) -> { + }; + private Map attributes = new LinkedHashMap<>(); + private String authorizationRequestUri; - private Function authorizationRequestUriFunction = builder -> builder.build(); + + private Function authorizationRequestUriFunction = (builder) -> builder.build(); + private final DefaultUriBuilderFactory uriBuilderFactory; private Builder(AuthorizationGrantType authorizationGrantType) { @@ -241,18 +264,19 @@ public final class OAuth2AuthorizationRequest implements Serializable { this.authorizationGrantType = authorizationGrantType; if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) { this.responseType = OAuth2AuthorizationResponseType.CODE; - } else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) { + } + else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) { this.responseType = OAuth2AuthorizationResponseType.TOKEN; } this.uriBuilderFactory = new DefaultUriBuilderFactory(); // The supplied authorizationUri may contain encoded parameters - // so disable encoding in UriBuilder and instead apply encoding within this builder + // so disable encoding in UriBuilder and instead apply encoding within this + // builder this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE); } /** * Sets the uri for the authorization endpoint. - * * @param authorizationUri the uri for the authorization endpoint * @return the {@link Builder} */ @@ -263,7 +287,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the client identifier. - * * @param clientId the client identifier * @return the {@link Builder} */ @@ -274,7 +297,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the uri for the redirection endpoint. - * * @param redirectUri the uri for the redirection endpoint * @return the {@link Builder} */ @@ -285,7 +307,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the scope(s). - * * @param scope the scope(s) * @return the {@link Builder} */ @@ -298,7 +319,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the scope(s). - * * @param scopes the scope(s) * @return the {@link Builder} */ @@ -309,7 +329,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the state. - * * @param state the state * @return the {@link Builder} */ @@ -320,7 +339,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the additional parameter(s) used in the request. - * * @param additionalParameters the additional parameter(s) used in the request * @return the {@link Builder} */ @@ -334,9 +352,9 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * A {@code Consumer} to be provided access to the additional parameter(s) * allowing the ability to add, replace, or remove. - * + * @param additionalParametersConsumer a {@code Consumer} of the additional + * parameters * @since 5.3 - * @param additionalParametersConsumer a {@code Consumer} of the additional parameters */ public Builder additionalParameters(Consumer> additionalParametersConsumer) { if (additionalParametersConsumer != null) { @@ -346,11 +364,10 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * A {@code Consumer} to be provided access to all the parameters - * allowing the ability to add, replace, or remove. - * - * @since 5.3 + * A {@code Consumer} to be provided access to all the parameters allowing the + * ability to add, replace, or remove. * @param parametersConsumer a {@code Consumer} of all the parameters + * @since 5.3 */ public Builder parameters(Consumer> parametersConsumer) { if (parametersConsumer != null) { @@ -361,10 +378,9 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Sets the attributes associated to the request. - * - * @since 5.2 * @param attributes the attributes associated to the request * @return the {@link Builder} + * @since 5.2 */ public Builder attributes(Map attributes) { if (!CollectionUtils.isEmpty(attributes)) { @@ -374,11 +390,10 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * A {@code Consumer} to be provided access to the attribute(s) - * allowing the ability to add, replace, or remove. - * - * @since 5.3 + * A {@code Consumer} to be provided access to the attribute(s) allowing the + * ability to add, replace, or remove. * @param attributesConsumer a {@code Consumer} of the attribute(s) + * @since 5.3 */ public Builder attributes(Consumer> attributesConsumer) { if (attributesConsumer != null) { @@ -388,15 +403,16 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * Sets the {@code URI} string representation of the OAuth 2.0 Authorization Request. + * Sets the {@code URI} string representation of the OAuth 2.0 Authorization + * Request. * *

        * NOTE: The {@code URI} string is required to be encoded in the * {@code application/x-www-form-urlencoded} MIME format. - * - * @since 5.1 - * @param authorizationRequestUri the {@code URI} string representation of the OAuth 2.0 Authorization Request + * @param authorizationRequestUri the {@code URI} string representation of the + * OAuth 2.0 Authorization Request * @return the {@link Builder} + * @since 5.1 */ public Builder authorizationRequestUri(String authorizationRequestUri) { this.authorizationRequestUri = authorizationRequestUri; @@ -404,11 +420,11 @@ public final class OAuth2AuthorizationRequest implements Serializable { } /** - * A {@code Function} to be provided a {@code UriBuilder} representation - * of the OAuth 2.0 Authorization Request allowing for further customizations. - * + * A {@code Function} to be provided a {@code UriBuilder} representation of the + * OAuth 2.0 Authorization Request allowing for further customizations. + * @param authorizationRequestUriFunction a {@code Function} to be provided a + * {@code UriBuilder} representation of the OAuth 2.0 Authorization Request * @since 5.3 - * @param authorizationRequestUriFunction a {@code Function} to be provided a {@code UriBuilder} representation of the OAuth 2.0 Authorization Request */ public Builder authorizationRequestUri(Function authorizationRequestUriFunction) { if (authorizationRequestUriFunction != null) { @@ -419,7 +435,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * Builds a new {@link OAuth2AuthorizationRequest}. - * * @return a {@link OAuth2AuthorizationRequest} */ public OAuth2AuthorizationRequest build() { @@ -428,7 +443,6 @@ public final class OAuth2AuthorizationRequest implements Serializable { if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) { Assert.hasText(this.redirectUri, "redirectUri cannot be empty"); } - OAuth2AuthorizationRequest authorizationRequest = new OAuth2AuthorizationRequest(); authorizationRequest.authorizationUri = this.authorizationUri; authorizationRequest.authorizationGrantType = this.authorizationGrantType; @@ -437,25 +451,20 @@ public final class OAuth2AuthorizationRequest implements Serializable { authorizationRequest.redirectUri = this.redirectUri; authorizationRequest.state = this.state; authorizationRequest.scopes = Collections.unmodifiableSet( - CollectionUtils.isEmpty(this.scopes) ? - Collections.emptySet() : new LinkedHashSet<>(this.scopes)); + CollectionUtils.isEmpty(this.scopes) ? Collections.emptySet() : new LinkedHashSet<>(this.scopes)); authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters); authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes); - authorizationRequest.authorizationRequestUri = - StringUtils.hasText(this.authorizationRequestUri) ? - this.authorizationRequestUri : this.buildAuthorizationRequestUri(); - + authorizationRequest.authorizationRequestUri = StringUtils.hasText(this.authorizationRequestUri) + ? this.authorizationRequestUri : this.buildAuthorizationRequestUri(); return authorizationRequest; } private String buildAuthorizationRequestUri() { - Map parameters = getParameters(); // Not encoded + Map parameters = getParameters(); // Not encoded this.parametersConsumer.accept(parameters); MultiValueMap queryParams = new LinkedMultiValueMap<>(); - parameters.forEach((k, v) -> queryParams.set( - encodeQueryParam(k), encodeQueryParam(String.valueOf(v)))); // Encoded - UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri) - .queryParams(queryParams); + parameters.forEach((k, v) -> queryParams.set(encodeQueryParam(k), encodeQueryParam(String.valueOf(v)))); // Encoded + UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri).queryParams(queryParams); return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); } @@ -464,8 +473,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue()); parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId); if (!CollectionUtils.isEmpty(this.scopes)) { - parameters.put(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(this.scopes, " ")); + parameters.put(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.scopes, " ")); } if (this.state != null) { parameters.put(OAuth2ParameterNames.STATE, this.state); @@ -481,5 +489,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { private static String encodeQueryParam(String value) { return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8); } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponse.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponse.java index 1e34de0ac4..d0142d046d 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponse.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponse.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import org.springframework.security.oauth2.core.OAuth2Error; @@ -20,17 +21,24 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; /** - * A representation of an OAuth 2.0 Authorization Response for the authorization code grant type. + * A representation of an OAuth 2.0 Authorization Response for the authorization code + * grant type. * * @author Joe Grandja * @since 5.0 * @see OAuth2Error - * @see Section 4.1.2 Authorization Response + * @see Section 4.1.2 Authorization + * Response */ public final class OAuth2AuthorizationResponse { + private String redirectUri; + private String state; + private String code; + private OAuth2Error error; private OAuth2AuthorizationResponse() { @@ -38,7 +46,6 @@ public final class OAuth2AuthorizationResponse { /** * Returns the uri where the response was redirected to. - * * @return the uri where the response was redirected to */ public String getRedirectUri() { @@ -47,7 +54,6 @@ public final class OAuth2AuthorizationResponse { /** * Returns the state. - * * @return the state */ public String getState() { @@ -56,7 +62,6 @@ public final class OAuth2AuthorizationResponse { /** * Returns the authorization code. - * * @return the authorization code */ public String getCode() { @@ -64,18 +69,20 @@ public final class OAuth2AuthorizationResponse { } /** - * Returns the {@link OAuth2Error OAuth 2.0 Error} if the Authorization Request failed, otherwise {@code null}. - * - * @return the {@link OAuth2Error} if the Authorization Request failed, otherwise {@code null} + * Returns the {@link OAuth2Error OAuth 2.0 Error} if the Authorization Request + * failed, otherwise {@code null}. + * @return the {@link OAuth2Error} if the Authorization Request failed, otherwise + * {@code null} */ public OAuth2Error getError() { return this.error; } /** - * Returns {@code true} if the Authorization Request succeeded, otherwise {@code false}. - * - * @return {@code true} if the Authorization Request succeeded, otherwise {@code false} + * Returns {@code true} if the Authorization Request succeeded, otherwise + * {@code false}. + * @return {@code true} if the Authorization Request succeeded, otherwise + * {@code false} */ public boolean statusOk() { return !this.statusError(); @@ -83,7 +90,6 @@ public final class OAuth2AuthorizationResponse { /** * Returns {@code true} if the Authorization Request failed, otherwise {@code false}. - * * @return {@code true} if the Authorization Request failed, otherwise {@code false} */ public boolean statusError() { @@ -92,7 +98,6 @@ public final class OAuth2AuthorizationResponse { /** * Returns a new {@link Builder}, initialized with the authorization code. - * * @param code the authorization code * @return the {@link Builder} */ @@ -103,7 +108,6 @@ public final class OAuth2AuthorizationResponse { /** * Returns a new {@link Builder}, initialized with the error code. - * * @param errorCode the error code * @return the {@link Builder} */ @@ -115,12 +119,18 @@ public final class OAuth2AuthorizationResponse { /** * A builder for {@link OAuth2AuthorizationResponse}. */ - public static class Builder { + public static final class Builder { + private String redirectUri; + private String state; + private String code; + private String errorCode; + private String errorDescription; + private String errorUri; private Builder() { @@ -128,7 +138,6 @@ public final class OAuth2AuthorizationResponse { /** * Sets the uri where the response was redirected to. - * * @param redirectUri the uri where the response was redirected to * @return the {@link Builder} */ @@ -139,7 +148,6 @@ public final class OAuth2AuthorizationResponse { /** * Sets the state. - * * @param state the state * @return the {@link Builder} */ @@ -150,7 +158,6 @@ public final class OAuth2AuthorizationResponse { /** * Sets the authorization code. - * * @param code the authorization code * @return the {@link Builder} */ @@ -161,7 +168,6 @@ public final class OAuth2AuthorizationResponse { /** * Sets the error code. - * * @param errorCode the error code * @return the {@link Builder} */ @@ -172,7 +178,6 @@ public final class OAuth2AuthorizationResponse { /** * Sets the error description. - * * @param errorDescription the error description * @return the {@link Builder} */ @@ -183,7 +188,6 @@ public final class OAuth2AuthorizationResponse { /** * Sets the error uri. - * * @param errorUri the error uri * @return the {@link Builder} */ @@ -194,7 +198,6 @@ public final class OAuth2AuthorizationResponse { /** * Builds a new {@link OAuth2AuthorizationResponse}. - * * @return a {@link OAuth2AuthorizationResponse} */ public OAuth2AuthorizationResponse build() { @@ -202,17 +205,18 @@ public final class OAuth2AuthorizationResponse { throw new IllegalArgumentException("code and errorCode cannot both be set"); } Assert.hasText(this.redirectUri, "redirectUri cannot be empty"); - OAuth2AuthorizationResponse authorizationResponse = new OAuth2AuthorizationResponse(); authorizationResponse.redirectUri = this.redirectUri; authorizationResponse.state = this.state; if (StringUtils.hasText(this.code)) { authorizationResponse.code = this.code; - } else { - authorizationResponse.error = new OAuth2Error( - this.errorCode, this.errorDescription, this.errorUri); + } + else { + authorizationResponse.error = new OAuth2Error(this.errorCode, this.errorDescription, this.errorUri); } return authorizationResponse; } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseType.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseType.java index ffcdb4329c..4415429058 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseType.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseType.java @@ -13,30 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; +import java.io.Serializable; + import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.io.Serializable; - /** - * The {@code response_type} parameter is consumed by the authorization endpoint which - * is used by the authorization code grant type and implicit grant type. - * The client sets the {@code response_type} parameter with the desired grant type before initiating the authorization request. + * The {@code response_type} parameter is consumed by the authorization endpoint which is + * used by the authorization code grant type and implicit grant type. The client sets the + * {@code response_type} parameter with the desired grant type before initiating the + * authorization request. * *

        - * The {@code response_type} parameter value may be one of "code" for requesting an authorization code or - * "token" for requesting an access token (implicit grant). - + * The {@code response_type} parameter value may be one of "code" for requesting + * an authorization code or "token" for requesting an access token (implicit + * grant). + * * @author Joe Grandja * @since 5.0 - * @see Section 3.1.1 Response Type + * @see Section 3.1.1 Response Type */ public final class OAuth2AuthorizationResponseType implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + public static final OAuth2AuthorizationResponseType CODE = new OAuth2AuthorizationResponseType("code"); + public static final OAuth2AuthorizationResponseType TOKEN = new OAuth2AuthorizationResponseType("token"); + private final String value; private OAuth2AuthorizationResponseType(String value) { @@ -46,7 +54,6 @@ public final class OAuth2AuthorizationResponseType implements Serializable { /** * Returns the value of the authorization response type. - * * @return the value of the authorization response type */ public String getValue() { @@ -69,4 +76,5 @@ public final class OAuth2AuthorizationResponseType implements Serializable { public int hashCode() { return this.getValue().hashCode(); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java index 3bb5b2e910..6368b33fd4 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; /** - * Standard and custom (non-standard) parameter names defined in the OAuth Parameters Registry - * and used by the authorization endpoint and token endpoint. + * Standard and custom (non-standard) parameter names defined in the OAuth Parameters + * Registry and used by the authorization endpoint and token endpoint. * * @author Joe Grandja * @since 5.0 - * @see 11.2 OAuth Parameters Registry + * @see 11.2 + * OAuth Parameters Registry */ public interface OAuth2ParameterNames { @@ -51,7 +53,8 @@ public interface OAuth2ParameterNames { String REDIRECT_URI = "redirect_uri"; /** - * {@code scope} - used in Authorization Request, Authorization Response, Access Token Request and Access Token Response. + * {@code scope} - used in Authorization Request, Authorization Response, Access Token + * Request and Access Token Response. */ String SCOPE = "scope"; @@ -101,7 +104,8 @@ public interface OAuth2ParameterNames { String ERROR = "error"; /** - * {@code error_description} - used in Authorization Response and Access Token Response. + * {@code error_description} - used in Authorization Response and Access Token + * Response. */ String ERROR_DESCRIPTION = "error_description"; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/PkceParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/PkceParameterNames.java index d258a55874..873e7efe54 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/PkceParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/PkceParameterNames.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; /** - * Standard parameter names defined in the OAuth Parameters Registry - * and used by the authorization endpoint and token endpoint. + * Standard parameter names defined in the OAuth Parameters Registry and used by the + * authorization endpoint and token endpoint. * * @author Stephen Doxsee * @author Kevin Bolduc * @since 5.2 - * @see 6.1 OAuth Parameters Registry + * @see 6.1 + * OAuth Parameters Registry */ public interface PkceParameterNames { @@ -40,4 +42,5 @@ public interface PkceParameterNames { * {@code code_verifier} - used in Token Request. */ String CODE_VERIFIER = "code_verifier"; + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/package-info.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/package-info.java index 6c214c76ff..0980dfb4db 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/package-info.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/package-info.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Support classes that model the OAuth 2.0 Request and Response messages - * from the Authorization Endpoint and Token Endpoint. + * Support classes that model the OAuth 2.0 Request and Response messages from the + * Authorization Endpoint and Token Endpoint. */ package org.springframework.security.oauth2.core.endpoint; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/HttpMessageConverters.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/HttpMessageConverters.java index 799599a125..b95d1e8a63 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/HttpMessageConverters.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/HttpMessageConverters.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.http.converter; import org.springframework.http.converter.GenericHttpMessageConverter; @@ -29,26 +30,35 @@ import org.springframework.util.ClassUtils; * @since 5.1 */ final class HttpMessageConverters { + private static final boolean jackson2Present; + private static final boolean gsonPresent; + private static final boolean jsonbPresent; static { ClassLoader classLoader = HttpMessageConverters.class.getClassLoader(); - jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader) && - ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader); + jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader) + && ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader); gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader); jsonbPresent = ClassUtils.isPresent("javax.json.bind.Jsonb", classLoader); } + private HttpMessageConverters() { + } + static GenericHttpMessageConverter getJsonMessageConverter() { if (jackson2Present) { return new MappingJackson2HttpMessageConverter(); - } else if (gsonPresent) { + } + if (gsonPresent) { return new GsonHttpMessageConverter(); - } else if (jsonbPresent) { + } + if (jsonbPresent) { return new JsonbHttpMessageConverter(); } return null; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverter.java index 71e79119ee..513a14fc82 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.http.converter; import java.nio.charset.Charset; @@ -36,26 +37,27 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon import org.springframework.util.Assert; /** - * A {@link HttpMessageConverter} for an {@link OAuth2AccessTokenResponse OAuth 2.0 Access Token Response}. + * A {@link HttpMessageConverter} for an {@link OAuth2AccessTokenResponse OAuth 2.0 Access + * Token Response}. * - * @see AbstractHttpMessageConverter - * @see OAuth2AccessTokenResponse * @author Joe Grandja * @since 5.1 + * @see AbstractHttpMessageConverter + * @see OAuth2AccessTokenResponse */ -public class OAuth2AccessTokenResponseHttpMessageConverter extends AbstractHttpMessageConverter { +public class OAuth2AccessTokenResponseHttpMessageConverter + extends AbstractHttpMessageConverter { + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; - private static final ParameterizedTypeReference> PARAMETERIZED_RESPONSE_TYPE = - new ParameterizedTypeReference>() {}; + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; private GenericHttpMessageConverter jsonMessageConverter = HttpMessageConverters.getJsonMessageConverter(); - protected Converter, OAuth2AccessTokenResponse> tokenResponseConverter = - new MapOAuth2AccessTokenResponseConverter(); + protected Converter, OAuth2AccessTokenResponse> tokenResponseConverter = new MapOAuth2AccessTokenResponseConverter(); - protected Converter> tokenResponseParametersConverter = - new OAuth2AccessTokenResponseMapConverter(); + protected Converter> tokenResponseParametersConverter = new OAuth2AccessTokenResponseMapConverter(); public OAuth2AccessTokenResponseHttpMessageConverter() { super(DEFAULT_CHARSET, MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); @@ -67,57 +69,63 @@ public class OAuth2AccessTokenResponseHttpMessageConverter extends AbstractHttpM } @Override - protected OAuth2AccessTokenResponse readInternal(Class clazz, HttpInputMessage inputMessage) - throws HttpMessageNotReadableException { - + @SuppressWarnings("unchecked") + protected OAuth2AccessTokenResponse readInternal(Class clazz, + HttpInputMessage inputMessage) throws HttpMessageNotReadableException { try { - // gh-6463 - // Parse parameter values as Object in order to handle potential JSON Object and then convert values to String - @SuppressWarnings("unchecked") - Map tokenResponseParameters = (Map) this.jsonMessageConverter.read( - PARAMETERIZED_RESPONSE_TYPE.getType(), null, inputMessage); - return this.tokenResponseConverter.convert( - tokenResponseParameters.entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - entry -> String.valueOf(entry.getValue())))); - } catch (Exception ex) { - throw new HttpMessageNotReadableException("An error occurred reading the OAuth 2.0 Access Token Response: " + - ex.getMessage(), ex, inputMessage); + // gh-6463: Parse parameter values as Object in order to handle potential JSON + // Object and then convert values to String + Map tokenResponseParameters = (Map) this.jsonMessageConverter + .read(STRING_OBJECT_MAP.getType(), null, inputMessage); + // @formatter:off + return this.tokenResponseConverter.convert(tokenResponseParameters + .entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, (entry) -> String.valueOf(entry.getValue())))); + // @formatter:on + } + catch (Exception ex) { + throw new HttpMessageNotReadableException( + "An error occurred reading the OAuth 2.0 Access Token Response: " + ex.getMessage(), ex, + inputMessage); } } @Override protected void writeInternal(OAuth2AccessTokenResponse tokenResponse, HttpOutputMessage outputMessage) throws HttpMessageNotWritableException { - try { Map tokenResponseParameters = this.tokenResponseParametersConverter.convert(tokenResponse); - this.jsonMessageConverter.write( - tokenResponseParameters, PARAMETERIZED_RESPONSE_TYPE.getType(), MediaType.APPLICATION_JSON, outputMessage); - } catch (Exception ex) { - throw new HttpMessageNotWritableException("An error occurred writing the OAuth 2.0 Access Token Response: " + ex.getMessage(), ex); + this.jsonMessageConverter.write(tokenResponseParameters, STRING_OBJECT_MAP.getType(), + MediaType.APPLICATION_JSON, outputMessage); + } + catch (Exception ex) { + throw new HttpMessageNotWritableException( + "An error occurred writing the OAuth 2.0 Access Token Response: " + ex.getMessage(), ex); } } /** - * Sets the {@link Converter} used for converting the OAuth 2.0 Access Token Response parameters - * to an {@link OAuth2AccessTokenResponse}. - * - * @param tokenResponseConverter the {@link Converter} used for converting to an {@link OAuth2AccessTokenResponse} + * Sets the {@link Converter} used for converting the OAuth 2.0 Access Token Response + * parameters to an {@link OAuth2AccessTokenResponse}. + * @param tokenResponseConverter the {@link Converter} used for converting to an + * {@link OAuth2AccessTokenResponse} */ - public final void setTokenResponseConverter(Converter, OAuth2AccessTokenResponse> tokenResponseConverter) { + public final void setTokenResponseConverter( + Converter, OAuth2AccessTokenResponse> tokenResponseConverter) { Assert.notNull(tokenResponseConverter, "tokenResponseConverter cannot be null"); this.tokenResponseConverter = tokenResponseConverter; } /** - * Sets the {@link Converter} used for converting the {@link OAuth2AccessTokenResponse} - * to a {@code Map} representation of the OAuth 2.0 Access Token Response parameters. - * - * @param tokenResponseParametersConverter the {@link Converter} used for converting to a {@code Map} representation of the Access Token Response parameters + * Sets the {@link Converter} used for converting the + * {@link OAuth2AccessTokenResponse} to a {@code Map} representation of the OAuth 2.0 + * Access Token Response parameters. + * @param tokenResponseParametersConverter the {@link Converter} used for converting + * to a {@code Map} representation of the Access Token Response parameters */ - public final void setTokenResponseParametersConverter(Converter> tokenResponseParametersConverter) { + public final void setTokenResponseParametersConverter( + Converter> tokenResponseParametersConverter) { Assert.notNull(tokenResponseParametersConverter, "tokenResponseParametersConverter cannot be null"); this.tokenResponseParametersConverter = tokenResponseParametersConverter; } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverter.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverter.java index a9901c97d9..aa82f778f9 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverter.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverter.java @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.http.converter; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpInputMessage; @@ -30,25 +37,20 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; - /** * A {@link HttpMessageConverter} for an {@link OAuth2Error OAuth 2.0 Error}. * - * @see AbstractHttpMessageConverter - * @see OAuth2Error * @author Joe Grandja * @since 5.1 + * @see AbstractHttpMessageConverter + * @see OAuth2Error */ public class OAuth2ErrorHttpMessageConverter extends AbstractHttpMessageConverter { + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; - private static final ParameterizedTypeReference> PARAMETERIZED_RESPONSE_TYPE = - new ParameterizedTypeReference>() {}; + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; private GenericHttpMessageConverter jsonMessageConverter = HttpMessageConverters.getJsonMessageConverter(); @@ -66,44 +68,42 @@ public class OAuth2ErrorHttpMessageConverter extends AbstractHttpMessageConverte } @Override + @SuppressWarnings("unchecked") protected OAuth2Error readInternal(Class clazz, HttpInputMessage inputMessage) throws HttpMessageNotReadableException { - try { - // gh-8157 - // Parse parameter values as Object in order to handle potential JSON Object and then convert values to String - @SuppressWarnings("unchecked") - Map errorParameters = (Map) this.jsonMessageConverter.read( - PARAMETERIZED_RESPONSE_TYPE.getType(), null, inputMessage); - return this.errorConverter.convert( - errorParameters.entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - entry -> String.valueOf(entry.getValue())))); - } catch (Exception ex) { - throw new HttpMessageNotReadableException("An error occurred reading the OAuth 2.0 Error: " + - ex.getMessage(), ex, inputMessage); + // gh-8157: Parse parameter values as Object in order to handle potential JSON + // Object and then convert values to String + Map errorParameters = (Map) this.jsonMessageConverter + .read(STRING_OBJECT_MAP.getType(), null, inputMessage); + return this.errorConverter.convert(errorParameters.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, (entry) -> String.valueOf(entry.getValue())))); + } + catch (Exception ex) { + throw new HttpMessageNotReadableException( + "An error occurred reading the OAuth 2.0 Error: " + ex.getMessage(), ex, inputMessage); } } @Override protected void writeInternal(OAuth2Error oauth2Error, HttpOutputMessage outputMessage) throws HttpMessageNotWritableException { - try { Map errorParameters = this.errorParametersConverter.convert(oauth2Error); - this.jsonMessageConverter.write( - errorParameters, PARAMETERIZED_RESPONSE_TYPE.getType(), MediaType.APPLICATION_JSON, outputMessage); - } catch (Exception ex) { - throw new HttpMessageNotWritableException("An error occurred writing the OAuth 2.0 Error: " + ex.getMessage(), ex); + this.jsonMessageConverter.write(errorParameters, STRING_OBJECT_MAP.getType(), MediaType.APPLICATION_JSON, + outputMessage); + } + catch (Exception ex) { + throw new HttpMessageNotWritableException( + "An error occurred writing the OAuth 2.0 Error: " + ex.getMessage(), ex); } } /** - * Sets the {@link Converter} used for converting the OAuth 2.0 Error parameters - * to an {@link OAuth2Error}. - * - * @param errorConverter the {@link Converter} used for converting to an {@link OAuth2Error} + * Sets the {@link Converter} used for converting the OAuth 2.0 Error parameters to an + * {@link OAuth2Error}. + * @param errorConverter the {@link Converter} used for converting to an + * {@link OAuth2Error} */ public final void setErrorConverter(Converter, OAuth2Error> errorConverter) { Assert.notNull(errorConverter, "errorConverter cannot be null"); @@ -111,19 +111,20 @@ public class OAuth2ErrorHttpMessageConverter extends AbstractHttpMessageConverte } /** - * Sets the {@link Converter} used for converting the {@link OAuth2Error} - * to a {@code Map} representation of the OAuth 2.0 Error parameters. - * - * @param errorParametersConverter the {@link Converter} used for converting to a {@code Map} representation of the Error parameters + * Sets the {@link Converter} used for converting the {@link OAuth2Error} to a + * {@code Map} representation of the OAuth 2.0 Error parameters. + * @param errorParametersConverter the {@link Converter} used for converting to a + * {@code Map} representation of the Error parameters */ - public final void setErrorParametersConverter(Converter> errorParametersConverter) { + public final void setErrorParametersConverter( + Converter> errorParametersConverter) { Assert.notNull(errorParametersConverter, "errorParametersConverter cannot be null"); this.errorParametersConverter = errorParametersConverter; } /** - * A {@link Converter} that converts the provided - * OAuth 2.0 Error parameters to an {@link OAuth2Error}. + * A {@link Converter} that converts the provided OAuth 2.0 Error parameters to an + * {@link OAuth2Error}. */ private static class OAuth2ErrorConverter implements Converter, OAuth2Error> { @@ -132,21 +133,20 @@ public class OAuth2ErrorHttpMessageConverter extends AbstractHttpMessageConverte String errorCode = parameters.get(OAuth2ParameterNames.ERROR); String errorDescription = parameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION); String errorUri = parameters.get(OAuth2ParameterNames.ERROR_URI); - return new OAuth2Error(errorCode, errorDescription, errorUri); } + } /** - * A {@link Converter} that converts the provided {@link OAuth2Error} - * to a {@code Map} representation of OAuth 2.0 Error parameters. + * A {@link Converter} that converts the provided {@link OAuth2Error} to a {@code Map} + * representation of OAuth 2.0 Error parameters. */ private static class OAuth2ErrorParametersConverter implements Converter> { @Override public Map convert(OAuth2Error oauth2Error) { Map parameters = new HashMap<>(); - parameters.put(OAuth2ParameterNames.ERROR, oauth2Error.getErrorCode()); if (StringUtils.hasText(oauth2Error.getDescription())) { parameters.put(OAuth2ParameterNames.ERROR_DESCRIPTION, oauth2Error.getDescription()); @@ -154,8 +154,9 @@ public class OAuth2ErrorHttpMessageConverter extends AbstractHttpMessageConverte if (StringUtils.hasText(oauth2Error.getUri())) { parameters.put(OAuth2ParameterNames.ERROR_URI, oauth2Error.getUri()); } - return parameters; } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/AddressStandardClaim.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/AddressStandardClaim.java index 47037a3127..ebce9e3b9a 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/AddressStandardClaim.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/AddressStandardClaim.java @@ -13,58 +13,59 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; /** - * The Address Claim represents a physical mailing address defined by the OpenID Connect Core 1.0 specification - * that can be returned either in the UserInfo Response or the ID Token. + * The Address Claim represents a physical mailing address defined by the OpenID Connect + * Core 1.0 specification that can be returned either in the UserInfo Response or the ID + * Token. * * @author Joe Grandja * @since 5.0 - * @see Address Claim - * @see UserInfo Response - * @see ID Token + * @see Address Claim + * @see UserInfo + * Response + * @see ID Token */ public interface AddressStandardClaim { /** * Returns the full mailing address, formatted for display. - * * @return the full mailing address */ String getFormatted(); /** - * Returns the full street address, which may include house number, street name, P.O. Box, etc. - * + * Returns the full street address, which may include house number, street name, P.O. + * Box, etc. * @return the full street address */ String getStreetAddress(); /** * Returns the city or locality. - * * @return the city or locality */ String getLocality(); /** * Returns the state, province, prefecture, or region. - * * @return the state, province, prefecture, or region */ String getRegion(); /** * Returns the zip code or postal code. - * * @return the zip code or postal code */ String getPostalCode(); /** * Returns the country. - * * @return the country */ String getCountry(); diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaim.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaim.java index 48e2fc8eb5..5d6a59cf97 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaim.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaim.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; import java.util.Map; @@ -25,11 +26,17 @@ import java.util.Map; * @see AddressStandardClaim */ public final class DefaultAddressStandardClaim implements AddressStandardClaim { + private String formatted; + private String streetAddress; + private String locality; + private String region; + private String postalCode; + private String country; private DefaultAddressStandardClaim() { @@ -73,35 +80,37 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { if (obj == null || !AddressStandardClaim.class.isAssignableFrom(obj.getClass())) { return false; } - - AddressStandardClaim that = (AddressStandardClaim) obj; - - if (this.getFormatted() != null ? !this.getFormatted().equals(that.getFormatted()) : that.getFormatted() != null) { + AddressStandardClaim other = (AddressStandardClaim) obj; + if ((this.getFormatted() != null) ? !this.getFormatted().equals(other.getFormatted()) + : other.getFormatted() != null) { return false; } - if (this.getStreetAddress() != null ? !this.getStreetAddress().equals(that.getStreetAddress()) : that.getStreetAddress() != null) { + if ((this.getStreetAddress() != null) ? !this.getStreetAddress().equals(other.getStreetAddress()) + : other.getStreetAddress() != null) { return false; } - if (this.getLocality() != null ? !this.getLocality().equals(that.getLocality()) : that.getLocality() != null) { + if ((this.getLocality() != null) ? !this.getLocality().equals(other.getLocality()) + : other.getLocality() != null) { return false; } - if (this.getRegion() != null ? !this.getRegion().equals(that.getRegion()) : that.getRegion() != null) { + if ((this.getRegion() != null) ? !this.getRegion().equals(other.getRegion()) : other.getRegion() != null) { return false; } - if (this.getPostalCode() != null ? !this.getPostalCode().equals(that.getPostalCode()) : that.getPostalCode() != null) { + if ((this.getPostalCode() != null) ? !this.getPostalCode().equals(other.getPostalCode()) + : other.getPostalCode() != null) { return false; } - return this.getCountry() != null ? this.getCountry().equals(that.getCountry()) : that.getCountry() == null; + return (this.getCountry() != null) ? this.getCountry().equals(other.getCountry()) : other.getCountry() == null; } @Override public int hashCode() { - int result = this.getFormatted() != null ? this.getFormatted().hashCode() : 0; - result = 31 * result + (this.getStreetAddress() != null ? this.getStreetAddress().hashCode() : 0); - result = 31 * result + (this.getLocality() != null ? this.getLocality().hashCode() : 0); - result = 31 * result + (this.getRegion() != null ? this.getRegion().hashCode() : 0); - result = 31 * result + (this.getPostalCode() != null ? this.getPostalCode().hashCode() : 0); - result = 31 * result + (this.getCountry() != null ? this.getCountry().hashCode() : 0); + int result = (this.getFormatted() != null) ? this.getFormatted().hashCode() : 0; + result = 31 * result + ((this.getStreetAddress() != null) ? this.getStreetAddress().hashCode() : 0); + result = 31 * result + ((this.getLocality() != null) ? this.getLocality().hashCode() : 0); + result = 31 * result + ((this.getRegion() != null) ? this.getRegion().hashCode() : 0); + result = 31 * result + ((this.getPostalCode() != null) ? this.getPostalCode().hashCode() : 0); + result = 31 * result + ((this.getCountry() != null) ? this.getCountry().hashCode() : 0); return result; } @@ -109,17 +118,29 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { * A builder for {@link DefaultAddressStandardClaim}. */ public static class Builder { + private static final String FORMATTED_FIELD_NAME = "formatted"; + private static final String STREET_ADDRESS_FIELD_NAME = "street_address"; + private static final String LOCALITY_FIELD_NAME = "locality"; + private static final String REGION_FIELD_NAME = "region"; + private static final String POSTAL_CODE_FIELD_NAME = "postal_code"; + private static final String COUNTRY_FIELD_NAME = "country"; + private String formatted; + private String streetAddress; + private String locality; + private String region; + private String postalCode; + private String country; /** @@ -129,8 +150,8 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { } /** - * Constructs and initializes the address attributes using the provided {@code addressFields}. - * + * Constructs and initializes the address attributes using the provided + * {@code addressFields}. * @param addressFields the fields used to initialize the address attributes */ public Builder(Map addressFields) { @@ -144,7 +165,6 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { /** * Sets the full mailing address, formatted for display. - * * @param formatted the full mailing address * @return the {@link Builder} */ @@ -154,8 +174,8 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { } /** - * Sets the full street address, which may include house number, street name, P.O. Box, etc. - * + * Sets the full street address, which may include house number, street name, P.O. + * Box, etc. * @param streetAddress the full street address * @return the {@link Builder} */ @@ -166,7 +186,6 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { /** * Sets the city or locality. - * * @param locality the city or locality * @return the {@link Builder} */ @@ -177,7 +196,6 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { /** * Sets the state, province, prefecture, or region. - * * @param region the state, province, prefecture, or region * @return the {@link Builder} */ @@ -188,7 +206,6 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { /** * Sets the zip code or postal code. - * * @param postalCode the zip code or postal code * @return the {@link Builder} */ @@ -199,7 +216,6 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { /** * Sets the country. - * * @param country the country * @return the {@link Builder} */ @@ -210,7 +226,6 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { /** * Builds a new {@link DefaultAddressStandardClaim}. - * * @return a {@link AddressStandardClaim} */ public AddressStandardClaim build() { @@ -221,8 +236,9 @@ public final class DefaultAddressStandardClaim implements AddressStandardClaim { address.region = this.region; address.postalCode = this.postalCode; address.country = this.country; - return address; } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimAccessor.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimAccessor.java index 0170f933ec..f229848460 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimAccessor.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimAccessor.java @@ -13,33 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc; -import org.springframework.security.oauth2.core.ClaimAccessor; +package org.springframework.security.oauth2.core.oidc; import java.net.URL; import java.time.Instant; import java.util.List; +import org.springframework.security.oauth2.core.ClaimAccessor; + /** - * A {@link ClaimAccessor} for the "claims" that can be returned in the ID Token, - * which provides information about the authentication of an End-User by an Authorization Server. + * A {@link ClaimAccessor} for the "claims" that can be returned in the ID + * Token, which provides information about the authentication of an End-User by an + * Authorization Server. * + * @author Joe Grandja + * @since 5.0 * @see ClaimAccessor * @see StandardClaimAccessor * @see StandardClaimNames * @see IdTokenClaimNames * @see OidcIdToken - * @see ID Token - * @see Standard Claims - * @author Joe Grandja - * @since 5.0 + * @see ID Token + * @see Standard + * Claims */ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Issuer identifier {@code (iss)}. - * * @return the Issuer identifier */ default URL getIssuer() { @@ -48,16 +52,15 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Subject identifier {@code (sub)}. - * * @return the Subject identifier */ + @Override default String getSubject() { return this.getClaimAsString(IdTokenClaimNames.SUB); } /** * Returns the Audience(s) {@code (aud)} that this ID Token is intended for. - * * @return the Audience(s) that this ID Token is intended for */ default List getAudience() { @@ -65,8 +68,8 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { } /** - * Returns the Expiration time {@code (exp)} on or after which the ID Token MUST NOT be accepted. - * + * Returns the Expiration time {@code (exp)} on or after which the ID Token MUST NOT + * be accepted. * @return the Expiration time on or after which the ID Token MUST NOT be accepted */ default Instant getExpiresAt() { @@ -75,7 +78,6 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the time at which the ID Token was issued {@code (iat)}. - * * @return the time at which the ID Token was issued */ default Instant getIssuedAt() { @@ -84,7 +86,6 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the time when the End-User authentication occurred {@code (auth_time)}. - * * @return the time when the End-User authentication occurred */ default Instant getAuthenticatedAt() { @@ -92,9 +93,8 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { } /** - * Returns a {@code String} value {@code (nonce)} used to associate a Client session with an ID Token, - * and to mitigate replay attacks. - * + * Returns a {@code String} value {@code (nonce)} used to associate a Client session + * with an ID Token, and to mitigate replay attacks. * @return the nonce used to associate a Client session with an ID Token */ default String getNonce() { @@ -103,7 +103,6 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Authentication Context Class Reference {@code (acr)}. - * * @return the Authentication Context Class Reference */ default String getAuthenticationContextClass() { @@ -112,7 +111,6 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Authentication Methods References {@code (amr)}. - * * @return the Authentication Methods References */ default List getAuthenticationMethods() { @@ -121,7 +119,6 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Authorized party {@code (azp)} to which the ID Token was issued. - * * @return the Authorized party to which the ID Token was issued */ default String getAuthorizedParty() { @@ -130,7 +127,6 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Access Token hash value {@code (at_hash)}. - * * @return the Access Token hash value */ default String getAccessTokenHash() { @@ -139,10 +135,10 @@ public interface IdTokenClaimAccessor extends StandardClaimAccessor { /** * Returns the Authorization Code hash value {@code (c_hash)}. - * * @return the Authorization Code hash value */ default String getAuthorizationCodeHash() { return this.getClaimAsString(IdTokenClaimNames.C_HASH); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimNames.java index c73b604f3d..76e9765ab7 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/IdTokenClaimNames.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; /** - * The names of the "claims" defined by the OpenID Connect Core 1.0 specification - * that can be returned in the ID Token. + * The names of the "claims" defined by the OpenID Connect Core 1.0 + * specification that can be returned in the ID Token. * * @author Joe Grandja * @since 5.0 * @see OidcIdToken - * @see ID Token + * @see ID Token */ public interface IdTokenClaimNames { @@ -43,7 +45,8 @@ public interface IdTokenClaimNames { String AUD = "aud"; /** - * {@code exp} - the Expiration time on or after which the ID Token MUST NOT be accepted + * {@code exp} - the Expiration time on or after which the ID Token MUST NOT be + * accepted */ String EXP = "exp"; @@ -58,8 +61,8 @@ public interface IdTokenClaimNames { String AUTH_TIME = "auth_time"; /** - * {@code nonce} - a {@code String} value used to associate a Client session with an ID Token, - * and to mitigate replay attacks. + * {@code nonce} - a {@code String} value used to associate a Client session with an + * ID Token, and to mitigate replay attacks. */ String NONCE = "nonce"; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcIdToken.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcIdToken.java index 5f3740dc43..87f72cd353 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcIdToken.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcIdToken.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; import java.time.Instant; @@ -26,43 +27,35 @@ import java.util.function.Consumer; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.ACR; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.AMR; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.AT_HASH; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.AUD; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.AUTH_TIME; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.AZP; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.C_HASH; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.EXP; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.IAT; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.ISS; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.NONCE; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.SUB; - /** - * An implementation of an {@link AbstractOAuth2Token} representing an OpenID Connect Core 1.0 ID Token. + * An implementation of an {@link AbstractOAuth2Token} representing an OpenID Connect Core + * 1.0 ID Token. * *

        - * The {@code OidcIdToken} is a security token that contains "claims" - * about the authentication of an End-User by an Authorization Server. + * The {@code OidcIdToken} is a security token that contains "claims" about the + * authentication of an End-User by an Authorization Server. * * @author Joe Grandja * @since 5.0 * @see AbstractOAuth2Token * @see IdTokenClaimAccessor * @see StandardClaimAccessor - * @see ID Token - * @see Standard Claims + * @see ID Token + * @see Standard + * Claims */ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAccessor { + private final Map claims; /** * Constructs a {@code OidcIdToken} using the provided parameters. - * * @param tokenValue the ID Token value * @param issuedAt the time at which the ID Token was issued {@code (iat)} - * @param expiresAt the expiration time {@code (exp)} on or after which the ID Token MUST NOT be accepted + * @param expiresAt the expiration time {@code (exp)} on or after which the ID Token + * MUST NOT be accepted * @param claims the claims about the authentication of the End-User */ public OidcIdToken(String tokenValue, Instant issuedAt, Instant expiresAt, Map claims) { @@ -78,7 +71,6 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce /** * Create a {@link Builder} based on the given token value - * * @param tokenValue the token value to use * @return the {@link Builder} for further configuration * @since 5.3 @@ -94,7 +86,9 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce * @since 5.3 */ public static final class Builder { + private String tokenValue; + private final Map claims = new LinkedHashMap<>(); private Builder(String tokenValue) { @@ -103,7 +97,6 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce /** * Use this token value in the resulting {@link OidcIdToken} - * * @param tokenValue The token value to use * @return the {@link Builder} for further configurations */ @@ -114,7 +107,6 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce /** * Use this claim in the resulting {@link OidcIdToken} - * * @param name The claim name * @param value The claim value * @return the {@link Builder} for further configurations @@ -125,8 +117,8 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce } /** - * Provides access to every {@link #claim(String, Object)} - * declared so far with the possibility to add, replace, or remove. + * Provides access to every {@link #claim(String, Object)} declared so far with + * the possibility to add, replace, or remove. * @param claimsConsumer the consumer * @return the {@link Builder} for further configurations */ @@ -137,132 +129,121 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce /** * Use this access token hash in the resulting {@link OidcIdToken} - * * @param accessTokenHash The access token hash to use * @return the {@link Builder} for further configurations */ public Builder accessTokenHash(String accessTokenHash) { - return claim(AT_HASH, accessTokenHash); + return claim(IdTokenClaimNames.AT_HASH, accessTokenHash); } /** * Use this audience in the resulting {@link OidcIdToken} - * * @param audience The audience(s) to use * @return the {@link Builder} for further configurations */ public Builder audience(Collection audience) { - return claim(AUD, audience); + return claim(IdTokenClaimNames.AUD, audience); } /** * Use this authentication {@link Instant} in the resulting {@link OidcIdToken} - * * @param authenticatedAt The authentication {@link Instant} to use * @return the {@link Builder} for further configurations */ public Builder authTime(Instant authenticatedAt) { - return claim(AUTH_TIME, authenticatedAt); + return claim(IdTokenClaimNames.AUTH_TIME, authenticatedAt); } /** - * Use this authentication context class reference in the resulting {@link OidcIdToken} - * - * @param authenticationContextClass The authentication context class reference to use + * Use this authentication context class reference in the resulting + * {@link OidcIdToken} + * @param authenticationContextClass The authentication context class reference to + * use * @return the {@link Builder} for further configurations */ public Builder authenticationContextClass(String authenticationContextClass) { - return claim(ACR, authenticationContextClass); + return claim(IdTokenClaimNames.ACR, authenticationContextClass); } /** * Use these authentication methods in the resulting {@link OidcIdToken} - * * @param authenticationMethods The authentication methods to use * @return the {@link Builder} for further configurations */ public Builder authenticationMethods(List authenticationMethods) { - return claim(AMR, authenticationMethods); + return claim(IdTokenClaimNames.AMR, authenticationMethods); } /** * Use this authorization code hash in the resulting {@link OidcIdToken} - * * @param authorizationCodeHash The authorization code hash to use * @return the {@link Builder} for further configurations */ public Builder authorizationCodeHash(String authorizationCodeHash) { - return claim(C_HASH, authorizationCodeHash); + return claim(IdTokenClaimNames.C_HASH, authorizationCodeHash); } /** * Use this authorized party in the resulting {@link OidcIdToken} - * * @param authorizedParty The authorized party to use * @return the {@link Builder} for further configurations */ public Builder authorizedParty(String authorizedParty) { - return claim(AZP, authorizedParty); + return claim(IdTokenClaimNames.AZP, authorizedParty); } /** * Use this expiration in the resulting {@link OidcIdToken} - * * @param expiresAt The expiration to use * @return the {@link Builder} for further configurations */ public Builder expiresAt(Instant expiresAt) { - return this.claim(EXP, expiresAt); + return this.claim(IdTokenClaimNames.EXP, expiresAt); } /** * Use this issued-at timestamp in the resulting {@link OidcIdToken} - * * @param issuedAt The issued-at timestamp to use * @return the {@link Builder} for further configurations */ public Builder issuedAt(Instant issuedAt) { - return this.claim(IAT, issuedAt); + return this.claim(IdTokenClaimNames.IAT, issuedAt); } /** * Use this issuer in the resulting {@link OidcIdToken} - * * @param issuer The issuer to use * @return the {@link Builder} for further configurations */ public Builder issuer(String issuer) { - return this.claim(ISS, issuer); + return this.claim(IdTokenClaimNames.ISS, issuer); } /** * Use this nonce in the resulting {@link OidcIdToken} - * * @param nonce The nonce to use * @return the {@link Builder} for further configurations */ public Builder nonce(String nonce) { - return this.claim(NONCE, nonce); + return this.claim(IdTokenClaimNames.NONCE, nonce); } /** * Use this subject in the resulting {@link OidcIdToken} - * * @param subject The subject to use * @return the {@link Builder} for further configurations */ public Builder subject(String subject) { - return this.claim(SUB, subject); + return this.claim(IdTokenClaimNames.SUB, subject); } /** * Build the {@link OidcIdToken} - * * @return The constructed {@link OidcIdToken} */ public OidcIdToken build() { - Instant iat = toInstant(this.claims.get(IAT)); - Instant exp = toInstant(this.claims.get(EXP)); + Instant iat = toInstant(this.claims.get(IdTokenClaimNames.IAT)); + Instant exp = toInstant(this.claims.get(IdTokenClaimNames.EXP)); return new OidcIdToken(this.tokenValue, iat, exp, this.claims); } @@ -272,5 +253,7 @@ public class OidcIdToken extends AbstractOAuth2Token implements IdTokenClaimAcce } return (Instant) timestamp; } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcScopes.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcScopes.java index e8b70c757a..083266e591 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcScopes.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcScopes.java @@ -13,22 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; import org.springframework.security.oauth2.core.OAuth2AccessToken; /** - * The scope values defined by the OpenID Connect Core 1.0 specification - * that can be used to request {@link StandardClaimNames claims}. + * The scope values defined by the OpenID Connect Core 1.0 specification that can be used + * to request {@link StandardClaimNames claims}. *

        - * The scope(s) associated to an {@link OAuth2AccessToken} determine what claims (resources) - * will be available when they are used to access OAuth 2.0 Protected Endpoints, - * such as the UserInfo Endpoint. + * The scope(s) associated to an {@link OAuth2AccessToken} determine what claims + * (resources) will be available when they are used to access OAuth 2.0 Protected + * Endpoints, such as the UserInfo Endpoint. * * @author Joe Grandja * @since 5.0 * @see StandardClaimNames - * @see Requesting Claims using Scope Values + * @see Requesting Claims + * using Scope Values */ public interface OidcScopes { @@ -45,7 +48,8 @@ public interface OidcScopes { String PROFILE = "profile"; /** - * The {@code email} scope requests access to the {@code email} and {@code email_verified} claims. + * The {@code email} scope requests access to the {@code email} and + * {@code email_verified} claims. */ String EMAIL = "email"; @@ -55,7 +59,8 @@ public interface OidcScopes { String ADDRESS = "address"; /** - * The {@code phone} scope requests access to the {@code phone_number} and {@code phone_number_verified} claims. + * The {@code phone} scope requests access to the {@code phone_number} and + * {@code phone_number_verified} claims. */ String PHONE = "phone"; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcUserInfo.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcUserInfo.java index 5de4899ff6..812a03ddd7 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcUserInfo.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/OidcUserInfo.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; import java.io.Serializable; @@ -25,48 +26,34 @@ import java.util.function.Consumer; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.ADDRESS; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.BIRTHDATE; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.EMAIL; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.EMAIL_VERIFIED; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.FAMILY_NAME; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.GENDER; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.GIVEN_NAME; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.LOCALE; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.MIDDLE_NAME; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.NAME; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.NICKNAME; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.PHONE_NUMBER; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.PHONE_NUMBER_VERIFIED; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.PICTURE; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.PREFERRED_USERNAME; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.PROFILE; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.SUB; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.UPDATED_AT; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.WEBSITE; -import static org.springframework.security.oauth2.core.oidc.StandardClaimNames.ZONEINFO; - /** - * A representation of a UserInfo Response that is returned - * from the OAuth 2.0 Protected Resource UserInfo Endpoint. + * A representation of a UserInfo Response that is returned from the OAuth 2.0 Protected + * Resource UserInfo Endpoint. * *

        - * The {@code OidcUserInfo} contains a set of "Standard Claims" about the authentication of an End-User. + * The {@code OidcUserInfo} contains a set of "Standard Claims" about the + * authentication of an End-User. * * @author Joe Grandja * @since 5.0 * @see StandardClaimAccessor - * @see UserInfo Response - * @see UserInfo Endpoint - * @see Standard Claims + * @see UserInfo + * Response + * @see UserInfo Endpoint + * @see Standard + * Claims */ public class OidcUserInfo implements StandardClaimAccessor, Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final Map claims; /** * Constructs a {@code OidcUserInfo} using the provided parameters. - * * @param claims the claims about the authentication of the End-User */ public OidcUserInfo(Map claims) { @@ -87,9 +74,7 @@ public class OidcUserInfo implements StandardClaimAccessor, Serializable { if (obj == null || this.getClass() != obj.getClass()) { return false; } - OidcUserInfo that = (OidcUserInfo) obj; - return this.getClaims().equals(that.getClaims()); } @@ -100,7 +85,6 @@ public class OidcUserInfo implements StandardClaimAccessor, Serializable { /** * Create a {@link Builder} - * * @return the {@link Builder} for further configuration * @since 5.3 */ @@ -115,13 +99,14 @@ public class OidcUserInfo implements StandardClaimAccessor, Serializable { * @since 5.3 */ public static final class Builder { + private final Map claims = new LinkedHashMap<>(); - private Builder() {} + private Builder() { + } /** * Use this claim in the resulting {@link OidcUserInfo} - * * @param name The claim name * @param value The claim value * @return the {@link Builder} for further configurations @@ -132,8 +117,8 @@ public class OidcUserInfo implements StandardClaimAccessor, Serializable { } /** - * Provides access to every {@link #claim(String, Object)} - * declared so far with the possibility to add, replace, or remove. + * Provides access to every {@link #claim(String, Object)} declared so far with + * the possibility to add, replace, or remove. * @param claimsConsumer the consumer * @return the {@link Builder} for further configurations */ @@ -144,211 +129,192 @@ public class OidcUserInfo implements StandardClaimAccessor, Serializable { /** * Use this address in the resulting {@link OidcUserInfo} - * * @param address The address to use * @return the {@link Builder} for further configurations */ public Builder address(String address) { - return this.claim(ADDRESS, address); + return this.claim(StandardClaimNames.ADDRESS, address); } /** * Use this birthdate in the resulting {@link OidcUserInfo} - * * @param birthdate The birthdate to use * @return the {@link Builder} for further configurations */ public Builder birthdate(String birthdate) { - return this.claim(BIRTHDATE, birthdate); + return this.claim(StandardClaimNames.BIRTHDATE, birthdate); } /** * Use this email in the resulting {@link OidcUserInfo} - * * @param email The email to use * @return the {@link Builder} for further configurations */ public Builder email(String email) { - return this.claim(EMAIL, email); + return this.claim(StandardClaimNames.EMAIL, email); } /** * Use this verified-email indicator in the resulting {@link OidcUserInfo} - * * @param emailVerified The verified-email indicator to use * @return the {@link Builder} for further configurations */ public Builder emailVerified(Boolean emailVerified) { - return this.claim(EMAIL_VERIFIED, emailVerified); + return this.claim(StandardClaimNames.EMAIL_VERIFIED, emailVerified); } /** * Use this family name in the resulting {@link OidcUserInfo} - * * @param familyName The family name to use * @return the {@link Builder} for further configurations */ public Builder familyName(String familyName) { - return claim(FAMILY_NAME, familyName); + return claim(StandardClaimNames.FAMILY_NAME, familyName); } /** * Use this gender in the resulting {@link OidcUserInfo} - * * @param gender The gender to use * @return the {@link Builder} for further configurations */ public Builder gender(String gender) { - return this.claim(GENDER, gender); + return this.claim(StandardClaimNames.GENDER, gender); } /** * Use this given name in the resulting {@link OidcUserInfo} - * * @param givenName The given name to use * @return the {@link Builder} for further configurations */ public Builder givenName(String givenName) { - return claim(GIVEN_NAME, givenName); + return claim(StandardClaimNames.GIVEN_NAME, givenName); } /** * Use this locale in the resulting {@link OidcUserInfo} - * * @param locale The locale to use * @return the {@link Builder} for further configurations */ public Builder locale(String locale) { - return this.claim(LOCALE, locale); + return this.claim(StandardClaimNames.LOCALE, locale); } /** * Use this middle name in the resulting {@link OidcUserInfo} - * * @param middleName The middle name to use * @return the {@link Builder} for further configurations */ public Builder middleName(String middleName) { - return claim(MIDDLE_NAME, middleName); + return claim(StandardClaimNames.MIDDLE_NAME, middleName); } /** * Use this name in the resulting {@link OidcUserInfo} - * * @param name The name to use * @return the {@link Builder} for further configurations */ public Builder name(String name) { - return claim(NAME, name); + return claim(StandardClaimNames.NAME, name); } /** * Use this nickname in the resulting {@link OidcUserInfo} - * * @param nickname The nickname to use * @return the {@link Builder} for further configurations */ public Builder nickname(String nickname) { - return claim(NICKNAME, nickname); + return claim(StandardClaimNames.NICKNAME, nickname); } /** * Use this picture in the resulting {@link OidcUserInfo} - * * @param picture The picture to use * @return the {@link Builder} for further configurations */ public Builder picture(String picture) { - return this.claim(PICTURE, picture); + return this.claim(StandardClaimNames.PICTURE, picture); } /** * Use this phone number in the resulting {@link OidcUserInfo} - * * @param phoneNumber The phone number to use * @return the {@link Builder} for further configurations */ public Builder phoneNumber(String phoneNumber) { - return this.claim(PHONE_NUMBER, phoneNumber); + return this.claim(StandardClaimNames.PHONE_NUMBER, phoneNumber); } /** * Use this verified-phone-number indicator in the resulting {@link OidcUserInfo} - * * @param phoneNumberVerified The verified-phone-number indicator to use * @return the {@link Builder} for further configurations */ public Builder phoneNumberVerified(String phoneNumberVerified) { - return this.claim(PHONE_NUMBER_VERIFIED, phoneNumberVerified); + return this.claim(StandardClaimNames.PHONE_NUMBER_VERIFIED, phoneNumberVerified); } /** * Use this preferred username in the resulting {@link OidcUserInfo} - * * @param preferredUsername The preferred username to use * @return the {@link Builder} for further configurations */ public Builder preferredUsername(String preferredUsername) { - return claim(PREFERRED_USERNAME, preferredUsername); + return claim(StandardClaimNames.PREFERRED_USERNAME, preferredUsername); } /** * Use this profile in the resulting {@link OidcUserInfo} - * * @param profile The profile to use * @return the {@link Builder} for further configurations */ public Builder profile(String profile) { - return claim(PROFILE, profile); + return claim(StandardClaimNames.PROFILE, profile); } /** * Use this subject in the resulting {@link OidcUserInfo} - * * @param subject The subject to use * @return the {@link Builder} for further configurations */ public Builder subject(String subject) { - return this.claim(SUB, subject); + return this.claim(StandardClaimNames.SUB, subject); } /** * Use this updated-at {@link Instant} in the resulting {@link OidcUserInfo} - * * @param updatedAt The updated-at {@link Instant} to use * @return the {@link Builder} for further configurations */ public Builder updatedAt(String updatedAt) { - return this.claim(UPDATED_AT, updatedAt); + return this.claim(StandardClaimNames.UPDATED_AT, updatedAt); } /** * Use this website in the resulting {@link OidcUserInfo} - * * @param website The website to use * @return the {@link Builder} for further configurations */ public Builder website(String website) { - return this.claim(WEBSITE, website); + return this.claim(StandardClaimNames.WEBSITE, website); } /** * Use this zoneinfo in the resulting {@link OidcUserInfo} - * * @param zoneinfo The zoneinfo to use * @return the {@link Builder} for further configurations */ public Builder zoneinfo(String zoneinfo) { - return this.claim(ZONEINFO, zoneinfo); + return this.claim(StandardClaimNames.ZONEINFO, zoneinfo); } /** * Build the {@link OidcUserInfo} - * * @return The constructed {@link OidcUserInfo} */ public OidcUserInfo build() { return new OidcUserInfo(this.claims); } + } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimAccessor.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimAccessor.java index 455e0f8f36..03abf96893 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimAccessor.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimAccessor.java @@ -13,31 +13,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc; -import org.springframework.security.oauth2.core.ClaimAccessor; -import org.springframework.util.CollectionUtils; +package org.springframework.security.oauth2.core.oidc; import java.time.Instant; import java.util.Map; +import org.springframework.security.oauth2.core.ClaimAccessor; +import org.springframework.util.CollectionUtils; + /** - * A {@link ClaimAccessor} for the "Standard Claims" that can be returned - * either in the UserInfo Response or the ID Token. + * A {@link ClaimAccessor} for the "Standard Claims" that can be returned either + * in the UserInfo Response or the ID Token. * + * @author Joe Grandja + * @since 5.0 * @see ClaimAccessor * @see StandardClaimNames * @see OidcUserInfo - * @see UserInfo Response - * @see Standard Claims - * @author Joe Grandja - * @since 5.0 + * @see UserInfo + * Response + * @see Standard + * Claims */ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the Subject identifier {@code (sub)}. - * * @return the Subject identifier */ default String getSubject() { @@ -46,7 +50,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's full name {@code (name)} in displayable form. - * * @return the user's full name */ default String getFullName() { @@ -55,7 +58,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's given name(s) or first name(s) {@code (given_name)}. - * * @return the user's given name(s) */ default String getGivenName() { @@ -64,7 +66,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's surname(s) or last name(s) {@code (family_name)}. - * * @return the user's family names(s) */ default String getFamilyName() { @@ -73,7 +74,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's middle name(s) {@code (middle_name)}. - * * @return the user's middle name(s) */ default String getMiddleName() { @@ -81,8 +81,8 @@ public interface StandardClaimAccessor extends ClaimAccessor { } /** - * Returns the user's nick name {@code (nickname)} that may or may not be the same as the {@code (given_name)}. - * + * Returns the user's nick name {@code (nickname)} that may or may not be the same as + * the {@code (given_name)}. * @return the user's nick name */ default String getNickName() { @@ -90,8 +90,8 @@ public interface StandardClaimAccessor extends ClaimAccessor { } /** - * Returns the preferred username {@code (preferred_username)} that the user wishes to be referred to. - * + * Returns the preferred username {@code (preferred_username)} that the user wishes to + * be referred to. * @return the user's preferred user name */ default String getPreferredUsername() { @@ -100,7 +100,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the URL of the user's profile page {@code (profile)}. - * * @return the URL of the user's profile page */ default String getProfile() { @@ -109,7 +108,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the URL of the user's profile picture {@code (picture)}. - * * @return the URL of the user's profile picture */ default String getPicture() { @@ -118,7 +116,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the URL of the user's web page or blog {@code (website)}. - * * @return the URL of the user's web page or blog */ default String getWebsite() { @@ -127,7 +124,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's preferred e-mail address {@code (email)}. - * * @return the user's preferred e-mail address */ default String getEmail() { @@ -135,9 +131,10 @@ public interface StandardClaimAccessor extends ClaimAccessor { } /** - * Returns {@code true} if the user's e-mail address has been verified {@code (email_verified)}, otherwise {@code false}. - * - * @return {@code true} if the user's e-mail address has been verified, otherwise {@code false} + * Returns {@code true} if the user's e-mail address has been verified + * {@code (email_verified)}, otherwise {@code false}. + * @return {@code true} if the user's e-mail address has been verified, otherwise + * {@code false} */ default Boolean getEmailVerified() { return this.getClaimAsBoolean(StandardClaimNames.EMAIL_VERIFIED); @@ -145,7 +142,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's gender {@code (gender)}. - * * @return the user's gender */ default String getGender() { @@ -154,7 +150,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's birth date {@code (birthdate)}. - * * @return the user's birth date */ default String getBirthdate() { @@ -163,7 +158,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's time zone {@code (zoneinfo)}. - * * @return the user's time zone */ default String getZoneInfo() { @@ -172,7 +166,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's locale {@code (locale)}. - * * @return the user's locale */ default String getLocale() { @@ -181,7 +174,6 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's preferred phone number {@code (phone_number)}. - * * @return the user's preferred phone number */ default String getPhoneNumber() { @@ -189,9 +181,10 @@ public interface StandardClaimAccessor extends ClaimAccessor { } /** - * Returns {@code true} if the user's phone number has been verified {@code (phone_number_verified)}, otherwise {@code false}. - * - * @return {@code true} if the user's phone number has been verified, otherwise {@code false} + * Returns {@code true} if the user's phone number has been verified + * {@code (phone_number_verified)}, otherwise {@code false}. + * @return {@code true} if the user's phone number has been verified, otherwise + * {@code false} */ default Boolean getPhoneNumberVerified() { return this.getClaimAsBoolean(StandardClaimNames.PHONE_NUMBER_VERIFIED); @@ -199,22 +192,20 @@ public interface StandardClaimAccessor extends ClaimAccessor { /** * Returns the user's preferred postal address {@code (address)}. - * * @return the user's preferred postal address */ default AddressStandardClaim getAddress() { Map addressFields = this.getClaimAsMap(StandardClaimNames.ADDRESS); - return (!CollectionUtils.isEmpty(addressFields) ? - new DefaultAddressStandardClaim.Builder(addressFields).build() : - new DefaultAddressStandardClaim.Builder().build()); + return (!CollectionUtils.isEmpty(addressFields) ? new DefaultAddressStandardClaim.Builder(addressFields).build() + : new DefaultAddressStandardClaim.Builder().build()); } /** * Returns the time the user's information was last updated {@code (updated_at)}. - * * @return the time the user's information was last updated */ default Instant getUpdatedAt() { return this.getClaimAsInstant(StandardClaimNames.UPDATED_AT); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimNames.java index e57b4df7a0..93c4d806da 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/StandardClaimNames.java @@ -13,17 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc; /** - * The names of the "Standard Claims" defined by the OpenID Connect Core 1.0 specification - * that can be returned either in the UserInfo Response or the ID Token. + * The names of the "Standard Claims" defined by the OpenID Connect Core 1.0 + * specification that can be returned either in the UserInfo Response or the ID Token. * * @author Joe Grandja * @since 5.0 - * @see Standard Claims - * @see UserInfo Response - * @see ID Token + * @see Standard + * Claims + * @see UserInfo + * Response + * @see ID Token */ public interface StandardClaimNames { @@ -53,12 +59,14 @@ public interface StandardClaimNames { String MIDDLE_NAME = "middle_name"; /** - * {@code nickname} - the user's nick name that may or may not be the same as the {@code given_name} + * {@code nickname} - the user's nick name that may or may not be the same as the + * {@code given_name} */ String NICKNAME = "nickname"; /** - * {@code preferred_username} - the preferred username that the user wishes to be referred to + * {@code preferred_username} - the preferred username that the user wishes to be + * referred to */ String PREFERRED_USERNAME = "preferred_username"; @@ -83,7 +91,8 @@ public interface StandardClaimNames { String EMAIL = "email"; /** - * {@code email_verified} - {@code true} if the user's e-mail address has been verified, otherwise {@code false} + * {@code email_verified} - {@code true} if the user's e-mail address has been + * verified, otherwise {@code false} */ String EMAIL_VERIFIED = "email_verified"; @@ -113,7 +122,8 @@ public interface StandardClaimNames { String PHONE_NUMBER = "phone_number"; /** - * {@code phone_number_verified} - {@code true} if the user's phone number has been verified, otherwise {@code false} + * {@code phone_number_verified} - {@code true} if the user's phone number has been + * verified, otherwise {@code false} */ String PHONE_NUMBER_VERIFIED = "phone_number_verified"; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/OidcParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/OidcParameterNames.java index 827c4557b2..44ca88ac60 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/OidcParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/OidcParameterNames.java @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc.endpoint; /** - * Standard parameter names defined in the OAuth Parameters Registry - * and used by the authorization endpoint and token endpoint. + * Standard parameter names defined in the OAuth Parameters Registry and used by the + * authorization endpoint and token endpoint. * * @author Joe Grandja * @author Mark Heckler * @since 5.0 - * @see 18.2 OAuth Parameters Registration + * @see 18.2 + * OAuth Parameters Registration */ public interface OidcParameterNames { diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/package-info.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/package-info.java index f96f4788b7..afcacdf7bc 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/package-info.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/endpoint/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Support classes that model the OpenID Connect Core 1.0 Request and Response messages * from the Authorization Endpoint and Token Endpoint. diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/package-info.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/package-info.java index 001ba47283..76f861e9a2 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/package-info.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Core classes and interfaces providing support for OpenID Connect Core 1.0. */ diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUser.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUser.java index b162ba85ae..2266fcf0e1 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUser.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUser.java @@ -16,21 +16,21 @@ package org.springframework.security.oauth2.core.oidc.user; +import java.util.Collection; +import java.util.Map; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; -import java.util.Collection; -import java.util.Map; - /** * The default implementation of an {@link OidcUser}. * *

        - * The default claim used for accessing the "name" of the - * user {@code Principal} from {@link #getClaims()} is {@link IdTokenClaimNames#SUB}. + * The default claim used for accessing the "name" of the user {@code Principal} + * from {@link #getClaims()} is {@link IdTokenClaimNames#SUB}. * * @author Joe Grandja * @author Vedran Pavic @@ -41,12 +41,13 @@ import java.util.Map; * @see OidcUserInfo */ public class DefaultOidcUser extends DefaultOAuth2User implements OidcUser { + private final OidcIdToken idToken; + private final OidcUserInfo userInfo; /** * Constructs a {@code DefaultOidcUser} using the provided parameters. - * * @param authorities the authorities granted to the user * @param idToken the {@link OidcIdToken ID Token} containing claims about the user */ @@ -56,36 +57,39 @@ public class DefaultOidcUser extends DefaultOAuth2User implements OidcUser { /** * Constructs a {@code DefaultOidcUser} using the provided parameters. - * * @param authorities the authorities granted to the user * @param idToken the {@link OidcIdToken ID Token} containing claims about the user - * @param nameAttributeKey the key used to access the user's "name" from {@link #getAttributes()} + * @param nameAttributeKey the key used to access the user's "name" from + * {@link #getAttributes()} */ - public DefaultOidcUser(Collection authorities, OidcIdToken idToken, String nameAttributeKey) { + public DefaultOidcUser(Collection authorities, OidcIdToken idToken, + String nameAttributeKey) { this(authorities, idToken, null, nameAttributeKey); } /** * Constructs a {@code DefaultOidcUser} using the provided parameters. - * * @param authorities the authorities granted to the user * @param idToken the {@link OidcIdToken ID Token} containing claims about the user - * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, may be {@code null} + * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, + * may be {@code null} */ - public DefaultOidcUser(Collection authorities, OidcIdToken idToken, OidcUserInfo userInfo) { + public DefaultOidcUser(Collection authorities, OidcIdToken idToken, + OidcUserInfo userInfo) { this(authorities, idToken, userInfo, IdTokenClaimNames.SUB); } /** * Constructs a {@code DefaultOidcUser} using the provided parameters. - * * @param authorities the authorities granted to the user * @param idToken the {@link OidcIdToken ID Token} containing claims about the user - * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, may be {@code null} - * @param nameAttributeKey the key used to access the user's "name" from {@link #getAttributes()} + * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, + * may be {@code null} + * @param nameAttributeKey the key used to access the user's "name" from + * {@link #getAttributes()} */ public DefaultOidcUser(Collection authorities, OidcIdToken idToken, - OidcUserInfo userInfo, String nameAttributeKey) { + OidcUserInfo userInfo, String nameAttributeKey) { super(authorities, OidcUserAuthority.collectClaims(idToken, userInfo), nameAttributeKey); this.idToken = idToken; this.userInfo = userInfo; @@ -105,4 +109,5 @@ public class DefaultOidcUser extends DefaultOAuth2User implements OidcUser { public OidcUserInfo getUserInfo() { return this.userInfo; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUser.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUser.java index 15b3014061..30bf20891a 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUser.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUser.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc.user; +import java.util.Map; + import org.springframework.security.core.AuthenticatedPrincipal; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.oidc.IdTokenClaimAccessor; @@ -23,20 +26,19 @@ import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimAccessor; import org.springframework.security.oauth2.core.user.OAuth2User; -import java.util.Map; - /** - * A representation of a user {@code Principal} - * that is registered with an OpenID Connect 1.0 Provider. + * A representation of a user {@code Principal} that is registered with an OpenID Connect + * 1.0 Provider. * *

        - * An {@code OidcUser} contains "claims" about the authentication of the End-User. - * The claims are aggregated from the {@link OidcIdToken} and the {@link OidcUserInfo} (if available). + * An {@code OidcUser} contains "claims" about the authentication of the + * End-User. The claims are aggregated from the {@link OidcIdToken} and the + * {@link OidcUserInfo} (if available). * *

        * Implementation instances of this interface represent an {@link AuthenticatedPrincipal} - * which is associated to an {@link Authentication} object - * and may be accessed via {@link Authentication#getPrincipal()}. + * which is associated to an {@link Authentication} object and may be accessed via + * {@link Authentication#getPrincipal()}. * * @author Joe Grandja * @since 5.0 @@ -46,30 +48,32 @@ import java.util.Map; * @see OidcUserInfo * @see IdTokenClaimAccessor * @see StandardClaimAccessor - * @see ID Token - * @see Standard Claims + * @see ID Token + * @see Standard + * Claims */ public interface OidcUser extends OAuth2User, IdTokenClaimAccessor { /** - * Returns the claims about the user. - * The claims are aggregated from {@link #getIdToken()} and {@link #getUserInfo()} (if available). - * + * Returns the claims about the user. The claims are aggregated from + * {@link #getIdToken()} and {@link #getUserInfo()} (if available). * @return a {@code Map} of claims about the user */ + @Override Map getClaims(); /** * Returns the {@link OidcUserInfo UserInfo} containing claims about the user. - * * @return the {@link OidcUserInfo} containing claims about the user. */ OidcUserInfo getUserInfo(); /** * Returns the {@link OidcIdToken ID Token} containing claims about the user. - * * @return the {@link OidcIdToken} containing claims about the user. */ OidcIdToken getIdToken(); + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthority.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthority.java index 2d215b1326..73bcdf624d 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthority.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthority.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.oidc.user; +import java.util.HashMap; +import java.util.Map; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.util.Assert; -import java.util.HashMap; -import java.util.Map; - /** * A {@link GrantedAuthority} that may be associated to an {@link OidcUser}. * @@ -32,12 +33,13 @@ import java.util.Map; * @see OidcUser */ public class OidcUserAuthority extends OAuth2UserAuthority { + private final OidcIdToken idToken; + private final OidcUserInfo userInfo; /** * Constructs a {@code OidcUserAuthority} using the provided parameters. - * * @param idToken the {@link OidcIdToken ID Token} containing claims about the user */ public OidcUserAuthority(OidcIdToken idToken) { @@ -45,11 +47,11 @@ public class OidcUserAuthority extends OAuth2UserAuthority { } /** - * Constructs a {@code OidcUserAuthority} using the provided parameters - * and defaults {@link #getAuthority()} to {@code ROLE_USER}. - * + * Constructs a {@code OidcUserAuthority} using the provided parameters and defaults + * {@link #getAuthority()} to {@code ROLE_USER}. * @param idToken the {@link OidcIdToken ID Token} containing claims about the user - * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, may be {@code null} + * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, + * may be {@code null} */ public OidcUserAuthority(OidcIdToken idToken, OidcUserInfo userInfo) { this("ROLE_USER", idToken, userInfo); @@ -57,10 +59,10 @@ public class OidcUserAuthority extends OAuth2UserAuthority { /** * Constructs a {@code OidcUserAuthority} using the provided parameters. - * * @param authority the authority granted to the user * @param idToken the {@link OidcIdToken ID Token} containing claims about the user - * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, may be {@code null} + * @param userInfo the {@link OidcUserInfo UserInfo} containing claims about the user, + * may be {@code null} */ public OidcUserAuthority(String authority, OidcIdToken idToken, OidcUserInfo userInfo) { super(authority, collectClaims(idToken, userInfo)); @@ -70,7 +72,6 @@ public class OidcUserAuthority extends OAuth2UserAuthority { /** * Returns the {@link OidcIdToken ID Token} containing claims about the user. - * * @return the {@link OidcIdToken} containing claims about the user. */ public OidcIdToken getIdToken() { @@ -78,8 +79,8 @@ public class OidcUserAuthority extends OAuth2UserAuthority { } /** - * Returns the {@link OidcUserInfo UserInfo} containing claims about the user, may be {@code null}. - * + * Returns the {@link OidcUserInfo UserInfo} containing claims about the user, may be + * {@code null}. * @return the {@link OidcUserInfo} containing claims about the user, or {@code null} */ public OidcUserInfo getUserInfo() { @@ -97,22 +98,19 @@ public class OidcUserAuthority extends OAuth2UserAuthority { if (!super.equals(obj)) { return false; } - OidcUserAuthority that = (OidcUserAuthority) obj; - if (!this.getIdToken().equals(that.getIdToken())) { return false; } - return this.getUserInfo() != null ? - this.getUserInfo().equals(that.getUserInfo()) : - that.getUserInfo() == null; + return (this.getUserInfo() != null) ? this.getUserInfo().equals(that.getUserInfo()) + : that.getUserInfo() == null; } @Override public int hashCode() { int result = super.hashCode(); result = 31 * result + this.getIdToken().hashCode(); - result = 31 * result + (this.getUserInfo() != null ? this.getUserInfo().hashCode() : 0); + result = 31 * result + ((this.getUserInfo() != null) ? this.getUserInfo().hashCode() : 0); return result; } @@ -125,4 +123,5 @@ public class OidcUserAuthority extends OAuth2UserAuthority { claims.putAll(idToken.getClaims()); return claims; } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/package-info.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/package-info.java index 3af0286c26..29bab7b734 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/package-info.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/oidc/user/package-info.java @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Provides a model for an OpenID Connect Core 1.0 representation of a user {@code Principal}. + * Provides a model for an OpenID Connect Core 1.0 representation of a user + * {@code Principal}. */ package org.springframework.security.oauth2.core.oidc.user; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/package-info.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/package-info.java index 77190dd447..4801f42b4e 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/package-info.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/package-info.java @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Core classes and interfaces providing support for the OAuth 2.0 Authorization Framework. + * Core classes and interfaces providing support for the OAuth 2.0 Authorization + * Framework. */ package org.springframework.security.oauth2.core; diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/DefaultOAuth2User.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/DefaultOAuth2User.java index 7657d6b1b7..31fb080f50 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/DefaultOAuth2User.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/DefaultOAuth2User.java @@ -13,53 +13,58 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.user; +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; -import java.io.Serializable; -import java.util.Map; -import java.util.Set; -import java.util.Collections; -import java.util.Collection; -import java.util.LinkedHashMap; -import java.util.TreeSet; -import java.util.SortedSet; -import java.util.Comparator; -import java.util.LinkedHashSet; - /** * The default implementation of an {@link OAuth2User}. * *

        - * User attribute names are not standardized between providers - * and therefore it is required to supply the key - * for the user's "name" attribute to one of the constructors. - * The key will be used for accessing the "name" of the - * {@code Principal} (user) via {@link #getAttributes()} - * and returning it from {@link #getName()}. + * User attribute names are not standardized between providers and therefore it is + * required to supply the key for the user's "name" attribute to one of + * the constructors. The key will be used for accessing the "name" of the + * {@code Principal} (user) via {@link #getAttributes()} and returning it from + * {@link #getName()}. * * @author Joe Grandja * @author Eddú Meléndez - * @see OAuth2User * @since 5.0 + * @see OAuth2User */ public class DefaultOAuth2User implements OAuth2User, Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final Set authorities; + private final Map attributes; + private final String nameAttributeKey; /** * Constructs a {@code DefaultOAuth2User} using the provided parameters. - * - * @param authorities the authorities granted to the user - * @param attributes the attributes about the user - * @param nameAttributeKey the key used to access the user's "name" from {@link #getAttributes()} + * @param authorities the authorities granted to the user + * @param attributes the attributes about the user + * @param nameAttributeKey the key used to access the user's "name" from + * {@link #getAttributes()} */ - public DefaultOAuth2User(Collection authorities, Map attributes, String nameAttributeKey) { + public DefaultOAuth2User(Collection authorities, Map attributes, + String nameAttributeKey) { Assert.notEmpty(authorities, "authorities cannot be empty"); Assert.notEmpty(attributes, "attributes cannot be empty"); Assert.hasText(nameAttributeKey, "nameAttributeKey cannot be empty"); @@ -87,8 +92,8 @@ public class DefaultOAuth2User implements OAuth2User, Serializable { } private Set sortAuthorities(Collection authorities) { - SortedSet sortedAuthorities = - new TreeSet<>(Comparator.comparing(GrantedAuthority::getAuthority)); + SortedSet sortedAuthorities = new TreeSet<>( + Comparator.comparing(GrantedAuthority::getAuthority)); sortedAuthorities.addAll(authorities); return sortedAuthorities; } @@ -101,9 +106,7 @@ public class DefaultOAuth2User implements OAuth2User, Serializable { if (obj == null || this.getClass() != obj.getClass()) { return false; } - DefaultOAuth2User that = (DefaultOAuth2User) obj; - if (!this.getName().equals(that.getName())) { return false; } @@ -133,4 +136,5 @@ public class DefaultOAuth2User implements OAuth2User, Serializable { sb.append("]"); return sb.toString(); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2User.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2User.java index f25db96c01..f8eda9ad34 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2User.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2User.java @@ -13,29 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.user; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; /** - * A representation of a user {@code Principal} - * that is registered with an OAuth 2.0 Provider. + * A representation of a user {@code Principal} that is registered with an OAuth 2.0 + * Provider. * *

        - * An OAuth 2.0 user is composed of one or more attributes, for example, - * first name, middle name, last name, email, phone number, address, etc. - * Each user attribute has a "name" and "value" and - * is keyed by the "name" in {@link #getAttributes()}. + * An OAuth 2.0 user is composed of one or more attributes, for example, first name, + * middle name, last name, email, phone number, address, etc. Each user attribute has a + * "name" and "value" and is keyed by the "name" in + * {@link #getAttributes()}. * *

        - * NOTE: Attribute names are not standardized between providers and therefore will vary. - * Please consult the provider's API documentation for the set of supported user attribute names. + * NOTE: Attribute names are not standardized between providers and + * therefore will vary. Please consult the provider's API documentation for the set of + * supported user attribute names. * *

        - * Implementation instances of this interface represent an {@link OAuth2AuthenticatedPrincipal} - * which is associated to an {@link Authentication} object - * and may be accessed via {@link Authentication#getPrincipal()}. + * Implementation instances of this interface represent an + * {@link OAuth2AuthenticatedPrincipal} which is associated to an {@link Authentication} + * object and may be accessed via {@link Authentication#getPrincipal()}. * * @author Joe Grandja * @author Eddú Meléndez diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthority.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthority.java index 161fc1f64c..ead74cbaff 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthority.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthority.java @@ -13,16 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.user; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.SpringSecurityCoreVersion; -import org.springframework.util.Assert; +package org.springframework.security.oauth2.core.user; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.SpringSecurityCoreVersion; +import org.springframework.util.Assert; + /** * A {@link GrantedAuthority} that may be associated to an {@link OAuth2User}. * @@ -31,14 +32,16 @@ import java.util.Map; * @see OAuth2User */ public class OAuth2UserAuthority implements GrantedAuthority { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + private final String authority; + private final Map attributes; /** - * Constructs a {@code OAuth2UserAuthority} using the provided parameters - * and defaults {@link #getAuthority()} to {@code ROLE_USER}. - * + * Constructs a {@code OAuth2UserAuthority} using the provided parameters and defaults + * {@link #getAuthority()} to {@code ROLE_USER}. * @param attributes the attributes about the user */ public OAuth2UserAuthority(Map attributes) { @@ -47,7 +50,6 @@ public class OAuth2UserAuthority implements GrantedAuthority { /** * Constructs a {@code OAuth2UserAuthority} using the provided parameters. - * * @param authority the authority granted to the user * @param attributes the attributes about the user */ @@ -65,7 +67,6 @@ public class OAuth2UserAuthority implements GrantedAuthority { /** * Returns the attributes about the user. - * * @return a {@code Map} of attributes about the user */ public Map getAttributes() { @@ -80,9 +81,7 @@ public class OAuth2UserAuthority implements GrantedAuthority { if (obj == null || this.getClass() != obj.getClass()) { return false; } - OAuth2UserAuthority that = (OAuth2UserAuthority) obj; - if (!this.getAuthority().equals(that.getAuthority())) { return false; } @@ -100,4 +99,5 @@ public class OAuth2UserAuthority implements GrantedAuthority { public String toString() { return this.getAuthority(); } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/package-info.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/package-info.java index 2f429402a1..6985015692 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/package-info.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/user/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Provides a model for an OAuth 2.0 representation of a user {@code Principal}. */ diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java index b6de896d32..6c3d93c1c0 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java @@ -16,6 +16,12 @@ package org.springframework.security.oauth2.core.web.reactive.function; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + import com.nimbusds.oauth2.sdk.AccessTokenResponse; import com.nimbusds.oauth2.sdk.ErrorObject; import com.nimbusds.oauth2.sdk.ParseException; @@ -23,6 +29,8 @@ import com.nimbusds.oauth2.sdk.TokenErrorResponse; import com.nimbusds.oauth2.sdk.TokenResponse; import com.nimbusds.oauth2.sdk.token.AccessToken; import net.minidev.json.JSONObject; +import reactor.core.publisher.Mono; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -32,16 +40,11 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.BodyExtractors; -import reactor.core.publisher.Mono; - -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; /** - * Provides a way to create an {@link OAuth2AccessTokenResponse} from a {@link ReactiveHttpInputMessage} + * Provides a way to create an {@link OAuth2AccessTokenResponse} from a + * {@link ReactiveHttpInputMessage} + * * @author Rob Winch * @since 5.1 */ @@ -50,16 +53,20 @@ class OAuth2AccessTokenResponseBodyExtractor private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - OAuth2AccessTokenResponseBodyExtractor() {} + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; + + OAuth2AccessTokenResponseBodyExtractor() { + } @Override - public Mono extract(ReactiveHttpInputMessage inputMessage, - Context context) { - ParameterizedTypeReference> type = new ParameterizedTypeReference>() {}; - BodyExtractor>, ReactiveHttpInputMessage> delegate = BodyExtractors.toMono(type); + public Mono extract(ReactiveHttpInputMessage inputMessage, Context context) { + BodyExtractor>, ReactiveHttpInputMessage> delegate = BodyExtractors + .toMono(STRING_OBJECT_MAP); return delegate.extract(inputMessage, context) - .onErrorMap(e -> new OAuth2AuthorizationException( - invalidTokenResponse("An error occurred parsing the Access Token response: " + e.getMessage()), e)) + .onErrorMap((ex) -> new OAuth2AuthorizationException( + invalidTokenResponse("An error occurred parsing the Access Token response: " + ex.getMessage()), + ex)) .switchIfEmpty(Mono.error(() -> new OAuth2AuthorizationException( invalidTokenResponse("Empty OAuth 2.0 Access Token Response")))) .map(OAuth2AccessTokenResponseBodyExtractor::parse) @@ -71,58 +78,52 @@ class OAuth2AccessTokenResponseBodyExtractor try { return TokenResponse.parse(new JSONObject(json)); } - catch (ParseException pe) { + catch (ParseException ex) { OAuth2Error oauth2Error = invalidTokenResponse( - "An error occurred parsing the Access Token response: " + pe.getMessage()); - throw new OAuth2AuthorizationException(oauth2Error, pe); + "An error occurred parsing the Access Token response: " + ex.getMessage()); + throw new OAuth2AuthorizationException(oauth2Error, ex); } } private static OAuth2Error invalidTokenResponse(String message) { - return new OAuth2Error( - INVALID_TOKEN_RESPONSE_ERROR_CODE, - message, - null); + return new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, message, null); } private static Mono oauth2AccessTokenResponse(TokenResponse tokenResponse) { if (tokenResponse.indicatesSuccess()) { - return Mono.just(tokenResponse) - .cast(AccessTokenResponse.class); + return Mono.just(tokenResponse).cast(AccessTokenResponse.class); } TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; ErrorObject errorObject = tokenErrorResponse.getErrorObject(); - OAuth2Error oauth2Error; - if (errorObject == null) { - oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); - } else { - oauth2Error = new OAuth2Error( - errorObject.getCode() != null ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR, - errorObject.getDescription(), - errorObject.getURI() != null ? errorObject.getURI().toString() : null); - } + OAuth2Error oauth2Error = getOAuth2Error(errorObject); return Mono.error(new OAuth2AuthorizationException(oauth2Error)); } + private static OAuth2Error getOAuth2Error(ErrorObject errorObject) { + if (errorObject == null) { + return new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); + } + String code = (errorObject.getCode() != null) ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR; + String description = errorObject.getDescription(); + String uri = (errorObject.getURI() != null) ? errorObject.getURI().toString() : null; + return new OAuth2Error(code, description, uri); + } + private static OAuth2AccessTokenResponse oauth2AccessTokenResponse(AccessTokenResponse accessTokenResponse) { AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken(); OAuth2AccessToken.TokenType accessTokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue() - .equalsIgnoreCase(accessToken.getType().getValue())) { + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(accessToken.getType().getValue())) { accessTokenType = OAuth2AccessToken.TokenType.BEARER; } long expiresIn = accessToken.getLifetime(); - - Set scopes = accessToken.getScope() == null ? - Collections.emptySet() : new LinkedHashSet<>(accessToken.getScope().toStringList()); - + Set scopes = (accessToken.getScope() != null) + ? new LinkedHashSet<>(accessToken.getScope().toStringList()) : Collections.emptySet(); String refreshToken = null; if (accessTokenResponse.getTokens().getRefreshToken() != null) { refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); } - Map additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); - + // @formatter:off return OAuth2AccessTokenResponse.withToken(accessToken.getValue()) .tokenType(accessTokenType) .expiresIn(expiresIn) @@ -130,5 +131,7 @@ class OAuth2AccessTokenResponseBodyExtractor .refreshToken(refreshToken) .additionalParameters(additionalParameters) .build(); + // @formatter:on } + } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java index 4818f7a064..77ed941a0d 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java @@ -16,13 +16,15 @@ package org.springframework.security.oauth2.core.web.reactive.function; +import reactor.core.publisher.Mono; + import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.web.reactive.function.BodyExtractor; -import reactor.core.publisher.Mono; /** * Static factory methods for OAuth2 {@link BodyExtractor} implementations. + * * @author Rob Winch * @since 5.1 */ @@ -36,5 +38,7 @@ public abstract class OAuth2BodyExtractors { return new OAuth2AccessTokenResponseBodyExtractor(); } - private OAuth2BodyExtractors() {} + private OAuth2BodyExtractors() { + } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthenticationMethodTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthenticationMethodTests.java index 3a0f698659..2b4af49903 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthenticationMethodTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthenticationMethodTests.java @@ -13,12 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + /** * Tests for {@link AuthenticationMethod}. * @@ -28,7 +30,8 @@ public class AuthenticationMethodTests { @Test public void constructorWhenValueIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthenticationMethod(null)).hasMessage("value cannot be empty"); + assertThatIllegalArgumentException().isThrownBy(() -> new AuthenticationMethod(null)) + .withMessage("value cannot be empty"); } @Test @@ -45,4 +48,5 @@ public class AuthenticationMethodTests { public void getValueWhenFormAuthenticationTypeThenReturnQuery() { assertThat(AuthenticationMethod.QUERY.getValue()).isEqualTo("query"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthorizationGrantTypeTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthorizationGrantTypeTests.java index 8205e162ca..d8b5aa6cd1 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthorizationGrantTypeTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/AuthorizationGrantTypeTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import org.junit.Test; @@ -50,4 +51,5 @@ public class AuthorizationGrantTypeTests { public void getValueWhenPasswordGrantTypeThenReturnPassword() { assertThat(AuthorizationGrantType.PASSWORD.getValue()).isEqualTo("password"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClaimAccessorTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClaimAccessorTests.java index 29fa8154bd..e0c1524dc7 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClaimAccessorTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClaimAccessorTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core; -import org.junit.Before; -import org.junit.Test; +package org.springframework.security.oauth2.core; import java.time.Instant; import java.util.Arrays; @@ -25,8 +23,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Before; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.catchThrowable; +import static org.assertj.core.api.Assertions.assertThatObject; /** * Tests for {@link ClaimAccessor}. @@ -34,7 +35,9 @@ import static org.assertj.core.api.Assertions.catchThrowable; * @author Joe Grandja */ public class ClaimAccessorTests { + private Map claims = new HashMap<>(); + private ClaimAccessor claimAccessor = (() -> this.claims); @Before @@ -48,9 +51,8 @@ public class ClaimAccessorTests { Instant expectedClaimValue = Instant.now(); String claimName = "date"; this.claims.put(claimName, Date.from(expectedClaimValue)); - - assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween( - expectedClaimValue.minusSeconds(1), expectedClaimValue.plusSeconds(1)); + assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween(expectedClaimValue.minusSeconds(1), + expectedClaimValue.plusSeconds(1)); } // gh-5191 @@ -59,9 +61,8 @@ public class ClaimAccessorTests { Instant expectedClaimValue = Instant.now(); String claimName = "longSeconds"; this.claims.put(claimName, expectedClaimValue.getEpochSecond()); - - assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween( - expectedClaimValue.minusSeconds(1), expectedClaimValue.plusSeconds(1)); + assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween(expectedClaimValue.minusSeconds(1), + expectedClaimValue.plusSeconds(1)); } @Test @@ -69,9 +70,8 @@ public class ClaimAccessorTests { Instant expectedClaimValue = Instant.now(); String claimName = "instant"; this.claims.put(claimName, expectedClaimValue); - - assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween( - expectedClaimValue.minusSeconds(1), expectedClaimValue.plusSeconds(1)); + assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween(expectedClaimValue.minusSeconds(1), + expectedClaimValue.plusSeconds(1)); } // gh-5250 @@ -80,9 +80,8 @@ public class ClaimAccessorTests { Instant expectedClaimValue = Instant.now(); String claimName = "integerSeconds"; this.claims.put(claimName, Long.valueOf(expectedClaimValue.getEpochSecond()).intValue()); - - assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween( - expectedClaimValue.minusSeconds(1), expectedClaimValue.plusSeconds(1)); + assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween(expectedClaimValue.minusSeconds(1), + expectedClaimValue.plusSeconds(1)); } // gh-5250 @@ -91,9 +90,8 @@ public class ClaimAccessorTests { Instant expectedClaimValue = Instant.now(); String claimName = "doubleSeconds"; this.claims.put(claimName, Long.valueOf(expectedClaimValue.getEpochSecond()).doubleValue()); - - assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween( - expectedClaimValue.minusSeconds(1), expectedClaimValue.plusSeconds(1)); + assertThat(this.claimAccessor.getClaimAsInstant(claimName)).isBetween(expectedClaimValue.minusSeconds(1), + expectedClaimValue.plusSeconds(1)); } // gh-5608 @@ -101,7 +99,6 @@ public class ClaimAccessorTests { public void getClaimAsStringWhenValueIsNullThenReturnNull() { String claimName = "claim-with-null-value"; this.claims.put(claimName, null); - assertThat(this.claimAccessor.getClaimAsString(claimName)).isNull(); } @@ -117,9 +114,7 @@ public class ClaimAccessorTests { List expectedClaimValue = Arrays.asList("item1", "item2"); String claimName = "list"; this.claims.put(claimName, expectedClaimValue); - List actualClaimValue = this.claimAccessor.getClaim(claimName); - assertThat(actualClaimValue).containsOnlyElementsOf(expectedClaimValue); } @@ -128,9 +123,7 @@ public class ClaimAccessorTests { boolean expectedClaimValue = true; String claimName = "boolean"; this.claims.put(claimName, expectedClaimValue); - boolean actualClaimValue = this.claimAccessor.getClaim(claimName); - assertThat(actualClaimValue).isEqualTo(expectedClaimValue); } @@ -139,9 +132,7 @@ public class ClaimAccessorTests { String expectedClaimValue = "true"; String claimName = "boolean"; this.claims.put(claimName, expectedClaimValue); - - Throwable thrown = catchThrowable(() -> { boolean actualClaimValue = this.claimAccessor.getClaim(claimName); }); - - assertThat(thrown).isInstanceOf(ClassCastException.class); + assertThatObject(this.claimAccessor.getClaim(claimName)).isNotInstanceOf(Boolean.class); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java index 174388b7ac..e514bf4c2e 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/ClientAuthenticationMethodTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import org.junit.Test; @@ -45,4 +46,5 @@ public class ClientAuthenticationMethodTests { public void getValueWhenAuthenticationMethodNoneThenReturnNone() { assertThat(ClientAuthenticationMethod.NONE.getValue()).isEqualTo("none"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipalTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipalTests.java index 8d794e7a09..5a1940c945 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipalTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DefaultOAuth2AuthenticatedPrincipalTests.java @@ -26,7 +26,7 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultOAuth2AuthenticatedPrincipal} @@ -34,48 +34,50 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Josh Cummings */ public class DefaultOAuth2AuthenticatedPrincipalTests { + String name = "test-subject"; + Map attributes = Collections.singletonMap("sub", this.name); + Collection authorities = AuthorityUtils.createAuthorityList("SCOPE_read"); @Test public void constructorWhenAttributesIsNullOrEmptyThenIllegalArgumentException() { - assertThatCode(() -> new DefaultOAuth2AuthenticatedPrincipal(null, this.authorities)) - .isInstanceOf(IllegalArgumentException.class); - - assertThatCode(() -> new DefaultOAuth2AuthenticatedPrincipal(Collections.emptyMap(), this.authorities)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DefaultOAuth2AuthenticatedPrincipal(null, this.authorities)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DefaultOAuth2AuthenticatedPrincipal(Collections.emptyMap(), this.authorities)); } @Test public void constructorWhenAuthoritiesIsNullOrEmptyThenNoAuthorities() { - Collection authorities = - new DefaultOAuth2AuthenticatedPrincipal(this.attributes, null).getAuthorities(); + Collection authorities = new DefaultOAuth2AuthenticatedPrincipal(this.attributes, + null).getAuthorities(); assertThat(authorities).isEmpty(); - - authorities = new DefaultOAuth2AuthenticatedPrincipal(this.attributes, - Collections.emptyList()).getAuthorities(); + authorities = new DefaultOAuth2AuthenticatedPrincipal(this.attributes, Collections.emptyList()) + .getAuthorities(); assertThat(authorities).isEmpty(); } @Test public void constructorWhenNameIsNullThenFallsbackToSubAttribute() { - OAuth2AuthenticatedPrincipal principal = - new DefaultOAuth2AuthenticatedPrincipal(null, this.attributes, this.authorities); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(null, this.attributes, + this.authorities); assertThat(principal.getName()).isEqualTo(this.attributes.get("sub")); } @Test public void getNameWhenInConstructorThenReturns() { - OAuth2AuthenticatedPrincipal principal = - new DefaultOAuth2AuthenticatedPrincipal("other-subject", this.attributes, this.authorities); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal("other-subject", + this.attributes, this.authorities); assertThat(principal.getName()).isEqualTo("other-subject"); } @Test public void getAttributeWhenGivenKeyThenReturnsValue() { - OAuth2AuthenticatedPrincipal principal = - new DefaultOAuth2AuthenticatedPrincipal(this.attributes, this.authorities); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(this.attributes, + this.authorities); assertThat((String) principal.getAttribute("sub")).isEqualTo("test-subject"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidatorTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidatorTests.java index eb9bc26bbc..ce4fa6b365 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidatorTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/DelegatingOAuth2TokenValidatorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import java.util.Arrays; @@ -21,12 +22,12 @@ import java.util.Collection; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * Tests for verifying {@link DelegatingOAuth2TokenValidator} @@ -34,15 +35,13 @@ import static org.mockito.Mockito.when; * @author Josh Cummings */ public class DelegatingOAuth2TokenValidatorTests { - private static final OAuth2Error DETAIL = new OAuth2Error( - "error", "description", "uri"); + + private static final OAuth2Error DETAIL = new OAuth2Error("error", "description", "uri"); @Test public void validateWhenNoValidatorsConfiguredThenReturnsSuccessfulResult() { - DelegatingOAuth2TokenValidator tokenValidator = - new DelegatingOAuth2TokenValidator<>(); + DelegatingOAuth2TokenValidator tokenValidator = new DelegatingOAuth2TokenValidator<>(); AbstractOAuth2Token token = mock(AbstractOAuth2Token.class); - assertThat(tokenValidator.validate(token).hasErrors()).isFalse(); } @@ -50,19 +49,12 @@ public class DelegatingOAuth2TokenValidatorTests { public void validateWhenAnyValidatorFailsThenReturnsFailureResultContainingDetailFromFailingValidator() { OAuth2TokenValidator success = mock(OAuth2TokenValidator.class); OAuth2TokenValidator failure = mock(OAuth2TokenValidator.class); - - when(success.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.success()); - when(failure.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.failure(DETAIL)); - - DelegatingOAuth2TokenValidator tokenValidator = - new DelegatingOAuth2TokenValidator<>(Arrays.asList(success, failure)); + given(success.validate(any(AbstractOAuth2Token.class))).willReturn(OAuth2TokenValidatorResult.success()); + given(failure.validate(any(AbstractOAuth2Token.class))).willReturn(OAuth2TokenValidatorResult.failure(DETAIL)); + DelegatingOAuth2TokenValidator tokenValidator = new DelegatingOAuth2TokenValidator<>( + Arrays.asList(success, failure)); AbstractOAuth2Token token = mock(AbstractOAuth2Token.class); - - OAuth2TokenValidatorResult result = - tokenValidator.validate(token); - + OAuth2TokenValidatorResult result = tokenValidator.validate(token); assertThat(result.hasErrors()).isTrue(); assertThat(result.getErrors()).containsExactly(DETAIL); } @@ -71,21 +63,15 @@ public class DelegatingOAuth2TokenValidatorTests { public void validateWhenMultipleValidatorsFailThenReturnsFailureResultContainingAllDetails() { OAuth2TokenValidator firstFailure = mock(OAuth2TokenValidator.class); OAuth2TokenValidator secondFailure = mock(OAuth2TokenValidator.class); - OAuth2Error otherDetail = new OAuth2Error("another-error"); - - when(firstFailure.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.failure(DETAIL)); - when(secondFailure.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.failure(otherDetail)); - - DelegatingOAuth2TokenValidator tokenValidator = - new DelegatingOAuth2TokenValidator<>(firstFailure, secondFailure); + given(firstFailure.validate(any(AbstractOAuth2Token.class))) + .willReturn(OAuth2TokenValidatorResult.failure(DETAIL)); + given(secondFailure.validate(any(AbstractOAuth2Token.class))) + .willReturn(OAuth2TokenValidatorResult.failure(otherDetail)); + DelegatingOAuth2TokenValidator tokenValidator = new DelegatingOAuth2TokenValidator<>( + firstFailure, secondFailure); AbstractOAuth2Token token = mock(AbstractOAuth2Token.class); - - OAuth2TokenValidatorResult result = - tokenValidator.validate(token); - + OAuth2TokenValidatorResult result = tokenValidator.validate(token); assertThat(result.hasErrors()).isTrue(); assertThat(result.getErrors()).containsExactly(DETAIL, otherDetail); } @@ -94,51 +80,37 @@ public class DelegatingOAuth2TokenValidatorTests { public void validateWhenAllValidatorsSucceedThenReturnsSuccessfulResult() { OAuth2TokenValidator firstSuccess = mock(OAuth2TokenValidator.class); OAuth2TokenValidator secondSuccess = mock(OAuth2TokenValidator.class); - - when(firstSuccess.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.success()); - when(secondSuccess.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.success()); - - DelegatingOAuth2TokenValidator tokenValidator = - new DelegatingOAuth2TokenValidator<>(Arrays.asList(firstSuccess, secondSuccess)); + given(firstSuccess.validate(any(AbstractOAuth2Token.class))).willReturn(OAuth2TokenValidatorResult.success()); + given(secondSuccess.validate(any(AbstractOAuth2Token.class))).willReturn(OAuth2TokenValidatorResult.success()); + DelegatingOAuth2TokenValidator tokenValidator = new DelegatingOAuth2TokenValidator<>( + Arrays.asList(firstSuccess, secondSuccess)); AbstractOAuth2Token token = mock(AbstractOAuth2Token.class); - - OAuth2TokenValidatorResult result = - tokenValidator.validate(token); - + OAuth2TokenValidatorResult result = tokenValidator.validate(token); assertThat(result.hasErrors()).isFalse(); assertThat(result.getErrors()).isEmpty(); } @Test public void constructorWhenInvokedWithNullValidatorListThenThrowsIllegalArgumentException() { - assertThatCode(() -> new DelegatingOAuth2TokenValidator<> - ((Collection>) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new DelegatingOAuth2TokenValidator<>( + (Collection>) null)); } @Test public void constructorsWhenInvokedWithSameInputsThenResultInSameOutputs() { OAuth2TokenValidator firstSuccess = mock(OAuth2TokenValidator.class); OAuth2TokenValidator secondSuccess = mock(OAuth2TokenValidator.class); - - when(firstSuccess.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.success()); - when(secondSuccess.validate(any(AbstractOAuth2Token.class))) - .thenReturn(OAuth2TokenValidatorResult.success()); - - DelegatingOAuth2TokenValidator firstValidator = - new DelegatingOAuth2TokenValidator<>(Arrays.asList(firstSuccess, secondSuccess)); - DelegatingOAuth2TokenValidator secondValidator = - new DelegatingOAuth2TokenValidator<>(firstSuccess, secondSuccess); - + given(firstSuccess.validate(any(AbstractOAuth2Token.class))).willReturn(OAuth2TokenValidatorResult.success()); + given(secondSuccess.validate(any(AbstractOAuth2Token.class))).willReturn(OAuth2TokenValidatorResult.success()); + DelegatingOAuth2TokenValidator firstValidator = new DelegatingOAuth2TokenValidator<>( + Arrays.asList(firstSuccess, secondSuccess)); + DelegatingOAuth2TokenValidator secondValidator = new DelegatingOAuth2TokenValidator<>( + firstSuccess, secondSuccess); AbstractOAuth2Token token = mock(AbstractOAuth2Token.class); - firstValidator.validate(token); secondValidator.validate(token); - verify(firstSuccess, times(2)).validate(token); verify(secondSuccess, times(2)).validate(token); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2AccessTokenTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2AccessTokenTests.java index ccede95337..984c1b7804 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2AccessTokenTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2AccessTokenTests.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core; -import org.junit.Test; -import org.springframework.util.SerializationUtils; +package org.springframework.security.oauth2.core; import java.time.Instant; import java.util.Arrays; import java.util.LinkedHashSet; import java.util.Set; +import org.junit.Test; + +import org.springframework.util.SerializationUtils; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -31,10 +33,15 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class OAuth2AccessTokenTests { + private static final OAuth2AccessToken.TokenType TOKEN_TYPE = OAuth2AccessToken.TokenType.BEARER; + private static final String TOKEN_VALUE = "access-token"; + private static final Instant ISSUED_AT = Instant.now(); + private static final Instant EXPIRES_AT = Instant.from(ISSUED_AT).plusSeconds(60); + private static final Set SCOPES = new LinkedHashSet<>(Arrays.asList("scope1", "scope2")); @Test @@ -64,9 +71,7 @@ public class OAuth2AccessTokenTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2AccessToken accessToken = new OAuth2AccessToken( - TOKEN_TYPE, TOKEN_VALUE, ISSUED_AT, EXPIRES_AT, SCOPES); - + OAuth2AccessToken accessToken = new OAuth2AccessToken(TOKEN_TYPE, TOKEN_VALUE, ISSUED_AT, EXPIRES_AT, SCOPES); assertThat(accessToken.getTokenType()).isEqualTo(TOKEN_TYPE); assertThat(accessToken.getTokenValue()).isEqualTo(TOKEN_VALUE); assertThat(accessToken.getIssuedAt()).isEqualTo(ISSUED_AT); @@ -77,11 +82,9 @@ public class OAuth2AccessTokenTests { // gh-5492 @Test public void constructorWhenCreatedThenIsSerializableAndDeserializable() { - OAuth2AccessToken accessToken = new OAuth2AccessToken( - TOKEN_TYPE, TOKEN_VALUE, ISSUED_AT, EXPIRES_AT, SCOPES); + OAuth2AccessToken accessToken = new OAuth2AccessToken(TOKEN_TYPE, TOKEN_VALUE, ISSUED_AT, EXPIRES_AT, SCOPES); byte[] serialized = SerializationUtils.serialize(accessToken); accessToken = (OAuth2AccessToken) SerializationUtils.deserialize(serialized); - assertThat(serialized).isNotNull(); assertThat(accessToken.getTokenType()).isEqualTo(TOKEN_TYPE); assertThat(accessToken.getTokenValue()).isEqualTo(TOKEN_VALUE); @@ -89,4 +92,5 @@ public class OAuth2AccessTokenTests { assertThat(accessToken.getExpiresAt()).isEqualTo(EXPIRES_AT); assertThat(accessToken.getScopes()).isEqualTo(SCOPES); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2ErrorTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2ErrorTests.java index 9cc9b12299..fab4afdc2c 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2ErrorTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2ErrorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import org.junit.Test; @@ -25,8 +26,11 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class OAuth2ErrorTests { + private static final String ERROR_CODE = "error-code"; + private static final String ERROR_DESCRIPTION = "error-description"; + private static final String ERROR_URI = "error-uri"; @Test(expected = IllegalArgumentException.class) @@ -37,9 +41,9 @@ public class OAuth2ErrorTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { OAuth2Error error = new OAuth2Error(ERROR_CODE, ERROR_DESCRIPTION, ERROR_URI); - assertThat(error.getErrorCode()).isEqualTo(ERROR_CODE); assertThat(error.getDescription()).isEqualTo(ERROR_DESCRIPTION); assertThat(error.getUri()).isEqualTo(ERROR_URI); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResultTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResultTests.java index 22c56c1576..e1aae08a90 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResultTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2TokenValidatorResultTests.java @@ -13,13 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core; import org.junit.Test; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -28,8 +26,8 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Josh Cummings */ public class OAuth2TokenValidatorResultTests { - private static final OAuth2Error DETAIL = new OAuth2Error( - "error", "description", "uri"); + + private static final OAuth2Error DETAIL = new OAuth2Error("error", "description", "uri"); @Test public void successWhenInvokedThenReturnsSuccessfulResult() { @@ -40,7 +38,6 @@ public class OAuth2TokenValidatorResultTests { @Test public void failureWhenInvokedWithDetailReturnsFailureResultIncludingDetail() { OAuth2TokenValidatorResult failure = OAuth2TokenValidatorResult.failure(DETAIL); - assertThat(failure.hasErrors()).isTrue(); assertThat(failure.getErrors()).containsExactly(DETAIL); } @@ -48,8 +45,8 @@ public class OAuth2TokenValidatorResultTests { @Test public void failureWhenInvokedWithMultipleDetailsReturnsFailureResultIncludingAll() { OAuth2TokenValidatorResult failure = OAuth2TokenValidatorResult.failure(DETAIL, DETAIL); - assertThat(failure.hasErrors()).isTrue(); assertThat(failure.getErrors()).containsExactly(DETAIL, DETAIL); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AccessTokens.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AccessTokens.java index 62ca87ff48..4d8a130169 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AccessTokens.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AccessTokens.java @@ -25,19 +25,19 @@ import java.util.HashSet; * @author Rob Winch * @since 5.1 */ -public class TestOAuth2AccessTokens { +public final class TestOAuth2AccessTokens { + + private TestOAuth2AccessTokens() { + } + public static OAuth2AccessToken noScopes() { - return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "no-scopes", - Instant.now(), + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "no-scopes", Instant.now(), Instant.now().plus(Duration.ofDays(1))); } public static OAuth2AccessToken scopes(String... scopes) { - return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "scopes", - Instant.now(), - Instant.now().plus(Duration.ofDays(1)), - new HashSet<>(Arrays.asList(scopes))); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "scopes", Instant.now(), + Instant.now().plus(Duration.ofDays(1)), new HashSet<>(Arrays.asList(scopes))); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2RefreshTokens.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2RefreshTokens.java index 8face452f5..3cbaa3dbc4 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2RefreshTokens.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/TestOAuth2RefreshTokens.java @@ -22,8 +22,13 @@ import java.time.Instant; * @author Rob Winch * @since 5.1 */ -public class TestOAuth2RefreshTokens { +public final class TestOAuth2RefreshTokens { + + private TestOAuth2RefreshTokens() { + } + public static OAuth2RefreshToken refreshToken() { return new OAuth2RefreshToken("refresh-token", Instant.now()); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimConversionServiceTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimConversionServiceTests.java index c7b8cc3571..6a07f8f37f 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimConversionServiceTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimConversionServiceTests.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.assertj.core.util.Lists; -import org.junit.Test; -import org.springframework.core.convert.ConversionService; +package org.springframework.security.oauth2.core.converter; import java.net.URL; import java.time.Instant; @@ -29,6 +26,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.assertj.core.util.Lists; +import org.junit.Test; + +import org.springframework.core.convert.ConversionService; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -38,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 5.2 */ public class ClaimConversionServiceTests { + private final ConversionService conversionService = ClaimConversionService.getSharedInstance(); @Test @@ -104,7 +107,8 @@ public class ClaimConversionServiceTests { Instant instant = Instant.now(); assertThat(this.conversionService.convert(String.valueOf(instant.getEpochSecond()), Instant.class)) .isEqualTo(instant.truncatedTo(ChronoUnit.SECONDS)); - assertThat(this.conversionService.convert(String.valueOf(instant.toString()), Instant.class)).isEqualTo(instant); + assertThat(this.conversionService.convert(String.valueOf(instant.toString()), Instant.class)) + .isEqualTo(instant); } @Test @@ -179,8 +183,7 @@ public class ClaimConversionServiceTests { @Test public void convertListStringWhenNotConvertibleThenReturnSingletonList() { String string = "not-convertible-list"; - assertThat(this.conversionService.convert(string, List.class)) - .isEqualTo(Collections.singletonList(string)); + assertThat(this.conversionService.convert(string, List.class)).isEqualTo(Collections.singletonList(string)); } @Test @@ -224,4 +227,5 @@ public class ClaimConversionServiceTests { List notConvertibleList = Lists.list("1", "2", "3", "4"); assertThat(this.conversionService.convert(notConvertibleList, Map.class)).isNull(); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverterTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverterTests.java index fee193f8e4..0b0d9f0acd 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverterTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/converter/ClaimTypeConverterTests.java @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.converter; -import org.assertj.core.util.Lists; -import org.junit.Before; -import org.junit.Test; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.Converter; +package org.springframework.security.oauth2.core.converter; import java.net.URL; import java.time.Instant; @@ -28,8 +23,15 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.assertj.core.util.Lists; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.Converter; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link ClaimTypeConverter}. @@ -38,13 +40,21 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @since 5.2 */ public class ClaimTypeConverterTests { + private static final String STRING_CLAIM = "string-claim"; + private static final String BOOLEAN_CLAIM = "boolean-claim"; + private static final String INSTANT_CLAIM = "instant-claim"; + private static final String URL_CLAIM = "url-claim"; + private static final String COLLECTION_STRING_CLAIM = "collection-string-claim"; + private static final String LIST_STRING_CLAIM = "list-string-claim"; + private static final String MAP_STRING_OBJECT_CLAIM = "map-string-object-claim"; + private ClaimTypeConverter claimTypeConverter; @Before @@ -58,9 +68,8 @@ public class ClaimTypeConverterTests { TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class))); Converter listStringConverter = getConverter( TypeDescriptor.collection(List.class, TypeDescriptor.valueOf(String.class))); - Converter mapStringObjectConverter = getConverter( - TypeDescriptor.map(Map.class, TypeDescriptor.valueOf(String.class), TypeDescriptor.valueOf(Object.class))); - + Converter mapStringObjectConverter = getConverter(TypeDescriptor.map(Map.class, + TypeDescriptor.valueOf(String.class), TypeDescriptor.valueOf(Object.class))); Map> claimTypeConverters = new HashMap<>(); claimTypeConverters.put(STRING_CLAIM, stringConverter); claimTypeConverters.put(BOOLEAN_CLAIM, booleanConverter); @@ -74,21 +83,20 @@ public class ClaimTypeConverterTests { private static Converter getConverter(TypeDescriptor targetDescriptor) { final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); - return source -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); + return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, + targetDescriptor); } @Test public void constructorWhenConvertersNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ClaimTypeConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new ClaimTypeConverter(null)); } @Test public void constructorWhenConvertersHasNullConverterThenThrowIllegalArgumentException() { Map> claimTypeConverters = new HashMap<>(); claimTypeConverters.put("claim1", null); - assertThatThrownBy(() -> new ClaimTypeConverter(claimTypeConverters)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new ClaimTypeConverter(claimTypeConverters)); } @Test @@ -107,7 +115,6 @@ public class ClaimTypeConverterTests { mapIntegerObject.put(1, "value1"); Map mapStringObject = new HashMap<>(); mapStringObject.put("1", "value1"); - Map claims = new HashMap<>(); claims.put(STRING_CLAIM, Boolean.TRUE); claims.put(BOOLEAN_CLAIM, "true"); @@ -116,9 +123,7 @@ public class ClaimTypeConverterTests { claims.put(COLLECTION_STRING_CLAIM, listNumber); claims.put(LIST_STRING_CLAIM, listNumber); claims.put(MAP_STRING_OBJECT_CLAIM, mapIntegerObject); - claims = this.claimTypeConverter.convert(claims); - assertThat(claims.get(STRING_CLAIM)).isEqualTo("true"); assertThat(claims.get(BOOLEAN_CLAIM)).isEqualTo(Boolean.TRUE); assertThat(claims.get(INSTANT_CLAIM)).isEqualTo(instant); @@ -137,7 +142,6 @@ public class ClaimTypeConverterTests { List listString = Lists.list("1", "2", "3", "4"); Map mapStringObject = new HashMap<>(); mapStringObject.put("1", "value1"); - Map claims = new HashMap<>(); claims.put(STRING_CLAIM, string); claims.put(BOOLEAN_CLAIM, bool); @@ -146,9 +150,7 @@ public class ClaimTypeConverterTests { claims.put(COLLECTION_STRING_CLAIM, listString); claims.put(LIST_STRING_CLAIM, listString); claims.put(MAP_STRING_OBJECT_CLAIM, mapStringObject); - claims = this.claimTypeConverter.convert(claims); - assertThat(claims.get(STRING_CLAIM)).isSameAs(string); assertThat(claims.get(BOOLEAN_CLAIM)).isSameAs(bool); assertThat(claims.get(INSTANT_CLAIM)).isSameAs(instant); @@ -162,9 +164,8 @@ public class ClaimTypeConverterTests { public void convertWhenConverterNotAvailableThenDoesNotConvert() { Map claims = new HashMap<>(); claims.put("claim1", "value1"); - claims = this.claimTypeConverter.convert(claims); - assertThat(claims.get("claim1")).isSameAs("value1"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverterTest.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverterTests.java similarity index 94% rename from oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverterTest.java rename to oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverterTests.java index 56dd03f83f..715f4efa77 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverterTest.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/MapOAuth2AccessTokenResponseConverterTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import java.time.Duration; @@ -32,7 +33,7 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; * * @author Nikita Konev */ -public class MapOAuth2AccessTokenResponseConverterTest { +public class MapOAuth2AccessTokenResponseConverterTests { private MapOAuth2AccessTokenResponseConverter messageConverter; @@ -41,7 +42,6 @@ public class MapOAuth2AccessTokenResponseConverterTest { this.messageConverter = new MapOAuth2AccessTokenResponseConverter(); } - @Test public void shouldConvertFull() { Map map = new HashMap<>(); @@ -52,7 +52,7 @@ public class MapOAuth2AccessTokenResponseConverterTest { map.put("refresh_token", "refresh-token-1234"); map.put("custom_parameter_1", "custom-value-1"); map.put("custom_parameter_2", "custom-value-2"); - OAuth2AccessTokenResponse converted = messageConverter.convert(map); + OAuth2AccessTokenResponse converted = this.messageConverter.convert(map); OAuth2AccessToken accessToken = converted.getAccessToken(); Assert.assertNotNull(accessToken); Assert.assertEquals("access-token-1234", accessToken.getTokenValue()); @@ -63,11 +63,9 @@ public class MapOAuth2AccessTokenResponseConverterTest { Assert.assertTrue(scopes.contains("read")); Assert.assertTrue(scopes.contains("write")); Assert.assertEquals(3600, Duration.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()).getSeconds()); - OAuth2RefreshToken refreshToken = converted.getRefreshToken(); Assert.assertNotNull(refreshToken); Assert.assertEquals("refresh-token-1234", refreshToken.getTokenValue()); - Map additionalParameters = converted.getAdditionalParameters(); Assert.assertNotNull(additionalParameters); Assert.assertEquals(2, additionalParameters.size()); @@ -80,7 +78,7 @@ public class MapOAuth2AccessTokenResponseConverterTest { Map map = new HashMap<>(); map.put("access_token", "access-token-1234"); map.put("token_type", "bearer"); - OAuth2AccessTokenResponse converted = messageConverter.convert(map); + OAuth2AccessTokenResponse converted = this.messageConverter.convert(map); OAuth2AccessToken accessToken = converted.getAccessToken(); Assert.assertNotNull(accessToken); Assert.assertEquals("access-token-1234", accessToken.getTokenValue()); @@ -88,12 +86,9 @@ public class MapOAuth2AccessTokenResponseConverterTest { Set scopes = accessToken.getScopes(); Assert.assertNotNull(scopes); Assert.assertEquals(0, scopes.size()); - Assert.assertEquals(1, Duration.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()).getSeconds()); - OAuth2RefreshToken refreshToken = converted.getRefreshToken(); Assert.assertNull(refreshToken); - Map additionalParameters = converted.getAdditionalParameters(); Assert.assertNotNull(additionalParameters); Assert.assertEquals(0, additionalParameters.size()); @@ -105,7 +100,7 @@ public class MapOAuth2AccessTokenResponseConverterTest { map.put("access_token", "access-token-1234"); map.put("token_type", "bearer"); map.put("expires_in", "2100-01-01-abc"); - OAuth2AccessTokenResponse converted = messageConverter.convert(map); + OAuth2AccessTokenResponse converted = this.messageConverter.convert(map); OAuth2AccessToken accessToken = converted.getAccessToken(); Assert.assertNotNull(accessToken); Assert.assertEquals("access-token-1234", accessToken.getTokenValue()); @@ -113,14 +108,12 @@ public class MapOAuth2AccessTokenResponseConverterTest { Set scopes = accessToken.getScopes(); Assert.assertNotNull(scopes); Assert.assertEquals(0, scopes.size()); - Assert.assertEquals(1, Duration.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()).getSeconds()); - OAuth2RefreshToken refreshToken = converted.getRefreshToken(); Assert.assertNull(refreshToken); - Map additionalParameters = converted.getAdditionalParameters(); Assert.assertNotNull(additionalParameters); Assert.assertEquals(0, additionalParameters.size()); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverterTest.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverterTests.java similarity index 84% rename from oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverterTest.java rename to oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverterTests.java index cdeb3dd73d..ae4f4117b4 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverterTest.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseMapConverterTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import java.util.HashMap; @@ -31,7 +32,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; * * @author Nikita Konev */ -public class OAuth2AccessTokenResponseMapConverterTest { +public class OAuth2AccessTokenResponseMapConverterTests { private OAuth2AccessTokenResponseMapConverter messageConverter; @@ -40,28 +41,25 @@ public class OAuth2AccessTokenResponseMapConverterTest { this.messageConverter = new OAuth2AccessTokenResponseMapConverter(); } - @Test public void convertFull() { Map additionalParameters = new HashMap<>(); additionalParameters.put("custom_parameter_1", "custom-value-1"); additionalParameters.put("custom_parameter_2", "custom-value-2"); - Set scopes = new HashSet<>(); scopes.add("read"); scopes.add("write"); - - OAuth2AccessTokenResponse build = OAuth2AccessTokenResponse - .withToken("access-token-value-1234") + // @formatter:off + OAuth2AccessTokenResponse build = OAuth2AccessTokenResponse.withToken("access-token-value-1234") .expiresIn(3699) .additionalParameters(additionalParameters) .refreshToken("refresh-token-value-1234") .scopes(scopes) .tokenType(OAuth2AccessToken.TokenType.BEARER) .build(); - Map result = messageConverter.convert(build); + // @formatter:on + Map result = this.messageConverter.convert(build); Assert.assertEquals(7, result.size()); - Assert.assertEquals("access-token-value-1234", result.get("access_token")); Assert.assertEquals("refresh-token-value-1234", result.get("refresh_token")); Assert.assertEquals("read write", result.get("scope")); @@ -73,15 +71,16 @@ public class OAuth2AccessTokenResponseMapConverterTest { @Test public void convertMinimal() { - OAuth2AccessTokenResponse build = OAuth2AccessTokenResponse - .withToken("access-token-value-1234") + // @formatter:off + OAuth2AccessTokenResponse build = OAuth2AccessTokenResponse.withToken("access-token-value-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) .build(); - Map result = messageConverter.convert(build); + // @formatter:on + Map result = this.messageConverter.convert(build); Assert.assertEquals(3, result.size()); - Assert.assertEquals("access-token-value-1234", result.get("access_token")); Assert.assertEquals("Bearer", result.get("token_type")); Assert.assertNotNull(result.get("expires_in")); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseTests.java index 8f2d3aa1cb..1d1974f8e4 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AccessTokenResponseTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.endpoint; -import org.junit.Test; -import org.springframework.security.oauth2.core.OAuth2AccessToken; +package org.springframework.security.oauth2.core.endpoint; import java.time.Instant; import java.util.Arrays; @@ -25,6 +23,10 @@ import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import org.junit.Test; + +import org.springframework.security.oauth2.core.OAuth2AccessToken; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -34,46 +36,55 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class OAuth2AccessTokenResponseTests { + private static final String TOKEN_VALUE = "access-token"; + private static final String REFRESH_TOKEN_VALUE = "refresh-token"; + private static final long EXPIRES_IN = Instant.now().plusSeconds(5).toEpochMilli(); @Test(expected = IllegalArgumentException.class) public void buildWhenTokenValueIsNullThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AccessTokenResponse.withToken(null) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(EXPIRES_IN) - .build(); + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(EXPIRES_IN) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildWhenTokenTypeIsNullThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) - .tokenType(null) - .expiresIn(EXPIRES_IN) - .build(); + .tokenType(null) + .expiresIn(EXPIRES_IN) + .build(); + // @formatter:on } @Test public void buildWhenExpiresInIsZeroThenExpiresAtOneSecondAfterIssueAt() { - OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse - .withToken(TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(0) - .build(); - assertThat(tokenResponse.getAccessToken().getExpiresAt()).isEqualTo( - tokenResponse.getAccessToken().getIssuedAt().plusSeconds(1)); + // @formatter:off + OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(0) + .build(); + // @formatter:on + assertThat(tokenResponse.getAccessToken().getExpiresAt()) + .isEqualTo(tokenResponse.getAccessToken().getIssuedAt().plusSeconds(1)); } @Test public void buildWhenExpiresInIsNegativeThenExpiresAtOneSecondAfterIssueAt() { - OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse - .withToken(TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(-1L) - .build(); - assertThat(tokenResponse.getAccessToken().getExpiresAt()).isEqualTo( - tokenResponse.getAccessToken().getIssuedAt().plusSeconds(1)); + // @formatter:off + OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(-1L) + .build(); + // @formatter:on + assertThat(tokenResponse.getAccessToken().getExpiresAt()) + .isEqualTo(tokenResponse.getAccessToken().getIssuedAt().plusSeconds(1)); } @Test @@ -83,16 +94,15 @@ public class OAuth2AccessTokenResponseTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - - OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse - .withToken(TOKEN_VALUE) - .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(expiresAt.toEpochMilli()) - .scopes(scopes) - .refreshToken(REFRESH_TOKEN_VALUE) - .additionalParameters(additionalParameters) - .build(); - + // @formatter:off + OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(expiresAt.toEpochMilli()) + .scopes(scopes) + .refreshToken(REFRESH_TOKEN_VALUE) + .additionalParameters(additionalParameters) + .build(); + // @formatter:on assertThat(tokenResponse.getAccessToken()).isNotNull(); assertThat(tokenResponse.getAccessToken().getTokenValue()).isEqualTo(TOKEN_VALUE); assertThat(tokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); @@ -110,25 +120,25 @@ public class OAuth2AccessTokenResponseTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - - OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse - .withToken(TOKEN_VALUE) + // @formatter:off + OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(expiresAt.toEpochMilli()) .scopes(scopes) .refreshToken(REFRESH_TOKEN_VALUE) .additionalParameters(additionalParameters) .build(); - - OAuth2AccessTokenResponse withResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) - .build(); - - assertThat(withResponse.getAccessToken().getTokenValue()).isEqualTo(tokenResponse.getAccessToken().getTokenValue()); + // @formatter:on + OAuth2AccessTokenResponse withResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse).build(); + assertThat(withResponse.getAccessToken().getTokenValue()) + .isEqualTo(tokenResponse.getAccessToken().getTokenValue()); assertThat(withResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(withResponse.getAccessToken().getIssuedAt()).isEqualTo(tokenResponse.getAccessToken().getIssuedAt()); - assertThat(withResponse.getAccessToken().getExpiresAt()).isEqualTo(tokenResponse.getAccessToken().getExpiresAt()); + assertThat(withResponse.getAccessToken().getExpiresAt()) + .isEqualTo(tokenResponse.getAccessToken().getExpiresAt()); assertThat(withResponse.getAccessToken().getScopes()).isEqualTo(tokenResponse.getAccessToken().getScopes()); - assertThat(withResponse.getRefreshToken().getTokenValue()).isEqualTo(tokenResponse.getRefreshToken().getTokenValue()); + assertThat(withResponse.getRefreshToken().getTokenValue()) + .isEqualTo(tokenResponse.getRefreshToken().getTokenValue()); assertThat(withResponse.getAdditionalParameters()).isEqualTo(tokenResponse.getAdditionalParameters()); } @@ -139,34 +149,30 @@ public class OAuth2AccessTokenResponseTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - - OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse - .withToken(TOKEN_VALUE) + // @formatter:off + OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(expiresAt.toEpochMilli()) .scopes(scopes) .additionalParameters(additionalParameters) .build(); - - OAuth2AccessTokenResponse withResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) - .build(); - + // @formatter:on + OAuth2AccessTokenResponse withResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse).build(); assertThat(withResponse.getRefreshToken()).isNull(); } @Test public void buildWhenResponseAndExpiresInThenExpiresAtEqualToIssuedAtPlusExpiresIn() { - OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse - .withToken(TOKEN_VALUE) + // @formatter:off + OAuth2AccessTokenResponse tokenResponse = OAuth2AccessTokenResponse.withToken(TOKEN_VALUE) .tokenType(OAuth2AccessToken.TokenType.BEARER) .build(); - + // @formatter:on long expiresIn = 30; OAuth2AccessTokenResponse withResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) - .expiresIn(expiresIn) - .build(); - - assertThat(withResponse.getAccessToken().getExpiresAt()).isEqualTo( - withResponse.getAccessToken().getIssuedAt().plusSeconds(expiresIn)); + .expiresIn(expiresIn).build(); + assertThat(withResponse.getAccessToken().getExpiresAt()) + .isEqualTo(withResponse.getAccessToken().getIssuedAt().plusSeconds(expiresIn)); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationExchangeTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationExchangeTests.java index 6986e19496..bd134c245b 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationExchangeTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationExchangeTests.java @@ -13,13 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; -import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2AuthorizationExchange}. @@ -30,21 +29,22 @@ public class OAuth2AuthorizationExchangeTests { @Test(expected = IllegalArgumentException.class) public void constructorWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() { - new OAuth2AuthorizationExchange(null, success().build()); + new OAuth2AuthorizationExchange(null, TestOAuth2AuthorizationResponses.success().build()); } @Test(expected = IllegalArgumentException.class) public void constructorWhenAuthorizationResponseIsNullThenThrowIllegalArgumentException() { - new OAuth2AuthorizationExchange(request().build(), null); + new OAuth2AuthorizationExchange(TestOAuth2AuthorizationRequests.request().build(), null); } @Test public void constructorWhenRequiredArgsProvidedThenCreated() { - OAuth2AuthorizationRequest authorizationRequest = request().build(); - OAuth2AuthorizationResponse authorizationResponse = success().build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); assertThat(authorizationExchange.getAuthorizationRequest()).isEqualTo(authorizationRequest); assertThat(authorizationExchange.getAuthorizationResponse()).isEqualTo(authorizationResponse); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java index 8f0745d4f2..ddb69f6527 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.endpoint; -import org.junit.Test; -import org.springframework.security.oauth2.core.AuthorizationGrantType; +package org.springframework.security.oauth2.core.endpoint; import java.net.URI; import java.util.Arrays; @@ -25,9 +23,12 @@ import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import org.junit.Test; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OAuth2AuthorizationRequest}. @@ -36,106 +37,118 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Joe Grandja */ public class OAuth2AuthorizationRequestTests { + private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize"; + private static final String CLIENT_ID = "client-id"; + private static final String REDIRECT_URI = "https://example.com"; + private static final Set SCOPES = new LinkedHashSet<>(Arrays.asList("scope1", "scope2")); + private static final String STATE = "state"; @Test public void buildWhenAuthorizationUriIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(null) - .clientId(CLIENT_ID) - .redirectUri(REDIRECT_URI) - .scopes(SCOPES) - .state(STATE) - .build() - ).isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizationRequest + .authorizationCode() + .authorizationUri(null) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .scopes(SCOPES) + .state(STATE) + .build() + ); + // @formatter:on } @Test public void buildWhenClientIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(null) - .redirectUri(REDIRECT_URI) - .scopes(SCOPES) - .state(STATE) - .build() - ).isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(AUTHORIZATION_URI) + .clientId(null) + .redirectUri(REDIRECT_URI) + .scopes(SCOPES) + .state(STATE) + .build() + ); + // @formatter:on } @Test public void buildWhenRedirectUriIsNullForImplicitThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> - OAuth2AuthorizationRequest.implicit() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .redirectUri(null) - .scopes(SCOPES) - .state(STATE) - .build() - ).isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> OAuth2AuthorizationRequest.implicit() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(null) + .scopes(SCOPES) + .state(STATE).build() + ); + // @formatter:on } @Test public void buildWhenRedirectUriIsNullForAuthorizationCodeThenDoesNotThrowAnyException() { - assertThatCode(() -> - OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .redirectUri(null) - .scopes(SCOPES) - .state(STATE) - .build()) - .doesNotThrowAnyException(); + // @formatter:off + OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(null) + .scopes(SCOPES) + .state(STATE) + .build(); + // @formatter:on } @Test public void buildWhenScopesIsNullThenDoesNotThrowAnyException() { - assertThatCode(() -> - OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .redirectUri(REDIRECT_URI) - .scopes(null) - .state(STATE) - .build()) - .doesNotThrowAnyException(); + // @formatter:off + OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .scopes(null) + .state(STATE) + .build(); + // @formatter:on } @Test public void buildWhenStateIsNullThenDoesNotThrowAnyException() { - assertThatCode(() -> - OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .redirectUri(REDIRECT_URI) - .scopes(SCOPES) - .state(null) - .build()) - .doesNotThrowAnyException(); + // @formatter:off + OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .scopes(SCOPES) + .state(null) + .build(); + // @formatter:on } @Test public void buildWhenAdditionalParametersEmptyThenDoesNotThrowAnyException() { - assertThatCode(() -> - OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .redirectUri(REDIRECT_URI) - .scopes(SCOPES) - .state(STATE) - .additionalParameters(Map::clear) - .build()) - .doesNotThrowAnyException(); + // @formatter:off + OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .scopes(SCOPES) + .state(STATE) + .additionalParameters(Map::clear) + .build(); + // @formatter:on } @Test public void buildWhenImplicitThenGrantTypeResponseTypeIsSet() { + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.implicit() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) @@ -143,12 +156,14 @@ public class OAuth2AuthorizationRequestTests { .scopes(SCOPES) .state(STATE) .build(); + // @formatter:on assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.TOKEN); } @Test public void buildWhenAuthorizationCodeThenGrantTypeResponseTypeIsSet() { + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) @@ -156,6 +171,7 @@ public class OAuth2AuthorizationRequestTests { .scopes(SCOPES) .state(STATE) .build(); + // @formatter:on assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); } @@ -165,11 +181,10 @@ public class OAuth2AuthorizationRequestTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - Map attributes = new HashMap<>(); attributes.put("attribute1", "value1"); attributes.put("attribute2", "value2"); - + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) @@ -180,7 +195,7 @@ public class OAuth2AuthorizationRequestTests { .attributes(attributes) .authorizationRequestUri(AUTHORIZATION_URI) .build(); - + // @formatter:on assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI); assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); @@ -195,6 +210,7 @@ public class OAuth2AuthorizationRequestTests { @Test public void buildWhenScopesMultiThenSeparatedByEncodedSpace() { + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.implicit() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) @@ -202,16 +218,15 @@ public class OAuth2AuthorizationRequestTests { .scopes(SCOPES) .state(STATE) .build(); - + // @formatter:on assertThat(authorizationRequest.getAuthorizationRequestUri()) - .isEqualTo("https://provider.com/oauth2/authorize?" + - "response_type=token&client_id=client-id&" + - "scope=scope1%20scope2&state=state&" + - "redirect_uri=https://example.com"); + .isEqualTo("https://provider.com/oauth2/authorize?" + "response_type=token&client_id=client-id&" + + "scope=scope1%20scope2&state=state&" + "redirect_uri=https://example.com"); } @Test public void buildWhenAuthorizationRequestUriSetThenOverridesDefault() { + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) @@ -220,19 +235,22 @@ public class OAuth2AuthorizationRequestTests { .state(STATE) .authorizationRequestUri(AUTHORIZATION_URI) .build(); + // @formatter:on assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI); } @Test public void buildWhenAuthorizationRequestUriFunctionSetThenOverridesDefault() { + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) .redirectUri(REDIRECT_URI) .scopes(SCOPES) .state(STATE) - .authorizationRequestUri(uriBuilder -> URI.create(AUTHORIZATION_URI)) + .authorizationRequestUri((uriBuilder) -> URI.create(AUTHORIZATION_URI)) .build(); + // @formatter:on assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI); } @@ -241,37 +259,26 @@ public class OAuth2AuthorizationRequestTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .redirectUri(REDIRECT_URI) - .scopes(SCOPES) - .state(STATE) - .additionalParameters(additionalParameters) - .build(); - + .authorizationUri(AUTHORIZATION_URI).clientId(CLIENT_ID).redirectUri(REDIRECT_URI).scopes(SCOPES) + .state(STATE).additionalParameters(additionalParameters).build(); assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .isEqualTo("https://provider.com/oauth2/authorize?" + - "response_type=code&client_id=client-id&" + - "scope=scope1%20scope2&state=state&" + - "redirect_uri=https://example.com¶m1=value1¶m2=value2"); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?" + + "response_type=code&client_id=client-id&" + "scope=scope1%20scope2&state=state&" + + "redirect_uri=https://example.com¶m1=value1¶m2=value2"); } @Test public void buildWhenRequiredParametersSetThenAuthorizationRequestUriIncludesRequiredParametersOnly() { OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri(AUTHORIZATION_URI) - .clientId(CLIENT_ID) - .build(); - - assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id"); + .authorizationUri(AUTHORIZATION_URI).clientId(CLIENT_ID).build(); + assertThat(authorizationRequest.getAuthorizationRequestUri()) + .isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id"); } @Test public void fromWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationRequest.from(null)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> OAuth2AuthorizationRequest.from(null)); } @Test @@ -279,11 +286,10 @@ public class OAuth2AuthorizationRequestTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("param1", "value1"); additionalParameters.put("param2", "value2"); - Map attributes = new HashMap<>(); attributes.put("attribute1", "value1"); attributes.put("attribute2", "value2"); - + // @formatter:off OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID) @@ -293,50 +299,45 @@ public class OAuth2AuthorizationRequestTests { .additionalParameters(additionalParameters) .attributes(attributes) .build(); - - OAuth2AuthorizationRequest authorizationRequestCopy = - OAuth2AuthorizationRequest.from(authorizationRequest).build(); - - assertThat(authorizationRequestCopy.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri()); + OAuth2AuthorizationRequest authorizationRequestCopy = OAuth2AuthorizationRequest.from(authorizationRequest) + .build(); + // @formatter:on + assertThat(authorizationRequestCopy.getAuthorizationUri()) + .isEqualTo(authorizationRequest.getAuthorizationUri()); assertThat(authorizationRequestCopy.getGrantType()).isEqualTo(authorizationRequest.getGrantType()); assertThat(authorizationRequestCopy.getResponseType()).isEqualTo(authorizationRequest.getResponseType()); assertThat(authorizationRequestCopy.getClientId()).isEqualTo(authorizationRequest.getClientId()); assertThat(authorizationRequestCopy.getRedirectUri()).isEqualTo(authorizationRequest.getRedirectUri()); assertThat(authorizationRequestCopy.getScopes()).isEqualTo(authorizationRequest.getScopes()); assertThat(authorizationRequestCopy.getState()).isEqualTo(authorizationRequest.getState()); - assertThat(authorizationRequestCopy.getAdditionalParameters()).isEqualTo(authorizationRequest.getAdditionalParameters()); + assertThat(authorizationRequestCopy.getAdditionalParameters()) + .isEqualTo(authorizationRequest.getAdditionalParameters()); assertThat(authorizationRequestCopy.getAttributes()).isEqualTo(authorizationRequest.getAttributes()); - assertThat(authorizationRequestCopy.getAuthorizationRequestUri()).isEqualTo(authorizationRequest.getAuthorizationRequestUri()); + assertThat(authorizationRequestCopy.getAuthorizationRequestUri()) + .isEqualTo(authorizationRequest.getAuthorizationRequestUri()); } @Test public void buildWhenAuthorizationUriIncludesQueryParameterThenAuthorizationRequestUrlIncludesIt() { - OAuth2AuthorizationRequest authorizationRequest = - TestOAuth2AuthorizationRequests.request() - .authorizationUri(AUTHORIZATION_URI + - "?param1=value1¶m2=value2").build(); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .authorizationUri(AUTHORIZATION_URI + "?param1=value1¶m2=value2").build(); assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .isEqualTo("https://provider.com/oauth2/authorize?" + - "param1=value1¶m2=value2&" + - "response_type=code&client_id=client-id&state=state&" + - "redirect_uri=https://example.com/authorize/oauth2/code/registration-id"); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?" + + "param1=value1¶m2=value2&" + "response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id"); } @Test public void buildWhenAuthorizationUriIncludesEscapedQueryParameterThenAuthorizationRequestUrlIncludesIt() { - OAuth2AuthorizationRequest authorizationRequest = - TestOAuth2AuthorizationRequests.request() - .authorizationUri(AUTHORIZATION_URI + - "?claims=%7B%22userinfo%22%3A%7B%22email_verified%22%3A%7B%22essential%22%3Atrue%7D%7D%7D").build(); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .authorizationUri(AUTHORIZATION_URI + + "?claims=%7B%22userinfo%22%3A%7B%22email_verified%22%3A%7B%22essential%22%3Atrue%7D%7D%7D") + .build(); assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .isEqualTo("https://provider.com/oauth2/authorize?" + - "claims=%7B%22userinfo%22%3A%7B%22email_verified%22%3A%7B%22essential%22%3Atrue%7D%7D%7D&" + - "response_type=code&client_id=client-id&state=state&" + - "redirect_uri=https://example.com/authorize/oauth2/code/registration-id"); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?" + + "claims=%7B%22userinfo%22%3A%7B%22email_verified%22%3A%7B%22essential%22%3Atrue%7D%7D%7D&" + + "response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id"); } @Test @@ -345,16 +346,13 @@ public class OAuth2AuthorizationRequestTests { additionalParameters.put("item amount", "19.95" + '\u20ac'); additionalParameters.put("item name", "H" + '\u00c5' + "M" + '\u00d6'); additionalParameters.put('\u00e2' + "ge", "4" + '\u00bd'); - OAuth2AuthorizationRequest authorizationRequest = - TestOAuth2AuthorizationRequests.request() - .additionalParameters(additionalParameters) - .build(); - + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .additionalParameters(additionalParameters).build(); assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull(); - assertThat(authorizationRequest.getAuthorizationRequestUri()) - .isEqualTo("https://example.com/login/oauth/authorize?" + - "response_type=code&client_id=client-id&state=state&" + - "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" + - "item%20amount=19.95%E2%82%AC&%C3%A2ge=4%C2%BD&item%20name=H%C3%85M%C3%96"); + assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo( + "https://example.com/login/oauth/authorize?" + "response_type=code&client_id=client-id&state=state&" + + "redirect_uri=https://example.com/authorize/oauth2/code/registration-id&" + + "item%20amount=19.95%E2%82%AC&%C3%A2ge=4%C2%BD&item%20name=H%C3%85M%C3%96"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTests.java index 00bd042b70..413fb236cd 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTests.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; /** * Tests for {@link OAuth2AuthorizationResponse}. @@ -26,102 +26,137 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Joe Grandja */ public class OAuth2AuthorizationResponseTests { + private static final String AUTH_CODE = "auth-code"; + private static final String REDIRECT_URI = "https://example.com"; + private static final String STATE = "state"; + private static final String ERROR_CODE = "error-code"; + private static final String ERROR_DESCRIPTION = "error-description"; + private static final String ERROR_URI = "error-uri"; @Test(expected = IllegalArgumentException.class) public void buildSuccessResponseWhenAuthCodeIsNullThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AuthorizationResponse.success(null) - .redirectUri(REDIRECT_URI) - .state(STATE) - .build(); + .redirectUri(REDIRECT_URI) + .state(STATE) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildSuccessResponseWhenRedirectUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AuthorizationResponse.success(AUTH_CODE) - .redirectUri(null) - .state(STATE) - .build(); + .redirectUri(null) + .state(STATE) + .build(); + // @formatter:on } @Test public void buildSuccessResponseWhenStateIsNullThenDoesNotThrowAnyException() { - assertThatCode(() -> OAuth2AuthorizationResponse.success(AUTH_CODE) - .redirectUri(REDIRECT_URI) - .state(null) - .build()).doesNotThrowAnyException(); + // @formatter:off + OAuth2AuthorizationResponse.success(AUTH_CODE) + .redirectUri(REDIRECT_URI) + .state(null) + .build(); + // @formatter:on } @Test public void buildSuccessResponseWhenAllAttributesProvidedThenAllAttributesAreSet() { + // @formatter:off OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success(AUTH_CODE) - .redirectUri(REDIRECT_URI) - .state(STATE) - .build(); - assertThat(authorizationResponse.getCode()).isEqualTo(AUTH_CODE); - assertThat(authorizationResponse.getRedirectUri()).isEqualTo(REDIRECT_URI); - assertThat(authorizationResponse.getState()).isEqualTo(STATE); + .redirectUri(REDIRECT_URI) + .state(STATE) + .build(); + assertThat(authorizationResponse.getCode()) + .isEqualTo(AUTH_CODE); + assertThat(authorizationResponse.getRedirectUri()) + .isEqualTo(REDIRECT_URI); + assertThat(authorizationResponse.getState()) + .isEqualTo(STATE); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildSuccessResponseWhenErrorCodeIsSetThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AuthorizationResponse.success(AUTH_CODE) - .redirectUri(REDIRECT_URI) - .state(STATE) - .errorCode(ERROR_CODE) - .build(); + .redirectUri(REDIRECT_URI) + .state(STATE) + .errorCode(ERROR_CODE) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildErrorResponseWhenErrorCodeIsNullThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AuthorizationResponse.error(null) - .redirectUri(REDIRECT_URI) - .state(STATE) - .build(); + .redirectUri(REDIRECT_URI) + .state(STATE) + .build(); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildErrorResponseWhenRedirectUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AuthorizationResponse.error(ERROR_CODE) - .redirectUri(null) - .state(STATE) - .build(); + .redirectUri(null) + .state(STATE) + .build(); + // @formatter:on } @Test public void buildErrorResponseWhenStateIsNullThenDoesNotThrowAnyException() { - assertThatCode(() -> OAuth2AuthorizationResponse.error(ERROR_CODE) - .redirectUri(REDIRECT_URI) - .state(null) - .build()).doesNotThrowAnyException(); + // @formatter:off + OAuth2AuthorizationResponse.error(ERROR_CODE) + .redirectUri(REDIRECT_URI) + .state(null) + .build(); + // @formatter:on } @Test public void buildErrorResponseWhenAllAttributesProvidedThenAllAttributesAreSet() { + // @formatter:off OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.error(ERROR_CODE) - .errorDescription(ERROR_DESCRIPTION) - .errorUri(ERROR_URI) - .redirectUri(REDIRECT_URI) - .state(STATE) - .build(); - assertThat(authorizationResponse.getError().getErrorCode()).isEqualTo(ERROR_CODE); - assertThat(authorizationResponse.getError().getDescription()).isEqualTo(ERROR_DESCRIPTION); - assertThat(authorizationResponse.getError().getUri()).isEqualTo(ERROR_URI); - assertThat(authorizationResponse.getRedirectUri()).isEqualTo(REDIRECT_URI); - assertThat(authorizationResponse.getState()).isEqualTo(STATE); + .errorDescription(ERROR_DESCRIPTION) + .errorUri(ERROR_URI) + .redirectUri(REDIRECT_URI) + .state(STATE) + .build(); + assertThat(authorizationResponse.getError().getErrorCode()) + .isEqualTo(ERROR_CODE); + assertThat(authorizationResponse.getError().getDescription()) + .isEqualTo(ERROR_DESCRIPTION); + assertThat(authorizationResponse.getError().getUri()) + .isEqualTo(ERROR_URI); + assertThat(authorizationResponse.getRedirectUri()) + .isEqualTo(REDIRECT_URI); + assertThat(authorizationResponse.getState()) + .isEqualTo(STATE); + // @formatter:on } @Test(expected = IllegalArgumentException.class) public void buildErrorResponseWhenAuthCodeIsSetThenThrowIllegalArgumentException() { + // @formatter:off OAuth2AuthorizationResponse.error(ERROR_CODE) - .redirectUri(REDIRECT_URI) - .state(STATE) - .code(AUTH_CODE) - .build(); + .redirectUri(REDIRECT_URI) + .state(STATE) + .code(AUTH_CODE) + .build(); + // @formatter:on } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTypeTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTypeTests.java index dd963fa0a1..238ddc6e3b 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTypeTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationResponseTypeTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.endpoint; import org.junit.Test; @@ -35,4 +36,5 @@ public class OAuth2AuthorizationResponseTypeTests { public void getValueWhenResponseTypeTokenThenReturnToken() { assertThat(OAuth2AuthorizationResponseType.TOKEN.getValue()).isEqualTo("token"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AccessTokenResponses.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AccessTokenResponses.java index 3cee5f9ef1..f952ff4bd5 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AccessTokenResponses.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AccessTokenResponses.java @@ -16,26 +16,33 @@ package org.springframework.security.oauth2.core.endpoint; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; - import java.util.HashMap; import java.util.Map; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; + /** * @author Rob Winch * @since 5.1 */ -public class TestOAuth2AccessTokenResponses { +public final class TestOAuth2AccessTokenResponses { + + private TestOAuth2AccessTokenResponses() { + } + public static OAuth2AccessTokenResponse.Builder accessTokenResponse() { - return OAuth2AccessTokenResponse.withToken("token") + // @formatter:off + return OAuth2AccessTokenResponse + .withToken("token") .tokenType(OAuth2AccessToken.TokenType.BEARER); + // @formatter:on } public static OAuth2AccessTokenResponse.Builder oidcAccessTokenResponse() { Map additionalParameters = new HashMap<>(); additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); - return accessTokenResponse() - .additionalParameters(additionalParameters); + return accessTokenResponse().additionalParameters(additionalParameters); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationExchanges.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationExchanges.java index 98fdbec1dc..761ce7f205 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationExchanges.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationExchanges.java @@ -21,7 +21,10 @@ package org.springframework.security.oauth2.core.endpoint; * @author Eddú Meléndez * @since 5.1 */ -public class TestOAuth2AuthorizationExchanges { +public final class TestOAuth2AuthorizationExchanges { + + private TestOAuth2AuthorizationExchanges() { + } public static OAuth2AuthorizationExchange success() { OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().build(); @@ -34,4 +37,5 @@ public class TestOAuth2AuthorizationExchanges { OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.error().build(); return new OAuth2AuthorizationExchange(request, response); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java index b160cdbf32..eaf559a28d 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationRequests.java @@ -23,21 +23,28 @@ import java.util.Map; * @author Rob Winch * @since 5.1 */ -public class TestOAuth2AuthorizationRequests { +public final class TestOAuth2AuthorizationRequests { + + private TestOAuth2AuthorizationRequests() { + } + public static OAuth2AuthorizationRequest.Builder request() { String registrationId = "registration-id"; String clientId = "client-id"; Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registrationId); + // @formatter:off return OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/login/oauth/authorize") .clientId(clientId) .redirectUri("https://example.com/authorize/oauth2/code/registration-id") .state("state") .attributes(attributes); + // @formatter:on } public static OAuth2AuthorizationRequest.Builder oidcRequest() { return request().scope("openid"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationResponses.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationResponses.java index 7ad7085c71..b4c6db04bb 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationResponses.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOAuth2AuthorizationResponses.java @@ -20,17 +20,25 @@ package org.springframework.security.oauth2.core.endpoint; * @author Rob Winch * @since 5.1 */ -public class TestOAuth2AuthorizationResponses { +public final class TestOAuth2AuthorizationResponses { + + private TestOAuth2AuthorizationResponses() { + } public static OAuth2AuthorizationResponse.Builder success() { + // @formatter:off return OAuth2AuthorizationResponse.success("authorization-code") .state("state") .redirectUri("https://example.com/authorize/oauth2/code/registration-id"); + // @formatter:on } public static OAuth2AuthorizationResponse.Builder error() { + // @formatter:off return OAuth2AuthorizationResponse.error("error") .redirectUri("https://example.com/authorize/oauth2/code/registration-id") .errorUri("https://example.com/error"); + // @formatter:on } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverterTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverterTests.java index e3ef18ee84..438f221169 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverterTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2AccessTokenResponseHttpMessageConverterTests.java @@ -13,10 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.http.converter; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + import org.junit.Before; import org.junit.Test; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageNotReadableException; @@ -26,17 +35,13 @@ import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.entry; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link OAuth2AccessTokenResponseHttpMessageConverter}. @@ -44,6 +49,7 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class OAuth2AccessTokenResponseHttpMessageConverterTests { + private OAuth2AccessTokenResponseHttpMessageConverter messageConverter; @Before @@ -58,97 +64,90 @@ public class OAuth2AccessTokenResponseHttpMessageConverterTests { @Test public void setTokenResponseConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.messageConverter.setTokenResponseConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.messageConverter.setTokenResponseConverter(null)); } @Test public void setTokenResponseParametersConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.messageConverter.setTokenResponseParametersConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.messageConverter.setTokenResponseParametersConverter(null)); } @Test public void readInternalWhenSuccessfulTokenResponseThenReadOAuth2AccessTokenResponse() throws Exception { - String tokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": \"3600\",\n" + - " \"scope\": \"read write\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n" + - "}\n"; - - MockClientHttpResponse response = new MockClientHttpResponse( - tokenResponse.getBytes(), HttpStatus.OK); - - OAuth2AccessTokenResponse accessTokenResponse = this.messageConverter.readInternal( - OAuth2AccessTokenResponse.class, response); - + // @formatter:off + String tokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(tokenResponse.getBytes(), HttpStatus.OK); + OAuth2AccessTokenResponse accessTokenResponse = this.messageConverter + .readInternal(OAuth2AccessTokenResponse.class, response); assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); - assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBeforeOrEqualTo(Instant.now().plusSeconds(3600)); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()) + .isBeforeOrEqualTo(Instant.now().plusSeconds(3600)); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token-1234"); assertThat(accessTokenResponse.getAdditionalParameters()).containsExactly( entry("custom_parameter_1", "custom-value-1"), entry("custom_parameter_2", "custom-value-2")); - } // gh-6463 @Test public void readInternalWhenSuccessfulTokenResponseWithObjectThenReadOAuth2AccessTokenResponse() { - String tokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": 3600,\n" + - " \"scope\": \"read write\",\n" + - " \"refresh_token\": \"refresh-token-1234\",\n" + - " \"custom_object_1\": {\"name1\": \"value1\"},\n" + - " \"custom_object_2\": [\"value1\", \"value2\"],\n" + - " \"custom_parameter_1\": \"custom-value-1\",\n" + - " \"custom_parameter_2\": \"custom-value-2\"\n" + - "}\n"; - - MockClientHttpResponse response = new MockClientHttpResponse( - tokenResponse.getBytes(), HttpStatus.OK); - - OAuth2AccessTokenResponse accessTokenResponse = this.messageConverter.readInternal( - OAuth2AccessTokenResponse.class, response); - + // @formatter:off + String tokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": 3600,\n" + + " \"scope\": \"read write\",\n" + + " \"refresh_token\": \"refresh-token-1234\",\n" + + " \"custom_object_1\": {\"name1\": \"value1\"},\n" + + " \"custom_object_2\": [\"value1\", \"value2\"],\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(tokenResponse.getBytes(), HttpStatus.OK); + OAuth2AccessTokenResponse accessTokenResponse = this.messageConverter + .readInternal(OAuth2AccessTokenResponse.class, response); assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); - assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBeforeOrEqualTo(Instant.now().plusSeconds(3600)); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()) + .isBeforeOrEqualTo(Instant.now().plusSeconds(3600)); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write"); assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token-1234"); assertThat(accessTokenResponse.getAdditionalParameters()).containsExactly( - entry("custom_object_1", "{name1=value1}"), - entry("custom_object_2", "[value1, value2]"), - entry("custom_parameter_1", "custom-value-1"), - entry("custom_parameter_2", "custom-value-2")); + entry("custom_object_1", "{name1=value1}"), entry("custom_object_2", "[value1, value2]"), + entry("custom_parameter_1", "custom-value-1"), entry("custom_parameter_2", "custom-value-2")); } // gh-8108 @Test public void readInternalWhenSuccessfulTokenResponseWithNullValueThenReadOAuth2AccessTokenResponse() { - String tokenResponse = "{\n" + - " \"access_token\": \"access-token-1234\",\n" + - " \"token_type\": \"bearer\",\n" + - " \"expires_in\": 3600,\n" + - " \"scope\": null,\n" + - " \"refresh_token\": \"refresh-token-1234\"\n" + - "}\n"; - - MockClientHttpResponse response = new MockClientHttpResponse( - tokenResponse.getBytes(), HttpStatus.OK); - - OAuth2AccessTokenResponse accessTokenResponse = this.messageConverter.readInternal( - OAuth2AccessTokenResponse.class, response); - + // @formatter:off + String tokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": 3600,\n" + + " \"scope\": null,\n" + + " \"refresh_token\": \"refresh-token-1234\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(tokenResponse.getBytes(), HttpStatus.OK); + OAuth2AccessTokenResponse accessTokenResponse = this.messageConverter + .readInternal(OAuth2AccessTokenResponse.class, response); assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); - assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBeforeOrEqualTo(Instant.now().plusSeconds(3600)); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()) + .isBeforeOrEqualTo(Instant.now().plusSeconds(3600)); assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("null"); assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token-1234"); } @@ -156,17 +155,13 @@ public class OAuth2AccessTokenResponseHttpMessageConverterTests { @Test public void readInternalWhenConversionFailsThenThrowHttpMessageNotReadableException() { Converter tokenResponseConverter = mock(Converter.class); - when(tokenResponseConverter.convert(any())).thenThrow(RuntimeException.class); + given(tokenResponseConverter.convert(any())).willThrow(RuntimeException.class); this.messageConverter.setTokenResponseConverter(tokenResponseConverter); - String tokenResponse = "{}"; - - MockClientHttpResponse response = new MockClientHttpResponse( - tokenResponse.getBytes(), HttpStatus.OK); - - assertThatThrownBy(() -> this.messageConverter.readInternal(OAuth2AccessTokenResponse.class, response)) - .isInstanceOf(HttpMessageNotReadableException.class) - .hasMessageContaining("An error occurred reading the OAuth 2.0 Access Token Response"); + MockClientHttpResponse response = new MockClientHttpResponse(tokenResponse.getBytes(), HttpStatus.OK); + assertThatExceptionOfType(HttpMessageNotReadableException.class) + .isThrownBy(() -> this.messageConverter.readInternal(OAuth2AccessTokenResponse.class, response)) + .withMessageContaining("An error occurred reading the OAuth 2.0 Access Token Response"); } @Test @@ -176,20 +171,18 @@ public class OAuth2AccessTokenResponseHttpMessageConverterTests { Map additionalParameters = new HashMap<>(); additionalParameters.put("custom_parameter_1", "custom-value-1"); additionalParameters.put("custom_parameter_2", "custom-value-2"); - - OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse - .withToken("access-token-1234") + // @formatter:off + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(expiresAt.toEpochMilli()) .scopes(scopes) .refreshToken("refresh-token-1234") .additionalParameters(additionalParameters) .build(); - + // @formatter:on MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); this.messageConverter.writeInternal(accessTokenResponse, outputMessage); String tokenResponse = outputMessage.getBodyAsString(); - assertThat(tokenResponse).contains("\"access_token\":\"access-token-1234\""); assertThat(tokenResponse).contains("\"token_type\":\"Bearer\""); assertThat(tokenResponse).contains("\"expires_in\""); @@ -202,19 +195,20 @@ public class OAuth2AccessTokenResponseHttpMessageConverterTests { @Test public void writeInternalWhenConversionFailsThenThrowHttpMessageNotWritableException() { Converter tokenResponseParametersConverter = mock(Converter.class); - when(tokenResponseParametersConverter.convert(any())).thenThrow(RuntimeException.class); + given(tokenResponseParametersConverter.convert(any())).willThrow(RuntimeException.class); this.messageConverter.setTokenResponseParametersConverter(tokenResponseParametersConverter); - + // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) - .expiresIn(Instant.now().plusSeconds(3600).toEpochMilli()) + .expiresIn(Instant.now().plusSeconds(3600) + .toEpochMilli()) .build(); - + // @formatter:on MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - - assertThatThrownBy(() -> this.messageConverter.writeInternal(accessTokenResponse, outputMessage)) - .isInstanceOf(HttpMessageNotWritableException.class) - .hasMessageContaining("An error occurred writing the OAuth 2.0 Access Token Response"); + assertThatExceptionOfType(HttpMessageNotWritableException.class) + .isThrownBy(() -> this.messageConverter.writeInternal(accessTokenResponse, outputMessage)) + .withMessageContaining("An error occurred writing the OAuth 2.0 Access Token Response"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverterTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverterTests.java index 11211aad56..35e697fdf6 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverterTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/http/converter/OAuth2ErrorHttpMessageConverterTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.core.http.converter; import org.junit.Before; import org.junit.Test; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageNotReadableException; @@ -26,10 +28,11 @@ import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.security.oauth2.core.OAuth2Error; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link OAuth2ErrorHttpMessageConverter}. @@ -37,6 +40,7 @@ import static org.mockito.Mockito.when; * @author Joe Grandja */ public class OAuth2ErrorHttpMessageConverterTests { + private OAuth2ErrorHttpMessageConverter messageConverter; @Before @@ -51,27 +55,24 @@ public class OAuth2ErrorHttpMessageConverterTests { @Test public void setErrorConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.messageConverter.setErrorConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.messageConverter.setErrorConverter(null)); } @Test public void setErrorParametersConverterWhenConverterIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.messageConverter.setErrorParametersConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.messageConverter.setErrorParametersConverter(null)); } @Test public void readInternalWhenErrorResponseThenReadOAuth2Error() throws Exception { - String errorResponse = "{\n" + - " \"error\": \"unauthorized_client\",\n" + - " \"error_description\": \"The client is not authorized\",\n" + - " \"error_uri\": \"https://tools.ietf.org/html/rfc6749#section-5.2\"\n" + - "}\n"; - - MockClientHttpResponse response = new MockClientHttpResponse( - errorResponse.getBytes(), HttpStatus.BAD_REQUEST); - + // @formatter:off + String errorResponse = "{\n" + + " \"error\": \"unauthorized_client\",\n" + + " \"error_description\": \"The client is not authorized\",\n" + + " \"error_uri\": \"https://tools.ietf.org/html/rfc6749#section-5.2\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(errorResponse.getBytes(), HttpStatus.BAD_REQUEST); OAuth2Error oauth2Error = this.messageConverter.readInternal(OAuth2Error.class, response); assertThat(oauth2Error.getErrorCode()).isEqualTo("unauthorized_client"); assertThat(oauth2Error.getDescription()).isEqualTo("The client is not authorized"); @@ -81,16 +82,15 @@ public class OAuth2ErrorHttpMessageConverterTests { // gh-8157 @Test public void readInternalWhenErrorResponseWithObjectThenReadOAuth2Error() throws Exception { - String errorResponse = "{\n" + - " \"error\": \"unauthorized_client\",\n" + - " \"error_description\": \"The client is not authorized\",\n" + - " \"error_codes\": [65001],\n" + - " \"error_uri\": \"https://tools.ietf.org/html/rfc6749#section-5.2\"\n" + - "}\n"; - - MockClientHttpResponse response = new MockClientHttpResponse( - errorResponse.getBytes(), HttpStatus.BAD_REQUEST); - + // @formatter:off + String errorResponse = "{\n" + + " \"error\": \"unauthorized_client\",\n" + + " \"error_description\": \"The client is not authorized\",\n" + + " \"error_codes\": [65001],\n" + + " \"error_uri\": \"https://tools.ietf.org/html/rfc6749#section-5.2\"\n" + + "}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse(errorResponse.getBytes(), HttpStatus.BAD_REQUEST); OAuth2Error oauth2Error = this.messageConverter.readInternal(OAuth2Error.class, response); assertThat(oauth2Error.getErrorCode()).isEqualTo("unauthorized_client"); assertThat(oauth2Error.getDescription()).isEqualTo("The client is not authorized"); @@ -100,28 +100,22 @@ public class OAuth2ErrorHttpMessageConverterTests { @Test public void readInternalWhenConversionFailsThenThrowHttpMessageNotReadableException() { Converter errorConverter = mock(Converter.class); - when(errorConverter.convert(any())).thenThrow(RuntimeException.class); + given(errorConverter.convert(any())).willThrow(RuntimeException.class); this.messageConverter.setErrorConverter(errorConverter); - String errorResponse = "{}"; - - MockClientHttpResponse response = new MockClientHttpResponse( - errorResponse.getBytes(), HttpStatus.BAD_REQUEST); - - assertThatThrownBy(() -> this.messageConverter.readInternal(OAuth2Error.class, response)) - .isInstanceOf(HttpMessageNotReadableException.class) - .hasMessageContaining("An error occurred reading the OAuth 2.0 Error"); + MockClientHttpResponse response = new MockClientHttpResponse(errorResponse.getBytes(), HttpStatus.BAD_REQUEST); + assertThatExceptionOfType(HttpMessageNotReadableException.class) + .isThrownBy(() -> this.messageConverter.readInternal(OAuth2Error.class, response)) + .withMessageContaining("An error occurred reading the OAuth 2.0 Error"); } @Test public void writeInternalWhenOAuth2ErrorThenWriteErrorResponse() throws Exception { - OAuth2Error oauth2Error = new OAuth2Error("unauthorized_client", - "The client is not authorized", "https://tools.ietf.org/html/rfc6749#section-5.2"); - + OAuth2Error oauth2Error = new OAuth2Error("unauthorized_client", "The client is not authorized", + "https://tools.ietf.org/html/rfc6749#section-5.2"); MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); this.messageConverter.writeInternal(oauth2Error, outputMessage); String errorResponse = outputMessage.getBodyAsString(); - assertThat(errorResponse).contains("\"error\":\"unauthorized_client\""); assertThat(errorResponse).contains("\"error_description\":\"The client is not authorized\""); assertThat(errorResponse).contains("\"error_uri\":\"https://tools.ietf.org/html/rfc6749#section-5.2\""); @@ -130,16 +124,14 @@ public class OAuth2ErrorHttpMessageConverterTests { @Test public void writeInternalWhenConversionFailsThenThrowHttpMessageNotWritableException() { Converter errorParametersConverter = mock(Converter.class); - when(errorParametersConverter.convert(any())).thenThrow(RuntimeException.class); + given(errorParametersConverter.convert(any())).willThrow(RuntimeException.class); this.messageConverter.setErrorParametersConverter(errorParametersConverter); - - OAuth2Error oauth2Error = new OAuth2Error("unauthorized_client", - "The client is not authorized", "https://tools.ietf.org/html/rfc6749#section-5.2"); - + OAuth2Error oauth2Error = new OAuth2Error("unauthorized_client", "The client is not authorized", + "https://tools.ietf.org/html/rfc6749#section-5.2"); MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - - assertThatThrownBy(() -> this.messageConverter.writeInternal(oauth2Error, outputMessage)) - .isInstanceOf(HttpMessageNotWritableException.class) - .hasMessageContaining("An error occurred writing the OAuth 2.0 Error"); + assertThatExceptionOfType(HttpMessageNotWritableException.class) + .isThrownBy(() -> this.messageConverter.writeInternal(oauth2Error, outputMessage)) + .withMessageContaining("An error occurred writing the OAuth 2.0 Error"); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaimTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaimTests.java index f22653985e..d85fd854af 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaimTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/DefaultAddressStandardClaimTests.java @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc; -import org.junit.Test; +package org.springframework.security.oauth2.core.oidc; import java.util.HashMap; import java.util.Map; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -28,6 +29,7 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class DefaultAddressStandardClaimTests { + static final String FORMATTED_FIELD_NAME = "formatted"; static final String STREET_ADDRESS_FIELD_NAME = "street_address"; static final String LOCALITY_FIELD_NAME = "locality"; @@ -43,16 +45,16 @@ public class DefaultAddressStandardClaimTests { @Test public void buildWhenAllAttributesProvidedThenAllAttributesAreSet() { - AddressStandardClaim addressStandardClaim = - new DefaultAddressStandardClaim.Builder() - .formatted(FORMATTED) - .streetAddress(STREET_ADDRESS) - .locality(LOCALITY) - .region(REGION) - .postalCode(POSTAL_CODE) - .country(COUNTRY) - .build(); - + // @formatter:off + AddressStandardClaim addressStandardClaim = new DefaultAddressStandardClaim.Builder() + .formatted(FORMATTED) + .streetAddress(STREET_ADDRESS) + .locality(LOCALITY) + .region(REGION) + .postalCode(POSTAL_CODE) + .country(COUNTRY) + .build(); + // @formatter:on assertThat(addressStandardClaim.getFormatted()).isEqualTo(FORMATTED); assertThat(addressStandardClaim.getStreetAddress()).isEqualTo(STREET_ADDRESS); assertThat(addressStandardClaim.getLocality()).isEqualTo(LOCALITY); @@ -70,11 +72,7 @@ public class DefaultAddressStandardClaimTests { addressFields.put(REGION_FIELD_NAME, REGION); addressFields.put(POSTAL_CODE_FIELD_NAME, POSTAL_CODE); addressFields.put(COUNTRY_FIELD_NAME, COUNTRY); - - AddressStandardClaim addressStandardClaim = - new DefaultAddressStandardClaim.Builder(addressFields) - .build(); - + AddressStandardClaim addressStandardClaim = new DefaultAddressStandardClaim.Builder(addressFields).build(); assertThat(addressStandardClaim.getFormatted()).isEqualTo(FORMATTED); assertThat(addressStandardClaim.getStreetAddress()).isEqualTo(STREET_ADDRESS); assertThat(addressStandardClaim.getLocality()).isEqualTo(LOCALITY); @@ -82,4 +80,5 @@ public class DefaultAddressStandardClaimTests { assertThat(addressStandardClaim.getPostalCode()).isEqualTo(POSTAL_CODE); assertThat(addressStandardClaim.getCountry()).isEqualTo(COUNTRY); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenBuilderTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenBuilderTests.java index c3b049eb06..56d8f7e48c 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenBuilderTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenBuilderTests.java @@ -21,34 +21,28 @@ import java.time.Instant; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.EXP; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.IAT; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.SUB; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link OidcUserInfo} */ public class OidcIdTokenBuilderTests { + @Test public void buildWhenCalledTwiceThenGeneratesTwoOidcIdTokens() { OidcIdToken.Builder idTokenBuilder = OidcIdToken.withTokenValue("token"); - - OidcIdToken first = idTokenBuilder - .tokenValue("V1") + // @formatter:off + OidcIdToken first = idTokenBuilder.tokenValue("V1") .claim("TEST_CLAIM_1", "C1") .build(); - - OidcIdToken second = idTokenBuilder - .tokenValue("V2") + OidcIdToken second = idTokenBuilder.tokenValue("V2") .claim("TEST_CLAIM_1", "C2") .claim("TEST_CLAIM_2", "C3") .build(); - + // @formatter:on assertThat(first.getClaims()).hasSize(1); assertThat(first.getClaims().get("TEST_CLAIM_1")).isEqualTo("C1"); assertThat(first.getTokenValue()).isEqualTo("V1"); - assertThat(second.getClaims()).hasSize(2); assertThat(second.getClaims().get("TEST_CLAIM_1")).isEqualTo("C2"); assertThat(second.getClaims().get("TEST_CLAIM_2")).isEqualTo("C3"); @@ -58,82 +52,67 @@ public class OidcIdTokenBuilderTests { @Test public void expiresAtWhenUsingGenericOrNamedClaimMethodRequiresInstant() { OidcIdToken.Builder idTokenBuilder = OidcIdToken.withTokenValue("token"); - Instant now = Instant.now(); - - OidcIdToken idToken = idTokenBuilder - .expiresAt(now).build(); + OidcIdToken idToken = idTokenBuilder.expiresAt(now).build(); assertThat(idToken.getExpiresAt()).isSameAs(now); - - idToken = idTokenBuilder - .expiresAt(now).build(); + idToken = idTokenBuilder.expiresAt(now).build(); assertThat(idToken.getExpiresAt()).isSameAs(now); - - assertThatCode(() -> idTokenBuilder - .claim(EXP, "not an instant").build()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> idTokenBuilder.claim(IdTokenClaimNames.EXP, "not an instant").build()); } @Test public void issuedAtWhenUsingGenericOrNamedClaimMethodRequiresInstant() { OidcIdToken.Builder idTokenBuilder = OidcIdToken.withTokenValue("token"); - Instant now = Instant.now(); - - OidcIdToken idToken = idTokenBuilder - .issuedAt(now).build(); + OidcIdToken idToken = idTokenBuilder.issuedAt(now).build(); assertThat(idToken.getIssuedAt()).isSameAs(now); - - idToken = idTokenBuilder - .issuedAt(now).build(); + idToken = idTokenBuilder.issuedAt(now).build(); assertThat(idToken.getIssuedAt()).isSameAs(now); - - assertThatCode(() -> idTokenBuilder - .claim(IAT, "not an instant").build()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> idTokenBuilder.claim(IdTokenClaimNames.IAT, "not an instant").build()); } @Test public void subjectWhenUsingGenericOrNamedClaimMethodThenLastOneWins() { OidcIdToken.Builder idTokenBuilder = OidcIdToken.withTokenValue("token"); - String generic = new String("sub"); String named = new String("sub"); - + // @formatter:off OidcIdToken idToken = idTokenBuilder .subject(named) - .claim(SUB, generic).build(); + .claim(IdTokenClaimNames.SUB, generic) + .build(); + // @formatter:on assertThat(idToken.getSubject()).isSameAs(generic); - - idToken = idTokenBuilder - .claim(SUB, generic) - .subject(named).build(); + idToken = idTokenBuilder.claim(IdTokenClaimNames.SUB, generic).subject(named).build(); assertThat(idToken.getSubject()).isSameAs(named); } @Test public void claimsWhenRemovingAClaimThenIsNotPresent() { + // @formatter:off OidcIdToken.Builder idTokenBuilder = OidcIdToken.withTokenValue("token") .claim("needs", "a claim"); - - OidcIdToken idToken = idTokenBuilder - .subject("sub") - .claims(claims -> claims.remove(SUB)) + OidcIdToken idToken = idTokenBuilder.subject("sub") + .claims((claims) -> claims.remove(IdTokenClaimNames.SUB)) .build(); + // @formatter:on assertThat(idToken.getSubject()).isNull(); } @Test public void claimsWhenAddingAClaimThenIsPresent() { OidcIdToken.Builder idTokenBuilder = OidcIdToken.withTokenValue("token"); - String name = new String("name"); String value = new String("value"); + // @formatter:off OidcIdToken idToken = idTokenBuilder - .claims(claims -> claims.put(name, value)) + .claims((claims) -> claims.put(name, value)) .build(); - + // @formatter:on assertThat(idToken.getClaims()).hasSize(1); assertThat(idToken.getClaims().get(name)).isSameAs(value); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenTests.java index 15a04e6b0a..4f795c6f95 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc; -import org.junit.Test; +package org.springframework.security.oauth2.core.oidc; import java.time.Instant; import java.util.Arrays; @@ -24,6 +23,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -32,35 +33,58 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class OidcIdTokenTests { + private static final String ISS_CLAIM = "iss"; + private static final String SUB_CLAIM = "sub"; + private static final String AUD_CLAIM = "aud"; + private static final String IAT_CLAIM = "iat"; + private static final String EXP_CLAIM = "exp"; + private static final String AUTH_TIME_CLAIM = "auth_time"; + private static final String NONCE_CLAIM = "nonce"; + private static final String ACR_CLAIM = "acr"; + private static final String AMR_CLAIM = "amr"; + private static final String AZP_CLAIM = "azp"; + private static final String AT_HASH_CLAIM = "at_hash"; + private static final String C_HASH_CLAIM = "c_hash"; private static final String ISS_VALUE = "https://provider.com"; + private static final String SUB_VALUE = "subject1"; + private static final List AUD_VALUE = Arrays.asList("aud1", "aud2"); + private static final long IAT_VALUE = Instant.now().toEpochMilli(); + private static final long EXP_VALUE = Instant.now().plusSeconds(60).toEpochMilli(); + private static final long AUTH_TIME_VALUE = Instant.now().minusSeconds(5).toEpochMilli(); + private static final String NONCE_VALUE = "nonce"; + private static final String ACR_VALUE = "acr"; + private static final List AMR_VALUE = Arrays.asList("amr1", "amr2"); + private static final String AZP_VALUE = "azp"; + private static final String AT_HASH_VALUE = "at_hash"; + private static final String C_HASH_VALUE = "c_hash"; private static final Map CLAIMS; - private static final String ID_TOKEN_VALUE = "id-token-value"; + private static final String ID_TOKEN_VALUE = "id-token-value"; static { CLAIMS = new HashMap<>(); CLAIMS.put(ISS_CLAIM, ISS_VALUE); @@ -84,15 +108,14 @@ public class OidcIdTokenTests { @Test(expected = IllegalArgumentException.class) public void constructorWhenClaimsIsEmptyThenThrowIllegalArgumentException() { - new OidcIdToken(ID_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), - Instant.ofEpochMilli(EXP_VALUE), Collections.emptyMap()); + new OidcIdToken(ID_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), Instant.ofEpochMilli(EXP_VALUE), + Collections.emptyMap()); } @Test public void constructorWhenParametersProvidedAndValidThenCreated() { OidcIdToken idToken = new OidcIdToken(ID_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), - Instant.ofEpochMilli(EXP_VALUE), CLAIMS); - + Instant.ofEpochMilli(EXP_VALUE), CLAIMS); assertThat(idToken.getClaims()).isEqualTo(CLAIMS); assertThat(idToken.getTokenValue()).isEqualTo(ID_TOKEN_VALUE); assertThat(idToken.getIssuer().toString()).isEqualTo(ISS_VALUE); @@ -108,4 +131,5 @@ public class OidcIdTokenTests { assertThat(idToken.getAccessTokenHash()).isEqualTo(AT_HASH_VALUE); assertThat(idToken.getAuthorizationCodeHash()).isEqualTo(C_HASH_VALUE); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoBuilderTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoBuilderTests.java index 9b1c057016..b7e02d0ed8 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoBuilderTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoBuilderTests.java @@ -19,28 +19,26 @@ package org.springframework.security.oauth2.core.oidc; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.core.oidc.IdTokenClaimNames.SUB; /** * Tests for {@link OidcUserInfo} */ public class OidcUserInfoBuilderTests { + @Test public void buildWhenCalledTwiceThenGeneratesTwoOidcUserInfos() { OidcUserInfo.Builder userInfoBuilder = OidcUserInfo.builder(); - + // @formatter:off OidcUserInfo first = userInfoBuilder .claim("TEST_CLAIM_1", "C1") .build(); - OidcUserInfo second = userInfoBuilder .claim("TEST_CLAIM_1", "C2") .claim("TEST_CLAIM_2", "C3") .build(); - + // @formatter:on assertThat(first.getClaims()).hasSize(1); assertThat(first.getClaims().get("TEST_CLAIM_1")).isEqualTo("C1"); - assertThat(second.getClaims()).hasSize(2); assertThat(second.getClaims().get("TEST_CLAIM_1")).isEqualTo("C2"); assertThat(second.getClaims().get("TEST_CLAIM_2")).isEqualTo("C3"); @@ -49,44 +47,48 @@ public class OidcUserInfoBuilderTests { @Test public void subjectWhenUsingGenericOrNamedClaimMethodThenLastOneWins() { OidcUserInfo.Builder userInfoBuilder = OidcUserInfo.builder(); - String generic = new String("sub"); String named = new String("sub"); - + // @formatter:off OidcUserInfo userInfo = userInfoBuilder .subject(named) - .claim(SUB, generic).build(); + .claim(IdTokenClaimNames.SUB, generic) + .build(); + // @formatter:on assertThat(userInfo.getSubject()).isSameAs(generic); - + // @formatter:off userInfo = userInfoBuilder - .claim(SUB, generic) - .subject(named).build(); + .claim(IdTokenClaimNames.SUB, generic) + .subject(named) + .build(); + // @formatter:on assertThat(userInfo.getSubject()).isSameAs(named); } @Test public void claimsWhenRemovingAClaimThenIsNotPresent() { + // @formatter:off OidcUserInfo.Builder userInfoBuilder = OidcUserInfo.builder() .claim("needs", "a claim"); - - OidcUserInfo userInfo = userInfoBuilder - .subject("sub") - .claims(claims -> claims.remove(SUB)) + OidcUserInfo userInfo = userInfoBuilder.subject("sub") + .claims((claims) -> claims.remove(IdTokenClaimNames.SUB)) .build(); + // @formatter:on assertThat(userInfo.getSubject()).isNull(); } @Test public void claimsWhenAddingAClaimThenIsPresent() { OidcUserInfo.Builder userInfoBuilder = OidcUserInfo.builder(); - String name = new String("name"); String value = new String("value"); + // @formatter:off OidcUserInfo userInfo = userInfoBuilder - .claims(claims -> claims.put(name, value)) + .claims((claims) -> claims.put(name, value)) .build(); - + // @formatter:on assertThat(userInfo.getClaims()).hasSize(1); assertThat(userInfo.getClaims().get(name)).isSameAs(value); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoTests.java index c94e798d84..53fe17d28d 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcUserInfoTests.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc; -import org.junit.Test; +package org.springframework.security.oauth2.core.oidc; import java.time.Instant; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.core.oidc.DefaultAddressStandardClaimTests.*; /** * Tests for {@link OidcUserInfo}. @@ -31,50 +31,88 @@ import static org.springframework.security.oauth2.core.oidc.DefaultAddressStanda * @author Joe Grandja */ public class OidcUserInfoTests { + private static final String SUB_CLAIM = "sub"; + private static final String NAME_CLAIM = "name"; + private static final String GIVEN_NAME_CLAIM = "given_name"; + private static final String FAMILY_NAME_CLAIM = "family_name"; + private static final String MIDDLE_NAME_CLAIM = "middle_name"; + private static final String NICKNAME_CLAIM = "nickname"; + private static final String PREFERRED_USERNAME_CLAIM = "preferred_username"; + private static final String PROFILE_CLAIM = "profile"; + private static final String PICTURE_CLAIM = "picture"; + private static final String WEBSITE_CLAIM = "website"; + private static final String EMAIL_CLAIM = "email"; + private static final String EMAIL_VERIFIED_CLAIM = "email_verified"; + private static final String GENDER_CLAIM = "gender"; + private static final String BIRTHDATE_CLAIM = "birthdate"; + private static final String ZONEINFO_CLAIM = "zoneinfo"; + private static final String LOCALE_CLAIM = "locale"; + private static final String PHONE_NUMBER_CLAIM = "phone_number"; + private static final String PHONE_NUMBER_VERIFIED_CLAIM = "phone_number_verified"; + private static final String ADDRESS_CLAIM = "address"; + private static final String UPDATED_AT_CLAIM = "updated_at"; private static final String SUB_VALUE = "subject1"; + private static final String NAME_VALUE = "full_name"; + private static final String GIVEN_NAME_VALUE = "given_name"; + private static final String FAMILY_NAME_VALUE = "family_name"; + private static final String MIDDLE_NAME_VALUE = "middle_name"; + private static final String NICKNAME_VALUE = "nickname"; + private static final String PREFERRED_USERNAME_VALUE = "preferred_username"; + private static final String PROFILE_VALUE = "profile"; + private static final String PICTURE_VALUE = "picture"; + private static final String WEBSITE_VALUE = "website"; + private static final String EMAIL_VALUE = "email"; + private static final Boolean EMAIL_VERIFIED_VALUE = true; + private static final String GENDER_VALUE = "gender"; + private static final String BIRTHDATE_VALUE = "birthdate"; + private static final String ZONEINFO_VALUE = "zoneinfo"; + private static final String LOCALE_VALUE = "locale"; + private static final String PHONE_NUMBER_VALUE = "phone_number"; + private static final Boolean PHONE_NUMBER_VERIFIED_VALUE = true; + private static final Map ADDRESS_VALUE; + private static final long UPDATED_AT_VALUE = Instant.now().minusSeconds(60).toEpochMilli(); private static final Map CLAIMS; - static { CLAIMS = new HashMap<>(); CLAIMS.put(SUB_CLAIM, SUB_VALUE); @@ -95,16 +133,19 @@ public class OidcUserInfoTests { CLAIMS.put(LOCALE_CLAIM, LOCALE_VALUE); CLAIMS.put(PHONE_NUMBER_CLAIM, PHONE_NUMBER_VALUE); CLAIMS.put(PHONE_NUMBER_VERIFIED_CLAIM, PHONE_NUMBER_VERIFIED_VALUE); - ADDRESS_VALUE = new HashMap<>(); - ADDRESS_VALUE.put(FORMATTED_FIELD_NAME, FORMATTED); - ADDRESS_VALUE.put(STREET_ADDRESS_FIELD_NAME, STREET_ADDRESS); - ADDRESS_VALUE.put(LOCALITY_FIELD_NAME, LOCALITY); - ADDRESS_VALUE.put(REGION_FIELD_NAME, REGION); - ADDRESS_VALUE.put(POSTAL_CODE_FIELD_NAME, POSTAL_CODE); - ADDRESS_VALUE.put(COUNTRY_FIELD_NAME, COUNTRY); + ADDRESS_VALUE.put(DefaultAddressStandardClaimTests.FORMATTED_FIELD_NAME, + DefaultAddressStandardClaimTests.FORMATTED); + ADDRESS_VALUE.put(DefaultAddressStandardClaimTests.STREET_ADDRESS_FIELD_NAME, + DefaultAddressStandardClaimTests.STREET_ADDRESS); + ADDRESS_VALUE.put(DefaultAddressStandardClaimTests.LOCALITY_FIELD_NAME, + DefaultAddressStandardClaimTests.LOCALITY); + ADDRESS_VALUE.put(DefaultAddressStandardClaimTests.REGION_FIELD_NAME, DefaultAddressStandardClaimTests.REGION); + ADDRESS_VALUE.put(DefaultAddressStandardClaimTests.POSTAL_CODE_FIELD_NAME, + DefaultAddressStandardClaimTests.POSTAL_CODE); + ADDRESS_VALUE.put(DefaultAddressStandardClaimTests.COUNTRY_FIELD_NAME, + DefaultAddressStandardClaimTests.COUNTRY); CLAIMS.put(ADDRESS_CLAIM, ADDRESS_VALUE); - CLAIMS.put(UPDATED_AT_CLAIM, UPDATED_AT_VALUE); } @@ -116,7 +157,6 @@ public class OidcUserInfoTests { @Test public void constructorWhenParametersProvidedAndValidThenCreated() { OidcUserInfo userInfo = new OidcUserInfo(CLAIMS); - assertThat(userInfo.getClaims()).isEqualTo(CLAIMS); assertThat(userInfo.getSubject()).isEqualTo(SUB_VALUE); assertThat(userInfo.getFullName()).isEqualTo(NAME_VALUE); @@ -139,4 +179,5 @@ public class OidcUserInfoTests { assertThat(userInfo.getAddress()).isEqualTo(new DefaultAddressStandardClaim.Builder(ADDRESS_VALUE).build()); assertThat(userInfo.getUpdatedAt().getEpochSecond()).isEqualTo(UPDATED_AT_VALUE); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java index a866554f54..ca859473d1 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java @@ -18,20 +18,26 @@ package org.springframework.security.oauth2.core.oidc; import java.time.Instant; -import static org.springframework.security.oauth2.core.oidc.OidcIdToken.withTokenValue; - /** * Test {@link OidcIdToken}s * * @author Josh Cummings */ -public class TestOidcIdTokens { +public final class TestOidcIdTokens { + + private TestOidcIdTokens() { + } + public static OidcIdToken.Builder idToken() { - return withTokenValue("id-token") + // @formatter:off + return OidcIdToken.withTokenValue("id-token") .issuer("https://example.com") .subject("subject") .issuedAt(Instant.now()) - .expiresAt(Instant.now().plusSeconds(86400)) + .expiresAt(Instant.now() + .plusSeconds(86400)) .claim("id", "id"); + // @formatter:on } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUserTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUserTests.java index 2fad69684f..78dd8b494f 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUserTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/DefaultOidcUserTests.java @@ -16,7 +16,14 @@ package org.springframework.security.oauth2.core.oidc.user; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; @@ -24,12 +31,6 @@ import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.StandardClaimNames; -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -39,22 +40,29 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class DefaultOidcUserTests { - private static final SimpleGrantedAuthority AUTHORITY = new SimpleGrantedAuthority("ROLE_USER"); - private static final Set AUTHORITIES = Collections.singleton(AUTHORITY); - private static final String SUBJECT = "test-subject"; - private static final String EMAIL = "test-subject@example.com"; - private static final String NAME = "test-name"; - private static final Map ID_TOKEN_CLAIMS = new HashMap<>(); - private static final Map USER_INFO_CLAIMS = new HashMap<>(); + private static final SimpleGrantedAuthority AUTHORITY = new SimpleGrantedAuthority("ROLE_USER"); + + private static final Set AUTHORITIES = Collections.singleton(AUTHORITY); + + private static final String SUBJECT = "test-subject"; + + private static final String EMAIL = "test-subject@example.com"; + + private static final String NAME = "test-name"; + + private static final Map ID_TOKEN_CLAIMS = new HashMap<>(); + + private static final Map USER_INFO_CLAIMS = new HashMap<>(); static { ID_TOKEN_CLAIMS.put(IdTokenClaimNames.ISS, "https://example.com"); ID_TOKEN_CLAIMS.put(IdTokenClaimNames.SUB, SUBJECT); USER_INFO_CLAIMS.put(StandardClaimNames.NAME, NAME); USER_INFO_CLAIMS.put(StandardClaimNames.EMAIL, EMAIL); } + private static final OidcIdToken ID_TOKEN = new OidcIdToken("id-token-value", Instant.EPOCH, Instant.MAX, + ID_TOKEN_CLAIMS); - private static final OidcIdToken ID_TOKEN = new OidcIdToken("id-token-value", Instant.EPOCH, Instant.MAX, ID_TOKEN_CLAIMS); private static final OidcUserInfo USER_INFO = new OidcUserInfo(USER_INFO_CLAIMS); @Test(expected = IllegalArgumentException.class) @@ -75,7 +83,6 @@ public class DefaultOidcUserTests { @Test public void constructorWhenAuthoritiesIdTokenProvidedThenCreated() { DefaultOidcUser user = new DefaultOidcUser(AUTHORITIES, ID_TOKEN); - assertThat(user.getClaims()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB); assertThat(user.getIdToken()).isEqualTo(ID_TOKEN); assertThat(user.getName()).isEqualTo(SUBJECT); @@ -87,7 +94,6 @@ public class DefaultOidcUserTests { @Test public void constructorWhenAuthoritiesIdTokenNameAttributeKeyProvidedThenCreated() { DefaultOidcUser user = new DefaultOidcUser(AUTHORITIES, ID_TOKEN, IdTokenClaimNames.SUB); - assertThat(user.getClaims()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB); assertThat(user.getIdToken()).isEqualTo(ID_TOKEN); assertThat(user.getName()).isEqualTo(SUBJECT); @@ -99,30 +105,29 @@ public class DefaultOidcUserTests { @Test public void constructorWhenAuthoritiesIdTokenUserInfoProvidedThenCreated() { DefaultOidcUser user = new DefaultOidcUser(AUTHORITIES, ID_TOKEN, USER_INFO); - - assertThat(user.getClaims()).containsOnlyKeys( - IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, StandardClaimNames.NAME, StandardClaimNames.EMAIL); + assertThat(user.getClaims()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, + StandardClaimNames.NAME, StandardClaimNames.EMAIL); assertThat(user.getIdToken()).isEqualTo(ID_TOKEN); assertThat(user.getUserInfo()).isEqualTo(USER_INFO); assertThat(user.getName()).isEqualTo(SUBJECT); assertThat(user.getAuthorities()).hasSize(1); assertThat(user.getAuthorities().iterator().next()).isEqualTo(AUTHORITY); - assertThat(user.getAttributes()).containsOnlyKeys( - IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, StandardClaimNames.NAME, StandardClaimNames.EMAIL); + assertThat(user.getAttributes()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, + StandardClaimNames.NAME, StandardClaimNames.EMAIL); } @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { DefaultOidcUser user = new DefaultOidcUser(AUTHORITIES, ID_TOKEN, USER_INFO, StandardClaimNames.EMAIL); - - assertThat(user.getClaims()).containsOnlyKeys( - IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, StandardClaimNames.NAME, StandardClaimNames.EMAIL); + assertThat(user.getClaims()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, + StandardClaimNames.NAME, StandardClaimNames.EMAIL); assertThat(user.getIdToken()).isEqualTo(ID_TOKEN); assertThat(user.getUserInfo()).isEqualTo(USER_INFO); assertThat(user.getName()).isEqualTo(EMAIL); assertThat(user.getAuthorities()).hasSize(1); assertThat(user.getAuthorities().iterator().next()).isEqualTo(AUTHORITY); - assertThat(user.getAttributes()).containsOnlyKeys( - IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, StandardClaimNames.NAME, StandardClaimNames.EMAIL); + assertThat(user.getAttributes()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, + StandardClaimNames.NAME, StandardClaimNames.EMAIL); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthorityTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthorityTests.java index d859b74d56..f139e6c167 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthorityTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/OidcUserAuthorityTests.java @@ -13,20 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc.user; -import org.junit.Test; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import org.springframework.security.oauth2.core.oidc.OidcUserInfo; -import org.springframework.security.oauth2.core.oidc.StandardClaimNames; +package org.springframework.security.oauth2.core.oidc.user; import java.time.Instant; import java.util.HashMap; import java.util.Map; +import org.junit.Test; + +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; /** * Tests for {@link OidcUserAuthority}. @@ -34,21 +35,27 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Joe Grandja */ public class OidcUserAuthorityTests { - private static final String AUTHORITY = "ROLE_USER"; - private static final String SUBJECT = "test-subject"; - private static final String EMAIL = "test-subject@example.com"; - private static final String NAME = "test-name"; - private static final Map ID_TOKEN_CLAIMS = new HashMap<>(); - private static final Map USER_INFO_CLAIMS = new HashMap<>(); + private static final String AUTHORITY = "ROLE_USER"; + + private static final String SUBJECT = "test-subject"; + + private static final String EMAIL = "test-subject@example.com"; + + private static final String NAME = "test-name"; + + private static final Map ID_TOKEN_CLAIMS = new HashMap<>(); + + private static final Map USER_INFO_CLAIMS = new HashMap<>(); static { ID_TOKEN_CLAIMS.put(IdTokenClaimNames.ISS, "https://example.com"); ID_TOKEN_CLAIMS.put(IdTokenClaimNames.SUB, SUBJECT); USER_INFO_CLAIMS.put(StandardClaimNames.NAME, NAME); USER_INFO_CLAIMS.put(StandardClaimNames.EMAIL, EMAIL); } + private static final OidcIdToken ID_TOKEN = new OidcIdToken("id-token-value", Instant.EPOCH, Instant.MAX, + ID_TOKEN_CLAIMS); - private static final OidcIdToken ID_TOKEN = new OidcIdToken("id-token-value", Instant.EPOCH, Instant.MAX, ID_TOKEN_CLAIMS); private static final OidcUserInfo USER_INFO = new OidcUserInfo(USER_INFO_CLAIMS); @Test(expected = IllegalArgumentException.class) @@ -58,7 +65,7 @@ public class OidcUserAuthorityTests { @Test public void constructorWhenUserInfoIsNullThenDoesNotThrowAnyException() { - assertThatCode(() -> new OidcUserAuthority(ID_TOKEN, null)).doesNotThrowAnyException(); + new OidcUserAuthority(ID_TOKEN, null); } @Test(expected = IllegalArgumentException.class) @@ -69,11 +76,11 @@ public class OidcUserAuthorityTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { OidcUserAuthority userAuthority = new OidcUserAuthority(AUTHORITY, ID_TOKEN, USER_INFO); - assertThat(userAuthority.getIdToken()).isEqualTo(ID_TOKEN); assertThat(userAuthority.getUserInfo()).isEqualTo(USER_INFO); assertThat(userAuthority.getAuthority()).isEqualTo(AUTHORITY); - assertThat(userAuthority.getAttributes()).containsOnlyKeys( - IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, StandardClaimNames.NAME, StandardClaimNames.EMAIL); + assertThat(userAuthority.getAttributes()).containsOnlyKeys(IdTokenClaimNames.ISS, IdTokenClaimNames.SUB, + StandardClaimNames.NAME, StandardClaimNames.EMAIL); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java index d53cbf49f4..3bda7ec32d 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.oidc.user; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +package org.springframework.security.oauth2.core.oidc.user; import java.time.Instant; import java.util.Arrays; @@ -26,21 +22,29 @@ import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; + /** * @author Joe Grandja */ -public class TestOidcUsers { +public final class TestOidcUsers { + + private TestOidcUsers() { + } public static DefaultOidcUser create() { OidcIdToken idToken = idToken(); OidcUserInfo userInfo = userInfo(); - return new DefaultOidcUser( - authorities(idToken, userInfo), idToken, userInfo); + return new DefaultOidcUser(authorities(idToken, userInfo), idToken, userInfo); } private static OidcIdToken idToken() { Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plusSeconds(3600); + // @formatter:off return OidcIdToken.withTokenValue("id-token") .issuedAt(issuedAt) .expiresAt(expiresAt) @@ -49,20 +53,16 @@ public class TestOidcUsers { .audience(Collections.unmodifiableSet(new LinkedHashSet<>(Collections.singletonList("client")))) .authorizedParty("client") .build(); + // @formatter:on } private static OidcUserInfo userInfo() { - return OidcUserInfo.builder() - .subject("subject") - .name("full name") - .build(); + return OidcUserInfo.builder().subject("subject").name("full name").build(); } private static Collection authorities(OidcIdToken idToken, OidcUserInfo userInfo) { - return new LinkedHashSet<>( - Arrays.asList( - new OidcUserAuthority(idToken, userInfo), - new SimpleGrantedAuthority("SCOPE_read"), - new SimpleGrantedAuthority("SCOPE_write"))); + return new LinkedHashSet<>(Arrays.asList(new OidcUserAuthority(idToken, userInfo), + new SimpleGrantedAuthority("SCOPE_read"), new SimpleGrantedAuthority("SCOPE_write"))); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/DefaultOAuth2UserTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/DefaultOAuth2UserTests.java index 3642b54abb..c1643e86d5 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/DefaultOAuth2UserTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/DefaultOAuth2UserTests.java @@ -16,15 +16,16 @@ package org.springframework.security.oauth2.core.user; -import org.junit.Test; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; -import org.springframework.util.SerializationUtils; - import java.util.Collections; import java.util.Map; import java.util.Set; +import org.junit.Test; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.util.SerializationUtils; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -34,12 +35,16 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class DefaultOAuth2UserTests { + private static final SimpleGrantedAuthority AUTHORITY = new SimpleGrantedAuthority("ROLE_USER"); + private static final Set AUTHORITIES = Collections.singleton(AUTHORITY); + private static final String ATTRIBUTE_NAME_KEY = "username"; + private static final String USERNAME = "test"; - private static final Map ATTRIBUTES = Collections.singletonMap( - ATTRIBUTE_NAME_KEY, USERNAME); + + private static final Map ATTRIBUTES = Collections.singletonMap(ATTRIBUTE_NAME_KEY, USERNAME); @Test(expected = IllegalArgumentException.class) public void constructorWhenAuthoritiesIsNullThenThrowIllegalArgumentException() { @@ -74,7 +79,6 @@ public class DefaultOAuth2UserTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { DefaultOAuth2User user = new DefaultOAuth2User(AUTHORITIES, ATTRIBUTES, ATTRIBUTE_NAME_KEY); - assertThat(user.getName()).isEqualTo(USERNAME); assertThat(user.getAuthorities()).hasSize(1); assertThat(user.getAuthorities().iterator().next()).isEqualTo(AUTHORITY); @@ -87,4 +91,5 @@ public class DefaultOAuth2UserTests { DefaultOAuth2User user = new DefaultOAuth2User(AUTHORITIES, ATTRIBUTES, ATTRIBUTE_NAME_KEY); SerializationUtils.serialize(user); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthorityTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthorityTests.java index ef37d2ab01..b7b22d5541 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthorityTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/OAuth2UserAuthorityTests.java @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.core.user; -import org.junit.Test; +package org.springframework.security.oauth2.core.user; import java.util.Collections; import java.util.Map; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -28,7 +29,9 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class OAuth2UserAuthorityTests { + private static final String AUTHORITY = "ROLE_USER"; + private static final Map ATTRIBUTES = Collections.singletonMap("username", "test"); @Test(expected = IllegalArgumentException.class) @@ -49,8 +52,8 @@ public class OAuth2UserAuthorityTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { OAuth2UserAuthority userAuthority = new OAuth2UserAuthority(AUTHORITY, ATTRIBUTES); - assertThat(userAuthority.getAuthority()).isEqualTo(AUTHORITY); assertThat(userAuthority.getAttributes()).isEqualTo(ATTRIBUTES); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/TestOAuth2Users.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/TestOAuth2Users.java index 456f6376aa..8c50d677b5 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/TestOAuth2Users.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/TestOAuth2Users.java @@ -16,19 +16,22 @@ package org.springframework.security.oauth2.core.user; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; - import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; + /** * @author Rob Winch */ -public class TestOAuth2Users { +public final class TestOAuth2Users { + + private TestOAuth2Users() { + } public static DefaultOAuth2User create() { String nameAttributeKey = "username"; @@ -39,10 +42,8 @@ public class TestOAuth2Users { } private static Collection authorities(Map attributes) { - return new LinkedHashSet<>( - Arrays.asList( - new OAuth2UserAuthority(attributes), - new SimpleGrantedAuthority("SCOPE_read"), - new SimpleGrantedAuthority("SCOPE_write"))); + return new LinkedHashSet<>(Arrays.asList(new OAuth2UserAuthority(attributes), + new SimpleGrantedAuthority("SCOPE_read"), new SimpleGrantedAuthority("SCOPE_write"))); } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java index 48f7cd19da..baf82a016b 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java @@ -16,8 +16,17 @@ package org.springframework.security.oauth2.core.web.reactive.function; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.core.codec.ByteBufferDecoder; import org.springframework.core.codec.StringDecoder; import org.springframework.http.HttpStatus; @@ -33,17 +42,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.web.reactive.function.BodyExtractor; -import reactor.core.publisher.Mono; - -import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rob Winch @@ -62,7 +63,6 @@ public class OAuth2BodyExtractorsTests { messageReaders.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); messageReaders.add(new DecoderHttpMessageReader<>(new Jackson2JsonDecoder())); messageReaders.add(new FormHttpMessageReader()); - this.hints = new HashMap<>(); this.context = new BodyExtractor.Context() { @Override @@ -86,50 +86,48 @@ public class OAuth2BodyExtractorsTests { public void oauth2AccessTokenResponseWhenInvalidJsonThenException() { BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors .oauth2AccessTokenResponse(); - MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); response.getHeaders().setContentType(MediaType.APPLICATION_JSON); response.setBody("{"); - Mono result = extractor.extract(response, this.context); - - assertThatCode(result::block) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("An error occurred parsing the Access Token response"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(result::block) + .withMessageContaining("An error occurred parsing the Access Token response"); + // @formatter:on } @Test public void oauth2AccessTokenResponseWhenEmptyThenException() { BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors .oauth2AccessTokenResponse(); - MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); - Mono result = extractor.extract(response, this.context); - - assertThatCode(result::block) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("Empty OAuth 2.0 Access Token Response"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(result::block) + .withMessageContaining("Empty OAuth 2.0 Access Token Response"); + // @formatter:on } @Test public void oauth2AccessTokenResponseWhenValidThenCreated() { BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors .oauth2AccessTokenResponse(); - MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); response.getHeaders().setContentType(MediaType.APPLICATION_JSON); - response.setBody("{\n" + // @formatter:off + response.setBody( + "{\n" + " \"access_token\":\"2YotnFZFEjr1zCsicMWpAA\",\n" + " \"token_type\":\"Bearer\",\n" + " \"expires_in\":3600,\n" + " \"refresh_token\":\"tGzv3JOkF0XG5Qx2TlKWIA\",\n" + " \"example_parameter\":\"example_value\"\n" + " }"); - + // @formatter:on Instant now = Instant.now(); OAuth2AccessTokenResponse result = extractor.extract(response, this.context).block(); - assertThat(result.getAccessToken().getTokenValue()).isEqualTo("2YotnFZFEjr1zCsicMWpAA"); assertThat(result.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(result.getAccessToken().getExpiresAt()).isBetween(now.plusSeconds(3600), now.plusSeconds(3600 + 2)); @@ -137,27 +135,26 @@ public class OAuth2BodyExtractorsTests { assertThat(result.getAdditionalParameters()).containsEntry("example_parameter", "example_value"); } - @Test // gh-6087 public void oauth2AccessTokenResponseWhenMultipleAttributeTypesThenCreated() { BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors .oauth2AccessTokenResponse(); - MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); response.getHeaders().setContentType(MediaType.APPLICATION_JSON); - response.setBody("{\n" - + " \"access_token\":\"2YotnFZFEjr1zCsicMWpAA\",\n" - + " \"token_type\":\"Bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"tGzv3JOkF0XG5Qx2TlKWIA\",\n" - + " \"subjson\":{}, \n" - + " \"list\":[] \n" - + " }"); - + // @formatter:off + response.setBody( + "{\n" + + " \"access_token\":\"2YotnFZFEjr1zCsicMWpAA\",\n" + + " \"token_type\":\"Bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"tGzv3JOkF0XG5Qx2TlKWIA\",\n" + + " \"subjson\":{}, \n" + + " \"list\":[] \n" + + " }"); + // @formatter:on Instant now = Instant.now(); OAuth2AccessTokenResponse result = extractor.extract(response, this.context).block(); - assertThat(result.getAccessToken().getTokenValue()).isEqualTo("2YotnFZFEjr1zCsicMWpAA"); assertThat(result.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(result.getAccessToken().getExpiresAt()).isBetween(now.plusSeconds(3600), now.plusSeconds(3600 + 2)); @@ -165,4 +162,5 @@ public class OAuth2BodyExtractorsTests { assertThat(result.getAdditionalParameters().get("subjson")).isInstanceOfAny(Map.class); assertThat(result.getAdditionalParameters().get("list")).isInstanceOfAny(List.class); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithm.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithm.java index d815208940..04d58314a2 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithm.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithm.java @@ -13,18 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose.jws; /** - * Super interface for cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification - * and used by JSON Web Signature (JWS) to digitally sign or create a MAC - * of the contents of the JWS Protected Header and JWS Payload. + * Super interface for cryptographic algorithms defined by the JSON Web Algorithms (JWA) + * specification and used by JSON Web Signature (JWS) to digitally sign or create a MAC of + * the contents of the JWS Protected Header and JWS Payload. * * @author Joe Grandja * @since 5.2 - * @see JSON Web Algorithms (JWA) - * @see JSON Web Signature (JWS) - * @see Cryptographic Algorithms for Digital Signatures and MACs + * @see JSON Web Algorithms + * (JWA) + * @see JSON Web Signature + * (JWS) + * @see Cryptographic Algorithms for Digital + * Signatures and MACs */ public interface JwsAlgorithm { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithms.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithms.java index b3ef013e6b..c8b159592c 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithms.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/JwsAlgorithms.java @@ -13,18 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose.jws; /** - * The cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification - * and used by JSON Web Signature (JWS) to digitally sign or create a MAC - * of the contents of the JWS Protected Header and JWS Payload. + * The cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification and + * used by JSON Web Signature (JWS) to digitally sign or create a MAC of the contents of + * the JWS Protected Header and JWS Payload. * * @author Joe Grandja * @since 5.0 - * @see JSON Web Algorithms (JWA) - * @see JSON Web Signature (JWS) - * @see Cryptographic Algorithms for Digital Signatures and MACs + * @see JSON Web Algorithms + * (JWA) + * @see JSON Web Signature + * (JWS) + * @see Cryptographic Algorithms for Digital + * Signatures and MACs */ public interface JwsAlgorithms { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/MacAlgorithm.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/MacAlgorithm.java index 2f525b9a5c..110d21896f 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/MacAlgorithm.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/MacAlgorithm.java @@ -13,19 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose.jws; - /** - * An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification - * and used by JSON Web Signature (JWS) to create a MAC of the contents of the JWS Protected Header and JWS Payload. + * An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) + * specification and used by JSON Web Signature (JWS) to create a MAC of the contents of + * the JWS Protected Header and JWS Payload. * * @author Joe Grandja * @since 5.2 * @see JwsAlgorithm - * @see JSON Web Algorithms (JWA) - * @see JSON Web Signature (JWS) - * @see Cryptographic Algorithms for Digital Signatures and MACs + * @see JSON Web Algorithms + * (JWA) + * @see JSON Web Signature + * (JWS) + * @see Cryptographic Algorithms for Digital + * Signatures and MACs */ public enum MacAlgorithm implements JwsAlgorithm { @@ -44,16 +49,23 @@ public enum MacAlgorithm implements JwsAlgorithm { */ HS512(JwsAlgorithms.HS512); - private final String name; MacAlgorithm(String name) { this.name = name; } + /** + * Returns the algorithm name. + * @return the algorithm name + */ + @Override + public String getName() { + return this.name; + } + /** * Attempt to resolve the provided algorithm name to a {@code MacAlgorithm}. - * * @param name the algorithm name * @return the resolved {@code MacAlgorithm}, or {@code null} if not found */ @@ -66,13 +78,4 @@ public enum MacAlgorithm implements JwsAlgorithm { return null; } - /** - * Returns the algorithm name. - * - * @return the algorithm name - */ - @Override - public String getName() { - return this.name; - } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithm.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithm.java index 8ea0d884b7..f27ea53f16 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithm.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithm.java @@ -13,18 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose.jws; /** - * An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) specification - * and used by JSON Web Signature (JWS) to digitally sign the contents of the JWS Protected Header and JWS Payload. + * An enumeration of the cryptographic algorithms defined by the JSON Web Algorithms (JWA) + * specification and used by JSON Web Signature (JWS) to digitally sign the contents of + * the JWS Protected Header and JWS Payload. * * @author Joe Grandja * @since 5.2 * @see JwsAlgorithm - * @see JSON Web Algorithms (JWA) - * @see JSON Web Signature (JWS) - * @see Cryptographic Algorithms for Digital Signatures and MACs + * @see JSON Web Algorithms + * (JWA) + * @see JSON Web Signature + * (JWS) + * @see Cryptographic Algorithms for Digital + * Signatures and MACs */ public enum SignatureAlgorithm implements JwsAlgorithm { @@ -73,16 +79,23 @@ public enum SignatureAlgorithm implements JwsAlgorithm { */ PS512(JwsAlgorithms.PS512); - private final String name; SignatureAlgorithm(String name) { this.name = name; } + /** + * Returns the algorithm name. + * @return the algorithm name + */ + @Override + public String getName() { + return this.name; + } + /** * Attempt to resolve the provided algorithm name to a {@code SignatureAlgorithm}. - * * @param name the algorithm name * @return the resolved {@code SignatureAlgorithm}, or {@code null} if not found */ @@ -95,13 +108,4 @@ public enum SignatureAlgorithm implements JwsAlgorithm { return null; } - /** - * Returns the algorithm name. - * - * @return the algorithm name - */ - @Override - public String getName() { - return this.name; - } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/package-info.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/package-info.java index 7801235a99..f3f95405a4 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/package-info.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jose/jws/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Core classes and interfaces providing support for JSON Web Signature (JWS). */ diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java index 11aa05b9f9..3a30545179 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java @@ -17,13 +17,15 @@ package org.springframework.security.oauth2.jwt; /** - * An exception similar to {@link org.springframework.security.authentication.BadCredentialsException} - * that indicates a {@link Jwt} that is invalid in some way. + * An exception similar to + * {@link org.springframework.security.authentication.BadCredentialsException} that + * indicates a {@link Jwt} that is invalid in some way. * * @author Josh Cummings * @since 5.3 */ public class BadJwtException extends JwtException { + public BadJwtException(String message) { super(message); } @@ -31,4 +33,5 @@ public class BadJwtException extends JwtException { public BadJwtException(String message, Throwable cause) { super(message, cause); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/Jwt.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/Jwt.java index e5e09f0117..829f731249 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/Jwt.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/Jwt.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.time.Instant; @@ -25,38 +26,35 @@ import java.util.function.Consumer; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.AUD; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.EXP; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.IAT; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.JTI; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.NBF; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; - /** - * An implementation of an {@link AbstractOAuth2Token} representing a JSON Web Token (JWT). + * An implementation of an {@link AbstractOAuth2Token} representing a JSON Web Token + * (JWT). * *

        * JWTs represent a set of "claims" as a JSON object that may be encoded in a - * JSON Web Signature (JWS) and/or JSON Web Encryption (JWE) structure. - * The JSON object, also known as the JWT Claims Set, consists of one or more claim name/value pairs. - * The claim name is a {@code String} and the claim value is an arbitrary JSON object. + * JSON Web Signature (JWS) and/or JSON Web Encryption (JWE) structure. The JSON object, + * also known as the JWT Claims Set, consists of one or more claim name/value pairs. The + * claim name is a {@code String} and the claim value is an arbitrary JSON object. * * @author Joe Grandja * @since 5.0 * @see AbstractOAuth2Token * @see JwtClaimAccessor - * @see JSON Web Token (JWT) - * @see JSON Web Signature (JWS) - * @see JSON Web Encryption (JWE) + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Encryption + * (JWE) */ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { + private final Map headers; + private final Map claims; /** * Constructs a {@code Jwt} using the provided parameters. - * * @param tokenValue the token value * @param issuedAt the time at which the JWT was issued * @param expiresAt the expiration time on or after which the JWT MUST NOT be accepted @@ -64,8 +62,8 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { * @param claims the JWT Claims Set * */ - public Jwt(String tokenValue, Instant issuedAt, Instant expiresAt, - Map headers, Map claims) { + public Jwt(String tokenValue, Instant issuedAt, Instant expiresAt, Map headers, + Map claims) { super(tokenValue, issuedAt, expiresAt); Assert.notEmpty(headers, "headers cannot be empty"); Assert.notEmpty(claims, "claims cannot be empty"); @@ -75,7 +73,6 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Returns the JOSE header(s). - * * @return a {@code Map} of the JOSE header(s) */ public Map getHeaders() { @@ -84,7 +81,6 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Returns the JWT Claims Set. - * * @return a {@code Map} of the JWT Claims Set */ @Override @@ -94,7 +90,6 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Return a {@link Jwt.Builder} - * * @return A {@link Jwt.Builder} */ public static Builder withTokenValue(String tokenValue) { @@ -108,9 +103,12 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { * @author Josh Cummings * @since 5.2 */ - public final static class Builder { + public static final class Builder { + private String tokenValue; + private final Map claims = new LinkedHashMap<>(); + private final Map headers = new LinkedHashMap<>(); private Builder(String tokenValue) { @@ -119,7 +117,6 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Use this token value in the resulting {@link Jwt} - * * @param tokenValue The token value to use * @return the {@link Builder} for further configurations */ @@ -130,7 +127,6 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Use this claim in the resulting {@link Jwt} - * * @param name The claim name * @param value The claim value * @return the {@link Builder} for further configurations @@ -141,8 +137,8 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { } /** - * Provides access to every {@link #claim(String, Object)} - * declared so far with the possibility to add, replace, or remove. + * Provides access to every {@link #claim(String, Object)} declared so far with + * the possibility to add, replace, or remove. * @param claimsConsumer the consumer * @return the {@link Builder} for further configurations */ @@ -153,7 +149,6 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Use this header in the resulting {@link Jwt} - * * @param name The header name * @param value The header value * @return the {@link Builder} for further configurations @@ -164,8 +159,8 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { } /** - * Provides access to every {@link #header(String, Object)} - * declared so far with the possibility to add, replace, or remove. + * Provides access to every {@link #header(String, Object)} declared so far with + * the possibility to add, replace, or remove. * @param headersConsumer the consumer * @return the {@link Builder} for further configurations */ @@ -176,88 +171,80 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { /** * Use this audience in the resulting {@link Jwt} - * * @param audience The audience(s) to use * @return the {@link Builder} for further configurations */ public Builder audience(Collection audience) { - return claim(AUD, audience); + return claim(JwtClaimNames.AUD, audience); } /** * Use this expiration in the resulting {@link Jwt} - * * @param expiresAt The expiration to use * @return the {@link Builder} for further configurations */ public Builder expiresAt(Instant expiresAt) { - this.claim(EXP, expiresAt); + this.claim(JwtClaimNames.EXP, expiresAt); return this; } /** * Use this identifier in the resulting {@link Jwt} - * * @param jti The identifier to use * @return the {@link Builder} for further configurations */ public Builder jti(String jti) { - this.claim(JTI, jti); + this.claim(JwtClaimNames.JTI, jti); return this; } /** * Use this issued-at timestamp in the resulting {@link Jwt} - * * @param issuedAt The issued-at timestamp to use * @return the {@link Builder} for further configurations */ public Builder issuedAt(Instant issuedAt) { - this.claim(IAT, issuedAt); + this.claim(JwtClaimNames.IAT, issuedAt); return this; } /** * Use this issuer in the resulting {@link Jwt} - * * @param issuer The issuer to use * @return the {@link Builder} for further configurations */ public Builder issuer(String issuer) { - this.claim(ISS, issuer); + this.claim(JwtClaimNames.ISS, issuer); return this; } /** * Use this not-before timestamp in the resulting {@link Jwt} - * * @param notBefore The not-before timestamp to use * @return the {@link Builder} for further configurations */ public Builder notBefore(Instant notBefore) { - this.claim(NBF, notBefore); + this.claim(JwtClaimNames.NBF, notBefore); return this; } /** * Use this subject in the resulting {@link Jwt} - * * @param subject The subject to use * @return the {@link Builder} for further configurations */ public Builder subject(String subject) { - this.claim(SUB, subject); + this.claim(JwtClaimNames.SUB, subject); return this; } /** * Build the {@link Jwt} - * * @return The constructed {@link Jwt} */ public Jwt build() { - Instant iat = toInstant(this.claims.get(IAT)); - Instant exp = toInstant(this.claims.get(EXP)); + Instant iat = toInstant(this.claims.get(JwtClaimNames.IAT)); + Instant exp = toInstant(this.claims.get(JwtClaimNames.EXP)); return new Jwt(this.tokenValue, iat, exp, this.headers, this.claims); } @@ -267,5 +254,7 @@ public class Jwt extends AbstractOAuth2Token implements JwtClaimAccessor { } return (Instant) timestamp; } + } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimAccessor.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimAccessor.java index f7d5740201..c84be85912 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimAccessor.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimAccessor.java @@ -13,30 +13,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.jwt; -import org.springframework.security.oauth2.core.ClaimAccessor; +package org.springframework.security.oauth2.jwt; import java.net.URL; import java.time.Instant; import java.util.List; +import org.springframework.security.oauth2.core.ClaimAccessor; + /** - * A {@link ClaimAccessor} for the "claims" that may be contained - * in the JSON object JWT Claims Set of a JSON Web Token (JWT). + * A {@link ClaimAccessor} for the "claims" that may be contained in the JSON + * object JWT Claims Set of a JSON Web Token (JWT). * * @author Joe Grandja * @since 5.0 * @see ClaimAccessor * @see JwtClaimNames * @see Jwt - * @see Registered Claim Names + * @see Registered Claim Names */ public interface JwtClaimAccessor extends ClaimAccessor { /** - * Returns the Issuer {@code (iss)} claim which identifies the principal that issued the JWT. - * + * Returns the Issuer {@code (iss)} claim which identifies the principal that issued + * the JWT. * @return the Issuer identifier */ default URL getIssuer() { @@ -44,9 +46,8 @@ public interface JwtClaimAccessor extends ClaimAccessor { } /** - * Returns the Subject {@code (sub)} claim which identifies the principal - * that is the subject of the JWT. - * + * Returns the Subject {@code (sub)} claim which identifies the principal that is the + * subject of the JWT. * @return the Subject identifier */ default String getSubject() { @@ -54,9 +55,8 @@ public interface JwtClaimAccessor extends ClaimAccessor { } /** - * Returns the Audience {@code (aud)} claim which identifies the recipient(s) - * that the JWT is intended for. - * + * Returns the Audience {@code (aud)} claim which identifies the recipient(s) that the + * JWT is intended for. * @return the Audience(s) that this JWT intended for */ default List getAudience() { @@ -64,28 +64,28 @@ public interface JwtClaimAccessor extends ClaimAccessor { } /** - * Returns the Expiration time {@code (exp)} claim which identifies the expiration time - * on or after which the JWT MUST NOT be accepted for processing. - * - * @return the Expiration time on or after which the JWT MUST NOT be accepted for processing + * Returns the Expiration time {@code (exp)} claim which identifies the expiration + * time on or after which the JWT MUST NOT be accepted for processing. + * @return the Expiration time on or after which the JWT MUST NOT be accepted for + * processing */ default Instant getExpiresAt() { return this.getClaimAsInstant(JwtClaimNames.EXP); } /** - * Returns the Not Before {@code (nbf)} claim which identifies the time - * before which the JWT MUST NOT be accepted for processing. - * - * @return the Not Before time before which the JWT MUST NOT be accepted for processing + * Returns the Not Before {@code (nbf)} claim which identifies the time before which + * the JWT MUST NOT be accepted for processing. + * @return the Not Before time before which the JWT MUST NOT be accepted for + * processing */ default Instant getNotBefore() { return this.getClaimAsInstant(JwtClaimNames.NBF); } /** - * Returns the Issued at {@code (iat)} claim which identifies the time at which the JWT was issued. - * + * Returns the Issued at {@code (iat)} claim which identifies the time at which the + * JWT was issued. * @return the Issued at claim which identifies the time at which the JWT was issued */ default Instant getIssuedAt() { @@ -93,11 +93,12 @@ public interface JwtClaimAccessor extends ClaimAccessor { } /** - * Returns the JWT ID {@code (jti)} claim which provides a unique identifier for the JWT. - * + * Returns the JWT ID {@code (jti)} claim which provides a unique identifier for the + * JWT. * @return the JWT ID claim which provides a unique identifier for the JWT */ default String getId() { return this.getClaimAsString(JwtClaimNames.JTI); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimNames.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimNames.java index aad5c0f452..683c21e509 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimNames.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimNames.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; /** - * The Registered Claim Names defined by the JSON Web Token (JWT) specification - * that may be contained in the JSON object JWT Claims Set. + * The Registered Claim Names defined by the JSON Web Token (JWT) specification that may + * be contained in the JSON object JWT Claims Set. * * @author Joe Grandja * @since 5.0 * @see JwtClaimAccessor - * @see JWT Claims + * @see JWT + * Claims */ public interface JwtClaimNames { @@ -32,22 +34,26 @@ public interface JwtClaimNames { String ISS = "iss"; /** - * {@code sub} - the Subject claim identifies the principal that is the subject of the JWT + * {@code sub} - the Subject claim identifies the principal that is the subject of the + * JWT */ String SUB = "sub"; /** - * {@code aud} - the Audience claim identifies the recipient(s) that the JWT is intended for + * {@code aud} - the Audience claim identifies the recipient(s) that the JWT is + * intended for */ String AUD = "aud"; /** - * {@code exp} - the Expiration time claim identifies the expiration time on or after which the JWT MUST NOT be accepted for processing + * {@code exp} - the Expiration time claim identifies the expiration time on or after + * which the JWT MUST NOT be accepted for processing */ String EXP = "exp"; /** - * {@code nbf} - the Not Before claim identifies the time before which the JWT MUST NOT be accepted for processing + * {@code nbf} - the Not Before claim identifies the time before which the JWT MUST + * NOT be accepted for processing */ String NBF = "nbf"; diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimValidator.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimValidator.java index c9b4a91d5b..73c13c7dc2 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimValidator.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimValidator.java @@ -13,34 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; +import java.util.function.Predicate; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.util.Assert; -import java.util.function.Predicate; - /** - * Validates a claim in a {@link Jwt} against a provided {@link java.util.function.Predicate} + * Validates a claim in a {@link Jwt} against a provided + * {@link java.util.function.Predicate} * * @author Zeeshan Adnan * @since 5.3 */ public final class JwtClaimValidator implements OAuth2TokenValidator { + private final Log logger = LogFactory.getLog(getClass()); private final String claim; + private final Predicate test; + private final OAuth2Error error; /** * Constructs a {@link JwtClaimValidator} using the provided parameters - * * @param claim - is the name of the claim in {@link Jwt} to validate. * @param test - is the predicate function for the claim to test against. */ @@ -49,23 +54,19 @@ public final class JwtClaimValidator implements OAuth2TokenValidator { Assert.notNull(test, "test can not be null"); this.claim = claim; this.test = test; - this.error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, - "The " + this.claim + " claim is not valid", + this.error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, "The " + this.claim + " claim is not valid", "https://tools.ietf.org/html/rfc6750#section-3.1"); } - /** - * {@inheritDoc} - */ @Override public OAuth2TokenValidatorResult validate(Jwt token) { Assert.notNull(token, "token cannot be null"); T claimValue = token.getClaim(this.claim); - if (test.test(claimValue)) { + if (this.test.test(claimValue)) { return OAuth2TokenValidatorResult.success(); - } else { - logger.debug(error.getDescription()); - return OAuth2TokenValidatorResult.failure(error); } + this.logger.debug(this.error.getDescription()); + return OAuth2TokenValidatorResult.failure(this.error); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoder.java index 020f6362eb..7cda7fd9a9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoder.java @@ -13,36 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; /** - * Implementations of this interface are responsible for "decoding" - * a JSON Web Token (JWT) from it's compact claims representation format to a {@link Jwt}. + * Implementations of this interface are responsible for "decoding" a JSON Web + * Token (JWT) from it's compact claims representation format to a {@link Jwt}. * *

        - * JWTs may be represented using the JWS Compact Serialization format for a - * JSON Web Signature (JWS) structure or JWE Compact Serialization format for a - * JSON Web Encryption (JWE) structure. Therefore, implementors are responsible - * for verifying a JWS and/or decrypting a JWE. + * JWTs may be represented using the JWS Compact Serialization format for a JSON Web + * Signature (JWS) structure or JWE Compact Serialization format for a JSON Web Encryption + * (JWE) structure. Therefore, implementors are responsible for verifying a JWS and/or + * decrypting a JWE. * * @author Joe Grandja * @since 5.0 * @see Jwt - * @see JSON Web Token (JWT) - * @see JSON Web Signature (JWS) - * @see JSON Web Encryption (JWE) - * @see JWS Compact Serialization - * @see JWE Compact Serialization + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Encryption + * (JWE) + * @see JWS + * Compact Serialization + * @see JWE + * Compact Serialization */ @FunctionalInterface public interface JwtDecoder { /** - * Decodes the JWT from it's compact claims representation format and returns a {@link Jwt}. - * + * Decodes the JWT from it's compact claims representation format and returns a + * {@link Jwt}. * @param token the JWT value * @return a {@link Jwt} * @throws JwtException if an error occurs while attempting to decode the JWT */ Jwt decode(String token) throws JwtException; + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java index 094cc96086..c1af8a909c 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderFactory.java @@ -13,25 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; /** - * A factory for {@link JwtDecoder}(s). - * This factory should be supplied with a type that provides - * contextual information used to create a specific {@code JwtDecoder}. + * A factory for {@link JwtDecoder}(s). This factory should be supplied with a type that + * provides contextual information used to create a specific {@code JwtDecoder}. * * @author Joe Grandja * @since 5.2 * @see JwtDecoder - * - * @param The type that provides contextual information used to create a specific {@code JwtDecoder}. + * @param The type that provides contextual information used to create a specific + * {@code JwtDecoder}. */ @FunctionalInterface public interface JwtDecoderFactory { /** * Creates a {@code JwtDecoder} using the supplied "contextual" type. - * * @param context the type that provides contextual information * @return a {@link JwtDecoder} */ diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderProviderConfigurationUtils.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderProviderConfigurationUtils.java index 2496abb5b1..5cf7ace331 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderProviderConfigurationUtils.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderProviderConfigurationUtils.java @@ -13,35 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.jwt; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.RequestEntity; -import org.springframework.http.ResponseEntity; -import org.springframework.web.client.HttpClientErrorException; -import org.springframework.web.client.RestTemplate; -import org.springframework.web.util.UriComponentsBuilder; +package org.springframework.security.oauth2.jwt; import java.net.URI; import java.util.Collections; import java.util.Map; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; + /** - * Allows resolving configuration from an - * OpenID Provider Configuration or - * Authorization Server Metadata Request based on provided - * issuer and method invoked. + * Allows resolving configuration from an OpenID + * Provider Configuration or + * Authorization Server Metadata + * Request based on provided issuer and method invoked. * * @author Thomas Vitale * @author Rafiullah Hamedy * @since 5.2 */ -class JwtDecoderProviderConfigurationUtils { +final class JwtDecoderProviderConfigurationUtils { + private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration"; + private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server"; + private static final RestTemplate rest = new RestTemplate(); - private static final ParameterizedTypeReference> typeReference = - new ParameterizedTypeReference>() {}; + + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; + + private JwtDecoderProviderConfigurationUtils() { + } static Map getConfigurationForOidcIssuerLocation(String oidcIssuerLocation) { return getConfiguration(oidcIssuerLocation, oidc(URI.create(oidcIssuerLocation))); @@ -53,36 +63,35 @@ class JwtDecoderProviderConfigurationUtils { } static void validateIssuer(Map configuration, String issuer) { - String metadataIssuer = "(unavailable)"; + String metadataIssuer = getMetadataIssuer(configuration); + Assert.state(issuer.equals(metadataIssuer), () -> "The Issuer \"" + metadataIssuer + + "\" provided in the configuration did not " + "match the requested issuer \"" + issuer + "\""); + } + + private static String getMetadataIssuer(Map configuration) { if (configuration.containsKey("issuer")) { - metadataIssuer = configuration.get("issuer").toString(); - } - if (!issuer.equals(metadataIssuer)) { - throw new IllegalStateException("The Issuer \"" + metadataIssuer + "\" provided in the configuration did not " - + "match the requested issuer \"" + issuer + "\""); + return configuration.get("issuer").toString(); } + return "(unavailable)"; } private static Map getConfiguration(String issuer, URI... uris) { - String errorMessage = "Unable to resolve the Configuration with the provided Issuer of " + - "\"" + issuer + "\""; + String errorMessage = "Unable to resolve the Configuration with the provided Issuer of " + "\"" + issuer + "\""; for (URI uri : uris) { try { RequestEntity request = RequestEntity.get(uri).build(); - ResponseEntity> response = rest.exchange(request, typeReference); + ResponseEntity> response = rest.exchange(request, STRING_OBJECT_MAP); Map configuration = response.getBody(); - - if (configuration.get("jwks_uri") == null) { - throw new IllegalArgumentException("The public JWK set URI must not be null"); - } - + Assert.isTrue(configuration.get("jwks_uri") != null, "The public JWK set URI must not be null"); return configuration; - } catch (IllegalArgumentException e) { - throw e; - } catch (RuntimeException e) { - if (!(e instanceof HttpClientErrorException && - ((HttpClientErrorException) e).getStatusCode().is4xxClientError())) { - throw new IllegalArgumentException(errorMessage, e); + } + catch (IllegalArgumentException ex) { + throw ex; + } + catch (RuntimeException ex) { + if (!(ex instanceof HttpClientErrorException + && ((HttpClientErrorException) ex).getStatusCode().is4xxClientError())) { + throw new IllegalArgumentException(errorMessage, ex); } // else try another endpoint } @@ -91,20 +100,27 @@ class JwtDecoderProviderConfigurationUtils { } private static URI oidc(URI issuer) { + // @formatter:off return UriComponentsBuilder.fromUri(issuer) .replacePath(issuer.getPath() + OIDC_METADATA_PATH) .build(Collections.emptyMap()); + // @formatter:on } private static URI oidcRfc8414(URI issuer) { + // @formatter:off return UriComponentsBuilder.fromUri(issuer) .replacePath(OIDC_METADATA_PATH + issuer.getPath()) .build(Collections.emptyMap()); + // @formatter:on } private static URI oauth(URI issuer) { + // @formatter:off return UriComponentsBuilder.fromUri(issuer) .replacePath(OAUTH_METADATA_PATH + issuer.getPath()) .build(Collections.emptyMap()); + // @formatter:on } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java index 0d2f233198..d109cc0d2d 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.util.Map; @@ -20,13 +21,12 @@ import java.util.Map; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; - /** - * Allows creating a {@link JwtDecoder} from an - * OpenID Provider Configuration or - * Authorization Server Metadata Request based on provided - * issuer and method invoked. + * Allows creating a {@link JwtDecoder} from an OpenID + * Provider Configuration or + * Authorization Server Metadata + * Request based on provided issuer and method invoked. * * @author Josh Cummings * @author Rafiullah Hamedy @@ -34,79 +34,81 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSe */ public final class JwtDecoders { + private JwtDecoders() { + } + /** - * Creates a {@link JwtDecoder} using the provided - * Issuer by making an - * OpenID Provider - * Configuration Request and using the values in the - * OpenID + * Creates a {@link JwtDecoder} using the provided Issuer + * by making an OpenID + * Provider Configuration Request and using the values in the OpenID * Provider Configuration Response to initialize the {@link JwtDecoder}. - * - * @param oidcIssuerLocation the Issuer - * @return a {@link JwtDecoder} that was initialized by the OpenID Provider Configuration. + * @param oidcIssuerLocation the Issuer + * @return a {@link JwtDecoder} that was initialized by the OpenID Provider + * Configuration. */ public static JwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation) { Assert.hasText(oidcIssuerLocation, "oidcIssuerLocation cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForOidcIssuerLocation(oidcIssuerLocation); + Map configuration = JwtDecoderProviderConfigurationUtils + .getConfigurationForOidcIssuerLocation(oidcIssuerLocation); return withProviderConfiguration(configuration, oidcIssuerLocation); } /** - * Creates a {@link JwtDecoder} using the provided - * Issuer by querying - * three different discovery endpoints serially, using the values in the first successful response to - * initialize. If an endpoint returns anything other than a 200 or a 4xx, the method will exit without - * attempting subsequent endpoints. + * Creates a {@link JwtDecoder} using the provided Issuer + * by querying three different discovery endpoints serially, using the values in the + * first successful response to initialize. If an endpoint returns anything other than + * a 200 or a 4xx, the method will exit without attempting subsequent endpoints. * - * The three endpoints are computed as follows, given that the {@code issuer} is composed of a {@code host} - * and a {@code path}: + * The three endpoints are computed as follows, given that the {@code issuer} is + * composed of a {@code host} and a {@code path}: * *

          - *
        1. - * {@code host/.well-known/openid-configuration/path}, as defined in - * RFC 8414's Compatibility Notes. - *
        2. - *
        3. - * {@code issuer/.well-known/openid-configuration}, as defined in - * - * OpenID Provider Configuration. - *
        4. - *
        5. - * {@code host/.well-known/oauth-authorization-server/path}, as defined in - * Authorization Server Metadata Request. - *
        6. + *
        7. {@code host/.well-known/openid-configuration/path}, as defined in + * RFC 8414's Compatibility + * Notes.
        8. + *
        9. {@code issuer/.well-known/openid-configuration}, as defined in + * OpenID Provider Configuration.
        10. + *
        11. {@code host/.well-known/oauth-authorization-server/path}, as defined in + * Authorization Server + * Metadata Request.
        12. *
        * * Note that the second endpoint is the equivalent of calling * {@link JwtDecoders#fromOidcIssuerLocation(String)} - * - * @param issuer the Issuer + * @param issuer the Issuer * @return a {@link JwtDecoder} that was initialized by one of the described endpoints */ public static JwtDecoder fromIssuerLocation(String issuer) { Assert.hasText(issuer, "issuer cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForIssuerLocation(issuer); + Map configuration = JwtDecoderProviderConfigurationUtils + .getConfigurationForIssuerLocation(issuer); return withProviderConfiguration(configuration, issuer); } /** - * Validate provided issuer and build {@link JwtDecoder} from - * OpenID Provider - * Configuration Response and Authorization Server Metadata - * Response. - * + * Validate provided issuer and build {@link JwtDecoder} from OpenID + * Provider Configuration Response and + * Authorization Server + * Metadata Response. * @param configuration the configuration values - * @param issuer the Issuer + * @param issuer the Issuer * @return {@link JwtDecoder} */ private static JwtDecoder withProviderConfiguration(Map configuration, String issuer) { JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer); OAuth2TokenValidator jwtValidator = JwtValidators.createDefaultWithIssuer(issuer); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(configuration.get("jwks_uri").toString()).build(); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(configuration.get("jwks_uri").toString()).build(); jwtDecoder.setJwtValidator(jwtValidator); - return jwtDecoder; } - private JwtDecoders() {} } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtException.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtException.java index fe3f58cadb..b13f0dff26 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtException.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; /** @@ -25,7 +26,6 @@ public class JwtException extends RuntimeException { /** * Constructs a {@code JwtException} using the provided parameters. - * * @param message the detail message */ public JwtException(String message) { @@ -34,11 +34,11 @@ public class JwtException extends RuntimeException { /** * Constructs a {@code JwtException} using the provided parameters. - * * @param message the detail message * @param cause the root cause */ public JwtException(String message, Throwable cause) { super(message, cause); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtIssuerValidator.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtIssuerValidator.java index da9beacce8..61c5d11882 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtIssuerValidator.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtIssuerValidator.java @@ -13,14 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; - /** * Validates the "iss" claim in a {@link Jwt}, that is matches a configured value * @@ -33,20 +32,17 @@ public final class JwtIssuerValidator implements OAuth2TokenValidator { /** * Constructs a {@link JwtIssuerValidator} using the provided parameters - * * @param issuer - The issuer that each {@link Jwt} should have. */ public JwtIssuerValidator(String issuer) { Assert.notNull(issuer, "issuer cannot be null"); - this.validator = new JwtClaimValidator(ISS, issuer::equals); + this.validator = new JwtClaimValidator(JwtClaimNames.ISS, issuer::equals); } - /** - * {@inheritDoc} - */ @Override public OAuth2TokenValidatorResult validate(Jwt token) { Assert.notNull(token, "token cannot be null"); return this.validator.validate(token); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtTimestampValidator.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtTimestampValidator.java index eef9dc8e8e..0fb72aca00 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtTimestampValidator.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtTimestampValidator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.time.Clock; @@ -30,20 +31,23 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.util.Assert; /** - * An implementation of {@link OAuth2TokenValidator} for verifying claims in a Jwt-based access token + * An implementation of {@link OAuth2TokenValidator} for verifying claims in a Jwt-based + * access token * *

        - * Because clocks can differ between the Jwt source, say the Authorization Server, and its destination, say the - * Resource Server, there is a default clock leeway exercised when deciding if the current time is within the Jwt's - * specified operating window + * Because clocks can differ between the Jwt source, say the Authorization Server, and its + * destination, say the Resource Server, there is a default clock leeway exercised when + * deciding if the current time is within the Jwt's specified operating window * * @author Josh Cummings * @since 5.1 * @see Jwt * @see OAuth2TokenValidator - * @see JSON Web Token (JWT) + * @see JSON Web Token + * (JWT) */ public final class JwtTimestampValidator implements OAuth2TokenValidator { + private final Log logger = LogFactory.getLog(getClass()); private static final Duration DEFAULT_MAX_CLOCK_SKEW = Duration.of(60, ChronoUnit.SECONDS); @@ -64,51 +68,39 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator { this.clockSkew = clockSkew; } - /** - * {@inheritDoc} - */ @Override public OAuth2TokenValidatorResult validate(Jwt jwt) { Assert.notNull(jwt, "jwt cannot be null"); - Instant expiry = jwt.getExpiresAt(); - if (expiry != null) { - if (Instant.now(this.clock).minus(clockSkew).isAfter(expiry)) { + if (Instant.now(this.clock).minus(this.clockSkew).isAfter(expiry)) { OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt expired at %s", jwt.getExpiresAt())); return OAuth2TokenValidatorResult.failure(oAuth2Error); } } - Instant notBefore = jwt.getNotBefore(); - if (notBefore != null) { - if (Instant.now(this.clock).plus(clockSkew).isBefore(notBefore)) { + if (Instant.now(this.clock).plus(this.clockSkew).isBefore(notBefore)) { OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt used before %s", jwt.getNotBefore())); return OAuth2TokenValidatorResult.failure(oAuth2Error); } } - return OAuth2TokenValidatorResult.success(); } private OAuth2Error createOAuth2Error(String reason) { - logger.debug(reason); - return new OAuth2Error( - OAuth2ErrorCodes.INVALID_REQUEST, - reason, + this.logger.debug(reason); + return new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, reason, "https://tools.ietf.org/html/rfc6750#section-3.1"); } /** - * ' - * Use this {@link Clock} with {@link Instant#now()} for assessing - * timestamp validity - * + * Use this {@link Clock} with {@link Instant#now()} for assessing timestamp validity * @param clock */ public void setClock(Clock clock) { Assert.notNull(clock, "clock cannot be null"); this.clock = clock; } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java index 3ea9e11050..94568d2dc6 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.util.ArrayList; @@ -23,13 +24,13 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.util.Assert; /** - * An exception that results from an unsuccessful - * {@link OAuth2TokenValidatorResult} + * An exception that results from an unsuccessful {@link OAuth2TokenValidatorResult} * * @author Josh Cummings * @since 5.1 */ public class JwtValidationException extends BadJwtException { + private final Collection errors; /** @@ -43,17 +44,16 @@ public class JwtValidationException extends BadJwtException { * *

         	 * 	if ( result.hasErrors() ) {
        -	 *  	Collection errors = result.getErrors();
        +	 *  	Collection<OAuth2Error> errors = result.getErrors();
         	 *  	throw new JwtValidationException(errors.iterator().next().getDescription(), errors);
         	 * 	}
         	 * 
        - * * @param message - the exception message - * @param errors - a list of {@link OAuth2Error}s with extra detail about the validation result + * @param errors - a list of {@link OAuth2Error}s with extra detail about the + * validation result */ public JwtValidationException(String message, Collection errors) { super(message); - Assert.notEmpty(errors, "errors cannot be empty"); this.errors = new ArrayList<>(errors); } @@ -65,4 +65,5 @@ public class JwtValidationException extends BadJwtException { public Collection getErrors() { return this.errors; } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidators.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidators.java index 9630c944fd..4d13ce52ab 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidators.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidators.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.util.ArrayList; @@ -24,22 +25,29 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator; /** * Provides factory methods for creating {@code OAuth2TokenValidator} + * * @author Josh Cummings * @author Rob Winch * @since 5.1 */ public final class JwtValidators { + private JwtValidators() { + } + /** *

        - * Create a {@link Jwt} Validator that contains all standard validators when an issuer is known. + * Create a {@link Jwt} Validator that contains all standard validators when an issuer + * is known. *

        *

        - * User's wanting to leverage the defaults plus additional validation can add the result of this - * method to {@code DelegatingOAuth2TokenValidator} along with the additional validators. + * User's wanting to leverage the defaults plus additional validation can add the + * result of this method to {@code DelegatingOAuth2TokenValidator} along with the + * additional validators. *

        * @param issuer the issuer - * @return - a delegating validator containing all standard validators as well as any supplied + * @return - a delegating validator containing all standard validators as well as any + * supplied */ public static OAuth2TokenValidator createDefaultWithIssuer(String issuer) { List> validators = new ArrayList<>(); @@ -53,14 +61,15 @@ public final class JwtValidators { * Create a {@link Jwt} Validator that contains all standard validators. *

        *

        - * User's wanting to leverage the defaults plus additional validation can add the result of this - * method to {@code DelegatingOAuth2TokenValidator} along with the additional validators. + * User's wanting to leverage the defaults plus additional validation can add the + * result of this method to {@code DelegatingOAuth2TokenValidator} along with the + * additional validators. *

        - * @return - a delegating validator containing all standard validators as well as any supplied + * @return - a delegating validator containing all standard validators as well as any + * supplied */ public static OAuth2TokenValidator createDefault() { return new DelegatingOAuth2TokenValidator<>(Arrays.asList(new JwtTimestampValidator())); } - private JwtValidators() {} } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverter.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverter.java index 50a6d382dc..221c6d730c 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverter.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverter.java @@ -16,13 +16,6 @@ package org.springframework.security.oauth2.jwt; -import org.springframework.core.convert.ConversionService; -import org.springframework.core.convert.TypeDescriptor; -import org.springframework.core.convert.converter.Converter; -import org.springframework.security.oauth2.core.converter.ClaimConversionService; -import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; -import org.springframework.util.Assert; - import java.net.URI; import java.net.URL; import java.time.Instant; @@ -30,28 +23,41 @@ import java.util.Collection; import java.util.HashMap; import java.util.Map; +import org.springframework.core.convert.ConversionService; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.converter.ClaimConversionService; +import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; +import org.springframework.util.Assert; + /** - * Converts a JWT claim set, claim by claim. Can be configured with custom converters - * by claim name. + * Converts a JWT claim set, claim by claim. Can be configured with custom converters by + * claim name. * * @author Josh Cummings * @since 5.1 * @see ClaimTypeConverter */ public final class MappedJwtClaimSetConverter implements Converter, Map> { - private final static ConversionService CONVERSION_SERVICE = ClaimConversionService.getSharedInstance(); - private final static TypeDescriptor OBJECT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Object.class); - private final static TypeDescriptor STRING_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(String.class); - private final static TypeDescriptor INSTANT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Instant.class); - private final static TypeDescriptor URL_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(URL.class); + + private static final ConversionService CONVERSION_SERVICE = ClaimConversionService.getSharedInstance(); + + private static final TypeDescriptor OBJECT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Object.class); + + private static final TypeDescriptor STRING_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(String.class); + + private static final TypeDescriptor INSTANT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Instant.class); + + private static final TypeDescriptor URL_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(URL.class); + private final Map> claimTypeConverters; + private final Converter, Map> delegate; /** * Constructs a {@link MappedJwtClaimSetConverter} with the provided arguments * * This will completely replace any set of default converters. - * * @param claimTypeConverters The {@link Map} of converters to use */ public MappedJwtClaimSetConverter(Map> claimTypeConverters) { @@ -64,34 +70,32 @@ public final class MappedJwtClaimSetConverter implements Converter * MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); * * - * Or, the following would supply a custom converter for the subject, leaving the other defaults - * in place: + * Or, the following would supply a custom converter for the subject, leaving the + * other defaults in place: * *
         	 * 	MappedJwtClaimsSetConverter.withDefaults(
         	 * 		Collections.singletonMap(JwtClaimNames.SUB, new UserDetailsServiceJwtSubjectConverter()));
         	 * 
        * - * To completely replace the underlying {@link Map} of converters, see {@link MappedJwtClaimSetConverter#MappedJwtClaimSetConverter(Map)}. - * + * To completely replace the underlying {@link Map} of converters, see + * {@link MappedJwtClaimSetConverter#MappedJwtClaimSetConverter(Map)}. * @param claimTypeConverters - * @return An instance of {@link MappedJwtClaimSetConverter} that contains the converters provided, - * plus any defaults that were not overridden. + * @return An instance of {@link MappedJwtClaimSetConverter} that contains the + * converters provided, plus any defaults that were not overridden. */ public static MappedJwtClaimSetConverter withDefaults(Map> claimTypeConverters) { Assert.notNull(claimTypeConverters, "claimTypeConverters cannot be null"); - Converter stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); Converter collectionStringConverter = getConverter( TypeDescriptor.collection(Collection.class, STRING_TYPE_DESCRIPTOR)); - Map> claimNameToConverter = new HashMap<>(); claimNameToConverter.put(JwtClaimNames.AUD, collectionStringConverter); claimNameToConverter.put(JwtClaimNames.EXP, MappedJwtClaimSetConverter::convertInstant); @@ -101,12 +105,11 @@ public final class MappedJwtClaimSetConverter implements Converter getConverter(TypeDescriptor targetDescriptor) { - return source -> CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor); + return (source) -> CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor); } private static Instant convertInstant(Object source) { @@ -114,9 +117,7 @@ public final class MappedJwtClaimSetConverter implements Converter "Could not coerce " + source + " into an Instant"); return result; } @@ -131,31 +132,25 @@ public final class MappedJwtClaimSetConverter implements Converter convert(Map claims) { Assert.notNull(claims, "claims cannot be null"); - Map mappedClaims = this.delegate.convert(claims); - mappedClaims = removeClaims(mappedClaims); mappedClaims = addClaims(mappedClaims); - Instant issuedAt = (Instant) mappedClaims.get(JwtClaimNames.IAT); Instant expiresAt = (Instant) mappedClaims.get(JwtClaimNames.EXP); if (issuedAt == null && expiresAt != null) { mappedClaims.put(JwtClaimNames.IAT, expiresAt.minusSeconds(1)); } - return mappedClaims; } @@ -171,11 +166,12 @@ public final class MappedJwtClaimSetConverter implements Converter addClaims(Map claims) { Map result = new HashMap<>(claims); - for (Map.Entry> entry : claimTypeConverters.entrySet()) { + for (Map.Entry> entry : this.claimTypeConverters.entrySet()) { if (!claims.containsKey(entry.getKey()) && entry.getValue().convert(null) != null) { result.put(entry.getKey(), entry.getValue().convert(null)); } } return result; } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 317d23a30d..97213b90c2 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -29,6 +29,7 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import java.util.function.Consumer; + import javax.crypto.SecretKey; import com.nimbusds.jose.JOSEException; @@ -69,7 +70,8 @@ import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; /** - * A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus configuration. + * A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus + * configuration. * * @author Josh Cummings * @author Joe Grandja @@ -77,18 +79,18 @@ import org.springframework.web.client.RestTemplate; * @since 5.2 */ public final class NimbusJwtDecoder implements JwtDecoder { - private static final String DECODING_ERROR_MESSAGE_TEMPLATE = - "An error occurred while attempting to decode the Jwt: %s"; + + private static final String DECODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to decode the Jwt: %s"; private final JWTProcessor jwtProcessor; - private Converter, Map> claimSetConverter = - MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); + private Converter, Map> claimSetConverter = MappedJwtClaimSetConverter + .withDefaults(Collections.emptyMap()); + private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); /** * Configures a {@link NimbusJwtDecoder} with the given parameters - * * @param jwtProcessor - the {@link JWTProcessor} to use */ public NimbusJwtDecoder(JWTProcessor jwtProcessor) { @@ -98,7 +100,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Use this {@link Jwt} Validator - * * @param jwtValidator - the Jwt Validator to use */ public void setJwtValidator(OAuth2TokenValidator jwtValidator) { @@ -108,7 +109,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Use the following {@link Converter} for manipulating the JWT's claim set - * * @param claimSetConverter the {@link Converter} to use */ public void setClaimSetConverter(Converter, Map> claimSetConverter) { @@ -118,7 +118,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Decode and validate the JWT from its compact claims representation format - * * @param token the JWT value * @return a validated {@link Jwt} * @throws JwtException @@ -136,7 +135,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { private JWT parse(String token) { try { return JWTParser.parse(token); - } catch (Exception ex) { + } + catch (Exception ex) { throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } } @@ -145,55 +145,54 @@ public final class NimbusJwtDecoder implements JwtDecoder { try { // Verify the signature JWTClaimsSet jwtClaimsSet = this.jwtProcessor.process(parsedJwt, null); - Map headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); Map claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); - + // @formatter:off return Jwt.withTokenValue(token) - .headers(h -> h.putAll(headers)) - .claims(c -> c.putAll(claims)) + .headers((h) -> h.putAll(headers)) + .claims((c) -> c.putAll(claims)) .build(); - } catch (RemoteKeySourceException ex) { + // @formatter:on + } + catch (RemoteKeySourceException ex) { if (ex.getCause() instanceof ParseException) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set")); - } else { - throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } - } catch (JOSEException ex) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); - } catch (Exception ex) { + } + catch (JOSEException ex) { + throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + } + catch (Exception ex) { if (ex.getCause() instanceof ParseException) { throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload")); - } else { - throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } + throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } } - private Jwt validateJwt(Jwt jwt){ + private Jwt validateJwt(Jwt jwt) { OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); if (result.hasErrors()) { Collection errors = result.getErrors(); - String validationErrorString = "Unable to validate Jwt"; - for (OAuth2Error oAuth2Error : errors) { - if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { - validationErrorString = String.format( - DECODING_ERROR_MESSAGE_TEMPLATE, oAuth2Error.getDescription()); - break; - } - } - throw new JwtValidationException( - validationErrorString, - result.getErrors()); + String validationErrorString = getJwtValidationExceptionMessage(errors); + throw new JwtValidationException(validationErrorString, errors); } - return jwt; } + private String getJwtValidationExceptionMessage(Collection errors) { + for (OAuth2Error oAuth2Error : errors) { + if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { + return String.format(DECODING_ERROR_MESSAGE_TEMPLATE, oAuth2Error.getDescription()); + } + } + return "Unable to validate Jwt"; + } + /** - * Use the given - * JWK Set uri. - * + * Use the given JWK Set + * uri. * @param jwkSetUri the JWK Set uri to use * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations */ @@ -203,7 +202,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Use the given public key to validate JWTs - * * @param key the public key to use * @return a {@link PublicKeyJwtDecoderBuilder} for further configurations */ @@ -213,7 +211,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Use the given {@code SecretKey} to validate the MAC on a JSON Web Signature (JWS). - * * @param secretKey the {@code SecretKey} used to validate the MAC * @return a {@link SecretKeyJwtDecoderBuilder} for further configurations */ @@ -223,26 +220,32 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * A builder for creating {@link NimbusJwtDecoder} instances based on a - * JWK Set uri. + * JWK Set + * uri. */ public static final class JwkSetUriJwtDecoderBuilder { + private String jwkSetUri; + private Set signatureAlgorithms = new HashSet<>(); + private RestOperations restOperations = new RestTemplate(); + private Cache cache; + private Consumer> jwtProcessorCustomizer; private JwkSetUriJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = jwkSetUri; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Append the given signing - * algorithm - * to the set of algorithms to use. - * + * algorithm to the set of algorithms to use. * @param signatureAlgorithm the algorithm to use * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations */ @@ -254,10 +257,10 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Configure the list of - * algorithms - * to use with the given {@link Consumer}. - * - * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list + * algorithms to use with the given {@link Consumer}. + * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring + * the algorithm list * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations */ public JwkSetUriJwtDecoderBuilder jwsAlgorithms(Consumer> signatureAlgorithmsConsumer) { @@ -267,11 +270,11 @@ public final class NimbusJwtDecoder implements JwtDecoder { } /** - * Use the given {@link RestOperations} to coordinate with the authorization servers indicated in the - * JWK Set uri - * as well as the - * Issuer. - * + * Use the given {@link RestOperations} to coordinate with the authorization + * servers indicated in the + * JWK Set uri as well + * as the Issuer. * @param restOperations * @return */ @@ -284,7 +287,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Use the given {@link Cache} to store * JWK Set. - * * @param cache the {@link Cache} to be used to store JWK Set * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations * @since 5.4 @@ -296,14 +298,15 @@ public final class NimbusJwtDecoder implements JwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations * @since 5.4 */ - public JwkSetUriJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public JwkSetUriJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; @@ -312,14 +315,13 @@ public final class NimbusJwtDecoder implements JwtDecoder { JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); - } else { - Set jwsAlgorithms = new HashSet<>(); - for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { - JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); - jwsAlgorithms.add(jwsAlgorithm); - } - return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } + Set jwsAlgorithms = new HashSet<>(); + for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + jwsAlgorithms.add(jwsAlgorithm); + } + return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } JWKSource jwkSource(ResourceRetriever jwkSetRetriever) { @@ -335,18 +337,15 @@ public final class NimbusJwtDecoder implements JwtDecoder { JWKSource jwkSource = jwkSource(jwkSetRetriever); ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); - // Spring Security validates the claim set independent from Nimbus - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - return jwtProcessor; } /** * Build the configured {@link NimbusJwtDecoder}. - * * @return the configured {@link NimbusJwtDecoder} */ public NimbusJwtDecoder build() { @@ -356,12 +355,14 @@ public final class NimbusJwtDecoder implements JwtDecoder { private static URL toURL(String url) { try { return new URL(url); - } catch (MalformedURLException ex) { + } + catch (MalformedURLException ex) { throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex); } } private static class NoOpJwkSetCache implements JWKSetCache { + @Override public void put(JWKSet jwkSet) { } @@ -375,10 +376,13 @@ public final class NimbusJwtDecoder implements JwtDecoder { public boolean requiresRefresh() { return true; } + } private static class CachingResourceRetriever implements ResourceRetriever { + private final Cache cache; + private final ResourceRetriever resourceRetriever; CachingResourceRetriever(Cache cache, ResourceRetriever resourceRetriever) { @@ -388,26 +392,29 @@ public final class NimbusJwtDecoder implements JwtDecoder { @Override public Resource retrieveResource(URL url) throws IOException { - String jwkSet; try { - jwkSet = this.cache.get(url.toString(), + String jwkSet = this.cache.get(url.toString(), () -> this.resourceRetriever.retrieveResource(url).getContent()); - } catch (Cache.ValueRetrievalException ex) { + return new Resource(jwkSet, "UTF-8"); + } + catch (Cache.ValueRetrievalException ex) { Throwable thrownByValueLoader = ex.getCause(); if (thrownByValueLoader instanceof IOException) { throw (IOException) thrownByValueLoader; } throw new IOException(thrownByValueLoader); - } catch (Exception ex) { + } + catch (Exception ex) { throw new IOException(ex); } - - return new Resource(jwkSet, "UTF-8"); } + } private static class RestOperationsResourceRetriever implements ResourceRetriever { + private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); + private final RestOperations restOperations; RestOperationsResourceRetriever(RestOperations restOperations) { @@ -419,46 +426,54 @@ public final class NimbusJwtDecoder implements JwtDecoder { public Resource retrieveResource(URL url) throws IOException { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); - - ResponseEntity response; - try { - RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); - response = this.restOperations.exchange(request, String.class); - } catch (Exception ex) { - throw new IOException(ex); - } - + ResponseEntity response = getResponse(url, headers); if (response.getStatusCodeValue() != 200) { throw new IOException(response.toString()); } - return new Resource(response.getBody(), "UTF-8"); } + + private ResponseEntity getResponse(URL url, HttpHeaders headers) throws IOException { + try { + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); + return this.restOperations.exchange(request, String.class); + } + catch (Exception ex) { + throw new IOException(ex); + } + } + } + } /** * A builder for creating {@link NimbusJwtDecoder} instances based on a public key. */ public static final class PublicKeyJwtDecoderBuilder { + private JWSAlgorithm jwsAlgorithm; + private RSAPublicKey key; + private Consumer> jwtProcessorCustomizer; private PublicKeyJwtDecoderBuilder(RSAPublicKey key) { Assert.notNull(key, "key cannot be null"); this.jwsAlgorithm = JWSAlgorithm.RS256; this.key = key; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Use the given signing - * algorithm. + * algorithm. * * The value should be one of - * RS256, RS384, or RS512. - * + * RS256, RS384, or RS512. * @param signatureAlgorithm the algorithm to use * @return a {@link PublicKeyJwtDecoderBuilder} for further configurations */ @@ -469,71 +484,71 @@ public final class NimbusJwtDecoder implements JwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link PublicKeyJwtDecoderBuilder} for further configurations * @since 5.4 */ - public PublicKeyJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public PublicKeyJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; } JWTProcessor processor() { - if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) { - throw new IllegalStateException("The provided key is of type RSA; " + - "however the signature algorithm is of some other type: " + - this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512."); - } - - JWSKeySelector jwsKeySelector = - new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key); + Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm), + () -> "The provided key is of type RSA; however the signature algorithm is of some other type: " + + this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512."); + JWSKeySelector jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); - // Spring Security validates the claim set independent from Nimbus - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - return jwtProcessor; } /** * Build the configured {@link NimbusJwtDecoder}. - * * @return the configured {@link NimbusJwtDecoder} */ public NimbusJwtDecoder build() { return new NimbusJwtDecoder(processor()); } + } /** - * A builder for creating {@link NimbusJwtDecoder} instances based on a {@code SecretKey}. + * A builder for creating {@link NimbusJwtDecoder} instances based on a + * {@code SecretKey}. */ public static final class SecretKeyJwtDecoderBuilder { + private final SecretKey secretKey; + private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + private Consumer> jwtProcessorCustomizer; private SecretKeyJwtDecoderBuilder(SecretKey secretKey) { Assert.notNull(secretKey, "secretKey cannot be null"); this.secretKey = secretKey; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Use the given - * algorithm - * when generating the MAC. + * algorithm when generating the MAC. * * The value should be one of - * HS256, HS384 or HS512. - * + * HS256, HS384 or HS512. * @param macAlgorithm the MAC algorithm to use * @return a {@link SecretKeyJwtDecoderBuilder} for further configurations */ @@ -544,14 +559,15 @@ public final class NimbusJwtDecoder implements JwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link SecretKeyJwtDecoderBuilder} for further configurations * @since 5.4 */ - public SecretKeyJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public SecretKeyJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; @@ -559,7 +575,6 @@ public final class NimbusJwtDecoder implements JwtDecoder { /** * Build the configured {@link NimbusJwtDecoder}. - * * @return the configured {@link NimbusJwtDecoder} */ public NimbusJwtDecoder build() { @@ -567,17 +582,17 @@ public final class NimbusJwtDecoder implements JwtDecoder { } JWTProcessor processor() { - JWSKeySelector jwsKeySelector = - new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.secretKey); + JWSKeySelector jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, + this.secretKey); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); - // Spring Security validates the claim set independent from Nimbus - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - return jwtProcessor; } + } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java index 8d698d01ae..cb3e8a7b6f 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; +import java.util.Collections; +import java.util.Map; + import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; @@ -22,44 +26,44 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.web.client.RestOperations; -import java.util.Collections; -import java.util.Map; - -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; - /** - * An implementation of a {@link JwtDecoder} that "decodes" a - * JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a - * JSON Web Signature (JWS). The public key used for verification is obtained from the - * JSON Web Key (JWK) Set {@code URL} supplied via the constructor. + * An implementation of a {@link JwtDecoder} that "decodes" a JSON Web Token (JWT) and + * additionally verifies it's digital signature if the JWT is a JSON Web Signature (JWS). + * The public key used for verification is obtained from the JSON Web Key (JWK) Set + * {@code URL} supplied via the constructor. * *

        * NOTE: This implementation uses the Nimbus JOSE + JWT SDK internally. * * @deprecated Use {@link NimbusJwtDecoder} or {@link JwtDecoders} instead - * * @author Joe Grandja * @author Josh Cummings * @since 5.0 * @see JwtDecoder * @see NimbusJwtDecoder - * @see JSON Web Token (JWT) - * @see JSON Web Signature (JWS) - * @see JSON Web Key (JWK) - * @see Nimbus JOSE + JWT SDK + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Key + * (JWK) + * @see Nimbus + * JOSE + JWT SDK */ @Deprecated public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { + private NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder jwtDecoderBuilder; + private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); - private Converter, Map> claimSetConverter = - MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); + + private Converter, Map> claimSetConverter = MappedJwtClaimSetConverter + .withDefaults(Collections.emptyMap()); private NimbusJwtDecoder delegate; /** * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters. - * * @param jwkSetUrl the JSON Web Key (JWK) Set {@code URL} */ public NimbusJwtDecoderJwkSupport(String jwkSetUrl) { @@ -68,15 +72,15 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { /** * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters. - * * @param jwkSetUrl the JSON Web Key (JWK) Set {@code URL} - * @param jwsAlgorithm the JSON Web Algorithm (JWA) used for verifying the digital signatures + * @param jwsAlgorithm the JSON Web Algorithm (JWA) used for verifying the digital + * signatures */ public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) { Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty"); Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty"); - - this.jwtDecoderBuilder = withJwkSetUri(jwkSetUrl).jwsAlgorithm(SignatureAlgorithm.from(jwsAlgorithm)); + this.jwtDecoderBuilder = NimbusJwtDecoder.withJwkSetUri(jwkSetUrl) + .jwsAlgorithm(SignatureAlgorithm.from(jwsAlgorithm)); this.delegate = makeDelegate(); } @@ -94,7 +98,6 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { /** * Use this {@link Jwt} Validator - * * @param jwtValidator - the Jwt Validator to use */ public void setJwtValidator(OAuth2TokenValidator jwtValidator) { @@ -105,7 +108,6 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { /** * Use the following {@link Converter} for manipulating the JWT's claim set - * * @param claimSetConverter the {@link Converter} to use */ public void setClaimSetConverter(Converter, Map> claimSetConverter) { @@ -116,13 +118,14 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { /** * Sets the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set. - * + * @param restOperations the {@link RestOperations} used when requesting the JSON Web + * Key (JWK) Set * @since 5.1 - * @param restOperations the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set */ public void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.jwtDecoderBuilder = this.jwtDecoderBuilder.restOperations(restOperations); this.delegate = makeDelegate(); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 9bbedfd306..122cf14c37 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -13,8 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; +import java.security.interfaces.RSAPublicKey; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; + +import javax.crypto.SecretKey; + import com.nimbusds.jose.Header; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; @@ -37,6 +50,9 @@ import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; @@ -47,24 +63,11 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import javax.crypto.SecretKey; -import java.security.interfaces.RSAPublicKey; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; -import java.util.function.Function; /** - * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a - * JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a - * JSON Web Signature (JWS). + * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a JSON Web + * Token (JWT) and additionally verifies it's digital signature if the JWT is a JSON Web + * Signature (JWS). * *

        * NOTE: This implementation uses the Nimbus JOSE + JWT SDK internally. @@ -73,21 +76,26 @@ import java.util.function.Function; * @author Joe Grandja * @since 5.1 * @see ReactiveJwtDecoder - * @see JSON Web Token (JWT) - * @see JSON Web Signature (JWS) - * @see JSON Web Key (JWK) - * @see Nimbus JOSE + JWT SDK + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Key + * (JWK) + * @see Nimbus + * JOSE + JWT SDK */ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { + private final Converter> jwtProcessor; private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); - private Converter, Map> claimSetConverter = - MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); + + private Converter, Map> claimSetConverter = MappedJwtClaimSetConverter + .withDefaults(Collections.emptyMap()); /** * Constructs a {@code NimbusReactiveJwtDecoder} using the provided parameters. - * * @param jwkSetUrl the JSON Web Key (JWK) Set {@code URL} */ public NimbusReactiveJwtDecoder(String jwkSetUrl) { @@ -96,7 +104,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Constructs a {@code NimbusReactiveJwtDecoder} using the provided parameters. - * * @param publicKey the {@code RSAPublicKey} used to verify the signature * @since 5.2 */ @@ -106,8 +113,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Constructs a {@code NimbusReactiveJwtDecoder} using the provided parameters. - * - * @param jwtProcessor the {@link Converter} used to process and verify the signed Jwt and return the Jwt Claim Set + * @param jwtProcessor the {@link Converter} used to process and verify the signed Jwt + * and return the Jwt Claim Set * @since 5.2 */ public NimbusReactiveJwtDecoder(Converter> jwtProcessor) { @@ -116,7 +123,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s. - * * @param jwtValidator the {@link OAuth2TokenValidator} to use */ public void setJwtValidator(OAuth2TokenValidator jwtValidator) { @@ -126,7 +132,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Use the following {@link Converter} for manipulating the JWT's claim set - * * @param claimSetConverter the {@link Converter} to use */ public void setClaimSetConverter(Converter, Map> claimSetConverter) { @@ -146,20 +151,26 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private JWT parse(String token) { try { return JWTParser.parse(token); - } catch (Exception ex) { + } + catch (Exception ex) { throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); } } private Mono decode(JWT parsedToken) { try { + // @formatter:off return this.jwtProcessor.convert(parsedToken) - .map(set -> createJwt(parsedToken, set)) - .map(this::validateJwt) - .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); - } catch (JwtException ex) { + .map((set) -> createJwt(parsedToken, set)) + .map(this::validateJwt) + .onErrorMap((ex) -> !(ex instanceof IllegalStateException) && !(ex instanceof JwtException), + (ex) -> new JwtException("An error occurred while attempting to decode the Jwt: ", ex)); + // @formatter:on + } + catch (JwtException ex) { throw ex; - } catch (RuntimeException ex) { + } + catch (RuntimeException ex) { throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); } } @@ -168,12 +179,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { try { Map headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); Map claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); - - return Jwt.withTokenValue(parsedJwt.getParsedString()) - .headers(h -> h.putAll(headers)) - .claims(c -> c.putAll(claims)) - .build(); - } catch (Exception ex) { + return Jwt.withTokenValue(parsedJwt.getParsedString()).headers((h) -> h.putAll(headers)) + .claims((c) -> c.putAll(claims)).build(); + } + catch (Exception ex) { throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); } } @@ -182,23 +191,24 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); if (result.hasErrors()) { Collection errors = result.getErrors(); - String validationErrorString = "Unable to validate Jwt"; - for (OAuth2Error oAuth2Error : errors) { - if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { - validationErrorString = oAuth2Error.getDescription(); - break; - } - } + String validationErrorString = getJwtValidationExceptionMessage(errors); throw new JwtValidationException(validationErrorString, errors); } - return jwt; } + private String getJwtValidationExceptionMessage(Collection errors) { + for (OAuth2Error oAuth2Error : errors) { + if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { + return oAuth2Error.getDescription(); + } + } + return "Unable to validate Jwt"; + } + /** - * Use the given - * JWK Set uri to validate JWTs. - * + * Use the given JWK Set + * uri to validate JWTs. * @param jwkSetUri the JWK Set uri to use * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations * @@ -210,7 +220,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Use the given public key to validate JWTs - * * @param key the public key to use * @return a {@link PublicKeyReactiveJwtDecoderBuilder} for further configurations * @@ -222,7 +231,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Use the given {@code SecretKey} to validate the MAC on a JSON Web Signature (JWS). - * * @param secretKey the {@code SecretKey} used to validate the MAC * @return a {@link SecretKeyReactiveJwtDecoderBuilder} for further configurations * @@ -234,7 +242,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Use the given {@link Function} to validate JWTs - * * @param source the {@link Function} * @return a {@link JwkSourceReactiveJwtDecoderBuilder} for further configurations * @@ -244,29 +251,47 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new JwkSourceReactiveJwtDecoderBuilder(source); } + private static JWTClaimsSet createClaimsSet(JWTProcessor jwtProcessor, + JWT parsedToken, C context) { + try { + return jwtProcessor.process(parsedToken, context); + } + catch (BadJOSEException ex) { + throw new BadJwtException("Failed to validate the token", ex); + } + catch (JOSEException ex) { + throw new JwtException("Failed to validate the token", ex); + } + } + /** * A builder for creating {@link NimbusReactiveJwtDecoder} instances based on a - * JWK Set uri. + * JWK Set + * uri. * * @since 5.2 */ public static final class JwkSetUriReactiveJwtDecoderBuilder { + private final String jwkSetUri; + private Set signatureAlgorithms = new HashSet<>(); + private WebClient webClient = WebClient.create(); + private Consumer> jwtProcessorCustomizer; private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = jwkSetUri; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Append the given signing - * algorithm - * to the set of algorithms to use. - * + * algorithm to the set of algorithms to use. * @param signatureAlgorithm the algorithm to use * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations */ @@ -278,24 +303,24 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Configure the list of - * algorithms - * to use with the given {@link Consumer}. - * - * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list + * algorithms to use with the given {@link Consumer}. + * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring + * the algorithm list * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations */ - public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms(Consumer> signatureAlgorithmsConsumer) { + public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms( + Consumer> signatureAlgorithmsConsumer) { Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null"); signatureAlgorithmsConsumer.accept(this.signatureAlgorithms); return this; } /** - * Use the given {@link WebClient} to coordinate with the authorization servers indicated in the - * JWK Set uri - * as well as the - * Issuer. - * + * Use the given {@link WebClient} to coordinate with the authorization servers + * indicated in the JWK + * Set uri as well as the Issuer. * @param webClient * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations */ @@ -306,14 +331,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusReactiveJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusReactiveJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations * @since 5.4 */ - public JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; @@ -321,7 +347,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Build the configured {@link NimbusReactiveJwtDecoder}. - * * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { @@ -331,14 +356,13 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); - } else { - Set jwsAlgorithms = new HashSet<>(); - for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { - JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); - jwsAlgorithms.add(jwsAlgorithm); - } - return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } + Set jwsAlgorithms = new HashSet<>(); + for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { + JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); + jwsAlgorithms.add(jwsAlgorithm); + } + return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } Converter> processor() { @@ -346,19 +370,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); JWSKeySelector jwsKeySelector = jwsKeySelector(jwkSource); jwtProcessor.setJWSKeySelector(jwsKeySelector); - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri); source.setWebClient(this.webClient); - Function expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector); - return jwt -> { + return (jwt) -> { JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader()); return source.get(selector) - .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) - .map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList))); + .onErrorMap((ex) -> new IllegalStateException("Could not obtain the keys", ex)) + .map((jwkList) -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList))); }; } @@ -374,34 +396,39 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { if (!expectedJwsAlgorithms.apply(jwsHeader.getAlgorithm())) { throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm()); } - return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); } + } /** - * A builder for creating {@link NimbusReactiveJwtDecoder} instances based on a public key. + * A builder for creating {@link NimbusReactiveJwtDecoder} instances based on a public + * key. * * @since 5.2 */ public static final class PublicKeyReactiveJwtDecoderBuilder { + private final RSAPublicKey key; + private JWSAlgorithm jwsAlgorithm; + private Consumer> jwtProcessorCustomizer; private PublicKeyReactiveJwtDecoderBuilder(RSAPublicKey key) { Assert.notNull(key, "key cannot be null"); this.key = key; this.jwsAlgorithm = JWSAlgorithm.RS256; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Use the given signing - * algorithm. - * The value should be one of - * RS256, RS384, or RS512. - * + * algorithm. The value should be one of + * RS256, RS384, or RS512. * @param signatureAlgorithm the algorithm to use * @return a {@link PublicKeyReactiveJwtDecoderBuilder} for further configurations */ @@ -412,14 +439,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusReactiveJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusReactiveJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link PublicKeyReactiveJwtDecoderBuilder} for further configurations * @since 5.4 */ - public PublicKeyReactiveJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public PublicKeyReactiveJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; @@ -427,7 +455,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Build the configured {@link NimbusReactiveJwtDecoder}. - * * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { @@ -435,50 +462,50 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } Converter> processor() { - if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) { - throw new IllegalStateException("The provided key is of type RSA; " + - "however the signature algorithm is of some other type: " + - this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512."); - } - - JWSKeySelector jwsKeySelector = - new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key); + Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm), + () -> "The provided key is of type RSA; however the signature algorithm is of some other type: " + + this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512."); + JWSKeySelector jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); - // Spring Security validates the claim set independent from Nimbus - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - - return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); + return (jwt) -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); } + } /** - * A builder for creating {@link NimbusReactiveJwtDecoder} instances based on a {@code SecretKey}. + * A builder for creating {@link NimbusReactiveJwtDecoder} instances based on a + * {@code SecretKey}. * * @since 5.2 */ public static final class SecretKeyReactiveJwtDecoderBuilder { + private final SecretKey secretKey; + private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256; + private Consumer> jwtProcessorCustomizer; private SecretKeyReactiveJwtDecoderBuilder(SecretKey secretKey) { Assert.notNull(secretKey, "secretKey cannot be null"); this.secretKey = secretKey; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Use the given - * algorithm - * when generating the MAC. + * algorithm when generating the MAC. * * The value should be one of - * HS256, HS384 or HS512. - * + * HS256, HS384 or HS512. * @param macAlgorithm the MAC algorithm to use * @return a {@link SecretKeyReactiveJwtDecoderBuilder} for further configurations */ @@ -489,14 +516,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusReactiveJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusReactiveJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link SecretKeyReactiveJwtDecoderBuilder} for further configurations * @since 5.4 */ - public SecretKeyReactiveJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public SecretKeyReactiveJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; @@ -504,7 +532,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Build the configured {@link NimbusReactiveJwtDecoder}. - * * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { @@ -512,18 +539,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } Converter> processor() { - JWSKeySelector jwsKeySelector = - new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.secretKey); + JWSKeySelector jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, + this.secretKey); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); - // Spring Security validates the claim set independent from Nimbus - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - - return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); + return (jwt) -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); } + } /** @@ -532,20 +558,24 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { * @since 5.2 */ public static final class JwkSourceReactiveJwtDecoderBuilder { + private final Function> jwkSource; + private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256; + private Consumer> jwtProcessorCustomizer; private JwkSourceReactiveJwtDecoderBuilder(Function> jwkSource) { Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; - this.jwtProcessorCustomizer = (processor) -> {}; + this.jwtProcessorCustomizer = (processor) -> { + }; } /** * Use the given signing - * algorithm. - * + * algorithm. * @param jwsAlgorithm the algorithm to use * @return a {@link JwkSourceReactiveJwtDecoderBuilder} for further configurations */ @@ -556,14 +586,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } /** - * Use the given {@link Consumer} to customize the {@link JWTProcessor ConfigurableJWTProcessor} before - * passing it to the build {@link NimbusReactiveJwtDecoder}. - * + * Use the given {@link Consumer} to customize the {@link JWTProcessor + * ConfigurableJWTProcessor} before passing it to the build + * {@link NimbusReactiveJwtDecoder}. * @param jwtProcessorCustomizer the callback used to alter the processor * @return a {@link JwkSourceReactiveJwtDecoderBuilder} for further configurations * @since 5.4 */ - public JwkSourceReactiveJwtDecoderBuilder jwtProcessorCustomizer(Consumer> jwtProcessorCustomizer) { + public JwkSourceReactiveJwtDecoderBuilder jwtProcessorCustomizer( + Consumer> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); this.jwtProcessorCustomizer = jwtProcessorCustomizer; return this; @@ -571,7 +602,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { /** * Build the configured {@link NimbusReactiveJwtDecoder}. - * * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { @@ -580,36 +610,23 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { Converter> processor() { JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); - JWSKeySelector jwsKeySelector = - new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); + JWSKeySelector jwsKeySelector = new JWSVerificationKeySelector<>(this.jwsAlgorithm, + jwkSource); DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector); - jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); - + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); this.jwtProcessorCustomizer.accept(jwtProcessor); - - return jwt -> { + return (jwt) -> { if (jwt instanceof SignedJWT) { return this.jwkSource.apply((SignedJWT) jwt) - .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) - .collectList() - .map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks))); + .onErrorMap((e) -> new IllegalStateException("Could not obtain the keys", e)).collectList() + .map((jwks) -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks))); } throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); }; } + } - private static JWTClaimsSet createClaimsSet(JWTProcessor jwtProcessor, - JWT parsedToken, C context) { - try { - return jwtProcessor.process(parsedToken, context); - } - catch (BadJOSEException e) { - throw new BadJwtException("Failed to validate the token", e); - } - catch (JOSEException e) { - throw new JwtException("Failed to validate the token", e); - } - } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSource.java index 7ffbbc7a82..40799f08a5 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSource.java @@ -16,17 +16,20 @@ package org.springframework.security.oauth2.jwt; +import java.util.List; + import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSelector; import reactor.core.publisher.Mono; -import java.util.List; - /** * A reactive version of {@link com.nimbusds.jose.jwk.source.JWKSource} + * * @author Rob Winch * @since 5.1 */ interface ReactiveJWKSource { + Mono> get(JWKSelector jwkSelector); + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSourceAdapter.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSourceAdapter.java index 784688e05d..a7a2dc8aa9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSourceAdapter.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJWKSourceAdapter.java @@ -16,20 +16,22 @@ package org.springframework.security.oauth2.jwt; +import java.util.List; + import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSelector; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; import reactor.core.publisher.Mono; -import java.util.List; - /** * Adapts a {@link JWKSource} to a {@link ReactiveJWKSource} which must be non-blocking. + * * @author Rob Winch * @since 5.1 */ class ReactiveJWKSourceAdapter implements ReactiveJWKSource { + private final JWKSource source; /** @@ -44,4 +46,5 @@ class ReactiveJWKSourceAdapter implements ReactiveJWKSource { public Mono> get(JWKSelector jwkSelector) { return Mono.fromCallable(() -> this.source.get(jwkSelector, null)); } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoder.java index 7f7431e49c..85866dea30 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoder.java @@ -13,35 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import reactor.core.publisher.Mono; /** - * Implementations of this interface are responsible for "decoding" - * a JSON Web Token (JWT) from it's compact claims representation format to a {@link Jwt}. + * Implementations of this interface are responsible for "decoding" a JSON Web + * Token (JWT) from it's compact claims representation format to a {@link Jwt}. * *

        - * JWTs may be represented using the JWS Compact Serialization format for a - * JSON Web Signature (JWS) structure or JWE Compact Serialization format for a - * JSON Web Encryption (JWE) structure. Therefore, implementors are responsible - * for verifying a JWS and/or decrypting a JWE. + * JWTs may be represented using the JWS Compact Serialization format for a JSON Web + * Signature (JWS) structure or JWE Compact Serialization format for a JSON Web Encryption + * (JWE) structure. Therefore, implementors are responsible for verifying a JWS and/or + * decrypting a JWE. * * @author Rob Winch * @since 5.1 * @see Jwt - * @see JSON Web Token (JWT) - * @see JSON Web Signature (JWS) - * @see JSON Web Encryption (JWE) - * @see JWS Compact Serialization - * @see JWE Compact Serialization + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Encryption + * (JWE) + * @see JWS + * Compact Serialization + * @see JWE + * Compact Serialization */ @FunctionalInterface public interface ReactiveJwtDecoder { /** - * Decodes the JWT from it's compact claims representation format and returns a {@link Jwt}. - * + * Decodes the JWT from it's compact claims representation format and returns a + * {@link Jwt}. * @param token the JWT value * @return a {@link Jwt} * @throws JwtException if an error occurs while attempting to decode the JWT diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java index cc93890fc5..7c6f9f0c08 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderFactory.java @@ -13,25 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; /** - * A factory for {@link ReactiveJwtDecoder}(s). - * This factory should be supplied with a type that provides - * contextual information used to create a specific {@code ReactiveJwtDecoder}. + * A factory for {@link ReactiveJwtDecoder}(s). This factory should be supplied with a + * type that provides contextual information used to create a specific + * {@code ReactiveJwtDecoder}. * * @author Joe Grandja * @since 5.2 * @see ReactiveJwtDecoder - * - * @param The type that provides contextual information used to create a specific {@code ReactiveJwtDecoder}. + * @param The type that provides contextual information used to create a specific + * {@code ReactiveJwtDecoder}. */ @FunctionalInterface public interface ReactiveJwtDecoderFactory { /** * Creates a {@code ReactiveJwtDecoder} using the supplied "contextual" type. - * * @param context the type that provides contextual information * @return a {@link ReactiveJwtDecoder} */ diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java index d062b515cf..c279cc5819 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.util.Map; @@ -20,92 +21,95 @@ import java.util.Map; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri; - /** - * Allows creating a {@link ReactiveJwtDecoder} from an - * OpenID Provider Configuration or - * Authorization Server Metadata Request based on provided - * issuer and method invoked. + * Allows creating a {@link ReactiveJwtDecoder} from an OpenID + * Provider Configuration or + * Authorization Server Metadata + * Request based on provided issuer and method invoked. * * @author Josh Cummings * @since 5.1 */ public final class ReactiveJwtDecoders { + private ReactiveJwtDecoders() { + } + /** - * Creates a {@link ReactiveJwtDecoder} using the provided - * Issuer by making an - * OpenID Provider - * Configuration Request and using the values in the - * OpenID + * Creates a {@link ReactiveJwtDecoder} using the provided Issuer + * by making an OpenID + * Provider Configuration Request and using the values in the OpenID * Provider Configuration Response to initialize the {@link ReactiveJwtDecoder}. - * - * @param oidcIssuerLocation the Issuer - * @return a {@link ReactiveJwtDecoder} that was initialized by the OpenID Provider Configuration. + * @param oidcIssuerLocation the Issuer + * @return a {@link ReactiveJwtDecoder} that was initialized by the OpenID Provider + * Configuration. */ public static ReactiveJwtDecoder fromOidcIssuerLocation(String oidcIssuerLocation) { Assert.hasText(oidcIssuerLocation, "oidcIssuerLocation cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForOidcIssuerLocation(oidcIssuerLocation); + Map configuration = JwtDecoderProviderConfigurationUtils + .getConfigurationForOidcIssuerLocation(oidcIssuerLocation); return withProviderConfiguration(configuration, oidcIssuerLocation); } /** - * Creates a {@link ReactiveJwtDecoder} using the provided - * Issuer by querying - * three different discovery endpoints serially, using the values in the first successful response to - * initialize. If an endpoint returns anything other than a 200 or a 4xx, the method will exit without - * attempting subsequent endpoints. + * Creates a {@link ReactiveJwtDecoder} using the provided Issuer + * by querying three different discovery endpoints serially, using the values in the + * first successful response to initialize. If an endpoint returns anything other than + * a 200 or a 4xx, the method will exit without attempting subsequent endpoints. * - * The three endpoints are computed as follows, given that the {@code issuer} is composed of a {@code host} - * and a {@code path}: + * The three endpoints are computed as follows, given that the {@code issuer} is + * composed of a {@code host} and a {@code path}: * *

          - *
        1. - * {@code host/.well-known/openid-configuration/path}, as defined in - * RFC 8414's Compatibility Notes. - *
        2. - *
        3. - * {@code issuer/.well-known/openid-configuration}, as defined in - * - * OpenID Provider Configuration. - *
        4. - *
        5. - * {@code host/.well-known/oauth-authorization-server/path}, as defined in - * Authorization Server Metadata Request. - *
        6. + *
        7. {@code host/.well-known/openid-configuration/path}, as defined in + * RFC 8414's Compatibility + * Notes.
        8. + *
        9. {@code issuer/.well-known/openid-configuration}, as defined in + * OpenID Provider Configuration.
        10. + *
        11. {@code host/.well-known/oauth-authorization-server/path}, as defined in + * Authorization Server + * Metadata Request.
        12. *
        * * Note that the second endpoint is the equivalent of calling * {@link ReactiveJwtDecoders#fromOidcIssuerLocation(String)} - * - * @param issuer the Issuer - * @return a {@link ReactiveJwtDecoder} that was initialized by one of the described endpoints + * @param issuer the Issuer + * @return a {@link ReactiveJwtDecoder} that was initialized by one of the described + * endpoints */ public static ReactiveJwtDecoder fromIssuerLocation(String issuer) { Assert.hasText(issuer, "issuer cannot be empty"); - Map configuration = JwtDecoderProviderConfigurationUtils.getConfigurationForIssuerLocation(issuer); + Map configuration = JwtDecoderProviderConfigurationUtils + .getConfigurationForIssuerLocation(issuer); return withProviderConfiguration(configuration, issuer); } /** - * Build {@link ReactiveJwtDecoder} from - * OpenID Provider - * Configuration Response and Authorization Server Metadata - * Response. - * + * Build {@link ReactiveJwtDecoder} from OpenID + * Provider Configuration Response and + * Authorization Server + * Metadata Response. * @param configuration the configuration values - * @param issuer the Issuer + * @param issuer the Issuer * @return {@link ReactiveJwtDecoder} */ private static ReactiveJwtDecoder withProviderConfiguration(Map configuration, String issuer) { JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer); OAuth2TokenValidator jwtValidator = JwtValidators.createDefaultWithIssuer(issuer); - NimbusReactiveJwtDecoder jwtDecoder = withJwkSetUri(configuration.get("jwks_uri").toString()).build(); + NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder + .withJwkSetUri(configuration.get("jwks_uri").toString()).build(); jwtDecoder.setJwtValidator(jwtValidator); - return jwtDecoder; } - private ReactiveJwtDecoders() {} } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java index 36c4e47703..498fa81a57 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java @@ -37,6 +37,7 @@ import org.springframework.web.reactive.function.client.WebClient; * @since 5.1 */ class ReactiveRemoteJWKSource implements ReactiveJWKSource { + /** * The cached JWK set. */ @@ -51,85 +52,81 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { this.jwkSetURL = jwkSetURL; } + @Override public Mono> get(JWKSelector jwkSelector) { + // @formatter:off return this.cachedJWKSet.get() .switchIfEmpty(Mono.defer(() -> getJWKSet())) - .flatMap(jwkSet -> get(jwkSelector, jwkSet)) - .switchIfEmpty(Mono.defer(() -> getJWKSet().map(jwkSet -> jwkSelector.select(jwkSet)))); + .flatMap((jwkSet) -> get(jwkSelector, jwkSet)) + .switchIfEmpty(Mono.defer(() -> getJWKSet() + .map((jwkSet) -> jwkSelector.select(jwkSet))) + ); + // @formatter:on } private Mono> get(JWKSelector jwkSelector, JWKSet jwkSet) { return Mono.defer(() -> { // Run the selector on the JWK set List matches = jwkSelector.select(jwkSet); - if (!matches.isEmpty()) { // Success return Mono.just(matches); } - // Refresh the JWK set if the sought key ID is not in the cached JWK set - // Looking for JWK with specific ID? String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); if (soughtKeyID == null) { // No key ID specified, return no matches return Mono.just(Collections.emptyList()); } - if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { // The key ID exists in the cached JWK set, matching // failed for some other reason, return no matches return Mono.just(Collections.emptyList()); } - return Mono.empty(); - }); } /** * Updates the cached JWK set from the configured URL. - * * @return The updated JWK set. - * * @throws RemoteKeySourceException If JWK retrieval failed. */ private Mono getJWKSet() { + // @formatter:off return this.webClient.get() .uri(this.jwkSetURL) .retrieve() .bodyToMono(String.class) .map(this::parse) - .doOnNext(jwkSet -> this.cachedJWKSet.set(Mono.just(jwkSet))) + .doOnNext((jwkSet) -> this.cachedJWKSet + .set(Mono.just(jwkSet)) + ) .cache(); + // @formatter:on } private JWKSet parse(String body) { try { return JWKSet.parse(body); } - catch (ParseException e) { - throw new RuntimeException(e); + catch (ParseException ex) { + throw new RuntimeException(ex); } } /** * Returns the first specified key ID (kid) for a JWK matcher. - * * @param jwkMatcher The JWK matcher. Must not be {@code null}. - * * @return The first key ID, {@code null} if none. */ protected static String getFirstSpecifiedKeyID(final JWKMatcher jwkMatcher) { - Set keyIDs = jwkMatcher.getKeyIDs(); - if (keyIDs == null || keyIDs.isEmpty()) { return null; } - - for (String id: keyIDs) { + for (String id : keyIDs) { if (id != null) { return id; } @@ -137,7 +134,8 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { return null; // No kid in matcher } - public void setWebClient(WebClient webClient) { + void setWebClient(WebClient webClient) { this.webClient = webClient; } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java index 4111a2866e..677e06ed60 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java @@ -34,7 +34,9 @@ import org.springframework.util.Assert; * @since 5.2 */ final class SingleKeyJWSKeySelector implements JWSKeySelector { + private final List keySet; + private final JWSAlgorithm expectedJwsAlgorithm; SingleKeyJWSKeySelector(JWSAlgorithm expectedJwsAlgorithm, Key key) { @@ -51,4 +53,5 @@ final class SingleKeyJWSKeySelector implements JWSKey } return this.keySet; } + } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/package-info.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/package-info.java index bd7e2c5e4d..e3a27e0d34 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/package-info.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/package-info.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Core classes and interfaces providing support for JSON Web Token (JWT). */ diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java index 555d92b347..7a2b7fb70d 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose; import java.security.KeyFactory; @@ -23,6 +24,7 @@ import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.X509EncodedKeySpec; import java.util.Base64; + import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; @@ -30,78 +32,84 @@ import javax.crypto.spec.SecretKeySpec; * @author Joe Grandja * @since 5.2 */ -public class TestKeys { - public static final KeyFactory kf; +public final class TestKeys { + public static final KeyFactory kf; static { try { kf = KeyFactory.getInstance("RSA"); - } catch (NoSuchAlgorithmException e) { - throw new IllegalStateException(e); + } + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException(ex); } } - public static final String DEFAULT_ENCODED_SECRET_KEY = "bCzY/M48bbkwBEWjmNSIEPfwApcvXOnkCxORBEbPr+4="; - public static final SecretKey DEFAULT_SECRET_KEY = - new SecretKeySpec(Base64.getDecoder().decode(DEFAULT_ENCODED_SECRET_KEY), "AES"); + public static final SecretKey DEFAULT_SECRET_KEY = new SecretKeySpec( + Base64.getDecoder().decode(DEFAULT_ENCODED_SECRET_KEY), "AES"); - public static final String DEFAULT_RSA_PUBLIC_KEY = - "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3FlqJr5TRskIQIgdE3Dd" + - "7D9lboWdcTUT8a+fJR7MAvQm7XXNoYkm3v7MQL1NYtDvL2l8CAnc0WdSTINU6IRv" + - "c5Kqo2Q4csNX9SHOmEfzoROjQqahEcve1jBXluoCXdYuYpx4/1tfRgG6ii4Uhxh6" + - "iI8qNMJQX+fLfqhbfYfxBQVRPywBkAbIP4x1EAsbC6FSNmkhCxiMNqEgxaIpY8C2" + - "kJdJ/ZIV+WW4noDdzpKqHcwmB8FsrumlVY/DNVvUSDIipiq9PbP4H99TXN1o746o" + - "RaNa07rq1hoCgMSSy+85SagCoxlmyE+D+of9SsMY8Ol9t0rdzpobBuhyJ/o5dfvj" + - "KwIDAQAB"; + // @formatter:off + public static final String DEFAULT_RSA_PUBLIC_KEY = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3FlqJr5TRskIQIgdE3Dd" + + "7D9lboWdcTUT8a+fJR7MAvQm7XXNoYkm3v7MQL1NYtDvL2l8CAnc0WdSTINU6IRv" + + "c5Kqo2Q4csNX9SHOmEfzoROjQqahEcve1jBXluoCXdYuYpx4/1tfRgG6ii4Uhxh6" + + "iI8qNMJQX+fLfqhbfYfxBQVRPywBkAbIP4x1EAsbC6FSNmkhCxiMNqEgxaIpY8C2" + + "kJdJ/ZIV+WW4noDdzpKqHcwmB8FsrumlVY/DNVvUSDIipiq9PbP4H99TXN1o746o" + + "RaNa07rq1hoCgMSSy+85SagCoxlmyE+D+of9SsMY8Ol9t0rdzpobBuhyJ/o5dfvj" + + "KwIDAQAB"; + // @formatter:on - public static final RSAPublicKey DEFAULT_PUBLIC_KEY = publicKey(); - - private static RSAPublicKey publicKey() { + public static final RSAPublicKey DEFAULT_PUBLIC_KEY; + static { X509EncodedKeySpec spec = new X509EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PUBLIC_KEY)); try { - return (RSAPublicKey) kf.generatePublic(spec); - } catch (InvalidKeySpecException e) { - throw new IllegalArgumentException(e); + DEFAULT_PUBLIC_KEY = (RSAPublicKey) kf.generatePublic(spec); + } + catch (InvalidKeySpecException ex) { + throw new IllegalArgumentException(ex); } } - public static final String DEFAULT_RSA_PRIVATE_KEY = - "MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcWWomvlNGyQhA" + - "iB0TcN3sP2VuhZ1xNRPxr58lHswC9Cbtdc2hiSbe/sxAvU1i0O8vaXwICdzRZ1JM" + - "g1TohG9zkqqjZDhyw1f1Ic6YR/OhE6NCpqERy97WMFeW6gJd1i5inHj/W19GAbqK" + - "LhSHGHqIjyo0wlBf58t+qFt9h/EFBVE/LAGQBsg/jHUQCxsLoVI2aSELGIw2oSDF" + - "oiljwLaQl0n9khX5ZbiegN3OkqodzCYHwWyu6aVVj8M1W9RIMiKmKr09s/gf31Nc" + - "3WjvjqhFo1rTuurWGgKAxJLL7zlJqAKjGWbIT4P6h/1Kwxjw6X23St3OmhsG6HIn" + - "+jl1++MrAgMBAAECggEBAMf820wop3pyUOwI3aLcaH7YFx5VZMzvqJdNlvpg1jbE" + - "E2Sn66b1zPLNfOIxLcBG8x8r9Ody1Bi2Vsqc0/5o3KKfdgHvnxAB3Z3dPh2WCDek" + - "lCOVClEVoLzziTuuTdGO5/CWJXdWHcVzIjPxmK34eJXioiLaTYqN3XKqKMdpD0ZG" + - "mtNTGvGf+9fQ4i94t0WqIxpMpGt7NM4RHy3+Onggev0zLiDANC23mWrTsUgect/7" + - "62TYg8g1bKwLAb9wCBT+BiOuCc2wrArRLOJgUkj/F4/gtrR9ima34SvWUyoUaKA0" + - "bi4YBX9l8oJwFGHbU9uFGEMnH0T/V0KtIB7qetReywkCgYEA9cFyfBIQrYISV/OA" + - "+Z0bo3vh2aL0QgKrSXZ924cLt7itQAHNZ2ya+e3JRlTczi5mnWfjPWZ6eJB/8MlH" + - "Gpn12o/POEkU+XjZZSPe1RWGt5g0S3lWqyx9toCS9ACXcN9tGbaqcFSVI73zVTRA" + - "8J9grR0fbGn7jaTlTX2tnlOTQ60CgYEA5YjYpEq4L8UUMFkuj+BsS3u0oEBnzuHd" + - "I9LEHmN+CMPosvabQu5wkJXLuqo2TxRnAznsA8R3pCLkdPGoWMCiWRAsCn979TdY" + - "QbqO2qvBAD2Q19GtY7lIu6C35/enQWzJUMQE3WW0OvjLzZ0l/9mA2FBRR+3F9A1d" + - "rBdnmv0c3TcCgYEAi2i+ggVZcqPbtgrLOk5WVGo9F1GqUBvlgNn30WWNTx4zIaEk" + - "HSxtyaOLTxtq2odV7Kr3LGiKxwPpn/T+Ief+oIp92YcTn+VfJVGw4Z3BezqbR8lA" + - "Uf/+HF5ZfpMrVXtZD4Igs3I33Duv4sCuqhEvLWTc44pHifVloozNxYfRfU0CgYBN" + - "HXa7a6cJ1Yp829l62QlJKtx6Ymj95oAnQu5Ez2ROiZMqXRO4nucOjGUP55Orac1a" + - "FiGm+mC/skFS0MWgW8evaHGDbWU180wheQ35hW6oKAb7myRHtr4q20ouEtQMdQIF" + - "snV39G1iyqeeAsf7dxWElydXpRi2b68i3BIgzhzebQKBgQCdUQuTsqV9y/JFpu6H" + - "c5TVvhG/ubfBspI5DhQqIGijnVBzFT//UfIYMSKJo75qqBEyP2EJSmCsunWsAFsM" + - "TszuiGTkrKcZy9G0wJqPztZZl2F2+bJgnA6nBEV7g5PA4Af+QSmaIhRwqGDAuROR" + - "47jndeyIaMTNETEmOnms+as17g=="; + // @formatter:off + public static final String DEFAULT_RSA_PRIVATE_KEY = "MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcWWomvlNGyQhA" + + "iB0TcN3sP2VuhZ1xNRPxr58lHswC9Cbtdc2hiSbe/sxAvU1i0O8vaXwICdzRZ1JM" + + "g1TohG9zkqqjZDhyw1f1Ic6YR/OhE6NCpqERy97WMFeW6gJd1i5inHj/W19GAbqK" + + "LhSHGHqIjyo0wlBf58t+qFt9h/EFBVE/LAGQBsg/jHUQCxsLoVI2aSELGIw2oSDF" + + "oiljwLaQl0n9khX5ZbiegN3OkqodzCYHwWyu6aVVj8M1W9RIMiKmKr09s/gf31Nc" + + "3WjvjqhFo1rTuurWGgKAxJLL7zlJqAKjGWbIT4P6h/1Kwxjw6X23St3OmhsG6HIn" + + "+jl1++MrAgMBAAECggEBAMf820wop3pyUOwI3aLcaH7YFx5VZMzvqJdNlvpg1jbE" + + "E2Sn66b1zPLNfOIxLcBG8x8r9Ody1Bi2Vsqc0/5o3KKfdgHvnxAB3Z3dPh2WCDek" + + "lCOVClEVoLzziTuuTdGO5/CWJXdWHcVzIjPxmK34eJXioiLaTYqN3XKqKMdpD0ZG" + + "mtNTGvGf+9fQ4i94t0WqIxpMpGt7NM4RHy3+Onggev0zLiDANC23mWrTsUgect/7" + + "62TYg8g1bKwLAb9wCBT+BiOuCc2wrArRLOJgUkj/F4/gtrR9ima34SvWUyoUaKA0" + + "bi4YBX9l8oJwFGHbU9uFGEMnH0T/V0KtIB7qetReywkCgYEA9cFyfBIQrYISV/OA" + + "+Z0bo3vh2aL0QgKrSXZ924cLt7itQAHNZ2ya+e3JRlTczi5mnWfjPWZ6eJB/8MlH" + + "Gpn12o/POEkU+XjZZSPe1RWGt5g0S3lWqyx9toCS9ACXcN9tGbaqcFSVI73zVTRA" + + "8J9grR0fbGn7jaTlTX2tnlOTQ60CgYEA5YjYpEq4L8UUMFkuj+BsS3u0oEBnzuHd" + + "I9LEHmN+CMPosvabQu5wkJXLuqo2TxRnAznsA8R3pCLkdPGoWMCiWRAsCn979TdY" + + "QbqO2qvBAD2Q19GtY7lIu6C35/enQWzJUMQE3WW0OvjLzZ0l/9mA2FBRR+3F9A1d" + + "rBdnmv0c3TcCgYEAi2i+ggVZcqPbtgrLOk5WVGo9F1GqUBvlgNn30WWNTx4zIaEk" + + "HSxtyaOLTxtq2odV7Kr3LGiKxwPpn/T+Ief+oIp92YcTn+VfJVGw4Z3BezqbR8lA" + + "Uf/+HF5ZfpMrVXtZD4Igs3I33Duv4sCuqhEvLWTc44pHifVloozNxYfRfU0CgYBN" + + "HXa7a6cJ1Yp829l62QlJKtx6Ymj95oAnQu5Ez2ROiZMqXRO4nucOjGUP55Orac1a" + + "FiGm+mC/skFS0MWgW8evaHGDbWU180wheQ35hW6oKAb7myRHtr4q20ouEtQMdQIF" + + "snV39G1iyqeeAsf7dxWElydXpRi2b68i3BIgzhzebQKBgQCdUQuTsqV9y/JFpu6H" + + "c5TVvhG/ubfBspI5DhQqIGijnVBzFT//UfIYMSKJo75qqBEyP2EJSmCsunWsAFsM" + + "TszuiGTkrKcZy9G0wJqPztZZl2F2+bJgnA6nBEV7g5PA4Af+QSmaIhRwqGDAuROR" + + "47jndeyIaMTNETEmOnms+as17g=="; + // @formatter:on - public static final RSAPrivateKey DEFAULT_PRIVATE_KEY = privateKey(); - - private static RSAPrivateKey privateKey() { + public static final RSAPrivateKey DEFAULT_PRIVATE_KEY; + static { PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PRIVATE_KEY)); try { - return (RSAPrivateKey) kf.generatePrivate(spec); - } catch (InvalidKeySpecException e) { - throw new IllegalArgumentException(e); + DEFAULT_PRIVATE_KEY = (RSAPrivateKey) kf.generatePrivate(spec); + } + catch (InvalidKeySpecException ex) { + throw new IllegalArgumentException(ex); } } + + private TestKeys() { + } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/MacAlgorithmTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/MacAlgorithmTests.java index cea7ef5054..e3435cf5c2 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/MacAlgorithmTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/MacAlgorithmTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose.jws; import org.junit.Test; @@ -38,4 +39,5 @@ public class MacAlgorithmTests { public void fromWhenAlgorithmInvalidThenDoesNotResolve() { assertThat(MacAlgorithm.from("invalid")).isNull(); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithmTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithmTests.java index 4b2c19f018..99b01996fd 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithmTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/jws/SignatureAlgorithmTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jose.jws; import org.junit.Test; @@ -44,4 +45,5 @@ public class SignatureAlgorithmTests { public void fromWhenAlgorithmInvalidThenDoesNotResolve() { assertThat(SignatureAlgorithm.from("invalid")).isNull(); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtBuilderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtBuilderTests.java index 4004ef9400..0aaf02945e 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtBuilderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtBuilderTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.time.Instant; @@ -20,10 +21,7 @@ import java.time.Instant; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.EXP; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.IAT; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link Jwt.Builder}. @@ -36,27 +34,23 @@ public class JwtBuilderTests { @Test public void buildWhenCalledTwiceThenGeneratesTwoJwts() { Jwt.Builder jwtBuilder = Jwt.withTokenValue("token"); - - Jwt first = jwtBuilder - .tokenValue("V1") + // @formatter:off + Jwt first = jwtBuilder.tokenValue("V1") .header("TEST_HEADER_1", "H1") .claim("TEST_CLAIM_1", "C1") .build(); - - Jwt second = jwtBuilder - .tokenValue("V2") + Jwt second = jwtBuilder.tokenValue("V2") .header("TEST_HEADER_1", "H2") .header("TEST_HEADER_2", "H3") .claim("TEST_CLAIM_1", "C2") .claim("TEST_CLAIM_2", "C3") .build(); - + // @formatter:on assertThat(first.getHeaders()).hasSize(1); assertThat(first.getHeaders().get("TEST_HEADER_1")).isEqualTo("H1"); assertThat(first.getClaims()).hasSize(1); assertThat(first.getClaims().get("TEST_CLAIM_1")).isEqualTo("C1"); assertThat(first.getTokenValue()).isEqualTo("V1"); - assertThat(second.getHeaders()).hasSize(2); assertThat(second.getHeaders().get("TEST_HEADER_1")).isEqualTo("H2"); assertThat(second.getHeaders().get("TEST_HEADER_2")).isEqualTo("H3"); @@ -68,115 +62,99 @@ public class JwtBuilderTests { @Test public void expiresAtWhenUsingGenericOrNamedClaimMethodRequiresInstant() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .header("needs", "a header"); - + // @formatter:on Instant now = Instant.now(); - - Jwt jwt = jwtBuilder - .expiresAt(now).build(); + Jwt jwt = jwtBuilder.expiresAt(now).build(); assertThat(jwt.getExpiresAt()).isSameAs(now); - - jwt = jwtBuilder - .expiresAt(now).build(); + jwt = jwtBuilder.expiresAt(now).build(); assertThat(jwt.getExpiresAt()).isSameAs(now); - - assertThatCode(() -> jwtBuilder - .claim(EXP, "not an instant").build()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> jwtBuilder.claim(JwtClaimNames.EXP, "not an instant").build()); } @Test public void issuedAtWhenUsingGenericOrNamedClaimMethodRequiresInstant() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .header("needs", "a header"); - + // @formatter:on Instant now = Instant.now(); - - Jwt jwt = jwtBuilder - .issuedAt(now).build(); + Jwt jwt = jwtBuilder.issuedAt(now).build(); assertThat(jwt.getIssuedAt()).isSameAs(now); - - jwt = jwtBuilder - .issuedAt(now).build(); + jwt = jwtBuilder.issuedAt(now).build(); assertThat(jwt.getIssuedAt()).isSameAs(now); - - assertThatCode(() -> jwtBuilder - .claim(IAT, "not an instant").build()) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> jwtBuilder.claim(JwtClaimNames.IAT, "not an instant").build()); } @Test public void subjectWhenUsingGenericOrNamedClaimMethodThenLastOneWins() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .header("needs", "a header"); - + // @formatter:on String generic = new String("sub"); String named = new String("sub"); - - Jwt jwt = jwtBuilder - .subject(named) - .claim(SUB, generic).build(); + Jwt jwt = jwtBuilder.subject(named).claim(JwtClaimNames.SUB, generic).build(); assertThat(jwt.getSubject()).isSameAs(generic); - - jwt = jwtBuilder - .claim(SUB, generic) - .subject(named).build(); + jwt = jwtBuilder.claim(JwtClaimNames.SUB, generic).subject(named).build(); assertThat(jwt.getSubject()).isSameAs(named); } @Test public void claimsWhenRemovingAClaimThenIsNotPresent() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .claim("needs", "a claim") .header("needs", "a header"); - - Jwt jwt = jwtBuilder - .subject("sub") - .claims(claims -> claims.remove(SUB)) + Jwt jwt = jwtBuilder.subject("sub") + .claims((claims) -> claims.remove(JwtClaimNames.SUB)) .build(); + // @formatter:on assertThat(jwt.getSubject()).isNull(); } @Test public void claimsWhenAddingAClaimThenIsPresent() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .header("needs", "a header"); - + // @formatter:on String name = new String("name"); String value = new String("value"); - Jwt jwt = jwtBuilder - .claims(claims -> claims.put(name, value)) - .build(); - + Jwt jwt = jwtBuilder.claims((claims) -> claims.put(name, value)).build(); assertThat(jwt.getClaims()).hasSize(1); assertThat(jwt.getClaims().get(name)).isSameAs(value); } @Test public void headersWhenRemovingAClaimThenIsNotPresent() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .claim("needs", "a claim") .header("needs", "a header"); - - Jwt jwt = jwtBuilder - .header("alg", "none") - .headers(headers -> headers.remove("alg")) + Jwt jwt = jwtBuilder.header("alg", "none") + .headers((headers) -> headers.remove("alg")) .build(); + // @formatter:on assertThat(jwt.getHeaders().get("alg")).isNull(); } @Test public void headersWhenAddingAClaimThenIsPresent() { + // @formatter:off Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") .claim("needs", "a claim"); - + // @formatter:on String name = new String("name"); String value = new String("value"); - Jwt jwt = jwtBuilder - .headers(headers -> headers.put(name, value)) + // @formatter:off + Jwt jwt = jwtBuilder.headers((headers) -> headers.put(name, value)) .build(); - + // @formatter:on assertThat(jwt.getHeaders()).hasSize(1); assertThat(jwt.getHeaders().get(name)).isSameAs(value); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimValidatorTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimValidatorTests.java index 968af2a8ab..820d73e322 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimValidatorTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimValidatorTests.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.jwt; -import org.junit.Test; -import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; +package org.springframework.security.oauth2.jwt; import java.util.function.Predicate; +import org.junit.Test; + +import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link JwtClaimValidator}. @@ -32,38 +32,35 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; */ public class JwtClaimValidatorTests { - private static final Predicate test = claim -> claim.equals("http://test"); - private final JwtClaimValidator validator = new JwtClaimValidator<>(ISS, test); + private static final Predicate test = (claim) -> claim.equals("http://test"); + + private final JwtClaimValidator validator = new JwtClaimValidator<>(JwtClaimNames.ISS, test); @Test public void validateWhenClaimPassesTheTestThenReturnsSuccess() { - Jwt jwt = jwt().claim(ISS, "http://test").build(); - assertThat(validator.validate(jwt)) - .isEqualTo(OAuth2TokenValidatorResult.success()); + Jwt jwt = TestJwts.jwt().claim(JwtClaimNames.ISS, "http://test").build(); + assertThat(this.validator.validate(jwt)).isEqualTo(OAuth2TokenValidatorResult.success()); } @Test public void validateWhenClaimFailsTheTestThenReturnsFailure() { - Jwt jwt = jwt().claim(ISS, "http://abc").build(); - assertThat(validator.validate(jwt).getErrors().isEmpty()) - .isFalse(); + Jwt jwt = TestJwts.jwt().claim(JwtClaimNames.ISS, "http://abc").build(); + assertThat(this.validator.validate(jwt).getErrors().isEmpty()).isFalse(); } @Test public void validateWhenClaimIsNullThenThrowsIllegalArgumentException() { - assertThatThrownBy(() -> new JwtClaimValidator(null, test)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new JwtClaimValidator<>(null, test)); } @Test - public void validateWhenTestIsNullThenThrowsIllegalArgumentException(){ - assertThatThrownBy(() -> new JwtClaimValidator<>(ISS, null)) - .isInstanceOf(IllegalArgumentException.class); + public void validateWhenTestIsNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> new JwtClaimValidator<>(JwtClaimNames.ISS, null)); } @Test public void validateWhenJwtIsNullThenThrowsIllegalArgumentException() { - assertThatThrownBy(() -> validator.validate(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.validator.validate(null)); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java index 8fa7ee5869..868fe4713b 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.net.URI; @@ -38,7 +39,9 @@ import org.springframework.http.MediaType; import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link JwtDecoders} @@ -47,40 +50,46 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Rafiullah Hamedy */ public class JwtDecodersTests { + /** - * Contains those parameters required to construct a JwtDecoder as well as any required parameters + * Contains those parameters required to construct a JwtDecoder as well as any + * required parameters */ - private static final String DEFAULT_RESPONSE_TEMPLATE = - "{\n" - + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" - + " \"id_token_signing_alg_values_supported\": [\n" - + " \"RS256\"\n" - + " ], \n" - + " \"issuer\": \"%s\", \n" - + " \"jwks_uri\": \"%s/.well-known/jwks.json\", \n" - + " \"response_types_supported\": [\n" - + " \"code\", \n" - + " \"token\", \n" - + " \"id_token\", \n" - + " \"code token\", \n" - + " \"code id_token\", \n" - + " \"token id_token\", \n" - + " \"code token id_token\", \n" - + " \"none\"\n" - + " ], \n" - + " \"subject_types_supported\": [\n" - + " \"public\"\n" - + " ], \n" - + " \"token_endpoint\": \"https://example.com/oauth2/v4/token\"\n" - + "}"; + // @formatter:off + private static final String DEFAULT_RESPONSE_TEMPLATE = "{\n" + + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" + + " \"id_token_signing_alg_values_supported\": [\n" + + " \"RS256\"\n" + + " ], \n" + + " \"issuer\": \"%s\", \n" + + " \"jwks_uri\": \"%s/.well-known/jwks.json\", \n" + + " \"response_types_supported\": [\n" + + " \"code\", \n" + + " \"token\", \n" + + " \"id_token\", \n" + + " \"code token\", \n" + + " \"code id_token\", \n" + + " \"token id_token\", \n" + + " \"code token id_token\", \n" + + " \"none\"\n" + + " ], \n" + + " \"subject_types_supported\": [\n" + + " \"public\"\n" + + " ], \n" + + " \"token_endpoint\": \"https://example.com/oauth2/v4/token\"\n" + + "}"; + // @formatter:on private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}"; + private static final String ISSUER_MISMATCH = "eyJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwczpcL1wvd3Jvbmdpc3N1ZXIiLCJleHAiOjQ2ODcyNTYwNDl9.Ax8LMI6rhB9Pv_CE3kFi1JPuLj9gZycifWrLeDpkObWEEVAsIls9zAhNFyJlG-Oo7up6_mDhZgeRfyKnpSF5GhKJtXJDCzwg0ZDVUE6rS0QadSxsMMGbl7c4y0lG_7TfLX2iWeNJukJj_oSW9KzW4FsBp1BoocWjrreesqQU3fZHbikH-c_Fs2TsAIpHnxflyEzfOFWpJ8D4DtzHXqfvieMwpy42xsPZK3LR84zlasf0Ne1tC_hLHvyHRdAXwn0CMoKxc7-8j0r9Mq8kAzUsPn9If7bMLqGkxUcTPdk5x7opAUajDZx95SXHLmtztNtBa2S6EfPJXuPKG6tM5Wq5Ug"; private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration"; + private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server"; private MockWebServer server; + private String issuer; @Before @@ -99,27 +108,33 @@ public class JwtDecodersTests { public void issuerWhenResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponse(); JwtDecoder decoder = JwtDecoders.fromOidcIssuerLocation(this.issuer); - assertThatCode(() -> decoder.decode(ISSUER_MISMATCH)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("The iss claim is not valid"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(ISSUER_MISMATCH)) + .withMessageContaining("The iss claim is not valid"); + // @formatter:on } @Test public void issuerWhenOidcFallbackResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponseOidc(); JwtDecoder decoder = JwtDecoders.fromIssuerLocation(this.issuer); - assertThatCode(() -> decoder.decode(ISSUER_MISMATCH)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("The iss claim is not valid"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(ISSUER_MISMATCH)) + .withMessageContaining("The iss claim is not valid"); + // @formatter:on } @Test public void issuerWhenOAuth2ResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponseOAuth2(); JwtDecoder decoder = JwtDecoders.fromIssuerLocation(this.issuer); - assertThatCode(() -> decoder.decode(ISSUER_MISMATCH)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("The iss claim is not valid"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(ISSUER_MISMATCH)) + .withMessageContaining("The iss claim is not valid"); + // @formatter:on } @Test @@ -149,22 +164,26 @@ public class JwtDecodersTests { @Test public void issuerWhenResponseIsNonCompliantThenThrowsRuntimeException() { prepareConfigurationResponse("{ \"missing_required_keys\" : \"and_values\" }"); - assertThatCode(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)); } @Test public void issuerWhenOidcFallbackResponseIsNonCompliantThenThrowsRuntimeException() { prepareConfigurationResponseOidc("{ \"missing_required_keys\" : \"and_values\" }"); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOAuth2ResponseIsNonCompliantThenThrowsRuntimeException() { prepareConfigurationResponseOAuth2("{ \"missing_required_keys\" : \"and_values\" }"); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } // gh-7512 @@ -172,9 +191,11 @@ public class JwtDecodersTests { public void issuerWhenResponseDoesNotContainJwksUriThenThrowsIllegalArgumentException() throws JsonMappingException, JsonProcessingException { prepareConfigurationResponse(this.buildResponseWithMissingJwksUri()); - assertThatCode(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The public JWK set URI must not be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)) + .withMessage("The public JWK set URI must not be null"); + // @formatter:on } // gh-7512 @@ -182,9 +203,12 @@ public class JwtDecodersTests { public void issuerWhenOidcFallbackResponseDoesNotContainJwksUriThenThrowsIllegalArgumentException() throws JsonMappingException, JsonProcessingException { prepareConfigurationResponseOidc(this.buildResponseWithMissingJwksUri()); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The public JWK set URI must not be null"); + .withMessage("The public JWK set URI must not be null"); + // @formatter:on } // gh-7512 @@ -192,69 +216,85 @@ public class JwtDecodersTests { public void issuerWhenOAuth2ResponseDoesNotContainJwksUriThenThrowsIllegalArgumentException() throws JsonMappingException, JsonProcessingException { prepareConfigurationResponseOAuth2(this.buildResponseWithMissingJwksUri()); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The public JWK set URI must not be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)) + .withMessage("The public JWK set URI must not be null"); + // @formatter:on } @Test public void issuerWhenResponseIsMalformedThenThrowsRuntimeException() { prepareConfigurationResponse("malformed"); - assertThatCode(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOidcFallbackResponseIsMalformedThenThrowsRuntimeException() { prepareConfigurationResponseOidc("malformed"); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOAuth2ResponseIsMalformedThenThrowsRuntimeException() { prepareConfigurationResponseOAuth2("malformed"); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenRespondingIssuerMismatchesRequestedIssuerThenThrowsIllegalStateException() { prepareConfigurationResponse(String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); - assertThatCode(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> JwtDecoders.fromOidcIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOidcFallbackRespondingIssuerMismatchesRequestedIssuerThenThrowsIllegalStateException() { prepareConfigurationResponseOidc(String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOAuth2RespondingIssuerMismatchesRequestedIssuerThenThrowsIllegalStateException() { - prepareConfigurationResponseOAuth2(String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); - assertThatCode(() -> JwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalStateException.class); + prepareConfigurationResponseOAuth2( + String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> JwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test - public void issuerWhenRequestedIssuerIsUnresponsiveThenThrowsIllegalArgumentException() - throws Exception { - + public void issuerWhenRequestedIssuerIsUnresponsiveThenThrowsIllegalArgumentException() throws Exception { this.server.shutdown(); - assertThatCode(() -> JwtDecoders.fromOidcIssuerLocation("https://issuer")) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> JwtDecoders.fromOidcIssuerLocation("https://issuer")); + // @formatter:on } @Test public void issuerWhenOidcFallbackRequestedIssuerIsUnresponsiveThenThrowsIllegalArgumentException() throws Exception { - this.server.shutdown(); - assertThatCode(() -> JwtDecoders.fromIssuerLocation("https://issuer")) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> JwtDecoders.fromIssuerLocation("https://issuer")); + // @formatter:on } private void prepareConfigurationResponse() { @@ -295,9 +335,13 @@ public class JwtDecodersTests { Dispatcher dispatcher = new Dispatcher() { @Override public MockResponse dispatch(RecordedRequest request) { - return Optional.of(request).map(RecordedRequest::getRequestUrl).map(HttpUrl::toString) - .map(responses::get) - .orElse(new MockResponse().setResponseCode(404)); + // @formatter:off + return Optional.of(request) + .map(RecordedRequest::getRequestUrl) + .map(HttpUrl::toString) + .map(responses::get) + .orElse(new MockResponse().setResponseCode(404)); + // @formatter:on } }; this.server.setDispatcher(dispatcher); @@ -309,14 +353,20 @@ public class JwtDecodersTests { private String oidc() { URI uri = URI.create(this.issuer); + // @formatter:off return UriComponentsBuilder.fromUri(uri) - .replacePath(uri.getPath() + OIDC_METADATA_PATH).toUriString(); + .replacePath(uri.getPath() + OIDC_METADATA_PATH) + .toUriString(); + // @formatter:on } private String oauth() { URI uri = URI.create(this.issuer); + // @formatter:off return UriComponentsBuilder.fromUri(uri) - .replacePath(OAUTH_METADATA_PATH + uri.getPath()).toUriString(); + .replacePath(OAUTH_METADATA_PATH + uri.getPath()) + .toUriString(); + // @formatter:on } private String jwks() { @@ -324,16 +374,20 @@ public class JwtDecodersTests { } private MockResponse response(String body) { + // @formatter:off return new MockResponse() .setBody(body) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + // @formatter:on } public String buildResponseWithMissingJwksUri() throws JsonMappingException, JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); Map response = mapper.readValue(DEFAULT_RESPONSE_TEMPLATE, - new TypeReference>(){}); + new TypeReference>() { + }); response.remove("jwks_uri"); return mapper.writeValueAsString(response); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtIssuerValidatorTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtIssuerValidatorTests.java index ba0a42c0ca..547c851c2b 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtIssuerValidatorTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtIssuerValidatorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import org.junit.Test; @@ -20,39 +21,37 @@ import org.junit.Test; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * @author Josh Cummings * @since 5.1 */ public class JwtIssuerValidatorTests { + private static final String ISSUER = "https://issuer"; private final JwtIssuerValidator validator = new JwtIssuerValidator(ISSUER); @Test public void validateWhenIssuerMatchesThenReturnsSuccess() { - Jwt jwt = jwt().claim("iss", ISSUER).build(); - + Jwt jwt = TestJwts.jwt().claim("iss", ISSUER).build(); + // @formatter:off assertThat(this.validator.validate(jwt)) .isEqualTo(OAuth2TokenValidatorResult.success()); + // @formatter:on } @Test public void validateWhenIssuerMismatchesThenReturnsError() { - Jwt jwt = jwt().claim(JwtClaimNames.ISS, "https://other").build(); - + Jwt jwt = TestJwts.jwt().claim(JwtClaimNames.ISS, "https://other").build(); OAuth2TokenValidatorResult result = this.validator.validate(jwt); - assertThat(result.getErrors()).isNotEmpty(); } @Test public void validateWhenJwtHasNoIssuerThenReturnsError() { - Jwt jwt = jwt().claim(JwtClaimNames.AUD, "https://aud").build(); - + Jwt jwt = TestJwts.jwt().claim(JwtClaimNames.AUD, "https://aud").build(); OAuth2TokenValidatorResult result = this.validator.validate(jwt); assertThat(result.getErrors()).isNotEmpty(); } @@ -60,22 +59,28 @@ public class JwtIssuerValidatorTests { // gh-6073 @Test public void validateWhenIssuerMatchesAndIsNotAUriThenReturnsSuccess() { - Jwt jwt = jwt().claim(JwtClaimNames.ISS, "issuer").build(); + Jwt jwt = TestJwts.jwt().claim(JwtClaimNames.ISS, "issuer").build(); JwtIssuerValidator validator = new JwtIssuerValidator("issuer"); - + // @formatter:off assertThat(validator.validate(jwt)) .isEqualTo(OAuth2TokenValidatorResult.success()); + // @formatter:on } @Test public void validateWhenJwtIsNullThenThrowsIllegalArgumentException() { - assertThatCode(() -> this.validator.validate(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.validator.validate(null)); + // @formatter:on } @Test public void constructorWhenNullIssuerIsGivenThenThrowsIllegalArgumentException() { - assertThatCode(() -> new JwtIssuerValidator(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtIssuerValidator(null)); + // @formatter:on } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTests.java index e59716da74..cc2e99e373 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.jwt; -import org.junit.Test; -import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; +package org.springframework.security.oauth2.jwt; import java.time.Instant; import java.util.Arrays; @@ -25,6 +23,10 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Test; + +import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -33,30 +35,43 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Joe Grandja */ public class JwtTests { + private static final String ISS_CLAIM = "iss"; + private static final String SUB_CLAIM = "sub"; + private static final String AUD_CLAIM = "aud"; + private static final String EXP_CLAIM = "exp"; + private static final String NBF_CLAIM = "nbf"; + private static final String IAT_CLAIM = "iat"; + private static final String JTI_CLAIM = "jti"; private static final String ISS_VALUE = "https://provider.com"; + private static final String SUB_VALUE = "subject1"; + private static final List AUD_VALUE = Arrays.asList("aud1", "aud2"); + private static final long EXP_VALUE = Instant.now().plusSeconds(60).toEpochMilli(); + private static final long NBF_VALUE = Instant.now().plusSeconds(5).toEpochMilli(); + private static final long IAT_VALUE = Instant.now().toEpochMilli(); + private static final String JTI_VALUE = "jwt-id-1"; private static final Map HEADERS; - private static final Map CLAIMS; - private static final String JWT_TOKEN_VALUE = "jwt-token-value"; + private static final Map CLAIMS; + + private static final String JWT_TOKEN_VALUE = "jwt-token-value"; static { HEADERS = new HashMap<>(); HEADERS.put("alg", JwsAlgorithms.RS256); - CLAIMS = new HashMap<>(); CLAIMS.put(ISS_CLAIM, ISS_VALUE); CLAIMS.put(SUB_CLAIM, SUB_VALUE); @@ -74,21 +89,20 @@ public class JwtTests { @Test(expected = IllegalArgumentException.class) public void constructorWhenHeadersIsEmptyThenThrowIllegalArgumentException() { - new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), - Instant.ofEpochMilli(EXP_VALUE), Collections.emptyMap(), CLAIMS); + new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), Instant.ofEpochMilli(EXP_VALUE), + Collections.emptyMap(), CLAIMS); } @Test(expected = IllegalArgumentException.class) public void constructorWhenClaimsIsEmptyThenThrowIllegalArgumentException() { - new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), - Instant.ofEpochMilli(EXP_VALUE), HEADERS, Collections.emptyMap()); + new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), Instant.ofEpochMilli(EXP_VALUE), HEADERS, + Collections.emptyMap()); } @Test public void constructorWhenParametersProvidedAndValidThenCreated() { - Jwt jwt = new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), - Instant.ofEpochMilli(EXP_VALUE), HEADERS, CLAIMS); - + Jwt jwt = new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), Instant.ofEpochMilli(EXP_VALUE), HEADERS, + CLAIMS); assertThat(jwt.getTokenValue()).isEqualTo(JWT_TOKEN_VALUE); assertThat(jwt.getHeaders()).isEqualTo(HEADERS); assertThat(jwt.getClaims()).isEqualTo(CLAIMS); @@ -100,4 +114,5 @@ public class JwtTests { assertThat(jwt.getIssuedAt().toEpochMilli()).isEqualTo(IAT_VALUE); assertThat(jwt.getId()).isEqualTo(JTI_VALUE); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTimestampValidatorTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTimestampValidatorTests.java index ee2f85825f..4f1708e85d 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTimestampValidatorTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTimestampValidatorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.time.Clock; @@ -31,9 +32,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.EXP; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests verifying {@link JwtTimestampValidator} @@ -41,37 +40,42 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * @author Josh Cummings */ public class JwtTimestampValidatorTests { + private static final Clock MOCK_NOW = Clock.fixed(Instant.ofEpochMilli(0), ZoneId.systemDefault()); + private static final String MOCK_TOKEN_VALUE = "token"; + private static final Instant MOCK_ISSUED_AT = Instant.MIN; + private static final Map MOCK_HEADER = Collections.singletonMap("alg", JwsAlgorithms.RS256); + private static final Map MOCK_CLAIM_SET = Collections.singletonMap("some", "claim"); @Test public void validateWhenJwtIsExpiredThenErrorMessageIndicatesExpirationTime() { Instant oneHourAgo = Instant.now().minusSeconds(3600); - - Jwt jwt = jwt().expiresAt(oneHourAgo).build(); - + Jwt jwt = TestJwts.jwt().expiresAt(oneHourAgo).build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(); - Collection details = jwtValidator.validate(jwt).getErrors(); - Collection messages = details.stream().map(OAuth2Error::getDescription).collect(Collectors.toList()); - + // @formatter:off + Collection messages = details.stream() + .map(OAuth2Error::getDescription) + .collect(Collectors.toList()); + // @formatter:on assertThat(messages).contains("Jwt expired at " + oneHourAgo); } @Test public void validateWhenJwtIsTooEarlyThenErrorMessageIndicatesNotBeforeTime() { Instant oneHourFromNow = Instant.now().plusSeconds(3600); - - Jwt jwt = jwt().notBefore(oneHourFromNow).build(); - + Jwt jwt = TestJwts.jwt().notBefore(oneHourFromNow).build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(); - Collection details = jwtValidator.validate(jwt).getErrors(); - Collection messages = details.stream().map(OAuth2Error::getDescription).collect(Collectors.toList()); - + // @formatter:off + Collection messages = details.stream() + .map(OAuth2Error::getDescription) + .collect(Collectors.toList()); + // @formatter:on assertThat(messages).contains("Jwt used before " + oneHourFromNow); } @@ -79,105 +83,83 @@ public class JwtTimestampValidatorTests { public void validateWhenConfiguredWithClockSkewThenValidatesUsingThatSkew() { Duration oneDayOff = Duration.ofDays(1); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(oneDayOff); - Instant now = Instant.now(); Instant almostOneDayAgo = now.minus(oneDayOff).plusSeconds(10); Instant almostOneDayFromNow = now.plus(oneDayOff).minusSeconds(10); Instant justOverOneDayAgo = now.minus(oneDayOff).minusSeconds(10); Instant justOverOneDayFromNow = now.plus(oneDayOff).plusSeconds(10); - - Jwt jwt = jwt() - .expiresAt(almostOneDayAgo) - .notBefore(almostOneDayFromNow) - .build(); - + Jwt jwt = TestJwts.jwt().expiresAt(almostOneDayAgo).notBefore(almostOneDayFromNow).build(); assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); - - jwt = jwt().expiresAt(justOverOneDayAgo).build(); - + jwt = TestJwts.jwt().expiresAt(justOverOneDayAgo).build(); OAuth2TokenValidatorResult result = jwtValidator.validate(jwt); - Collection messages = - result.getErrors().stream().map(OAuth2Error::getDescription).collect(Collectors.toList()); - + // @formatter:off + Collection messages = result.getErrors() + .stream() + .map(OAuth2Error::getDescription) + .collect(Collectors.toList()); + // @formatter:on assertThat(result.hasErrors()).isTrue(); assertThat(messages).contains("Jwt expired at " + justOverOneDayAgo); - - jwt = jwt().notBefore(justOverOneDayFromNow).build(); - + jwt = TestJwts.jwt().notBefore(justOverOneDayFromNow).build(); result = jwtValidator.validate(jwt); - messages = - result.getErrors().stream().map(OAuth2Error::getDescription).collect(Collectors.toList()); - + // @formatter:off + messages = result.getErrors() + .stream() + .map(OAuth2Error::getDescription) + .collect(Collectors.toList()); + // @formatter:on assertThat(result.hasErrors()).isTrue(); assertThat(messages).contains("Jwt used before " + justOverOneDayFromNow); - } @Test public void validateWhenConfiguredWithFixedClockThenValidatesUsingFixedTime() { - Jwt jwt = jwt().expiresAt(Instant.now(MOCK_NOW)).build(); - + Jwt jwt = TestJwts.jwt().expiresAt(Instant.now(MOCK_NOW)).build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(Duration.ofNanos(0)); jwtValidator.setClock(MOCK_NOW); - assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); - - jwt = jwt().notBefore(Instant.now(MOCK_NOW)).build(); - + jwt = TestJwts.jwt().notBefore(Instant.now(MOCK_NOW)).build(); assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); } @Test public void validateWhenNeitherExpiryNorNotBeforeIsSpecifiedThenReturnsSuccessfulResult() { - Jwt jwt = jwt().claims(c -> c.remove(EXP)).build(); - + Jwt jwt = TestJwts.jwt().claims((c) -> c.remove(JwtClaimNames.EXP)).build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(); assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); } @Test public void validateWhenNotBeforeIsValidAndExpiryIsNotSpecifiedThenReturnsSuccessfulResult() { - Jwt jwt = jwt() - .claims(c -> c.remove(EXP)) - .notBefore(Instant.MIN) - .build(); - + Jwt jwt = TestJwts.jwt().claims((c) -> c.remove(JwtClaimNames.EXP)).notBefore(Instant.MIN).build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(); assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); } @Test public void validateWhenExpiryIsValidAndNotBeforeIsNotSpecifiedThenReturnsSuccessfulResult() { - Jwt jwt = jwt().build(); - + Jwt jwt = TestJwts.jwt().build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(); assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); } @Test public void validateWhenBothExpiryAndNotBeforeAreValidThenReturnsSuccessfulResult() { - Jwt jwt = jwt() - .expiresAt(Instant.now(MOCK_NOW)) - .notBefore(Instant.now(MOCK_NOW)) - .build(); - + Jwt jwt = TestJwts.jwt().expiresAt(Instant.now(MOCK_NOW)).notBefore(Instant.now(MOCK_NOW)).build(); JwtTimestampValidator jwtValidator = new JwtTimestampValidator(Duration.ofNanos(0)); jwtValidator.setClock(MOCK_NOW); - assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse(); } @Test public void setClockWhenInvokedWithNullThenThrowsIllegalArgumentException() { JwtTimestampValidator jwtValidator = new JwtTimestampValidator(); - - assertThatCode(() -> jwtValidator.setClock(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> jwtValidator.setClock(null)); } @Test public void constructorWhenInvokedWithNullDurationThenThrowsIllegalArgumentException() { - assertThatCode(() -> new JwtTimestampValidator(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new JwtTimestampValidator(null)); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverterTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverterTests.java index 322890449d..403d5692ea 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverterTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverterTests.java @@ -31,10 +31,11 @@ import org.junit.Test; import org.springframework.core.convert.converter.Converter; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link MappedJwtClaimSetConverter} @@ -42,57 +43,44 @@ import static org.mockito.Mockito.when; * @author Josh Cummings */ public class MappedJwtClaimSetConverterTests { + @Test public void convertWhenUsingCustomExpiresAtConverterThenIssuedAtConverterStillConsultsIt() { Instant at = Instant.ofEpochMilli(1000000000000L); Converter expiresAtConverter = mock(Converter.class); - when(expiresAtConverter.convert(any())).thenReturn(at); - + given(expiresAtConverter.convert(any())).willReturn(at); MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter .withDefaults(Collections.singletonMap(JwtClaimNames.EXP, expiresAtConverter)); - Map source = new HashMap<>(); Map target = converter.convert(source); - - assertThat(target.get(JwtClaimNames.IAT)). - isEqualTo(Instant.ofEpochMilli(at.toEpochMilli()).minusSeconds(1)); + assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochMilli(at.toEpochMilli()).minusSeconds(1)); } @Test public void convertWhenUsingDefaultsThenBasesIssuedAtOffOfExpiration() { - MappedJwtClaimSetConverter converter = - MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map source = Collections.singletonMap(JwtClaimNames.EXP, 1000000000L); Map target = converter.convert(source); - assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(1000000000L)); assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L).minusSeconds(1)); } @Test public void convertWhenUsingDefaultsThenCoercesAudienceAccordingToJwtSpec() { - MappedJwtClaimSetConverter converter = - MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map source = Collections.singletonMap(JwtClaimNames.AUD, "audience"); Map target = converter.convert(source); - assertThat(target.get(JwtClaimNames.AUD)).isInstanceOf(Collection.class); assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience")); - source = Collections.singletonMap(JwtClaimNames.AUD, Arrays.asList("one", "two")); target = converter.convert(source); - assertThat(target.get(JwtClaimNames.AUD)).isInstanceOf(Collection.class); assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("one", "two")); } @Test public void convertWhenUsingDefaultsThenCoercesAllAttributesInJwtSpec() { - MappedJwtClaimSetConverter converter = - MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map source = new HashMap<>(); source.put(JwtClaimNames.JTI, 1); source.put(JwtClaimNames.AUD, "audience"); @@ -101,9 +89,7 @@ public class MappedJwtClaimSetConverterTests { source.put(JwtClaimNames.ISS, "https://any.url"); source.put(JwtClaimNames.NBF, 1000000000); source.put(JwtClaimNames.SUB, 1234); - Map target = converter.convert(source); - assertThat(target.get(JwtClaimNames.JTI)).isEqualTo("1"); assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience")); assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(2000000000L)); @@ -118,8 +104,7 @@ public class MappedJwtClaimSetConverterTests { Converter claimConverter = mock(Converter.class); MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter .withDefaults(Collections.singletonMap(JwtClaimNames.SUB, claimConverter)); - when(claimConverter.convert(any(Object.class))).thenReturn("1234"); - + given(claimConverter.convert(any(Object.class))).willReturn("1234"); Map source = new HashMap<>(); source.put(JwtClaimNames.JTI, 1); source.put(JwtClaimNames.AUD, "audience"); @@ -128,9 +113,7 @@ public class MappedJwtClaimSetConverterTests { source.put(JwtClaimNames.ISS, URI.create("https://any.url")); source.put(JwtClaimNames.NBF, "1000000000"); source.put(JwtClaimNames.SUB, 2345); - Map target = converter.convert(source); - assertThat(target.get(JwtClaimNames.JTI)).isEqualTo("1"); assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience")); assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(2000000000L)); @@ -142,12 +125,9 @@ public class MappedJwtClaimSetConverterTests { @Test public void convertWhenConverterReturnsNullThenClaimIsRemoved() { - MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter - .withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map source = Collections.singletonMap(JwtClaimNames.ISS, null); Map target = converter.convert(source); - assertThat(target).doesNotContainKey(JwtClaimNames.ISS); } @@ -156,11 +136,9 @@ public class MappedJwtClaimSetConverterTests { Converter claimConverter = mock(Converter.class); MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter .withDefaults(Collections.singletonMap("custom-claim", claimConverter)); - when(claimConverter.convert(any())).thenReturn("custom-value"); - + given(claimConverter.convert(any())).willReturn("custom-value"); Map source = new HashMap<>(); Map target = converter.convert(source); - assertThat(target.get("custom-claim")).isEqualTo("custom-value"); } @@ -169,8 +147,7 @@ public class MappedJwtClaimSetConverterTests { Converter claimConverter = mock(Converter.class); MappedJwtClaimSetConverter converter = new MappedJwtClaimSetConverter( Collections.singletonMap(JwtClaimNames.SUB, claimConverter)); - when(claimConverter.convert(any(Object.class))).thenReturn("1234"); - + given(claimConverter.convert(any(Object.class))).willReturn("1234"); Map source = new HashMap<>(); source.put(JwtClaimNames.JTI, new Object()); source.put(JwtClaimNames.AUD, new Object()); @@ -179,9 +156,7 @@ public class MappedJwtClaimSetConverterTests { source.put(JwtClaimNames.ISS, new Object()); source.put(JwtClaimNames.NBF, new Object()); source.put(JwtClaimNames.SUB, new Object()); - Map target = converter.convert(source); - assertThat(target.get(JwtClaimNames.JTI)).isEqualTo(source.get(JwtClaimNames.JTI)); assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(source.get(JwtClaimNames.AUD)); assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(source.get(JwtClaimNames.EXP)); @@ -193,28 +168,21 @@ public class MappedJwtClaimSetConverterTests { @Test public void convertWhenUsingDefaultsThenFailedConversionThrowsIllegalStateException() { - MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter - .withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map badIssuer = Collections.singletonMap(JwtClaimNames.ISS, "https://badly formed iss"); - assertThatCode(() -> converter.convert(badIssuer)).isInstanceOf(IllegalStateException.class); - + assertThatIllegalStateException().isThrownBy(() -> converter.convert(badIssuer)); Map badIssuedAt = Collections.singletonMap(JwtClaimNames.IAT, "badly-formed-iat"); - assertThatCode(() -> converter.convert(badIssuedAt)).isInstanceOf(IllegalStateException.class); - + assertThatIllegalStateException().isThrownBy(() -> converter.convert(badIssuedAt)); Map badExpiresAt = Collections.singletonMap(JwtClaimNames.EXP, "badly-formed-exp"); - assertThatCode(() -> converter.convert(badExpiresAt)).isInstanceOf(IllegalStateException.class); - + assertThatIllegalStateException().isThrownBy(() -> converter.convert(badExpiresAt)); Map badNotBefore = Collections.singletonMap(JwtClaimNames.NBF, "badly-formed-nbf"); - assertThatCode(() -> converter.convert(badNotBefore)).isInstanceOf(IllegalStateException.class); + assertThatIllegalStateException().isThrownBy(() -> converter.convert(badNotBefore)); } // gh-6073 @Test public void convertWhenIssuerIsNotAUriThenConvertsToString() { - MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter - .withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map nonUriIssuer = Collections.singletonMap(JwtClaimNames.ISS, "issuer"); Map target = converter.convert(nonUriIssuer); assertThat(target.get(JwtClaimNames.ISS)).isEqualTo("issuer"); @@ -223,9 +191,7 @@ public class MappedJwtClaimSetConverterTests { // gh-6073 @Test public void convertWhenIssuerIsOfTypeURLThenConvertsToString() throws Exception { - MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter - .withDefaults(Collections.emptyMap()); - + MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap()); Map issuer = Collections.singletonMap(JwtClaimNames.ISS, new URL("https://issuer")); Map target = converter.convert(issuer); assertThat(target.get(JwtClaimNames.ISS)).isEqualTo("https://issuer"); @@ -233,13 +199,12 @@ public class MappedJwtClaimSetConverterTests { @Test public void constructWhenAnyParameterIsNullThenIllegalArgumentException() { - assertThatCode(() -> new MappedJwtClaimSetConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new MappedJwtClaimSetConverter(null)); } @Test public void withDefaultsWhenAnyParameterIsNullThenIllegalArgumentException() { - assertThatCode(() -> MappedJwtClaimSetConverter.withDefaults(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> MappedJwtClaimSetConverter.withDefaults(null)); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java index 2597dd2638..28f407ded2 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.util.Arrays; @@ -22,7 +23,6 @@ import java.util.Map; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; -import org.assertj.core.api.Assertions; import org.junit.Test; import org.springframework.core.convert.converter.Converter; @@ -37,14 +37,14 @@ import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * Tests for {@link NimbusJwtDecoderJwkSupport}. @@ -53,46 +53,46 @@ import static org.mockito.Mockito.when; * @author Josh Cummings */ public class NimbusJwtDecoderJwkSupportTests { + private static final String JWK_SET_URL = "https://provider.com/oauth2/keys"; + private static final String JWS_ALGORITHM = JwsAlgorithms.RS256; private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}"; + private static final String MALFORMED_JWK_SET = "malformed"; private static final String SIGNED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJzY3AiOlsibWVzc2FnZTpyZWFkIl0sImV4cCI6NDY4Mzg5Nzc3Nn0.LtMVtIiRIwSyc3aX35Zl0JVwLTcQZAB3dyBOMHNaHCKUljwMrf20a_gT79LfhjDzE_fUVUmFiAO32W1vFnYpZSVaMDUgeIOIOpxfoe9shj_uYenAwIS-_UxqGVIJiJoXNZh_MK80ShNpvsQwamxWEEOAMBtpWNiVYNDMdfgho9n3o5_Z7Gjy8RLBo1tbDREbO9kTFwGIxm_EYpezmRCRq4w1DdS6UDW321hkwMxPnCMSWOvp-hRpmgY2yjzLgPJ6Aucmg9TJ8jloAP1DjJoF1gRR7NTAk8LOGkSjTzVYDYMbCF51YdpojhItSk80YzXiEsv1mTz4oMM49jXBmfXFMA"; + private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A"; + private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9."; private NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); @Test public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwtDecoderJwkSupport(null)); } @Test public void constructorWhenJwkSetUrlInvalidThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport("invalid.com")) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwtDecoderJwkSupport("invalid.com")); } @Test public void constructorWhenJwsAlgorithmIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(JWK_SET_URL, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwtDecoderJwkSupport(JWK_SET_URL, null)); } @Test public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { - Assertions.assertThatThrownBy(() -> this.jwtDecoder.setRestOperations(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.jwtDecoder.setRestOperations(null)); } @Test public void decodeWhenJwtInvalidThenThrowJwtException() { - assertThatThrownBy(() -> this.jwtDecoder.decode("invalid")) - .isInstanceOf(JwtException.class); + assertThatExceptionOfType(JwtException.class).isThrownBy(() -> this.jwtDecoder.decode("invalid")); } // gh-5168 @@ -100,57 +100,65 @@ public class NimbusJwtDecoderJwkSupportTests { public void decodeWhenExpClaimNullThenDoesNotThrowException() { NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL); jwtDecoder.setRestOperations(mockJwkSetResponse(JWK_SET)); - jwtDecoder.setClaimSetConverter(map -> { + jwtDecoder.setClaimSetConverter((map) -> { Map claims = new HashMap<>(map); claims.remove(JwtClaimNames.EXP); return claims; }); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)).doesNotThrowAnyException(); + jwtDecoder.decode(SIGNED_JWT); } // gh-5457 @Test public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() { - assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) - .isInstanceOf(JwtException.class) - .hasMessageContaining("Unsupported algorithm of none"); + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) + .withMessageContaining("Unsupported algorithm of none"); + // @formatter:on } @Test public void decodeWhenJwtIsMalformedThenReturnsStockException() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - assertThatCode(() -> jwtDecoder.decode(MALFORMED_JWT)) - .isInstanceOf(JwtException.class) - .hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(MALFORMED_JWT)) + .withMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); + // @formatter:on server.shutdown(); } } @Test public void decodeWhenJwkResponseIsMalformedThenReturnsStockException() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtException.class) - .hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) + .withMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); + // @formatter:on server.shutdown(); } } @Test public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtException.class) - .hasMessageContaining("An error occurred while attempting to decode the Jwt"); + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) + .withMessageContaining("An error occurred while attempting to decode the Jwt"); + // @formatter:on server.shutdown(); } } @@ -158,13 +166,13 @@ public class NimbusJwtDecoderJwkSupportTests { // gh-5603 @Test public void decodeWhenCustomRestOperationsSetThenUsed() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); RestTemplate restTemplate = spy(new RestTemplate()); jwtDecoder.setRestOperations(restTemplate); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)).doesNotThrowAnyException(); + jwtDecoder.decode(SIGNED_JWT); verify(restTemplate).exchange(any(RequestEntity.class), eq(String.class)); server.shutdown(); } @@ -172,59 +180,54 @@ public class NimbusJwtDecoderJwkSupportTests { @Test public void decodeWhenJwtFailsValidationThenReturnsCorrespondingErrorMessage() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); - NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - OAuth2Error failure = new OAuth2Error("mock-error", "mock-description", "mock-uri"); - OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); - when(jwtValidator.validate(any(Jwt.class))).thenReturn(OAuth2TokenValidatorResult.failure(failure)); + given(jwtValidator.validate(any(Jwt.class))).willReturn(OAuth2TokenValidatorResult.failure(failure)); decoder.setJwtValidator(jwtValidator); - - assertThatCode(() -> decoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(SIGNED_JWT)) + .withMessageContaining("mock-description"); + // @formatter:on } } @Test public void decodeWhenJwtValidationHasTwoErrorsThenJwtExceptionMessageShowsFirstError() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); - NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - OAuth2Error firstFailure = new OAuth2Error("mock-error", "mock-description", "mock-uri"); OAuth2Error secondFailure = new OAuth2Error("another-error", "another-description", "another-uri"); OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(firstFailure, secondFailure); - OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); - when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + given(jwtValidator.validate(any(Jwt.class))).willReturn(result); decoder.setJwtValidator(jwtValidator); - - assertThatCode(() -> decoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description") - .hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure)); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(SIGNED_JWT)) + .withMessageContaining("mock-description") + .satisfies((ex) -> assertThat(ex) + .hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure)) + ); + // @formatter:on } } @Test public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); - NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - Converter, Map> claimSetConverter = mock(Converter.class); - when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value")); + given(claimSetConverter.convert(any(Map.class))).willReturn(Collections.singletonMap("custom", "value")); decoder.setClaimSetConverter(claimSetConverter); - Jwt jwt = decoder.decode(SIGNED_JWT); assertThat(jwt.getClaims().size()).isEqualTo(1); assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); @@ -233,14 +236,14 @@ public class NimbusJwtDecoderJwkSupportTests { @Test public void setClaimSetConverterWhenIsNullThenThrowsIllegalArgumentException() { - assertThatCode(() -> jwtDecoder.setClaimSetConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.jwtDecoder.setClaimSetConverter(null)); } private static RestOperations mockJwkSetResponse(String response) { RestOperations restOperations = mock(RestOperations.class); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(new ResponseEntity<>(response, HttpStatus.OK)); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(response, HttpStatus.OK)); return restOperations; } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index 9e1e8e95b3..3181054742 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -33,6 +33,7 @@ import java.util.Date; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; + import javax.crypto.SecretKey; import com.nimbusds.jose.JOSEObjectType; @@ -53,11 +54,10 @@ import com.nimbusds.jwt.proc.BadJWTException; import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; import okhttp3.mockwebserver.MockWebServer; -import org.assertj.core.api.Assertions; import org.junit.BeforeClass; import org.junit.Test; - import org.mockito.ArgumentCaptor; + import org.springframework.cache.Cache; import org.springframework.cache.concurrent.ConcurrentMapCache; import org.springframework.core.convert.converter.Converter; @@ -75,18 +75,16 @@ import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey; -import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey; /** * Tests for {@link NimbusJwtDecoder} @@ -96,17 +94,25 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecre * @author Mykyta Bezverkhyi */ public class NimbusJwtDecoderTests { + private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}"; + private static final String MALFORMED_JWK_SET = "malformed"; private static final String SIGNED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJzY3AiOlsibWVzc2FnZTpyZWFkIl0sImV4cCI6NDY4Mzg5Nzc3Nn0.LtMVtIiRIwSyc3aX35Zl0JVwLTcQZAB3dyBOMHNaHCKUljwMrf20a_gT79LfhjDzE_fUVUmFiAO32W1vFnYpZSVaMDUgeIOIOpxfoe9shj_uYenAwIS-_UxqGVIJiJoXNZh_MK80ShNpvsQwamxWEEOAMBtpWNiVYNDMdfgho9n3o5_Z7Gjy8RLBo1tbDREbO9kTFwGIxm_EYpezmRCRq4w1DdS6UDW321hkwMxPnCMSWOvp-hRpmgY2yjzLgPJ6Aucmg9TJ8jloAP1DjJoF1gRR7NTAk8LOGkSjTzVYDYMbCF51YdpojhItSk80YzXiEsv1mTz4oMM49jXBmfXFMA"; + private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A"; + private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9."; + private static final String EMPTY_EXP_CLAIM_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJhdWQiOiJhdWRpZW5jZSJ9.D1eT0jpBEpuh74p-YT-uF81Z7rkVqIpUtJ5hWWFiVShZ9s8NIntK4Q1GlvlziiySSaVYaXtpTmDB3c8r-Z5Mj4ibihiueCSq7jaPD3sA8IMQKL-L6Uol8MSD_lSFE2n3fVBTxFeaejBKfZsDxnhzgpy8g7PncR47w8NHs-7tKO4qw7G_SV3hkNpDNoqZTfMImxyWEebgKM2pJAhN4das2CO1KAjYMfEByLcgYncE8fzdYPJhMFo2XRRSQABoeUBuKSAwIntBaOGvcb-qII_Hefc5U0cmpNItG75F2XfX803plKI4FFpAxJsbPKWSQmhs6bZOrhx0x74pY5LS3ghmJw"; private static final String JWK_SET_URI = "https://issuer/.well-known/jwks.json"; + private static final String RS512_SIGNED_JWT = "eyJhbGciOiJSUzUxMiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJleHAiOjE5NzQzMjYxMTl9.LKAx-60EBfD7jC1jb1eKcjO4uLvf3ssISV-8tN-qp7gAjSvKvj4YA9-V2mIb6jcS1X_xGmNy6EIimZXpWaBR3nJmeu-jpe85u4WaW2Ztr8ecAi-dTO7ZozwdtljKuBKKvj4u1nF70zyCNl15AozSG0W1ASrjUuWrJtfyDG6WoZ8VfNMuhtU-xUYUFvscmeZKUYQcJ1KS-oV5tHeF8aNiwQoiPC_9KXCOZtNEJFdq6-uzFdHxvOP2yex5Gbmg5hXonauIFXG2ZPPGdXzm-5xkhBpgM8U7A_6wb3So8wBvLYYm2245QUump63AJRAy8tQpwt4n9MvQxQgS3z9R-NK92A"; + private static final String RS256_SIGNED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJleHAiOjE5NzQzMjYzMzl9.CT-H2OWEqmSs1NWmnta5ealLFvM8OlbQTjGhfRcKLNxrTrzsOkqBJl-AN3k16BQU7mS32o744TiiZ29NcDlxPsr1MqTlN86-dobPiuNIDLp3A1bOVdXMcVFuMYkrNv0yW0tGS9OjEqsCCuZDkZ1by6AhsHLbGwRY-6AQdcRouZygGpOQu1hNun5j8q5DpSTY4AXKARIFlF-O3OpVbPJ0ebr3Ki-i3U9p_55H0e4-wx2bqcApWlqgofl1I8NKWacbhZgn81iibup2W7E0CzCzh71u1Mcy3xk1sYePx-dwcxJnHmxJReBBWjJZEAeCrkbnn_OCuo2fA-EQyNJtlN5F2w"; + private static final String VERIFY_KEY = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAq4yKxb6SNePdDmQi9xFCrP6QvHosErQzryknQTTTffs0t3cy3Er3lIceuhZ7yQNSCDfPFqG8GoyoKhuChRiA5D+J2ab7bqTa1QJKfnCyERoscftgN2fXPHjHoiKbpGV2tMVw8mXl//tePOAiKbMJaBUnlAvJgkk1rVm08dSwpLC1sr2M19euf9jwnRGkMRZuhp9iCPgECRke5T8Ixpv0uQjSmGHnWUKTFlbj8sM83suROR1Ue64JSGScANc5vk3huJ/J97qTC+K2oKj6L8d9O8dpc4obijEOJwpydNvTYDgbiivYeSB00KS9jlBkQ5B2QqLvLVEygDl3dp59nGx6YQIDAQAB"; private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); @@ -122,68 +128,77 @@ public class NimbusJwtDecoderTests { @Test public void constructorWhenJwtProcessorIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new NimbusJwtDecoder(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusJwtDecoder(null)); + // @formatter:on } @Test public void setClaimSetConverterWhenIsNullThenThrowsIllegalArgumentException() { - assertThatCode(() -> this.jwtDecoder.setClaimSetConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.jwtDecoder.setClaimSetConverter(null)); + // @formatter:on } @Test public void setJwtValidatorWhenNullThenThrowIllegalArgumentException() { - Assertions.assertThatThrownBy(() -> this.jwtDecoder.setJwtValidator(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.jwtDecoder.setJwtValidator(null)); + // @formatter:on } @Test public void decodeWhenJwtInvalidThenThrowJwtException() { - assertThatThrownBy(() -> this.jwtDecoder.decode("invalid")) - .isInstanceOf(BadJwtException.class); + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> this.jwtDecoder.decode("invalid")); + // @formatter:on } // gh-5168 @Test public void decodeWhenExpClaimNullThenDoesNotThrowException() { - assertThatCode(() -> this.jwtDecoder.decode(EMPTY_EXP_CLAIM_JWT)) - .doesNotThrowAnyException(); + this.jwtDecoder.decode(EMPTY_EXP_CLAIM_JWT); } @Test public void decodeWhenIatClaimNullThenDoesNotThrowException() { - assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) - .doesNotThrowAnyException(); + this.jwtDecoder.decode(SIGNED_JWT); } // gh-5457 @Test public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() { - assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) - .isInstanceOf(BadJwtException.class) - .hasMessageContaining("Unsupported algorithm of none"); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) + .withMessageContaining("Unsupported algorithm of none"); + // @formatter:on } @Test public void decodeWhenJwtIsMalformedThenReturnsStockException() { - assertThatCode(() -> this.jwtDecoder.decode(MALFORMED_JWT)) - .isInstanceOf(BadJwtException.class) - .hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.jwtDecoder.decode(MALFORMED_JWT)) + .withMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); + // @formatter:on } @Test public void decodeWhenJwtFailsValidationThenReturnsCorrespondingErrorMessage() { OAuth2Error failure = new OAuth2Error("mock-error", "mock-description", "mock-uri"); - OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); - when(jwtValidator.validate(any(Jwt.class))) - .thenReturn(OAuth2TokenValidatorResult.failure(failure)); + given(jwtValidator.validate(any(Jwt.class))).willReturn(OAuth2TokenValidatorResult.failure(failure)); this.jwtDecoder.setJwtValidator(jwtValidator); - - assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> this.jwtDecoder.decode(SIGNED_JWT)) + .withMessageContaining("mock-description"); + // @formatter:on } @Test @@ -191,40 +206,40 @@ public class NimbusJwtDecoderTests { OAuth2Error firstFailure = new OAuth2Error("mock-error", "mock-description", "mock-uri"); OAuth2Error secondFailure = new OAuth2Error("another-error", "another-description", "another-uri"); OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(firstFailure, secondFailure); - OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); - when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + given(jwtValidator.validate(any(Jwt.class))).willReturn(result); this.jwtDecoder.setJwtValidator(jwtValidator); - - assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description") - .hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure)); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> this.jwtDecoder.decode(SIGNED_JWT)) + .withMessageContaining("mock-description") + .satisfies((ex) -> assertThat(ex) + .hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure)) + ); + // @formatter:on } @Test public void decodeWhenReadingErrorPickTheFirstErrorMessage() { OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); this.jwtDecoder.setJwtValidator(jwtValidator); - OAuth2Error errorEmpty = new OAuth2Error("mock-error", "", "mock-uri"); OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); OAuth2Error error2 = new OAuth2Error("mock-error-second", "mock-description-second", "mock-uri-second"); OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(errorEmpty, error, error2); - when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); - - Assertions.assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description"); + given(jwtValidator.validate(any(Jwt.class))).willReturn(result); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> this.jwtDecoder.decode(SIGNED_JWT)) + .withMessageContaining("mock-description"); + // @formatter:on } @Test public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() { Converter, Map> claimSetConverter = mock(Converter.class); - when(claimSetConverter.convert(any(Map.class))) - .thenReturn(Collections.singletonMap("custom", "value")); + given(claimSetConverter.convert(any(Map.class))).willReturn(Collections.singletonMap("custom", "value")); this.jwtDecoder.setClaimSetConverter(claimSetConverter); - Jwt jwt = this.jwtDecoder.decode(SIGNED_JWT); assertThat(jwt.getClaims().size()).isEqualTo(1); assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); @@ -235,11 +250,11 @@ public class NimbusJwtDecoderTests { public void decodeWhenClaimSetConverterFailsThenBadJwtException() { Converter, Map> claimSetConverter = mock(Converter.class); this.jwtDecoder.setClaimSetConverter(claimSetConverter); - - when(claimSetConverter.convert(any(Map.class))).thenThrow(new IllegalArgumentException("bad conversion")); - - assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(BadJwtException.class); + given(claimSetConverter.convert(any(Map.class))).willThrow(new IllegalArgumentException("bad conversion")); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.jwtDecoder.decode(SIGNED_JWT)); + // @formatter:on } @Test @@ -252,92 +267,119 @@ public class NimbusJwtDecoderTests { @Test public void decodeWhenJwkResponseIsMalformedThenReturnsStockException() { NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(MALFORMED_JWK_SET)); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtException.class) + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) .isNotInstanceOf(BadJwtException.class) - .hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); + .withMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); + // @formatter:on } @Test public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { String jwkSetUri = server.url("/.well-known/jwks.json").toString(); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).build(); - + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(jwkSetUri).build(); server.shutdown(); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtException.class) + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) .isNotInstanceOf(BadJwtException.class) - .hasMessageContaining("An error occurred while attempting to decode the Jwt"); + .withMessageContaining("An error occurred while attempting to decode the Jwt"); + // @formatter:on } } @Test public void decodeWhenJwkEndpointIsUnresponsiveAndCacheIsConfiguredThenReturnsJwtException() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { Cache cache = new ConcurrentMapCache("test-jwk-set-cache"); String jwkSetUri = server.url("/.well-known/jwks.json").toString(); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).cache(cache).build(); - + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(jwkSetUri).cache(cache).build(); server.shutdown(); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtException.class) + // @formatter:off + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) .isNotInstanceOf(BadJwtException.class) - .hasMessageContaining("An error occurred while attempting to decode the Jwt"); + .withMessageContaining("An error occurred while attempting to decode the Jwt"); + // @formatter:on } } @Test public void withJwkSetUriWhenNullOrEmptyThenThrowsException() { - Assertions.assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withJwkSetUri(null)); + // @formatter:on } @Test public void jwsAlgorithmWhenNullThenThrowsException() { - NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI); - Assertions.assertThatCode(() -> builder.jwsAlgorithm(null)).isInstanceOf(IllegalArgumentException.class); + NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> builder.jwsAlgorithm(null)); + // @formatter:on } @Test public void restOperationsWhenNullThenThrowsException() { - NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI); - Assertions.assertThatCode(() -> builder.restOperations(null)).isInstanceOf(IllegalArgumentException.class); + NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> builder.restOperations(null)); + // @formatter:on } @Test public void cacheWhenNullThenThrowsException() { - NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI); - Assertions.assertThatCode(() -> builder.cache(null)).isInstanceOf(IllegalArgumentException.class); + NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> builder.cache(null)); + // @formatter:on } @Test public void withPublicKeyWhenNullThenThrowsException() { - assertThatThrownBy(() -> withPublicKey(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withPublicKey(null)); + // @formatter:on } @Test public void buildWhenSignatureAlgorithmMismatchesKeyTypeThenThrowsException() { - Assertions.assertThatCode(() -> withPublicKey(key()) - .signatureAlgorithm(SignatureAlgorithm.ES256) - .build()) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> NimbusJwtDecoder.withPublicKey(key()) + .signatureAlgorithm(SignatureAlgorithm.ES256) + .build() + ); + // @formatter:on } @Test public void decodeWhenUsingPublicKeyThenSuccessfullyDecodes() throws Exception { - NimbusJwtDecoder decoder = withPublicKey(key()).build(); + NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey(key()).build(); + // @formatter:off assertThat(decoder.decode(RS256_SIGNED_JWT)) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } @Test public void decodeWhenUsingPublicKeyWithRs512ThenSuccessfullyDecodes() throws Exception { - NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey(key()) + .signatureAlgorithm(SignatureAlgorithm.RS512) + .build(); assertThat(decoder.decode(RS512_SIGNED_JWT)) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } // gh-7049 @@ -345,23 +387,35 @@ public class NimbusJwtDecoderTests { public void decodeWhenUsingPublicKeyWithKidThenStillUsesKey() throws Exception { RSAPublicKey publicKey = TestKeys.DEFAULT_PUBLIC_KEY; RSAPrivateKey privateKey = TestKeys.DEFAULT_PRIVATE_KEY; - JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("one").build(); + // @formatter:off + JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.RS256) + .keyID("one") + .build(); JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() .subject("test-subject") .expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); + // @formatter:on SignedJWT signedJwt = signedJwt(privateKey, header, claimsSet); - NimbusJwtDecoder decoder = withPublicKey(publicKey).signatureAlgorithm(SignatureAlgorithm.RS256).build(); + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder + .withPublicKey(publicKey) + .signatureAlgorithm(SignatureAlgorithm.RS256) + .build(); assertThat(decoder.decode(signedJwt.serialize())) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } @Test public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception { - NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); - Assertions.assertThatCode(() -> decoder.decode(RS256_SIGNED_JWT)) - .isInstanceOf(BadJwtException.class); + NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512) + .build(); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder.decode(RS256_SIGNED_JWT)); + // @formatter:on } // gh-8730 @@ -370,160 +424,197 @@ public class NimbusJwtDecoderTests { RSAPublicKey publicKey = TestKeys.DEFAULT_PUBLIC_KEY; RSAPrivateKey privateKey = TestKeys.DEFAULT_PRIVATE_KEY; JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.RS256).type(new JOSEObjectType("JWS")).build(); - JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() - .expirationTime(Date.from(Instant.now().plusSeconds(60))) + JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); SignedJWT signedJwt = signedJwt(privateKey, header, claimsSet); - NimbusJwtDecoder decoder = withPublicKey(publicKey) + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey(publicKey) .signatureAlgorithm(SignatureAlgorithm.RS256) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); + // @formatter:on assertThat(decoder.decode(signedJwt.serialize()).containsClaim(JwtClaimNames.EXP)).isNotNull(); } @Test public void withPublicKeyWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { - assertThatThrownBy(() -> withPublicKey(key()).jwtProcessorCustomizer(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withPublicKey(key()).jwtProcessorCustomizer(null)) + .withMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:on } @Test public void withSecretKeyWhenNullThenThrowsIllegalArgumentException() { - assertThatThrownBy(() -> withSecretKey(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("secretKey cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withSecretKey(null)) + .withMessage("secretKey cannot be null"); + // @formatter:on } @Test public void withSecretKeyWhenMacAlgorithmNullThenThrowsIllegalArgumentException() { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - assertThatThrownBy(() -> withSecretKey(secretKey).macAlgorithm(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("macAlgorithm cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withSecretKey(secretKey).macAlgorithm(null)) + .withMessage("macAlgorithm cannot be null"); + // @formatter:on } @Test public void decodeWhenUsingSecretKeyThenSuccessfullyDecodes() throws Exception { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; MacAlgorithm macAlgorithm = MacAlgorithm.HS256; + // @formatter:off JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() .subject("test-subject") .expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); + // @formatter:on SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet); - NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(macAlgorithm).build(); + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder.withSecretKey(secretKey) + .macAlgorithm(macAlgorithm) + .build(); assertThat(decoder.decode(signedJWT.serialize())) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } @Test public void decodeWhenUsingSecretKeyAndIncorrectAlgorithmThenThrowsJwtException() throws Exception { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; MacAlgorithm macAlgorithm = MacAlgorithm.HS256; + // @formatter:off JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() .subject("test-subject") .expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); + // @formatter:on SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet); - NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build(); - assertThatThrownBy(() -> decoder.decode(signedJWT.serialize())) - .isInstanceOf(BadJwtException.class) - .hasMessageContaining("Unsupported algorithm of HS256"); + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder.withSecretKey(secretKey) + .macAlgorithm(MacAlgorithm.HS512) + .build(); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder.decode(signedJWT.serialize())) + .withMessageContaining("Unsupported algorithm of HS256"); + // @formatter:on } // gh-7056 @Test public void decodeWhenUsingSecertKeyWithKidThenStillUsesKey() throws Exception { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.HS256).keyID("one").build(); + // @formatter:off + JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.HS256) + .keyID("one") + .build(); JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() .subject("test-subject") .expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); + // @formatter:on SignedJWT signedJwt = signedJwt(secretKey, header, claimsSet); - NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS256).build(); + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder.withSecretKey(secretKey) + .macAlgorithm(MacAlgorithm.HS256) + .build(); assertThat(decoder.decode(signedJwt.serialize())) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } // gh-8730 @Test public void withSecretKeyWhenUsingCustomTypeHeaderThenSuccessfullyDecodes() throws Exception { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.HS256).type(new JOSEObjectType("JWS")).build(); + // @formatter:off + JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.HS256) + .type(new JOSEObjectType("JWS")) + .build(); JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() .expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); + // @formatter:on SignedJWT signedJwt = signedJwt(secretKey, header, claimsSet); - NimbusJwtDecoder decoder = withSecretKey(secretKey) + // @formatter:off + NimbusJwtDecoder decoder = NimbusJwtDecoder.withSecretKey(secretKey) .macAlgorithm(MacAlgorithm.HS256) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); + // @formatter:on assertThat(decoder.decode(signedJwt.serialize()).containsClaim(JwtClaimNames.EXP)).isNotNull(); } @Test public void withSecretKeyWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - assertThatThrownBy(() -> withSecretKey(secretKey).jwtProcessorCustomizer(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withSecretKey(secretKey).jwtProcessorCustomizer(null)) + .withMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:on } @Test public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() { JWKSource jwkSource = mock(JWKSource.class); - JWSKeySelector jwsKeySelector = - withJwkSetUri(JWK_SET_URI).jwsKeySelector(jwkSource); + JWSKeySelector jwsKeySelector = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) + .jwsKeySelector(jwkSource); assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); - JWSVerificationKeySelector jwsVerificationKeySelector = - (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256)) - .isTrue(); + JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256)).isTrue(); } @Test public void jwsKeySelectorWhenOneAlgorithmThenReturnsSingleSelector() { JWKSource jwkSource = mock(JWKSource.class); - JWSKeySelector jwsKeySelector = - withJwkSetUri(JWK_SET_URI).jwsAlgorithm(SignatureAlgorithm.RS512) - .jwsKeySelector(jwkSource); + // @formatter:off + JWSKeySelector jwsKeySelector = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) + .jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + // @formatter:on assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); - JWSVerificationKeySelector jwsVerificationKeySelector = - (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512)) - .isTrue(); + JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512)).isTrue(); } @Test public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() { JWKSource jwkSource = mock(JWKSource.class); - JWSKeySelector jwsKeySelector = - withJwkSetUri(JWK_SET_URI) - .jwsAlgorithm(SignatureAlgorithm.RS256) - .jwsAlgorithm(SignatureAlgorithm.RS512) - .jwsKeySelector(jwkSource); + // @formatter:off + JWSKeySelector jwsKeySelector = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) + .jwsAlgorithm(SignatureAlgorithm.RS256) + .jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + // @formatter:on assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); - JWSVerificationKeySelector jwsAlgorithmMapKeySelector = - (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256)) - .isTrue(); - assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512)) - .isTrue(); + JWSVerificationKeySelector jwsAlgorithmMapKeySelector = (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256)).isTrue(); + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512)).isTrue(); } // gh-7290 @Test public void decodeWhenJwkSetRequestedThenAcceptHeaderJsonAndJwkSetJson() { RestOperations restOperations = mock(RestOperations.class); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); - JWTProcessor processor = withJwkSetUri(JWK_SET_URI) + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); + // @formatter:off + JWTProcessor processor = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .restOperations(restOperations) .processor(); + // @formatter:on NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(processor); jwtDecoder.decode(SIGNED_JWT); ArgumentCaptor requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class); @@ -534,18 +625,17 @@ public class NimbusJwtDecoderTests { @Test public void decodeWhenCacheThenStoreRetrievedJwkSetToCache() { - // given Cache cache = new ConcurrentMapCache("test-jwk-set-cache"); RestOperations restOperations = mock(RestOperations.class); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI) + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); + // @formatter:off + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .restOperations(restOperations) .cache(cache) .build(); - // when + // @formatter:on jwtDecoder.decode(SIGNED_JWT); - // then assertThat(cache.get(JWK_SET_URI, String.class)).isEqualTo(JWK_SET); ArgumentCaptor requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class); verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class)); @@ -556,17 +646,16 @@ public class NimbusJwtDecoderTests { @Test public void decodeWhenCacheThenRetrieveFromCache() { - // given RestOperations restOperations = mock(RestOperations.class); Cache cache = mock(Cache.class); - when(cache.get(eq(JWK_SET_URI), any(Callable.class))).thenReturn(JWK_SET); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI) + given(cache.get(eq(JWK_SET_URI), any(Callable.class))).willReturn(JWK_SET); + // @formatter:off + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .cache(cache) .restOperations(restOperations) .build(); - // when + // @formatter:on jwtDecoder.decode(SIGNED_JWT); - // then verify(cache).get(eq(JWK_SET_URI), any(Callable.class)); verifyNoMoreInteractions(cache); verifyNoInteractions(restOperations); @@ -574,42 +663,49 @@ public class NimbusJwtDecoderTests { @Test public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtException() { - // given Cache cache = new ConcurrentMapCache("test-jwk-set-cache"); RestOperations restOperations = mock(RestOperations.class); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenThrow(new RestClientException("Cannot retrieve JWK Set")); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI) + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willThrow(new RestClientException("Cannot retrieve JWK Set")); + // @formatter:off + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .restOperations(restOperations) .cache(cache) .build(); - // then - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(JwtException.class) + assertThatExceptionOfType(JwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) .isNotInstanceOf(BadJwtException.class) - .hasMessageContaining("An error occurred while attempting to decode the Jwt"); + .withMessageContaining("An error occurred while attempting to decode the Jwt"); + // @formatter:on } // gh-8730 @Test public void withJwkSetUriWhenUsingCustomTypeHeaderThenRefuseOmittedType() throws Exception { RestOperations restOperations = mock(RestOperations.class); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); - NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI) + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); + // @formatter:off + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .restOperations(restOperations) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); - assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) - .isInstanceOf(BadJwtException.class) - .hasMessageContaining("An error occurred while attempting to decode the Jwt: Required JOSE header \"typ\" (type) parameter is missing"); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)) + .withMessageContaining("An error occurred while attempting to decode the Jwt: " + + "Required JOSE header \"typ\" (type) parameter is missing"); + // @formatter:on } @Test public void withJwkSetUriWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { - assertThatThrownBy(() -> withJwkSetUri(JWK_SET_URI).jwtProcessorCustomizer(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI).jwtProcessorCustomizer(null)) + .withMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:on } private RSAPublicKey key() throws InvalidKeySpecException { @@ -618,7 +714,8 @@ public class NimbusJwtDecoderTests { return (RSAPublicKey) kf.generatePublic(spec); } - private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception { + private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) + throws Exception { return signedJwt(secretKey, new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet); } @@ -640,11 +737,13 @@ public class NimbusJwtDecoderTests { private static JWTProcessor withSigning(String jwkResponse) { RestOperations restOperations = mock(RestOperations.class); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(new ResponseEntity<>(jwkResponse, HttpStatus.OK)); - return withJwkSetUri(JWK_SET_URI) + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(jwkResponse, HttpStatus.OK)); + // @formatter:off + return NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .restOperations(restOperations) .processor(); + // @formatter:on } private static JWTProcessor withoutSigning() { @@ -652,16 +751,18 @@ public class NimbusJwtDecoderTests { } private static class MockJwtProcessor extends DefaultJWTProcessor { - @Override - public JWTClaimsSet process(SignedJWT signedJWT, SecurityContext context) - throws BadJOSEException { + @Override + public JWTClaimsSet process(SignedJWT signedJWT, SecurityContext context) throws BadJOSEException { try { return signedJWT.getJWTClaimsSet(); - } catch (ParseException e) { + } + catch (ParseException ex) { // Payload not a JSON object - throw new BadJWTException(e.getMessage(), e); + throw new BadJWTException(ex.getMessage(), ex); } } + } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 8c884dc8ab..e8f780b911 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -29,6 +29,7 @@ import java.util.Base64; import java.util.Collections; import java.util.Date; import java.util.Map; + import javax.crypto.SecretKey; import com.nimbusds.jose.JOSEObjectType; @@ -46,7 +47,6 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; -import org.assertj.core.api.AssertionsForClassTypes; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; @@ -64,17 +64,14 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSource; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withPublicKey; -import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withSecretKey; /** * @author Rob Winch @@ -84,27 +81,35 @@ import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.w public class NimbusReactiveJwtDecoderTests { private String expired = "eyJraWQiOiJrZXktaWQtMSIsImFsZyI6IlJTMjU2In0.eyJzY29wZSI6Im1lc3NhZ2U6cmVhZCIsImV4cCI6MTUyOTkzNzYzMX0.Dt5jFOKkB8zAmjciwvlGkj4LNStXWH0HNIfr8YYajIthBIpVgY5Hg_JL8GBmUFzKDgyusT0q60OOg8_Pdi4Lu-VTWyYutLSlNUNayMlyBaVEWfyZJnh2_OwMZr1vRys6HF-o1qZldhwcfvczHg61LwPa1ISoqaAltDTzBu9cGISz2iBUCuR0x71QhbuRNyJdjsyS96NqiM_TspyiOSxmlNch2oAef1MssOQ23CrKilIvEDsz_zk5H94q7rH0giWGdEHCENESsTJS0zvzH6r2xIWjd5WnihFpCPkwznEayxaEhrdvJqT_ceyXCIfY4m3vujPQHNDG0UshpwvDuEbPUg"; + private String messageReadToken = "eyJraWQiOiJrZXktaWQtMSIsImFsZyI6IlJTMjU2In0.eyJzY29wZSI6Im1lc3NhZ2U6cmVhZCIsImV4cCI6OTIyMzM3MjAwNjA5NjM3NX0.bnQ8IJDXmQbmIXWku0YT1HOyV_3d0iQSA_0W2CmPyELhsxFETzBEEcZ0v0xCBiswDT51rwD83wbX3YXxb84fM64AhpU8wWOxLjha4J6HJX2JnlG47ydaAVD7eWGSYTavyyQ-CwUjQWrfMVcObFZLYG11ydzRYOR9-aiHcK3AobcTcS8jZFeI8EGQV_Cd3IJ018uFCf6VnXLv7eV2kRt08Go2RiPLW47ExvD7Dzzz_wDBKfb4pNem7fDvuzB3UPcp5m9QvLZicnbS_6AvDi6P1y_DFJf-1T5gkGmX5piDH1L1jg2Yl6tjmXbk5B3VhsyjJuXE6gzq1d-xie0Z1NVOxw"; + private String unsignedToken = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9."; - private String jwkSet = - "{\n" - + " \"keys\":[\n" - + " {\n" - + " \"kty\":\"RSA\",\n" - + " \"e\":\"AQAB\",\n" - + " \"use\":\"sig\",\n" - + " \"kid\":\"key-id-1\",\n" - + " \"n\":\"qL48v1clgFw-Evm145pmh8nRYiNt72Gupsshn7Qs8dxEydCRp1DPOV_PahPk1y2nvldBNIhfNL13JOAiJ6BTiF-2ICuICAhDArLMnTH61oL1Hepq8W1xpa9gxsnL1P51thvfmiiT4RTW57koy4xIWmIp8ZXXfYgdH2uHJ9R0CQBuYKe7nEOObjxCFWC8S30huOfW2cYtv0iB23h6w5z2fDLjddX6v_FXM7ktcokgpm3_XmvT_-bL6_GGwz9k6kJOyMTubecr-WT__le8ikY66zlplYXRQh6roFfFCL21Pt8xN5zrk-0AMZUnmi8F2S2ztSBmAVJ7H71ELXsURBVZpw\"\n" - + " }\n" - + " ]\n" - + "}"; + + // @formatter:off + private String jwkSet = "{\n" + + " \"keys\":[\n" + + " {\n" + + " \"kty\":\"RSA\",\n" + + " \"e\":\"AQAB\",\n" + + " \"use\":\"sig\",\n" + + " \"kid\":\"key-id-1\",\n" + + " \"n\":\"qL48v1clgFw-Evm145pmh8nRYiNt72Gupsshn7Qs8dxEydCRp1DPOV_PahPk1y2nvldBNIhfNL13JOAiJ6BTiF-2ICuICAhDArLMnTH61oL1Hepq8W1xpa9gxsnL1P51thvfmiiT4RTW57koy4xIWmIp8ZXXfYgdH2uHJ9R0CQBuYKe7nEOObjxCFWC8S30huOfW2cYtv0iB23h6w5z2fDLjddX6v_FXM7ktcokgpm3_XmvT_-bL6_GGwz9k6kJOyMTubecr-WT__le8ikY66zlplYXRQh6roFfFCL21Pt8xN5zrk-0AMZUnmi8F2S2ztSBmAVJ7H71ELXsURBVZpw\"\n" + + " }\n" + + " ]\n" + + "}"; + // @formatter:on + private String jwkSetUri = "https://issuer/certs"; private String rsa512 = "eyJhbGciOiJSUzUxMiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJleHAiOjE5NzQzMjYxMTl9.LKAx-60EBfD7jC1jb1eKcjO4uLvf3ssISV-8tN-qp7gAjSvKvj4YA9-V2mIb6jcS1X_xGmNy6EIimZXpWaBR3nJmeu-jpe85u4WaW2Ztr8ecAi-dTO7ZozwdtljKuBKKvj4u1nF70zyCNl15AozSG0W1ASrjUuWrJtfyDG6WoZ8VfNMuhtU-xUYUFvscmeZKUYQcJ1KS-oV5tHeF8aNiwQoiPC_9KXCOZtNEJFdq6-uzFdHxvOP2yex5Gbmg5hXonauIFXG2ZPPGdXzm-5xkhBpgM8U7A_6wb3So8wBvLYYm2245QUump63AJRAy8tQpwt4n9MvQxQgS3z9R-NK92A"; + private String rsa256 = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJleHAiOjE5NzQzMjYzMzl9.CT-H2OWEqmSs1NWmnta5ealLFvM8OlbQTjGhfRcKLNxrTrzsOkqBJl-AN3k16BQU7mS32o744TiiZ29NcDlxPsr1MqTlN86-dobPiuNIDLp3A1bOVdXMcVFuMYkrNv0yW0tGS9OjEqsCCuZDkZ1by6AhsHLbGwRY-6AQdcRouZygGpOQu1hNun5j8q5DpSTY4AXKARIFlF-O3OpVbPJ0ebr3Ki-i3U9p_55H0e4-wx2bqcApWlqgofl1I8NKWacbhZgn81iibup2W7E0CzCzh71u1Mcy3xk1sYePx-dwcxJnHmxJReBBWjJZEAeCrkbnn_OCuo2fA-EQyNJtlN5F2w"; + private String publicKey = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAq4yKxb6SNePdDmQi9xFCrP6QvHosErQzryknQTTTffs0t3cy3Er3lIceuhZ7yQNSCDfPFqG8GoyoKhuChRiA5D+J2ab7bqTa1QJKfnCyERoscftgN2fXPHjHoiKbpGV2tMVw8mXl//tePOAiKbMJaBUnlAvJgkk1rVm08dSwpLC1sr2M19euf9jwnRGkMRZuhp9iCPgECRke5T8Ixpv0uQjSmGHnWUKTFlbj8sM83suROR1Ue64JSGScANc5vk3huJ/J97qTC+K2oKj6L8d9O8dpc4obijEOJwpydNvTYDgbiivYeSB00KS9jlBkQ5B2QqLvLVEygDl3dp59nGx6YQIDAQAB"; private MockWebServer server; + private NimbusReactiveJwtDecoder decoder; private static KeyFactory kf; @@ -118,7 +123,7 @@ public class NimbusReactiveJwtDecoderTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.server.enqueue(new MockResponse().setBody(jwkSet)); + this.server.enqueue(new MockResponse().setBody(this.jwkSet)); this.decoder = new NimbusReactiveJwtDecoder(this.server.url("/certs").toString()); } @@ -130,123 +135,133 @@ public class NimbusReactiveJwtDecoderTests { @Test public void decodeWhenInvalidUrl() { this.decoder = new NimbusReactiveJwtDecoder("https://s"); - - assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) - .isInstanceOf(IllegalStateException.class) - .hasCauseInstanceOf(UnknownHostException.class); - + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> this.decoder.decode(this.messageReadToken).block()) + .withCauseInstanceOf(UnknownHostException.class); + // @formatter:on } @Test public void decodeWhenMessageReadScopeThenSuccess() { Jwt jwt = this.decoder.decode(this.messageReadToken).block(); - assertThat(jwt.getClaims().get("scope")).isEqualTo("message:read"); } @Test public void decodeWhenRSAPublicKeyThenSuccess() throws Exception { - byte[] bytes = Base64.getDecoder().decode("MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqL48v1clgFw+Evm145pmh8nRYiNt72Gupsshn7Qs8dxEydCRp1DPOV/PahPk1y2nvldBNIhfNL13JOAiJ6BTiF+2ICuICAhDArLMnTH61oL1Hepq8W1xpa9gxsnL1P51thvfmiiT4RTW57koy4xIWmIp8ZXXfYgdH2uHJ9R0CQBuYKe7nEOObjxCFWC8S30huOfW2cYtv0iB23h6w5z2fDLjddX6v/FXM7ktcokgpm3/XmvT/+bL6/GGwz9k6kJOyMTubecr+WT//le8ikY66zlplYXRQh6roFfFCL21Pt8xN5zrk+0AMZUnmi8F2S2ztSBmAVJ7H71ELXsURBVZpwIDAQAB"); + byte[] bytes = Base64.getDecoder().decode( + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqL48v1clgFw+Evm145pmh8nRYiNt72Gupsshn7Qs8dxEydCRp1DPOV/PahPk1y2nvldBNIhfNL13JOAiJ6BTiF+2ICuICAhDArLMnTH61oL1Hepq8W1xpa9gxsnL1P51thvfmiiT4RTW57koy4xIWmIp8ZXXfYgdH2uHJ9R0CQBuYKe7nEOObjxCFWC8S30huOfW2cYtv0iB23h6w5z2fDLjddX6v/FXM7ktcokgpm3/XmvT/+bL6/GGwz9k6kJOyMTubecr+WT//le8ikY66zlplYXRQh6roFfFCL21Pt8xN5zrk+0AMZUnmi8F2S2ztSBmAVJ7H71ELXsURBVZpwIDAQAB"); RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA") .generatePublic(new X509EncodedKeySpec(bytes)); this.decoder = new NimbusReactiveJwtDecoder(publicKey); String noKeyId = "eyJhbGciOiJSUzI1NiJ9.eyJzY29wZSI6IiIsImV4cCI6OTIyMzM3MjAwNjA5NjM3NX0.hNVuHSUkxdLZrDfqdmKcOi0ggmNaDuB4ZPxPtJl1gwBiXzIGN6Hwl24O2BfBZiHFKUTQDs4_RvzD71mEG3DvUrcKmdYWqIB1l8KNmxQLUDG-cAPIpJmRJgCh50tf8OhOE_Cb9E1HcsOUb47kT9iz-VayNBcmo6BmyZLdEGhsdGBrc3Mkz2dd_0PF38I2Hf_cuSjn9gBjFGtiPEXJvob3PEjVTSx_zvodT8D9p3An1R3YBZf5JSd1cQisrXgDX2k1Jmf7UKKWzgfyCgnEtRWWbsUdPqo3rSEY9GDC1iSQXsFTTC1FT_JJDkwzGf011fsU5O_Ko28TARibmKTCxAKNRQ"; - - assertThatCode(() -> this.decoder.decode(noKeyId).block()) - .doesNotThrowAnyException(); + this.decoder.decode(noKeyId).block(); } @Test public void decodeWhenIssuedAtThenSuccess() { String withIssuedAt = "eyJraWQiOiJrZXktaWQtMSIsImFsZyI6IlJTMjU2In0.eyJzY29wZSI6IiIsImV4cCI6OTIyMzM3MjAwNjA5NjM3NSwiaWF0IjoxNTI5OTQyNDQ4fQ.LBzAJO-FR-uJDHST61oX4kimuQjz6QMJPW_mvEXRB6A-fMQWpfTQ089eboipAqsb33XnwWth9ELju9HMWLk0FjlWVVzwObh9FcoKelmPNR8mZIlFG-pAYGgSwi8HufyLabXHntFavBiFtqwp_z9clSOFK1RxWvt3lywEbGgtCKve0BXOjfKWiH1qe4QKGixH-NFxidvz8Qd5WbJwyb9tChC6ZKoKPv7Jp-N5KpxkY-O2iUtINvn4xOSactUsvKHgF8ZzZjvJGzG57r606OZXaNtoElQzjAPU5xDGg5liuEJzfBhvqiWCLRmSuZ33qwp3aoBnFgEw0B85gsNe3ggABg"; - Jwt jwt = this.decoder.decode(withIssuedAt).block(); - assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1529942448L)); } @Test public void decodeWhenExpiredThenFail() { - assertThatCode(() -> this.decoder.decode(this.expired).block()) - .isInstanceOf(JwtValidationException.class); + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> this.decoder.decode(this.expired).block()); } @Test public void decodeWhenNoPeriodThenFail() { - assertThatCode(() -> this.decoder.decode("").block()) - .isInstanceOf(BadJwtException.class); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.decoder.decode("").block()); + // @formatter:on } @Test public void decodeWhenInvalidJwkSetUrlThenFail() { this.decoder = new NimbusReactiveJwtDecoder("http://localhost:1280/certs"); - assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> this.decoder.decode(this.messageReadToken).block()); + // @formatter:on } @Test public void decodeWhenInvalidSignatureThenFail() { - assertThatCode(() -> this.decoder.decode(this.messageReadToken.substring(0, this.messageReadToken.length() - 2)).block()) - .isInstanceOf(BadJwtException.class); + assertThatExceptionOfType(BadJwtException.class).isThrownBy(() -> this.decoder + .decode(this.messageReadToken.substring(0, this.messageReadToken.length() - 2)).block()); } @Test public void decodeWhenAlgNoneThenFail() { - assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIm5vbmUiLA0KICAidHlwIjogIkpXVCINCn0.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block()) - .isInstanceOf(BadJwtException.class) - .hasMessage("Unsupported algorithm of none"); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.decoder + .decode("ew0KICAiYWxnIjogIm5vbmUiLA0KICAidHlwIjogIkpXVCINCn0.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.") + .block() + ) + .withMessage("Unsupported algorithm of none"); + // @formatter:on } @Test public void decodeWhenInvalidAlgMismatchThenFail() { - assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIkVTMjU2IiwNCiAgInR5cCI6ICJKV1QiDQp9.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block()) - .isInstanceOf(BadJwtException.class); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.decoder + .decode("ew0KICAiYWxnIjogIkVTMjU2IiwNCiAgInR5cCI6ICJKV1QiDQp9.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.") + .block() + ); + // @formatter:on } @Test public void decodeWhenUnsignedTokenThenMessageDoesNotMentionClass() { - assertThatCode(() -> this.decoder.decode(this.unsignedToken).block()) - .isInstanceOf(BadJwtException.class) - .hasMessage("Unsupported algorithm of none"); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.decoder.decode(this.unsignedToken).block()) + .withMessage("Unsupported algorithm of none"); + // @formatter:on } @Test public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() { OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); this.decoder.setJwtValidator(jwtValidator); - OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(error); - when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); - - assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description"); + given(jwtValidator.validate(any(Jwt.class))).willReturn(result); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> this.decoder.decode(this.messageReadToken).block()) + .withMessageContaining("mock-description"); + // @formatter:on } @Test public void decodeWhenReadingErrorPickTheFirstErrorMessage() { OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); this.decoder.setJwtValidator(jwtValidator); - OAuth2Error errorEmpty = new OAuth2Error("mock-error", "", "mock-uri"); OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); OAuth2Error error2 = new OAuth2Error("mock-error-second", "mock-description-second", "mock-uri-second"); OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(errorEmpty, error, error2); - when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); - - assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("mock-description"); + given(jwtValidator.validate(any(Jwt.class))).willReturn(result); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> this.decoder.decode(this.messageReadToken).block()) + .withMessageContaining("mock-description"); + // @formatter:on } @Test public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() { Converter, Map> claimSetConverter = mock(Converter.class); this.decoder.setClaimSetConverter(claimSetConverter); - - when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value")); - + given(claimSetConverter.convert(any(Map.class))).willReturn(Collections.singletonMap("custom", "value")); Jwt jwt = this.decoder.decode(this.messageReadToken).block(); assertThat(jwt.getClaims().size()).isEqualTo(1); assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); @@ -258,57 +273,76 @@ public class NimbusReactiveJwtDecoderTests { public void decodeWhenClaimSetConverterFailsThenBadJwtException() { Converter, Map> claimSetConverter = mock(Converter.class); this.decoder.setClaimSetConverter(claimSetConverter); - - when(claimSetConverter.convert(any(Map.class))).thenThrow(new IllegalArgumentException("bad conversion")); - - assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) - .isInstanceOf(BadJwtException.class); + given(claimSetConverter.convert(any(Map.class))).willThrow(new IllegalArgumentException("bad conversion")); + // @formatter:off + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.decoder.decode(this.messageReadToken).block()); + // @formatter:on } @Test public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() { - assertThatCode(() -> this.decoder.setJwtValidator(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.decoder.setJwtValidator(null)); + // @formatter:on } @Test public void setClaimSetConverterWhenNullThrowsIllegalArgumentException() { - assertThatCode(() -> this.decoder.setClaimSetConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.decoder.setClaimSetConverter(null)); + // @formatter:on } @Test public void withJwkSetUriWhenNullOrEmptyThenThrowsException() { - assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> NimbusReactiveJwtDecoder.withJwkSetUri(null)); } @Test public void jwsAlgorithmWhenNullThenThrowsException() { - NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = withJwkSetUri(this.jwkSetUri); - assertThatCode(() -> builder.jwsAlgorithm(null)).isInstanceOf(IllegalArgumentException.class); + NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = NimbusReactiveJwtDecoder + .withJwkSetUri(this.jwkSetUri); + assertThatIllegalArgumentException().isThrownBy(() -> builder.jwsAlgorithm(null)); } @Test public void withJwkSetUriWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { - assertThatCode(() -> withJwkSetUri(jwkSetUri).jwtProcessorCustomizer(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder + .withJwkSetUri(this.jwkSetUri) + .jwtProcessorCustomizer(null) + .build() + ) + .withMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:on } @Test public void restOperationsWhenNullThenThrowsException() { - NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = withJwkSetUri(this.jwkSetUri); - assertThatCode(() -> builder.webClient(null)).isInstanceOf(IllegalArgumentException.class); + NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = NimbusReactiveJwtDecoder + .withJwkSetUri(this.jwkSetUri); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> builder.webClient(null)); + // @formatter:on } // gh-5603 @Test public void decodeWhenSignedThenOk() { WebClient webClient = mockJwkSetResponse(this.jwkSet); - NimbusReactiveJwtDecoder decoder = withJwkSetUri(this.jwkSetUri).webClient(webClient).build(); - assertThat(decoder.decode(messageReadToken).block()) + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withJwkSetUri(this.jwkSetUri) + .webClient(webClient) + .build(); + assertThat(decoder.decode(this.messageReadToken).block()) .extracting(Jwt::getExpiresAt) .isNotNull(); + // @formatter:on verify(webClient).get(); } @@ -316,144 +350,200 @@ public class NimbusReactiveJwtDecoderTests { @Test public void withJwkSetUriWhenUsingCustomTypeHeaderThenRefuseOmittedType() { WebClient webClient = mockJwkSetResponse(this.jwkSet); - NimbusReactiveJwtDecoder decoder = withJwkSetUri(this.jwkSetUri) + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withJwkSetUri(this.jwkSetUri) .webClient(webClient) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); - assertThatCode(() -> decoder.decode(messageReadToken).block()) - .isInstanceOf(BadJwtException.class) - .hasRootCauseMessage("Required JOSE header \"typ\" (type) parameter is missing"); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder.decode(this.messageReadToken).block()) + .havingRootCause().withMessage("Required JOSE header \"typ\" (type) parameter is missing"); + // @formatter:on } @Test public void withPublicKeyWhenNullThenThrowsException() { - assertThatThrownBy(() -> withPublicKey(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder.withPublicKey(null)); + // @formatter:on } @Test public void buildWhenSignatureAlgorithmMismatchesKeyTypeThenThrowsException() { - assertThatCode(() -> withPublicKey(key()) - .signatureAlgorithm(SignatureAlgorithm.ES256) - .build()) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> NimbusReactiveJwtDecoder.withPublicKey(key()) + .signatureAlgorithm(SignatureAlgorithm.ES256) + .build() + ); + // @formatter:on } @Test public void buildWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { - assertThatCode(() -> withPublicKey(key()).jwtProcessorCustomizer(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder.withPublicKey(key()) + .jwtProcessorCustomizer(null) + .build() + ) + .withMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:on } @Test public void decodeWhenUsingPublicKeyThenSuccessfullyDecodes() throws Exception { - NimbusReactiveJwtDecoder decoder = withPublicKey(key()).build(); + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withPublicKey(key()) + .build(); assertThat(decoder.decode(this.rsa256).block()) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } @Test public void decodeWhenUsingPublicKeyWithRs512ThenSuccessfullyDecodes() throws Exception { - NimbusReactiveJwtDecoder decoder = - withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withPublicKey(key()) + .signatureAlgorithm(SignatureAlgorithm.RS512) + .build(); assertThat(decoder.decode(this.rsa512).block()) .extracting(Jwt::getSubject) .isEqualTo("test-subject"); + // @formatter:on } @Test public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception { - NimbusReactiveJwtDecoder decoder = - withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); - assertThatCode(() -> decoder.decode(this.rsa256).block()) - .isInstanceOf(BadJwtException.class); + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withPublicKey(key()) + .signatureAlgorithm(SignatureAlgorithm.RS512) + .build(); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder + .decode(this.rsa256) + .block() + ); + // @formatter:on } // gh-8730 @Test public void withPublicKeyWhenUsingCustomTypeHeaderThenRefuseOmittedType() throws Exception { - NimbusReactiveJwtDecoder decoder = withPublicKey(key()) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withPublicKey(key()) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); - - AssertionsForClassTypes.assertThatCode(() -> decoder.decode(this.rsa256).block()) - .isInstanceOf(BadJwtException.class) - .hasRootCauseMessage("Required JOSE header \"typ\" (type) parameter is missing"); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder.decode(this.rsa256).block()) + .havingRootCause().withMessage("Required JOSE header \"typ\" (type) parameter is missing"); + // @formatter:on } @Test public void withJwkSourceWhenNullThenThrowsException() { - assertThatCode(() -> withJwkSource(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder.withJwkSource(null)); + // @formatter:on } @Test public void withJwkSourceWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { - assertThatCode(() -> withJwkSource(jwt -> Flux.empty()).jwtProcessorCustomizer(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> NimbusReactiveJwtDecoder + .withJwkSource((jwt) -> Flux.empty()).jwtProcessorCustomizer(null).build()) + .withMessage("jwtProcessorCustomizer cannot be null"); } @Test public void decodeWhenCustomJwkSourceResolutionThenDecodes() { - NimbusReactiveJwtDecoder decoder = - withJwkSource(jwt -> Flux.fromIterable(parseJWKSet(this.jwkSet).getKeys())) - .build(); - + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder + .withJwkSource((jwt) -> Flux.fromIterable(parseJWKSet(this.jwkSet).getKeys())) + .build(); assertThat(decoder.decode(this.messageReadToken).block()) .extracting(Jwt::getExpiresAt) .isNotNull(); + // @formatter:on } // gh-8730 @Test public void withJwkSourceWhenUsingCustomTypeHeaderThenRefuseOmittedType() { - NimbusReactiveJwtDecoder decoder = withJwkSource(jwt -> Flux.empty()) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder + .withJwkSource((jwt) -> Flux.empty()) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); - - assertThatCode(() -> decoder.decode(this.messageReadToken).block()) - .isInstanceOf(BadJwtException.class) - .hasRootCauseMessage("Required JOSE header \"typ\" (type) parameter is missing"); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder.decode(this.messageReadToken).block()) + .havingRootCause() + .withMessage("Required JOSE header \"typ\" (type) parameter is missing"); + // @formatter:on } @Test public void withSecretKeyWhenSecretKeyNullThenThrowsIllegalArgumentException() { - assertThatThrownBy(() -> withSecretKey(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("secretKey cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder.withSecretKey(null)) + .withMessage("secretKey cannot be null"); + // @formatter:on } @Test public void withSecretKeyWhenJwtProcessorCustomizerNullThenThrowsIllegalArgumentException() { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - assertThatThrownBy(() -> withSecretKey(secretKey).jwtProcessorCustomizer(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder + .withSecretKey(secretKey) + .jwtProcessorCustomizer(null) + .build() + ) + .withMessage("jwtProcessorCustomizer cannot be null"); + // @formatter:on } @Test public void withSecretKeyWhenMacAlgorithmNullThenThrowsIllegalArgumentException() { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - assertThatThrownBy(() -> withSecretKey(secretKey).macAlgorithm(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("macAlgorithm cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> NimbusReactiveJwtDecoder + .withSecretKey(secretKey) + .macAlgorithm(null) + ) + .withMessage("macAlgorithm cannot be null"); + // @formatter:on } @Test public void decodeWhenSecretKeyThenSuccess() throws Exception { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; MacAlgorithm macAlgorithm = MacAlgorithm.HS256; + // @formatter:off JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() .subject("test-subject") .expirationTime(Date.from(Instant.now().plusSeconds(60))) .build(); + // @formatter:on SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet); - - this.decoder = withSecretKey(secretKey).macAlgorithm(macAlgorithm).build(); - Jwt jwt = this.decoder.decode(signedJWT.serialize()).block(); + // @formatter:off + this.decoder = NimbusReactiveJwtDecoder.withSecretKey(secretKey) + .macAlgorithm(macAlgorithm) + .build(); + Jwt jwt = this.decoder.decode(signedJWT.serialize()) + .block(); + // @formatter:on assertThat(jwt.getSubject()).isEqualTo("test-subject"); } @@ -461,72 +551,71 @@ public class NimbusReactiveJwtDecoderTests { @Test public void withSecretKeyWhenUsingCustomTypeHeaderThenRefuseOmittedType() { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; - NimbusReactiveJwtDecoder decoder = withSecretKey(secretKey) - .jwtProcessorCustomizer(p -> p.setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS")))) + // @formatter:off + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withSecretKey(secretKey) + .jwtProcessorCustomizer((p) -> p + .setJWSTypeVerifier(new DefaultJOSEObjectTypeVerifier<>(new JOSEObjectType("JWS"))) + ) .build(); - assertThatCode(() -> decoder.decode(messageReadToken).block()) - .isInstanceOf(BadJwtException.class) - .hasRootCauseMessage("Required JOSE header \"typ\" (type) parameter is missing"); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> decoder.decode(this.messageReadToken).block()) + .havingRootCause().withMessage("Required JOSE header \"typ\" (type) parameter is missing"); + // @formatter:on } @Test public void decodeWhenSecretKeyAndAlgorithmMismatchThenThrowsJwtException() throws Exception { SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY; MacAlgorithm macAlgorithm = MacAlgorithm.HS256; - JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() - .subject("test-subject") - .expirationTime(Date.from(Instant.now().plusSeconds(60))) - .build(); + JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().subject("test-subject") + .expirationTime(Date.from(Instant.now().plusSeconds(60))).build(); SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet); - - this.decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build(); - assertThatThrownBy(() -> this.decoder.decode(signedJWT.serialize()).block()) - .isInstanceOf(BadJwtException.class); + // @formatter:off + this.decoder = NimbusReactiveJwtDecoder.withSecretKey(secretKey) + .macAlgorithm(MacAlgorithm.HS512) + .build(); + assertThatExceptionOfType(BadJwtException.class) + .isThrownBy(() -> this.decoder.decode(signedJWT.serialize()).block()); + // @formatter:on } @Test public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() { JWKSource jwkSource = mock(JWKSource.class); - JWSKeySelector jwsKeySelector = - withJwkSetUri(this.jwkSetUri).jwsKeySelector(jwkSource); + JWSKeySelector jwsKeySelector = NimbusReactiveJwtDecoder.withJwkSetUri(this.jwkSetUri) + .jwsKeySelector(jwkSource); assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); - JWSVerificationKeySelector jwsVerificationKeySelector = - (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256)) - .isTrue(); + JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256)).isTrue(); } @Test public void jwsKeySelectorWhenOneAlgorithmThenReturnsSingleSelector() { JWKSource jwkSource = mock(JWKSource.class); - JWSKeySelector jwsKeySelector = - withJwkSetUri(this.jwkSetUri).jwsAlgorithm(SignatureAlgorithm.RS512) - .jwsKeySelector(jwkSource); + JWSKeySelector jwsKeySelector = NimbusReactiveJwtDecoder.withJwkSetUri(this.jwkSetUri) + .jwsAlgorithm(SignatureAlgorithm.RS512).jwsKeySelector(jwkSource); assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); - JWSVerificationKeySelector jwsVerificationKeySelector = - (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512)) - .isTrue(); + JWSVerificationKeySelector jwsVerificationKeySelector = (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512)).isTrue(); } @Test public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() { JWKSource jwkSource = mock(JWKSource.class); - JWSKeySelector jwsKeySelector = - withJwkSetUri(this.jwkSetUri) - .jwsAlgorithm(SignatureAlgorithm.RS256) - .jwsAlgorithm(SignatureAlgorithm.RS512) - .jwsKeySelector(jwkSource); + // @formatter:off + JWSKeySelector jwsKeySelector = NimbusReactiveJwtDecoder.withJwkSetUri(this.jwkSetUri) + .jwsAlgorithm(SignatureAlgorithm.RS256) + .jwsAlgorithm(SignatureAlgorithm.RS512) + .jwsKeySelector(jwkSource); + // @formatter:on assertThat(jwsKeySelector instanceof JWSVerificationKeySelector); - JWSVerificationKeySelector jwsAlgorithmMapKeySelector = - (JWSVerificationKeySelector) jwsKeySelector; - assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256)) - .isTrue(); - assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512)) - .isTrue(); + JWSVerificationKeySelector jwsAlgorithmMapKeySelector = (JWSVerificationKeySelector) jwsKeySelector; + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256)).isTrue(); + assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512)).isTrue(); } - private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception { + private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) + throws Exception { SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet); JWSSigner signer = new MACSigner(secretKey); signedJWT.sign(signer); @@ -536,8 +625,9 @@ public class NimbusReactiveJwtDecoderTests { private JWKSet parseJWKSet(String jwkSet) { try { return JWKSet.parse(jwkSet); - } catch (ParseException e) { - throw new IllegalArgumentException(e); + } + catch (ParseException ex) { + throw new IllegalArgumentException(ex); } } @@ -551,10 +641,11 @@ public class NimbusReactiveJwtDecoderTests { WebClient real = WebClient.builder().build(); WebClient.RequestHeadersUriSpec spec = spy(real.get()); WebClient webClient = spy(WebClient.class); - when(webClient.get()).thenReturn(spec); + given(webClient.get()).willReturn(spec); WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class); - when(responseSpec.bodyToMono(String.class)).thenReturn(Mono.just(response)); - when(spec.retrieve()).thenReturn(responseSpec); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(response)); + given(spec.retrieve()).willReturn(responseSpec); return webClient; } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java index 3e6b411fc5..6c5d22ef30 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.jwt; import java.net.URI; @@ -37,7 +38,9 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.web.util.UriComponentsBuilder; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link ReactiveJwtDecoders} @@ -46,40 +49,46 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Rafiullah Hamedy */ public class ReactiveJwtDecodersTests { + /** - * Contains those parameters required to construct a ReactiveJwtDecoder as well as any required parameters + * Contains those parameters required to construct a ReactiveJwtDecoder as well as any + * required parameters */ - private static final String DEFAULT_RESPONSE_TEMPLATE = - "{\n" - + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" - + " \"id_token_signing_alg_values_supported\": [\n" - + " \"RS256\"\n" - + " ], \n" - + " \"issuer\": \"%s\", \n" - + " \"jwks_uri\": \"%s/.well-known/jwks.json\", \n" - + " \"response_types_supported\": [\n" - + " \"code\", \n" - + " \"token\", \n" - + " \"id_token\", \n" - + " \"code token\", \n" - + " \"code id_token\", \n" - + " \"token id_token\", \n" - + " \"code token id_token\", \n" - + " \"none\"\n" - + " ], \n" - + " \"subject_types_supported\": [\n" - + " \"public\"\n" - + " ], \n" - + " \"token_endpoint\": \"https://example.com/oauth2/v4/token\"\n" - + "}"; + // @formatter:off + private static final String DEFAULT_RESPONSE_TEMPLATE = "{\n" + + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" + + " \"id_token_signing_alg_values_supported\": [\n" + + " \"RS256\"\n" + + " ], \n" + + " \"issuer\": \"%s\", \n" + + " \"jwks_uri\": \"%s/.well-known/jwks.json\", \n" + + " \"response_types_supported\": [\n" + + " \"code\", \n" + + " \"token\", \n" + + " \"id_token\", \n" + + " \"code token\", \n" + + " \"code id_token\", \n" + + " \"token id_token\", \n" + + " \"code token id_token\", \n" + + " \"none\"\n" + + " ], \n" + + " \"subject_types_supported\": [\n" + + " \"public\"\n" + + " ], \n" + + " \"token_endpoint\": \"https://example.com/oauth2/v4/token\"\n" + + "}"; + // @formatter:on private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}"; + private static final String ISSUER_MISMATCH = "eyJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwczpcL1wvd3Jvbmdpc3N1ZXIiLCJleHAiOjQ2ODcyNTYwNDl9.Ax8LMI6rhB9Pv_CE3kFi1JPuLj9gZycifWrLeDpkObWEEVAsIls9zAhNFyJlG-Oo7up6_mDhZgeRfyKnpSF5GhKJtXJDCzwg0ZDVUE6rS0QadSxsMMGbl7c4y0lG_7TfLX2iWeNJukJj_oSW9KzW4FsBp1BoocWjrreesqQU3fZHbikH-c_Fs2TsAIpHnxflyEzfOFWpJ8D4DtzHXqfvieMwpy42xsPZK3LR84zlasf0Ne1tC_hLHvyHRdAXwn0CMoKxc7-8j0r9Mq8kAzUsPn9If7bMLqGkxUcTPdk5x7opAUajDZx95SXHLmtztNtBa2S6EfPJXuPKG6tM5Wq5Ug"; private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration"; + private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server"; private MockWebServer server; + private String issuer; @Before @@ -98,56 +107,61 @@ public class ReactiveJwtDecodersTests { @Test public void issuerWhenResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponse(); - ReactiveJwtDecoder decoder = ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer); - - assertThatCode(() -> decoder.decode(ISSUER_MISMATCH).block()) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("The iss claim is not valid"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(ISSUER_MISMATCH).block()) + .withMessageContaining("The iss claim is not valid"); + // @formatter:on } @Test public void issuerWhenOidcFallbackResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponseOidc(); - ReactiveJwtDecoder decoder = ReactiveJwtDecoders.fromIssuerLocation(this.issuer); - - assertThatCode(() -> decoder.decode(ISSUER_MISMATCH).block()) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("The iss claim is not valid"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(ISSUER_MISMATCH).block()) + .withMessageContaining("The iss claim is not valid"); + // @formatter:on } @Test public void issuerWhenOAuth2ResponseIsTypicalThenReturnedDecoderValidatesIssuer() { prepareConfigurationResponseOAuth2(); - ReactiveJwtDecoder decoder = ReactiveJwtDecoders.fromIssuerLocation(this.issuer); - - assertThatCode(() -> decoder.decode(ISSUER_MISMATCH).block()) - .isInstanceOf(JwtValidationException.class) - .hasMessageContaining("The iss claim is not valid"); + // @formatter:off + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> decoder.decode(ISSUER_MISMATCH).block()) + .withMessageContaining("The iss claim is not valid"); + // @formatter:on } @Test public void issuerWhenResponseIsNonCompliantThenThrowsRuntimeException() { prepareConfigurationResponse("{ \"missing_required_keys\" : \"and_values\" }"); - - assertThatCode(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOidcFallbackResponseIsNonCompliantThenThrowsRuntimeException() { prepareConfigurationResponseOidc("{ \"missing_required_keys\" : \"and_values\" }"); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOAuth2ResponseIsNonCompliantThenThrowsRuntimeException() { prepareConfigurationResponseOAuth2("{ \"missing_required_keys\" : \"and_values\" }"); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } // gh-7512 @@ -155,9 +169,11 @@ public class ReactiveJwtDecodersTests { public void issuerWhenResponseDoesNotContainJwksUriThenThrowsIllegalArgumentException() throws JsonMappingException, JsonProcessingException { prepareConfigurationResponse(this.buildResponseWithMissingJwksUri()); - assertThatCode(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The public JWK set URI must not be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)) + .withMessage("The public JWK set URI must not be null"); + // @formatter:on } // gh-7512 @@ -165,9 +181,11 @@ public class ReactiveJwtDecodersTests { public void issuerWhenOidcFallbackResponseDoesNotContainJwksUriThenThrowsIllegalArgumentException() throws JsonMappingException, JsonProcessingException { prepareConfigurationResponseOidc(this.buildResponseWithMissingJwksUri()); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The public JWK set URI must not be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) + .withMessage("The public JWK set URI must not be null"); + // @formatter:on } // gh-7512 @@ -175,72 +193,85 @@ public class ReactiveJwtDecodersTests { public void issuerWhenOAuth2ResponseDoesNotContainJwksUriThenThrowsIllegalArgumentException() throws JsonMappingException, JsonProcessingException { prepareConfigurationResponseOAuth2(this.buildResponseWithMissingJwksUri()); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The public JWK set URI must not be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) + .withMessage("The public JWK set URI must not be null"); + // @formatter:on } @Test public void issuerWhenResponseIsMalformedThenThrowsRuntimeException() { prepareConfigurationResponse("malformed"); - - assertThatCode(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOidcFallbackResponseIsMalformedThenThrowsRuntimeException() { prepareConfigurationResponseOidc("malformed"); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOAuth2ResponseIsMalformedThenThrowsRuntimeException() { prepareConfigurationResponseOAuth2("malformed"); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(RuntimeException.class); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenRespondingIssuerMismatchesRequestedIssuerThenThrowsIllegalStateException() { prepareConfigurationResponse(String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); - - assertThatCode(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> ReactiveJwtDecoders.fromOidcIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOidcFallbackRespondingIssuerMismatchesRequestedIssuerThenThrowsIllegalStateException() { prepareConfigurationResponseOidc(String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalStateException.class); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test public void issuerWhenOAuth2RespondingIssuerMismatchesRequestedIssuerThenThrowsIllegalStateException() { - prepareConfigurationResponseOAuth2(String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)) - .isInstanceOf(IllegalStateException.class); + prepareConfigurationResponseOAuth2( + String.format(DEFAULT_RESPONSE_TEMPLATE, this.issuer + "/wrong", this.issuer)); + // @formatter:off + assertThatIllegalStateException() + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation(this.issuer)); + // @formatter:on } @Test - public void issuerWhenRequestedIssuerIsUnresponsiveThenThrowsIllegalArgumentException() - throws Exception { - + public void issuerWhenRequestedIssuerIsUnresponsiveThenThrowsIllegalArgumentException() throws Exception { this.server.shutdown(); - - assertThatCode(() -> ReactiveJwtDecoders.fromOidcIssuerLocation("https://issuer")) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ReactiveJwtDecoders.fromOidcIssuerLocation("https://issuer")); + // @formatter:on } @Test public void issuerWhenOidcFallbackRequestedIssuerIsUnresponsiveThenThrowsIllegalArgumentException() throws Exception { - this.server.shutdown(); - assertThatCode(() -> ReactiveJwtDecoders.fromIssuerLocation("https://issuer")) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> ReactiveJwtDecoders.fromIssuerLocation("https://issuer")); + // @formatter:on } private void prepareConfigurationResponse() { @@ -281,9 +312,13 @@ public class ReactiveJwtDecodersTests { Dispatcher dispatcher = new Dispatcher() { @Override public MockResponse dispatch(RecordedRequest request) { - return Optional.of(request).map(RecordedRequest::getRequestUrl).map(HttpUrl::toString) + // @formatter:off + return Optional.of(request) + .map(RecordedRequest::getRequestUrl) + .map(HttpUrl::toString) .map(responses::get) .orElse(new MockResponse().setResponseCode(404)); + // @formatter:on } }; this.server.setDispatcher(dispatcher); @@ -295,14 +330,20 @@ public class ReactiveJwtDecodersTests { private String oidc() { URI uri = URI.create(this.issuer); + // @formatter:off return UriComponentsBuilder.fromUri(uri) - .replacePath(uri.getPath() + OIDC_METADATA_PATH).toUriString(); + .replacePath(uri.getPath() + OIDC_METADATA_PATH) + .toUriString(); + // @formatter:on } private String oauth() { URI uri = URI.create(this.issuer); + // @formatter:off return UriComponentsBuilder.fromUri(uri) - .replacePath(OAUTH_METADATA_PATH + uri.getPath()).toUriString(); + .replacePath(OAUTH_METADATA_PATH + uri.getPath()) + .toUriString(); + // @formatter:on } private String jwks() { @@ -310,16 +351,19 @@ public class ReactiveJwtDecodersTests { } private MockResponse response(String body) { - return new MockResponse() - .setBody(body) + // @formatter:off + return new MockResponse().setBody(body) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + // @formatter:on } public String buildResponseWithMissingJwksUri() throws JsonMappingException, JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); Map response = mapper.readValue(DEFAULT_RESPONSE_TEMPLATE, - new TypeReference>(){}); + new TypeReference>() { + }); response.remove("jwks_uri"); return mapper.writeValueAsString(response); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java index d5551e1d68..7fecc4fc5e 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java @@ -16,6 +16,9 @@ package org.springframework.security.oauth2.jwt; +import java.util.Collections; +import java.util.List; + import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; import com.nimbusds.jose.jwk.JWKSelector; @@ -29,12 +32,9 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import java.util.Collections; -import java.util.List; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -42,6 +42,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class ReactiveRemoteJWKSourceTests { + @Mock private JWKMatcher matcher; @@ -51,6 +52,7 @@ public class ReactiveRemoteJWKSourceTests { private MockWebServer server; + // @formatter:off private String keys = "{\n" + " \"keys\": [\n" + " {\n" @@ -71,8 +73,9 @@ public class ReactiveRemoteJWKSourceTests { + " }\n" + " ]\n" + "}\n"; + // @formatter:on - + // @formatter:off private String keys2 = "{\n" + " \"keys\": [\n" + " {\n" @@ -85,30 +88,27 @@ public class ReactiveRemoteJWKSourceTests { + " }\n" + " ]\n" + "}\n"; + // @formatter:on @Before public void setup() { this.server = new MockWebServer(); this.source = new ReactiveRemoteJWKSource(this.server.url("/").toString()); - this.server.enqueue(new MockResponse().setBody(this.keys)); this.selector = new JWKSelector(this.matcher); } @Test public void getWhenMultipleRequestThenCached() { - when(this.matcher.matches(any())).thenReturn(true); - + given(this.matcher.matches(any())).willReturn(true); this.source.get(this.selector).block(); this.source.get(this.selector).block(); - assertThat(this.server.getRequestCount()).isEqualTo(1); } @Test public void getWhenMatchThenCreatesKeys() { - when(this.matcher.matches(any())).thenReturn(true); - + given(this.matcher.matches(any())).willReturn(true); List keys = this.source.get(this.selector).block(); assertThat(keys).hasSize(2); JWK key1 = keys.get(0); @@ -116,7 +116,6 @@ public class ReactiveRemoteJWKSourceTests { assertThat(key1.getAlgorithm().getName()).isEqualTo("RS256"); assertThat(key1.getKeyType()).isEqualTo(KeyType.RSA); assertThat(key1.getKeyUse()).isEqualTo(KeyUse.SIGNATURE); - JWK key2 = keys.get(1); assertThat(key2.getKeyID()).isEqualTo("7ddf54d3032d1f0d48c3618892ca74c1ac30ad77"); assertThat(key2.getAlgorithm().getName()).isEqualTo("RS256"); @@ -126,20 +125,17 @@ public class ReactiveRemoteJWKSourceTests { @Test public void getWhenNoMatchAndNoKeyIdThenEmpty() { - when(this.matcher.matches(any())).thenReturn(false); - when(this.matcher.getKeyIDs()).thenReturn(Collections.emptySet()); - + given(this.matcher.matches(any())).willReturn(false); + given(this.matcher.getKeyIDs()).willReturn(Collections.emptySet()); assertThat(this.source.get(this.selector).block()).isEmpty(); } @Test public void getWhenNoMatchAndKeyIdNotMatchThenRefreshAndFoundThenFound() { this.server.enqueue(new MockResponse().setBody(this.keys2)); - when(this.matcher.matches(any())).thenReturn(false, false, true); - when(this.matcher.getKeyIDs()).thenReturn(Collections.singleton("rotated")); - + given(this.matcher.matches(any())).willReturn(false, false, true); + given(this.matcher.getKeyIDs()).willReturn(Collections.singleton("rotated")); List keys = this.source.get(this.selector).block(); - assertThat(keys).hasSize(1); assertThat(keys.get(0).getKeyID()).isEqualTo("rotated"); } @@ -147,19 +143,17 @@ public class ReactiveRemoteJWKSourceTests { @Test public void getWhenNoMatchAndKeyIdNotMatchThenRefreshAndNotFoundThenEmpty() { this.server.enqueue(new MockResponse().setBody(this.keys2)); - when(this.matcher.matches(any())).thenReturn(false, false, false); - when(this.matcher.getKeyIDs()).thenReturn(Collections.singleton("rotated")); - + given(this.matcher.matches(any())).willReturn(false, false, false); + given(this.matcher.getKeyIDs()).willReturn(Collections.singleton("rotated")); List keys = this.source.get(this.selector).block(); - assertThat(keys).isEmpty(); } @Test public void getWhenNoMatchAndKeyIdMatchThenEmpty() { - when(this.matcher.matches(any())).thenReturn(false); - when(this.matcher.getKeyIDs()).thenReturn(Collections.singleton("7ddf54d3032d1f0d48c3618892ca74c1ac30ad77")); - + given(this.matcher.matches(any())).willReturn(false); + given(this.matcher.getKeyIDs()).willReturn(Collections.singleton("7ddf54d3032d1f0d48c3618892ca74c1ac30ad77")); assertThat(this.source.get(this.selector).block()).isEmpty(); } + } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwts.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwts.java index b253a79616..f710d485e5 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwts.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwts.java @@ -19,8 +19,13 @@ package org.springframework.security.oauth2.jwt; import java.time.Instant; import java.util.Arrays; -public class TestJwts { +public final class TestJwts { + + private TestJwts() { + } + public static Jwt.Builder jwt() { + // @formatter:off return Jwt.withTokenValue("token") .header("alg", "none") .audience(Arrays.asList("https://audience.example.org")) @@ -30,11 +35,15 @@ public class TestJwts { .jti("jti") .notBefore(Instant.MIN) .subject("mock-test-subject"); + // @formatter:on } public static Jwt user() { + // @formatter:off return jwt() .claim("sub", "mock-test-subject") .build(); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationToken.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationToken.java index 6f12be55db..940ee7b676 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationToken.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationToken.java @@ -27,53 +27,50 @@ import org.springframework.util.Assert; /** * An {@link Authentication} that contains a - * Bearer Token. + * Bearer + * Token. * - * Used by {@link BearerTokenAuthenticationFilter} to prepare an authentication attempt and supported - * by {@link JwtAuthenticationProvider}. + * Used by {@link BearerTokenAuthenticationFilter} to prepare an authentication attempt + * and supported by {@link JwtAuthenticationProvider}. * * @author Josh Cummings * @since 5.1 */ public class BearerTokenAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - private String token; + private final String token; /** * Create a {@code BearerTokenAuthenticationToken} using the provided parameter(s) - * * @param token - the bearer token */ public BearerTokenAuthenticationToken(String token) { super(Collections.emptyList()); - Assert.hasText(token, "token cannot be empty"); - this.token = token; } /** - * Get the Bearer Token - * @return the token that proves the caller's authority to perform the {@link javax.servlet.http.HttpServletRequest} + * Get the + * Bearer + * Token + * @return the token that proves the caller's authority to perform the + * {@link javax.servlet.http.HttpServletRequest} */ public String getToken() { return this.token; } - /** - * {@inheritDoc} - */ @Override public Object getCredentials() { return this.getToken(); } - /** - * {@inheritDoc} - */ @Override public Object getPrincipal() { return this.getToken(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenError.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenError.java index 26f0db3e78..e3641dafe1 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenError.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenError.java @@ -21,14 +21,16 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.util.Assert; /** - * A representation of a Bearer Token Error. + * A representation of a + * Bearer Token + * Error. * * @author Vedran Pavic * @author Josh Cummings * @since 5.1 * @see BearerTokenErrorCodes - * @see RFC 6750 Section 3: The WWW-Authenticate - * Response Header Field + * @see RFC 6750 + * Section 3: The WWW-Authenticate Response Header Field */ public final class BearerTokenError extends OAuth2Error { @@ -38,7 +40,6 @@ public final class BearerTokenError extends OAuth2Error { /** * Create a {@code BearerTokenError} using the provided parameters - * * @param errorCode the error code * @param httpStatus the HTTP status */ @@ -48,26 +49,23 @@ public final class BearerTokenError extends OAuth2Error { /** * Create a {@code BearerTokenError} using the provided parameters - * * @param errorCode the error code * @param httpStatus the HTTP status * @param description the description * @param errorUri the URI * @param scope the scope */ - public BearerTokenError(String errorCode, HttpStatus httpStatus, String description, String errorUri, String scope) { + public BearerTokenError(String errorCode, HttpStatus httpStatus, String description, String errorUri, + String scope) { super(errorCode, description, errorUri); Assert.notNull(httpStatus, "httpStatus cannot be null"); - Assert.isTrue(isDescriptionValid(description), "description contains invalid ASCII characters, it must conform to RFC 6750"); Assert.isTrue(isErrorCodeValid(errorCode), "errorCode contains invalid ASCII characters, it must conform to RFC 6750"); Assert.isTrue(isErrorUriValid(errorUri), "errorUri contains invalid ASCII characters, it must conform to RFC 6750"); - Assert.isTrue(isScopeValid(scope), - "scope contains invalid ASCII characters, it must conform to RFC 6750"); - + Assert.isTrue(isScopeValid(scope), "scope contains invalid ASCII characters, it must conform to RFC 6750"); this.httpStatus = httpStatus; this.scope = scope; } @@ -89,37 +87,39 @@ public final class BearerTokenError extends OAuth2Error { } private static boolean isDescriptionValid(String description) { - return description == null || - description.chars().allMatch(c -> - withinTheRangeOf(c, 0x20, 0x21) || - withinTheRangeOf(c, 0x23, 0x5B) || - withinTheRangeOf(c, 0x5D, 0x7E)); + // @formatter:off + return description == null || description.chars().allMatch((c) -> + withinTheRangeOf(c, 0x20, 0x21) || + withinTheRangeOf(c, 0x23, 0x5B) || + withinTheRangeOf(c, 0x5D, 0x7E)); + // @formatter:on } private static boolean isErrorCodeValid(String errorCode) { - return errorCode.chars().allMatch(c -> - withinTheRangeOf(c, 0x20, 0x21) || - withinTheRangeOf(c, 0x23, 0x5B) || - withinTheRangeOf(c, 0x5D, 0x7E)); + // @formatter:off + return errorCode.chars().allMatch((c) -> + withinTheRangeOf(c, 0x20, 0x21) || + withinTheRangeOf(c, 0x23, 0x5B) || + withinTheRangeOf(c, 0x5D, 0x7E)); + // @formatter:on } private static boolean isErrorUriValid(String errorUri) { - return errorUri == null || - errorUri.chars().allMatch(c -> - c == 0x21 || - withinTheRangeOf(c, 0x23, 0x5B) || - withinTheRangeOf(c, 0x5D, 0x7E)); + return errorUri == null || errorUri.chars() + .allMatch((c) -> c == 0x21 || withinTheRangeOf(c, 0x23, 0x5B) || withinTheRangeOf(c, 0x5D, 0x7E)); } private static boolean isScopeValid(String scope) { - return scope == null || - scope.chars().allMatch(c -> - withinTheRangeOf(c, 0x20, 0x21) || - withinTheRangeOf(c, 0x23, 0x5B) || - withinTheRangeOf(c, 0x5D, 0x7E)); + // @formatter:off + return scope == null || scope.chars().allMatch((c) -> + withinTheRangeOf(c, 0x20, 0x21) || + withinTheRangeOf(c, 0x23, 0x5B) || + withinTheRangeOf(c, 0x5D, 0x7E)); + // @formatter:on } private static boolean withinTheRangeOf(int c, int min, int max) { return c >= min && c <= max; } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorCodes.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorCodes.java index 0e74b2d6f1..84457dc674 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorCodes.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorCodes.java @@ -17,29 +17,32 @@ package org.springframework.security.oauth2.server.resource; /** - * Standard error codes defined by the OAuth 2.0 Authorization Framework: Bearer Token Usage. + * Standard error codes defined by the OAuth 2.0 Authorization Framework: Bearer Token + * Usage. * * @author Vedran Pavic * @since 5.1 - * @see RFC 6750 Section 3.1: Error Codes + * @see RFC 6750 + * Section 3.1: Error Codes */ public interface BearerTokenErrorCodes { /** - * {@code invalid_request} - The request is missing a required parameter, includes an unsupported parameter or - * parameter value, repeats the same parameter, uses more than one method for including an access token, or is - * otherwise malformed. + * {@code invalid_request} - The request is missing a required parameter, includes an + * unsupported parameter or parameter value, repeats the same parameter, uses more + * than one method for including an access token, or is otherwise malformed. */ String INVALID_REQUEST = "invalid_request"; /** - * {@code invalid_token} - The access token provided is expired, revoked, malformed, or invalid for other - * reasons. + * {@code invalid_token} - The access token provided is expired, revoked, malformed, + * or invalid for other reasons. */ String INVALID_TOKEN = "invalid_token"; /** - * {@code insufficient_scope} - The request requires higher privileges than provided by the access token. + * {@code insufficient_scope} - The request requires higher privileges than provided + * by the access token. */ String INSUFFICIENT_SCOPE = "insufficient_scope"; diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrors.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrors.java index 7bbd5387d4..eaa90b8692 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrors.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/BearerTokenErrors.java @@ -18,78 +18,76 @@ package org.springframework.security.oauth2.server.resource; import org.springframework.http.HttpStatus; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes.INSUFFICIENT_SCOPE; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes.INVALID_REQUEST; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes.INVALID_TOKEN; - /** - * A factory for creating {@link BearerTokenError} instances that correspond to the registered - * Bearer Token Error Codes. + * A factory for creating {@link BearerTokenError} instances that correspond to the + * registered Bearer Token Error + * Codes. * * @author Josh Cummings * @since 5.3 */ public final class BearerTokenErrors { + private static final BearerTokenError DEFAULT_INVALID_REQUEST = invalidRequest("Invalid request"); + private static final BearerTokenError DEFAULT_INVALID_TOKEN = invalidToken("Invalid token"); + private static final BearerTokenError DEFAULT_INSUFFICIENT_SCOPE = insufficientScope("Insufficient scope", null); private static final String DEFAULT_URI = "https://tools.ietf.org/html/rfc6750#section-3.1"; + private BearerTokenErrors() { + } + /** * Create a {@link BearerTokenError} caused by an invalid request - * * @param message a description of the error * @return a {@link BearerTokenError} */ public static BearerTokenError invalidRequest(String message) { try { - return new BearerTokenError(INVALID_REQUEST, - HttpStatus.BAD_REQUEST, - message, + return new BearerTokenError(BearerTokenErrorCodes.INVALID_REQUEST, HttpStatus.BAD_REQUEST, message, DEFAULT_URI); - } catch (IllegalArgumentException malformed) { - // some third-party library error messages are not suitable for RFC 6750's error message charset + } + catch (IllegalArgumentException ex) { + // some third-party library error messages are not suitable for RFC 6750's + // error message charset return DEFAULT_INVALID_REQUEST; } } /** * Create a {@link BearerTokenError} caused by an invalid token - * * @param message a description of the error * @return a {@link BearerTokenError} */ public static BearerTokenError invalidToken(String message) { try { - return new BearerTokenError(INVALID_TOKEN, - HttpStatus.UNAUTHORIZED, - message, + return new BearerTokenError(BearerTokenErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED, message, DEFAULT_URI); - } catch (IllegalArgumentException malformed) { - // some third-party library error messages are not suitable for RFC 6750's error message charset + } + catch (IllegalArgumentException ex) { + // some third-party library error messages are not suitable for RFC 6750's + // error message charset return DEFAULT_INVALID_TOKEN; } } /** * Create a {@link BearerTokenError} caused by an invalid token - * * @param scope the scope attribute to use in the error * @return a {@link BearerTokenError} */ public static BearerTokenError insufficientScope(String message, String scope) { try { - return new BearerTokenError(INSUFFICIENT_SCOPE, - HttpStatus.FORBIDDEN, - message, - DEFAULT_URI, - scope); - } catch (IllegalArgumentException malformed) { - // some third-party library error messages are not suitable for RFC 6750's error message charset + return new BearerTokenError(BearerTokenErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN, message, + DEFAULT_URI, scope); + } + catch (IllegalArgumentException ex) { + // some third-party library error messages are not suitable for RFC 6750's + // error message charset return DEFAULT_INSUFFICIENT_SCOPE; } } - private BearerTokenErrors() {} } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/InvalidBearerTokenException.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/InvalidBearerTokenException.java index 47ac1c72be..0ba62813da 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/InvalidBearerTokenException.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/InvalidBearerTokenException.java @@ -18,8 +18,6 @@ package org.springframework.security.oauth2.server.resource; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrors.invalidToken; - /** * An {@link OAuth2AuthenticationException} that indicates an invalid bearer token. * @@ -32,26 +30,27 @@ public class InvalidBearerTokenException extends OAuth2AuthenticationException { * Construct an instance of {@link InvalidBearerTokenException} given the provided * description. * - * The description will be wrapped into an {@link org.springframework.security.oauth2.core.OAuth2Error} - * instance as the {@code error_description}. - * + * The description will be wrapped into an + * {@link org.springframework.security.oauth2.core.OAuth2Error} instance as the + * {@code error_description}. * @param description the description */ public InvalidBearerTokenException(String description) { - super(invalidToken(description)); + super(BearerTokenErrors.invalidToken(description)); } /** * Construct an instance of {@link InvalidBearerTokenException} given the provided * description and cause * - * The description will be wrapped into an {@link org.springframework.security.oauth2.core.OAuth2Error} - * instance as the {@code error_description}. - * + * The description will be wrapped into an + * {@link org.springframework.security.oauth2.core.OAuth2Error} instance as the + * {@code error_description}. * @param description the description * @param cause the causing exception */ public InvalidBearerTokenException(String description, Throwable cause) { - super(invalidToken(description), cause); + super(BearerTokenErrors.invalidToken(description), cause); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/AbstractOAuth2TokenAuthenticationToken.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/AbstractOAuth2TokenAuthenticationToken.java index d6d63b6b92..d221136802 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/AbstractOAuth2TokenAuthenticationToken.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/AbstractOAuth2TokenAuthenticationToken.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.authentication; import java.util.Collection; @@ -28,25 +29,31 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; /** - * Base class for {@link AbstractAuthenticationToken} implementations - * that expose common attributes between different OAuth 2.0 Access Token Formats. + * Base class for {@link AbstractAuthenticationToken} implementations that expose common + * attributes between different OAuth 2.0 Access Token Formats. * *

        * For example, a {@link Jwt} could expose its {@link Jwt#getClaims() claims} via * {@link #getTokenAttributes()} or an "Introspected" OAuth 2.0 Access Token - * could expose the attributes of the Introspection Response via {@link #getTokenAttributes()}. + * could expose the attributes of the Introspection Response via + * {@link #getTokenAttributes()}. * * @author Joe Grandja * @since 5.1 * @see OAuth2AccessToken * @see Jwt - * @see 2.2 Introspection Response + * @see 2.2 + * Introspection Response */ -public abstract class AbstractOAuth2TokenAuthenticationToken extends AbstractAuthenticationToken { +public abstract class AbstractOAuth2TokenAuthenticationToken + extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private Object principal; + private Object credentials; + private T token; /** @@ -59,20 +66,14 @@ public abstract class AbstractOAuth2TokenAuthenticationToken authorities) { + protected AbstractOAuth2TokenAuthenticationToken(T token, Collection authorities) { this(token, token, token, authorities); } - protected AbstractOAuth2TokenAuthenticationToken( - T token, - Object principal, - Object credentials, + protected AbstractOAuth2TokenAuthenticationToken(T token, Object principal, Object credentials, Collection authorities) { super(authorities); @@ -83,17 +84,11 @@ public abstract class AbstractOAuth2TokenAuthenticationToken getTokenAttributes(); + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthentication.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthentication.java index ca0f1ee334..1e70e86b89 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthentication.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthentication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.authentication; import java.util.Collection; @@ -28,8 +29,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; import org.springframework.util.Assert; /** - * An {@link org.springframework.security.core.Authentication} token that represents a successful authentication as - * obtained through a bearer token. + * An {@link org.springframework.security.core.Authentication} token that represents a + * successful authentication as obtained through a bearer token. * * @author Josh Cummings * @since 5.2 @@ -39,29 +40,26 @@ public class BearerTokenAuthentication extends AbstractOAuth2TokenAuthentication private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - private Map attributes; + private final Map attributes; /** * Constructs a {@link BearerTokenAuthentication} with the provided arguments - * * @param principal The OAuth 2.0 attributes * @param credentials The verified token * @param authorities The authorities associated with the given token */ public BearerTokenAuthentication(OAuth2AuthenticatedPrincipal principal, OAuth2AccessToken credentials, Collection authorities) { - super(credentials, principal, credentials, authorities); - Assert.isTrue(credentials.getTokenType() == OAuth2AccessToken.TokenType.BEARER, "credentials must be a bearer token"); + Assert.isTrue(credentials.getTokenType() == OAuth2AccessToken.TokenType.BEARER, + "credentials must be a bearer token"); this.attributes = Collections.unmodifiableMap(new LinkedHashMap<>(principal.getAttributes())); setAuthenticated(true); } - /** - * {@inheritDoc} - */ @Override public Map getTokenAttributes() { return this.attributes; } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java index 4b8171e8cb..1962e03aac 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java @@ -32,8 +32,8 @@ import org.springframework.util.Assert; * @since 5.1 */ public class JwtAuthenticationConverter implements Converter { - private Converter> jwtGrantedAuthoritiesConverter - = new JwtGrantedAuthoritiesConverter(); + + private Converter> jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); private String principalClaimName; @@ -43,14 +43,13 @@ public class JwtAuthenticationConverter implements Converter> jwtGrantedAuthoritiesConverter) { + public void setJwtGrantedAuthoritiesConverter( + Converter> jwtGrantedAuthoritiesConverter) { Assert.notNull(jwtGrantedAuthoritiesConverter, "jwtGrantedAuthoritiesConverter cannot be null"); this.jwtGrantedAuthoritiesConverter = jwtGrantedAuthoritiesConverter; } /** - * Sets the principal claim name. - * Defaults to {@link JwtClaimNames#SUB}. - * + * Sets the principal claim name. Defaults to {@link JwtClaimNames#SUB}. * @param principalClaimName The principal claim name * @since 5.4 */ @@ -86,4 +83,5 @@ public class JwtAuthenticationConverter implements ConverterBearer Tokens - * for protecting OAuth 2.0 Resource Servers. + * Bearer + * Tokens for protecting OAuth 2.0 Resource Servers. *

        *

        - * This {@link AuthenticationProvider} is responsible for decoding and verifying a {@link Jwt}-encoded access token, - * returning its claims set as part of the {@link Authentication} statement. + * This {@link AuthenticationProvider} is responsible for decoding and verifying a + * {@link Jwt}-encoded access token, returning its claims set as part of the + * {@link Authentication} statement. *

        *

        - * Scopes are translated into {@link GrantedAuthority}s according to the following algorithm: + * Scopes are translated into {@link GrantedAuthority}s according to the following + * algorithm: * - * 1. If there is a "scope" or "scp" attribute, then - * if a {@link String}, then split by spaces and return, or - * if a {@link Collection}, then simply return - * 2. Take the resulting {@link Collection} of {@link String}s and prepend the "SCOPE_" keyword, adding - * as {@link GrantedAuthority}s. + * 1. If there is a "scope" or "scp" attribute, then if a {@link String}, then split by + * spaces and return, or if a {@link Collection}, then simply return 2. Take the resulting + * {@link Collection} of {@link String}s and prepend the "SCOPE_" keyword, adding as + * {@link GrantedAuthority}s. * * @author Josh Cummings * @author Joe Grandja @@ -57,6 +59,7 @@ import org.springframework.util.Assert; * @see JwtDecoder */ public final class JwtAuthenticationProvider implements AuthenticationProvider { + private final JwtDecoder jwtDecoder; private Converter jwtAuthenticationConverter = new JwtAuthenticationConverter(); @@ -68,35 +71,33 @@ public final class JwtAuthenticationProvider implements AuthenticationProvider { /** * Decode and validate the - * Bearer Token. - * + * Bearer + * Token. * @param authentication the authentication request object. - * * @return A successful authentication * @throws AuthenticationException if authentication failed for some reason */ @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { BearerTokenAuthenticationToken bearer = (BearerTokenAuthenticationToken) authentication; - - Jwt jwt; - try { - jwt = this.jwtDecoder.decode(bearer.getToken()); - } catch (BadJwtException failed) { - throw new InvalidBearerTokenException(failed.getMessage(), failed); - } catch (JwtException failed) { - throw new AuthenticationServiceException(failed.getMessage(), failed); - } - + Jwt jwt = getJwt(bearer); AbstractAuthenticationToken token = this.jwtAuthenticationConverter.convert(jwt); token.setDetails(bearer.getDetails()); - return token; } - /** - * {@inheritDoc} - */ + private Jwt getJwt(BearerTokenAuthenticationToken bearer) { + try { + return this.jwtDecoder.decode(bearer.getToken()); + } + catch (BadJwtException failed) { + throw new InvalidBearerTokenException(failed.getMessage(), failed); + } + catch (JwtException failed) { + throw new AuthenticationServiceException(failed.getMessage(), failed); + } + } + @Override public boolean supports(Class authentication) { return BearerTokenAuthenticationToken.class.isAssignableFrom(authentication); @@ -104,8 +105,8 @@ public final class JwtAuthenticationProvider implements AuthenticationProvider { public void setJwtAuthenticationConverter( Converter jwtAuthenticationConverter) { - Assert.notNull(jwtAuthenticationConverter, "jwtAuthenticationConverter cannot be null"); this.jwtAuthenticationConverter = jwtAuthenticationConverter; } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java index 44edd9bb55..e389e5e93c 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.authentication; import java.util.Collection; @@ -24,8 +25,8 @@ import org.springframework.security.core.Transient; import org.springframework.security.oauth2.jwt.Jwt; /** - * An implementation of an {@link AbstractOAuth2TokenAuthenticationToken} - * representing a {@link Jwt} {@code Authentication}. + * An implementation of an {@link AbstractOAuth2TokenAuthenticationToken} representing a + * {@link Jwt} {@code Authentication}. * * @author Joe Grandja * @since 5.1 @@ -34,13 +35,13 @@ import org.springframework.security.oauth2.jwt.Jwt; */ @Transient public class JwtAuthenticationToken extends AbstractOAuth2TokenAuthenticationToken { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private final String name; /** * Constructs a {@code JwtAuthenticationToken} using the provided parameters. - * * @param jwt the JWT */ public JwtAuthenticationToken(Jwt jwt) { @@ -50,7 +51,6 @@ public class JwtAuthenticationToken extends AbstractOAuth2TokenAuthenticationTok /** * Constructs a {@code JwtAuthenticationToken} using the provided parameters. - * * @param jwt the JWT * @param authorities the authorities assigned to the JWT */ @@ -62,7 +62,6 @@ public class JwtAuthenticationToken extends AbstractOAuth2TokenAuthenticationTok /** * Constructs a {@code JwtAuthenticationToken} using the provided parameters. - * * @param jwt the JWT * @param authorities the authorities assigned to the JWT * @param name the principal name @@ -73,9 +72,6 @@ public class JwtAuthenticationToken extends AbstractOAuth2TokenAuthenticationTok this.name = name; } - /** - * {@inheritDoc} - */ @Override public Map getTokenAttributes() { return this.getToken().getClaims(); @@ -88,4 +84,5 @@ public class JwtAuthenticationToken extends AbstractOAuth2TokenAuthenticationTok public String getName() { return this.name; } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java index 7cd7fd518a..f548c8ac2c 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java @@ -28,31 +28,34 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; import org.springframework.security.oauth2.jwt.Jwt; /** - * A {@link Converter} that takes a {@link Jwt} and converts it into a {@link BearerTokenAuthentication}. + * A {@link Converter} that takes a {@link Jwt} and converts it into a + * {@link BearerTokenAuthentication}. * - * In the process, it will attempt to parse either the "scope" or "scp" attribute, whichever it finds first. + * In the process, it will attempt to parse either the "scope" or "scp" attribute, + * whichever it finds first. * - * It's not intended that this implementation be configured since it is simply an adapter. If you are using, - * for example, a custom {@link JwtGrantedAuthoritiesConverter}, then it's recommended that you simply - * create your own {@link Converter} that delegates to your custom {@link JwtGrantedAuthoritiesConverter} - * and instantiates the appropriate {@link BearerTokenAuthentication}. + * It's not intended that this implementation be configured since it is simply an adapter. + * If you are using, for example, a custom {@link JwtGrantedAuthoritiesConverter}, then + * it's recommended that you simply create your own {@link Converter} that delegates to + * your custom {@link JwtGrantedAuthoritiesConverter} and instantiates the appropriate + * {@link BearerTokenAuthentication}. * * @author Josh Cummings * @since 5.2 */ public final class JwtBearerTokenAuthenticationConverter implements Converter { + private final JwtAuthenticationConverter jwtAuthenticationConverter = new JwtAuthenticationConverter(); @Override public AbstractAuthenticationToken convert(Jwt jwt) { - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt()); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), + jwt.getIssuedAt(), jwt.getExpiresAt()); Map attributes = jwt.getClaims(); - AbstractAuthenticationToken token = this.jwtAuthenticationConverter.convert(jwt); Collection authorities = token.getAuthorities(); - OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(attributes, authorities); return new BearerTokenAuthentication(principal, accessToken, authorities); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverter.java index 2e9eaec2e9..d126984a21 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverter.java @@ -36,10 +36,10 @@ import org.springframework.util.StringUtils; * @since 5.2 */ public final class JwtGrantedAuthoritiesConverter implements Converter> { + private static final String DEFAULT_AUTHORITY_PREFIX = "SCOPE_"; - private static final Collection WELL_KNOWN_AUTHORITIES_CLAIM_NAMES = - Arrays.asList("scope", "scp"); + private static final Collection WELL_KNOWN_AUTHORITIES_CLAIM_NAMES = Arrays.asList("scope", "scp"); private String authorityPrefix = DEFAULT_AUTHORITY_PREFIX; @@ -47,7 +47,6 @@ public final class JwtGrantedAuthoritiesConverter implements Converter getAuthorities(Jwt jwt) { String claimName = getAuthoritiesClaimName(jwt); - if (claimName == null) { return Collections.emptyList(); } - Object authorities = jwt.getClaim(claimName); if (authorities instanceof String) { if (StringUtils.hasText((String) authorities)) { return Arrays.asList(((String) authorities).split(" ")); - } else { - return Collections.emptyList(); } - } else if (authorities instanceof Collection) { + return Collections.emptyList(); + } + if (authorities instanceof Collection) { return (Collection) authorities; } - return Collections.emptyList(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java index 97cc3d3fe9..0542792f01 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolver.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Predicate; + import javax.servlet.http.HttpServletRequest; import com.nimbusds.jwt.JWTParser; @@ -39,27 +40,33 @@ import org.springframework.security.oauth2.server.resource.web.DefaultBearerToke import org.springframework.util.Assert; /** - * An implementation of {@link AuthenticationManagerResolver} that resolves a JWT-based {@link AuthenticationManager} - * based on the Issuer in a - * signed JWT (JWS). + * An implementation of {@link AuthenticationManagerResolver} that resolves a JWT-based + * {@link AuthenticationManager} based on the Issuer in + * a signed JWT (JWS). * - * To use, this class must be able to determine whether or not the `iss` claim is trusted. Recall that - * anyone can stand up an authorization server and issue valid tokens to a resource server. The simplest way - * to achieve this is to supply a list of trusted issuers in the constructor. + * To use, this class must be able to determine whether or not the `iss` claim is trusted. + * Recall that anyone can stand up an authorization server and issue valid tokens to a + * resource server. The simplest way to achieve this is to supply a list of trusted + * issuers in the constructor. * - * This class derives the Issuer from the `iss` claim found in the {@link HttpServletRequest}'s - * Bearer Token. + * This class derives the Issuer from the `iss` claim found in the + * {@link HttpServletRequest}'s + * Bearer + * Token. * * @author Josh Cummings * @since 5.3 */ public final class JwtIssuerAuthenticationManagerResolver implements AuthenticationManagerResolver { + private final AuthenticationManagerResolver issuerAuthenticationManagerResolver; + private final Converter issuerConverter = new JwtClaimIssuerConverter(); /** - * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided parameters - * + * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided + * parameters * @param trustedIssuers a list of trusted issuers */ public JwtIssuerAuthenticationManagerResolver(String... trustedIssuers) { @@ -67,22 +74,23 @@ public final class JwtIssuerAuthenticationManagerResolver implements Authenticat } /** - * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided parameters - * + * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided + * parameters * @param trustedIssuers a list of trusted issuers */ public JwtIssuerAuthenticationManagerResolver(Collection trustedIssuers) { Assert.notEmpty(trustedIssuers, "trustedIssuers cannot be empty"); - this.issuerAuthenticationManagerResolver = - new TrustedIssuerJwtAuthenticationManagerResolver - (Collections.unmodifiableCollection(trustedIssuers)::contains); + this.issuerAuthenticationManagerResolver = new TrustedIssuerJwtAuthenticationManagerResolver( + Collections.unmodifiableCollection(trustedIssuers)::contains); } /** - * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided parameters + * Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided + * parameters * - * Note that the {@link AuthenticationManagerResolver} provided in this constructor will need to - * verify that the issuer is trusted. This should be done via an allowlist. + * Note that the {@link AuthenticationManagerResolver} provided in this constructor + * will need to verify that the issuer is trusted. This should be done via an + * allowlist. * * One way to achieve this is with a {@link Map} where the keys are the known issuers: *

        @@ -94,19 +102,20 @@ public final class JwtIssuerAuthenticationManagerResolver implements Authenticat
         	 * 
        * * The keys in the {@link Map} are the allowed issuers. - * - * @param issuerAuthenticationManagerResolver a strategy for resolving the {@link AuthenticationManager} by the issuer + * @param issuerAuthenticationManagerResolver a strategy for resolving the + * {@link AuthenticationManager} by the issuer */ - public JwtIssuerAuthenticationManagerResolver(AuthenticationManagerResolver issuerAuthenticationManagerResolver) { + public JwtIssuerAuthenticationManagerResolver( + AuthenticationManagerResolver issuerAuthenticationManagerResolver) { Assert.notNull(issuerAuthenticationManagerResolver, "issuerAuthenticationManagerResolver cannot be null"); this.issuerAuthenticationManagerResolver = issuerAuthenticationManagerResolver; } /** - * Return an {@link AuthenticationManager} based off of the `iss` claim found in the request's bearer token - * - * @throws OAuth2AuthenticationException if the bearer token is malformed or an {@link AuthenticationManager} - * can't be derived from the issuer + * Return an {@link AuthenticationManager} based off of the `iss` claim found in the + * request's bearer token + * @throws OAuth2AuthenticationException if the bearer token is malformed or an + * {@link AuthenticationManager} can't be derived from the issuer */ @Override public AuthenticationManager resolve(HttpServletRequest request) { @@ -118,8 +127,7 @@ public final class JwtIssuerAuthenticationManagerResolver implements Authenticat return authenticationManager; } - private static class JwtClaimIssuerConverter - implements Converter { + private static class JwtClaimIssuerConverter implements Converter { private final BearerTokenResolver resolver = new DefaultBearerTokenResolver(); @@ -131,17 +139,20 @@ public final class JwtIssuerAuthenticationManagerResolver implements Authenticat if (issuer != null) { return issuer; } - } catch (Exception e) { - throw new InvalidBearerTokenException(e.getMessage(), e); + } + catch (Exception ex) { + throw new InvalidBearerTokenException(ex.getMessage(), ex); } throw new InvalidBearerTokenException("Missing issuer"); } + } private static class TrustedIssuerJwtAuthenticationManagerResolver implements AuthenticationManagerResolver { private final Map authenticationManagers = new ConcurrentHashMap<>(); + private final Predicate trustedIssuer; TrustedIssuerJwtAuthenticationManagerResolver(Predicate trustedIssuer) { @@ -151,12 +162,14 @@ public final class JwtIssuerAuthenticationManagerResolver implements Authenticat @Override public AuthenticationManager resolve(String issuer) { if (this.trustedIssuer.test(issuer)) { - return this.authenticationManagers.computeIfAbsent(issuer, k -> { + return this.authenticationManagers.computeIfAbsent(issuer, (k) -> { JwtDecoder jwtDecoder = JwtDecoders.fromIssuerLocation(issuer); return new JwtAuthenticationProvider(jwtDecoder)::authenticate; }); } return null; } + } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java index 0328d5ae3b..e73635e887 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolver.java @@ -41,17 +41,20 @@ import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; /** - * An implementation of {@link ReactiveAuthenticationManagerResolver} that resolves a JWT-based - * {@link ReactiveAuthenticationManager} based on the - * Issuer in a - * signed JWT (JWS). + * An implementation of {@link ReactiveAuthenticationManagerResolver} that resolves a + * JWT-based {@link ReactiveAuthenticationManager} based on the Issuer in + * a signed JWT (JWS). * - * To use, this class must be able to determine whether or not the `iss` claim is trusted. Recall that - * anyone can stand up an authorization server and issue valid tokens to a resource server. The simplest way - * to achieve this is to supply a list of trusted issuers in the constructor. + * To use, this class must be able to determine whether or not the `iss` claim is trusted. + * Recall that anyone can stand up an authorization server and issue valid tokens to a + * resource server. The simplest way to achieve this is to supply a list of trusted + * issuers in the constructor. * - * This class derives the Issuer from the `iss` claim found in the {@link ServerWebExchange}'s - * Bearer Token. + * This class derives the Issuer from the `iss` claim found in the + * {@link ServerWebExchange}'s + * Bearer + * Token. * * @author Josh Cummings * @author Roman Matiushchenko @@ -61,11 +64,12 @@ public final class JwtIssuerReactiveAuthenticationManagerResolver implements ReactiveAuthenticationManagerResolver { private final ReactiveAuthenticationManagerResolver issuerAuthenticationManagerResolver; + private final Converter> issuerConverter = new JwtClaimIssuerConverter(); /** - * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the provided parameters - * + * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the + * provided parameters * @param trustedIssuers a list of trusted issuers */ public JwtIssuerReactiveAuthenticationManagerResolver(String... trustedIssuers) { @@ -73,21 +77,23 @@ public final class JwtIssuerReactiveAuthenticationManagerResolver } /** - * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the provided parameters - * + * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the + * provided parameters * @param trustedIssuers a collection of trusted issuers */ public JwtIssuerReactiveAuthenticationManagerResolver(Collection trustedIssuers) { Assert.notEmpty(trustedIssuers, "trustedIssuers cannot be empty"); - this.issuerAuthenticationManagerResolver = - new TrustedIssuerJwtAuthenticationManagerResolver(new ArrayList<>(trustedIssuers)::contains); + this.issuerAuthenticationManagerResolver = new TrustedIssuerJwtAuthenticationManagerResolver( + new ArrayList<>(trustedIssuers)::contains); } /** - * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the provided parameters + * Construct a {@link JwtIssuerReactiveAuthenticationManagerResolver} using the + * provided parameters * - * Note that the {@link ReactiveAuthenticationManagerResolver} provided in this constructor will need to - * verify that the issuer is trusted. This should be done via an allowed list of issuers. + * Note that the {@link ReactiveAuthenticationManagerResolver} provided in this + * constructor will need to verify that the issuer is trusted. This should be done via + * an allowed list of issuers. * * One way to achieve this is with a {@link Map} where the keys are the known issuers: *
        @@ -95,65 +101,64 @@ public final class JwtIssuerReactiveAuthenticationManagerResolver
         	 *     authenticationManagers.put("https://issuerOne.example.org", managerOne);
         	 *     authenticationManagers.put("https://issuerTwo.example.org", managerTwo);
         	 *     JwtIssuerReactiveAuthenticationManagerResolver resolver = new JwtIssuerReactiveAuthenticationManagerResolver
        -	 *     	(issuer -> Mono.justOrEmpty(authenticationManagers.get(issuer));
        +	 *     	((issuer) -> Mono.justOrEmpty(authenticationManagers.get(issuer));
         	 * 
        * * The keys in the {@link Map} are the trusted issuers. - * - * @param issuerAuthenticationManagerResolver a strategy for resolving the {@link ReactiveAuthenticationManager} - * by the issuer + * @param issuerAuthenticationManagerResolver a strategy for resolving the + * {@link ReactiveAuthenticationManager} by the issuer */ - public JwtIssuerReactiveAuthenticationManagerResolver - (ReactiveAuthenticationManagerResolver issuerAuthenticationManagerResolver) { - + public JwtIssuerReactiveAuthenticationManagerResolver( + ReactiveAuthenticationManagerResolver issuerAuthenticationManagerResolver) { Assert.notNull(issuerAuthenticationManagerResolver, "issuerAuthenticationManagerResolver cannot be null"); this.issuerAuthenticationManagerResolver = issuerAuthenticationManagerResolver; } /** - * Return an {@link AuthenticationManager} based off of the `iss` claim found in the request's bearer token - * - * @throws OAuth2AuthenticationException if the bearer token is malformed or an {@link ReactiveAuthenticationManager} - * can't be derived from the issuer + * Return an {@link AuthenticationManager} based off of the `iss` claim found in the + * request's bearer token + * @throws OAuth2AuthenticationException if the bearer token is malformed or an + * {@link ReactiveAuthenticationManager} can't be derived from the issuer */ @Override public Mono resolve(ServerWebExchange exchange) { + // @formatter:off return this.issuerConverter.convert(exchange) - .flatMap(issuer -> - this.issuerAuthenticationManagerResolver.resolve(issuer).switchIfEmpty( - Mono.error(() -> new InvalidBearerTokenException("Invalid issuer " + issuer))) + .flatMap((issuer) -> this.issuerAuthenticationManagerResolver + .resolve(issuer) + .switchIfEmpty(Mono.error(() -> new InvalidBearerTokenException("Invalid issuer " + issuer))) ); + // @formatter:on } - private static class JwtClaimIssuerConverter - implements Converter> { + private static class JwtClaimIssuerConverter implements Converter> { - private final ServerBearerTokenAuthenticationConverter converter = - new ServerBearerTokenAuthenticationConverter(); + private final ServerBearerTokenAuthenticationConverter converter = new ServerBearerTokenAuthenticationConverter(); @Override public Mono convert(@NonNull ServerWebExchange exchange) { - return this.converter.convert(exchange).map(convertedToken -> { + return this.converter.convert(exchange).map((convertedToken) -> { BearerTokenAuthenticationToken token = (BearerTokenAuthenticationToken) convertedToken; try { String issuer = JWTParser.parse(token.getToken()).getJWTClaimsSet().getIssuer(); if (issuer == null) { throw new InvalidBearerTokenException("Missing issuer"); - } else { - return issuer; } - } catch (Exception e) { - throw new InvalidBearerTokenException(e.getMessage(), e); + return issuer; + } + catch (Exception ex) { + throw new InvalidBearerTokenException(ex.getMessage(), ex); } }); } + } private static class TrustedIssuerJwtAuthenticationManagerResolver implements ReactiveAuthenticationManagerResolver { - private final Map> authenticationManagers = - new ConcurrentHashMap<>(); + private final Map> authenticationManagers = new ConcurrentHashMap<>(); + private final Predicate trustedIssuer; TrustedIssuerJwtAuthenticationManagerResolver(Predicate trustedIssuer) { @@ -165,12 +170,15 @@ public final class JwtIssuerReactiveAuthenticationManagerResolver if (!this.trustedIssuer.test(issuer)) { return Mono.empty(); } - return this.authenticationManagers.computeIfAbsent(issuer, k -> - Mono.fromCallable(() -> - new JwtReactiveAuthenticationManager(ReactiveJwtDecoders.fromIssuerLocation(k)) - ) - .subscribeOn(Schedulers.boundedElastic()) - .cache()); + // @formatter:off + return this.authenticationManagers.computeIfAbsent(issuer, + (k) -> Mono.fromCallable(() -> new JwtReactiveAuthenticationManager(ReactiveJwtDecoders.fromIssuerLocation(k))) + .subscribeOn(Schedulers.boundedElastic()) + .cache() + ); + // @formatter:on } + } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java index ab877b2181..50c0470bb4 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java @@ -39,10 +39,11 @@ import org.springframework.util.Assert; * @since 5.1 */ public final class JwtReactiveAuthenticationManager implements ReactiveAuthenticationManager { + private final ReactiveJwtDecoder jwtDecoder; - private Converter> jwtAuthenticationConverter - = new ReactiveJwtAuthenticationConverterAdapter(new JwtAuthenticationConverter()); + private Converter> jwtAuthenticationConverter = new ReactiveJwtAuthenticationConverterAdapter( + new JwtAuthenticationConverter()); public JwtReactiveAuthenticationManager(ReactiveJwtDecoder jwtDecoder) { Assert.notNull(jwtDecoder, "jwtDecoder cannot be null"); @@ -51,33 +52,34 @@ public final class JwtReactiveAuthenticationManager implements ReactiveAuthentic @Override public Mono authenticate(Authentication authentication) { + // @formatter:off return Mono.justOrEmpty(authentication) - .filter(a -> a instanceof BearerTokenAuthenticationToken) + .filter((a) -> a instanceof BearerTokenAuthenticationToken) .cast(BearerTokenAuthenticationToken.class) .map(BearerTokenAuthenticationToken::getToken) .flatMap(this.jwtDecoder::decode) .flatMap(this.jwtAuthenticationConverter::convert) .cast(Authentication.class) .onErrorMap(JwtException.class, this::onError); + // @formatter:on } /** - * Use the given {@link Converter} for converting a {@link Jwt} into an {@link AbstractAuthenticationToken}. - * + * Use the given {@link Converter} for converting a {@link Jwt} into an + * {@link AbstractAuthenticationToken}. * @param jwtAuthenticationConverter the {@link Converter} to use */ public void setJwtAuthenticationConverter( Converter> jwtAuthenticationConverter) { - Assert.notNull(jwtAuthenticationConverter, "jwtAuthenticationConverter cannot be null"); this.jwtAuthenticationConverter = jwtAuthenticationConverter; } - private AuthenticationException onError(JwtException e) { - if (e instanceof BadJwtException) { - return new InvalidBearerTokenException(e.getMessage(), e); - } else { - return new AuthenticationServiceException(e.getMessage(), e); + private AuthenticationException onError(JwtException ex) { + if (ex instanceof BadJwtException) { + return new InvalidBearerTokenException(ex.getMessage(), ex); } + return new AuthenticationServiceException(ex.getMessage(), ex); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProvider.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProvider.java index f1818b73ac..6554f2b0bb 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProvider.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.authentication; import java.time.Instant; @@ -29,29 +30,29 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.InvalidBearerTokenException; import org.springframework.security.oauth2.server.resource.introspection.BadOpaqueTokenException; +import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionException; import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUED_AT; - /** * An {@link AuthenticationProvider} implementation for opaque - * Bearer Tokens, - * using an - * OAuth 2.0 Introspection Endpoint - * to check the token's validity and reveal its attributes. + * Bearer + * Tokens, using an + * OAuth 2.0 Introspection + * Endpoint to check the token's validity and reveal its attributes. *

        - * This {@link AuthenticationProvider} is responsible for introspecting and verifying an opaque access token, - * returning its attributes set as part of the {@link Authentication} statement. + * This {@link AuthenticationProvider} is responsible for introspecting and verifying an + * opaque access token, returning its attributes set as part of the {@link Authentication} + * statement. *

        - * Scopes are translated into {@link GrantedAuthority}s according to the following algorithm: + * Scopes are translated into {@link GrantedAuthority}s according to the following + * algorithm: *

          - *
        1. - * If there is a "scope" attribute, then convert to a {@link Collection} of {@link String}s. - *
        2. - * Take the resulting {@link Collection} and prepend the "SCOPE_" keyword to each element, adding as {@link GrantedAuthority}s. + *
        3. If there is a "scope" attribute, then convert to a {@link Collection} of + * {@link String}s. + *
        4. Take the resulting {@link Collection} and prepend the "SCOPE_" keyword to each + * element, adding as {@link GrantedAuthority}s. *
        * * @author Josh Cummings @@ -59,11 +60,11 @@ import static org.springframework.security.oauth2.server.resource.introspection. * @see AuthenticationProvider */ public final class OpaqueTokenAuthenticationProvider implements AuthenticationProvider { + private OpaqueTokenIntrospector introspector; /** * Creates a {@code OpaqueTokenAuthenticationProvider} with the provided parameters - * * @param introspector The {@link OpaqueTokenIntrospector} to use */ public OpaqueTokenAuthenticationProvider(OpaqueTokenIntrospector introspector) { @@ -73,10 +74,9 @@ public final class OpaqueTokenAuthenticationProvider implements AuthenticationPr /** * Introspect and validate the opaque - * Bearer Token. - * + * Bearer + * Token. * @param authentication the authentication request object. - * * @return A successful authentication * @throws AuthenticationException if authentication failed for some reason */ @@ -86,34 +86,34 @@ public final class OpaqueTokenAuthenticationProvider implements AuthenticationPr return null; } BearerTokenAuthenticationToken bearer = (BearerTokenAuthenticationToken) authentication; - - OAuth2AuthenticatedPrincipal principal; - try { - principal = this.introspector.introspect(bearer.getToken()); - } catch (BadOpaqueTokenException failed) { - throw new InvalidBearerTokenException(failed.getMessage()); - } catch (OAuth2IntrospectionException failed) { - throw new AuthenticationServiceException(failed.getMessage()); - } - + OAuth2AuthenticatedPrincipal principal = getOAuth2AuthenticatedPrincipal(bearer); AbstractAuthenticationToken result = convert(principal, bearer.getToken()); result.setDetails(bearer.getDetails()); return result; } - /** - * {@inheritDoc} - */ + private OAuth2AuthenticatedPrincipal getOAuth2AuthenticatedPrincipal(BearerTokenAuthenticationToken bearer) { + try { + return this.introspector.introspect(bearer.getToken()); + } + catch (BadOpaqueTokenException failed) { + throw new InvalidBearerTokenException(failed.getMessage()); + } + catch (OAuth2IntrospectionException failed) { + throw new AuthenticationServiceException(failed.getMessage()); + } + } + @Override public boolean supports(Class authentication) { return BearerTokenAuthenticationToken.class.isAssignableFrom(authentication); } private AbstractAuthenticationToken convert(OAuth2AuthenticatedPrincipal principal, String token) { - Instant iat = principal.getAttribute(ISSUED_AT); - Instant exp = principal.getAttribute(EXPIRES_AT); - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - token, iat, exp); + Instant iat = principal.getAttribute(OAuth2IntrospectionClaimNames.ISSUED_AT); + Instant exp = principal.getAttribute(OAuth2IntrospectionClaimNames.EXPIRES_AT); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, token, iat, exp); return new BearerTokenAuthentication(principal, accessToken, principal.getAuthorities()); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManager.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManager.java index 30a918a4da..ad2d848241 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManager.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManager.java @@ -30,29 +30,29 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.InvalidBearerTokenException; import org.springframework.security.oauth2.server.resource.introspection.BadOpaqueTokenException; +import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionException; import org.springframework.security.oauth2.server.resource.introspection.ReactiveOpaqueTokenIntrospector; import org.springframework.util.Assert; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUED_AT; - /** * An {@link ReactiveAuthenticationManager} implementation for opaque - * Bearer Tokens, - * using an - * OAuth 2.0 Introspection Endpoint - * to check the token's validity and reveal its attributes. + * Bearer + * Tokens, using an + * OAuth 2.0 Introspection + * Endpoint to check the token's validity and reveal its attributes. *

        - * This {@link ReactiveAuthenticationManager} is responsible for introspecting and verifying an opaque access token, - * returning its attributes set as part of the {@link Authentication} statement. + * This {@link ReactiveAuthenticationManager} is responsible for introspecting and + * verifying an opaque access token, returning its attributes set as part of the + * {@link Authentication} statement. *

        - * Scopes are translated into {@link GrantedAuthority}s according to the following algorithm: + * Scopes are translated into {@link GrantedAuthority}s according to the following + * algorithm: *

          - *
        1. - * If there is a "scope" attribute, then convert to a {@link Collection} of {@link String}s. - *
        2. - * Take the resulting {@link Collection} and prepend the "SCOPE_" keyword to each element, adding as {@link GrantedAuthority}s. + *
        3. If there is a "scope" attribute, then convert to a {@link Collection} of + * {@link String}s. + *
        4. Take the resulting {@link Collection} and prepend the "SCOPE_" keyword to each + * element, adding as {@link GrantedAuthority}s. *
        * * @author Josh Cummings @@ -60,11 +60,12 @@ import static org.springframework.security.oauth2.server.resource.introspection. * @see ReactiveAuthenticationManager */ public class OpaqueTokenReactiveAuthenticationManager implements ReactiveAuthenticationManager { + private ReactiveOpaqueTokenIntrospector introspector; /** - * Creates a {@code OpaqueTokenReactiveAuthenticationManager} with the provided parameters - * + * Creates a {@code OpaqueTokenReactiveAuthenticationManager} with the provided + * parameters * @param introspector The {@link ReactiveOpaqueTokenIntrospector} to use */ public OpaqueTokenReactiveAuthenticationManager(ReactiveOpaqueTokenIntrospector introspector) { @@ -74,33 +75,35 @@ public class OpaqueTokenReactiveAuthenticationManager implements ReactiveAuthent @Override public Mono authenticate(Authentication authentication) { + // @formatter:off return Mono.justOrEmpty(authentication) .filter(BearerTokenAuthenticationToken.class::isInstance) .cast(BearerTokenAuthenticationToken.class) .map(BearerTokenAuthenticationToken::getToken) .flatMap(this::authenticate) .cast(Authentication.class); + // @formatter:on } private Mono authenticate(String token) { + // @formatter:off return this.introspector.introspect(token) - .map(principal -> { - Instant iat = principal.getAttribute(ISSUED_AT); - Instant exp = principal.getAttribute(EXPIRES_AT); - + .map((principal) -> { + Instant iat = principal.getAttribute(OAuth2IntrospectionClaimNames.ISSUED_AT); + Instant exp = principal.getAttribute(OAuth2IntrospectionClaimNames.EXPIRES_AT); // construct token - OAuth2AccessToken accessToken = - new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, token, iat, exp); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, token, iat, exp); return new BearerTokenAuthentication(principal, accessToken, principal.getAuthorities()); }) .onErrorMap(OAuth2IntrospectionException.class, this::onError); + // @formatter:on } - private AuthenticationException onError(OAuth2IntrospectionException e) { - if (e instanceof BadOpaqueTokenException) { - return new InvalidBearerTokenException(e.getMessage(), e); - } else { - return new AuthenticationServiceException(e.getMessage(), e); + private AuthenticationException onError(OAuth2IntrospectionException ex) { + if (ex instanceof BadOpaqueTokenException) { + return new InvalidBearerTokenException(ex.getMessage(), ex); } + return new AuthenticationServiceException(ex.getMessage(), ex); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverter.java index 4bd8ba0d58..05e2f93a13 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverter.java @@ -26,32 +26,36 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; /** - * Reactive version of {@link JwtAuthenticationConverter} for converting a {@link Jwt} - * to a {@link AbstractAuthenticationToken Mono<AbstractAuthenticationToken>}. + * Reactive version of {@link JwtAuthenticationConverter} for converting a {@link Jwt} to + * a {@link AbstractAuthenticationToken Mono<AbstractAuthenticationToken>}. * * @author Eric Deandrea * @since 5.2 */ public final class ReactiveJwtAuthenticationConverter implements Converter> { - private Converter> jwtGrantedAuthoritiesConverter - = new ReactiveJwtGrantedAuthoritiesConverterAdapter(new JwtGrantedAuthoritiesConverter()); + + private Converter> jwtGrantedAuthoritiesConverter = new ReactiveJwtGrantedAuthoritiesConverterAdapter( + new JwtGrantedAuthoritiesConverter()); @Override public Mono convert(Jwt jwt) { + // @formatter:off return this.jwtGrantedAuthoritiesConverter.convert(jwt) .collectList() - .map(authorities -> new JwtAuthenticationToken(jwt, authorities)); + .map((authorities) -> new JwtAuthenticationToken(jwt, authorities)); + // @formatter:on } /** - * Sets the {@link Converter Converter<Jwt, Flux<GrantedAuthority>>} to use. - * Defaults to a reactive {@link JwtGrantedAuthoritiesConverter}. - * + * Sets the {@link Converter Converter<Jwt, Flux<GrantedAuthority>>} to + * use. Defaults to a reactive {@link JwtGrantedAuthoritiesConverter}. * @param jwtGrantedAuthoritiesConverter The converter * @see JwtGrantedAuthoritiesConverter */ - public void setJwtGrantedAuthoritiesConverter(Converter> jwtGrantedAuthoritiesConverter) { + public void setJwtGrantedAuthoritiesConverter( + Converter> jwtGrantedAuthoritiesConverter) { Assert.notNull(jwtGrantedAuthoritiesConverter, "jwtGrantedAuthoritiesConverter cannot be null"); this.jwtGrantedAuthoritiesConverter = jwtGrantedAuthoritiesConverter; } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapter.java index 1727b67086..a55e130815 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapter.java @@ -30,6 +30,7 @@ import org.springframework.util.Assert; * @since 5.1.1 */ public class ReactiveJwtAuthenticationConverterAdapter implements Converter> { + private final Converter delegate; public ReactiveJwtAuthenticationConverterAdapter(Converter delegate) { @@ -37,7 +38,9 @@ public class ReactiveJwtAuthenticationConverterAdapter implements Converter convert(Jwt jwt) { return Mono.just(jwt).map(this.delegate::convert); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapter.java index cf98c33706..5dc1b81df2 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapter.java @@ -26,11 +26,11 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; /** - * Adapts a {@link Converter Converter<Jwt, Collection<GrantedAuthority>>} to a - * {@link Converter Converter<Jwt, Flux<GrantedAuthority>>}. + * Adapts a {@link Converter Converter<Jwt, Collection<GrantedAuthority>>} to + * a {@link Converter Converter<Jwt, Flux<GrantedAuthority>>}. *

        - * Make sure the {@link Converter Converter<Jwt, Collection<GrantedAuthority>>} - * being adapted is non-blocking. + * Make sure the {@link Converter Converter<Jwt, + * Collection<GrantedAuthority>>} being adapted is non-blocking. *

        * * @author Eric Deandrea @@ -38,9 +38,11 @@ import org.springframework.util.Assert; * @see JwtGrantedAuthoritiesConverter */ public final class ReactiveJwtGrantedAuthoritiesConverterAdapter implements Converter> { + private final Converter> grantedAuthoritiesConverter; - public ReactiveJwtGrantedAuthoritiesConverterAdapter(Converter> grantedAuthoritiesConverter) { + public ReactiveJwtGrantedAuthoritiesConverterAdapter( + Converter> grantedAuthoritiesConverter) { Assert.notNull(grantedAuthoritiesConverter, "grantedAuthoritiesConverter cannot be null"); this.grantedAuthoritiesConverter = grantedAuthoritiesConverter; } @@ -49,4 +51,5 @@ public final class ReactiveJwtGrantedAuthoritiesConverterAdapter implements Conv public Flux convert(Jwt jwt) { return Flux.fromIterable(this.grantedAuthoritiesConverter.convert(jwt)); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/package-info.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/package-info.java index 2ee786b68e..3b022e8675 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/package-info.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/package-info.java @@ -15,6 +15,7 @@ */ /** - * OAuth 2.0 Resource Server {@code Authentication}s and supporting classes and interfaces. + * OAuth 2.0 Resource Server {@code Authentication}s and supporting classes and + * interfaces. */ package org.springframework.security.oauth2.server.resource.authentication; diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/BadOpaqueTokenException.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/BadOpaqueTokenException.java index b190400cae..5e155c8bce 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/BadOpaqueTokenException.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/BadOpaqueTokenException.java @@ -17,13 +17,15 @@ package org.springframework.security.oauth2.server.resource.introspection; /** - * An exception similar to {@link org.springframework.security.authentication.BadCredentialsException} - * that indicates an opaque token that is invalid in some way. + * An exception similar to + * {@link org.springframework.security.authentication.BadCredentialsException} that + * indicates an opaque token that is invalid in some way. * * @author Josh Cummings * @since 5.3 */ public class BadOpaqueTokenException extends OAuth2IntrospectionException { + public BadOpaqueTokenException(String message) { super(message); } @@ -31,4 +33,5 @@ public class BadOpaqueTokenException extends OAuth2IntrospectionException { public BadOpaqueTokenException(String message, Throwable cause) { super(message, cause); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospector.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospector.java index 440ba8b9c4..a1d427712f 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospector.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospector.java @@ -46,32 +46,26 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.AUDIENCE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.CLIENT_ID; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUED_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUER; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.NOT_BEFORE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SCOPE; - /** - * A Nimbus implementation of {@link OpaqueTokenIntrospector} that verifies and introspects - * a token using the configured - * OAuth 2.0 Introspection Endpoint. + * A Nimbus implementation of {@link OpaqueTokenIntrospector} that verifies and + * introspects a token using the configured + * OAuth 2.0 Introspection + * Endpoint. * * @author Josh Cummings * @author MD Sayem Ahmed * @since 5.2 */ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { + private Converter> requestEntityConverter; + private RestOperations restOperations; private final String authorityPrefix = "SCOPE_"; /** * Creates a {@code OpaqueTokenAuthenticationProvider} with the provided parameters - * * @param introspectionUri The introspection endpoint uri * @param clientId The client id authorized to introspect * @param clientSecret The client's secret @@ -80,7 +74,6 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { Assert.notNull(introspectionUri, "introspectionUri cannot be null"); Assert.notNull(clientId, "clientId cannot be null"); Assert.notNull(clientSecret, "clientSecret cannot be null"); - this.requestEntityConverter = this.defaultRequestEntityConverter(URI.create(introspectionUri)); RestTemplate restTemplate = new RestTemplate(); restTemplate.getInterceptors().add(new BasicAuthenticationInterceptor(clientId, clientSecret)); @@ -90,22 +83,20 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { /** * Creates a {@code OpaqueTokenAuthenticationProvider} with the provided parameters * - * The given {@link RestOperations} should perform its own client authentication against the - * introspection endpoint. - * + * The given {@link RestOperations} should perform its own client authentication + * against the introspection endpoint. * @param introspectionUri The introspection endpoint uri * @param restOperations The client for performing the introspection request */ public NimbusOpaqueTokenIntrospector(String introspectionUri, RestOperations restOperations) { Assert.notNull(introspectionUri, "introspectionUri cannot be null"); Assert.notNull(restOperations, "restOperations cannot be null"); - this.requestEntityConverter = this.defaultRequestEntityConverter(URI.create(introspectionUri)); this.restOperations = restOperations; } private Converter> defaultRequestEntityConverter(URI introspectionUri) { - return token -> { + return (token) -> { HttpHeaders headers = requestHeaders(); MultiValueMap body = requestBody(token); return new RequestEntity<>(body, headers, HttpMethod.POST, introspectionUri); @@ -124,46 +115,40 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { return body; } - /** - * {@inheritDoc} - */ @Override public OAuth2AuthenticatedPrincipal introspect(String token) { RequestEntity requestEntity = this.requestEntityConverter.convert(token); if (requestEntity == null) { throw new OAuth2IntrospectionException("requestEntityConverter returned a null entity"); } - ResponseEntity responseEntity = makeRequest(requestEntity); HTTPResponse httpResponse = adaptToNimbusResponse(responseEntity); TokenIntrospectionResponse introspectionResponse = parseNimbusResponse(httpResponse); TokenIntrospectionSuccessResponse introspectionSuccessResponse = castToNimbusSuccess(introspectionResponse); - - // relying solely on the authorization server to validate this token (not checking 'exp', for example) + // relying solely on the authorization server to validate this token (not checking + // 'exp', for example) if (!introspectionSuccessResponse.isActive()) { throw new BadOpaqueTokenException("Provided token isn't active"); } - return convertClaimsSet(introspectionSuccessResponse); } /** - * Sets the {@link Converter} used for converting the OAuth 2.0 access token to a {@link RequestEntity} - * representation of the OAuth 2.0 token introspection request. - * - * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation - * of the token introspection request + * Sets the {@link Converter} used for converting the OAuth 2.0 access token to a + * {@link RequestEntity} representation of the OAuth 2.0 token introspection request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the token introspection request */ public void setRequestEntityConverter(Converter> requestEntityConverter) { Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); - this.requestEntityConverter = requestEntityConverter; } private ResponseEntity makeRequest(RequestEntity requestEntity) { try { return this.restOperations.exchange(requestEntity, String.class); - } catch (Exception ex) { + } + catch (Exception ex) { throw new OAuth2IntrospectionException(ex.getMessage(), ex); } } @@ -172,10 +157,8 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { HTTPResponse response = new HTTPResponse(responseEntity.getStatusCodeValue()); response.setHeader(HttpHeaders.CONTENT_TYPE, responseEntity.getHeaders().getContentType().toString()); response.setContent(responseEntity.getBody()); - if (response.getStatusCode() != HTTPResponse.SC_OK) { - throw new OAuth2IntrospectionException( - "Introspection endpoint responded with " + response.getStatusCode()); + throw new OAuth2IntrospectionException("Introspection endpoint responded with " + response.getStatusCode()); } return response; } @@ -183,7 +166,8 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { private TokenIntrospectionResponse parseNimbusResponse(HTTPResponse response) { try { return TokenIntrospectionResponse.parse(response); - } catch (Exception ex) { + } + catch (Exception ex) { throw new OAuth2IntrospectionException(ex.getMessage(), ex); } } @@ -203,42 +187,43 @@ public class NimbusOpaqueTokenIntrospector implements OpaqueTokenIntrospector { for (Audience audience : response.getAudience()) { audiences.add(audience.getValue()); } - claims.put(AUDIENCE, Collections.unmodifiableList(audiences)); + claims.put(OAuth2IntrospectionClaimNames.AUDIENCE, Collections.unmodifiableList(audiences)); } if (response.getClientID() != null) { - claims.put(CLIENT_ID, response.getClientID().getValue()); + claims.put(OAuth2IntrospectionClaimNames.CLIENT_ID, response.getClientID().getValue()); } if (response.getExpirationTime() != null) { Instant exp = response.getExpirationTime().toInstant(); - claims.put(EXPIRES_AT, exp); + claims.put(OAuth2IntrospectionClaimNames.EXPIRES_AT, exp); } if (response.getIssueTime() != null) { Instant iat = response.getIssueTime().toInstant(); - claims.put(ISSUED_AT, iat); + claims.put(OAuth2IntrospectionClaimNames.ISSUED_AT, iat); } if (response.getIssuer() != null) { - claims.put(ISSUER, issuer(response.getIssuer().getValue())); + claims.put(OAuth2IntrospectionClaimNames.ISSUER, issuer(response.getIssuer().getValue())); } if (response.getNotBeforeTime() != null) { - claims.put(NOT_BEFORE, response.getNotBeforeTime().toInstant()); + claims.put(OAuth2IntrospectionClaimNames.NOT_BEFORE, response.getNotBeforeTime().toInstant()); } if (response.getScope() != null) { List scopes = Collections.unmodifiableList(response.getScope().toStringList()); - claims.put(SCOPE, scopes); - + claims.put(OAuth2IntrospectionClaimNames.SCOPE, scopes); for (String scope : scopes) { authorities.add(new SimpleGrantedAuthority(this.authorityPrefix + scope)); } } - return new OAuth2IntrospectionAuthenticatedPrincipal(claims, authorities); } private URL issuer(String uri) { try { return new URL(uri); - } catch (Exception ex) { - throw new OAuth2IntrospectionException("Invalid " + ISSUER + " value: " + uri); + } + catch (Exception ex) { + throw new OAuth2IntrospectionException( + "Invalid " + OAuth2IntrospectionClaimNames.ISSUER + " value: " + uri); } } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospector.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospector.java index 05c1e9c4c3..8b8dc9bd39 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospector.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospector.java @@ -43,31 +43,26 @@ import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.AUDIENCE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.CLIENT_ID; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUED_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUER; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.NOT_BEFORE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SCOPE; - /** - * A Nimbus implementation of {@link ReactiveOpaqueTokenIntrospector} that verifies and introspects - * a token using the configured - * OAuth 2.0 Introspection Endpoint. + * A Nimbus implementation of {@link ReactiveOpaqueTokenIntrospector} that verifies and + * introspects a token using the configured + * OAuth 2.0 Introspection + * Endpoint. * * @author Josh Cummings * @since 5.2 */ public class NimbusReactiveOpaqueTokenIntrospector implements ReactiveOpaqueTokenIntrospector { - private URI introspectionUri; - private WebClient webClient; + + private final URI introspectionUri; + + private final WebClient webClient; private String authorityPrefix = "SCOPE_"; /** - * Creates a {@code OpaqueTokenReactiveAuthenticationManager} with the provided parameters - * + * Creates a {@code OpaqueTokenReactiveAuthenticationManager} with the provided + * parameters * @param introspectionUri The introspection endpoint uri * @param clientId The client id authorized to introspect * @param clientSecret The client secret for the authorized client @@ -76,68 +71,67 @@ public class NimbusReactiveOpaqueTokenIntrospector implements ReactiveOpaqueToke Assert.hasText(introspectionUri, "introspectionUri cannot be empty"); Assert.hasText(clientId, "clientId cannot be empty"); Assert.notNull(clientSecret, "clientSecret cannot be null"); - this.introspectionUri = URI.create(introspectionUri); - this.webClient = WebClient.builder() - .defaultHeaders(h -> h.setBasicAuth(clientId, clientSecret)) - .build(); + this.webClient = WebClient.builder().defaultHeaders((h) -> h.setBasicAuth(clientId, clientSecret)).build(); } /** - * Creates a {@code OpaqueTokenReactiveAuthenticationManager} with the provided parameters - * + * Creates a {@code OpaqueTokenReactiveAuthenticationManager} with the provided + * parameters * @param introspectionUri The introspection endpoint uri * @param webClient The client for performing the introspection request */ public NimbusReactiveOpaqueTokenIntrospector(String introspectionUri, WebClient webClient) { Assert.hasText(introspectionUri, "introspectionUri cannot be null"); Assert.notNull(webClient, "webClient cannot be null"); - this.introspectionUri = URI.create(introspectionUri); this.webClient = webClient; } - /** - * {@inheritDoc} - */ @Override public Mono introspect(String token) { + // @formatter:off return Mono.just(token) .flatMap(this::makeRequest) .flatMap(this::adaptToNimbusResponse) .map(this::parseNimbusResponse) .map(this::castToNimbusSuccess) - .doOnNext(response -> validate(token, response)) + .doOnNext((response) -> validate(token, response)) .map(this::convertClaimsSet) - .onErrorMap(e -> !(e instanceof OAuth2IntrospectionException), this::onError); + .onErrorMap((e) -> !(e instanceof OAuth2IntrospectionException), this::onError); + // @formatter:on } private Mono makeRequest(String token) { + // @formatter:off return this.webClient.post() .uri(this.introspectionUri) .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_UTF8_VALUE) .body(BodyInserters.fromFormData("token", token)) .exchange(); + // @formatter:on } private Mono adaptToNimbusResponse(ClientResponse responseEntity) { HTTPResponse response = new HTTPResponse(responseEntity.rawStatusCode()); response.setHeader(HttpHeaders.CONTENT_TYPE, responseEntity.headers().contentType().get().toString()); if (response.getStatusCode() != HTTPResponse.SC_OK) { + // @formatter:off return responseEntity.bodyToFlux(DataBuffer.class) - .map(DataBufferUtils::release) - .then(Mono.error(new OAuth2IntrospectionException( - "Introspection endpoint responded with " + response.getStatusCode()))); + .map(DataBufferUtils::release) + .then(Mono.error(new OAuth2IntrospectionException( + "Introspection endpoint responded with " + response.getStatusCode())) + ); + // @formatter:on } - return responseEntity.bodyToMono(String.class) - .doOnNext(response::setContent) - .map(body -> response); + return responseEntity.bodyToMono(String.class).doOnNext(response::setContent).map((body) -> response); } private TokenIntrospectionResponse parseNimbusResponse(HTTPResponse response) { try { return TokenIntrospectionResponse.parse(response); - } catch (Exception ex) { + } + catch (Exception ex) { throw new OAuth2IntrospectionException(ex.getMessage(), ex); } } @@ -150,7 +144,8 @@ public class NimbusReactiveOpaqueTokenIntrospector implements ReactiveOpaqueToke } private void validate(String token, TokenIntrospectionSuccessResponse response) { - // relying solely on the authorization server to validate this token (not checking 'exp', for example) + // relying solely on the authorization server to validate this token (not checking + // 'exp', for example) if (!response.isActive()) { throw new BadOpaqueTokenException("Provided token isn't active"); } @@ -164,46 +159,48 @@ public class NimbusReactiveOpaqueTokenIntrospector implements ReactiveOpaqueToke for (Audience audience : response.getAudience()) { audiences.add(audience.getValue()); } - claims.put(AUDIENCE, Collections.unmodifiableList(audiences)); + claims.put(OAuth2IntrospectionClaimNames.AUDIENCE, Collections.unmodifiableList(audiences)); } if (response.getClientID() != null) { - claims.put(CLIENT_ID, response.getClientID().getValue()); + claims.put(OAuth2IntrospectionClaimNames.CLIENT_ID, response.getClientID().getValue()); } if (response.getExpirationTime() != null) { Instant exp = response.getExpirationTime().toInstant(); - claims.put(EXPIRES_AT, exp); + claims.put(OAuth2IntrospectionClaimNames.EXPIRES_AT, exp); } if (response.getIssueTime() != null) { Instant iat = response.getIssueTime().toInstant(); - claims.put(ISSUED_AT, iat); + claims.put(OAuth2IntrospectionClaimNames.ISSUED_AT, iat); } if (response.getIssuer() != null) { - claims.put(ISSUER, issuer(response.getIssuer().getValue())); + claims.put(OAuth2IntrospectionClaimNames.ISSUER, issuer(response.getIssuer().getValue())); } if (response.getNotBeforeTime() != null) { - claims.put(NOT_BEFORE, response.getNotBeforeTime().toInstant()); + claims.put(OAuth2IntrospectionClaimNames.NOT_BEFORE, response.getNotBeforeTime().toInstant()); } if (response.getScope() != null) { List scopes = Collections.unmodifiableList(response.getScope().toStringList()); - claims.put(SCOPE, scopes); + claims.put(OAuth2IntrospectionClaimNames.SCOPE, scopes); for (String scope : scopes) { authorities.add(new SimpleGrantedAuthority(this.authorityPrefix + scope)); } } - return new OAuth2IntrospectionAuthenticatedPrincipal(claims, authorities); } private URL issuer(String uri) { try { return new URL(uri); - } catch (Exception ex) { - throw new OAuth2IntrospectionException("Invalid " + ISSUER + " value: " + uri); + } + catch (Exception ex) { + throw new OAuth2IntrospectionException( + "Invalid " + OAuth2IntrospectionClaimNames.ISSUER + " value: " + uri); } } - private OAuth2IntrospectionException onError(Throwable e) { - return new OAuth2IntrospectionException(e.getMessage(), e); + private OAuth2IntrospectionException onError(Throwable ex) { + return new OAuth2IntrospectionException(ex.getMessage(), ex); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipal.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipal.java index 8e9427831e..08ec004143 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipal.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipal.java @@ -29,40 +29,39 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; * * @author David Kovac * @since 5.4 - * @see Introspection Response + * @see Introspection Response */ -public final class OAuth2IntrospectionAuthenticatedPrincipal implements OAuth2IntrospectionClaimAccessor, - OAuth2AuthenticatedPrincipal, Serializable { +public final class OAuth2IntrospectionAuthenticatedPrincipal + implements OAuth2IntrospectionClaimAccessor, OAuth2AuthenticatedPrincipal, Serializable { + private final OAuth2AuthenticatedPrincipal delegate; /** - * Constructs an {@code OAuth2IntrospectionAuthenticatedPrincipal} using the provided parameters. - * - * @param attributes the attributes of the OAuth 2.0 Token Introspection + * Constructs an {@code OAuth2IntrospectionAuthenticatedPrincipal} using the provided + * parameters. + * @param attributes the attributes of the OAuth 2.0 Token Introspection * @param authorities the authorities of the OAuth 2.0 Token Introspection */ public OAuth2IntrospectionAuthenticatedPrincipal(Map attributes, Collection authorities) { - this.delegate = new DefaultOAuth2AuthenticatedPrincipal(attributes, authorities); } /** - * Constructs an {@code OAuth2IntrospectionAuthenticatedPrincipal} using the provided parameters. - * - * @param name the name attached to the OAuth 2.0 Token Introspection - * @param attributes the attributes of the OAuth 2.0 Token Introspection + * Constructs an {@code OAuth2IntrospectionAuthenticatedPrincipal} using the provided + * parameters. + * @param name the name attached to the OAuth 2.0 Token Introspection + * @param attributes the attributes of the OAuth 2.0 Token Introspection * @param authorities the authorities of the OAuth 2.0 Token Introspection */ public OAuth2IntrospectionAuthenticatedPrincipal(String name, Map attributes, Collection authorities) { - this.delegate = new DefaultOAuth2AuthenticatedPrincipal(name, attributes, authorities); } /** * Gets the attributes of the OAuth 2.0 Token Introspection in map form. - * * @return a {@link Map} of the attribute's objects keyed by the attribute's names */ @Override @@ -71,9 +70,8 @@ public final class OAuth2IntrospectionAuthenticatedPrincipal implements OAuth2In } /** - * Get the {@link Collection} of {@link GrantedAuthority}s associated - * with this OAuth 2.0 Token Introspection - * + * Get the {@link Collection} of {@link GrantedAuthority}s associated with this OAuth + * 2.0 Token Introspection * @return the OAuth 2.0 Token Introspection authorities */ @Override @@ -81,19 +79,14 @@ public final class OAuth2IntrospectionAuthenticatedPrincipal implements OAuth2In return this.delegate.getAuthorities(); } - /** - * {@inheritDoc} - */ @Override public String getName() { return this.delegate.getName(); } - /** - * {@inheritDoc} - */ @Override public Map getClaims() { return getAttributes(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimAccessor.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimAccessor.java index 18c3c30e78..b95109d670 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimAccessor.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimAccessor.java @@ -23,20 +23,21 @@ import java.util.List; import org.springframework.security.oauth2.core.ClaimAccessor; /** - * A {@link ClaimAccessor} for the "claims" that may be contained - * in the Introspection Response. + * A {@link ClaimAccessor} for the "claims" that may be contained in the + * Introspection Response. * * @author David Kovac * @since 5.4 * @see ClaimAccessor * @see OAuth2IntrospectionClaimNames * @see OAuth2IntrospectionAuthenticatedPrincipal - * @see Introspection Response + * @see Introspection Response */ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { + /** * Returns the indicator {@code (active)} whether or not the token is currently active - * * @return the indicator whether or not the token is currently active */ default boolean isActive() { @@ -45,7 +46,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns the scopes {@code (scope)} associated with the token - * * @return the scopes associated with the token */ default String getScope() { @@ -54,7 +54,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns the client identifier {@code (client_id)} for the token - * * @return the client identifier for the token */ default String getClientId() { @@ -62,9 +61,10 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { } /** - * Returns a human-readable identifier {@code (username)} for the resource owner that authorized the token - * - * @return a human-readable identifier for the resource owner that authorized the token + * Returns a human-readable identifier {@code (username)} for the resource owner that + * authorized the token + * @return a human-readable identifier for the resource owner that authorized the + * token */ default String getUsername() { return this.getClaimAsString(OAuth2IntrospectionClaimNames.USERNAME); @@ -72,7 +72,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns the type of the token {@code (token_type)}, for example {@code bearer}. - * * @return the type of the token, for example {@code bearer}. */ default String getTokenType() { @@ -81,7 +80,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns a timestamp {@code (exp)} indicating when the token expires - * * @return a timestamp indicating when the token expires */ default Instant getExpiresAt() { @@ -90,7 +88,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns a timestamp {@code (iat)} indicating when the token was issued - * * @return a timestamp indicating when the token was issued */ default Instant getIssuedAt() { @@ -98,8 +95,8 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { } /** - * Returns a timestamp {@code (nbf)} indicating when the token is not to be used before - * + * Returns a timestamp {@code (nbf)} indicating when the token is not to be used + * before * @return a timestamp indicating when the token is not to be used before */ default Instant getNotBefore() { @@ -107,9 +104,10 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { } /** - * Returns usually a machine-readable identifier {@code (sub)} of the resource owner who authorized the token - * - * @return usually a machine-readable identifier of the resource owner who authorized the token + * Returns usually a machine-readable identifier {@code (sub)} of the resource owner + * who authorized the token + * @return usually a machine-readable identifier of the resource owner who authorized + * the token */ default String getSubject() { return this.getClaimAsString(OAuth2IntrospectionClaimNames.SUBJECT); @@ -117,7 +115,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns the intended audience {@code (aud)} for the token - * * @return the intended audience for the token */ default List getAudience() { @@ -126,7 +123,6 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns the issuer {@code (iss)} of the token - * * @return the issuer of the token */ default URL getIssuer() { @@ -135,10 +131,10 @@ public interface OAuth2IntrospectionClaimAccessor extends ClaimAccessor { /** * Returns the identifier {@code (jti)} for the token - * * @return the identifier for the token */ default String getId() { return this.getClaimAsString(OAuth2IntrospectionClaimNames.JTI); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimNames.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimNames.java index d2f011d18a..c21e4bb91a 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimNames.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionClaimNames.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.introspection; /** * The names of the "Introspection Claims" defined by an - * Introspection Response. + * Introspection + * Response. * * @author Josh Cummings * @since 5.2 @@ -40,7 +42,8 @@ public interface OAuth2IntrospectionClaimNames { String CLIENT_ID = "client_id"; /** - * {@code username} - A human-readable identifier for the resource owner that authorized the token + * {@code username} - A human-readable identifier for the resource owner that + * authorized the token */ String USERNAME = "username"; @@ -65,7 +68,8 @@ public interface OAuth2IntrospectionClaimNames { String NOT_BEFORE = "nbf"; /** - * {@code sub} - Usually a machine-readable identifier of the resource owner who authorized the token + * {@code sub} - Usually a machine-readable identifier of the resource owner who + * authorized the token */ String SUBJECT = "sub"; @@ -83,4 +87,5 @@ public interface OAuth2IntrospectionClaimNames { * {@code jti} - The identifier for the token */ String JTI = "jti"; + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionException.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionException.java index ffd468fc54..e2649ba975 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionException.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionException.java @@ -23,6 +23,7 @@ package org.springframework.security.oauth2.server.resource.introspection; * @since 5.2 */ public class OAuth2IntrospectionException extends RuntimeException { + public OAuth2IntrospectionException(String message) { super(message); } @@ -30,4 +31,5 @@ public class OAuth2IntrospectionException extends RuntimeException { public OAuth2IntrospectionException(String message, Throwable cause) { super(message, cause); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OpaqueTokenIntrospector.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OpaqueTokenIntrospector.java index 672647d0fb..056a7766b2 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OpaqueTokenIntrospector.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/OpaqueTokenIntrospector.java @@ -24,11 +24,12 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; * A contract for introspecting and verifying an OAuth 2.0 token. * * A typical implementation of this interface will make a request to an - * OAuth 2.0 Introspection Endpoint - * to verify the token and return its attributes, indicating a successful verification. + * OAuth 2.0 Introspection + * Endpoint to verify the token and return its attributes, indicating a successful + * verification. * - * Another sensible implementation of this interface would be to query a backing store - * of tokens, for example a distributed cache. + * Another sensible implementation of this interface would be to query a backing store of + * tokens, for example a distributed cache. * * @author Josh Cummings * @since 5.2 @@ -40,9 +41,9 @@ public interface OpaqueTokenIntrospector { * Introspect and verify the given token, returning its attributes. * * Returning a {@link Map} is indicative that the token is valid. - * * @param token the token to introspect * @return the token's attributes */ OAuth2AuthenticatedPrincipal introspect(String token); + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/ReactiveOpaqueTokenIntrospector.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/ReactiveOpaqueTokenIntrospector.java index 60d1078571..2d91da5ff4 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/ReactiveOpaqueTokenIntrospector.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/introspection/ReactiveOpaqueTokenIntrospector.java @@ -26,11 +26,12 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; * A contract for introspecting and verifying an OAuth 2.0 token. * * A typical implementation of this interface will make a request to an - * OAuth 2.0 Introspection Endpoint - * to verify the token and return its attributes, indicating a successful verification. + * OAuth 2.0 Introspection + * Endpoint to verify the token and return its attributes, indicating a successful + * verification. * - * Another sensible implementation of this interface would be to query a backing store - * of tokens, for example a distributed cache. + * Another sensible implementation of this interface would be to query a backing store of + * tokens, for example a distributed cache. * * @author Josh Cummings * @since 5.2 @@ -42,9 +43,9 @@ public interface ReactiveOpaqueTokenIntrospector { * Introspect and verify the given token, returning its attributes. * * Returning a {@link Map} is indicative that the token is valid. - * * @param token the token to introspect * @return the token's attributes */ Mono introspect(String token); + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPoint.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPoint.java index b6861448be..f28cdd2c3b 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPoint.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPoint.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.resource.web; import java.util.LinkedHashMap; import java.util.Map; + import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -31,76 +32,62 @@ import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.util.StringUtils; /** - * An {@link AuthenticationEntryPoint} implementation used to commence authentication of protected resource requests - * using {@link BearerTokenAuthenticationFilter}. + * An {@link AuthenticationEntryPoint} implementation used to commence authentication of + * protected resource requests using {@link BearerTokenAuthenticationFilter}. *

        - * Uses information provided by {@link BearerTokenError} to set HTTP response status code and populate - * {@code WWW-Authenticate} HTTP header. + * Uses information provided by {@link BearerTokenError} to set HTTP response status code + * and populate {@code WWW-Authenticate} HTTP header. * * @author Vedran Pavic - * @see BearerTokenError - * @see RFC 6750 Section 3: The WWW-Authenticate - * Response Header Field * @since 5.1 + * @see BearerTokenError + * @see RFC 6750 + * Section 3: The WWW-Authenticate Response Header Field */ public final class BearerTokenAuthenticationEntryPoint implements AuthenticationEntryPoint { private String realmName; /** - * Collect error details from the provided parameters and format according to - * RFC 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and {@code scope}. - * - * @param request that resulted in an AuthenticationException - * @param response so that the user agent can begin authentication + * Collect error details from the provided parameters and format according to RFC + * 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and + * {@code scope}. + * @param request that resulted in an AuthenticationException + * @param response so that the user agent can begin authentication * @param authException that caused the invocation */ @Override - public void commence( - HttpServletRequest request, HttpServletResponse response, + public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) { - HttpStatus status = HttpStatus.UNAUTHORIZED; - Map parameters = new LinkedHashMap<>(); - if (this.realmName != null) { parameters.put("realm", this.realmName); } - if (authException instanceof OAuth2AuthenticationException) { OAuth2Error error = ((OAuth2AuthenticationException) authException).getError(); - parameters.put("error", error.getErrorCode()); - if (StringUtils.hasText(error.getDescription())) { parameters.put("error_description", error.getDescription()); } - if (StringUtils.hasText(error.getUri())) { parameters.put("error_uri", error.getUri()); } - if (error instanceof BearerTokenError) { BearerTokenError bearerTokenError = (BearerTokenError) error; - if (StringUtils.hasText(bearerTokenError.getScope())) { parameters.put("scope", bearerTokenError.getScope()); } - status = ((BearerTokenError) error).getHttpStatus(); } } - String wwwAuthenticate = computeWWWAuthenticateHeaderValue(parameters); - response.addHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticate); response.setStatus(status.value()); } /** * Set the default realm name to use in the bearer token error response - * * @param realmName */ public void setRealmName(String realmName) { @@ -110,20 +97,18 @@ public final class BearerTokenAuthenticationEntryPoint implements Authentication private static String computeWWWAuthenticateHeaderValue(Map parameters) { StringBuilder wwwAuthenticate = new StringBuilder(); wwwAuthenticate.append("Bearer"); - if (!parameters.isEmpty()) { wwwAuthenticate.append(" "); int i = 0; for (Map.Entry entry : parameters.entrySet()) { wwwAuthenticate.append(entry.getKey()).append("=\"").append(entry.getValue()).append("\""); - if (i != parameters.size() - 1) { wwwAuthenticate.append(", "); } i++; } } - return wwwAuthenticate.toString(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java index c96ab90fd5..e69618545e 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.resource.web; import java.io.IOException; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -40,38 +41,39 @@ import org.springframework.web.filter.OncePerRequestFilter; /** * Authenticates requests that contain an OAuth 2.0 - * Bearer Token. + * Bearer + * Token. * - * This filter should be wired with an {@link AuthenticationManager} that can authenticate a - * {@link BearerTokenAuthenticationToken}. + * This filter should be wired with an {@link AuthenticationManager} that can authenticate + * a {@link BearerTokenAuthenticationToken}. * * @author Josh Cummings * @author Vedran Pavic * @author Joe Grandja * @since 5.1 - * @see The OAuth 2.0 Authorization Framework: Bearer Token Usage + * @see The OAuth 2.0 + * Authorization Framework: Bearer Token Usage * @see JwtAuthenticationProvider */ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter { + private final AuthenticationManagerResolver authenticationManagerResolver; - private final AuthenticationDetailsSource authenticationDetailsSource = - new WebAuthenticationDetailsSource(); + private final AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private BearerTokenResolver bearerTokenResolver = new DefaultBearerTokenResolver(); private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint(); - private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) -> - authenticationEntryPoint.commence(request, response, exception); + private AuthenticationFailureHandler authenticationFailureHandler = (request, response, + exception) -> this.authenticationEntryPoint.commence(request, response, exception); /** * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s) * @param authenticationManagerResolver */ - public BearerTokenAuthenticationFilter - (AuthenticationManagerResolver authenticationManagerResolver) { - + public BearerTokenAuthenticationFilter( + AuthenticationManagerResolver authenticationManagerResolver) { Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); this.authenticationManagerResolver = authenticationManagerResolver; } @@ -82,13 +84,13 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter */ public BearerTokenAuthenticationFilter(AuthenticationManager authenticationManager) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); - this.authenticationManagerResolver = request -> authenticationManager; + this.authenticationManagerResolver = (request) -> authenticationManager; } /** - * Extract any Bearer Token from - * the request and attempt an authentication. - * + * Extract any + * Bearer + * Token from the request and attempt an authentication. * @param request * @param response * @param filterChain @@ -98,49 +100,38 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - - final boolean debug = this.logger.isDebugEnabled(); - String token; - try { token = this.bearerTokenResolver.resolve(request); - } catch ( OAuth2AuthenticationException invalid ) { + } + catch (OAuth2AuthenticationException invalid) { this.authenticationEntryPoint.commence(request, response, invalid); return; } - if (token == null) { filterChain.doFilter(request, response); return; } - BearerTokenAuthenticationToken authenticationRequest = new BearerTokenAuthenticationToken(token); - authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); - try { AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request); Authentication authenticationResult = authenticationManager.authenticate(authenticationRequest); - SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authenticationResult); SecurityContextHolder.setContext(context); - filterChain.doFilter(request, response); - } catch (AuthenticationException failed) { + } + catch (AuthenticationException failed) { SecurityContextHolder.clearContext(); - - if (debug) { - this.logger.debug("Authentication request for failed!", failed); - } - + this.logger.debug("Authentication request for failed!", failed); this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed); } } /** - * Set the {@link BearerTokenResolver} to use. Defaults to {@link DefaultBearerTokenResolver}. + * Set the {@link BearerTokenResolver} to use. Defaults to + * {@link DefaultBearerTokenResolver}. * @param bearerTokenResolver the {@code BearerTokenResolver} to use */ public void setBearerTokenResolver(BearerTokenResolver bearerTokenResolver) { @@ -149,7 +140,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter } /** - * Set the {@link AuthenticationEntryPoint} to use. Defaults to {@link BearerTokenAuthenticationEntryPoint}. + * Set the {@link AuthenticationEntryPoint} to use. Defaults to + * {@link BearerTokenAuthenticationEntryPoint}. * @param authenticationEntryPoint the {@code AuthenticationEntryPoint} to use */ public void setAuthenticationEntryPoint(final AuthenticationEntryPoint authenticationEntryPoint) { @@ -158,7 +150,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter } /** - * Set the {@link AuthenticationFailureHandler} to use. Default implementation invokes {@link AuthenticationEntryPoint}. + * Set the {@link AuthenticationFailureHandler} to use. Default implementation invokes + * {@link AuthenticationEntryPoint}. * @param authenticationFailureHandler the {@code AuthenticationFailureHandler} to use * @since 5.2 */ @@ -166,4 +159,5 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); this.authenticationFailureHandler = authenticationFailureHandler; } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenResolver.java index bebd4fc0f7..f7bd2efd3f 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenResolver.java @@ -21,20 +21,22 @@ import javax.servlet.http.HttpServletRequest; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; /** - * A strategy for resolving Bearer Tokens - * from the {@link HttpServletRequest}. + * A strategy for resolving + * Bearer + * Tokens from the {@link HttpServletRequest}. * * @author Vedran Pavic * @since 5.1 - * @see RFC 6750 Section 2: Authenticated Requests + * @see RFC 6750 + * Section 2: Authenticated Requests */ @FunctionalInterface public interface BearerTokenResolver { /** - * Resolve any Bearer Token - * value from the request. - * + * Resolve any + * Bearer + * Token value from the request. * @param request the request * @return the Bearer Token value or {@code null} if none found * @throws OAuth2AuthenticationException if the found token is invalid diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolver.java index 248d3d9606..db2fd78187 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolver.java @@ -18,28 +18,27 @@ package org.springframework.security.oauth2.server.resource.web; import java.util.regex.Matcher; import java.util.regex.Pattern; + import javax.servlet.http.HttpServletRequest; import org.springframework.http.HttpHeaders; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.BearerTokenError; +import org.springframework.security.oauth2.server.resource.BearerTokenErrors; import org.springframework.util.StringUtils; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrors.invalidRequest; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrors.invalidToken; - /** * The default {@link BearerTokenResolver} implementation based on RFC 6750. * * @author Vedran Pavic * @since 5.1 - * @see RFC 6750 Section 2: Authenticated Requests + * @see RFC 6750 + * Section 2: Authenticated Requests */ public final class DefaultBearerTokenResolver implements BearerTokenResolver { - private static final Pattern authorizationPattern = Pattern.compile( - "^Bearer (?[a-zA-Z0-9-._~+/]+=*)$", - Pattern.CASE_INSENSITIVE); + private static final Pattern authorizationPattern = Pattern.compile("^Bearer (?[a-zA-Z0-9-._~+/]+=*)$", + Pattern.CASE_INSENSITIVE); private boolean allowFormEncodedBodyParameter = false; @@ -47,40 +46,40 @@ public final class DefaultBearerTokenResolver implements BearerTokenResolver { private String bearerTokenHeaderName = HttpHeaders.AUTHORIZATION; - /** - * {@inheritDoc} - */ @Override public String resolve(HttpServletRequest request) { String authorizationHeaderToken = resolveFromAuthorizationHeader(request); String parameterToken = resolveFromRequestParameters(request); if (authorizationHeaderToken != null) { if (parameterToken != null) { - BearerTokenError error = invalidRequest("Found multiple bearer tokens in the request"); + BearerTokenError error = BearerTokenErrors + .invalidRequest("Found multiple bearer tokens in the request"); throw new OAuth2AuthenticationException(error); } return authorizationHeaderToken; } - else if (parameterToken != null && isParameterTokenSupportedForRequest(request)) { + if (parameterToken != null && isParameterTokenSupportedForRequest(request)) { return parameterToken; } return null; } /** - * Set if transport of access token using form-encoded body parameter is supported. Defaults to {@code false}. - * @param allowFormEncodedBodyParameter if the form-encoded body parameter is supported + * Set if transport of access token using form-encoded body parameter is supported. + * Defaults to {@code false}. + * @param allowFormEncodedBodyParameter if the form-encoded body parameter is + * supported */ public void setAllowFormEncodedBodyParameter(boolean allowFormEncodedBodyParameter) { this.allowFormEncodedBodyParameter = allowFormEncodedBodyParameter; } /** - * Set if transport of access token using URI query parameter is supported. Defaults to {@code false}. - * - * The spec recommends against using this mechanism for sending bearer tokens, and even goes as far as - * stating that it was only included for completeness. + * Set if transport of access token using URI query parameter is supported. Defaults + * to {@code false}. * + * The spec recommends against using this mechanism for sending bearer tokens, and + * even goes as far as stating that it was only included for completeness. * @param allowUriQueryParameter if the URI query parameter is supported */ public void setAllowUriQueryParameter(boolean allowUriQueryParameter) { @@ -91,8 +90,8 @@ public final class DefaultBearerTokenResolver implements BearerTokenResolver { * Set this value to configure what header is checked when resolving a Bearer Token. * This value is defaulted to {@link HttpHeaders#AUTHORIZATION}. * - * This allows other headers to be used as the Bearer Token source such as {@link HttpHeaders#PROXY_AUTHORIZATION} - * + * This allows other headers to be used as the Bearer Token source such as + * {@link HttpHeaders#PROXY_AUTHORIZATION} * @param bearerTokenHeaderName the header to check when retrieving the Bearer Token. * @since 5.4 */ @@ -102,30 +101,26 @@ public final class DefaultBearerTokenResolver implements BearerTokenResolver { private String resolveFromAuthorizationHeader(HttpServletRequest request) { String authorization = request.getHeader(this.bearerTokenHeaderName); - if (StringUtils.startsWithIgnoreCase(authorization, "bearer")) { - Matcher matcher = authorizationPattern.matcher(authorization); - - if (!matcher.matches()) { - BearerTokenError error = invalidToken("Bearer token is malformed"); - throw new OAuth2AuthenticationException(error); - } - - return matcher.group("token"); + if (!StringUtils.startsWithIgnoreCase(authorization, "bearer")) { + return null; } - return null; + Matcher matcher = authorizationPattern.matcher(authorization); + if (!matcher.matches()) { + BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed"); + throw new OAuth2AuthenticationException(error); + } + return matcher.group("token"); } private static String resolveFromRequestParameters(HttpServletRequest request) { String[] values = request.getParameterValues("access_token"); - if (values == null || values.length == 0) { + if (values == null || values.length == 0) { return null; } - if (values.length == 1) { return values[0]; } - - BearerTokenError error = invalidRequest("Found multiple bearer tokens in the request"); + BearerTokenError error = BearerTokenErrors.invalidRequest("Found multiple bearer tokens in the request"); throw new OAuth2AuthenticationException(error); } @@ -133,4 +128,5 @@ public final class DefaultBearerTokenResolver implements BearerTokenResolver { return ((this.allowFormEncodedBodyParameter && "POST".equals(request.getMethod())) || (this.allowUriQueryParameter && "GET".equals(request.getMethod()))); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolver.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolver.java index 77593895e9..abbabcd845 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolver.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolver.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.resource.web; import javax.servlet.http.HttpServletRequest; + import org.springframework.util.Assert; /** @@ -38,4 +39,5 @@ public class HeaderBearerTokenResolver implements BearerTokenResolver { public String resolve(HttpServletRequest request) { return request.getHeader(this.header); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandler.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandler.java index d535f95601..043412492f 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandler.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandler.java @@ -16,6 +16,12 @@ package org.springframework.security.oauth2.server.resource.web.access; +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.security.access.AccessDeniedException; @@ -24,18 +30,16 @@ import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; import org.springframework.security.web.access.AccessDeniedHandler; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.LinkedHashMap; -import java.util.Map; - /** * Translates any {@link AccessDeniedException} into an HTTP response in accordance with - * RFC 6750 Section 3: The WWW-Authenticate. + * RFC 6750 + * Section 3: The WWW-Authenticate. *

        - * So long as the class can prove that the request has a valid OAuth 2.0 {@link Authentication}, then will return an - * insufficient scope error; otherwise, - * it will simply indicate the scheme (Bearer) and any configured realm. + * So long as the class can prove that the request has a valid OAuth 2.0 + * {@link Authentication}, then will return an + * insufficient + * scope error; otherwise, it will simply indicate the scheme (Bearer) and any + * configured realm. * * @author Josh Cummings * @since 5.1 @@ -45,39 +49,33 @@ public final class BearerTokenAccessDeniedHandler implements AccessDeniedHandler private String realmName; /** - * Collect error details from the provided parameters and format according to - * RFC 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and {@code scope}. - * - * @param request that resulted in an AccessDeniedException - * @param response so that the user agent can be advised of the failure + * Collect error details from the provided parameters and format according to RFC + * 6750, specifically {@code error}, {@code error_description}, {@code error_uri}, and + * {@code scope}. + * @param request that resulted in an AccessDeniedException + * @param response so that the user agent can be advised of the failure * @param accessDeniedException that caused the invocation */ @Override - public void handle( - HttpServletRequest request, HttpServletResponse response, + public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException accessDeniedException) { - Map parameters = new LinkedHashMap<>(); - if (this.realmName != null) { parameters.put("realm", this.realmName); } - if (request.getUserPrincipal() instanceof AbstractOAuth2TokenAuthenticationToken) { parameters.put("error", BearerTokenErrorCodes.INSUFFICIENT_SCOPE); - parameters.put("error_description", "The request requires higher privileges than provided by the access token."); + parameters.put("error_description", + "The request requires higher privileges than provided by the access token."); parameters.put("error_uri", "https://tools.ietf.org/html/rfc6750#section-3.1"); } - String wwwAuthenticate = computeWWWAuthenticateHeaderValue(parameters); - response.addHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticate); response.setStatus(HttpStatus.FORBIDDEN.value()); } /** * Set the default realm name to use in the bearer token error response - * * @param realmName */ public void setRealmName(String realmName) { @@ -87,20 +85,18 @@ public final class BearerTokenAccessDeniedHandler implements AccessDeniedHandler private static String computeWWWAuthenticateHeaderValue(Map parameters) { StringBuilder wwwAuthenticate = new StringBuilder(); wwwAuthenticate.append("Bearer"); - if (!parameters.isEmpty()) { wwwAuthenticate.append(" "); int i = 0; for (Map.Entry entry : parameters.entrySet()) { wwwAuthenticate.append(entry.getKey()).append("=\"").append(entry.getValue()).append("\""); - if (i != parameters.size() - 1) { wwwAuthenticate.append(", "); } i++; } } - return wwwAuthenticate.toString(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandler.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandler.java index c4e74650c4..ce74a6d0d1 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandler.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandler.java @@ -16,6 +16,13 @@ package org.springframework.security.oauth2.server.resource.web.access.server; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; + +import reactor.core.publisher.Mono; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.security.access.AccessDeniedException; @@ -24,50 +31,45 @@ import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.Arrays; -import java.util.Collection; -import java.util.LinkedHashMap; -import java.util.Map; /** * Translates any {@link AccessDeniedException} into an HTTP response in accordance with - * RFC 6750 Section 3: The WWW-Authenticate. + * RFC 6750 + * Section 3: The WWW-Authenticate. * - * So long as the class can prove that the request has a valid OAuth 2.0 {@link Authentication}, then will return an - * insufficient scope error; otherwise, - * it will simply indicate the scheme (Bearer) and any configured realm. + * So long as the class can prove that the request has a valid OAuth 2.0 + * {@link Authentication}, then will return an + * insufficient + * scope error; otherwise, it will simply indicate the scheme (Bearer) and any + * configured realm. * * @author Josh Cummings * @since 5.1 * */ public class BearerTokenServerAccessDeniedHandler implements ServerAccessDeniedHandler { - private static final Collection WELL_KNOWN_SCOPE_ATTRIBUTE_NAMES = - Arrays.asList("scope", "scp"); + + private static final Collection WELL_KNOWN_SCOPE_ATTRIBUTE_NAMES = Arrays.asList("scope", "scp"); private String realmName; @Override public Mono handle(ServerWebExchange exchange, AccessDeniedException denied) { - Map parameters = new LinkedHashMap<>(); - if (this.realmName != null) { parameters.put("realm", this.realmName); } - + // @formatter:off return exchange.getPrincipal() .filter(AbstractOAuth2TokenAuthenticationToken.class::isInstance) - .map(token -> errorMessageParameters(parameters)) + .map((token) -> errorMessageParameters(parameters)) .switchIfEmpty(Mono.just(parameters)) - .flatMap(params -> respond(exchange, params)); + .flatMap((params) -> respond(exchange, params)); + // @formatter:on } /** * Set the default realm name to use in the bearer token error response - * * @param realmName */ public final void setRealmName(String realmName) { @@ -76,9 +78,9 @@ public class BearerTokenServerAccessDeniedHandler implements ServerAccessDeniedH private static Map errorMessageParameters(Map parameters) { parameters.put("error", BearerTokenErrorCodes.INSUFFICIENT_SCOPE); - parameters.put("error_description", "The request requires higher privileges than provided by the access token."); + parameters.put("error_description", + "The request requires higher privileges than provided by the access token."); parameters.put("error_uri", "https://tools.ietf.org/html/rfc6750#section-3.1"); - return parameters; } @@ -103,7 +105,7 @@ public class BearerTokenServerAccessDeniedHandler implements ServerAccessDeniedH i++; } } - return wwwAuthenticate.toString(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java index 700531d588..9d39a3a590 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunction.java @@ -29,14 +29,16 @@ import org.springframework.web.reactive.function.client.ExchangeFunction; /** * An {@link ExchangeFilterFunction} that adds the - * Bearer Token - * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. + * Bearer + * Token from an existing {@link AbstractOAuth2Token} tied to the current + * {@link Authentication}. * - * Suitable for Reactive applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} - * configuration: + * Suitable for Reactive applications, applying it to a typical + * {@link org.springframework.web.reactive.function.client.WebClient} configuration: * *

        - *  @Bean
        +
        + *  @Bean
          *  WebClient webClient() {
          *      ServerBearerExchangeFilterFunction bearer = new ServerBearerExchangeFilterFunction();
          *      return WebClient.builder()
        @@ -47,35 +49,39 @@ import org.springframework.web.reactive.function.client.ExchangeFunction;
          * @author Josh Cummings
          * @since 5.2
          */
        -public final class ServerBearerExchangeFilterFunction
        -		implements ExchangeFilterFunction {
        +public final class ServerBearerExchangeFilterFunction implements ExchangeFilterFunction {
         
        -	/**
        -	 * {@inheritDoc}
        -	 */
         	@Override
         	public Mono filter(ClientRequest request, ExchangeFunction next) {
        -		return oauth2Token()
        -				.map(token -> bearer(request, token))
        +		// @formatter:off
        +		return oauth2Token().map((token) -> bearer(request, token))
         				.defaultIfEmpty(request)
         				.flatMap(next::exchange);
        +		// @formatter:on
         	}
         
         	private Mono oauth2Token() {
        +		// @formatter:off
         		return currentAuthentication()
        -				.filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token)
        +				.filter((authentication) -> authentication.getCredentials() instanceof AbstractOAuth2Token)
         				.map(Authentication::getCredentials)
         				.cast(AbstractOAuth2Token.class);
        +		// @formatter:on
         	}
         
         	private Mono currentAuthentication() {
        +		// @formatter:off
         		return ReactiveSecurityContextHolder.getContext()
         				.map(SecurityContext::getAuthentication);
        +		// @formatter:on
         	}
         
         	private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) {
        +		// @formatter:off
         		return ClientRequest.from(request)
        -				.headers(headers -> headers.setBearerAuth(token.getTokenValue()))
        +				.headers((headers) -> headers.setBearerAuth(token.getTokenValue()))
         				.build();
        +		// @formatter:on
         	}
        +
         }
        diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java
        index f59af70f02..981bd2c2a1 100644
        --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java
        +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunction.java
        @@ -30,14 +30,16 @@ import org.springframework.web.reactive.function.client.ExchangeFunction;
         
         /**
          * An {@link ExchangeFilterFunction} that adds the
        - * Bearer Token
        - * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}.
        + * Bearer
        + * Token from an existing {@link AbstractOAuth2Token} tied to the current
        + * {@link Authentication}.
          *
        - * Suitable for Servlet applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient}
        - * configuration:
        + * Suitable for Servlet applications, applying it to a typical
        + * {@link org.springframework.web.reactive.function.client.WebClient} configuration:
          *
          * 
        - *  @Bean
        +
        + *  @Bean
          *  WebClient webClient() {
          *      ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction();
          *      return WebClient.builder()
        @@ -45,39 +47,38 @@ import org.springframework.web.reactive.function.client.ExchangeFunction;
          *  }
          * 
        * - * To locate the bearer token, this looks in the Reactor {@link Context} for a key of type {@link Authentication}. + * To locate the bearer token, this looks in the Reactor {@link Context} for a key of type + * {@link Authentication}. * * Registering * {@see org.springframework.security.config.annotation.web.configuration.OAuth2ResourceServerConfiguration.OAuth2ResourceServerWebFluxSecurityConfiguration.BearerRequestContextSubscriberRegistrar}, - * as a {@code @Bean} will take care of this automatically, - * but certainly an application can supply a {@link Context} of its own to override. + * as a {@code @Bean} will take care of this automatically, but certainly an application + * can supply a {@link Context} of its own to override. * * @author Josh Cummings * @since 5.2 */ -public final class ServletBearerExchangeFilterFunction - implements ExchangeFilterFunction { +public final class ServletBearerExchangeFilterFunction implements ExchangeFilterFunction { - static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = - "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES"; + static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES"; - /** - * {@inheritDoc} - */ @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return oauth2Token() - .map(token -> bearer(request, token)) + // @formatter:off + return oauth2Token().map((token) -> bearer(request, token)) .defaultIfEmpty(request) .flatMap(next::exchange); + // @formatter:on } private Mono oauth2Token() { + // @formatter:off return Mono.subscriberContext() .flatMap(this::currentAuthentication) - .filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token) + .filter((authentication) -> authentication.getCredentials() instanceof AbstractOAuth2Token) .map(Authentication::getCredentials) .cast(AbstractOAuth2Token.class); + // @formatter:on } private Mono currentAuthentication(Context ctx) { @@ -85,7 +86,8 @@ public final class ServletBearerExchangeFilterFunction } private T getAttribute(Context ctx, Class clazz) { - // NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds this key + // NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds + // this key if (!ctx.hasKey(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY)) { return null; } @@ -94,8 +96,11 @@ public final class ServletBearerExchangeFilterFunction } private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { + // @formatter:off return ClientRequest.from(request) - .headers(headers -> headers.setBearerAuth(token.getTokenValue())) + .headers((headers) -> headers.setBearerAuth(token.getTokenValue())) .build(); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPoint.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPoint.java index 3a050f30cb..167fa10c0b 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPoint.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPoint.java @@ -16,6 +16,11 @@ package org.springframework.security.oauth2.server.resource.web.server; +import java.util.LinkedHashMap; +import java.util.Map; + +import reactor.core.publisher.Mono; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.server.reactive.ServerHttpResponse; @@ -28,26 +33,21 @@ import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.LinkedHashMap; -import java.util.Map; /** - * An {@link AuthenticationEntryPoint} implementation used to commence authentication of protected resource requests - * using {@link BearerTokenAuthenticationFilter}. + * An {@link AuthenticationEntryPoint} implementation used to commence authentication of + * protected resource requests using {@link BearerTokenAuthenticationFilter}. *

        - * Uses information provided by {@link BearerTokenError} to set HTTP response status code and populate - * {@code WWW-Authenticate} HTTP header. + * Uses information provided by {@link BearerTokenError} to set HTTP response status code + * and populate {@code WWW-Authenticate} HTTP header. * * @author Rob Winch - * @see BearerTokenError - * @see RFC 6750 Section 3: The WWW-Authenticate - * Response Header Field * @since 5.1 + * @see BearerTokenError + * @see RFC 6750 + * Section 3: The WWW-Authenticate Response Header Field */ -public final class BearerTokenServerAuthenticationEntryPoint implements - ServerAuthenticationEntryPoint { +public final class BearerTokenServerAuthenticationEntryPoint implements ServerAuthenticationEntryPoint { private String realmName; @@ -59,7 +59,6 @@ public final class BearerTokenServerAuthenticationEntryPoint implements public Mono commence(ServerWebExchange exchange, AuthenticationException authException) { return Mono.defer(() -> { HttpStatus status = getStatus(authException); - Map parameters = createParameters(authException); String wwwAuthenticate = computeWWWAuthenticateHeaderValue(parameters); ServerHttpResponse response = exchange.getResponse(); @@ -74,23 +73,17 @@ public final class BearerTokenServerAuthenticationEntryPoint implements if (this.realmName != null) { parameters.put("realm", this.realmName); } - if (authException instanceof OAuth2AuthenticationException) { OAuth2Error error = ((OAuth2AuthenticationException) authException).getError(); - parameters.put("error", error.getErrorCode()); - if (StringUtils.hasText(error.getDescription())) { parameters.put("error_description", error.getDescription()); } - if (StringUtils.hasText(error.getUri())) { parameters.put("error_uri", error.getUri()); } - if (error instanceof BearerTokenError) { BearerTokenError bearerTokenError = (BearerTokenError) error; - if (StringUtils.hasText(bearerTokenError.getScope())) { parameters.put("scope", bearerTokenError.getScope()); } @@ -112,7 +105,6 @@ public final class BearerTokenServerAuthenticationEntryPoint implements private static String computeWWWAuthenticateHeaderValue(Map parameters) { StringBuilder wwwAuthenticate = new StringBuilder(); wwwAuthenticate.append("Bearer"); - if (!parameters.isEmpty()) { wwwAuthenticate.append(" "); int i = 0; @@ -124,7 +116,7 @@ public final class BearerTokenServerAuthenticationEntryPoint implements i++; } } - return wwwAuthenticate.toString(); } + } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverter.java index d59c943e6c..e4d3d6dfc9 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverter.java @@ -28,39 +28,39 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.BearerTokenError; +import org.springframework.security.oauth2.server.resource.BearerTokenErrors; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrors.invalidRequest; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrors.invalidToken; - /** - * A strategy for resolving Bearer Tokens - * from the {@link ServerWebExchange}. + * A strategy for resolving + * Bearer + * Tokens from the {@link ServerWebExchange}. * * @author Rob Winch * @since 5.1 - * @see RFC 6750 Section 2: Authenticated Requests + * @see RFC 6750 + * Section 2: Authenticated Requests */ -public class ServerBearerTokenAuthenticationConverter - implements ServerAuthenticationConverter { - private static final Pattern authorizationPattern = Pattern.compile( - "^Bearer (?[a-zA-Z0-9-._~+/]+=*)$", - Pattern.CASE_INSENSITIVE); +public class ServerBearerTokenAuthenticationConverter implements ServerAuthenticationConverter { + + private static final Pattern authorizationPattern = Pattern.compile("^Bearer (?[a-zA-Z0-9-._~+/]+=*)$", + Pattern.CASE_INSENSITIVE); private boolean allowUriQueryParameter = false; + private String bearerTokenHeaderName = HttpHeaders.AUTHORIZATION; + @Override public Mono convert(ServerWebExchange exchange) { - return Mono.fromCallable(() -> token(exchange.getRequest())) - .map(token -> { - if (token.isEmpty()) { - BearerTokenError error = invalidTokenError(); - throw new OAuth2AuthenticationException(error); - } - return new BearerTokenAuthenticationToken(token); - }); + return Mono.fromCallable(() -> token(exchange.getRequest())).map((token) -> { + if (token.isEmpty()) { + BearerTokenError error = invalidTokenError(); + throw new OAuth2AuthenticationException(error); + } + return new BearerTokenAuthenticationToken(token); + }); } private String token(ServerHttpRequest request) { @@ -68,23 +68,24 @@ public class ServerBearerTokenAuthenticationConverter String parameterToken = request.getQueryParams().getFirst("access_token"); if (authorizationHeaderToken != null) { if (parameterToken != null) { - BearerTokenError error = invalidRequest("Found multiple bearer tokens in the request"); + BearerTokenError error = BearerTokenErrors + .invalidRequest("Found multiple bearer tokens in the request"); throw new OAuth2AuthenticationException(error); } return authorizationHeaderToken; } - else if (parameterToken != null && isParameterTokenSupportedForRequest(request)) { + if (parameterToken != null && isParameterTokenSupportedForRequest(request)) { return parameterToken; } return null; } /** - * Set if transport of access token using URI query parameter is supported. Defaults to {@code false}. - * - * The spec recommends against using this mechanism for sending bearer tokens, and even goes as far as - * stating that it was only included for completeness. + * Set if transport of access token using URI query parameter is supported. Defaults + * to {@code false}. * + * The spec recommends against using this mechanism for sending bearer tokens, and + * even goes as far as stating that it was only included for completeness. * @param allowUriQueryParameter if the URI query parameter is supported */ public void setAllowUriQueryParameter(boolean allowUriQueryParameter) { @@ -95,8 +96,8 @@ public class ServerBearerTokenAuthenticationConverter * Set this value to configure what header is checked when resolving a Bearer Token. * This value is defaulted to {@link HttpHeaders#AUTHORIZATION}. * - * This allows other headers to be used as the Bearer Token source such as {@link HttpHeaders#PROXY_AUTHORIZATION} - * + * This allows other headers to be used as the Bearer Token source such as + * {@link HttpHeaders#PROXY_AUTHORIZATION} * @param bearerTokenHeaderName the header to check when retrieving the Bearer Token. * @since 5.4 */ @@ -106,24 +107,23 @@ public class ServerBearerTokenAuthenticationConverter private String resolveFromAuthorizationHeader(HttpHeaders headers) { String authorization = headers.getFirst(this.bearerTokenHeaderName); - if (StringUtils.startsWithIgnoreCase(authorization, "bearer")) { - Matcher matcher = authorizationPattern.matcher(authorization); - - if (!matcher.matches() ) { - BearerTokenError error = invalidTokenError(); - throw new OAuth2AuthenticationException(error); - } - - return matcher.group("token"); + if (!StringUtils.startsWithIgnoreCase(authorization, "bearer")) { + return null; } - return null; + Matcher matcher = authorizationPattern.matcher(authorization); + if (!matcher.matches()) { + BearerTokenError error = invalidTokenError(); + throw new OAuth2AuthenticationException(error); + } + return matcher.group("token"); } private static BearerTokenError invalidTokenError() { - return invalidToken("Bearer token is malformed"); + return BearerTokenErrors.invalidToken("Bearer token is malformed"); } private boolean isParameterTokenSupportedForRequest(ServerHttpRequest request) { return this.allowUriQueryParameter && HttpMethod.GET.equals(request.getMethod()); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AuthenticatedPrincipals.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AuthenticatedPrincipals.java index 92c9437a34..11cfb3bcc7 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AuthenticatedPrincipals.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/core/TestOAuth2AuthenticatedPrincipals.java @@ -36,9 +36,14 @@ import org.springframework.security.oauth2.server.resource.introspection.OAuth2I * * @author Josh Cummings */ -public class TestOAuth2AuthenticatedPrincipals { +public final class TestOAuth2AuthenticatedPrincipals { + + private TestOAuth2AuthenticatedPrincipals() { + } + public static OAuth2AuthenticatedPrincipal active() { - return active(attributes -> {}); + return active((attributes) -> { + }); } public static OAuth2AuthenticatedPrincipal active(Consumer> attributesConsumer) { @@ -53,18 +58,18 @@ public class TestOAuth2AuthenticatedPrincipals { attributes.put(OAuth2IntrospectionClaimNames.SUBJECT, "Z5O3upPC88QrAjx00dis"); attributes.put(OAuth2IntrospectionClaimNames.USERNAME, "jdoe"); attributesConsumer.accept(attributes); - - Collection authorities = - Arrays.asList(new SimpleGrantedAuthority("SCOPE_read"), - new SimpleGrantedAuthority("SCOPE_write"), new SimpleGrantedAuthority("SCOPE_dolphin")); + Collection authorities = Arrays.asList(new SimpleGrantedAuthority("SCOPE_read"), + new SimpleGrantedAuthority("SCOPE_write"), new SimpleGrantedAuthority("SCOPE_dolphin")); return new OAuth2IntrospectionAuthenticatedPrincipal(attributes, authorities); } private static URL url(String url) { try { return new URL(url); - } catch (IOException e) { - throw new UncheckedIOException(e); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); } } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationTokenTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationTokenTests.java index ff9b361048..9669a52d77 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationTokenTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenAuthenticationTokenTests.java @@ -19,7 +19,7 @@ package org.springframework.security.oauth2.server.resource; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link BearerTokenAuthenticationToken} @@ -27,26 +27,31 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Josh Cummings */ public class BearerTokenAuthenticationTokenTests { + @Test public void constructorWhenTokenIsNullThenThrowsException() { - assertThatCode(() -> new BearerTokenAuthenticationToken(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("token cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenAuthenticationToken(null)) + .withMessageContaining("token cannot be empty"); + // @formatter:on } @Test public void constructorWhenTokenIsEmptyThenThrowsException() { - assertThatCode(() -> new BearerTokenAuthenticationToken("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("token cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenAuthenticationToken("")) + .withMessageContaining("token cannot be empty"); + // @formatter:on } @Test public void constructorWhenTokenHasValueThenConstructedCorrectly() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token"); - assertThat(token.getToken()).isEqualTo("token"); assertThat(token.getPrincipal()).isEqualTo("token"); assertThat(token.getCredentials()).isEqualTo("token"); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorTests.java index b8d78569e0..63111e86ab 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorTests.java @@ -21,7 +21,7 @@ import org.junit.Test; import org.springframework.http.HttpStatus; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link BearerTokenError} @@ -44,7 +44,6 @@ public class BearerTokenErrorTests { @Test public void constructorWithErrorCodeWhenErrorCodeIsValidThenCreated() { BearerTokenError error = new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, null, null); - assertThat(error.getErrorCode()).isEqualTo(TEST_ERROR_CODE); assertThat(error.getHttpStatus()).isEqualTo(TEST_HTTP_STATUS); assertThat(error.getDescription()).isNull(); @@ -54,27 +53,35 @@ public class BearerTokenErrorTests { @Test public void constructorWithErrorCodeAndHttpStatusWhenErrorCodeIsNullThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(null, TEST_HTTP_STATUS, null, null)) - .isInstanceOf(IllegalArgumentException.class).hasMessage("errorCode cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(null, TEST_HTTP_STATUS, null, null)) + .withMessage("errorCode cannot be empty"); + // @formatter:on } @Test public void constructorWithErrorCodeAndHttpStatusWhenErrorCodeIsEmptyThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError("", TEST_HTTP_STATUS, null, null)) - .isInstanceOf(IllegalArgumentException.class).hasMessage("errorCode cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError("", TEST_HTTP_STATUS, null, null)) + .withMessage("errorCode cannot be empty"); + // @formatter:on } @Test public void constructorWithErrorCodeAndHttpStatusWhenHttpStatusIsNullThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(TEST_ERROR_CODE, null, null, null)) - .isInstanceOf(IllegalArgumentException.class).hasMessage("httpStatus cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(TEST_ERROR_CODE, null, null, null)) + .withMessage("httpStatus cannot be null"); + // @formatter:on } @Test public void constructorWithAllParametersWhenAllParametersAreValidThenCreated() { BearerTokenError error = new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE); - assertThat(error.getErrorCode()).isEqualTo(TEST_ERROR_CODE); assertThat(error.getHttpStatus()).isEqualTo(TEST_HTTP_STATUS); assertThat(error.getDescription()).isEqualTo(TEST_DESCRIPTION); @@ -84,55 +91,77 @@ public class BearerTokenErrorTests { @Test public void constructorWithAllParametersWhenErrorCodeIsNullThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(null, TEST_HTTP_STATUS, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE)) - .isInstanceOf(IllegalArgumentException.class).hasMessage("errorCode cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(null, TEST_HTTP_STATUS, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE)) + .withMessage("errorCode cannot be empty"); + // @formatter:on } @Test public void constructorWithAllParametersWhenErrorCodeIsEmptyThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError("", TEST_HTTP_STATUS, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE)) - .isInstanceOf(IllegalArgumentException.class).hasMessage("errorCode cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError("", TEST_HTTP_STATUS, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE)) + .withMessage("errorCode cannot be empty"); + // @formatter:on } @Test public void constructorWithAllParametersWhenHttpStatusIsNullThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(TEST_ERROR_CODE, null, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE)) - .isInstanceOf(IllegalArgumentException.class).hasMessage("httpStatus cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(TEST_ERROR_CODE, null, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE)) + .withMessage("httpStatus cannot be null"); + // @formatter:on } @Test public void constructorWithAllParametersWhenErrorCodeIsInvalidThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(TEST_ERROR_CODE + "\"", TEST_HTTP_STATUS, TEST_DESCRIPTION, - TEST_URI, TEST_SCOPE)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("errorCode") - .hasMessageContaining("RFC 6750"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(TEST_ERROR_CODE + "\"", + TEST_HTTP_STATUS, TEST_DESCRIPTION, TEST_URI, TEST_SCOPE) + ) + .withMessageContaining("errorCode") + .withMessageContaining("RFC 6750"); + // @formatter:on } @Test public void constructorWithAllParametersWhenDescriptionIsInvalidThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, TEST_DESCRIPTION + "\"", - TEST_URI, TEST_SCOPE)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("description") - .hasMessageContaining("RFC 6750"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, + TEST_DESCRIPTION + "\"", TEST_URI, TEST_SCOPE) + ) + .withMessageContaining("description") + .withMessageContaining("RFC 6750"); + // @formatter:on } @Test public void constructorWithAllParametersWhenErrorUriIsInvalidThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, TEST_DESCRIPTION, - TEST_URI + "\"", TEST_SCOPE)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("errorUri") - .hasMessageContaining("RFC 6750"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, TEST_DESCRIPTION, + TEST_URI + "\"", TEST_SCOPE) + ) + .withMessageContaining("errorUri") + .withMessageContaining("RFC 6750"); + // @formatter:on } @Test public void constructorWithAllParametersWhenScopeIsInvalidThenThrowIllegalArgumentException() { - assertThatCode(() -> new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, TEST_DESCRIPTION, - TEST_URI, TEST_SCOPE + "\"")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("scope") - .hasMessageContaining("RFC 6750"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenError(TEST_ERROR_CODE, TEST_HTTP_STATUS, + TEST_DESCRIPTION, TEST_URI, TEST_SCOPE + "\"") + ) + .withMessageContaining("scope") + .withMessageContaining("RFC 6750"); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorsTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorsTests.java index c244e68ba4..e8baca0464 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorsTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/BearerTokenErrorsTests.java @@ -18,22 +18,19 @@ package org.springframework.security.oauth2.server.resource; import org.junit.Test; +import org.springframework.http.HttpStatus; + import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.HttpStatus.BAD_REQUEST; -import static org.springframework.http.HttpStatus.FORBIDDEN; -import static org.springframework.http.HttpStatus.UNAUTHORIZED; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes.INSUFFICIENT_SCOPE; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes.INVALID_REQUEST; -import static org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes.INVALID_TOKEN; public class BearerTokenErrorsTests { + @Test public void invalidRequestWhenMessageGivenThenBearerTokenErrorReturned() { String message = "message"; BearerTokenError error = BearerTokenErrors.invalidRequest(message); - assertThat(error.getErrorCode()).isSameAs(INVALID_REQUEST); + assertThat(error.getErrorCode()).isSameAs(BearerTokenErrorCodes.INVALID_REQUEST); assertThat(error.getDescription()).isSameAs(message); - assertThat(error.getHttpStatus()).isSameAs(BAD_REQUEST); + assertThat(error.getHttpStatus()).isSameAs(HttpStatus.BAD_REQUEST); assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); } @@ -41,9 +38,9 @@ public class BearerTokenErrorsTests { public void invalidRequestWhenInvalidMessageGivenThenDefaultBearerTokenErrorReturned() { String message = "has \"invalid\" chars"; BearerTokenError error = BearerTokenErrors.invalidRequest(message); - assertThat(error.getErrorCode()).isSameAs(INVALID_REQUEST); + assertThat(error.getErrorCode()).isSameAs(BearerTokenErrorCodes.INVALID_REQUEST); assertThat(error.getDescription()).isEqualTo("Invalid request"); - assertThat(error.getHttpStatus()).isSameAs(BAD_REQUEST); + assertThat(error.getHttpStatus()).isSameAs(HttpStatus.BAD_REQUEST); assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); } @@ -51,9 +48,9 @@ public class BearerTokenErrorsTests { public void invalidTokenWhenMessageGivenThenBearerTokenErrorReturned() { String message = "message"; BearerTokenError error = BearerTokenErrors.invalidToken(message); - assertThat(error.getErrorCode()).isSameAs(INVALID_TOKEN); + assertThat(error.getErrorCode()).isSameAs(BearerTokenErrorCodes.INVALID_TOKEN); assertThat(error.getDescription()).isSameAs(message); - assertThat(error.getHttpStatus()).isSameAs(UNAUTHORIZED); + assertThat(error.getHttpStatus()).isSameAs(HttpStatus.UNAUTHORIZED); assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); } @@ -61,9 +58,9 @@ public class BearerTokenErrorsTests { public void invalidTokenWhenInvalidMessageGivenThenDefaultBearerTokenErrorReturned() { String message = "has \"invalid\" chars"; BearerTokenError error = BearerTokenErrors.invalidToken(message); - assertThat(error.getErrorCode()).isSameAs(INVALID_TOKEN); + assertThat(error.getErrorCode()).isSameAs(BearerTokenErrorCodes.INVALID_TOKEN); assertThat(error.getDescription()).isEqualTo("Invalid token"); - assertThat(error.getHttpStatus()).isSameAs(UNAUTHORIZED); + assertThat(error.getHttpStatus()).isSameAs(HttpStatus.UNAUTHORIZED); assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); } @@ -72,9 +69,9 @@ public class BearerTokenErrorsTests { String message = "message"; String scope = "scope"; BearerTokenError error = BearerTokenErrors.insufficientScope(message, scope); - assertThat(error.getErrorCode()).isSameAs(INSUFFICIENT_SCOPE); + assertThat(error.getErrorCode()).isSameAs(BearerTokenErrorCodes.INSUFFICIENT_SCOPE); assertThat(error.getDescription()).isSameAs(message); - assertThat(error.getHttpStatus()).isSameAs(FORBIDDEN); + assertThat(error.getHttpStatus()).isSameAs(HttpStatus.FORBIDDEN); assertThat(error.getScope()).isSameAs(scope); assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); } @@ -83,10 +80,11 @@ public class BearerTokenErrorsTests { public void insufficientScopeWhenInvalidMessageGivenThenDefaultBearerTokenErrorReturned() { String message = "has \"invalid\" chars"; BearerTokenError error = BearerTokenErrors.insufficientScope(message, "scope"); - assertThat(error.getErrorCode()).isSameAs(INSUFFICIENT_SCOPE); + assertThat(error.getErrorCode()).isSameAs(BearerTokenErrorCodes.INSUFFICIENT_SCOPE); assertThat(error.getDescription()).isSameAs("Insufficient scope"); - assertThat(error.getHttpStatus()).isSameAs(FORBIDDEN); + assertThat(error.getHttpStatus()).isSameAs(HttpStatus.FORBIDDEN); assertThat(error.getScope()).isNull(); assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/DefaultAuthenticationEventPublisherBearerTokenTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/DefaultAuthenticationEventPublisherBearerTokenTests.java index d87c0d8acf..10fdb5fe57 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/DefaultAuthenticationEventPublisherBearerTokenTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/DefaultAuthenticationEventPublisherBearerTokenTests.java @@ -22,13 +22,13 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.security.authentication.DefaultAuthenticationEventPublisher; import org.springframework.security.authentication.event.AuthenticationFailureBadCredentialsEvent; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link DefaultAuthenticationEventPublisher}'s bearer token use cases @@ -36,17 +36,18 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * {@see DefaultAuthenticationEventPublisher} */ public class DefaultAuthenticationEventPublisherBearerTokenTests { + DefaultAuthenticationEventPublisher publisher; @Test public void publishAuthenticationFailureWhenInvalidBearerTokenExceptionThenMaps() { ApplicationEventPublisher appPublisher = mock(ApplicationEventPublisher.class); - Authentication authentication = new JwtAuthenticationToken(jwt().build()); + Authentication authentication = new JwtAuthenticationToken(TestJwts.jwt().build()); Exception cause = new Exception(); this.publisher = new DefaultAuthenticationEventPublisher(appPublisher); this.publisher.publishAuthenticationFailure(new InvalidBearerTokenException("invalid"), authentication); this.publisher.publishAuthenticationFailure(new InvalidBearerTokenException("invalid", cause), authentication); - verify(appPublisher, times(2)).publishEvent( - isA(AuthenticationFailureBadCredentialsEvent.class)); + verify(appPublisher, times(2)).publishEvent(isA(AuthenticationFailureBadCredentialsEvent.class)); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthenticationTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthenticationTests.java index 8aa7fefc7f..1f4cfdf292 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthenticationTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/BearerTokenAuthenticationTests.java @@ -33,12 +33,10 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.CLIENT_ID; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SUBJECT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.USERNAME; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link BearerTokenAuthentication} @@ -46,33 +44,39 @@ import static org.springframework.security.oauth2.server.resource.introspection. * @author Josh Cummings */ public class BearerTokenAuthenticationTests { - private final OAuth2AccessToken token = - new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", Instant.now(), Instant.now().plusSeconds(3600)); + + private final OAuth2AccessToken token = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plusSeconds(3600)); + private final String name = "sub"; + private Map attributesMap = new HashMap<>(); + private DefaultOAuth2AuthenticatedPrincipal principal; + private final Collection authorities = AuthorityUtils.createAuthorityList("USER"); @Before public void setUp() { - this.attributesMap.put(SUBJECT, this.name); - this.attributesMap.put(CLIENT_ID, "client_id"); - this.attributesMap.put(USERNAME, "username"); + this.attributesMap.put(OAuth2IntrospectionClaimNames.SUBJECT, this.name); + this.attributesMap.put(OAuth2IntrospectionClaimNames.CLIENT_ID, "client_id"); + this.attributesMap.put(OAuth2IntrospectionClaimNames.USERNAME, "username"); this.principal = new DefaultOAuth2AuthenticatedPrincipal(this.attributesMap, null); } @Test public void getNameWhenConfiguredInConstructorThenReturnsName() { - OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(this.name, this.attributesMap, this.authorities); - BearerTokenAuthentication authenticated = new BearerTokenAuthentication(principal, this.token, this.authorities); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(this.name, this.attributesMap, + this.authorities); + BearerTokenAuthentication authenticated = new BearerTokenAuthentication(principal, this.token, + this.authorities); assertThat(authenticated.getName()).isEqualTo(this.name); } @Test public void getNameWhenHasNoSubjectThenReturnsNull() { - OAuth2AuthenticatedPrincipal principal = - new DefaultOAuth2AuthenticatedPrincipal(Collections.singletonMap("claim", "value"), null); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal( + Collections.singletonMap("claim", "value"), null); BearerTokenAuthentication authenticated = new BearerTokenAuthentication(principal, this.token, null); assertThat(authenticated.getName()).isNull(); } @@ -80,43 +84,50 @@ public class BearerTokenAuthenticationTests { @Test public void getNameWhenTokenHasUsernameThenReturnsUsernameAttribute() { BearerTokenAuthentication authenticated = new BearerTokenAuthentication(this.principal, this.token, null); - assertThat(authenticated.getName()).isEqualTo(this.principal.getAttribute(SUBJECT)); + // @formatter:off + assertThat(authenticated.getName()) + .isEqualTo(this.principal.getAttribute(OAuth2IntrospectionClaimNames.SUBJECT)); + // @formatter:on } @Test public void constructorWhenTokenIsNullThenThrowsException() { - assertThatCode(() -> new BearerTokenAuthentication(this.principal, null, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("token cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenAuthentication(this.principal, null, null)) + .withMessageContaining("token cannot be null"); + // @formatter:on } @Test public void constructorWhenCredentialIsNullThenThrowsException() { - assertThatCode(() -> new BearerTokenAuthentication(null, this.token, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("principal cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenAuthentication(null, this.token, null)) + .withMessageContaining("principal cannot be null"); + // @formatter:on } @Test public void constructorWhenPassingAllAttributesThenTokenIsAuthenticated() { - OAuth2AuthenticatedPrincipal principal = - new DefaultOAuth2AuthenticatedPrincipal("harris", Collections.singletonMap("claim", "value"), null); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal("harris", + Collections.singletonMap("claim", "value"), null); BearerTokenAuthentication authenticated = new BearerTokenAuthentication(principal, this.token, null); assertThat(authenticated.isAuthenticated()).isTrue(); } @Test public void getTokenAttributesWhenHasTokenThenReturnsThem() { - BearerTokenAuthentication authenticated = - new BearerTokenAuthentication(this.principal, this.token, Collections.emptyList()); + BearerTokenAuthentication authenticated = new BearerTokenAuthentication(this.principal, this.token, + Collections.emptyList()); assertThat(authenticated.getTokenAttributes()).isEqualTo(this.principal.getAttributes()); } @Test public void getAuthoritiesWhenHasAuthoritiesThenReturnsThem() { List authorities = AuthorityUtils.createAuthorityList("USER"); - BearerTokenAuthentication authenticated = - new BearerTokenAuthentication(this.principal, this.token, authorities); + BearerTokenAuthentication authenticated = new BearerTokenAuthentication(this.principal, this.token, + authorities); assertThat(authenticated.getAuthorities()).isEqualTo(authorities); } @@ -137,7 +148,7 @@ public class BearerTokenAuthenticationTests { JSONObject attributes = new JSONObject(Collections.singletonMap("iss", new URL("https://idp.example.com"))); OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(attributes, null); BearerTokenAuthentication token = new BearerTokenAuthentication(principal, this.token, null); - assertThatCode(token::toString) - .doesNotThrowAnyException(); + token.toString(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java index 6ade846fbc..037d1070ba 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java @@ -26,10 +26,10 @@ import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link JwtAuthenticationConverter} @@ -38,17 +38,15 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * @author Evgeniy Cheban */ public class JwtAuthenticationConverterTests { + JwtAuthenticationConverter jwtAuthenticationConverter = new JwtAuthenticationConverter(); @Test public void convertWhenDefaultGrantedAuthoritiesConverterSet() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - + Jwt jwt = TestJwts.jwt().claim("scope", "message:read message:write").build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt); Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), new SimpleGrantedAuthority("SCOPE_message:write")); } @@ -61,48 +59,48 @@ public class JwtAuthenticationConverterTests { @Test public void convertWithOverriddenGrantedAuthoritiesConverter() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - - Converter> grantedAuthoritiesConverter = - token -> Arrays.asList(new SimpleGrantedAuthority("blah")); - + Jwt jwt = TestJwts.jwt().claim("scope", "message:read message:write").build(); + Converter> grantedAuthoritiesConverter = (token) -> Arrays + .asList(new SimpleGrantedAuthority("blah")); this.jwtAuthenticationConverter.setJwtGrantedAuthoritiesConverter(grantedAuthoritiesConverter); - AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt); Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("blah")); + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("blah")); } @Test public void whenSettingNullPrincipalClaimName() { + // @formatter:off assertThatIllegalArgumentException() .isThrownBy(() -> this.jwtAuthenticationConverter.setPrincipalClaimName(null)) .withMessage("principalClaimName cannot be empty"); + // @formatter:on } @Test public void whenSettingEmptyPrincipalClaimName() { + // @formatter:off assertThatIllegalArgumentException() .isThrownBy(() -> this.jwtAuthenticationConverter.setPrincipalClaimName("")) .withMessage("principalClaimName cannot be empty"); + // @formatter:on } @Test public void whenSettingBlankPrincipalClaimName() { + // @formatter:off assertThatIllegalArgumentException() .isThrownBy(() -> this.jwtAuthenticationConverter.setPrincipalClaimName(" ")) .withMessage("principalClaimName cannot be empty"); + // @formatter:on } @Test public void convertWhenPrincipalClaimNameSet() { this.jwtAuthenticationConverter.setPrincipalClaimName("user_id"); - - Jwt jwt = jwt().claim("user_id", "100").build(); + Jwt jwt = TestJwts.jwt().claim("user_id", "100").build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt); - assertThat(authentication.getName()).isEqualTo("100"); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java index 50db5d9187..9e92ce52e4 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.authentication; import java.util.function.Predicate; @@ -30,14 +31,14 @@ import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link JwtAuthenticationProvider} @@ -46,6 +47,7 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; */ @RunWith(MockitoJUnitRunner.class) public class JwtAuthenticationProviderTests { + @Mock Converter jwtAuthenticationConverter; @@ -56,60 +58,52 @@ public class JwtAuthenticationProviderTests { @Before public void setup() { - this.provider = - new JwtAuthenticationProvider(this.jwtDecoder); - this.provider.setJwtAuthenticationConverter(jwtAuthenticationConverter); + this.provider = new JwtAuthenticationProvider(this.jwtDecoder); + this.provider.setJwtAuthenticationConverter(this.jwtAuthenticationConverter); } @Test public void authenticateWhenJwtDecodesThenAuthenticationHasAttributesContainedInJwt() { BearerTokenAuthenticationToken token = this.authentication(); - - Jwt jwt = jwt().claim("name", "value").build(); - - when(this.jwtDecoder.decode("token")).thenReturn(jwt); - when(this.jwtAuthenticationConverter.convert(jwt)).thenReturn(new JwtAuthenticationToken(jwt)); - - JwtAuthenticationToken authentication = - (JwtAuthenticationToken) this.provider.authenticate(token); - + Jwt jwt = TestJwts.jwt().claim("name", "value").build(); + given(this.jwtDecoder.decode("token")).willReturn(jwt); + given(this.jwtAuthenticationConverter.convert(jwt)).willReturn(new JwtAuthenticationToken(jwt)); + JwtAuthenticationToken authentication = (JwtAuthenticationToken) this.provider.authenticate(token); assertThat(authentication.getTokenAttributes()).containsEntry("name", "value"); } @Test public void authenticateWhenJwtDecodeFailsThenRespondsWithInvalidToken() { BearerTokenAuthenticationToken token = this.authentication(); - - when(this.jwtDecoder.decode("token")).thenThrow(BadJwtException.class); - - assertThatCode(() -> this.provider.authenticate(token)) - .matches(failed -> failed instanceof OAuth2AuthenticationException) + given(this.jwtDecoder.decode("token")).willThrow(BadJwtException.class); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) .matches(errorCode(BearerTokenErrorCodes.INVALID_TOKEN)); + // @formatter:on } @Test public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() { BearerTokenAuthenticationToken token = this.authentication(); - - when(this.jwtDecoder.decode(token.getToken())).thenThrow(new BadJwtException("with \"invalid\" chars")); - - assertThatCode(() -> this.provider.authenticate(token)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasFieldOrPropertyWithValue( - "error.description", - "Invalid token"); + given(this.jwtDecoder.decode(token.getToken())).willThrow(new BadJwtException("with \"invalid\" chars")); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies((ex) -> assertThat(ex).hasFieldOrPropertyWithValue("error.description", "Invalid token")); + // @formatter:on } // gh-7785 @Test public void authenticateWhenDecoderFailsGenericallyThenThrowsGenericException() { BearerTokenAuthenticationToken token = this.authentication(); - - when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("no jwk set")); - - assertThatCode(() -> this.provider.authenticate(token)) - .isInstanceOf(AuthenticationException.class) + given(this.jwtDecoder.decode(token.getToken())).willThrow(new JwtException("no jwk set")); + // @formatter:off + assertThatExceptionOfType(AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) .isNotInstanceOf(OAuth2AuthenticationException.class); + // @formatter:on } @Test @@ -117,16 +111,15 @@ public class JwtAuthenticationProviderTests { BearerTokenAuthenticationToken token = this.authentication(); Object details = mock(Object.class); token.setDetails(details); - - Jwt jwt = jwt().build(); + Jwt jwt = TestJwts.jwt().build(); JwtAuthenticationToken authentication = new JwtAuthenticationToken(jwt); - - when(this.jwtDecoder.decode(token.getToken())).thenReturn(jwt); - when(this.jwtAuthenticationConverter.convert(jwt)).thenReturn(authentication); - + given(this.jwtDecoder.decode(token.getToken())).willReturn(jwt); + given(this.jwtAuthenticationConverter.convert(jwt)).willReturn(authentication); + // @formatter:off assertThat(this.provider.authenticate(token)) - .isEqualTo(authentication) - .hasFieldOrPropertyWithValue("details", details); + .isEqualTo(authentication).hasFieldOrPropertyWithValue("details", + details); + // @formatter:on } @Test @@ -139,7 +132,7 @@ public class JwtAuthenticationProviderTests { } private Predicate errorCode(String errorCode) { - return failed -> - ((OAuth2AuthenticationException) failed).getError().getErrorCode() == errorCode; + return (failed) -> ((OAuth2AuthenticationException) failed).getError().getErrorCode() == errorCode; } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java index f439abf7d2..dcd487b2d3 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java @@ -24,11 +24,11 @@ import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jwt.Jwt; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.oauth2.jose.jws.JwsAlgorithms.RS256; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link JwtAuthenticationToken} @@ -42,7 +42,6 @@ public class JwtAuthenticationTokenTests { public void getNameWhenJwtHasSubjectThenReturnsSubject() { Jwt jwt = builder().subject("Carl").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt); - assertThat(token.getName()).isEqualTo("Carl"); } @@ -50,15 +49,13 @@ public class JwtAuthenticationTokenTests { public void getNameWhenJwtHasNoSubjectThenReturnsNull() { Jwt jwt = builder().claim("claim", "value").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt); - assertThat(token.getName()).isNull(); } @Test public void constructorWhenJwtIsNullThenThrowsException() { - assertThatCode(() -> new JwtAuthenticationToken(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("token cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> new JwtAuthenticationToken(null)) + .withMessageContaining("token cannot be null"); } @Test @@ -66,7 +63,6 @@ public class JwtAuthenticationTokenTests { Collection authorities = AuthorityUtils.createAuthorityList("test"); Jwt jwt = builder().claim("claim", "value").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt, authorities); - assertThat(token.getAuthorities()).isEqualTo(authorities); assertThat(token.getPrincipal()).isEqualTo(jwt); assertThat(token.getCredentials()).isEqualTo(jwt); @@ -79,7 +75,6 @@ public class JwtAuthenticationTokenTests { public void constructorWhenUsingOnlyJwtThenConstructedCorrectly() { Jwt jwt = builder().claim("claim", "value").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt); - assertThat(token.getAuthorities()).isEmpty(); assertThat(token.getPrincipal()).isEqualTo(jwt); assertThat(token.getCredentials()).isEqualTo(jwt); @@ -92,7 +87,6 @@ public class JwtAuthenticationTokenTests { public void getNameWhenConstructedWithJwtThenReturnsSubject() { Jwt jwt = builder().subject("Hayden").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt); - assertThat(token.getName()).isEqualTo("Hayden"); } @@ -101,7 +95,6 @@ public class JwtAuthenticationTokenTests { Collection authorities = AuthorityUtils.createAuthorityList("test"); Jwt jwt = builder().subject("Hayden").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt, authorities); - assertThat(token.getName()).isEqualTo("Hayden"); } @@ -110,7 +103,6 @@ public class JwtAuthenticationTokenTests { Collection authorities = AuthorityUtils.createAuthorityList("test"); Jwt jwt = builder().claim("claim", "value").build(); JwtAuthenticationToken token = new JwtAuthenticationToken(jwt, authorities, "Hayden"); - assertThat(token.getName()).isEqualTo("Hayden"); } @@ -118,13 +110,13 @@ public class JwtAuthenticationTokenTests { public void getNameWhenConstructedWithNoSubjectThenReturnsNull() { Collection authorities = AuthorityUtils.createAuthorityList("test"); Jwt jwt = builder().claim("claim", "value").build(); - assertThat(new JwtAuthenticationToken(jwt, authorities, null).getName()).isNull(); assertThat(new JwtAuthenticationToken(jwt, authorities).getName()).isNull(); assertThat(new JwtAuthenticationToken(jwt).getName()).isNull(); } private Jwt.Builder builder() { - return Jwt.withTokenValue("token").header("alg", RS256); + return Jwt.withTokenValue("token").header("alg", JwsAlgorithms.RS256); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java index f96073aeb7..10d4744258 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java @@ -32,18 +32,18 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Josh Cummings */ public class JwtBearerTokenAuthenticationConverterTests { - private final JwtBearerTokenAuthenticationConverter converter = - new JwtBearerTokenAuthenticationConverter(); + + private final JwtBearerTokenAuthenticationConverter converter = new JwtBearerTokenAuthenticationConverter(); @Test public void convertWhenJwtThenBearerTokenAuthentication() { + // @formatter:off Jwt jwt = Jwt.withTokenValue("token-value") .claim("claim", "value") .header("header", "value") .build(); - + // @formatter:on AbstractAuthenticationToken token = this.converter.convert(jwt); - assertThat(token).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication bearerToken = (BearerTokenAuthentication) token; assertThat(bearerToken.getToken().getTokenValue()).isEqualTo("token-value"); @@ -53,33 +53,32 @@ public class JwtBearerTokenAuthenticationConverterTests { @Test public void convertWhenJwtWithScopeAttributeThenBearerTokenAuthentication() { + // @formatter:off Jwt jwt = Jwt.withTokenValue("token-value") .claim("scope", "message:read message:write") .header("header", "value") .build(); - + // @formatter:on AbstractAuthenticationToken token = this.converter.convert(jwt); - assertThat(token).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication bearerToken = (BearerTokenAuthentication) token; - assertThat(bearerToken.getAuthorities()) - .containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), - new SimpleGrantedAuthority("SCOPE_message:write")); + assertThat(bearerToken.getAuthorities()).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), + new SimpleGrantedAuthority("SCOPE_message:write")); } @Test public void convertWhenJwtWithScpAttributeThenBearerTokenAuthentication() { + // @formatter:off Jwt jwt = Jwt.withTokenValue("token-value") .claim("scp", Arrays.asList("message:read", "message:write")) .header("header", "value") .build(); - + // @formatter:on AbstractAuthenticationToken token = this.converter.convert(jwt); - assertThat(token).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication bearerToken = (BearerTokenAuthentication) token; - assertThat(bearerToken.getAuthorities()) - .containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), - new SimpleGrantedAuthority("SCOPE_message:write")); + assertThat(bearerToken.getAuthorities()).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), + new SimpleGrantedAuthority("SCOPE_message:write")); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverterTests.java index 70ecf618e7..b0de546d85 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtGrantedAuthoritiesConverterTests.java @@ -25,9 +25,9 @@ import org.junit.Test; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link JwtGrantedAuthoritiesConverter} @@ -45,199 +45,214 @@ public class JwtGrantedAuthoritiesConverterTests { @Test public void convertWhenTokenHasScopeAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scope", "message:read message:write") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), new SimpleGrantedAuthority("SCOPE_message:write")); } @Test public void convertWithCustomAuthorityPrefixWhenTokenHasScopeAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scope", "message:read message:write") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthorityPrefix("ROLE_"); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("ROLE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("ROLE_message:read"), new SimpleGrantedAuthority("ROLE_message:write")); } @Test public void convertWithBlankAsCustomAuthorityPrefixWhenTokenHasScopeAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scope", "message:read message:write") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthorityPrefix(""); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("message:read"), new SimpleGrantedAuthority("message:write")); } @Test public void convertWhenTokenHasEmptyScopeAttributeThenTranslatedToNoAuthorities() { - Jwt jwt = jwt().claim("scope", "").build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scope", "") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasScpAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scp", Arrays.asList("message:read", "message:write")).build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", Arrays.asList("message:read", "message:write")) + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), new SimpleGrantedAuthority("SCOPE_message:write")); } @Test public void convertWithCustomAuthorityPrefixWhenTokenHasScpAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scp", Arrays.asList("message:read", "message:write")).build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", Arrays.asList("message:read", "message:write")) + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthorityPrefix("ROLE_"); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("ROLE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("ROLE_message:read"), new SimpleGrantedAuthority("ROLE_message:write")); } @Test public void convertWithBlankAsCustomAuthorityPrefixWhenTokenHasScpAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scp", "message:read message:write").build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", "message:read message:write") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthorityPrefix(""); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("message:read"), new SimpleGrantedAuthority("message:write")); } @Test public void convertWhenTokenHasEmptyScpAttributeThenTranslatedToNoAuthorities() { - Jwt jwt = jwt().claim("scp", Collections.emptyList()).build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", Collections.emptyList()) + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasBothScopeAndScpThenScopeAttributeIsTranslatedToAuthorities() { - Jwt jwt = jwt() - .claim("scp", Arrays.asList("message:read", "message:write")) - .claim("scope", "missive:read missive:write") - .build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", Arrays.asList("message:read", "message:write")) + .claim("scope", "missive:read missive:write") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_missive:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("SCOPE_missive:read"), new SimpleGrantedAuthority("SCOPE_missive:write")); } @Test public void convertWhenTokenHasEmptyScopeAndNonEmptyScpThenScopeAttributeIsTranslatedToNoAuthorities() { - Jwt jwt = jwt() - .claim("scp", Arrays.asList("message:read", "message:write")) - .claim("scope", "") - .build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", Arrays.asList("message:read", "message:write")) + .claim("scope", "") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasEmptyScopeAndEmptyScpAttributeThenTranslatesToNoAuthorities() { - Jwt jwt = jwt() - .claim("scp", Collections.emptyList()) - .claim("scope", Collections.emptyList()) - .build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scp", Collections.emptyList()) + .claim("scope", Collections.emptyList()) + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasNoScopeAndNoScpAttributeThenTranslatesToNoAuthorities() { - Jwt jwt = jwt().claim("roles", Arrays.asList("message:read", "message:write")).build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("roles", Arrays.asList("message:read", "message:write")) + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasUnsupportedTypeForScopeThenTranslatesToNoAuthorities() { - Jwt jwt = jwt().claim("scope", new String[] {"message:read", "message:write"}).build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scope", new String[] { "message:read", "message:write" }) + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasCustomClaimNameThenCustomClaimNameAttributeIsTranslatedToAuthorities() { - Jwt jwt = jwt() + // @formatter:off + Jwt jwt = TestJwts.jwt() .claim("roles", Arrays.asList("message:read", "message:write")) .claim("scope", "missive:read missive:write") .build(); - + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthoritiesClaimName("roles"); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), new SimpleGrantedAuthority("SCOPE_message:write")); } @Test public void convertWhenTokenHasEmptyCustomClaimNameThenCustomClaimNameAttributeIsTranslatedToNoAuthorities() { - Jwt jwt = jwt() + // @formatter:off + Jwt jwt = TestJwts.jwt() .claim("roles", Collections.emptyList()) .claim("scope", "missive:read missive:write") .build(); - + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthoritiesClaimName("roles"); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } @Test public void convertWhenTokenHasNoCustomClaimNameThenCustomClaimNameAttributeIsTranslatedToNoAuthorities() { - Jwt jwt = jwt().claim("scope", "missive:read missive:write").build(); - + // @formatter:off + Jwt jwt = TestJwts.jwt() + .claim("scope", "missive:read missive:write") + .build(); + // @formatter:on JwtGrantedAuthoritiesConverter jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); jwtGrantedAuthoritiesConverter.setAuthoritiesClaimName("roles"); Collection authorities = jwtGrantedAuthoritiesConverter.convert(jwt); - assertThat(authorities).isEmpty(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java index 21df5189fe..354c551105 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerAuthenticationManagerResolverTests.java @@ -38,23 +38,25 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManagerResolver; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; /** * Tests for {@link JwtIssuerAuthenticationManagerResolver} */ public class JwtIssuerAuthenticationManagerResolverTests { - private static final String DEFAULT_RESPONSE_TEMPLATE = "{\n" - + " \"issuer\": \"%s\", \n" - + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" - + "}"; + + private static final String DEFAULT_RESPONSE_TEMPLATE = "{\n" + " \"issuer\": \"%s\", \n" + + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" + "}"; private String jwt = jwt("iss", "trusted"); + private String evil = jwt("iss", "\""); + private String noIssuer = jwt("sub", "sub"); @Test @@ -62,125 +64,130 @@ public class JwtIssuerAuthenticationManagerResolverTests { try (MockWebServer server = new MockWebServer()) { server.start(); String issuer = server.url("").toString(); - server.enqueue(new MockResponse() - .setResponseCode(200) + // @formatter:off + server.enqueue(new MockResponse().setResponseCode(200) .setHeader("Content-Type", "application/json") - .setBody(String.format(DEFAULT_RESPONSE_TEMPLATE, issuer, issuer))); + .setBody(String.format(DEFAULT_RESPONSE_TEMPLATE, issuer, issuer) + )); + // @formatter:on JWSObject jws = new JWSObject(new JWSHeader(JWSAlgorithm.RS256), - new Payload(new JSONObject(Collections.singletonMap(ISS, issuer)))); + new Payload(new JSONObject(Collections.singletonMap(JwtClaimNames.ISS, issuer)))); jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); - - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver(issuer); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + issuer); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + jws.serialize()); - - AuthenticationManager authenticationManager = - authenticationManagerResolver.resolve(request); + AuthenticationManager authenticationManager = authenticationManagerResolver.resolve(request); assertThat(authenticationManager).isNotNull(); - - AuthenticationManager cachedAuthenticationManager = - authenticationManagerResolver.resolve(request); + AuthenticationManager cachedAuthenticationManager = authenticationManagerResolver.resolve(request); assertThat(authenticationManager).isSameAs(cachedAuthenticationManager); } } @Test public void resolveWhenUsingUntrustedIssuerThenException() { - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver("other", "issuers"); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + "other", "issuers"); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + this.jwt); - - assertThatCode(() -> authenticationManagerResolver.resolve(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(request)) + .withMessageContaining("Invalid issuer"); + // @formatter:on } @Test public void resolveWhenUsingCustomIssuerAuthenticationManagerResolverThenUses() { AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver(issuer -> authenticationManager); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + (issuer) -> authenticationManager); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + this.jwt); - - assertThat(authenticationManagerResolver.resolve(request)) - .isSameAs(authenticationManager); + assertThat(authenticationManagerResolver.resolve(request)).isSameAs(authenticationManager); } @Test public void resolveWhenUsingExternalSourceThenRespondsToChanges() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + this.jwt); - Map authenticationManagers = new HashMap<>(); - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver(authenticationManagers::get); - assertThatCode(() -> authenticationManagerResolver.resolve(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Invalid issuer"); - + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + authenticationManagers::get); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(request)) + .withMessageContaining("Invalid issuer"); + // @formatter:on AuthenticationManager authenticationManager = mock(AuthenticationManager.class); authenticationManagers.put("trusted", authenticationManager); - assertThat(authenticationManagerResolver.resolve(request)) - .isSameAs(authenticationManager); - + assertThat(authenticationManagerResolver.resolve(request)).isSameAs(authenticationManager); authenticationManagers.clear(); - assertThatCode(() -> authenticationManagerResolver.resolve(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(request)) + .withMessageContaining("Invalid issuer"); + // @formatter:on } @Test public void resolveWhenBearerTokenMalformedThenException() { - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver("trusted"); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + "trusted"); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer jwt"); - assertThatCode(() -> authenticationManagerResolver.resolve(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageNotContaining("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(request)) + .withMessageNotContaining("Invalid issuer"); + // @formatter:on } @Test public void resolveWhenBearerTokenNoIssuerThenException() { - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver("trusted"); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + "trusted"); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + this.noIssuer); - assertThatCode(() -> authenticationManagerResolver.resolve(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Missing issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(request)) + .withMessageContaining("Missing issuer"); + // @formatter:on } @Test public void resolveWhenBearerTokenEvilThenGenericException() { - JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerAuthenticationManagerResolver("trusted"); + JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerAuthenticationManagerResolver( + "trusted"); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + this.evil); - assertThatCode(() -> authenticationManagerResolver.resolve(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver + .resolve(request) + ) + .withMessage("Invalid issuer"); + // @formatter:on } @Test public void constructorWhenNullOrEmptyIssuersThenException() { - assertThatCode(() -> new JwtIssuerAuthenticationManagerResolver((Collection) null)) - .isInstanceOf(IllegalArgumentException.class); - assertThatCode(() -> new JwtIssuerAuthenticationManagerResolver(Collections.emptyList())) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtIssuerAuthenticationManagerResolver((Collection) null)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtIssuerAuthenticationManagerResolver(Collections.emptyList())); } @Test public void constructorWhenNullAuthenticationManagerResolverThenException() { - assertThatCode(() -> new JwtIssuerAuthenticationManagerResolver((AuthenticationManagerResolver) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtIssuerAuthenticationManagerResolver((AuthenticationManagerResolver) null)); } private String jwt(String claim, String value) { PlainJWT jwt = new PlainJWT(new JWTClaimsSet.Builder().claim(claim, value).build()); return jwt.serialize(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java index d694707096..a794aa1fde 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtIssuerReactiveAuthenticationManagerResolverTests.java @@ -40,137 +40,139 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.ISS; /** * Tests for {@link JwtIssuerReactiveAuthenticationManagerResolver} */ public class JwtIssuerReactiveAuthenticationManagerResolverTests { + + // @formatter:off private static final String DEFAULT_RESPONSE_TEMPLATE = "{\n" + " \"issuer\": \"%s\", \n" + " \"jwks_uri\": \"%s/.well-known/jwks.json\" \n" + "}"; + // @formatter:on private String jwt = jwt("iss", "trusted"); + private String evil = jwt("iss", "\""); + private String noIssuer = jwt("sub", "sub"); @Test public void resolveWhenUsingTrustedIssuerThenReturnsAuthenticationManager() throws Exception { try (MockWebServer server = new MockWebServer()) { String issuer = server.url("").toString(); - server.enqueue(new MockResponse() - .setResponseCode(200) - .setHeader("Content-Type", "application/json") + server.enqueue(new MockResponse().setResponseCode(200).setHeader("Content-Type", "application/json") .setBody(String.format(DEFAULT_RESPONSE_TEMPLATE, issuer, issuer))); JWSObject jws = new JWSObject(new JWSHeader(JWSAlgorithm.RS256), - new Payload(new JSONObject(Collections.singletonMap(ISS, issuer)))); + new Payload(new JSONObject(Collections.singletonMap(JwtClaimNames.ISS, issuer)))); jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY)); - - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver(issuer); + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + issuer); MockServerWebExchange exchange = withBearerToken(jws.serialize()); - - ReactiveAuthenticationManager authenticationManager = - authenticationManagerResolver.resolve(exchange).block(); + ReactiveAuthenticationManager authenticationManager = authenticationManagerResolver.resolve(exchange) + .block(); assertThat(authenticationManager).isNotNull(); - - ReactiveAuthenticationManager cachedAuthenticationManager = - authenticationManagerResolver.resolve(exchange).block(); + ReactiveAuthenticationManager cachedAuthenticationManager = authenticationManagerResolver.resolve(exchange) + .block(); assertThat(authenticationManager).isSameAs(cachedAuthenticationManager); } } @Test public void resolveWhenUsingUntrustedIssuerThenException() { - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver("other", "issuers"); + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + "other", "issuers"); MockServerWebExchange exchange = withBearerToken(this.jwt); - - assertThatCode(() -> authenticationManagerResolver.resolve(exchange).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .withMessageContaining("Invalid issuer"); + // @formatter:on } @Test public void resolveWhenUsingCustomIssuerAuthenticationManagerResolverThenUses() { ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class); - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver(issuer -> Mono.just(authenticationManager)); + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + (issuer) -> Mono.just(authenticationManager)); MockServerWebExchange exchange = withBearerToken(this.jwt); - - assertThat(authenticationManagerResolver.resolve(exchange).block()) - .isSameAs(authenticationManager); + assertThat(authenticationManagerResolver.resolve(exchange).block()).isSameAs(authenticationManager); } @Test public void resolveWhenUsingExternalSourceThenRespondsToChanges() { MockServerWebExchange exchange = withBearerToken(this.jwt); - Map authenticationManagers = new HashMap<>(); - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver(issuer -> Mono.justOrEmpty(authenticationManagers.get(issuer))); - assertThatCode(() -> authenticationManagerResolver.resolve(exchange).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Invalid issuer"); - + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + (issuer) -> Mono.justOrEmpty(authenticationManagers.get(issuer))); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .withMessageContaining("Invalid issuer"); ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class); authenticationManagers.put("trusted", authenticationManager); - assertThat(authenticationManagerResolver.resolve(exchange).block()) - .isSameAs(authenticationManager); - + assertThat(authenticationManagerResolver.resolve(exchange).block()).isSameAs(authenticationManager); authenticationManagers.clear(); - assertThatCode(() -> authenticationManagerResolver.resolve(exchange).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .withMessageContaining("Invalid issuer"); + // @formatter:on } @Test public void resolveWhenBearerTokenMalformedThenException() { - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver("trusted"); + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + "trusted"); MockServerWebExchange exchange = withBearerToken("jwt"); - assertThatCode(() -> authenticationManagerResolver.resolve(exchange).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageNotContaining("Invalid issuer"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .withMessageNotContaining("Invalid issuer"); + // @formatter:on } @Test public void resolveWhenBearerTokenNoIssuerThenException() { - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver("trusted"); + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + "trusted"); MockServerWebExchange exchange = withBearerToken(this.noIssuer); - assertThatCode(() -> authenticationManagerResolver.resolve(exchange).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Missing issuer"); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .withMessageContaining("Missing issuer"); } @Test public void resolveWhenBearerTokenEvilThenGenericException() { - JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = - new JwtIssuerReactiveAuthenticationManagerResolver("trusted"); + JwtIssuerReactiveAuthenticationManagerResolver authenticationManagerResolver = new JwtIssuerReactiveAuthenticationManagerResolver( + "trusted"); MockServerWebExchange exchange = withBearerToken(this.evil); - assertThatCode(() -> authenticationManagerResolver.resolve(exchange).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessage("Invalid token"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> authenticationManagerResolver.resolve(exchange).block()) + .withMessage("Invalid token"); + // @formatter:on } @Test public void constructorWhenNullOrEmptyIssuersThenException() { - assertThatCode(() -> new JwtIssuerReactiveAuthenticationManagerResolver((Collection) null)) - .isInstanceOf(IllegalArgumentException.class); - assertThatCode(() -> new JwtIssuerReactiveAuthenticationManagerResolver(Collections.emptyList())) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtIssuerReactiveAuthenticationManagerResolver((Collection) null)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtIssuerReactiveAuthenticationManagerResolver(Collections.emptyList())); } @Test public void constructorWhenNullAuthenticationManagerResolverThenException() { - assertThatCode(() -> new JwtIssuerReactiveAuthenticationManagerResolver((ReactiveAuthenticationManagerResolver) null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy( + () -> new JwtIssuerReactiveAuthenticationManagerResolver((ReactiveAuthenticationManagerResolver) null)); } private String jwt(String claim, String value) { @@ -179,8 +181,12 @@ public class JwtIssuerReactiveAuthenticationManagerResolverTests { } private MockServerWebExchange withBearerToken(String token) { + // @formatter:off MockServerHttpRequest request = MockServerHttpRequest.get("/") - .header("Authorization", "Bearer " + token).build(); + .header("Authorization", "Bearer " + token) + .build(); + // @formatter:on return MockServerWebExchange.from(request); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java index 8317aaba58..989044a046 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java @@ -32,13 +32,14 @@ import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch @@ -46,6 +47,7 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; */ @RunWith(MockitoJUnitRunner.class) public class JwtReactiveAuthenticationManagerTests { + @Mock private ReactiveJwtDecoder jwtDecoder; @@ -56,82 +58,91 @@ public class JwtReactiveAuthenticationManagerTests { @Before public void setup() { this.manager = new JwtReactiveAuthenticationManager(this.jwtDecoder); - this.jwt = jwt().claim("scope", "message:read message:write").build(); + // @formatter:off + this.jwt = TestJwts.jwt() + .claim("scope", "message:read message:write") + .build(); + // @formatter:on } @Test public void constructorWhenJwtDecoderNullThenIllegalArgumentException() { this.jwtDecoder = null; - assertThatCode(() -> new JwtReactiveAuthenticationManager(this.jwtDecoder)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JwtReactiveAuthenticationManager(this.jwtDecoder)); + // @formatter:on } @Test public void authenticateWhenWrongTypeThenEmpty() { TestingAuthenticationToken token = new TestingAuthenticationToken("foo", "bar"); - assertThat(this.manager.authenticate(token).block()).isNull(); } @Test public void authenticateWhenEmptyJwtThenEmpty() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(token.getToken())).thenReturn(Mono.empty()); - + given(this.jwtDecoder.decode(token.getToken())).willReturn(Mono.empty()); assertThat(this.manager.authenticate(token).block()).isNull(); } @Test public void authenticateWhenJwtExceptionThenOAuth2AuthenticationException() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.error(new BadJwtException("Oops"))); - - assertThatCode(() -> this.manager.authenticate(token).block()) - .isInstanceOf(OAuth2AuthenticationException.class); + given(this.jwtDecoder.decode(any())).willReturn(Mono.error(new BadJwtException("Oops"))); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(token).block()); } // gh-7549 @Test public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(token.getToken())).thenThrow(new BadJwtException("with \"invalid\" chars")); - - assertThatCode(() -> this.manager.authenticate(token).block()) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasFieldOrPropertyWithValue( - "error.description", - "Invalid token"); + given(this.jwtDecoder.decode(token.getToken())).willThrow(new BadJwtException("with \"invalid\" chars")); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(token).block()) + .satisfies((ex) -> assertThat(ex) + .hasFieldOrPropertyWithValue("error.description", "Invalid token") + ); + // @formatter:on } // gh-7785 @Test public void authenticateWhenDecoderFailsGenericallyThenThrowsGenericException() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("no jwk set")); - - assertThatCode(() -> this.manager.authenticate(token).block()) - .isInstanceOf(AuthenticationException.class) + given(this.jwtDecoder.decode(token.getToken())).willThrow(new JwtException("no jwk set")); + // @formatter:off + assertThatExceptionOfType(AuthenticationException.class) + .isThrownBy(() -> this.manager.authenticate(token).block()) .isNotInstanceOf(OAuth2AuthenticationException.class); + // @formatter:on } @Test public void authenticateWhenNotJwtExceptionThenPropagates() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.error(new RuntimeException("Oops"))); - - assertThatCode(() -> this.manager.authenticate(token).block()) - .isInstanceOf(RuntimeException.class); + given(this.jwtDecoder.decode(any())).willReturn(Mono.error(new RuntimeException("Oops"))); + // @formatter:off + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> this.manager.authenticate(token).block()); + // @formatter:on } @Test public void authenticateWhenJwtThenSuccess() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(token.getToken())).thenReturn(Mono.just(this.jwt)); - + given(this.jwtDecoder.decode(token.getToken())).willReturn(Mono.just(this.jwt)); Authentication authentication = this.manager.authenticate(token).block(); - assertThat(authentication).isNotNull(); assertThat(authentication.isAuthenticated()).isTrue(); - assertThat(authentication.getAuthorities()).extracting(GrantedAuthority::getAuthority).containsOnly("SCOPE_message:read", "SCOPE_message:write"); + // @formatter:off + assertThat(authentication.getAuthorities()) + .extracting(GrantedAuthority::getAuthority) + .containsOnly("SCOPE_message:read", "SCOPE_message:write"); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProviderTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProviderTests.java index 61496ff77d..9bf34998dd 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProviderTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenAuthenticationProviderTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.authentication; import java.net.URL; @@ -26,6 +27,7 @@ import org.junit.Test; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionAuthenticatedPrincipal; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames; @@ -33,19 +35,11 @@ import org.springframework.security.oauth2.server.resource.introspection.OAuth2I import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals.active; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ACTIVE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.AUDIENCE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUER; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.NOT_BEFORE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SCOPE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SUBJECT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.USERNAME; /** * Tests for {@link OpaqueTokenAuthenticationProvider} @@ -53,68 +47,71 @@ import static org.springframework.security.oauth2.server.resource.introspection. * @author Josh Cummings */ public class OpaqueTokenAuthenticationProviderTests { + @Test public void authenticateWhenActiveTokenThenOk() throws Exception { - OAuth2AuthenticatedPrincipal principal = active(attributes -> attributes.put("extension_field", "twenty-seven")); + OAuth2AuthenticatedPrincipal principal = TestOAuth2AuthenticatedPrincipals + .active((attributes) -> attributes.put("extension_field", "twenty-seven")); OpaqueTokenIntrospector introspector = mock(OpaqueTokenIntrospector.class); - when(introspector.introspect(any())).thenReturn(principal); + given(introspector.introspect(any())).willReturn(principal); OpaqueTokenAuthenticationProvider provider = new OpaqueTokenAuthenticationProvider(introspector); - - Authentication result = - provider.authenticate(new BearerTokenAuthenticationToken("token")); - + Authentication result = provider.authenticate(new BearerTokenAuthenticationToken("token")); assertThat(result.getPrincipal()).isInstanceOf(OAuth2IntrospectionAuthenticatedPrincipal.class); - Map attributes = ((OAuth2AuthenticatedPrincipal) result.getPrincipal()).getAttributes(); + // @formatter:off assertThat(attributes) .isNotNull() - .containsEntry(ACTIVE, true) - .containsEntry(AUDIENCE, Arrays.asList("https://protected.example.net/resource")) + .containsEntry(OAuth2IntrospectionClaimNames.ACTIVE, true) + .containsEntry(OAuth2IntrospectionClaimNames.AUDIENCE, + Arrays.asList("https://protected.example.net/resource")) .containsEntry(OAuth2IntrospectionClaimNames.CLIENT_ID, "l238j323ds-23ij4") - .containsEntry(EXPIRES_AT, Instant.ofEpochSecond(1419356238)) - .containsEntry(ISSUER, new URL("https://server.example.com/")) - .containsEntry(NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) - .containsEntry(SCOPE, Arrays.asList("read", "write", "dolphin")) - .containsEntry(SUBJECT, "Z5O3upPC88QrAjx00dis") - .containsEntry(USERNAME, "jdoe") + .containsEntry(OAuth2IntrospectionClaimNames.EXPIRES_AT, Instant.ofEpochSecond(1419356238)) + .containsEntry(OAuth2IntrospectionClaimNames.ISSUER, new URL("https://server.example.com/")) + .containsEntry(OAuth2IntrospectionClaimNames.NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) + .containsEntry(OAuth2IntrospectionClaimNames.SCOPE, Arrays.asList("read", "write", "dolphin")) + .containsEntry(OAuth2IntrospectionClaimNames.SUBJECT, "Z5O3upPC88QrAjx00dis") + .containsEntry(OAuth2IntrospectionClaimNames.USERNAME, "jdoe") .containsEntry("extension_field", "twenty-seven"); - - assertThat(result.getAuthorities()).extracting("authority") - .containsExactly("SCOPE_read", "SCOPE_write", "SCOPE_dolphin"); + assertThat(result.getAuthorities()) + .extracting("authority") + .containsExactly("SCOPE_read", "SCOPE_write", + "SCOPE_dolphin"); + // @formatter:on } @Test public void authenticateWhenMissingScopeAttributeThenNoAuthorities() { - OAuth2AuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal(Collections.singletonMap("claim", "value"), null); + OAuth2AuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal( + Collections.singletonMap("claim", "value"), null); OpaqueTokenIntrospector introspector = mock(OpaqueTokenIntrospector.class); - when(introspector.introspect(any())).thenReturn(principal); + given(introspector.introspect(any())).willReturn(principal); OpaqueTokenAuthenticationProvider provider = new OpaqueTokenAuthenticationProvider(introspector); - - Authentication result = - provider.authenticate(new BearerTokenAuthenticationToken("token")); + Authentication result = provider.authenticate(new BearerTokenAuthenticationToken("token")); assertThat(result.getPrincipal()).isInstanceOf(OAuth2AuthenticatedPrincipal.class); - Map attributes = ((OAuth2AuthenticatedPrincipal) result.getPrincipal()).getAttributes(); + // @formatter:off assertThat(attributes) .isNotNull() - .doesNotContainKey(SCOPE); - + .doesNotContainKey(OAuth2IntrospectionClaimNames.SCOPE); + // @formatter:on assertThat(result.getAuthorities()).isEmpty(); } @Test public void authenticateWhenIntrospectionEndpointThrowsExceptionThenInvalidToken() { OpaqueTokenIntrospector introspector = mock(OpaqueTokenIntrospector.class); - when(introspector.introspect(any())).thenThrow(new OAuth2IntrospectionException("with \"invalid\" chars")); + given(introspector.introspect(any())).willThrow(new OAuth2IntrospectionException("with \"invalid\" chars")); OpaqueTokenAuthenticationProvider provider = new OpaqueTokenAuthenticationProvider(introspector); - - assertThatCode(() -> provider.authenticate(new BearerTokenAuthenticationToken("token"))) - .isInstanceOf(AuthenticationServiceException.class); + assertThatExceptionOfType(AuthenticationServiceException.class) + .isThrownBy(() -> provider.authenticate(new BearerTokenAuthenticationToken("token"))); } @Test public void constructorWhenIntrospectionClientIsNullThenIllegalArgumentException() { - assertThatCode(() -> new OpaqueTokenAuthenticationProvider(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OpaqueTokenAuthenticationProvider(null)); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManagerTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManagerTests.java index e7b92f8a45..2dafb8f241 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/OpaqueTokenReactiveAuthenticationManagerTests.java @@ -28,6 +28,7 @@ import reactor.core.publisher.Mono; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionAuthenticatedPrincipal; import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames; @@ -35,19 +36,11 @@ import org.springframework.security.oauth2.server.resource.introspection.OAuth2I import org.springframework.security.oauth2.server.resource.introspection.ReactiveOpaqueTokenIntrospector; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals.active; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ACTIVE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.AUDIENCE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUER; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.NOT_BEFORE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SCOPE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SUBJECT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.USERNAME; /** * Tests for {@link OpaqueTokenReactiveAuthenticationManager} @@ -55,72 +48,68 @@ import static org.springframework.security.oauth2.server.resource.introspection. * @author Josh Cummings */ public class OpaqueTokenReactiveAuthenticationManagerTests { + @Test public void authenticateWhenActiveTokenThenOk() throws Exception { - OAuth2AuthenticatedPrincipal authority = active(attributes -> attributes.put("extension_field", "twenty-seven")); + OAuth2AuthenticatedPrincipal authority = TestOAuth2AuthenticatedPrincipals + .active((attributes) -> attributes.put("extension_field", "twenty-seven")); ReactiveOpaqueTokenIntrospector introspector = mock(ReactiveOpaqueTokenIntrospector.class); - when(introspector.introspect(any())).thenReturn(Mono.just(authority)); - OpaqueTokenReactiveAuthenticationManager provider = - new OpaqueTokenReactiveAuthenticationManager(introspector); - - Authentication result = - provider.authenticate(new BearerTokenAuthenticationToken("token")).block(); - + given(introspector.introspect(any())).willReturn(Mono.just(authority)); + OpaqueTokenReactiveAuthenticationManager provider = new OpaqueTokenReactiveAuthenticationManager(introspector); + Authentication result = provider.authenticate(new BearerTokenAuthenticationToken("token")).block(); assertThat(result.getPrincipal()).isInstanceOf(OAuth2IntrospectionAuthenticatedPrincipal.class); - Map attributes = ((OAuth2AuthenticatedPrincipal) result.getPrincipal()).getAttributes(); + // @formatter:off assertThat(attributes) .isNotNull() - .containsEntry(ACTIVE, true) - .containsEntry(AUDIENCE, Arrays.asList("https://protected.example.net/resource")) + .containsEntry(OAuth2IntrospectionClaimNames.ACTIVE, true) + .containsEntry(OAuth2IntrospectionClaimNames.AUDIENCE, + Arrays.asList("https://protected.example.net/resource")) .containsEntry(OAuth2IntrospectionClaimNames.CLIENT_ID, "l238j323ds-23ij4") - .containsEntry(EXPIRES_AT, Instant.ofEpochSecond(1419356238)) - .containsEntry(ISSUER, new URL("https://server.example.com/")) - .containsEntry(NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) - .containsEntry(SCOPE, Arrays.asList("read", "write", "dolphin")) - .containsEntry(SUBJECT, "Z5O3upPC88QrAjx00dis") - .containsEntry(USERNAME, "jdoe") + .containsEntry(OAuth2IntrospectionClaimNames.EXPIRES_AT, Instant.ofEpochSecond(1419356238)) + .containsEntry(OAuth2IntrospectionClaimNames.ISSUER, new URL("https://server.example.com/")) + .containsEntry(OAuth2IntrospectionClaimNames.NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) + .containsEntry(OAuth2IntrospectionClaimNames.SCOPE, Arrays.asList("read", "write", "dolphin")) + .containsEntry(OAuth2IntrospectionClaimNames.SUBJECT, "Z5O3upPC88QrAjx00dis") + .containsEntry(OAuth2IntrospectionClaimNames.USERNAME, "jdoe") .containsEntry("extension_field", "twenty-seven"); - - assertThat(result.getAuthorities()).extracting("authority") - .containsExactly("SCOPE_read", "SCOPE_write", "SCOPE_dolphin"); + assertThat(result.getAuthorities()) + .extracting("authority") + .containsExactly("SCOPE_read", "SCOPE_write", + "SCOPE_dolphin"); + // @formatter:on } @Test public void authenticateWhenMissingScopeAttributeThenNoAuthorities() { - OAuth2AuthenticatedPrincipal authority = new OAuth2IntrospectionAuthenticatedPrincipal(Collections.singletonMap("claim", "value"), null); + OAuth2AuthenticatedPrincipal authority = new OAuth2IntrospectionAuthenticatedPrincipal( + Collections.singletonMap("claim", "value"), null); ReactiveOpaqueTokenIntrospector introspector = mock(ReactiveOpaqueTokenIntrospector.class); - when(introspector.introspect(any())).thenReturn(Mono.just(authority)); - OpaqueTokenReactiveAuthenticationManager provider = - new OpaqueTokenReactiveAuthenticationManager(introspector); - - Authentication result = - provider.authenticate(new BearerTokenAuthenticationToken("token")).block(); + given(introspector.introspect(any())).willReturn(Mono.just(authority)); + OpaqueTokenReactiveAuthenticationManager provider = new OpaqueTokenReactiveAuthenticationManager(introspector); + Authentication result = provider.authenticate(new BearerTokenAuthenticationToken("token")).block(); assertThat(result.getPrincipal()).isInstanceOf(OAuth2IntrospectionAuthenticatedPrincipal.class); - Map attributes = ((OAuth2AuthenticatedPrincipal) result.getPrincipal()).getAttributes(); - assertThat(attributes) - .isNotNull() - .doesNotContainKey(SCOPE); - + assertThat(attributes).isNotNull().doesNotContainKey(OAuth2IntrospectionClaimNames.SCOPE); assertThat(result.getAuthorities()).isEmpty(); } @Test public void authenticateWhenIntrospectionEndpointThrowsExceptionThenInvalidToken() { ReactiveOpaqueTokenIntrospector introspector = mock(ReactiveOpaqueTokenIntrospector.class); - when(introspector.introspect(any())) - .thenReturn(Mono.error(new OAuth2IntrospectionException("with \"invalid\" chars"))); - OpaqueTokenReactiveAuthenticationManager provider = - new OpaqueTokenReactiveAuthenticationManager(introspector); - - assertThatCode(() -> provider.authenticate(new BearerTokenAuthenticationToken("token")).block()) - .isInstanceOf(AuthenticationServiceException.class); + given(introspector.introspect(any())) + .willReturn(Mono.error(new OAuth2IntrospectionException("with \"invalid\" chars"))); + OpaqueTokenReactiveAuthenticationManager provider = new OpaqueTokenReactiveAuthenticationManager(introspector); + assertThatExceptionOfType(AuthenticationServiceException.class) + .isThrownBy(() -> provider.authenticate(new BearerTokenAuthenticationToken("token")).block()); } @Test public void constructorWhenIntrospectionClientIsNullThenIllegalArgumentException() { - assertThatCode(() -> new OpaqueTokenReactiveAuthenticationManager(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OpaqueTokenReactiveAuthenticationManager(null)); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapterTests.java index 6e076f0677..292a49812f 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterAdapterTests.java @@ -26,9 +26,9 @@ import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link ReactiveJwtAuthenticationConverterAdapter} @@ -36,84 +36,76 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * @author Josh Cummings */ public class ReactiveJwtAuthenticationConverterAdapterTests { + Converter converter = new JwtAuthenticationConverter(); - ReactiveJwtAuthenticationConverterAdapter jwtAuthenticationConverter = - new ReactiveJwtAuthenticationConverterAdapter(converter); + + ReactiveJwtAuthenticationConverterAdapter jwtAuthenticationConverter = new ReactiveJwtAuthenticationConverterAdapter( + this.converter); @Test public void convertWhenTokenHasScopeAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - + Jwt jwt = TestJwts.jwt().claim("scope", "message:read message:write").build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), - new SimpleGrantedAuthority("SCOPE_message:write")); + // @formatter:off + assertThat(authorities) + .containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), + new SimpleGrantedAuthority("SCOPE_message:write")); + // @formatter:on } @Test public void convertWhenTokenHasEmptyScopeAttributeThenTranslatedToNoAuthorities() { - Jwt jwt = jwt().claim("scope", "").build(); - + Jwt jwt = TestJwts.jwt().claim("scope", "").build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); - Collection authorities = authentication.getAuthorities(); - assertThat(authorities).containsExactly(); } @Test public void convertWhenTokenHasScpAttributeThenTranslatedToAuthorities() { - Jwt jwt = jwt().claim("scp", Arrays.asList("message:read", "message:write")).build(); - + Jwt jwt = TestJwts.jwt().claim("scp", Arrays.asList("message:read", "message:write")).build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); - Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), - new SimpleGrantedAuthority("SCOPE_message:write")); + // @formatter:off + assertThat(authorities) + .containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), + new SimpleGrantedAuthority("SCOPE_message:write")); + // @formatter:on } @Test public void convertWhenTokenHasEmptyScpAttributeThenTranslatedToNoAuthorities() { - Jwt jwt = jwt().claim("scp", Arrays.asList()).build(); - + Jwt jwt = TestJwts.jwt().claim("scp", Arrays.asList()).build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); - Collection authorities = authentication.getAuthorities(); - assertThat(authorities).containsExactly(); } @Test public void convertWhenTokenHasBothScopeAndScpThenScopeAttributeIsTranslatedToAuthorities() { - Jwt jwt = jwt() - .claim("scp", Arrays.asList("message:read", "message:write")) - .claim("scope", "missive:read missive:write") - .build(); - + Jwt jwt = TestJwts.jwt().claim("scp", Arrays.asList("message:read", "message:write")) + .claim("scope", "missive:read missive:write").build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); - Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_missive:read"), - new SimpleGrantedAuthority("SCOPE_missive:write")); + // @formatter:off + assertThat(authorities) + .containsExactly(new SimpleGrantedAuthority("SCOPE_missive:read"), + new SimpleGrantedAuthority("SCOPE_missive:write")); + // @formatter:on } @Test public void convertWhenTokenHasEmptyScopeAndNonEmptyScpThenScopeAttributeIsTranslatedToNoAuthorities() { - Jwt jwt = jwt() + // @formatter:off + Jwt jwt = TestJwts.jwt() .claim("scp", Arrays.asList("message:read", "message:write")) .claim("scope", "") .build(); - + // @formatter:on AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); - Collection authorities = authentication.getAuthorities(); - assertThat(authorities).containsExactly(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterTests.java index 809c76be1d..7022bb254c 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtAuthenticationConverterTests.java @@ -26,10 +26,10 @@ import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link ReactiveJwtAuthenticationConverter} @@ -38,17 +38,15 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * @since 5.2 */ public class ReactiveJwtAuthenticationConverterTests { + ReactiveJwtAuthenticationConverter jwtAuthenticationConverter = new ReactiveJwtAuthenticationConverter(); @Test public void convertWhenDefaultGrantedAuthoritiesConverterSet() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - + Jwt jwt = TestJwts.jwt().claim("scope", "message:read message:write").build(); AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("SCOPE_message:read"), + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("SCOPE_message:read"), new SimpleGrantedAuthority("SCOPE_message:write")); } @@ -61,17 +59,13 @@ public class ReactiveJwtAuthenticationConverterTests { @Test public void convertWithOverriddenGrantedAuthoritiesConverter() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - - Converter> grantedAuthoritiesConverter = - token -> Flux.just(new SimpleGrantedAuthority("blah")); - + Jwt jwt = TestJwts.jwt().claim("scope", "message:read message:write").build(); + Converter> grantedAuthoritiesConverter = (token) -> Flux + .just(new SimpleGrantedAuthority("blah")); this.jwtAuthenticationConverter.setJwtGrantedAuthoritiesConverter(grantedAuthoritiesConverter); - AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt).block(); Collection authorities = authentication.getAuthorities(); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("blah")); + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("blah")); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapterTests.java index e7ff494522..9c099c050b 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/ReactiveJwtGrantedAuthoritiesConverterAdapterTests.java @@ -26,10 +26,10 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** * Tests for {@link ReactiveJwtGrantedAuthoritiesConverterAdapter} @@ -38,27 +38,24 @@ import static org.springframework.security.oauth2.jwt.TestJwts.jwt; * @since 5.2 */ public class ReactiveJwtGrantedAuthoritiesConverterAdapterTests { + @Test public void convertWithGrantedAuthoritiesConverter() { - Jwt jwt = jwt().claim("scope", "message:read message:write").build(); - - Converter> grantedAuthoritiesConverter = - token -> Arrays.asList(new SimpleGrantedAuthority("blah")); - - Collection authorities = - new ReactiveJwtGrantedAuthoritiesConverterAdapter(grantedAuthoritiesConverter) - .convert(jwt) - .toStream() - .collect(Collectors.toList()); - - assertThat(authorities).containsExactly( - new SimpleGrantedAuthority("blah")); + Jwt jwt = TestJwts.jwt().claim("scope", "message:read message:write").build(); + Converter> grantedAuthoritiesConverter = (token) -> Arrays + .asList(new SimpleGrantedAuthority("blah")); + Collection authorities = new ReactiveJwtGrantedAuthoritiesConverterAdapter( + grantedAuthoritiesConverter).convert(jwt).toStream().collect(Collectors.toList()); + assertThat(authorities).containsExactly(new SimpleGrantedAuthority("blah")); } @Test public void whenConstructingWithInvalidConverter() { + // @formatter:off assertThatIllegalArgumentException() .isThrownBy(() -> new ReactiveJwtGrantedAuthoritiesConverterAdapter(null)) .withMessage("grantedAuthoritiesConverter cannot be null"); + // @formatter:on } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/TestBearerTokenAuthentications.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/TestBearerTokenAuthentications.java index 72a2584005..2ef4b005a6 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/TestBearerTokenAuthentications.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/TestBearerTokenAuthentications.java @@ -33,19 +33,18 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; * * @author Josh Cummings */ -public class TestBearerTokenAuthentications { - public static BearerTokenAuthentication bearer() { - Collection authorities = - AuthorityUtils.createAuthorityList("SCOPE_USER"); - OAuth2AuthenticatedPrincipal principal = - new DefaultOAuth2AuthenticatedPrincipal( - Collections.singletonMap("sub", "user"), - authorities); - OAuth2AccessToken token = - new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", Instant.now(), Instant.now().plusSeconds(86400), - new HashSet<>(Arrays.asList("USER"))); +public final class TestBearerTokenAuthentications { + private TestBearerTokenAuthentications() { + } + + public static BearerTokenAuthentication bearer() { + Collection authorities = AuthorityUtils.createAuthorityList("SCOPE_USER"); + OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal( + Collections.singletonMap("sub", "user"), authorities); + OAuth2AccessToken token = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), + Instant.now().plusSeconds(86400), new HashSet<>(Arrays.asList("USER"))); return new BearerTokenAuthentication(principal, token, authorities); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospectorTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospectorTests.java index 8b49fb98e4..49ca325058 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospectorTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusOpaqueTokenIntrospectorTests.java @@ -43,20 +43,13 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.AUDIENCE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUER; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.NOT_BEFORE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SCOPE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SUBJECT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.USERNAME; /** * Tests for {@link NimbusOpaqueTokenIntrospector} @@ -64,237 +57,237 @@ import static org.springframework.security.oauth2.server.resource.introspection. public class NimbusOpaqueTokenIntrospectorTests { private static final String INTROSPECTION_URL = "https://server.example.com"; + private static final String CLIENT_ID = "client"; + private static final String CLIENT_SECRET = "secret"; - private static final String ACTIVE_RESPONSE = "{\n" + - " \"active\": true,\n" + - " \"client_id\": \"l238j323ds-23ij4\",\n" + - " \"username\": \"jdoe\",\n" + - " \"scope\": \"read write dolphin\",\n" + - " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + - " \"aud\": \"https://protected.example.net/resource\",\n" + - " \"iss\": \"https://server.example.com/\",\n" + - " \"exp\": 1419356238,\n" + - " \"iat\": 1419350238,\n" + - " \"extension_field\": \"twenty-seven\"\n" + - " }"; + // @formatter:off + private static final String ACTIVE_RESPONSE = "{\n" + + " \"active\": true,\n" + + " \"client_id\": \"l238j323ds-23ij4\",\n" + + " \"username\": \"jdoe\",\n" + + " \"scope\": \"read write dolphin\",\n" + + " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + + " \"aud\": \"https://protected.example.net/resource\",\n" + + " \"iss\": \"https://server.example.com/\",\n" + + " \"exp\": 1419356238,\n" + + " \"iat\": 1419350238,\n" + + " \"extension_field\": \"twenty-seven\"\n" + + " }"; + // @formatter:on - private static final String INACTIVE_RESPONSE = "{\n" + - " \"active\": false\n" + - " }"; + // @formatter:off + private static final String INACTIVE_RESPONSE = "{\n" + + " \"active\": false\n" + + " }"; + // @formatter:on - private static final String INVALID_RESPONSE = "{\n" + - " \"client_id\": \"l238j323ds-23ij4\",\n" + - " \"username\": \"jdoe\",\n" + - " \"scope\": \"read write dolphin\",\n" + - " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + - " \"aud\": \"https://protected.example.net/resource\",\n" + - " \"iss\": \"https://server.example.com/\",\n" + - " \"exp\": 1419356238,\n" + - " \"iat\": 1419350238,\n" + - " \"extension_field\": \"twenty-seven\"\n" + - " }"; + // @formatter:off + private static final String INVALID_RESPONSE = "{\n" + + " \"client_id\": \"l238j323ds-23ij4\",\n" + + " \"username\": \"jdoe\",\n" + + " \"scope\": \"read write dolphin\",\n" + + " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + + " \"aud\": \"https://protected.example.net/resource\",\n" + + " \"iss\": \"https://server.example.com/\",\n" + + " \"exp\": 1419356238,\n" + + " \"iat\": 1419350238,\n" + + " \"extension_field\": \"twenty-seven\"\n" + + " }"; + // @formatter:on - private static final String MALFORMED_ISSUER_RESPONSE = "{\n" + - " \"active\" : \"true\",\n" + - " \"iss\" : \"badissuer\"\n" + - " }"; + // @formatter:off + private static final String MALFORMED_ISSUER_RESPONSE = "{\n" + + " \"active\" : \"true\",\n" + + " \"iss\" : \"badissuer\"\n" + + " }"; + // @formatter:on - private static final String MALFORMED_SCOPE_RESPONSE = "{\n" + - " \"active\": true,\n" + - " \"client_id\": \"l238j323ds-23ij4\",\n" + - " \"username\": \"jdoe\",\n" + - " \"scope\": [ \"read\", \"write\", \"dolphin\" ],\n" + - " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + - " \"aud\": \"https://protected.example.net/resource\",\n" + - " \"iss\": \"https://server.example.com/\",\n" + - " \"exp\": 1419356238,\n" + - " \"iat\": 1419350238,\n" + - " \"extension_field\": \"twenty-seven\"\n" + - " }"; + // @formatter:off + private static final String MALFORMED_SCOPE_RESPONSE = "{\n" + + " \"active\": true,\n" + + " \"client_id\": \"l238j323ds-23ij4\",\n" + + " \"username\": \"jdoe\",\n" + + " \"scope\": [ \"read\", \"write\", \"dolphin\" ],\n" + + " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + + " \"aud\": \"https://protected.example.net/resource\",\n" + + " \"iss\": \"https://server.example.com/\",\n" + + " \"exp\": 1419356238,\n" + + " \"iat\": 1419350238,\n" + + " \"extension_field\": \"twenty-seven\"\n" + + " }"; + // @formatter:on private static final ResponseEntity ACTIVE = response(ACTIVE_RESPONSE); + private static final ResponseEntity INACTIVE = response(INACTIVE_RESPONSE); + private static final ResponseEntity INVALID = response(INVALID_RESPONSE); + private static final ResponseEntity MALFORMED_ISSUER = response(MALFORMED_ISSUER_RESPONSE); + private static final ResponseEntity MALFORMED_SCOPE = response(MALFORMED_SCOPE_RESPONSE); @Test public void introspectWhenActiveTokenThenOk() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.setDispatcher(requiresAuth(CLIENT_ID, CLIENT_SECRET, ACTIVE_RESPONSE)); - String introspectUri = server.url("/introspect").toString(); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(introspectUri, CLIENT_ID, CLIENT_SECRET); - + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(introspectUri, CLIENT_ID, + CLIENT_SECRET); OAuth2AuthenticatedPrincipal authority = introspectionClient.introspect("token"); + // @formatter:off assertThat(authority.getAttributes()) .isNotNull() .containsEntry(OAuth2IntrospectionClaimNames.ACTIVE, true) - .containsEntry(AUDIENCE, Arrays.asList("https://protected.example.net/resource")) + .containsEntry(OAuth2IntrospectionClaimNames.AUDIENCE, + Arrays.asList("https://protected.example.net/resource")) .containsEntry(OAuth2IntrospectionClaimNames.CLIENT_ID, "l238j323ds-23ij4") - .containsEntry(EXPIRES_AT, Instant.ofEpochSecond(1419356238)) - .containsEntry(ISSUER, new URL("https://server.example.com/")) - .containsEntry(SCOPE, Arrays.asList("read", "write", "dolphin")) - .containsEntry(SUBJECT, "Z5O3upPC88QrAjx00dis") - .containsEntry(USERNAME, "jdoe") + .containsEntry(OAuth2IntrospectionClaimNames.EXPIRES_AT, Instant.ofEpochSecond(1419356238)) + .containsEntry(OAuth2IntrospectionClaimNames.ISSUER, new URL("https://server.example.com/")) + .containsEntry(OAuth2IntrospectionClaimNames.SCOPE, Arrays.asList("read", "write", "dolphin")) + .containsEntry(OAuth2IntrospectionClaimNames.SUBJECT, "Z5O3upPC88QrAjx00dis") + .containsEntry(OAuth2IntrospectionClaimNames.USERNAME, "jdoe") .containsEntry("extension_field", "twenty-seven"); + // @formatter:on } } @Test public void introspectWhenBadClientCredentialsThenError() throws IOException { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.setDispatcher(requiresAuth(CLIENT_ID, CLIENT_SECRET, ACTIVE_RESPONSE)); - String introspectUri = server.url("/introspect").toString(); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(introspectUri, CLIENT_ID, "wrong"); - - assertThatCode(() -> introspectionClient.introspect("token")) - .isInstanceOf(OAuth2IntrospectionException.class); + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(introspectUri, CLIENT_ID, + "wrong"); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token")); } } @Test public void introspectWhenInactiveTokenThenInvalidToken() { RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(INACTIVE); - - assertThatCode(() -> introspectionClient.introspect("token")) - .isInstanceOf(OAuth2IntrospectionException.class) - .extracting("message") - .isEqualTo("Provided token isn't active"); + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))).willReturn(INACTIVE); + // @formatter:off + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token")) + .withMessage("Provided token isn't active"); + // @formatter:on } @Test public void introspectWhenActiveTokenThenParsesValuesInResponse() { Map introspectedValues = new HashMap<>(); introspectedValues.put(OAuth2IntrospectionClaimNames.ACTIVE, true); - introspectedValues.put(AUDIENCE, Arrays.asList("aud")); - introspectedValues.put(NOT_BEFORE, 29348723984L); - + introspectedValues.put(OAuth2IntrospectionClaimNames.AUDIENCE, Arrays.asList("aud")); + introspectedValues.put(OAuth2IntrospectionClaimNames.NOT_BEFORE, 29348723984L); RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(response(new JSONObject(introspectedValues).toJSONString())); - + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(response(new JSONObject(introspectedValues).toJSONString())); OAuth2AuthenticatedPrincipal authority = introspectionClient.introspect("token"); + // @formatter:off assertThat(authority.getAttributes()) .isNotNull() .containsEntry(OAuth2IntrospectionClaimNames.ACTIVE, true) - .containsEntry(AUDIENCE, Arrays.asList("aud")) - .containsEntry(NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) + .containsEntry(OAuth2IntrospectionClaimNames.AUDIENCE, Arrays.asList("aud")) + .containsEntry(OAuth2IntrospectionClaimNames.NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) .doesNotContainKey(OAuth2IntrospectionClaimNames.CLIENT_ID) - .doesNotContainKey(SCOPE); + .doesNotContainKey(OAuth2IntrospectionClaimNames.SCOPE); + // @formatter:on } @Test public void introspectWhenIntrospectionEndpointThrowsExceptionThenInvalidToken() { RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenThrow(new IllegalStateException("server was unresponsive")); - - assertThatCode(() -> introspectionClient.introspect("token")) - .isInstanceOf(OAuth2IntrospectionException.class) - .extracting("message") - .isEqualTo("server was unresponsive"); + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willThrow(new IllegalStateException("server was unresponsive")); + // @formatter:off + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token")) + .withMessage("server was unresponsive"); + // @formatter:on } - @Test public void introspectWhenIntrospectionEndpointReturnsMalformedResponseThenInvalidToken() { RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(response("malformed")); - - assertThatCode(() -> introspectionClient.introspect("token")) - .isInstanceOf(OAuth2IntrospectionException.class); + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))).willReturn(response("malformed")); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token")); } @Test public void introspectWhenIntrospectionTokenReturnsInvalidResponseThenInvalidToken() { RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(INVALID); - - assertThatCode(() -> introspectionClient.introspect("token")) - .isInstanceOf(OAuth2IntrospectionException.class); + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))).willReturn(INVALID); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token")); } @Test public void introspectWhenIntrospectionTokenReturnsMalformedIssuerResponseThenInvalidToken() { RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(MALFORMED_ISSUER); - - assertThatCode(() -> introspectionClient.introspect("token")) - .isInstanceOf(OAuth2IntrospectionException.class); + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))).willReturn(MALFORMED_ISSUER); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token")); } // gh-7563 @Test public void introspectWhenIntrospectionTokenReturnsMalformedScopeThenEmptyAuthorities() { RestOperations restOperations = mock(RestOperations.class); - OpaqueTokenIntrospector introspectionClient = - new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, restOperations); - when(restOperations.exchange(any(RequestEntity.class), eq(String.class))) - .thenReturn(MALFORMED_SCOPE); - + OpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))).willReturn(MALFORMED_SCOPE); OAuth2AuthenticatedPrincipal principal = introspectionClient.introspect("token"); assertThat(principal.getAuthorities()).isEmpty(); - assertThat((Object) principal.getAttribute("scope")) - .isNotNull() - .isInstanceOf(JSONArray.class); JSONArray scope = principal.getAttribute("scope"); assertThat(scope).containsExactly("read", "write", "dolphin"); } @Test public void constructorWhenIntrospectionUriIsNullThenIllegalArgumentException() { - assertThatCode(() -> new NimbusOpaqueTokenIntrospector(null, CLIENT_ID, CLIENT_SECRET)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusOpaqueTokenIntrospector(null, CLIENT_ID, CLIENT_SECRET)); } @Test public void constructorWhenClientIdIsNullThenIllegalArgumentException() { - assertThatCode(() -> new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, null, CLIENT_SECRET)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, null, CLIENT_SECRET)); } @Test public void constructorWhenClientSecretIsNullThenIllegalArgumentException() { - assertThatCode(() -> new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, CLIENT_ID, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, CLIENT_ID, null)); } @Test public void constructorWhenRestOperationsIsNullThenIllegalArgumentException() { - assertThatCode(() -> new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, null)); } @Test public void setRequestEntityConverterWhenConverterIsNullThenExceptionIsThrown() { RestOperations restOperations = mock(RestOperations.class); - - NimbusOpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector( - INTROSPECTION_URL, restOperations - ); - + NimbusOpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); assertThatExceptionOfType(IllegalArgumentException.class) .isThrownBy(() -> introspectionClient.setRequestEntityConverter(null)); } @@ -306,15 +299,12 @@ public class NimbusOpaqueTokenIntrospectorTests { Converter> requestEntityConverter = mock(Converter.class); RequestEntity requestEntity = mock(RequestEntity.class); String tokenToIntrospect = "some token"; - when(requestEntityConverter.convert(tokenToIntrospect)).thenReturn(requestEntity); - when(restOperations.exchange(requestEntity, String.class)).thenReturn(ACTIVE); - NimbusOpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector( - INTROSPECTION_URL, restOperations - ); + given(requestEntityConverter.convert(tokenToIntrospect)).willReturn(requestEntity); + given(restOperations.exchange(requestEntity, String.class)).willReturn(ACTIVE); + NimbusOpaqueTokenIntrospector introspectionClient = new NimbusOpaqueTokenIntrospector(INTROSPECTION_URL, + restOperations); introspectionClient.setRequestEntityConverter(requestEntityConverter); - introspectionClient.introspect(tokenToIntrospect); - verify(requestEntityConverter).convert(tokenToIntrospect); } @@ -329,10 +319,12 @@ public class NimbusOpaqueTokenIntrospectorTests { @Override public MockResponse dispatch(RecordedRequest request) { String authorization = request.getHeader(HttpHeaders.AUTHORIZATION); + // @formatter:off return Optional.ofNullable(authorization) - .filter(a -> isAuthorized(authorization, username, password)) - .map(a -> ok(response)) + .filter((a) -> isAuthorized(authorization, username, password)) + .map((a) -> ok(response)) .orElse(unauthorized()); + // @formatter:on } }; } @@ -343,11 +335,14 @@ public class NimbusOpaqueTokenIntrospectorTests { } private static MockResponse ok(String response) { + // @formatter:off return new MockResponse().setBody(response) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + // @formatter:on } private static MockResponse unauthorized() { return new MockResponse().setResponseCode(401); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospectorTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospectorTests.java index cc7f50d3f9..666bf301ac 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospectorTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/NimbusReactiveOpaqueTokenIntrospectorTests.java @@ -41,219 +41,221 @@ import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.AUDIENCE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.EXPIRES_AT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.ISSUER; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.NOT_BEFORE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SCOPE; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SUBJECT; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.USERNAME; /** * Tests for {@link NimbusReactiveOpaqueTokenIntrospector} */ public class NimbusReactiveOpaqueTokenIntrospectorTests { + private static final String INTROSPECTION_URL = "https://server.example.com"; + private static final String CLIENT_ID = "client"; + private static final String CLIENT_SECRET = "secret"; - private static final String ACTIVE_RESPONSE = "{\n" + - " \"active\": true,\n" + - " \"client_id\": \"l238j323ds-23ij4\",\n" + - " \"username\": \"jdoe\",\n" + - " \"scope\": \"read write dolphin\",\n" + - " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + - " \"aud\": \"https://protected.example.net/resource\",\n" + - " \"iss\": \"https://server.example.com/\",\n" + - " \"exp\": 1419356238,\n" + - " \"iat\": 1419350238,\n" + - " \"extension_field\": \"twenty-seven\"\n" + - " }"; + // @formatter:off + private static final String ACTIVE_RESPONSE = "{\n" + + " \"active\": true,\n" + + " \"client_id\": \"l238j323ds-23ij4\",\n" + + " \"username\": \"jdoe\",\n" + + " \"scope\": \"read write dolphin\",\n" + + " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + + " \"aud\": \"https://protected.example.net/resource\",\n" + + " \"iss\": \"https://server.example.com/\",\n" + + " \"exp\": 1419356238,\n" + + " \"iat\": 1419350238,\n" + + " \"extension_field\": \"twenty-seven\"\n" + + " }"; + // @formatter:on - private static final String INACTIVE_RESPONSE = "{\n" + - " \"active\": false\n" + - " }"; + // @formatter:off + private static final String INACTIVE_RESPONSE = "{\n" + + " \"active\": false\n" + + " }"; + // @formatter:on - private static final String INVALID_RESPONSE = "{\n" + - " \"client_id\": \"l238j323ds-23ij4\",\n" + - " \"username\": \"jdoe\",\n" + - " \"scope\": \"read write dolphin\",\n" + - " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + - " \"aud\": \"https://protected.example.net/resource\",\n" + - " \"iss\": \"https://server.example.com/\",\n" + - " \"exp\": 1419356238,\n" + - " \"iat\": 1419350238,\n" + - " \"extension_field\": \"twenty-seven\"\n" + - " }"; + // @formatter:off + private static final String INVALID_RESPONSE = "{\n" + + " \"client_id\": \"l238j323ds-23ij4\",\n" + + " \"username\": \"jdoe\",\n" + + " \"scope\": \"read write dolphin\",\n" + + " \"sub\": \"Z5O3upPC88QrAjx00dis\",\n" + + " \"aud\": \"https://protected.example.net/resource\",\n" + + " \"iss\": \"https://server.example.com/\",\n" + + " \"exp\": 1419356238,\n" + + " \"iat\": 1419350238,\n" + + " \"extension_field\": \"twenty-seven\"\n" + + " }"; + // @formatter:on - private static final String MALFORMED_ISSUER_RESPONSE = "{\n" + - " \"active\" : \"true\",\n" + - " \"iss\" : \"badissuer\"\n" + - " }"; + // @formatter:off + private static final String MALFORMED_ISSUER_RESPONSE = "{\n" + + " \"active\" : \"true\",\n" + + " \"iss\" : \"badissuer\"\n" + + " }"; + // @formatter:on @Test public void authenticateWhenActiveTokenThenOk() throws Exception { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.setDispatcher(requiresAuth(CLIENT_ID, CLIENT_SECRET, ACTIVE_RESPONSE)); - String introspectUri = server.url("/introspect").toString(); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(introspectUri, CLIENT_ID, CLIENT_SECRET); - + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + introspectUri, CLIENT_ID, CLIENT_SECRET); OAuth2AuthenticatedPrincipal authority = introspectionClient.introspect("token").block(); + // @formatter:off assertThat(authority.getAttributes()) .isNotNull() .containsEntry(OAuth2IntrospectionClaimNames.ACTIVE, true) - .containsEntry(AUDIENCE, Arrays.asList("https://protected.example.net/resource")) + .containsEntry(OAuth2IntrospectionClaimNames.AUDIENCE, + Arrays.asList("https://protected.example.net/resource")) .containsEntry(OAuth2IntrospectionClaimNames.CLIENT_ID, "l238j323ds-23ij4") - .containsEntry(EXPIRES_AT, Instant.ofEpochSecond(1419356238)) - .containsEntry(ISSUER, new URL("https://server.example.com/")) - .containsEntry(SCOPE, Arrays.asList("read", "write", "dolphin")) - .containsEntry(SUBJECT, "Z5O3upPC88QrAjx00dis") - .containsEntry(USERNAME, "jdoe") + .containsEntry(OAuth2IntrospectionClaimNames.EXPIRES_AT, Instant.ofEpochSecond(1419356238)) + .containsEntry(OAuth2IntrospectionClaimNames.ISSUER, new URL("https://server.example.com/")) + .containsEntry(OAuth2IntrospectionClaimNames.SCOPE, Arrays.asList("read", "write", "dolphin")) + .containsEntry(OAuth2IntrospectionClaimNames.SUBJECT, "Z5O3upPC88QrAjx00dis") + .containsEntry(OAuth2IntrospectionClaimNames.USERNAME, "jdoe") .containsEntry("extension_field", "twenty-seven"); + // @formatter:on } } @Test public void authenticateWhenBadClientCredentialsThenAuthenticationException() throws IOException { - try ( MockWebServer server = new MockWebServer() ) { + try (MockWebServer server = new MockWebServer()) { server.setDispatcher(requiresAuth(CLIENT_ID, CLIENT_SECRET, ACTIVE_RESPONSE)); - String introspectUri = server.url("/introspect").toString(); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(introspectUri, CLIENT_ID, "wrong"); + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + introspectUri, CLIENT_ID, "wrong"); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token").block()); - assertThatCode(() -> introspectionClient.introspect("token").block()) - .isInstanceOf(OAuth2IntrospectionException.class); } } @Test public void authenticateWhenInactiveTokenThenInvalidToken() { WebClient webClient = mockResponse(INACTIVE_RESPONSE); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, webClient); - - assertThatCode(() -> introspectionClient.introspect("token").block()) - .isInstanceOf(BadOpaqueTokenException.class) - .extracting("message") - .isEqualTo("Provided token isn't active"); + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + INTROSPECTION_URL, webClient); + assertThatExceptionOfType(BadOpaqueTokenException.class) + .isThrownBy(() -> introspectionClient.introspect("token").block()) + .withMessage("Provided token isn't active"); } @Test public void authenticateWhenActiveTokenThenParsesValuesInResponse() { Map introspectedValues = new HashMap<>(); introspectedValues.put(OAuth2IntrospectionClaimNames.ACTIVE, true); - introspectedValues.put(AUDIENCE, Arrays.asList("aud")); - introspectedValues.put(NOT_BEFORE, 29348723984L); - + introspectedValues.put(OAuth2IntrospectionClaimNames.AUDIENCE, Arrays.asList("aud")); + introspectedValues.put(OAuth2IntrospectionClaimNames.NOT_BEFORE, 29348723984L); WebClient webClient = mockResponse(new JSONObject(introspectedValues).toJSONString()); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, webClient); - + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + INTROSPECTION_URL, webClient); OAuth2AuthenticatedPrincipal authority = introspectionClient.introspect("token").block(); + // @formatter:off assertThat(authority.getAttributes()) .isNotNull() .containsEntry(OAuth2IntrospectionClaimNames.ACTIVE, true) - .containsEntry(AUDIENCE, Arrays.asList("aud")) - .containsEntry(NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) + .containsEntry(OAuth2IntrospectionClaimNames.AUDIENCE, Arrays.asList("aud")) + .containsEntry(OAuth2IntrospectionClaimNames.NOT_BEFORE, Instant.ofEpochSecond(29348723984L)) .doesNotContainKey(OAuth2IntrospectionClaimNames.CLIENT_ID) - .doesNotContainKey(SCOPE); + .doesNotContainKey(OAuth2IntrospectionClaimNames.SCOPE); + // @formatter:on } @Test public void authenticateWhenIntrospectionEndpointThrowsExceptionThenInvalidToken() { WebClient webClient = mockResponse(new IllegalStateException("server was unresponsive")); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, webClient); - - assertThatCode(() -> introspectionClient.introspect("token").block()) - .isInstanceOf(OAuth2IntrospectionException.class) - .extracting("message") - .isEqualTo("server was unresponsive"); + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + INTROSPECTION_URL, webClient); + // @formatter:off + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token").block()) + .withMessage("server was unresponsive"); + // @formatter:on } @Test public void authenticateWhenIntrospectionEndpointReturnsMalformedResponseThenInvalidToken() { WebClient webClient = mockResponse("malformed"); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, webClient); - - assertThatCode(() -> introspectionClient.introspect("token").block()) - .isInstanceOf(OAuth2IntrospectionException.class); + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + INTROSPECTION_URL, webClient); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token").block()); } @Test public void authenticateWhenIntrospectionTokenReturnsInvalidResponseThenInvalidToken() { WebClient webClient = mockResponse(INVALID_RESPONSE); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, webClient); - - assertThatCode(() -> introspectionClient.introspect("token").block()) - .isInstanceOf(OAuth2IntrospectionException.class); + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + INTROSPECTION_URL, webClient); + // @formatter:off + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token").block()); + // @formatter:on } @Test public void authenticateWhenIntrospectionTokenReturnsMalformedIssuerResponseThenInvalidToken() { WebClient webClient = mockResponse(MALFORMED_ISSUER_RESPONSE); - NimbusReactiveOpaqueTokenIntrospector introspectionClient = - new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, webClient); - - assertThatCode(() -> introspectionClient.introspect("token").block()) - .isInstanceOf(OAuth2IntrospectionException.class); + NimbusReactiveOpaqueTokenIntrospector introspectionClient = new NimbusReactiveOpaqueTokenIntrospector( + INTROSPECTION_URL, webClient); + assertThatExceptionOfType(OAuth2IntrospectionException.class) + .isThrownBy(() -> introspectionClient.introspect("token").block()); } @Test public void constructorWhenIntrospectionUriIsEmptyThenIllegalArgumentException() { - assertThatCode(() -> new NimbusReactiveOpaqueTokenIntrospector("", CLIENT_ID, CLIENT_SECRET)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusReactiveOpaqueTokenIntrospector("", CLIENT_ID, CLIENT_SECRET)); } @Test public void constructorWhenClientIdIsEmptyThenIllegalArgumentException() { - assertThatCode(() -> new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, "", CLIENT_SECRET)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, "", CLIENT_SECRET)); } @Test public void constructorWhenClientSecretIsNullThenIllegalArgumentException() { - assertThatCode(() -> new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, CLIENT_ID, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, CLIENT_ID, null)); } @Test public void constructorWhenRestOperationsIsNullThenIllegalArgumentException() { - assertThatCode(() -> new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new NimbusReactiveOpaqueTokenIntrospector(INTROSPECTION_URL, null)); } private WebClient mockResponse(String response) { WebClient real = WebClient.builder().build(); WebClient.RequestBodyUriSpec spec = spy(real.post()); WebClient webClient = spy(WebClient.class); - when(webClient.post()).thenReturn(spec); + given(webClient.post()).willReturn(spec); ClientResponse clientResponse = mock(ClientResponse.class); - when(clientResponse.rawStatusCode()).thenReturn(200); - when(clientResponse.statusCode()).thenReturn(HttpStatus.OK); - when(clientResponse.bodyToMono(String.class)).thenReturn(Mono.just(response)); + given(clientResponse.rawStatusCode()).willReturn(200); + given(clientResponse.statusCode()).willReturn(HttpStatus.OK); + given(clientResponse.bodyToMono(String.class)).willReturn(Mono.just(response)); ClientResponse.Headers headers = mock(ClientResponse.Headers.class); - when(headers.contentType()).thenReturn(Optional.of(MediaType.APPLICATION_JSON_UTF8)); - when(clientResponse.headers()).thenReturn(headers); - when(spec.exchange()).thenReturn(Mono.just(clientResponse)); + given(headers.contentType()).willReturn(Optional.of(MediaType.APPLICATION_JSON_UTF8)); + given(clientResponse.headers()).willReturn(headers); + given(spec.exchange()).willReturn(Mono.just(clientResponse)); return webClient; } - private WebClient mockResponse(Throwable t) { + private WebClient mockResponse(Throwable ex) { WebClient real = WebClient.builder().build(); WebClient.RequestBodyUriSpec spec = spy(real.post()); WebClient webClient = spy(WebClient.class); - when(webClient.post()).thenReturn(spec); - when(spec.exchange()).thenThrow(t); + given(webClient.post()).willReturn(spec); + given(spec.exchange()).willThrow(ex); return webClient; } @@ -262,10 +264,12 @@ public class NimbusReactiveOpaqueTokenIntrospectorTests { @Override public MockResponse dispatch(RecordedRequest request) { String authorization = request.getHeader(HttpHeaders.AUTHORIZATION); + // @formatter:off return Optional.ofNullable(authorization) - .filter(a -> isAuthorized(authorization, username, password)) - .map(a -> ok(response)) + .filter((a) -> isAuthorized(authorization, username, password)) + .map((a) -> ok(response)) .orElse(unauthorized()); + // @formatter:on } }; } @@ -276,11 +280,14 @@ public class NimbusReactiveOpaqueTokenIntrospectorTests { } private static MockResponse ok(String response) { + // @formatter:off return new MockResponse().setBody(response) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + // @formatter:on } private static MockResponse unauthorized() { return new MockResponse().setResponseCode(401); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipalTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipalTests.java index 83b6f318f4..bb96cb17ec 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipalTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/introspection/OAuth2IntrospectionAuthenticatedPrincipalTests.java @@ -16,9 +16,6 @@ package org.springframework.security.oauth2.server.resource.introspection; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; - import java.time.Instant; import java.util.Arrays; import java.util.Collection; @@ -28,47 +25,72 @@ import java.util.List; import java.util.Map; import org.junit.Test; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + /** * Tests for {@link OAuth2IntrospectionAuthenticatedPrincipal} * * @author David Kovac */ public class OAuth2IntrospectionAuthenticatedPrincipalTests { + private static final String AUTHORITY = "SCOPE_read"; + private static final Collection AUTHORITIES = AuthorityUtils.createAuthorityList(AUTHORITY); private static final String SUBJECT = "test-subject"; private static final String ACTIVE_CLAIM = "active"; + private static final String CLIENT_ID_CLAIM = "client_id"; + private static final String USERNAME_CLAIM = "username"; + private static final String TOKEN_TYPE_CLAIM = "token_type"; + private static final String EXP_CLAIM = "exp"; + private static final String IAT_CLAIM = "iat"; + private static final String NBF_CLAIM = "nbf"; + private static final String SUB_CLAIM = "sub"; + private static final String AUD_CLAIM = "aud"; + private static final String ISS_CLAIM = "iss"; + private static final String JTI_CLAIM = "jti"; private static final boolean ACTIVE_VALUE = true; + private static final String CLIENT_ID_VALUE = "client-id-1"; + private static final String USERNAME_VALUE = "username-1"; + private static final String TOKEN_TYPE_VALUE = "token-type-1"; + private static final long EXP_VALUE = Instant.now().plusSeconds(60).getEpochSecond(); + private static final long IAT_VALUE = Instant.now().getEpochSecond(); + private static final long NBF_VALUE = Instant.now().plusSeconds(5).getEpochSecond(); + private static final String SUB_VALUE = "subject1"; + private static final List AUD_VALUE = Arrays.asList("aud1", "aud2"); + private static final String ISS_VALUE = "https://provider.com"; + private static final String JTI_VALUE = "jwt-id-1"; private static final Map CLAIMS; - static { CLAIMS = new HashMap<>(); CLAIMS.put(ACTIVE_CLAIM, ACTIVE_VALUE); @@ -86,36 +108,32 @@ public class OAuth2IntrospectionAuthenticatedPrincipalTests { @Test public void constructorWhenAttributesIsNullOrEmptyThenIllegalArgumentException() { - assertThatCode(() -> new OAuth2IntrospectionAuthenticatedPrincipal(null, AUTHORITIES)) - .isInstanceOf(IllegalArgumentException.class); - - assertThatCode(() -> new OAuth2IntrospectionAuthenticatedPrincipal(Collections.emptyMap(), AUTHORITIES)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2IntrospectionAuthenticatedPrincipal(null, AUTHORITIES)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2IntrospectionAuthenticatedPrincipal(Collections.emptyMap(), AUTHORITIES)); } @Test public void constructorWhenAuthoritiesIsNullOrEmptyThenNoAuthorities() { - Collection authorities = - new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, null).getAuthorities(); + Collection authorities = new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, null) + .getAuthorities(); assertThat(authorities).isEmpty(); - - authorities = new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, - Collections.emptyList()).getAuthorities(); + authorities = new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, Collections.emptyList()).getAuthorities(); assertThat(authorities).isEmpty(); } @Test public void constructorWhenNameIsNullThenFallsbackToSubAttribute() { - OAuth2AuthenticatedPrincipal principal = - new OAuth2IntrospectionAuthenticatedPrincipal(null, CLAIMS, AUTHORITIES); + OAuth2AuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal(null, CLAIMS, + AUTHORITIES); assertThat(principal.getName()).isEqualTo(CLAIMS.get(SUB_CLAIM)); } @Test public void constructorWhenAttributesAuthoritiesProvidedThenCreated() { - OAuth2IntrospectionAuthenticatedPrincipal principal = - new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, AUTHORITIES); - + OAuth2IntrospectionAuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, + AUTHORITIES); assertThat(principal.getName()).isEqualTo(CLAIMS.get(SUB_CLAIM)); assertThat(principal.getAttributes()).isEqualTo(CLAIMS); assertThat(principal.getClaims()).isEqualTo(CLAIMS); @@ -136,9 +154,8 @@ public class OAuth2IntrospectionAuthenticatedPrincipalTests { @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2IntrospectionAuthenticatedPrincipal principal = - new OAuth2IntrospectionAuthenticatedPrincipal(SUBJECT, CLAIMS, AUTHORITIES); - + OAuth2IntrospectionAuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal(SUBJECT, + CLAIMS, AUTHORITIES); assertThat(principal.getName()).isEqualTo(SUBJECT); assertThat(principal.getAttributes()).isEqualTo(CLAIMS); assertThat(principal.getClaims()).isEqualTo(CLAIMS); @@ -159,26 +176,30 @@ public class OAuth2IntrospectionAuthenticatedPrincipalTests { @Test public void getNameWhenInConstructorThenReturns() { - OAuth2AuthenticatedPrincipal principal = - new OAuth2IntrospectionAuthenticatedPrincipal(SUB_VALUE, CLAIMS, AUTHORITIES); + OAuth2AuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal(SUB_VALUE, CLAIMS, + AUTHORITIES); assertThat(principal.getName()).isEqualTo(SUB_VALUE); } @Test public void getAttributeWhenGivenKeyThenReturnsValue() { - OAuth2AuthenticatedPrincipal principal = - new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, AUTHORITIES); - - assertThat((Object) principal.getAttribute(ACTIVE_CLAIM)).isEqualTo(ACTIVE_VALUE); - assertThat((Object) principal.getAttribute(CLIENT_ID_CLAIM)).isEqualTo(CLIENT_ID_VALUE); - assertThat((Object) principal.getAttribute(USERNAME_CLAIM)).isEqualTo(USERNAME_VALUE); - assertThat((Object) principal.getAttribute(TOKEN_TYPE_CLAIM)).isEqualTo(TOKEN_TYPE_VALUE); - assertThat((Object) principal.getAttribute(EXP_CLAIM)).isEqualTo(EXP_VALUE); - assertThat((Object) principal.getAttribute(IAT_CLAIM)).isEqualTo(IAT_VALUE); - assertThat((Object) principal.getAttribute(NBF_CLAIM)).isEqualTo(NBF_VALUE); - assertThat((Object) principal.getAttribute(SUB_CLAIM)).isEqualTo(SUB_VALUE); - assertThat((Object) principal.getAttribute(AUD_CLAIM)).isEqualTo(AUD_VALUE); - assertThat((Object) principal.getAttribute(ISS_CLAIM)).isEqualTo(ISS_VALUE); - assertThat((Object) principal.getAttribute(JTI_CLAIM)).isEqualTo(JTI_VALUE); + OAuth2AuthenticatedPrincipal principal = new OAuth2IntrospectionAuthenticatedPrincipal(CLAIMS, AUTHORITIES); + assertHasEqualAttribute(principal, ACTIVE_CLAIM, ACTIVE_VALUE); + assertHasEqualAttribute(principal, CLIENT_ID_CLAIM, CLIENT_ID_VALUE); + assertHasEqualAttribute(principal, USERNAME_CLAIM, USERNAME_VALUE); + assertHasEqualAttribute(principal, TOKEN_TYPE_CLAIM, TOKEN_TYPE_VALUE); + assertHasEqualAttribute(principal, EXP_CLAIM, EXP_VALUE); + assertHasEqualAttribute(principal, IAT_CLAIM, IAT_VALUE); + assertHasEqualAttribute(principal, NBF_CLAIM, NBF_VALUE); + assertHasEqualAttribute(principal, SUB_CLAIM, SUB_VALUE); + assertHasEqualAttribute(principal, AUD_CLAIM, AUD_VALUE); + assertHasEqualAttribute(principal, ISS_CLAIM, ISS_VALUE); + assertHasEqualAttribute(principal, JTI_CLAIM, JTI_VALUE); } + + private void assertHasEqualAttribute(OAuth2AuthenticatedPrincipal principal, String name, Object expected) { + Object value = principal.getAttribute(name); + assertThat(value).isEqualTo(expected); + } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPointTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPointTests.java index fc5bca4482..93e112e268 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPointTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationEntryPointTests.java @@ -28,7 +28,6 @@ import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; /** * Tests for {@link BearerTokenAuthenticationEntryPoint}. @@ -46,129 +45,88 @@ public class BearerTokenAuthenticationEntryPointTests { } @Test - public void commenceWhenNoBearerTokenErrorThenStatus401AndAuthHeader() - throws Exception { - + public void commenceWhenNoBearerTokenErrorThenStatus401AndAuthHeader() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - this.authenticationEntryPoint.commence(request, response, new BadCredentialsException("test")); - assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer"); } @Test - public void commenceWhenNoBearerTokenErrorAndRealmSetThenStatus401AndAuthHeaderWithRealm() - throws Exception { - + public void commenceWhenNoBearerTokenErrorAndRealmSetThenStatus401AndAuthHeaderWithRealm() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - this.authenticationEntryPoint.setRealmName("test"); this.authenticationEntryPoint.commence(request, response, new BadCredentialsException("test")); - assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer realm=\"test\""); } @Test - public void commenceWhenInvalidRequestErrorThenStatus400AndHeaderWithError() - throws Exception { - + public void commenceWhenInvalidRequestErrorThenStatus400AndHeaderWithError() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - BearerTokenError error = new BearerTokenError( - BearerTokenErrorCodes.INVALID_REQUEST, - HttpStatus.BAD_REQUEST, - null, - null); - - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_REQUEST, HttpStatus.BAD_REQUEST, + null, null); + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer error=\"invalid_request\""); } @Test - public void commenceWhenInvalidRequestErrorThenStatus400AndHeaderWithErrorDetails() - throws Exception { - + public void commenceWhenInvalidRequestErrorThenStatus400AndHeaderWithErrorDetails() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_REQUEST, HttpStatus.BAD_REQUEST, "The access token expired", null, null); - - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getHeader("WWW-Authenticate")) .isEqualTo("Bearer error=\"invalid_request\", error_description=\"The access token expired\""); } @Test - public void commenceWhenInvalidRequestErrorThenStatus400AndHeaderWithErrorUri() - throws Exception { - + public void commenceWhenInvalidRequestErrorThenStatus400AndHeaderWithErrorUri() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_REQUEST, HttpStatus.BAD_REQUEST, null, "https://example.com", null); - - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getHeader("WWW-Authenticate")) .isEqualTo("Bearer error=\"invalid_request\", error_uri=\"https://example.com\""); } @Test - public void commenceWhenInvalidTokenErrorThenStatus401AndHeaderWithError() - throws Exception { - + public void commenceWhenInvalidTokenErrorThenStatus401AndHeaderWithError() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED, null, null); - - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer error=\"invalid_token\""); } @Test - public void commenceWhenInsufficientScopeErrorThenStatus403AndHeaderWithError() - throws Exception { - + public void commenceWhenInsufficientScopeErrorThenStatus403AndHeaderWithError() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN, null, null); - - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer error=\"insufficient_scope\""); } @Test - public void commenceWhenInsufficientScopeErrorThenStatus403AndHeaderWithErrorAndScope() - throws Exception { - + public void commenceWhenInsufficientScopeErrorThenStatus403AndHeaderWithErrorAndScope() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN, null, null, "test.read test.write"); - - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getHeader("WWW-Authenticate")) .isEqualTo("Bearer error=\"insufficient_scope\", scope=\"test.read test.write\""); @@ -177,16 +135,12 @@ public class BearerTokenAuthenticationEntryPointTests { @Test public void commenceWhenInsufficientScopeAndRealmSetThenStatus403AndHeaderWithErrorAndAllDetails() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN, "Insufficient scope", "https://example.com", "test.read test.write"); - this.authenticationEntryPoint.setRealmName("test"); - this.authenticationEntryPoint.commence(request, response, - new OAuth2AuthenticationException(error)); - + this.authenticationEntryPoint.commence(request, response, new OAuth2AuthenticationException(error)); assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo( "Bearer realm=\"test\", error=\"insufficient_scope\", error_description=\"Insufficient scope\", " @@ -195,8 +149,7 @@ public class BearerTokenAuthenticationEntryPointTests { @Test public void setRealmNameWhenNullRealmNameThenNoExceptionThrown() { - assertThatCode(() -> this.authenticationEntryPoint.setRealmName(null)) - .doesNotThrowAnyException(); + this.authenticationEntryPoint.setRealmName(null); } } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java index 25339f031b..90a04ac322 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.oauth2.server.resource.web; import java.io.IOException; + import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -40,11 +42,11 @@ import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; /** * Tests {@link BearerTokenAuthenticationFilterTests} @@ -53,6 +55,7 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class BearerTokenAuthenticationFilterTests { + @Mock AuthenticationEntryPoint authenticationEntryPoint; @@ -83,138 +86,110 @@ public class BearerTokenAuthenticationFilterTests { @Test public void doFilterWhenBearerTokenPresentThenAuthenticates() throws ServletException, IOException { - when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token"); - - BearerTokenAuthenticationFilter filter = - addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager)); + given(this.bearerTokenResolver.resolve(this.request)).willReturn("token"); + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); filter.doFilter(this.request, this.response, this.filterChain); - - ArgumentCaptor captor = - ArgumentCaptor.forClass(BearerTokenAuthenticationToken.class); - + ArgumentCaptor captor = ArgumentCaptor + .forClass(BearerTokenAuthenticationToken.class); verify(this.authenticationManager).authenticate(captor.capture()); - assertThat(captor.getValue().getPrincipal()).isEqualTo("token"); } @Test public void doFilterWhenUsingAuthenticationManagerResolverThenAuthenticates() throws Exception { - BearerTokenAuthenticationFilter filter = - addMocks(new BearerTokenAuthenticationFilter(this.authenticationManagerResolver)); - - when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token"); - when(this.authenticationManagerResolver.resolve(any())).thenReturn(this.authenticationManager); - + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManagerResolver)); + given(this.bearerTokenResolver.resolve(this.request)).willReturn("token"); + given(this.authenticationManagerResolver.resolve(any())).willReturn(this.authenticationManager); filter.doFilter(this.request, this.response, this.filterChain); - - ArgumentCaptor captor = - ArgumentCaptor.forClass(BearerTokenAuthenticationToken.class); - + ArgumentCaptor captor = ArgumentCaptor + .forClass(BearerTokenAuthenticationToken.class); verify(this.authenticationManager).authenticate(captor.capture()); - assertThat(captor.getValue().getPrincipal()).isEqualTo("token"); } @Test - public void doFilterWhenNoBearerTokenPresentThenDoesNotAuthenticate() - throws ServletException, IOException { - - when(this.bearerTokenResolver.resolve(this.request)).thenReturn(null); - + public void doFilterWhenNoBearerTokenPresentThenDoesNotAuthenticate() throws ServletException, IOException { + given(this.bearerTokenResolver.resolve(this.request)).willReturn(null); dontAuthenticate(); } @Test public void doFilterWhenMalformedBearerTokenThenPropagatesError() throws ServletException, IOException { - BearerTokenError error = new BearerTokenError( - BearerTokenErrorCodes.INVALID_REQUEST, - HttpStatus.BAD_REQUEST, - "description", - "uri"); - + BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_REQUEST, HttpStatus.BAD_REQUEST, + "description", "uri"); OAuth2AuthenticationException exception = new OAuth2AuthenticationException(error); - - when(this.bearerTokenResolver.resolve(this.request)).thenThrow(exception); - + given(this.bearerTokenResolver.resolve(this.request)).willThrow(exception); dontAuthenticate(); - verify(this.authenticationEntryPoint).commence(this.request, this.response, exception); } @Test - public void doFilterWhenAuthenticationFailsWithDefaultHandlerThenPropagatesError() throws ServletException, IOException { - BearerTokenError error = new BearerTokenError( - BearerTokenErrorCodes.INVALID_TOKEN, - HttpStatus.UNAUTHORIZED, - "description", - "uri" - ); - + public void doFilterWhenAuthenticationFailsWithDefaultHandlerThenPropagatesError() + throws ServletException, IOException { + BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED, + "description", "uri"); OAuth2AuthenticationException exception = new OAuth2AuthenticationException(error); - - when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token"); - when(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class))) - .thenThrow(exception); - - BearerTokenAuthenticationFilter filter = - addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager)); + given(this.bearerTokenResolver.resolve(this.request)).willReturn("token"); + given(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class))).willThrow(exception); + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); filter.doFilter(this.request, this.response, this.filterChain); - verify(this.authenticationEntryPoint).commence(this.request, this.response, exception); } @Test - public void doFilterWhenAuthenticationFailsWithCustomHandlerThenPropagatesError() throws ServletException, IOException { - BearerTokenError error = new BearerTokenError( - BearerTokenErrorCodes.INVALID_TOKEN, - HttpStatus.UNAUTHORIZED, - "description", - "uri" - ); - + public void doFilterWhenAuthenticationFailsWithCustomHandlerThenPropagatesError() + throws ServletException, IOException { + BearerTokenError error = new BearerTokenError(BearerTokenErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED, + "description", "uri"); OAuth2AuthenticationException exception = new OAuth2AuthenticationException(error); - - when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token"); - when(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class))) - .thenThrow(exception); - - BearerTokenAuthenticationFilter filter = - addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager)); + given(this.bearerTokenResolver.resolve(this.request)).willReturn("token"); + given(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class))).willThrow(exception); + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); filter.setAuthenticationFailureHandler(this.authenticationFailureHandler); filter.doFilter(this.request, this.response, this.filterChain); - verify(this.authenticationFailureHandler).onAuthenticationFailure(this.request, this.response, exception); } @Test public void setAuthenticationEntryPointWhenNullThenThrowsException() { BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager); - assertThatCode(() -> filter.setAuthenticationEntryPoint(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("authenticationEntryPoint cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> filter.setAuthenticationEntryPoint(null)) + .withMessageContaining("authenticationEntryPoint cannot be null"); + // @formatter:on } @Test public void setBearerTokenResolverWhenNullThenThrowsException() { BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager); - assertThatCode(() -> filter.setBearerTokenResolver(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("bearerTokenResolver cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> filter.setBearerTokenResolver(null)) + .withMessageContaining("bearerTokenResolver cannot be null"); + // @formatter:on } @Test public void constructorWhenNullAuthenticationManagerThenThrowsException() { - assertThatCode(() -> new BearerTokenAuthenticationFilter((AuthenticationManager) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("authenticationManager cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenAuthenticationFilter((AuthenticationManager) null)) + .withMessageContaining("authenticationManager cannot be null"); + // @formatter:on } @Test public void constructorWhenNullAuthenticationManagerResolverThenThrowsException() { - assertThatCode(() -> - new BearerTokenAuthenticationFilter((AuthenticationManagerResolver) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("authenticationManagerResolver cannot be null"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new BearerTokenAuthenticationFilter((AuthenticationManagerResolver) null)) + .withMessageContaining("authenticationManagerResolver cannot be null"); + // @formatter:on } private BearerTokenAuthenticationFilter addMocks(BearerTokenAuthenticationFilter filter) { @@ -223,13 +198,11 @@ public class BearerTokenAuthenticationFilterTests { return filter; } - private void dontAuthenticate() - throws ServletException, IOException { - - BearerTokenAuthenticationFilter filter = - addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager)); + private void dontAuthenticate() throws ServletException, IOException { + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); filter.doFilter(this.request, this.response, this.filterChain); - verifyNoMoreInteractions(this.authenticationManager); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolverTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolverTests.java index d8584c3f88..e80ec201af 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolverTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/DefaultBearerTokenResolverTests.java @@ -25,7 +25,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link DefaultBearerTokenResolver}. @@ -33,7 +33,9 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Vedran Pavic */ public class DefaultBearerTokenResolverTests { + private static final String CUSTOM_HEADER = "custom-header"; + private static final String TEST_TOKEN = "test-token"; private DefaultBearerTokenResolver resolver; @@ -47,7 +49,6 @@ public class DefaultBearerTokenResolverTests { public void resolveWhenValidHeaderIsPresentThenTokenIsResolved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isEqualTo(TEST_TOKEN); } @@ -57,7 +58,6 @@ public class DefaultBearerTokenResolverTests { String token = TEST_TOKEN + "=="; MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer " + token); - assertThat(this.resolver.resolve(request)).isEqualTo(token); } @@ -66,7 +66,6 @@ public class DefaultBearerTokenResolverTests { this.resolver.setBearerTokenHeaderName(CUSTOM_HEADER); MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader(CUSTOM_HEADER, "Bearer " + TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isEqualTo(TEST_TOKEN); } @@ -74,14 +73,12 @@ public class DefaultBearerTokenResolverTests { public void resolveWhenLowercaseHeaderIsPresentThenTokenIsResolved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("authorization", "bearer " + TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isEqualTo(TEST_TOKEN); } @Test public void resolveWhenNoHeaderIsPresentThenTokenIsNotResolved() { MockHttpServletRequest request = new MockHttpServletRequest(); - assertThat(this.resolver.resolve(request)).isNull(); } @@ -89,7 +86,6 @@ public class DefaultBearerTokenResolverTests { public void resolveWhenHeaderWithWrongSchemeIsPresentThenTokenIsNotResolved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + Base64.getEncoder().encodeToString("test:test".getBytes())); - assertThat(this.resolver.resolve(request)).isNull(); } @@ -97,18 +93,16 @@ public class DefaultBearerTokenResolverTests { public void resolveWhenHeaderWithMissingTokenIsPresentThenAuthenticationExceptionIsThrown() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer "); - - assertThatCode(() -> this.resolver.resolve(request)).isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining(("Bearer token is malformed")); + assertThatExceptionOfType(OAuth2AuthenticationException.class).isThrownBy(() -> this.resolver.resolve(request)) + .withMessageContaining(("Bearer token is malformed")); } @Test public void resolveWhenHeaderWithInvalidCharactersIsPresentThenAuthenticationExceptionIsThrown() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Bearer an\"invalid\"token"); - - assertThatCode(() -> this.resolver.resolve(request)).isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining(("Bearer token is malformed")); + assertThatExceptionOfType(OAuth2AuthenticationException.class).isThrownBy(() -> this.resolver.resolve(request)) + .withMessageContaining(("Bearer token is malformed")); } @Test @@ -118,9 +112,8 @@ public class DefaultBearerTokenResolverTests { request.setMethod("POST"); request.setContentType("application/x-www-form-urlencoded"); request.addParameter("access_token", TEST_TOKEN); - - assertThatCode(() -> this.resolver.resolve(request)).isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Found multiple bearer tokens in the request"); + assertThatExceptionOfType(OAuth2AuthenticationException.class).isThrownBy(() -> this.resolver.resolve(request)) + .withMessageContaining("Found multiple bearer tokens in the request"); } @Test @@ -129,29 +122,25 @@ public class DefaultBearerTokenResolverTests { request.addHeader("Authorization", "Bearer " + TEST_TOKEN); request.setMethod("GET"); request.addParameter("access_token", TEST_TOKEN); - - assertThatCode(() -> this.resolver.resolve(request)).isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Found multiple bearer tokens in the request"); + assertThatExceptionOfType(OAuth2AuthenticationException.class).isThrownBy(() -> this.resolver.resolve(request)) + .withMessageContaining("Found multiple bearer tokens in the request"); } @Test public void resolveWhenRequestContainsTwoAccessTokenParametersThenAuthenticationExceptionIsThrown() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addParameter("access_token", "token1", "token2"); - - assertThatCode(() -> this.resolver.resolve(request)).isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Found multiple bearer tokens in the request"); + assertThatExceptionOfType(OAuth2AuthenticationException.class).isThrownBy(() -> this.resolver.resolve(request)) + .withMessageContaining("Found multiple bearer tokens in the request"); } @Test public void resolveWhenFormParameterIsPresentAndSupportedThenTokenIsResolved() { this.resolver.setAllowFormEncodedBodyParameter(true); - MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("POST"); request.setContentType("application/x-www-form-urlencoded"); request.addParameter("access_token", TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isEqualTo(TEST_TOKEN); } @@ -161,18 +150,15 @@ public class DefaultBearerTokenResolverTests { request.setMethod("POST"); request.setContentType("application/x-www-form-urlencoded"); request.addParameter("access_token", TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isNull(); } @Test public void resolveWhenQueryParameterIsPresentAndSupportedThenTokenIsResolved() { this.resolver.setAllowUriQueryParameter(true); - MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); request.addParameter("access_token", TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isEqualTo(TEST_TOKEN); } @@ -181,7 +167,7 @@ public class DefaultBearerTokenResolverTests { MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); request.addParameter("access_token", TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isNull(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolverTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolverTests.java index ca8d047873..075af21b7b 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolverTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/HeaderBearerTokenResolverTests.java @@ -21,7 +21,7 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link HeaderBearerTokenResolver} @@ -38,30 +38,33 @@ public class HeaderBearerTokenResolverTests { @Test public void constructorWhenHeaderNullThenThrowIllegalArgumentException() { - assertThatCode(() -> { new HeaderBearerTokenResolver(null); }) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("header cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new HeaderBearerTokenResolver(null)) + .withMessage("header cannot be empty"); + // @formatter:on } @Test public void constructorWhenHeaderEmptyThenThrowIllegalArgumentException() { - assertThatCode(() -> { new HeaderBearerTokenResolver(""); }) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("header cannot be empty"); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new HeaderBearerTokenResolver("")) + .withMessage("header cannot be empty"); + // @formatter:on } @Test public void resolveWhenTokenPresentThenTokenIsResolved() { MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader(CORRECT_HEADER, TEST_TOKEN); - assertThat(this.resolver.resolve(request)).isEqualTo(TEST_TOKEN); } @Test public void resolveWhenTokenNotPresentThenTokenIsNotResolved() { MockHttpServletRequest request = new MockHttpServletRequest(); - assertThat(this.resolver.resolve(request)).isNull(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java index a4da50ea00..0aa0aaf600 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/MockExchangeFunction.java @@ -32,6 +32,7 @@ import static org.mockito.Mockito.mock; * @since 5.1 */ public class MockExchangeFunction implements ExchangeFunction { + private List requests = new ArrayList<>(); private ClientResponse response = mock(ClientResponse.class); @@ -55,4 +56,5 @@ public class MockExchangeFunction implements ExchangeFunction { return Mono.just(this.response); }); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandlerTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandlerTests.java index 3799e77c4d..fd481e8f96 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandlerTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/BearerTokenAccessDeniedHandlerTests.java @@ -16,8 +16,12 @@ package org.springframework.security.oauth2.server.resource.web.access; +import java.util.Collections; +import java.util.Map; + import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -25,11 +29,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; -import java.util.Collections; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; /** * Tests for {@link BearerTokenAccessDeniedHandlerTests} @@ -37,6 +37,7 @@ import static org.assertj.core.api.Assertions.assertThatCode; * @author Josh Cummings */ public class BearerTokenAccessDeniedHandlerTests { + private BearerTokenAccessDeniedHandler accessDeniedHandler; @Before @@ -45,34 +46,24 @@ public class BearerTokenAccessDeniedHandlerTests { } @Test - public void handleWhenNotOAuth2AuthenticatedThenStatus403() - throws Exception { - + public void handleWhenNotOAuth2AuthenticatedThenStatus403() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - Authentication authentication = new TestingAuthenticationToken("user", "pass"); request.setUserPrincipal(authentication); - this.accessDeniedHandler.handle(request, response, null); - assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer"); } @Test - public void handleWhenNotOAuth2AuthenticatedAndRealmSetThenStatus403AndAuthHeaderWithRealm() - throws Exception { - + public void handleWhenNotOAuth2AuthenticatedAndRealmSetThenStatus403AndAuthHeaderWithRealm() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - Authentication authentication = new TestingAuthenticationToken("user", "pass"); request.setUserPrincipal(authentication); - this.accessDeniedHandler.setRealmName("test"); this.accessDeniedHandler.handle(request, response, null); - assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer realm=\"test\""); } @@ -80,25 +71,23 @@ public class BearerTokenAccessDeniedHandlerTests { @Test public void handleWhenOAuth2AuthenticatedThenStatus403AndAuthHeaderWithInsufficientScopeErrorAttribute() throws Exception { - MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); - Authentication token = new TestingOAuth2TokenAuthenticationToken(Collections.emptyMap()); request.setUserPrincipal(token); - this.accessDeniedHandler.handle(request, response, null); - assertThat(response.getStatus()).isEqualTo(403); - assertThat(response.getHeader("WWW-Authenticate")).isEqualTo("Bearer error=\"insufficient_scope\", " + - "error_description=\"The request requires higher privileges than provided by the access token.\", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""); + // @formatter:off + assertThat(response.getHeader("WWW-Authenticate")) + .isEqualTo("Bearer error=\"insufficient_scope\", " + + "error_description=\"The request requires higher privileges than provided by the access token.\", " + + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""); + // @formatter:on } @Test public void setRealmNameWhenNullRealmNameThenNoExceptionThrown() { - assertThatCode(() -> this.accessDeniedHandler.setRealmName(null)) - .doesNotThrowAnyException(); + this.accessDeniedHandler.setRealmName(null); } static class TestingOAuth2TokenAuthenticationToken @@ -117,9 +106,13 @@ public class BearerTokenAccessDeniedHandlerTests { } static class TestingOAuth2Token extends AbstractOAuth2Token { + TestingOAuth2Token(String tokenValue) { super(tokenValue); } + } + } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandlerTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandlerTests.java index c4fb17ee65..20f9a47d5f 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandlerTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/access/server/BearerTokenServerAccessDeniedHandlerTests.java @@ -16,8 +16,14 @@ package org.springframework.security.oauth2.server.resource.web.access.server; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpStatus; import org.springframework.mock.http.server.reactive.MockServerHttpResponse; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -25,18 +31,13 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.Arrays; -import java.util.Collections; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class BearerTokenServerAccessDeniedHandlerTests { + private BearerTokenServerAccessDeniedHandler accessDeniedHandler; @Before @@ -46,56 +47,47 @@ public class BearerTokenServerAccessDeniedHandlerTests { @Test public void handleWhenNotOAuth2AuthenticatedThenStatus403() { - Authentication token = new TestingAuthenticationToken("user", "pass"); ServerWebExchange exchange = mock(ServerWebExchange.class); - when(exchange.getPrincipal()).thenReturn(Mono.just(token)); - when(exchange.getResponse()).thenReturn(new MockServerHttpResponse()); - + given(exchange.getPrincipal()).willReturn(Mono.just(token)); + given(exchange.getResponse()).willReturn(new MockServerHttpResponse()); this.accessDeniedHandler.handle(exchange, null).block(); - assertThat(exchange.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); - assertThat(exchange.getResponse().getHeaders().get("WWW-Authenticate")).isEqualTo( - Arrays.asList("Bearer")); + assertThat(exchange.getResponse().getHeaders().get("WWW-Authenticate")).isEqualTo(Arrays.asList("Bearer")); } @Test public void handleWhenNotOAuth2AuthenticatedAndRealmSetThenStatus403AndAuthHeaderWithRealm() { - Authentication token = new TestingAuthenticationToken("user", "pass"); ServerWebExchange exchange = mock(ServerWebExchange.class); - when(exchange.getPrincipal()).thenReturn(Mono.just(token)); - when(exchange.getResponse()).thenReturn(new MockServerHttpResponse()); - + given(exchange.getPrincipal()).willReturn(Mono.just(token)); + given(exchange.getResponse()).willReturn(new MockServerHttpResponse()); this.accessDeniedHandler.setRealmName("test"); this.accessDeniedHandler.handle(exchange, null).block(); - assertThat(exchange.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); - assertThat(exchange.getResponse().getHeaders().get("WWW-Authenticate")).isEqualTo( - Arrays.asList("Bearer realm=\"test\"")); + assertThat(exchange.getResponse().getHeaders().get("WWW-Authenticate")) + .isEqualTo(Arrays.asList("Bearer realm=\"test\"")); } @Test public void handleWhenOAuth2AuthenticatedThenStatus403AndAuthHeaderWithInsufficientScopeErrorAttribute() { - Authentication token = new TestingOAuth2TokenAuthenticationToken(Collections.emptyMap()); ServerWebExchange exchange = mock(ServerWebExchange.class); - when(exchange.getPrincipal()).thenReturn(Mono.just(token)); - when(exchange.getResponse()).thenReturn(new MockServerHttpResponse()); - + given(exchange.getPrincipal()).willReturn(Mono.just(token)); + given(exchange.getResponse()).willReturn(new MockServerHttpResponse()); this.accessDeniedHandler.handle(exchange, null).block(); - assertThat(exchange.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); - assertThat(exchange.getResponse().getHeaders().get("WWW-Authenticate")).isEqualTo( - Arrays.asList("Bearer error=\"insufficient_scope\", " + - "error_description=\"The request requires higher privileges than provided by the access token.\", " + - "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"")); + // @formatter:off + assertThat(exchange.getResponse().getHeaders().get("WWW-Authenticate")) + .isEqualTo(Arrays.asList("Bearer error=\"insufficient_scope\", " + + "error_description=\"The request requires higher privileges than provided by the access token.\", " + + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"")); + // @formatter:on } @Test public void setRealmNameWhenNullRealmNameThenNoExceptionThrown() { - assertThatCode(() -> this.accessDeniedHandler.setRealmName(null)) - .doesNotThrowAnyException(); + this.accessDeniedHandler.setRealmName(null); } static class TestingOAuth2TokenAuthenticationToken @@ -114,9 +106,13 @@ public class BearerTokenServerAccessDeniedHandlerTests { } static class TestingOAuth2Token extends AbstractOAuth2Token { + TestingOAuth2Token(String tokenValue) { super(tokenValue); } + } + } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java index 22bdb72bcd..bafa9ed4c8 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServerBearerExchangeFilterFunctionTests.java @@ -25,6 +25,7 @@ import java.util.Map; import org.junit.Test; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; @@ -34,7 +35,6 @@ import org.springframework.security.oauth2.server.resource.web.MockExchangeFunct import org.springframework.web.reactive.function.client.ClientRequest; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.HttpMethod.GET; /** * Tests for {@link ServerBearerExchangeFilterFunction} @@ -42,15 +42,16 @@ import static org.springframework.http.HttpMethod.GET; * @author Josh Cummings */ public class ServerBearerExchangeFilterFunctionTests { + private ServerBearerExchangeFilterFunction function = new ServerBearerExchangeFilterFunction(); private MockExchangeFunction exchange = new MockExchangeFunction(); - private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token-0", - Instant.now(), - Instant.now().plus(Duration.ofDays(1))); - private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken(accessToken) { + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token-0", + Instant.now(), Instant.now().plus(Duration.ofDays(1))); + + private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken( + this.accessToken) { @Override public Map getTokenAttributes() { return Collections.emptyMap(); @@ -59,23 +60,16 @@ public class ServerBearerExchangeFilterFunctionTests { @Test public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); this.function.filter(request, this.exchange).block(); - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); } @Test public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); this.function.filter(request, this.exchange) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) - .block(); - + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)).block(); assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } @@ -83,29 +77,21 @@ public class ServerBearerExchangeFilterFunctionTests { // gh-7353 @Test public void filterWhenAuthenticatedWithOtherTokenThenAuthorizationHeaderNull() throws Exception { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); TestingAuthenticationToken token = new TestingAuthenticationToken("user", "pass"); this.function.filter(request, this.exchange) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(token)) - .block(); - - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) - .isNull(); + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(token)).block(); + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); } @Test public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .header(HttpHeaders.AUTHORIZATION, "Existing") - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .header(HttpHeaders.AUTHORIZATION, "Existing").build(); this.function.filter(request, this.exchange) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) - .block(); - + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)).block(); HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java index e5fd79b339..54b4f164d8 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/reactive/function/client/ServletBearerExchangeFilterFunctionTests.java @@ -29,6 +29,7 @@ import org.mockito.junit.MockitoJUnitRunner; import reactor.util.context.Context; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -37,8 +38,6 @@ import org.springframework.security.oauth2.server.resource.web.MockExchangeFunct import org.springframework.web.reactive.function.client.ClientRequest; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.HttpMethod.GET; -import static org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY; /** * Tests for {@link ServletBearerExchangeFilterFunction} @@ -47,16 +46,16 @@ import static org.springframework.security.oauth2.server.resource.web.reactive.f */ @RunWith(MockitoJUnitRunner.class) public class ServletBearerExchangeFilterFunctionTests { + private ServletBearerExchangeFilterFunction function = new ServletBearerExchangeFilterFunction(); private MockExchangeFunction exchange = new MockExchangeFunction(); - private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token-0", - Instant.now(), - Instant.now().plus(Duration.ofDays(1))); + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token-0", + Instant.now(), Instant.now().plus(Duration.ofDays(1))); - private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken(accessToken) { + private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken( + this.accessToken) { @Override public Map getTokenAttributes() { return Collections.emptyMap(); @@ -65,53 +64,33 @@ public class ServletBearerExchangeFilterFunctionTests { @Test public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); this.function.filter(request, this.exchange).block(); - - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) - .isNull(); + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); } // gh-7353 @Test public void filterWhenAuthenticatedWithOtherTokenThenAuthorizationHeaderNull() { TestingAuthenticationToken token = new TestingAuthenticationToken("user", "pass"); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(context(token)) - .block(); - - assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) - .isNull(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); + this.function.filter(request, this.exchange).subscriberContext(context(token)).block(); + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); } @Test public void filterWhenAuthenticatedThenAuthorizationHeader() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(context(this.authentication)) - .block(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); + this.function.filter(request, this.exchange).subscriberContext(context(this.authentication)).block(); assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) .isEqualTo("Bearer " + this.accessToken.getTokenValue()); } @Test public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) - .header(HttpHeaders.AUTHORIZATION, "Existing") - .build(); - - this.function.filter(request, this.exchange) - .subscriberContext(context(this.authentication)) - .block(); - + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .header(HttpHeaders.AUTHORIZATION, "Existing").build(); + this.function.filter(request, this.exchange).subscriberContext(context(this.authentication)).block(); HttpHeaders headers = this.exchange.getRequest().headers(); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } @@ -119,6 +98,8 @@ public class ServletBearerExchangeFilterFunctionTests { private Context context(Authentication authentication) { Map, Object> contextAttributes = new HashMap<>(); contextAttributes.put(Authentication.class, authentication); - return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes); + return Context.of(ServletBearerExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, + contextAttributes); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPointTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPointTests.java index 722ff98608..b3d76dca21 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPointTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/BearerTokenServerAuthenticationEntryPointTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.resource.web.server; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; @@ -28,13 +29,14 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.server.resource.BearerTokenError; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch * @since 5.1 */ public class BearerTokenServerAuthenticationEntryPointTests { + private BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); @@ -42,7 +44,6 @@ public class BearerTokenServerAuthenticationEntryPointTests { @Test public void commenceWhenNotOAuth2AuthenticationExceptionThenBearer() { this.entryPoint.commence(this.exchange, new BadCredentialsException("")).block(); - assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Bearer"); assertThat(getResponse().getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); } @@ -50,10 +51,9 @@ public class BearerTokenServerAuthenticationEntryPointTests { @Test public void commenceWhenRealmNameThenHasRealmName() { this.entryPoint.setRealmName("Realm"); - this.entryPoint.commence(this.exchange, new BadCredentialsException("")).block(); - - assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Bearer realm=\"Realm\""); + assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)) + .isEqualTo("Bearer realm=\"Realm\""); assertThat(getResponse().getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); } @@ -61,10 +61,9 @@ public class BearerTokenServerAuthenticationEntryPointTests { public void commenceWhenOAuth2AuthenticationExceptionThenContainsErrorInformation() { OAuth2Error oauthError = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST); OAuth2AuthenticationException exception = new OAuth2AuthenticationException(oauthError); - this.entryPoint.commence(this.exchange, exception).block(); - - assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Bearer error=\"invalid_request\""); + assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)) + .isEqualTo("Bearer error=\"invalid_request\""); assertThat(getResponse().getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); } @@ -72,29 +71,26 @@ public class BearerTokenServerAuthenticationEntryPointTests { public void commenceWhenOAuth2ErrorCompleteThenContainsErrorInformation() { OAuth2Error oauthError = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, "Oops", "https://example.com"); OAuth2AuthenticationException exception = new OAuth2AuthenticationException(oauthError); - this.entryPoint.commence(this.exchange, exception).block(); - - assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Bearer error=\"invalid_request\", error_description=\"Oops\", error_uri=\"https://example.com\""); + assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo( + "Bearer error=\"invalid_request\", error_description=\"Oops\", error_uri=\"https://example.com\""); assertThat(getResponse().getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); } @Test public void commenceWhenBearerTokenThenErrorInformation() { - OAuth2Error oauthError = new BearerTokenError(OAuth2ErrorCodes.INVALID_REQUEST, - HttpStatus.BAD_REQUEST, "Oops", "https://example.com"); + OAuth2Error oauthError = new BearerTokenError(OAuth2ErrorCodes.INVALID_REQUEST, HttpStatus.BAD_REQUEST, "Oops", + "https://example.com"); OAuth2AuthenticationException exception = new OAuth2AuthenticationException(oauthError); - this.entryPoint.commence(this.exchange, exception).block(); - - assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Bearer error=\"invalid_request\", error_description=\"Oops\", error_uri=\"https://example.com\""); + assertThat(getResponse().getHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo( + "Bearer error=\"invalid_request\", error_description=\"Oops\", error_uri=\"https://example.com\""); assertThat(getResponse().getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); } @Test public void commenceWhenNoSubscriberThenNothingHappens() { this.entryPoint.commence(this.exchange, new BadCredentialsException("")); - assertThat(getResponse().getHeaders()).isEmpty(); assertThat(getResponse().getStatusCode()).isNull(); } @@ -102,4 +98,5 @@ public class BearerTokenServerAuthenticationEntryPointTests { private MockServerHttpResponse getResponse() { return this.exchange.getResponse(); } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverterTests.java index e7932d6f66..9a8cf9c1b3 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/server/ServerBearerTokenAuthenticationConverterTests.java @@ -31,15 +31,16 @@ import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.catchThrowableOfType; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rob Winch * @since 5.1 */ public class ServerBearerTokenAuthenticationConverterTests { + private static final String CUSTOM_HEADER = "custom-header"; + private static final String TEST_TOKEN = "test-token"; private ServerBearerTokenAuthenticationConverter converter; @@ -51,10 +52,10 @@ public class ServerBearerTokenAuthenticationConverterTests { @Test public void resolveWhenValidHeaderIsPresentThenTokenIsResolved() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_TOKEN); - + // @formatter:on assertThat(convertToToken(request).getToken()).isEqualTo(TEST_TOKEN); } @@ -62,117 +63,116 @@ public class ServerBearerTokenAuthenticationConverterTests { @Test public void resolveWhenHeaderEndsWithPaddingIndicatorThenTokenIsResolved() { String token = TEST_TOKEN + "=="; - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(HttpHeaders.AUTHORIZATION, "Bearer " + token); - + // @formatter:on assertThat(convertToToken(request).getToken()).isEqualTo(token); } @Test public void resolveWhenCustomDefinedHeaderIsValidAndPresentThenTokenIsResolved() { this.converter.setBearerTokenHeaderName(CUSTOM_HEADER); - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(CUSTOM_HEADER, "Bearer " + TEST_TOKEN); - + // @formatter:on assertThat(convertToToken(request).getToken()).isEqualTo(TEST_TOKEN); } // gh-7011 @Test public void resolveWhenValidHeaderIsEmptyStringThenTokenIsResolved() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") - .header(HttpHeaders.AUTHORIZATION, "Bearer "); - - OAuth2AuthenticationException expected = catchThrowableOfType(() -> convertToToken(request), - OAuth2AuthenticationException.class); - BearerTokenError error = (BearerTokenError) expected.getError(); - assertThat(error.getErrorCode()).isEqualTo(BearerTokenErrorCodes.INVALID_TOKEN); - assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); - assertThat(error.getHttpStatus()).isEqualTo(HttpStatus.UNAUTHORIZED); + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/").header(HttpHeaders.AUTHORIZATION, + "Bearer "); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> convertToToken(request)) + .satisfies((ex) -> { + BearerTokenError error = (BearerTokenError) ex.getError(); + assertThat(error.getErrorCode()).isEqualTo(BearerTokenErrorCodes.INVALID_TOKEN); + assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); + assertThat(error.getHttpStatus()).isEqualTo(HttpStatus.UNAUTHORIZED); + }); + // @formatter:on } @Test public void resolveWhenLowercaseHeaderIsPresentThenTokenIsResolved() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(HttpHeaders.AUTHORIZATION, "bearer " + TEST_TOKEN); - + // @formatter:on assertThat(convertToToken(request).getToken()).isEqualTo(TEST_TOKEN); } @Test public void resolveWhenNoHeaderIsPresentThenTokenIsNotResolved() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/"); - + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/"); assertThat(convertToToken(request)).isNull(); } @Test public void resolveWhenHeaderWithWrongSchemeIsPresentThenTokenIsNotResolved() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") - .header(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder().encodeToString("test:test".getBytes())); - + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") + .header(HttpHeaders.AUTHORIZATION, + "Basic " + Base64.getEncoder().encodeToString("test:test".getBytes())); + // @formatter:on assertThat(convertToToken(request)).isNull(); } @Test public void resolveWhenHeaderWithMissingTokenIsPresentThenAuthenticationExceptionIsThrown() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(HttpHeaders.AUTHORIZATION, "Bearer "); - - assertThatCode(() -> convertToToken(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining(("Bearer token is malformed")); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> convertToToken(request)) + .withMessageContaining(("Bearer token is malformed")); + // @formatter:on } @Test public void resolveWhenHeaderWithInvalidCharactersIsPresentThenAuthenticationExceptionIsThrown() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(HttpHeaders.AUTHORIZATION, "Bearer an\"invalid\"token"); - - assertThatCode(() -> convertToToken(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining(("Bearer token is malformed")); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> convertToToken(request)) + .withMessageContaining(("Bearer token is malformed")); + // @formatter:on } // gh-8865 @Test public void resolveWhenHeaderWithInvalidCharactersIsPresentAndNotSubscribedThenNoneExceptionIsThrown() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .header(HttpHeaders.AUTHORIZATION, "Bearer an\"invalid\"token"); - - assertThatCode(() -> this.converter.convert(MockServerWebExchange.from(request))) - .doesNotThrowAnyException(); + // @formatter:on + this.converter.convert(MockServerWebExchange.from(request)); } @Test public void resolveWhenValidHeaderIsPresentTogetherWithQueryParameterThenAuthenticationExceptionIsThrown() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .queryParam("access_token", TEST_TOKEN) .header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_TOKEN); - - assertThatCode(() -> convertToToken(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .hasMessageContaining("Found multiple bearer tokens in the request"); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> convertToToken(request)) + .withMessageContaining("Found multiple bearer tokens in the request"); + // @formatter:on } @Test public void resolveWhenQueryParameterIsPresentAndSupportedThenTokenIsResolved() { this.converter.setAllowUriQueryParameter(true); - - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .queryParam("access_token", TEST_TOKEN); - + // @formatter:on assertThat(convertToToken(request).getToken()).isEqualTo(TEST_TOKEN); } @@ -180,25 +180,26 @@ public class ServerBearerTokenAuthenticationConverterTests { @Test public void resolveWhenQueryParameterIsEmptyAndSupportedThenOAuth2AuthenticationException() { this.converter.setAllowUriQueryParameter(true); - - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .queryParam("access_token", ""); - - OAuth2AuthenticationException expected = catchThrowableOfType(() -> convertToToken(request), - OAuth2AuthenticationException.class); - BearerTokenError error = (BearerTokenError) expected.getError(); - assertThat(error.getErrorCode()).isEqualTo(BearerTokenErrorCodes.INVALID_TOKEN); - assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); - assertThat(error.getHttpStatus()).isEqualTo(HttpStatus.UNAUTHORIZED); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> convertToToken(request)) + .satisfies((ex) -> { + BearerTokenError error = (BearerTokenError) ex.getError(); + assertThat(error.getErrorCode()).isEqualTo(BearerTokenErrorCodes.INVALID_TOKEN); + assertThat(error.getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); + assertThat(error.getHttpStatus()).isEqualTo(HttpStatus.UNAUTHORIZED); + }); + // @formatter:on } @Test public void resolveWhenQueryParameterIsPresentAndNotSupportedThenTokenIsNotResolved() { - MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest - .get("/") + // @formatter:off + MockServerHttpRequest.BaseBuilder request = MockServerHttpRequest.get("/") .queryParam("access_token", TEST_TOKEN); - + // @formatter:on assertThat(convertToToken(request)).isNull(); } @@ -208,6 +209,11 @@ public class ServerBearerTokenAuthenticationConverterTests { private BearerTokenAuthenticationToken convertToToken(MockServerHttpRequest request) { MockServerWebExchange exchange = MockServerWebExchange.from(request); - return this.converter.convert(exchange).cast(BearerTokenAuthenticationToken.class).block(); + // @formatter:off + return this.converter.convert(exchange) + .cast(BearerTokenAuthenticationToken.class) + .block(); + // @formatter:on } + } diff --git a/openid/src/main/java/org/springframework/security/openid/AuthenticationCancelledException.java b/openid/src/main/java/org/springframework/security/openid/AuthenticationCancelledException.java index e5e89d5e38..089b710612 100644 --- a/openid/src/main/java/org/springframework/security/openid/AuthenticationCancelledException.java +++ b/openid/src/main/java/org/springframework/security/openid/AuthenticationCancelledException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import org.springframework.security.core.AuthenticationException; @@ -20,20 +21,21 @@ import org.springframework.security.core.AuthenticationException; /** * Indicates that OpenID authentication was cancelled * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley, Opsera Ltd + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class AuthenticationCancelledException extends AuthenticationException { - // ~ Constructors - // =================================================================================================== public AuthenticationCancelledException(String msg) { super(msg); } - public AuthenticationCancelledException(String msg, Throwable t) { - super(msg, t); + public AuthenticationCancelledException(String msg, Throwable cause) { + super(msg, cause); } + } diff --git a/openid/src/main/java/org/springframework/security/openid/AxFetchListFactory.java b/openid/src/main/java/org/springframework/security/openid/AxFetchListFactory.java index 3c99c94626..f6b4f5bd57 100644 --- a/openid/src/main/java/org/springframework/security/openid/AxFetchListFactory.java +++ b/openid/src/main/java/org/springframework/security/openid/AxFetchListFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import java.util.List; @@ -24,20 +25,22 @@ import java.util.List; * This allows the list of attributes for a fetch request to be tailored for different * OpenID providers, since they do not all support the same attributes. * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Luke Taylor * @since 3.1 + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public interface AxFetchListFactory { /** * Builds the list of attributes which should be added to the fetch request for the * supplied OpenID identifier. - * * @param identifier the claimed_identity * @return the attributes to fetch for this identifier */ List createAttributeList(String identifier); + } diff --git a/openid/src/main/java/org/springframework/security/openid/NullAxFetchListFactory.java b/openid/src/main/java/org/springframework/security/openid/NullAxFetchListFactory.java index 75df033bac..2d34debff5 100644 --- a/openid/src/main/java/org/springframework/security/openid/NullAxFetchListFactory.java +++ b/openid/src/main/java/org/springframework/security/openid/NullAxFetchListFactory.java @@ -13,20 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import java.util.Collections; import java.util.List; /** - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Luke Taylor * @since 3.1 + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class NullAxFetchListFactory implements AxFetchListFactory { + + @Override public List createAttributeList(String identifier) { return Collections.emptyList(); } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java b/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java index 7ab6f47615..33e65240e8 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenID4JavaConsumer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import java.util.ArrayList; @@ -38,190 +39,159 @@ import org.openid4java.message.ParameterList; import org.openid4java.message.ax.AxMessage; import org.openid4java.message.ax.FetchRequest; import org.openid4java.message.ax.FetchResponse; + import org.springframework.util.StringUtils; /** * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. * @author Ray Krueger * @author Luke Taylor */ +@Deprecated @SuppressWarnings("unchecked") public class OpenID4JavaConsumer implements OpenIDConsumer { - private static final String DISCOVERY_INFO_KEY = DiscoveryInformation.class.getName(); - private static final String ATTRIBUTE_LIST_KEY = "SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"; - // ~ Instance fields - // ================================================================================================ + private static final String DISCOVERY_INFO_KEY = DiscoveryInformation.class.getName(); + + private static final String ATTRIBUTE_LIST_KEY = "SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST"; protected final Log logger = LogFactory.getLog(getClass()); private final ConsumerManager consumerManager; - private final AxFetchListFactory attributesToFetchFactory; - // ~ Constructors - // =================================================================================================== + private final AxFetchListFactory attributesToFetchFactory; public OpenID4JavaConsumer() throws ConsumerException { this(new ConsumerManager(), new NullAxFetchListFactory()); } - public OpenID4JavaConsumer(AxFetchListFactory attributesToFetchFactory) - throws ConsumerException { + public OpenID4JavaConsumer(AxFetchListFactory attributesToFetchFactory) throws ConsumerException { this(new ConsumerManager(), attributesToFetchFactory); } - public OpenID4JavaConsumer(ConsumerManager consumerManager, - AxFetchListFactory attributesToFetchFactory) { + public OpenID4JavaConsumer(ConsumerManager consumerManager, AxFetchListFactory attributesToFetchFactory) { this.consumerManager = consumerManager; this.attributesToFetchFactory = attributesToFetchFactory; } - // ~ Methods - // ======================================================================================================== - - public String beginConsumption(HttpServletRequest req, String identityUrl, - String returnToUrl, String realm) throws OpenIDConsumerException { - List discoveries; - - try { - discoveries = consumerManager.discover(identityUrl); - } - catch (DiscoveryException e) { - throw new OpenIDConsumerException("Error during discovery", e); - } - - DiscoveryInformation information = consumerManager.associate(discoveries); + @Override + public String beginConsumption(HttpServletRequest req, String identityUrl, String returnToUrl, String realm) + throws OpenIDConsumerException { + List discoveries = getDiscoveries(identityUrl); + DiscoveryInformation information = this.consumerManager.associate(discoveries); req.getSession().setAttribute(DISCOVERY_INFO_KEY, information); + AuthRequest authReq = getAuthRequest(req, identityUrl, returnToUrl, realm, information); + return authReq.getDestinationUrl(true); + } - AuthRequest authReq; - + private List getDiscoveries(String identityUrl) throws OpenIDConsumerException { try { - authReq = consumerManager.authenticate(information, returnToUrl, realm); - - logger.debug("Looking up attribute fetch list for identifier: " + identityUrl); - - List attributesToFetch = attributesToFetchFactory - .createAttributeList(identityUrl); + return this.consumerManager.discover(identityUrl); + } + catch (DiscoveryException ex) { + throw new OpenIDConsumerException("Error during discovery", ex); + } + } + private AuthRequest getAuthRequest(HttpServletRequest req, String identityUrl, String returnToUrl, String realm, + DiscoveryInformation information) throws OpenIDConsumerException { + try { + AuthRequest authReq = this.consumerManager.authenticate(information, returnToUrl, realm); + this.logger.debug("Looking up attribute fetch list for identifier: " + identityUrl); + List attributesToFetch = this.attributesToFetchFactory.createAttributeList(identityUrl); if (!attributesToFetch.isEmpty()) { req.getSession().setAttribute(ATTRIBUTE_LIST_KEY, attributesToFetch); FetchRequest fetchRequest = FetchRequest.createFetchRequest(); for (OpenIDAttribute attr : attributesToFetch) { - if (logger.isDebugEnabled()) { - logger.debug("Adding attribute " + attr.getType() - + " to fetch request"); + if (this.logger.isDebugEnabled()) { + this.logger.debug("Adding attribute " + attr.getType() + " to fetch request"); } - fetchRequest.addAttribute(attr.getName(), attr.getType(), - attr.isRequired(), attr.getCount()); + fetchRequest.addAttribute(attr.getName(), attr.getType(), attr.isRequired(), attr.getCount()); } authReq.addExtension(fetchRequest); } + return authReq; } - catch (MessageException | ConsumerException e) { - throw new OpenIDConsumerException( - "Error processing ConsumerManager authentication", e); + catch (MessageException | ConsumerException ex) { + throw new OpenIDConsumerException("Error processing ConsumerManager authentication", ex); } - - return authReq.getDestinationUrl(true); } - public OpenIDAuthenticationToken endConsumption(HttpServletRequest request) - throws OpenIDConsumerException { + @Override + public OpenIDAuthenticationToken endConsumption(HttpServletRequest request) throws OpenIDConsumerException { // extract the parameters from the authentication response // (which comes in as a HTTP request from the OpenID provider) ParameterList openidResp = new ParameterList(request.getParameterMap()); - // retrieve the previously stored discovery information - DiscoveryInformation discovered = (DiscoveryInformation) request.getSession() - .getAttribute(DISCOVERY_INFO_KEY); - + DiscoveryInformation discovered = (DiscoveryInformation) request.getSession().getAttribute(DISCOVERY_INFO_KEY); if (discovered == null) { throw new OpenIDConsumerException( "DiscoveryInformation is not available. Possible causes are lost session or replay attack"); } - - List attributesToFetch = (List) request - .getSession().getAttribute(ATTRIBUTE_LIST_KEY); - + List attributesToFetch = (List) request.getSession() + .getAttribute(ATTRIBUTE_LIST_KEY); request.getSession().removeAttribute(DISCOVERY_INFO_KEY); request.getSession().removeAttribute(ATTRIBUTE_LIST_KEY); - // extract the receiving URL from the HTTP request StringBuffer receivingURL = request.getRequestURL(); String queryString = request.getQueryString(); - if (StringUtils.hasLength(queryString)) { receivingURL.append("?").append(request.getQueryString()); } - // verify the response VerificationResult verification; - try { - verification = consumerManager.verify(receivingURL.toString(), openidResp, - discovered); + verification = this.consumerManager.verify(receivingURL.toString(), openidResp, discovered); } - catch (MessageException | AssociationException | DiscoveryException e) { - throw new OpenIDConsumerException("Error verifying openid response", e); + catch (MessageException | AssociationException | DiscoveryException ex) { + throw new OpenIDConsumerException("Error verifying openid response", ex); } - // examine the verification result and extract the verified identifier Identifier verified = verification.getVerifiedId(); - if (verified == null) { Identifier id = discovered.getClaimedIdentifier(); return new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.FAILURE, - id == null ? "Unknown" : id.getIdentifier(), + (id != null) ? id.getIdentifier() : "Unknown", "Verification status message: [" + verification.getStatusMsg() + "]", - Collections. emptyList()); + Collections.emptyList()); } - - List attributes = fetchAxAttributes( - verification.getAuthResponse(), attributesToFetch); - - return new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.SUCCESS, - verified.getIdentifier(), "some message", attributes); + List attributes = fetchAxAttributes(verification.getAuthResponse(), attributesToFetch); + return new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.SUCCESS, verified.getIdentifier(), + "some message", attributes); } - List fetchAxAttributes(Message authSuccess, - List attributesToFetch) throws OpenIDConsumerException { - - if (attributesToFetch == null - || !authSuccess.hasExtension(AxMessage.OPENID_NS_AX)) { + List fetchAxAttributes(Message authSuccess, List attributesToFetch) + throws OpenIDConsumerException { + if (attributesToFetch == null || !authSuccess.hasExtension(AxMessage.OPENID_NS_AX)) { return Collections.emptyList(); } - - logger.debug("Extracting attributes retrieved by attribute exchange"); - + this.logger.debug("Extracting attributes retrieved by attribute exchange"); List attributes = Collections.emptyList(); - try { MessageExtension ext = authSuccess.getExtension(AxMessage.OPENID_NS_AX); if (ext instanceof FetchResponse) { FetchResponse fetchResp = (FetchResponse) ext; attributes = new ArrayList<>(attributesToFetch.size()); - for (OpenIDAttribute attr : attributesToFetch) { List values = fetchResp.getAttributeValues(attr.getName()); if (!values.isEmpty()) { - OpenIDAttribute fetched = new OpenIDAttribute(attr.getName(), - attr.getType(), values); + OpenIDAttribute fetched = new OpenIDAttribute(attr.getName(), attr.getType(), values); fetched.setRequired(attr.isRequired()); attributes.add(fetched); } } } } - catch (MessageException e) { - throw new OpenIDConsumerException("Attribute retrieval failed", e); + catch (MessageException ex) { + throw new OpenIDConsumerException("Attribute retrieval failed", ex); } - - if (logger.isDebugEnabled()) { - logger.debug("Retrieved attributes" + attributes); + if (this.logger.isDebugEnabled()) { + this.logger.debug("Retrieved attributes" + attributes); } - return attributes; } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAttribute.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAttribute.java index c15d45856d..003c067146 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAttribute.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAttribute.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import java.io.Serializable; @@ -27,16 +28,22 @@ import org.springframework.util.Assert; * should be requested during a fetch request, or to hold values for an attribute which * are returned during the authentication process. * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Luke Taylor * @since 3.0 + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenIDAttribute implements Serializable { + private final String name; + private final String typeIdentifier; + private boolean required = false; + private int count = 1; private final List values; @@ -58,14 +65,14 @@ public class OpenIDAttribute implements Serializable { * The attribute name */ public String getName() { - return name; + return this.name; } /** * The attribute type Identifier (a URI). */ public String getType() { - return typeIdentifier; + return this.typeIdentifier; } /** @@ -73,7 +80,7 @@ public class OpenIDAttribute implements Serializable { * Defaults to "false". */ public boolean isRequired() { - return required; + return this.required; } public void setRequired(boolean required) { @@ -85,7 +92,7 @@ public class OpenIDAttribute implements Serializable { * request. Defaults to 1. */ public int getCount() { - return count; + return this.count; } public void setCount(int count) { @@ -96,20 +103,20 @@ public class OpenIDAttribute implements Serializable { * The values obtained from an attribute exchange. */ public List getValues() { - Assert.notNull(values, - "Cannot read values from an authentication request attribute"); - return values; + Assert.notNull(this.values, "Cannot read values from an authentication request attribute"); + return this.values; } @Override public String toString() { StringBuilder result = new StringBuilder("["); - result.append(name); - if (values != null) { + result.append(this.name); + if (this.values != null) { result.append(":"); - result.append(values.toString()); + result.append(this.values.toString()); } result.append("]"); return result.toString(); } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java index 9c3999d704..c0c3cffbb8 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java @@ -16,7 +16,22 @@ package org.springframework.security.openid; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLEncoder; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.openid4java.consumer.ConsumerException; + import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -25,15 +40,6 @@ import org.springframework.security.web.authentication.rememberme.AbstractRememb import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.MalformedURLException; -import java.net.URL; -import java.net.URLEncoder; -import java.util.*; - /** * Filter which processes OpenID authentication requests. *

        @@ -59,58 +65,47 @@ import java.util.*; * where it should (normally) be processed by an OpenIDAuthenticationProvider in * order to load the authorities for the user. * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley * @author Ray Krueger * @author Luke Taylor * @since 2.0 * @see OpenIDAuthenticationProvider + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessingFilter { - // ~ Static fields/initializers - // ===================================================================================== public static final String DEFAULT_CLAIMED_IDENTITY_FIELD = "openid_identifier"; - // ~ Instance fields - // ================================================================================================ - private OpenIDConsumer consumer; - private String claimedIdentityFieldName = DEFAULT_CLAIMED_IDENTITY_FIELD; - private Map realmMapping = Collections.emptyMap(); - private Set returnToUrlParameters = Collections.emptySet(); - // ~ Constructors - // =================================================================================================== + private String claimedIdentityFieldName = DEFAULT_CLAIMED_IDENTITY_FIELD; + + private Map realmMapping = Collections.emptyMap(); + + private Set returnToUrlParameters = Collections.emptySet(); public OpenIDAuthenticationFilter() { super("/login/openid"); } - // ~ Methods - // ======================================================================================================== - @Override public void afterPropertiesSet() { super.afterPropertiesSet(); - - if (consumer == null) { + if (this.consumer == null) { try { - consumer = new OpenID4JavaConsumer(); + this.consumer = new OpenID4JavaConsumer(); } - catch (ConsumerException e) { - throw new IllegalArgumentException("Failed to initialize OpenID", e); + catch (ConsumerException ex) { + throw new IllegalArgumentException("Failed to initialize OpenID", ex); } } - - if (returnToUrlParameters.isEmpty() - && getRememberMeServices() instanceof AbstractRememberMeServices) { - returnToUrlParameters = new HashSet<>(); - returnToUrlParameters - .add(((AbstractRememberMeServices) getRememberMeServices()) - .getParameter()); + if (this.returnToUrlParameters.isEmpty() && getRememberMeServices() instanceof AbstractRememberMeServices) { + this.returnToUrlParameters = new HashSet<>(); + this.returnToUrlParameters.add(((AbstractRememberMeServices) getRememberMeServices()).getParameter()); } } @@ -124,109 +119,88 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing * */ @Override - public Authentication attemptAuthentication(HttpServletRequest request, - HttpServletResponse response) throws AuthenticationException, IOException { + public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) + throws AuthenticationException, IOException { OpenIDAuthenticationToken token; - String identity = request.getParameter("openid.identity"); - if (!StringUtils.hasText(identity)) { String claimedIdentity = obtainUsername(request); - try { String returnToUrl = buildReturnToUrl(request); String realm = lookupRealm(returnToUrl); - String openIdUrl = consumer.beginConsumption(request, claimedIdentity, - returnToUrl, realm); - if (logger.isDebugEnabled()) { - logger.debug("return_to is '" + returnToUrl + "', realm is '" + realm - + "'"); - logger.debug("Redirecting to " + openIdUrl); + String openIdUrl = this.consumer.beginConsumption(request, claimedIdentity, returnToUrl, realm); + if (this.logger.isDebugEnabled()) { + this.logger.debug("return_to is '" + returnToUrl + "', realm is '" + realm + "'"); + this.logger.debug("Redirecting to " + openIdUrl); } response.sendRedirect(openIdUrl); - // Indicate to parent class that authentication is continuing. return null; } - catch (OpenIDConsumerException e) { - logger.debug("Failed to consume claimedIdentity: " + claimedIdentity, e); + catch (OpenIDConsumerException ex) { + this.logger.debug("Failed to consume claimedIdentity: " + claimedIdentity, ex); throw new AuthenticationServiceException( "Unable to process claimed identity '" + claimedIdentity + "'"); } } - - if (logger.isDebugEnabled()) { - logger.debug("Supplied OpenID identity is " + identity); + if (this.logger.isDebugEnabled()) { + this.logger.debug("Supplied OpenID identity is " + identity); } - try { - token = consumer.endConsumption(request); + token = this.consumer.endConsumption(request); } - catch (OpenIDConsumerException oice) { - throw new AuthenticationServiceException("Consumer error", oice); + catch (OpenIDConsumerException ex) { + throw new AuthenticationServiceException("Consumer error", ex); } - - token.setDetails(authenticationDetailsSource.buildDetails(request)); - + token.setDetails(this.authenticationDetailsSource.buildDetails(request)); // delegate to the authentication provider - Authentication authentication = this.getAuthenticationManager().authenticate( - token); - + Authentication authentication = this.getAuthenticationManager().authenticate(token); return authentication; } protected String lookupRealm(String returnToUrl) { - String mapping = realmMapping.get(returnToUrl); - + String mapping = this.realmMapping.get(returnToUrl); if (mapping == null) { try { URL url = new URL(returnToUrl); int port = url.getPort(); - - StringBuilder realmBuffer = new StringBuilder(returnToUrl.length()) - .append(url.getProtocol()).append("://").append(url.getHost()); + StringBuilder realmBuffer = new StringBuilder(returnToUrl.length()).append(url.getProtocol()) + .append("://").append(url.getHost()); if (port > 0) { realmBuffer.append(":").append(port); } realmBuffer.append("/"); mapping = realmBuffer.toString(); } - catch (MalformedURLException e) { - logger.warn("returnToUrl was not a valid URL: [" + returnToUrl + "]", e); + catch (MalformedURLException ex) { + this.logger.warn("returnToUrl was not a valid URL: [" + returnToUrl + "]", ex); } } - return mapping; } /** * Builds the return_to URL that will be sent to the OpenID service provider. * By default returns the URL of the current request. - * * @param request the current request which is being processed by this filter * @return The return_to URL. */ protected String buildReturnToUrl(HttpServletRequest request) { StringBuffer sb = request.getRequestURL(); - - Iterator iterator = returnToUrlParameters.iterator(); + Iterator iterator = this.returnToUrlParameters.iterator(); boolean isFirst = true; - while (iterator.hasNext()) { String name = iterator.next(); // Assume for simplicity that there is only one value String value = request.getParameter(name); - if (value == null) { continue; } - if (isFirst) { sb.append("?"); isFirst = false; } sb.append(utf8UrlEncode(name)).append("=").append(utf8UrlEncode(value)); - if (iterator.hasNext()) { sb.append("&"); } @@ -238,13 +212,11 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing * Reads the claimedIdentityFieldName from the submitted request. */ protected String obtainUsername(HttpServletRequest req) { - String claimedIdentity = req.getParameter(claimedIdentityFieldName); - + String claimedIdentity = req.getParameter(this.claimedIdentityFieldName); if (!StringUtils.hasText(claimedIdentity)) { - logger.error("No claimed identity supplied in authentication request"); + this.logger.error("No claimed identity supplied in authentication request"); return ""; } - return claimedIdentity.trim(); } @@ -259,7 +231,6 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing * protocol, hostname and port followed by a trailing slash. This means that * https://foo.example.com/login/openid will automatically become * http://foo.example.com:80/ - * * @param realmMapping containing returnToUrl -> realm mappings */ public void setRealmMapping(Map realmMapping) { @@ -269,7 +240,6 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing /** * The name of the request parameter containing the OpenID identity, as submitted from * the initial login form. - * * @param claimedIdentityFieldName defaults to "openid_identifier" */ public void setClaimedIdentityFieldName(String claimedIdentityFieldName) { @@ -284,7 +254,6 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing * Specifies any extra parameters submitted along with the identity field which should * be appended to the {@code return_to} URL which is assembled by * {@link #buildReturnToUrl}. - * * @param returnToUrlParameters the set of parameter names. If not set, it will * default to the parameter name used by the {@code RememberMeServices} obtained from * the parent class (if one is set). @@ -296,7 +265,6 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing /** * Performs URL encoding with UTF-8 - * * @param value the value to URL encode * @return the encoded value */ @@ -304,11 +272,12 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing try { return URLEncoder.encode(value, "UTF-8"); } - catch (UnsupportedEncodingException e) { + catch (UnsupportedEncodingException ex) { Error err = new AssertionError( "The Java platform guarantees UTF-8 support, but it seemingly is not present."); - err.initCause(e); + err.initCause(ex); throw err; } } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationProvider.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationProvider.java index e42baa6d1f..b1f71d54ff 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationProvider.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import org.springframework.beans.factory.InitializingBean; @@ -44,76 +45,54 @@ import org.springframework.util.Assert; * {@code Authentication} token, so additional properties such as email addresses, * telephone numbers etc can easily be stored. * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley, Opsera Ltd. * @author Luke Taylor + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ -public class OpenIDAuthenticationProvider - implements AuthenticationProvider, InitializingBean { - // ~ Instance fields - // ================================================================================================ +@Deprecated +public class OpenIDAuthenticationProvider implements AuthenticationProvider, InitializingBean { private AuthenticationUserDetailsService userDetailsService; + private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); - // ~ Methods - // ======================================================================================================== - + @Override public void afterPropertiesSet() { Assert.notNull(this.userDetailsService, "The userDetailsService must be set"); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.authentication.AuthenticationProvider#authenticate - * (org.springframework.security.Authentication) - */ - public Authentication authenticate(final Authentication authentication) - throws AuthenticationException { - + @Override + public Authentication authenticate(final Authentication authentication) throws AuthenticationException { if (!supports(authentication.getClass())) { return null; } - - if (authentication instanceof OpenIDAuthenticationToken) { - OpenIDAuthenticationToken response = (OpenIDAuthenticationToken) authentication; - OpenIDAuthenticationStatus status = response.getStatus(); - - // handle the various possibilities - if (status == OpenIDAuthenticationStatus.SUCCESS) { - // Lookup user details - UserDetails userDetails = this.userDetailsService - .loadUserDetails(response); - - return createSuccessfulAuthentication(userDetails, response); - - } - else if (status == OpenIDAuthenticationStatus.CANCELLED) { - throw new AuthenticationCancelledException("Log in cancelled"); - } - else if (status == OpenIDAuthenticationStatus.ERROR) { - throw new AuthenticationServiceException( - "Error message from server: " + response.getMessage()); - } - else if (status == OpenIDAuthenticationStatus.FAILURE) { - throw new BadCredentialsException( - "Log in failed - identity could not be verified"); - } - else if (status == OpenIDAuthenticationStatus.SETUP_NEEDED) { - throw new AuthenticationServiceException( - "The server responded setup was needed, which shouldn't happen"); - } - else { - throw new AuthenticationServiceException( - "Unrecognized return value " + status.toString()); - } + if (!(authentication instanceof OpenIDAuthenticationToken)) { + return null; } - - return null; + OpenIDAuthenticationToken response = (OpenIDAuthenticationToken) authentication; + OpenIDAuthenticationStatus status = response.getStatus(); + // handle the various possibilities + if (status == OpenIDAuthenticationStatus.SUCCESS) { + // Lookup user details + UserDetails userDetails = this.userDetailsService.loadUserDetails(response); + return createSuccessfulAuthentication(userDetails, response); + } + if (status == OpenIDAuthenticationStatus.CANCELLED) { + throw new AuthenticationCancelledException("Log in cancelled"); + } + if (status == OpenIDAuthenticationStatus.ERROR) { + throw new AuthenticationServiceException("Error message from server: " + response.getMessage()); + } + if (status == OpenIDAuthenticationStatus.FAILURE) { + throw new BadCredentialsException("Log in failed - identity could not be verified"); + } + if (status == OpenIDAuthenticationStatus.SETUP_NEEDED) { + throw new AuthenticationServiceException("The server responded setup was needed, which shouldn't happen"); + } + throw new AuthenticationServiceException("Unrecognized return value " + status.toString()); } /** @@ -123,24 +102,21 @@ public class OpenIDAuthenticationProvider * The default implementation just creates a new OpenIDAuthenticationToken from the * original, but with the UserDetails as the principal and including the authorities * loaded by the UserDetailsService. - * * @param userDetails the loaded UserDetails object * @param auth the token passed to the authenticate method, containing * @return the token which will represent the authenticated user. */ - protected Authentication createSuccessfulAuthentication(UserDetails userDetails, - OpenIDAuthenticationToken auth) { + protected Authentication createSuccessfulAuthentication(UserDetails userDetails, OpenIDAuthenticationToken auth) { return new OpenIDAuthenticationToken(userDetails, - this.authoritiesMapper.mapAuthorities(userDetails.getAuthorities()), - auth.getIdentityUrl(), auth.getAttributes()); + this.authoritiesMapper.mapAuthorities(userDetails.getAuthorities()), auth.getIdentityUrl(), + auth.getAttributes()); } /** * Used to load the {@code UserDetails} for the authenticated OpenID user. */ public void setUserDetailsService(UserDetailsService userDetailsService) { - this.userDetailsService = new UserDetailsByNameServiceWrapper<>( - userDetailsService); + this.userDetailsService = new UserDetailsByNameServiceWrapper<>(userDetailsService); } /** @@ -151,13 +127,7 @@ public class OpenIDAuthenticationProvider this.userDetailsService = userDetailsService; } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.authentication.AuthenticationProvider#supports(java - * .lang.Class) - */ + @Override public boolean supports(Class authentication) { return OpenIDAuthenticationToken.class.isAssignableFrom(authentication); } @@ -165,4 +135,5 @@ public class OpenIDAuthenticationProvider public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) { this.authoritiesMapper = authoritiesMapper; } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationStatus.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationStatus.java index db2eee832b..e7c4450ed2 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationStatus.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationStatus.java @@ -13,17 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; /** * Authentication status codes, based on JanRain status codes - * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author JanRain Inc. * @author Robin Bramley, Opsera Ltd * @author Luke Taylor + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ public enum OpenIDAuthenticationStatus { @@ -47,15 +48,13 @@ public enum OpenIDAuthenticationStatus { private final String name; - // ~ Constructors - // =================================================================================================== - OpenIDAuthenticationStatus(String name) { this.name = name; } @Override public String toString() { - return name; + return this.name; } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationToken.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationToken.java index 0625e07727..3a3d0de93e 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationToken.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationToken.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import java.util.ArrayList; @@ -26,29 +27,29 @@ import org.springframework.security.core.SpringSecurityCoreVersion; /** * OpenID Authentication Token * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenIDAuthenticationToken extends AbstractAuthenticationToken { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - // ~ Instance fields - // ================================================================================================ - private final OpenIDAuthenticationStatus status; + private final Object principal; + private final String identityUrl; + private final String message; + private final List attributes; - // ~ Constructors - // =================================================================================================== - - public OpenIDAuthenticationToken(OpenIDAuthenticationStatus status, - String identityUrl, String message, List attributes) { + public OpenIDAuthenticationToken(OpenIDAuthenticationStatus status, String identityUrl, String message, + List attributes) { super(new ArrayList<>(0)); this.principal = identityUrl; this.status = status; @@ -60,14 +61,11 @@ public class OpenIDAuthenticationToken extends AbstractAuthenticationToken { /** * Created by the OpenIDAuthenticationProvider on successful authentication. - * * @param principal usually the UserDetails returned by the configured * UserDetailsService used by the OpenIDAuthenticationProvider. - * */ - public OpenIDAuthenticationToken(Object principal, - Collection authorities, String identityUrl, - List attributes) { + public OpenIDAuthenticationToken(Object principal, Collection authorities, + String identityUrl, List attributes) { super(authorities); this.principal = principal; this.status = OpenIDAuthenticationStatus.SUCCESS; @@ -78,23 +76,21 @@ public class OpenIDAuthenticationToken extends AbstractAuthenticationToken { setAuthenticated(true); } - // ~ Methods - // ======================================================================================================== - /** * Returns 'null' always, as no credentials are processed by the OpenID provider. * @see org.springframework.security.core.Authentication#getCredentials() */ + @Override public Object getCredentials() { return null; } public String getIdentityUrl() { - return identityUrl; + return this.identityUrl; } public String getMessage() { - return message; + return this.message; } /** @@ -102,20 +98,22 @@ public class OpenIDAuthenticationToken extends AbstractAuthenticationToken { * * @see org.springframework.security.core.Authentication#getPrincipal() */ + @Override public Object getPrincipal() { - return principal; + return this.principal; } public OpenIDAuthenticationStatus getStatus() { - return status; + return this.status; } public List getAttributes() { - return attributes; + return this.attributes; } @Override public String toString() { - return "[" + super.toString() + ", attributes : " + attributes + "]"; + return "[" + super.toString() + ", attributes : " + this.attributes + "]"; } + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDConsumer.java b/openid/src/main/java/org/springframework/security/openid/OpenIDConsumer.java index 303b143a9a..671b960bb5 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDConsumer.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDConsumer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import javax.servlet.http.HttpServletRequest; @@ -20,18 +21,19 @@ import javax.servlet.http.HttpServletRequest; /** * An interface for OpenID library implementations * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Ray Krueger * @author Robin Bramley, Opsera Ltd + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public interface OpenIDConsumer { /** * Given the request, the claimedIdentity, the return to url, and a realm, lookup the * openId authentication page the user should be redirected to. - * * @param req HttpServletRequest * @param claimedIdentity String URI the user presented during authentication * @param returnToUrl String URI of the URL we want the user sent back to by the OP @@ -39,10 +41,9 @@ public interface OpenIDConsumer { * @return String URI to redirect user to for authentication * @throws OpenIDConsumerException if anything bad happens */ - String beginConsumption(HttpServletRequest req, String claimedIdentity, - String returnToUrl, String realm) throws OpenIDConsumerException; - - OpenIDAuthenticationToken endConsumption(HttpServletRequest req) + String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) throws OpenIDConsumerException; + OpenIDAuthenticationToken endConsumption(HttpServletRequest req) throws OpenIDConsumerException; + } diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDConsumerException.java b/openid/src/main/java/org/springframework/security/openid/OpenIDConsumerException.java index f184032c15..b020f0efe4 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDConsumerException.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDConsumerException.java @@ -13,25 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; /** * Thrown by an OpenIDConsumer if it cannot process a request * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley, Opsera Ltd + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenIDConsumerException extends Exception { - // ~ Constructors - // =================================================================================================== public OpenIDConsumerException(String message) { super(message); } - public OpenIDConsumerException(String message, Throwable t) { - super(message, t); + public OpenIDConsumerException(String message, Throwable cause) { + super(message, cause); } + } diff --git a/openid/src/main/java/org/springframework/security/openid/RegexBasedAxFetchListFactory.java b/openid/src/main/java/org/springframework/security/openid/RegexBasedAxFetchListFactory.java index b59481bb38..9a41eb1090 100644 --- a/openid/src/main/java/org/springframework/security/openid/RegexBasedAxFetchListFactory.java +++ b/openid/src/main/java/org/springframework/security/openid/RegexBasedAxFetchListFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import java.util.Collections; @@ -22,13 +23,16 @@ import java.util.Map; import java.util.regex.Pattern; /** - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Luke Taylor * @since 3.1 + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class RegexBasedAxFetchListFactory implements AxFetchListFactory { + private final Map> idToAttributes; /** @@ -36,9 +40,9 @@ public class RegexBasedAxFetchListFactory implements AxFetchListFactory { * which should be fetched for that pattern. */ public RegexBasedAxFetchListFactory(Map> regexMap) { - idToAttributes = new LinkedHashMap<>(); + this.idToAttributes = new LinkedHashMap<>(); for (Map.Entry> entry : regexMap.entrySet()) { - idToAttributes.put(Pattern.compile(entry.getKey()), entry.getValue()); + this.idToAttributes.put(Pattern.compile(entry.getKey()), entry.getValue()); } } @@ -46,13 +50,13 @@ public class RegexBasedAxFetchListFactory implements AxFetchListFactory { * Iterates through the patterns stored in the map and returns the list of attributes * defined for the first match. If no match is found, returns an empty list. */ + @Override public List createAttributeList(String identifier) { - for (Map.Entry> entry : idToAttributes.entrySet()) { + for (Map.Entry> entry : this.idToAttributes.entrySet()) { if (entry.getKey().matcher(identifier).matches()) { return entry.getValue(); } } - return Collections.emptyList(); } diff --git a/openid/src/main/java/org/springframework/security/openid/package-info.java b/openid/src/main/java/org/springframework/security/openid/package-info.java index 3abaa47aa6..62b2897d07 100644 --- a/openid/src/main/java/org/springframework/security/openid/package-info.java +++ b/openid/src/main/java/org/springframework/security/openid/package-info.java @@ -13,5 +13,5 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.openid; +package org.springframework.security.openid; diff --git a/openid/src/test/java/org/springframework/security/openid/MockOpenIDConsumer.java b/openid/src/test/java/org/springframework/security/openid/MockOpenIDConsumer.java index ce8f1a2382..8f4cb141be 100644 --- a/openid/src/test/java/org/springframework/security/openid/MockOpenIDConsumer.java +++ b/openid/src/test/java/org/springframework/security/openid/MockOpenIDConsumer.java @@ -13,21 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; import javax.servlet.http.HttpServletRequest; /** - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley, Opsera Ltd + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class MockOpenIDConsumer implements OpenIDConsumer { - // ~ Instance fields - // ================================================================================================ private OpenIDAuthenticationToken token; + private String redirectUrl; public MockOpenIDConsumer() { @@ -46,21 +48,18 @@ public class MockOpenIDConsumer implements OpenIDConsumer { this.token = token; } - // ~ Methods - // ======================================================================================================== - - public String beginConsumption(HttpServletRequest req, String claimedIdentity, - String returnToUrl, String realm) { - return redirectUrl; + @Override + public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) { + return this.redirectUrl; } + @Override public OpenIDAuthenticationToken endConsumption(HttpServletRequest req) { - return token; + return this.token; } /** * Set the redirectUrl to be returned by beginConsumption - * * @param redirectUrl */ public void setRedirectUrl(String redirectUrl) { @@ -73,10 +72,10 @@ public class MockOpenIDConsumer implements OpenIDConsumer { /** * Set the token to be returned by endConsumption - * * @param token */ public void setToken(OpenIDAuthenticationToken token) { this.token = token; } + } diff --git a/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java b/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java index 971c19a15f..2b9887d8e7 100644 --- a/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java +++ b/openid/src/test/java/org/springframework/security/openid/OpenID4JavaConsumerTests.java @@ -13,14 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.openid; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import java.util.Arrays; +import java.util.List; -import org.junit.*; +import org.junit.Test; import org.mockito.ArgumentMatchers; import org.openid4java.association.AssociationException; import org.openid4java.consumer.ConsumerException; @@ -35,17 +34,25 @@ import org.openid4java.message.MessageException; import org.openid4java.message.ParameterList; import org.openid4java.message.ax.AxMessage; import org.openid4java.message.ax.FetchResponse; + import org.springframework.mock.web.MockHttpServletRequest; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; /** - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Luke Taylor + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenID4JavaConsumerTests { + List attributes = Arrays.asList(new OpenIDAttribute("a", "b"), new OpenIDAttribute("b", "b", Arrays.asList("c"))); @@ -55,25 +62,16 @@ public class OpenID4JavaConsumerTests { ConsumerManager mgr = mock(ConsumerManager.class); AuthRequest authReq = mock(AuthRequest.class); DiscoveryInformation di = mock(DiscoveryInformation.class); - - when(mgr.authenticate(any(DiscoveryInformation.class), any(), any())) - .thenReturn(authReq); - when(mgr.associate(any())).thenReturn(di); - - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, - new MockAttributesFactory()); - + given(mgr.authenticate(any(DiscoveryInformation.class), any(), any())).willReturn(authReq); + given(mgr.associate(any())).willReturn(di); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new MockAttributesFactory()); MockHttpServletRequest request = new MockHttpServletRequest(); consumer.beginConsumption(request, "", "", ""); - - assertThat(request.getSession().getAttribute( - "SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST")).isEqualTo(attributes); - assertThat( - request.getSession().getAttribute(DiscoveryInformation.class.getName())).isEqualTo(di); - + assertThat(request.getSession().getAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST")) + .isEqualTo(this.attributes); + assertThat(request.getSession().getAttribute(DiscoveryInformation.class.getName())).isEqualTo(di); // Check with empty attribute fetch list consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); - request = new MockHttpServletRequest(); consumer.beginConsumption(request, "", "", ""); } @@ -81,28 +79,23 @@ public class OpenID4JavaConsumerTests { @Test(expected = OpenIDConsumerException.class) public void discoveryExceptionRaisesOpenIDException() throws Exception { ConsumerManager mgr = mock(ConsumerManager.class); - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, - new NullAxFetchListFactory()); - when(mgr.discover(any())).thenThrow(new DiscoveryException("msg")); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + given(mgr.discover(any())).willThrow(new DiscoveryException("msg")); consumer.beginConsumption(new MockHttpServletRequest(), "", "", ""); } @Test - public void messageOrConsumerAuthenticationExceptionRaisesOpenIDException() - throws Exception { + public void messageOrConsumerAuthenticationExceptionRaisesOpenIDException() throws Exception { ConsumerManager mgr = mock(ConsumerManager.class); - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, - new NullAxFetchListFactory()); - - when(mgr.authenticate(ArgumentMatchers.any(), any(), any())) - .thenThrow(new MessageException("msg"), new ConsumerException("msg")); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + given(mgr.authenticate(ArgumentMatchers.any(), any(), any())) + .willThrow(new MessageException("msg"), new ConsumerException("msg")); try { consumer.beginConsumption(new MockHttpServletRequest(), "", "", ""); fail("OpenIDConsumerException was not thrown"); } catch (OpenIDConsumerException expected) { } - try { consumer.beginConsumption(new MockHttpServletRequest(), "", "", ""); fail("OpenIDConsumerException was not thrown"); @@ -114,126 +107,90 @@ public class OpenID4JavaConsumerTests { @Test public void failedVerificationReturnsFailedAuthenticationStatus() throws Exception { ConsumerManager mgr = mock(ConsumerManager.class); - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, - new NullAxFetchListFactory()); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); VerificationResult vr = mock(VerificationResult.class); DiscoveryInformation di = mock(DiscoveryInformation.class); - - when( - mgr.verify(any(), any(ParameterList.class), - any(DiscoveryInformation.class))).thenReturn(vr); - + given(mgr.verify(any(), any(ParameterList.class), any(DiscoveryInformation.class))).willReturn(vr); MockHttpServletRequest request = new MockHttpServletRequest(); - request.getSession().setAttribute(DiscoveryInformation.class.getName(), di); - OpenIDAuthenticationToken auth = consumer.endConsumption(request); - assertThat(auth.getStatus()).isEqualTo(OpenIDAuthenticationStatus.FAILURE); } @Test public void verificationExceptionsRaiseOpenIDException() throws Exception { ConsumerManager mgr = mock(ConsumerManager.class); - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, - new NullAxFetchListFactory()); - - when( - mgr.verify(any(), any(ParameterList.class), - any(DiscoveryInformation.class))) - .thenThrow(new MessageException("")) - .thenThrow(new AssociationException("")) - .thenThrow(new DiscoveryException("")); - + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); + given(mgr.verify(any(), any(ParameterList.class), any(DiscoveryInformation.class))) + .willThrow(new MessageException(""), new AssociationException(""), new DiscoveryException("")); MockHttpServletRequest request = new MockHttpServletRequest(); request.setQueryString("x=5"); - try { consumer.endConsumption(request); fail("OpenIDConsumerException was not thrown"); } catch (OpenIDConsumerException expected) { } - try { consumer.endConsumption(request); fail("OpenIDConsumerException was not thrown"); } catch (OpenIDConsumerException expected) { } - try { consumer.endConsumption(request); fail("OpenIDConsumerException was not thrown"); } catch (OpenIDConsumerException expected) { } - } @SuppressWarnings("serial") @Test public void successfulVerificationReturnsExpectedAuthentication() throws Exception { ConsumerManager mgr = mock(ConsumerManager.class); - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, - new NullAxFetchListFactory()); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(mgr, new NullAxFetchListFactory()); VerificationResult vr = mock(VerificationResult.class); DiscoveryInformation di = mock(DiscoveryInformation.class); Identifier id = (Identifier) () -> "id"; Message msg = mock(Message.class); - - when( - mgr.verify(any(), any(ParameterList.class), - any(DiscoveryInformation.class))).thenReturn(vr); - when(vr.getVerifiedId()).thenReturn(id); - when(vr.getAuthResponse()).thenReturn(msg); - + given(mgr.verify(any(), any(ParameterList.class), any(DiscoveryInformation.class))).willReturn(vr); + given(vr.getVerifiedId()).willReturn(id); + given(vr.getAuthResponse()).willReturn(msg); MockHttpServletRequest request = new MockHttpServletRequest(); - request.getSession().setAttribute(DiscoveryInformation.class.getName(), di); - request.getSession().setAttribute( - "SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST", attributes); - + request.getSession().setAttribute("SPRING_SECURITY_OPEN_ID_ATTRIBUTES_FETCH_LIST", this.attributes); OpenIDAuthenticationToken auth = consumer.endConsumption(request); - assertThat(auth.getStatus()).isEqualTo(OpenIDAuthenticationStatus.SUCCESS); } @Test public void fetchAttributesReturnsExpectedValues() throws Exception { - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer( - new NullAxFetchListFactory()); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(new NullAxFetchListFactory()); Message msg = mock(Message.class); FetchResponse fr = mock(FetchResponse.class); - when(msg.hasExtension(AxMessage.OPENID_NS_AX)).thenReturn(true); - when(msg.getExtension(AxMessage.OPENID_NS_AX)).thenReturn(fr); - when(fr.getAttributeValues("a")).thenReturn(Arrays.asList("x", "y")); - - List fetched = consumer.fetchAxAttributes(msg, attributes); - + given(msg.hasExtension(AxMessage.OPENID_NS_AX)).willReturn(true); + given(msg.getExtension(AxMessage.OPENID_NS_AX)).willReturn(fr); + given(fr.getAttributeValues("a")).willReturn(Arrays.asList("x", "y")); + List fetched = consumer.fetchAxAttributes(msg, this.attributes); assertThat(fetched).hasSize(1); assertThat(fetched.get(0).getValues()).hasSize(2); } @Test(expected = OpenIDConsumerException.class) - public void messageExceptionFetchingAttributesRaisesOpenIDException() - throws Exception { - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer( - new NullAxFetchListFactory()); + public void messageExceptionFetchingAttributesRaisesOpenIDException() throws Exception { + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(new NullAxFetchListFactory()); Message msg = mock(Message.class); FetchResponse fr = mock(FetchResponse.class); - when(msg.hasExtension(AxMessage.OPENID_NS_AX)).thenReturn(true); - when(msg.getExtension(AxMessage.OPENID_NS_AX)) - .thenThrow(new MessageException("")); - when(fr.getAttributeValues("a")).thenReturn(Arrays.asList("x", "y")); - - consumer.fetchAxAttributes(msg, attributes); + given(msg.hasExtension(AxMessage.OPENID_NS_AX)).willReturn(true); + given(msg.getExtension(AxMessage.OPENID_NS_AX)).willThrow(new MessageException("")); + given(fr.getAttributeValues("a")).willReturn(Arrays.asList("x", "y")); + consumer.fetchAxAttributes(msg, this.attributes); } @Test(expected = OpenIDConsumerException.class) public void missingDiscoveryInformationThrowsException() throws Exception { - OpenID4JavaConsumer consumer = new OpenID4JavaConsumer( - new NullAxFetchListFactory()); + OpenID4JavaConsumer consumer = new OpenID4JavaConsumer(new NullAxFetchListFactory()); consumer.endConsumption(new MockHttpServletRequest()); } @@ -245,8 +202,11 @@ public class OpenID4JavaConsumerTests { private class MockAttributesFactory implements AxFetchListFactory { + @Override public List createAttributeList(String identifier) { - return attributes; + return OpenID4JavaConsumerTests.this.attributes; } + } + } diff --git a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java index dbaf2eeb68..546bb81a4e 100644 --- a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java +++ b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.openid; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +package org.springframework.security.openid; import java.net.URI; import java.util.Collections; @@ -27,34 +25,47 @@ import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + /** * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenIDAuthenticationFilterTests { OpenIDAuthenticationFilter filter; + private static final String REDIRECT_URL = "https://www.example.com/redirect"; + private static final String CLAIMED_IDENTITY_URL = "https://www.example.com/identity"; + private static final String REQUEST_PATH = "/login/openid"; - private static final String FILTER_PROCESS_URL = "http://localhost:8080" - + REQUEST_PATH; + + private static final String FILTER_PROCESS_URL = "http://localhost:8080" + REQUEST_PATH; + private static final String DEFAULT_TARGET_URL = FILTER_PROCESS_URL; @Before public void setUp() { - filter = new OpenIDAuthenticationFilter(); - filter.setConsumer(new MockOpenIDConsumer(REDIRECT_URL)); + this.filter = new OpenIDAuthenticationFilter(); + this.filter.setConsumer(new MockOpenIDConsumer(REDIRECT_URL)); SavedRequestAwareAuthenticationSuccessHandler successHandler = new SavedRequestAwareAuthenticationSuccessHandler(); - filter.setAuthenticationSuccessHandler(new SavedRequestAwareAuthenticationSuccessHandler()); + this.filter.setAuthenticationSuccessHandler(new SavedRequestAwareAuthenticationSuccessHandler()); successHandler.setDefaultTargetUrl(DEFAULT_TARGET_URL); - filter.setAuthenticationManager(a -> a); - filter.afterPropertiesSet(); + this.filter.setAuthenticationManager((a) -> a); + this.filter.afterPropertiesSet(); } @Test @@ -64,26 +75,23 @@ public class OpenIDAuthenticationFilterTests { req.setRequestURI(REQUEST_PATH); req.setServerPort(8080); MockHttpServletResponse response = new MockHttpServletResponse(); - req.setParameter("openid_identifier", " " + CLAIMED_IDENTITY_URL); req.setRemoteHost("www.example.com"); - - filter.setConsumer(new MockOpenIDConsumer() { - public String beginConsumption(HttpServletRequest req, - String claimedIdentity, String returnToUrl, String realm) { + this.filter.setConsumer(new MockOpenIDConsumer() { + @Override + public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, + String realm) { assertThat(claimedIdentity).isEqualTo(CLAIMED_IDENTITY_URL); assertThat(returnToUrl).isEqualTo(DEFAULT_TARGET_URL); assertThat(realm).isEqualTo("http://localhost:8080/"); return REDIRECT_URL; } }); - FilterChain fc = mock(FilterChain.class); - filter.doFilter(req, response, fc); + this.filter.doFilter(req, response, fc); assertThat(response.getRedirectedUrl()).isEqualTo(REDIRECT_URL); // Filter chain shouldn't proceed - verify(fc, never()).doFilter(any(HttpServletRequest.class), - any(HttpServletResponse.class)); + verify(fc, never()).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } /** @@ -96,9 +104,8 @@ public class OpenIDAuthenticationFilterTests { String paramValue = "https://example.com/path?a=b&c=d"; MockHttpServletRequest req = new MockHttpServletRequest("GET", REQUEST_PATH); req.addParameter(paramName, paramValue); - filter.setReturnToUrlParameters(Collections.singleton(paramName)); - - URI returnTo = new URI(filter.buildReturnToUrl(req)); + this.filter.setReturnToUrlParameters(Collections.singleton(paramName)); + URI returnTo = new URI(this.filter.buildReturnToUrl(req)); String query = returnTo.getRawQuery(); assertThat(count(query, '=')).isEqualTo(1); assertThat(count(query, '&')).isZero(); @@ -116,4 +123,5 @@ public class OpenIDAuthenticationFilterTests { } return count; } + } diff --git a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java index 902ca9b8fe..65ad76bf72 100644 --- a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java +++ b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationProviderTests.java @@ -16,10 +16,8 @@ package org.springframework.security.openid; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import org.junit.Test; + import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -32,23 +30,23 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsByNameServiceWrapper; import org.springframework.security.core.userdetails.UserDetailsService; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link OpenIDAuthenticationProvider} * - * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are - * encouraged to migrate - * to OpenID Connect, which is supported by spring-security-oauth2. * @author Robin Bramley, Opsera Ltd + * @deprecated The OpenID 1.0 and 2.0 protocols have been deprecated and users are + * encouraged to + * migrate to OpenID Connect, which is + * supported by spring-security-oauth2. */ +@Deprecated public class OpenIDAuthenticationProviderTests { - // ~ Static fields/initializers - // ===================================================================================== private static final String USERNAME = "user.acegiopenid.com"; - // ~ Methods - // ======================================================================================================== - /* * Test method for * 'org.springframework.security.authentication.openid.OpenIDAuthenticationProvider. @@ -59,12 +57,9 @@ public class OpenIDAuthenticationProviderTests { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); provider.setAuthoritiesMapper(new NullAuthoritiesMapper()); - - Authentication preAuth = new OpenIDAuthenticationToken( - OpenIDAuthenticationStatus.CANCELLED, USERNAME, "", null); - + Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.CANCELLED, USERNAME, "", + null); assertThat(preAuth.isAuthenticated()).isFalse(); - try { provider.authenticate(preAuth); fail("Should throw an AuthenticationException"); @@ -83,12 +78,8 @@ public class OpenIDAuthenticationProviderTests { public void testAuthenticateError() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); - - Authentication preAuth = new OpenIDAuthenticationToken( - OpenIDAuthenticationStatus.ERROR, USERNAME, "", null); - + Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.ERROR, USERNAME, "", null); assertThat(preAuth.isAuthenticated()).isFalse(); - try { provider.authenticate(preAuth); fail("Should throw an AuthenticationException"); @@ -107,21 +98,15 @@ public class OpenIDAuthenticationProviderTests { public void testAuthenticateFailure() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setAuthenticationUserDetailsService( - new UserDetailsByNameServiceWrapper<>( - new MockUserDetailsService())); - - Authentication preAuth = new OpenIDAuthenticationToken( - OpenIDAuthenticationStatus.FAILURE, USERNAME, "", null); - + new UserDetailsByNameServiceWrapper<>(new MockUserDetailsService())); + Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.FAILURE, USERNAME, "", null); assertThat(preAuth.isAuthenticated()).isFalse(); - try { provider.authenticate(preAuth); fail("Should throw an AuthenticationException"); } catch (BadCredentialsException expected) { - assertThat("Log in failed - identity could not be verified").isEqualTo( - expected.getMessage()); + assertThat("Log in failed - identity could not be verified").isEqualTo(expected.getMessage()); } } @@ -134,20 +119,16 @@ public class OpenIDAuthenticationProviderTests { public void testAuthenticateSetupNeeded() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); - - Authentication preAuth = new OpenIDAuthenticationToken( - OpenIDAuthenticationStatus.SETUP_NEEDED, USERNAME, "", null); - + Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.SETUP_NEEDED, USERNAME, "", + null); assertThat(preAuth.isAuthenticated()).isFalse(); - try { provider.authenticate(preAuth); fail("Should throw an AuthenticationException"); } catch (AuthenticationServiceException expected) { - assertThat( - "The server responded setup was needed, which shouldn't happen").isEqualTo( - expected.getMessage()); + assertThat("The server responded setup was needed, which shouldn't happen") + .isEqualTo(expected.getMessage()); } } @@ -160,14 +141,9 @@ public class OpenIDAuthenticationProviderTests { public void testAuthenticateSuccess() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); - - Authentication preAuth = new OpenIDAuthenticationToken( - OpenIDAuthenticationStatus.SUCCESS, USERNAME, "", null); - + Authentication preAuth = new OpenIDAuthenticationToken(OpenIDAuthenticationStatus.SUCCESS, USERNAME, "", null); assertThat(preAuth.isAuthenticated()).isFalse(); - Authentication postAuth = provider.authenticate(preAuth); - assertThat(postAuth).isNotNull(); assertThat(postAuth instanceof OpenIDAuthenticationToken).isTrue(); assertThat(postAuth.isAuthenticated()).isTrue(); @@ -175,15 +151,13 @@ public class OpenIDAuthenticationProviderTests { assertThat(postAuth.getPrincipal() instanceof UserDetails).isTrue(); assertThat(postAuth.getAuthorities()).isNotNull(); assertThat(postAuth.getAuthorities().size() > 0).isTrue(); - assertThat( - ((OpenIDAuthenticationToken) postAuth).getStatus() == OpenIDAuthenticationStatus.SUCCESS).isTrue(); + assertThat(((OpenIDAuthenticationToken) postAuth).getStatus() == OpenIDAuthenticationStatus.SUCCESS).isTrue(); assertThat(((OpenIDAuthenticationToken) postAuth).getMessage() == null).isTrue(); } @Test public void testDetectsMissingAuthoritiesPopulator() throws Exception { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); - try { provider.afterPropertiesSet(); fail("Should have thrown Exception"); @@ -202,9 +176,7 @@ public class OpenIDAuthenticationProviderTests { public void testDoesntSupport() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); - - assertThat( - provider.supports(UsernamePasswordAuthenticationToken.class)).isFalse(); + assertThat(provider.supports(UsernamePasswordAuthenticationToken.class)).isFalse(); } /* @@ -216,9 +188,7 @@ public class OpenIDAuthenticationProviderTests { public void testIgnoresUserPassAuthToken() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); - - UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken( - USERNAME, "password"); + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(USERNAME, "password"); assertThat(provider.authenticate(token)).isNull(); } @@ -231,7 +201,6 @@ public class OpenIDAuthenticationProviderTests { public void testSupports() { OpenIDAuthenticationProvider provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); - assertThat(provider.supports(OpenIDAuthenticationToken.class)).isTrue(); } @@ -242,10 +211,9 @@ public class OpenIDAuthenticationProviderTests { provider.afterPropertiesSet(); fail("IllegalArgumentException expected, ssoAuthoritiesPopulator is null"); } - catch (IllegalArgumentException e) { + catch (IllegalArgumentException ex) { // expected } - provider = new OpenIDAuthenticationProvider(); provider.setUserDetailsService(new MockUserDetailsService()); provider.afterPropertiesSet(); @@ -253,10 +221,12 @@ public class OpenIDAuthenticationProviderTests { static class MockUserDetailsService implements UserDetailsService { - public UserDetails loadUserByUsername(String ssoUserId) - throws AuthenticationException { + @Override + public UserDetails loadUserByUsername(String ssoUserId) throws AuthenticationException { return new User(ssoUserId, "password", true, true, true, true, AuthorityUtils.createAuthorityList("ROLE_A", "ROLE_B")); } + } + } diff --git a/remoting/src/main/java/org/springframework/security/remoting/dns/DnsResolver.java b/remoting/src/main/java/org/springframework/security/remoting/dns/DnsResolver.java index 1bed997ea0..990fdeb6fe 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/dns/DnsResolver.java +++ b/remoting/src/main/java/org/springframework/security/remoting/dns/DnsResolver.java @@ -27,14 +27,12 @@ public interface DnsResolver { /** * Resolves the IP Address (A record) to the specified host name. Throws * DnsEntryNotFoundException if there is no record. - * * @param hostname The hostname for which you need the IP Address * @return IP Address as a String * @throws DnsEntryNotFoundException No record found * @throws DnsLookupException Unknown DNS error */ - String resolveIpAddress(String hostname) throws DnsEntryNotFoundException, - DnsLookupException; + String resolveIpAddress(String hostname) throws DnsEntryNotFoundException, DnsLookupException; /** *

        @@ -56,22 +54,19 @@ public interface DnsResolver { * The method will return the record with highest priority (which means the lowest * number in the DNS record) and if there are more than one records with the same * priority, it will return the one with the highest weight. You will find more - * informatione about DNS service records at Wikipedia. - * + * informatione about DNS service records at + * Wikipedia. * @param serviceType The service type you are searching for, e.g. ldap, kerberos, ... * @param domain The domain, in which you are searching for the service * @return The hostname of the service * @throws DnsEntryNotFoundException No record found * @throws DnsLookupException Unknown DNS error */ - String resolveServiceEntry(String serviceType, String domain) - throws DnsEntryNotFoundException, DnsLookupException; + String resolveServiceEntry(String serviceType, String domain) throws DnsEntryNotFoundException, DnsLookupException; /** * Resolves the host name for the specified service and then the IP Address for this * host in one call. - * * @param serviceType The service type you are searching for, e.g. ldap, kerberos, ... * @param domain The domain, in which you are searching for the service * @return IP Address of the service diff --git a/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java b/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java index cc699e15df..fb406cb15c 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java +++ b/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java @@ -46,7 +46,6 @@ public class JndiDnsResolver implements DnsResolver { /** * Allows to inject an own JNDI context factory. - * * @param ctxFactory factory to use, when a DirContext is needed * @see InitialDirContext * @see DirContext @@ -55,35 +54,17 @@ public class JndiDnsResolver implements DnsResolver { this.ctxFactory = ctxFactory; } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.remoting.dns.DnsResolver#resolveIpAddress(java.lang - * .String) - */ + @Override public String resolveIpAddress(String hostname) { return resolveIpAddress(hostname, this.ctxFactory.getCtx()); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.remoting.dns.DnsResolver#resolveServiceEntry(java. - * lang.String, java.lang.String) - */ + @Override public String resolveServiceEntry(String serviceType, String domain) { return resolveServiceEntry(serviceType, domain, this.ctxFactory.getCtx()); } - /* - * (non-Javadoc) - * - * @see - * org.springframework.security.remoting.dns.DnsResolver#resolveServiceIpAddress(java - * .lang.String, java.lang.String) - */ + @Override public String resolveServiceIpAddress(String serviceType, String domain) { DirContext ctx = this.ctxFactory.getCtx(); String hostname = resolveServiceEntry(serviceType, domain, ctx); @@ -99,33 +80,29 @@ public class JndiDnsResolver implements DnsResolver { // only the first. return dnsRecord.get().toString(); } - catch (NamingException e) { - throw new DnsLookupException("DNS lookup failed for: " + hostname, e); + catch (NamingException ex) { + throw new DnsLookupException("DNS lookup failed for: " + hostname, ex); } } // This method is needed, so that we can use only one DirContext for // resolveServiceIpAddress(). - private String resolveServiceEntry(String serviceType, String domain, - DirContext ctx) { + private String resolveServiceEntry(String serviceType, String domain, DirContext ctx) { String result = null; try { - String query = new StringBuilder("_").append(serviceType).append("._tcp.") - .append(domain).toString(); + String query = new StringBuilder("_").append(serviceType).append("._tcp.").append(domain).toString(); Attribute dnsRecord = lookup(query, ctx, "SRV"); // There are maybe more records defined, we will return the one // with the highest priority (lowest number) and the highest weight // (highest number) int highestPriority = -1; int highestWeight = -1; - - for (NamingEnumeration recordEnum = dnsRecord.getAll(); recordEnum - .hasMoreElements();) { + for (NamingEnumeration recordEnum = dnsRecord.getAll(); recordEnum.hasMoreElements();) { String[] record = recordEnum.next().toString().split(" "); if (record.length != 4) { - throw new DnsLookupException("Wrong service record for query " + query - + ": [" + Arrays.toString(record) + "]"); + throw new DnsLookupException( + "Wrong service record for query " + query + ": [" + Arrays.toString(record) + "]"); } int priority = Integer.parseInt(record[0]); int weight = Integer.parseInt(record[1]); @@ -142,11 +119,9 @@ public class JndiDnsResolver implements DnsResolver { } } } - catch (NamingException e) { - throw new DnsLookupException( - "DNS lookup failed for service " + serviceType + " at " + domain, e); + catch (NamingException ex) { + throw new DnsLookupException("DNS lookup failed for service " + serviceType + " at " + domain, ex); } - // remove the "." at the end if (result.endsWith(".")) { result = result.substring(0, result.length() - 1); @@ -157,34 +132,31 @@ public class JndiDnsResolver implements DnsResolver { private Attribute lookup(String query, DirContext ictx, String recordType) { try { Attributes dnsResult = ictx.getAttributes(query, new String[] { recordType }); - return dnsResult.get(recordType); } - catch (NamingException e) { - if (e instanceof NameNotFoundException) { - throw new DnsEntryNotFoundException("DNS entry not found for:" + query, - e); + catch (NamingException ex) { + if (ex instanceof NameNotFoundException) { + throw new DnsEntryNotFoundException("DNS entry not found for:" + query, ex); } - throw new DnsLookupException("DNS lookup failed for: " + query, e); + throw new DnsLookupException("DNS lookup failed for: " + query, ex); } } private static class DefaultInitialContextFactory implements InitialContextFactory { + @Override public DirContext getCtx() { Hashtable env = new Hashtable<>(); - env.put(Context.INITIAL_CONTEXT_FACTORY, - "com.sun.jndi.dns.DnsContextFactory"); + env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.dns.DnsContextFactory"); env.put(Context.PROVIDER_URL, "dns:"); // This is needed for IBM JDK/JRE - InitialDirContext ictx; try { - ictx = new InitialDirContext(env); + return new InitialDirContext(env); } - catch (NamingException e) { - throw new DnsLookupException( - "Cannot create InitialDirContext for DNS lookup", e); + catch (NamingException ex) { + throw new DnsLookupException("Cannot create InitialDirContext for DNS lookup", ex); } - return ictx; } + } + } diff --git a/remoting/src/main/java/org/springframework/security/remoting/dns/package-info.java b/remoting/src/main/java/org/springframework/security/remoting/dns/package-info.java index 69ba68a47b..6f8c8a3f9b 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/dns/package-info.java +++ b/remoting/src/main/java/org/springframework/security/remoting/dns/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * DNS resolution. */ package org.springframework.security.remoting.dns; - diff --git a/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutor.java b/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutor.java index 59e61d31b6..648c886428 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutor.java +++ b/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutor.java @@ -22,6 +22,7 @@ import java.util.Base64; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.remoting.httpinvoker.SimpleHttpInvokerRequestExecutor; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; @@ -34,32 +35,20 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Ben Alex * @author Rob Winch */ -public class AuthenticationSimpleHttpInvokerRequestExecutor extends - SimpleHttpInvokerRequestExecutor { - // ~ Static fields/initializers - // ===================================================================================== +public class AuthenticationSimpleHttpInvokerRequestExecutor extends SimpleHttpInvokerRequestExecutor { - private static final Log logger = LogFactory - .getLog(AuthenticationSimpleHttpInvokerRequestExecutor.class); - - // ~ Instance fields - // ================================================================================================ + private static final Log logger = LogFactory.getLog(AuthenticationSimpleHttpInvokerRequestExecutor.class); private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); - // ~ Methods - // ======================================================================================================== - /** * Provided so subclasses can perform additional configuration if required (eg set * additional request headers for non-security related information etc). - * * @param con the HTTP connection to prepare * @param contentLength the length of the content to send * */ - protected void doPrepareConnection(HttpURLConnection con, int contentLength) - throws IOException { + protected void doPrepareConnection(HttpURLConnection con, int contentLength) throws IOException { } /** @@ -73,20 +62,18 @@ public class AuthenticationSimpleHttpInvokerRequestExecutor extends * The SecurityContextHolder is used to obtain the relevant principal and * credentials. *

        - * * @param con the HTTP connection to prepare * @param contentLength the length of the content to send - * * @throws IOException if thrown by HttpURLConnection methods */ - protected void prepareConnection(HttpURLConnection con, int contentLength) - throws IOException { + @Override + protected void prepareConnection(HttpURLConnection con, int contentLength) throws IOException { super.prepareConnection(con, contentLength); Authentication auth = SecurityContextHolder.getContext().getAuthentication(); if ((auth != null) && (auth.getName() != null) && (auth.getCredentials() != null) - && !trustResolver.isAnonymous(auth)) { + && !this.trustResolver.isAnonymous(auth)) { String base64 = auth.getName() + ":" + auth.getCredentials().toString(); con.setRequestProperty("Authorization", "Basic " + new String(Base64.getEncoder().encode(base64.getBytes()))); @@ -105,4 +92,5 @@ public class AuthenticationSimpleHttpInvokerRequestExecutor extends doPrepareConnection(con, contentLength); } + } diff --git a/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/package-info.java b/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/package-info.java index fcf9e723b4..cd052eba4d 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/package-info.java +++ b/remoting/src/main/java/org/springframework/security/remoting/httpinvoker/package-info.java @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Enables use of Spring's HttpInvoker extension points to - * present the principal and credentials located - * in the ContextHolder via BASIC authentication. + * Enables use of Spring's HttpInvoker extension points to present the + * principal and credentials located in the + * ContextHolder via BASIC authentication. *

        * The beans are wired as follows: * @@ -32,4 +33,3 @@ *

        */ package org.springframework.security.remoting.httpinvoker; - diff --git a/remoting/src/main/java/org/springframework/security/remoting/package-info.java b/remoting/src/main/java/org/springframework/security/remoting/package-info.java index 8346f9098a..1439310acb 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/package-info.java +++ b/remoting/src/main/java/org/springframework/security/remoting/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Remote client related functionality. */ package org.springframework.security.remoting; - diff --git a/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocation.java b/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocation.java index dae5befd63..56563301ae 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocation.java +++ b/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocation.java @@ -16,17 +16,19 @@ package org.springframework.security.remoting.rmi; +import java.lang.reflect.InvocationTargetException; + import org.aopalliance.intercept.MethodInvocation; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.remoting.support.RemoteInvocation; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.core.context.SecurityContextHolder; -import java.lang.reflect.InvocationTargetException; - /** * The actual {@code RemoteInvocation} that is passed from the client to the server. *

        @@ -46,93 +48,72 @@ public class ContextPropagatingRemoteInvocation extends RemoteInvocation { private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; - private static final Log logger = LogFactory - .getLog(ContextPropagatingRemoteInvocation.class); - - // ~ Instance fields - // ================================================================================================ + private static final Log logger = LogFactory.getLog(ContextPropagatingRemoteInvocation.class); private final String principal; - private final String credentials; - // ~ Constructors - // =================================================================================================== + private final String credentials; /** * Constructs the object, storing the principal and credentials extracted from the * client-side security context. - * * @param methodInvocation the method to invoke */ public ContextPropagatingRemoteInvocation(MethodInvocation methodInvocation) { super(methodInvocation); - Authentication currentUser = SecurityContextHolder.getContext() - .getAuthentication(); - + Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); if (currentUser != null) { - principal = currentUser.getName(); + this.principal = currentUser.getName(); Object userCredentials = currentUser.getCredentials(); - credentials = userCredentials == null ? null : userCredentials.toString(); + this.credentials = (userCredentials != null) ? userCredentials.toString() : null; } else { - principal = credentials = null; + this.credentials = null; + this.principal = null; } - if (logger.isDebugEnabled()) { - logger.debug("RemoteInvocation now has principal: " + principal); - if (credentials == null) { + logger.debug("RemoteInvocation now has principal: " + this.principal); + if (this.credentials == null) { logger.debug("RemoteInvocation now has null credentials."); } } } - // ~ Methods - // ======================================================================================================== - /** * Invoked on the server-side. *

        * The transmitted principal and credentials will be used to create an unauthenticated - * {@code Authentication} instance for processing by the {@code AuthenticationManager}. - * + * {@code Authentication} instance for processing by the + * {@code AuthenticationManager}. * @param targetObject the target object to apply the invocation to - * * @return the invocation result - * * @throws NoSuchMethodException if the method name could not be resolved * @throws IllegalAccessException if the method could not be accessed * @throws InvocationTargetException if the method invocation resulted in an exception */ - public Object invoke(Object targetObject) throws NoSuchMethodException, - IllegalAccessException, InvocationTargetException { - - if (principal != null) { - Authentication request = createAuthenticationRequest(principal, credentials); + @Override + public Object invoke(Object targetObject) + throws NoSuchMethodException, IllegalAccessException, InvocationTargetException { + if (this.principal != null) { + Authentication request = createAuthenticationRequest(this.principal, this.credentials); request.setAuthenticated(false); SecurityContextHolder.getContext().setAuthentication(request); - - if (logger.isDebugEnabled()) { - logger.debug("Set SecurityContextHolder to contain: " + request); - } + logger.debug(LogMessage.format("Set SecurityContextHolder to contain: %s", request)); } - try { return super.invoke(targetObject); } finally { SecurityContextHolder.clearContext(); - - if (logger.isDebugEnabled()) { - logger.debug("Cleared SecurityContextHolder."); - } + logger.debug("Cleared SecurityContextHolder."); } } /** * Creates the server-side authentication request object. */ - protected Authentication createAuthenticationRequest(String principal, - String credentials) { + protected Authentication createAuthenticationRequest(String principal, String credentials) { return new UsernamePasswordAuthenticationToken(principal, credentials); } + } diff --git a/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationFactory.java b/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationFactory.java index 47dc7985ac..6ea30971a4 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationFactory.java +++ b/remoting/src/main/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationFactory.java @@ -34,10 +34,10 @@ import org.springframework.remoting.support.RemoteInvocationFactory; * @author Ben Alex */ public class ContextPropagatingRemoteInvocationFactory implements RemoteInvocationFactory { - // ~ Methods - // ======================================================================================================== + @Override public RemoteInvocation createRemoteInvocation(MethodInvocation methodInvocation) { return new ContextPropagatingRemoteInvocation(methodInvocation); } + } diff --git a/remoting/src/main/java/org/springframework/security/remoting/rmi/package-info.java b/remoting/src/main/java/org/springframework/security/remoting/rmi/package-info.java index 605b65ac59..ead4d44863 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/rmi/package-info.java +++ b/remoting/src/main/java/org/springframework/security/remoting/rmi/package-info.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** - * Enables use of Spring's RMI remoting extension points to propagate the SecurityContextHolder (which - * should contain an Authentication request token) from one JVM to the remote JVM. + * Enables use of Spring's RMI remoting extension points to propagate the + * SecurityContextHolder (which should contain an Authentication + * request token) from one JVM to the remote JVM. *

        - * The beans are wired as follows: - *

        + * The beans are wired as follows: 
          * <bean id="test" class="org.springframework.remoting.rmi.RmiProxyFactoryBean">
          *   <property name="serviceUrl"><value>rmi://localhost/Test</value></property>
          *   <property name="serviceInterface"><value>test.TargetInterface</value></property>
        @@ -31,4 +32,3 @@
          * 
        */ package org.springframework.security.remoting.rmi; - diff --git a/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java b/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java index 4f2aee1959..9dc14e37e9 100644 --- a/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java +++ b/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java @@ -16,9 +16,6 @@ package org.springframework.security.remoting.dns; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - import javax.naming.NameNotFoundException; import javax.naming.NamingException; import javax.naming.directory.Attributes; @@ -29,83 +26,77 @@ import javax.naming.directory.DirContext; import org.junit.Before; import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + /** - * * @author Mike Wiesner * @since 3.0 */ public class JndiDnsResolverTests { private JndiDnsResolver dnsResolver; + private InitialContextFactory contextFactory; + private DirContext context; @Before public void setup() { - contextFactory = mock(InitialContextFactory.class); - context = mock(DirContext.class); - dnsResolver = new JndiDnsResolver(); - dnsResolver.setCtxFactory(contextFactory); - when(contextFactory.getCtx()).thenReturn(context); + this.contextFactory = mock(InitialContextFactory.class); + this.context = mock(DirContext.class); + this.dnsResolver = new JndiDnsResolver(); + this.dnsResolver.setCtxFactory(this.contextFactory); + given(this.contextFactory.getCtx()).willReturn(this.context); } @Test public void testResolveIpAddress() throws Exception { Attributes records = new BasicAttributes("A", "63.246.7.80"); - - when(context.getAttributes("www.springsource.com", new String[] { "A" })) - .thenReturn(records); - - String ipAddress = dnsResolver.resolveIpAddress("www.springsource.com"); + given(this.context.getAttributes("www.springsource.com", new String[] { "A" })).willReturn(records); + String ipAddress = this.dnsResolver.resolveIpAddress("www.springsource.com"); assertThat(ipAddress).isEqualTo("63.246.7.80"); } @Test(expected = DnsEntryNotFoundException.class) public void testResolveIpAddressNotExisting() throws Exception { - when(context.getAttributes(any(String.class), any(String[].class))).thenThrow( - new NameNotFoundException("not found")); - - dnsResolver.resolveIpAddress("notexisting.ansdansdugiuzgguzgioansdiandwq.foo"); + given(this.context.getAttributes(any(String.class), any(String[].class))) + .willThrow(new NameNotFoundException("not found")); + this.dnsResolver.resolveIpAddress("notexisting.ansdansdugiuzgguzgioansdiandwq.foo"); } @Test public void testResolveServiceEntry() throws Exception { BasicAttributes records = createSrvRecords(); - - when(context.getAttributes("_ldap._tcp.springsource.com", new String[] { "SRV" })) - .thenReturn(records); - - String hostname = dnsResolver.resolveServiceEntry("ldap", "springsource.com"); + given(this.context.getAttributes("_ldap._tcp.springsource.com", new String[] { "SRV" })).willReturn(records); + String hostname = this.dnsResolver.resolveServiceEntry("ldap", "springsource.com"); assertThat(hostname).isEqualTo("kdc.springsource.com"); } @Test(expected = DnsEntryNotFoundException.class) public void testResolveServiceEntryNotExisting() throws Exception { - when(context.getAttributes(any(String.class), any(String[].class))).thenThrow( - new NameNotFoundException("not found")); - - dnsResolver.resolveServiceEntry("wrong", "secpod.de"); + given(this.context.getAttributes(any(String.class), any(String[].class))) + .willThrow(new NameNotFoundException("not found")); + this.dnsResolver.resolveServiceEntry("wrong", "secpod.de"); } @Test public void testResolveServiceIpAddress() throws Exception { BasicAttributes srvRecords = createSrvRecords(); BasicAttributes aRecords = new BasicAttributes("A", "63.246.7.80"); - when(context.getAttributes("_ldap._tcp.springsource.com", new String[] { "SRV" })) - .thenReturn(srvRecords); - when(context.getAttributes("kdc.springsource.com", new String[] { "A" })) - .thenReturn(aRecords); - - String ipAddress = dnsResolver - .resolveServiceIpAddress("ldap", "springsource.com"); + given(this.context.getAttributes("_ldap._tcp.springsource.com", new String[] { "SRV" })).willReturn(srvRecords); + given(this.context.getAttributes("kdc.springsource.com", new String[] { "A" })).willReturn(aRecords); + String ipAddress = this.dnsResolver.resolveServiceIpAddress("ldap", "springsource.com"); assertThat(ipAddress).isEqualTo("63.246.7.80"); } @Test(expected = DnsLookupException.class) public void testUnknowError() throws Exception { - when(context.getAttributes(any(String.class), any(String[].class))).thenThrow( - new NamingException("error")); - dnsResolver.resolveIpAddress(""); + given(this.context.getAttributes(any(String.class), any(String[].class))) + .willThrow(new NamingException("error")); + this.dnsResolver.resolveIpAddress(""); } private BasicAttributes createSrvRecords() { @@ -121,4 +112,5 @@ public class JndiDnsResolverTests { records.put(record); return records; } + } diff --git a/remoting/src/test/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutorTests.java b/remoting/src/test/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutorTests.java index 4bc83c05d5..4bc006bcd9 100644 --- a/remoting/src/test/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutorTests.java +++ b/remoting/src/test/java/org/springframework/security/remoting/httpinvoker/AuthenticationSimpleHttpInvokerRequestExecutorTests.java @@ -16,8 +16,6 @@ package org.springframework.security.remoting.httpinvoker; -import static org.assertj.core.api.Assertions.assertThat; - import java.net.HttpURLConnection; import java.net.URL; import java.util.HashMap; @@ -25,12 +23,15 @@ import java.util.Map; import org.junit.After; import org.junit.Test; + import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.assertThat; + /** * Tests {@link AuthenticationSimpleHttpInvokerRequestExecutor}. * @@ -39,8 +40,6 @@ import org.springframework.security.core.context.SecurityContextHolder; */ public class AuthenticationSimpleHttpInvokerRequestExecutorTests { - // ~ Methods - // ======================================================================================================== @After public void tearDown() { SecurityContextHolder.clearContext(); @@ -49,33 +48,27 @@ public class AuthenticationSimpleHttpInvokerRequestExecutorTests { @Test public void testNormalOperation() throws Exception { // Setup client-side context - Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken( - "Aladdin", "open sesame"); + Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken("Aladdin", "open sesame"); SecurityContextHolder.getContext().setAuthentication(clientSideAuthentication); - // Create a connection and ensure our executor sets its // properties correctly AuthenticationSimpleHttpInvokerRequestExecutor executor = new AuthenticationSimpleHttpInvokerRequestExecutor(); HttpURLConnection conn = new MockHttpURLConnection(new URL("https://localhost/")); executor.prepareConnection(conn, 10); - // Check connection properties // See https://tools.ietf.org/html/rfc1945 section 11.1 for example // we are comparing against - assertThat(conn.getRequestProperty("Authorization")).isEqualTo( - "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="); + assertThat(conn.getRequestProperty("Authorization")).isEqualTo("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="); } @Test public void testNullContextHolderIsNull() throws Exception { SecurityContextHolder.getContext().setAuthentication(null); - // Create a connection and ensure our executor sets its // properties correctly AuthenticationSimpleHttpInvokerRequestExecutor executor = new AuthenticationSimpleHttpInvokerRequestExecutor(); HttpURLConnection conn = new MockHttpURLConnection(new URL("https://localhost/")); executor.prepareConnection(conn, 10); - // Check connection properties (shouldn't be an Authorization header) assertThat(conn.getRequestProperty("Authorization")).isNull(); } @@ -83,23 +76,18 @@ public class AuthenticationSimpleHttpInvokerRequestExecutorTests { // SEC-1975 @Test public void testNullContextHolderWhenAnonymous() throws Exception { - AnonymousAuthenticationToken anonymous = new AnonymousAuthenticationToken("key", - "principal", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + AnonymousAuthenticationToken anonymous = new AnonymousAuthenticationToken("key", "principal", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); SecurityContextHolder.getContext().setAuthentication(anonymous); - // Create a connection and ensure our executor sets its // properties correctly AuthenticationSimpleHttpInvokerRequestExecutor executor = new AuthenticationSimpleHttpInvokerRequestExecutor(); HttpURLConnection conn = new MockHttpURLConnection(new URL("https://localhost/")); executor.prepareConnection(conn, 10); - // Check connection properties (shouldn't be an Authorization header) assertThat(conn.getRequestProperty("Authorization")).isNull(); } - // ~ Inner Classes - // ================================================================================================== - private class MockHttpURLConnection extends HttpURLConnection { private Map requestProperties = new HashMap<>(); @@ -108,24 +96,31 @@ public class AuthenticationSimpleHttpInvokerRequestExecutorTests { super(u); } + @Override public void connect() { throw new UnsupportedOperationException("mock not implemented"); } + @Override public void disconnect() { throw new UnsupportedOperationException("mock not implemented"); } + @Override public String getRequestProperty(String key) { - return requestProperties.get(key); + return this.requestProperties.get(key); } + @Override public void setRequestProperty(String key, String value) { - requestProperties.put(key, value); + this.requestProperties.put(key, value); } + @Override public boolean usingProxy() { throw new UnsupportedOperationException("mock not implemented"); } + } + } diff --git a/remoting/src/test/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationTests.java b/remoting/src/test/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationTests.java index c5877fd021..b7512c245e 100644 --- a/remoting/src/test/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationTests.java +++ b/remoting/src/test/java/org/springframework/security/remoting/rmi/ContextPropagatingRemoteInvocationTests.java @@ -16,14 +16,12 @@ package org.springframework.security.remoting.rmi; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; - import java.lang.reflect.Method; import org.aopalliance.intercept.MethodInvocation; import org.junit.After; import org.junit.Test; + import org.springframework.security.TargetObject; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; @@ -31,6 +29,9 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.util.SimpleMethodInvocation; import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link ContextPropagatingRemoteInvocation} and * {@link ContextPropagatingRemoteInvocationFactory}. @@ -39,8 +40,6 @@ import org.springframework.test.util.ReflectionTestUtils; */ public class ContextPropagatingRemoteInvocationTests { - // ~ Methods - // ======================================================================================================== @After public void tearDown() { SecurityContextHolder.clearContext(); @@ -49,52 +48,40 @@ public class ContextPropagatingRemoteInvocationTests { private ContextPropagatingRemoteInvocation getRemoteInvocation() throws Exception { Class clazz = TargetObject.class; Method method = clazz.getMethod("makeLowerCase", new Class[] { String.class }); - MethodInvocation mi = new SimpleMethodInvocation(new TargetObject(), method, - "SOME_STRING"); - + MethodInvocation mi = new SimpleMethodInvocation(new TargetObject(), method, "SOME_STRING"); ContextPropagatingRemoteInvocationFactory factory = new ContextPropagatingRemoteInvocationFactory(); - return (ContextPropagatingRemoteInvocation) factory.createRemoteInvocation(mi); } @Test public void testContextIsResetEvenIfExceptionOccurs() throws Exception { // Setup client-side context - Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken( - "rod", "koala"); + Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken("rod", "koala"); SecurityContextHolder.getContext().setAuthentication(clientSideAuthentication); - ContextPropagatingRemoteInvocation remoteInvocation = getRemoteInvocation(); - try { // Set up the wrong arguments. remoteInvocation.setArguments(new Object[] {}); remoteInvocation.invoke(TargetObject.class.newInstance()); fail("Expected IllegalArgumentException"); } - catch (IllegalArgumentException e) { + catch (IllegalArgumentException ex) { // expected } - - assertThat( - SecurityContextHolder.getContext().getAuthentication()).withFailMessage( - "Authentication must be null").isNull(); + assertThat(SecurityContextHolder.getContext().getAuthentication()) + .withFailMessage("Authentication must be null").isNull(); } @Test public void testNormalOperation() throws Exception { // Setup client-side context - Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken( - "rod", "koala"); + Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken("rod", "koala"); SecurityContextHolder.getContext().setAuthentication(clientSideAuthentication); - ContextPropagatingRemoteInvocation remoteInvocation = getRemoteInvocation(); - // Set to null, as ContextPropagatingRemoteInvocation already obtained // a copy and nulling is necessary to ensure the Context delivered by // ContextPropagatingRemoteInvocation is used on server-side SecurityContextHolder.clearContext(); - // The result from invoking the TargetObject should contain the // Authentication class delivered via the SecurityContextHolder assertThat(remoteInvocation.invoke(new TargetObject())).isEqualTo( @@ -104,24 +91,19 @@ public class ContextPropagatingRemoteInvocationTests { @Test public void testNullContextHolderDoesNotCauseInvocationProblems() throws Exception { SecurityContextHolder.clearContext(); // just to be explicit - ContextPropagatingRemoteInvocation remoteInvocation = getRemoteInvocation(); SecurityContextHolder.clearContext(); // unnecessary, but for // explicitness - - assertThat(remoteInvocation.invoke(new TargetObject())).isEqualTo( - "some_string Authentication empty"); + assertThat(remoteInvocation.invoke(new TargetObject())).isEqualTo("some_string Authentication empty"); } // SEC-1867 @Test public void testNullCredentials() throws Exception { - Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken( - "rod", null); + Authentication clientSideAuthentication = new UsernamePasswordAuthenticationToken("rod", null); SecurityContextHolder.getContext().setAuthentication(clientSideAuthentication); - ContextPropagatingRemoteInvocation remoteInvocation = getRemoteInvocation(); - assertThat( - ReflectionTestUtils.getField(remoteInvocation, "credentials")).isNull(); + assertThat(ReflectionTestUtils.getField(remoteInvocation, "credentials")).isNull(); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchange.java b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchange.java index 1979d5e263..e08708b570 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchange.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchange.java @@ -17,6 +17,7 @@ package org.springframework.security.rsocket.api; import io.rsocket.Payload; + import org.springframework.util.MimeType; /** @@ -26,6 +27,7 @@ import org.springframework.util.MimeType; * @since 5.2 */ public interface PayloadExchange { + PayloadExchangeType getType(); Payload getPayload(); @@ -33,4 +35,5 @@ public interface PayloadExchange { MimeType getDataMimeType(); MimeType getMetadataMimeType(); + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchangeType.java b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchangeType.java index e31d84b50e..9daa62532c 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchangeType.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadExchangeType.java @@ -23,14 +23,16 @@ package org.springframework.security.rsocket.api; * @since 5.2 */ public enum PayloadExchangeType { + /** - * The Setup. Can - * be used to determine if a Payload is part of the connection + * The Setup. Can be + * used to determine if a Payload is part of the connection */ SETUP(false), /** - * A Fire and Forget exchange. + * A Fire and Forget + * exchange. */ FIRE_AND_FORGET(true), @@ -41,9 +43,9 @@ public enum PayloadExchangeType { REQUEST_RESPONSE(true), /** - * A Request Stream - * exchange. This is only represents the request portion. The {@link #PAYLOAD} type - * represents the data that submitted. + * A Request + * Stream exchange. This is only represents the request portion. The + * {@link #PAYLOAD} type represents the data that submitted. */ REQUEST_STREAM(true), @@ -77,4 +79,5 @@ public enum PayloadExchangeType { public boolean isRequest() { return this.isRequest; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptor.java index 2f6f36c02b..692cfa72cf 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptor.java @@ -19,14 +19,15 @@ package org.springframework.security.rsocket.api; import reactor.core.publisher.Mono; /** - * Contract for interception-style, chained processing of Payloads that may - * be used to implement cross-cutting, application-agnostic requirements such - * as security, timeouts, and others. + * Contract for interception-style, chained processing of Payloads that may be used to + * implement cross-cutting, application-agnostic requirements such as security, timeouts, + * and others. * * @author Rob Winch * @since 5.2 */ public interface PayloadInterceptor { + /** * Process the Web request and (optionally) delegate to the next * {@code PayloadInterceptor} through the given {@link PayloadInterceptorChain}. @@ -35,4 +36,5 @@ public interface PayloadInterceptor { * @return {@code Mono} to indicate when payload processing is complete */ Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain); + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptorChain.java b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptorChain.java index 30307ffc05..edfe567d75 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptorChain.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/api/PayloadInterceptorChain.java @@ -19,16 +19,18 @@ package org.springframework.security.rsocket.api; import reactor.core.publisher.Mono; /** - * Contract to allow a {@link PayloadInterceptor} to delegate to the next in the chain. - * * + * Contract to allow a {@link PayloadInterceptor} to delegate to the next in the chain. * + * * @author Rob Winch * @since 5.2 */ public interface PayloadInterceptorChain { + /** * Process the payload exchange. * @param exchange the current server exchange * @return {@code Mono} to indicate when request processing is complete */ Mono next(PayloadExchange exchange); + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptor.java index 69fc5f6015..ae1b9e2846 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptor.java @@ -16,18 +16,19 @@ package org.springframework.security.rsocket.authentication; +import java.util.List; + +import reactor.core.publisher.Mono; + import org.springframework.core.Ordered; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.util.Assert; -import reactor.core.publisher.Mono; -import org.springframework.security.rsocket.api.PayloadInterceptorChain; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadInterceptor; - -import java.util.List; +import org.springframework.security.rsocket.api.PayloadInterceptorChain; +import org.springframework.util.Assert; /** * If {@link ReactiveSecurityContextHolder} is empty populates an @@ -38,16 +39,17 @@ import java.util.List; */ public class AnonymousPayloadInterceptor implements PayloadInterceptor, Ordered { - private String key; - private Object principal; - private List authorities; + private final String key; + + private final Object principal; + + private final List authorities; private int order; /** * Creates a filter with a principal named "anonymousUser" and the single authority * "ROLE_ANONYMOUS". - * * @param key the key to identify tokens created by this filter */ public AnonymousPayloadInterceptor(String key) { @@ -55,12 +57,11 @@ public class AnonymousPayloadInterceptor implements PayloadInterceptor, Ordered } /** - * @param key key the key to identify tokens created by this filter - * @param principal the principal which will be used to represent anonymous users + * @param key key the key to identify tokens created by this filter + * @param principal the principal which will be used to represent anonymous users * @param authorities the authority list for anonymous users */ - public AnonymousPayloadInterceptor(String key, Object principal, - List authorities) { + public AnonymousPayloadInterceptor(String key, Object principal, List authorities) { Assert.hasLength(key, "key cannot be null or empty"); Assert.notNull(principal, "Anonymous authentication principal must be set"); Assert.notNull(authorities, "Anonymous authorities must be set"); @@ -80,14 +81,13 @@ public class AnonymousPayloadInterceptor implements PayloadInterceptor, Ordered @Override public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) { - return ReactiveSecurityContextHolder.getContext() - .switchIfEmpty(Mono.defer(() -> { - AnonymousAuthenticationToken authentication = new AnonymousAuthenticationToken( - this.key, this.principal, this.authorities); - return chain.next(exchange) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) - .then(Mono.empty()); - })) - .flatMap(securityContext -> chain.next(exchange)); + return ReactiveSecurityContextHolder.getContext().switchIfEmpty(Mono.defer(() -> { + AnonymousAuthenticationToken authentication = new AnonymousAuthenticationToken(this.key, this.principal, + this.authorities); + return chain.next(exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .then(Mono.empty()); + })).flatMap((securityContext) -> chain.next(exchange)); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadExchangeConverter.java index 9b7e1ddae0..096c1bb179 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadExchangeConverter.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadExchangeConverter.java @@ -16,11 +16,16 @@ package org.springframework.security.rsocket.authentication; +import java.nio.charset.StandardCharsets; +import java.util.Map; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.metadata.AuthMetadataCodec; import io.rsocket.metadata.WellKnownAuthType; +import io.rsocket.metadata.WellKnownMimeType; +import reactor.core.publisher.Mono; + import org.springframework.core.codec.ByteArrayDecoder; import org.springframework.messaging.rsocket.DefaultMetadataExtractor; import org.springframework.messaging.rsocket.MetadataExtractor; @@ -30,36 +35,35 @@ import org.springframework.security.oauth2.server.resource.BearerTokenAuthentica import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Mono; - -import java.nio.charset.StandardCharsets; -import java.util.Map; /** - * Converts from the {@link PayloadExchange} for - * Authentication Extension. - * For - * Simple - * a {@link UsernamePasswordAuthenticationToken} is returned. For - * Bearer + * Converts from the {@link PayloadExchange} for Authentication + * Extension. For Simple + * a {@link UsernamePasswordAuthenticationToken} is returned. For Bearer * a {@link BearerTokenAuthenticationToken} is returned. * * @author Rob Winch * @since 5.3 */ public class AuthenticationPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter { - private static final MimeType COMPOSITE_METADATA_MIME_TYPE = MimeTypeUtils.parseMimeType( - WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); - private static final MimeType AUTHENTICATION_MIME_TYPE = MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()); + private static final MimeType COMPOSITE_METADATA_MIME_TYPE = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); - private MetadataExtractor metadataExtractor = createDefaultExtractor(); + private static final MimeType AUTHENTICATION_MIME_TYPE = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()); + + private final MetadataExtractor metadataExtractor = createDefaultExtractor(); @Override public Mono convert(PayloadExchange exchange) { - return Mono.fromCallable(() -> this.metadataExtractor - .extract(exchange.getPayload(), this.COMPOSITE_METADATA_MIME_TYPE)) - .flatMap(metadata -> Mono.justOrEmpty(authentication(metadata))); + return Mono + .fromCallable(() -> this.metadataExtractor.extract(exchange.getPayload(), + AuthenticationPayloadExchangeConverter.COMPOSITE_METADATA_MIME_TYPE)) + .flatMap((metadata) -> Mono.justOrEmpty(authentication(metadata))); } private Authentication authentication(Map metadata) { @@ -74,7 +78,8 @@ public class AuthenticationPayloadExchangeConverter implements PayloadExchangeAu WellKnownAuthType wellKnownAuthType = AuthMetadataCodec.readWellKnownAuthType(rawAuthentication); if (WellKnownAuthType.SIMPLE.equals(wellKnownAuthType)) { return simple(rawAuthentication); - } else if (WellKnownAuthType.BEARER.equals(wellKnownAuthType)) { + } + if (WellKnownAuthType.BEARER.equals(wellKnownAuthType)) { return bearer(rawAuthentication); } throw new IllegalArgumentException("Unknown Mime Type " + wellKnownAuthType); @@ -99,4 +104,5 @@ public class AuthenticationPayloadExchangeConverter implements PayloadExchangeAu result.metadataToExtract(AUTHENTICATION_MIME_TYPE, byte[].class, "authentication"); return result; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptor.java index 2765f1e7b7..0a0aa58777 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptor.java @@ -16,6 +16,8 @@ package org.springframework.security.rsocket.authentication; +import reactor.core.publisher.Mono; + import org.springframework.core.Ordered; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; @@ -24,7 +26,6 @@ import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadInterceptor; import org.springframework.security.rsocket.api.PayloadInterceptorChain; import org.springframework.util.Assert; -import reactor.core.publisher.Mono; /** * Uses the provided {@code ReactiveAuthenticationManager} to authenticate a Payload. If @@ -40,8 +41,7 @@ public class AuthenticationPayloadInterceptor implements PayloadInterceptor, Ord private int order; - private PayloadExchangeAuthenticationConverter authenticationConverter = - new BasicAuthenticationPayloadExchangeConverter(); + private PayloadExchangeAuthenticationConverter authenticationConverter = new BasicAuthenticationPayloadExchangeConverter(); /** * Creates a new instance @@ -65,22 +65,20 @@ public class AuthenticationPayloadInterceptor implements PayloadInterceptor, Ord * Sets the convert to be used * @param authenticationConverter */ - public void setAuthenticationConverter( - PayloadExchangeAuthenticationConverter authenticationConverter) { + public void setAuthenticationConverter(PayloadExchangeAuthenticationConverter authenticationConverter) { Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); this.authenticationConverter = authenticationConverter; } + @Override public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) { - return this.authenticationConverter.convert(exchange) - .switchIfEmpty(chain.next(exchange).then(Mono.empty())) - .flatMap(a -> this.authenticationManager.authenticate(a)) - .flatMap(a -> onAuthenticationSuccess(chain.next(exchange), a)); + return this.authenticationConverter.convert(exchange).switchIfEmpty(chain.next(exchange).then(Mono.empty())) + .flatMap((a) -> this.authenticationManager.authenticate(a)) + .flatMap((a) -> onAuthenticationSuccess(chain.next(exchange), a)); } private Mono onAuthenticationSuccess(Mono payload, Authentication authentication) { - return payload - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); + return payload.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); } } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BasicAuthenticationPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BasicAuthenticationPayloadExchangeConverter.java index dac42d0ba4..1a806c3bb8 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BasicAuthenticationPayloadExchangeConverter.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BasicAuthenticationPayloadExchangeConverter.java @@ -17,6 +17,8 @@ package org.springframework.security.rsocket.authentication; import io.rsocket.metadata.WellKnownMimeType; +import reactor.core.publisher.Mono; + import org.springframework.messaging.rsocket.DefaultMetadataExtractor; import org.springframework.messaging.rsocket.MetadataExtractor; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -26,7 +28,6 @@ import org.springframework.security.rsocket.metadata.BasicAuthenticationDecoder; import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Mono; /** * Converts from the {@link PayloadExchange} to a @@ -38,23 +39,26 @@ import reactor.core.publisher.Mono; */ public class BasicAuthenticationPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter { - private MimeType metadataMimetype = MimeTypeUtils.parseMimeType( - WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + private MimeType metadataMimetype = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); private MetadataExtractor metadataExtractor = createDefaultExtractor(); @Override public Mono convert(PayloadExchange exchange) { - return Mono.fromCallable(() -> this.metadataExtractor - .extract(exchange.getPayload(), this.metadataMimetype)) - .flatMap(metadata -> Mono.justOrEmpty(metadata.get(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE.toString()))) + return Mono.fromCallable(() -> this.metadataExtractor.extract(exchange.getPayload(), this.metadataMimetype)) + .flatMap((metadata) -> Mono + .justOrEmpty(metadata.get(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE.toString()))) .cast(UsernamePasswordMetadata.class) - .map(credentials -> new UsernamePasswordAuthenticationToken(credentials.getUsername(), credentials.getPassword())); + .map((credentials) -> new UsernamePasswordAuthenticationToken(credentials.getUsername(), + credentials.getPassword())); } private static MetadataExtractor createDefaultExtractor() { DefaultMetadataExtractor result = new DefaultMetadataExtractor(new BasicAuthenticationDecoder()); - result.metadataToExtract(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE, UsernamePasswordMetadata.class, (String) null); + result.metadataToExtract(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE, + UsernamePasswordMetadata.class, (String) null); return result; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BearerPayloadExchangeConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BearerPayloadExchangeConverter.java index 9b1d92a58d..030b1d260f 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BearerPayloadExchangeConverter.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/BearerPayloadExchangeConverter.java @@ -16,27 +16,28 @@ package org.springframework.security.rsocket.authentication; +import java.nio.charset.StandardCharsets; + import io.netty.buffer.ByteBuf; import io.rsocket.metadata.CompositeMetadata; +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.metadata.BearerTokenMetadata; -import reactor.core.publisher.Mono; - -import java.nio.charset.StandardCharsets; /** - * Converts from the {@link PayloadExchange} to a - * {@link BearerTokenAuthenticationToken} by extracting - * {@link BearerTokenMetadata#BEARER_AUTHENTICATION_MIME_TYPE} from the metadata. - * @author Rob Winch + * Converts from the {@link PayloadExchange} to a {@link BearerTokenAuthenticationToken} + * by extracting {@link BearerTokenMetadata#BEARER_AUTHENTICATION_MIME_TYPE} from the + * metadata. + * + * @author Rob Winch * @since 5.2 */ public class BearerPayloadExchangeConverter implements PayloadExchangeAuthenticationConverter { - private static final String BEARER_MIME_TYPE_VALUE = - BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE.toString(); + private static final String BEARER_MIME_TYPE_VALUE = BearerTokenMetadata.BEARER_AUTHENTICATION_MIME_TYPE.toString(); @Override public Mono convert(PayloadExchange exchange) { @@ -51,4 +52,5 @@ public class BearerPayloadExchangeConverter implements PayloadExchangeAuthentica } return Mono.empty(); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/PayloadExchangeAuthenticationConverter.java b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/PayloadExchangeAuthenticationConverter.java index dbb247246e..1b1e5d98ff 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authentication/PayloadExchangeAuthenticationConverter.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authentication/PayloadExchangeAuthenticationConverter.java @@ -16,15 +16,19 @@ package org.springframework.security.rsocket.authentication; +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.rsocket.api.PayloadExchange; -import reactor.core.publisher.Mono; /** * Converts from a {@link PayloadExchange} to an {@link Authentication} + * * @author Rob Winch * @since 5.2 */ public interface PayloadExchangeAuthenticationConverter { + Mono convert(PayloadExchange exchange); + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptor.java index c439f98a88..05b030afb7 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptor.java @@ -16,16 +16,17 @@ package org.springframework.security.rsocket.authorization; +import reactor.core.publisher.Mono; + import org.springframework.core.Ordered; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; -import org.springframework.util.Assert; -import reactor.core.publisher.Mono; -import org.springframework.security.rsocket.api.PayloadInterceptorChain; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadInterceptor; +import org.springframework.security.rsocket.api.PayloadInterceptorChain; +import org.springframework.util.Assert; /** * Provides authorization of the {@link PayloadExchange}. @@ -34,12 +35,12 @@ import org.springframework.security.rsocket.api.PayloadInterceptor; * @since 5.2 */ public class AuthorizationPayloadInterceptor implements PayloadInterceptor, Ordered { + private final ReactiveAuthorizationManager authorizationManager; private int order; - public AuthorizationPayloadInterceptor( - ReactiveAuthorizationManager authorizationManager) { + public AuthorizationPayloadInterceptor(ReactiveAuthorizationManager authorizationManager) { Assert.notNull(authorizationManager, "authorizationManager cannot be null"); this.authorizationManager = authorizationManager; } @@ -55,11 +56,12 @@ public class AuthorizationPayloadInterceptor implements PayloadInterceptor, Orde @Override public Mono intercept(PayloadExchange exchange, PayloadInterceptorChain chain) { - return ReactiveSecurityContextHolder.getContext() - .filter(c -> c.getAuthentication() != null) + return ReactiveSecurityContextHolder.getContext().filter((c) -> c.getAuthentication() != null) .map(SecurityContext::getAuthentication) - .switchIfEmpty(Mono.error(() -> new AuthenticationCredentialsNotFoundException("An Authentication (possibly AnonymousAuthenticationToken) is required."))) - .as(authentication -> this.authorizationManager.verify(authentication, exchange)) + .switchIfEmpty(Mono.error(() -> new AuthenticationCredentialsNotFoundException( + "An Authentication (possibly AnonymousAuthenticationToken) is required."))) + .as((authentication) -> this.authorizationManager.verify(authentication, exchange)) .then(chain.next(exchange)); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java b/rsocket/src/main/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java index 0382491137..295d460b2f 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManager.java @@ -16,31 +16,36 @@ package org.springframework.security.rsocket.authorization; -import org.springframework.security.authorization.AuthorizationDecision; -import org.springframework.security.authorization.ReactiveAuthorizationManager; -import org.springframework.security.core.Authentication; -import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import org.springframework.security.rsocket.api.PayloadExchange; -import org.springframework.security.rsocket.util.matcher.PayloadExchangeAuthorizationContext; -import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcher; -import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcherEntry; - import java.util.ArrayList; import java.util.List; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.security.authorization.AuthorizationDecision; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.rsocket.api.PayloadExchange; +import org.springframework.security.rsocket.util.matcher.PayloadExchangeAuthorizationContext; +import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcher; +import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcher.MatchResult; +import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcherEntry; +import org.springframework.util.Assert; + /** * Maps a @{code List} of {@link PayloadExchangeMatcher} instances to - * @{code ReactiveAuthorizationManager} instances. * + * @{code ReactiveAuthorizationManager} instances. * @author Rob Winch * @since 5.2 */ -public class PayloadExchangeMatcherReactiveAuthorizationManager implements ReactiveAuthorizationManager { +public final class PayloadExchangeMatcherReactiveAuthorizationManager + implements ReactiveAuthorizationManager { + private final List>> mappings; - private PayloadExchangeMatcherReactiveAuthorizationManager(List>> mappings) { + private PayloadExchangeMatcherReactiveAuthorizationManager( + List>> mappings) { Assert.notEmpty(mappings, "mappings cannot be null"); this.mappings = mappings; } @@ -48,22 +53,19 @@ public class PayloadExchangeMatcherReactiveAuthorizationManager implements React @Override public Mono check(Mono authentication, PayloadExchange exchange) { return Flux.fromIterable(this.mappings) - .concatMap(mapping -> mapping.getMatcher().matches(exchange) - .filter(PayloadExchangeMatcher.MatchResult::isMatch) - .map(r -> r.getVariables()) - .flatMap(variables -> mapping.getEntry() - .check(authentication, new PayloadExchangeAuthorizationContext(exchange, variables)) - ) - ) - .next() - .switchIfEmpty(Mono.fromCallable(() -> new AuthorizationDecision(false))); + .concatMap((mapping) -> mapping.getMatcher().matches(exchange) + .filter(PayloadExchangeMatcher.MatchResult::isMatch).map(MatchResult::getVariables) + .flatMap((variables) -> mapping.getEntry().check(authentication, + new PayloadExchangeAuthorizationContext(exchange, variables)))) + .next().switchIfEmpty(Mono.fromCallable(() -> new AuthorizationDecision(false))); } public static PayloadExchangeMatcherReactiveAuthorizationManager.Builder builder() { return new PayloadExchangeMatcherReactiveAuthorizationManager.Builder(); } - public static class Builder { + public static final class Builder { + private final List>> mappings = new ArrayList<>(); private Builder() { @@ -78,5 +80,7 @@ public class PayloadExchangeMatcherReactiveAuthorizationManager implements React public PayloadExchangeMatcherReactiveAuthorizationManager build() { return new PayloadExchangeMatcherReactiveAuthorizationManager(this.mappings); } + } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/ContextPayloadInterceptorChain.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/ContextPayloadInterceptorChain.java index c92785b7d0..6697abc725 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/ContextPayloadInterceptorChain.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/ContextPayloadInterceptorChain.java @@ -16,18 +16,20 @@ package org.springframework.security.rsocket.core; -import org.springframework.security.rsocket.api.PayloadExchange; -import org.springframework.security.rsocket.api.PayloadInterceptor; -import org.springframework.security.rsocket.api.PayloadInterceptorChain; -import reactor.core.publisher.Mono; -import reactor.util.context.Context; - import java.util.List; import java.util.ListIterator; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import org.springframework.security.rsocket.api.PayloadExchange; +import org.springframework.security.rsocket.api.PayloadInterceptor; +import org.springframework.security.rsocket.api.PayloadInterceptorChain; + /** - * A {@link PayloadInterceptorChain} which exposes the Reactor {@link Context} via a member variable. - * This class is not Thread safe, so a new instance must be created for each Thread. + * A {@link PayloadInterceptorChain} which exposes the Reactor {@link Context} via a + * member variable. This class is not Thread safe, so a new instance must be created for + * each Thread. * * Internally {@code ContextPayloadInterceptorChain} is used to ensure that the Reactor * {@code Context} is captured so it can be transferred to subscribers outside of this @@ -71,14 +73,10 @@ class ContextPayloadInterceptorChain implements PayloadInterceptorChain { this.next = next; } + @Override public Mono next(PayloadExchange exchange) { - return Mono.defer(() -> - shouldIntercept() ? - this.currentInterceptor.intercept(exchange, this.next) : - Mono.subscriberContext() - .doOnNext(c -> this.context = c) - .then() - ); + return Mono.defer(() -> shouldIntercept() ? this.currentInterceptor.intercept(exchange, this.next) + : Mono.subscriberContext().doOnNext((c) -> this.context = c).then()); } Context getContext() { @@ -96,4 +94,5 @@ class ContextPayloadInterceptorChain implements PayloadInterceptorChain { public String toString() { return getClass().getSimpleName() + "[currentInterceptor=" + this.currentInterceptor + "]"; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/DefaultPayloadExchange.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/DefaultPayloadExchange.java index e49ccb5154..e824c23f9e 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/DefaultPayloadExchange.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/DefaultPayloadExchange.java @@ -17,6 +17,7 @@ package org.springframework.security.rsocket.core; import io.rsocket.Payload; + import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadExchangeType; import org.springframework.util.Assert; @@ -69,4 +70,5 @@ public class DefaultPayloadExchange implements PayloadExchange { public MimeType getDataMimeType() { return this.dataMimeType; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java index b1c7271ed9..3120cab77c 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java @@ -16,25 +16,28 @@ package org.springframework.security.rsocket.core; +import java.util.List; + import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.util.RSocketProxy; import org.reactivestreams.Publisher; -import org.springframework.security.rsocket.api.PayloadExchangeType; -import org.springframework.security.rsocket.api.PayloadInterceptor; -import org.springframework.util.MimeType; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.context.Context; -import java.util.List; +import org.springframework.security.rsocket.api.PayloadExchangeType; +import org.springframework.security.rsocket.api.PayloadInterceptor; +import org.springframework.util.MimeType; /** * Combines the {@link PayloadInterceptor} with an {@link RSocketProxy} + * * @author Rob Winch * @since 5.2 */ class PayloadInterceptorRSocket extends RSocketProxy { + private final List interceptors; private final MimeType metadataMimeType; @@ -43,14 +46,12 @@ class PayloadInterceptorRSocket extends RSocketProxy { private final Context context; - PayloadInterceptorRSocket(RSocket delegate, - List interceptors, MimeType metadataMimeType, + PayloadInterceptorRSocket(RSocket delegate, List interceptors, MimeType metadataMimeType, MimeType dataMimeType) { this(delegate, interceptors, metadataMimeType, dataMimeType, Context.empty()); } - PayloadInterceptorRSocket(RSocket delegate, - List interceptors, MimeType metadataMimeType, + PayloadInterceptorRSocket(RSocket delegate, List interceptors, MimeType metadataMimeType, MimeType dataMimeType, Context context) { super(delegate); this.metadataMimeType = metadataMimeType; @@ -71,71 +72,52 @@ class PayloadInterceptorRSocket extends RSocketProxy { @Override public Mono fireAndForget(Payload payload) { return intercept(PayloadExchangeType.FIRE_AND_FORGET, payload) - .flatMap(context -> - this.source.fireAndForget(payload) - .subscriberContext(context) - ); + .flatMap((context) -> this.source.fireAndForget(payload).subscriberContext(context)); } @Override public Mono requestResponse(Payload payload) { return intercept(PayloadExchangeType.REQUEST_RESPONSE, payload) - .flatMap(context -> - this.source.requestResponse(payload) - .subscriberContext(context) - ); + .flatMap((context) -> this.source.requestResponse(payload).subscriberContext(context)); } @Override public Flux requestStream(Payload payload) { return intercept(PayloadExchangeType.REQUEST_STREAM, payload) - .flatMapMany(context -> - this.source.requestStream(payload) - .subscriberContext(context) - ); + .flatMapMany((context) -> this.source.requestStream(payload).subscriberContext(context)); } @Override public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads) - .switchOnFirst((signal, innerFlux) -> { - Payload firstPayload = signal.get(); - return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload) - .flatMapMany(context -> - innerFlux - .skip(1) - .flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p)) - .transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads)) - .transform(securedPayloads -> this.source.requestChannel(securedPayloads)) - .subscriberContext(context) - ); - }); + return Flux.from(payloads).switchOnFirst((signal, innerFlux) -> { + Payload firstPayload = signal.get(); + return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload).flatMapMany((context) -> innerFlux + .skip(1).flatMap((p) -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p)) + .transform((securedPayloads) -> Flux.concat(Flux.just(firstPayload), securedPayloads)) + .transform((securedPayloads) -> this.source.requestChannel(securedPayloads)) + .subscriberContext(context)); + }); } @Override public Mono metadataPush(Payload payload) { return intercept(PayloadExchangeType.METADATA_PUSH, payload) - .flatMap(c -> this.source - .metadataPush(payload) - .subscriberContext(c) - ); + .flatMap((c) -> this.source.metadataPush(payload).subscriberContext(c)); } private Mono intercept(PayloadExchangeType type, Payload payload) { return Mono.defer(() -> { ContextPayloadInterceptorChain chain = new ContextPayloadInterceptorChain(this.interceptors); - DefaultPayloadExchange exchange = new DefaultPayloadExchange(type, payload, - this.metadataMimeType, this.dataMimeType); - return chain.next(exchange) - .then(Mono.fromCallable(() -> chain.getContext())) - .defaultIfEmpty(Context.empty()) - .subscriberContext(this.context); + DefaultPayloadExchange exchange = new DefaultPayloadExchange(type, payload, this.metadataMimeType, + this.dataMimeType); + return chain.next(exchange).then(Mono.fromCallable(() -> chain.getContext())) + .defaultIfEmpty(Context.empty()).subscriberContext(this.context); }); } @Override public String toString() { - return getClass().getSimpleName() + "[source=" + this.source + ",interceptors=" - + this.interceptors + "]"; + return getClass().getSimpleName() + "[source=" + this.source + ",interceptors=" + this.interceptors + "]"; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java index a5a849096b..ebd0ed7d5c 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java @@ -16,11 +16,16 @@ package org.springframework.security.rsocket.core; +import java.util.List; + import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; import io.rsocket.metadata.WellKnownMimeType; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + import org.springframework.lang.Nullable; import org.springframework.security.rsocket.api.PayloadExchangeType; import org.springframework.security.rsocket.api.PayloadInterceptor; @@ -28,16 +33,13 @@ import org.springframework.util.Assert; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import org.springframework.util.StringUtils; -import reactor.core.publisher.Mono; -import reactor.util.context.Context; - -import java.util.List; /** * @author Rob Winch * @since 5.2 */ class PayloadSocketAcceptor implements SocketAcceptor { + private final SocketAcceptor delegate; private final List interceptors; @@ -45,8 +47,8 @@ class PayloadSocketAcceptor implements SocketAcceptor { @Nullable private MimeType defaultDataMimeType; - private MimeType defaultMetadataMimeType = - MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + private MimeType defaultMetadataMimeType = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); PayloadSocketAcceptor(SocketAcceptor delegate, List interceptors) { Assert.notNull(delegate, "delegate cannot be null"); @@ -64,16 +66,15 @@ class PayloadSocketAcceptor implements SocketAcceptor { public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { MimeType dataMimeType = parseMimeType(setup.dataMimeType(), this.defaultDataMimeType); Assert.notNull(dataMimeType, "No `dataMimeType` in ConnectionSetupPayload and no default value"); - MimeType metadataMimeType = parseMimeType(setup.metadataMimeType(), this.defaultMetadataMimeType); Assert.notNull(metadataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value"); - // FIXME do we want to make the sendingSocket available in the PayloadExchange return intercept(setup, dataMimeType, metadataMimeType) - .flatMap(ctx -> this.delegate.accept(setup, sendingSocket) - .map(acceptingSocket -> new PayloadInterceptorRSocket(acceptingSocket, this.interceptors, metadataMimeType, dataMimeType, ctx)) - .subscriberContext(ctx) - ); + .flatMap( + (ctx) -> this.delegate.accept(setup, sendingSocket) + .map((acceptingSocket) -> new PayloadInterceptorRSocket(acceptingSocket, + this.interceptors, metadataMimeType, dataMimeType, ctx)) + .subscriberContext(ctx)); } private Mono intercept(Payload payload, MimeType dataMimeType, MimeType metadataMimeType) { @@ -81,8 +82,7 @@ class PayloadSocketAcceptor implements SocketAcceptor { ContextPayloadInterceptorChain chain = new ContextPayloadInterceptorChain(this.interceptors); DefaultPayloadExchange exchange = new DefaultPayloadExchange(PayloadExchangeType.SETUP, payload, metadataMimeType, dataMimeType); - return chain.next(exchange) - .then(Mono.fromCallable(() -> chain.getContext())) + return chain.next(exchange).then(Mono.fromCallable(() -> chain.getContext())) .defaultIfEmpty(Context.empty()); }); } @@ -91,12 +91,13 @@ class PayloadSocketAcceptor implements SocketAcceptor { return StringUtils.hasText(str) ? MimeTypeUtils.parseMimeType(str) : defaultMimeType; } - public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) { + void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) { this.defaultDataMimeType = defaultDataMimeType; } - public void setDefaultMetadataMimeType(MimeType defaultMetadataMimeType) { + void setDefaultMetadataMimeType(MimeType defaultMetadataMimeType) { Assert.notNull(defaultMetadataMimeType, "defaultMetadataMimeType cannot be null"); this.defaultMetadataMimeType = defaultMetadataMimeType; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptor.java index 52cba75c45..45af27cb83 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptor.java @@ -16,17 +16,18 @@ package org.springframework.security.rsocket.core; +import java.util.List; + import io.rsocket.SocketAcceptor; import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.plugins.SocketAcceptorInterceptor; + import org.springframework.lang.Nullable; import org.springframework.security.rsocket.api.PayloadInterceptor; import org.springframework.util.Assert; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import java.util.List; - /** * A {@link SocketAcceptorInterceptor} that applies the {@link PayloadInterceptor}s * @@ -40,8 +41,8 @@ public class PayloadSocketAcceptorInterceptor implements SocketAcceptorIntercept @Nullable private MimeType defaultDataMimeType; - private MimeType defaultMetadataMimeType = - MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + private MimeType defaultMetadataMimeType = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); public PayloadSocketAcceptorInterceptor(List interceptors) { this.interceptors = interceptors; @@ -49,8 +50,7 @@ public class PayloadSocketAcceptorInterceptor implements SocketAcceptorIntercept @Override public SocketAcceptor apply(SocketAcceptor socketAcceptor) { - PayloadSocketAcceptor acceptor = new PayloadSocketAcceptor( - socketAcceptor, this.interceptors); + PayloadSocketAcceptor acceptor = new PayloadSocketAcceptor(socketAcceptor, this.interceptors); acceptor.setDefaultDataMimeType(this.defaultDataMimeType); acceptor.setDefaultMetadataMimeType(this.defaultMetadataMimeType); return acceptor; @@ -64,4 +64,5 @@ public class PayloadSocketAcceptorInterceptor implements SocketAcceptorIntercept Assert.notNull(defaultMetadataMimeType, "defaultMetadataMimeType cannot be null"); this.defaultMetadataMimeType = defaultMetadataMimeType; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/SecuritySocketAcceptorInterceptor.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/SecuritySocketAcceptorInterceptor.java index f8f9cbaf81..9eeede18b5 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/SecuritySocketAcceptorInterceptor.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/SecuritySocketAcceptorInterceptor.java @@ -18,16 +18,19 @@ package org.springframework.security.rsocket.core; import io.rsocket.SocketAcceptor; import io.rsocket.plugins.SocketAcceptorInterceptor; + import org.springframework.util.Assert; /** - * A SocketAcceptorInterceptor that applies Security through a delegate {@link SocketAcceptorInterceptor}. This allows - * security to be applied lazily to an application. + * A SocketAcceptorInterceptor that applies Security through a delegate + * {@link SocketAcceptorInterceptor}. This allows security to be applied lazily to an + * application. * * @author Rob Winch * @since 5.2 */ public class SecuritySocketAcceptorInterceptor implements SocketAcceptorInterceptor { + private final SocketAcceptorInterceptor acceptorInterceptor; public SecuritySocketAcceptorInterceptor(SocketAcceptorInterceptor acceptorInterceptor) { @@ -39,4 +42,5 @@ public class SecuritySocketAcceptorInterceptor implements SocketAcceptorIntercep public SocketAcceptor apply(SocketAcceptor socketAcceptor) { return this.acceptorInterceptor.apply(socketAcceptor); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java index 6b8a45fe41..9876fa477e 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoder.java @@ -16,63 +16,62 @@ package org.springframework.security.rsocket.metadata; +import java.util.Map; + import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.core.ResolvableType; import org.springframework.core.codec.AbstractDecoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.util.MimeType; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.util.Map; /** * Decodes {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE} * * @author Rob Winch * @since 5.2 - * @deprecated Basic Authentication did not evolve into a standard. Use Simple Authentication instead. + * @deprecated Basic Authentication did not evolve into a standard. Use Simple + * Authentication instead. */ @Deprecated public class BasicAuthenticationDecoder extends AbstractDecoder { + public BasicAuthenticationDecoder() { super(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE); } @Override - public Flux decode(Publisher input, - ResolvableType elementType, MimeType mimeType, Map hints) { - return Flux.from(input) - .map(DataBuffer::asByteBuffer) - .map(byteBuffer -> { - byte[] sizeBytes = new byte[4]; - byteBuffer.get(sizeBytes); - - int usernameSize = 4; - byte[] usernameBytes = new byte[usernameSize]; - byteBuffer.get(usernameBytes); - byte[] passwordBytes = new byte[byteBuffer.remaining()]; - byteBuffer.get(passwordBytes); - String username = new String(usernameBytes); - String password = new String(passwordBytes); - return new UsernamePasswordMetadata(username, password); - }); + public Flux decode(Publisher input, ResolvableType elementType, + MimeType mimeType, Map hints) { + return Flux.from(input).map(DataBuffer::asByteBuffer).map((byteBuffer) -> { + byte[] sizeBytes = new byte[4]; + byteBuffer.get(sizeBytes); + int usernameSize = 4; + byte[] usernameBytes = new byte[usernameSize]; + byteBuffer.get(usernameBytes); + byte[] passwordBytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(passwordBytes); + String username = new String(usernameBytes); + String password = new String(passwordBytes); + return new UsernamePasswordMetadata(username, password); + }); } @Override - public Mono decodeToMono(Publisher input, - ResolvableType elementType, MimeType mimeType, Map hints) { - return Mono.from(input) - .map(DataBuffer::asByteBuffer) - .map(byteBuffer -> { - int usernameSize = byteBuffer.getInt(); - byte[] usernameBytes = new byte[usernameSize]; - byteBuffer.get(usernameBytes); - byte[] passwordBytes = new byte[byteBuffer.remaining()]; - byteBuffer.get(passwordBytes); - String username = new String(usernameBytes); - String password = new String(passwordBytes); - return new UsernamePasswordMetadata(username, password); - }); + public Mono decodeToMono(Publisher input, ResolvableType elementType, + MimeType mimeType, Map hints) { + return Mono.from(input).map(DataBuffer::asByteBuffer).map((byteBuffer) -> { + int usernameSize = byteBuffer.getInt(); + byte[] usernameBytes = new byte[usernameSize]; + byteBuffer.get(usernameBytes); + byte[] passwordBytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(passwordBytes); + String username = new String(usernameBytes); + String password = new String(passwordBytes); + return new UsernamePasswordMetadata(username, password); + }); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java index 75e3f909ac..d1b6f739ad 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BasicAuthenticationEncoder.java @@ -16,47 +16,45 @@ package org.springframework.security.rsocket.metadata; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Map; + import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + import org.springframework.core.ResolvableType; import org.springframework.core.codec.AbstractEncoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.util.MimeType; -import reactor.core.publisher.Flux; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Map; /** * Encodes {@link UsernamePasswordMetadata#BASIC_AUTHENTICATION_MIME_TYPE} * * @author Rob Winch * @since 5.2 - * @deprecated Basic Authentication did not evolve into a standard. use {@link SimpleAuthenticationEncoder} + * @deprecated Basic Authentication did not evolve into a standard. use + * {@link SimpleAuthenticationEncoder} */ @Deprecated -public class BasicAuthenticationEncoder extends - AbstractEncoder { +public class BasicAuthenticationEncoder extends AbstractEncoder { public BasicAuthenticationEncoder() { super(UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE); } @Override - public Flux encode( - Publisher inputStream, - DataBufferFactory bufferFactory, ResolvableType elementType, - MimeType mimeType, Map hints) { - return Flux.from(inputStream).map(credentials -> - encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); + public Flux encode(Publisher inputStream, + DataBufferFactory bufferFactory, ResolvableType elementType, MimeType mimeType, Map hints) { + return Flux.from(inputStream) + .map((credentials) -> encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); } @Override - public DataBuffer encodeValue(UsernamePasswordMetadata credentials, - DataBufferFactory bufferFactory, ResolvableType valueType, MimeType mimeType, - Map hints) { + public DataBuffer encodeValue(UsernamePasswordMetadata credentials, DataBufferFactory bufferFactory, + ResolvableType valueType, MimeType mimeType, Map hints) { String username = credentials.getUsername(); String password = credentials.getPassword(); byte[] usernameBytes = username.getBytes(StandardCharsets.UTF_8); @@ -69,10 +67,12 @@ public class BasicAuthenticationEncoder extends metadata.write(password.getBytes(StandardCharsets.UTF_8)); release = false; return metadata; - } finally { + } + finally { if (release) { DataBufferUtils.release(metadata); } } } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenAuthenticationEncoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenAuthenticationEncoder.java index ead8f960fa..e822fa8bff 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenAuthenticationEncoder.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenAuthenticationEncoder.java @@ -16,10 +16,14 @@ package org.springframework.security.rsocket.metadata; +import java.util.Map; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.metadata.AuthMetadataCodec; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + import org.springframework.core.ResolvableType; import org.springframework.core.codec.AbstractEncoder; import org.springframework.core.io.buffer.DataBuffer; @@ -27,20 +31,19 @@ import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; - -import java.util.Map; /** - * Encodes Bearer Authentication. + * Encodes Bearer + * Authentication. * * @author Rob Winch * @since 5.3 */ -public class BearerTokenAuthenticationEncoder extends - AbstractEncoder { +public class BearerTokenAuthenticationEncoder extends AbstractEncoder { - private static final MimeType AUTHENTICATION_MIME_TYPE = MimeTypeUtils.parseMimeType("message/x.rsocket.authentication.v0"); + private static final MimeType AUTHENTICATION_MIME_TYPE = MimeTypeUtils + .parseMimeType("message/x.rsocket.authentication.v0"); private NettyDataBufferFactory defaultBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); @@ -49,23 +52,19 @@ public class BearerTokenAuthenticationEncoder extends } @Override - public Flux encode( - Publisher inputStream, - DataBufferFactory bufferFactory, ResolvableType elementType, - MimeType mimeType, Map hints) { - return Flux.from(inputStream).map(credentials -> - encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); + public Flux encode(Publisher inputStream, + DataBufferFactory bufferFactory, ResolvableType elementType, MimeType mimeType, Map hints) { + return Flux.from(inputStream) + .map((credentials) -> encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); } @Override - public DataBuffer encodeValue(BearerTokenMetadata credentials, - DataBufferFactory bufferFactory, ResolvableType valueType, MimeType mimeType, - Map hints) { + public DataBuffer encodeValue(BearerTokenMetadata credentials, DataBufferFactory bufferFactory, + ResolvableType valueType, MimeType mimeType, Map hints) { String token = credentials.getToken(); NettyDataBufferFactory factory = nettyFactory(bufferFactory); ByteBufAllocator allocator = factory.getByteBufAllocator(); - ByteBuf simpleAuthentication = AuthMetadataCodec - .encodeBearerMetadata(allocator, token.toCharArray()); + ByteBuf simpleAuthentication = AuthMetadataCodec.encodeBearerMetadata(allocator, token.toCharArray()); return factory.wrap(simpleAuthentication); } @@ -75,4 +74,5 @@ public class BearerTokenAuthenticationEncoder extends } return this.defaultBufferFactory; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java index 5998b07cdd..364c4e4209 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/BearerTokenMetadata.java @@ -14,28 +14,30 @@ * limitations under the License. */ - package org.springframework.security.rsocket.metadata; import org.springframework.http.MediaType; import org.springframework.util.MimeType; /** - * Represents a bearer token that has been encoded into a - * {@link Payload#metadata()}. + * Represents a bearer token that has been encoded into a {@link Payload#metadata()}. * * @author Rob Winch * @since 5.2 */ public class BearerTokenMetadata { + /** * Represents a bearer token which is encoded as a String. * * See rsocket/rsocket#272 - * @deprecated Basic did not evolve into the standard. Instead use Simple Authentication MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()) + * @deprecated Basic did not evolve into the standard. Instead use Simple + * Authentication + * MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()) */ @Deprecated - public static final MimeType BEARER_AUTHENTICATION_MIME_TYPE = new MediaType("message", "x.rsocket.authentication.bearer.v0"); + public static final MimeType BEARER_AUTHENTICATION_MIME_TYPE = new MediaType("message", + "x.rsocket.authentication.bearer.v0"); private final String token; @@ -46,4 +48,5 @@ public class BearerTokenMetadata { public String getToken() { return this.token; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/SimpleAuthenticationEncoder.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/SimpleAuthenticationEncoder.java index eb4b0c737e..1c31395de7 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/SimpleAuthenticationEncoder.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/SimpleAuthenticationEncoder.java @@ -16,10 +16,14 @@ package org.springframework.security.rsocket.metadata; +import java.util.Map; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.metadata.AuthMetadataCodec; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + import org.springframework.core.ResolvableType; import org.springframework.core.codec.AbstractEncoder; import org.springframework.core.io.buffer.DataBuffer; @@ -27,22 +31,19 @@ import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; - -import java.util.Map; /** - * Encodes - * Simple + * Encodes Simple * Authentication. * * @author Rob Winch * @since 5.3 */ -public class SimpleAuthenticationEncoder extends - AbstractEncoder { +public class SimpleAuthenticationEncoder extends AbstractEncoder { - private static final MimeType AUTHENTICATION_MIME_TYPE = MimeTypeUtils.parseMimeType("message/x.rsocket.authentication.v0"); + private static final MimeType AUTHENTICATION_MIME_TYPE = MimeTypeUtils + .parseMimeType("message/x.rsocket.authentication.v0"); private NettyDataBufferFactory defaultBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); @@ -51,24 +52,21 @@ public class SimpleAuthenticationEncoder extends } @Override - public Flux encode( - Publisher inputStream, - DataBufferFactory bufferFactory, ResolvableType elementType, - MimeType mimeType, Map hints) { - return Flux.from(inputStream).map(credentials -> - encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); + public Flux encode(Publisher inputStream, + DataBufferFactory bufferFactory, ResolvableType elementType, MimeType mimeType, Map hints) { + return Flux.from(inputStream) + .map((credentials) -> encodeValue(credentials, bufferFactory, elementType, mimeType, hints)); } @Override - public DataBuffer encodeValue(UsernamePasswordMetadata credentials, - DataBufferFactory bufferFactory, ResolvableType valueType, MimeType mimeType, - Map hints) { + public DataBuffer encodeValue(UsernamePasswordMetadata credentials, DataBufferFactory bufferFactory, + ResolvableType valueType, MimeType mimeType, Map hints) { String username = credentials.getUsername(); String password = credentials.getPassword(); NettyDataBufferFactory factory = nettyFactory(bufferFactory); ByteBufAllocator allocator = factory.getByteBufAllocator(); - ByteBuf simpleAuthentication = AuthMetadataCodec - .encodeSimpleMetadata(allocator, username.toCharArray(), password.toCharArray()); + ByteBuf simpleAuthentication = AuthMetadataCodec.encodeSimpleMetadata(allocator, username.toCharArray(), + password.toCharArray()); return factory.wrap(simpleAuthentication); } @@ -78,4 +76,5 @@ public class SimpleAuthenticationEncoder extends } return this.defaultBufferFactory; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java index dab4ad6ea5..de32ecc872 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/metadata/UsernamePasswordMetadata.java @@ -17,6 +17,7 @@ package org.springframework.security.rsocket.metadata; import io.rsocket.Payload; + import org.springframework.http.MediaType; import org.springframework.util.MimeType; @@ -28,15 +29,19 @@ import org.springframework.util.MimeType; * @since 5.2 */ public final class UsernamePasswordMetadata { + /** * Represents a username password which is encoded as * {@code ${username-bytes-length}${username-bytes}${password-bytes}}. * * See rsocket/rsocket#272 - * @deprecated Basic did not evolve into the standard. Instead use Simple Authentication MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()) + * @deprecated Basic did not evolve into the standard. Instead use Simple + * Authentication + * MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_AUTHENTICATION.getString()) */ @Deprecated - public static final MimeType BASIC_AUTHENTICATION_MIME_TYPE = new MediaType("message", "x.rsocket.authentication.basic.v0"); + public static final MimeType BASIC_AUTHENTICATION_MIME_TYPE = new MediaType("message", + "x.rsocket.authentication.basic.v0"); private final String username; @@ -54,4 +59,5 @@ public final class UsernamePasswordMetadata { public String getPassword() { return this.password; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeAuthorizationContext.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeAuthorizationContext.java index 7312dbb4d0..7236d7f248 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeAuthorizationContext.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeAuthorizationContext.java @@ -16,17 +16,19 @@ package org.springframework.security.rsocket.util.matcher; -import org.springframework.security.rsocket.api.PayloadExchange; - import java.util.Collections; import java.util.Map; +import org.springframework.security.rsocket.api.PayloadExchange; + /** * @author Rob Winch * @since 5.2 */ public class PayloadExchangeAuthorizationContext { + private final PayloadExchange exchange; + private final Map variables; public PayloadExchangeAuthorizationContext(PayloadExchange exchange) { @@ -45,4 +47,5 @@ public class PayloadExchangeAuthorizationContext { public Map getVariables() { return Collections.unmodifiableMap(this.variables); } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcher.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcher.java index cf8f00bc98..582bf8a2e4 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcher.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcher.java @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.rsocket.util.matcher; -import org.springframework.security.rsocket.api.PayloadExchange; -import reactor.core.publisher.Mono; +package org.springframework.security.rsocket.util.matcher; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import reactor.core.publisher.Mono; + +import org.springframework.security.rsocket.api.PayloadExchange; + /** * An interface for determining if a {@link PayloadExchangeMatcher} matches. + * * @author Rob Winch * @since 5.2 */ @@ -40,7 +43,9 @@ public interface PayloadExchangeMatcher { * The result of matching */ class MatchResult { + private final boolean match; + private final Map variables; private MatchResult(boolean match, Map variables) { @@ -49,7 +54,7 @@ public interface PayloadExchangeMatcher { } public boolean isMatch() { - return match; + return this.match; } /** @@ -57,7 +62,7 @@ public interface PayloadExchangeMatcher { * @return */ public Map getVariables() { - return variables; + return this.variables; } /** @@ -70,12 +75,15 @@ public interface PayloadExchangeMatcher { /** * - * Creates an instance of {@link MatchResult} that is a match with the specified variables + * Creates an instance of {@link MatchResult} that is a match with the specified + * variables * @param variables * @return */ public static Mono match(Map variables) { - return Mono.just(new MatchResult(true, variables == null ? null : new HashMap(variables))); + MatchResult result = new MatchResult(true, + (variables != null) ? new HashMap(variables) : null); + return Mono.just(result); } /** @@ -85,5 +93,7 @@ public interface PayloadExchangeMatcher { public static Mono notMatch() { return Mono.just(new MatchResult(false, Collections.emptyMap())); } + } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcherEntry.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcherEntry.java index 6f7aa6ae28..2431fc8bf4 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcherEntry.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatcherEntry.java @@ -20,7 +20,9 @@ package org.springframework.security.rsocket.util.matcher; * @author Rob Winch */ public class PayloadExchangeMatcherEntry { + private final PayloadExchangeMatcher matcher; + private final T entry; public PayloadExchangeMatcherEntry(PayloadExchangeMatcher matcher, T entry) { @@ -35,4 +37,5 @@ public class PayloadExchangeMatcherEntry { public T getEntry() { return this.entry; } + } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatchers.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatchers.java index a6e110ebeb..ef96a1b1c9 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatchers.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/PayloadExchangeMatchers.java @@ -16,42 +16,51 @@ package org.springframework.security.rsocket.util.matcher; +import reactor.core.publisher.Mono; + import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadExchangeType; -import reactor.core.publisher.Mono; /** * @author Rob Winch */ -public abstract class PayloadExchangeMatchers { +public final class PayloadExchangeMatchers { + + private PayloadExchangeMatchers() { + } public static PayloadExchangeMatcher setup() { return new PayloadExchangeMatcher() { + + @Override public Mono matches(PayloadExchange exchange) { - return PayloadExchangeType.SETUP.equals(exchange.getType()) ? - MatchResult.match() : - MatchResult.notMatch(); + return PayloadExchangeType.SETUP.equals(exchange.getType()) ? MatchResult.match() + : MatchResult.notMatch(); } + }; } public static PayloadExchangeMatcher anyRequest() { return new PayloadExchangeMatcher() { + + @Override public Mono matches(PayloadExchange exchange) { - return exchange.getType().isRequest() ? - MatchResult.match() : - MatchResult.notMatch(); + return exchange.getType().isRequest() ? MatchResult.match() : MatchResult.notMatch(); } + }; } public static PayloadExchangeMatcher anyExchange() { return new PayloadExchangeMatcher() { + + @Override public Mono matches(PayloadExchange exchange) { return MatchResult.match(); } + }; } - private PayloadExchangeMatchers() {} } diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcher.java b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcher.java index ce5bf542d7..d4c4ab8dbe 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcher.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcher.java @@ -16,19 +16,19 @@ package org.springframework.security.rsocket.util.matcher; +import java.util.Map; +import java.util.Optional; + +import reactor.core.publisher.Mono; + import org.springframework.messaging.rsocket.MetadataExtractor; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.util.Assert; import org.springframework.util.RouteMatcher; -import reactor.core.publisher.Mono; -import java.util.Map; -import java.util.Optional; +// FIXME: Pay attention to the package this goes into. It requires spring-messaging for the MetadataExtractor. /** - * FIXME: Pay attention to the package this goes into. It requires spring-messaging for - * the MetadataExtractor. - * * @author Rob Winch * @since 5.2 */ @@ -40,8 +40,7 @@ public class RoutePayloadExchangeMatcher implements PayloadExchangeMatcher { private final RouteMatcher routeMatcher; - public RoutePayloadExchangeMatcher(MetadataExtractor metadataExtractor, - RouteMatcher routeMatcher, String pattern) { + public RoutePayloadExchangeMatcher(MetadataExtractor metadataExtractor, RouteMatcher routeMatcher, String pattern) { Assert.notNull(pattern, "pattern cannot be null"); this.metadataExtractor = metadataExtractor; this.routeMatcher = routeMatcher; @@ -50,12 +49,12 @@ public class RoutePayloadExchangeMatcher implements PayloadExchangeMatcher { @Override public Mono matches(PayloadExchange exchange) { - Map metadata = this.metadataExtractor - .extract(exchange.getPayload(), exchange.getMetadataMimeType()); + Map metadata = this.metadataExtractor.extract(exchange.getPayload(), + exchange.getMetadataMimeType()); return Optional.ofNullable((String) metadata.get(MetadataExtractor.ROUTE_KEY)) - .map(routeValue -> this.routeMatcher.parseRoute(routeValue)) - .map(route -> this.routeMatcher.matchAndExtract(this.pattern, route)) - .map(v -> MatchResult.match(v)) - .orElse(MatchResult.notMatch()); + .map((routeValue) -> this.routeMatcher.parseRoute(routeValue)) + .map((route) -> this.routeMatcher.matchAndExtract(this.pattern, route)).map((v) -> MatchResult.match(v)) + .orElse(MatchResult.notMatch()); } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java index 5aa69706fe..67465ccc21 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AnonymousPayloadInterceptorTests.java @@ -16,11 +16,14 @@ package org.springframework.security.rsocket.authentication; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; + import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; @@ -29,15 +32,15 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.rsocket.api.PayloadExchange; -import java.util.List; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class AnonymousPayloadInterceptorTests { + @Mock private PayloadExchange exchange; @@ -51,57 +54,46 @@ public class AnonymousPayloadInterceptorTests { @Test public void constructorKeyWhenKeyNullThenException() { String key = null; - assertThatCode(() -> new AnonymousPayloadInterceptor(key)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new AnonymousPayloadInterceptor(key)); } @Test public void constructorKeyPrincipalAuthoritiesWhenKeyNullThenException() { String key = null; - assertThatCode(() -> new AnonymousPayloadInterceptor(key, "principal", - AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new AnonymousPayloadInterceptor(key, "principal", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); } @Test public void constructorKeyPrincipalAuthoritiesWhenPrincipalNullThenException() { Object principal = null; - assertThatCode(() -> new AnonymousPayloadInterceptor("key", principal, - AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new AnonymousPayloadInterceptor("key", principal, + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); } @Test public void constructorKeyPrincipalAuthoritiesWhenAuthoritiesNullThenException() { List authorities = null; - assertThatCode(() -> new AnonymousPayloadInterceptor("key", "principal", - authorities)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> new AnonymousPayloadInterceptor("key", "principal", authorities)); } @Test public void interceptWhenNoAuthenticationThenAnonymousAuthentication() { AuthenticationPayloadInterceptorChain chain = new AuthenticationPayloadInterceptorChain(); - this.interceptor.intercept(this.exchange, chain).block(); - Authentication authentication = chain.getAuthentication(); - assertThat(authentication).isInstanceOf(AnonymousAuthenticationToken.class); } @Test public void interceptWhenAuthenticationThenOriginalAuthentication() { AuthenticationPayloadInterceptorChain chain = new AuthenticationPayloadInterceptorChain(); - TestingAuthenticationToken expected = - new TestingAuthenticationToken("test", "password"); - + TestingAuthenticationToken expected = new TestingAuthenticationToken("test", "password"); this.interceptor.intercept(this.exchange, chain) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expected)) - .block(); - + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expected)).block(); Authentication authentication = chain.getAuthentication(); - assertThat(authentication).isEqualTo(expected); } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java index 30693d6b45..e1effd435f 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorChain.java @@ -16,30 +16,33 @@ package org.springframework.security.rsocket.authentication; +import reactor.core.publisher.Mono; + import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; -import reactor.core.publisher.Mono; -import org.springframework.security.rsocket.api.PayloadInterceptorChain; import org.springframework.security.rsocket.api.PayloadExchange; +import org.springframework.security.rsocket.api.PayloadInterceptorChain; /** * @author Rob Winch */ class AuthenticationPayloadInterceptorChain implements PayloadInterceptorChain { + private Authentication authentication; @Override public Mono next(PayloadExchange exchange) { return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) - .doOnNext(a -> this.setAuthentication(a)).then(); + .doOnNext((a) -> this.setAuthentication(a)).then(); } - public Authentication getAuthentication() { + Authentication getAuthentication() { return this.authentication; } - public void setAuthentication(Authentication authentication) { + void setAuthentication(Authentication authentication) { this.authentication = authentication; } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java index 7da8b63bcf..474d6afe49 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authentication/AuthenticationPayloadInterceptorTests.java @@ -16,6 +16,8 @@ package org.springframework.security.rsocket.authentication; +import java.util.Map; + import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.rsocket.Payload; @@ -28,6 +30,10 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; + import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; @@ -37,33 +43,31 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadExchangeType; +import org.springframework.security.rsocket.api.PayloadInterceptorChain; +import org.springframework.security.rsocket.core.DefaultPayloadExchange; import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder; import org.springframework.security.rsocket.metadata.UsernamePasswordMetadata; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import reactor.test.publisher.PublisherProbe; -import org.springframework.security.rsocket.core.DefaultPayloadExchange; -import org.springframework.security.rsocket.api.PayloadInterceptorChain; -import org.springframework.security.rsocket.api.PayloadExchange; -import java.util.Map; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class AuthenticationPayloadInterceptorTests { - static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType( - WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + static final MimeType COMPOSITE_METADATA = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + @Mock ReactiveAuthenticationManager authenticationManager; @@ -72,76 +76,55 @@ public class AuthenticationPayloadInterceptorTests { @Test public void constructorWhenAuthenticationManagerNullThenException() { - assertThatCode(() -> new AuthenticationPayloadInterceptor(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new AuthenticationPayloadInterceptor(null)); } @Test public void interceptWhenBasicCredentialsThenAuthenticates() { - AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor( - this.authenticationManager); + AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor(this.authenticationManager); PayloadExchange exchange = createExchange(); - TestingAuthenticationToken expectedAuthentication = - new TestingAuthenticationToken("user", "password"); - when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just( - expectedAuthentication)); - + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password"); + given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(expectedAuthentication)); AuthenticationPayloadInterceptorChain authenticationPayloadChain = new AuthenticationPayloadInterceptorChain(); - interceptor.intercept(exchange, authenticationPayloadChain) - .block(); - + interceptor.intercept(exchange, authenticationPayloadChain).block(); Authentication authentication = authenticationPayloadChain.getAuthentication(); - verify(this.authenticationManager).authenticate(this.authenticationArg.capture()); - assertThat(this.authenticationArg.getValue()).isEqualToComparingFieldByField(new UsernamePasswordAuthenticationToken("user", "password")); + assertThat(this.authenticationArg.getValue()) + .isEqualToComparingFieldByField(new UsernamePasswordAuthenticationToken("user", "password")); assertThat(authentication).isEqualTo(expectedAuthentication); } @Test public void interceptWhenAuthenticationSuccessThenChainSubscribedOnce() { - AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor( - this.authenticationManager); - + AuthenticationPayloadInterceptor interceptor = new AuthenticationPayloadInterceptor(this.authenticationManager); PayloadExchange exchange = createExchange(); - TestingAuthenticationToken expectedAuthentication = - new TestingAuthenticationToken("user", "password"); - when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just( - expectedAuthentication)); - + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password"); + given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(expectedAuthentication)); PublisherProbe voidResult = PublisherProbe.empty(); PayloadInterceptorChain chain = mock(PayloadInterceptorChain.class); - when(chain.next(any())).thenReturn(voidResult.mono()); - - + given(chain.next(any())).willReturn(voidResult.mono()); StepVerifier.create(interceptor.intercept(exchange, chain)) - .then(() -> assertThat(voidResult.subscribeCount()).isEqualTo(1)) - .verifyComplete(); + .then(() -> assertThat(voidResult.subscribeCount()).isEqualTo(1)).verifyComplete(); } private Payload createRequestPayload() { - UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); BasicAuthenticationEncoder encoder = new BasicAuthenticationEncoder(); DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); - ResolvableType elementType = ResolvableType - .forClass(UsernamePasswordMetadata.class); + ResolvableType elementType = ResolvableType.forClass(UsernamePasswordMetadata.class); MimeType mimeType = UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE; Map hints = null; - DataBuffer dataBuffer = encoder.encodeValue(credentials, factory, - elementType, mimeType, hints); - + DataBuffer dataBuffer = encoder.encodeValue(credentials, factory, elementType, mimeType, hints); ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; CompositeByteBuf metadata = allocator.compositeBuffer(); - CompositeMetadataCodec.encodeAndAddMetadata( - metadata, allocator, mimeType.toString(), NettyDataBufferFactory.toByteBuf(dataBuffer)); - - return DefaultPayload.create(allocator.buffer(), - metadata); + CompositeMetadataCodec.encodeAndAddMetadata(metadata, allocator, mimeType.toString(), + NettyDataBufferFactory.toByteBuf(dataBuffer)); + return DefaultPayload.create(allocator.buffer(), metadata); } private PayloadExchange createExchange() { - return new DefaultPayloadExchange(PayloadExchangeType.REQUEST_RESPONSE, createRequestPayload(), COMPOSITE_METADATA, - MediaType.APPLICATION_JSON); + return new DefaultPayloadExchange(PayloadExchangeType.REQUEST_RESPONSE, createRequestPayload(), + COMPOSITE_METADATA, MediaType.APPLICATION_JSON); } } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java index 5e5178a46d..a4bbd99bd9 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/AuthorizationPayloadInterceptorTests.java @@ -20,28 +20,30 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; -import org.springframework.security.access.AccessDeniedException; -import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.authorization.ReactiveAuthorizationManager; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import reactor.test.publisher.PublisherProbe; import reactor.util.context.Context; -import org.springframework.security.rsocket.api.PayloadInterceptorChain; -import org.springframework.security.rsocket.api.PayloadExchange; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.when; -import static org.springframework.security.authorization.AuthenticatedReactiveAuthorizationManager.authenticated; -import static org.springframework.security.authorization.AuthorityReactiveAuthorizationManager.hasRole; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.authorization.AuthenticatedReactiveAuthorizationManager; +import org.springframework.security.authorization.AuthorityReactiveAuthorizationManager; +import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.rsocket.api.PayloadExchange; +import org.springframework.security.rsocket.api.PayloadInterceptorChain; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class AuthorizationPayloadInterceptorTests { + @Mock private ReactiveAuthorizationManager authorizationManager; @@ -57,61 +59,44 @@ public class AuthorizationPayloadInterceptorTests { @Test public void interceptWhenAuthenticationEmptyAndSubscribedThenException() { - when(this.chain.next(any())).thenReturn(this.chainResult.mono()); - - AuthorizationPayloadInterceptor interceptor = - new AuthorizationPayloadInterceptor(authenticated()); - + given(this.chain.next(any())).willReturn(this.chainResult.mono()); + AuthorizationPayloadInterceptor interceptor = new AuthorizationPayloadInterceptor( + AuthenticatedReactiveAuthorizationManager.authenticated()); StepVerifier.create(interceptor.intercept(this.exchange, this.chain)) - .then(() -> this.chainResult.assertWasNotSubscribed()) - .verifyError(AuthenticationCredentialsNotFoundException.class); + .then(() -> this.chainResult.assertWasNotSubscribed()) + .verifyError(AuthenticationCredentialsNotFoundException.class); } @Test public void interceptWhenAuthenticationNotSubscribedAndEmptyThenCompletes() { - when(this.chain.next(any())).thenReturn(this.chainResult.mono()); - when(this.authorizationManager.verify(any(), any())) - .thenReturn(this.managerResult.mono()); - - AuthorizationPayloadInterceptor interceptor = - new AuthorizationPayloadInterceptor(this.authorizationManager); - + given(this.chain.next(any())).willReturn(this.chainResult.mono()); + given(this.authorizationManager.verify(any(), any())).willReturn(this.managerResult.mono()); + AuthorizationPayloadInterceptor interceptor = new AuthorizationPayloadInterceptor(this.authorizationManager); StepVerifier.create(interceptor.intercept(this.exchange, this.chain)) - .then(() -> this.chainResult.assertWasSubscribed()) - .verifyComplete(); + .then(() -> this.chainResult.assertWasSubscribed()).verifyComplete(); } @Test public void interceptWhenNotAuthorizedThenException() { - when(this.chain.next(any())).thenReturn(this.chainResult.mono()); - - AuthorizationPayloadInterceptor interceptor = - new AuthorizationPayloadInterceptor(hasRole("USER")); + given(this.chain.next(any())).willReturn(this.chainResult.mono()); + AuthorizationPayloadInterceptor interceptor = new AuthorizationPayloadInterceptor( + AuthorityReactiveAuthorizationManager.hasRole("USER")); Context userContext = ReactiveSecurityContextHolder .withAuthentication(new TestingAuthenticationToken("user", "password")); - - Mono intercept = interceptor.intercept(this.exchange, this.chain) - .subscriberContext(userContext); - - StepVerifier.create(intercept) - .then(() -> this.chainResult.assertWasNotSubscribed()) + Mono intercept = interceptor.intercept(this.exchange, this.chain).subscriberContext(userContext); + StepVerifier.create(intercept).then(() -> this.chainResult.assertWasNotSubscribed()) .verifyError(AccessDeniedException.class); } @Test public void interceptWhenAuthorizedThenContinues() { - when(this.chain.next(any())).thenReturn(this.chainResult.mono()); - - AuthorizationPayloadInterceptor interceptor = - new AuthorizationPayloadInterceptor(authenticated()); + given(this.chain.next(any())).willReturn(this.chainResult.mono()); + AuthorizationPayloadInterceptor interceptor = new AuthorizationPayloadInterceptor( + AuthenticatedReactiveAuthorizationManager.authenticated()); Context userContext = ReactiveSecurityContextHolder .withAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER")); - - Mono intercept = interceptor.intercept(this.exchange, this.chain) - .subscriberContext(userContext); - - StepVerifier.create(intercept) - .then(() -> this.chainResult.assertWasSubscribed()) - .verifyComplete(); + Mono intercept = interceptor.intercept(this.exchange, this.chain).subscriberContext(userContext); + StepVerifier.create(intercept).then(() -> this.chainResult.assertWasSubscribed()).verifyComplete(); } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTests.java similarity index 61% rename from rsocket/src/test/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java rename to rsocket/src/test/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTests.java index 03a614b792..395c528108 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTest.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/authorization/PayloadExchangeMatcherReactiveAuthorizationManagerTests.java @@ -20,6 +20,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.rsocket.api.PayloadExchange; @@ -27,17 +29,16 @@ import org.springframework.security.rsocket.util.matcher.PayloadExchangeAuthoriz import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcher; import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatcherEntry; import org.springframework.security.rsocket.util.matcher.PayloadExchangeMatchers; -import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) -public class PayloadExchangeMatcherReactiveAuthorizationManagerTest { +public class PayloadExchangeMatcherReactiveAuthorizationManagerTests { @Mock private ReactiveAuthorizationManager authz; @@ -51,58 +52,45 @@ public class PayloadExchangeMatcherReactiveAuthorizationManagerTest { @Test public void checkWhenGrantedThenGranted() { AuthorizationDecision expected = new AuthorizationDecision(true); - when(this.authz.check(any(), any())).thenReturn(Mono.just( - expected)); - PayloadExchangeMatcherReactiveAuthorizationManager manager = - PayloadExchangeMatcherReactiveAuthorizationManager.builder() - .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) - .build(); - - assertThat(manager.check(Mono.empty(), this.exchange).block()) - .isEqualTo(expected); + given(this.authz.check(any(), any())).willReturn(Mono.just(expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = PayloadExchangeMatcherReactiveAuthorizationManager + .builder().add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) + .build(); + assertThat(manager.check(Mono.empty(), this.exchange).block()).isEqualTo(expected); } @Test public void checkWhenDeniedThenDenied() { AuthorizationDecision expected = new AuthorizationDecision(false); - when(this.authz.check(any(), any())).thenReturn(Mono.just( - expected)); - PayloadExchangeMatcherReactiveAuthorizationManager manager = - PayloadExchangeMatcherReactiveAuthorizationManager.builder() - .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) - .build(); - - assertThat(manager.check(Mono.empty(), this.exchange).block()) - .isEqualTo(expected); + given(this.authz.check(any(), any())).willReturn(Mono.just(expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = PayloadExchangeMatcherReactiveAuthorizationManager + .builder().add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) + .build(); + assertThat(manager.check(Mono.empty(), this.exchange).block()).isEqualTo(expected); } @Test public void checkWhenFirstMatchThenSecondUsed() { AuthorizationDecision expected = new AuthorizationDecision(true); - when(this.authz.check(any(), any())).thenReturn(Mono.just( - expected)); - PayloadExchangeMatcherReactiveAuthorizationManager manager = - PayloadExchangeMatcherReactiveAuthorizationManager.builder() - .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) - .add(new PayloadExchangeMatcherEntry<>(e -> PayloadExchangeMatcher.MatchResult.notMatch(), this.authz2)) - .build(); - - assertThat(manager.check(Mono.empty(), this.exchange).block()) - .isEqualTo(expected); + given(this.authz.check(any(), any())).willReturn(Mono.just(expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = PayloadExchangeMatcherReactiveAuthorizationManager + .builder().add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz)) + .add(new PayloadExchangeMatcherEntry<>((e) -> PayloadExchangeMatcher.MatchResult.notMatch(), + this.authz2)) + .build(); + assertThat(manager.check(Mono.empty(), this.exchange).block()).isEqualTo(expected); } @Test public void checkWhenSecondMatchThenSecondUsed() { AuthorizationDecision expected = new AuthorizationDecision(true); - when(this.authz2.check(any(), any())).thenReturn(Mono.just( - expected)); - PayloadExchangeMatcherReactiveAuthorizationManager manager = - PayloadExchangeMatcherReactiveAuthorizationManager.builder() - .add(new PayloadExchangeMatcherEntry<>(e -> PayloadExchangeMatcher.MatchResult.notMatch(), this.authz)) - .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz2)) - .build(); - - assertThat(manager.check(Mono.empty(), this.exchange).block()) - .isEqualTo(expected); + given(this.authz2.check(any(), any())).willReturn(Mono.just(expected)); + PayloadExchangeMatcherReactiveAuthorizationManager manager = PayloadExchangeMatcherReactiveAuthorizationManager + .builder() + .add(new PayloadExchangeMatcherEntry<>((e) -> PayloadExchangeMatcher.MatchResult.notMatch(), + this.authz)) + .add(new PayloadExchangeMatcherEntry<>(PayloadExchangeMatchers.anyExchange(), this.authz2)).build(); + assertThat(manager.check(Mono.empty(), this.exchange).block()).isEqualTo(expected); } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java index f434ed4c70..b8e54b7aa0 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java @@ -25,10 +25,13 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; /** - * A {@link SocketAcceptor} that captures the {@link SecurityContext} and then continues with the {@link RSocket} + * A {@link SocketAcceptor} that captures the {@link SecurityContext} and then continues + * with the {@link RSocket} + * * @author Rob Winch */ class CaptureSecurityContextSocketAcceptor implements SocketAcceptor { + private final RSocket accept; private SecurityContext securityContext; @@ -40,11 +43,11 @@ class CaptureSecurityContextSocketAcceptor implements SocketAcceptor { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { return ReactiveSecurityContextHolder.getContext() - .doOnNext(securityContext -> this.securityContext = securityContext) - .thenReturn(this.accept); + .doOnNext((securityContext) -> this.securityContext = securityContext).thenReturn(this.accept); } - public SecurityContext getSecurityContext() { + SecurityContext getSecurityContext() { return this.securityContext; } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java index a925ac676e..3a153b8c61 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java @@ -16,6 +16,10 @@ package org.springframework.security.rsocket.core; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.metadata.WellKnownMimeType; @@ -28,6 +32,12 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; +import reactor.test.publisher.TestPublisher; + import org.springframework.http.MediaType; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; @@ -37,26 +47,17 @@ import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadExchangeType; import org.springframework.security.rsocket.api.PayloadInterceptor; import org.springframework.security.rsocket.api.PayloadInterceptorChain; -import org.springframework.security.rsocket.core.DefaultPayloadExchange; -import org.springframework.security.rsocket.core.PayloadInterceptorRSocket; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import reactor.test.publisher.PublisherProbe; -import reactor.test.publisher.TestPublisher; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -64,8 +65,8 @@ import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class PayloadInterceptorRSocketTests { - static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType( - WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + static final MimeType COMPOSITE_METADATA = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); @Mock RSocket delegate; @@ -94,43 +95,33 @@ public class PayloadInterceptorRSocketTests { public void constructorWhenNullDelegateThenException() { this.delegate = null; List interceptors = Arrays.asList(this.interceptor); - assertThatCode(() -> { - new PayloadInterceptorRSocket(this.delegate, interceptors, - metadataMimeType, dataMimeType); - }) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, + this.metadataMimeType, this.dataMimeType)); } @Test public void constructorWhenNullInterceptorsThenException() { List interceptors = null; - assertThatCode(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, - metadataMimeType, dataMimeType)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, + this.metadataMimeType, this.dataMimeType)); } @Test public void constructorWhenEmptyInterceptorsThenException() { List interceptors = Collections.emptyList(); - assertThatCode(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, - metadataMimeType, dataMimeType)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new PayloadInterceptorRSocket(this.delegate, interceptors, + this.metadataMimeType, this.dataMimeType)); } // single interceptor - @Test public void fireAndForgetWhenInterceptorCompletesThenDelegateSubscribed() { - when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); - when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withChainNext()); + given(this.delegate.fireAndForget(any())).willReturn(this.voidResult.mono()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.fireAndForget(this.payload)) - .then(() -> this.voidResult.assertWasSubscribed()) - .verifyComplete(); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.fireAndForget(this.payload)).then(() -> this.voidResult.assertWasSubscribed()) + .verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -138,15 +129,12 @@ public class PayloadInterceptorRSocketTests { @Test public void fireAndForgetWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); StepVerifier.create(interceptor.fireAndForget(this.payload)) - .then(() -> this.voidResult.assertWasNotSubscribed()) - .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); - + .then(() -> this.voidResult.assertWasNotSubscribed()) + .verifyErrorSatisfies((e) -> assertThat(e).isEqualTo(expected)); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -154,21 +142,17 @@ public class PayloadInterceptorRSocketTests { @Test public void fireAndForgetWhenSecurityContextThenDelegateContext() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); - when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); - when(this.delegate.fireAndForget(any())).thenReturn(Mono.empty()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withAuthenticated(authentication)); + given(this.delegate.fireAndForget(any())).willReturn(Mono.empty()); RSocket assertAuthentication = new RSocketProxy(this.delegate) { @Override public Mono fireAndForget(Payload payload) { - return assertAuthentication(authentication) - .flatMap(a -> super.fireAndForget(payload)); + return assertAuthentication(authentication).flatMap((a) -> super.fireAndForget(payload)); } }; PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); interceptor.fireAndForget(this.payload).block(); - verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).fireAndForget(this.payload); @@ -176,18 +160,13 @@ public class PayloadInterceptorRSocketTests { @Test public void requestResponseWhenInterceptorCompletesThenDelegateSubscribed() { - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); - when(this.delegate.requestResponse(any())).thenReturn(this.payloadResult.mono()); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()); + given(this.delegate.requestResponse(any())).willReturn(this.payloadResult.mono()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); StepVerifier.create(interceptor.requestResponse(this.payload)) - .then(() -> this.payloadResult.assertSubscribers()) - .then(() -> this.payloadResult.emit(this.payload)) - .expectNext(this.payload) - .verifyComplete(); - + .then(() -> this.payloadResult.assertSubscribers()).then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).requestResponse(this.payload); @@ -196,13 +175,11 @@ public class PayloadInterceptorRSocketTests { @Test public void requestResponseWhenInterceptorErrorsThenDelegateNotInvoked() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - assertThatCode(() -> interceptor.requestResponse(this.payload).block()).isEqualTo(expected); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> interceptor.requestResponse(this.payload).block()).isEqualTo(expected); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verifyZeroInteractions(this.delegate); @@ -211,25 +188,19 @@ public class PayloadInterceptorRSocketTests { @Test public void requestResponseWhenSecurityContextThenDelegateContext() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); - when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); - when(this.delegate.requestResponse(any())).thenReturn(this.payloadResult.mono()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withAuthenticated(authentication)); + given(this.delegate.requestResponse(any())).willReturn(this.payloadResult.mono()); RSocket assertAuthentication = new RSocketProxy(this.delegate) { @Override public Mono requestResponse(Payload payload) { - return assertAuthentication(authentication) - .flatMap(a -> super.requestResponse(payload)); + return assertAuthentication(authentication).flatMap((a) -> super.requestResponse(payload)); } }; PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); StepVerifier.create(interceptor.requestResponse(this.payload)) - .then(() -> this.payloadResult.assertSubscribers()) - .then(() -> this.payloadResult.emit(this.payload)) - .expectNext(this.payload) - .verifyComplete(); - + .then(() -> this.payloadResult.assertSubscribers()).then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).requestResponse(this.payload); @@ -237,18 +208,12 @@ public class PayloadInterceptorRSocketTests { @Test public void requestStreamWhenInterceptorCompletesThenDelegateSubscribed() { - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); - when(this.delegate.requestStream(any())).thenReturn(this.payloadResult.flux()); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()); + given(this.delegate.requestStream(any())).willReturn(this.payloadResult.flux()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.requestStream(this.payload)) - .then(() -> this.payloadResult.assertSubscribers()) - .then(() -> this.payloadResult.emit(this.payload)) - .expectNext(this.payload) - .verifyComplete(); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.requestStream(this.payload)).then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)).expectNext(this.payload).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -256,15 +221,12 @@ public class PayloadInterceptorRSocketTests { @Test public void requestStreamWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); StepVerifier.create(interceptor.requestStream(this.payload)) - .then(() -> this.payloadResult.assertNoSubscribers()) - .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); - + .then(() -> this.payloadResult.assertNoSubscribers()) + .verifyErrorSatisfies((e) -> assertThat(e).isEqualTo(expected)); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -272,25 +234,18 @@ public class PayloadInterceptorRSocketTests { @Test public void requestStreamWhenSecurityContextThenDelegateContext() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); - when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); - when(this.delegate.requestStream(any())).thenReturn(this.payloadResult.flux()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withAuthenticated(authentication)); + given(this.delegate.requestStream(any())).willReturn(this.payloadResult.flux()); RSocket assertAuthentication = new RSocketProxy(this.delegate) { @Override public Flux requestStream(Payload payload) { - return assertAuthentication(authentication) - .flatMapMany(a -> super.requestStream(payload)); + return assertAuthentication(authentication).flatMapMany((a) -> super.requestStream(payload)); } }; PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.requestStream(this.payload)) - .then(() -> this.payloadResult.assertSubscribers()) - .then(() -> this.payloadResult.emit(this.payload)) - .expectNext(this.payload) - .verifyComplete(); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.requestStream(this.payload)).then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)).expectNext(this.payload).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).requestStream(this.payload); @@ -298,18 +253,13 @@ public class PayloadInterceptorRSocketTests { @Test public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() { - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); - when(this.delegate.requestChannel(any())).thenReturn(this.payloadResult.flux()); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()); + given(this.delegate.requestChannel(any())).willReturn(this.payloadResult.flux()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); StepVerifier.create(interceptor.requestChannel(Flux.just(this.payload))) - .then(() -> this.payloadResult.assertSubscribers()) - .then(() -> this.payloadResult.emit(this.payload)) - .expectNext(this.payload) - .verifyComplete(); - + .then(() -> this.payloadResult.assertSubscribers()).then(() -> this.payloadResult.emit(this.payload)) + .expectNext(this.payload).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).requestChannel(any()); @@ -318,15 +268,12 @@ public class PayloadInterceptorRSocketTests { @Test public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); - StepVerifier.create(interceptor.requestChannel(Flux.just(this.payload))) - .then(() -> this.payloadResult.assertNoSubscribers()) - .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); - + .then(() -> this.payloadResult.assertNoSubscribers()) + .verifyErrorSatisfies((e) -> assertThat(e).isEqualTo(expected)); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -335,25 +282,18 @@ public class PayloadInterceptorRSocketTests { public void requestChannelWhenSecurityContextThenDelegateContext() { Mono payload = Mono.just(this.payload); TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); - when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); - when(this.delegate.requestChannel(any())).thenReturn(this.payloadResult.flux()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withAuthenticated(authentication)); + given(this.delegate.requestChannel(any())).willReturn(this.payloadResult.flux()); RSocket assertAuthentication = new RSocketProxy(this.delegate) { @Override public Flux requestChannel(Publisher payload) { - return assertAuthentication(authentication) - .flatMapMany(a -> super.requestChannel(payload)); + return assertAuthentication(authentication).flatMapMany((a) -> super.requestChannel(payload)); } }; PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.requestChannel(payload)) - .then(() -> this.payloadResult.assertSubscribers()) - .then(() -> this.payloadResult.emit(this.payload)) - .expectNext(this.payload) - .verifyComplete(); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.requestChannel(payload)).then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(this.payload)).expectNext(this.payload).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).requestChannel(any()); @@ -361,16 +301,12 @@ public class PayloadInterceptorRSocketTests { @Test public void metadataPushWhenInterceptorCompletesThenDelegateSubscribed() { - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); - when(this.delegate.metadataPush(any())).thenReturn(this.voidResult.mono()); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()); + given(this.delegate.metadataPush(any())).willReturn(this.voidResult.mono()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.metadataPush(this.payload)) - .then(() -> this.voidResult.assertWasSubscribed()) - .verifyComplete(); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.metadataPush(this.payload)).then(() -> this.voidResult.assertWasSubscribed()) + .verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -378,15 +314,11 @@ public class PayloadInterceptorRSocketTests { @Test public void metadataPushWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.metadataPush(this.payload)) - .then(() -> this.voidResult.assertWasNotSubscribed()) - .verifyErrorSatisfies(e -> assertThat(e).isEqualTo(expected)); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.metadataPush(this.payload)).then(() -> this.voidResult.assertWasNotSubscribed()) + .verifyErrorSatisfies((e) -> assertThat(e).isEqualTo(expected)); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); } @@ -394,22 +326,17 @@ public class PayloadInterceptorRSocketTests { @Test public void metadataPushWhenSecurityContextThenDelegateContext() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password"); - when(this.interceptor.intercept(any(), any())).thenAnswer(withAuthenticated(authentication)); - when(this.delegate.metadataPush(any())).thenReturn(this.voidResult.mono()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withAuthenticated(authentication)); + given(this.delegate.metadataPush(any())).willReturn(this.voidResult.mono()); RSocket assertAuthentication = new RSocketProxy(this.delegate) { @Override public Mono metadataPush(Payload payload) { - return assertAuthentication(authentication) - .flatMap(a -> super.metadataPush(payload)); + return assertAuthentication(authentication).flatMap((a) -> super.metadataPush(payload)); } }; PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(assertAuthentication, - Arrays.asList(this.interceptor), metadataMimeType, dataMimeType); - - StepVerifier.create(interceptor.metadataPush(this.payload)) - .verifyComplete(); - + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType); + StepVerifier.create(interceptor.metadataPush(this.payload)).verifyComplete(); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.delegate).metadataPush(this.payload); @@ -417,37 +344,27 @@ public class PayloadInterceptorRSocketTests { } // multiple interceptors - @Test public void fireAndForgetWhenInterceptorsCompleteThenDelegateInvoked() { - when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); - when(this.interceptor2.intercept(any(), any())).thenAnswer(withChainNext()); - when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withChainNext()); + given(this.interceptor2.intercept(any(), any())).willAnswer(withChainNext()); + given(this.delegate.fireAndForget(any())).willReturn(this.voidResult.mono()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, - dataMimeType); - + Arrays.asList(this.interceptor, this.interceptor2), this.metadataMimeType, this.dataMimeType); interceptor.fireAndForget(this.payload).block(); - verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); this.voidResult.assertWasSubscribed(); } - @Test public void fireAndForgetWhenInterceptorsMutatesPayloadThenDelegateInvoked() { - when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); - when(this.interceptor2.intercept(any(), any())).thenAnswer(withChainNext()); - when(this.delegate.fireAndForget(any())).thenReturn(this.voidResult.mono()); - + given(this.interceptor.intercept(any(), any())).willAnswer(withChainNext()); + given(this.interceptor2.intercept(any(), any())).willAnswer(withChainNext()); + given(this.delegate.fireAndForget(any())).willReturn(this.voidResult.mono()); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, - dataMimeType); - + Arrays.asList(this.interceptor, this.interceptor2), this.metadataMimeType, this.dataMimeType); interceptor.fireAndForget(this.payload).block(); - verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.interceptor2).intercept(any(), any()); @@ -458,14 +375,11 @@ public class PayloadInterceptorRSocketTests { @Test public void fireAndForgetWhenInterceptor1ErrorsThenInterceptor2AndDelegateNotInvoked() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, - dataMimeType); - - assertThatCode(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected); - + Arrays.asList(this.interceptor, this.interceptor2), this.metadataMimeType, this.dataMimeType); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verifyZeroInteractions(this.interceptor2); @@ -475,15 +389,12 @@ public class PayloadInterceptorRSocketTests { @Test public void fireAndForgetWhenInterceptor2ErrorsThenInterceptor2AndDelegateNotInvoked() { RuntimeException expected = new RuntimeException("Oops"); - when(this.interceptor.intercept(any(), any())).thenAnswer(withChainNext()); - when(this.interceptor2.intercept(any(), any())).thenReturn(Mono.error(expected)); - + given(this.interceptor.intercept(any(), any())).willAnswer(withChainNext()); + given(this.interceptor2.intercept(any(), any())).willReturn(Mono.error(expected)); PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, - Arrays.asList(this.interceptor, this.interceptor2), metadataMimeType, - dataMimeType); - - assertThatCode(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected); - + Arrays.asList(this.interceptor, this.interceptor2), this.metadataMimeType, this.dataMimeType); + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> interceptor.fireAndForget(this.payload).block()).isEqualTo(expected); verify(this.interceptor).intercept(this.exchange.capture(), any()); assertThat(this.exchange.getValue().getPayload()).isEqualTo(this.payload); verify(this.interceptor2).intercept(any(), any()); @@ -491,25 +402,25 @@ public class PayloadInterceptorRSocketTests { } private Mono assertAuthentication(Authentication authentication) { - return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .doOnNext(a -> assertThat(a).isEqualTo(authentication)); + return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) + .doOnNext((a) -> assertThat(a).isEqualTo(authentication)); } private Answer withAuthenticated(Authentication authentication) { - return invocation -> { + return (invocation) -> { PayloadInterceptorChain c = (PayloadInterceptorChain) invocation.getArguments()[1]; - return c.next(new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, this.metadataMimeType, - this.dataMimeType)) + return c.next(new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, + this.metadataMimeType, this.dataMimeType)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)); }; } private static Answer> withChainNext() { - return invocation -> { + return (invocation) -> { PayloadExchange exchange = (PayloadExchange) invocation.getArguments()[0]; PayloadInterceptorChain chain = (PayloadInterceptorChain) invocation.getArguments()[1]; return chain.next(exchange); }; } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptorTests.java index d06ac483d0..1727033cee 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptorTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorInterceptorTests.java @@ -16,6 +16,9 @@ package org.springframework.security.rsocket.core; +import java.util.Arrays; +import java.util.List; + import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -27,27 +30,24 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.http.MediaType; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadInterceptor; -import org.springframework.security.rsocket.core.PayloadInterceptorRSocket; -import org.springframework.security.rsocket.core.PayloadSocketAcceptorInterceptor; -import reactor.core.publisher.Mono; - -import java.util.Arrays; -import java.util.List; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Matchers.any; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class PayloadSocketAcceptorInterceptorTests { + @Mock private PayloadInterceptor interceptor; @@ -75,21 +75,18 @@ public class PayloadSocketAcceptorInterceptorTests { @Test public void applyWhenDefaultMetadataMimeTypeThenDefaulted() { - when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); - + given(this.setupPayload.dataMimeType()).willReturn(MediaType.APPLICATION_JSON_VALUE); PayloadExchange exchange = captureExchange(); - - assertThat(exchange.getMetadataMimeType().toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getMetadataMimeType().toString()) + .isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } @Test public void acceptWhenDefaultMetadataMimeTypeOverrideThenDefaulted() { this.acceptorInterceptor.setDefaultMetadataMimeType(MediaType.APPLICATION_JSON); - when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); - + given(this.setupPayload.dataMimeType()).willReturn(MediaType.APPLICATION_JSON_VALUE); PayloadExchange exchange = captureExchange(); - assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } @@ -97,29 +94,23 @@ public class PayloadSocketAcceptorInterceptorTests { @Test public void acceptWhenDefaultDataMimeTypeThenDefaulted() { this.acceptorInterceptor.setDefaultDataMimeType(MediaType.APPLICATION_JSON); - PayloadExchange exchange = captureExchange(); - - assertThat(exchange.getMetadataMimeType().toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getMetadataMimeType().toString()) + .isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } private PayloadExchange captureExchange() { - when(this.socketAcceptor.accept(any(), any())).thenReturn(Mono.just(this.rSocket)); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); - + given(this.socketAcceptor.accept(any(), any())).willReturn(Mono.just(this.rSocket)); + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()); SocketAcceptor wrappedAcceptor = this.acceptorInterceptor.apply(this.socketAcceptor); RSocket result = wrappedAcceptor.accept(this.setupPayload, this.rSocket).block(); - assertThat(result).isInstanceOf(PayloadInterceptorRSocket.class); - - when(this.rSocket.fireAndForget(any())).thenReturn(Mono.empty()); - + given(this.rSocket.fireAndForget(any())).willReturn(Mono.empty()); result.fireAndForget(this.payload).block(); - - ArgumentCaptor exchangeArg = - ArgumentCaptor.forClass(PayloadExchange.class); + ArgumentCaptor exchangeArg = ArgumentCaptor.forClass(PayloadExchange.class); verify(this.interceptor, times(2)).intercept(exchangeArg.capture(), any()); return exchangeArg.getValue(); } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java index 69b7e2356d..4c3722a1e9 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java @@ -43,11 +43,11 @@ import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadInterceptor; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.mockito.Matchers.any; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -83,33 +83,34 @@ public class PayloadSocketAcceptorTests { @Test public void constructorWhenNullDelegateThenException() { this.delegate = null; - assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); } @Test public void constructorWhenNullInterceptorsThenException() { this.interceptors = null; - assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); } @Test public void constructorWhenEmptyInterceptorsThenException() { this.interceptors = Collections.emptyList(); - assertThatCode(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); + assertThatIllegalArgumentException() + .isThrownBy(() -> new PayloadSocketAcceptor(this.delegate, this.interceptors)); } @Test public void acceptWhenDataMimeTypeNullThenException() { - assertThatCode(() -> this.acceptor.accept(this.setupPayload, this.rSocket) - .block()).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.acceptor.accept(this.setupPayload, this.rSocket).block()); } @Test public void acceptWhenDefaultMetadataMimeTypeThenDefaulted() { - when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); - + given(this.setupPayload.dataMimeType()).willReturn(MediaType.APPLICATION_JSON_VALUE); PayloadExchange exchange = captureExchange(); - assertThat(exchange.getMetadataMimeType().toString()) .isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); @@ -118,10 +119,8 @@ public class PayloadSocketAcceptorTests { @Test public void acceptWhenDefaultMetadataMimeTypeOverrideThenDefaulted() { this.acceptor.setDefaultMetadataMimeType(MediaType.APPLICATION_JSON); - when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); - + given(this.setupPayload.dataMimeType()).willReturn(MediaType.APPLICATION_JSON_VALUE); PayloadExchange exchange = captureExchange(); - assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } @@ -129,61 +128,51 @@ public class PayloadSocketAcceptorTests { @Test public void acceptWhenDefaultDataMimeTypeThenDefaulted() { this.acceptor.setDefaultDataMimeType(MediaType.APPLICATION_JSON); - PayloadExchange exchange = captureExchange(); - - assertThat(exchange.getMetadataMimeType() - .toString()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + assertThat(exchange.getMetadataMimeType().toString()) + .isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } @Test public void acceptWhenExplicitMimeTypeThenThenOverrideDefault() { - when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE); - when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); - + given(this.setupPayload.metadataMimeType()).willReturn(MediaType.TEXT_PLAIN_VALUE); + given(this.setupPayload.dataMimeType()).willReturn(MediaType.APPLICATION_JSON_VALUE); PayloadExchange exchange = captureExchange(); - assertThat(exchange.getMetadataMimeType()).isEqualTo(MediaType.TEXT_PLAIN); assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON); } - @Test // gh-8654 public void acceptWhenDelegateAcceptRequiresReactiveSecurityContext() { - when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE); - when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE); - SecurityContext expectedSecurityContext = new SecurityContextImpl(new TestingAuthenticationToken("user", "password", "ROLE_USER")); - CaptureSecurityContextSocketAcceptor captureSecurityContext = new CaptureSecurityContextSocketAcceptor(this.rSocket); + given(this.setupPayload.metadataMimeType()).willReturn(MediaType.TEXT_PLAIN_VALUE); + given(this.setupPayload.dataMimeType()).willReturn(MediaType.APPLICATION_JSON_VALUE); + SecurityContext expectedSecurityContext = new SecurityContextImpl( + new TestingAuthenticationToken("user", "password", "ROLE_USER")); + CaptureSecurityContextSocketAcceptor captureSecurityContext = new CaptureSecurityContextSocketAcceptor( + this.rSocket); PayloadInterceptor authenticateInterceptor = (exchange, chain) -> { - Context withSecurityContext = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedSecurityContext)); - return chain.next(exchange) - .subscriberContext(withSecurityContext); + Context withSecurityContext = ReactiveSecurityContextHolder + .withSecurityContext(Mono.just(expectedSecurityContext)); + return chain.next(exchange).subscriberContext(withSecurityContext); }; List interceptors = Arrays.asList(authenticateInterceptor); this.acceptor = new PayloadSocketAcceptor(captureSecurityContext, interceptors); - this.acceptor.accept(this.setupPayload, this.rSocket).block(); - assertThat(captureSecurityContext.getSecurityContext()).isEqualTo(expectedSecurityContext); } private PayloadExchange captureExchange() { - when(this.delegate.accept(any(), any())).thenReturn(Mono.just(this.rSocket)); - when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()); - + given(this.delegate.accept(any(), any())).willReturn(Mono.just(this.rSocket)); + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()); RSocket result = this.acceptor.accept(this.setupPayload, this.rSocket).block(); - assertThat(result).isInstanceOf(PayloadInterceptorRSocket.class); - - when(this.rSocket.fireAndForget(any())).thenReturn(Mono.empty()); - + given(this.rSocket.fireAndForget(any())).willReturn(Mono.empty()); result.fireAndForget(this.payload).block(); - - ArgumentCaptor exchangeArg = - ArgumentCaptor.forClass(PayloadExchange.class); + ArgumentCaptor exchangeArg = ArgumentCaptor.forClass(PayloadExchange.class); verify(this.interceptor, times(2)).intercept(exchangeArg.capture(), any()); return exchangeArg.getValue(); } + } diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java index 2654d2378c..df8aa2c1be 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/metadata/BasicAuthenticationDecoderTests.java @@ -16,14 +16,15 @@ package org.springframework.security.rsocket.metadata; +import java.util.Map; + import org.junit.Test; +import reactor.core.publisher.Mono; + import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.util.MimeType; -import reactor.core.publisher.Mono; - -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -31,24 +32,20 @@ import static org.assertj.core.api.Assertions.assertThat; * @author Rob Winch */ public class BasicAuthenticationDecoderTests { + @Test public void basicAuthenticationWhenEncodedThenDecodes() { BasicAuthenticationEncoder encoder = new BasicAuthenticationEncoder(); BasicAuthenticationDecoder decoder = new BasicAuthenticationDecoder(); - UsernamePasswordMetadata expectedCredentials = - new UsernamePasswordMetadata("rob", "password"); + UsernamePasswordMetadata expectedCredentials = new UsernamePasswordMetadata("rob", "password"); DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); - ResolvableType elementType = ResolvableType - .forClass(UsernamePasswordMetadata.class); + ResolvableType elementType = ResolvableType.forClass(UsernamePasswordMetadata.class); MimeType mimeType = UsernamePasswordMetadata.BASIC_AUTHENTICATION_MIME_TYPE; Map hints = null; - - DataBuffer dataBuffer = encoder.encodeValue(expectedCredentials, factory, - elementType, mimeType, hints); + DataBuffer dataBuffer = encoder.encodeValue(expectedCredentials, factory, elementType, mimeType, hints); UsernamePasswordMetadata actualCredentials = decoder .decodeToMono(Mono.just(dataBuffer), elementType, mimeType, hints).block(); - assertThat(actualCredentials).isEqualToComparingFieldByField(expectedCredentials); } -} \ No newline at end of file +} diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcherTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcherTests.java index 6d0b2ae73b..747e6e2c8a 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcherTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/util/matcher/RoutePayloadExchangeMatcherTests.java @@ -16,6 +16,9 @@ package org.springframework.security.rsocket.util.matcher; +import java.util.Collections; +import java.util.Map; + import io.rsocket.Payload; import io.rsocket.metadata.WellKnownMimeType; import org.junit.Before; @@ -23,29 +26,28 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; + import org.springframework.http.MediaType; import org.springframework.messaging.rsocket.MetadataExtractor; -import org.springframework.security.rsocket.core.DefaultPayloadExchange; import org.springframework.security.rsocket.api.PayloadExchange; import org.springframework.security.rsocket.api.PayloadExchangeType; +import org.springframework.security.rsocket.core.DefaultPayloadExchange; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import org.springframework.util.RouteMatcher; -import java.util.Collections; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; /** * @author Rob Winch */ @RunWith(MockitoJUnitRunner.class) public class RoutePayloadExchangeMatcherTests { - static final MimeType COMPOSITE_METADATA = MimeTypeUtils.parseMimeType( - WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + static final MimeType COMPOSITE_METADATA = MimeTypeUtils + .parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); @Mock private MetadataExtractor metadataExtractor; @@ -69,14 +71,13 @@ public class RoutePayloadExchangeMatcherTests { public void setup() { this.pattern = "a.b"; this.matcher = new RoutePayloadExchangeMatcher(this.metadataExtractor, this.routeMatcher, this.pattern); - this.exchange = new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, COMPOSITE_METADATA, - MediaType.APPLICATION_JSON); + this.exchange = new DefaultPayloadExchange(PayloadExchangeType.REQUEST_CHANNEL, this.payload, + COMPOSITE_METADATA, MediaType.APPLICATION_JSON); } @Test public void matchesWhenNoRouteThenNotMatch() { - when(this.metadataExtractor.extract(any(), any())) - .thenReturn(Collections.emptyMap()); + given(this.metadataExtractor.extract(any(), any())).willReturn(Collections.emptyMap()); PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); assertThat(result.isMatch()).isFalse(); } @@ -84,8 +85,8 @@ public class RoutePayloadExchangeMatcherTests { @Test public void matchesWhenNotMatchThenNotMatch() { String route = "route"; - when(this.metadataExtractor.extract(any(), any())) - .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); + given(this.metadataExtractor.extract(any(), any())) + .willReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); assertThat(result.isMatch()).isFalse(); } @@ -93,10 +94,10 @@ public class RoutePayloadExchangeMatcherTests { @Test public void matchesWhenMatchAndNoVariablesThenMatch() { String route = "route"; - when(this.metadataExtractor.extract(any(), any())) - .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); - when(this.routeMatcher.parseRoute(any())).thenReturn(this.route); - when(this.routeMatcher.matchAndExtract(any(), any())).thenReturn(Collections.emptyMap()); + given(this.metadataExtractor.extract(any(), any())) + .willReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); + given(this.routeMatcher.parseRoute(any())).willReturn(this.route); + given(this.routeMatcher.matchAndExtract(any(), any())).willReturn(Collections.emptyMap()); PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); assertThat(result.isMatch()).isTrue(); } @@ -105,12 +106,13 @@ public class RoutePayloadExchangeMatcherTests { public void matchesWhenMatchAndVariablesThenMatchAndVariables() { String route = "route"; Map variables = Collections.singletonMap("a", "b"); - when(this.metadataExtractor.extract(any(), any())) - .thenReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); - when(this.routeMatcher.parseRoute(any())).thenReturn(this.route); - when(this.routeMatcher.matchAndExtract(any(), any())).thenReturn(variables); + given(this.metadataExtractor.extract(any(), any())) + .willReturn(Collections.singletonMap(MetadataExtractor.ROUTE_KEY, route)); + given(this.routeMatcher.parseRoute(any())).willReturn(this.route); + given(this.routeMatcher.matchAndExtract(any(), any())).willReturn(variables); PayloadExchangeMatcher.MatchResult result = this.matcher.matches(this.exchange).block(); assertThat(result.isMatch()).isTrue(); assertThat(result.getVariables()).containsAllEntriesOf(variables); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java index a533a2753e..a923760be1 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/OpenSamlInitializationService.java @@ -20,6 +20,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; + import javax.xml.XMLConstants; import net.shibboleth.utilities.java.support.xml.BasicParserPool; @@ -28,29 +29,30 @@ import org.apache.commons.logging.LogFactory; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.config.InitializationService; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.springframework.security.saml2.Saml2Exception; -import static java.lang.Boolean.FALSE; -import static java.lang.Boolean.TRUE; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.setParserPool; - /** - * An initialization service for initializing OpenSAML. Each Spring Security OpenSAML-based component invokes - * the {@link #initialize()} method at static initialization time. + * An initialization service for initializing OpenSAML. Each Spring Security + * OpenSAML-based component invokes the {@link #initialize()} method at static + * initialization time. * - * {@link #initialize()} is idempotent and may be safely called in custom classes that need OpenSAML to be - * initialized in order to function correctly. It's recommended that you call this {@link #initialize()} method - * when using Spring Security and OpenSAML instead of OpenSAML's {@link InitializationService#initialize()}. + * {@link #initialize()} is idempotent and may be safely called in custom classes that + * need OpenSAML to be initialized in order to function correctly. It's recommended that + * you call this {@link #initialize()} method when using Spring Security and OpenSAML + * instead of OpenSAML's {@link InitializationService#initialize()}. * - * The primary purpose of {@link #initialize()} is to prepare OpenSAML's {@link XMLObjectProviderRegistry} - * with some reasonable defaults. Any changes that Spring Security makes to the registry happen in this method. + * The primary purpose of {@link #initialize()} is to prepare OpenSAML's + * {@link XMLObjectProviderRegistry} with some reasonable defaults. Any changes that + * Spring Security makes to the registry happen in this method. * - * To override those defaults, call {@link #requireInitialize(Consumer)} and change the registry: + * To override those defaults, call {@link #requireInitialize(Consumer)} and change the + * registry: * *
          * 	static {
        - *  	OpenSamlInitializationService.requireInitialize(registry -> {
        + *  	OpenSamlInitializationService.requireInitialize((registry) -> {
          *  	 	registry.setParserPool(...);
          *  		registry.getBuilderFactory().registerBuilder(...);
          *  	});
        @@ -59,45 +61,53 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.setP
          *
          * {@link #requireInitialize(Consumer)} may only be called once per application.
          *
        - * If the application already initialized OpenSAML before {@link #requireInitialize(Consumer)} was called,
        - * then the configuration changes will not be applied and an exception will be thrown. The reason for this is to
        - * alert you to the fact that there are likely some initialization ordering problems in your application that
        - * would otherwise lead to an unpredictable state.
        + * If the application already initialized OpenSAML before
        + * {@link #requireInitialize(Consumer)} was called, then the configuration changes will
        + * not be applied and an exception will be thrown. The reason for this is to alert you to
        + * the fact that there are likely some initialization ordering problems in your
        + * application that would otherwise lead to an unpredictable state.
          *
        - * If you must change the registry's configuration in multiple places in your application, you are expected
        - * to handle the initialization ordering issues yourself instead of trying to call {@link #requireInitialize(Consumer)}
        - * multiple times.
        + * If you must change the registry's configuration in multiple places in your application,
        + * you are expected to handle the initialization ordering issues yourself instead of
        + * trying to call {@link #requireInitialize(Consumer)} multiple times.
          *
          * @author Josh Cummings
          * @since 5.4
          */
        -public class OpenSamlInitializationService {
        +public final class OpenSamlInitializationService {
        +
         	private static final Log log = LogFactory.getLog(OpenSamlInitializationService.class);
        +
         	private static final AtomicBoolean initialized = new AtomicBoolean(false);
         
        +	private OpenSamlInitializationService() {
        +	}
        +
         	/**
         	 * Ready OpenSAML for use and configure it with reasonable defaults.
         	 *
        -	 * Initialization is guaranteed to happen only once per application. This method will passively return
        -	 * {@code false} if initialization already took place earlier in the application.
        -	 *
        -	 * @return whether or not initialization was performed. The first thread to initialize OpenSAML will
        -	 * return {@code true} while the rest will return {@code false}.
        +	 * Initialization is guaranteed to happen only once per application. This method will
        +	 * passively return {@code false} if initialization already took place earlier in the
        +	 * application.
        +	 * @return whether or not initialization was performed. The first thread to initialize
        +	 * OpenSAML will return {@code true} while the rest will return {@code false}.
         	 * @throws Saml2Exception if OpenSAML failed to initialize
         	 */
         	public static boolean initialize() {
        -		return initialize(registry -> {});
        +		return initialize((registry) -> {
        +		});
         	}
         
         	/**
        -	 * Ready OpenSAML for use, configure it with reasonable defaults, and modify the {@link XMLObjectProviderRegistry}
        -	 * using the provided {@link Consumer}.
        +	 * Ready OpenSAML for use, configure it with reasonable defaults, and modify the
        +	 * {@link XMLObjectProviderRegistry} using the provided {@link Consumer}.
         	 *
        -	 * Initialization is guaranteed to happen only once per application. This method will throw an exception
        -	 * if initialization already took place earlier in the application.
        -	 *
        -	 * @param registryConsumer the {@link Consumer} to further configure the {@link XMLObjectProviderRegistry}
        -	 * @throws Saml2Exception if initialization already happened previously or if OpenSAML failed to initialize
        +	 * Initialization is guaranteed to happen only once per application. This method will
        +	 * throw an exception if initialization already took place earlier in the application.
        +	 * @param registryConsumer the {@link Consumer} to further configure the
        +	 * {@link XMLObjectProviderRegistry}
        +	 * @throws Saml2Exception if initialization already happened previously or if OpenSAML
        +	 * failed to initialize
         	 */
         	public static void requireInitialize(Consumer registryConsumer) {
         		if (!initialize(registryConsumer)) {
        @@ -108,39 +118,39 @@ public class OpenSamlInitializationService {
         	private static boolean initialize(Consumer registryConsumer) {
         		if (initialized.compareAndSet(false, true)) {
         			log.trace("Initializing OpenSAML");
        -
         			try {
         				InitializationService.initialize();
        -			} catch (Exception e) {
        -				throw new Saml2Exception(e);
         			}
        -
        +			catch (Exception ex) {
        +				throw new Saml2Exception(ex);
        +			}
         			BasicParserPool parserPool = new BasicParserPool();
         			parserPool.setMaxPoolSize(50);
        -
        -			Map parserBuilderFeatures = new HashMap<>();
        -			parserBuilderFeatures.put("http://apache.org/xml/features/disallow-doctype-decl", TRUE);
        -			parserBuilderFeatures.put(XMLConstants.FEATURE_SECURE_PROCESSING, TRUE);
        -			parserBuilderFeatures.put("http://xml.org/sax/features/external-general-entities", FALSE);
        -			parserBuilderFeatures.put("http://apache.org/xml/features/validation/schema/normalized-value", FALSE);
        -			parserBuilderFeatures.put("http://xml.org/sax/features/external-parameter-entities", FALSE);
        -			parserBuilderFeatures.put("http://apache.org/xml/features/dom/defer-node-expansion", FALSE);
        -			parserPool.setBuilderFeatures(parserBuilderFeatures);
        -
        +			parserPool.setBuilderFeatures(getParserBuilderFeatures());
         			try {
         				parserPool.initialize();
        -			} catch (Exception e) {
        -				throw new Saml2Exception(e);
         			}
        -			setParserPool(parserPool);
        -
        +			catch (Exception ex) {
        +				throw new Saml2Exception(ex);
        +			}
        +			XMLObjectProviderRegistrySupport.setParserPool(parserPool);
         			registryConsumer.accept(ConfigurationService.get(XMLObjectProviderRegistry.class));
        -
         			log.debug("Initialized OpenSAML");
         			return true;
        -		} else {
        -			log.debug("Refused to re-initialize OpenSAML");
        -			return false;
         		}
        +		log.debug("Refused to re-initialize OpenSAML");
        +		return false;
         	}
        +
        +	private static Map getParserBuilderFeatures() {
        +		Map parserBuilderFeatures = new HashMap<>();
        +		parserBuilderFeatures.put("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE);
        +		parserBuilderFeatures.put(XMLConstants.FEATURE_SECURE_PROCESSING, Boolean.TRUE);
        +		parserBuilderFeatures.put("http://xml.org/sax/features/external-general-entities", Boolean.FALSE);
        +		parserBuilderFeatures.put("http://apache.org/xml/features/validation/schema/normalized-value", Boolean.FALSE);
        +		parserBuilderFeatures.put("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE);
        +		parserBuilderFeatures.put("http://apache.org/xml/features/dom/defer-node-expansion", Boolean.FALSE);
        +		return parserBuilderFeatures;
        +	}
        +
         }
        diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2Error.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2Error.java
        index 94b0450276..6709092ced 100644
        --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2Error.java
        +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2Error.java
        @@ -25,21 +25,23 @@ import org.springframework.util.Assert;
          * A representation of an SAML 2.0 Error.
          *
          * 

        - * At a minimum, an error response will contain an error code. - * The commonly used error code are defined in this class - * or a new codes can be defined in the future as arbitrary strings. + * At a minimum, an error response will contain an error code. The commonly used error + * code are defined in this class or a new codes can be defined in the future as arbitrary + * strings. *

        + * * @since 5.2 */ public class Saml2Error implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private final String errorCode; + private final String description; /** * Constructs a {@code Saml2Error} using the provided parameters. - * * @param errorCode the error code * @param description the error description */ @@ -51,7 +53,6 @@ public class Saml2Error implements Serializable { /** * Returns the error code. - * * @return the error code */ public final String getErrorCode() { @@ -60,7 +61,6 @@ public class Saml2Error implements Serializable { /** * Returns the error description. - * * @return the error description */ public final String getDescription() { @@ -69,7 +69,7 @@ public class Saml2Error implements Serializable { @Override public String toString() { - return "[" + this.getErrorCode() + "] " + - (this.getDescription() != null ? this.getDescription() : ""); + return "[" + this.getErrorCode() + "] " + ((this.getDescription() != null) ? this.getDescription() : ""); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java index 810c4338c7..63753f1121 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java @@ -22,80 +22,84 @@ package org.springframework.security.saml2.core; * @since 5.2 */ public interface Saml2ErrorCodes { + /** - * SAML Data does not represent a SAML 2 Response object. - * A valid XML object was received, but that object was not a - * SAML 2 Response object of type {@code ResponseType} per specification + * SAML Data does not represent a SAML 2 Response object. A valid XML object was + * received, but that object was not a SAML 2 Response object of type + * {@code ResponseType} per specification * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=46 */ String UNKNOWN_RESPONSE_CLASS = "unknown_response_class"; + /** - * The response data is malformed or incomplete. - * An invalid XML object was received, and XML unmarshalling failed. + * The response data is malformed or incomplete. An invalid XML object was received, + * and XML unmarshalling failed. */ String MALFORMED_RESPONSE_DATA = "malformed_response_data"; + /** - * Response destination does not match the request URL. - * A SAML 2 response object was received at a URL that - * did not match the URL stored in the {code Destination} attribute - * in the Response object. + * Response destination does not match the request URL. A SAML 2 response object was + * received at a URL that did not match the URL stored in the {code Destination} + * attribute in the Response object. * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38 */ String INVALID_DESTINATION = "invalid_destination"; + /** - * The assertion was not valid. - * The assertion used for authentication failed validation. - * Details around the failure will be present in the error description. + * The assertion was not valid. The assertion used for authentication failed + * validation. Details around the failure will be present in the error description. */ String INVALID_ASSERTION = "invalid_assertion"; + /** - * The signature of response or assertion was invalid. - * Either the response or the assertion was missing a signature - * or the signature could not be verified using the system's - * configured credentials. Most commonly the IDP's - * X509 certificate. + * The signature of response or assertion was invalid. Either the response or the + * assertion was missing a signature or the signature could not be verified using the + * system's configured credentials. Most commonly the IDP's X509 certificate. */ String INVALID_SIGNATURE = "invalid_signature"; + /** - * The assertion did not contain a subject element. - * The subject element, type SubjectType, contains - * a {@code NameID} or an {@code EncryptedID} that is used - * to assign the authenticated principal an identifier, - * typically a username. + * The assertion did not contain a subject element. The subject element, type + * SubjectType, contains a {@code NameID} or an {@code EncryptedID} that is used to + * assign the authenticated principal an identifier, typically a username. * * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=18 */ String SUBJECT_NOT_FOUND = "subject_not_found"; + /** - * The subject did not contain a user identifier - * The assertion contained a subject element, but the subject - * element did not have a {@code NameID} or {@code EncryptedID} - * element + * The subject did not contain a user identifier The assertion contained a subject + * element, but the subject element did not have a {@code NameID} or + * {@code EncryptedID} element * * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=18 */ String USERNAME_NOT_FOUND = "username_not_found"; + /** - * The system failed to decrypt an assertion or a name identifier. - * This error code will be thrown if the decryption of either a - * {@code EncryptedAssertion} or {@code EncryptedID} fails. + * The system failed to decrypt an assertion or a name identifier. This error code + * will be thrown if the decryption of either a {@code EncryptedAssertion} or + * {@code EncryptedID} fails. * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=17 */ String DECRYPTION_ERROR = "decryption_error"; + /** * An Issuer element contained a value that didn't * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=15 */ String INVALID_ISSUER = "invalid_issuer"; + /** - * An error happened during validation. - * Used when internal, non classified, errors are caught during the - * authentication process. + * An error happened during validation. Used when internal, non classified, errors are + * caught during the authentication process. */ String INTERNAL_VALIDATION_ERROR = "internal_validation_error"; + /** - * The relying party registration was not found. - * The registration ID did not correspond to any relying party registration. + * The relying party registration was not found. The registration ID did not + * correspond to any relying party registration. */ String RELYING_PARTY_REGISTRATION_NOT_FOUND = "relying_party_registration_not_found"; + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java index 3df1e7c3da..2a68b496a2 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java @@ -30,6 +30,7 @@ import org.springframework.util.Assert; * @since 5.4 */ public final class Saml2ResponseValidatorResult { + static final Saml2ResponseValidatorResult NO_ERRORS = new Saml2ResponseValidatorResult(Collections.emptyList()); private final Collection errors; @@ -41,7 +42,6 @@ public final class Saml2ResponseValidatorResult { /** * Say whether this result indicates success - * * @return whether this result has errors */ public boolean hasErrors() { @@ -50,17 +50,16 @@ public final class Saml2ResponseValidatorResult { /** * Return error details regarding the validation attempt - * - * @return the collection of results in this result, if any; returns an empty list otherwise + * @return the collection of results in this result, if any; returns an empty list + * otherwise */ public Collection getErrors() { return Collections.unmodifiableCollection(this.errors); } /** - * Return a new {@link Saml2ResponseValidatorResult} that contains - * both the given {@link Saml2Error} and the errors from the result - * + * Return a new {@link Saml2ResponseValidatorResult} that contains both the given + * {@link Saml2Error} and the errors from the result * @param error the {@link Saml2Error} to append * @return a new {@link Saml2ResponseValidatorResult} for further reporting */ @@ -72,10 +71,8 @@ public final class Saml2ResponseValidatorResult { } /** - * Return a new {@link Saml2ResponseValidatorResult} that contains - * the errors from the given {@link Saml2ResponseValidatorResult} as well - * as this result. - * + * Return a new {@link Saml2ResponseValidatorResult} that contains the errors from the + * given {@link Saml2ResponseValidatorResult} as well as this result. * @param result the {@link Saml2ResponseValidatorResult} to merge with this one * @return a new {@link Saml2ResponseValidatorResult} for further reporting */ @@ -88,7 +85,6 @@ public final class Saml2ResponseValidatorResult { /** * Construct a successful {@link Saml2ResponseValidatorResult} - * * @return an {@link Saml2ResponseValidatorResult} with no errors */ public static Saml2ResponseValidatorResult success() { @@ -97,7 +93,6 @@ public final class Saml2ResponseValidatorResult { /** * Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail - * * @param errors the list of errors * @return an {@link Saml2ResponseValidatorResult} with the errors specified */ @@ -107,7 +102,6 @@ public final class Saml2ResponseValidatorResult { /** * Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail - * * @param errors the list of errors * @return an {@link Saml2ResponseValidatorResult} with the errors specified */ @@ -118,4 +112,5 @@ public final class Saml2ResponseValidatorResult { return new Saml2ResponseValidatorResult(errors); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2X509Credential.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2X509Credential.java index 8569b390db..4fde34733c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2X509Credential.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2X509Credential.java @@ -13,50 +13,42 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.saml2.core; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.Arrays; import java.util.LinkedHashSet; import java.util.Objects; import java.util.Set; import org.springframework.util.Assert; -import static java.util.Arrays.asList; -import static org.springframework.util.Assert.notEmpty; -import static org.springframework.util.Assert.notNull; -import static org.springframework.util.Assert.state; - /** - * An object for holding a public certificate, any associated private key, and its intended - * - * usages - * - * (Line 584, Section 4.3 Credentials). + * An object for holding a public certificate, any associated private key, and its + * intended + * usages (Line 584, Section 4.3 Credentials). * - * @since 5.4 * @author Filip Hanik * @author Josh Cummings + * @since 5.4 */ public final class Saml2X509Credential { - public enum Saml2X509CredentialType { - VERIFICATION, - ENCRYPTION, - SIGNING, - DECRYPTION, - } private final PrivateKey privateKey; + private final X509Certificate certificate; + private final Set credentialTypes; /** * Creates a {@link Saml2X509Credential} using the provided parameters - * * @param certificate the credential's public certificiate - * @param types the credential's intended usages, must be one of {@link Saml2X509CredentialType#VERIFICATION} or - * {@link Saml2X509CredentialType#ENCRYPTION} or both. + * @param types the credential's intended usages, must be one of + * {@link Saml2X509CredentialType#VERIFICATION} or + * {@link Saml2X509CredentialType#ENCRYPTION} or both. */ public Saml2X509Credential(X509Certificate certificate, Saml2X509CredentialType... types) { this(null, false, certificate, types); @@ -65,11 +57,11 @@ public final class Saml2X509Credential { /** * Creates a {@link Saml2X509Credential} using the provided parameters - * * @param privateKey the credential's private key * @param certificate the credential's public certificate - * @param types the credential's intended usages, must be one of {@link Saml2X509CredentialType#SIGNING} or - * {@link Saml2X509CredentialType#DECRYPTION} or both. + * @param types the credential's intended usages, must be one of + * {@link Saml2X509CredentialType#SIGNING} or + * {@link Saml2X509CredentialType#DECRYPTION} or both. */ public Saml2X509Credential(PrivateKey privateKey, X509Certificate certificate, Saml2X509CredentialType... types) { this(privateKey, true, certificate, types); @@ -78,7 +70,6 @@ public final class Saml2X509Credential { /** * Creates a {@link Saml2X509Credential} using the provided parameters - * * @param privateKey the credential's private key * @param certificate the credential's public certificate * @param types the credential's intended usages @@ -125,26 +116,22 @@ public final class Saml2X509Credential { return new Saml2X509Credential(privateKey, certificate, Saml2X509Credential.Saml2X509CredentialType.SIGNING); } - private Saml2X509Credential( - PrivateKey privateKey, - boolean keyRequired, - X509Certificate certificate, + private Saml2X509Credential(PrivateKey privateKey, boolean keyRequired, X509Certificate certificate, Saml2X509CredentialType... types) { - notNull(certificate, "certificate cannot be null"); - notEmpty(types, "credentials types cannot be empty"); + Assert.notNull(certificate, "certificate cannot be null"); + Assert.notEmpty(types, "credentials types cannot be empty"); if (keyRequired) { - notNull(privateKey, "privateKey cannot be null"); + Assert.notNull(privateKey, "privateKey cannot be null"); } this.privateKey = privateKey; this.certificate = certificate; - this.credentialTypes = new LinkedHashSet<>(asList(types)); + this.credentialTypes = new LinkedHashSet<>(Arrays.asList(types)); } /** * Get the private key for this credential - * * @return the private key, may be null - * @see {@link #Saml2X509Credential(PrivateKey, X509Certificate, Saml2X509CredentialType...)} + * @see #Saml2X509Credential(PrivateKey, X509Certificate, Saml2X509CredentialType...) */ public PrivateKey getPrivateKey() { return this.privateKey; @@ -152,7 +139,6 @@ public final class Saml2X509Credential { /** * Get the public certificate for this credential - * * @return the public certificate */ public X509Certificate getCertificate() { @@ -161,7 +147,6 @@ public final class Saml2X509Credential { /** * Indicate whether this credential can be used for signing - * * @return true if the credential has a {@link Saml2X509CredentialType#SIGNING} type */ public boolean isSigningCredential() { @@ -170,8 +155,8 @@ public final class Saml2X509Credential { /** * Indicate whether this credential can be used for decryption - * - * @return true if the credential has a {@link Saml2X509CredentialType#DECRYPTION} type + * @return true if the credential has a {@link Saml2X509CredentialType#DECRYPTION} + * type */ public boolean isDecryptionCredential() { return getCredentialTypes().contains(Saml2X509CredentialType.DECRYPTION); @@ -179,8 +164,8 @@ public final class Saml2X509Credential { /** * Indicate whether this credential can be used for verification - * - * @return true if the credential has a {@link Saml2X509CredentialType#VERIFICATION} type + * @return true if the credential has a {@link Saml2X509CredentialType#VERIFICATION} + * type */ public boolean isVerificationCredential() { return getCredentialTypes().contains(Saml2X509CredentialType.VERIFICATION); @@ -188,8 +173,8 @@ public final class Saml2X509Credential { /** * Indicate whether this credential can be used for encryption - * - * @return true if the credential has a {@link Saml2X509CredentialType#ENCRYPTION} type + * @return true if the credential has a {@link Saml2X509CredentialType#ENCRYPTION} + * type */ public boolean isEncryptionCredential() { return getCredentialTypes().contains(Saml2X509CredentialType.ENCRYPTION); @@ -197,7 +182,6 @@ public final class Saml2X509Credential { /** * List all this credential's intended usages - * * @return the set of this credential's intended usages */ public Set getCredentialTypes() { @@ -206,12 +190,15 @@ public final class Saml2X509Credential { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } Saml2X509Credential that = (Saml2X509Credential) o; - return Objects.equals(this.privateKey, that.privateKey) && - this.certificate.equals(that.certificate) && - this.credentialTypes.equals(that.credentialTypes); + return Objects.equals(this.privateKey, that.privateKey) && this.certificate.equals(that.certificate) + && this.credentialTypes.equals(that.credentialTypes); } @Override @@ -228,7 +215,20 @@ public final class Saml2X509Credential { break; } } - state(valid, () -> usage +" is not a valid usage for this credential"); + Assert.state(valid, () -> usage + " is not a valid usage for this credential"); } } + + public enum Saml2X509CredentialType { + + VERIFICATION, + + ENCRYPTION, + + SIGNING, + + DECRYPTION, + + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/credentials/Saml2X509Credential.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/credentials/Saml2X509Credential.java index a83fc06e75..0649f09e3d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/credentials/Saml2X509Credential.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/credentials/Saml2X509Credential.java @@ -13,53 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.saml2.credentials; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.Arrays; import java.util.LinkedHashSet; import java.util.Objects; import java.util.Set; import org.springframework.util.Assert; -import static java.util.Arrays.asList; -import static org.springframework.util.Assert.notEmpty; -import static org.springframework.util.Assert.notNull; -import static org.springframework.util.Assert.state; - /** * Saml2X509Credential is meant to hold an X509 certificate, or an X509 certificate and a * private key. Per: * https://www.oasis-open.org/committees/download.php/8958/sstc-saml-implementation-guidelines-draft-01.pdf - * Line: 584, Section 4.3 Credentials Used for both signing, signature verification and encryption/decryption + * Line: 584, Section 4.3 Credentials Used for both signing, signature verification and + * encryption/decryption * * @since 5.2 - * @deprecated Use {@link org.springframework.security.saml2.core.Saml2X509Credential} instead + * @deprecated Use {@link org.springframework.security.saml2.core.Saml2X509Credential} + * instead */ @Deprecated public class Saml2X509Credential { - /** - * @deprecated Use {@link org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType} instead - */ - @Deprecated - public enum Saml2X509CredentialType { - VERIFICATION, - ENCRYPTION, - SIGNING, - DECRYPTION, - } private final PrivateKey privateKey; + private final X509Certificate certificate; + private final Set credentialTypes; /** * Creates a Saml2X509Credentials representing Identity Provider credentials for * verification, encryption or both. * @param certificate an IDP X509Certificate, cannot be null - * @param types credential types, must be one of {@link Saml2X509CredentialType#VERIFICATION} or - * {@link Saml2X509CredentialType#ENCRYPTION} or both. + * @param types credential types, must be one of + * {@link Saml2X509CredentialType#VERIFICATION} or + * {@link Saml2X509CredentialType#ENCRYPTION} or both. */ public Saml2X509Credential(X509Certificate certificate, Saml2X509CredentialType... types) { this(null, false, certificate, types); @@ -70,9 +62,11 @@ public class Saml2X509Credential { * Creates a Saml2X509Credentials representing Service Provider credentials for * signing, decryption or both. * @param privateKey a private key used for signing or decryption, cannot be null - * @param certificate an SP X509Certificate shared with identity providers, cannot be null - * @param types credential types, must be one of {@link Saml2X509CredentialType#SIGNING} or - * {@link Saml2X509CredentialType#DECRYPTION} or both. + * @param certificate an SP X509Certificate shared with identity providers, cannot be + * null + * @param types credential types, must be one of + * {@link Saml2X509CredentialType#SIGNING} or + * {@link Saml2X509CredentialType#DECRYPTION} or both. */ public Saml2X509Credential(PrivateKey privateKey, X509Certificate certificate, Saml2X509CredentialType... types) { this(privateKey, true, certificate, types); @@ -87,25 +81,21 @@ public class Saml2X509Credential { this.credentialTypes = types; } - private Saml2X509Credential( - PrivateKey privateKey, - boolean keyRequired, - X509Certificate certificate, + private Saml2X509Credential(PrivateKey privateKey, boolean keyRequired, X509Certificate certificate, Saml2X509CredentialType... types) { - notNull(certificate, "certificate cannot be null"); - notEmpty(types, "credentials types cannot be empty"); + Assert.notNull(certificate, "certificate cannot be null"); + Assert.notEmpty(types, "credentials types cannot be empty"); if (keyRequired) { - notNull(privateKey, "privateKey cannot be null"); + Assert.notNull(privateKey, "privateKey cannot be null"); } this.privateKey = privateKey; this.certificate = certificate; - this.credentialTypes = new LinkedHashSet<>(asList(types)); + this.credentialTypes = new LinkedHashSet<>(Arrays.asList(types)); } - /** - * Returns true if the credential has a private key and can be used for signing, the types will contain - * {@link Saml2X509CredentialType#SIGNING}. + * Returns true if the credential has a private key and can be used for signing, the + * types will contain {@link Saml2X509CredentialType#SIGNING}. * @return true if the credential is a {@link Saml2X509CredentialType#SIGNING} type */ public boolean isSigningCredential() { @@ -113,8 +103,8 @@ public class Saml2X509Credential { } /** - * Returns true if the credential has a private key and can be used for decryption, the types will contain - * {@link Saml2X509CredentialType#DECRYPTION}. + * Returns true if the credential has a private key and can be used for decryption, + * the types will contain {@link Saml2X509CredentialType#DECRYPTION}. * @return true if the credential is a {@link Saml2X509CredentialType#DECRYPTION} type */ public boolean isDecryptionCredential() { @@ -122,18 +112,20 @@ public class Saml2X509Credential { } /** - * Returns true if the credential has a certificate and can be used for signature verification, the types will contain - * {@link Saml2X509CredentialType#VERIFICATION}. - * @return true if the credential is a {@link Saml2X509CredentialType#VERIFICATION} type + * Returns true if the credential has a certificate and can be used for signature + * verification, the types will contain {@link Saml2X509CredentialType#VERIFICATION}. + * @return true if the credential is a {@link Saml2X509CredentialType#VERIFICATION} + * type */ public boolean isSignatureVerficationCredential() { return getCredentialTypes().contains(Saml2X509CredentialType.VERIFICATION); } /** - * Returns true if the credential has a certificate and can be used for signature verification, the types will contain - * {@link Saml2X509CredentialType#VERIFICATION}. - * @return true if the credential is a {@link Saml2X509CredentialType#VERIFICATION} type + * Returns true if the credential has a certificate and can be used for signature + * verification, the types will contain {@link Saml2X509CredentialType#VERIFICATION}. + * @return true if the credential is a {@link Saml2X509CredentialType#VERIFICATION} + * type */ public boolean isEncryptionCredential() { return getCredentialTypes().contains(Saml2X509CredentialType.ENCRYPTION); @@ -150,7 +142,7 @@ public class Saml2X509Credential { /** * Returns the private key, or null if this credential type doesn't require one. * @return the private key, or null - * @see {@link #Saml2X509Credential(PrivateKey, X509Certificate, Saml2X509CredentialType...)} + * @see #Saml2X509Credential(PrivateKey, X509Certificate, Saml2X509CredentialType...) */ public PrivateKey getPrivateKey() { return this.privateKey; @@ -166,12 +158,15 @@ public class Saml2X509Credential { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } Saml2X509Credential that = (Saml2X509Credential) o; - return Objects.equals(this.privateKey, that.privateKey) && - this.certificate.equals(that.certificate) && - this.credentialTypes.equals(that.credentialTypes); + return Objects.equals(this.privateKey, that.privateKey) && this.certificate.equals(that.certificate) + && this.credentialTypes.equals(that.credentialTypes); } @Override @@ -188,7 +183,26 @@ public class Saml2X509Credential { break; } } - state(valid, () -> usage +" is not a valid usage for this credential"); + Assert.state(valid, () -> usage + " is not a valid usage for this credential"); } } + + /** + * @deprecated Use + * {@link org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType} + * instead + */ + @Deprecated + public enum Saml2X509CredentialType { + + VERIFICATION, + + ENCRYPTION, + + SIGNING, + + DECRYPTION, + + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java index 99ef2cdc22..028ecd6bae 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java @@ -16,39 +16,41 @@ package org.springframework.security.saml2.provider.service.authentication; +import java.nio.charset.Charset; + import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; -import java.nio.charset.Charset; - /** * Data holder for {@code AuthNRequest} parameters to be sent using either the - * {@link Saml2MessageBinding#POST} or {@link Saml2MessageBinding#REDIRECT} binding. - * Data will be encoded and possibly deflated, but will not be escaped for transport, - * ie URL encoded, {@link org.springframework.web.util.UriUtils#encode(String, Charset)} - * or HTML encoded, {@link org.springframework.web.util.HtmlUtils#htmlEscape(String)}. - * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf (line 2031) + * {@link Saml2MessageBinding#POST} or {@link Saml2MessageBinding#REDIRECT} binding. Data + * will be encoded and possibly deflated, but will not be escaped for transport, ie URL + * encoded, {@link org.springframework.web.util.UriUtils#encode(String, Charset)} or HTML + * encoded, {@link org.springframework.web.util.HtmlUtils#htmlEscape(String)}. + * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf + * (line 2031) * + * @since 5.3 * @see Saml2AuthenticationRequestFactory#createPostAuthenticationRequest(Saml2AuthenticationRequestContext) * @see Saml2AuthenticationRequestFactory#createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext) - * @since 5.3 */ -abstract class AbstractSaml2AuthenticationRequest { +public abstract class AbstractSaml2AuthenticationRequest { private final String samlRequest; + private final String relayState; + private final String authenticationRequestUri; /** * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest} - * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or null + * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or + * null * @param relayState - RelayState value that accompanies the request, may be null - * @param authenticationRequestUri - The authenticationRequestUri, a URL, where to send the XML message, cannot be empty or null + * @param authenticationRequestUri - The authenticationRequestUri, a URL, where to + * send the XML message, cannot be empty or null */ - AbstractSaml2AuthenticationRequest( - String samlRequest, - String relayState, - String authenticationRequestUri) { + AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) { Assert.hasText(samlRequest, "samlRequest cannot be null or empty"); Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty"); this.authenticationRequestUri = authenticationRequestUri; @@ -57,9 +59,10 @@ abstract class AbstractSaml2AuthenticationRequest { } /** - * Returns the AuthNRequest XML value to be sent. This value is already encoded for transport. - * If {@link #getBinding()} is {@link Saml2MessageBinding#REDIRECT} the value is deflated and SAML encoded. - * If {@link #getBinding()} is {@link Saml2MessageBinding#POST} the value is SAML encoded. + * Returns the AuthNRequest XML value to be sent. This value is already encoded for + * transport. If {@link #getBinding()} is {@link Saml2MessageBinding#REDIRECT} the + * value is deflated and SAML encoded. If {@link #getBinding()} is + * {@link Saml2MessageBinding#POST} the value is SAML encoded. * @return the SAMLRequest parameter value */ public String getSamlRequest() { @@ -83,8 +86,9 @@ abstract class AbstractSaml2AuthenticationRequest { } /** - * Returns the binding this AuthNRequest will be sent and - * encoded with. If {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be automatically applied. + * Returns the binding this AuthNRequest will be sent and encoded with. If + * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be + * automatically applied. * @return the binding this message will be sent with. */ public abstract Saml2MessageBinding getBinding(); @@ -92,9 +96,12 @@ abstract class AbstractSaml2AuthenticationRequest { /** * A builder for {@link AbstractSaml2AuthenticationRequest} and its subclasses. */ - static class Builder> { + public static class Builder> { + String authenticationRequestUri; + String samlRequest; + String relayState; protected Builder() { @@ -109,12 +116,10 @@ abstract class AbstractSaml2AuthenticationRequest { return (T) this; } - /** * Sets the {@code RelayState} parameter that will accompany this AuthNRequest - * - * @param relayState the relay state value, unencoded. if null or empty, the parameter will be removed from the - * map. + * @param relayState the relay state value, unencoded. if null or empty, the + * parameter will be removed from the map. * @return this object */ public T relayState(String relayState) { @@ -124,7 +129,6 @@ abstract class AbstractSaml2AuthenticationRequest { /** * Sets the {@code SAMLRequest} parameter that will accompany this AuthNRequest - * * @param samlRequest the SAMLRequest parameter. * @return this object */ @@ -134,8 +138,8 @@ abstract class AbstractSaml2AuthenticationRequest { } /** - * Sets the {@code authenticationRequestUri}, a URL that will receive the AuthNRequest message - * + * Sets the {@code authenticationRequestUri}, a URL that will receive the + * AuthNRequest message * @param authenticationRequestUri the relay state value, unencoded. * @return this object */ @@ -143,6 +147,7 @@ abstract class AbstractSaml2AuthenticationRequest { this.authenticationRequestUri = authenticationRequestUri; return _this(); } + } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java index b474c20aed..8e1ac9270b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipal.java @@ -16,12 +16,12 @@ package org.springframework.security.saml2.provider.service.authentication; -import org.springframework.util.Assert; - import java.io.Serializable; import java.util.List; import java.util.Map; +import org.springframework.util.Assert; + /** * Default implementation of a {@link Saml2AuthenticatedPrincipal}. * @@ -31,12 +31,12 @@ import java.util.Map; public class DefaultSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPrincipal, Serializable { private final String name; + private final Map> attributes; public DefaultSaml2AuthenticatedPrincipal(String name, Map> attributes) { Assert.notNull(name, "name cannot be null"); Assert.notNull(attributes, "attributes cannot be null"); - this.name = name; this.attributes = attributes; } @@ -50,4 +50,5 @@ public class DefaultSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPri public Map> getAttributes() { return this.attributes; } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java index 11fa522588..26264076ab 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayInputStream; @@ -20,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -29,6 +31,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; + import javax.annotation.Nonnull; import javax.xml.namespace.QName; @@ -55,6 +58,7 @@ import org.opensaml.saml.criterion.ProtocolCriterion; import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion; import org.opensaml.saml.saml2.assertion.ConditionValidator; import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator; +import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; import org.opensaml.saml.saml2.assertion.StatementValidator; import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator; import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator; @@ -97,6 +101,7 @@ import org.w3c.dom.Document; import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; @@ -113,59 +118,49 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import static java.util.Arrays.asList; -import static java.util.Collections.singleton; -import static java.util.Collections.singletonList; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.CLOCK_SKEW; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.COND_VALID_AUDIENCES; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.DECRYPTION_ERROR; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_DESTINATION; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ISSUER; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND; -import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.failure; -import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success; -import static org.springframework.util.Assert.notNull; - /** - * Implementation of {@link AuthenticationProvider} for SAML authentications when receiving a - * {@code Response} object containing an {@code Assertion}. This implementation uses - * the {@code OpenSAML 3} library. + * Implementation of {@link AuthenticationProvider} for SAML authentications when + * receiving a {@code Response} object containing an {@code Assertion}. This + * implementation uses the {@code OpenSAML 3} library. * *

        - * The {@link OpenSamlAuthenticationProvider} supports {@link Saml2AuthenticationToken} objects - * that contain a SAML response in its decoded XML format {@link Saml2AuthenticationToken#getSaml2Response()} - * along with the information about the asserting party, the identity provider (IDP), as well as - * the relying party, the service provider (SP, this application). + * The {@link OpenSamlAuthenticationProvider} supports {@link Saml2AuthenticationToken} + * objects that contain a SAML response in its decoded XML format + * {@link Saml2AuthenticationToken#getSaml2Response()} along with the information about + * the asserting party, the identity provider (IDP), as well as the relying party, the + * service provider (SP, this application). *

        *

        - * The {@link Saml2AuthenticationToken} will be processed into a SAML Response object. - * The SAML response object can be signed. If the Response is signed, a signature will not be required on the assertion. + * The {@link Saml2AuthenticationToken} will be processed into a SAML Response object. The + * SAML response object can be signed. If the Response is signed, a signature will not be + * required on the assertion. *

        *

        - * While a response object can contain a list of assertion, this provider will only leverage - * the first valid assertion for the purpose of authentication. Assertions that do not pass validation - * will be ignored. If no valid assertions are found a {@link Saml2AuthenticationException} is thrown. + * While a response object can contain a list of assertion, this provider will only + * leverage the first valid assertion for the purpose of authentication. Assertions that + * do not pass validation will be ignored. If no valid assertions are found a + * {@link Saml2AuthenticationException} is thrown. *

        *

        - * This provider supports two types of encrypted SAML elements - *

        - * If the assertion is encrypted, then signature validation on the assertion is no longer required. + * This provider supports two types of encrypted SAML elements + * + * If the assertion is encrypted, then signature validation on the assertion is no longer + * required. *

        *

        - * This provider does not perform an X509 certificate validation on the configured asserting party, IDP, verification - * certificates. + * This provider does not perform an X509 certificate validation on the configured + * asserting party, IDP, verification certificates. *

        + * * @since 5.2 - * @see SAML 2 StatusResponse + * @see SAML 2 + * StatusResponse * @see OpenSAML 3 */ public final class OpenSamlAuthenticationProvider implements AuthenticationProvider { @@ -177,49 +172,45 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class); private final XMLObjectProviderRegistry registry; + private final ResponseUnmarshaller responseUnmarshaller; + private final ParserPool parserPool; - private Converter> authoritiesExtractor = - (a -> singletonList(new SimpleGrantedAuthority("ROLE_USER"))); - private GrantedAuthoritiesMapper authoritiesMapper = (a -> a); + private Converter> authoritiesExtractor = ((a) -> Collections + .singletonList(new SimpleGrantedAuthority("ROLE_USER"))); + + private GrantedAuthoritiesMapper authoritiesMapper = ((a) -> a); + private Duration responseTimeValidationSkew = Duration.ofMinutes(5); - private Converter responseAuthenticationConverter = - responseToken -> { - Response response = responseToken.response; - Saml2AuthenticationToken token = responseToken.token; - Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); - String username = assertion.getSubject().getNameID().getValue(); - Map> attributes = getAssertionAttributes(assertion); - return new Saml2Authentication( - new DefaultSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(), - this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); - }; + private Converter responseAuthenticationConverter = ( + responseToken) -> { + Response response = responseToken.response; + Saml2AuthenticationToken token = responseToken.token; + Assertion assertion = CollectionUtils.firstElement(response.getAssertions()); + String username = assertion.getSubject().getNameID().getValue(); + Map> attributes = getAssertionAttributes(assertion); + return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes), + token.getSaml2Response(), this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion))); + }; - private Converter assertionSignatureValidator = - createDefaultAssertionValidator(INVALID_SIGNATURE, - assertionToken -> { - SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token); - return SAML20AssertionValidators.createSignatureValidator(engine); - }, - assertionToken -> - new ValidationContext(Collections.singletonMap(SIGNATURE_REQUIRED, false)) - ); + private Converter assertionSignatureValidator = createDefaultAssertionValidator( + Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { + SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token); + return SAML20AssertionValidators.createSignatureValidator(engine); + }, (assertionToken) -> new ValidationContext( + Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false))); - private Converter assertionValidator = - createDefaultAssertionValidator(INVALID_ASSERTION, - assertionToken -> SAML20AssertionValidators.attributeValidator, - assertionToken -> createValidationContext( - assertionToken, - params -> params.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis()) - )); + private Converter assertionValidator = createDefaultAssertionValidator( + Saml2ErrorCodes.INVALID_ASSERTION, (assertionToken) -> SAML20AssertionValidators.attributeValidator, + (assertionToken) -> createValidationContext(assertionToken, (params) -> params + .put(SAML2AssertionValidationParameters.CLOCK_SKEW, this.responseTimeValidationSkew.toMillis()))); + + private Converter signatureTrustEngineConverter = new SignatureTrustEngineConverter(); - private Converter signatureTrustEngineConverter = - new SignatureTrustEngineConverter(); private Converter decrypterConverter = new DecrypterConverter(); - /** * Creates an {@link OpenSamlAuthenticationProvider} */ @@ -231,7 +222,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi } /** - * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response. + * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML + * 2.0 Response. * * You can still invoke the default validator by delgating to * {@link #createDefaultAssertionValidator}, like so: @@ -259,15 +251,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * })); *
        * - * Consider taking a look at {@link #createValidationContext} to see how it - * constructs a {@link ValidationContext}. + * Consider taking a look at {@link #createValidationContext} to see how it constructs + * a {@link ValidationContext}. * * It is not necessary to delegate to the default validator. You can safely replace it * entirely with your own. Note that signature verification is performed as a separate * step from this validator. * * This method takes precedence over {@link #setResponseTimeValidationSkew}. - * * @param assertionValidator * @since 5.4 */ @@ -280,8 +271,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi * Set the {@link Converter} to use for converting a validated {@link Response} into * an {@link AbstractAuthenticationToken}. * - * You can delegate to the default behavior by calling {@link #createDefaultResponseAuthenticationConverter()} - * like so: + * You can delegate to the default behavior by calling + * {@link #createDefaultResponseAuthenticationConverter()} like so: * *
         	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
        @@ -296,7 +287,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         	 *
         	 * This method takes precedence over {@link #setAuthoritiesExtractor(Converter)} and
         	 * {@link #setAuthoritiesMapper(GrantedAuthoritiesMapper)}.
        -	 *
         	 * @param responseAuthenticationConverter the {@link Converter} to use
         	 * @since 5.4
         	 */
        @@ -307,26 +297,28 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         	}
         
         	/**
        -	 * Sets the {@link Converter} used for extracting assertion attributes that
        -	 * can be mapped to authorities.
        -	 * @param authoritiesExtractor the {@code Converter} used for mapping the
        -	 *                             assertion attributes to authorities
        +	 * Sets the {@link Converter} used for extracting assertion attributes that can be
        +	 * mapped to authorities.
        +	 * @param authoritiesExtractor the {@code Converter} used for mapping the assertion
        +	 * attributes to authorities
         	 * @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead
         	 */
        -	public void setAuthoritiesExtractor(Converter> authoritiesExtractor) {
        +	public void setAuthoritiesExtractor(
        +			Converter> authoritiesExtractor) {
         		Assert.notNull(authoritiesExtractor, "authoritiesExtractor cannot be null");
         		this.authoritiesExtractor = authoritiesExtractor;
         	}
         
         	/**
        -	 * Sets the {@link GrantedAuthoritiesMapper} used for mapping assertion attributes
        -	 * to a new set of authorities which will be associated to the {@link Saml2Authentication}.
        -	 * Note: This implementation is only retrieving
        -	 * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities
        +	 * Sets the {@link GrantedAuthoritiesMapper} used for mapping assertion attributes to
        +	 * a new set of authorities which will be associated to the
        +	 * {@link Saml2Authentication}. Note: This implementation is only retrieving
        +	 * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the
        +	 * user's authorities
         	 * @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead
         	 */
         	public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
        -		notNull(authoritiesMapper, "authoritiesMapper cannot be null");
        +		Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
         		this.authoritiesMapper = authoritiesMapper;
         	}
         
        @@ -340,64 +332,56 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         		this.responseTimeValidationSkew = responseTimeValidationSkew;
         	}
         
        -
         	/**
        -	 * Construct a default strategy for validating each SAML 2.0 Assertion and
        -	 * associated {@link Authentication} token
        -	 *
        +	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
        +	 * {@link Authentication} token
         	 * @return the default assertion validator strategy
         	 * @since 5.4
         	 */
        -	public static Converter
        -			createDefaultAssertionValidator() {
        +	public static Converter createDefaultAssertionValidator() {
         
        -		return createDefaultAssertionValidator(INVALID_ASSERTION,
        -				assertionToken -> SAML20AssertionValidators.attributeValidator,
        -				assertionToken -> createValidationContext(assertionToken, params -> {}));
        +		return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
        +				(assertionToken) -> SAML20AssertionValidators.attributeValidator,
        +				(assertionToken) -> createValidationContext(assertionToken, (params) -> {
        +				}));
         	}
         
         	/**
        -	 * Construct a default strategy for validating each SAML 2.0 Assertion and
        -	 * associated {@link Authentication} token
        -	 *
        +	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
        +	 * {@link Authentication} token
        +	 * @param contextConverter the conversion strategy to use to generate a
        +	 * {@link ValidationContext} for each assertion being validated
         	 * @return the default assertion validator strategy
        -	 * @param contextConverter the conversion strategy to use to generate a {@link ValidationContext}
        -	 * for each assertion being validated
         	 * @since 5.4
         	 */
        -	public static Converter
        -			createDefaultAssertionValidator(Converter contextConverter) {
        +	public static Converter createDefaultAssertionValidator(
        +			Converter contextConverter) {
         
        -		return createDefaultAssertionValidator(INVALID_ASSERTION,
        -				assertionToken -> SAML20AssertionValidators.attributeValidator,
        -				contextConverter);
        +		return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
        +				(assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter);
         	}
         
         	/**
        -	 * Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication}
        -	 * token into a {@link Saml2Authentication}
        -	 *
        +	 * Construct a default strategy for converting a SAML 2.0 Response and
        +	 * {@link Authentication} token into a {@link Saml2Authentication}
         	 * @return the default response authentication converter strategy
         	 * @since 5.4
         	 */
        -	public static Converter
        -			createDefaultResponseAuthenticationConverter() {
        -		return responseToken -> {
        +	public static Converter createDefaultResponseAuthenticationConverter() {
        +		return (responseToken) -> {
         			Saml2AuthenticationToken token = responseToken.token;
         			Response response = responseToken.response;
         			Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
         			String username = assertion.getSubject().getNameID().getValue();
         			Map> attributes = getAssertionAttributes(assertion);
        -			return new Saml2Authentication(
        -					new DefaultSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
        -					Collections.singleton(new SimpleGrantedAuthority("ROLE_USER")));
        +			return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
        +					token.getSaml2Response(), Collections.singleton(new SimpleGrantedAuthority("ROLE_USER")));
         		};
         	}
         
         	/**
         	 * @param authentication the authentication request object, must be of type
        -	 *                       {@link Saml2AuthenticationToken}
        -	 *
        +	 * {@link Saml2AuthenticationToken}
         	 * @return {@link Saml2Authentication} if the assertion is valid
         	 * @throws AuthenticationException if a validation exception occurs
         	 */
        @@ -409,16 +393,15 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         			Response response = parse(serializedResponse);
         			process(token, response);
         			return this.responseAuthenticationConverter.convert(new ResponseToken(response, token));
        -		} catch (Saml2AuthenticationException e) {
        -			throw e;
        -		} catch (Exception e) {
        -			throw authException(INTERNAL_VALIDATION_ERROR, e.getMessage(), e);
        +		}
        +		catch (Saml2AuthenticationException ex) {
        +			throw ex;
        +		}
        +		catch (Exception ex) {
        +			throw createAuthenticationException(Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR, ex.getMessage(), ex);
         		}
         	}
         
        -	/**
        -	 * {@inheritDoc}
        -	 */
         	@Override
         	public boolean supports(Class authentication) {
         		return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
        @@ -430,37 +413,35 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         
         	private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException {
         		try {
        -			Document document = this.parserPool.parse(new ByteArrayInputStream(
        -					response.getBytes(StandardCharsets.UTF_8)));
        +			Document document = this.parserPool
        +					.parse(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8)));
         			Element element = document.getDocumentElement();
         			return (Response) this.responseUnmarshaller.unmarshall(element);
         		}
        -		catch (Exception e) {
        -			throw authException(MALFORMED_RESPONSE_DATA, e.getMessage(), e);
        +		catch (Exception ex) {
        +			throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, ex.getMessage(), ex);
         		}
         	}
         
         	private void process(Saml2AuthenticationToken token, Response response) {
         		String issuer = response.getIssuer().getValue();
        -		if (logger.isDebugEnabled()) {
        -			logger.debug("Processing SAML response from " + issuer);
        -		}
        -
        +		logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
         		boolean responseSigned = response.isSigned();
         		Saml2ResponseValidatorResult result = validateResponse(token, response);
         
         		Decrypter decrypter = this.decrypterConverter.convert(token);
         		List assertions = decryptAssertions(decrypter, response);
         		if (!isSigned(responseSigned, assertions)) {
        -			throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
        -					"Please either sign the response or all of the assertions.");
        +			String description = "Either the response or one of the assertions is unsigned. "
        +					+ "Please either sign the response or all of the assertions.";
        +			throw createAuthenticationException(Saml2ErrorCodes.INVALID_SIGNATURE, description, null);
         		}
         		result = result.concat(validateAssertions(token, response));
         
         		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
         		NameID nameId = decryptPrincipal(decrypter, firstAssertion);
         		if (nameId == null || nameId.getValue() == null) {
        -			Saml2Error error = new Saml2Error(SUBJECT_NOT_FOUND,
        +			Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
         					"Assertion [" + firstAssertion.getID() + "] is missing a subject");
         			result = result.concat(error);
         		}
        @@ -468,89 +449,91 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         		if (result.hasErrors()) {
         			Collection errors = result.getErrors();
         			if (logger.isTraceEnabled()) {
        -				logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]: " +
        -						errors);
        -			} else if (logger.isDebugEnabled()) {
        -				logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
        +				logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID()
        +						+ "]: " + errors);
        +			}
        +			else if (logger.isDebugEnabled()) {
        +				logger.debug(
        +						"Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
         			}
         			Saml2Error first = errors.iterator().next();
        -			throw authException(first.getErrorCode(), first.getDescription());
        -		} else {
        +			throw createAuthenticationException(first.getErrorCode(), first.getDescription(), null);
        +		}
        +		else {
         			if (logger.isDebugEnabled()) {
         				logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
         			}
         		}
         	}
         
        -	private Saml2ResponseValidatorResult validateResponse
        -			(Saml2AuthenticationToken token, Response response) {
        +	private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken token, Response response) {
         
         		Collection errors = new ArrayList<>();
         		String issuer = response.getIssuer().getValue();
        -
         		if (response.isSigned()) {
         			SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
         			try {
         				profileValidator.validate(response.getSignature());
        -			} catch (Exception e) {
        -				errors.add(new Saml2Error(INVALID_SIGNATURE,
        +			}
        +			catch (Exception ex) {
        +				errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
         						"Invalid signature for SAML Response [" + response.getID() + "]: "));
         			}
         
         			try {
         				CriteriaSet criteriaSet = new CriteriaSet();
         				criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
        -				criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
        +				criteriaSet.add(
        +						new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
         				criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
         				if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
        -					errors.add(new Saml2Error(INVALID_SIGNATURE,
        +					errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
         							"Invalid signature for SAML Response [" + response.getID() + "]"));
         				}
        -			} catch (Exception e) {
        -				errors.add(new Saml2Error(INVALID_SIGNATURE,
        +			}
        +			catch (Exception ex) {
        +				errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
         						"Invalid signature for SAML Response [" + response.getID() + "]: "));
         			}
         		}
        -
         		String destination = response.getDestination();
         		String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
         		if (StringUtils.hasText(destination) && !destination.equals(location)) {
         			String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
        -			errors.add(new Saml2Error(INVALID_DESTINATION, message));
        +			errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
         		}
        -
         		String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
         		if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
         			String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
        -			errors.add(new Saml2Error(INVALID_ISSUER, message));
        +			errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
         		}
         
        -		return failure(errors);
        +		return Saml2ResponseValidatorResult.failure(errors);
         	}
         
        -	private List decryptAssertions
        -			(Decrypter decrypter, Response response) {
        +	private List decryptAssertions(Decrypter decrypter, Response response) {
         		List assertions = new ArrayList<>();
         		for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
         			try {
         				Assertion assertion = decrypter.decrypt(encryptedAssertion);
         				assertions.add(assertion);
        -			} catch (DecryptionException e) {
        -				throw authException(DECRYPTION_ERROR, e.getMessage(), e);
        +			}
        +			catch (DecryptionException ex) {
        +				throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
         			}
         		}
         		response.getAssertions().addAll(assertions);
         		return response.getAssertions();
         	}
         
        -	private Saml2ResponseValidatorResult validateAssertions
        -			(Saml2AuthenticationToken token, Response response) {
        +	private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
         		List assertions = response.getAssertions();
         		if (assertions.isEmpty()) {
        -			throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
        +			throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
        +					"No assertions found in response.", null);
         		}
         
        -		Saml2ResponseValidatorResult result = success();
        +		Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
         		if (logger.isDebugEnabled()) {
         			logger.debug("Validating " + assertions.size() + " assertions");
         		}
        @@ -560,25 +543,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         				logger.trace("Validating assertion " + assertion.getID());
         			}
         			AssertionToken assertionToken = new AssertionToken(assertion, token);
        -			result = result
        -					.concat(this.assertionSignatureValidator.convert(assertionToken))
        +			result = result.concat(this.assertionSignatureValidator.convert(assertionToken))
         					.concat(this.assertionValidator.convert(assertionToken));
         		}
         
         		return result;
         	}
         
        +	private void addValidationException(Map exceptions, String code,
        +			String message, Exception cause) {
        +		exceptions.put(code, createAuthenticationException(code, message, cause));
        +	}
        +
         	private boolean isSigned(boolean responseSigned, List assertions) {
         		if (responseSigned) {
         			return true;
         		}
        -
         		for (Assertion assertion : assertions) {
         			if (!assertion.isSigned()) {
         				return false;
         			}
         		}
        -
         		return true;
         	}
         
        @@ -593,8 +578,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         			NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID());
         			assertion.getSubject().setNameID(nameId);
         			return nameId;
        -		} catch (DecryptionException e) {
        -			throw authException(DECRYPTION_ERROR, e.getMessage(), e);
        +		}
        +		catch (DecryptionException ex) {
        +			throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
         		}
         	}
         
        @@ -602,7 +588,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         		Map> attributeMap = new LinkedHashMap<>();
         		for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
         			for (Attribute attribute : attributeStatement.getAttributes()) {
        -
         				List attributeValues = new ArrayList<>();
         				for (XMLObject xmlObject : attribute.getAttributeValues()) {
         					Object attributeValue = getXmlObjectValue(xmlObject);
        @@ -611,7 +596,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         					}
         				}
         				attributeMap.put(attribute.getName(), attributeValues);
        -
         			}
         		}
         		return attributeMap;
        @@ -632,21 +616,64 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         		}
         		if (xmlObject instanceof XSBoolean) {
         			XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue();
        -			return xsBooleanValue != null ? xsBooleanValue.getValue() : null;
        +			return (xsBooleanValue != null) ? xsBooleanValue.getValue() : null;
         		}
         		if (xmlObject instanceof XSDateTime) {
         			DateTime dateTime = ((XSDateTime) xmlObject).getValue();
        -			return dateTime != null ? Instant.ofEpochMilli(dateTime.getMillis()) : null;
        +			return (dateTime != null) ? Instant.ofEpochMilli(dateTime.getMillis()) : null;
         		}
         		return null;
         	}
         
        -	private static class SignatureTrustEngineConverter implements Converter {
        +	private static Saml2AuthenticationException createAuthenticationException(String code, String message,
        +			Exception cause) {
        +		return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
        +	}
        +
        +	private static Converter createDefaultAssertionValidator(
        +			String errorCode, Converter validatorConverter,
        +			Converter contextConverter) {
        +
        +		return (assertionToken) -> {
        +			Assertion assertion = assertionToken.assertion;
        +			SAML20AssertionValidator validator = validatorConverter.convert(assertionToken);
        +			ValidationContext context = contextConverter.convert(assertionToken);
        +			try {
        +				ValidationResult result = validator.validate(assertion, context);
        +				if (result == ValidationResult.VALID) {
        +					return Saml2ResponseValidatorResult.success();
        +				}
        +			}
        +			catch (Exception ex) {
        +				String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
        +						((Response) assertion.getParent()).getID(), ex.getMessage());
        +				return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
        +			}
        +			String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s", assertion.getID(),
        +					((Response) assertion.getParent()).getID(), context.getValidationFailureMessage());
        +			return Saml2ResponseValidatorResult.failure(new Saml2Error(errorCode, message));
        +		};
        +	}
        +
        +	private static ValidationContext createValidationContext(AssertionToken assertionToken,
        +			Consumer> paramsConsumer) {
        +		String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId();
        +		String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
        +		Map params = new HashMap<>();
        +		params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
        +		params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
        +		paramsConsumer.accept(params);
        +		return new ValidationContext(params);
        +	}
        +
        +	private static class SignatureTrustEngineConverter
        +			implements Converter {
         
         		@Override
         		public SignatureTrustEngine convert(Saml2AuthenticationToken token) {
         			Set credentials = new HashSet<>();
        -			Collection keys = token.getRelyingPartyRegistration().getAssertingPartyDetails().getVerificationX509Credentials();
        +			Collection keys = token.getRelyingPartyRegistration().getAssertingPartyDetails()
        +					.getVerificationX509Credentials();
         			for (Saml2X509Credential key : keys) {
         				BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
         				cred.setUsageType(UsageType.SIGNING);
        @@ -654,55 +681,20 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         				credentials.add(cred);
         			}
         			CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
        -			return new ExplicitKeySignatureTrustEngine(
        -					credentialsResolver,
        -					DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()
        -			);
        +			return new ExplicitKeySignatureTrustEngine(credentialsResolver,
        +					DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver());
         		}
        -	}
         
        -	private static Converter createDefaultAssertionValidator(
        -			String errorCode,
        -			Converter validatorConverter,
        -			Converter contextConverter) {
        -
        -		return assertionToken -> {
        -			Assertion assertion = assertionToken.assertion;
        -			SAML20AssertionValidator validator = validatorConverter.convert(assertionToken);
        -			ValidationContext context = contextConverter.convert(assertionToken);
        -			try {
        -				ValidationResult result = validator.validate(assertion, context);
        -				if (result == ValidationResult.VALID) {
        -					return success();
        -				}
        -			} catch (Exception e) {
        -				String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
        -						assertion.getID(), ((Response) assertion.getParent()).getID(),
        -						e.getMessage());
        -				return failure(new Saml2Error(errorCode, message));
        -			}
        -			String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
        -					assertion.getID(), ((Response) assertion.getParent()).getID(),
        -					context.getValidationFailureMessage());
        -			return failure(new Saml2Error(errorCode, message));
        -		};
        -	}
        -
        -	private static ValidationContext createValidationContext(
        -			AssertionToken assertionToken, Consumer> paramsConsumer) {
        -		String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId();
        -		String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
        -		Map params = new HashMap<>();
        -		params.put(COND_VALID_AUDIENCES, singleton(audience));
        -		params.put(SC_VALID_RECIPIENTS, singleton(recipient));
        -		paramsConsumer.accept(params);
        -		return new ValidationContext(params);
         	}
         
         	private static class SAML20AssertionValidators {
        +
         		private static final Collection conditions = new ArrayList<>();
        +
         		private static final Collection subjects = new ArrayList<>();
        +
         		private static final Collection statements = new ArrayList<>();
        +
         		private static final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
         
         		static {
        @@ -733,18 +725,18 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         			});
         		}
         
        -		private static final SAML20AssertionValidator attributeValidator =
        -			new SAML20AssertionValidator(conditions, subjects, statements, null, null) {
        -				@Nonnull
        -				@Override
        -				protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
        -					return ValidationResult.VALID;
        -				}
        -			};
        +		private static final SAML20AssertionValidator attributeValidator = new SAML20AssertionValidator(conditions,
        +				subjects, statements, null, null) {
        +			@Nonnull
        +			@Override
        +			protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
        +				return ValidationResult.VALID;
        +			}
        +		};
         
         		static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) {
        -			return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
        -					engine, validator) {
        +			return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), engine,
        +					validator) {
         				@Nonnull
         				@Override
         				protected ValidationResult validateConditions(Assertion assertion, ValidationContext context) {
        @@ -763,17 +755,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         					return ValidationResult.VALID;
         				}
         			};
        +
         		}
        +
         	}
         
         	private static class DecrypterConverter implements Converter {
        +
         		private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
        -				asList(
        -						new InlineEncryptedKeyResolver(),
        -						new EncryptedElementTypeEncryptedKeyResolver(),
        -						new SimpleRetrievalMethodEncryptedKeyResolver()
        -				)
        -		);
        +				Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(),
        +						new SimpleRetrievalMethodEncryptedKeyResolver()));
         
         		@Override
         		public Decrypter convert(Saml2AuthenticationToken token) {
        @@ -787,31 +778,19 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         			decrypter.setRootInNewDocument(true);
         			return decrypter;
         		}
        -	}
         
        -	private static Saml2Error validationError(String code, String description) {
        -		return new Saml2Error(code, description);
        -	}
        -
        -	private static Saml2AuthenticationException authException(String code, String description)
        -			throws Saml2AuthenticationException {
        -
        -		return new Saml2AuthenticationException(validationError(code, description));
        -	}
        -
        -	private static Saml2AuthenticationException authException(String code, String description, Exception cause)
        -			throws Saml2AuthenticationException {
        -
        -		return new Saml2AuthenticationException(validationError(code, description), cause);
         	}
         
         	/**
        -	 * A tuple containing an OpenSAML {@link Response} and its associated authentication token.
        +	 * A tuple containing an OpenSAML {@link Response} and its associated authentication
        +	 * token.
         	 *
         	 * @since 5.4
         	 */
         	public static class ResponseToken {
        +
         		private final Saml2AuthenticationToken token;
        +
         		private final Response response;
         
         		ResponseToken(Response response, Saml2AuthenticationToken token) {
        @@ -826,15 +805,19 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         		public Saml2AuthenticationToken getToken() {
         			return this.token;
         		}
        +
         	}
         
         	/**
        -	 * A tuple containing an OpenSAML {@link Assertion} and its associated authentication token.
        +	 * A tuple containing an OpenSAML {@link Assertion} and its associated authentication
        +	 * token.
         	 *
         	 * @since 5.4
         	 */
         	public static class AssertionToken {
        +
         		private final Saml2AuthenticationToken token;
        +
         		private final Assertion assertion;
         
         		AssertionToken(Assertion assertion, Saml2AuthenticationToken token) {
        @@ -849,5 +832,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
         		public Saml2AuthenticationToken getToken() {
         			return this.token;
         		}
        +
         	}
        +
         }
        diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java
        index 8900cc00a2..0915607b2f 100644
        --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java
        +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java
        @@ -57,17 +57,14 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
         import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
         import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
         import org.springframework.util.Assert;
        +import org.springframework.util.StringUtils;
         import org.springframework.web.util.UriUtils;
         
        -import static java.nio.charset.StandardCharsets.UTF_8;
        -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
        -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
        -import static org.springframework.util.StringUtils.hasText;
        -
         /**
          * @since 5.2
          */
         public class OpenSamlAuthenticationRequestFactory implements Saml2AuthenticationRequestFactory {
        +
         	static {
         		OpenSamlInitializationService.initialize();
         	}
        @@ -75,19 +72,19 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         	private Clock clock = Clock.systemUTC();
         
         	private AuthnRequestMarshaller marshaller;
        +
         	private AuthnRequestBuilder authnRequestBuilder;
        +
         	private IssuerBuilder issuerBuilder;
         
        -	private Converter protocolBindingResolver =
        -			context -> {
        -				if (context == null) {
        -					return SAMLConstants.SAML2_POST_BINDING_URI;
        -				}
        -				return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
        -			};
        +	private Converter protocolBindingResolver = (context) -> {
        +		if (context == null) {
        +			return SAMLConstants.SAML2_POST_BINDING_URI;
        +		}
        +		return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
        +	};
         
        -	private Converter authenticationRequestContextConverter
        -			= this::createAuthnRequest;
        +	private Converter authenticationRequestContextConverter = this::createAuthnRequest;
         
         	/**
         	 * Creates an {@link OpenSamlAuthenticationRequestFactory}
        @@ -98,81 +95,64 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         				.getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
         		this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory()
         				.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
        -		this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory()
        -				.getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
        +		this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
         	}
         
         	@Override
         	@Deprecated
         	public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
        -		AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(),
        -				request.getDestination(), request.getAssertionConsumerServiceUrl(),
        -				this.protocolBindingResolver.convert(null));
        +		AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(), request.getDestination(),
        +				request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null));
         		for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
         			if (credential.isSigningCredential()) {
        -				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), request.getIssuer());
        +				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
        +						request.getIssuer());
         				return serialize(sign(authnRequest, cred));
         			}
         		}
         		throw new IllegalArgumentException("No signing credential provided");
         	}
         
        -	/**
        -	 * {@inheritDoc}
        -	 */
         	@Override
         	public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
         		AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
        -		String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
        -			serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
        -			serialize(authnRequest);
        +		String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()
        +				? serialize(sign(authnRequest, context.getRelyingPartyRegistration())) : serialize(authnRequest);
         
         		return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
        -				.samlRequest(samlEncode(xml.getBytes(UTF_8)))
        -				.build();
        +				.samlRequest(Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))).build();
         	}
         
        -	/**
        -	 * {@inheritDoc}
        -	 */
         	@Override
        -	public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
        +	public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(
        +			Saml2AuthenticationRequestContext context) {
         		AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
         		String xml = serialize(authnRequest);
         		Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
        -		String deflatedAndEncoded = samlEncode(samlDeflate(xml));
        -		result.samlRequest(deflatedAndEncoded)
        -				.relayState(context.getRelayState());
        -
        +		String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
        +		result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
         		if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
        -			Collection signingCredentials = context.getRelyingPartyRegistration().getSigningX509Credentials();
        +			Collection signingCredentials = context.getRelyingPartyRegistration()
        +					.getSigningX509Credentials();
         			for (Saml2X509Credential credential : signingCredentials) {
         				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), "");
        -				Map signedParams = signQueryParameters(
        -						cred,
        -						deflatedAndEncoded,
        +				Map signedParams = signQueryParameters(cred, deflatedAndEncoded,
         						context.getRelayState());
        -				return result
        -						.samlRequest(signedParams.get("SAMLRequest"))
        -						.relayState(signedParams.get("RelayState"))
        -						.sigAlg(signedParams.get("SigAlg"))
        -						.signature(signedParams.get("Signature"))
        -						.build();
        +				return result.samlRequest(signedParams.get("SAMLRequest")).relayState(signedParams.get("RelayState"))
        +						.sigAlg(signedParams.get("SigAlg")).signature(signedParams.get("Signature")).build();
         			}
         			throw new Saml2Exception("No signing credential provided");
         		}
        -
         		return result.build();
         	}
         
         	private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
        -		return createAuthnRequest(context.getIssuer(),
        -				context.getDestination(), context.getAssertionConsumerServiceUrl(),
        -				this.protocolBindingResolver.convert(context));
        +		return createAuthnRequest(context.getIssuer(), context.getDestination(),
        +				context.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(context));
         	}
         
        -	private AuthnRequest createAuthnRequest
        -			(String issuer, String destination, String assertionConsumerServiceUrl, String protocolBinding) {
        +	private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl,
        +			String protocolBinding) {
         		AuthnRequest auth = this.authnRequestBuilder.buildObject();
         		auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
         		auth.setIssueInstant(new DateTime(this.clock.millis()));
        @@ -189,7 +169,6 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         
         	/**
         	 * Set the {@link AuthnRequest} post-processor resolver
        -	 *
         	 * @param authenticationRequestContextConverter
         	 * @since 5.4
         	 */
        @@ -200,10 +179,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         	}
         
         	/**
        -	 * '
        -	 * Use this {@link Clock} with {@link Instant#now()} for generating
        -	 * timestamps
        -	 *
        +	 * ' Use this {@link Clock} with {@link Instant#now()} for generating timestamps
         	 * @param clock
         	 */
         	public void setClock(Clock clock) {
        @@ -214,30 +190,30 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         	/**
         	 * Sets the {@code protocolBinding} to use when generating authentication requests.
         	 * Acceptable values are {@link SAMLConstants#SAML2_POST_BINDING_URI} and
        -	 * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI}
        -	 * The IDP will be reading this value in the {@code AuthNRequest} to determine how to
        -	 * send the Response/Assertion to the ACS URL, assertion consumer service URL.
        -	 *
        +	 * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI} The IDP will be reading this value
        +	 * in the {@code AuthNRequest} to determine how to send the Response/Assertion to the
        +	 * ACS URL, assertion consumer service URL.
         	 * @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or
         	 * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI}
         	 * @throws IllegalArgumentException if the protocolBinding is not valid
        -	 * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding(Saml2MessageBinding)}
        +	 * @deprecated Use
        +	 * {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding(Saml2MessageBinding)}
         	 * instead
         	 */
         	@Deprecated
         	public void setProtocolBinding(String protocolBinding) {
        -		boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding) ||
        -				SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding);
        +		boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding)
        +				|| SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding);
         		if (!isAllowedBinding) {
         			throw new IllegalArgumentException("Invalid protocol binding: " + protocolBinding);
         		}
        -		this.protocolBindingResolver = context -> protocolBinding;
        +		this.protocolBindingResolver = (context) -> protocolBinding;
         	}
         
         	private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
         		for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) {
        -			Credential cred = getSigningCredential(
        -					credential.getCertificate(), credential.getPrivateKey(), relyingPartyRegistration.getEntityId());
        +			Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
        +					relyingPartyRegistration.getEntityId());
         			return sign(authnRequest, cred);
         		}
         		throw new IllegalArgumentException("No signing credential provided");
        @@ -252,8 +228,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         		try {
         			SignatureSupport.signObject(authnRequest, parameters);
         			return authnRequest;
        -		} catch (MarshallingException | SignatureException | SecurityException e) {
        -			throw new Saml2Exception(e);
        +		}
        +		catch (MarshallingException | SignatureException | SecurityException ex) {
        +			throw new Saml2Exception(ex);
         		}
         	}
         
        @@ -264,49 +241,32 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         		return cred;
         	}
         
        -	private Map signQueryParameters(
        -			Credential credential,
        -			String samlRequest,
        -			String relayState) {
        +	private Map signQueryParameters(Credential credential, String samlRequest, String relayState) {
         		Assert.notNull(samlRequest, "samlRequest cannot be null");
         		String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256;
         		StringBuilder queryString = new StringBuilder();
        -		queryString
        -				.append("SAMLRequest")
        -				.append("=")
        -				.append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
        +		queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
         				.append("&");
        -		if (hasText(relayState)) {
        -			queryString
        -					.append("RelayState")
        -					.append("=")
        -					.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1))
        -					.append("&");
        +		if (StringUtils.hasText(relayState)) {
        +			queryString.append("RelayState").append("=")
        +					.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
         		}
        -		queryString
        -				.append("SigAlg")
        -				.append("=")
        -				.append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
        -
        +		queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
         		try {
        -			byte[] rawSignature = XMLSigningUtil.signWithURI(
        -					credential,
        -					algorithmUri,
        -					queryString.toString().getBytes(StandardCharsets.UTF_8)
        -			);
        +			byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
        +					queryString.toString().getBytes(StandardCharsets.UTF_8));
         			String b64Signature = Saml2Utils.samlEncode(rawSignature);
        -
         			Map result = new LinkedHashMap<>();
         			result.put("SAMLRequest", samlRequest);
        -			if (hasText(relayState)) {
        +			if (StringUtils.hasText(relayState)) {
         				result.put("RelayState", relayState);
         			}
         			result.put("SigAlg", algorithmUri);
         			result.put("Signature", b64Signature);
         			return result;
         		}
        -		catch (SecurityException e) {
        -			throw new Saml2Exception(e);
        +		catch (SecurityException ex) {
        +			throw new Saml2Exception(ex);
         		}
         	}
         
        @@ -314,8 +274,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
         		try {
         			Element element = this.marshaller.marshall(authnRequest);
         			return SerializeSupport.nodeToString(element);
        -		} catch (MarshallingException e) {
        -			throw new Saml2Exception(e);
        +		}
        +		catch (MarshallingException ex) {
        +			throw new Saml2Exception(ex);
         		}
         	}
        +
         }
        diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java
        index 54cb297ffb..5996b0a4c5 100644
        --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java
        +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java
        @@ -16,14 +16,14 @@
         
         package org.springframework.security.saml2.provider.service.authentication;
         
        -import org.springframework.lang.Nullable;
        -import org.springframework.security.core.AuthenticatedPrincipal;
        -import org.springframework.util.CollectionUtils;
        -
         import java.util.Collections;
         import java.util.List;
         import java.util.Map;
         
        +import org.springframework.lang.Nullable;
        +import org.springframework.security.core.AuthenticatedPrincipal;
        +import org.springframework.util.CollectionUtils;
        +
         /**
          * Saml2 representation of an {@link AuthenticatedPrincipal}.
          *
        @@ -31,9 +31,9 @@ import java.util.Map;
          * @since 5.2.2
          */
         public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
        +
         	/**
         	 * Get the first value of Saml2 token attribute by name
        -	 *
         	 * @param name the name of the attribute
         	 * @param  the type of the attribute
         	 * @return the first attribute value or {@code null} otherwise
        @@ -47,7 +47,6 @@ public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
         
         	/**
         	 * Get the Saml2 token attribute by name
        -	 *
         	 * @param name the name of the attribute
         	 * @param  the type of the attribute
         	 * @return the attribute or {@code null} otherwise
        @@ -60,11 +59,11 @@ public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
         
         	/**
         	 * Get the Saml2 token attributes
        -	 *
         	 * @return the Saml2 token attributes
         	 * @since 5.4
         	 */
         	default Map> getAttributes() {
         		return Collections.emptyMap();
         	}
        +
         }
        diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Authentication.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Authentication.java
        index a2a5951648..d37792456b 100644
        --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Authentication.java
        +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Authentication.java
        @@ -16,32 +16,32 @@
         
         package org.springframework.security.saml2.provider.service.authentication;
         
        +import java.util.Collection;
        +
         import org.springframework.security.authentication.AbstractAuthenticationToken;
         import org.springframework.security.core.AuthenticatedPrincipal;
         import org.springframework.security.core.Authentication;
         import org.springframework.security.core.GrantedAuthority;
         import org.springframework.util.Assert;
         
        -import java.util.Collection;
        -
         /**
        - * An implementation of an {@link AbstractAuthenticationToken}
        - * that represents an authenticated SAML 2.0 {@link Authentication}.
        + * An implementation of an {@link AbstractAuthenticationToken} that represents an
        + * authenticated SAML 2.0 {@link Authentication}.
          * 

        - * The {@link Authentication} associates valid SAML assertion - * data with a Spring Security authentication object - * The complete assertion is contained in the object in String format, - * {@link Saml2Authentication#getSaml2Response()} + * The {@link Authentication} associates valid SAML assertion data with a Spring Security + * authentication object The complete assertion is contained in the object in String + * format, {@link Saml2Authentication#getSaml2Response()} + * * @since 5.2 * @see AbstractAuthenticationToken */ public class Saml2Authentication extends AbstractAuthenticationToken { private final AuthenticatedPrincipal principal; + private final String saml2Response; - public Saml2Authentication(AuthenticatedPrincipal principal, - String saml2Response, + public Saml2Authentication(AuthenticatedPrincipal principal, String saml2Response, Collection authorities) { super(authorities); Assert.notNull(principal, "principal cannot be null"); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationException.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationException.java index 12b16a7358..a4f4610833 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationException.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationException.java @@ -27,23 +27,23 @@ import org.springframework.util.Assert; *

        * There are a number of scenarios where an error may occur, for example: *

          - *
        • The response or assertion request is missing or malformed
        • - *
        • Missing or invalid subject
        • - *
        • Missing or invalid signatures
        • - *
        • The time period validation for the assertion fails
        • - *
        • One of the assertion conditions was not met
        • - *
        • Decryption failed
        • - *
        • Unable to locate a subject identifier, commonly known as username
        • + *
        • The response or assertion request is missing or malformed
        • + *
        • Missing or invalid subject
        • + *
        • Missing or invalid signatures
        • + *
        • The time period validation for the assertion fails
        • + *
        • One of the assertion conditions was not met
        • + *
        • Decryption failed
        • + *
        • Unable to locate a subject identifier, commonly known as username
        • *
        * * @since 5.2 */ public class Saml2AuthenticationException extends AuthenticationException { - private Saml2Error error; + + private final Saml2Error error; /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * * @param error the {@link Saml2Error SAML 2.0 Error} */ public Saml2AuthenticationException(Saml2Error error) { @@ -52,90 +52,99 @@ public class Saml2AuthenticationException extends AuthenticationException { /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * * @param error the {@link Saml2Error SAML 2.0 Error} * @param cause the root cause */ public Saml2AuthenticationException(Saml2Error error, Throwable cause) { - this(error, cause.getMessage(), cause); + this(error, (cause != null) ? cause.getMessage() : error.getDescription(), cause); } /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * * @param error the {@link Saml2Error SAML 2.0 Error} * @param message the detail message */ public Saml2AuthenticationException(Saml2Error error, String message) { - super(message); - this.setError(error); + this(error, message, null); } /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * * @param error the {@link Saml2Error SAML 2.0 Error} * @param message the detail message * @param cause the root cause */ public Saml2AuthenticationException(Saml2Error error, String message, Throwable cause) { super(message, cause); - this.setError(error); + Assert.notNull(error, "error cannot be null"); + this.error = error; } /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * - * @param error the {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error SAML 2.0 Error} - * @deprecated Use {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error} constructor instead + * @param error the + * {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error + * SAML 2.0 Error} + * @deprecated Use + * {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error} + * constructor instead */ @Deprecated - public Saml2AuthenticationException(org.springframework.security.saml2.provider.service.authentication.Saml2Error error) { + public Saml2AuthenticationException( + org.springframework.security.saml2.provider.service.authentication.Saml2Error error) { this(error, error.getDescription()); } /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * - * @param error the {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error SAML 2.0 Error} + * @param error the + * {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error + * SAML 2.0 Error} * @param cause the root cause - * @deprecated Use {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error} constructor instead + * @deprecated Use + * {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error} + * constructor instead */ @Deprecated - public Saml2AuthenticationException(org.springframework.security.saml2.provider.service.authentication.Saml2Error error, Throwable cause) { + public Saml2AuthenticationException( + org.springframework.security.saml2.provider.service.authentication.Saml2Error error, Throwable cause) { this(error, cause.getMessage(), cause); } /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * * @param error the {@link Saml2Error SAML 2.0 Error} * @param message the detail message * @deprecated Use {@link Saml2Error} constructor instead */ @Deprecated - public Saml2AuthenticationException(org.springframework.security.saml2.provider.service.authentication.Saml2Error error, String message) { - super(message); - this.setError(error); + public Saml2AuthenticationException( + org.springframework.security.saml2.provider.service.authentication.Saml2Error error, String message) { + this(error, message, null); } /** * Constructs a {@code Saml2AuthenticationException} using the provided parameters. - * - * @param error the {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error SAML 2.0 Error} + * @param error the + * {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error + * SAML 2.0 Error} * @param message the detail message * @param cause the root cause - * @deprecated Use {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error} constructor instead + * @deprecated Use + * {@link org.springframework.security.saml2.provider.service.authentication.Saml2Error} + * constructor instead */ @Deprecated - public Saml2AuthenticationException(org.springframework.security.saml2.provider.service.authentication.Saml2Error error, String message, Throwable cause) { + public Saml2AuthenticationException( + org.springframework.security.saml2.provider.service.authentication.Saml2Error error, String message, + Throwable cause) { super(message, cause); - this.setError(error); + Assert.notNull(error, "error cannot be null"); + this.error = new Saml2Error(error.getErrorCode(), error.getDescription()); } /** * Get the associated {@link Saml2Error} - * * @return the associated {@link Saml2Error} */ public Saml2Error getSaml2Error() { @@ -144,7 +153,6 @@ public class Saml2AuthenticationException extends AuthenticationException { /** * Returns the {@link Saml2Error SAML 2.0 Error}. - * * @return the {@link Saml2Error} * @deprecated Use {@link #getSaml2Error()} instead */ @@ -154,20 +162,12 @@ public class Saml2AuthenticationException extends AuthenticationException { this.error.getErrorCode(), this.error.getDescription()); } - private void setError(Saml2Error error) { - Assert.notNull(error, "error cannot be null"); - this.error = error; - } - - private void setError(org.springframework.security.saml2.provider.service.authentication.Saml2Error error) { - setError(new Saml2Error(error.getErrorCode(), error.getDescription())); - } - @Override public String toString() { final StringBuffer sb = new StringBuffer("Saml2AuthenticationException{"); - sb.append("error=").append(error); + sb.append("error=").append(this.error); sb.append('}'); return sb.toString(); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequest.java index 1fcc1a80a2..1cac7d9fb5 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequest.java @@ -16,33 +16,35 @@ package org.springframework.security.saml2.provider.service.authentication; -import org.springframework.security.saml2.credentials.Saml2X509Credential; -import org.springframework.util.Assert; - import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.function.Consumer; +import org.springframework.security.saml2.credentials.Saml2X509Credential; +import org.springframework.util.Assert; + /** - * Data holder for information required to send an {@code AuthNRequest} - * from the service provider to the identity provider - * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf (line 2031) + * Data holder for information required to send an {@code AuthNRequest} from the service + * provider to the identity provider + * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf + * (line 2031) * * @since 5.2 * @deprecated use {@link Saml2AuthenticationRequestContext} */ @Deprecated public final class Saml2AuthenticationRequest { + private final String issuer; + private final List credentials; + private final String destination; + private final String assertionConsumerServiceUrl; - private Saml2AuthenticationRequest( - String issuer, - String destination, - String assertionConsumerServiceUrl, + private Saml2AuthenticationRequest(String issuer, String destination, String assertionConsumerServiceUrl, List credentials) { Assert.hasText(issuer, "issuer cannot be null"); Assert.hasText(destination, "destination cannot be null"); @@ -58,10 +60,9 @@ public final class Saml2AuthenticationRequest { } } - /** - * returns the issuer, the local SP entity ID, for this authentication request. - * This property should be used to populate the {@code AuthNRequest.Issuer} XML element. + * returns the issuer, the local SP entity ID, for this authentication request. This + * property should be used to populate the {@code AuthNRequest.Issuer} XML element. * This value typically is a URI, but can be an arbitrary string. * @return issuer */ @@ -70,8 +71,9 @@ public final class Saml2AuthenticationRequest { } /** - * returns the destination, the WEB Single Sign On URI, for this authentication request. - * This property populates the {@code AuthNRequest#Destination} XML attribute. + * returns the destination, the WEB Single Sign On URI, for this authentication + * request. This property populates the {@code AuthNRequest#Destination} XML + * attribute. * @return destination */ public String getDestination() { @@ -79,17 +81,18 @@ public final class Saml2AuthenticationRequest { } /** - * Returns the desired {@code AssertionConsumerServiceUrl} that this SP wishes to receive the - * assertion on. The IDP may or may not honor this request. - * This property populates the {@code AuthNRequest#AssertionConsumerServiceURL} XML attribute. + * Returns the desired {@code AssertionConsumerServiceUrl} that this SP wishes to + * receive the assertion on. The IDP may or may not honor this request. This property + * populates the {@code AuthNRequest#AssertionConsumerServiceURL} XML attribute. * @return the AssertionConsumerServiceURL value */ public String getAssertionConsumerServiceUrl() { - return assertionConsumerServiceUrl; + return this.assertionConsumerServiceUrl; } /** - * Returns a list of credentials that can be used to sign the {@code AuthNRequest} object + * Returns a list of credentials that can be used to sign the {@code AuthNRequest} + * object * @return signing credentials */ public List getCredentials() { @@ -97,8 +100,7 @@ public final class Saml2AuthenticationRequest { } /** - * A builder for {@link Saml2AuthenticationRequest}. - * returns a builder object + * A builder for {@link Saml2AuthenticationRequest}. returns a builder object */ public static Builder builder() { return new Builder(); @@ -106,25 +108,25 @@ public final class Saml2AuthenticationRequest { /** * A builder for {@link Saml2AuthenticationRequest}. - * @param context a context object to copy values from. - * returns a builder object + * @param context a context object to copy values from. returns a builder object */ public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) { - return new Builder() - .assertionConsumerServiceUrl(context.getAssertionConsumerServiceUrl()) - .issuer(context.getIssuer()) - .destination(context.getDestination()) - .credentials(c -> c.addAll(context.getRelyingPartyRegistration().getCredentials())) - ; + return new Builder().assertionConsumerServiceUrl(context.getAssertionConsumerServiceUrl()) + .issuer(context.getIssuer()).destination(context.getDestination()) + .credentials((c) -> c.addAll(context.getRelyingPartyRegistration().getCredentials())); } /** * A builder for {@link Saml2AuthenticationRequest}. */ - public static class Builder { + public static final class Builder { + private String issuer; + private List credentials = new LinkedList<>(); + private String destination; + private String assertionConsumerServiceUrl; private Builder() { @@ -141,14 +143,12 @@ public final class Saml2AuthenticationRequest { } /** - * Modifies the collection of {@link Saml2X509Credential} credentials - * used in communication between IDP and SP, specifically signing the - * authentication request. - * For example: - * + * Modifies the collection of {@link Saml2X509Credential} credentials used in + * communication between IDP and SP, specifically signing the authentication + * request. For example: * Saml2X509Credential credential = ...; * return Saml2AuthenticationRequest.withLocalSpEntityId("id") - * .credentials(c -> c.add(credential)) + * .credentials((c) -> c.add(credential)) * ... * .build(); * @@ -161,7 +161,8 @@ public final class Saml2AuthenticationRequest { } /** - * Sets the Destination for the authentication request. Typically the {@code Service Provider EntityID} + * Sets the Destination for the authentication request. Typically the + * {@code Service Provider EntityID} * @param destination - a required value * @return this {@code Builder} */ @@ -184,15 +185,13 @@ public final class Saml2AuthenticationRequest { /** * Creates a {@link Saml2AuthenticationRequest} object. * @return the Saml2AuthenticationRequest object - * @throws {@link IllegalArgumentException} if a required property is not set + * @throws IllegalArgumentException if a required property is not set */ public Saml2AuthenticationRequest build() { - return new Saml2AuthenticationRequest( - this.issuer, - this.destination, - this.assertionConsumerServiceUrl, - this.credentials - ); + return new Saml2AuthenticationRequest(this.issuer, this.destination, this.assertionConsumerServiceUrl, + this.credentials); } + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java index 01343f2b3c..a8bf6aeb51 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java @@ -20,26 +20,27 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.util.Assert; /** - * Data holder for information required to create an {@code AuthNRequest} - * to be sent from the service provider to the identity provider - *
        + * Data holder for information required to create an {@code AuthNRequest} to be sent from + * the service provider to the identity provider * Assertions and Protocols for SAML 2 (line 2031) * + * @since 5.3 * @see Saml2AuthenticationRequestFactory#createPostAuthenticationRequest(Saml2AuthenticationRequestContext) * @see Saml2AuthenticationRequestFactory#createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext) - * @since 5.3 */ public class Saml2AuthenticationRequestContext { + private final RelyingPartyRegistration relyingPartyRegistration; + private final String issuer; + private final String assertionConsumerServiceUrl; + private final String relayState; - protected Saml2AuthenticationRequestContext( - RelyingPartyRegistration relyingPartyRegistration, - String issuer, - String assertionConsumerServiceUrl, - String relayState) { + protected Saml2AuthenticationRequestContext(RelyingPartyRegistration relyingPartyRegistration, String issuer, + String assertionConsumerServiceUrl, String relayState) { Assert.hasText(issuer, "issuer cannot be null or empty"); Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null"); Assert.hasText(assertionConsumerServiceUrl, "spAssertionConsumerServiceUrl cannot be null or empty"); @@ -50,7 +51,8 @@ public class Saml2AuthenticationRequestContext { } /** - * Returns the {@link RelyingPartyRegistration} configuration for which the AuthNRequest is intended for. + * Returns the {@link RelyingPartyRegistration} configuration for which the + * AuthNRequest is intended for. * @return the {@link RelyingPartyRegistration} configuration */ public RelyingPartyRegistration getRelyingPartyRegistration() { @@ -59,8 +61,8 @@ public class Saml2AuthenticationRequestContext { /** * Returns the {@code Issuer} value to be used in the {@code AuthNRequest} object. - * This property should be used to populate the {@code AuthNRequest.Issuer} XML element. - * This value typically is a URI, but can be an arbitrary string. + * This property should be used to populate the {@code AuthNRequest.Issuer} XML + * element. This value typically is a URI, but can be an arbitrary string. * @return the Issuer value */ public String getIssuer() { @@ -68,13 +70,13 @@ public class Saml2AuthenticationRequestContext { } /** - * Returns the desired {@code AssertionConsumerServiceUrl} that this SP wishes to receive the - * assertion on. The IDP may or may not honor this request. - * This property populates the {@code AuthNRequest.AssertionConsumerServiceURL} XML attribute. + * Returns the desired {@code AssertionConsumerServiceUrl} that this SP wishes to + * receive the assertion on. The IDP may or may not honor this request. This property + * populates the {@code AuthNRequest.AssertionConsumerServiceURL} XML attribute. * @return the AssertionConsumerServiceURL value */ public String getAssertionConsumerServiceUrl() { - return assertionConsumerServiceUrl; + return this.assertionConsumerServiceUrl; } /** @@ -86,8 +88,9 @@ public class Saml2AuthenticationRequestContext { } /** - * Returns the {@code Destination}, the WEB Single Sign On URI, for this authentication request. - * This property can also populate the {@code AuthNRequest.Destination} XML attribute. + * Returns the {@code Destination}, the WEB Single Sign On URI, for this + * authentication request. This property can also populate the + * {@code AuthNRequest.Destination} XML attribute. * @return the Destination value */ public String getDestination() { @@ -105,10 +108,14 @@ public class Saml2AuthenticationRequestContext { /** * A builder for {@link Saml2AuthenticationRequestContext}. */ - public static class Builder { + public static final class Builder { + private String issuer; + private String assertionConsumerServiceUrl; + private String relayState; + private RelyingPartyRegistration relyingPartyRegistration; private Builder() { @@ -125,7 +132,8 @@ public class Saml2AuthenticationRequestContext { } /** - * Sets the {@link RelyingPartyRegistration} used to build the authentication request. + * Sets the {@link RelyingPartyRegistration} used to build the authentication + * request. * @param relyingPartyRegistration - a required value * @return this {@code Builder} */ @@ -147,7 +155,8 @@ public class Saml2AuthenticationRequestContext { /** * Sets the {@code RelayState} parameter that will accompany this AuthNRequest - * @param relayState the relay state value, unencoded. if null or empty, the parameter will be removed from the map. + * @param relayState the relay state value, unencoded. if null or empty, the + * parameter will be removed from the map. * @return this object */ public Builder relayState(String relayState) { @@ -158,15 +167,13 @@ public class Saml2AuthenticationRequestContext { /** * Creates a {@link Saml2AuthenticationRequestContext} object. * @return the Saml2AuthenticationRequest object - * @throws {@link IllegalArgumentException} if a required property is not set + * @throws IllegalArgumentException if a required property is not set */ public Saml2AuthenticationRequestContext build() { - return new Saml2AuthenticationRequestContext( - this.relyingPartyRegistration, - this.issuer, - this.assertionConsumerServiceUrl, - this.relayState - ); + return new Saml2AuthenticationRequestContext(this.relyingPartyRegistration, this.issuer, + this.assertionConsumerServiceUrl, this.relayState); } + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactory.java index b70a4fe69c..db2b13585b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactory.java @@ -22,14 +22,10 @@ import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import static org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequest.withAuthenticationRequestContext; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode; - /** - * Component that generates AuthenticationRequest, samlp:AuthnRequestType XML, and accompanying - * signature data. - * as defined by https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf + * Component that generates AuthenticationRequest, samlp:AuthnRequestType + * XML, and accompanying signature data. as defined by + * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf * Page 50, Line 2147 * * @since 5.2 @@ -37,81 +33,83 @@ import static org.springframework.security.saml2.provider.service.authentication public interface Saml2AuthenticationRequestFactory { /** - * Creates an authentication request from the Service Provider, sp, to the Identity Provider, idp. - * The authentication result is an XML string that may be signed, encrypted, both or neither. - * This method only returns the {@code SAMLRequest} string for the request, and for a complete - * set of data parameters please use {@link #createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext)} - * or {@link #createPostAuthenticationRequest(Saml2AuthenticationRequestContext)} - * - * @param request information about the identity provider, - * the recipient of this authentication request and accompanying data - * @return XML data in the format of a String. This data may be signed, encrypted, both signed and encrypted with the - * signature embedded in the XML or neither signed and encrypted + * Creates an authentication request from the Service Provider, sp, to the Identity + * Provider, idp. The authentication result is an XML string that may be signed, + * encrypted, both or neither. This method only returns the {@code SAMLRequest} string + * for the request, and for a complete set of data parameters please use + * {@link #createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext)} or + * {@link #createPostAuthenticationRequest(Saml2AuthenticationRequestContext)} + * @param request information about the identity provider, the recipient of this + * authentication request and accompanying data + * @return XML data in the format of a String. This data may be signed, encrypted, + * both signed and encrypted with the signature embedded in the XML or neither signed + * and encrypted * @throws Saml2Exception when a SAML library exception occurs * @since 5.2 - * @deprecated please use {@link #createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext)} - * or {@link #createPostAuthenticationRequest(Saml2AuthenticationRequestContext)} - * This method will be removed in future versions of Spring Security + * @deprecated please use + * {@link #createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext)} or + * {@link #createPostAuthenticationRequest(Saml2AuthenticationRequestContext)} This + * method will be removed in future versions of Spring Security */ @Deprecated String createAuthenticationRequest(Saml2AuthenticationRequest request); /** - * Creates all the necessary AuthNRequest parameters for a REDIRECT binding. - * If the {@link Saml2AuthenticationRequestContext} doesn't contain any {@link Saml2X509CredentialType#SIGNING} credentials - * the result will not contain any signatures. - * The data set will be signed and encoded for REDIRECT binding including the DEFLATE encoding. - * It will contain the following parameters to be sent as part of the query string: - * {@code SAMLRequest, RelayState, SigAlg, Signature}. - * The default implementation, for sake of backwards compatibility, of this method returns the - * SAMLRequest message with an XML signature embedded, that should only be used for the{@link Saml2MessageBinding#POST} - * binding, but works over {@link Saml2MessageBinding#POST} with most providers. - * @param context - information about the identity provider, the recipient of this authentication request and - * accompanying data - * @return a {@link Saml2RedirectAuthenticationRequest} object with applicable http parameters - * necessary to make the AuthNRequest over a POST or REDIRECT binding. - * All parameters will be SAML encoded/deflated, but escaped, ie URI encoded or encoded for Form Data. + * Creates all the necessary AuthNRequest parameters for a REDIRECT binding. If the + * {@link Saml2AuthenticationRequestContext} doesn't contain any + * {@link Saml2X509CredentialType#SIGNING} credentials the result will not contain any + * signatures. The data set will be signed and encoded for REDIRECT binding including + * the DEFLATE encoding. It will contain the following parameters to be sent as part + * of the query string: {@code SAMLRequest, RelayState, SigAlg, Signature}. The + * default implementation, for sake of backwards compatibility, of this method returns + * the SAMLRequest message with an XML signature embedded, that should only be used + * for the{@link Saml2MessageBinding#POST} binding, but works over + * {@link Saml2MessageBinding#POST} with most providers. + * @param context - information about the identity provider, the recipient of this + * authentication request and accompanying data + * @return a {@link Saml2RedirectAuthenticationRequest} object with applicable http + * parameters necessary to make the AuthNRequest over a POST or REDIRECT binding. All + * parameters will be SAML encoded/deflated, but escaped, ie URI encoded or encoded + * for Form Data. * @throws Saml2Exception when a SAML library exception occurs * @since 5.3 */ default Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest( - Saml2AuthenticationRequestContext context - ) { - //backwards compatible with 5.2.x settings - Saml2AuthenticationRequest.Builder resultBuilder = withAuthenticationRequestContext(context); + Saml2AuthenticationRequestContext context) { + // backwards compatible with 5.2.x settings + Saml2AuthenticationRequest.Builder resultBuilder = Saml2AuthenticationRequest + .withAuthenticationRequestContext(context); String samlRequest = createAuthenticationRequest(resultBuilder.build()); - samlRequest = samlEncode(samlDeflate(samlRequest)); - return Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context) - .samlRequest(samlRequest) + samlRequest = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(samlRequest)); + return Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context).samlRequest(samlRequest) .build(); } - /** - * Creates all the necessary AuthNRequest parameters for a POST binding. - * If the {@link Saml2AuthenticationRequestContext} doesn't contain any {@link Saml2X509CredentialType#SIGNING} credentials - * the result will not contain any signatures. - * The data set will be signed and encoded for POST binding and if applicable signed with XML signatures. - * will contain the following parameters to be sent as part of the form data: {@code SAMLRequest, RelayState}. - * The default implementation of this method returns the SAMLRequest message with an XML signature embedded, - * that should only be used for the {@link Saml2MessageBinding#POST} binding. - * @param context - information about the identity provider, the recipient of this authentication request and - * accompanying data - * @return a {@link Saml2PostAuthenticationRequest} object with applicable http parameters - * necessary to make the AuthNRequest over a POST binding. - * All parameters will be SAML encoded but not escaped for Form Data. + * Creates all the necessary AuthNRequest parameters for a POST binding. If the + * {@link Saml2AuthenticationRequestContext} doesn't contain any + * {@link Saml2X509CredentialType#SIGNING} credentials the result will not contain any + * signatures. The data set will be signed and encoded for POST binding and if + * applicable signed with XML signatures. will contain the following parameters to be + * sent as part of the form data: {@code SAMLRequest, RelayState}. The default + * implementation of this method returns the SAMLRequest message with an XML signature + * embedded, that should only be used for the {@link Saml2MessageBinding#POST} + * binding. + * @param context - information about the identity provider, the recipient of this + * authentication request and accompanying data + * @return a {@link Saml2PostAuthenticationRequest} object with applicable http + * parameters necessary to make the AuthNRequest over a POST binding. All parameters + * will be SAML encoded but not escaped for Form Data. * @throws Saml2Exception when a SAML library exception occurs * @since 5.3 */ - default Saml2PostAuthenticationRequest createPostAuthenticationRequest( - Saml2AuthenticationRequestContext context - ) { - //backwards compatible with 5.2.x settings - Saml2AuthenticationRequest.Builder resultBuilder = withAuthenticationRequestContext(context); + default Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { + // backwards compatible with 5.2.x settings + Saml2AuthenticationRequest.Builder resultBuilder = Saml2AuthenticationRequest + .withAuthenticationRequestContext(context); String samlRequest = createAuthenticationRequest(resultBuilder.build()); - samlRequest = samlEncode(samlRequest.getBytes(StandardCharsets.UTF_8)); - return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context) - .samlRequest(samlRequest) + samlRequest = Saml2Utils.samlEncode(samlRequest.getBytes(StandardCharsets.UTF_8)); + return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context).samlRequest(samlRequest) .build(); } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java index 22146994f0..5f4f8fdb33 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java @@ -24,37 +24,34 @@ import org.springframework.security.saml2.credentials.Saml2X509Credential; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.Assert; -import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId; - /** - * Represents an incoming SAML 2.0 response containing an assertion that has not been validated. - * {@link Saml2AuthenticationToken#isAuthenticated()} will always return false. + * Represents an incoming SAML 2.0 response containing an assertion that has not been + * validated. {@link Saml2AuthenticationToken#isAuthenticated()} will always return false. * - * @since 5.2 * @author Filip Hanik * @author Josh Cummings + * @since 5.2 */ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { private final RelyingPartyRegistration relyingPartyRegistration; + private final String saml2Response; /** * Creates a {@link Saml2AuthenticationToken} with the provided parameters * - * Note that the given {@link RelyingPartyRegistration} should have all its - * templates resolved at this point. See + * Note that the given {@link RelyingPartyRegistration} should have all its templates + * resolved at this point. See * {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter} * for an example of performing that resolution. - * - * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to use + * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to + * use * @param saml2Response the SAML 2.0 response to authenticate * * @since 5.4 */ - public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, - String saml2Response) { - + public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) { super(Collections.emptyList()); Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null"); Assert.notNull(saml2Response, "saml2Response cannot be null"); @@ -65,26 +62,23 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Creates an authentication token from an incoming SAML 2 Response object * @param saml2Response inflated and decoded XML representation of the SAML 2 Response - * @param recipientUri the URL that the SAML 2 Response was received at. Used for validation + * @param recipientUri the URL that the SAML 2 Response was received at. Used for + * validation * @param idpEntityId the entity ID of the asserting entity * @param localSpEntityId the configured local SP, the relying party, entity ID - * @param credentials the credentials configured for signature verification and decryption - * @deprecated Use {@link Saml2AuthenticationToken(RelyingPartyRegistration, String)} instead + * @param credentials the credentials configured for signature verification and + * decryption + * @deprecated Use {@link #Saml2AuthenticationToken(RelyingPartyRegistration, String)} + * instead */ @Deprecated - public Saml2AuthenticationToken(String saml2Response, - String recipientUri, - String idpEntityId, - String localSpEntityId, - List credentials) { + public Saml2AuthenticationToken(String saml2Response, String recipientUri, String idpEntityId, + String localSpEntityId, List credentials) { super(null); - this.relyingPartyRegistration = withRegistrationId(idpEntityId) - .entityId(localSpEntityId) - .assertionConsumerServiceLocation(recipientUri) - .credentials(c -> c.addAll(credentials)) - .assertingPartyDetails(assertingParty -> assertingParty - .entityId(idpEntityId) - .singleSignOnServiceLocation(idpEntityId)) + this.relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId(idpEntityId) + .entityId(localSpEntityId).assertionConsumerServiceLocation(recipientUri) + .credentials((c) -> c.addAll(credentials)).assertingPartyDetails((assertingParty) -> assertingParty + .entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId)) .build(); this.saml2Response = saml2Response; } @@ -109,7 +103,6 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Get the resolved {@link RelyingPartyRegistration} associated with the request - * * @return the resolved {@link RelyingPartyRegistration} * @since 5.4 */ @@ -128,7 +121,8 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Returns the URI that the SAML 2 Response object came in on * @return URI as a string - * @deprecated Use {@link #getRelyingPartyRegistration().getAssertionConsumerServiceLocation()} instead + * @deprecated Use + * {@code getRelyingPartyRegistration().getAssertionConsumerServiceLocation()} instead */ @Deprecated public String getRecipientUri() { @@ -138,7 +132,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Returns the configured entity ID of the receiving relying party, SP * @return an entityID for the configured local relying party - * @deprecated Use {@link #getRelyingPartyRegistration().getEntityId()} instead + * @deprecated Use {@code getRelyingPartyRegistration().getEntityId()} instead */ @Deprecated public String getLocalSpEntityId() { @@ -147,8 +141,9 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Returns all the credentials associated with the relying party configuraiton - * @return - * @deprecated Get the credentials through {@link #getRelyingPartyRegistration()} instead + * @return all associated credentials + * @deprecated Get the credentials through {@link #getRelyingPartyRegistration()} + * instead */ @Deprecated public List getX509Credentials() { @@ -166,7 +161,6 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * The state of this object cannot be changed. Will always throw an exception * @param authenticated ignored - * @throws {@link IllegalArgumentException} */ @Override public void setAuthenticated(boolean authenticated) { @@ -176,10 +170,13 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken { /** * Returns the configured IDP, asserting party, entity ID * @return a string representing the entity ID - * @deprecated Use {@link #getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()} instead + * @deprecated Use + * {@code getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()} + * instead */ @Deprecated public String getIdpEntityId() { return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Error.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Error.java index 721b0d5b5a..4c2b35afae 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Error.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Error.java @@ -24,22 +24,23 @@ import org.springframework.security.core.SpringSecurityCoreVersion; * A representation of an SAML 2.0 Error. * *

        - * At a minimum, an error response will contain an error code. - * The commonly used error code are defined in this class - * or a new codes can be defined in the future as arbitrary strings. + * At a minimum, an error response will contain an error code. The commonly used error + * code are defined in this class or a new codes can be defined in the future as arbitrary + * strings. *

        + * * @since 5.2 * @deprecated Use {@link org.springframework.security.saml2.core.Saml2Error} instead */ @Deprecated public class Saml2Error implements Serializable { + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private final org.springframework.security.saml2.core.Saml2Error error; /** * Constructs a {@code Saml2Error} using the provided parameters. - * * @param errorCode the error code * @param description the error description */ @@ -49,7 +50,6 @@ public class Saml2Error implements Serializable { /** * Returns the error code. - * * @return the error code */ public final String getErrorCode() { @@ -58,7 +58,6 @@ public class Saml2Error implements Serializable { /** * Returns the error description. - * * @return the error description */ public final String getDescription() { @@ -67,7 +66,7 @@ public class Saml2Error implements Serializable { @Override public String toString() { - return "[" + this.getErrorCode() + "] " + - (this.getDescription() != null ? this.getDescription() : ""); + return "[" + this.getErrorCode() + "] " + ((this.getDescription() != null) ? this.getDescription() : ""); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2ErrorCodes.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2ErrorCodes.java index b525f3a8e7..fbf31b24ec 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2ErrorCodes.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2ErrorCodes.java @@ -24,80 +24,84 @@ package org.springframework.security.saml2.provider.service.authentication; */ @Deprecated public interface Saml2ErrorCodes { + /** - * SAML Data does not represent a SAML 2 Response object. - * A valid XML object was received, but that object was not a - * SAML 2 Response object of type {@code ResponseType} per specification + * SAML Data does not represent a SAML 2 Response object. A valid XML object was + * received, but that object was not a SAML 2 Response object of type + * {@code ResponseType} per specification * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=46 */ String UNKNOWN_RESPONSE_CLASS = org.springframework.security.saml2.core.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS; + /** - * The response data is malformed or incomplete. - * An invalid XML object was received, and XML unmarshalling failed. + * The response data is malformed or incomplete. An invalid XML object was received, + * and XML unmarshalling failed. */ String MALFORMED_RESPONSE_DATA = org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA; + /** - * Response destination does not match the request URL. - * A SAML 2 response object was received at a URL that - * did not match the URL stored in the {code Destination} attribute - * in the Response object. + * Response destination does not match the request URL. A SAML 2 response object was + * received at a URL that did not match the URL stored in the {code Destination} + * attribute in the Response object. * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38 */ String INVALID_DESTINATION = org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_DESTINATION; + /** - * The assertion was not valid. - * The assertion used for authentication failed validation. - * Details around the failure will be present in the error description. + * The assertion was not valid. The assertion used for authentication failed + * validation. Details around the failure will be present in the error description. */ String INVALID_ASSERTION = org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; + /** - * The signature of response or assertion was invalid. - * Either the response or the assertion was missing a signature - * or the signature could not be verified using the system's - * configured credentials. Most commonly the IDP's - * X509 certificate. + * The signature of response or assertion was invalid. Either the response or the + * assertion was missing a signature or the signature could not be verified using the + * system's configured credentials. Most commonly the IDP's X509 certificate. */ String INVALID_SIGNATURE = org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; + /** - * The assertion did not contain a subject element. - * The subject element, type SubjectType, contains - * a {@code NameID} or an {@code EncryptedID} that is used - * to assign the authenticated principal an identifier, - * typically a username. + * The assertion did not contain a subject element. The subject element, type + * SubjectType, contains a {@code NameID} or an {@code EncryptedID} that is used to + * assign the authenticated principal an identifier, typically a username. * * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=18 */ String SUBJECT_NOT_FOUND = org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND; + /** - * The subject did not contain a user identifier - * The assertion contained a subject element, but the subject - * element did not have a {@code NameID} or {@code EncryptedID} - * element + * The subject did not contain a user identifier The assertion contained a subject + * element, but the subject element did not have a {@code NameID} or + * {@code EncryptedID} element * * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=18 */ String USERNAME_NOT_FOUND = org.springframework.security.saml2.core.Saml2ErrorCodes.USERNAME_NOT_FOUND; + /** - * The system failed to decrypt an assertion or a name identifier. - * This error code will be thrown if the decryption of either a - * {@code EncryptedAssertion} or {@code EncryptedID} fails. + * The system failed to decrypt an assertion or a name identifier. This error code + * will be thrown if the decryption of either a {@code EncryptedAssertion} or + * {@code EncryptedID} fails. * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=17 */ String DECRYPTION_ERROR = org.springframework.security.saml2.core.Saml2ErrorCodes.DECRYPTION_ERROR; + /** * An Issuer element contained a value that didn't * https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=15 */ String INVALID_ISSUER = org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ISSUER; + /** - * An error happened during validation. - * Used when internal, non classified, errors are caught during the - * authentication process. + * An error happened during validation. Used when internal, non classified, errors are + * caught during the authentication process. */ String INTERNAL_VALIDATION_ERROR = org.springframework.security.saml2.core.Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR; + /** - * The relying party registration was not found. - * The registration ID did not correspond to any relying party registration. + * The relying party registration was not found. The registration ID did not + * correspond to any relying party registration. */ String RELYING_PARTY_REGISTRATION_NOT_FOUND = org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND; + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java index d621a7c704..5fc84dd078 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java @@ -18,22 +18,18 @@ package org.springframework.security.saml2.provider.service.authentication; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; - /** - * Data holder for information required to send an {@code AuthNRequest} over a POST binding - * from the service provider to the identity provider - * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf (line 2031) + * Data holder for information required to send an {@code AuthNRequest} over a POST + * binding from the service provider to the identity provider + * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf + * (line 2031) * - * @see Saml2AuthenticationRequestFactory * @since 5.3 + * @see Saml2AuthenticationRequestFactory */ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest { - private Saml2PostAuthenticationRequest( - String samlRequest, - String relayState, - String authenticationRequestUri) { + Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) { super(samlRequest, relayState, authenticationRequestUri); } @@ -42,30 +38,28 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR */ @Override public Saml2MessageBinding getBinding() { - return POST; + return Saml2MessageBinding.POST; } /** - * Constructs a {@link Builder} from a {@link Saml2AuthenticationRequestContext} object. - * By default the {@link Saml2PostAuthenticationRequest#getAuthenticationRequestUri()} will be set to the - * {@link Saml2AuthenticationRequestContext#getDestination()} value. - * @param context input providing {@code Destination}, {@code RelayState}, and {@code Issuer} objects. + * Constructs a {@link Builder} from a {@link Saml2AuthenticationRequestContext} + * object. By default the + * {@link Saml2PostAuthenticationRequest#getAuthenticationRequestUri()} will be set to + * the {@link Saml2AuthenticationRequestContext#getDestination()} value. + * @param context input providing {@code Destination}, {@code RelayState}, and + * {@code Issuer} objects. * @return a modifiable builder object */ public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) { - return new Builder() - .authenticationRequestUri(context.getDestination()) - .relayState(context.getRelayState()) - ; + return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState()); } /** * Builder class for a {@link Saml2PostAuthenticationRequest} object. */ - public static class Builder extends AbstractSaml2AuthenticationRequest.Builder { + public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder { private Builder() { - super(); } /** @@ -73,13 +67,9 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR * @return an immutable {@link Saml2PostAuthenticationRequest} object. */ public Saml2PostAuthenticationRequest build() { - return new Saml2PostAuthenticationRequest( - this.samlRequest, - this.relayState, - this.authenticationRequestUri - ); + return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri); } + } - } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java index fdfc8372aa..80fec1d392 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java @@ -18,26 +18,22 @@ package org.springframework.security.saml2.provider.service.authentication; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; - /** - * Data holder for information required to send an {@code AuthNRequest} over a REDIRECT binding - * from the service provider to the identity provider - * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf (line 2031) + * Data holder for information required to send an {@code AuthNRequest} over a REDIRECT + * binding from the service provider to the identity provider + * https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf + * (line 2031) * - * @see Saml2AuthenticationRequestFactory * @since 5.3 + * @see Saml2AuthenticationRequestFactory */ -public class Saml2RedirectAuthenticationRequest extends AbstractSaml2AuthenticationRequest { +public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2AuthenticationRequest { private final String sigAlg; + private final String signature; - private Saml2RedirectAuthenticationRequest( - String samlRequest, - String sigAlg, - String signature, - String relayState, + private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState, String authenticationRequestUri) { super(samlRequest, relayState, authenticationRequestUri); this.sigAlg = sigAlg; @@ -61,36 +57,36 @@ public class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authenticat } /** - * @return {@link Saml2MessageBinding#REDIRECT} + * @return {@link Saml2MessageBinding#REDIRECT} */ @Override public Saml2MessageBinding getBinding() { - return REDIRECT; + return Saml2MessageBinding.REDIRECT; } /** - * Constructs a {@link Saml2RedirectAuthenticationRequest.Builder} from a {@link Saml2AuthenticationRequestContext} object. - * By default the {@link Saml2RedirectAuthenticationRequest#getAuthenticationRequestUri()} will be set to the - * {@link Saml2AuthenticationRequestContext#getDestination()} value. - * @param context input providing {@code Destination}, {@code RelayState}, and {@code Issuer} objects. + * Constructs a {@link Saml2RedirectAuthenticationRequest.Builder} from a + * {@link Saml2AuthenticationRequestContext} object. By default the + * {@link Saml2RedirectAuthenticationRequest#getAuthenticationRequestUri()} will be + * set to the {@link Saml2AuthenticationRequestContext#getDestination()} value. + * @param context input providing {@code Destination}, {@code RelayState}, and + * {@code Issuer} objects. * @return a modifiable builder object */ public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) { - return new Builder() - .authenticationRequestUri(context.getDestination()) - .relayState(context.getRelayState()) - ; + return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState()); } /** * Builder class for a {@link Saml2RedirectAuthenticationRequest} object. */ - public static class Builder extends AbstractSaml2AuthenticationRequest.Builder { + public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder { + private String sigAlg; + private String signature; private Builder() { - super(); } /** @@ -118,16 +114,10 @@ public class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authenticat * @return an immutable {@link Saml2RedirectAuthenticationRequest} object. */ public Saml2RedirectAuthenticationRequest build() { - return new Saml2RedirectAuthenticationRequest( - this.samlRequest, - this.sigAlg, - this.signature, - this.relayState, - this.authenticationRequestUri - ); + return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature, + this.relayState, this.authenticationRequestUri); } } - } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java index ae271df111..f8f1066a79 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java @@ -16,26 +16,27 @@ package org.springframework.security.saml2.provider.service.authentication; -import org.apache.commons.codec.binary.Base64; -import org.springframework.security.saml2.Saml2Exception; - import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.zip.Deflater.DEFLATED; +import org.apache.commons.codec.binary.Base64; + +import org.springframework.security.saml2.Saml2Exception; /** * @since 5.3 */ final class Saml2Utils { + private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }); - private static Base64 BASE64 = new Base64(0, new byte[]{'\n'}); + private Saml2Utils() { + } static String samlEncode(byte[] b) { return BASE64.encodeAsString(b); @@ -48,13 +49,13 @@ final class Saml2Utils { static byte[] samlDeflate(String s) { try { ByteArrayOutputStream b = new ByteArrayOutputStream(); - DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true)); - deflater.write(s.getBytes(UTF_8)); + DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(Deflater.DEFLATED, true)); + deflater.write(s.getBytes(StandardCharsets.UTF_8)); deflater.finish(); return b.toByteArray(); } - catch (IOException e) { - throw new Saml2Exception("Unable to deflate string", e); + catch (IOException ex) { + throw new Saml2Exception("Unable to deflate string", ex); } } @@ -64,10 +65,11 @@ final class Saml2Utils { InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); iout.write(b); iout.finish(); - return new String(out.toByteArray(), UTF_8); + return new String(out.toByteArray(), StandardCharsets.UTF_8); } - catch (IOException e) { - throw new Saml2Exception("Unable to inflate string", e); + catch (IOException ex) { + throw new Saml2Exception("Unable to inflate string", ex); } } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java index bc27e60ec9..edcd9c35c6 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java @@ -21,10 +21,12 @@ import java.util.ArrayList; import java.util.Base64; import java.util.Collection; import java.util.List; + import javax.xml.namespace.QName; import net.shibboleth.utilities.java.support.xml.SerializeSupport; import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.metadata.AssertionConsumerService; import org.opensaml.saml.saml2.metadata.EntityDescriptor; @@ -43,18 +45,16 @@ import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.Assert; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; - /** - * Resolves the SAML 2.0 Relying Party Metadata for a given {@link RelyingPartyRegistration} - * using the OpenSAML API. + * Resolves the SAML 2.0 Relying Party Metadata for a given + * {@link RelyingPartyRegistration} using the OpenSAML API. * * @author Jakub Kubrynski * @author Josh Cummings * @since 5.4 */ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { + static { OpenSamlInitializationService.initialize(); } @@ -62,22 +62,17 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { private final EntityDescriptorMarshaller entityDescriptorMarshaller; public OpenSamlMetadataResolver() { - this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) - getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME); + this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) XMLObjectProviderRegistrySupport + .getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME); Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null"); } - /** - * {@inheritDoc} - */ @Override public String resolve(RelyingPartyRegistration relyingPartyRegistration) { EntityDescriptor entityDescriptor = build(EntityDescriptor.ELEMENT_QNAME); entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId()); - SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration); entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor); - return serialize(entityDescriptor); } @@ -85,10 +80,10 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { SPSSODescriptor spSsoDescriptor = build(SPSSODescriptor.DEFAULT_ELEMENT_NAME); spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS); spSsoDescriptor.setWantAssertionsSigned(true); - spSsoDescriptor.getKeyDescriptors().addAll(buildKeys( - registration.getSigningX509Credentials(), UsageType.SIGNING)); - spSsoDescriptor.getKeyDescriptors().addAll(buildKeys( - registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION)); + spSsoDescriptor.getKeyDescriptors() + .addAll(buildKeys(registration.getSigningX509Credentials(), UsageType.SIGNING)); + spSsoDescriptor.getKeyDescriptors() + .addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION)); spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration)); return spSsoDescriptor; } @@ -107,16 +102,14 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { KeyInfo keyInfo = build(KeyInfo.DEFAULT_ELEMENT_NAME); X509Certificate x509Certificate = build(X509Certificate.DEFAULT_ELEMENT_NAME); X509Data x509Data = build(X509Data.DEFAULT_ELEMENT_NAME); - try { x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded()))); - } catch (CertificateEncodingException e) { + } + catch (CertificateEncodingException ex) { throw new Saml2Exception("Cannot encode certificate " + certificate.toString()); } - x509Data.getX509Certificates().add(x509Certificate); keyInfo.getX509Datas().add(x509Data); - keyDescriptor.setUse(usageType); keyDescriptor.setKeyInfo(keyInfo); return keyDescriptor; @@ -132,20 +125,21 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver { @SuppressWarnings("unchecked") private T build(QName elementName) { - XMLObjectBuilder builder = getBuilderFactory().getBuilder(elementName); + XMLObjectBuilder builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName); if (builder == null) { throw new Saml2Exception("Unable to resolve Builder for " + elementName); } return (T) builder.buildObject(elementName); } - private String serialize(EntityDescriptor entityDescriptor) { try { Element element = this.entityDescriptorMarshaller.marshall(entityDescriptor); return SerializeSupport.prettyPrintXML(element); - } catch (Exception e) { - throw new Saml2Exception(e); + } + catch (Exception ex) { + throw new Saml2Exception(ex); } } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java index adfafcf56c..999a124771 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java @@ -19,18 +19,20 @@ package org.springframework.security.saml2.provider.service.metadata; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; /** - * Resolves the SAML 2.0 Relying Party Metadata for a given {@link RelyingPartyRegistration} + * Resolves the SAML 2.0 Relying Party Metadata for a given + * {@link RelyingPartyRegistration} * * @author Jakub Kubrynski * @author Josh Cummings * @since 5.4 */ public interface Saml2MetadataResolver { + /** * Resolve the given relying party's metadata - * * @param relyingPartyRegistration the relying party * @return the relying party's metadata */ String resolve(RelyingPartyRegistration relyingPartyRegistration); + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java index 4f28cea5e8..495c86123b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java @@ -16,17 +16,14 @@ package org.springframework.security.saml2.provider.service.registration; -import org.springframework.util.Assert; - +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; -import static java.util.Arrays.asList; -import static org.springframework.util.Assert.notEmpty; -import static org.springframework.util.Assert.notNull; +import org.springframework.util.Assert; /** * @since 5.2 @@ -37,23 +34,22 @@ public class InMemoryRelyingPartyRegistrationRepository private final Map byRegistrationId; public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) { - this(asList(registrations)); + this(Arrays.asList(registrations)); } public InMemoryRelyingPartyRegistrationRepository(Collection registrations) { - notEmpty(registrations, "registrations cannot be empty"); + Assert.notEmpty(registrations, "registrations cannot be empty"); this.byRegistrationId = createMappingToIdentityProvider(registrations); } private static Map createMappingToIdentityProvider( - Collection rps - ) { + Collection rps) { LinkedHashMap result = new LinkedHashMap<>(); for (RelyingPartyRegistration rp : rps) { - notNull(rp, "relying party collection cannot contain null values"); + Assert.notNull(rp, "relying party collection cannot contain null values"); String key = rp.getRegistrationId(); - notNull(rp, "relying party identifier cannot be null"); - Assert.isNull(result.get(key), () -> "relying party duplicate identifier '" + key+"' detected."); + Assert.notNull(rp, "relying party identifier cannot be null"); + Assert.isNull(result.get(key), () -> "relying party duplicate identifier '" + key + "' detected."); result.put(key, rp); } return Collections.unmodifiableMap(result); @@ -61,7 +57,7 @@ public class InMemoryRelyingPartyRegistrationRepository @Override public RelyingPartyRegistration findByRegistrationId(String id) { - return this.byRegistrationId.get(id); + return this.byRegistrationId.get(id); } @Override diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter.java index 77c3e0c988..1aefa5489d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter.java @@ -27,6 +27,7 @@ import java.util.List; import net.shibboleth.utilities.java.support.xml.ParserPool; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.metadata.EntityDescriptor; import org.opensaml.saml.saml2.metadata.IDPSSODescriptor; import org.opensaml.saml.saml2.metadata.KeyDescriptor; @@ -47,19 +48,14 @@ import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; -import static java.lang.Boolean.TRUE; -import static org.opensaml.saml.common.xml.SAMLConstants.SAML20P_NS; -import static org.springframework.security.saml2.core.Saml2X509Credential.encryption; -import static org.springframework.security.saml2.core.Saml2X509Credential.verification; -import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId; - /** - * An {@link HttpMessageConverter} that takes an {@code IDPSSODescriptor} in an HTTP response - * and converts it into a {@link RelyingPartyRegistration.Builder}. + * An {@link HttpMessageConverter} that takes an {@code IDPSSODescriptor} in an HTTP + * response and converts it into a {@link RelyingPartyRegistration.Builder}. * - * The primary use case for this is constructing a {@link RelyingPartyRegistration} for inclusion in a - * {@link RelyingPartyRegistrationRepository}. To do so, you can include an instance of this converter in a - * {@link org.springframework.web.client.RestOperations} like so: + * The primary use case for this is constructing a {@link RelyingPartyRegistration} for + * inclusion in a {@link RelyingPartyRegistrationRepository}. To do so, you can include an + * instance of this converter in a {@link org.springframework.web.client.RestOperations} + * like so: * *
          * 		RestOperations rest = new RestTemplate(Collections.singletonList(
        @@ -69,11 +65,12 @@ import static org.springframework.security.saml2.provider.service.registration.R
          * 		RelyingPartyRegistration registration = builder.registrationId("registration-id").build();
          * 
        * - * Note that this will only configure the asserting party (IDP) half of the {@link RelyingPartyRegistration}, - * meaning where and how to send AuthnRequests, how to verify Assertions, etc. + * Note that this will only configure the asserting party (IDP) half of the + * {@link RelyingPartyRegistration}, meaning where and how to send AuthnRequests, how to + * verify Assertions, etc. * - * To further configure the {@link RelyingPartyRegistration} with relying party (SP) information, you may - * invoke the appropriate methods on the builder. + * To further configure the {@link RelyingPartyRegistration} with relying party (SP) + * information, you may invoke the appropriate methods on the builder. * * @author Josh Cummings * @since 5.4 @@ -86,6 +83,7 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter } private final EntityDescriptorUnmarshaller unmarshaller; + private final ParserPool parserPool; /** @@ -98,39 +96,26 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter this.parserPool = registry.getParserPool(); } - /** - * {@inheritDoc} - */ @Override public boolean canRead(Class clazz, MediaType mediaType) { return RelyingPartyRegistration.Builder.class.isAssignableFrom(clazz); } - /** - * {@inheritDoc} - */ @Override public boolean canWrite(Class clazz, MediaType mediaType) { return false; } - /** - * {@inheritDoc} - */ @Override public List getSupportedMediaTypes() { return Arrays.asList(MediaType.APPLICATION_XML, MediaType.TEXT_XML); } - /** - * {@inheritDoc} - */ @Override - public RelyingPartyRegistration.Builder read(Class clazz, HttpInputMessage inputMessage) - throws IOException, HttpMessageNotReadableException { - + public RelyingPartyRegistration.Builder read(Class clazz, + HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { EntityDescriptor descriptor = entityDescriptor(inputMessage.getBody()); - IDPSSODescriptor idpssoDescriptor = descriptor.getIDPSSODescriptor(SAML20P_NS); + IDPSSODescriptor idpssoDescriptor = descriptor.getIDPSSODescriptor(SAMLConstants.SAML20P_NS); if (idpssoDescriptor == null) { throw new Saml2Exception("Metadata response is missing the necessary IDPSSODescriptor element"); } @@ -140,54 +125,84 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter if (keyDescriptor.getUse().equals(UsageType.SIGNING)) { List certificates = certificates(keyDescriptor); for (X509Certificate certificate : certificates) { - verification.add(verification(certificate)); + verification.add(Saml2X509Credential.verification(certificate)); } } if (keyDescriptor.getUse().equals(UsageType.ENCRYPTION)) { List certificates = certificates(keyDescriptor); for (X509Certificate certificate : certificates) { - encryption.add(encryption(certificate)); + encryption.add(Saml2X509Credential.encryption(certificate)); } } if (keyDescriptor.getUse().equals(UsageType.UNSPECIFIED)) { List certificates = certificates(keyDescriptor); for (X509Certificate certificate : certificates) { - verification.add(verification(certificate)); - encryption.add(encryption(certificate)); + verification.add(Saml2X509Credential.verification(certificate)); + encryption.add(Saml2X509Credential.encryption(certificate)); } } } if (verification.isEmpty()) { - throw new Saml2Exception("Metadata response is missing verification certificates, necessary for verifying SAML assertions"); + throw new Saml2Exception( + "Metadata response is missing verification certificates, necessary for verifying SAML assertions"); } - RelyingPartyRegistration.Builder builder = withRegistrationId(descriptor.getEntityID()) - .assertingPartyDetails(party -> party - .entityId(descriptor.getEntityID()) - .wantAuthnRequestsSigned(TRUE.equals(idpssoDescriptor.getWantAuthnRequestsSigned())) - .verificationX509Credentials(c -> c.addAll(verification)) - .encryptionX509Credentials(c -> c.addAll(encryption))); + RelyingPartyRegistration.Builder builder = RelyingPartyRegistration.withRegistrationId(descriptor.getEntityID()) + .assertingPartyDetails((party) -> party.entityId(descriptor.getEntityID()) + .wantAuthnRequestsSigned(Boolean.TRUE.equals(idpssoDescriptor.getWantAuthnRequestsSigned())) + .verificationX509Credentials((c) -> c.addAll(verification)) + .encryptionX509Credentials((c) -> c.addAll(encryption))); for (SingleSignOnService singleSignOnService : idpssoDescriptor.getSingleSignOnServices()) { Saml2MessageBinding binding; if (singleSignOnService.getBinding().equals(Saml2MessageBinding.POST.getUrn())) { binding = Saml2MessageBinding.POST; - } else if (singleSignOnService.getBinding().equals(Saml2MessageBinding.REDIRECT.getUrn())) { + } + else if (singleSignOnService.getBinding().equals(Saml2MessageBinding.REDIRECT.getUrn())) { binding = Saml2MessageBinding.REDIRECT; - } else { + } + else { continue; } - builder.assertingPartyDetails(party -> party - .singleSignOnServiceLocation(singleSignOnService.getLocation()) - .singleSignOnServiceBinding(binding)); + builder.assertingPartyDetails( + (party) -> party.singleSignOnServiceLocation(singleSignOnService.getLocation()) + .singleSignOnServiceBinding(binding)); return builder; } - throw new Saml2Exception("Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests"); + throw new Saml2Exception( + "Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests"); + } + + private List getVerification(IDPSSODescriptor idpssoDescriptor) { + List verification = new ArrayList<>(); + for (KeyDescriptor keyDescriptor : idpssoDescriptor.getKeyDescriptors()) { + if (keyDescriptor.getUse().equals(UsageType.SIGNING)) { + List certificates = certificates(keyDescriptor); + for (X509Certificate certificate : certificates) { + verification.add(Saml2X509Credential.verification(certificate)); + } + } + } + return verification; + } + + private List getEncryption(IDPSSODescriptor idpssoDescriptor) { + List encryption = new ArrayList<>(); + for (KeyDescriptor keyDescriptor : idpssoDescriptor.getKeyDescriptors()) { + if (keyDescriptor.getUse().equals(UsageType.ENCRYPTION)) { + List certificates = certificates(keyDescriptor); + for (X509Certificate certificate : certificates) { + encryption.add(Saml2X509Credential.encryption(certificate)); + } + } + } + return encryption; } private List certificates(KeyDescriptor keyDescriptor) { try { return KeyInfoSupport.getCertificates(keyDescriptor.getKeyInfo()); - } catch (CertificateException e) { - throw new Saml2Exception(e); + } + catch (CertificateException ex) { + throw new Saml2Exception(ex); } } @@ -196,13 +211,16 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter Document document = this.parserPool.parse(inputStream); Element element = document.getDocumentElement(); return (EntityDescriptor) this.unmarshaller.unmarshall(element); - } catch (Exception e) { - throw new Saml2Exception(e); + } + catch (Exception ex) { + throw new Saml2Exception(ex); } } @Override - public void write(RelyingPartyRegistration.Builder builder, MediaType contentType, HttpOutputMessage outputMessage) throws HttpMessageNotWritableException { + public void write(RelyingPartyRegistration.Builder builder, MediaType contentType, HttpOutputMessage outputMessage) + throws HttpMessageNotWritableException { throw new HttpMessageNotWritableException("This converter cannot write a RelyingPartyRegistration.Builder"); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index b1435d7336..137e97f88d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -31,10 +31,12 @@ import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.util.Assert; /** - * Represents a configured relying party (aka Service Provider) and asserting party (aka Identity Provider) pair. + * Represents a configured relying party (aka Service Provider) and asserting party (aka + * Identity Provider) pair. * *

        - * Each RP/AP pair is uniquely identified using a {@code registrationId}, an arbitrary string. + * Each RP/AP pair is uniquely identified using a {@code registrationId}, an arbitrary + * string. * *

        * A fully configured registration may look like: @@ -54,39 +56,41 @@ import org.springframework.util.Assert; * RelyingPartyRegistration rp = RelyingPartyRegistration.withRegistrationId(registrationId) * .entityId(relyingPartyEntityId) * .assertionConsumerServiceLocation(assertingConsumerServiceLocation) - * .signingX509Credentials(c -> c.add(relyingPartySigningCredential)) - * .assertingPartyDetails(details -> details + * .signingX509Credentials((c) -> c.add(relyingPartySigningCredential)) + * .assertingPartyDetails((details) -> details * .entityId(assertingPartyEntityId)); * .singleSignOnServiceLocation(singleSignOnServiceLocation)) - * .verifyingX509Credentials(c -> c.add(assertingPartyVerificationCredential)) + * .verifyingX509Credentials((c) -> c.add(assertingPartyVerificationCredential)) * .build(); * * - * @since 5.2 * @author Filip Hanik * @author Josh Cummings + * @since 5.2 */ -public class RelyingPartyRegistration { +public final class RelyingPartyRegistration { private final String registrationId; + private final String entityId; + private final String assertionConsumerServiceLocation; + private final Saml2MessageBinding assertionConsumerServiceBinding; + private final ProviderDetails providerDetails; + private final List credentials; + private final Collection decryptionX509Credentials; + private final Collection signingX509Credentials; - private RelyingPartyRegistration( - String registrationId, - String entityId, - String assertionConsumerServiceLocation, - Saml2MessageBinding assertionConsumerServiceBinding, - ProviderDetails providerDetails, + private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation, + Saml2MessageBinding assertionConsumerServiceBinding, ProviderDetails providerDetails, Collection credentials, Collection decryptionX509Credentials, Collection signingX509Credentials) { - Assert.hasText(registrationId, "registrationId cannot be empty"); Assert.hasText(entityId, "entityId cannot be empty"); Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty"); @@ -105,8 +109,7 @@ public class RelyingPartyRegistration { Assert.notNull(signingX509Credentials, "signingX509Credentials cannot be null"); for (Saml2X509Credential c : signingX509Credentials) { Assert.notNull(c, "signingX509Credentials cannot contain null elements"); - Assert.isTrue(c.isSigningCredential(), - "All signingX509Credentials must have a usage of SIGNING set"); + Assert.isTrue(c.isSigningCredential(), "All signingX509Credentials must have a usage of SIGNING set"); } this.registrationId = registrationId; this.entityId = entityId; @@ -120,7 +123,6 @@ public class RelyingPartyRegistration { /** * Get the unique registration id for this RP/AP pair - * * @return the unique registration id for this RP/AP pair */ public String getRegistrationId() { @@ -128,18 +130,17 @@ public class RelyingPartyRegistration { } /** - * Get the relying party's - * EntityID. + * Get the relying party's EntityID. * *

        - * Equivalent to the value found in the relying party's - * <EntityDescriptor EntityID="..."/> + * Equivalent to the value found in the relying party's <EntityDescriptor + * EntityID="..."/> * *

        - * This value may contain a number of placeholders, which need to be - * resolved before use. They are {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}. - * + * This value may contain a number of placeholders, which need to be resolved before + * use. They are {@code baseUrl}, {@code registrationId}, {@code baseScheme}, + * {@code baseHost}, and {@code basePort}. * @return the relying party's EntityID * @since 5.4 */ @@ -148,14 +149,13 @@ public class RelyingPartyRegistration { } /** - * Get the AssertionConsumerService Location. - * Equivalent to the value found in <AssertionConsumerService Location="..."/> - * in the relying party's <SPSSODescriptor>. - * - * This value may contain a number of placeholders, which need to be - * resolved before use. They are {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}. + * Get the AssertionConsumerService Location. Equivalent to the value found in + * <AssertionConsumerService Location="..."/> in the relying party's + * <SPSSODescriptor>. * + * This value may contain a number of placeholders, which need to be resolved before + * use. They are {@code baseUrl}, {@code registrationId}, {@code baseScheme}, + * {@code baseHost}, and {@code basePort}. * @return the AssertionConsumerService Location * @since 5.4 */ @@ -164,10 +164,9 @@ public class RelyingPartyRegistration { } /** - * Get the AssertionConsumerService Binding. - * Equivalent to the value found in <AssertionConsumerService Binding="..."/> - * in the relying party's <SPSSODescriptor>. - * + * Get the AssertionConsumerService Binding. Equivalent to the value found in + * <AssertionConsumerService Binding="..."/> in the relying party's + * <SPSSODescriptor>. * @return the AssertionConsumerService Binding * @since 5.4 */ @@ -176,9 +175,10 @@ public class RelyingPartyRegistration { } /** - * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated with this relying party - * - * @return the {@link Collection} of decryption {@link Saml2X509Credential}s associated with this relying party + * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated + * with this relying party + * @return the {@link Collection} of decryption {@link Saml2X509Credential}s + * associated with this relying party * @since 5.4 */ public Collection getDecryptionX509Credentials() { @@ -186,9 +186,10 @@ public class RelyingPartyRegistration { } /** - * Get the {@link Collection} of signing {@link Saml2X509Credential}s associated with this relying party - * - * @return the {@link Collection} of signing {@link Saml2X509Credential}s associated with this relying party + * Get the {@link Collection} of signing {@link Saml2X509Credential}s associated with + * this relying party + * @return the {@link Collection} of signing {@link Saml2X509Credential}s associated + * with this relying party * @since 5.4 */ public Collection getSigningX509Credentials() { @@ -197,7 +198,6 @@ public class RelyingPartyRegistration { /** * Get the configuration details for the Asserting Party - * * @return the {@link AssertingPartyDetails} * @since 5.4 */ @@ -208,7 +208,8 @@ public class RelyingPartyRegistration { /** * Returns the entity ID of the IDP, the asserting party. * @return entity ID of the asserting party - * @deprecated use {@link AssertingPartyDetails#getEntityId} from {@link #getAssertingPartyDetails} + * @deprecated use {@link AssertingPartyDetails#getEntityId} from + * {@link #getAssertingPartyDetails} */ @Deprecated public String getRemoteIdpEntityId() { @@ -217,8 +218,8 @@ public class RelyingPartyRegistration { /** * returns the URL template for which ACS URL authentication requests should contain - * Possible variables are {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}. + * Possible variables are {@code baseUrl}, {@code registrationId}, {@code baseScheme}, + * {@code baseHost}, and {@code basePort}. * @return string containing the ACS URL template, with or without variables present * @deprecated Use {@link #getAssertionConsumerServiceLocation} instead */ @@ -228,10 +229,11 @@ public class RelyingPartyRegistration { } /** - * Contains the URL for which to send the SAML 2 Authentication Request to initiate - * a single sign on flow. + * Contains the URL for which to send the SAML 2 Authentication Request to initiate a + * single sign on flow. * @return a IDP URL that accepts REDIRECT or POST binding for authentication requests - * @deprecated use {@link AssertingPartyDetails#getSingleSignOnServiceLocation} from {@link #getAssertingPartyDetails} + * @deprecated use {@link AssertingPartyDetails#getSingleSignOnServiceLocation} from + * {@link #getAssertingPartyDetails} */ @Deprecated public String getIdpWebSsoUrl() { @@ -251,8 +253,8 @@ public class RelyingPartyRegistration { /** * The local relying party, or Service Provider, can generate it's entity ID based on - * possible variables of {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}, for example + * possible variables of {@code baseUrl}, {@code registrationId}, {@code baseScheme}, + * {@code baseHost}, and {@code basePort}, for example * {@code {baseUrl}/saml2/service-provider-metadata/{registrationId}} * @return a string containing the entity ID or entity ID template * @deprecated Use {@link #getEntityId} instead @@ -263,10 +265,11 @@ public class RelyingPartyRegistration { } /** - * Returns a list of configured credentials to be used in message exchanges between relying party, SP, and - * asserting party, IDP. + * Returns a list of configured credentials to be used in message exchanges between + * relying party, SP, and asserting party, IDP. * @return a list of credentials - * @deprecated Instead of retrieving all credentials, use the appropriate method for obtaining the correct type + * @deprecated Instead of retrieving all credentials, use the appropriate method for + * obtaining the correct type */ @Deprecated public List getCredentials() { @@ -277,11 +280,13 @@ public class RelyingPartyRegistration { * @return a filtered list containing only credentials of type * {@link org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType#VERIFICATION}. * Returns an empty list of credentials are not found - * @deprecated Use {@link #getAssertingPartyDetails().getSigningX509Credentials()} instead + * @deprecated Use {code #getAssertingPartyDetails().getSigningX509Credentials()} + * instead */ @Deprecated public List getVerificationCredentials() { - return filterCredentials(org.springframework.security.saml2.credentials.Saml2X509Credential::isSignatureVerficationCredential); + return filterCredentials( + org.springframework.security.saml2.credentials.Saml2X509Credential::isSignatureVerficationCredential); } /** @@ -292,18 +297,21 @@ public class RelyingPartyRegistration { */ @Deprecated public List getSigningCredentials() { - return filterCredentials(org.springframework.security.saml2.credentials.Saml2X509Credential::isSigningCredential); + return filterCredentials( + org.springframework.security.saml2.credentials.Saml2X509Credential::isSigningCredential); } /** * @return a filtered list containing only credentials of type * {@link org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType#ENCRYPTION}. * Returns an empty list of credentials are not found - * @deprecated Use {@link AssertingPartyDetails#getEncryptionX509Credentials()} instead + * @deprecated Use {@link AssertingPartyDetails#getEncryptionX509Credentials()} + * instead */ @Deprecated public List getEncryptionCredentials() { - return filterCredentials(org.springframework.security.saml2.credentials.Saml2X509Credential::isEncryptionCredential); + return filterCredentials( + org.springframework.security.saml2.credentials.Saml2X509Credential::isEncryptionCredential); } /** @@ -314,12 +322,12 @@ public class RelyingPartyRegistration { */ @Deprecated public List getDecryptionCredentials() { - return filterCredentials(org.springframework.security.saml2.credentials.Saml2X509Credential::isDecryptionCredential); + return filterCredentials( + org.springframework.security.saml2.credentials.Saml2X509Credential::isDecryptionCredential); } private List filterCredentials( Function filter) { - List result = new LinkedList<>(); for (org.springframework.security.saml2.credentials.Saml2X509Credential c : this.credentials) { if (filter.apply(c)) { @@ -330,7 +338,8 @@ public class RelyingPartyRegistration { } /** - * Creates a {@code RelyingPartyRegistration} {@link Builder} with a known {@code registrationId} + * Creates a {@code RelyingPartyRegistration} {@link Builder} with a known + * {@code registrationId} * @param registrationId a string identifier for the {@code RelyingPartyRegistration} * @return {@code Builder} to create a {@code RelyingPartyRegistration} object */ @@ -340,50 +349,99 @@ public class RelyingPartyRegistration { } /** - * Creates a {@code RelyingPartyRegistration} {@link Builder} based on an existing object + * Creates a {@code RelyingPartyRegistration} {@link Builder} based on an existing + * object * @param registration the {@code RelyingPartyRegistration} * @return {@code Builder} to create a {@code RelyingPartyRegistration} object */ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { Assert.notNull(registration, "registration cannot be null"); - return withRegistrationId(registration.getRegistrationId()) - .entityId(registration.getEntityId()) - .signingX509Credentials(c -> c.addAll(registration.getSigningX509Credentials())) - .decryptionX509Credentials(c -> c.addAll(registration.getDecryptionX509Credentials())) + return withRegistrationId(registration.getRegistrationId()).entityId(registration.getEntityId()) + .signingX509Credentials((c) -> c.addAll(registration.getSigningX509Credentials())) + .decryptionX509Credentials((c) -> c.addAll(registration.getDecryptionX509Credentials())) .assertionConsumerServiceLocation(registration.getAssertionConsumerServiceLocation()) .assertionConsumerServiceBinding(registration.getAssertionConsumerServiceBinding()) - .assertingPartyDetails(assertingParty -> assertingParty - .entityId(registration.getAssertingPartyDetails().getEntityId()) - .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) - .verificationX509Credentials(c -> c.addAll(registration.getAssertingPartyDetails().getVerificationX509Credentials())) - .encryptionX509Credentials(c -> c.addAll(registration.getAssertingPartyDetails().getEncryptionX509Credentials())) - .singleSignOnServiceLocation(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()) - .singleSignOnServiceBinding(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding()) - ); + .assertingPartyDetails((assertingParty) -> assertingParty + .entityId(registration.getAssertingPartyDetails().getEntityId()) + .wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) + .verificationX509Credentials((c) -> c + .addAll(registration.getAssertingPartyDetails().getVerificationX509Credentials())) + .encryptionX509Credentials( + (c) -> c.addAll(registration.getAssertingPartyDetails().getEncryptionX509Credentials())) + .singleSignOnServiceLocation( + registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()) + .singleSignOnServiceBinding( + registration.getAssertingPartyDetails().getSingleSignOnServiceBinding())); } + private static Saml2X509Credential fromDeprecated( + org.springframework.security.saml2.credentials.Saml2X509Credential credential) { + PrivateKey privateKey = credential.getPrivateKey(); + X509Certificate certificate = credential.getCertificate(); + Set credentialTypes = new HashSet<>(); + if (credential.isSigningCredential()) { + credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.SIGNING); + } + if (credential.isSignatureVerficationCredential()) { + credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.VERIFICATION); + } + if (credential.isEncryptionCredential()) { + credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION); + } + if (credential.isDecryptionCredential()) { + credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.DECRYPTION); + } + return new Saml2X509Credential(privateKey, certificate, credentialTypes); + } + + private static org.springframework.security.saml2.credentials.Saml2X509Credential toDeprecated( + Saml2X509Credential credential) { + PrivateKey privateKey = credential.getPrivateKey(); + X509Certificate certificate = credential.getCertificate(); + Set credentialTypes = new HashSet<>(); + if (credential.isSigningCredential()) { + credentialTypes.add( + org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING); + } + if (credential.isVerificationCredential()) { + credentialTypes.add( + org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION); + } + if (credential.isEncryptionCredential()) { + credentialTypes.add( + org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION); + } + if (credential.isDecryptionCredential()) { + credentialTypes.add( + org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION); + } + return new org.springframework.security.saml2.credentials.Saml2X509Credential(privateKey, certificate, + credentialTypes); + } /** * The configuration metadata of the Asserting party * * @since 5.4 */ - public final static class AssertingPartyDetails { + public static final class AssertingPartyDetails { + private final String entityId; + private final boolean wantAuthnRequestsSigned; + private final Collection verificationX509Credentials; + private final Collection encryptionX509Credentials; + private final String singleSignOnServiceLocation; + private final Saml2MessageBinding singleSignOnServiceBinding; - private AssertingPartyDetails( - String entityId, - boolean wantAuthnRequestsSigned, + private AssertingPartyDetails(String entityId, boolean wantAuthnRequestsSigned, Collection verificationX509Credentials, - Collection encryptionX509Credentials, - String singleSignOnServiceLocation, + Collection encryptionX509Credentials, String singleSignOnServiceLocation, Saml2MessageBinding singleSignOnServiceBinding) { - Assert.hasText(entityId, "entityId cannot be null or empty"); Assert.notNull(verificationX509Credentials, "verificationX509Credentials cannot be null"); for (Saml2X509Credential credential : verificationX509Credentials) { @@ -408,18 +466,17 @@ public class RelyingPartyRegistration { } /** - * Get the asserting party's - * EntityID. + * Get the asserting party's EntityID. * *

        - * Equivalent to the value found in the asserting party's - * <EntityDescriptor EntityID="..."/> + * Equivalent to the value found in the asserting party's <EntityDescriptor + * EntityID="..."/> * *

        - * This value may contain a number of placeholders, which need to be - * resolved before use. They are {@code baseUrl}, {@code registrationId}, + * This value may contain a number of placeholders, which need to be resolved + * before use. They are {@code baseUrl}, {@code registrationId}, * {@code baseScheme}, {@code baseHost}, and {@code basePort}. - * * @return the asserting party's EntityID */ public String getEntityId() { @@ -427,9 +484,8 @@ public class RelyingPartyRegistration { } /** - * Get the WantAuthnRequestsSigned setting, indicating the asserting party's preference that - * relying parties should sign the AuthnRequest before sending. - * + * Get the WantAuthnRequestsSigned setting, indicating the asserting party's + * preference that relying parties should sign the AuthnRequest before sending. * @return the WantAuthnRequestsSigned value */ public boolean getWantAuthnRequestsSigned() { @@ -437,9 +493,10 @@ public class RelyingPartyRegistration { } /** - * Get all verification {@link Saml2X509Credential}s associated with this asserting party - * - * @return all verification {@link Saml2X509Credential}s associated with this asserting party + * Get all verification {@link Saml2X509Credential}s associated with this + * asserting party + * @return all verification {@link Saml2X509Credential}s associated with this + * asserting party * @since 5.4 */ public Collection getVerificationX509Credentials() { @@ -447,9 +504,10 @@ public class RelyingPartyRegistration { } /** - * Get all encryption {@link Saml2X509Credential}s associated with this asserting party - * - * @return all encryption {@link Saml2X509Credential}s associated with this asserting party + * Get all encryption {@link Saml2X509Credential}s associated with this asserting + * party + * @return all encryption {@link Saml2X509Credential}s associated with this + * asserting party * @since 5.4 */ public Collection getEncryptionX509Credentials() { @@ -457,14 +515,13 @@ public class RelyingPartyRegistration { } /** - * Get the - * SingleSignOnService + * Get the SingleSignOnService * Location. * *

        - * Equivalent to the value found in <SingleSignOnService Location="..."/> - * in the asserting party's <IDPSSODescriptor>. - * + * Equivalent to the value found in <SingleSignOnService Location="..."/> in + * the asserting party's <IDPSSODescriptor>. * @return the SingleSignOnService Location */ public String getSingleSignOnServiceLocation() { @@ -472,34 +529,38 @@ public class RelyingPartyRegistration { } /** - * Get the - * SingleSignOnService + * Get the SingleSignOnService * Binding. * *

        - * Equivalent to the value found in <SingleSignOnService Binding="..."/> - * in the asserting party's <IDPSSODescriptor>. - * + * Equivalent to the value found in <SingleSignOnService Binding="..."/> in + * the asserting party's <IDPSSODescriptor>. * @return the SingleSignOnService Location */ public Saml2MessageBinding getSingleSignOnServiceBinding() { return this.singleSignOnServiceBinding; } - public final static class Builder { + public static final class Builder { + private String entityId; + private boolean wantAuthnRequestsSigned = true; + private Collection verificationX509Credentials = new HashSet<>(); + private Collection encryptionX509Credentials = new HashSet<>(); + private String singleSignOnServiceLocation; + private Saml2MessageBinding singleSignOnServiceBinding = Saml2MessageBinding.REDIRECT; /** - * Set the asserting party's - * EntityID. - * Equivalent to the value found in the asserting party's - * <EntityDescriptor EntityID="..."/> - * + * Set the asserting party's EntityID. + * Equivalent to the value found in the asserting party's <EntityDescriptor + * EntityID="..."/> * @param entityId the asserting party's EntityID * @return the {@link ProviderDetails.Builder} for further configuration */ @@ -509,9 +570,9 @@ public class RelyingPartyRegistration { } /** - * Set the WantAuthnRequestsSigned setting, indicating the asserting party's preference that - * relying parties should sign the AuthnRequest before sending. - * + * Set the WantAuthnRequestsSigned setting, indicating the asserting party's + * preference that relying parties should sign the AuthnRequest before + * sending. * @param wantAuthnRequestsSigned the WantAuthnRequestsSigned setting * @return the {@link ProviderDetails.Builder} for further configuration */ @@ -522,9 +583,10 @@ public class RelyingPartyRegistration { /** * Apply this {@link Consumer} to the list of {@link Saml2X509Credential}s - * - * @param credentialsConsumer a {@link Consumer} of the {@link List} of {@link Saml2X509Credential}s - * @return the {@link RelyingPartyRegistration.Builder} for further configuration + * @param credentialsConsumer a {@link Consumer} of the {@link List} of + * {@link Saml2X509Credential}s + * @return the {@link RelyingPartyRegistration.Builder} for further + * configuration * @since 5.4 */ public Builder verificationX509Credentials(Consumer> credentialsConsumer) { @@ -534,9 +596,10 @@ public class RelyingPartyRegistration { /** * Apply this {@link Consumer} to the list of {@link Saml2X509Credential}s - * - * @param credentialsConsumer a {@link Consumer} of the {@link List} of {@link Saml2X509Credential}s - * @return the {@link RelyingPartyRegistration.Builder} for further configuration + * @param credentialsConsumer a {@link Consumer} of the {@link List} of + * {@link Saml2X509Credential}s + * @return the {@link RelyingPartyRegistration.Builder} for further + * configuration * @since 5.4 */ public Builder encryptionX509Credentials(Consumer> credentialsConsumer) { @@ -545,14 +608,13 @@ public class RelyingPartyRegistration { } /** - * Set the - * SingleSignOnService + * Set the SingleSignOnService * Location. * *

        - * Equivalent to the value found in <SingleSignOnService Location="..."/> - * in the asserting party's <IDPSSODescriptor>. - * + * Equivalent to the value found in <SingleSignOnService + * Location="..."/> in the asserting party's <IDPSSODescriptor>. * @param singleSignOnServiceLocation the SingleSignOnService Location * @return the {@link ProviderDetails.Builder} for further configuration */ @@ -562,14 +624,13 @@ public class RelyingPartyRegistration { } /** - * Set the - * SingleSignOnService + * Set the SingleSignOnService * Binding. * *

        * Equivalent to the value found in <SingleSignOnService Binding="..."/> * in the asserting party's <IDPSSODescriptor>. - * * @param singleSignOnServiceBinding the SingleSignOnService Binding * @return the {@link ProviderDetails.Builder} for further configuration */ @@ -579,29 +640,29 @@ public class RelyingPartyRegistration { } /** - * Creates an immutable ProviderDetails object representing the configuration for an Identity Provider, IDP + * Creates an immutable ProviderDetails object representing the configuration + * for an Identity Provider, IDP * @return immutable ProviderDetails object */ public AssertingPartyDetails build() { - return new AssertingPartyDetails( - this.entityId, - this.wantAuthnRequestsSigned, - this.verificationX509Credentials, - this.encryptionX509Credentials, - this.singleSignOnServiceLocation, - this.singleSignOnServiceBinding - ); + return new AssertingPartyDetails(this.entityId, this.wantAuthnRequestsSigned, + this.verificationX509Credentials, this.encryptionX509Credentials, + this.singleSignOnServiceLocation, this.singleSignOnServiceBinding); } + } + } /** * Configuration for IDP SSO endpoint configuration + * * @since 5.3 * @deprecated Use {@link AssertingPartyDetails} instead */ @Deprecated - public final static class ProviderDetails { + public static final class ProviderDetails { + private final AssertingPartyDetails assertingPartyDetails; private ProviderDetails(AssertingPartyDetails assertingPartyDetails) { @@ -618,17 +679,18 @@ public class RelyingPartyRegistration { } /** - * Contains the URL for which to send the SAML 2 Authentication Request to initiate - * a single sign on flow. - * @return a IDP URL that accepts REDIRECT or POST binding for authentication requests + * Contains the URL for which to send the SAML 2 Authentication Request to + * initiate a single sign on flow. + * @return a IDP URL that accepts REDIRECT or POST binding for authentication + * requests */ public String getWebSsoUrl() { return this.assertingPartyDetails.getSingleSignOnServiceLocation(); } /** - * @return {@code true} if AuthNRequests from this relying party to the IDP should be signed - * {@code false} if no signature is required. + * @return {@code true} if AuthNRequests from this relying party to the IDP should + * be signed {@code false} if no signature is required. */ public boolean isSignAuthNRequest() { return this.assertingPartyDetails.getWantAuthnRequestsSigned(); @@ -643,20 +705,20 @@ public class RelyingPartyRegistration { /** * Builder for IDP SSO endpoint configuration + * * @since 5.3 * @deprecated Use {@link AssertingPartyDetails.Builder} instead */ @Deprecated - public final static class Builder { - private final AssertingPartyDetails.Builder assertingPartyDetailsBuilder = - new AssertingPartyDetails.Builder(); + public static final class Builder { + + private final AssertingPartyDetails.Builder assertingPartyDetailsBuilder = new AssertingPartyDetails.Builder(); /** - * Set the asserting party's - * EntityID. - * Equivalent to the value found in the asserting party's - * <EntityDescriptor EntityID="..."/> - * + * Set the asserting party's EntityID. + * Equivalent to the value found in the asserting party's <EntityDescriptor + * EntityID="..."/> * @param entityId the asserting party's EntityID * @return the {@link Builder} for further configuration * @since 5.4 @@ -667,9 +729,10 @@ public class RelyingPartyRegistration { } /** - * Sets the {@code SSO URL} for the remote asserting party, the Identity Provider. - * - * @param url - a URL that accepts authentication requests via REDIRECT or POST bindings + * Sets the {@code SSO URL} for the remote asserting party, the Identity + * Provider. + * @param url - a URL that accepts authentication requests via REDIRECT or + * POST bindings * @return this object */ public Builder webSsoUrl(String url) { @@ -679,7 +742,6 @@ public class RelyingPartyRegistration { /** * Set to true if the AuthNRequest message should be signed - * * @param signAuthNRequest true if the message should be signed * @return this object */ @@ -688,11 +750,10 @@ public class RelyingPartyRegistration { return this; } - /** * Sets the message binding to be used when sending an AuthNRequest message - * - * @param binding either {@link Saml2MessageBinding#POST} or {@link Saml2MessageBinding#REDIRECT} + * @param binding either {@link Saml2MessageBinding#POST} or + * {@link Saml2MessageBinding#REDIRECT} * @return this object */ public Builder binding(Saml2MessageBinding binding) { @@ -701,30 +762,40 @@ public class RelyingPartyRegistration { } /** - * Creates an immutable ProviderDetails object representing the configuration for an Identity Provider, IDP + * Creates an immutable ProviderDetails object representing the configuration + * for an Identity Provider, IDP * @return immutable ProviderDetails object */ public ProviderDetails build() { return new ProviderDetails(this.assertingPartyDetailsBuilder.build()); } + } + } - public final static class Builder { + public static final class Builder { + private String registrationId; + private String entityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; + private Collection signingX509Credentials = new HashSet<>(); + private Collection decryptionX509Credentials = new HashSet<>(); + private String assertionConsumerServiceLocation = "{baseUrl}/login/saml2/sso/{registrationId}"; + private Saml2MessageBinding assertionConsumerServiceBinding = Saml2MessageBinding.POST; + private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder(); + private Collection credentials = new HashSet<>(); private Builder(String registrationId) { this.registrationId = registrationId; } - /** * Sets the {@code registrationId} template. Often be used in URL paths * @param id registrationId for this object, should be unique @@ -736,15 +807,14 @@ public class RelyingPartyRegistration { } /** - * Set the relying party's - * EntityID. - * Equivalent to the value found in the relying party's - * <EntityDescriptor EntityID="..."/> - * - * This value may contain a number of placeholders. - * They are {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}. + * Set the relying party's EntityID. + * Equivalent to the value found in the relying party's <EntityDescriptor + * EntityID="..."/> * + * This value may contain a number of placeholders. They are {@code baseUrl}, + * {@code registrationId}, {@code baseScheme}, {@code baseHost}, and + * {@code basePort}. * @return the {@link Builder} for further configuration * @since 5.4 */ @@ -754,10 +824,11 @@ public class RelyingPartyRegistration { } /** - * Apply this {@link Consumer} to the {@link Collection} of {@link Saml2X509Credential}s - * for the purposes of modifying the {@link Collection} - * - * @param credentialsConsumer - the {@link Consumer} for modifying the {@link Collection} + * Apply this {@link Consumer} to the {@link Collection} of + * {@link Saml2X509Credential}s for the purposes of modifying the + * {@link Collection} + * @param credentialsConsumer - the {@link Consumer} for modifying the + * {@link Collection} * @return the {@link Builder} for further configuration * @since 5.4 */ @@ -767,10 +838,11 @@ public class RelyingPartyRegistration { } /** - * Apply this {@link Consumer} to the {@link Collection} of {@link Saml2X509Credential}s - * for the purposes of modifying the {@link Collection} - * - * @param credentialsConsumer - the {@link Consumer} for modifying the {@link Collection} + * Apply this {@link Consumer} to the {@link Collection} of + * {@link Saml2X509Credential}s for the purposes of modifying the + * {@link Collection} + * @param credentialsConsumer - the {@link Consumer} for modifying the + * {@link Collection} * @return the {@link Builder} for further configuration * @since 5.4 */ @@ -780,18 +852,18 @@ public class RelyingPartyRegistration { } /** - * Set the AssertionConsumerService + * Set the AssertionConsumerService * Location. * *

        - * Equivalent to the value found in <AssertionConsumerService Location="..."/> - * in the relying party's <SPSSODescriptor> + * Equivalent to the value found in <AssertionConsumerService + * Location="..."/> in the relying party's <SPSSODescriptor> * *

        - * This value may contain a number of placeholders. - * They are {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}. - * + * This value may contain a number of placeholders. They are {@code baseUrl}, + * {@code registrationId}, {@code baseScheme}, {@code baseHost}, and + * {@code basePort}. * @param assertionConsumerServiceLocation * @return the {@link Builder} for further configuration * @since 5.4 @@ -802,13 +874,13 @@ public class RelyingPartyRegistration { } /** - * Set the AssertionConsumerService + * Set the AssertionConsumerService * Binding. * *

        - * Equivalent to the value found in <AssertionConsumerService Binding="..."/> - * in the relying party's <SPSSODescriptor> - * + * Equivalent to the value found in <AssertionConsumerService + * Binding="..."/> in the relying party's <SPSSODescriptor> * @param assertionConsumerServiceBinding * @return the {@link Builder} for further configuration * @since 5.4 @@ -820,7 +892,6 @@ public class RelyingPartyRegistration { /** * Apply this {@link Consumer} to further configure the Asserting Party details - * * @param assertingPartyDetails The {@link Consumer} to apply * @return the {@link Builder} for further configuration * @since 5.4 @@ -831,34 +902,37 @@ public class RelyingPartyRegistration { } /** - * Modifies the collection of {@link Saml2X509Credential} objects - * used in communication between IDP and SP - * For example: - * + * Modifies the collection of {@link Saml2X509Credential} objects used in + * communication between IDP and SP For example: * Saml2X509Credential credential = ...; * return RelyingPartyRegistration.withRegistrationId("id") - * .credentials(c -> c.add(credential)) + * .credentials((c) -> c.add(credential)) * ... * .build(); * * @param credentials - a consumer that can modify the collection of credentials * @return this object - * @deprecated Use {@link #signingX509Credentials} or {@link #decryptionX509Credentials} instead - * for relying party keys or {@link AssertingPartyDetails.Builder#verificationX509Credentials} or - * {@link AssertingPartyDetails.Builder#encryptionX509Credentials} for asserting party keys + * @deprecated Use {@link #signingX509Credentials} or + * {@link #decryptionX509Credentials} instead for relying party keys or + * {@link AssertingPartyDetails.Builder#verificationX509Credentials} or + * {@link AssertingPartyDetails.Builder#encryptionX509Credentials} for asserting + * party keys */ @Deprecated - public Builder credentials(Consumer> credentials) { + public Builder credentials( + Consumer> credentials) { credentials.accept(this.credentials); return this; } /** - * Assertion Consumer - * Service URL template. It can contain variables {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}. - * @param assertionConsumerServiceUrlTemplate the Assertion Consumer Service URL template (i.e. - * "{baseUrl}/login/saml2/sso/{registrationId}". + * Assertion + * Consumer Service URL template. It can contain variables {@code baseUrl}, + * {@code registrationId}, {@code baseScheme}, {@code baseHost}, and + * {@code basePort}. + * @param assertionConsumerServiceUrlTemplate the Assertion Consumer Service URL + * template (i.e. "{baseUrl}/login/saml2/sso/{registrationId}". * @return this object * @deprecated Use {@link #assertionConsumerServiceLocation} instead. */ @@ -869,33 +943,38 @@ public class RelyingPartyRegistration { } /** - * Sets the {@code entityId} for the remote asserting party, the Identity Provider. + * Sets the {@code entityId} for the remote asserting party, the Identity + * Provider. * @param entityId the IDP entityId * @return this object - * @deprecated use {@link #assertingPartyDetails(Consumer< AssertingPartyDetails.Builder >)} + * @deprecated use + * {@code #assertingPartyDetails(Consumer)} */ @Deprecated public Builder remoteIdpEntityId(String entityId) { - assertingPartyDetails(idp -> idp.entityId(entityId)); + assertingPartyDetails((idp) -> idp.entityId(entityId)); return this; } /** * Sets the {@code SSO URL} for the remote asserting party, the Identity Provider. - * @param url - a URL that accepts authentication requests via REDIRECT or POST bindings + * @param url - a URL that accepts authentication requests via REDIRECT or POST + * bindings * @return this object - * @deprecated use {@link #assertingPartyDetails(Consumer< AssertingPartyDetails.Builder >)} + * @deprecated use + * {@code #assertingPartyDetails(Consumer)} */ @Deprecated public Builder idpWebSsoUrl(String url) { - assertingPartyDetails(config -> config.singleSignOnServiceLocation(url)); + assertingPartyDetails((config) -> config.singleSignOnServiceLocation(url)); return this; } /** - * Sets the local relying party, or Service Provider, entity Id template. - * can generate it's entity ID based on possible variables of {@code baseUrl}, {@code registrationId}, - * {@code baseScheme}, {@code baseHost}, and {@code basePort}, for example + * Sets the local relying party, or Service Provider, entity Id template. can + * generate it's entity ID based on possible variables of {@code baseUrl}, + * {@code registrationId}, {@code baseScheme}, {@code baseHost}, and + * {@code basePort}, for example * {@code {baseUrl}/saml2/service-provider-metadata/{registrationId}} * @return a string containing the entity ID or entity ID template * @deprecated Use {@link #entityId} instead @@ -919,25 +998,24 @@ public class RelyingPartyRegistration { } /** - * Constructs a RelyingPartyRegistration object based on the builder configurations + * Constructs a RelyingPartyRegistration object based on the builder + * configurations * @return a RelyingPartyRegistration instance */ public RelyingPartyRegistration build() { for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : this.credentials) { Saml2X509Credential mapped = fromDeprecated(credential); if (credential.isSigningCredential()) { - signingX509Credentials(c -> c.add(mapped)); + signingX509Credentials((c) -> c.add(mapped)); } if (credential.isDecryptionCredential()) { - decryptionX509Credentials(c -> c.add(mapped)); + decryptionX509Credentials((c) -> c.add(mapped)); } if (credential.isSignatureVerficationCredential()) { - this.providerDetails.assertingPartyDetailsBuilder - .verificationX509Credentials(c -> c.add(mapped)); + this.providerDetails.assertingPartyDetailsBuilder.verificationX509Credentials((c) -> c.add(mapped)); } if (credential.isEncryptionCredential()) { - this.providerDetails.assertingPartyDetailsBuilder - .encryptionX509Credentials(c -> c.add(mapped)); + this.providerDetails.assertingPartyDetailsBuilder.encryptionX509Credentials((c) -> c.add(mapped)); } } @@ -953,55 +1031,12 @@ public class RelyingPartyRegistration { for (Saml2X509Credential credential : this.providerDetails.assertingPartyDetailsBuilder.encryptionX509Credentials) { this.credentials.add(toDeprecated(credential)); } - - return new RelyingPartyRegistration( - this.registrationId, - this.entityId, - this.assertionConsumerServiceLocation, - this.assertionConsumerServiceBinding, - this.providerDetails.build(), - this.credentials, - this.decryptionX509Credentials, - this.signingX509Credentials - ); + return new RelyingPartyRegistration(this.registrationId, this.entityId, + this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, + this.providerDetails.build(), this.credentials, this.decryptionX509Credentials, + this.signingX509Credentials); } + } - private static Saml2X509Credential fromDeprecated(org.springframework.security.saml2.credentials.Saml2X509Credential credential) { - PrivateKey privateKey = credential.getPrivateKey(); - X509Certificate certificate = credential.getCertificate(); - Set credentialTypes = new HashSet<>(); - if (credential.isSigningCredential()) { - credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.SIGNING); - } - if (credential.isSignatureVerficationCredential()) { - credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.VERIFICATION); - } - if (credential.isEncryptionCredential()) { - credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION); - } - if (credential.isDecryptionCredential()) { - credentialTypes.add(Saml2X509Credential.Saml2X509CredentialType.DECRYPTION); - } - return new Saml2X509Credential(privateKey, certificate, credentialTypes); - } - - private static org.springframework.security.saml2.credentials.Saml2X509Credential toDeprecated(Saml2X509Credential credential) { - PrivateKey privateKey = credential.getPrivateKey(); - X509Certificate certificate = credential.getCertificate(); - Set credentialTypes = new HashSet<>(); - if (credential.isSigningCredential()) { - credentialTypes.add(org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING); - } - if (credential.isVerificationCredential()) { - credentialTypes.add(org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION); - } - if (credential.isEncryptionCredential()) { - credentialTypes.add(org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION); - } - if (credential.isDecryptionCredential()) { - credentialTypes.add(org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION); - } - return new org.springframework.security.saml2.credentials.Saml2X509Credential(privateKey, certificate, credentialTypes); - } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java index 7f2d6943f5..1c681d92a3 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java @@ -19,15 +19,14 @@ package org.springframework.security.saml2.provider.service.registration; /** * A repository for {@link RelyingPartyRegistration}s * - * @since 5.2 * @author Filip Hanik + * @since 5.2 */ public interface RelyingPartyRegistrationRepository { /** - * Returns the relying party registration identified by the provided {@code registrationId}, - * or {@code null} if not found. - * + * Returns the relying party registration identified by the provided + * {@code registrationId}, or {@code null} if not found. * @param registrationId the registration identifier * @return the {@link RelyingPartyRegistration} if found, otherwise {@code null} */ diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrations.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrations.java index 29b9363afb..05e8695fb4 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrations.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrations.java @@ -30,12 +30,16 @@ import org.springframework.web.client.RestTemplate; * @since 5.4 */ public final class RelyingPartyRegistrations { - private static final RestOperations rest = new RestTemplate - (Arrays.asList(new OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter())); + + private static final RestOperations rest = new RestTemplate( + Arrays.asList(new OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter())); + + private RelyingPartyRegistrations() { + } /** - * Return a {@link RelyingPartyRegistration.Builder} based off of the given - * SAML 2.0 Asserting Party (IDP) metadata. + * Return a {@link RelyingPartyRegistration.Builder} based off of the given SAML 2.0 + * Asserting Party (IDP) metadata. * * Note that by default the registrationId is set to be the given metadata location, * but this will most often not be sufficient. To complete the configuration, most @@ -48,21 +52,23 @@ public final class RelyingPartyRegistrations { * .build(); * * - * Also note that an {@code IDPSSODescriptor} typically only contains information about - * the asserting party. Thus, you will need to remember to still populate anything about the - * relying party, like any private keys the relying party will use for signing AuthnRequests. - * + * Also note that an {@code IDPSSODescriptor} typically only contains information + * about the asserting party. Thus, you will need to remember to still populate + * anything about the relying party, like any private keys the relying party will use + * for signing AuthnRequests. * @param metadataLocation * @return the {@link RelyingPartyRegistration.Builder} for further configuration */ public static RelyingPartyRegistration.Builder fromMetadataLocation(String metadataLocation) { try { return rest.getForObject(metadataLocation, RelyingPartyRegistration.Builder.class); - } catch (RestClientException e) { - if (e.getCause() instanceof Saml2Exception) { - throw (Saml2Exception) e.getCause(); + } + catch (RestClientException ex) { + if (ex.getCause() instanceof Saml2Exception) { + throw (Saml2Exception) ex.getCause(); } - throw new Saml2Exception(e); + throw new Saml2Exception(ex); } } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/Saml2MessageBinding.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/Saml2MessageBinding.java index 154dcc88f4..958f608e7e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/Saml2MessageBinding.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/Saml2MessageBinding.java @@ -17,17 +17,17 @@ package org.springframework.security.saml2.provider.service.registration; /** - * The type of bindings that messages are exchanged using - * Supported bindings are {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST} - * and {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect}. - * In addition there is support for {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect} - * with an XML signature in the message rather than query parameters. + * The type of bindings that messages are exchanged using Supported bindings are + * {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST} and + * {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect}. In addition there is + * support for {@code urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect} with an XML + * signature in the message rather than query parameters. * @since 5.3 */ public enum Saml2MessageBinding { - POST("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"), - REDIRECT("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"); + POST("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"), REDIRECT( + "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"); private final String urn; @@ -40,6 +40,7 @@ public enum Saml2MessageBinding { * @return URN value representing this binding */ public String getUrn() { - return urn; + return this.urn; } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java new file mode 100644 index 0000000000..d6544958aa --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.servlet.filter; + +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * @since 5.3 + */ +final class Saml2ServletUtils { + + private static final char PATH_DELIMITER = '/'; + + private Saml2ServletUtils() { + } + + static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { + if (!StringUtils.hasText(template)) { + return baseUrl; + } + String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); + String registrationId = relyingParty.getRegistrationId(); + Map uriVariables = new HashMap<>(); + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", (host != null) ? host : ""); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", (port != -1) ? ":" + port : ""); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path)) { + if (path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + } + uriVariables.put("basePath", (path != null) ? path : ""); + uriVariables.put("baseUrl", uriComponents.toUriString()); + uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); + uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); + return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString(); + } + + static String getApplicationUri(HttpServletRequest request) { + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + .replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); + return uriComponents.toUriString(); + } + +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index c073ff0092..2c2f833c83 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.saml2.core.Saml2Error; +import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; @@ -30,9 +31,7 @@ import org.springframework.security.web.authentication.AbstractAuthenticationPro import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy; import org.springframework.util.Assert; - -import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND; -import static org.springframework.util.StringUtils.hasText; +import org.springframework.util.StringUtils; /** * @since 5.2 @@ -40,12 +39,14 @@ import static org.springframework.util.StringUtils.hasText; public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter { public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}"; + private final AuthenticationConverter authenticationConverter; /** - * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured - * to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL - * @param relyingPartyRegistrationRepository - repository of configured SAML 2 entities. Required. + * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is + * configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL + * @param relyingPartyRegistrationRepository - repository of configured SAML 2 + * entities. Required. */ public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { this(relyingPartyRegistrationRepository, DEFAULT_FILTER_PROCESSES_URI); @@ -53,35 +54,32 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce /** * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter - * @param relyingPartyRegistrationRepository - repository of configured SAML 2 entities. Required. - * @param filterProcessesUrl the processing URL, must contain a {registrationId} variable. Required. + * @param relyingPartyRegistrationRepository - repository of configured SAML 2 + * entities. Required. + * @param filterProcessesUrl the processing URL, must contain a {registrationId} + * variable. Required. */ - public Saml2WebSsoAuthenticationFilter( - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, String filterProcessesUrl) { - this(new Saml2AuthenticationTokenConverter - (new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), - filterProcessesUrl); + this(new Saml2AuthenticationTokenConverter( + new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), filterProcessesUrl); } /** * Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters - * - * @param authenticationConverter the strategy for converting an {@link HttpServletRequest} - * into an {@link Authentication} - * @param filterProcessingUrl the processing URL, must contain a {registrationId} variable + * @param authenticationConverter the strategy for converting an + * {@link HttpServletRequest} into an {@link Authentication} + * @param filterProcessingUrl the processing URL, must contain a {registrationId} + * variable * @since 5.4 */ - public Saml2WebSsoAuthenticationFilter( - AuthenticationConverter authenticationConverter, + public Saml2WebSsoAuthenticationFilter(AuthenticationConverter authenticationConverter, String filterProcessingUrl) { super(filterProcessingUrl); Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); Assert.hasText(filterProcessingUrl, "filterProcessesUrl must contain a URL pattern"); - Assert.isTrue( - filterProcessingUrl.contains("{registrationId}"), - "filterProcessesUrl must contain a {registrationId} match variable" - ); + Assert.isTrue(filterProcessingUrl.contains("{registrationId}"), + "filterProcessesUrl must contain a {registrationId} match variable"); this.authenticationConverter = authenticationConverter; setAllowSessionCreation(true); setSessionAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy()); @@ -89,7 +87,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce @Override protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response) { - return (super.requiresAuthentication(request, response) && hasText(request.getParameter("SAMLResponse"))); + return (super.requiresAuthentication(request, response) + && StringUtils.hasText(request.getParameter("SAMLResponse"))); } @Override @@ -97,10 +96,11 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce throws AuthenticationException { Authentication authentication = this.authenticationConverter.convert(request); if (authentication == null) { - Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND, + Saml2Error saml2Error = new Saml2Error(Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND, "No relying party registration found"); throw new Saml2AuthenticationException(saml2Error); } return getAuthenticationManager().authenticate(authentication); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 3fa3e9522c..731c6a3c66 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -17,6 +17,8 @@ package org.springframework.security.saml2.provider.service.servlet.filter; import java.io.IOException; +import java.nio.charset.StandardCharsets; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -43,55 +45,58 @@ import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; -import static java.nio.charset.StandardCharsets.ISO_8859_1; - /** * This {@code Filter} formulates a - * SAML 2.0 AuthnRequest (line 1968) - * and redirects to a configured asserting party. + * SAML 2.0 + * AuthnRequest (line 1968) and redirects to a configured asserting party. * *

        - * It supports the - * HTTP-Redirect (line 520) - * and - * HTTP-POST (line 753) - * bindings. + * It supports the HTTP-Redirect + * (line 520) and HTTP-POST + * (line 753) bindings. * *

        - * By default, this {@code Filter} responds to authentication requests - * at the {@code URI} {@code /oauth2/authorization/{registrationId}}. - * The {@code URI} template variable {@code {registrationId}} represents the - * {@link RelyingPartyRegistration#getRegistrationId() registration identifier} of the relying party - * that is used for initiating the authentication request. + * By default, this {@code Filter} responds to authentication requests at the {@code URI} + * {@code /oauth2/authorization/{registrationId}}. The {@code URI} template variable + * {@code {registrationId}} represents the + * {@link RelyingPartyRegistration#getRegistrationId() registration identifier} of the + * relying party that is used for initiating the authentication request. * - * @since 5.2 * @author Filip Hanik * @author Josh Cummings + * @since 5.2 */ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter { private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver; + private Saml2AuthenticationRequestFactory authenticationRequestFactory; private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); /** - * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters - * - * @param relyingPartyRegistrationRepository a repository for relying party configurations - * @deprecated use the constructor that takes a {@link Saml2AuthenticationRequestFactory} + * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided + * parameters + * @param relyingPartyRegistrationRepository a repository for relying party + * configurations + * @deprecated use the constructor that takes a + * {@link Saml2AuthenticationRequestFactory} */ @Deprecated - public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + public Saml2WebSsoAuthenticationRequestFilter( + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { this(new DefaultSaml2AuthenticationRequestContextResolver( new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory()); } /** - * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters - * - * @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext} + * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided + * parameters + * @param authenticationRequestContextResolver a strategy for formulating a + * {@link Saml2AuthenticationRequestContext} * @since 5.4 */ public Saml2WebSsoAuthenticationRequestFilter( @@ -105,9 +110,10 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter } /** - * Use the given {@link Saml2AuthenticationRequestFactory} for formulating the SAML 2.0 AuthnRequest - * - * @param authenticationRequestFactory the {@link Saml2AuthenticationRequestFactory} to use + * Use the given {@link Saml2AuthenticationRequestFactory} for formulating the SAML + * 2.0 AuthnRequest + * @param authenticationRequestFactory the {@link Saml2AuthenticationRequestFactory} + * to use * @deprecated use the constructor instead */ @Deprecated @@ -118,7 +124,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter /** * Use the given {@link RequestMatcher} that activates this filter for a given request - * * @param redirectMatcher the {@link RequestMatcher} to use */ public void setRedirectMatcher(RequestMatcher redirectMatcher) { @@ -126,13 +131,9 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter this.redirectMatcher = redirectMatcher; } - /** - * {@inheritDoc} - */ @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - MatchResult matcher = this.redirectMatcher.matcher(request); if (!matcher.isMatch()) { filterChain.doFilter(request, response); @@ -147,41 +148,37 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration(); if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { sendRedirect(response, context); - } else { + } + else { sendPost(response, context); } } private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException { - Saml2RedirectAuthenticationRequest authenticationRequest = - this.authenticationRequestFactory.createRedirectAuthenticationRequest(context); + Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory + .createRedirectAuthenticationRequest(context); UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(authenticationRequest.getAuthenticationRequestUri()); addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder); addParameter("RelayState", authenticationRequest.getRelayState(), uriBuilder); addParameter("SigAlg", authenticationRequest.getSigAlg(), uriBuilder); addParameter("Signature", authenticationRequest.getSignature(), uriBuilder); - String redirectUrl = uriBuilder - .build(true) - .toUriString(); + String redirectUrl = uriBuilder.build(true).toUriString(); response.sendRedirect(redirectUrl); } private void addParameter(String name, String value, UriComponentsBuilder builder) { Assert.hasText(name, "name cannot be empty or null"); if (StringUtils.hasText(value)) { - builder.queryParam( - UriUtils.encode(name, ISO_8859_1), - UriUtils.encode(value, ISO_8859_1) - ); + builder.queryParam(UriUtils.encode(name, StandardCharsets.ISO_8859_1), + UriUtils.encode(value, StandardCharsets.ISO_8859_1)); } } - private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context) - throws IOException { - Saml2PostAuthenticationRequest authenticationRequest = - this.authenticationRequestFactory.createPostAuthenticationRequest(context); + private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException { + Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory + .createPostAuthenticationRequest(context); String html = createSamlPostRequestFormData(authenticationRequest); response.setContentType(MediaType.TEXT_HTML_VALUE); response.getWriter().write(html); @@ -191,42 +188,42 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter String authenticationRequestUri = authenticationRequest.getAuthenticationRequestUri(); String relayState = authenticationRequest.getRelayState(); String samlRequest = authenticationRequest.getSamlRequest(); - StringBuilder postHtml = new StringBuilder() - .append("\n") - .append("\n") - .append(" \n") - .append(" \n") - .append(" \n") - .append(" \n") - .append("

        \n") - .append(" Note: Since your browser does not support JavaScript,\n") - .append(" you must press the Continue button once to proceed.\n") - .append("

        \n") - .append(" \n") - .append(" \n") - .append("
        \n") - .append("
        \n") - .append(" \n"); + StringBuilder html = new StringBuilder(); + html.append("\n"); + html.append("\n").append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append("
        \n"); + html.append(" \n"); if (StringUtils.hasText(relayState)) { - postHtml - .append(" \n"); + html.append(" \n"); } - postHtml - .append("
        \n") - .append(" \n") - .append(" \n") - .append(" \n") - .append(" \n") - .append(""); - return postHtml.toString(); + html.append("
        \n"); + html.append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append(" \n"); + html.append(""); + return html.toString(); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java index 3768233fdf..10b667847c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java @@ -19,11 +19,13 @@ package org.springframework.security.saml2.provider.service.web; import java.util.HashMap; import java.util.Map; import java.util.function.Function; + import javax.servlet.http.HttpServletRequest; import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -31,17 +33,13 @@ import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; -import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; -import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; -import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; - /** * A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the - * registration id from the request, querying a {@link RelyingPartyRegistrationRepository}, - * and resolving any template values. + * registration id from the request, querying a + * {@link RelyingPartyRegistrationRepository}, and resolving any template values. * - * @since 5.4 * @author Josh Cummings + * @since 5.4 */ public final class DefaultRelyingPartyRegistrationResolver implements Converter { @@ -49,11 +47,11 @@ public final class DefaultRelyingPartyRegistrationResolver private static final char PATH_DELIMITER = '/'; private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final Converter registrationIdResolver = new RegistrationIdResolver(); - public DefaultRelyingPartyRegistrationResolver - (RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { - + public DefaultRelyingPartyRegistrationResolver( + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; } @@ -64,66 +62,57 @@ public final class DefaultRelyingPartyRegistrationResolver if (registrationId == null) { return null; } - RelyingPartyRegistration relyingPartyRegistration = - this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository + .findByRegistrationId(registrationId); if (relyingPartyRegistration == null) { return null; } - String applicationUri = getApplicationUri(request); Function templateResolver = templateResolver(applicationUri, relyingPartyRegistration); String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId()); - String assertionConsumerServiceLocation = templateResolver.apply( - relyingPartyRegistration.getAssertionConsumerServiceLocation()); - return withRelyingPartyRegistration(relyingPartyRegistration) - .entityId(relyingPartyEntityId) - .assertionConsumerServiceLocation(assertionConsumerServiceLocation) + String assertionConsumerServiceLocation = templateResolver + .apply(relyingPartyRegistration.getAssertionConsumerServiceLocation()); + return RelyingPartyRegistration.withRelyingPartyRegistration(relyingPartyRegistration) + .entityId(relyingPartyEntityId).assertionConsumerServiceLocation(assertionConsumerServiceLocation) .build(); } private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { - return template -> resolveUrlTemplate(template, applicationUri, relyingParty); + return (template) -> resolveUrlTemplate(template, applicationUri, relyingParty); } private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); String registrationId = relyingParty.getRegistrationId(); Map uriVariables = new HashMap<>(); - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) - .replaceQuery(null) - .fragment(null) + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null) .build(); String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", scheme == null ? "" : scheme); + uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); String host = uriComponents.getHost(); - uriVariables.put("baseHost", host == null ? "" : host); + uriVariables.put("baseHost", (host != null) ? host : ""); // following logic is based on HierarchicalUriComponents#toUriString() int port = uriComponents.getPort(); - uriVariables.put("basePort", port == -1 ? "" : ":" + port); + uriVariables.put("basePort", (port == -1) ? "" : ":" + port); String path = uriComponents.getPath(); if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { path = PATH_DELIMITER + path; } - uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("basePath", (path != null) ? path : ""); uriVariables.put("baseUrl", uriComponents.toUriString()); uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); - - return UriComponentsBuilder.fromUriString(template) - .buildAndExpand(uriVariables) - .toUriString(); + return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString(); } private static String getApplicationUri(HttpServletRequest request) { - UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) - .replacePath(request.getContextPath()) - .replaceQuery(null) - .fragment(null) - .build(); + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + .replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); return uriComponents.toUriString(); } private static class RegistrationIdResolver implements Converter { + private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); @Override @@ -131,5 +120,7 @@ public final class DefaultRelyingPartyRegistrationResolver RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); return result.getVariables().get("registrationId"); } + } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java index b9d15b7860..a6cdb3ed91 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java @@ -27,27 +27,26 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.util.Assert; /** - * The default implementation for {@link Saml2AuthenticationRequestContextResolver} - * which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext} + * The default implementation for {@link Saml2AuthenticationRequestContextResolver} which + * uses the current request and given relying party to formulate a + * {@link Saml2AuthenticationRequestContext} * * @author Shazin Sadakath * @author Josh Cummings * @since 5.4 */ -public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver { +public final class DefaultSaml2AuthenticationRequestContextResolver + implements Saml2AuthenticationRequestContextResolver { private final Log logger = LogFactory.getLog(getClass()); private final Converter relyingPartyRegistrationResolver; - public DefaultSaml2AuthenticationRequestContextResolver - (Converter relyingPartyRegistrationResolver) { + public DefaultSaml2AuthenticationRequestContextResolver( + Converter relyingPartyRegistrationResolver) { this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; } - /** - * {@inheritDoc} - */ @Override public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); @@ -56,20 +55,19 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S return null; } if (this.logger.isDebugEnabled()) { - this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + - relyingParty.getRegistrationId() + "]"); + this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + + relyingParty.getRegistrationId() + "]"); } return createRedirectAuthenticationRequestContext(request, relyingParty); } - private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( - HttpServletRequest request, RelyingPartyRegistration relyingParty) { + private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(HttpServletRequest request, + RelyingPartyRegistration relyingParty) { - return Saml2AuthenticationRequestContext.builder() - .issuer(relyingParty.getEntityId()) + return Saml2AuthenticationRequestContext.builder().issuer(relyingParty.getEntityId()) .relyingPartyRegistration(relyingParty) .assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation()) - .relayState(request.getParameter("RelayState")) - .build(); + .relayState(request.getParameter("RelayState")).build(); } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java index db24c8ff90..233d93b23d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java @@ -22,7 +22,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A /** * This {@code Saml2AuthenticationRequestContextResolver} formulates a - * SAML 2.0 AuthnRequest (line 1968) + * SAML 2.0 + * AuthnRequest (line 1968) * * @author Shazin Sadakath * @author Josh Cummings @@ -31,11 +32,11 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A public interface Saml2AuthenticationRequestContextResolver { /** - * This {@code resolve} method is defined to create a {@link Saml2AuthenticationRequestContext} - * - * + * This {@code resolve} method is defined to create a + * {@link Saml2AuthenticationRequestContext} * @param request the current request * @return the created {@link Saml2AuthenticationRequestContext} for the request */ Saml2AuthenticationRequestContext resolve(HttpServletRequest request); + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java index 632ef955b9..bcce7e6fa8 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java @@ -18,8 +18,10 @@ package org.springframework.security.saml2.provider.service.web; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; + import javax.servlet.http.HttpServletRequest; import org.apache.commons.codec.binary.Base64; @@ -32,36 +34,32 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.util.Assert; -import static java.nio.charset.StandardCharsets.UTF_8; - /** - * An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} appropriate - * for authenticated a SAML 2.0 Assertion against an + * An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken} + * appropriate for authenticated a SAML 2.0 Assertion against an * {@link org.springframework.security.authentication.AuthenticationManager}. * * @author Josh Cummings * @since 5.4 */ public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter { - private static Base64 BASE64 = new Base64(0, new byte[]{'\n'}); + + private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }); private final Converter relyingPartyRegistrationResolver; /** - * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for resolving + * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for + * resolving {@link RelyingPartyRegistration}s + * @param relyingPartyRegistrationResolver the strategy for resolving * {@link RelyingPartyRegistration}s - * - * @param relyingPartyRegistrationResolver the strategy for resolving {@link RelyingPartyRegistration}s */ - public Saml2AuthenticationTokenConverter - (Converter relyingPartyRegistrationResolver) { + public Saml2AuthenticationTokenConverter( + Converter relyingPartyRegistrationResolver) { Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null"); this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; } - /** - * {@inheritDoc} - */ @Override public Saml2AuthenticationToken convert(HttpServletRequest request) { RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request); @@ -81,9 +79,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo if (HttpMethod.GET.matches(request.getMethod())) { return samlInflate(b); } - else { - return new String(b, UTF_8); - } + return new String(b, StandardCharsets.UTF_8); } private byte[] samlDecode(String s) { @@ -93,13 +89,14 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo private String samlInflate(byte[] b) { try { ByteArrayOutputStream out = new ByteArrayOutputStream(); - InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); - iout.write(b); - iout.finish(); - return new String(out.toByteArray(), UTF_8); + InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true)); + inflaterOutputStream.write(b); + inflaterOutputStream.finish(); + return new String(out.toByteArray(), StandardCharsets.UTF_8); } - catch (IOException e) { - throw new Saml2Exception("Unable to inflate string", e); + catch (IOException ex) { + throw new Saml2Exception("Unable to inflate string", ex); } } + } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java index 50cb27803c..9e328cb6c3 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java @@ -17,6 +17,7 @@ package org.springframework.security.saml2.provider.service.web; import java.io.IOException; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -42,6 +43,7 @@ import org.springframework.web.filter.OncePerRequestFilter; public final class Saml2MetadataFilter extends OncePerRequestFilter { private final Converter relyingPartyRegistrationConverter; + private final Saml2MetadataResolver saml2MetadataResolver; private RequestMatcher requestMatcher = new AntPathRequestMatcher( @@ -58,20 +60,16 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws ServletException, IOException { - RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request); if (!matcher.isMatch()) { chain.doFilter(request, response); return; } - - RelyingPartyRegistration relyingPartyRegistration = - this.relyingPartyRegistrationConverter.convert(request); + RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request); if (relyingPartyRegistration == null) { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); return; } - String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration); String registrationId = relyingPartyRegistration.getRegistrationId(); writeMetadataToResponse(response, registrationId, metadata); @@ -79,7 +77,6 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata) throws IOException { - response.setContentType(MediaType.APPLICATION_XML_VALUE); response.setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"saml-" + registrationId + "-metadata.xml\""); @@ -88,13 +85,13 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter { } /** - * Set the {@link RequestMatcher} that determines whether this filter should - * handle the incoming {@link HttpServletRequest} - * + * Set the {@link RequestMatcher} that determines whether this filter should handle + * the incoming {@link HttpServletRequest} * @param requestMatcher */ public void setRequestMatcher(RequestMatcher requestMatcher) { Assert.notNull(requestMatcher, "requestMatcher cannot be null"); this.requestMatcher = requestMatcher; } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java index 5a09bf70d9..7e7cb5a0d1 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/OpenSamlInitializationServiceTests.java @@ -23,7 +23,7 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.springframework.security.saml2.Saml2Exception; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link OpenSamlInitializationService} @@ -37,8 +37,9 @@ public class OpenSamlInitializationServiceTests { OpenSamlInitializationService.initialize(); XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); assertThat(registry.getParserPool()).isNotNull(); - assertThatCode(() -> OpenSamlInitializationService.requireInitialize(r -> {})) - .isInstanceOf(Saml2Exception.class) - .hasMessageContaining("OpenSAML was already initialized previously"); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSamlInitializationService.requireInitialize((r) -> { + })).withMessageContaining("OpenSAML was already initialized previously"); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java index fa96940f13..8c15cd58c5 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.saml2.core; import org.junit.Test; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for verifying {@link Saml2ResponseValidatorResult} @@ -26,8 +27,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * @author Josh Cummings */ public class Saml2ResponseValidatorResultTests { - private static final Saml2Error DETAIL = new Saml2Error( - "error", "description"); + + private static final Saml2Error DETAIL = new Saml2Error("error", "description"); @Test public void successWhenInvokedThenReturnsSuccessfulResult() { @@ -64,9 +65,7 @@ public class Saml2ResponseValidatorResultTests { @Test public void concatResultWhenInvokedThenReturnsCopyContainingAll() { Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL); - Saml2ResponseValidatorResult merged = failure - .concat(failure) - .concat(failure); + Saml2ResponseValidatorResult merged = failure.concat(failure).concat(failure); assertThat(merged.hasErrors()).isTrue(); assertThat(merged.getErrors()).containsExactly(DETAIL, DETAIL, DETAIL); @@ -75,15 +74,22 @@ public class Saml2ResponseValidatorResultTests { @Test public void concatErrorWhenNullThenIllegalArgument() { - assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL) - .concat((Saml2Error) null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL) + .concat((Saml2Error) null) + ); + // @formatter:on } @Test public void concatResultWhenNullThenIllegalArgument() { - assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL) - .concat((Saml2ResponseValidatorResult) null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL) + .concat((Saml2ResponseValidatorResult) null) + ); + // @formatter:on } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java index 3de1cecc9b..a518b911a3 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2Utils.java @@ -18,6 +18,7 @@ package org.springframework.security.saml2.core; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import java.util.zip.Inflater; @@ -27,12 +28,12 @@ import org.apache.commons.codec.binary.Base64; import org.springframework.security.saml2.Saml2Exception; -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.zip.Deflater.DEFLATED; - public final class Saml2Utils { - private static Base64 BASE64 = new Base64(0, new byte[]{'\n'}); + private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }); + + private Saml2Utils() { + } public static String samlEncode(byte[] b) { return BASE64.encodeAsString(b); @@ -44,27 +45,29 @@ public final class Saml2Utils { public static byte[] samlDeflate(String s) { try { - ByteArrayOutputStream b = new ByteArrayOutputStream(); - DeflaterOutputStream deflater = new DeflaterOutputStream(b, new Deflater(DEFLATED, true)); - deflater.write(s.getBytes(UTF_8)); - deflater.finish(); - return b.toByteArray(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(out, + new Deflater(Deflater.DEFLATED, true)); + deflaterOutputStream.write(s.getBytes(StandardCharsets.UTF_8)); + deflaterOutputStream.finish(); + return out.toByteArray(); } - catch (IOException e) { - throw new Saml2Exception("Unable to deflate string", e); + catch (IOException ex) { + throw new Saml2Exception("Unable to deflate string", ex); } } public static String samlInflate(byte[] b) { try { ByteArrayOutputStream out = new ByteArrayOutputStream(); - InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true)); - iout.write(b); - iout.finish(); - return new String(out.toByteArray(), UTF_8); + InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true)); + inflaterOutputStream.write(b); + inflaterOutputStream.finish(); + return new String(out.toByteArray(), StandardCharsets.UTF_8); } - catch (IOException e) { - throw new Saml2Exception("Unable to inflate string", e); + catch (IOException ex) { + throw new Saml2Exception("Unable to inflate string", ex); } } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java index e3f7b32109..5f41ee0e08 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java @@ -17,6 +17,7 @@ package org.springframework.security.saml2.core; import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.security.PrivateKey; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; @@ -27,12 +28,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.springframework.security.converter.RsaKeyConverters; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.SIGNING; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION; +import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; public class Saml2X509CredentialTests { @@ -40,159 +36,161 @@ public class Saml2X509CredentialTests { public ExpectedException exception = ExpectedException.none(); private PrivateKey key; + private X509Certificate certificate; @Before public void setup() throws Exception { - String keyData = "-----BEGIN PRIVATE KEY-----\n" + - "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + - "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + - "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + - "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + - "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + - "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + - "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + - "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + - "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + - "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + - "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + - "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + - "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + - "INrtuLp4YHbgk1mi\n" + - "-----END PRIVATE KEY-----"; - key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(UTF_8))); + String keyData = "-----BEGIN PRIVATE KEY-----\n" + + "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + + "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + + "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + + "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + + "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + + "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + + "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + + "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + + "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + + "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + + "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + + "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + + "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + "INrtuLp4YHbgk1mi\n" + + "-----END PRIVATE KEY-----"; + this.key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(StandardCharsets.UTF_8))); final CertificateFactory factory = CertificateFactory.getInstance("X.509"); - String certificateData = "-----BEGIN CERTIFICATE-----\n" + - "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + - "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + - "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + - "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + - "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + - "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + - "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + - "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + - "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + - "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + - "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + - "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + - "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + - "-----END CERTIFICATE-----"; - certificate = (X509Certificate) factory - .generateCertificate(new ByteArrayInputStream(certificateData.getBytes(UTF_8))); + String certificateData = "-----BEGIN CERTIFICATE-----\n" + + "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + + "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + + "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + + "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + + "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + + "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + + "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + + "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + + "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + + "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + + "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + + "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----"; + this.certificate = (X509Certificate) factory + .generateCertificate(new ByteArrayInputStream(certificateData.getBytes(StandardCharsets.UTF_8))); } @Test public void constructorWhenRelyingPartyWithCredentialsThenItSucceeds() { - new Saml2X509Credential(key, certificate, SIGNING); - new Saml2X509Credential(key, certificate, SIGNING, DECRYPTION); - new Saml2X509Credential(key, certificate, DECRYPTION); - Saml2X509Credential.signing(key, certificate); - Saml2X509Credential.decryption(key, certificate); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING, + Saml2X509CredentialType.DECRYPTION); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.DECRYPTION); + Saml2X509Credential.signing(this.key, this.certificate); + Saml2X509Credential.decryption(this.key, this.certificate); } @Test public void constructorWhenAssertingPartyWithCredentialsThenItSucceeds() { - new Saml2X509Credential(certificate, VERIFICATION); - new Saml2X509Credential(certificate, VERIFICATION, ENCRYPTION); - new Saml2X509Credential(certificate, ENCRYPTION); - Saml2X509Credential.verification(certificate); - Saml2X509Credential.encryption(certificate); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION, + Saml2X509CredentialType.ENCRYPTION); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.ENCRYPTION); + Saml2X509Credential.verification(this.certificate); + Saml2X509Credential.encryption(this.certificate); } @Test public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, (X509Certificate) null, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, certificate, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenRelyingPartyWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(key, null, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenAssertingPartyWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(key, certificate, ENCRYPTION); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION); } @Test public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(key, certificate, VERIFICATION); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION); } @Test public void constructorWhenAssertingPartyWithSigningUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(certificate, SIGNING); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(certificate, DECRYPTION); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION); } @Test public void factoryWhenRelyingPartyForSigningWithoutCredentialsThenItFails() { - exception.expect(IllegalArgumentException.class); + this.exception.expect(IllegalArgumentException.class); Saml2X509Credential.signing(null, null); } @Test public void factoryWhenRelyingPartyForSigningWithoutPrivateKeyThenItFails() { - exception.expect(IllegalArgumentException.class); - Saml2X509Credential.signing(null, certificate); + this.exception.expect(IllegalArgumentException.class); + Saml2X509Credential.signing(null, this.certificate); } @Test public void factoryWhenRelyingPartyForSigningWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); - Saml2X509Credential.signing(key, null); + this.exception.expect(IllegalArgumentException.class); + Saml2X509Credential.signing(this.key, null); } @Test public void factoryWhenRelyingPartyForDecryptionWithoutCredentialsThenItFails() { - exception.expect(IllegalArgumentException.class); + this.exception.expect(IllegalArgumentException.class); Saml2X509Credential.decryption(null, null); } @Test public void factoryWhenRelyingPartyForDecryptionWithoutPrivateKeyThenItFails() { - exception.expect(IllegalArgumentException.class); - Saml2X509Credential.decryption(null, certificate); + this.exception.expect(IllegalArgumentException.class); + Saml2X509Credential.decryption(null, this.certificate); } @Test public void factoryWhenRelyingPartyForDecryptionWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); - Saml2X509Credential.decryption(key, null); + this.exception.expect(IllegalArgumentException.class); + Saml2X509Credential.decryption(this.key, null); } @Test public void factoryWhenAssertingPartyForVerificationWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); + this.exception.expect(IllegalArgumentException.class); Saml2X509Credential.verification(null); } @Test public void factoryWhenAssertingPartyForEncryptionWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); + this.exception.expect(IllegalArgumentException.class); Saml2X509Credential.encryption(null); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java index f2d711882e..55cd6b53b9 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/TestSaml2X509Credentials.java @@ -17,6 +17,7 @@ package org.springframework.security.saml2.core; import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.security.KeyException; import java.security.PrivateKey; import java.security.cert.CertificateException; @@ -26,154 +27,147 @@ import java.security.cert.X509Certificate; import org.opensaml.security.crypto.KeySupport; import org.springframework.security.saml2.Saml2Exception; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.SIGNING; -import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION; +import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; public final class TestSaml2X509Credentials { + + private TestSaml2X509Credentials() { + } + public static Saml2X509Credential assertingPartySigningCredential() { - return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING); + return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING); } public static Saml2X509Credential assertingPartyEncryptingCredential() { - return new Saml2X509Credential(spCertificate(), ENCRYPTION); + return new Saml2X509Credential(spCertificate(), Saml2X509CredentialType.ENCRYPTION); } public static Saml2X509Credential assertingPartyPrivateCredential() { - return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING, DECRYPTION); + return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING, + Saml2X509CredentialType.DECRYPTION); } public static Saml2X509Credential relyingPartyVerifyingCredential() { - return new Saml2X509Credential(idpCertificate(), VERIFICATION); + return new Saml2X509Credential(idpCertificate(), Saml2X509CredentialType.VERIFICATION); } public static Saml2X509Credential relyingPartySigningCredential() { - return new Saml2X509Credential(spPrivateKey(), spCertificate(), SIGNING); + return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.SIGNING); } public static Saml2X509Credential relyingPartyDecryptingCredential() { - return new Saml2X509Credential(spPrivateKey(), spCertificate(), DECRYPTION); + return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.DECRYPTION); } private static X509Certificate certificate(String cert) { ByteArrayInputStream certBytes = new ByteArrayInputStream(cert.getBytes()); try { - return (X509Certificate) CertificateFactory - .getInstance("X.509") - .generateCertificate(certBytes); + return (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(certBytes); } - catch (CertificateException e) { - throw new Saml2Exception(e); + catch (CertificateException ex) { + throw new Saml2Exception(ex); } } private static PrivateKey privateKey(String key) { try { - return KeySupport.decodePrivateKey(key.getBytes(UTF_8), new char[0]); + return KeySupport.decodePrivateKey(key.getBytes(StandardCharsets.UTF_8), new char[0]); } - catch (KeyException e) { - throw new Saml2Exception(e); + catch (KeyException ex) { + throw new Saml2Exception(ex); } } - private static X509Certificate idpCertificate() { - return certificate("-----BEGIN CERTIFICATE-----\n" - + "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" - + "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" - + "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" - + "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" - + "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" - + "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" - + "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" - + "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" - + "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" - + "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" - + "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" - + "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" - + "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" - + "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" - + "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" - + "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" - + "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" - + "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" - + "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" - + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" - + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" - + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" - + "-----END CERTIFICATE-----\n"); + private static X509Certificate idpCertificate() { + return certificate( + "-----BEGIN CERTIFICATE-----\n" + "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" + + "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" + + "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" + + "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" + + "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" + + "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" + + "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" + + "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" + + "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" + + "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" + + "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" + + "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" + + "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" + + "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" + + "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" + + "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" + + "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" + + "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" + + "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" + + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" + + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" + + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + "-----END CERTIFICATE-----\n"); } - private static PrivateKey idpPrivateKey() { - return privateKey("-----BEGIN PRIVATE KEY-----\n" - + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4cn62E1xLqpN3\n" - + "4PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz2ZivLwZX\n" - + "W+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWWRDodcoHE\n" - + "fDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQnX8Ttl7h\n" - + "Z6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5cljz0X/T\n" - + "Xy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gphiJH3jvZ7\n" - + "I+J5lS8VAgMBAAECggEBAKyxBlIS7mcp3chvq0RF7B3PHFJMMzkwE+t3pLJcs4cZ\n" - + "nezh/KbREfP70QjXzk/llnZCvxeIs5vRu24vbdBm79qLHqBuHp8XfHHtuo2AfoAQ\n" - + "l4h047Xc/+TKMivnPQ0jX9qqndKDLqZDf5wnbslDmlskvF0a/MjsLU0TxtOfo+dB\n" - + "t55FW11cGqxZwhS5Gnr+cbw3OkHz23b9gEOt9qfwPVepeysbmm9FjU+k4yVa7rAN\n" - + "xcbzVb6Y7GCITe2tgvvEHmjB9BLmWrH3mZ3Af17YU/iN6TrpPd6Sj3QoS+2wGtAe\n" - + "HbUs3CKJu7bIHcj4poal6Kh8519S+erJTtqQ8M0ZiEECgYEA43hLYAPaUueFkdfh\n" - + "9K/7ClH6436CUH3VdizwUXi26fdhhV/I/ot6zLfU2mgEHU22LBECWQGtAFm8kv0P\n" - + "zPn+qjaR3e62l5PIlSYbnkIidzoDZ2ztu4jF5LgStlTJQPteFEGgZVl5o9DaSZOq\n" - + "Yd7G3XqXuQ1VGMW58G5FYJPtA1cCgYEAz5TPUtK+R2KXHMjUwlGY9AefQYRYmyX2\n" - + "Tn/OFgKvY8lpAkMrhPKONq7SMYc8E9v9G7A0dIOXvW7QOYSapNhKU+np3lUafR5F\n" - + "4ZN0bxZ9qjHbn3AMYeraKjeutHvlLtbHdIc1j3sxe/EzltRsYmiqLdEBW0p6hwWg\n" - + "tyGhYWVyaXMCgYAfDOKtHpmEy5nOCLwNXKBWDk7DExfSyPqEgSnk1SeS1HP5ctPK\n" - + "+1st6sIhdiVpopwFc+TwJWxqKdW18tlfT5jVv1E2DEnccw3kXilS9xAhWkfwrEvf\n" - + "V5I74GydewFl32o+NZ8hdo9GL1I8zO1rIq/et8dSOWGuWf9BtKu/vTGTTQKBgFxU\n" - + "VjsCnbvmsEwPUAL2hE/WrBFaKocnxXx5AFNt8lEyHtDwy4Sg1nygGcIJ4sD6koQk\n" - + "RdClT3LkvR04TAiSY80bN/i6ZcPNGUwSaDGZEWAIOSWbkwZijZNFnSGOEgxZX/IG\n" - + "yd39766vREEMTwEeiMNEOZQ/dmxkJm4OOVe25cLdAoGACOtPnq1Fxay80UYBf4rQ\n" - + "+bJ9yX1ulB8WIree1hD7OHSB2lRHxrVYWrglrTvkh63Lgx+EcsTV788OsvAVfPPz\n" - + "BZrn8SdDlQqalMxUBYEFwnsYD3cQ8yOUnijFVC4xNcdDv8OIqVgSk4KKxU5AshaA\n" + "xk6Mox+u8Cc2eAK12H13i+8=\n" - + "-----END PRIVATE KEY-----\n"); + return privateKey( + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4cn62E1xLqpN3\n" + + "4PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz2ZivLwZX\n" + + "W+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWWRDodcoHE\n" + + "fDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQnX8Ttl7h\n" + + "Z6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5cljz0X/T\n" + + "Xy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gphiJH3jvZ7\n" + + "I+J5lS8VAgMBAAECggEBAKyxBlIS7mcp3chvq0RF7B3PHFJMMzkwE+t3pLJcs4cZ\n" + + "nezh/KbREfP70QjXzk/llnZCvxeIs5vRu24vbdBm79qLHqBuHp8XfHHtuo2AfoAQ\n" + + "l4h047Xc/+TKMivnPQ0jX9qqndKDLqZDf5wnbslDmlskvF0a/MjsLU0TxtOfo+dB\n" + + "t55FW11cGqxZwhS5Gnr+cbw3OkHz23b9gEOt9qfwPVepeysbmm9FjU+k4yVa7rAN\n" + + "xcbzVb6Y7GCITe2tgvvEHmjB9BLmWrH3mZ3Af17YU/iN6TrpPd6Sj3QoS+2wGtAe\n" + + "HbUs3CKJu7bIHcj4poal6Kh8519S+erJTtqQ8M0ZiEECgYEA43hLYAPaUueFkdfh\n" + + "9K/7ClH6436CUH3VdizwUXi26fdhhV/I/ot6zLfU2mgEHU22LBECWQGtAFm8kv0P\n" + + "zPn+qjaR3e62l5PIlSYbnkIidzoDZ2ztu4jF5LgStlTJQPteFEGgZVl5o9DaSZOq\n" + + "Yd7G3XqXuQ1VGMW58G5FYJPtA1cCgYEAz5TPUtK+R2KXHMjUwlGY9AefQYRYmyX2\n" + + "Tn/OFgKvY8lpAkMrhPKONq7SMYc8E9v9G7A0dIOXvW7QOYSapNhKU+np3lUafR5F\n" + + "4ZN0bxZ9qjHbn3AMYeraKjeutHvlLtbHdIc1j3sxe/EzltRsYmiqLdEBW0p6hwWg\n" + + "tyGhYWVyaXMCgYAfDOKtHpmEy5nOCLwNXKBWDk7DExfSyPqEgSnk1SeS1HP5ctPK\n" + + "+1st6sIhdiVpopwFc+TwJWxqKdW18tlfT5jVv1E2DEnccw3kXilS9xAhWkfwrEvf\n" + + "V5I74GydewFl32o+NZ8hdo9GL1I8zO1rIq/et8dSOWGuWf9BtKu/vTGTTQKBgFxU\n" + + "VjsCnbvmsEwPUAL2hE/WrBFaKocnxXx5AFNt8lEyHtDwy4Sg1nygGcIJ4sD6koQk\n" + + "RdClT3LkvR04TAiSY80bN/i6ZcPNGUwSaDGZEWAIOSWbkwZijZNFnSGOEgxZX/IG\n" + + "yd39766vREEMTwEeiMNEOZQ/dmxkJm4OOVe25cLdAoGACOtPnq1Fxay80UYBf4rQ\n" + + "+bJ9yX1ulB8WIree1hD7OHSB2lRHxrVYWrglrTvkh63Lgx+EcsTV788OsvAVfPPz\n" + + "BZrn8SdDlQqalMxUBYEFwnsYD3cQ8yOUnijFVC4xNcdDv8OIqVgSk4KKxU5AshaA\n" + + "xk6Mox+u8Cc2eAK12H13i+8=\n" + "-----END PRIVATE KEY-----\n"); } private static X509Certificate spCertificate() { - - return certificate("-----BEGIN CERTIFICATE-----\n" + - "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + - "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + - "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + - "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + - "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + - "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + - "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + - "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + - "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + - "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + - "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + - "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + - "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + - "-----END CERTIFICATE-----"); + return certificate( + "-----BEGIN CERTIFICATE-----\n" + "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + + "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + + "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + + "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + + "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + + "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + + "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + + "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + + "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + + "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + + "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + + "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----"); } private static PrivateKey spPrivateKey() { - return privateKey("-----BEGIN PRIVATE KEY-----\n" + - "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + - "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + - "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + - "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + - "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + - "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + - "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + - "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + - "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + - "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + - "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + - "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + - "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + - "INrtuLp4YHbgk1mi\n" + - "-----END PRIVATE KEY-----"); + return privateKey( + "-----BEGIN PRIVATE KEY-----\n" + "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + + "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + + "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + + "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + + "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + + "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + + "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + + "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + + "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + + "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + + "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + + "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + + "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + "INrtuLp4YHbgk1mi\n" + + "-----END PRIVATE KEY-----"); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java index 292619040d..dd9d9ba715 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java @@ -16,23 +16,19 @@ package org.springframework.security.saml2.credentials; -import org.springframework.security.converter.RsaKeyConverters; +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.security.PrivateKey; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import java.io.ByteArrayInputStream; -import java.security.PrivateKey; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION; +import org.springframework.security.converter.RsaKeyConverters; +import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType; public class Saml2X509CredentialTests { @@ -40,110 +36,111 @@ public class Saml2X509CredentialTests { public ExpectedException exception = ExpectedException.none(); private Saml2X509Credential credential; + private PrivateKey key; + private X509Certificate certificate; @Before public void setup() throws Exception { - String keyData = "-----BEGIN PRIVATE KEY-----\n" + - "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + - "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + - "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + - "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + - "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + - "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + - "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + - "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + - "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + - "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + - "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + - "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + - "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + - "INrtuLp4YHbgk1mi\n" + - "-----END PRIVATE KEY-----"; - key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(UTF_8))); + String keyData = "-----BEGIN PRIVATE KEY-----\n" + + "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + + "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + + "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + + "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + + "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + + "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + + "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + + "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + + "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + + "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + + "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + + "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + + "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + "INrtuLp4YHbgk1mi\n" + + "-----END PRIVATE KEY-----"; + this.key = RsaKeyConverters.pkcs8().convert(new ByteArrayInputStream(keyData.getBytes(StandardCharsets.UTF_8))); final CertificateFactory factory = CertificateFactory.getInstance("X.509"); - String certificateData = "-----BEGIN CERTIFICATE-----\n" + - "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + - "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + - "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + - "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + - "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + - "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + - "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + - "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + - "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + - "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + - "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + - "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + - "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + - "-----END CERTIFICATE-----"; - certificate = (X509Certificate) factory - .generateCertificate(new ByteArrayInputStream(certificateData.getBytes(UTF_8))); + String certificateData = "-----BEGIN CERTIFICATE-----\n" + + "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + + "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + + "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + + "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + + "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + + "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + + "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + + "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + + "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + + "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + + "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + + "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----"; + this.certificate = (X509Certificate) factory + .generateCertificate(new ByteArrayInputStream(certificateData.getBytes(StandardCharsets.UTF_8))); } @Test public void constructorWhenRelyingPartyWithCredentialsThenItSucceeds() { - new Saml2X509Credential(key, certificate, SIGNING); - new Saml2X509Credential(key, certificate, SIGNING, DECRYPTION); - new Saml2X509Credential(key, certificate, DECRYPTION); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.SIGNING, + Saml2X509CredentialType.DECRYPTION); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.DECRYPTION); } @Test public void constructorWhenAssertingPartyWithCredentialsThenItSucceeds() { - new Saml2X509Credential(certificate, VERIFICATION); - new Saml2X509Credential(certificate, VERIFICATION, ENCRYPTION); - new Saml2X509Credential(certificate, ENCRYPTION); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.VERIFICATION, + Saml2X509CredentialType.ENCRYPTION); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.ENCRYPTION); } @Test public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, (X509Certificate) null, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, certificate, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenRelyingPartyWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(key, null, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenAssertingPartyWithoutCertificateThenItFails() { - exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, SIGNING); + this.exception.expect(IllegalArgumentException.class); + new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(key, certificate, ENCRYPTION); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION); } @Test public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(key, certificate, VERIFICATION); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION); } @Test public void constructorWhenAssertingPartyWithSigningUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(certificate, SIGNING); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING); } @Test public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() { - exception.expect(IllegalStateException.class); - new Saml2X509Credential(certificate, DECRYPTION); + this.exception.expect(IllegalStateException.class); + new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION); } - } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/TestSaml2X509Credentials.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/TestSaml2X509Credentials.java index 001e864f89..87c8c2a57d 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/TestSaml2X509Credentials.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/TestSaml2X509Credentials.java @@ -17,6 +17,7 @@ package org.springframework.security.saml2.credentials; import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.security.KeyException; import java.security.PrivateKey; import java.security.cert.CertificateException; @@ -26,154 +27,147 @@ import java.security.cert.X509Certificate; import org.opensaml.security.crypto.KeySupport; import org.springframework.security.saml2.Saml2Exception; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING; -import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION; +import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType; public final class TestSaml2X509Credentials { + + private TestSaml2X509Credentials() { + } + public static Saml2X509Credential assertingPartySigningCredential() { - return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING); + return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING); } public static Saml2X509Credential assertingPartyEncryptingCredential() { - return new Saml2X509Credential(spCertificate(), ENCRYPTION); + return new Saml2X509Credential(spCertificate(), Saml2X509CredentialType.ENCRYPTION); } public static Saml2X509Credential assertingPartyPrivateCredential() { - return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), SIGNING, DECRYPTION); + return new Saml2X509Credential(idpPrivateKey(), idpCertificate(), Saml2X509CredentialType.SIGNING, + Saml2X509CredentialType.DECRYPTION); } public static Saml2X509Credential relyingPartyVerifyingCredential() { - return new Saml2X509Credential(idpCertificate(), VERIFICATION); + return new Saml2X509Credential(idpCertificate(), Saml2X509CredentialType.VERIFICATION); } public static Saml2X509Credential relyingPartySigningCredential() { - return new Saml2X509Credential(spPrivateKey(), spCertificate(), SIGNING); + return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.SIGNING); } public static Saml2X509Credential relyingPartyDecryptingCredential() { - return new Saml2X509Credential(spPrivateKey(), spCertificate(), DECRYPTION); + return new Saml2X509Credential(spPrivateKey(), spCertificate(), Saml2X509CredentialType.DECRYPTION); } private static X509Certificate certificate(String cert) { ByteArrayInputStream certBytes = new ByteArrayInputStream(cert.getBytes()); try { - return (X509Certificate) CertificateFactory - .getInstance("X.509") - .generateCertificate(certBytes); + return (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(certBytes); } - catch (CertificateException e) { - throw new Saml2Exception(e); + catch (CertificateException ex) { + throw new Saml2Exception(ex); } } private static PrivateKey privateKey(String key) { try { - return KeySupport.decodePrivateKey(key.getBytes(UTF_8), new char[0]); + return KeySupport.decodePrivateKey(key.getBytes(StandardCharsets.UTF_8), new char[0]); } - catch (KeyException e) { - throw new Saml2Exception(e); + catch (KeyException ex) { + throw new Saml2Exception(ex); } } - private static X509Certificate idpCertificate() { - return certificate("-----BEGIN CERTIFICATE-----\n" - + "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" - + "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" - + "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" - + "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" - + "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" - + "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" - + "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" - + "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" - + "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" - + "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" - + "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" - + "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" - + "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" - + "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" - + "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" - + "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" - + "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" - + "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" - + "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" - + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" - + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" - + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" - + "-----END CERTIFICATE-----\n"); + private static X509Certificate idpCertificate() { + return certificate( + "-----BEGIN CERTIFICATE-----\n" + "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" + + "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" + + "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" + + "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" + + "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" + + "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" + + "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" + + "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" + + "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" + + "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" + + "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" + + "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" + + "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" + + "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" + + "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" + + "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" + + "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" + + "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" + + "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" + + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" + + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" + + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + "-----END CERTIFICATE-----\n"); } - private static PrivateKey idpPrivateKey() { - return privateKey("-----BEGIN PRIVATE KEY-----\n" - + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4cn62E1xLqpN3\n" - + "4PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz2ZivLwZX\n" - + "W+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWWRDodcoHE\n" - + "fDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQnX8Ttl7h\n" - + "Z6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5cljz0X/T\n" - + "Xy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gphiJH3jvZ7\n" - + "I+J5lS8VAgMBAAECggEBAKyxBlIS7mcp3chvq0RF7B3PHFJMMzkwE+t3pLJcs4cZ\n" - + "nezh/KbREfP70QjXzk/llnZCvxeIs5vRu24vbdBm79qLHqBuHp8XfHHtuo2AfoAQ\n" - + "l4h047Xc/+TKMivnPQ0jX9qqndKDLqZDf5wnbslDmlskvF0a/MjsLU0TxtOfo+dB\n" - + "t55FW11cGqxZwhS5Gnr+cbw3OkHz23b9gEOt9qfwPVepeysbmm9FjU+k4yVa7rAN\n" - + "xcbzVb6Y7GCITe2tgvvEHmjB9BLmWrH3mZ3Af17YU/iN6TrpPd6Sj3QoS+2wGtAe\n" - + "HbUs3CKJu7bIHcj4poal6Kh8519S+erJTtqQ8M0ZiEECgYEA43hLYAPaUueFkdfh\n" - + "9K/7ClH6436CUH3VdizwUXi26fdhhV/I/ot6zLfU2mgEHU22LBECWQGtAFm8kv0P\n" - + "zPn+qjaR3e62l5PIlSYbnkIidzoDZ2ztu4jF5LgStlTJQPteFEGgZVl5o9DaSZOq\n" - + "Yd7G3XqXuQ1VGMW58G5FYJPtA1cCgYEAz5TPUtK+R2KXHMjUwlGY9AefQYRYmyX2\n" - + "Tn/OFgKvY8lpAkMrhPKONq7SMYc8E9v9G7A0dIOXvW7QOYSapNhKU+np3lUafR5F\n" - + "4ZN0bxZ9qjHbn3AMYeraKjeutHvlLtbHdIc1j3sxe/EzltRsYmiqLdEBW0p6hwWg\n" - + "tyGhYWVyaXMCgYAfDOKtHpmEy5nOCLwNXKBWDk7DExfSyPqEgSnk1SeS1HP5ctPK\n" - + "+1st6sIhdiVpopwFc+TwJWxqKdW18tlfT5jVv1E2DEnccw3kXilS9xAhWkfwrEvf\n" - + "V5I74GydewFl32o+NZ8hdo9GL1I8zO1rIq/et8dSOWGuWf9BtKu/vTGTTQKBgFxU\n" - + "VjsCnbvmsEwPUAL2hE/WrBFaKocnxXx5AFNt8lEyHtDwy4Sg1nygGcIJ4sD6koQk\n" - + "RdClT3LkvR04TAiSY80bN/i6ZcPNGUwSaDGZEWAIOSWbkwZijZNFnSGOEgxZX/IG\n" - + "yd39766vREEMTwEeiMNEOZQ/dmxkJm4OOVe25cLdAoGACOtPnq1Fxay80UYBf4rQ\n" - + "+bJ9yX1ulB8WIree1hD7OHSB2lRHxrVYWrglrTvkh63Lgx+EcsTV788OsvAVfPPz\n" - + "BZrn8SdDlQqalMxUBYEFwnsYD3cQ8yOUnijFVC4xNcdDv8OIqVgSk4KKxU5AshaA\n" + "xk6Mox+u8Cc2eAK12H13i+8=\n" - + "-----END PRIVATE KEY-----\n"); + return privateKey( + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4cn62E1xLqpN3\n" + + "4PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz2ZivLwZX\n" + + "W+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWWRDodcoHE\n" + + "fDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQnX8Ttl7h\n" + + "Z6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5cljz0X/T\n" + + "Xy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gphiJH3jvZ7\n" + + "I+J5lS8VAgMBAAECggEBAKyxBlIS7mcp3chvq0RF7B3PHFJMMzkwE+t3pLJcs4cZ\n" + + "nezh/KbREfP70QjXzk/llnZCvxeIs5vRu24vbdBm79qLHqBuHp8XfHHtuo2AfoAQ\n" + + "l4h047Xc/+TKMivnPQ0jX9qqndKDLqZDf5wnbslDmlskvF0a/MjsLU0TxtOfo+dB\n" + + "t55FW11cGqxZwhS5Gnr+cbw3OkHz23b9gEOt9qfwPVepeysbmm9FjU+k4yVa7rAN\n" + + "xcbzVb6Y7GCITe2tgvvEHmjB9BLmWrH3mZ3Af17YU/iN6TrpPd6Sj3QoS+2wGtAe\n" + + "HbUs3CKJu7bIHcj4poal6Kh8519S+erJTtqQ8M0ZiEECgYEA43hLYAPaUueFkdfh\n" + + "9K/7ClH6436CUH3VdizwUXi26fdhhV/I/ot6zLfU2mgEHU22LBECWQGtAFm8kv0P\n" + + "zPn+qjaR3e62l5PIlSYbnkIidzoDZ2ztu4jF5LgStlTJQPteFEGgZVl5o9DaSZOq\n" + + "Yd7G3XqXuQ1VGMW58G5FYJPtA1cCgYEAz5TPUtK+R2KXHMjUwlGY9AefQYRYmyX2\n" + + "Tn/OFgKvY8lpAkMrhPKONq7SMYc8E9v9G7A0dIOXvW7QOYSapNhKU+np3lUafR5F\n" + + "4ZN0bxZ9qjHbn3AMYeraKjeutHvlLtbHdIc1j3sxe/EzltRsYmiqLdEBW0p6hwWg\n" + + "tyGhYWVyaXMCgYAfDOKtHpmEy5nOCLwNXKBWDk7DExfSyPqEgSnk1SeS1HP5ctPK\n" + + "+1st6sIhdiVpopwFc+TwJWxqKdW18tlfT5jVv1E2DEnccw3kXilS9xAhWkfwrEvf\n" + + "V5I74GydewFl32o+NZ8hdo9GL1I8zO1rIq/et8dSOWGuWf9BtKu/vTGTTQKBgFxU\n" + + "VjsCnbvmsEwPUAL2hE/WrBFaKocnxXx5AFNt8lEyHtDwy4Sg1nygGcIJ4sD6koQk\n" + + "RdClT3LkvR04TAiSY80bN/i6ZcPNGUwSaDGZEWAIOSWbkwZijZNFnSGOEgxZX/IG\n" + + "yd39766vREEMTwEeiMNEOZQ/dmxkJm4OOVe25cLdAoGACOtPnq1Fxay80UYBf4rQ\n" + + "+bJ9yX1ulB8WIree1hD7OHSB2lRHxrVYWrglrTvkh63Lgx+EcsTV788OsvAVfPPz\n" + + "BZrn8SdDlQqalMxUBYEFwnsYD3cQ8yOUnijFVC4xNcdDv8OIqVgSk4KKxU5AshaA\n" + + "xk6Mox+u8Cc2eAK12H13i+8=\n" + "-----END PRIVATE KEY-----\n"); } private static X509Certificate spCertificate() { - - return certificate("-----BEGIN CERTIFICATE-----\n" + - "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + - "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + - "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + - "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + - "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + - "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + - "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + - "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + - "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + - "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + - "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + - "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + - "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + - "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + - "-----END CERTIFICATE-----"); + return certificate( + "-----BEGIN CERTIFICATE-----\n" + "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + + "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + + "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + + "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + + "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + + "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + + "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + + "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + + "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + + "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + + "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + + "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + "-----END CERTIFICATE-----"); } private static PrivateKey spPrivateKey() { - return privateKey("-----BEGIN PRIVATE KEY-----\n" + - "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + - "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + - "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + - "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + - "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + - "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + - "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + - "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + - "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + - "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + - "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + - "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + - "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + - "INrtuLp4YHbgk1mi\n" + - "-----END PRIVATE KEY-----"); + return privateKey( + "-----BEGIN PRIVATE KEY-----\n" + "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + + "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + + "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + + "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + + "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + + "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + + "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + + "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + + "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + + "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + + "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + + "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + + "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + "INrtuLp4YHbgk1mi\n" + + "-----END PRIVATE KEY-----"); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java index 0352be6741..c91651ef24 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/DefaultSaml2AuthenticatedPrincipalTests.java @@ -16,17 +16,17 @@ package org.springframework.security.saml2.provider.service.authentication; -import org.joda.time.DateTime; -import org.junit.Test; - import java.time.Instant; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.joda.time.DateTime; +import org.junit.Test; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; public class DefaultSaml2AuthenticatedPrincipalTests { @@ -43,16 +43,14 @@ public class DefaultSaml2AuthenticatedPrincipalTests { public void createDefaultSaml2AuthenticatedPrincipalWhenNameNullThenException() { Map> attributes = new LinkedHashMap<>(); attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); - assertThatCode(() -> new DefaultSaml2AuthenticatedPrincipal(null, attributes)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("name cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> new DefaultSaml2AuthenticatedPrincipal(null, attributes)) + .withMessageContaining("name cannot be null"); } @Test public void createDefaultSaml2AuthenticatedPrincipalWhenAttributesNullThenException() { - assertThatCode(() -> new DefaultSaml2AuthenticatedPrincipal("user", null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("attributes cannot be null"); + assertThatIllegalArgumentException().isThrownBy(() -> new DefaultSaml2AuthenticatedPrincipal("user", null)) + .withMessageContaining("attributes cannot be null"); } @Test @@ -75,16 +73,13 @@ public class DefaultSaml2AuthenticatedPrincipalTests { public void getAttributeWhenDistinctValuesThenReturnsValues() { final Boolean registered = true; final Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis()); - Map> attributes = new LinkedHashMap<>(); attributes.put("registration", Arrays.asList(registered, registeredDate)); - DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user", attributes); - List registrationInfo = principal.getAttribute("registration"); - assertThat(registrationInfo).isNotNull(); assertThat((Boolean) registrationInfo.get(0)).isEqualTo(registered); assertThat((Instant) registrationInfo.get(1)).isEqualTo(registeredDate); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index ecf2f785f7..a1bac457fd 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; + import javax.xml.namespace.QName; import net.shibboleth.utilities.java.support.xml.SerializeSupport; @@ -38,9 +39,11 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.Marshaller; import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.common.assertion.ValidationContext; +import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.EncryptedAssertion; @@ -56,36 +59,17 @@ import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.Saml2Error; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.credentials.Saml2X509Credential; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; +import org.springframework.util.StringUtils; -import static java.util.Collections.singleton; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION; -import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE; -import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; -import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator; -import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultResponseAuthenticationConverter; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signedResponseWithOneAssertion; -import static org.springframework.util.StringUtils.hasText; /** * Tests for {@link OpenSamlAuthenticationProvider} @@ -96,23 +80,27 @@ import static org.springframework.util.StringUtils.hasText; public class OpenSamlAuthenticationProviderTests { private static String DESTINATION = "https://localhost/login/saml2/sso/idp-alias"; + private static String RELYING_PARTY_ENTITY_ID = "https://localhost/saml2/service-provider-metadata/idp-alias"; + private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal - ("name", Collections.emptyMap()); - private Saml2Authentication authentication = new Saml2Authentication - (this.principal, "response", Collections.emptyList()); + + private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("name", + Collections.emptyMap()); + + private Saml2Authentication authentication = new Saml2Authentication(this.principal, "response", + Collections.emptyList()); @Rule public ExpectedException exception = ExpectedException.none(); @Test public void supportsWhenSaml2AuthenticationTokenThenReturnTrue() { - assertThat(this.provider.supports(Saml2AuthenticationToken.class)) - .withFailMessage(OpenSamlAuthenticationProvider.class + "should support " + Saml2AuthenticationToken.class) + .withFailMessage( + OpenSamlAuthenticationProvider.class + "should support " + Saml2AuthenticationToken.class) .isTrue(); } @@ -126,123 +114,114 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); - - Assertion assertion = (Assertion) getBuilderFactory().getBuilder(Assertion.DEFAULT_ELEMENT_NAME) - .buildObject(Assertion.DEFAULT_ELEMENT_NAME); - this.provider.authenticate(token(serialize(assertion), relyingPartyVerifyingCredential())); + Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory() + .getBuilder(Assertion.DEFAULT_ELEMENT_NAME).buildObject(Assertion.DEFAULT_ELEMENT_NAME); + this.provider + .authenticate(token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential())); } @Test public void authenticateWhenXmlErrorThenThrowAuthenticationException() { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); - - Saml2AuthenticationToken token = token("invalid xml", relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token("invalid xml", + TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION)); - - Response response = response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion()); - signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Response response = TestOpenSamlObjects.response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID); + response.getAssertions().add(TestOpenSamlObjects.assertion()); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() { this.exception.expect( - authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.") - ); - - Saml2AuthenticationToken token = token(response(), assertingPartySigningCredential()); + authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.")); + Saml2AuthenticationToken token = token(TestOpenSamlObjects.response(), + TestSaml2X509Credentials.assertingPartySigningCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE)); - - Response response = response(); - response.getAssertions().add(assertion()); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Response response = TestOpenSamlObjects.response(); + response.getAssertions().add(TestOpenSamlObjects.assertion()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() throws Exception { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_ASSERTION)); - - Response response = response(); - Assertion assertion = assertion(); - assertion - .getSubject() - .getSubjectConfirmations() - .get(0) - .getSubjectConfirmationData() + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData() .setNotOnOrAfter(DateTime.now().minus(Duration.standardDays(3))); - signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test - public void authenticateWhenMissingSubjectThenThrowAuthenticationException() { + public void authenticateWhenMissingSubjectThenThrowAuthenticationException() { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); - - Response response = response(); - Assertion assertion = assertion(); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); assertion.setSubject(null); - signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); - - Response response = response(); - Assertion assertion = assertion(); - assertion - .getSubject() - .getNameID() - .setValue(null); - signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + assertion.getSubject().getNameID().setValue(null); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenAssertionContainsValidationAddressThenItSucceeds() throws Exception { - Response response = response(); - Assertion assertion = assertion(); - assertion.getSubject().getSubjectConfirmations().forEach( - sc -> sc.getSubjectConfirmationData().setAddress("10.10.10.10") - ); - signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + assertion.getSubject().getSubjectConfirmations() + .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { - Response response = response(); - Assertion assertion = assertion(); - List attributes = attributeStatements(); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + List attributes = TestOpenSamlObjects.attributeStatements(); assertion.getAttributeStatements().addAll(attributes); - signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); Authentication authentication = this.provider.authenticate(token); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); - Map expected = new LinkedHashMap<>(); expected.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com")); expected.put("name", Collections.singletonList("John Doe")); @@ -251,7 +230,6 @@ public class OpenSamlAuthenticationProviderTests { expected.put("registered", Collections.singletonList(true)); Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis()); expected.put("registeredDate", Collections.singletonList(registeredDate)); - assertThat((String) principal.getFirstAttribute("name")).isEqualTo("John Doe"); assertThat(principal.getAttributes()).isEqualTo(expected); } @@ -259,84 +237,94 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception { this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE)); - - Response response = response(); - EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); + Response response = TestOpenSamlObjects.response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(response, relyingPartyDecryptingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenEncryptedAssertionWithSignatureThenItSucceeds() throws Exception { - Response response = response(); - Assertion assertion = signed(assertion(), assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); - EncryptedAssertion encryptedAssertion = encrypted(assertion, assertingPartyEncryptingCredential()); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion(), + TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(), relyingPartyDecryptingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), + TestSaml2X509Credentials.relyingPartyDecryptingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenEncryptedAssertionWithResponseSignatureThenItSucceeds() throws Exception { - Response response = response(); - EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); + Response response = TestOpenSamlObjects.response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(), relyingPartyDecryptingCredential()); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), + TestSaml2X509Credentials.relyingPartyDecryptingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenEncryptedNameIdWithSignatureThenItSucceeds() throws Exception { - Response response = response(); - Assertion assertion = assertion(); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); NameID nameId = assertion.getSubject().getNameID(); - EncryptedID encryptedID = encrypted(nameId, assertingPartyEncryptingCredential()); + EncryptedID encryptedID = TestOpenSamlObjects.encrypted(nameId, + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); assertion.getSubject().setNameID(null); assertion.getSubject().setEncryptedID(encryptedID); response.getAssertions().add(assertion); - signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(), relyingPartyDecryptingCredential()); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), + TestSaml2X509Credentials.relyingPartyDecryptingCredential()); this.provider.authenticate(token); } - @Test public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() throws Exception { - this.exception.expect( - authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData") - ); - - Response response = response(); - EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); + this.exception + .expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); + Response response = TestOpenSamlObjects.response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(serialize(response), relyingPartyVerifyingCredential()); + Saml2AuthenticationToken token = token(serialize(response), + TestSaml2X509Credentials.relyingPartyVerifyingCredential()); this.provider.authenticate(token); } @Test public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationException() throws Exception { - this.exception.expect( - authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData") - ); - - Response response = response(); - EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential()); + this.exception + .expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); + Response response = TestOpenSamlObjects.response(); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(serialize(response), assertingPartyPrivateCredential()); + Saml2AuthenticationToken token = token(serialize(response), + TestSaml2X509Credentials.assertingPartyPrivateCredential()); this.provider.authenticate(token); } @Test public void writeObjectWhenTypeIsSaml2AuthenticationThenNoException() throws IOException { - Response response = response(); - Assertion assertion = signed(assertion(), assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); - EncryptedAssertion encryptedAssertion = encrypted(assertion, assertingPartyEncryptingCredential()); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion(), + TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, + TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential(), relyingPartyDecryptingCredential()); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(), + TestSaml2X509Credentials.relyingPartyDecryptingCredential()); Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); - // the following code will throw an exception if authentication isn't serializable ByteArrayOutputStream byteStream = new ByteArrayOutputStream(1024); ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteStream); @@ -346,48 +334,58 @@ public class OpenSamlAuthenticationProviderTests { @Test public void createDefaultAssertionValidatorWhenAssertionThenValidates() { - Response response = signedResponseWithOneAssertion(); + Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); Assertion assertion = response.getAssertions().get(0); - OpenSamlAuthenticationProvider.AssertionToken assertionToken = - new OpenSamlAuthenticationProvider.AssertionToken(assertion, token()); - assertThat( - createDefaultAssertionValidator().convert(assertionToken) - .hasErrors()).isFalse(); + OpenSamlAuthenticationProvider.AssertionToken assertionToken = new OpenSamlAuthenticationProvider.AssertionToken( + assertion, token()); + assertThat(OpenSamlAuthenticationProvider.createDefaultAssertionValidator().convert(assertionToken).hasErrors()) + .isFalse(); } @Test public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() { OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> - createDefaultAssertionValidator(token -> new ValidationContext()).convert(assertionToken) - .concat(new Saml2Error("wrong error", "wrong error"))); - Response response = response(); - Assertion assertion = assertion(); + // @formatter:off + provider.setAssertionValidator((assertionToken) -> OpenSamlAuthenticationProvider + .createDefaultAssertionValidator((token) -> new ValidationContext()) + .convert(assertionToken) + .concat(new Saml2Error("wrong error", "wrong error")) + ); + // @formatter:on + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); assertion.getConditions().getConditions().add(oneTimeUse); response.getAssertions().add(assertion); - signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - assertThatThrownBy(() -> provider.authenticate(token)) - .isInstanceOf(Saml2AuthenticationException.class) - .hasFieldOrPropertyWithValue("error.errorCode", INVALID_ASSERTION); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + // @formatter:off + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> provider.authenticate(token)).isInstanceOf(Saml2AuthenticationException.class) + .satisfies((error) -> assertThat(error.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_ASSERTION)); + // @formatter:on } @Test public void authenticateWhenCustomAssertionValidatorThenUses() { - Converter validator = - mock(Converter.class); + Converter validator = mock( + Converter.class); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> - createDefaultAssertionValidator().convert(assertionToken) - .concat(validator.convert(assertionToken))); - Response response = response(); - Assertion assertion = assertion(); + // @formatter:off + provider.setAssertionValidator((assertionToken) -> OpenSamlAuthenticationProvider.createDefaultAssertionValidator() + .convert(assertionToken) + .concat(validator.convert(assertionToken)) + ); + // @formatter:on + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); response.getAssertions().add(assertion); - signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class))) - .thenReturn(success()); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + given(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class))) + .willReturn(Saml2ResponseValidatorResult.success()); provider.authenticate(token); verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)); } @@ -395,83 +393,97 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() { OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(assertionToken -> success()); - Response response = response(); - Assertion assertion = assertion(); - signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature + provider.setAssertionValidator((assertionToken) -> Saml2ResponseValidatorResult.success()); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); + TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.relyingPartyDecryptingCredential(), + RELYING_PARTY_ENTITY_ID); // broken + // signature response.getAssertions().add(assertion); - signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - assertThatThrownBy(() -> provider.authenticate(token)) - .isInstanceOf(Saml2AuthenticationException.class) - .hasFieldOrPropertyWithValue("error.errorCode", INVALID_SIGNATURE); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + // @formatter:off + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> provider.authenticate(token)) + .satisfies((error) -> assertThat(error.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_SIGNATURE)); + // @formatter:on } @Test public void authenticateWhenValidationContextCustomizedThenUsers() { Map parameters = new HashMap<>(); - parameters.put(SC_VALID_RECIPIENTS, singleton("blah")); + parameters.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton("blah")); ValidationContext context = mock(ValidationContext.class); - when(context.getStaticParameters()).thenReturn(parameters); + given(context.getStaticParameters()).willReturn(parameters); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); - provider.setAssertionValidator(createDefaultAssertionValidator(assertionToken -> context)); - Response response = response(); - Assertion assertion = assertion(); + provider.setAssertionValidator( + OpenSamlAuthenticationProvider.createDefaultAssertionValidator((assertionToken) -> context)); + Response response = TestOpenSamlObjects.response(); + Assertion assertion = TestOpenSamlObjects.assertion(); response.getAssertions().add(assertion); - signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - assertThatThrownBy(() -> provider.authenticate(token)) - .isInstanceOf(Saml2AuthenticationException.class) - .hasMessageContaining("Invalid assertion"); + TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), + ASSERTING_PARTY_ENTITY_ID); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + // @formatter:off + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> provider.authenticate(token)).isInstanceOf(Saml2AuthenticationException.class) + .satisfies((error) -> assertThat(error).hasMessageContaining("Invalid assertion")); + // @formatter:on verify(context, atLeastOnce()).getStaticParameters(); } @Test public void setAssertionValidatorWhenNullThenIllegalArgument() { - assertThatCode(() -> this.provider.setAssertionValidator(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.provider.setAssertionValidator(null)); + // @formatter:on } @Test public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() { - Response response = signedResponseWithOneAssertion(); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); - OpenSamlAuthenticationProvider.ResponseToken responseToken = - new OpenSamlAuthenticationProvider.ResponseToken(response, token); - Saml2Authentication authentication = createDefaultResponseAuthenticationConverter() - .convert(responseToken); + Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); + OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken( + response, token); + Saml2Authentication authentication = OpenSamlAuthenticationProvider + .createDefaultResponseAuthenticationConverter().convert(responseToken); assertThat(authentication.getName()).isEqualTo("test@saml.user"); } @Test public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() { - Converter authenticationConverter = - mock(Converter.class); + Converter authenticationConverter = mock( + Converter.class); OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider(); provider.setResponseAuthenticationConverter(authenticationConverter); - Response response = signedResponseWithOneAssertion(); - Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential()); + Response response = TestOpenSamlObjects.signedResponseWithOneAssertion(); + Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); provider.authenticate(token); verify(authenticationConverter).convert(any()); } @Test public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() { - assertThatCode(() -> this.provider.setResponseAuthenticationConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.provider.setResponseAuthenticationConverter(null)); + // @formatter:on } private T build(QName qName) { - return (T) getBuilderFactory().getBuilder(qName).buildObject(qName); + return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); } private String serialize(XMLObject object) { try { - Marshaller marshaller = getMarshallerFactory().getMarshaller(object); + Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object); Element element = marshaller.marshall(object); return SerializeSupport.nodeToString(element); - } catch (MarshallingException e) { - throw new Saml2Exception(e); + } + catch (MarshallingException ex) { + throw new Saml2Exception(ex); } } @@ -490,7 +502,7 @@ public class OpenSamlAuthenticationProviderTests { if (!code.equals(ex.getError().getErrorCode())) { return false; } - if (hasText(description)) { + if (StringUtils.hasText(description)) { if (!description.equals(ex.getError().getDescription())) { return false; } @@ -500,15 +512,14 @@ public class OpenSamlAuthenticationProviderTests { @Override public void describeTo(Description desc) { - String excepting = "Saml2AuthenticationException[code="+code+"; description="+description+"]"; + String excepting = "Saml2AuthenticationException[code=" + code + "; description=" + description + "]"; desc.appendText(excepting); - } }; } private Saml2AuthenticationToken token() { - return token(response(), relyingPartyVerifyingCredential()); + return token(TestOpenSamlObjects.response(), TestSaml2X509Credentials.relyingPartyVerifyingCredential()); } private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) { @@ -517,7 +528,8 @@ public class OpenSamlAuthenticationProviderTests { } private Saml2AuthenticationToken token(String payload, Saml2X509Credential... credentials) { - return new Saml2AuthenticationToken(payload, - DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, Arrays.asList(credentials)); + return new Saml2AuthenticationToken(payload, DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, + Arrays.asList(credentials)); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index bd4313b599..99fc66b954 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -17,12 +17,14 @@ package org.springframework.security.saml2.provider.service.authentication; import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; @@ -31,25 +33,16 @@ import org.w3c.dom.Element; import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getParserPool; -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getUnmarshallerFactory; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate; -import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest; -import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; /** * Tests for {@link OpenSamlAuthenticationRequestFactory} @@ -57,10 +50,13 @@ import static org.springframework.security.saml2.provider.service.registration.S public class OpenSamlAuthenticationRequestFactoryTests { private OpenSamlAuthenticationRequestFactory factory; + private Saml2AuthenticationRequestContext.Builder contextBuilder; + private Saml2AuthenticationRequestContext context; private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; + private RelyingPartyRegistration relyingPartyRegistration; private AuthnRequestUnmarshaller unmarshaller; @@ -72,120 +68,109 @@ public class OpenSamlAuthenticationRequestFactoryTests { public void setUp() { this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id") .assertionConsumerServiceLocation("template") - .providerDetails(c -> c.webSsoUrl("https://destination/sso")) - .providerDetails(c -> c.entityId("remote-entity-id")) - .localEntityIdTemplate("local-entity-id") - .credentials(c -> c.add(relyingPartySigningCredential())); + .providerDetails((c) -> c.webSsoUrl("https://destination/sso")) + .providerDetails((c) -> c.entityId("remote-entity-id")).localEntityIdTemplate("local-entity-id") + .credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential())); this.relyingPartyRegistration = this.relyingPartyRegistrationBuilder.build(); - contextBuilder = Saml2AuthenticationRequestContext.builder() - .issuer("https://issuer") - .relyingPartyRegistration(relyingPartyRegistration) + this.contextBuilder = Saml2AuthenticationRequestContext.builder().issuer("https://issuer") + .relyingPartyRegistration(this.relyingPartyRegistration) .assertionConsumerServiceUrl("https://issuer/sso"); - context = contextBuilder.build(); - factory = new OpenSamlAuthenticationRequestFactory(); - this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory() + this.context = this.contextBuilder.build(); + this.factory = new OpenSamlAuthenticationRequestFactory(); + this.unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory() .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); } @Test public void createAuthenticationRequestWhenInvokingDeprecatedMethodThenReturnsXML() { - Saml2AuthenticationRequest request = Saml2AuthenticationRequest.withAuthenticationRequestContext(context).build(); - String result = factory.createAuthenticationRequest(request); - assertThat(result.replace("\n", "")).startsWith(" c.signAuthNRequest(false)) - .build() - ) + RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration) + .providerDetails((c) -> c.signAuthNRequest(false)).build()) .build(); - Saml2RedirectAuthenticationRequest result = factory.createRedirectAuthenticationRequest(context); + Saml2RedirectAuthenticationRequest result = this.factory.createRedirectAuthenticationRequest(this.context); assertThat(result.getSamlRequest()).isNotEmpty(); assertThat(result.getRelayState()).isEqualTo("Relay State Value"); assertThat(result.getSigAlg()).isNull(); assertThat(result.getSignature()).isNull(); - assertThat(result.getBinding()).isEqualTo(REDIRECT); + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); } @Test public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() { - context = contextBuilder - .relayState("Relay State Value") + this.context = this.contextBuilder.relayState("Relay State Value") .relyingPartyRegistration( - withRelyingPartyRegistration(relyingPartyRegistration) - .providerDetails(c -> c.signAuthNRequest(false)) - .build() - ) + RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration) + .providerDetails((c) -> c.signAuthNRequest(false)).build()) .build(); - Saml2PostAuthenticationRequest result = factory.createPostAuthenticationRequest(context); + Saml2PostAuthenticationRequest result = this.factory.createPostAuthenticationRequest(this.context); assertThat(result.getSamlRequest()).isNotEmpty(); assertThat(result.getRelayState()).isEqualTo("Relay State Value"); - assertThat(result.getBinding()).isEqualTo(POST); - assertThat(new String(samlDecode(result.getSamlRequest()), UTF_8)) + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()), StandardCharsets.UTF_8)) .doesNotContain("ds:Signature"); } @Test public void createPostAuthenticationRequestWhenSignRequestThenSignatureIsPresent() { - context = contextBuilder - .relayState("Relay State Value") + this.context = this.contextBuilder.relayState("Relay State Value") .relyingPartyRegistration( - withRelyingPartyRegistration(relyingPartyRegistration) - .build() - ) + RelyingPartyRegistration.withRelyingPartyRegistration(this.relyingPartyRegistration).build()) .build(); - Saml2PostAuthenticationRequest result = factory.createPostAuthenticationRequest(context); + Saml2PostAuthenticationRequest result = this.factory.createPostAuthenticationRequest(this.context); assertThat(result.getSamlRequest()).isNotEmpty(); assertThat(result.getRelayState()).isEqualTo("Relay State Value"); - assertThat(result.getBinding()).isEqualTo(POST); - assertThat(new String(samlDecode(result.getSamlRequest()), UTF_8)) + assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()), StandardCharsets.UTF_8)) .contains("ds:Signature"); } @Test public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() { - AuthnRequest authn = getAuthNRequest(POST); + AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); Assert.assertEquals(SAMLConstants.SAML2_POST_BINDING_URI, authn.getProtocolBinding()); } @Test public void createAuthenticationRequestWhenSetUriThenReturnsCorrectBinding() { - factory.setProtocolBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); - AuthnRequest authn = getAuthNRequest(POST); + this.factory.setProtocolBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); + AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); Assert.assertEquals(SAMLConstants.SAML2_REDIRECT_BINDING_URI, authn.getProtocolBinding()); } @Test public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage(containsString("my-invalid-binding")); - factory.setProtocolBinding("my-invalid-binding"); + this.exception.expect(IllegalArgumentException.class); + this.exception.expectMessage(containsString("my-invalid-binding")); + this.factory.setProtocolBinding("my-invalid-binding"); } @Test public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() { - Converter authenticationRequestContextConverter = - mock(Converter.class); - when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest()); + Converter authenticationRequestContextConverter = mock( + Converter.class); + given(authenticationRequestContextConverter.convert(this.context)) + .willReturn(TestOpenSamlObjects.authnRequest()); this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter); this.factory.createPostAuthenticationRequest(this.context); @@ -194,9 +179,10 @@ public class OpenSamlAuthenticationRequestFactoryTests { @Test public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() { - Converter authenticationRequestContextConverter = - mock(Converter.class); - when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest()); + Converter authenticationRequestContextConverter = mock( + Converter.class); + given(authenticationRequestContextConverter.convert(this.context)) + .willReturn(TestOpenSamlObjects.authnRequest()); this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter); this.factory.createRedirectAuthenticationRequest(this.context); @@ -205,44 +191,45 @@ public class OpenSamlAuthenticationRequestFactoryTests { @Test public void setAuthenticationRequestContextConverterWhenNullThenException() { - assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.factory.setAuthenticationRequestContextConverter(null)); + // @formatter:on } @Test public void createPostAuthenticationRequestWhenAssertionConsumerServiceBindingThenUses() { RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder - .assertionConsumerServiceBinding(REDIRECT) - .build(); + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build(); Saml2AuthenticationRequestContext context = this.contextBuilder - .relyingPartyRegistration(relyingPartyRegistration) - .build(); + .relyingPartyRegistration(relyingPartyRegistration).build(); Saml2PostAuthenticationRequest request = this.factory.createPostAuthenticationRequest(context); String samlRequest = request.getSamlRequest(); - String inflated = new String(samlDecode(samlRequest)); + String inflated = new String(Saml2Utils.samlDecode(samlRequest)); assertThat(inflated).contains("ProtocolBinding=\"" + SAMLConstants.SAML2_REDIRECT_BINDING_URI + "\""); } private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) { - AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ? - factory.createRedirectAuthenticationRequest(context) : - factory.createPostAuthenticationRequest(context); + AbstractSaml2AuthenticationRequest result = (binding == Saml2MessageBinding.REDIRECT) + ? this.factory.createRedirectAuthenticationRequest(this.context) + : this.factory.createPostAuthenticationRequest(this.context); String samlRequest = result.getSamlRequest(); assertThat(samlRequest).isNotEmpty(); - if (result.getBinding() == REDIRECT) { - samlRequest = samlInflate(samlDecode(samlRequest)); + if (result.getBinding() == Saml2MessageBinding.REDIRECT) { + samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); } else { - samlRequest = new String(samlDecode(samlRequest), UTF_8); + samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); } try { - Document document = getParserPool().parse( - new ByteArrayInputStream(samlRequest.getBytes(UTF_8))); + Document document = XMLObjectProviderRegistrySupport.getParserPool() + .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); Element element = document.getDocumentElement(); return (AuthnRequest) this.unmarshaller.unmarshall(element); } - catch (Exception e) { - throw new Saml2Exception(e); + catch (Exception ex) { + throw new Saml2Exception(ex); } } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java index 563473a35d..c5b611e6d6 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java @@ -20,12 +20,10 @@ import java.util.UUID; import org.junit.Test; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; -import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; /** * Tests for {@link Saml2AuthenticationRequestFactory} default interface methods @@ -34,40 +32,35 @@ public class Saml2AuthenticationRequestFactoryTests { private RelyingPartyRegistration registration = RelyingPartyRegistration.withRegistrationId("id") .assertionConsumerServiceUrlTemplate("template") - .providerDetails(c -> c.webSsoUrl("https://example.com/destination")) - .providerDetails(c -> c.entityId("remote-entity-id")) - .localEntityIdTemplate("local-entity-id") - .credentials(c -> c.add(relyingPartySigningCredential())) - .build(); + .providerDetails((c) -> c.webSsoUrl("https://example.com/destination")) + .providerDetails((c) -> c.entityId("remote-entity-id")).localEntityIdTemplate("local-entity-id") + .credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential())).build(); @Test public void createAuthenticationRequestParametersWhenRedirectDefaultIsUsedMessageIsDeflatedAndEncoded() { - final String value = "Test String: "+ UUID.randomUUID().toString(); - Saml2AuthenticationRequestFactory factory = request -> value; + final String value = "Test String: " + UUID.randomUUID().toString(); + Saml2AuthenticationRequestFactory factory = (request) -> value; Saml2AuthenticationRequestContext request = Saml2AuthenticationRequestContext.builder() - .relyingPartyRegistration(registration) - .issuer("https://example.com/issuer") - .assertionConsumerServiceUrl("https://example.com/acs-url") - .build(); + .relyingPartyRegistration(this.registration).issuer("https://example.com/issuer") + .assertionConsumerServiceUrl("https://example.com/acs-url").build(); Saml2RedirectAuthenticationRequest response = factory.createRedirectAuthenticationRequest(request); String resultValue = response.getSamlRequest(); - byte[] decoded = samlDecode(resultValue); - String inflated = samlInflate(decoded); + byte[] decoded = Saml2Utils.samlDecode(resultValue); + String inflated = Saml2Utils.samlInflate(decoded); assertThat(inflated).isEqualTo(value); } @Test public void createAuthenticationRequestParametersWhenPostDefaultIsUsedMessageIsEncoded() { - final String value = "Test String: "+ UUID.randomUUID().toString(); - Saml2AuthenticationRequestFactory factory = request -> value; + final String value = "Test String: " + UUID.randomUUID().toString(); + Saml2AuthenticationRequestFactory factory = (request) -> value; Saml2AuthenticationRequestContext request = Saml2AuthenticationRequestContext.builder() - .relyingPartyRegistration(registration) - .issuer("https://example.com/issuer") - .assertionConsumerServiceUrl("https://example.com/acs-url") - .build(); + .relyingPartyRegistration(this.registration).issuer("https://example.com/issuer") + .assertionConsumerServiceUrl("https://example.com/acs-url").build(); Saml2PostAuthenticationRequest response = factory.createPostAuthenticationRequest(request); String resultValue = response.getSamlRequest(); - byte[] decoded = samlDecode(resultValue); + byte[] decoded = Saml2Utils.samlDecode(resultValue); assertThat(new String(decoded)).isEqualTo(value); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java index b237a64498..c607d12527 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; + import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; import javax.xml.namespace.QName; @@ -32,6 +33,7 @@ import org.apache.xml.security.encryption.XMLCipherParameters; import org.joda.time.DateTime; import org.joda.time.Duration; import org.opensaml.core.xml.XMLObject; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.core.xml.schema.XSAny; import org.opensaml.core.xml.schema.XSBoolean; @@ -49,6 +51,7 @@ import org.opensaml.core.xml.schema.impl.XSURIBuilder; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.saml.common.assertion.ValidationContext; +import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; @@ -82,22 +85,26 @@ import org.opensaml.xmlsec.signature.support.SignatureSupport; import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; - -import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory; -import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS; -import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; public final class TestOpenSamlObjects { + static { OpenSamlInitializationService.initialize(); } - private static String USERNAME = "test@saml.user"; + private static String DESTINATION = "https://localhost/login/saml2/sso/idp-alias"; + private static String RELYING_PARTY_ENTITY_ID = "https://localhost/saml2/service-provider-metadata/idp-alias"; + private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp"; - private static SecretKey SECRET_KEY = - new SecretKeySpec(Base64.getDecoder().decode("shOnwNMoCv88HKMEa91+FlYoD5RNvzMTAL5LGxZKIFk="), "AES"); + + private static SecretKey SECRET_KEY = new SecretKeySpec( + Base64.getDecoder().decode("shOnwNMoCv88HKMEa91+FlYoD5RNvzMTAL5LGxZKIFk="), "AES"); + + private TestOpenSamlObjects() { + } static Response response() { return response(DESTINATION, ASSERTING_PARTY_ENTITY_ID); @@ -105,7 +112,7 @@ public final class TestOpenSamlObjects { static Response response(String destination, String issuerEntityId) { Response response = build(Response.DEFAULT_ELEMENT_NAME); - response.setID("R"+UUID.randomUUID().toString()); + response.setID("R" + UUID.randomUUID().toString()); response.setIssueInstant(DateTime.now()); response.setVersion(SAMLVersion.VERSION_20); response.setID("_" + UUID.randomUUID().toString()); @@ -117,28 +124,22 @@ public final class TestOpenSamlObjects { static Response signedResponseWithOneAssertion() { Response response = response(); response.getAssertions().add(assertion()); - return signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); + return signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); } static Assertion assertion() { return assertion(USERNAME, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, DESTINATION); } - static Assertion assertion( - String username, - String issuerEntityId, - String recipientEntityId, - String recipientUri - ) { + static Assertion assertion(String username, String issuerEntityId, String recipientEntityId, String recipientUri) { Assertion assertion = build(Assertion.DEFAULT_ELEMENT_NAME); - assertion.setID("A"+ UUID.randomUUID().toString()); + assertion.setID("A" + UUID.randomUUID().toString()); assertion.setIssueInstant(DateTime.now()); assertion.setVersion(SAMLVersion.VERSION_20); assertion.setIssueInstant(DateTime.now()); assertion.setIssuer(issuer(issuerEntityId)); assertion.setSubject(subject(username)); assertion.setConditions(conditions()); - SubjectConfirmation subjectConfirmation = subjectConfirmation(); subjectConfirmation.setMethod(SubjectConfirmation.METHOD_BEARER); SubjectConfirmationData confirmationData = subjectConfirmationData(recipientEntityId); @@ -156,11 +157,9 @@ public final class TestOpenSamlObjects { static Subject subject(String principalName) { Subject subject = build(Subject.DEFAULT_ELEMENT_NAME); - if (principalName != null) { subject.setNameID(nameId(principalName)); } - return subject; } @@ -206,7 +205,8 @@ public final class TestOpenSamlObjects { return cred; } - static Credential getSigningCredential(org.springframework.security.saml2.credentials.Saml2X509Credential credential, String entityId) { + static Credential getSigningCredential( + org.springframework.security.saml2.credentials.Saml2X509Credential credential, String entityId) { BasicCredential cred = getBasicCredential(credential); cred.setEntityId(entityId); cred.setUsageType(UsageType.SIGNING); @@ -214,17 +214,12 @@ public final class TestOpenSamlObjects { } static BasicCredential getBasicCredential(Saml2X509Credential credential) { - return CredentialSupport.getSimpleCredential( - credential.getCertificate(), - credential.getPrivateKey() - ); + return CredentialSupport.getSimpleCredential(credential.getCertificate(), credential.getPrivateKey()); } - static BasicCredential getBasicCredential(org.springframework.security.saml2.credentials.Saml2X509Credential credential) { - return CredentialSupport.getSimpleCredential( - credential.getCertificate(), - credential.getPrivateKey() - ); + static BasicCredential getBasicCredential( + org.springframework.security.saml2.credentials.Saml2X509Credential credential) { + return CredentialSupport.getSimpleCredential(credential.getCertificate(), credential.getPrivateKey()); } static T signed(T signable, Saml2X509Credential credential, String entityId) { @@ -236,14 +231,15 @@ public final class TestOpenSamlObjects { parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); try { SignatureSupport.signObject(signable, parameters); - } catch (MarshallingException | SignatureException | SecurityException e) { - throw new Saml2Exception(e); } - + catch (MarshallingException | SignatureException | SecurityException ex) { + throw new Saml2Exception(ex); + } return signable; } - static T signed(T signable, org.springframework.security.saml2.credentials.Saml2X509Credential credential, String entityId) { + static T signed(T signable, + org.springframework.security.saml2.credentials.Saml2X509Credential credential, String entityId) { SignatureSigningParameters parameters = new SignatureSigningParameters(); Credential signingCredential = getSigningCredential(credential, entityId); parameters.setSigningCredential(signingCredential); @@ -252,10 +248,10 @@ public final class TestOpenSamlObjects { parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); try { SignatureSupport.signObject(signable, parameters); - } catch (MarshallingException | SignatureException | SecurityException e) { - throw new Saml2Exception(e); } - + catch (MarshallingException | SignatureException | SecurityException ex) { + throw new Saml2Exception(ex); + } return signable; } @@ -265,19 +261,20 @@ public final class TestOpenSamlObjects { try { return encrypter.encrypt(assertion); } - catch (EncryptionException e) { - throw new Saml2Exception("Unable to encrypt assertion.", e); + catch (EncryptionException ex) { + throw new Saml2Exception("Unable to encrypt assertion.", ex); } } - static EncryptedAssertion encrypted(Assertion assertion, org.springframework.security.saml2.credentials.Saml2X509Credential credential) { + static EncryptedAssertion encrypted(Assertion assertion, + org.springframework.security.saml2.credentials.Saml2X509Credential credential) { X509Certificate certificate = credential.getCertificate(); Encrypter encrypter = getEncrypter(certificate); try { return encrypter.encrypt(assertion); } - catch (EncryptionException e) { - throw new Saml2Exception("Unable to encrypt assertion.", e); + catch (EncryptionException ex) { + throw new Saml2Exception("Unable to encrypt assertion.", ex); } } @@ -287,113 +284,100 @@ public final class TestOpenSamlObjects { try { return encrypter.encrypt(nameId); } - catch (EncryptionException e) { - throw new Saml2Exception("Unable to encrypt nameID.", e); + catch (EncryptionException ex) { + throw new Saml2Exception("Unable to encrypt nameID.", ex); } } - static EncryptedID encrypted(NameID nameId, org.springframework.security.saml2.credentials.Saml2X509Credential credential) { + static EncryptedID encrypted(NameID nameId, + org.springframework.security.saml2.credentials.Saml2X509Credential credential) { X509Certificate certificate = credential.getCertificate(); Encrypter encrypter = getEncrypter(certificate); try { return encrypter.encrypt(nameId); } - catch (EncryptionException e) { - throw new Saml2Exception("Unable to encrypt nameID.", e); + catch (EncryptionException ex) { + throw new Saml2Exception("Unable to encrypt nameID.", ex); } } private static Encrypter getEncrypter(X509Certificate certificate) { String dataAlgorithm = XMLCipherParameters.AES_256; String keyAlgorithm = XMLCipherParameters.RSA_1_5; - BasicCredential dataCredential = new BasicCredential(SECRET_KEY); DataEncryptionParameters dataEncryptionParameters = new DataEncryptionParameters(); dataEncryptionParameters.setEncryptionCredential(dataCredential); dataEncryptionParameters.setAlgorithm(dataAlgorithm); - Credential credential = CredentialSupport.getSimpleCredential(certificate, null); KeyEncryptionParameters keyEncryptionParameters = new KeyEncryptionParameters(); keyEncryptionParameters.setEncryptionCredential(credential); keyEncryptionParameters.setAlgorithm(keyAlgorithm); - Encrypter encrypter = new Encrypter(dataEncryptionParameters, keyEncryptionParameters); Encrypter.KeyPlacement keyPlacement = Encrypter.KeyPlacement.valueOf("PEER"); encrypter.setKeyPlacement(keyPlacement); - return encrypter; } static List attributeStatements() { List attributeStatements = new ArrayList<>(); - AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder(); AttributeBuilder attributeBuilder = new AttributeBuilder(); - AttributeStatement attrStmt1 = attributeStatementBuilder.buildObject(); - Attribute emailAttr = attributeBuilder.buildObject(); emailAttr.setName("email"); - XSAny email1 = new XSAnyBuilder() - .buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSAny.TYPE_NAME); // gh-8864 + XSAny email1 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSAny.TYPE_NAME); // gh-8864 email1.setTextContent("john.doe@example.com"); emailAttr.getAttributeValues().add(email1); XSAny email2 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME); email2.setTextContent("doe.john@example.com"); emailAttr.getAttributeValues().add(email2); attrStmt1.getAttributes().add(emailAttr); - Attribute nameAttr = attributeBuilder.buildObject(); nameAttr.setName("name"); XSString name = new XSStringBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSString.TYPE_NAME); name.setValue("John Doe"); nameAttr.getAttributeValues().add(name); attrStmt1.getAttributes().add(nameAttr); - Attribute ageAttr = attributeBuilder.buildObject(); ageAttr.setName("age"); XSInteger age = new XSIntegerBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSInteger.TYPE_NAME); age.setValue(21); ageAttr.getAttributeValues().add(age); attrStmt1.getAttributes().add(ageAttr); - attributeStatements.add(attrStmt1); - AttributeStatement attrStmt2 = attributeStatementBuilder.buildObject(); - Attribute websiteAttr = attributeBuilder.buildObject(); websiteAttr.setName("website"); XSURI uri = new XSURIBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSURI.TYPE_NAME); uri.setValue("https://johndoe.com/"); websiteAttr.getAttributeValues().add(uri); attrStmt2.getAttributes().add(websiteAttr); - Attribute registeredAttr = attributeBuilder.buildObject(); registeredAttr.setName("registered"); - XSBoolean registered = new XSBooleanBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSBoolean.TYPE_NAME); + XSBoolean registered = new XSBooleanBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, + XSBoolean.TYPE_NAME); registered.setValue(new XSBooleanValue(true, false)); registeredAttr.getAttributeValues().add(registered); attrStmt2.getAttributes().add(registeredAttr); - Attribute registeredDateAttr = attributeBuilder.buildObject(); registeredDateAttr.setName("registeredDate"); - XSDateTime registeredDate = new XSDateTimeBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSDateTime.TYPE_NAME); + XSDateTime registeredDate = new XSDateTimeBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, + XSDateTime.TYPE_NAME); registeredDate.setValue(DateTime.parse("1970-01-01T00:00:00Z")); registeredDateAttr.getAttributeValues().add(registeredDate); attrStmt2.getAttributes().add(registeredDateAttr); - attributeStatements.add(attrStmt2); - return attributeStatements; } static ValidationContext validationContext() { Map params = new HashMap<>(); - params.put(SC_VALID_RECIPIENTS, Collections.singleton(DESTINATION)); + params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(DESTINATION)); return new ValidationContext(params); } static T build(QName qName) { - return (T) getBuilderFactory().getBuilder(qName).buildObject(qName); + return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationRequestContexts.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationRequestContexts.java index 57a721146e..ad3aa83a47 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationRequestContexts.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationRequestContexts.java @@ -16,17 +16,20 @@ package org.springframework.security.saml2.provider.service.authentication; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; /** * Test {@link Saml2AuthenticationRequestContext}s */ -public class TestSaml2AuthenticationRequestContexts { +public final class TestSaml2AuthenticationRequestContexts { + + private TestSaml2AuthenticationRequestContexts() { + } + public static Saml2AuthenticationRequestContext.Builder authenticationRequestContext() { - return Saml2AuthenticationRequestContext.builder() - .relayState("relayState") - .issuer("issuer") - .relyingPartyRegistration(relyingPartyRegistration().build()) + return Saml2AuthenticationRequestContext.builder().relayState("relayState").issuer("issuer") + .relyingPartyRegistration(TestRelyingPartyRegistrations.relyingPartyRegistration().build()) .assertionConsumerServiceUrl("assertionConsumerServiceUrl"); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java index 161e6124a9..2613e452b3 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java @@ -18,13 +18,12 @@ package org.springframework.security.saml2.provider.service.metadata; import org.junit.Test; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.full; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials; /** * Tests for {@link OpenSamlMetadataResolver} @@ -33,21 +32,12 @@ public class OpenSamlMetadataResolverTests { @Test public void resolveWhenRelyingPartyThenMetadataMatches() { - // given - RelyingPartyRegistration relyingPartyRegistration = full() - .assertionConsumerServiceBinding(REDIRECT) - .build(); + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full() + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build(); OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver(); - - // when String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration); - - // then - assertThat(metadata) - .contains("") + assertThat(metadata).contains("") .contains("") .contains("MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh") .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"") @@ -56,25 +46,17 @@ public class OpenSamlMetadataResolverTests { @Test public void resolveWhenRelyingPartyNoCredentialsThenMetadataMatches() { - // given - RelyingPartyRegistration relyingPartyRegistration = noCredentials() - .assertingPartyDetails(party -> party - .verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())) - ) + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials( + (c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) .build(); OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver(); - - // when String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration); - - // then - assertThat(metadata) - .contains("") + assertThat(metadata).contains("") .doesNotContain("") .contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"") .contains("Location=\"https://rp.example.org/acs\" index=\"1\""); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests.java index e9a85b8a01..30ce7e3261 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests.java @@ -25,38 +25,32 @@ import java.util.Base64; import org.junit.Before; import org.junit.Test; +import org.springframework.http.HttpStatus; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.security.saml2.Saml2Exception; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.http.HttpStatus.OK; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests { - private static final String CERTIFICATE = - "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYDVQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYDVQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwXc2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0BwaXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAaBgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQDDBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlrQHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWWRDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQnX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gphiJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduOnRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+vZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLuxbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6zV9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk"; - private static final String ENTITY_DESCRIPTOR_TEMPLATE = - "\n" + - "\n%s" + - ""; - private static final String IDP_SSO_DESCRIPTOR_TEMPLATE = - "\n" + - "%s\n" + - ""; - private static final String KEY_DESCRIPTOR_TEMPLATE = - "\n" + - "\n" + - "\n" + - "" + CERTIFICATE + "\n" + - "\n" + - "\n" + - ""; - private static final String SINGLE_SIGN_ON_SERVICE_TEMPLATE = - ""; + private static final String CERTIFICATE = "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYDVQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYDVQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwXc2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0BwaXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAaBgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQDDBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlrQHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWWRDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQnX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gphiJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduOnRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+vZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLuxbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6zV9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk"; + + private static final String ENTITY_DESCRIPTOR_TEMPLATE = "\n" + + "\n%s" + + ""; + + private static final String IDP_SSO_DESCRIPTOR_TEMPLATE = "\n" + + "%s\n" + ""; + + private static final String KEY_DESCRIPTOR_TEMPLATE = "\n" + + "\n" + "\n" + + "" + CERTIFICATE + "\n" + "\n" + "\n" + + ""; + + private static final String SINGLE_SIGN_ON_SERVICE_TEMPLATE = ""; private OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverter converter; @@ -67,50 +61,45 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests { @Test public void readWhenMissingIDPSSODescriptorThenException() { - MockClientHttpResponse response = new MockClientHttpResponse - ((String.format(ENTITY_DESCRIPTOR_TEMPLATE, "")).getBytes(), OK); - assertThatCode(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response)) - .isInstanceOf(Saml2Exception.class) - .hasMessageContaining("Metadata response is missing the necessary IDPSSODescriptor element"); + MockClientHttpResponse response = new MockClientHttpResponse( + (String.format(ENTITY_DESCRIPTOR_TEMPLATE, "")).getBytes(), HttpStatus.OK); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response)) + .withMessageContaining("Metadata response is missing the necessary IDPSSODescriptor element"); } @Test public void readWhenMissingVerificationKeyThenException() { - String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, - String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, "")); - MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK); - assertThatCode(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response)) - .isInstanceOf(Saml2Exception.class) - .hasMessageContaining("Metadata response is missing verification certificates, necessary for verifying SAML assertions"); + String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, "")); + MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response)) + .withMessageContaining( + "Metadata response is missing verification certificates, necessary for verifying SAML assertions"); } @Test public void readWhenMissingSingleSignOnServiceThenException() { String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, - String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, - String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"signing\"") - )); - MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK); - assertThatCode(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response)) - .isInstanceOf(Saml2Exception.class) - .hasMessageContaining("Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests"); + String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"signing\""))); + MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> this.converter.read(RelyingPartyRegistration.Builder.class, response)) + .withMessageContaining( + "Metadata response is missing a SingleSignOnService, necessary for sending AuthnRequests"); } @Test public void readWhenDescriptorFullySpecifiedThenConfigures() throws Exception { String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, - String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"signing\"") + - String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"encryption\"") + - String.format(SINGLE_SIGN_ON_SERVICE_TEMPLATE) - )); - MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK); - RelyingPartyRegistration registration = - this.converter.read(RelyingPartyRegistration.Builder.class, response) - .registrationId("one") - .build(); - RelyingPartyRegistration.AssertingPartyDetails details = - registration.getAssertingPartyDetails(); + String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"signing\"") + + String.format(KEY_DESCRIPTOR_TEMPLATE, "use=\"encryption\"") + + String.format(SINGLE_SIGN_ON_SERVICE_TEMPLATE))); + MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK); + RelyingPartyRegistration registration = this.converter.read(RelyingPartyRegistration.Builder.class, response) + .registrationId("one").build(); + RelyingPartyRegistration.AssertingPartyDetails details = registration.getAssertingPartyDetails(); assertThat(details.getWantAuthnRequestsSigned()).isFalse(); assertThat(details.getSingleSignOnServiceLocation()).isEqualTo("sso-location"); assertThat(details.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); @@ -125,18 +114,12 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests { @Test public void readWhenKeyDescriptorHasNoUseThenConfiguresBothKeyTypes() throws Exception { - String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, - String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, - String.format(KEY_DESCRIPTOR_TEMPLATE, "") + - String.format(SINGLE_SIGN_ON_SERVICE_TEMPLATE) - )); - MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), OK); - RelyingPartyRegistration registration = - this.converter.read(RelyingPartyRegistration.Builder.class, response) - .registrationId("one") - .build(); - RelyingPartyRegistration.AssertingPartyDetails details = - registration.getAssertingPartyDetails(); + String payload = String.format(ENTITY_DESCRIPTOR_TEMPLATE, String.format(IDP_SSO_DESCRIPTOR_TEMPLATE, + String.format(KEY_DESCRIPTOR_TEMPLATE, "") + String.format(SINGLE_SIGN_ON_SERVICE_TEMPLATE))); + MockClientHttpResponse response = new MockClientHttpResponse(payload.getBytes(), HttpStatus.OK); + RelyingPartyRegistration registration = this.converter.read(RelyingPartyRegistration.Builder.class, response) + .registrationId("one").build(); + RelyingPartyRegistration.AssertingPartyDetails details = registration.getAssertingPartyDetails(); assertThat(details.getVerificationX509Credentials().iterator().next().getCertificate()) .isEqualTo(x509Certificate(CERTIFICATE)); assertThat(details.getEncryptionX509Credentials()).hasSize(1); @@ -147,10 +130,11 @@ public class OpenSamlRelyingPartyRegistrationBuilderHttpMessageConverterTests { X509Certificate x509Certificate(String data) { try { InputStream certificate = new ByteArrayInputStream(Base64.getDecoder().decode(data.getBytes())); - return (X509Certificate) CertificateFactory.getInstance("X.509") - .generateCertificate(certificate); - } catch (Exception e) { - throw new IllegalArgumentException(e); + return (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(certificate); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); } } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java index d8cd44acd5..8f6c444913 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java @@ -18,33 +18,26 @@ package org.springframework.security.saml2.provider.service.registration; import org.junit.Test; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; -import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; public class RelyingPartyRegistrationTests { @Test public void withRelyingPartyRegistrationWorks() { - RelyingPartyRegistration registration = relyingPartyRegistration() - .providerDetails(p -> p.binding(POST)) - .providerDetails(p -> p.signAuthNRequest(false)) - .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT) - .build(); + RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .providerDetails((p) -> p.binding(Saml2MessageBinding.POST)) + .providerDetails((p) -> p.signAuthNRequest(false)) + .assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build(); RelyingPartyRegistration copy = RelyingPartyRegistration.withRelyingPartyRegistration(registration).build(); compareRegistrations(registration, copy); } private void compareRegistrations(RelyingPartyRegistration registration, RelyingPartyRegistration copy) { - assertThat(copy.getRegistrationId()) - .isEqualTo(registration.getRegistrationId()) - .isEqualTo("simplesamlphp"); - assertThat(copy.getProviderDetails().getEntityId()) - .isEqualTo(registration.getProviderDetails().getEntityId()) + assertThat(copy.getRegistrationId()).isEqualTo(registration.getRegistrationId()).isEqualTo("simplesamlphp"); + assertThat(copy.getProviderDetails().getEntityId()).isEqualTo(registration.getProviderDetails().getEntityId()) .isEqualTo(copy.getAssertingPartyDetails().getEntityId()) .isEqualTo(registration.getAssertingPartyDetails().getEntityId()) .isEqualTo("https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/metadata.php"); @@ -53,38 +46,27 @@ public class RelyingPartyRegistrationTests { .isEqualTo(copy.getAssertionConsumerServiceLocation()) .isEqualTo(registration.getAssertionConsumerServiceLocation()) .isEqualTo("{baseUrl}" + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI); - assertThat(copy.getCredentials()) - .containsAll(registration.getCredentials()) - .containsExactly( - registration.getCredentials().get(0), - registration.getCredentials().get(1) - ); - assertThat(copy.getLocalEntityIdTemplate()) - .isEqualTo(registration.getLocalEntityIdTemplate()) - .isEqualTo(copy.getEntityId()) - .isEqualTo(registration.getEntityId()) + assertThat(copy.getCredentials()).containsAll(registration.getCredentials()) + .containsExactly(registration.getCredentials().get(0), registration.getCredentials().get(1)); + assertThat(copy.getLocalEntityIdTemplate()).isEqualTo(registration.getLocalEntityIdTemplate()) + .isEqualTo(copy.getEntityId()).isEqualTo(registration.getEntityId()) .isEqualTo("{baseUrl}/saml2/service-provider-metadata/{registrationId}"); - assertThat(copy.getProviderDetails().getWebSsoUrl()) - .isEqualTo(registration.getProviderDetails().getWebSsoUrl()) + assertThat(copy.getProviderDetails().getWebSsoUrl()).isEqualTo(registration.getProviderDetails().getWebSsoUrl()) .isEqualTo(copy.getAssertingPartyDetails().getSingleSignOnServiceLocation()) .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()) .isEqualTo("https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php"); - assertThat(copy.getProviderDetails().getBinding()) - .isEqualTo(registration.getProviderDetails().getBinding()) + assertThat(copy.getProviderDetails().getBinding()).isEqualTo(registration.getProviderDetails().getBinding()) .isEqualTo(copy.getAssertingPartyDetails().getSingleSignOnServiceBinding()) .isEqualTo(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding()) - .isEqualTo(POST); + .isEqualTo(Saml2MessageBinding.POST); assertThat(copy.getProviderDetails().isSignAuthNRequest()) .isEqualTo(registration.getProviderDetails().isSignAuthNRequest()) .isEqualTo(copy.getAssertingPartyDetails().getWantAuthnRequestsSigned()) - .isEqualTo(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) - .isFalse(); + .isEqualTo(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()).isFalse(); assertThat(copy.getAssertionConsumerServiceBinding()) .isEqualTo(registration.getAssertionConsumerServiceBinding()); - assertThat(copy.getDecryptionX509Credentials()) - .isEqualTo(registration.getDecryptionX509Credentials()); - assertThat(copy.getSigningX509Credentials()) - .isEqualTo(registration.getSigningX509Credentials()); + assertThat(copy.getDecryptionX509Credentials()).isEqualTo(registration.getDecryptionX509Credentials()); + assertThat(copy.getSigningX509Credentials()).isEqualTo(registration.getSigningX509Credentials()); assertThat(copy.getAssertingPartyDetails().getEncryptionX509Credentials()) .isEqualTo(registration.getAssertingPartyDetails().getEncryptionX509Credentials()); assertThat(copy.getAssertingPartyDetails().getVerificationX509Credentials()) @@ -93,16 +75,12 @@ public class RelyingPartyRegistrationTests { @Test public void buildWhenUsingDefaultsThenAssertionConsumerServiceBindingDefaultsToPost() { - RelyingPartyRegistration relyingPartyRegistration = withRegistrationId("id") - .entityId("entity-id") - .assertionConsumerServiceLocation("location") - .assertingPartyDetails(assertingParty -> assertingParty - .entityId("entity-id") - .singleSignOnServiceLocation("location")) - .credentials(c -> c.add(relyingPartyVerifyingCredential())) - .build(); - - assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()) - .isEqualTo(POST); + RelyingPartyRegistration relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id") + .entityId("entity-id").assertionConsumerServiceLocation("location") + .assertingPartyDetails((assertingParty) -> assertingParty.entityId("entity-id") + .singleSignOnServiceLocation("location")) + .credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())).build(); + assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding()).isEqualTo(Saml2MessageBinding.POST); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationsTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationsTests.java index c3dfaed2b2..3c64c8e653 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationsTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationsTests.java @@ -23,114 +23,78 @@ import org.junit.Test; import org.springframework.security.saml2.Saml2Exception; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * Tests for {@link RelyingPartyRegistration} */ public class RelyingPartyRegistrationsTests { - private static final String IDP_SSO_DESCRIPTOR_PAYLOAD = - "\n" + - " \n" + - " \n" + - " \n" + - " example.com\n" + - " \n" + - " \n" + - " \n" + - " Consortium GARR IdP\n" + - " \n" + - " \n" + - " Consortium GARR IdP\n" + - " \n" + - " \n" + - " \n" + - " This Identity Provider gives support for the Consortium GARR's user community\n" + - " \n" + - " \n" + - " Questo Identity Provider di test fornisce supporto alla comunita' utenti GARR\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " MIIDZjCCAk6gAwIBAgIVAL9O+PA7SXtlwZZY8MVSE9On1cVWMA0GCSqGSIb3DQEB\n" + - " BQUAMCkxJzAlBgNVBAMTHmlkZW0tcHVwYWdlbnQuZG16LWludC51bmltby5pdDAe\n" + - " Fw0xMzA3MjQwMDQ0MTRaFw0zMzA3MjQwMDQ0MTRaMCkxJzAlBgNVBAMTHmlkZW0t\n" + - " cHVwYWdlbnQuZG16LWludC51bmltby5pdDCCASIwDQYJKoZIhvcNAMIIDQADggEP\n" + - " ADCCAQoCggEBAIAcp/VyzZGXUF99kwj4NvL/Rwv4YvBgLWzpCuoxqHZ/hmBwJtqS\n" + - " v0y9METBPFbgsF3hCISnxbcmNVxf/D0MoeKtw1YPbsUmow/bFe+r72hZ+IVAcejN\n" + - " iDJ7t5oTjsRN1t1SqvVVk6Ryk5AZhpFW+W9pE9N6c7kJ16Rp2/mbtax9OCzxpece\n" + - " byi1eiLfIBmkcRawL/vCc2v6VLI18i6HsNVO3l2yGosKCbuSoGDx2fCdAOk/rgdz\n" + - " cWOvFsIZSKuD+FVbSS/J9GVs7yotsS4PRl4iX9UMnfDnOMfO7bcBgbXtDl4SCU1v\n" + - " dJrRw7IL/pLz34Rv9a8nYitrzrxtLOp3nYUCAwEAAaOBhDCBgTBgBgMIIDEEWTBX\n" + - " gh5pZGVtLXB1cGFnZW50LmRtei1pbnQudW5pbW8uaXSGNWh0dHBzOi8vaWRlbS1w\n" + - " dXBhZ2VudC5kbXotaW50LnVuaW1vLml0L2lkcC9zaGliYm9sZXRoMB0GA1UdDgQW\n" + - " BBT8PANzz+adGnTRe8ldcyxAwe4VnzANBgkqhkiG9w0BAQUFAAOCAQEAOEnO8Clu\n" + - " 9z/Lf/8XOOsTdxJbV29DIF3G8KoQsB3dBsLwPZVEAQIP6ceS32Xaxrl6FMTDDNkL\n" + - " qUvvInUisw0+I5zZwYHybJQCletUWTnz58SC4C9G7FpuXHFZnOGtRcgGD1NOX4UU\n" + - " duus/4nVcGSLhDjszZ70Xtj0gw2Sn46oQPHTJ81QZ3Y9ih+Aj1c9OtUSBwtWZFkU\n" + - " yooAKoR8li68Yb21zN2N65AqV+ndL98M8xUYMKLONuAXStDeoVCipH6PJ09Z5U2p\n" + - " V5p4IQRV6QBsNw9CISJFuHzkVYTH5ZxzN80Ru46vh4y2M0Nu8GQ9I085KoZkrf5e\n" + - " Cq53OZt9ISjHEw==\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " \n" + - " Consortium GARR\n" + - " \n" + - " \n" + - " Consortium GARR\n" + - " \n" + - " \n" + - " \n" + - " Consortium GARR\n" + - " \n" + - " \n" + - " Consortium GARR\n" + - " \n" + - " \n" + - " \n" + - " https://example.org\n" + - " \n" + - " \n" + - " \n" + - " \n" + - " mailto:technical.contact@example.com\n" + - " \n" + - " \n" + - ""; + + private static final String IDP_SSO_DESCRIPTOR_PAYLOAD = "\n" + " \n" + + " \n" + + " \n" + " example.com\n" + + " \n" + " \n" + " \n" + + " Consortium GARR IdP\n" + " \n" + + " \n" + " Consortium GARR IdP\n" + + " \n" + " \n" + " \n" + + " This Identity Provider gives support for the Consortium GARR's user community\n" + + " \n" + " \n" + + " Questo Identity Provider di test fornisce supporto alla comunita' utenti GARR\n" + + " \n" + " \n" + " \n" + " \n" + + " \n" + " \n" + " \n" + + " \n" + + " MIIDZjCCAk6gAwIBAgIVAL9O+PA7SXtlwZZY8MVSE9On1cVWMA0GCSqGSIb3DQEB\n" + + " BQUAMCkxJzAlBgNVBAMTHmlkZW0tcHVwYWdlbnQuZG16LWludC51bmltby5pdDAe\n" + + " Fw0xMzA3MjQwMDQ0MTRaFw0zMzA3MjQwMDQ0MTRaMCkxJzAlBgNVBAMTHmlkZW0t\n" + + " cHVwYWdlbnQuZG16LWludC51bmltby5pdDCCASIwDQYJKoZIhvcNAMIIDQADggEP\n" + + " ADCCAQoCggEBAIAcp/VyzZGXUF99kwj4NvL/Rwv4YvBgLWzpCuoxqHZ/hmBwJtqS\n" + + " v0y9METBPFbgsF3hCISnxbcmNVxf/D0MoeKtw1YPbsUmow/bFe+r72hZ+IVAcejN\n" + + " iDJ7t5oTjsRN1t1SqvVVk6Ryk5AZhpFW+W9pE9N6c7kJ16Rp2/mbtax9OCzxpece\n" + + " byi1eiLfIBmkcRawL/vCc2v6VLI18i6HsNVO3l2yGosKCbuSoGDx2fCdAOk/rgdz\n" + + " cWOvFsIZSKuD+FVbSS/J9GVs7yotsS4PRl4iX9UMnfDnOMfO7bcBgbXtDl4SCU1v\n" + + " dJrRw7IL/pLz34Rv9a8nYitrzrxtLOp3nYUCAwEAAaOBhDCBgTBgBgMIIDEEWTBX\n" + + " gh5pZGVtLXB1cGFnZW50LmRtei1pbnQudW5pbW8uaXSGNWh0dHBzOi8vaWRlbS1w\n" + + " dXBhZ2VudC5kbXotaW50LnVuaW1vLml0L2lkcC9zaGliYm9sZXRoMB0GA1UdDgQW\n" + + " BBT8PANzz+adGnTRe8ldcyxAwe4VnzANBgkqhkiG9w0BAQUFAAOCAQEAOEnO8Clu\n" + + " 9z/Lf/8XOOsTdxJbV29DIF3G8KoQsB3dBsLwPZVEAQIP6ceS32Xaxrl6FMTDDNkL\n" + + " qUvvInUisw0+I5zZwYHybJQCletUWTnz58SC4C9G7FpuXHFZnOGtRcgGD1NOX4UU\n" + + " duus/4nVcGSLhDjszZ70Xtj0gw2Sn46oQPHTJ81QZ3Y9ih+Aj1c9OtUSBwtWZFkU\n" + + " yooAKoR8li68Yb21zN2N65AqV+ndL98M8xUYMKLONuAXStDeoVCipH6PJ09Z5U2p\n" + + " V5p4IQRV6QBsNw9CISJFuHzkVYTH5ZxzN80Ru46vh4y2M0Nu8GQ9I085KoZkrf5e\n" + + " Cq53OZt9ISjHEw==\n" + " \n" + + " \n" + " \n" + " \n" + " \n" + + " \n" + + " \n" + " \n" + " \n" + + " \n" + " Consortium GARR\n" + + " \n" + " \n" + + " Consortium GARR\n" + " \n" + " \n" + + " \n" + " Consortium GARR\n" + + " \n" + " \n" + + " Consortium GARR\n" + " \n" + " \n" + + " \n" + " https://example.org\n" + + " \n" + " \n" + " \n" + + " \n" + + " mailto:technical.contact@example.com\n" + + " \n" + " \n" + ""; @Test public void fromMetadataLocationWhenResolvableThenPopulatesBuilder() throws Exception { try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody(IDP_SSO_DESCRIPTOR_PAYLOAD).setResponseCode(200)); RelyingPartyRegistration registration = RelyingPartyRegistrations - .fromMetadataLocation(server.url("/").toString()) - .entityId("rp") - .build(); + .fromMetadataLocation(server.url("/").toString()).entityId("rp").build(); RelyingPartyRegistration.AssertingPartyDetails details = registration.getAssertingPartyDetails(); assertThat(details.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); assertThat(details.getSingleSignOnServiceLocation()) .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); - assertThat(details.getSingleSignOnServiceBinding()) - .isEqualTo(Saml2MessageBinding.POST); + assertThat(details.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); assertThat(details.getVerificationX509Credentials()).hasSize(1); assertThat(details.getEncryptionX509Credentials()).hasSize(1); } @@ -142,8 +106,8 @@ public class RelyingPartyRegistrationsTests { server.enqueue(new MockResponse().setBody(IDP_SSO_DESCRIPTOR_PAYLOAD).setResponseCode(200)); String url = server.url("/").toString(); server.shutdown(); - assertThatCode(() -> RelyingPartyRegistrations.fromMetadataLocation(url)) - .isInstanceOf(Saml2Exception.class); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> RelyingPartyRegistrations.fromMetadataLocation(url)); } } @@ -152,8 +116,9 @@ public class RelyingPartyRegistrationsTests { try (MockWebServer server = new MockWebServer()) { server.enqueue(new MockResponse().setBody("malformed").setResponseCode(200)); String url = server.url("/").toString(); - assertThatCode(() -> RelyingPartyRegistrations.fromMetadataLocation(url)) - .isInstanceOf(Saml2Exception.class); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> RelyingPartyRegistrations.fromMetadataLocation(url)); } } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java index c378576f26..a0e6aa8a0c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java @@ -16,57 +16,49 @@ package org.springframework.security.saml2.provider.service.registration; -import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.credentials.Saml2X509Credential; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; - /** * Preconfigured test data for {@link RelyingPartyRegistration} objects */ -public class TestRelyingPartyRegistrations { +public final class TestRelyingPartyRegistrations { + + private TestRelyingPartyRegistrations() { + } public static RelyingPartyRegistration.Builder relyingPartyRegistration() { String registrationId = "simplesamlphp"; - String rpEntityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; - Saml2X509Credential signingCredential = relyingPartySigningCredential(); - String assertionConsumerServiceLocation = "{baseUrl}" + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; - + Saml2X509Credential signingCredential = TestSaml2X509Credentials.relyingPartySigningCredential(); + String assertionConsumerServiceLocation = "{baseUrl}" + + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; String apEntityId = "https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/metadata.php"; - Saml2X509Credential verificationCertificate = relyingPartyVerifyingCredential(); + Saml2X509Credential verificationCertificate = TestSaml2X509Credentials.relyingPartyVerifyingCredential(); String singleSignOnServiceLocation = "https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php"; - - return RelyingPartyRegistration.withRegistrationId(registrationId) - .entityId(rpEntityId) + return RelyingPartyRegistration.withRegistrationId(registrationId).entityId(rpEntityId) .assertionConsumerServiceLocation(assertionConsumerServiceLocation) - .credentials(c -> c.add(signingCredential)) - .providerDetails(c -> c - .entityId(apEntityId) - .webSsoUrl(singleSignOnServiceLocation)) - .credentials(c -> c.add(verificationCertificate)); + .credentials((c) -> c.add(signingCredential)) + .providerDetails((c) -> c.entityId(apEntityId).webSsoUrl(singleSignOnServiceLocation)) + .credentials((c) -> c.add(verificationCertificate)); } public static RelyingPartyRegistration.Builder noCredentials() { - return RelyingPartyRegistration.withRegistrationId("registration-id") - .entityId("rp-entity-id") - .assertionConsumerServiceLocation("https://rp.example.org/acs") - .assertingPartyDetails(party -> party - .entityId("ap-entity-id") - .singleSignOnServiceLocation("https://ap.example.org/sso") - ); + return RelyingPartyRegistration.withRegistrationId("registration-id").entityId("rp-entity-id") + .assertionConsumerServiceLocation("https://rp.example.org/acs").assertingPartyDetails((party) -> party + .entityId("ap-entity-id").singleSignOnServiceLocation("https://ap.example.org/sso")); } public static RelyingPartyRegistration.Builder full() { return noCredentials() - .signingX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential())) - .decryptionX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential())) - .assertingPartyDetails(party -> party - .verificationX509Credentials(c -> c.add( - TestSaml2X509Credentials.relyingPartyVerifyingCredential()) - ) - ); + .signingX509Credentials((c) -> c.add(org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartySigningCredential())) + .decryptionX509Credentials((c) -> c.add(org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartyDecryptingCredential())) + .assertingPartyDetails((party) -> party.verificationX509Credentials( + (c) -> c.add(org.springframework.security.saml2.core.TestSaml2X509Credentials + .relyingPartyVerifyingCredential()))); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java index b685adab56..6ebb014b51 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java @@ -16,28 +16,32 @@ package org.springframework.security.saml2.provider.service.servlet.filter; +import javax.servlet.http.HttpServletResponse; + import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; -import javax.servlet.http.HttpServletResponse; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class Saml2WebSsoAuthenticationFilterTests { private Saml2WebSsoAuthenticationFilter filter; + private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); + private MockHttpServletRequest request = new MockHttpServletRequest(); + private HttpServletResponse response = new MockHttpServletResponse(); @Rule @@ -45,51 +49,50 @@ public class Saml2WebSsoAuthenticationFilterTests { @Before public void setup() { - filter = new Saml2WebSsoAuthenticationFilter(repository); - request.setPathInfo("/login/saml2/sso/idp-registration-id"); - request.setParameter("SAMLResponse", "xml-data-goes-here"); + this.filter = new Saml2WebSsoAuthenticationFilter(this.repository); + this.request.setPathInfo("/login/saml2/sso/idp-registration-id"); + this.request.setParameter("SAMLResponse", "xml-data-goes-here"); } @Test public void constructingFilterWithMissingRegistrationIdVariableThenThrowsException() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage("filterProcessesUrl must contain a {registrationId} match variable"); - filter = new Saml2WebSsoAuthenticationFilter(repository, "/url/missing/variable"); + this.exception.expect(IllegalArgumentException.class); + this.exception.expectMessage("filterProcessesUrl must contain a {registrationId} match variable"); + this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable"); } @Test public void constructingFilterWithValidRegistrationIdVariableThenSucceeds() { - filter = new Saml2WebSsoAuthenticationFilter(repository, "/url/variable/is/present/{registrationId}"); + this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/variable/is/present/{registrationId}"); } @Test public void requiresAuthenticationWhenHappyPathThenReturnsTrue() { - Assert.assertTrue(filter.requiresAuthentication(request, response)); + Assert.assertTrue(this.filter.requiresAuthentication(this.request, this.response)); } @Test public void requiresAuthenticationWhenCustomProcessingUrlThenReturnsTrue() { - filter = new Saml2WebSsoAuthenticationFilter(repository, "/some/other/path/{registrationId}"); - request.setPathInfo("/some/other/path/idp-registration-id"); - request.setParameter("SAMLResponse", "xml-data-goes-here"); - Assert.assertTrue(filter.requiresAuthentication(request, response)); + this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}"); + this.request.setPathInfo("/some/other/path/idp-registration-id"); + this.request.setParameter("SAMLResponse", "xml-data-goes-here"); + Assert.assertTrue(this.filter.requiresAuthentication(this.request, this.response)); } @Test public void attemptAuthenticationWhenRegistrationIdDoesNotExistThenThrowsException() { - when(repository.findByRegistrationId("non-existent-id")).thenReturn(null); - - filter = new Saml2WebSsoAuthenticationFilter(repository, "/some/other/path/{registrationId}"); - - request.setPathInfo("/some/other/path/non-existent-id"); - request.setParameter("SAMLResponse", "response"); - + given(this.repository.findByRegistrationId("non-existent-id")).willReturn(null); + this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}"); + this.request.setPathInfo("/some/other/path/non-existent-id"); + this.request.setParameter("SAMLResponse", "response"); try { - filter.attemptAuthentication(request, response); + this.filter.attemptAuthentication(this.request, this.response); failBecauseExceptionWasNotThrown(Saml2AuthenticationException.class); - } catch (Exception e) { - assertThat(e).isInstanceOf(Saml2AuthenticationException.class); - assertThat(e.getMessage()).isEqualTo("No relying party registration found"); + } + catch (Exception ex) { + assertThat(ex).isInstanceOf(Saml2AuthenticationException.class); + assertThat(ex.getMessage()).isEqualTo("No relying party registration found"); } } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index 1ea4d636c9..6079de5bcb 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.servlet.filter; import java.io.IOException; import java.nio.charset.StandardCharsets; + import javax.servlet.ServletException; import org.junit.Before; @@ -26,158 +27,136 @@ import org.junit.Test; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; +import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; -import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; -import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; public class Saml2WebSsoAuthenticationRequestFilterTests { private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO"; + private Saml2WebSsoAuthenticationRequestFilter filter; + private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); + private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class); - private Saml2AuthenticationRequestContextResolver resolver = - mock(Saml2AuthenticationRequestContextResolver.class); + + private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class); + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockFilterChain filterChain; + private RelyingPartyRegistration.Builder rpBuilder; @Before public void setup() { - filter = new Saml2WebSsoAuthenticationRequestFilter(repository); - request = new MockHttpServletRequest(); - response = new MockHttpServletResponse(); - request.setPathInfo("/saml2/authenticate/registration-id"); - - filterChain = new MockFilterChain(); - - rpBuilder = RelyingPartyRegistration - .withRegistrationId("registration-id") - .providerDetails(c -> c.entityId("idp-entity-id")) - .providerDetails(c -> c.webSsoUrl(IDP_SSO_URL)) + this.filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.request.setPathInfo("/saml2/authenticate/registration-id"); + this.filterChain = new MockFilterChain(); + this.rpBuilder = RelyingPartyRegistration.withRegistrationId("registration-id") + .providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL)) .assertionConsumerServiceUrlTemplate("template") - .credentials(c -> c.add(assertingPartyPrivateCredential())); + .credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential())); } @Test public void doFilterWhenNoRelayStateThenRedirectDoesNotContainParameter() throws ServletException, IOException { - when(repository.findByRegistrationId("registration-id")).thenReturn(rpBuilder.build()); - filter.doFilterInternal(request, response, filterChain); - assertThat(response.getHeader("Location")) - .doesNotContain("RelayState=") - .startsWith(IDP_SSO_URL); + given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build()); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getHeader("Location")).doesNotContain("RelayState=").startsWith(IDP_SSO_URL); } @Test public void doFilterWhenRelayStateThenRedirectDoesContainParameter() throws ServletException, IOException { - when(repository.findByRegistrationId("registration-id")).thenReturn(rpBuilder.build()); - request.setParameter("RelayState", "my-relay-state"); - filter.doFilterInternal(request, response, filterChain); - assertThat(response.getHeader("Location")) - .contains("RelayState=my-relay-state") - .startsWith(IDP_SSO_URL); + given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build()); + this.request.setParameter("RelayState", "my-relay-state"); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getHeader("Location")).contains("RelayState=my-relay-state").startsWith(IDP_SSO_URL); } @Test public void doFilterWhenRelayStateThatRequiresEncodingThenRedirectDoesContainsEncodedParameter() throws Exception { - when(repository.findByRegistrationId("registration-id")).thenReturn(rpBuilder.build()); + given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build()); final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param"; final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1); - request.setParameter("RelayState", relayStateValue); - filter.doFilterInternal(request, response, filterChain); - assertThat(response.getHeader("Location")) - .contains("RelayState="+relayStateEncoded) + this.request.setParameter("RelayState", relayStateValue); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getHeader("Location")).contains("RelayState=" + relayStateEncoded) .startsWith(IDP_SSO_URL); } @Test public void doFilterWhenSimpleSignatureSpecifiedThenSignatureParametersAreInTheRedirectURL() throws Exception { - when(repository.findByRegistrationId("registration-id")).thenReturn( - rpBuilder - .build() - ); + given(this.repository.findByRegistrationId("registration-id")).willReturn(this.rpBuilder.build()); final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param"; final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1); - request.setParameter("RelayState", relayStateValue); - filter.doFilterInternal(request, response, filterChain); - assertThat(response.getHeader("Location")) - .contains("RelayState="+relayStateEncoded) - .contains("SigAlg=") - .contains("Signature=") - .startsWith(IDP_SSO_URL); + this.request.setParameter("RelayState", relayStateValue); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getHeader("Location")).contains("RelayState=" + relayStateEncoded).contains("SigAlg=") + .contains("Signature=").startsWith(IDP_SSO_URL); } @Test public void doFilterWhenSignatureIsDisabledThenSignatureParametersAreNotInTheRedirectURL() throws Exception { - when(repository.findByRegistrationId("registration-id")).thenReturn( - rpBuilder - .providerDetails(c -> c.signAuthNRequest(false)) - .build() - ); + given(this.repository.findByRegistrationId("registration-id")) + .willReturn(this.rpBuilder.providerDetails((c) -> c.signAuthNRequest(false)).build()); final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param"; final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1); - request.setParameter("RelayState", relayStateValue); - filter.doFilterInternal(request, response, filterChain); - assertThat(response.getHeader("Location")) - .contains("RelayState="+relayStateEncoded) - .doesNotContain("SigAlg=") - .doesNotContain("Signature=") - .startsWith(IDP_SSO_URL); + this.request.setParameter("RelayState", relayStateValue); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getHeader("Location")).contains("RelayState=" + relayStateEncoded) + .doesNotContain("SigAlg=").doesNotContain("Signature=").startsWith(IDP_SSO_URL); } @Test public void doFilterWhenPostFormDataIsPresent() throws Exception { - when(repository.findByRegistrationId("registration-id")).thenReturn( - rpBuilder - .providerDetails(c -> c.binding(POST)) - .build() - ); + given(this.repository.findByRegistrationId("registration-id")) + .willReturn(this.rpBuilder.providerDetails((c) -> c.binding(Saml2MessageBinding.POST)).build()); final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param&javascript{alert('1');}"; final String relayStateEncoded = HtmlUtils.htmlEscape(relayStateValue); - request.setParameter("RelayState", relayStateValue); - filter.doFilterInternal(request, response, filterChain); - assertThat(response.getHeader("Location")).isNull(); - assertThat(response.getContentAsString()) + this.request.setParameter("RelayState", relayStateValue); + this.filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getHeader("Location")).isNull(); + assertThat(this.response.getContentAsString()) .contains("
        ") .contains(" c.binding(POST)) - .build(); + .providerDetails((c) -> c.binding(Saml2MessageBinding.POST)).build(); Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class); - when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); - when(authenticationRequest.getRelayState()).thenReturn("relay"); - when(authenticationRequest.getSamlRequest()).thenReturn("saml"); - when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty); - when(this.factory.createPostAuthenticationRequest(any())) - .thenReturn(authenticationRequest); - - Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter - (this.repository); + given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri"); + given(authenticationRequest.getRelayState()).willReturn("relay"); + given(authenticationRequest.getSamlRequest()).willReturn("saml"); + given(this.repository.findByRegistrationId("registration-id")).willReturn(relyingParty); + given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); filter.setAuthenticationRequestFactory(this.factory); filter.doFilterInternal(this.request, this.response, this.filterChain); - assertThat(this.response.getContentAsString()) - .contains("") + assertThat(this.response.getContentAsString()).contains("") .contains(" c.binding(POST)) - .build(); + .providerDetails((c) -> c.binding(Saml2MessageBinding.POST)).build(); Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class); - when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); - when(authenticationRequest.getRelayState()).thenReturn("relay"); - when(authenticationRequest.getSamlRequest()).thenReturn("saml"); - when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext() - .relyingPartyRegistration(relyingParty) - .build()); - when(this.factory.createPostAuthenticationRequest(any())) - .thenReturn(authenticationRequest); - - Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter - (this.resolver, this.factory); + given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri"); + given(authenticationRequest.getRelayState()).willReturn("relay"); + given(authenticationRequest.getSamlRequest()).willReturn("saml"); + given(this.resolver.resolve(this.request)).willReturn(TestSaml2AuthenticationRequestContexts + .authenticationRequestContext().relyingPartyRegistration(relyingParty).build()); + given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver, + this.factory); filter.doFilterInternal(this.request, this.response, this.filterChain); - assertThat(this.response.getContentAsString()) - .contains("") + assertThat(this.response.getContentAsString()).contains("") .contains(" filter.setRedirectMatcher(null)) - .isInstanceOf(IllegalArgumentException.class); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); + assertThatIllegalArgumentException().isThrownBy(() -> filter.setRedirectMatcher(null)); } @Test public void setAuthenticationRequestFactoryWhenNullThenException() { Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); - assertThatCode(() -> filter.setAuthenticationRequestFactory(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> filter.setAuthenticationRequestFactory(null)); } @Test public void doFilterWhenRequestMatcherFailsThenSkipsFilter() throws Exception { - Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter - (this.repository); - filter.setRedirectMatcher(request -> false); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); + filter.setRedirectMatcher((request) -> false); filter.doFilter(this.request, this.response, this.filterChain); verifyNoInteractions(this.repository); } @Test public void doFilterWhenRelyingPartyRegistrationNotFoundThenUnauthorized() throws Exception { - Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter - (this.repository); + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.response.getStatus()).isEqualTo(401); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java index 693075f803..41418978c4 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java @@ -22,20 +22,24 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultRelyingPartyRegistrationResolver} */ public class DefaultRelyingPartyRegistrationResolverTests { - private final RelyingPartyRegistration registration = relyingPartyRegistration().build(); - private final RelyingPartyRegistrationRepository repository = - new InMemoryRelyingPartyRegistrationRepository(this.registration); - private final DefaultRelyingPartyRegistrationResolver resolver = - new DefaultRelyingPartyRegistrationResolver(this.repository); + + private final RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .build(); + + private final RelyingPartyRegistrationRepository repository = new InMemoryRelyingPartyRegistrationRepository( + this.registration); + + private final DefaultRelyingPartyRegistrationResolver resolver = new DefaultRelyingPartyRegistrationResolver( + this.repository); @Test public void resolveWhenRequestContainsRegistrationIdThenResolves() { @@ -43,8 +47,7 @@ public class DefaultRelyingPartyRegistrationResolverTests { request.setPathInfo("/some/path/" + this.registration.getRegistrationId()); RelyingPartyRegistration registration = this.resolver.convert(request); assertThat(registration).isNotNull(); - assertThat(registration.getRegistrationId()) - .isEqualTo(this.registration.getRegistrationId()); + assertThat(registration.getRegistrationId()).isEqualTo(this.registration.getRegistrationId()); assertThat(registration.getEntityId()) .isEqualTo("http://localhost/saml2/service-provider-metadata/" + this.registration.getRegistrationId()); assertThat(registration.getAssertionConsumerServiceLocation()) @@ -68,7 +71,7 @@ public class DefaultRelyingPartyRegistrationResolverTests { @Test public void constructorWhenNullRelyingPartyRegistrationThenIllegalArgument() { - assertThatCode(() -> new DefaultRelyingPartyRegistrationResolver(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new DefaultRelyingPartyRegistrationResolver(null)); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java index 80f2cd6afc..1905a0db3b 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java @@ -20,12 +20,12 @@ import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link DefaultSaml2AuthenticationRequestContextResolver} @@ -36,36 +36,38 @@ import static org.springframework.security.saml2.credentials.TestSaml2X509Creden public class DefaultSaml2AuthenticationRequestContextResolverTests { private static final String ASSERTING_PARTY_SSO_URL = "https://idp.example.com/sso"; + private static final String RELYING_PARTY_SSO_URL = "https://sp.example.com/sso"; + private static final String ASSERTING_PARTY_ENTITY_ID = "asserting-party-entity-id"; + private static final String RELYING_PARTY_ENTITY_ID = "relying-party-entity-id"; + private static final String REGISTRATION_ID = "registration-id"; private MockHttpServletRequest request; + private RelyingPartyRegistration.Builder relyingPartyBuilder; - private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver - = new DefaultSaml2AuthenticationRequestContextResolver( - new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build())); + + private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver((id) -> this.relyingPartyBuilder.build())); @Before public void setup() { this.request = new MockHttpServletRequest(); this.request.setPathInfo("/saml2/authenticate/registration-id"); - this.relyingPartyBuilder = RelyingPartyRegistration - .withRegistrationId(REGISTRATION_ID) + this.relyingPartyBuilder = RelyingPartyRegistration.withRegistrationId(REGISTRATION_ID) .localEntityIdTemplate(RELYING_PARTY_ENTITY_ID) - .providerDetails(c -> c.entityId(ASSERTING_PARTY_ENTITY_ID)) - .providerDetails(c -> c.webSsoUrl(ASSERTING_PARTY_SSO_URL)) + .providerDetails((c) -> c.entityId(ASSERTING_PARTY_ENTITY_ID)) + .providerDetails((c) -> c.webSsoUrl(ASSERTING_PARTY_SSO_URL)) .assertionConsumerServiceUrlTemplate(RELYING_PARTY_SSO_URL) - .credentials(c -> c.add(relyingPartyVerifyingCredential())); + .credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential())); } @Test public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() { this.request.addParameter("RelayState", "relay-state"); - Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request); - + Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(this.request); assertThat(context).isNotNull(); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL); assertThat(context.getRelayState()).isEqualTo("relay-state"); @@ -77,29 +79,22 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests { @Test public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() { - this.relyingPartyBuilder - .assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}"); - Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request); - + this.relyingPartyBuilder.assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}"); + Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(this.request); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id"); } @Test public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() { - this.relyingPartyBuilder - .assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}"); - Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request); - + this.relyingPartyBuilder.assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}"); + Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(this.request); assertThat(context.getAssertionConsumerServiceUrl()) .isEqualTo("http://localhost/saml2/authenticate/registration-id"); } @Test public void resolveWhenRelyingPartyNullThenException() { - assertThatCode(() -> - this.authenticationRequestContextResolver.resolve(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationRequestContextResolver.resolve(null)); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java index 74b987ea63..91038de6ae 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java @@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web; import java.io.IOException; import java.nio.charset.StandardCharsets; + import javax.servlet.http.HttpServletRequest; import org.junit.Test; @@ -31,63 +32,63 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.core.Saml2Utils; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.util.StreamUtils; import org.springframework.web.util.UriUtils; -import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; +import static org.mockito.BDDMockito.given; @RunWith(MockitoJUnitRunner.class) public class Saml2AuthenticationTokenConverterTests { + @Mock Converter relyingPartyRegistrationResolver; - RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistration().build(); + RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration() + .build(); @Test public void convertWhenSamlResponseThenToken() { - Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter - (this.relyingPartyRegistrationResolver); - when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) - .thenReturn(this.relyingPartyRegistration); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); - request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(UTF_8))); + request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8))); Saml2AuthenticationToken token = converter.convert(request); assertThat(token.getSaml2Response()).isEqualTo("response"); assertThat(token.getRelyingPartyRegistration().getRegistrationId()) - .isEqualTo(relyingPartyRegistration.getRegistrationId()); + .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); } @Test public void convertWhenNoSamlResponseThenNull() { - Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter - (this.relyingPartyRegistrationResolver); - when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) - .thenReturn(this.relyingPartyRegistration); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); assertThat(converter.convert(request)).isNull(); } @Test public void convertWhenNoRelyingPartyRegistrationThenNull() { - Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter - (this.relyingPartyRegistrationResolver); - when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) - .thenReturn(null); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))).willReturn(null); MockHttpServletRequest request = new MockHttpServletRequest(); assertThat(converter.convert(request)).isNull(); } @Test public void convertWhenGetRequestThenInflates() { - Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter - (this.relyingPartyRegistrationResolver); - when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) - .thenReturn(this.relyingPartyRegistration); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("GET"); byte[] deflated = Saml2Utils.samlDeflate("response"); @@ -96,21 +97,20 @@ public class Saml2AuthenticationTokenConverterTests { Saml2AuthenticationToken token = converter.convert(request); assertThat(token.getSaml2Response()).isEqualTo("response"); assertThat(token.getRelyingPartyRegistration().getRegistrationId()) - .isEqualTo(relyingPartyRegistration.getRegistrationId()); + .isEqualTo(this.relyingPartyRegistration.getRegistrationId()); } @Test public void constructorWhenResolverIsNullThenIllegalArgument() { - assertThatCode(() -> new Saml2AuthenticationTokenConverter(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null)); } @Test public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception { - Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter - (this.relyingPartyRegistrationResolver); - when(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) - .thenReturn(this.relyingPartyRegistration); + Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter( + this.relyingPartyRegistrationResolver); + given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class))) + .willReturn(this.relyingPartyRegistration); MockHttpServletRequest request = new MockHttpServletRequest(); request.setParameter("SAMLResponse", getSsoCircleEncodedXml()); Saml2AuthenticationToken token = converter.convert(request); @@ -118,8 +118,7 @@ public class Saml2AuthenticationTokenConverterTests { } private void validateSsoCircleXml(String xml) { - assertThat(xml) - .contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") + assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"") .contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"") .contains("https://idp.ssocircle.com"); } @@ -127,6 +126,7 @@ public class Saml2AuthenticationTokenConverterTests { private String getSsoCircleEncodedXml() throws IOException { ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded"); String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8); - return UriUtils.decode(response, UTF_8); + return UriUtils.decode(response, StandardCharsets.UTF_8); } + } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTest.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java similarity index 76% rename from saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTest.java rename to saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java index 4c490bd284..12e024a3a1 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTest.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java @@ -16,44 +16,50 @@ package org.springframework.security.saml2.provider.service.web; +import javax.servlet.FilterChain; + import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.core.TestSaml2X509Credentials; import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; -import javax.servlet.FilterChain; - import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential; -import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials; /** * Tests for {@link Saml2MetadataFilter} */ -public class Saml2MetadataFilterTest { +public class Saml2MetadataFilterTests { RelyingPartyRegistrationRepository repository; + Saml2MetadataResolver resolver; + Saml2MetadataFilter filter; + MockHttpServletRequest request; + MockHttpServletResponse response; + FilterChain chain; @Before public void setup() { this.repository = mock(RelyingPartyRegistrationRepository.class); this.resolver = mock(Saml2MetadataResolver.class); - this.filter = new Saml2MetadataFilter( - new DefaultRelyingPartyRegistrationResolver(this.repository), this.resolver); + this.filter = new Saml2MetadataFilter(new DefaultRelyingPartyRegistrationResolver(this.repository), + this.resolver); this.request = new MockHttpServletRequest(); this.response = new MockHttpServletResponse(); this.chain = mock(FilterChain.class); @@ -61,61 +67,39 @@ public class Saml2MetadataFilterTest { @Test public void doFilterWhenMatcherSucceedsThenResolverInvoked() throws Exception { - // given this.request.setPathInfo("/saml2/service-provider-metadata/registration-id"); - - // when this.filter.doFilter(this.request, this.response, this.chain); - - // then verifyNoInteractions(this.chain); verify(this.repository).findByRegistrationId("registration-id"); } @Test public void doFilterWhenMatcherFailsThenProcessesFilterChain() throws Exception { - // given this.request.setPathInfo("/saml2/authenticate/registration-id"); - - // when this.filter.doFilter(this.request, this.response, this.chain); - - // then verify(this.chain).doFilter(this.request, this.response); } @Test public void doFilterWhenNoRelyingPartyRegistrationThenUnauthorized() throws Exception { - // given this.request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration"); - when(this.repository.findByRegistrationId("invalidRegistration")).thenReturn(null); - - // when + given(this.repository.findByRegistrationId("invalidRegistration")).willReturn(null); this.filter.doFilter(this.request, this.response, this.chain); - - // then verifyNoInteractions(this.chain); assertThat(this.response.getStatus()).isEqualTo(401); } @Test public void doFilterWhenRelyingPartyRegistrationFoundThenInvokesMetadataResolver() throws Exception { - // given this.request.setPathInfo("/saml2/service-provider-metadata/validRegistration"); - RelyingPartyRegistration validRegistration = noCredentials() - .assertingPartyDetails(party -> party - .verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential()))) + RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.noCredentials() + .assertingPartyDetails((party) -> party.verificationX509Credentials( + (c) -> c.add(TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) .build(); - String generatedMetadata = "test"; - when(this.resolver.resolve(validRegistration)).thenReturn(generatedMetadata); - - this.filter = new Saml2MetadataFilter(request -> validRegistration, this.resolver); - - // when + given(this.resolver.resolve(validRegistration)).willReturn(generatedMetadata); + this.filter = new Saml2MetadataFilter((request) -> validRegistration, this.resolver); this.filter.doFilter(this.request, this.response, this.chain); - - // then verifyNoInteractions(this.chain); assertThat(this.response.getStatus()).isEqualTo(200); assertThat(this.response.getContentAsString()).isEqualTo(generatedMetadata); @@ -124,21 +108,16 @@ public class Saml2MetadataFilterTest { @Test public void doFilterWhenCustomRequestMatcherThenUses() throws Exception { - // given this.request.setPathInfo("/path"); this.filter.setRequestMatcher(new AntPathRequestMatcher("/path")); - - // when this.filter.doFilter(this.request, this.response, this.chain); - - // then verifyNoInteractions(this.chain); verify(this.repository).findByRegistrationId("path"); } @Test public void setRequestMatcherWhenNullThenIllegalArgument() { - assertThatCode(() -> this.filter.setRequestMatcher(null)) - .isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null)); } + } diff --git a/samples/boot/hellorsocket/src/integration-test/java/sample/HelloRSocketApplicationITests.java b/samples/boot/hellorsocket/src/integration-test/java/sample/HelloRSocketApplicationITests.java index 420588c356..0709dd93bc 100644 --- a/samples/boot/hellorsocket/src/integration-test/java/sample/HelloRSocketApplicationITests.java +++ b/samples/boot/hellorsocket/src/integration-test/java/sample/HelloRSocketApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.beans.factory.annotation.Autowired; @@ -52,7 +53,7 @@ public class HelloRSocketApplicationITests { public void messageWhenAuthenticatedThenSuccess() { UsernamePasswordMetadata credentials = new UsernamePasswordMetadata("user", "password"); RSocketRequester requester = this.requester - .rsocketStrategies(builder -> builder.encoder(new BasicAuthenticationEncoder())) + .rsocketStrategies((builder) -> builder.encoder(new BasicAuthenticationEncoder())) .setupMetadata(credentials, BASIC_AUTHENTICATION_MIME_TYPE) .connectTcp("localhost", this.port) .block(); diff --git a/samples/boot/hellowebflux-method/src/integration-test/java/sample/HelloWebfluxMethodApplicationITests.java b/samples/boot/hellowebflux-method/src/integration-test/java/sample/HelloWebfluxMethodApplicationITests.java index e5dcd999ca..d95956b2d6 100644 --- a/samples/boot/hellowebflux-method/src/integration-test/java/sample/HelloWebfluxMethodApplicationITests.java +++ b/samples/boot/hellowebflux-method/src/integration-test/java/sample/HelloWebfluxMethodApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.util.function.Consumer; @@ -69,11 +70,11 @@ public class HelloWebfluxMethodApplicationITests { } private Consumer robsCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("rob", "rob"); + return (httpHeaders) -> httpHeaders.setBasicAuth("rob", "rob"); } private Consumer adminCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("admin", "admin"); + return (httpHeaders) -> httpHeaders.setBasicAuth("admin", "admin"); } } diff --git a/samples/boot/hellowebflux-method/src/main/java/sample/SecurityConfig.java b/samples/boot/hellowebflux-method/src/main/java/sample/SecurityConfig.java index cf8cd0e92b..fbbcdc835e 100644 --- a/samples/boot/hellowebflux-method/src/main/java/sample/SecurityConfig.java +++ b/samples/boot/hellowebflux-method/src/main/java/sample/SecurityConfig.java @@ -40,7 +40,7 @@ public class SecurityConfig { return http // Demonstrate that method security works // Best practice to use both for defense in depth - .authorizeExchange(exchanges -> exchanges + .authorizeExchange((exchanges) -> exchanges .anyExchange().permitAll() ) .httpBasic(withDefaults()) diff --git a/samples/boot/hellowebflux-method/src/test/java/sample/HelloWebfluxMethodApplicationTests.java b/samples/boot/hellowebflux-method/src/test/java/sample/HelloWebfluxMethodApplicationTests.java index fdf850a114..6a2e2898ed 100644 --- a/samples/boot/hellowebflux-method/src/test/java/sample/HelloWebfluxMethodApplicationTests.java +++ b/samples/boot/hellowebflux-method/src/test/java/sample/HelloWebfluxMethodApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockUser; @@ -58,8 +59,6 @@ public class HelloWebfluxMethodApplicationTests { .expectStatus().isUnauthorized(); } - // --- Basic Authentication --- - @Test public void messageWhenUserThenForbidden() { this.rest @@ -81,8 +80,6 @@ public class HelloWebfluxMethodApplicationTests { .expectBody(String.class).isEqualTo("Hello World!"); } - // --- WithMockUser --- - @Test @WithMockUser public void messageWhenWithMockUserThenForbidden() { @@ -104,8 +101,6 @@ public class HelloWebfluxMethodApplicationTests { .expectBody(String.class).isEqualTo("Hello World!"); } - // --- mutateWith mockUser --- - @Test public void messageWhenMutateWithMockUserThenForbidden() { this.rest @@ -128,10 +123,10 @@ public class HelloWebfluxMethodApplicationTests { } private Consumer robsCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("rob", "rob"); + return (httpHeaders) -> httpHeaders.setBasicAuth("rob", "rob"); } private Consumer adminCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("admin", "admin"); + return (httpHeaders) -> httpHeaders.setBasicAuth("admin", "admin"); } } diff --git a/samples/boot/hellowebflux/src/integration-test/java/sample/HelloWebfluxApplicationITests.java b/samples/boot/hellowebflux/src/integration-test/java/sample/HelloWebfluxApplicationITests.java index a749cbb592..6d1dda2fac 100644 --- a/samples/boot/hellowebflux/src/integration-test/java/sample/HelloWebfluxApplicationITests.java +++ b/samples/boot/hellowebflux/src/integration-test/java/sample/HelloWebfluxApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.util.function.Consumer; @@ -68,10 +69,10 @@ public class HelloWebfluxApplicationITests { } private Consumer userCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "user"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "user"); } private Consumer invalidCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "INVALID"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "INVALID"); } } diff --git a/samples/boot/hellowebflux/src/test/java/sample/HelloWebfluxApplicationTests.java b/samples/boot/hellowebflux/src/test/java/sample/HelloWebfluxApplicationTests.java index 7d6c2a7bf6..cc7892116f 100644 --- a/samples/boot/hellowebflux/src/test/java/sample/HelloWebfluxApplicationTests.java +++ b/samples/boot/hellowebflux/src/test/java/sample/HelloWebfluxApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockUser; @@ -104,10 +105,10 @@ public class HelloWebfluxApplicationTests { } private Consumer userCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "user"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "user"); } private Consumer invalidCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "INVALID"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "INVALID"); } } diff --git a/samples/boot/hellowebfluxfn/src/integration-test/java/sample/HelloWebfluxFnApplicationITests.java b/samples/boot/hellowebfluxfn/src/integration-test/java/sample/HelloWebfluxFnApplicationITests.java index 80e805ff4f..de42538a78 100644 --- a/samples/boot/hellowebfluxfn/src/integration-test/java/sample/HelloWebfluxFnApplicationITests.java +++ b/samples/boot/hellowebfluxfn/src/integration-test/java/sample/HelloWebfluxFnApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.util.function.Consumer; @@ -74,10 +75,10 @@ public class HelloWebfluxFnApplicationITests { } private Consumer userCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "user"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "user"); } private Consumer invalidCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "INVALID"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "INVALID"); } } diff --git a/samples/boot/hellowebfluxfn/src/main/java/sample/HelloUserController.java b/samples/boot/hellowebfluxfn/src/main/java/sample/HelloUserController.java index e1fd699e3d..e0c0810cee 100644 --- a/samples/boot/hellowebfluxfn/src/main/java/sample/HelloUserController.java +++ b/samples/boot/hellowebfluxfn/src/main/java/sample/HelloUserController.java @@ -36,7 +36,7 @@ public class HelloUserController { public Mono hello(ServerRequest serverRequest) { return serverRequest.principal() .map(Principal::getName) - .flatMap(username -> + .flatMap((username) -> ServerResponse.ok() .contentType(MediaType.APPLICATION_JSON) .syncBody(Collections.singletonMap("message", "Hello " + username + "!")) diff --git a/samples/boot/hellowebfluxfn/src/test/java/sample/HelloWebfluxFnApplicationTests.java b/samples/boot/hellowebfluxfn/src/test/java/sample/HelloWebfluxFnApplicationTests.java index 3bdc2f7e03..826fe95cf3 100644 --- a/samples/boot/hellowebfluxfn/src/test/java/sample/HelloWebfluxFnApplicationTests.java +++ b/samples/boot/hellowebfluxfn/src/test/java/sample/HelloWebfluxFnApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockUser; @@ -105,10 +106,10 @@ public class HelloWebfluxFnApplicationTests { } private Consumer userCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "user"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "user"); } private Consumer invalidCredentials() { - return httpHeaders -> httpHeaders.setBasicAuth("user", "INVALID"); + return (httpHeaders) -> httpHeaders.setBasicAuth("user", "INVALID"); } } diff --git a/samples/boot/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldApplicationTests.java b/samples/boot/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldApplicationTests.java index 3a2bb375d5..faeeb015a2 100644 --- a/samples/boot/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldApplicationTests.java +++ b/samples/boot/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.Test; diff --git a/samples/boot/helloworld/src/main/java/org/springframework/security/samples/HelloWorldApplication.java b/samples/boot/helloworld/src/main/java/org/springframework/security/samples/HelloWorldApplication.java index 7ffeb8607c..170bcadef3 100644 --- a/samples/boot/helloworld/src/main/java/org/springframework/security/samples/HelloWorldApplication.java +++ b/samples/boot/helloworld/src/main/java/org/springframework/security/samples/HelloWorldApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/boot/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 9d7cf4c34c..c3c71cf2fe 100644 --- a/samples/boot/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/boot/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.context.annotation.Bean; @@ -34,11 +35,11 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorize -> authorize + .authorizeRequests((authorize) -> authorize .antMatchers("/css/**", "/index").permitAll() .antMatchers("/user/**").hasRole("USER") ) - .formLogin(formLogin -> formLogin + .formLogin((formLogin) -> formLogin .loginPage("/login") .failureUrl("/login-error") ); diff --git a/samples/boot/helloworld/src/main/java/org/springframework/security/samples/web/MainController.java b/samples/boot/helloworld/src/main/java/org/springframework/security/samples/web/MainController.java index b4d9c9d9e1..793a8e97eb 100644 --- a/samples/boot/helloworld/src/main/java/org/springframework/security/samples/web/MainController.java +++ b/samples/boot/helloworld/src/main/java/org/springframework/security/samples/web/MainController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.web; import org.springframework.stereotype.Controller; diff --git a/samples/boot/insecure/src/integration-test/java/org/springframework/security/samples/InsecureApplicationTests.java b/samples/boot/insecure/src/integration-test/java/org/springframework/security/samples/InsecureApplicationTests.java index 55200f47b3..f2c7542e5b 100644 --- a/samples/boot/insecure/src/integration-test/java/org/springframework/security/samples/InsecureApplicationTests.java +++ b/samples/boot/insecure/src/integration-test/java/org/springframework/security/samples/InsecureApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.Test; diff --git a/samples/boot/insecure/src/main/java/org/springframework/security/samples/InsecureApplication.java b/samples/boot/insecure/src/main/java/org/springframework/security/samples/InsecureApplication.java index a5e15a119b..aab24fa9d3 100644 --- a/samples/boot/insecure/src/main/java/org/springframework/security/samples/InsecureApplication.java +++ b/samples/boot/insecure/src/main/java/org/springframework/security/samples/InsecureApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/insecure/src/main/java/org/springframework/security/samples/web/MainController.java b/samples/boot/insecure/src/main/java/org/springframework/security/samples/web/MainController.java index 06c85aa0d6..f7de44adcc 100644 --- a/samples/boot/insecure/src/main/java/org/springframework/security/samples/web/MainController.java +++ b/samples/boot/insecure/src/main/java/org/springframework/security/samples/web/MainController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.web; import org.springframework.stereotype.Controller; diff --git a/samples/boot/oauth2authorizationserver/src/main/java/sample/AuthorizationServerConfiguration.java b/samples/boot/oauth2authorizationserver/src/main/java/sample/AuthorizationServerConfiguration.java index 8a47b9b7a0..4b20b0c88c 100644 --- a/samples/boot/oauth2authorizationserver/src/main/java/sample/AuthorizationServerConfiguration.java +++ b/samples/boot/oauth2authorizationserver/src/main/java/sample/AuthorizationServerConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.math.BigInteger; @@ -166,7 +167,7 @@ class UserConfig extends WebSecurityConfigurerAdapter { .and() .httpBasic() .and() - .csrf().ignoringRequestMatchers(request -> "/introspect".equals(request.getRequestURI())); + .csrf().ignoringRequestMatchers((request) -> "/introspect".equals(request.getRequestURI())); } @Bean diff --git a/samples/boot/oauth2authorizationserver/src/main/java/sample/OAuth2AuthorizationServerApplication.java b/samples/boot/oauth2authorizationserver/src/main/java/sample/OAuth2AuthorizationServerApplication.java index f0ab3a4358..13602ec74e 100644 --- a/samples/boot/oauth2authorizationserver/src/main/java/sample/OAuth2AuthorizationServerApplication.java +++ b/samples/boot/oauth2authorizationserver/src/main/java/sample/OAuth2AuthorizationServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2authorizationserver/src/test/java/sample/OAuth2AuthorizationServerApplicationTests.java b/samples/boot/oauth2authorizationserver/src/test/java/sample/OAuth2AuthorizationServerApplicationTests.java index d9efa503bc..05377bf3f8 100644 --- a/samples/boot/oauth2authorizationserver/src/test/java/sample/OAuth2AuthorizationServerApplicationTests.java +++ b/samples/boot/oauth2authorizationserver/src/test/java/sample/OAuth2AuthorizationServerApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; diff --git a/samples/boot/oauth2login-webflux/src/integration-test/java/sample/OAuth2LoginApplicationTests.java b/samples/boot/oauth2login-webflux/src/integration-test/java/sample/OAuth2LoginApplicationTests.java index 77be0fa63b..c3e2b63af6 100644 --- a/samples/boot/oauth2login-webflux/src/integration-test/java/sample/OAuth2LoginApplicationTests.java +++ b/samples/boot/oauth2login-webflux/src/integration-test/java/sample/OAuth2LoginApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -45,7 +46,7 @@ public class OAuth2LoginApplicationTests { @Test public void requestWhenMockOidcLoginThenIndex() { this.clientRegistrationRepository.findByRegistrationId("github") - .map(clientRegistration -> + .map((clientRegistration) -> this.test.mutateWith(mockOAuth2Login().clientRegistration(clientRegistration)) .get().uri("/") .exchange() diff --git a/samples/boot/oauth2login-webflux/src/main/java/sample/ReactiveOAuth2LoginApplication.java b/samples/boot/oauth2login-webflux/src/main/java/sample/ReactiveOAuth2LoginApplication.java index 42fb0d9412..f6fd94e849 100644 --- a/samples/boot/oauth2login-webflux/src/main/java/sample/ReactiveOAuth2LoginApplication.java +++ b/samples/boot/oauth2login-webflux/src/main/java/sample/ReactiveOAuth2LoginApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2login-webflux/src/test/java/sample/OAuth2LoginControllerTests.java b/samples/boot/oauth2login-webflux/src/test/java/sample/OAuth2LoginControllerTests.java index d1fb3c5b3e..11ae5f8ba6 100644 --- a/samples/boot/oauth2login-webflux/src/test/java/sample/OAuth2LoginControllerTests.java +++ b/samples/boot/oauth2login-webflux/src/test/java/sample/OAuth2LoginControllerTests.java @@ -65,12 +65,12 @@ public class OAuth2LoginControllerTests { .bindToController(this.controller) .apply(springSecurity()) .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .argumentResolvers(c -> { + .argumentResolvers((c) -> { c.addCustomResolver(new AuthenticationPrincipalArgumentResolver(new ReactiveAdapterRegistry())); c.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver (this.clientRegistrationRepository, this.authorizedClientRepository)); }) - .viewResolvers(c -> c.viewResolver(this.viewResolver)) + .viewResolvers((c) -> c.viewResolver(this.viewResolver)) .build(); } diff --git a/samples/boot/oauth2login/src/integration-test/java/sample/OAuth2LoginApplicationTests.java b/samples/boot/oauth2login/src/integration-test/java/sample/OAuth2LoginApplicationTests.java index 077a51abf3..77867dbbc8 100644 --- a/samples/boot/oauth2login/src/integration-test/java/sample/OAuth2LoginApplicationTests.java +++ b/samples/boot/oauth2login/src/integration-test/java/sample/OAuth2LoginApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.net.URI; @@ -307,7 +308,7 @@ public class OAuth2LoginApplicationTests { private HtmlAnchor getClientAnchorElement(HtmlPage page, ClientRegistration clientRegistration) { Optional clientAnchorElement = page.getAnchors().stream() - .filter(e -> e.asText().equals(clientRegistration.getClientName())).findFirst(); + .filter((e) -> e.asText().equals(clientRegistration.getClientName())).findFirst(); return (clientAnchorElement.orElse(null)); } @@ -334,17 +335,17 @@ public class OAuth2LoginApplicationTests { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) - .oauth2Login(oauth2Login -> + .oauth2Login((oauth2Login) -> oauth2Login - .tokenEndpoint(tokenEndpoint -> + .tokenEndpoint((tokenEndpoint) -> tokenEndpoint .accessTokenResponseClient(this.mockAccessTokenResponseClient()) ) - .userInfoEndpoint(userInfoEndpoint -> + .userInfoEndpoint((userInfoEndpoint) -> userInfoEndpoint .userService(this.mockUserService()) ) diff --git a/samples/boot/oauth2login/src/main/java/sample/OAuth2LoginApplication.java b/samples/boot/oauth2login/src/main/java/sample/OAuth2LoginApplication.java index 6496b2d0a3..dc69b3f9ec 100644 --- a/samples/boot/oauth2login/src/main/java/sample/OAuth2LoginApplication.java +++ b/samples/boot/oauth2login/src/main/java/sample/OAuth2LoginApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java b/samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java index e70bd8744a..fa46a60ad7 100644 --- a/samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java +++ b/samples/boot/oauth2login/src/main/java/sample/web/OAuth2LoginController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.web; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2login/src/test/java/sample/web/OAuth2LoginControllerTests.java b/samples/boot/oauth2login/src/test/java/sample/web/OAuth2LoginControllerTests.java index 3779bb7ff8..6a3af42ff7 100644 --- a/samples/boot/oauth2login/src/test/java/sample/web/OAuth2LoginControllerTests.java +++ b/samples/boot/oauth2login/src/test/java/sample/web/OAuth2LoginControllerTests.java @@ -63,7 +63,7 @@ public class OAuth2LoginControllerTests { this.mvc.perform(get("/").with(oauth2Login() .clientRegistration(clientRegistration) - .attributes(a -> a.put("sub", "spring-security")))) + .attributes((a) -> a.put("sub", "spring-security")))) .andExpect(model().attribute("userName", "spring-security")) .andExpect(model().attribute("clientName", "my-client-name")) .andExpect(model().attribute("userAttributes", Collections.singletonMap("sub", "spring-security"))); diff --git a/samples/boot/oauth2resourceserver-jwe/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java b/samples/boot/oauth2resourceserver-jwe/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java index b10249a255..cd25346ce8 100644 --- a/samples/boot/oauth2resourceserver-jwe/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java +++ b/samples/boot/oauth2resourceserver-jwe/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -60,8 +61,6 @@ public class OAuth2ResourceServerApplicationITests { .andExpect(content().string(containsString("Hello, subject!"))); } - // -- tests with scopes - @Test public void performWhenValidBearerTokenThenScopedRequestsAlsoWork() throws Exception { diff --git a/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerApplication.java b/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerApplication.java index e854631aa2..a0841c00f3 100644 --- a/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerApplication.java +++ b/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerController.java b/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerController.java index 87b123b7f8..f0bcdbe64f 100644 --- a/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerController.java +++ b/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java b/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java index 827a9628f7..dfe4135d4b 100644 --- a/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java +++ b/samples/boot/oauth2resourceserver-jwe/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.net.URL; @@ -68,12 +69,12 @@ public class OAuth2ResourceServerSecurityConfiguration extends WebSecurityConfig protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers("/message/**").hasAuthority("SCOPE_message:read") .anyRequest().authenticated() ) - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer .jwt(withDefaults()) ); diff --git a/samples/boot/oauth2resourceserver-multitenancy/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java b/samples/boot/oauth2resourceserver-multitenancy/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java index c50157dd92..be6d3f559c 100644 --- a/samples/boot/oauth2resourceserver-multitenancy/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java +++ b/samples/boot/oauth2resourceserver-multitenancy/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -60,8 +61,6 @@ public class OAuth2ResourceServerApplicationITests { .andExpect(content().string(containsString("Hello, subject for tenant one!"))); } - // -- tests with scopes - @Test public void tenantOnePerformWhenValidBearerTokenThenScopedRequestsAlsoWork() throws Exception { @@ -96,8 +95,6 @@ public class OAuth2ResourceServerApplicationITests { .andExpect(content().string(containsString("Hello, subject for tenant two!"))); } - // -- tests with scopes - @Test public void tenantTwoPerformWhenValidBearerTokenThenScopedRequestsAlsoWork() throws Exception { diff --git a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java index f0663200de..9e3f58e628 100644 --- a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java +++ b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java @@ -156,10 +156,10 @@ public class MockWebServerPropertySource extends PropertySource i if ("/introspect".equals(request.getPath())) { return Optional.ofNullable(request.getHeader(HttpHeaders.AUTHORIZATION)) - .filter(authorization -> isAuthorized(authorization, "client", "secret")) - .map(authorization -> parseBody(request.getBody())) - .map(parameters -> parameters.get("token")) - .map(token -> { + .filter((authorization) -> isAuthorized(authorization, "client", "secret")) + .map((authorization) -> parseBody(request.getBody())) + .map((parameters) -> parameters.get("token")) + .map((token) -> { if ("00ed5855-1869-47a0-b0c9-0f3ce520aee7".equals(token)) { return NO_SCOPES_RESPONSE; } else if ("b43d1500-c405-4dc9-b9c9-6cfd966c34c9".equals(token)) { @@ -181,8 +181,8 @@ public class MockWebServerPropertySource extends PropertySource i private Map parseBody(Buffer body) { return Stream.of(body.readUtf8().split("&")) - .map(parameter -> parameter.split("=")) - .collect(Collectors.toMap(parts -> parts[0], parts -> parts[1])); + .map((parameter) -> parameter.split("=")) + .collect(Collectors.toMap((parts) -> parts[0], (parts) -> parts[1])); } private static MockResponse response(String body, int status) { diff --git a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerApplication.java b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerApplication.java index 06fc946601..d5c70cc70d 100644 --- a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerApplication.java +++ b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerController.java b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerController.java index 18165789a7..1dce6e718a 100644 --- a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerController.java +++ b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java index a52933b648..75bf3865b9 100644 --- a/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java +++ b/samples/boot/oauth2resourceserver-multitenancy/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import javax.servlet.http.HttpServletRequest; @@ -36,11 +37,11 @@ public class OAuth2ResourceServerSecurityConfiguration extends WebSecurityConfig protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authz -> authz + .authorizeRequests((authz) -> authz .antMatchers("/message/**").hasAuthority("SCOPE_message:read") .anyRequest().authenticated() ) - .oauth2ResourceServer(oauth2 -> oauth2 + .oauth2ResourceServer((oauth2) -> oauth2 .authenticationManagerResolver(this.authenticationManagerResolver) ); // @formatter:on diff --git a/samples/boot/oauth2resourceserver-opaque/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java b/samples/boot/oauth2resourceserver-opaque/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java index cbeeb8ebda..10bd3fc7ec 100644 --- a/samples/boot/oauth2resourceserver-opaque/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java +++ b/samples/boot/oauth2resourceserver-opaque/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -60,8 +61,6 @@ public class OAuth2ResourceServerApplicationITests { .andExpect(content().string(containsString("Hello, subject!"))); } - // -- tests with scopes - @Test public void performWhenValidBearerTokenThenScopedRequestsAlsoWork() throws Exception { diff --git a/samples/boot/oauth2resourceserver-opaque/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java b/samples/boot/oauth2resourceserver-opaque/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java index 32e1cdbd35..0bc6b57558 100644 --- a/samples/boot/oauth2resourceserver-opaque/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java +++ b/samples/boot/oauth2resourceserver-opaque/src/main/java/org/springframework/boot/env/MockWebServerPropertySource.java @@ -142,10 +142,10 @@ public class MockWebServerPropertySource extends PropertySource i private MockResponse doDispatch(RecordedRequest request) { if ("/introspect".equals(request.getPath())) { return Optional.ofNullable(request.getHeader(HttpHeaders.AUTHORIZATION)) - .filter(authorization -> isAuthorized(authorization, "client", "secret")) - .map(authorization -> parseBody(request.getBody())) - .map(parameters -> parameters.get("token")) - .map(token -> { + .filter((authorization) -> isAuthorized(authorization, "client", "secret")) + .map((authorization) -> parseBody(request.getBody())) + .map((parameters) -> parameters.get("token")) + .map((token) -> { if ("00ed5855-1869-47a0-b0c9-0f3ce520aee7".equals(token)) { return NO_SCOPES_RESPONSE; } else if ("b43d1500-c405-4dc9-b9c9-6cfd966c34c9".equals(token)) { @@ -167,8 +167,8 @@ public class MockWebServerPropertySource extends PropertySource i private Map parseBody(Buffer body) { return Stream.of(body.readUtf8().split("&")) - .map(parameter -> parameter.split("=")) - .collect(Collectors.toMap(parts -> parts[0], parts -> parts[1])); + .map((parameter) -> parameter.split("=")) + .collect(Collectors.toMap((parts) -> parts[0], (parts) -> parts[1])); } private static MockResponse response(String body, int status) { diff --git a/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerApplication.java b/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerApplication.java index 06fc946601..d5c70cc70d 100644 --- a/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerApplication.java +++ b/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerController.java b/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerController.java index f06db58ffd..857d29be66 100644 --- a/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerController.java +++ b/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java b/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java index 42d424fd7a..6fc3444c73 100644 --- a/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java +++ b/samples/boot/oauth2resourceserver-opaque/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.beans.factory.annotation.Value; @@ -35,15 +36,15 @@ public class OAuth2ResourceServerSecurityConfiguration extends WebSecurityConfig protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers(HttpMethod.GET, "/message/**").hasAuthority("SCOPE_message:read") .antMatchers(HttpMethod.POST, "/message/**").hasAuthority("SCOPE_message:write") .anyRequest().authenticated() ) - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer - .opaqueToken(opaqueToken -> + .opaqueToken((opaqueToken) -> opaqueToken .introspectionUri(this.introspectionUri) .introspectionClientCredentials(this.clientId, this.clientSecret) diff --git a/samples/boot/oauth2resourceserver-opaque/src/test/java/sample/OAuth2ResourceServerControllerTests.java b/samples/boot/oauth2resourceserver-opaque/src/test/java/sample/OAuth2ResourceServerControllerTests.java index 48839acd3d..ee3318f246 100644 --- a/samples/boot/oauth2resourceserver-opaque/src/test/java/sample/OAuth2ResourceServerControllerTests.java +++ b/samples/boot/oauth2resourceserver-opaque/src/test/java/sample/OAuth2ResourceServerControllerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -45,13 +46,13 @@ public class OAuth2ResourceServerControllerTests { @Test public void indexGreetsAuthenticatedUser() throws Exception { - this.mvc.perform(get("/").with(opaqueToken().attributes(a -> a.put("sub", "ch4mpy")))) + this.mvc.perform(get("/").with(opaqueToken().attributes((a) -> a.put("sub", "ch4mpy")))) .andExpect(content().string(is("Hello, ch4mpy!"))); } @Test public void messageCanBeReadWithScopeMessageReadAuthority() throws Exception { - this.mvc.perform(get("/message").with(opaqueToken().attributes(a -> a.put("scope", "message:read")))) + this.mvc.perform(get("/message").with(opaqueToken().attributes((a) -> a.put("scope", "message:read")))) .andExpect(content().string(is("secret message"))); this.mvc.perform(get("/message") diff --git a/samples/boot/oauth2resourceserver-static/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java b/samples/boot/oauth2resourceserver-static/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java index a5d644662a..851d3bd242 100644 --- a/samples/boot/oauth2resourceserver-static/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java +++ b/samples/boot/oauth2resourceserver-static/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -60,8 +61,6 @@ public class OAuth2ResourceServerApplicationITests { .andExpect(content().string(containsString("Hello, subject!"))); } - // -- tests with scopes - @Test public void performWhenValidBearerTokenThenScopedRequestsAlsoWork() throws Exception { diff --git a/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerApplication.java b/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerApplication.java index e854631aa2..a0841c00f3 100644 --- a/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerApplication.java +++ b/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerController.java b/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerController.java index 87b123b7f8..f0bcdbe64f 100644 --- a/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerController.java +++ b/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java b/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java index bdc248d5aa..a57bed2d51 100644 --- a/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java +++ b/samples/boot/oauth2resourceserver-static/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import java.security.interfaces.RSAPublicKey; @@ -38,14 +39,14 @@ public class OAuth2ResourceServerSecurityConfiguration extends WebSecurityConfig protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers("/message/**").hasAuthority("SCOPE_message:read") .anyRequest().authenticated() ) - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer - .jwt(jwt -> + .jwt((jwt) -> jwt.decoder(jwtDecoder()) ) ); diff --git a/samples/boot/oauth2resourceserver-webflux/src/integration-test/java/sample/ServerOAuth2ResourceServerApplicationITests.java b/samples/boot/oauth2resourceserver-webflux/src/integration-test/java/sample/ServerOAuth2ResourceServerApplicationITests.java index 13d09b16c1..9cbdcf2b36 100644 --- a/samples/boot/oauth2resourceserver-webflux/src/integration-test/java/sample/ServerOAuth2ResourceServerApplicationITests.java +++ b/samples/boot/oauth2resourceserver-webflux/src/integration-test/java/sample/ServerOAuth2ResourceServerApplicationITests.java @@ -38,8 +38,8 @@ import static org.hamcrest.Matchers.containsString; @RunWith(SpringJUnit4ClassRunner.class) public class ServerOAuth2ResourceServerApplicationITests { - Consumer noScopesToken = http -> http.setBearerAuth("eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJzdWJqZWN0IiwiZXhwIjo0NjgzODA1MTI4fQ.ULEPdHG-MK5GlrTQMhgqcyug2brTIZaJIrahUeq9zaiwUSdW83fJ7W1IDd2Z3n4a25JY2uhEcoV95lMfccHR6y_2DLrNvfta22SumY9PEDF2pido54LXG6edIGgarnUbJdR4rpRe_5oRGVa8gDx8FnuZsNv6StSZHAzw5OsuevSTJ1UbJm4UfX3wiahFOQ2OI6G-r5TB2rQNdiPHuNyzG5yznUqRIZ7-GCoMqHMaC-1epKxiX8gYXRROuUYTtcMNa86wh7OVDmvwVmFioRcR58UWBRoO1XQexTtOQq_t8KYsrPZhb9gkyW8x2bAQF-d0J0EJY8JslaH6n4RBaZISww"); - Consumer messageReadToken = http -> http.setBearerAuth("eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJzdWJqZWN0Iiwic2NvcGUiOiJtZXNzYWdlOnJlYWQiLCJleHAiOjQ2ODM4MDUxNDF9.h-j6FKRFdnTdmAueTZCdep45e6DPwqM68ZQ8doIJ1exi9YxAlbWzOwId6Bd0L5YmCmp63gGQgsBUBLzwnZQ8kLUgUOBEC3UzSWGRqMskCY9_k9pX0iomX6IfF3N0PaYs0WPC4hO1s8wfZQ-6hKQ4KigFi13G9LMLdH58PRMK0pKEvs3gCbHJuEPw-K5ORlpdnleUTQIwINafU57cmK3KocTeknPAM_L716sCuSYGvDl6xUTXO7oPdrXhS_EhxLP6KxrpI1uD4Ea_5OWTh7S0Wx5LLDfU6wBG1DowN20d374zepOIEkR-Jnmr_QlR44vmRqS5ncrF-1R0EGcPX49U6A"); + Consumer noScopesToken = (http) -> http.setBearerAuth("eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJzdWJqZWN0IiwiZXhwIjo0NjgzODA1MTI4fQ.ULEPdHG-MK5GlrTQMhgqcyug2brTIZaJIrahUeq9zaiwUSdW83fJ7W1IDd2Z3n4a25JY2uhEcoV95lMfccHR6y_2DLrNvfta22SumY9PEDF2pido54LXG6edIGgarnUbJdR4rpRe_5oRGVa8gDx8FnuZsNv6StSZHAzw5OsuevSTJ1UbJm4UfX3wiahFOQ2OI6G-r5TB2rQNdiPHuNyzG5yznUqRIZ7-GCoMqHMaC-1epKxiX8gYXRROuUYTtcMNa86wh7OVDmvwVmFioRcR58UWBRoO1XQexTtOQq_t8KYsrPZhb9gkyW8x2bAQF-d0J0EJY8JslaH6n4RBaZISww"); + Consumer messageReadToken = (http) -> http.setBearerAuth("eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJzdWJqZWN0Iiwic2NvcGUiOiJtZXNzYWdlOnJlYWQiLCJleHAiOjQ2ODM4MDUxNDF9.h-j6FKRFdnTdmAueTZCdep45e6DPwqM68ZQ8doIJ1exi9YxAlbWzOwId6Bd0L5YmCmp63gGQgsBUBLzwnZQ8kLUgUOBEC3UzSWGRqMskCY9_k9pX0iomX6IfF3N0PaYs0WPC4hO1s8wfZQ-6hKQ4KigFi13G9LMLdH58PRMK0pKEvs3gCbHJuEPw-K5ORlpdnleUTQIwINafU57cmK3KocTeknPAM_L716sCuSYGvDl6xUTXO7oPdrXhS_EhxLP6KxrpI1uD4Ea_5OWTh7S0Wx5LLDfU6wBG1DowN20d374zepOIEkR-Jnmr_QlR44vmRqS5ncrF-1R0EGcPX49U6A"); @Autowired private WebTestClient rest; @@ -55,8 +55,6 @@ public class ServerOAuth2ResourceServerApplicationITests { .expectBody(String.class).isEqualTo("Hello, subject!"); } - // -- tests with scopes - @Test public void getWhenValidBearerTokenThenScopedRequestsAlsoWork() { diff --git a/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/OAuth2ResourceServerController.java b/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/OAuth2ResourceServerController.java index 4e0fed8cdd..ea34b99e20 100644 --- a/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/OAuth2ResourceServerController.java +++ b/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/OAuth2ResourceServerController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/SecurityConfig.java b/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/SecurityConfig.java index fad9309c02..d4f6429b51 100644 --- a/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/SecurityConfig.java +++ b/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/SecurityConfig.java @@ -34,13 +34,13 @@ public class SecurityConfig { @Bean SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { http - .authorizeExchange(exchanges -> + .authorizeExchange((exchanges) -> exchanges .pathMatchers(HttpMethod.GET, "/message/**").hasAuthority("SCOPE_message:read") .pathMatchers(HttpMethod.POST, "/message/**").hasAuthority("SCOPE_message:write") .anyExchange().authenticated() ) - .oauth2ResourceServer(oauth2ResourceServer -> + .oauth2ResourceServer((oauth2ResourceServer) -> oauth2ResourceServer .jwt(withDefaults()) ); diff --git a/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/ServerOAuth2ResourceServerApplication.java b/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/ServerOAuth2ResourceServerApplication.java index 4c63ed00ab..c3e3575e26 100644 --- a/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/ServerOAuth2ResourceServerApplication.java +++ b/samples/boot/oauth2resourceserver-webflux/src/main/java/sample/ServerOAuth2ResourceServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2resourceserver-webflux/src/test/java/sample/OAuth2ResourceServerControllerTests.java b/samples/boot/oauth2resourceserver-webflux/src/test/java/sample/OAuth2ResourceServerControllerTests.java index 3f1291c6fa..8fc72574f1 100644 --- a/samples/boot/oauth2resourceserver-webflux/src/test/java/sample/OAuth2ResourceServerControllerTests.java +++ b/samples/boot/oauth2resourceserver-webflux/src/test/java/sample/OAuth2ResourceServerControllerTests.java @@ -50,14 +50,14 @@ public class OAuth2ResourceServerControllerTests { @Test public void indexGreetsAuthenticatedUser() { - this.rest.mutateWith(mockJwt().jwt(jwt -> jwt.subject("test-subject"))) + this.rest.mutateWith(mockJwt().jwt((jwt) -> jwt.subject("test-subject"))) .get().uri("/").exchange() .expectBody(String.class).isEqualTo("Hello, test-subject!"); } @Test public void messageCanBeReadWithScopeMessageReadAuthority() { - this.rest.mutateWith(mockJwt().jwt(jwt -> jwt.claim("scope", "message:read"))) + this.rest.mutateWith(mockJwt().jwt((jwt) -> jwt.claim("scope", "message:read"))) .get().uri("/message").exchange() .expectBody(String.class).isEqualTo("secret message"); @@ -78,7 +78,7 @@ public class OAuth2ResourceServerControllerTests { Jwt jwt = jwt().claim("scope", "").build(); when(this.jwtDecoder.decode(anyString())).thenReturn(Mono.just(jwt)); this.rest.post().uri("/message") - .headers(headers -> headers.setBearerAuth(jwt.getTokenValue())) + .headers((headers) -> headers.setBearerAuth(jwt.getTokenValue())) .syncBody("Hello message").exchange() .expectStatus().isForbidden(); } @@ -88,7 +88,7 @@ public class OAuth2ResourceServerControllerTests { Jwt jwt = jwt().claim("scope", "message:read").build(); when(this.jwtDecoder.decode(anyString())).thenReturn(Mono.just(jwt)); this.rest.post().uri("/message") - .headers(headers -> headers.setBearerAuth(jwt.getTokenValue())) + .headers((headers) -> headers.setBearerAuth(jwt.getTokenValue())) .syncBody("Hello message").exchange() .expectStatus().isForbidden(); } @@ -98,7 +98,7 @@ public class OAuth2ResourceServerControllerTests { Jwt jwt = jwt().claim("scope", "message:write").build(); when(this.jwtDecoder.decode(anyString())).thenReturn(Mono.just(jwt)); this.rest.post().uri("/message") - .headers(headers -> headers.setBearerAuth(jwt.getTokenValue())) + .headers((headers) -> headers.setBearerAuth(jwt.getTokenValue())) .syncBody("Hello message").exchange() .expectStatus().isOk() .expectBody(String.class).isEqualTo("Message was created. Content: Hello message"); diff --git a/samples/boot/oauth2resourceserver/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java b/samples/boot/oauth2resourceserver/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java index 6479417499..bb7de58665 100644 --- a/samples/boot/oauth2resourceserver/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java +++ b/samples/boot/oauth2resourceserver/src/integration-test/java/sample/OAuth2ResourceServerApplicationITests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -62,8 +63,6 @@ public class OAuth2ResourceServerApplicationITests { .andExpect(content().string(containsString("Hello, subject!"))); } - // -- tests with scopes - @Test public void performWhenValidBearerTokenThenScopedRequestsAlsoWork() throws Exception { diff --git a/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerApplication.java b/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerApplication.java index e854631aa2..a0841c00f3 100644 --- a/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerApplication.java +++ b/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerController.java b/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerController.java index 241f054885..c761078dd8 100644 --- a/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerController.java +++ b/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java b/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java index 819e0bc5e3..d7e157cfc6 100644 --- a/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java +++ b/samples/boot/oauth2resourceserver/src/main/java/sample/OAuth2ResourceServerSecurityConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.beans.factory.annotation.Value; @@ -37,7 +38,7 @@ public class OAuth2ResourceServerSecurityConfiguration extends WebSecurityConfig protected void configure(HttpSecurity http) throws Exception { // @formatter:off http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers(HttpMethod.GET, "/message/**").hasAuthority("SCOPE_message:read") .antMatchers(HttpMethod.POST, "/message/**").hasAuthority("SCOPE_message:write") diff --git a/samples/boot/oauth2resourceserver/src/test/java/sample/OAuth2ResourceServerControllerTests.java b/samples/boot/oauth2resourceserver/src/test/java/sample/OAuth2ResourceServerControllerTests.java index 46ec794827..565212b073 100644 --- a/samples/boot/oauth2resourceserver/src/test/java/sample/OAuth2ResourceServerControllerTests.java +++ b/samples/boot/oauth2resourceserver/src/test/java/sample/OAuth2ResourceServerControllerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; @@ -47,13 +48,13 @@ public class OAuth2ResourceServerControllerTests { @Test public void indexGreetsAuthenticatedUser() throws Exception { - mockMvc.perform(get("/").with(jwt().jwt(jwt -> jwt.subject("ch4mpy")))) + mockMvc.perform(get("/").with(jwt().jwt((jwt) -> jwt.subject("ch4mpy")))) .andExpect(content().string(is("Hello, ch4mpy!"))); } @Test public void messageCanBeReadWithScopeMessageReadAuthority() throws Exception { - mockMvc.perform(get("/message").with(jwt().jwt(jwt -> jwt.claim("scope", "message:read")))) + mockMvc.perform(get("/message").with(jwt().jwt((jwt) -> jwt.claim("scope", "message:read")))) .andExpect(content().string(is("secret message"))); mockMvc.perform(get("/message") @@ -79,7 +80,7 @@ public class OAuth2ResourceServerControllerTests { public void messageCanNotBeCreatedWithScopeMessageReadAuthority() throws Exception { mockMvc.perform(post("/message") .content("Hello message") - .with(jwt().jwt(jwt -> jwt.claim("scope", "message:read")))) + .with(jwt().jwt((jwt) -> jwt.claim("scope", "message:read")))) .andExpect(status().isForbidden()); } @@ -88,7 +89,7 @@ public class OAuth2ResourceServerControllerTests { throws Exception { mockMvc.perform(post("/message") .content("Hello message") - .with(jwt().jwt(jwt -> jwt.claim("scope", "message:write")))) + .with(jwt().jwt((jwt) -> jwt.claim("scope", "message:write")))) .andExpect(status().isOk()) .andExpect(content().string(is("Message was created. Content: Hello message"))); } diff --git a/samples/boot/oauth2webclient-webflux/src/main/java/sample/OAuth2WebClientWebFluxApplication.java b/samples/boot/oauth2webclient-webflux/src/main/java/sample/OAuth2WebClientWebFluxApplication.java index 1882440595..cdd6416f93 100644 --- a/samples/boot/oauth2webclient-webflux/src/main/java/sample/OAuth2WebClientWebFluxApplication.java +++ b/samples/boot/oauth2webclient-webflux/src/main/java/sample/OAuth2WebClientWebFluxApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/SecurityConfig.java b/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/SecurityConfig.java index bea4effe96..33fb8b7867 100644 --- a/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/SecurityConfig.java +++ b/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.config; import org.springframework.context.annotation.Bean; @@ -34,7 +35,7 @@ public class SecurityConfig { @Bean SecurityWebFilterChain configure(ServerHttpSecurity http) { http - .authorizeExchange(exchanges -> + .authorizeExchange((exchanges) -> exchanges .pathMatchers("/", "/public/**").permitAll() .anyExchange().authenticated() diff --git a/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/OAuth2WebClientController.java b/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/OAuth2WebClientController.java index f7daadb18b..a1ddae90aa 100644 --- a/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/OAuth2WebClientController.java +++ b/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/OAuth2WebClientController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.web; import reactor.core.publisher.Mono; diff --git a/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java b/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java index 89961e84f2..558277394f 100644 --- a/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java +++ b/samples/boot/oauth2webclient-webflux/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.web; import reactor.core.publisher.Mono; diff --git a/samples/boot/oauth2webclient/src/main/java/sample/OAuth2WebClientApplication.java b/samples/boot/oauth2webclient/src/main/java/sample/OAuth2WebClientApplication.java index 89cf9a5fcd..4c4a6b7de1 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/OAuth2WebClientApplication.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/OAuth2WebClientApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/oauth2webclient/src/main/java/sample/config/SecurityConfig.java b/samples/boot/oauth2webclient/src/main/java/sample/config/SecurityConfig.java index 80cae2f658..e5602d812d 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/config/SecurityConfig.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.config; import org.springframework.context.annotation.Bean; @@ -35,7 +36,7 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .mvcMatchers("/", "/public/**").permitAll() .anyRequest().authenticated() diff --git a/samples/boot/oauth2webclient/src/main/java/sample/web/OAuth2WebClientController.java b/samples/boot/oauth2webclient/src/main/java/sample/web/OAuth2WebClientController.java index 226b5e9bcd..f20b479d1f 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/web/OAuth2WebClientController.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/web/OAuth2WebClientController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.web; import org.springframework.stereotype.Controller; diff --git a/samples/boot/oauth2webclient/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java b/samples/boot/oauth2webclient/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java index b0b8200f30..883db445d0 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/web/RegisteredOAuth2AuthorizedClientController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.web; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; diff --git a/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java b/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java index 176826fea4..3b6115f459 100644 --- a/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java +++ b/samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java @@ -483,7 +483,7 @@ public class Saml2LoginIntegrationTests { String code, Matcher message ) { - return result -> { + return (result) -> { final HttpSession session = result.getRequest().getSession(false); AssertionErrors.assertNotNull("HttpSession", session); Object exception = session.getAttribute(AUTHENTICATION_EXCEPTION); diff --git a/samples/boot/saml2login/src/main/java/sample/Saml2LoginApplication.java b/samples/boot/saml2login/src/main/java/sample/Saml2LoginApplication.java index 7162c406ab..08202d28de 100644 --- a/samples/boot/saml2login/src/main/java/sample/Saml2LoginApplication.java +++ b/samples/boot/saml2login/src/main/java/sample/Saml2LoginApplication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.springframework.boot.SpringApplication; diff --git a/samples/boot/webflux-form/src/integration-test/java/sample/WebfluxFormApplicationTests.java b/samples/boot/webflux-form/src/integration-test/java/sample/WebfluxFormApplicationTests.java index 41cda60a5c..c9062faa59 100644 --- a/samples/boot/webflux-form/src/integration-test/java/sample/WebfluxFormApplicationTests.java +++ b/samples/boot/webflux-form/src/integration-test/java/sample/WebfluxFormApplicationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Before; diff --git a/samples/boot/webflux-form/src/main/java/sample/WebfluxFormSecurityConfig.java b/samples/boot/webflux-form/src/main/java/sample/WebfluxFormSecurityConfig.java index c87bc0a4af..94642742dc 100644 --- a/samples/boot/webflux-form/src/main/java/sample/WebfluxFormSecurityConfig.java +++ b/samples/boot/webflux-form/src/main/java/sample/WebfluxFormSecurityConfig.java @@ -46,13 +46,13 @@ public class WebfluxFormSecurityConfig { @Bean SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { http - .authorizeExchange(exchanges -> + .authorizeExchange((exchanges) -> exchanges .pathMatchers("/login").permitAll() .anyExchange().authenticated() ) .httpBasic(withDefaults()) - .formLogin(formLogin -> + .formLogin((formLogin) -> formLogin .loginPage("/login") ); diff --git a/samples/boot/webflux-x509/src/main/java/sample/MeController.java b/samples/boot/webflux-x509/src/main/java/sample/MeController.java index f7a3958784..7c10c6aedf 100644 --- a/samples/boot/webflux-x509/src/main/java/sample/MeController.java +++ b/samples/boot/webflux-x509/src/main/java/sample/MeController.java @@ -35,6 +35,6 @@ public class MeController { public Mono me() { return ReactiveSecurityContextHolder.getContext() .map(SecurityContext::getAuthentication) - .map(authentication -> "Hello, " + authentication.getName()); + .map((authentication) -> "Hello, " + authentication.getName()); } } diff --git a/samples/boot/webflux-x509/src/main/java/sample/WebfluxX509Application.java b/samples/boot/webflux-x509/src/main/java/sample/WebfluxX509Application.java index 02e145c14b..89813d073b 100644 --- a/samples/boot/webflux-x509/src/main/java/sample/WebfluxX509Application.java +++ b/samples/boot/webflux-x509/src/main/java/sample/WebfluxX509Application.java @@ -46,7 +46,7 @@ public class WebfluxX509Application { // @formatter:off http .x509(withDefaults()) - .authorizeExchange(exchanges -> + .authorizeExchange((exchanges) -> exchanges .anyExchange().authenticated() ); diff --git a/samples/boot/webflux-x509/src/test/java/sample/WebfluxX509ApplicationTest.java b/samples/boot/webflux-x509/src/test/java/sample/WebfluxX509ApplicationTest.java index 604692d5ad..e6ed9cf200 100644 --- a/samples/boot/webflux-x509/src/test/java/sample/WebfluxX509ApplicationTest.java +++ b/samples/boot/webflux-x509/src/test/java/sample/WebfluxX509ApplicationTest.java @@ -55,7 +55,7 @@ public class WebfluxX509ApplicationTest { .exchange() .expectStatus().isOk() .expectBody() - .consumeWith(result -> { + .consumeWith((result) -> { String responseBody = new String(result.getResponseBody()); assertThat(responseBody).contains("Hello, client"); }); @@ -79,7 +79,7 @@ public class WebfluxX509ApplicationTest { .trustManager(devCA) .keyManager(clientKey, clientCrt); - HttpClient httpClient = HttpClient.create().secure(sslContextSpec -> sslContextSpec.sslContext(sslContextBuilder)); + HttpClient httpClient = HttpClient.create().secure((sslContextSpec) -> sslContextSpec.sslContext(sslContextBuilder)); ClientHttpConnector httpConnector = new ReactorClientHttpConnector(httpClient); return WebTestClient diff --git a/samples/javaconfig/aspectj/src/main/java/sample/aspectj/AspectjSecurityConfig.java b/samples/javaconfig/aspectj/src/main/java/sample/aspectj/AspectjSecurityConfig.java index 0bfe89b7f7..038b8271f3 100644 --- a/samples/javaconfig/aspectj/src/main/java/sample/aspectj/AspectjSecurityConfig.java +++ b/samples/javaconfig/aspectj/src/main/java/sample/aspectj/AspectjSecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.springframework.beans.factory.annotation.Autowired; diff --git a/samples/javaconfig/aspectj/src/main/java/sample/aspectj/SecuredService.java b/samples/javaconfig/aspectj/src/main/java/sample/aspectj/SecuredService.java index bf954915b8..9f81cf1733 100644 --- a/samples/javaconfig/aspectj/src/main/java/sample/aspectj/SecuredService.java +++ b/samples/javaconfig/aspectj/src/main/java/sample/aspectj/SecuredService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.springframework.security.access.annotation.Secured; diff --git a/samples/javaconfig/aspectj/src/main/java/sample/aspectj/Service.java b/samples/javaconfig/aspectj/src/main/java/sample/aspectj/Service.java index 2c5a433018..70b8c0c71f 100644 --- a/samples/javaconfig/aspectj/src/main/java/sample/aspectj/Service.java +++ b/samples/javaconfig/aspectj/src/main/java/sample/aspectj/Service.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.springframework.security.access.annotation.Secured; diff --git a/samples/javaconfig/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java b/samples/javaconfig/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java index 1c7cf918bf..344de63e11 100644 --- a/samples/javaconfig/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java +++ b/samples/javaconfig/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.junit.After; diff --git a/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 7de3308deb..c33dc58cf2 100644 --- a/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index d71dd73fc7..45773f3eb7 100644 --- a/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/concurrency/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; @@ -42,14 +43,14 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { protected void configure( HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) .formLogin(withDefaults()) - .sessionManagement(sessionManagement -> + .sessionManagement((sessionManagement) -> sessionManagement - .sessionConcurrency(sessionConcurrency -> + .sessionConcurrency((sessionConcurrency) -> sessionConcurrency .maximumSessions(1) .expiredUrl("/login?expired") diff --git a/samples/javaconfig/concurrency/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/concurrency/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index e77a3ef7c1..dd807f16d1 100644 --- a/samples/javaconfig/concurrency/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/concurrency/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/javaconfig/data/src/main/java/samples/DataConfig.java b/samples/javaconfig/data/src/main/java/samples/DataConfig.java index 063d637110..cd401a5df4 100644 --- a/samples/javaconfig/data/src/main/java/samples/DataConfig.java +++ b/samples/javaconfig/data/src/main/java/samples/DataConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples; import org.springframework.context.annotation.Bean; diff --git a/samples/javaconfig/data/src/main/java/samples/data/Message.java b/samples/javaconfig/data/src/main/java/samples/data/Message.java index 94e04ce520..7c567e9c12 100644 --- a/samples/javaconfig/data/src/main/java/samples/data/Message.java +++ b/samples/javaconfig/data/src/main/java/samples/data/Message.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.data; import java.util.Calendar; diff --git a/samples/javaconfig/data/src/main/java/samples/data/MessageRepository.java b/samples/javaconfig/data/src/main/java/samples/data/MessageRepository.java index f356f3065a..c98a1d3fde 100644 --- a/samples/javaconfig/data/src/main/java/samples/data/MessageRepository.java +++ b/samples/javaconfig/data/src/main/java/samples/data/MessageRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.data; import org.springframework.data.jpa.repository.JpaRepository; diff --git a/samples/javaconfig/data/src/main/java/samples/data/SecurityMessageRepository.java b/samples/javaconfig/data/src/main/java/samples/data/SecurityMessageRepository.java index e9a2f8eee1..f6b2ac956d 100644 --- a/samples/javaconfig/data/src/main/java/samples/data/SecurityMessageRepository.java +++ b/samples/javaconfig/data/src/main/java/samples/data/SecurityMessageRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.data; import org.springframework.data.jpa.repository.Query; diff --git a/samples/javaconfig/data/src/main/java/samples/data/User.java b/samples/javaconfig/data/src/main/java/samples/data/User.java index 55190206ba..15d4f37d77 100644 --- a/samples/javaconfig/data/src/main/java/samples/data/User.java +++ b/samples/javaconfig/data/src/main/java/samples/data/User.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.data; import javax.persistence.Entity; diff --git a/samples/javaconfig/data/src/test/java/samples/data/SecurityMessageRepositoryTests.java b/samples/javaconfig/data/src/test/java/samples/data/SecurityMessageRepositoryTests.java index 3ca8aa3203..9972659fd2 100644 --- a/samples/javaconfig/data/src/test/java/samples/data/SecurityMessageRepositoryTests.java +++ b/samples/javaconfig/data/src/test/java/samples/data/SecurityMessageRepositoryTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.data; import org.junit.After; diff --git a/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/FormJcTests.java b/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/FormJcTests.java index 8a56b13168..1ca1e669a1 100644 --- a/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/FormJcTests.java +++ b/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/FormJcTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index 8becdc4b33..22ddd1c639 100644 --- a/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index e4956e3eed..af6fbaa988 100644 --- a/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/javaconfig/form/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 5b002bade0..8f9cbba6f6 100644 --- a/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/form/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; @@ -29,17 +30,17 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers("/resources/**").permitAll() .anyRequest().authenticated() ) - .formLogin(formLogin -> + .formLogin((formLogin) -> formLogin .loginPage("/login") .permitAll() ) - .logout(logout -> + .logout((logout) -> logout .permitAll() ); diff --git a/samples/javaconfig/form/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/form/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index e77a3ef7c1..dd807f16d1 100644 --- a/samples/javaconfig/form/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/form/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 4182b2b95c..595d2413e0 100644 --- a/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; diff --git a/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/mvc/MessageJsonController.java b/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/mvc/MessageJsonController.java index 65bc3d1f12..696d3c75e7 100644 --- a/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/mvc/MessageJsonController.java +++ b/samples/javaconfig/hellojs/src/main/java/org/springframework/security/samples/mvc/MessageJsonController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.mvc; import java.util.ArrayList; diff --git a/samples/javaconfig/hellojs/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/hellojs/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index e77a3ef7c1..dd807f16d1 100644 --- a/samples/javaconfig/hellojs/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/hellojs/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index edce6c794f..45f8ae8d78 100644 --- a/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.core.annotation.Order; diff --git a/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 2ad15cf836..c0719e88f8 100644 --- a/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/hellomvc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; diff --git a/samples/javaconfig/hellomvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/hellomvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index 140318cd59..f7117bf46b 100644 --- a/samples/javaconfig/hellomvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/hellomvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.*; diff --git a/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldJcTests.java b/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldJcTests.java index 85e7fcad64..d2f43eb8bd 100644 --- a/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldJcTests.java +++ b/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldJcTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index fae8d98220..ed527c3957 100644 --- a/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index 2532ba2406..30822ea2a9 100644 --- a/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/javaconfig/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index c3bdfe2417..0d80bfac16 100644 --- a/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.context.annotation.Bean; diff --git a/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityWebApplicationInitializer.java b/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityWebApplicationInitializer.java index 0605ea4647..008b87b870 100644 --- a/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityWebApplicationInitializer.java +++ b/samples/javaconfig/helloworld/src/main/java/org/springframework/security/samples/config/SecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index a657c60795..ad65d6ec9b 100644 --- a/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/inmemory/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.context.annotation.Bean; diff --git a/samples/javaconfig/inmemory/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/inmemory/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index 1352614e87..507cc875da 100644 --- a/samples/javaconfig/inmemory/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/inmemory/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.*; diff --git a/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/JdbcJcTests.java b/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/JdbcJcTests.java index 985d2c6f63..dad6a0052c 100644 --- a/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/JdbcJcTests.java +++ b/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/JdbcJcTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index d3bbbdb25b..71033bf941 100644 --- a/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index 2532ba2406..30822ea2a9 100644 --- a/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/javaconfig/jdbc/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 7715efc3e9..e3dbb1cf51 100644 --- a/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/jdbc/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import javax.sql.DataSource; diff --git a/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/LdapJcTests.java b/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/LdapJcTests.java index 7e8f7f6436..99df0ac959 100644 --- a/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/LdapJcTests.java +++ b/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/LdapJcTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index 8becdc4b33..22ddd1c639 100644 --- a/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index 2532ba2406..30822ea2a9 100644 --- a/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/javaconfig/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 0d07a12d62..116219d418 100644 --- a/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/ldap/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/DataConfiguration.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/DataConfiguration.java index 8cf7d9695d..9e174051d8 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/DataConfiguration.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/DataConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import javax.sql.DataSource; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/MessageWebApplicationInitializer.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/MessageWebApplicationInitializer.java index bf46ff3364..428042060f 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/MessageWebApplicationInitializer.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/MessageWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import javax.servlet.Filter; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/RootConfiguration.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/RootConfiguration.java index 56b3ed6828..ce3d173848 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/RootConfiguration.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/config/RootConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.context.annotation.ComponentScan; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/Message.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/Message.java index 6e180460ae..23e038f702 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/Message.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/Message.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.data; import java.util.Calendar; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/MessageRepository.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/MessageRepository.java index 9ed5d1ec49..22a417facf 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/MessageRepository.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/data/MessageRepository.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.data; import org.springframework.data.repository.CrudRepository; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/DefaultController.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/DefaultController.java index fff3fc600f..62b42b489f 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/DefaultController.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/DefaultController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.mvc; import org.springframework.stereotype.Controller; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/MessageController.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/MessageController.java index 85aa9bda1b..8c909eb0a6 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/MessageController.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/MessageController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.mvc; import javax.validation.Valid; diff --git a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/config/WebMvcConfiguration.java b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/config/WebMvcConfiguration.java index 73fed516fb..dfbef33d6c 100644 --- a/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/config/WebMvcConfiguration.java +++ b/samples/javaconfig/messages/src/main/java/org/springframework/security/samples/mvc/config/WebMvcConfiguration.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.mvc.config; import org.springframework.beans.factory.annotation.Autowired; diff --git a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 3a4ec0ad52..c194eba095 100644 --- a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 452a80bdd0..9645aac23b 100644 --- a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -31,64 +32,64 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers("/resources/**").permitAll() .anyRequest().authenticated() ) - .openidLogin(openidLogin -> + .openidLogin((openidLogin) -> openidLogin .loginPage("/login") .permitAll() .authenticationUserDetailsService(new CustomUserDetailsService()) - .attributeExchange(googleExchange -> + .attributeExchange((googleExchange) -> googleExchange .identifierPattern("https://www.google.com/.*") - .attribute(emailAttribute -> + .attribute((emailAttribute) -> emailAttribute .name("email") .type("https://axschema.org/contact/email") .required(true) ) - .attribute(firstnameAttribute -> + .attribute((firstnameAttribute) -> firstnameAttribute .name("firstname") .type("https://axschema.org/namePerson/first") .required(true) ) - .attribute(lastnameAttribute -> + .attribute((lastnameAttribute) -> lastnameAttribute .name("lastname") .type("https://axschema.org/namePerson/last") .required(true) ) ) - .attributeExchange(yahooExchange -> + .attributeExchange((yahooExchange) -> yahooExchange .identifierPattern(".*yahoo.com.*") - .attribute(emailAttribute -> + .attribute((emailAttribute) -> emailAttribute .name("email") .type("https://axschema.org/contact/email") .required(true) ) - .attribute(fullnameAttribute -> + .attribute((fullnameAttribute) -> fullnameAttribute .name("fullname") .type("https://axschema.org/namePerson") .required(true) ) ) - .attributeExchange(myopenidExchange -> + .attributeExchange((myopenidExchange) -> myopenidExchange .identifierPattern(".*myopenid.com.*") - .attribute(emailAttribute -> + .attribute((emailAttribute) -> emailAttribute .name("email") .type("https://schema.openid.net/contact/email") .required(true) ) - .attribute(fullnameAttribute -> + .attribute((fullnameAttribute) -> fullnameAttribute .name("fullname") .type("https://schema.openid.net/namePerson") diff --git a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/mvc/UserController.java b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/mvc/UserController.java index d700a3e832..7b20b9d788 100644 --- a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/mvc/UserController.java +++ b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/mvc/UserController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.mvc; import org.springframework.security.openid.OpenIDAuthenticationToken; diff --git a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/security/CustomUserDetailsService.java b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/security/CustomUserDetailsService.java index faaa81afbe..935a1efcf5 100644 --- a/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/security/CustomUserDetailsService.java +++ b/samples/javaconfig/openid/src/main/java/org/springframework/security/samples/security/CustomUserDetailsService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.security; import org.springframework.security.core.authority.AuthorityUtils; diff --git a/samples/javaconfig/openid/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/openid/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index 07a2cfb2c8..776fa09e04 100644 --- a/samples/javaconfig/openid/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/openid/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 1544f08b3b..512365bcf0 100644 --- a/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/preauth/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -26,12 +27,12 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers("/login", "/resources/**").permitAll() .anyRequest().authenticated() ) - .jee(jee -> + .jee((jee) -> jee .mappableRoles("USER", "ADMIN") ); diff --git a/samples/javaconfig/preauth/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/preauth/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index e77a3ef7c1..dd807f16d1 100644 --- a/samples/javaconfig/preauth/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/preauth/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index c2e36065a0..2570717394 100644 --- a/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/rememberme/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; @@ -41,12 +42,12 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .antMatchers("/resources/**").permitAll() .anyRequest().authenticated() ) - .formLogin(formLogin -> + .formLogin((formLogin) -> formLogin .loginPage("/login") .permitAll() diff --git a/samples/javaconfig/rememberme/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/rememberme/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index e77a3ef7c1..dd807f16d1 100644 --- a/samples/javaconfig/rememberme/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/rememberme/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 7de3308deb..c33dc58cf2 100644 --- a/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 0d6af74358..15a24f5b50 100644 --- a/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import java.io.ByteArrayInputStream; @@ -56,11 +57,11 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { return RelyingPartyRegistration.withRegistrationId(registrationId) .entityId(localEntityIdTemplate) .assertionConsumerServiceLocation(acsUrlTemplate) - .signingX509Credentials(c -> c.add(signingCredential)) - .assertingPartyDetails(config -> config + .signingX509Credentials((c) -> c.add(signingCredential)) + .assertingPartyDetails((config) -> config .entityId(idpEntityId) .singleSignOnServiceLocation(webSsoEndpoint) - .verificationX509Credentials(c -> c.add(idpVerificationCertificate))) + .verificationX509Credentials((c) -> c.add(idpVerificationCertificate))) .build(); } diff --git a/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index 0d79b19a50..8a416b95c1 100644 --- a/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; @@ -51,7 +52,7 @@ public class SecurityConfigTests { Saml2WebSsoAuthenticationFilter filter = (Saml2WebSsoAuthenticationFilter) filters .stream() .filter( - f -> f instanceof Saml2WebSsoAuthenticationFilter + (f) -> f instanceof Saml2WebSsoAuthenticationFilter ) .findFirst() .get(); diff --git a/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java b/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java index 4ed43fc3cb..f851f82de0 100644 --- a/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java +++ b/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/MessageSecurityWebApplicationInitializer.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; diff --git a/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/SecurityConfig.java b/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/SecurityConfig.java index 1fa5356df4..040dd75762 100644 --- a/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/SecurityConfig.java +++ b/samples/javaconfig/x509/src/main/java/org/springframework/security/samples/config/SecurityConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.springframework.beans.factory.annotation.Autowired; @@ -42,7 +43,7 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { @Override protected void configure(HttpSecurity http) throws Exception { http - .authorizeRequests(authorizeRequests -> + .authorizeRequests((authorizeRequests) -> authorizeRequests .anyRequest().authenticated() ) diff --git a/samples/javaconfig/x509/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/javaconfig/x509/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index e77a3ef7c1..dd807f16d1 100644 --- a/samples/javaconfig/x509/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/javaconfig/x509/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import org.junit.Test; diff --git a/samples/xml/aspectj/src/main/java/sample/aspectj/SecuredService.java b/samples/xml/aspectj/src/main/java/sample/aspectj/SecuredService.java index bf954915b8..9f81cf1733 100644 --- a/samples/xml/aspectj/src/main/java/sample/aspectj/SecuredService.java +++ b/samples/xml/aspectj/src/main/java/sample/aspectj/SecuredService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.springframework.security.access.annotation.Secured; diff --git a/samples/xml/aspectj/src/main/java/sample/aspectj/Service.java b/samples/xml/aspectj/src/main/java/sample/aspectj/Service.java index 2c5a433018..70b8c0c71f 100644 --- a/samples/xml/aspectj/src/main/java/sample/aspectj/Service.java +++ b/samples/xml/aspectj/src/main/java/sample/aspectj/Service.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.springframework.security.access.annotation.Secured; diff --git a/samples/xml/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java b/samples/xml/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java index a195b18224..5e78a92108 100644 --- a/samples/xml/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java +++ b/samples/xml/aspectj/src/test/java/sample/aspectj/AspectJInterceptorTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.aspectj; import org.junit.After; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleProxyTests.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleProxyTests.java index de65160781..906952cb63 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleProxyTests.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleProxyTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas; import java.util.HashMap; @@ -116,14 +117,14 @@ public class CasSampleProxyTests { public void extremelySecurePageWhenReusingTicketThenDisplays() { this.login.to(this::serviceParam).assertAt().login("rod"); Map ptCache = new HashMap<>(); - this.extremelySecure.to(url -> url + "?ticket=" + ptCache.computeIfAbsent(url, this::getPt)).assertAt(); - this.extremelySecure.to(url -> url + "?ticket=" + ptCache.get(url)).assertAt(); + this.extremelySecure.to((url) -> url + "?ticket=" + ptCache.computeIfAbsent(url, this::getPt)).assertAt(); + this.extremelySecure.to((url) -> url + "?ticket=" + ptCache.get(url)).assertAt(); } @Test public void securePageWhenInvalidTicketThenFails() { this.login.to(this::serviceParam).assertAt().login("scott"); - this.secure.to(url -> url + "?ticket=invalid"); + this.secure.to((url) -> url + "?ticket=invalid"); this.unauthorized.assertAt(); } diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleTests.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleTests.java index c03ef891fc..e820531969 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleTests.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/CasSampleTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas; import org.junit.After; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/JettyCasService.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/JettyCasService.java index 1491beb6f2..dc9454b7e0 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/JettyCasService.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/JettyCasService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas; import java.io.File; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/AccessDeniedPage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/AccessDeniedPage.java index e1426a360f..cb143f14ff 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/AccessDeniedPage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/AccessDeniedPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ExtremelySecurePage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ExtremelySecurePage.java index 7ad5e5d0ee..1524396506 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ExtremelySecurePage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ExtremelySecurePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/HomePage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/HomePage.java index 25cf8dbded..87317329d7 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/HomePage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LocalLogoutPage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LocalLogoutPage.java index 075a792080..7cac6aad53 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LocalLogoutPage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LocalLogoutPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LoginPage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LoginPage.java index 44c509905c..436e39eb8b 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LoginPage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ProxyTicketSamplePage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ProxyTicketSamplePage.java index 4f4bd78fa3..c570414f3a 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ProxyTicketSamplePage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/ProxyTicketSamplePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/SecurePage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/SecurePage.java index 7b5afe0093..338ce55ccf 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/SecurePage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/SecurePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/UnauthorizedPage.java b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/UnauthorizedPage.java index 4b8e099ab1..90b967d5b6 100644 --- a/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/UnauthorizedPage.java +++ b/samples/xml/cas/cassample/src/integration-test/java/org/springframework/security/samples/cas/pages/UnauthorizedPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/cas/cassample/src/main/java/org/springframework/security/samples/cas/web/ProxyTicketSampleServlet.java b/samples/xml/cas/cassample/src/main/java/org/springframework/security/samples/cas/web/ProxyTicketSampleServlet.java index 29f6241644..6b66a0b871 100644 --- a/samples/xml/cas/cassample/src/main/java/org/springframework/security/samples/cas/web/ProxyTicketSampleServlet.java +++ b/samples/xml/cas/cassample/src/main/java/org/springframework/security/samples/cas/web/ProxyTicketSampleServlet.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.cas.web; import java.io.IOException; diff --git a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/ContactsTests.java b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/ContactsTests.java index 5a7f5aee8c..db1e4e2e33 100644 --- a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/ContactsTests.java +++ b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/ContactsTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/AddPage.java b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/AddPage.java index cc58bacb08..303e426375 100644 --- a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/AddPage.java +++ b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/AddPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/ContactsPage.java b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/ContactsPage.java index dd045ee110..a6b4ff5d6f 100644 --- a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/ContactsPage.java +++ b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/ContactsPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.By; @@ -65,19 +66,19 @@ public class ContactsPage { } Predicate byEmail(final String val) { - return e -> e.findElements(By.xpath("td[position()=3 and normalize-space()='" + val + "']")).size() == 1; + return (e) -> e.findElements(By.xpath("td[position()=3 and normalize-space()='" + val + "']")).size() == 1; } Predicate byName(final String val) { - return e -> e.findElements(By.xpath("td[position()=2 and normalize-space()='" + val + "']")).size() == 1; + return (e) -> e.findElements(By.xpath("td[position()=2 and normalize-space()='" + val + "']")).size() == 1; } public DeleteContactLink andHasContact(final String name, final String email) { return this.contacts.stream() .filter(byEmail(email).and(byName(name))) - .map(e -> e.findElement(By.cssSelector("td:nth-child(4) > a"))) + .map((e) -> e.findElement(By.cssSelector("td:nth-child(4) > a"))) .findFirst() - .map(e -> new DeleteContactLink(webDriver, e)) + .map((e) -> new DeleteContactLink(webDriver, e)) .get(); } diff --git a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index 66370e9f8a..8c985c8d83 100644 --- a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index dc39dfe732..d0f9efefbf 100644 --- a/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/xml/contacts/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/contacts/src/main/java/sample/contact/AddDeleteContactController.java b/samples/xml/contacts/src/main/java/sample/contact/AddDeleteContactController.java index 3f0ac3d9de..8c6ca3c256 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/AddDeleteContactController.java +++ b/samples/xml/contacts/src/main/java/sample/contact/AddDeleteContactController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import org.springframework.beans.factory.annotation.Autowired; diff --git a/samples/xml/contacts/src/main/java/sample/contact/AddPermission.java b/samples/xml/contacts/src/main/java/sample/contact/AddPermission.java index 31439eac66..254841e42e 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/AddPermission.java +++ b/samples/xml/contacts/src/main/java/sample/contact/AddPermission.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import org.springframework.security.acls.domain.BasePermission; @@ -23,15 +24,12 @@ import org.springframework.security.acls.domain.BasePermission; * @author Ben Alex */ public class AddPermission { - // ~ Instance fields - // ================================================================================================ + public Contact contact; public Integer permission = BasePermission.READ.getMask(); public String recipient; - // ~ Methods - // ======================================================================================================== public Contact getContact() { return contact; diff --git a/samples/xml/contacts/src/main/java/sample/contact/AddPermissionValidator.java b/samples/xml/contacts/src/main/java/sample/contact/AddPermissionValidator.java index 5c4b51f2b2..f56b7647e2 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/AddPermissionValidator.java +++ b/samples/xml/contacts/src/main/java/sample/contact/AddPermissionValidator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import org.springframework.security.acls.domain.BasePermission; @@ -27,8 +28,7 @@ import org.springframework.validation.Validator; * @author Ben Alex */ public class AddPermissionValidator implements Validator { - // ~ Methods - // ======================================================================================================== + @SuppressWarnings("unchecked") public boolean supports(Class clazz) { diff --git a/samples/xml/contacts/src/main/java/sample/contact/AdminPermissionController.java b/samples/xml/contacts/src/main/java/sample/contact/AdminPermissionController.java index c44bf22f78..7f048f1666 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/AdminPermissionController.java +++ b/samples/xml/contacts/src/main/java/sample/contact/AdminPermissionController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import java.util.HashMap; diff --git a/samples/xml/contacts/src/main/java/sample/contact/ClientApplication.java b/samples/xml/contacts/src/main/java/sample/contact/ClientApplication.java index c93a5d6c68..1203b77e1a 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/ClientApplication.java +++ b/samples/xml/contacts/src/main/java/sample/contact/ClientApplication.java @@ -36,20 +36,15 @@ import org.springframework.util.StopWatch; * @author Ben Alex */ public class ClientApplication { - // ~ Instance fields - // ================================================================================================ + private final ListableBeanFactory beanFactory; - // ~ Constructors - // =================================================================================================== public ClientApplication(ListableBeanFactory beanFactory) { this.beanFactory = beanFactory; } - // ~ Methods - // ======================================================================================================== public void invokeContactManager(Authentication authentication, int nrOfCalls) { StopWatch stopWatch = new StopWatch(nrOfCalls + " ContactManager call(s)"); diff --git a/samples/xml/contacts/src/main/java/sample/contact/Contact.java b/samples/xml/contacts/src/main/java/sample/contact/Contact.java index b7819ce405..ec719e856b 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/Contact.java +++ b/samples/xml/contacts/src/main/java/sample/contact/Contact.java @@ -24,15 +24,12 @@ import java.io.Serializable; * @author Ben Alex */ public class Contact implements Serializable { - // ~ Instance fields - // ================================================================================================ + private Long id; private String email; private String name; - // ~ Constructors - // =================================================================================================== public Contact(String name, String email) { this.name = name; @@ -42,8 +39,6 @@ public class Contact implements Serializable { public Contact() { } - // ~ Methods - // ======================================================================================================== /** * @return Returns the email. diff --git a/samples/xml/contacts/src/main/java/sample/contact/ContactDao.java b/samples/xml/contacts/src/main/java/sample/contact/ContactDao.java index c882e2856b..9f245582d2 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/ContactDao.java +++ b/samples/xml/contacts/src/main/java/sample/contact/ContactDao.java @@ -24,8 +24,7 @@ import java.util.List; * @author Ben Alex */ public interface ContactDao { - // ~ Methods - // ======================================================================================================== + void create(Contact contact); diff --git a/samples/xml/contacts/src/main/java/sample/contact/ContactDaoSpring.java b/samples/xml/contacts/src/main/java/sample/contact/ContactDaoSpring.java index 36619dd260..f2b8029516 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/ContactDaoSpring.java +++ b/samples/xml/contacts/src/main/java/sample/contact/ContactDaoSpring.java @@ -30,12 +30,10 @@ import org.springframework.jdbc.core.support.JdbcDaoSupport; */ public class ContactDaoSpring extends JdbcDaoSupport implements ContactDao { - // ~ Methods - // ======================================================================================================== public void create(final Contact contact) { getJdbcTemplate().update("insert into contacts values (?, ?, ?)", - ps -> { + (ps) -> { ps.setLong(1, contact.getId()); ps.setString(2, contact.getName()); ps.setString(3, contact.getEmail()); @@ -44,13 +42,13 @@ public class ContactDaoSpring extends JdbcDaoSupport implements ContactDao { public void delete(final Long contactId) { getJdbcTemplate().update("delete from contacts where id = ?", - ps -> ps.setLong(1, contactId)); + (ps) -> ps.setLong(1, contactId)); } public void update(final Contact contact) { getJdbcTemplate().update( "update contacts set contact_name = ?, address = ? where id = ?", - ps -> { + (ps) -> { ps.setString(1, contact.getName()); ps.setString(2, contact.getEmail()); ps.setLong(3, contact.getId()); diff --git a/samples/xml/contacts/src/main/java/sample/contact/ContactManager.java b/samples/xml/contacts/src/main/java/sample/contact/ContactManager.java index 924df446d2..5d3ff14a47 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/ContactManager.java +++ b/samples/xml/contacts/src/main/java/sample/contact/ContactManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import org.springframework.security.access.prepost.PostFilter; @@ -28,8 +29,7 @@ import java.util.List; * @author Ben Alex */ public interface ContactManager { - // ~ Methods - // ======================================================================================================== + @PreAuthorize("hasPermission(#contact, admin)") void addPermission(Contact contact, Sid recipient, Permission permission); diff --git a/samples/xml/contacts/src/main/java/sample/contact/ContactManagerBackend.java b/samples/xml/contacts/src/main/java/sample/contact/ContactManagerBackend.java index ea7eabba7b..3ba1a7dbf2 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/ContactManagerBackend.java +++ b/samples/xml/contacts/src/main/java/sample/contact/ContactManagerBackend.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import org.springframework.security.acls.domain.BasePermission; @@ -49,15 +50,12 @@ import java.util.Random; @Transactional public class ContactManagerBackend extends ApplicationObjectSupport implements ContactManager, InitializingBean { - // ~ Instance fields - // ================================================================================================ + private ContactDao contactDao; private MutableAclService mutableAclService; private int counter = 1000; - // ~ Methods - // ======================================================================================================== public void afterPropertiesSet() { Assert.notNull(contactDao, "contactDao required"); diff --git a/samples/xml/contacts/src/main/java/sample/contact/DataSourcePopulator.java b/samples/xml/contacts/src/main/java/sample/contact/DataSourcePopulator.java index c58b1fd76c..2315426909 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/DataSourcePopulator.java +++ b/samples/xml/contacts/src/main/java/sample/contact/DataSourcePopulator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import java.util.Random; @@ -43,8 +44,7 @@ import org.springframework.util.Assert; * @author Ben Alex */ public class DataSourcePopulator implements InitializingBean { - // ~ Instance fields - // ================================================================================================ + JdbcTemplate template; private MutableAclService mutableAclService; @@ -60,8 +60,6 @@ public class DataSourcePopulator implements InitializingBean { "Parklin", "Findlay", "Robinson", "Giugni", "Lang", "Chi", "Carmichael" }; private int createEntities = 50; - // ~ Methods - // ======================================================================================================== public void afterPropertiesSet() { Assert.notNull(mutableAclService, "mutableAclService required"); @@ -164,7 +162,7 @@ public class DataSourcePopulator implements InitializingBean { for (int i = 1; i < createEntities; i++) { final ObjectIdentity objectIdentity = new ObjectIdentityImpl(Contact.class, (long) i); - tt.execute(arg0 -> { + tt.execute((arg0) -> { mutableAclService.createAcl(objectIdentity); return null; @@ -269,7 +267,7 @@ public class DataSourcePopulator implements InitializingBean { } private void updateAclInTransaction(final MutableAcl acl) { - tt.execute(arg0 -> { + tt.execute((arg0) -> { mutableAclService.updateAcl(acl); return null; diff --git a/samples/xml/contacts/src/main/java/sample/contact/IndexController.java b/samples/xml/contacts/src/main/java/sample/contact/IndexController.java index 09c576862f..b46b4e9810 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/IndexController.java +++ b/samples/xml/contacts/src/main/java/sample/contact/IndexController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import java.util.HashMap; @@ -44,16 +45,12 @@ public class IndexController { BasePermission.DELETE, BasePermission.ADMINISTRATION }; private final static Permission[] HAS_ADMIN = new Permission[] { BasePermission.ADMINISTRATION }; - // ~ Instance fields - // ================================================================================================ @Autowired private ContactManager contactManager; @Autowired private PermissionEvaluator permissionEvaluator; - // ~ Methods - // ======================================================================================================== /** * The public index page, used for unauthenticated users. diff --git a/samples/xml/contacts/src/main/java/sample/contact/WebContact.java b/samples/xml/contacts/src/main/java/sample/contact/WebContact.java index df38b0c5f0..c4ae49df67 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/WebContact.java +++ b/samples/xml/contacts/src/main/java/sample/contact/WebContact.java @@ -22,14 +22,11 @@ package sample.contact; * @author Ben Alex */ public class WebContact { - // ~ Instance fields - // ================================================================================================ + private String email; private String name; - // ~ Methods - // ======================================================================================================== public String getEmail() { return email; diff --git a/samples/xml/contacts/src/main/java/sample/contact/WebContactValidator.java b/samples/xml/contacts/src/main/java/sample/contact/WebContactValidator.java index f4f27d4306..8d0b754489 100644 --- a/samples/xml/contacts/src/main/java/sample/contact/WebContactValidator.java +++ b/samples/xml/contacts/src/main/java/sample/contact/WebContactValidator.java @@ -25,8 +25,7 @@ import org.springframework.validation.Validator; * @author Ben Alex */ public class WebContactValidator implements Validator { - // ~ Methods - // ======================================================================================================== + @SuppressWarnings("unchecked") public boolean supports(Class clazz) { diff --git a/samples/xml/contacts/src/test/java/sample/contact/ContactManagerTests.java b/samples/xml/contacts/src/test/java/sample/contact/ContactManagerTests.java index 9a9b8ec023..2c64395c1d 100644 --- a/samples/xml/contacts/src/test/java/sample/contact/ContactManagerTests.java +++ b/samples/xml/contacts/src/test/java/sample/contact/ContactManagerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.contact; import static org.assertj.core.api.Assertions.assertThat; @@ -44,14 +45,11 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; "/applicationContext-common-business.xml" }) @RunWith(SpringJUnit4ClassRunner.class) public class ContactManagerTests { - // ~ Instance fields - // ================================================================================================ + @Autowired protected ContactManager contactManager; - // ~ Methods - // ======================================================================================================== void assertContainsContact(long id, List contacts) { for (Contact contact : contacts) { diff --git a/samples/xml/dms/src/main/java/sample/dms/AbstractElement.java b/samples/xml/dms/src/main/java/sample/dms/AbstractElement.java index 046e500b91..a65f334179 100755 --- a/samples/xml/dms/src/main/java/sample/dms/AbstractElement.java +++ b/samples/xml/dms/src/main/java/sample/dms/AbstractElement.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms; import java.util.ArrayList; diff --git a/samples/xml/dms/src/main/java/sample/dms/DataSourcePopulator.java b/samples/xml/dms/src/main/java/sample/dms/DataSourcePopulator.java index 9e2250c288..476a73dd91 100755 --- a/samples/xml/dms/src/main/java/sample/dms/DataSourcePopulator.java +++ b/samples/xml/dms/src/main/java/sample/dms/DataSourcePopulator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms; import javax.sql.DataSource; diff --git a/samples/xml/dms/src/main/java/sample/dms/Directory.java b/samples/xml/dms/src/main/java/sample/dms/Directory.java index 0326cfc41a..06a0b20991 100755 --- a/samples/xml/dms/src/main/java/sample/dms/Directory.java +++ b/samples/xml/dms/src/main/java/sample/dms/Directory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms; /** diff --git a/samples/xml/dms/src/main/java/sample/dms/DocumentDao.java b/samples/xml/dms/src/main/java/sample/dms/DocumentDao.java index 46d27f4f56..1419d02808 100755 --- a/samples/xml/dms/src/main/java/sample/dms/DocumentDao.java +++ b/samples/xml/dms/src/main/java/sample/dms/DocumentDao.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms; /** diff --git a/samples/xml/dms/src/main/java/sample/dms/DocumentDaoImpl.java b/samples/xml/dms/src/main/java/sample/dms/DocumentDaoImpl.java index fc1c73afc0..0d83bf75f7 100755 --- a/samples/xml/dms/src/main/java/sample/dms/DocumentDaoImpl.java +++ b/samples/xml/dms/src/main/java/sample/dms/DocumentDaoImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms; import java.util.List; diff --git a/samples/xml/dms/src/main/java/sample/dms/File.java b/samples/xml/dms/src/main/java/sample/dms/File.java index cec040822a..1ceaf8ecdb 100755 --- a/samples/xml/dms/src/main/java/sample/dms/File.java +++ b/samples/xml/dms/src/main/java/sample/dms/File.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms; import org.springframework.util.Assert; diff --git a/samples/xml/dms/src/main/java/sample/dms/secured/SecureDataSourcePopulator.java b/samples/xml/dms/src/main/java/sample/dms/secured/SecureDataSourcePopulator.java index 845848dbe7..7487c73db7 100755 --- a/samples/xml/dms/src/main/java/sample/dms/secured/SecureDataSourcePopulator.java +++ b/samples/xml/dms/src/main/java/sample/dms/secured/SecureDataSourcePopulator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms.secured; import javax.sql.DataSource; diff --git a/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDao.java b/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDao.java index 004e33c298..d4132d2979 100755 --- a/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDao.java +++ b/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDao.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms.secured; import sample.dms.DocumentDao; diff --git a/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDaoImpl.java b/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDaoImpl.java index 2fad553bb6..caae4574b3 100755 --- a/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDaoImpl.java +++ b/samples/xml/dms/src/main/java/sample/dms/secured/SecureDocumentDaoImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.dms.secured; import org.springframework.security.acls.domain.BasePermission; diff --git a/samples/xml/dms/src/test/java/sample/DmsIntegrationTests.java b/samples/xml/dms/src/test/java/sample/DmsIntegrationTests.java index bf28298f2e..11fbd3f6a6 100644 --- a/samples/xml/dms/src/test/java/sample/DmsIntegrationTests.java +++ b/samples/xml/dms/src/test/java/sample/DmsIntegrationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; /* @@ -122,13 +123,9 @@ public class DmsIntegrationTests extends AbstractTransactionalJUnit4SpringContex // plus 10 files AbstractElement[] nonHomeElements = this.documentDao.findElements(nonHomeDir); - assertThat(nonHomeElements).hasSize(shouldBeFiltered ? 11 : 12); // cannot - // see - // the user's - // "confidential" - // sub-directory - // when - // filtering + assertThat(nonHomeElements).hasSize(shouldBeFiltered ? 11 : 12); + + // cannot see the user's "confidential" sub-directory when filtering // Attempt to read the other user's confidential directory from the returned // results diff --git a/samples/xml/dms/src/test/java/sample/SecureDmsIntegrationTests.java b/samples/xml/dms/src/test/java/sample/SecureDmsIntegrationTests.java index 33b6b34641..63965afb26 100644 --- a/samples/xml/dms/src/test/java/sample/SecureDmsIntegrationTests.java +++ b/samples/xml/dms/src/test/java/sample/SecureDmsIntegrationTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; diff --git a/samples/xml/gae/src/main/java/samples/gae/security/AppRole.java b/samples/xml/gae/src/main/java/samples/gae/security/AppRole.java index 030893db66..84743e6044 100644 --- a/samples/xml/gae/src/main/java/samples/gae/security/AppRole.java +++ b/samples/xml/gae/src/main/java/samples/gae/security/AppRole.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.security; import org.springframework.security.core.GrantedAuthority; diff --git a/samples/xml/gae/src/main/java/samples/gae/security/GaeAuthenticationFilter.java b/samples/xml/gae/src/main/java/samples/gae/security/GaeAuthenticationFilter.java index 1d828d4032..838529efe3 100644 --- a/samples/xml/gae/src/main/java/samples/gae/security/GaeAuthenticationFilter.java +++ b/samples/xml/gae/src/main/java/samples/gae/security/GaeAuthenticationFilter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.security; import java.io.IOException; diff --git a/samples/xml/gae/src/main/java/samples/gae/security/GaeUserAuthentication.java b/samples/xml/gae/src/main/java/samples/gae/security/GaeUserAuthentication.java index b8d1f41a10..961ea99952 100644 --- a/samples/xml/gae/src/main/java/samples/gae/security/GaeUserAuthentication.java +++ b/samples/xml/gae/src/main/java/samples/gae/security/GaeUserAuthentication.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.security; import java.util.Collection; diff --git a/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationEntryPoint.java b/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationEntryPoint.java index 0a4077f60f..4929c73cc4 100644 --- a/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationEntryPoint.java +++ b/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationEntryPoint.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.security; import java.io.IOException; diff --git a/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationProvider.java b/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationProvider.java index 5f1a4bbada..6a8f1a37ae 100644 --- a/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationProvider.java +++ b/samples/xml/gae/src/main/java/samples/gae/security/GoogleAccountsAuthenticationProvider.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.security; import com.google.appengine.api.users.User; diff --git a/samples/xml/gae/src/main/java/samples/gae/users/GaeDatastoreUserRegistry.java b/samples/xml/gae/src/main/java/samples/gae/users/GaeDatastoreUserRegistry.java index 7545ca8b78..ea5b5c29b4 100644 --- a/samples/xml/gae/src/main/java/samples/gae/users/GaeDatastoreUserRegistry.java +++ b/samples/xml/gae/src/main/java/samples/gae/users/GaeDatastoreUserRegistry.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.users; import com.google.appengine.api.datastore.DatastoreService; diff --git a/samples/xml/gae/src/main/java/samples/gae/users/GaeUser.java b/samples/xml/gae/src/main/java/samples/gae/users/GaeUser.java index cb4eab10c4..0d167aa5d7 100644 --- a/samples/xml/gae/src/main/java/samples/gae/users/GaeUser.java +++ b/samples/xml/gae/src/main/java/samples/gae/users/GaeUser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.users; import java.io.Serializable; diff --git a/samples/xml/gae/src/main/java/samples/gae/users/InMemoryUserRegistry.java b/samples/xml/gae/src/main/java/samples/gae/users/InMemoryUserRegistry.java index 1d7e2a77df..be1e481b3f 100644 --- a/samples/xml/gae/src/main/java/samples/gae/users/InMemoryUserRegistry.java +++ b/samples/xml/gae/src/main/java/samples/gae/users/InMemoryUserRegistry.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.users; import java.util.Collections; diff --git a/samples/xml/gae/src/main/java/samples/gae/users/UserRegistry.java b/samples/xml/gae/src/main/java/samples/gae/users/UserRegistry.java index d805ccb489..8a53378e54 100644 --- a/samples/xml/gae/src/main/java/samples/gae/users/UserRegistry.java +++ b/samples/xml/gae/src/main/java/samples/gae/users/UserRegistry.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.users; /** diff --git a/samples/xml/gae/src/main/java/samples/gae/validation/Forename.java b/samples/xml/gae/src/main/java/samples/gae/validation/Forename.java index d22c682ec2..3028c73c3c 100644 --- a/samples/xml/gae/src/main/java/samples/gae/validation/Forename.java +++ b/samples/xml/gae/src/main/java/samples/gae/validation/Forename.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.validation; import static java.lang.annotation.ElementType.*; diff --git a/samples/xml/gae/src/main/java/samples/gae/validation/ForenameValidator.java b/samples/xml/gae/src/main/java/samples/gae/validation/ForenameValidator.java index 986d0f375f..0e4d56bc31 100644 --- a/samples/xml/gae/src/main/java/samples/gae/validation/ForenameValidator.java +++ b/samples/xml/gae/src/main/java/samples/gae/validation/ForenameValidator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.validation; import java.util.regex.Pattern; diff --git a/samples/xml/gae/src/main/java/samples/gae/validation/Surname.java b/samples/xml/gae/src/main/java/samples/gae/validation/Surname.java index a01836e222..5f10e8f4c9 100644 --- a/samples/xml/gae/src/main/java/samples/gae/validation/Surname.java +++ b/samples/xml/gae/src/main/java/samples/gae/validation/Surname.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.validation; import static java.lang.annotation.ElementType.*; diff --git a/samples/xml/gae/src/main/java/samples/gae/validation/SurnameValidator.java b/samples/xml/gae/src/main/java/samples/gae/validation/SurnameValidator.java index f21e185401..a894922c7f 100644 --- a/samples/xml/gae/src/main/java/samples/gae/validation/SurnameValidator.java +++ b/samples/xml/gae/src/main/java/samples/gae/validation/SurnameValidator.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.validation; import java.util.regex.Pattern; diff --git a/samples/xml/gae/src/main/java/samples/gae/web/GaeAppController.java b/samples/xml/gae/src/main/java/samples/gae/web/GaeAppController.java index 79b7e32f7f..c19ffb4194 100644 --- a/samples/xml/gae/src/main/java/samples/gae/web/GaeAppController.java +++ b/samples/xml/gae/src/main/java/samples/gae/web/GaeAppController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.web; import java.io.IOException; diff --git a/samples/xml/gae/src/main/java/samples/gae/web/RegistrationController.java b/samples/xml/gae/src/main/java/samples/gae/web/RegistrationController.java index a14ae77f88..0a535a2a23 100644 --- a/samples/xml/gae/src/main/java/samples/gae/web/RegistrationController.java +++ b/samples/xml/gae/src/main/java/samples/gae/web/RegistrationController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.web; import java.util.EnumSet; diff --git a/samples/xml/gae/src/main/java/samples/gae/web/RegistrationForm.java b/samples/xml/gae/src/main/java/samples/gae/web/RegistrationForm.java index 62d98a00f4..6001e532f0 100644 --- a/samples/xml/gae/src/main/java/samples/gae/web/RegistrationForm.java +++ b/samples/xml/gae/src/main/java/samples/gae/web/RegistrationForm.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.web; import samples.gae.validation.Forename; diff --git a/samples/xml/gae/src/test/java/samples/gae/security/AppRoleTests.java b/samples/xml/gae/src/test/java/samples/gae/security/AppRoleTests.java index be7eb06751..8db7aea192 100644 --- a/samples/xml/gae/src/test/java/samples/gae/security/AppRoleTests.java +++ b/samples/xml/gae/src/test/java/samples/gae/security/AppRoleTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.security; import static org.assertj.core.api.Assertions.*; diff --git a/samples/xml/gae/src/test/java/samples/gae/users/GaeDataStoreUserRegistryTests.java b/samples/xml/gae/src/test/java/samples/gae/users/GaeDataStoreUserRegistryTests.java index 28e6eb964d..69e24d47a7 100644 --- a/samples/xml/gae/src/test/java/samples/gae/users/GaeDataStoreUserRegistryTests.java +++ b/samples/xml/gae/src/test/java/samples/gae/users/GaeDataStoreUserRegistryTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.gae.users; import static org.assertj.core.api.Assertions.assertThat; diff --git a/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldXmlTests.java b/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldXmlTests.java index f723f945d9..68abc3a173 100644 --- a/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldXmlTests.java +++ b/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/HelloWorldXmlTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index fae8d98220..ed527c3957 100644 --- a/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index 2532ba2406..30822ea2a9 100644 --- a/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/xml/helloworld/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/HelloInsecureTests.java b/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/HelloInsecureTests.java index f93bbffe38..fb6231a809 100644 --- a/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/HelloInsecureTests.java +++ b/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/HelloInsecureTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index 7388a703ed..9211300f62 100644 --- a/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/xml/insecure/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/insecuremvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java b/samples/xml/insecuremvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java index 4dccf7f223..ec2d3f088e 100644 --- a/samples/xml/insecuremvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java +++ b/samples/xml/insecuremvc/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.config; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; diff --git a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/JaasXmlTests.java b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/JaasXmlTests.java index c16683fbcf..a6345a8221 100644 --- a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/JaasXmlTests.java +++ b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/JaasXmlTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index 8e800a251f..69625fe8d2 100644 --- a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index 668b9d0aed..9aa9fea10f 100644 --- a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java index e2c97392a0..819ad73e04 100644 --- a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java +++ b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java index e20e0984cb..d2e067cb23 100644 --- a/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java +++ b/samples/xml/jaas/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/jaas/src/main/java/samples/jaas/RoleUserAuthorityGranter.java b/samples/xml/jaas/src/main/java/samples/jaas/RoleUserAuthorityGranter.java index 4878471107..8f1fee9841 100644 --- a/samples/xml/jaas/src/main/java/samples/jaas/RoleUserAuthorityGranter.java +++ b/samples/xml/jaas/src/main/java/samples/jaas/RoleUserAuthorityGranter.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.jaas; import java.security.Principal; diff --git a/samples/xml/jaas/src/main/java/samples/jaas/UsernameEqualsPasswordLoginModule.java b/samples/xml/jaas/src/main/java/samples/jaas/UsernameEqualsPasswordLoginModule.java index fe6fc7ba83..c390c8cbc1 100644 --- a/samples/xml/jaas/src/main/java/samples/jaas/UsernameEqualsPasswordLoginModule.java +++ b/samples/xml/jaas/src/main/java/samples/jaas/UsernameEqualsPasswordLoginModule.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package samples.jaas; import java.io.Serializable; @@ -34,15 +35,12 @@ import javax.security.auth.spi.LoginModule; * @author Rob Winch */ public class UsernameEqualsPasswordLoginModule implements LoginModule { - // ~ Instance fields - // ================================================================================================ + private String password; private String username; private Subject subject; - // ~ Methods - // ======================================================================================================== @Override public boolean abort() { diff --git a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/LdapXmlTests.java b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/LdapXmlTests.java index 5dba285788..29f44441ea 100644 --- a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/LdapXmlTests.java +++ b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/LdapXmlTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples; import org.junit.After; diff --git a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java index 8e800a251f..69625fe8d2 100644 --- a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java +++ b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/HomePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java index 668b9d0aed..9aa9fea10f 100644 --- a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java +++ b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LoginPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java index e2c97392a0..819ad73e04 100644 --- a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java +++ b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/LogoutPage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java index f03bdf3870..81c7cb4050 100644 --- a/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java +++ b/samples/xml/ldap/src/integration-test/java/org/springframework/security/samples/pages/SecurePage.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.pages; import org.openqa.selenium.WebDriver; diff --git a/samples/xml/oauth2login/src/main/java/sample/config/WebConfig.java b/samples/xml/oauth2login/src/main/java/sample/config/WebConfig.java index 8761f0c77e..53db0361a8 100644 --- a/samples/xml/oauth2login/src/main/java/sample/config/WebConfig.java +++ b/samples/xml/oauth2login/src/main/java/sample/config/WebConfig.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.config; import org.springframework.beans.BeansException; diff --git a/samples/xml/oauth2login/src/main/java/sample/web/OAuth2LoginController.java b/samples/xml/oauth2login/src/main/java/sample/web/OAuth2LoginController.java index 0e66bd9d94..4165428145 100644 --- a/samples/xml/oauth2login/src/main/java/sample/web/OAuth2LoginController.java +++ b/samples/xml/oauth2login/src/main/java/sample/web/OAuth2LoginController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample.web; import org.springframework.security.core.annotation.AuthenticationPrincipal; diff --git a/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetails.java b/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetails.java index 7a5e400ec2..066eb0bbdc 100644 --- a/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetails.java +++ b/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetails.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.openid; import java.util.Collection; diff --git a/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetailsService.java b/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetailsService.java index 27a3bb409d..b92fe9b556 100644 --- a/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetailsService.java +++ b/samples/xml/openid/src/main/java/org/springframework/security/samples/openid/CustomUserDetailsService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.openid; import java.util.HashMap; diff --git a/samples/xml/preauth/src/test/java/sample/PreAuthXmlTests.java b/samples/xml/preauth/src/test/java/sample/PreAuthXmlTests.java index be5528e86c..d5875080c6 100644 --- a/samples/xml/preauth/src/test/java/sample/PreAuthXmlTests.java +++ b/samples/xml/preauth/src/test/java/sample/PreAuthXmlTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package sample; import org.junit.Test; diff --git a/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/LoginForm.java b/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/LoginForm.java index bbf22962d7..f635fb2f25 100644 --- a/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/LoginForm.java +++ b/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/LoginForm.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.servletapi.mvc; /** diff --git a/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/ServletApiController.java b/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/ServletApiController.java index 95170c1e9a..bdfd1a6988 100644 --- a/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/ServletApiController.java +++ b/samples/xml/servletapi/src/main/java/org/springframework/security/samples/servletapi/mvc/ServletApiController.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.samples.servletapi.mvc; import java.io.IOException; diff --git a/samples/xml/tutorial/src/main/java/bigbank/Account.java b/samples/xml/tutorial/src/main/java/bigbank/Account.java index a5ef9cd8a4..438df6aa77 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/Account.java +++ b/samples/xml/tutorial/src/main/java/bigbank/Account.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank; /** diff --git a/samples/xml/tutorial/src/main/java/bigbank/BankDao.java b/samples/xml/tutorial/src/main/java/bigbank/BankDao.java index 4411896ee2..84ebc6c077 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/BankDao.java +++ b/samples/xml/tutorial/src/main/java/bigbank/BankDao.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank; public interface BankDao { diff --git a/samples/xml/tutorial/src/main/java/bigbank/BankDaoStub.java b/samples/xml/tutorial/src/main/java/bigbank/BankDaoStub.java index 78d78f85b2..0a782379fd 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/BankDaoStub.java +++ b/samples/xml/tutorial/src/main/java/bigbank/BankDaoStub.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank; import java.util.HashMap; diff --git a/samples/xml/tutorial/src/main/java/bigbank/BankService.java b/samples/xml/tutorial/src/main/java/bigbank/BankService.java index 0e8f6014bc..5f76bb98f3 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/BankService.java +++ b/samples/xml/tutorial/src/main/java/bigbank/BankService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank; import org.springframework.security.access.prepost.PreAuthorize; diff --git a/samples/xml/tutorial/src/main/java/bigbank/BankServiceImpl.java b/samples/xml/tutorial/src/main/java/bigbank/BankServiceImpl.java index 5feede7f61..e12f22845a 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/BankServiceImpl.java +++ b/samples/xml/tutorial/src/main/java/bigbank/BankServiceImpl.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank; import org.springframework.util.Assert; diff --git a/samples/xml/tutorial/src/main/java/bigbank/SeedData.java b/samples/xml/tutorial/src/main/java/bigbank/SeedData.java index 21b7791c0d..e4b15389e9 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/SeedData.java +++ b/samples/xml/tutorial/src/main/java/bigbank/SeedData.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank; import org.springframework.beans.factory.InitializingBean; diff --git a/samples/xml/tutorial/src/main/java/bigbank/web/ListAccounts.java b/samples/xml/tutorial/src/main/java/bigbank/web/ListAccounts.java index f17aad051b..7fd239b140 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/web/ListAccounts.java +++ b/samples/xml/tutorial/src/main/java/bigbank/web/ListAccounts.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank.web; import javax.servlet.http.HttpServletRequest; diff --git a/samples/xml/tutorial/src/main/java/bigbank/web/PostAccounts.java b/samples/xml/tutorial/src/main/java/bigbank/web/PostAccounts.java index 3a5f66c26a..e32ea4a0b7 100644 --- a/samples/xml/tutorial/src/main/java/bigbank/web/PostAccounts.java +++ b/samples/xml/tutorial/src/main/java/bigbank/web/PostAccounts.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package bigbank.web; import javax.servlet.http.HttpServletRequest; diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/TagLibConfig.java b/taglibs/src/main/java/org/springframework/security/taglibs/TagLibConfig.java index ec4360991a..64083a679a 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/TagLibConfig.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/TagLibConfig.java @@ -16,11 +16,11 @@ package org.springframework.security.taglibs; +import javax.servlet.jsp.tagext.Tag; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import javax.servlet.jsp.tagext.Tag; - /** * internal configuration class for taglibs. * @@ -29,39 +29,37 @@ import javax.servlet.jsp.tagext.Tag; * @author Luke Taylor */ public final class TagLibConfig { + static Log logger = LogFactory.getLog("spring-security-taglibs"); static final boolean DISABLE_UI_SECURITY; + static final String SECURED_UI_PREFIX; + static final String SECURED_UI_SUFFIX; static { String db = System.getProperty("spring.security.disableUISecurity"); String prefix = System.getProperty("spring.security.securedUIPrefix"); String suffix = System.getProperty("spring.security.securedUISuffix"); - - SECURED_UI_PREFIX = prefix == null ? "" : prefix; - SECURED_UI_SUFFIX = suffix == null ? "" : suffix; - + SECURED_UI_PREFIX = (prefix != null) ? prefix : ""; + SECURED_UI_SUFFIX = (suffix != null) ? suffix : ""; DISABLE_UI_SECURITY = "true".equals(db); - if (DISABLE_UI_SECURITY) { logger.warn("***** UI security is disabled. All unauthorized content will be displayed *****"); } } + private TagLibConfig() { + } + /** * Returns EVAL_BODY_INCLUDE if the authorized flag is true or UI security has been * disabled. Otherwise returns SKIP_BODY. - * * @param authorized whether the user is authorized to see the content or not */ public static int evalOrSkip(boolean authorized) { - if (authorized || DISABLE_UI_SECURITY) { - return Tag.EVAL_BODY_INCLUDE; - } - - return Tag.SKIP_BODY; + return (authorized || DISABLE_UI_SECURITY) ? Tag.EVAL_BODY_INCLUDE : Tag.SKIP_BODY; } public static boolean isUiSecurityDisabled() { @@ -75,4 +73,5 @@ public final class TagLibConfig { public static String getSecuredUiSuffix() { return SECURED_UI_SUFFIX; } + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java index 1da2aa966c..95fb17928f 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.authz; import java.io.IOException; @@ -55,8 +56,11 @@ import org.springframework.util.StringUtils; * @since 3.1.0 */ public abstract class AbstractAuthorizeTag { + private String access; + private String url; + private String method = "GET"; /** @@ -85,34 +89,23 @@ public abstract class AbstractAuthorizeTag { *
      12. url, method
      13. * * The above combinations are mutually exclusive and evaluated in the given order. - * * @return the result of the authorization decision * @throws IOException */ public boolean authorize() throws IOException { - boolean isAuthorized; - if (StringUtils.hasText(getAccess())) { - isAuthorized = authorizeUsingAccessExpression(); - + return authorizeUsingAccessExpression(); } - else if (StringUtils.hasText(getUrl())) { - isAuthorized = authorizeUsingUrlCheck(); - + if (StringUtils.hasText(getUrl())) { + return authorizeUsingUrlCheck(); } - else { - isAuthorized = false; - - } - - return isAuthorized; + return false; } /** * Make an authorization decision based on a Spring EL expression. See the * "Expression-Based Access Control" chapter in Spring Security for details on what * expressions can be used. - * * @return the result of the authorization decision * @throws IOException */ @@ -120,55 +113,41 @@ public abstract class AbstractAuthorizeTag { if (SecurityContextHolder.getContext().getAuthentication() == null) { return false; } - SecurityExpressionHandler handler = getExpressionHandler(); - Expression accessExpression; try { accessExpression = handler.getExpressionParser().parseExpression(getAccess()); - } - catch (ParseException e) { - IOException ioException = new IOException(); - ioException.initCause(e); - throw ioException; + catch (ParseException ex) { + throw new IOException(ex); } - - return ExpressionUtils.evaluateAsBoolean(accessExpression, - createExpressionEvaluationContext(handler)); + return ExpressionUtils.evaluateAsBoolean(accessExpression, createExpressionEvaluationContext(handler)); } /** * Allows the {@code EvaluationContext} to be customized for variable lookup etc. */ - protected EvaluationContext createExpressionEvaluationContext( - SecurityExpressionHandler handler) { - FilterInvocation f = new FilterInvocation(getRequest(), getResponse(), - (request, response) -> { - throw new UnsupportedOperationException(); - }); - - return handler.createEvaluationContext(SecurityContextHolder.getContext() - .getAuthentication(), f); + protected EvaluationContext createExpressionEvaluationContext(SecurityExpressionHandler handler) { + FilterInvocation f = new FilterInvocation(getRequest(), getResponse(), (request, response) -> { + throw new UnsupportedOperationException(); + }); + return handler.createEvaluationContext(SecurityContextHolder.getContext().getAuthentication(), f); } /** * Make an authorization decision based on the URL and HTTP method attributes. True is * returned if the user is allowed to access the given URL as defined. - * * @return the result of the authorization decision * @throws IOException */ public boolean authorizeUsingUrlCheck() throws IOException { String contextPath = ((HttpServletRequest) getRequest()).getContextPath(); - Authentication currentUser = SecurityContextHolder.getContext() - .getAuthentication(); - return getPrivilegeEvaluator().isAllowed(contextPath, getUrl(), getMethod(), - currentUser); + Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); + return getPrivilegeEvaluator().isAllowed(contextPath, getUrl(), getMethod(), currentUser); } public String getAccess() { - return access; + return this.access; } public void setAccess(String access) { @@ -176,7 +155,7 @@ public abstract class AbstractAuthorizeTag { } public String getUrl() { - return url; + return this.url; } public void setUrl(String url) { @@ -184,32 +163,26 @@ public abstract class AbstractAuthorizeTag { } public String getMethod() { - return method; + return this.method; } public void setMethod(String method) { this.method = (method != null) ? method.toUpperCase() : null; } - /*------------- Private helper methods -----------------*/ - @SuppressWarnings({ "unchecked", "rawtypes" }) - private SecurityExpressionHandler getExpressionHandler() - throws IOException { - ApplicationContext appContext = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext()); - Map handlers = appContext - .getBeansOfType(SecurityExpressionHandler.class); - - for (SecurityExpressionHandler h : handlers.values()) { - if (FilterInvocation.class.equals(GenericTypeResolver.resolveTypeArgument( - h.getClass(), SecurityExpressionHandler.class))) { - return h; + private SecurityExpressionHandler getExpressionHandler() throws IOException { + ApplicationContext appContext = SecurityWebApplicationContextUtils + .findRequiredWebApplicationContext(getServletContext()); + Map handlers = appContext.getBeansOfType(SecurityExpressionHandler.class); + for (SecurityExpressionHandler handler : handlers.values()) { + if (FilterInvocation.class.equals( + GenericTypeResolver.resolveTypeArgument(handler.getClass(), SecurityExpressionHandler.class))) { + return handler; } } - - throw new IOException( - "No visible WebSecurityExpressionHandler instance could be found in the application " - + "context. There must be at least one in order to support expressions in JSP 'authorize' tags."); + throw new IOException("No visible WebSecurityExpressionHandler instance could be found in the application " + + "context. There must be at least one in order to support expressions in JSP 'authorize' tags."); } private WebInvocationPrivilegeEvaluator getPrivilegeEvaluator() throws IOException { @@ -218,17 +191,15 @@ public abstract class AbstractAuthorizeTag { if (privEvaluatorFromRequest != null) { return privEvaluatorFromRequest; } - - ApplicationContext ctx = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext()); - Map wipes = ctx - .getBeansOfType(WebInvocationPrivilegeEvaluator.class); - + ApplicationContext ctx = SecurityWebApplicationContextUtils + .findRequiredWebApplicationContext(getServletContext()); + Map wipes = ctx.getBeansOfType(WebInvocationPrivilegeEvaluator.class); if (wipes.size() == 0) { throw new IOException( "No visible WebInvocationPrivilegeEvaluator instance could be found in the application " + "context. There must be at least one in order to support the use of URL access checks in 'authorize' tags."); } - return (WebInvocationPrivilegeEvaluator) wipes.values().toArray()[0]; } + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java index 0d1c6149e3..16904672a6 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.authz; import java.util.ArrayList; @@ -27,6 +28,7 @@ import javax.servlet.jsp.tagext.TagSupport; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.context.ApplicationContext; import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.core.Authentication; @@ -34,7 +36,6 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.taglibs.TagLibConfig; import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; - /** * An implementation of {@link Tag} that allows its body through if all authorizations are * granted to the request's principal. @@ -53,57 +54,41 @@ import org.springframework.security.web.context.support.SecurityWebApplicationCo * @author Rob Winch */ public class AccessControlListTag extends TagSupport { - // ~ Static fields/initializers - // ===================================================================================== protected static final Log logger = LogFactory.getLog(AccessControlListTag.class); - // ~ Instance fields - // ================================================================================================ - private ApplicationContext applicationContext; + private Object domainObject; + private PermissionEvaluator permissionEvaluator; + private String hasPermission = ""; + private String var; - // ~ Methods - // ======================================================================================================== - + @Override public int doStartTag() throws JspException { - if ((null == hasPermission) || "".equals(hasPermission)) { + if ((null == this.hasPermission) || "".equals(this.hasPermission)) { return skipBody(); } - initializeIfRequired(); - - if (domainObject == null) { - if (logger.isDebugEnabled()) { - logger.debug("domainObject resolved to null, so including tag body"); - } - + if (this.domainObject == null) { + logger.debug("domainObject resolved to null, so including tag body"); // Of course they have access to a null object! return evalBody(); } - - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); if (authentication == null) { - if (logger.isDebugEnabled()) { - logger.debug("SecurityContextHolder did not return a non-null Authentication object, so skipping tag body"); - } - + logger.debug("SecurityContextHolder did not return a non-null Authentication object, so skipping tag body"); return skipBody(); } - - List requiredPermissions = parseHasPermission(hasPermission); + List requiredPermissions = parseHasPermission(this.hasPermission); for (Object requiredPermission : requiredPermissions) { - if (!permissionEvaluator.hasPermission(authentication, domainObject, - requiredPermission)) { + if (!this.permissionEvaluator.hasPermission(authentication, this.domainObject, requiredPermission)) { return skipBody(); } } - return evalBody(); } @@ -115,7 +100,7 @@ public class AccessControlListTag extends TagSupport { try { parsedPermission = Integer.parseInt(permissionToParse); } - catch (NumberFormatException notBitMask) { + catch (NumberFormatException ex) { } parsedPermissions.add(parsedPermission); } @@ -123,68 +108,60 @@ public class AccessControlListTag extends TagSupport { } private int skipBody() { - if (var != null) { - pageContext.setAttribute(var, Boolean.FALSE, PageContext.PAGE_SCOPE); + if (this.var != null) { + this.pageContext.setAttribute(this.var, Boolean.FALSE, PageContext.PAGE_SCOPE); } return TagLibConfig.evalOrSkip(false); } private int evalBody() { - if (var != null) { - pageContext.setAttribute(var, Boolean.TRUE, PageContext.PAGE_SCOPE); + if (this.var != null) { + this.pageContext.setAttribute(this.var, Boolean.TRUE, PageContext.PAGE_SCOPE); } return TagLibConfig.evalOrSkip(true); } /** * Allows test cases to override where application context obtained from. - * * @param pageContext so the ServletContext can be accessed as required * by Spring's WebApplicationContextUtils - * * @return the Spring application context (never null) */ protected ApplicationContext getContext(PageContext pageContext) { ServletContext servletContext = pageContext.getServletContext(); - return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext); } public Object getDomainObject() { - return domainObject; + return this.domainObject; } public String getHasPermission() { - return hasPermission; + return this.hasPermission; } private void initializeIfRequired() throws JspException { - if (applicationContext != null) { + if (this.applicationContext != null) { return; } - - this.applicationContext = getContext(pageContext); - - permissionEvaluator = getBeanOfType(PermissionEvaluator.class); + this.applicationContext = getContext(this.pageContext); + this.permissionEvaluator = getBeanOfType(PermissionEvaluator.class); } private T getBeanOfType(Class type) throws JspException { - Map map = applicationContext.getBeansOfType(type); - - for (ApplicationContext context = applicationContext.getParent(); context != null; context = context + Map map = this.applicationContext.getBeansOfType(type); + for (ApplicationContext context = this.applicationContext.getParent(); context != null; context = context .getParent()) { map.putAll(context.getBeansOfType(type)); } - if (map.size() == 0) { return null; } - else if (map.size() == 1) { + if (map.size() == 1) { return map.values().iterator().next(); } - - throw new JspException("Found incorrect number of " + type.getSimpleName() - + " instances in " + "application context - you must have only have one!"); + throw new JspException("Found incorrect number of " + type.getSimpleName() + " instances in " + + "application context - you must have only have one!"); } public void setDomainObject(Object domainObject) { @@ -198,4 +175,5 @@ public class AccessControlListTag extends TagSupport { public void setVar(String var) { this.var = var; } + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java index e740ffb6ec..d7fd43627c 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java @@ -16,15 +16,6 @@ package org.springframework.security.taglibs.authz; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.web.util.TextEscapeUtils; - -import org.springframework.beans.BeanWrapperImpl; -import org.springframework.beans.BeansException; -import org.springframework.web.util.TagUtils; - import java.io.IOException; import javax.servlet.jsp.JspException; @@ -32,6 +23,14 @@ import javax.servlet.jsp.PageContext; import javax.servlet.jsp.tagext.Tag; import javax.servlet.jsp.tagext.TagSupport; +import org.springframework.beans.BeanWrapperImpl; +import org.springframework.beans.BeansException; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.util.TextEscapeUtils; +import org.springframework.web.util.TagUtils; + /** * An {@link javax.servlet.jsp.tagext.Tag} implementation that allows convenient access to * the current Authentication object. @@ -43,17 +42,15 @@ import javax.servlet.jsp.tagext.TagSupport; */ public class AuthenticationTag extends TagSupport { - // ~ Instance fields - // ================================================================================================ - private String var; - private String property; - private int scope; - private boolean scopeSpecified; - private boolean htmlEscape = true; - // ~ Methods - // ======================================================================================================== + private String property; + + private int scope; + + private boolean scopeSpecified; + + private boolean htmlEscape = true; public AuthenticationTag() { init(); @@ -61,9 +58,9 @@ public class AuthenticationTag extends TagSupport { // resets local state private void init() { - var = null; - scopeSpecified = false; - scope = PageContext.PAGE_SCOPE; + this.var = null; + this.scopeSpecified = false; + this.scope = PageContext.PAGE_SCOPE; } public void setVar(String var) { @@ -79,55 +76,53 @@ public class AuthenticationTag extends TagSupport { this.scopeSpecified = true; } + @Override public int doStartTag() throws JspException { return super.doStartTag(); } + @Override public int doEndTag() throws JspException { Object result = null; // determine the value by... - if (property != null) { + if (this.property != null) { if ((SecurityContextHolder.getContext() == null) || !(SecurityContextHolder.getContext() instanceof SecurityContext) || (SecurityContextHolder.getContext().getAuthentication() == null)) { return Tag.EVAL_PAGE; } - Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - if (auth.getPrincipal() == null) { return Tag.EVAL_PAGE; } - try { BeanWrapperImpl wrapper = new BeanWrapperImpl(auth); - result = wrapper.getPropertyValue(property); + result = wrapper.getPropertyValue(this.property); } - catch (BeansException e) { - throw new JspException(e); + catch (BeansException ex) { + throw new JspException(ex); } } - - if (var != null) { + if (this.var != null) { /* * Store the result, letting an IllegalArgumentException propagate back if the * scope is invalid (e.g., if an attempt is made to store something in the * session without any HttpSession existing). */ if (result != null) { - pageContext.setAttribute(var, result, scope); + this.pageContext.setAttribute(this.var, result, this.scope); } else { - if (scopeSpecified) { - pageContext.removeAttribute(var, scope); + if (this.scopeSpecified) { + this.pageContext.removeAttribute(this.var, this.scope); } else { - pageContext.removeAttribute(var); + this.pageContext.removeAttribute(this.var); } } } else { - if (htmlEscape) { + if (this.htmlEscape) { writeMessage(TextEscapeUtils.escapeEntities(String.valueOf(result))); } else { @@ -139,7 +134,7 @@ public class AuthenticationTag extends TagSupport { protected void writeMessage(String msg) throws JspException { try { - pageContext.getOut().write(String.valueOf(msg)); + this.pageContext.getOut().write(String.valueOf(msg)); } catch (IOException ioe) { throw new JspException(ioe); @@ -158,6 +153,7 @@ public class AuthenticationTag extends TagSupport { * overridden. */ protected boolean isHtmlEscape() { - return htmlEscape; + return this.htmlEscape; } + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/JspAuthorizeTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/JspAuthorizeTag.java index 5ee16f7a3e..e5b4cedb78 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/JspAuthorizeTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/JspAuthorizeTag.java @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.authz; import java.io.IOException; -import java.util.*; +import java.util.List; import javax.servlet.ServletContext; import javax.servlet.ServletRequest; @@ -43,8 +44,8 @@ import org.springframework.security.web.FilterInvocation; * A JSP {@link Tag} implementation of {@link AbstractAuthorizeTag}. * * @author Rossen Stoyanchev - * @see AbstractAuthorizeTag * @since 3.1.0 + * @see AbstractAuthorizeTag */ public class JspAuthorizeTag extends AbstractAuthorizeTag implements Tag { @@ -61,105 +62,101 @@ public class JspAuthorizeTag extends AbstractAuthorizeTag implements Tag { /** * Invokes the base class {@link AbstractAuthorizeTag#authorize()} method to decide if * the body of the tag should be skipped or not. - * * @return {@link Tag#SKIP_BODY} or {@link Tag#EVAL_BODY_INCLUDE} */ + @Override public int doStartTag() throws JspException { try { - authorized = super.authorize(); - - if (!authorized && TagLibConfig.isUiSecurityDisabled()) { - pageContext.getOut().write(TagLibConfig.getSecuredUiPrefix()); + this.authorized = super.authorize(); + if (!this.authorized && TagLibConfig.isUiSecurityDisabled()) { + this.pageContext.getOut().write(TagLibConfig.getSecuredUiPrefix()); } - - if (var != null) { - pageContext.setAttribute(var, authorized, PageContext.PAGE_SCOPE); + if (this.var != null) { + this.pageContext.setAttribute(this.var, this.authorized, PageContext.PAGE_SCOPE); } - - return TagLibConfig.evalOrSkip(authorized); - + return TagLibConfig.evalOrSkip(this.authorized); } - catch (IOException e) { - throw new JspException(e); + catch (IOException ex) { + throw new JspException(ex); } } @Override - protected EvaluationContext createExpressionEvaluationContext( - SecurityExpressionHandler handler) { - return new PageContextVariableLookupEvaluationContext( - super.createExpressionEvaluationContext(handler)); + protected EvaluationContext createExpressionEvaluationContext(SecurityExpressionHandler handler) { + return new PageContextVariableLookupEvaluationContext(super.createExpressionEvaluationContext(handler)); } /** * Default processing of the end tag returning EVAL_PAGE. - * * @return EVAL_PAGE * @see Tag#doEndTag() */ + @Override public int doEndTag() throws JspException { try { - if (!authorized && TagLibConfig.isUiSecurityDisabled()) { - pageContext.getOut().write(TagLibConfig.getSecuredUiSuffix()); + if (!this.authorized && TagLibConfig.isUiSecurityDisabled()) { + this.pageContext.getOut().write(TagLibConfig.getSecuredUiSuffix()); } } - catch (IOException e) { - throw new JspException(e); + catch (IOException ex) { + throw new JspException(ex); } - return EVAL_PAGE; } public String getId() { - return id; + return this.id; } public void setId(String id) { this.id = id; } + @Override public Tag getParent() { - return parent; + return this.parent; } + @Override public void setParent(Tag parent) { this.parent = parent; } public String getVar() { - return var; + return this.var; } public void setVar(String var) { this.var = var; } + @Override public void release() { - parent = null; - id = null; + this.parent = null; + this.id = null; } + @Override public void setPageContext(PageContext pageContext) { this.pageContext = pageContext; } @Override protected ServletRequest getRequest() { - return pageContext.getRequest(); + return this.pageContext.getRequest(); } @Override protected ServletResponse getResponse() { - return pageContext.getResponse(); + return this.pageContext.getResponse(); } @Override protected ServletContext getServletContext() { - return pageContext.getServletContext(); + return this.pageContext.getServletContext(); } - private final class PageContextVariableLookupEvaluationContext implements - EvaluationContext { + private final class PageContextVariableLookupEvaluationContext implements EvaluationContext { private EvaluationContext delegate; @@ -167,54 +164,65 @@ public class JspAuthorizeTag extends AbstractAuthorizeTag implements Tag { this.delegate = delegate; } + @Override public TypedValue getRootObject() { - return delegate.getRootObject(); + return this.delegate.getRootObject(); } + @Override public List getConstructorResolvers() { - return delegate.getConstructorResolvers(); + return this.delegate.getConstructorResolvers(); } + @Override public List getMethodResolvers() { - return delegate.getMethodResolvers(); + return this.delegate.getMethodResolvers(); } + @Override public List getPropertyAccessors() { - return delegate.getPropertyAccessors(); + return this.delegate.getPropertyAccessors(); } + @Override public TypeLocator getTypeLocator() { - return delegate.getTypeLocator(); + return this.delegate.getTypeLocator(); } + @Override public TypeConverter getTypeConverter() { - return delegate.getTypeConverter(); + return this.delegate.getTypeConverter(); } + @Override public TypeComparator getTypeComparator() { - return delegate.getTypeComparator(); + return this.delegate.getTypeComparator(); } + @Override public OperatorOverloader getOperatorOverloader() { - return delegate.getOperatorOverloader(); + return this.delegate.getOperatorOverloader(); } + @Override public BeanResolver getBeanResolver() { - return delegate.getBeanResolver(); + return this.delegate.getBeanResolver(); } + @Override public void setVariable(String name, Object value) { - delegate.setVariable(name, value); + this.delegate.setVariable(name, value); } + @Override public Object lookupVariable(String name) { - Object result = delegate.lookupVariable(name); - + Object result = this.delegate.lookupVariable(name); if (result == null) { - result = pageContext.findAttribute(name); + result = JspAuthorizeTag.this.pageContext.findAttribute(name); } return result; } + } } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/package-info.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/package-info.java index b0d0ce4388..c1d65c388d 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/package-info.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * JSP Security tag library implementation. */ package org.springframework.security.taglibs.authz; - diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/csrf/AbstractCsrfTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/csrf/AbstractCsrfTag.java index aff3d5f936..d9ad9145d2 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/csrf/AbstractCsrfTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/csrf/AbstractCsrfTag.java @@ -16,36 +16,35 @@ package org.springframework.security.taglibs.csrf; -import org.springframework.security.web.csrf.CsrfToken; +import java.io.IOException; import javax.servlet.jsp.JspException; import javax.servlet.jsp.tagext.TagSupport; -import java.io.IOException; + +import org.springframework.security.web.csrf.CsrfToken; /** * An abstract tag for handling CSRF operations. * - * @since 3.2.2 * @author Nick Williams + * @since 3.2.2 */ abstract class AbstractCsrfTag extends TagSupport { @Override public int doEndTag() throws JspException { - - CsrfToken token = (CsrfToken) this.pageContext.getRequest().getAttribute( - CsrfToken.class.getName()); + CsrfToken token = (CsrfToken) this.pageContext.getRequest().getAttribute(CsrfToken.class.getName()); if (token != null) { try { this.pageContext.getOut().write(this.handleToken(token)); } - catch (IOException e) { - throw new JspException(e); + catch (IOException ex) { + throw new JspException(ex); } } - return EVAL_PAGE; } protected abstract String handleToken(CsrfToken token); + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfInputTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfInputTag.java index 1b339a8b34..799c0eab6b 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfInputTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfInputTag.java @@ -22,14 +22,14 @@ import org.springframework.security.web.csrf.CsrfToken; * A JSP tag that prints out a hidden form field for the CSRF token. See the JSP Tab * Library documentation for more information. * - * @since 3.2.2 * @author Nick Williams + * @since 3.2.2 */ public class CsrfInputTag extends AbstractCsrfTag { @Override public String handleToken(CsrfToken token) { - return ""; + return ""; } + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTag.java index 11ad6d0a3b..389dd3d51c 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTag.java @@ -22,16 +22,16 @@ import org.springframework.security.web.csrf.CsrfToken; * A JSP tag that prints out a meta tags holding the CSRF form field name and token value * for use in JavaScrip code. See the JSP Tab Library documentation for more information. * - * @since 3.2.2 * @author Nick Williams + * @since 3.2.2 */ public class CsrfMetaTagsTag extends AbstractCsrfTag { @Override public String handleToken(CsrfToken token) { - return "" + "" + ""; + return "" + + "" + + ""; } + } diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/package-info.java b/taglibs/src/main/java/org/springframework/security/taglibs/package-info.java index 719ba06d3e..de2fd6e483 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/package-info.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/package-info.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Security related tag libraries that can be used in JSPs and templates. */ package org.springframework.security.taglibs; - diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/TldTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/TldTests.java index 44176488a9..37fad7c6c3 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/TldTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/TldTests.java @@ -16,34 +16,29 @@ package org.springframework.security.taglibs; -import org.junit.Test; -import org.w3c.dom.Document; +import java.io.File; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; -import java.io.File; + +import org.junit.Test; +import org.w3c.dom.Document; import static org.assertj.core.api.Assertions.assertThat; public class TldTests { - //SEC-2324 + // SEC-2324 @Test - public void testTldVersionIsCorrect() throws Exception{ + public void testTldVersionIsCorrect() throws Exception { String SPRING_SECURITY_VERSION = "springSecurityVersion"; - String version = System.getProperty(SPRING_SECURITY_VERSION); - File securityTld = new File("src/main/resources/META-INF/security.tld"); - DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder(); Document document = documentBuilder.parse(securityTld); - String tlibVersion = document.getElementsByTagName("tlib-version").item(0).getTextContent(); - assertThat(version).startsWith(tlibVersion); } - } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java index 7a7497bd94..d203c3b017 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java @@ -13,16 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.authz; -import static org.assertj.core.api.Assertions.*; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import java.io.IOException; import java.util.Collections; @@ -33,6 +26,7 @@ import javax.servlet.ServletResponse; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockServletContext; @@ -44,23 +38,33 @@ import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; import org.springframework.security.web.access.expression.DefaultWebSecurityExpressionHandler; import org.springframework.web.context.WebApplicationContext; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + /** - * * @author Rob Winch * */ public class AbstractAuthorizeTagTests { + private AbstractAuthorizeTag tag; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockServletContext servletContext; @Before public void setup() { - tag = new AuthzTag(); - request = new MockHttpServletRequest(); - response = new MockHttpServletResponse(); - servletContext = new MockServletContext(); + this.tag = new AuthzTag(); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.servletContext = new MockServletContext(); } @After @@ -72,12 +76,9 @@ public class AbstractAuthorizeTagTests { public void privilegeEvaluatorFromRequest() throws IOException { String uri = "/something"; WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class); - tag.setUrl(uri); - request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, - expected); - - tag.authorizeUsingUrlCheck(); - + this.tag.setUrl(uri); + this.request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, expected); + this.tag.authorizeUsingUrlCheck(); verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any()); } @@ -85,13 +86,12 @@ public class AbstractAuthorizeTagTests { public void privilegeEvaluatorFromChildContext() throws IOException { String uri = "/something"; WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class); - tag.setUrl(uri); + this.tag.setUrl(uri); WebApplicationContext wac = mock(WebApplicationContext.class); - when(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class)).thenReturn(Collections.singletonMap("wipe", expected)); - servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); - - tag.authorizeUsingUrlCheck(); - + given(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class)) + .willReturn(Collections.singletonMap("wipe", expected)); + this.servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + this.tag.authorizeUsingUrlCheck(); verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any()); } @@ -100,29 +100,31 @@ public class AbstractAuthorizeTagTests { public void expressionFromChildContext() throws IOException { SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "USER")); DefaultWebSecurityExpressionHandler expected = new DefaultWebSecurityExpressionHandler(); - tag.setAccess("permitAll"); + this.tag.setAccess("permitAll"); WebApplicationContext wac = mock(WebApplicationContext.class); - when(wac.getBeansOfType(SecurityExpressionHandler.class)).thenReturn(Collections.singletonMap("wipe", expected)); - servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); - - assertThat(tag.authorize()).isTrue(); + given(wac.getBeansOfType(SecurityExpressionHandler.class)) + .willReturn(Collections.singletonMap("wipe", expected)); + this.servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + assertThat(this.tag.authorize()).isTrue(); } private class AuthzTag extends AbstractAuthorizeTag { @Override protected ServletRequest getRequest() { - return request; + return AbstractAuthorizeTagTests.this.request; } @Override protected ServletResponse getResponse() { - return response; + return AbstractAuthorizeTagTests.this.response; } @Override protected ServletContext getServletContext() { - return servletContext; + return AbstractAuthorizeTagTests.this.servletContext; } + } + } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java index cba9d213ea..84d562ddb9 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java @@ -13,12 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.authz; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import java.util.HashMap; +import java.util.Map; + +import javax.servlet.ServletContext; +import javax.servlet.jsp.tagext.Tag; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; -import org.junit.*; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockPageContext; @@ -29,42 +36,42 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.context.WebApplicationContext; -import javax.servlet.ServletContext; -import javax.servlet.jsp.tagext.Tag; -import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; /** - * * @author Luke Taylor * @author Rob Winch * @since 3.0 */ @SuppressWarnings("unchecked") public class AccessControlListTagTests { + AccessControlListTag tag; + PermissionEvaluator pe; + MockPageContext pageContext; + Authentication bob = new TestingAuthenticationToken("bob", "bobspass", "A"); @Before @SuppressWarnings("rawtypes") public void setup() { - SecurityContextHolder.getContext().setAuthentication(bob); - tag = new AccessControlListTag(); + SecurityContextHolder.getContext().setAuthentication(this.bob); + this.tag = new AccessControlListTag(); WebApplicationContext ctx = mock(WebApplicationContext.class); - - pe = mock(PermissionEvaluator.class); - + this.pe = mock(PermissionEvaluator.class); Map beanMap = new HashMap(); - beanMap.put("pe", pe); - when(ctx.getBeansOfType(PermissionEvaluator.class)).thenReturn(beanMap); - + beanMap.put("pe", this.pe); + given(ctx.getBeansOfType(PermissionEvaluator.class)).willReturn(beanMap); MockServletContext servletCtx = new MockServletContext(); - servletCtx.setAttribute( - WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); - pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), - new MockHttpServletResponse()); - tag.setPageContext(pageContext); + servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); + this.pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse()); + this.tag.setPageContext(this.pageContext); } @After @@ -75,108 +82,96 @@ public class AccessControlListTagTests { @Test public void bodyIsEvaluatedIfAclGrantsAccess() throws Exception { Object domainObject = new Object(); - when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(true); - - tag.setDomainObject(domainObject); - tag.setHasPermission("READ"); - tag.setVar("allowed"); - assertThat(tag.getDomainObject()).isSameAs(domainObject); - assertThat(tag.getHasPermission()).isEqualTo("READ"); - - assertThat(tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); - assertThat((Boolean) pageContext.getAttribute("allowed")).isTrue(); + given(this.pe.hasPermission(this.bob, domainObject, "READ")).willReturn(true); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("READ"); + this.tag.setVar("allowed"); + assertThat(this.tag.getDomainObject()).isSameAs(domainObject); + assertThat(this.tag.getHasPermission()).isEqualTo("READ"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); } @Test public void childContext() throws Exception { - ServletContext servletContext = pageContext.getServletContext(); + ServletContext servletContext = this.pageContext.getServletContext(); WebApplicationContext wac = (WebApplicationContext) servletContext .getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); servletContext.removeAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); - Object domainObject = new Object(); - when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(true); - - tag.setDomainObject(domainObject); - tag.setHasPermission("READ"); - tag.setVar("allowed"); - assertThat(tag.getDomainObject()).isSameAs(domainObject); - assertThat(tag.getHasPermission()).isEqualTo("READ"); - - assertThat(tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); - assertThat((Boolean) pageContext.getAttribute("allowed")).isTrue(); + given(this.pe.hasPermission(this.bob, domainObject, "READ")).willReturn(true); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("READ"); + this.tag.setVar("allowed"); + assertThat(this.tag.getDomainObject()).isSameAs(domainObject); + assertThat(this.tag.getHasPermission()).isEqualTo("READ"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); } // SEC-2022 @Test public void multiHasPermissionsAreSplit() throws Exception { Object domainObject = new Object(); - when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(true); - when(pe.hasPermission(bob, domainObject, "WRITE")).thenReturn(true); - - tag.setDomainObject(domainObject); - tag.setHasPermission("READ,WRITE"); - tag.setVar("allowed"); - assertThat(tag.getDomainObject()).isSameAs(domainObject); - assertThat(tag.getHasPermission()).isEqualTo("READ,WRITE"); - - assertThat(tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); - assertThat((Boolean) pageContext.getAttribute("allowed")).isTrue(); - verify(pe).hasPermission(bob, domainObject, "READ"); - verify(pe).hasPermission(bob, domainObject, "WRITE"); - verifyNoMoreInteractions(pe); + given(this.pe.hasPermission(this.bob, domainObject, "READ")).willReturn(true); + given(this.pe.hasPermission(this.bob, domainObject, "WRITE")).willReturn(true); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("READ,WRITE"); + this.tag.setVar("allowed"); + assertThat(this.tag.getDomainObject()).isSameAs(domainObject); + assertThat(this.tag.getHasPermission()).isEqualTo("READ,WRITE"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); + verify(this.pe).hasPermission(this.bob, domainObject, "READ"); + verify(this.pe).hasPermission(this.bob, domainObject, "WRITE"); + verifyNoMoreInteractions(this.pe); } // SEC-2023 @Test public void hasPermissionsBitMaskSupported() throws Exception { Object domainObject = new Object(); - when(pe.hasPermission(bob, domainObject, 1)).thenReturn(true); - when(pe.hasPermission(bob, domainObject, 2)).thenReturn(true); - - tag.setDomainObject(domainObject); - tag.setHasPermission("1,2"); - tag.setVar("allowed"); - assertThat(tag.getDomainObject()).isSameAs(domainObject); - assertThat(tag.getHasPermission()).isEqualTo("1,2"); - - assertThat(tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); - assertThat((Boolean) pageContext.getAttribute("allowed")).isTrue(); - verify(pe).hasPermission(bob, domainObject, 1); - verify(pe).hasPermission(bob, domainObject, 2); - verifyNoMoreInteractions(pe); + given(this.pe.hasPermission(this.bob, domainObject, 1)).willReturn(true); + given(this.pe.hasPermission(this.bob, domainObject, 2)).willReturn(true); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("1,2"); + this.tag.setVar("allowed"); + assertThat(this.tag.getDomainObject()).isSameAs(domainObject); + assertThat(this.tag.getHasPermission()).isEqualTo("1,2"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); + verify(this.pe).hasPermission(this.bob, domainObject, 1); + verify(this.pe).hasPermission(this.bob, domainObject, 2); + verifyNoMoreInteractions(this.pe); } @Test public void hasPermissionsMixedBitMaskSupported() throws Exception { Object domainObject = new Object(); - when(pe.hasPermission(bob, domainObject, 1)).thenReturn(true); - when(pe.hasPermission(bob, domainObject, "WRITE")).thenReturn(true); - - tag.setDomainObject(domainObject); - tag.setHasPermission("1,WRITE"); - tag.setVar("allowed"); - assertThat(tag.getDomainObject()).isSameAs(domainObject); - assertThat(tag.getHasPermission()).isEqualTo("1,WRITE"); - - assertThat(tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); - assertThat((Boolean) pageContext.getAttribute("allowed")).isTrue(); - verify(pe).hasPermission(bob, domainObject, 1); - verify(pe).hasPermission(bob, domainObject, "WRITE"); - verifyNoMoreInteractions(pe); + given(this.pe.hasPermission(this.bob, domainObject, 1)).willReturn(true); + given(this.pe.hasPermission(this.bob, domainObject, "WRITE")).willReturn(true); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("1,WRITE"); + this.tag.setVar("allowed"); + assertThat(this.tag.getDomainObject()).isSameAs(domainObject); + assertThat(this.tag.getHasPermission()).isEqualTo("1,WRITE"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); + verify(this.pe).hasPermission(this.bob, domainObject, 1); + verify(this.pe).hasPermission(this.bob, domainObject, "WRITE"); + verifyNoMoreInteractions(this.pe); } @Test public void bodyIsSkippedIfAclDeniesAccess() throws Exception { Object domainObject = new Object(); - when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(false); - - tag.setDomainObject(domainObject); - tag.setHasPermission("READ"); - tag.setVar("allowed"); - - assertThat(tag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat((Boolean) pageContext.getAttribute("allowed")).isFalse(); + given(this.pe.hasPermission(this.bob, domainObject, "READ")).willReturn(false); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("READ"); + this.tag.setVar("allowed"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isFalse(); } + } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java index d768b992a6..555052ff36 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java @@ -16,35 +16,33 @@ package org.springframework.security.taglibs.authz; -import static org.assertj.core.api.Assertions.*; - import javax.servlet.jsp.JspException; import javax.servlet.jsp.tagext.Tag; import org.junit.After; import org.junit.Test; + import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + /** * Tests {@link AuthenticationTag}. * * @author Ben Alex */ public class AuthenticationTagTests { - // ~ Instance fields - // ================================================================================================ private final MyAuthenticationTag authenticationTag = new MyAuthenticationTag(); - private final Authentication auth = new TestingAuthenticationToken(new User( - "rodUserDetails", "koala", true, true, true, true, - AuthorityUtils.NO_AUTHORITIES), "koala", AuthorityUtils.NO_AUTHORITIES); - // ~ Methods - // ======================================================================================================== + private final Authentication auth = new TestingAuthenticationToken( + new User("rodUserDetails", "koala", true, true, true, true, AuthorityUtils.NO_AUTHORITIES), "koala", + AuthorityUtils.NO_AUTHORITIES); @After public void tearDown() { @@ -53,72 +51,64 @@ public class AuthenticationTagTests { @Test public void testOperationWhenPrincipalIsAUserDetailsInstance() throws JspException { - SecurityContextHolder.getContext().setAuthentication(auth); - - authenticationTag.setProperty("name"); - assertThat(authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat(authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); - assertThat(authenticationTag.getLastMessage()).isEqualTo("rodUserDetails"); + SecurityContextHolder.getContext().setAuthentication(this.auth); + this.authenticationTag.setProperty("name"); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + assertThat(this.authenticationTag.getLastMessage()).isEqualTo("rodUserDetails"); } @Test public void testOperationWhenPrincipalIsAString() throws JspException { SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("rodAsString", "koala", - AuthorityUtils.NO_AUTHORITIES)); - - authenticationTag.setProperty("principal"); - assertThat(authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat(authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); - assertThat(authenticationTag.getLastMessage()).isEqualTo("rodAsString"); + new TestingAuthenticationToken("rodAsString", "koala", AuthorityUtils.NO_AUTHORITIES)); + this.authenticationTag.setProperty("principal"); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + assertThat(this.authenticationTag.getLastMessage()).isEqualTo("rodAsString"); } @Test public void testNestedPropertyIsReadCorrectly() throws JspException { - SecurityContextHolder.getContext().setAuthentication(auth); - - authenticationTag.setProperty("principal.username"); - assertThat(authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat(authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); - assertThat(authenticationTag.getLastMessage()).isEqualTo("rodUserDetails"); + SecurityContextHolder.getContext().setAuthentication(this.auth); + this.authenticationTag.setProperty("principal.username"); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + assertThat(this.authenticationTag.getLastMessage()).isEqualTo("rodUserDetails"); } @Test public void testOperationWhenPrincipalIsNull() throws JspException { - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken(null, "koala", - AuthorityUtils.NO_AUTHORITIES)); - - authenticationTag.setProperty("principal"); - assertThat(authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat(authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + SecurityContextHolder.getContext() + .setAuthentication(new TestingAuthenticationToken(null, "koala", AuthorityUtils.NO_AUTHORITIES)); + this.authenticationTag.setProperty("principal"); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); } @Test public void testOperationWhenSecurityContextIsNull() throws Exception { SecurityContextHolder.getContext().setAuthentication(null); - - authenticationTag.setProperty("principal"); - assertThat(authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat(authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); - assertThat(authenticationTag.getLastMessage()).isNull(); + this.authenticationTag.setProperty("principal"); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + assertThat(this.authenticationTag.getLastMessage()).isNull(); } @Test public void testSkipsBodyIfNullOrEmptyOperation() throws Exception { - authenticationTag.setProperty(""); - assertThat(authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); - assertThat(authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + this.authenticationTag.setProperty(""); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); } @Test public void testThrowsExceptionForUnrecognisedProperty() { - SecurityContextHolder.getContext().setAuthentication(auth); - authenticationTag.setProperty("qsq"); - + SecurityContextHolder.getContext().setAuthentication(this.auth); + this.authenticationTag.setProperty("qsq"); try { - authenticationTag.doStartTag(); - authenticationTag.doEndTag(); + this.authenticationTag.doStartTag(); + this.authenticationTag.doEndTag(); fail("Should have throwns JspException"); } catch (JspException expected) { @@ -127,37 +117,36 @@ public class AuthenticationTagTests { @Test public void htmlEscapingIsUsedByDefault() throws Exception { - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("<>& ", "")); - authenticationTag.setProperty("name"); - authenticationTag.doStartTag(); - authenticationTag.doEndTag(); - assertThat(authenticationTag.getLastMessage()).isEqualTo("<>& "); + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("<>& ", "")); + this.authenticationTag.setProperty("name"); + this.authenticationTag.doStartTag(); + this.authenticationTag.doEndTag(); + assertThat(this.authenticationTag.getLastMessage()).isEqualTo("<>& "); } @Test public void settingHtmlEscapeToFalsePreventsEscaping() throws Exception { - SecurityContextHolder.getContext().setAuthentication( - new TestingAuthenticationToken("<>& ", "")); - authenticationTag.setProperty("name"); - authenticationTag.setHtmlEscape("false"); - authenticationTag.doStartTag(); - authenticationTag.doEndTag(); - assertThat(authenticationTag.getLastMessage()).isEqualTo("<>& "); + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("<>& ", "")); + this.authenticationTag.setProperty("name"); + this.authenticationTag.setHtmlEscape("false"); + this.authenticationTag.doStartTag(); + this.authenticationTag.doEndTag(); + assertThat(this.authenticationTag.getLastMessage()).isEqualTo("<>& "); } - // ~ Inner Classes - // ================================================================================================== - private class MyAuthenticationTag extends AuthenticationTag { + String lastMessage = null; - public String getLastMessage() { - return lastMessage; + String getLastMessage() { + return this.lastMessage; } + @Override protected void writeMessage(String msg) { - lastMessage = msg; + this.lastMessage = msg; } + } + } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthorizeTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthorizeTagTests.java index 25c8e926e3..5111ff448b 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthorizeTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthorizeTagTests.java @@ -16,9 +16,6 @@ package org.springframework.security.taglibs.authz; -import static org.mockito.Mockito.*; -import static org.assertj.core.api.Assertions.assertThat; - import javax.servlet.jsp.JspException; import javax.servlet.jsp.tagext.Tag; @@ -28,6 +25,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -42,43 +40,41 @@ import org.springframework.security.web.access.expression.DefaultWebSecurityExpr import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.StaticWebApplicationContext; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; + /** * @author Francois Beausoleil * @author Luke Taylor */ @RunWith(MockitoJUnitRunner.class) public class AuthorizeTagTests { - // ~ Instance fields - // ================================================================================================ @Mock private PermissionEvaluator permissionEvaluator; - private JspAuthorizeTag authorizeTag; - private MockHttpServletRequest request = new MockHttpServletRequest(); - private final TestingAuthenticationToken currentUser = new TestingAuthenticationToken( - "abc", "123", "ROLE SUPERVISOR", "ROLE_TELLER"); - // ~ Methods - // ======================================================================================================== + private JspAuthorizeTag authorizeTag; + + private MockHttpServletRequest request = new MockHttpServletRequest(); + + private final TestingAuthenticationToken currentUser = new TestingAuthenticationToken("abc", "123", + "ROLE SUPERVISOR", "ROLE_TELLER"); @Before public void setUp() { - SecurityContextHolder.getContext().setAuthentication(currentUser); + SecurityContextHolder.getContext().setAuthentication(this.currentUser); StaticWebApplicationContext ctx = new StaticWebApplicationContext(); - BeanDefinitionBuilder webExpressionHandler = BeanDefinitionBuilder .rootBeanDefinition(DefaultWebSecurityExpressionHandler.class); - webExpressionHandler.addPropertyValue("permissionEvaluator", permissionEvaluator); - - ctx.registerBeanDefinition("expressionHandler", - webExpressionHandler.getBeanDefinition()); + webExpressionHandler.addPropertyValue("permissionEvaluator", this.permissionEvaluator); + ctx.registerBeanDefinition("expressionHandler", webExpressionHandler.getBeanDefinition()); ctx.registerSingleton("wipe", MockWebInvocationPrivilegeEvaluator.class); MockServletContext servletCtx = new MockServletContext(); - servletCtx.setAttribute( - WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); - authorizeTag = new JspAuthorizeTag(); - authorizeTag.setPageContext(new MockPageContext(servletCtx, request, - new MockHttpServletResponse())); + servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); + this.authorizeTag = new JspAuthorizeTag(); + this.authorizeTag.setPageContext(new MockPageContext(servletCtx, this.request, new MockHttpServletResponse())); } @After @@ -87,83 +83,81 @@ public class AuthorizeTagTests { } // access attribute tests - @Test public void taglibsDocumentationHasPermissionOr() throws Exception { Object domain = new Object(); - request.setAttribute("domain", domain); - authorizeTag - .setAccess("hasPermission(#domain,'read') or hasPermission(#domain,'write')"); - when(permissionEvaluator.hasPermission(eq(currentUser), eq(domain), anyString())) - .thenReturn(true); - - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + this.request.setAttribute("domain", domain); + this.authorizeTag.setAccess("hasPermission(#domain,'read') or hasPermission(#domain,'write')"); + given(this.permissionEvaluator.hasPermission(eq(this.currentUser), eq(domain), anyString())).willReturn(true); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); } @Test public void skipsBodyIfNoAuthenticationPresent() throws Exception { SecurityContextHolder.clearContext(); - authorizeTag.setAccess("permitAll"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + this.authorizeTag.setAccess("permitAll"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); } @Test public void skipsBodyIfAccessExpressionDeniesAccess() throws Exception { - authorizeTag.setAccess("denyAll"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + this.authorizeTag.setAccess("denyAll"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); } @Test public void showsBodyIfAccessExpressionAllowsAccess() throws Exception { - authorizeTag.setAccess("permitAll"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + this.authorizeTag.setAccess("permitAll"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); } @Test public void requestAttributeIsResolvedAsElVariable() throws JspException { - request.setAttribute("blah", "blah"); - authorizeTag.setAccess("#blah == 'blah'"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + this.request.setAttribute("blah", "blah"); + this.authorizeTag.setAccess("#blah == 'blah'"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); } // url attribute tests @Test public void skipsBodyWithUrlSetIfNoAuthenticationPresent() throws Exception { SecurityContextHolder.clearContext(); - authorizeTag.setUrl("/something"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + this.authorizeTag.setUrl("/something"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); } @Test public void skipsBodyIfUrlIsNotAllowed() throws Exception { - authorizeTag.setUrl("/notallowed"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + this.authorizeTag.setUrl("/notallowed"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); } @Test public void evaluatesBodyIfUrlIsAllowed() throws Exception { - authorizeTag.setUrl("/allowed"); - authorizeTag.setMethod("GET"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + this.authorizeTag.setUrl("/allowed"); + this.authorizeTag.setMethod("GET"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); } @Test public void skipsBodyIfMethodIsNotAllowed() throws Exception { - authorizeTag.setUrl("/allowed"); - authorizeTag.setMethod("POST"); - assertThat(authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + this.authorizeTag.setUrl("/allowed"); + this.authorizeTag.setMethod("POST"); + assertThat(this.authorizeTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); } - public static class MockWebInvocationPrivilegeEvaluator implements - WebInvocationPrivilegeEvaluator { + public static class MockWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator { + @Override public boolean isAllowed(String uri, Authentication authentication) { return "/allowed".equals(uri); } - public boolean isAllowed(String contextPath, String uri, String method, - Authentication authentication) { + @Override + public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) { return "/allowed".equals(uri) && (method == null || "GET".equals(method)); } + } + } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/csrf/AbstractCsrfTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/csrf/AbstractCsrfTagTests.java index f34050034a..51727e71dd 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/csrf/AbstractCsrfTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/csrf/AbstractCsrfTagTests.java @@ -13,10 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.csrf; +import java.io.UnsupportedEncodingException; + +import javax.servlet.jsp.JspException; +import javax.servlet.jsp.tagext.Tag; + import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockPageContext; @@ -24,12 +31,7 @@ import org.springframework.mock.web.MockServletContext; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; -import javax.servlet.jsp.JspException; -import javax.servlet.jsp.tagext.TagSupport; - -import java.io.UnsupportedEncodingException; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Nick Williams @@ -37,7 +39,9 @@ import static org.assertj.core.api.Assertions.*; public class AbstractCsrfTagTests { public MockTag tag; + private MockHttpServletRequest request; + private MockHttpServletResponse response; @Before @@ -45,60 +49,48 @@ public class AbstractCsrfTagTests { MockServletContext servletContext = new MockServletContext(); this.request = new MockHttpServletRequest(servletContext); this.response = new MockHttpServletResponse(); - MockPageContext pageContext = new MockPageContext(servletContext, this.request, - this.response); + MockPageContext pageContext = new MockPageContext(servletContext, this.request, this.response); this.tag = new MockTag(); this.tag.setPageContext(pageContext); } @Test public void noCsrfDoesNotRender() throws JspException, UnsupportedEncodingException { - this.tag.handleReturn = "shouldNotBeRendered"; - int returned = this.tag.doEndTag(); - - assertThat(returned).as("The returned value is not correct.").isEqualTo(TagSupport.EVAL_PAGE); - assertThat(this.response.getContentAsString()).withFailMessage("The output value is not correct.").isEqualTo(""); + assertThat(returned).as("The returned value is not correct.").isEqualTo(Tag.EVAL_PAGE); + assertThat(this.response.getContentAsString()).withFailMessage("The output value is not correct.") + .isEqualTo(""); } @Test - public void hasCsrfRendersReturnedValue() throws JspException, - UnsupportedEncodingException { - - CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", - "abc123def456ghi789"); + public void hasCsrfRendersReturnedValue() throws JspException, UnsupportedEncodingException { + CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", "abc123def456ghi789"); this.request.setAttribute(CsrfToken.class.getName(), token); - this.tag.handleReturn = "fooBarBazQux"; - int returned = this.tag.doEndTag(); - - assertThat(returned).as("The returned value is not correct.").isEqualTo(TagSupport.EVAL_PAGE); - assertThat(this.response.getContentAsString()).withFailMessage("The output value is not correct.").isEqualTo("fooBarBazQux"); + assertThat(returned).as("The returned value is not correct.").isEqualTo(Tag.EVAL_PAGE); + assertThat(this.response.getContentAsString()).withFailMessage("The output value is not correct.") + .isEqualTo("fooBarBazQux"); assertThat(this.tag.token).as("The token is not correct.").isSameAs(token); } @Test - public void hasCsrfRendersDifferentValue() throws JspException, - UnsupportedEncodingException { - - CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", - "abc123def456ghi789"); + public void hasCsrfRendersDifferentValue() throws JspException, UnsupportedEncodingException { + CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", "abc123def456ghi789"); this.request.setAttribute(CsrfToken.class.getName(), token); - this.tag.handleReturn = ""; - int returned = this.tag.doEndTag(); - - assertThat(returned).as("The returned value is not correct.").isEqualTo(TagSupport.EVAL_PAGE); - assertThat(this.response.getContentAsString()).withFailMessage("The output value is not correct.").isEqualTo(""); + assertThat(returned).as("The returned value is not correct.").isEqualTo(Tag.EVAL_PAGE); + assertThat(this.response.getContentAsString()).withFailMessage("The output value is not correct.") + .isEqualTo(""); assertThat(this.tag.token).as("The token is not correct.").isSameAs(token); } private static class MockTag extends AbstractCsrfTag { private CsrfToken token; + private String handleReturn; @Override @@ -106,5 +98,7 @@ public class AbstractCsrfTagTests { this.token = token; return this.handleReturn; } + } + } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfInputTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfInputTagTests.java index 4f68345705..aa9b84251a 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfInputTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfInputTagTests.java @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.csrf; import org.junit.Before; import org.junit.Test; + import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Nick Williams @@ -36,24 +38,20 @@ public class CsrfInputTagTests { @Test public void handleTokenReturnsHiddenInput() { - CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", - "abc123def456ghi789"); - + CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", "abc123def456ghi789"); String value = this.tag.handleToken(token); - assertThat(value).as("The returned value should not be null.").isNotNull(); - assertThat( - value).withFailMessage("The output is not correct.").isEqualTo(""); + assertThat(value).withFailMessage("The output is not correct.") + .isEqualTo(""); } @Test public void handleTokenReturnsHiddenInputDifferentTokenValue() { - CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "csrfParameter", - "fooBarBazQux"); - + CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "csrfParameter", "fooBarBazQux"); String value = this.tag.handleToken(token); - assertThat(value).as("The returned value should not be null.").isNotNull(); - assertThat(value).withFailMessage("The output is not correct.").isEqualTo(""); + assertThat(value).withFailMessage("The output is not correct.") + .isEqualTo(""); } + } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTagTests.java index f71f19e9ca..6e07a33e72 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/csrf/CsrfMetaTagsTagTests.java @@ -13,14 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.taglibs.csrf; import org.junit.Before; import org.junit.Test; + import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Nick Williams @@ -36,27 +38,24 @@ public class CsrfMetaTagsTagTests { @Test public void handleTokenRendersTags() { - CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", - "abc123def456ghi789"); - + CsrfToken token = new DefaultCsrfToken("X-Csrf-Token", "_csrf", "abc123def456ghi789"); String value = this.tag.handleToken(token); - assertThat(value).as("The returned value should not be null.").isNotNull(); - assertThat(value).withFailMessage("The output is not correct.").isEqualTo("" - + "" - + ""); + assertThat(value).withFailMessage("The output is not correct.") + .isEqualTo("" + + "" + + ""); } @Test public void handleTokenRendersTagsDifferentToken() { - CsrfToken token = new DefaultCsrfToken("csrfHeader", "csrfParameter", - "fooBarBazQux"); - + CsrfToken token = new DefaultCsrfToken("csrfHeader", "csrfParameter", "fooBarBazQux"); String value = this.tag.handleToken(token); - assertThat(value).as("The returned value should not be null.").isNotNull(); - assertThat(value).withFailMessage("The output is not correct.").isEqualTo("" - + "" - + ""); + assertThat(value).withFailMessage("The output is not correct.") + .isEqualTo("" + + "" + + ""); } + } diff --git a/test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolder.java b/test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolder.java index fc22058465..f62938d499 100644 --- a/test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolder.java +++ b/test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolder.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context; import javax.servlet.FilterChain; @@ -46,7 +47,8 @@ import org.springframework.util.Assert; * *
      14. The test is ran. When used with {@link MockMvc} it is typically used with * {@link SecurityMockMvcRequestPostProcessors#testSecurityContext()}. Which ensures the - * {@link SecurityContext} from {@link TestSecurityContextHolder} is properly populated.
      15. + * {@link SecurityContext} from {@link TestSecurityContextHolder} is properly + * populated. *
      16. After the test is executed, the {@link TestSecurityContextHolder} and the * {@link SecurityContextHolder} are cleared out
      17. * @@ -54,12 +56,14 @@ import org.springframework.util.Assert; * @author Rob Winch * @author Tadaya Tsuyukubo * @since 4.0 - * */ public final class TestSecurityContextHolder { private static final ThreadLocal contextHolder = new ThreadLocal<>(); + private TestSecurityContextHolder() { + } + /** * Clears the {@link SecurityContext} from {@link TestSecurityContextHolder} and * {@link SecurityContextHolder}. @@ -71,17 +75,14 @@ public final class TestSecurityContextHolder { /** * Gets the {@link SecurityContext} from {@link TestSecurityContextHolder}. - * * @return the {@link SecurityContext} from {@link TestSecurityContextHolder}. */ public static SecurityContext getContext() { SecurityContext ctx = contextHolder.get(); - if (ctx == null) { ctx = getDefaultContext(); contextHolder.set(ctx); } - return ctx; } @@ -97,10 +98,9 @@ public final class TestSecurityContextHolder { } /** - * Creates a new {@link SecurityContext} with the given {@link Authentication}. - * The {@link SecurityContext} is set on {@link TestSecurityContextHolder} and + * Creates a new {@link SecurityContext} with the given {@link Authentication}. The + * {@link SecurityContext} is set on {@link TestSecurityContextHolder} and * {@link SecurityContextHolder}. - * * @param authentication the {@link Authentication} to use * @since 5.1.1 */ @@ -114,13 +114,10 @@ public final class TestSecurityContextHolder { /** * Gets the default {@link SecurityContext} by delegating to the * {@link SecurityContextHolder} - * * @return the default {@link SecurityContext} */ private static SecurityContext getDefaultContext() { return SecurityContextHolder.getContext(); } - private TestSecurityContextHolder() { - } } diff --git a/test/src/main/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListeners.java b/test/src/main/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListeners.java index 0ae6271523..265f4fe1cd 100644 --- a/test/src/main/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListeners.java +++ b/test/src/main/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListeners.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.annotation; import java.lang.annotation.Documented; @@ -28,11 +29,10 @@ import org.springframework.security.test.context.support.WithSecurityContextTest import org.springframework.test.context.TestExecutionListeners; /** - * There are many times a user may want to use Spring Security's test support - * (i.e. WithMockUser) but have no need for any other - * {@link TestExecutionListeners} (i.e. no need to setup an - * {@link ApplicationContext}). This annotation is a meta annotation that only - * enables Spring Security's {@link TestExecutionListeners}. + * There are many times a user may want to use Spring Security's test support (i.e. + * WithMockUser) but have no need for any other {@link TestExecutionListeners} (i.e. no + * need to setup an {@link ApplicationContext}). This annotation is a meta annotation that + * only enables Spring Security's {@link TestExecutionListeners}. * * @author Rob Winch * @since 4.0.2 @@ -43,7 +43,8 @@ import org.springframework.test.context.TestExecutionListeners; @Inherited @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) -@TestExecutionListeners(inheritListeners = false, listeners = {WithSecurityContextTestExecutionListener.class, - ReactorContextTestExecutionListener.class}) +@TestExecutionListeners(inheritListeners = false, + listeners = { WithSecurityContextTestExecutionListener.class, ReactorContextTestExecutionListener.class }) public @interface SecurityTestExecutionListeners { + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/DelegatingTestExecutionListener.java b/test/src/main/java/org/springframework/security/test/context/support/DelegatingTestExecutionListener.java index 3b34dda6bf..a26db38020 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/DelegatingTestExecutionListener.java +++ b/test/src/main/java/org/springframework/security/test/context/support/DelegatingTestExecutionListener.java @@ -25,8 +25,7 @@ import org.springframework.util.Assert; * @author Rob Winch * @since 5.0 */ -class DelegatingTestExecutionListener - extends AbstractTestExecutionListener { +class DelegatingTestExecutionListener extends AbstractTestExecutionListener { private final TestExecutionListener delegate; @@ -37,36 +36,37 @@ class DelegatingTestExecutionListener @Override public void beforeTestClass(TestContext testContext) throws Exception { - delegate.beforeTestClass(testContext); + this.delegate.beforeTestClass(testContext); } @Override public void prepareTestInstance(TestContext testContext) throws Exception { - delegate.prepareTestInstance(testContext); + this.delegate.prepareTestInstance(testContext); } @Override public void beforeTestMethod(TestContext testContext) throws Exception { - delegate.beforeTestMethod(testContext); + this.delegate.beforeTestMethod(testContext); } @Override public void beforeTestExecution(TestContext testContext) throws Exception { - delegate.beforeTestExecution(testContext); + this.delegate.beforeTestExecution(testContext); } @Override public void afterTestExecution(TestContext testContext) throws Exception { - delegate.afterTestExecution(testContext); + this.delegate.afterTestExecution(testContext); } @Override public void afterTestMethod(TestContext testContext) throws Exception { - delegate.afterTestMethod(testContext); + this.delegate.afterTestMethod(testContext); } @Override public void afterTestClass(TestContext testContext) throws Exception { - delegate.afterTestClass(testContext); + this.delegate.afterTestClass(testContext); } + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java b/test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java index 0ad8200938..dd6dbad701 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java +++ b/test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java @@ -17,6 +17,12 @@ package org.springframework.security.test.context.support; import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; @@ -25,11 +31,6 @@ import org.springframework.test.context.TestContext; import org.springframework.test.context.TestExecutionListener; import org.springframework.test.context.support.AbstractTestExecutionListener; import org.springframework.util.ClassUtils; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Hooks; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Operators; -import reactor.util.context.Context; /** * Sets up the Reactor Context with the Authentication from the TestSecurityContextHolder @@ -40,10 +41,10 @@ import reactor.util.context.Context; * @see WithSecurityContextTestExecutionListener * @see org.springframework.security.test.context.annotation.SecurityTestExecutionListeners */ -public class ReactorContextTestExecutionListener - extends DelegatingTestExecutionListener { +public class ReactorContextTestExecutionListener extends DelegatingTestExecutionListener { private static final String HOOKS_CLASS_NAME = "reactor.core.publisher.Hooks"; + private static final String CONTEXT_OPERATOR_KEY = SecurityContext.class.getName(); public ReactorContextTestExecutionListener() { @@ -51,70 +52,11 @@ public class ReactorContextTestExecutionListener } private static TestExecutionListener createDelegate() { - return ClassUtils.isPresent(HOOKS_CLASS_NAME, ReactorContextTestExecutionListener.class.getClassLoader()) ? - new DelegateTestExecutionListener() : - new AbstractTestExecutionListener() {}; - } - - private static class DelegateTestExecutionListener extends AbstractTestExecutionListener { - @Override - public void beforeTestMethod(TestContext testContext) { - SecurityContext securityContext = TestSecurityContextHolder.getContext(); - Hooks.onLastOperator(CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> new SecuritySubContext<>(sub, securityContext))); - } - - @Override - public void afterTestMethod(TestContext testContext) { - Hooks.resetOnLastOperator(CONTEXT_OPERATOR_KEY); - } - - private static class SecuritySubContext implements CoreSubscriber { - private static String CONTEXT_DEFAULTED_ATTR_NAME = SecuritySubContext.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); - - private final CoreSubscriber delegate; - private final SecurityContext securityContext; - - SecuritySubContext(CoreSubscriber delegate, SecurityContext securityContext) { - this.delegate = delegate; - this.securityContext = securityContext; - } - - @Override - public Context currentContext() { - Context context = delegate.currentContext(); - if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) { - return context; - } - context = context.put(CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE); - Authentication authentication = securityContext.getAuthentication(); - if (authentication == null) { - return context; - } - Context toMerge = ReactiveSecurityContextHolder.withSecurityContext( - Mono.just(this.securityContext)); - return toMerge.putAll(context); - } - - @Override - public void onSubscribe(Subscription s) { - delegate.onSubscribe(s); - } - - @Override - public void onNext(T t) { - delegate.onNext(t); - } - - @Override - public void onError(Throwable t) { - delegate.onError(t); - } - - @Override - public void onComplete() { - delegate.onComplete(); - } + if (!ClassUtils.isPresent(HOOKS_CLASS_NAME, ReactorContextTestExecutionListener.class.getClassLoader())) { + return new AbstractTestExecutionListener() { + }; } + return new DelegateTestExecutionListener(); } /** @@ -124,4 +66,72 @@ public class ReactorContextTestExecutionListener public int getOrder() { return 11000; } + + private static class DelegateTestExecutionListener extends AbstractTestExecutionListener { + + @Override + public void beforeTestMethod(TestContext testContext) { + SecurityContext securityContext = TestSecurityContextHolder.getContext(); + Hooks.onLastOperator(CONTEXT_OPERATOR_KEY, + Operators.lift((s, sub) -> new SecuritySubContext<>(sub, securityContext))); + } + + @Override + public void afterTestMethod(TestContext testContext) { + Hooks.resetOnLastOperator(CONTEXT_OPERATOR_KEY); + } + + private static class SecuritySubContext implements CoreSubscriber { + + private static String CONTEXT_DEFAULTED_ATTR_NAME = SecuritySubContext.class.getName() + .concat(".CONTEXT_DEFAULTED_ATTR_NAME"); + + private final CoreSubscriber delegate; + + private final SecurityContext securityContext; + + SecuritySubContext(CoreSubscriber delegate, SecurityContext securityContext) { + this.delegate = delegate; + this.securityContext = securityContext; + } + + @Override + public Context currentContext() { + Context context = this.delegate.currentContext(); + if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) { + return context; + } + context = context.put(CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE); + Authentication authentication = this.securityContext.getAuthentication(); + if (authentication == null) { + return context; + } + Context toMerge = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(this.securityContext)); + return toMerge.putAll(context); + } + + @Override + public void onSubscribe(Subscription s) { + this.delegate.onSubscribe(s); + } + + @Override + public void onNext(T t) { + this.delegate.onNext(t); + } + + @Override + public void onError(Throwable ex) { + this.delegate.onError(ex); + } + + @Override + public void onComplete() { + this.delegate.onComplete(); + } + + } + + } + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/TestExecutionEvent.java b/test/src/main/java/org/springframework/security/test/context/support/TestExecutionEvent.java index 2051494549..6ea8ed13d0 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/TestExecutionEvent.java +++ b/test/src/main/java/org/springframework/security/test/context/support/TestExecutionEvent.java @@ -19,20 +19,26 @@ package org.springframework.security.test.context.support; import org.springframework.test.context.TestContext; /** - * Represents the events on the methods of {@link org.springframework.test.context.TestExecutionListener} + * Represents the events on the methods of + * {@link org.springframework.test.context.TestExecutionListener} * * @author Rob Winch * @since 5.1 */ public enum TestExecutionEvent { + /** - * Associated to {@link org.springframework.test.context.TestExecutionListener#beforeTestMethod(TestContext)} + * Associated to + * {@link org.springframework.test.context.TestExecutionListener#beforeTestMethod(TestContext)} * event. */ TEST_METHOD, + /** - * Associated to {@link org.springframework.test.context.TestExecutionListener#beforeTestExecution(TestContext)} + * Associated to + * {@link org.springframework.test.context.TestExecutionListener#beforeTestExecution(TestContext)} * event. */ TEST_EXECUTION + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUser.java b/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUser.java index 7ce99146e3..6cb49ac9f0 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUser.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Documented; @@ -28,12 +29,12 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.test.context.TestContext; /** - * When used with {@link WithSecurityContextTestExecutionListener} this - * annotation can be added to a test method to emulate running with an anonymous - * user. The {@link SecurityContext} that is used will contain an - * {@link AnonymousAuthenticationToken}. This is useful when a user wants to run - * a majority of tests as a specific user and wishes to override a few methods - * to be anonymous. For example: + * When used with {@link WithSecurityContextTestExecutionListener} this annotation can be + * added to a test method to emulate running with an anonymous user. The + * {@link SecurityContext} that is used will contain an + * {@link AnonymousAuthenticationToken}. This is useful when a user wants to run a + * majority of tests as a specific user and wishes to override a few methods to be + * anonymous. For example: * *
          * 
        @@ -47,8 +48,7 @@ import org.springframework.test.context.TestContext;
          *
          *     // ... lots of tests ran with a default user ...
          * }
        - * 
        - * 
        + * * * @author Rob Winch * @since 4.1 @@ -69,4 +69,5 @@ public @interface WithAnonymousUser { */ @AliasFor(annotation = WithSecurityContext.class) TestExecutionEvent setupBefore() default TestExecutionEvent.TEST_METHOD; + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUserSecurityContextFactory.java b/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUserSecurityContextFactory.java index 8d7cf4a043..d812fa9877 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUserSecurityContextFactory.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUserSecurityContextFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.util.List; @@ -25,18 +26,16 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; /** - * A {@link WithAnonymousUserSecurityContextFactory} that runs with an {@link AnonymousAuthenticationToken}. - * . - * - * @see WithUserDetails + * A {@link WithAnonymousUserSecurityContextFactory} that runs with an + * {@link AnonymousAuthenticationToken}. . * * @author Rob Winch * @since 4.1 + * @see WithUserDetails */ +final class WithAnonymousUserSecurityContextFactory implements WithSecurityContextFactory { -final class WithAnonymousUserSecurityContextFactory implements - WithSecurityContextFactory { - + @Override public SecurityContext createSecurityContext(WithAnonymousUser withUser) { List authorities = AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"); Authentication authentication = new AnonymousAuthenticationToken("key", "anonymous", authorities); @@ -44,4 +43,5 @@ final class WithAnonymousUserSecurityContextFactory implements context.setAuthentication(authentication); return context; } -} \ No newline at end of file + +} diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithMockUser.java b/test/src/main/java/org/springframework/security/test/context/support/WithMockUser.java index d7500ffcb3..c4c62244d2 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithMockUser.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithMockUser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Documented; @@ -56,6 +57,7 @@ import org.springframework.test.web.servlet.MockMvc; @Documented @WithSecurityContext(factory = WithMockUserSecurityContextFactory.class) public @interface WithMockUser { + /** * Convenience mechanism for specifying the username. The default is "user". If * {@link #username()} is specified it will be used instead of {@link #value()} @@ -78,9 +80,9 @@ public @interface WithMockUser { * with "ROLE_". For example, the default will result in "ROLE_USER" being used. *

        *

        - * If {@link #authorities()} is specified this property cannot be changed from the default. + * If {@link #authorities()} is specified this property cannot be changed from the + * default. *

        - * * @return */ String[] roles() default { "USER" }; @@ -94,7 +96,6 @@ public @interface WithMockUser { * If this property is specified then {@link #roles()} is not used. This differs from * {@link #roles()} in that it does not prefix the values passed in automatically. *

        - * * @return */ String[] authorities() default {}; @@ -114,4 +115,5 @@ public @interface WithMockUser { */ @AliasFor(annotation = WithSecurityContext.class) TestExecutionEvent setupBefore() default TestExecutionEvent.TEST_METHOD; + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java b/test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java index d2cc86d966..513723c1f1 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.util.ArrayList; @@ -26,6 +27,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; /** @@ -35,40 +37,32 @@ import org.springframework.util.StringUtils; * @since 4.0 * @see WithMockUser */ -final class WithMockUserSecurityContextFactory implements - WithSecurityContextFactory { +final class WithMockUserSecurityContextFactory implements WithSecurityContextFactory { + @Override public SecurityContext createSecurityContext(WithMockUser withUser) { - String username = StringUtils.hasLength(withUser.username()) ? withUser - .username() : withUser.value(); - if (username == null) { - throw new IllegalArgumentException(withUser - + " cannot have null username on both username and value properties"); - } - + String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value(); + Assert.notNull(username, () -> withUser + " cannot have null username on both username and value properties"); List grantedAuthorities = new ArrayList<>(); for (String authority : withUser.authorities()) { grantedAuthorities.add(new SimpleGrantedAuthority(authority)); } - if (grantedAuthorities.isEmpty()) { for (String role : withUser.roles()) { - if (role.startsWith("ROLE_")) { - throw new IllegalArgumentException("roles cannot start with ROLE_ Got " - + role); - } + Assert.isTrue(!role.startsWith("ROLE_"), () -> "roles cannot start with ROLE_ Got " + role); grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_" + role)); } - } else if (!(withUser.roles().length == 1 && "USER".equals(withUser.roles()[0]))) { - throw new IllegalStateException("You cannot define roles attribute "+ Arrays.asList(withUser.roles())+" with authorities attribute "+ Arrays.asList(withUser.authorities())); } - - User principal = new User(username, withUser.password(), true, true, true, true, - grantedAuthorities); - Authentication authentication = new UsernamePasswordAuthenticationToken( - principal, principal.getPassword(), principal.getAuthorities()); + else if (!(withUser.roles().length == 1 && "USER".equals(withUser.roles()[0]))) { + throw new IllegalStateException("You cannot define roles attribute " + Arrays.asList(withUser.roles()) + + " with authorities attribute " + Arrays.asList(withUser.authorities())); + } + User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities); + Authentication authentication = new UsernamePasswordAuthenticationToken(principal, principal.getPassword(), + principal.getAuthorities()); SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); return context; } + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContext.java b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContext.java index 5f401eecf8..d696fe5e1b 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContext.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContext.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Annotation; @@ -54,11 +55,11 @@ import org.springframework.test.context.TestContext; @Inherited @Documented public @interface WithSecurityContext { + /** * The {@link WithUserDetailsSecurityContextFactory} to use to create the * {@link SecurityContext}. It can contain {@link Autowired} and other Spring * annotations. - * * @return */ Class> factory(); diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextFactory.java b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextFactory.java index e99060b5e7..86dbafc668 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextFactory.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Annotation; @@ -24,22 +25,21 @@ import org.springframework.security.test.context.TestSecurityContextHolder; * An API that works with WithUserTestExcecutionListener for creating a * {@link SecurityContext} that is populated in the {@link TestSecurityContextHolder}. * - * @author Rob Winch - * * @param + * @author Rob Winch + * @since 4.0 * @see WithSecurityContext * @see WithMockUser * @see WithUserDetails - * @since 4.0 */ public interface WithSecurityContextFactory { /** * Create a {@link SecurityContext} given an Annotation. - * * @param annotation the {@link Annotation} to create the {@link SecurityContext} * from. Cannot be null. * @return the {@link SecurityContext} to use. Cannot be null. */ SecurityContext createSecurityContext(A annotation); -} \ No newline at end of file + +} diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java index 0d7321a9f0..e184e6a28d 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Annotation; @@ -47,10 +48,10 @@ import org.springframework.test.web.servlet.MockMvc; * @see ReactorContextTestExecutionListener * @see org.springframework.security.test.context.annotation.SecurityTestExecutionListeners */ -public class WithSecurityContextTestExecutionListener - extends AbstractTestExecutionListener { +public class WithSecurityContextTestExecutionListener extends AbstractTestExecutionListener { - static final String SECURITY_CONTEXT_ATTR_NAME = WithSecurityContextTestExecutionListener.class.getName().concat(".SECURITY_CONTEXT"); + static final String SECURITY_CONTEXT_ATTR_NAME = WithSecurityContextTestExecutionListener.class.getName() + .concat(".SECURITY_CONTEXT"); /** * Sets up the {@link SecurityContext} for each test method. First the specific method @@ -60,21 +61,18 @@ public class WithSecurityContextTestExecutionListener */ @Override public void beforeTestMethod(TestContext testContext) { - TestSecurityContext testSecurityContext = createTestSecurityContext( - testContext.getTestMethod(), testContext); + TestSecurityContext testSecurityContext = createTestSecurityContext(testContext.getTestMethod(), testContext); if (testSecurityContext == null) { - testSecurityContext = createTestSecurityContext(testContext.getTestClass(), - testContext); + testSecurityContext = createTestSecurityContext(testContext.getTestClass(), testContext); } if (testSecurityContext == null) { return; } - - Supplier supplier = testSecurityContext - .getSecurityContextSupplier(); + Supplier supplier = testSecurityContext.getSecurityContextSupplier(); if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) { TestSecurityContextHolder.setContext(supplier.get()); - } else { + } + else { testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier); } } @@ -92,19 +90,17 @@ public class WithSecurityContextTestExecutionListener } } - private TestSecurityContext createTestSecurityContext(AnnotatedElement annotated, - TestContext context) { - WithSecurityContext withSecurityContext = AnnotatedElementUtils - .findMergedAnnotation(annotated, WithSecurityContext.class); + private TestSecurityContext createTestSecurityContext(AnnotatedElement annotated, TestContext context) { + WithSecurityContext withSecurityContext = AnnotatedElementUtils.findMergedAnnotation(annotated, + WithSecurityContext.class); return createTestSecurityContext(annotated, withSecurityContext, context); } - private TestSecurityContext createTestSecurityContext(Class annotated, - TestContext context) { + private TestSecurityContext createTestSecurityContext(Class annotated, TestContext context) { MetaAnnotationUtils.AnnotationDescriptor withSecurityContextDescriptor = MetaAnnotationUtils .findAnnotationDescriptor(annotated, WithSecurityContext.class); - WithSecurityContext withSecurityContext = withSecurityContextDescriptor == null - ? null : withSecurityContextDescriptor.getAnnotation(); + WithSecurityContext withSecurityContext = (withSecurityContextDescriptor != null) + ? withSecurityContextDescriptor.getAnnotation() : null; return createTestSecurityContext(annotated, withSecurityContext, context); } @@ -114,35 +110,32 @@ public class WithSecurityContextTestExecutionListener if (withSecurityContext == null) { return null; } - withSecurityContext = AnnotationUtils - .synthesizeAnnotation(withSecurityContext, annotated); + withSecurityContext = AnnotationUtils.synthesizeAnnotation(withSecurityContext, annotated); WithSecurityContextFactory factory = createFactory(withSecurityContext, context); Class type = (Class) GenericTypeResolver - .resolveTypeArgument(factory.getClass(), - WithSecurityContextFactory.class); + .resolveTypeArgument(factory.getClass(), WithSecurityContextFactory.class); Annotation annotation = findAnnotation(annotated, type); Supplier supplier = () -> { try { return factory.createSecurityContext(annotation); - } catch (RuntimeException e) { - throw new IllegalStateException( - "Unable to create SecurityContext using " + annotation, e); + } + catch (RuntimeException ex) { + throw new IllegalStateException("Unable to create SecurityContext using " + annotation, ex); } }; TestExecutionEvent initialize = withSecurityContext.setupBefore(); return new TestSecurityContext(supplier, initialize); } - private Annotation findAnnotation(AnnotatedElement annotated, - Class type) { + private Annotation findAnnotation(AnnotatedElement annotated, Class type) { Annotation findAnnotation = AnnotationUtils.findAnnotation(annotated, type); if (findAnnotation != null) { return findAnnotation; } Annotation[] allAnnotations = AnnotationUtils.getAnnotations(annotated); for (Annotation annotationToTest : allAnnotations) { - WithSecurityContext withSecurityContext = AnnotationUtils.findAnnotation( - annotationToTest.annotationType(), WithSecurityContext.class); + WithSecurityContext withSecurityContext = AnnotationUtils.findAnnotation(annotationToTest.annotationType(), + WithSecurityContext.class); if (withSecurityContext != null) { return annotationToTest; } @@ -150,19 +143,17 @@ public class WithSecurityContextTestExecutionListener return null; } - private WithSecurityContextFactory createFactory( - WithSecurityContext withSecurityContext, TestContext testContext) { - Class> clazz = withSecurityContext - .factory(); + private WithSecurityContextFactory createFactory(WithSecurityContext withSecurityContext, + TestContext testContext) { + Class> clazz = withSecurityContext.factory(); try { - return testContext.getApplicationContext().getAutowireCapableBeanFactory() - .createBean(clazz); + return testContext.getApplicationContext().getAutowireCapableBeanFactory().createBean(clazz); } - catch (IllegalStateException e) { + catch (IllegalStateException ex) { return BeanUtils.instantiateClass(clazz); } - catch (Exception e) { - throw new RuntimeException(e); + catch (Exception ex) { + throw new RuntimeException(ex); } } @@ -184,21 +175,24 @@ public class WithSecurityContextTestExecutionListener } static class TestSecurityContext { + private final Supplier securityContextSupplier; + private final TestExecutionEvent testExecutionEvent; - TestSecurityContext(Supplier securityContextSupplier, - TestExecutionEvent testExecutionEvent) { + TestSecurityContext(Supplier securityContextSupplier, TestExecutionEvent testExecutionEvent) { this.securityContextSupplier = securityContextSupplier; this.testExecutionEvent = testExecutionEvent; } - public Supplier getSecurityContextSupplier() { + Supplier getSecurityContextSupplier() { return this.securityContextSupplier; } - public TestExecutionEvent getTestExecutionEvent() { + TestExecutionEvent getTestExecutionEvent() { return this.testExecutionEvent; } + } + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithUserDetails.java b/test/src/main/java/org/springframework/security/test/context/support/WithUserDetails.java index 453539946c..7da8db4ba8 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithUserDetails.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithUserDetails.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Documented; @@ -55,18 +56,17 @@ import org.springframework.test.web.servlet.MockMvc; @Documented @WithSecurityContext(factory = WithUserDetailsSecurityContextFactory.class) public @interface WithUserDetails { + /** * The username to look up in the {@link UserDetailsService} - * * @return */ String value() default "user"; /** - * The bean name for the {@link UserDetailsService} to use. If this is not - * provided, then the lookup is done by type and expects only a single + * The bean name for the {@link UserDetailsService} to use. If this is not provided, + * then the lookup is done by type and expects only a single * {@link UserDetailsService} bean to be exposed. - * * @return the bean name for the {@link UserDetailsService} to use. * @since 4.1 */ @@ -81,4 +81,5 @@ public @interface WithUserDetails { */ @AliasFor(annotation = WithSecurityContext.class) TestExecutionEvent setupBefore() default TestExecutionEvent.TEST_METHOD; + } diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java b/test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java index caa8dc6351..cabc9e348b 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import org.springframework.beans.factory.BeanFactory; @@ -35,16 +36,14 @@ import org.springframework.util.StringUtils; * A {@link WithUserDetailsSecurityContextFactory} that works with {@link WithUserDetails} * . * - * @see WithUserDetails - * * @author Rob Winch * @since 4.0 + * @see WithUserDetails */ +final class WithUserDetailsSecurityContextFactory implements WithSecurityContextFactory { -final class WithUserDetailsSecurityContextFactory implements - WithSecurityContextFactory { - - private static final boolean reactorPresent = ClassUtils.isPresent("reactor.core.publisher.Mono", WithUserDetailsSecurityContextFactory.class.getClassLoader()); + private static final boolean reactorPresent = ClassUtils.isPresent("reactor.core.publisher.Mono", + WithUserDetailsSecurityContextFactory.class.getClassLoader()); private BeanFactory beans; @@ -53,14 +52,15 @@ final class WithUserDetailsSecurityContextFactory implements this.beans = beans; } + @Override public SecurityContext createSecurityContext(WithUserDetails withUser) { String beanName = withUser.userDetailsServiceBeanName(); UserDetailsService userDetailsService = findUserDetailsService(beanName); String username = withUser.value(); Assert.hasLength(username, "value() must be non empty String"); UserDetails principal = userDetailsService.loadUserByUsername(username); - Authentication authentication = new UsernamePasswordAuthenticationToken( - principal, principal.getPassword(), principal.getAuthorities()); + Authentication authentication = new UsernamePasswordAuthenticationToken(principal, principal.getPassword(), + principal.getAuthorities()); SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); return context; @@ -73,35 +73,35 @@ final class WithUserDetailsSecurityContextFactory implements return reactive; } } - return StringUtils.hasLength(beanName) - ? this.beans.getBean(beanName, UserDetailsService.class) - : this.beans.getBean(UserDetailsService.class); + return StringUtils.hasLength(beanName) ? this.beans.getBean(beanName, UserDetailsService.class) + : this.beans.getBean(UserDetailsService.class); } - public UserDetailsService findAndAdaptReactiveUserDetailsService(String beanName) { + UserDetailsService findAndAdaptReactiveUserDetailsService(String beanName) { try { - ReactiveUserDetailsService reactiveUserDetailsService = StringUtils - .hasLength(beanName) ? - this.beans.getBean(beanName, ReactiveUserDetailsService.class) : - this.beans.getBean(ReactiveUserDetailsService.class); + ReactiveUserDetailsService reactiveUserDetailsService = StringUtils.hasLength(beanName) + ? this.beans.getBean(beanName, ReactiveUserDetailsService.class) + : this.beans.getBean(ReactiveUserDetailsService.class); return new ReactiveUserDetailsServiceAdapter(reactiveUserDetailsService); - } catch(NoSuchBeanDefinitionException | BeanNotOfRequiredTypeException notReactive) { + } + catch (NoSuchBeanDefinitionException | BeanNotOfRequiredTypeException ex) { return null; } } - private class ReactiveUserDetailsServiceAdapter implements UserDetailsService { + private final class ReactiveUserDetailsServiceAdapter implements UserDetailsService { + private final ReactiveUserDetailsService userDetailsService; - private ReactiveUserDetailsServiceAdapter( - ReactiveUserDetailsService userDetailsService) { + private ReactiveUserDetailsServiceAdapter(ReactiveUserDetailsService userDetailsService) { this.userDetailsService = userDetailsService; } @Override - public UserDetails loadUserByUsername(String username) - throws UsernameNotFoundException { + public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException { return this.userDetailsService.findByUsername(username).block(); } + } + } diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java index 997f41eb23..6b21de46cc 100644 --- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java @@ -67,6 +67,7 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter; @@ -89,9 +90,6 @@ import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; -import static java.lang.Boolean.TRUE; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; - /** * Test utilities for working with Spring Security and * {@link org.springframework.test.web.reactive.server.WebTestClient.Builder#apply(WebTestClientConfigurer)}. @@ -99,7 +97,10 @@ import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; * @author Rob Winch * @since 5.0 */ -public class SecurityMockServerConfigurers { +public final class SecurityMockServerConfigurers { + + private SecurityMockServerConfigurers() { + } /** * Sets up Spring Security's {@link WebTestClient} test support @@ -107,50 +108,50 @@ public class SecurityMockServerConfigurers { */ public static MockServerConfigurer springSecurity() { return new MockServerConfigurer() { + + @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { - builder.filters( filters -> filters.add(0, new MutatorFilter())); + builder.filters((filters) -> filters.add(0, new MutatorFilter())); } + }; } /** * Updates the ServerWebExchange to use the provided Authentication as the Principal - * * @param authentication the Authentication to use. * @return the configurer to use */ - public static T mockAuthentication(Authentication authentication) { + public static T mockAuthentication( + Authentication authentication) { return (T) new MutatorWebTestClientConfigurer(() -> Mono.just(authentication).map(SecurityContextImpl::new)); } /** - * Updates the ServerWebExchange to use the provided UserDetails to create a UsernamePasswordAuthenticationToken as - * the Principal - * + * Updates the ServerWebExchange to use the provided UserDetails to create a + * UsernamePasswordAuthenticationToken as the Principal * @param userDetails the UserDetails to use. * @return the configurer to use */ public static T mockUser(UserDetails userDetails) { - return mockAuthentication(new UsernamePasswordAuthenticationToken(userDetails, userDetails.getPassword(), userDetails.getAuthorities())); + return mockAuthentication(new UsernamePasswordAuthenticationToken(userDetails, userDetails.getPassword(), + userDetails.getAuthorities())); } /** - * Updates the ServerWebExchange to use a UserDetails to create a UsernamePasswordAuthenticationToken as - * the Principal. This uses a default username of "user", password of "password", and granted authorities of - * "ROLE_USER". - * + * Updates the ServerWebExchange to use a UserDetails to create a + * UsernamePasswordAuthenticationToken as the Principal. This uses a default username + * of "user", password of "password", and granted authorities of "ROLE_USER". * @return the {@link UserExchangeMutator} to use */ public static UserExchangeMutator mockUser() { return mockUser("user"); } - /** - * Updates the ServerWebExchange to use a UserDetails to create a UsernamePasswordAuthenticationToken as - * the Principal. This uses a default password of "password" and granted authorities of - * "ROLE_USER". - * + * Updates the ServerWebExchange to use a UserDetails to create a + * UsernamePasswordAuthenticationToken as the Principal. This uses a default password + * of "password" and granted authorities of "ROLE_USER". * @return the {@link WebTestClientConfigurer} to use */ public static UserExchangeMutator mockUser(String username) { @@ -159,11 +160,9 @@ public class SecurityMockServerConfigurers { /** * Updates the ServerWebExchange to establish a {@link SecurityContext} that has a - * {@link JwtAuthenticationToken} for the - * {@link Authentication} and a {@link Jwt} for the - * {@link Authentication#getPrincipal()}. All details are - * declarative and do not require the JWT to be valid. - * + * {@link JwtAuthenticationToken} for the {@link Authentication} and a {@link Jwt} for + * the {@link Authentication#getPrincipal()}. All details are declarative and do not + * require the JWT to be valid. * @return the {@link JwtMutator} to further configure or use * @since 5.2 */ @@ -173,11 +172,9 @@ public class SecurityMockServerConfigurers { /** * Updates the ServerWebExchange to establish a {@link SecurityContext} that has a - * {@link BearerTokenAuthentication} for the - * {@link Authentication} and an {@link OAuth2AuthenticatedPrincipal} for the - * {@link Authentication#getPrincipal()}. All details are - * declarative and do not require the token to be valid. - * + * {@link BearerTokenAuthentication} for the {@link Authentication} and an + * {@link OAuth2AuthenticatedPrincipal} for the {@link Authentication#getPrincipal()}. + * All details are declarative and do not require the token to be valid. * @return the {@link OpaqueTokenMutator} to further configure or use * @since 5.3 */ @@ -187,43 +184,39 @@ public class SecurityMockServerConfigurers { /** * Updates the ServerWebExchange to establish a {@link SecurityContext} that has a - * {@link OAuth2AuthenticationToken} for the - * {@link Authentication}. All details are + * {@link OAuth2AuthenticationToken} for the {@link Authentication}. All details are * declarative and do not require the corresponding OAuth 2.0 tokens to be valid. - * * @return the {@link OAuth2LoginMutator} to further configure or use * @since 5.3 */ public static OAuth2LoginMutator mockOAuth2Login() { - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", - null, null, Collections.singleton("read")); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, + null, Collections.singleton("read")); return new OAuth2LoginMutator(accessToken); } /** * Updates the ServerWebExchange to establish a {@link SecurityContext} that has a - * {@link OAuth2AuthenticationToken} for the - * {@link Authentication}. All details are + * {@link OAuth2AuthenticationToken} for the {@link Authentication}. All details are * declarative and do not require the corresponding OAuth 2.0 tokens to be valid. - * * @return the {@link OidcLoginMutator} to further configure or use * @since 5.3 */ public static OidcLoginMutator mockOidcLogin() { - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", - null, null, Collections.singleton("read")); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, + null, Collections.singleton("read")); return new OidcLoginMutator(accessToken); } /** - * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the session. - * All details are declarative and do not require the corresponding OAuth 2.0 tokens to be valid. + * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the + * session. All details are declarative and do not require the corresponding OAuth 2.0 + * tokens to be valid. * *

        - * The support works by associating the authorized client to the ServerWebExchange - * via the {@link WebSessionServerOAuth2AuthorizedClientRepository} + * The support works by associating the authorized client to the ServerWebExchange via + * the {@link WebSessionServerOAuth2AuthorizedClientRepository} *

        - * * @return the {@link OAuth2ClientMutator} to further configure or use * @since 5.3 */ @@ -232,15 +225,16 @@ public class SecurityMockServerConfigurers { } /** - * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the session. - * All details are declarative and do not require the corresponding OAuth 2.0 tokens to be valid. + * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the + * session. All details are declarative and do not require the corresponding OAuth 2.0 + * tokens to be valid. * *

        - * The support works by associating the authorized client to the ServerWebExchange - * via the {@link WebSessionServerOAuth2AuthorizedClientRepository} + * The support works by associating the authorized client to the ServerWebExchange via + * the {@link WebSessionServerOAuth2AuthorizedClientRepository} *

        - * - * @param registrationId The registration id associated with the {@link OAuth2AuthorizedClient} + * @param registrationId The registration id associated with the + * {@link OAuth2AuthorizedClient} * @return the {@link OAuth2ClientMutator} to further configure or use * @since 5.3 */ @@ -252,20 +246,21 @@ public class SecurityMockServerConfigurers { return new CsrfMutator(); } - public static class CsrfMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class CsrfMutator implements WebTestClientConfigurer, MockServerConfigurer { - @Override - public void afterConfigurerAdded(WebTestClient.Builder builder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, - @Nullable ClientHttpConnector connector) { - CsrfWebFilter filter = new CsrfWebFilter(); - filter.setRequireCsrfProtectionMatcher( e -> ServerWebExchangeMatcher.MatchResult.notMatch()); - httpHandlerBuilder.filters( filters -> filters.add(0, filter)); + private CsrfMutator() { } @Override - public void afterConfigureAdded( - WebTestClient.MockServerSpec serverSpec) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { + CsrfWebFilter filter = new CsrfWebFilter(); + filter.setRequireCsrfProtectionMatcher((e) -> ServerWebExchangeMatcher.MatchResult.notMatch()); + httpHandlerBuilder.filters((filters) -> filters.add(0, filter)); + } + + @Override + public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { } @@ -274,14 +269,15 @@ public class SecurityMockServerConfigurers { } - private CsrfMutator() {} } /** - * Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}}. Defaults to use a - * password of "password" and granted authorities of "ROLE_USER". + * Updates the WebServerExchange using {@code {@link + * SecurityMockServerConfigurers#mockUser(UserDetails)}}. Defaults to use a password + * of "password" and granted authorities of "ROLE_USER". */ - public static class UserExchangeMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class UserExchangeMutator implements WebTestClientConfigurer, MockServerConfigurer { + private final User.UserBuilder userBuilder; private UserExchangeMutator(String username) { @@ -301,9 +297,8 @@ public class SecurityMockServerConfigurers { } /** - * Specifies the roles to use. Default is "USER". This is similar to authorities except each role is - * automatically prefixed with "ROLE_USER". - * + * Specifies the roles to use. Default is "USER". This is similar to authorities + * except each role is automatically prefixed with "ROLE_USER". * @param roles the roles to use. * @return the UserExchangeMutator */ @@ -314,7 +309,6 @@ public class SecurityMockServerConfigurers { /** * Specifies the {@code GrantedAuthority}s to use. Default is "ROLE_USER". - * * @param authorities the authorities to use. * @return the UserExchangeMutator */ @@ -325,7 +319,6 @@ public class SecurityMockServerConfigurers { /** * Specifies the {@code GrantedAuthority}s to use. Default is "ROLE_USER". - * * @param authorities the authorities to use. * @return the UserExchangeMutator */ @@ -375,37 +368,46 @@ public class SecurityMockServerConfigurers { } @Override - public void afterConfigurerAdded(WebTestClient.Builder builder, @Nullable WebHttpHandlerBuilder webHttpHandlerBuilder, @Nullable ClientHttpConnector clientHttpConnector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder webHttpHandlerBuilder, + @Nullable ClientHttpConnector clientHttpConnector) { configurer().afterConfigurerAdded(builder, webHttpHandlerBuilder, clientHttpConnector); } private T configurer() { return mockUser(this.userBuilder.build()); } + } - private static class MutatorWebTestClientConfigurer implements WebTestClientConfigurer, MockServerConfigurer { + private static final class MutatorWebTestClientConfigurer implements WebTestClientConfigurer, MockServerConfigurer { + private final Supplier> context; private MutatorWebTestClientConfigurer(Supplier> context) { this.context = context; } + @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { builder.filters(addSetupMutatorFilter()); } @Override - public void afterConfigurerAdded(WebTestClient.Builder builder, @Nullable WebHttpHandlerBuilder webHttpHandlerBuilder, @Nullable ClientHttpConnector clientHttpConnector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder webHttpHandlerBuilder, + @Nullable ClientHttpConnector clientHttpConnector) { webHttpHandlerBuilder.filters(addSetupMutatorFilter()); } private Consumer> addSetupMutatorFilter() { - return filters -> filters.add(0, new SetupMutatorFilter(this.context)); + return (filters) -> filters.add(0, new SetupMutatorFilter(this.context)); } + } - private static class SetupMutatorFilter implements WebFilter { + private static final class SetupMutatorFilter implements WebFilter { + private final Supplier> context; private SetupMutatorFilter(Supplier> context) { @@ -414,12 +416,14 @@ public class SecurityMockServerConfigurers { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain webFilterChain) { - exchange.getAttributes().computeIfAbsent(MutatorFilter.ATTRIBUTE_NAME, key -> this.context); + exchange.getAttributes().computeIfAbsent(MutatorFilter.ATTRIBUTE_NAME, (key) -> this.context); return webFilterChain.filter(exchange); } + } private static class MutatorFilter implements WebFilter { + public static final String ATTRIBUTE_NAME = "context"; @Override @@ -428,46 +432,47 @@ public class SecurityMockServerConfigurers { if (context != null) { exchange.getAttributes().remove(ATTRIBUTE_NAME); return webFilterChain.filter(exchange) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(context.get())); + .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(context.get())); } return webFilterChain.filter(exchange); } + } /** - * Updates the WebServerExchange using - * {@code {@link SecurityMockServerConfigurers#mockAuthentication(Authentication)}}. + * Updates the WebServerExchange using {@code {@link + * SecurityMockServerConfigurers#mockAuthentication(Authentication)}}. * * @author Jérôme Wacongne <ch4mp@c4-soft.com> * @author Josh Cummings * @since 5.2 */ - public static class JwtMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class JwtMutator implements WebTestClientConfigurer, MockServerConfigurer { + private Jwt jwt; - private Converter> authoritiesConverter = - new JwtGrantedAuthoritiesConverter(); + + private Converter> authoritiesConverter = new JwtGrantedAuthoritiesConverter(); private JwtMutator() { - jwt((jwt) -> {}); + jwt((jwt) -> { + }); } /** - * Use the given {@link Jwt.Builder} {@link Consumer} to configure the underlying {@link Jwt} + * Use the given {@link Jwt.Builder} {@link Consumer} to configure the underlying + * {@link Jwt} * - * This method first creates a default {@link Jwt.Builder} instance with default values for - * the {@code alg}, {@code sub}, and {@code scope} claims. The {@link Consumer} can then modify - * these or provide additional configuration. - * - * Calling {@link SecurityMockServerConfigurers#mockJwt()} is the equivalent of calling - * {@code SecurityMockMvcRequestPostProcessors.mockJwt().jwt(() -> {})}. + * This method first creates a default {@link Jwt.Builder} instance with default + * values for the {@code alg}, {@code sub}, and {@code scope} claims. The + * {@link Consumer} can then modify these or provide additional configuration. * + * Calling {@link SecurityMockServerConfigurers#mockJwt()} is the equivalent of + * calling {@code SecurityMockMvcRequestPostProcessors.mockJwt().jwt(() -> {})}. * @param jwtBuilderConsumer For configuring the underlying {@link Jwt} * @return the {@link JwtMutator} for further configuration */ public JwtMutator jwt(Consumer jwtBuilderConsumer) { - Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") - .header("alg", "none") - .claim(SUB, "user") + Jwt.Builder jwtBuilder = Jwt.withTokenValue("token").header("alg", "none").claim(JwtClaimNames.SUB, "user") .claim("scope", "read"); jwtBuilderConsumer.accept(jwtBuilder); this.jwt = jwtBuilder.build(); @@ -476,7 +481,6 @@ public class SecurityMockServerConfigurers { /** * Use the given {@link Jwt} - * * @param jwt The {@link Jwt} to use * @return the {@link JwtMutator} for further configuration */ @@ -492,7 +496,7 @@ public class SecurityMockServerConfigurers { */ public JwtMutator authorities(Collection authorities) { Assert.notNull(authorities, "authorities cannot be null"); - this.authoritiesConverter = jwt -> authorities; + this.authoritiesConverter = (jwt) -> authorities; return this; } @@ -503,16 +507,15 @@ public class SecurityMockServerConfigurers { */ public JwtMutator authorities(GrantedAuthority... authorities) { Assert.notNull(authorities, "authorities cannot be null"); - this.authoritiesConverter = jwt -> Arrays.asList(authorities); + this.authoritiesConverter = (jwt) -> Arrays.asList(authorities); return this; } /** * Provides the configured {@link Jwt} so that custom authorities can be derived * from it - * - * @param authoritiesConverter the conversion strategy from {@link Jwt} to a {@link Collection} - * of {@link GrantedAuthority}s + * @param authoritiesConverter the conversion strategy from {@link Jwt} to a + * {@link Collection} of {@link GrantedAuthority}s * @return the {@link JwtMutator} for further configuration */ public JwtMutator authorities(Converter> authoritiesConverter) { @@ -532,10 +535,8 @@ public class SecurityMockServerConfigurers { } @Override - public void afterConfigurerAdded( - WebTestClient.Builder builder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, - @Nullable ClientHttpConnector connector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { httpHandlerBuilder.filter((exchange, chain) -> { CsrfWebFilter.skipExchange(exchange); return chain.filter(exchange); @@ -544,26 +545,31 @@ public class SecurityMockServerConfigurers { } private T configurer() { - return mockAuthentication(new JwtAuthenticationToken(this.jwt, this.authoritiesConverter.convert(this.jwt))); + return mockAuthentication( + new JwtAuthenticationToken(this.jwt, this.authoritiesConverter.convert(this.jwt))); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OpaqueTokenMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class OpaqueTokenMutator implements WebTestClientConfigurer, MockServerConfigurer { + private Supplier> attributes = this::defaultAttributes; + private Supplier> authorities = this::defaultAuthorities; private Supplier principal = this::defaultPrincipal; - private OpaqueTokenMutator() { } + private OpaqueTokenMutator() { + } /** * Mutate the attributes using the given {@link Consumer} - * - * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of attributes + * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of + * attributes * @return the {@link OpaqueTokenMutator} for further configuration */ public OpaqueTokenMutator attributes(Consumer> attributesConsumer) { @@ -623,10 +629,8 @@ public class SecurityMockServerConfigurers { } @Override - public void afterConfigurerAdded( - WebTestClient.Builder builder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, - @Nullable ClientHttpConnector connector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { httpHandlerBuilder.filter((exchange, chain) -> { CsrfWebFilter.skipExchange(exchange); return chain.filter(exchange); @@ -637,8 +641,8 @@ public class SecurityMockServerConfigurers { private T configurer() { OAuth2AuthenticatedPrincipal principal = this.principal.get(); OAuth2AccessToken accessToken = getOAuth2AccessToken(principal); - BearerTokenAuthentication token = new BearerTokenAuthentication - (principal, accessToken, principal.getAuthorities()); + BearerTokenAuthentication token = new BearerTokenAuthentication(principal, accessToken, + principal.getAuthorities()); return mockAuthentication(token); } @@ -666,21 +670,18 @@ public class SecurityMockServerConfigurers { } private OAuth2AuthenticatedPrincipal defaultPrincipal() { - return new OAuth2IntrospectionAuthenticatedPrincipal - (this.attributes.get(), this.authorities.get()); + return new OAuth2IntrospectionAuthenticatedPrincipal(this.attributes.get(), this.authorities.get()); } private Collection getAuthorities(Collection scopes) { - return scopes.stream() - .map(scope -> new SimpleGrantedAuthority("SCOPE_" + scope)) + return scopes.stream().map((scope) -> new SimpleGrantedAuthority("SCOPE_" + scope)) .collect(Collectors.toList()); } private OAuth2AccessToken getOAuth2AccessToken(OAuth2AuthenticatedPrincipal principal) { Instant expiresAt = getInstant(principal.getAttributes(), "exp"); Instant issuedAt = getInstant(principal.getAttributes(), "iat"); - return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", issuedAt, expiresAt); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", issuedAt, expiresAt); } private Instant getInstant(Map attributes, String name) { @@ -693,24 +694,28 @@ public class SecurityMockServerConfigurers { } throw new IllegalArgumentException(name + " attribute must be of type Instant"); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OAuth2LoginMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class OAuth2LoginMutator implements WebTestClientConfigurer, MockServerConfigurer { + private final String nameAttributeKey = "sub"; private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; private Supplier> authorities = this::defaultAuthorities; + private Supplier> attributes = this::defaultAttributes; + private Supplier oauth2User = this::defaultPrincipal; - private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository = - new WebSessionServerOAuth2AuthorizedClientRepository(); + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); private OAuth2LoginMutator(OAuth2AccessToken accessToken) { this.accessToken = accessToken; @@ -719,7 +724,6 @@ public class SecurityMockServerConfigurers { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OAuth2LoginMutator} for further configuration */ @@ -732,7 +736,6 @@ public class SecurityMockServerConfigurers { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OAuth2LoginMutator} for further configuration */ @@ -745,8 +748,8 @@ public class SecurityMockServerConfigurers { /** * Mutate the attributes using the given {@link Consumer} - * - * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of attributes + * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of + * attributes * @return the {@link OAuth2LoginMutator} for further configuration */ public OAuth2LoginMutator attributes(Consumer> attributesConsumer) { @@ -762,7 +765,6 @@ public class SecurityMockServerConfigurers { /** * Use the provided {@link OAuth2User} as the authenticated user. - * * @param oauth2User the {@link OAuth2User} to use * @return the {@link OAuth2LoginMutator} for further configuration */ @@ -777,9 +779,9 @@ public class SecurityMockServerConfigurers { * The supplied {@link ClientRegistration} will be registered into an * {@link WebSessionServerOAuth2AuthorizedClientRepository}. Tests relying on * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an {@link WebSessionServerOAuth2AuthorizedClientRepository} bean - * to the application context. - * + * annotations should register an + * {@link WebSessionServerOAuth2AuthorizedClientRepository} bean to the + * application context. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OAuth2LoginMutator} for further configuration */ @@ -791,34 +793,24 @@ public class SecurityMockServerConfigurers { @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { OAuth2AuthenticationToken token = getToken(); - mockOAuth2Client() - .accessToken(this.accessToken) - .clientRegistration(this.clientRegistration) - .principalName(token.getPrincipal().getName()) - .beforeServerCreated(builder); + mockOAuth2Client().accessToken(this.accessToken).clientRegistration(this.clientRegistration) + .principalName(token.getPrincipal().getName()).beforeServerCreated(builder); mockAuthentication(token).beforeServerCreated(builder); } @Override public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { OAuth2AuthenticationToken token = getToken(); - mockOAuth2Client() - .accessToken(this.accessToken) - .clientRegistration(this.clientRegistration) - .principalName(token.getPrincipal().getName()) - .afterConfigureAdded(serverSpec); + mockOAuth2Client().accessToken(this.accessToken).clientRegistration(this.clientRegistration) + .principalName(token.getPrincipal().getName()).afterConfigureAdded(serverSpec); mockAuthentication(token).afterConfigureAdded(serverSpec); } @Override - public void afterConfigurerAdded( - WebTestClient.Builder builder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, - @Nullable ClientHttpConnector connector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { OAuth2AuthenticationToken token = getToken(); - mockOAuth2Client() - .accessToken(this.accessToken) - .clientRegistration(this.clientRegistration) + mockOAuth2Client().accessToken(this.accessToken).clientRegistration(this.clientRegistration) .principalName(token.getPrincipal().getName()) .afterConfigurerAdded(builder, httpHandlerBuilder, connector); mockAuthentication(token).afterConfigurerAdded(builder, httpHandlerBuilder, connector); @@ -826,14 +818,13 @@ public class SecurityMockServerConfigurers { private OAuth2AuthenticationToken getToken() { OAuth2User oauth2User = this.oauth2User.get(); - return new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId()); + return new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(), + this.clientRegistration.getRegistrationId()); } private ClientRegistration.Builder clientRegistrationBuilder() { - return ClientRegistration.withRegistrationId("test") - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .clientId("test-client") - .tokenUri("https://token-uri.example.org"); + return ClientRegistration.withRegistrationId("test").authorizationGrantType(AuthorizationGrantType.PASSWORD) + .clientId("test-client").tokenUri("https://token-uri.example.org"); } private Collection defaultAuthorities() { @@ -854,22 +845,28 @@ public class SecurityMockServerConfigurers { private OAuth2User defaultPrincipal() { return new DefaultOAuth2User(this.authorities.get(), this.attributes.get(), this.nameAttributeKey); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OidcLoginMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class OidcLoginMutator implements WebTestClientConfigurer, MockServerConfigurer { + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; + private OidcIdToken idToken; + private OidcUserInfo userInfo; + private Supplier oidcUser = this::defaultPrincipal; + private Collection authorities; - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = - new WebSessionServerOAuth2AuthorizedClientRepository(); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); private OidcLoginMutator(OAuth2AccessToken accessToken) { this.accessToken = accessToken; @@ -878,7 +875,6 @@ public class SecurityMockServerConfigurers { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OidcLoginMutator} for further configuration */ @@ -891,7 +887,6 @@ public class SecurityMockServerConfigurers { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OidcLoginMutator} for further configuration */ @@ -904,8 +899,8 @@ public class SecurityMockServerConfigurers { /** * Use the provided {@link OidcIdToken} when constructing the authenticated user - * - * @param idTokenBuilderConsumer a {@link Consumer} of a {@link OidcIdToken.Builder} + * @param idTokenBuilderConsumer a {@link Consumer} of a + * {@link OidcIdToken.Builder} * @return the {@link OidcLoginMutator} for further configuration */ public OidcLoginMutator idToken(Consumer idTokenBuilderConsumer) { @@ -919,8 +914,8 @@ public class SecurityMockServerConfigurers { /** * Use the provided {@link OidcUserInfo} when constructing the authenticated user - * - * @param userInfoBuilderConsumer a {@link Consumer} of a {@link OidcUserInfo.Builder} + * @param userInfoBuilderConsumer a {@link Consumer} of a + * {@link OidcUserInfo.Builder} * @return the {@link OidcLoginMutator} for further configuration */ public OidcLoginMutator userInfoToken(Consumer userInfoBuilderConsumer) { @@ -934,9 +929,8 @@ public class SecurityMockServerConfigurers { /** * Use the provided {@link OidcUser} as the authenticated user. *

        - * Supplying an {@link OidcUser} will take precedence over {@link #idToken}, {@link #userInfo}, - * and list of {@link GrantedAuthority}s to use. - * + * Supplying an {@link OidcUser} will take precedence over {@link #idToken}, + * {@link #userInfo}, and list of {@link GrantedAuthority}s to use. * @param oidcUser the {@link OidcUser} to use * @return the {@link OidcLoginMutator} for further configuration */ @@ -951,9 +945,9 @@ public class SecurityMockServerConfigurers { * The supplied {@link ClientRegistration} will be registered into an * {@link WebSessionServerOAuth2AuthorizedClientRepository}. Tests relying on * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an {@link WebSessionServerOAuth2AuthorizedClientRepository} bean - * to the application context. - * + * annotations should register an + * {@link WebSessionServerOAuth2AuthorizedClientRepository} bean to the + * application context. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OidcLoginMutator} for further configuration */ @@ -965,70 +959,57 @@ public class SecurityMockServerConfigurers { @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { OAuth2AuthenticationToken token = getToken(); - mockOAuth2Client() - .accessToken(this.accessToken) - .principalName(token.getPrincipal().getName()) - .clientRegistration(this.clientRegistration) - .beforeServerCreated(builder); + mockOAuth2Client().accessToken(this.accessToken).principalName(token.getPrincipal().getName()) + .clientRegistration(this.clientRegistration).beforeServerCreated(builder); mockAuthentication(token).beforeServerCreated(builder); } @Override public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { OAuth2AuthenticationToken token = getToken(); - mockOAuth2Client() - .accessToken(this.accessToken) - .principalName(token.getPrincipal().getName()) - .clientRegistration(this.clientRegistration) - .afterConfigureAdded(serverSpec); + mockOAuth2Client().accessToken(this.accessToken).principalName(token.getPrincipal().getName()) + .clientRegistration(this.clientRegistration).afterConfigureAdded(serverSpec); mockAuthentication(token).afterConfigureAdded(serverSpec); } @Override - public void afterConfigurerAdded( - WebTestClient.Builder builder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, - @Nullable ClientHttpConnector connector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { OAuth2AuthenticationToken token = getToken(); - mockOAuth2Client() - .accessToken(this.accessToken) - .principalName(token.getPrincipal().getName()) + mockOAuth2Client().accessToken(this.accessToken).principalName(token.getPrincipal().getName()) .clientRegistration(this.clientRegistration) .afterConfigurerAdded(builder, httpHandlerBuilder, connector); mockAuthentication(token).afterConfigurerAdded(builder, httpHandlerBuilder, connector); } private ClientRegistration.Builder clientRegistrationBuilder() { - return ClientRegistration.withRegistrationId("test") - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .clientId("test-client") - .tokenUri("https://token-uri.example.org"); + return ClientRegistration.withRegistrationId("test").authorizationGrantType(AuthorizationGrantType.PASSWORD) + .clientId("test-client").tokenUri("https://token-uri.example.org"); } private OAuth2AuthenticationToken getToken() { OidcUser oidcUser = this.oidcUser.get(); - return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId()); + return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), + this.clientRegistration.getRegistrationId()); } private Collection getAuthorities() { - if (this.authorities == null) { - Set authorities = new LinkedHashSet<>(); - authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo())); - for (String authority : this.accessToken.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); - } - return authorities; - } else { + if (this.authorities != null) { return this.authorities; } + Set authorities = new LinkedHashSet<>(); + authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo())); + for (String authority : this.accessToken.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } + return authorities; } private OidcIdToken getOidcIdToken() { - if (this.idToken == null) { - return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user")); - } else { + if (this.idToken != null) { return this.idToken; } + return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user")); } private OidcUserInfo getOidcUserInfo() { @@ -1038,35 +1019,41 @@ public class SecurityMockServerConfigurers { private OidcUser defaultPrincipal() { return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OAuth2ClientMutator implements WebTestClientConfigurer, MockServerConfigurer { + public static final class OAuth2ClientMutator implements WebTestClientConfigurer, MockServerConfigurer { + private String registrationId = "test"; + private ClientRegistration clientRegistration; + private String principalName = "user"; + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, null, Collections.singleton("read")); - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository = - new WebSessionServerOAuth2AuthorizedClientRepository(); + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); private OAuth2ClientMutator() { } private OAuth2ClientMutator(String registrationId) { this.registrationId = registrationId; - clientRegistration(c -> {}); + clientRegistration((c) -> { + }); } /** * Use this {@link ClientRegistration} - * * @param clientRegistration - * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration + * @return the + * {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} + * for further configuration */ public OAuth2ClientMutator clientRegistration(ClientRegistration clientRegistration) { this.clientRegistration = clientRegistration; @@ -1075,13 +1062,13 @@ public class SecurityMockServerConfigurers { /** * Use this {@link Consumer} to configure a {@link ClientRegistration} - * * @param clientRegistrationConfigurer the {@link ClientRegistration} configurer - * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration + * @return the + * {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} + * for further configuration */ - public OAuth2ClientMutator clientRegistration - (Consumer clientRegistrationConfigurer) { - + public OAuth2ClientMutator clientRegistration( + Consumer clientRegistrationConfigurer) { ClientRegistration.Builder builder = clientRegistrationBuilder(); clientRegistrationConfigurer.accept(builder); this.clientRegistration = builder.build(); @@ -1090,7 +1077,6 @@ public class SecurityMockServerConfigurers { /** * Use this as the resource owner's principal name - * * @param principalName the resource owner's principal name * @return the {@link OAuth2ClientMutator} for further configuration */ @@ -1102,16 +1088,16 @@ public class SecurityMockServerConfigurers { /** * Use this {@link OAuth2AccessToken} - * * @param accessToken the {@link OAuth2AccessToken} to use - * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration + * @return the + * {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} + * for further configuration */ public OAuth2ClientMutator accessToken(OAuth2AccessToken accessToken) { this.accessToken = accessToken; return this; } - @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { builder.filters(addAuthorizedClientFilter()); @@ -1119,25 +1105,22 @@ public class SecurityMockServerConfigurers { @Override public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { - } @Override - public void afterConfigurerAdded( - WebTestClient.Builder builder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, - @Nullable ClientHttpConnector connector) { + public void afterConfigurerAdded(WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { httpHandlerBuilder.filters(addAuthorizedClientFilter()); } private Consumer> addAuthorizedClientFilter() { OAuth2AuthorizedClient client = getClient(); - return filters -> filters.add(0, (exchange, chain) -> { + return (filters) -> filters.add(0, (exchange, chain) -> { ReactiveOAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServerTestUtils .getOAuth2AuthorizedClientManager(exchange); if (!(authorizationClientManager instanceof TestReactiveOAuth2AuthorizedClientManager)) { - authorizationClientManager = - new TestReactiveOAuth2AuthorizedClientManager(authorizationClientManager); + authorizationClientManager = new TestReactiveOAuth2AuthorizedClientManager( + authorizationClientManager); OAuth2ClientServerTestUtils.setOAuth2AuthorizedClientManager(exchange, authorizationClientManager); } TestReactiveOAuth2AuthorizedClientManager.enable(exchange); @@ -1147,33 +1130,29 @@ public class SecurityMockServerConfigurers { } private OAuth2AuthorizedClient getClient() { - if (this.clientRegistration == null) { - throw new IllegalArgumentException("Please specify a ClientRegistration via one " + - "of the clientRegistration methods"); - } + Assert.notNull(this.clientRegistration, + "Please specify a ClientRegistration via one of the clientRegistration methods"); return new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); } private ClientRegistration.Builder clientRegistrationBuilder() { return ClientRegistration.withRegistrationId(this.registrationId) - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .clientId("test-client") - .clientSecret("test-secret") - .tokenUri("https://idp.example.org/oauth/token"); + .authorizationGrantType(AuthorizationGrantType.PASSWORD).clientId("test-client") + .clientSecret("test-secret").tokenUri("https://idp.example.org/oauth/token"); } /** - * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the - * request is wrapped + * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for + * testing when the request is wrapped */ - private static class TestReactiveOAuth2AuthorizedClientManager + private static final class TestReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { - final static String TOKEN_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class - .getName().concat(".TOKEN"); + static final String TOKEN_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class.getName() + .concat(".TOKEN"); - final static String ENABLED_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class - .getName().concat(".ENABLED"); + static final String ENABLED_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class.getName() + .concat(".ENABLED"); private final ReactiveOAuth2AuthorizedClientManager delegate; @@ -1183,60 +1162,62 @@ public class SecurityMockServerConfigurers { @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { - ServerWebExchange exchange = - authorizeRequest.getAttribute(ServerWebExchange.class.getName()); + ServerWebExchange exchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); if (isEnabled(exchange)) { OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME); return Mono.just(client); - } else { - return this.delegate.authorize(authorizeRequest); } + return this.delegate.authorize(authorizeRequest); } - public static void enable(ServerWebExchange exchange) { - exchange.getAttributes().put(ENABLED_ATTR_NAME, TRUE); + static void enable(ServerWebExchange exchange) { + exchange.getAttributes().put(ENABLED_ATTR_NAME, Boolean.TRUE); } - public boolean isEnabled(ServerWebExchange exchange) { - return TRUE.equals(exchange.getAttribute(ENABLED_ATTR_NAME)); + boolean isEnabled(ServerWebExchange exchange) { + return Boolean.TRUE.equals(exchange.getAttribute(ENABLED_ATTR_NAME)); } + } - private static class OAuth2ClientServerTestUtils { - private static final ServerOAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO = - new WebSessionServerOAuth2AuthorizedClientRepository(); + private static final class OAuth2ClientServerTestUtils { + + private static final ServerOAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO = new WebSessionServerOAuth2AuthorizedClientRepository(); + + private OAuth2ClientServerTestUtils() { + } /** - * Gets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}. - * If one is not found, one based off of {@link WebSessionServerOAuth2AuthorizedClientRepository} is used. - * + * Gets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified + * {@link ServerWebExchange}. If one is not found, one based off of + * {@link WebSessionServerOAuth2AuthorizedClientRepository} is used. * @param exchange the {@link ServerWebExchange} to obtain the * {@link ReactiveOAuth2AuthorizedClientManager} * @return the {@link ReactiveOAuth2AuthorizedClientManager} for the specified * {@link ServerWebExchange} */ - public static ReactiveOAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(ServerWebExchange exchange) { - OAuth2AuthorizedClientArgumentResolver resolver = - findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class); + static ReactiveOAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(ServerWebExchange exchange) { + OAuth2AuthorizedClientArgumentResolver resolver = findResolver(exchange, + OAuth2AuthorizedClientArgumentResolver.class); if (resolver == null) { - return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient - (authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), exchange); + return (authorizeRequest) -> DEFAULT_CLIENT_REPO.loadAuthorizedClient( + authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), exchange); } - return (ReactiveOAuth2AuthorizedClientManager) - ReflectionTestUtils.getField(resolver, "authorizedClientManager"); + return (ReactiveOAuth2AuthorizedClientManager) ReflectionTestUtils.getField(resolver, + "authorizedClientManager"); } /** - * Sets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}. - * + * Sets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified + * {@link ServerWebExchange}. * @param exchange the {@link ServerWebExchange} to obtain the * {@link ReactiveOAuth2AuthorizedClientManager} * @param manager the {@link ReactiveOAuth2AuthorizedClientManager} to set */ - public static void setOAuth2AuthorizedClientManager(ServerWebExchange exchange, + static void setOAuth2AuthorizedClientManager(ServerWebExchange exchange, ReactiveOAuth2AuthorizedClientManager manager) { - OAuth2AuthorizedClientArgumentResolver resolver = - findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class); + OAuth2AuthorizedClientArgumentResolver resolver = findResolver(exchange, + OAuth2AuthorizedClientArgumentResolver.class); if (resolver == null) { return; } @@ -1246,14 +1227,16 @@ public class SecurityMockServerConfigurers { @SuppressWarnings("unchecked") static T findResolver(ServerWebExchange exchange, Class resolverClass) { - if (!ClassUtils.isPresent - ("org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter", null)) { + if (!ClassUtils.isPresent( + "org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter", + null)) { return null; } return WebFluxClasspathGuard.findResolver(exchange, resolverClass); } private static class WebFluxClasspathGuard { + static T findResolver(ServerWebExchange exchange, Class resolverClass) { RequestMappingHandlerAdapter handlerAdapter = getRequestMappingHandlerAdapter(exchange); @@ -1264,8 +1247,8 @@ public class SecurityMockServerConfigurers { if (configurer == null) { return null; } - List resolvers = (List) - ReflectionTestUtils.invokeGetterMethod(configurer, "customResolvers"); + List resolvers = (List) ReflectionTestUtils + .invokeGetterMethod(configurer, "customResolvers"); if (resolvers == null) { return null; } @@ -1277,7 +1260,8 @@ public class SecurityMockServerConfigurers { return null; } - private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServerWebExchange exchange) { + private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter( + ServerWebExchange exchange) { ApplicationContext context = exchange.getApplicationContext(); if (context != null) { String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class); @@ -1287,10 +1271,11 @@ public class SecurityMockServerConfigurers { } return null; } + } - private OAuth2ClientServerTestUtils() { - } } + } + } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java index fb5a69d24d..58058b4537 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; +import javax.servlet.ServletContext; + import org.springframework.beans.Mergeable; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; @@ -25,8 +28,6 @@ import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilde import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.ServletContext; - import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -35,14 +36,15 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder * * @author Rob Winch * @since 4.0 - * */ public final class SecurityMockMvcRequestBuilders { + private SecurityMockMvcRequestBuilders() { + } + /** * Creates a request (including any necessary {@link CsrfToken}) that will submit a * form based login to POST "/login". - * * @return the FormLoginRequestBuilder for further customizations */ public static FormLoginRequestBuilder formLogin() { @@ -52,9 +54,7 @@ public final class SecurityMockMvcRequestBuilders { /** * Creates a request (including any necessary {@link CsrfToken}) that will submit a * form based login to POST {@code loginProcessingUrl}. - * * @param loginProcessingUrl the URL to POST to - * * @return the FormLoginRequestBuilder for further customizations */ public static FormLoginRequestBuilder formLogin(String loginProcessingUrl) { @@ -63,7 +63,6 @@ public final class SecurityMockMvcRequestBuilders { /** * Creates a logout request. - * * @return the LogoutRequestBuilder for additional customizations */ public static LogoutRequestBuilder logout() { @@ -73,9 +72,7 @@ public final class SecurityMockMvcRequestBuilders { /** * Creates a logout request (including any necessary {@link CsrfToken}) to the * specified {@code logoutUrl} - * * @param logoutUrl the logout request URL - * * @return the LogoutRequestBuilder for additional customizations */ public static LogoutRequestBuilder logout(String logoutUrl) { @@ -89,28 +86,30 @@ public final class SecurityMockMvcRequestBuilders { * @since 4.0 */ public static final class LogoutRequestBuilder implements RequestBuilder, Mergeable { + private String logoutUrl = "/logout"; + private RequestPostProcessor postProcessor = csrf(); + private Mergeable parent; + private LogoutRequestBuilder() { + } + @Override public MockHttpServletRequest buildRequest(ServletContext servletContext) { - MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl) - .accept(MediaType.TEXT_HTML, MediaType.ALL); - + MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl).accept(MediaType.TEXT_HTML, + MediaType.ALL); if (this.parent != null) { logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent); } - MockHttpServletRequest request = logoutRequest.buildRequest(servletContext); logoutRequest.postProcessRequest(request); - return this.postProcessor.postProcessRequest(request); } /** * Specifies the logout URL to POST to. Defaults to "/logout". - * * @param logoutUrl the logout URL to POST to. Defaults to "/logout". * @return the {@link LogoutRequestBuilder} for additional customizations */ @@ -121,14 +120,12 @@ public final class SecurityMockMvcRequestBuilders { /** * Specifies the logout URL to POST to. - * * @param logoutUrl the logout URL to POST to. * @param uriVars the URI variables * @return the {@link LogoutRequestBuilder} for additional customizations */ public LogoutRequestBuilder logoutUrl(String logoutUrl, Object... uriVars) { - this.logoutUrl = UriComponentsBuilder.fromPath(logoutUrl) - .buildAndExpand(uriVars).encode().toString(); + this.logoutUrl = UriComponentsBuilder.fromPath(logoutUrl).buildAndExpand(uriVars).encode().toString(); return this; } @@ -145,13 +142,10 @@ public final class SecurityMockMvcRequestBuilders { if (parent instanceof Mergeable) { this.parent = (Mergeable) parent; return this; - } else { - throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); } + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); } - private LogoutRequestBuilder() { - } } /** @@ -161,36 +155,40 @@ public final class SecurityMockMvcRequestBuilders { * @since 4.0 */ public static final class FormLoginRequestBuilder implements RequestBuilder, Mergeable { + private String usernameParam = "username"; + private String passwordParam = "password"; + private String username = "user"; + private String password = "password"; + private String loginProcessingUrl = "/login"; + private MediaType acceptMediaType = MediaType.APPLICATION_FORM_URLENCODED; + private Mergeable parent; private RequestPostProcessor postProcessor = csrf(); + private FormLoginRequestBuilder() { + } + @Override public MockHttpServletRequest buildRequest(ServletContext servletContext) { - MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl) - .accept(this.acceptMediaType) - .param(this.usernameParam, this.username) - .param(this.passwordParam, this.password); - + MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl).accept(this.acceptMediaType) + .param(this.usernameParam, this.username).param(this.passwordParam, this.password); if (this.parent != null) { loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent); } - MockHttpServletRequest request = loginRequest.buildRequest(servletContext); loginRequest.postProcessRequest(request); - return this.postProcessor.postProcessRequest(request); } /** * Specifies the URL to POST to. Default is "/login" - * * @param loginProcessingUrl the URL to POST to. Default is "/login" * @return the {@link FormLoginRequestBuilder} for additional customizations */ @@ -201,14 +199,13 @@ public final class SecurityMockMvcRequestBuilders { /** * Specifies the URL to POST to. - * * @param loginProcessingUrl the URL to POST to * @param uriVars the URI variables * @return the {@link FormLoginRequestBuilder} for additional customizations */ public FormLoginRequestBuilder loginProcessingUrl(String loginProcessingUrl, Object... uriVars) { - this.loginProcessingUrl = UriComponentsBuilder.fromPath(loginProcessingUrl) - .buildAndExpand(uriVars).encode().toString(); + this.loginProcessingUrl = UriComponentsBuilder.fromPath(loginProcessingUrl).buildAndExpand(uriVars).encode() + .toString(); return this; } @@ -256,14 +253,12 @@ public final class SecurityMockMvcRequestBuilders { /** * Specify both the password parameter name and the password. - * * @param passwordParameter the HTTP parameter to place the password. Default is * "password". * @param password the value of the password parameter. Default is "password". * @return the {@link FormLoginRequestBuilder} for additional customizations */ - public FormLoginRequestBuilder password(String passwordParameter, - String password) { + public FormLoginRequestBuilder password(String passwordParameter, String password) { passwordParam(passwordParameter); this.password = password; return this; @@ -271,7 +266,6 @@ public final class SecurityMockMvcRequestBuilders { /** * Specify both the password parameter name and the password. - * * @param usernameParameter the HTTP parameter to place the username. Default is * "username". * @param username the value of the username parameter. Default is "user". @@ -285,7 +279,6 @@ public final class SecurityMockMvcRequestBuilders { /** * Specify a media type to set as the Accept header in the request. - * * @param acceptMediaType the {@link MediaType} to set the Accept header to. * Default is: MediaType.APPLICATION_FORM_URLENCODED * @return the {@link FormLoginRequestBuilder} for additional customizations @@ -305,18 +298,13 @@ public final class SecurityMockMvcRequestBuilders { if (parent == null) { return this; } - if (parent instanceof Mergeable ) { + if (parent instanceof Mergeable) { this.parent = (Mergeable) parent; return this; - } else { - throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); } + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); } - private FormLoginRequestBuilder() { - } } - private SecurityMockMvcRequestBuilders() { - } } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 866396e754..a1d88b2e09 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import java.io.IOException; @@ -35,6 +36,7 @@ import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; + import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -78,6 +80,7 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter; @@ -104,9 +107,6 @@ import org.springframework.web.context.support.WebApplicationContextUtils; import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter; -import static java.lang.Boolean.TRUE; -import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; - /** * Contains {@link MockMvc} {@link RequestPostProcessor} implementations for Spring * Security. @@ -116,10 +116,12 @@ import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB; */ public final class SecurityMockMvcRequestPostProcessors { + private SecurityMockMvcRequestPostProcessors() { + } + /** * Creates a DigestRequestPostProcessor that enables easily adding digest based * authentication to a request. - * * @return the DigestRequestPostProcessor to use */ public static DigestRequestPostProcessor digest() { @@ -129,7 +131,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Creates a DigestRequestPostProcessor that enables easily adding digest based * authentication to a request. - * * @param username the username to use * @return the DigestRequestPostProcessor to use */ @@ -149,28 +150,24 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Finds an X509Cetificate using a resoureName and populates it on the request. - * * @param resourceName the name of the X509Certificate resource * @return the * {@link org.springframework.test.web.servlet.request.RequestPostProcessor} to use. * @throws IOException * @throws CertificateException */ - public static RequestPostProcessor x509(String resourceName) - throws IOException, CertificateException { + public static RequestPostProcessor x509(String resourceName) throws IOException, CertificateException { ResourceLoader loader = new DefaultResourceLoader(); Resource resource = loader.getResource(resourceName); InputStream inputStream = resource.getInputStream(); CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); - X509Certificate certificate = (X509Certificate) certFactory - .generateCertificate(inputStream); + X509Certificate certificate = (X509Certificate) certFactory.generateCertificate(inputStream); return x509(certificate); } /** * Creates a {@link RequestPostProcessor} that will automatically populate a valid * {@link CsrfToken} in the request. - * * @return the {@link CsrfRequestPostProcessor} for further customizations. */ public static CsrfRequestPostProcessor csrf() { @@ -180,7 +177,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Creates a {@link RequestPostProcessor} that can be used to ensure that the * resulting request is ran with the user in the {@link TestSecurityContextHolder}. - * * @return the {@link RequestPostProcessor} to sue */ public static RequestPostProcessor testSecurityContext() { @@ -207,7 +203,6 @@ public final class SecurityMockMvcRequestPostProcessors { *

      18. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      19. * - * * @param username the username to populate * @return the {@link UserRequestPostProcessor} for additional customization */ @@ -235,7 +230,6 @@ public final class SecurityMockMvcRequestPostProcessors { *
      20. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      21. * - * * @param user the UserDetails to populate * @return the {@link RequestPostProcessor} to use */ @@ -244,15 +238,14 @@ public final class SecurityMockMvcRequestPostProcessors { } /** - * Establish a {@link SecurityContext} that has a - * {@link JwtAuthenticationToken} for the - * {@link Authentication} and a {@link Jwt} for the - * {@link Authentication#getPrincipal()}. All details are - * declarative and do not require the JWT to be valid. + * Establish a {@link SecurityContext} that has a {@link JwtAuthenticationToken} for + * the {@link Authentication} and a {@link Jwt} for the + * {@link Authentication#getPrincipal()}. All details are declarative and do not + * require the JWT to be valid. * *

        - * The support works by associating the authentication to the HttpServletRequest. To associate - * the request to the SecurityContextHolder you need to ensure that the + * The support works by associating the authentication to the HttpServletRequest. To + * associate the request to the SecurityContextHolder you need to ensure that the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few * ways to do this are: *

        @@ -263,7 +256,6 @@ public final class SecurityMockMvcRequestPostProcessors { *
      22. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      23. * - * * @return the {@link JwtRequestPostProcessor} for additional customization */ public static JwtRequestPostProcessor jwt() { @@ -271,15 +263,14 @@ public final class SecurityMockMvcRequestPostProcessors { } /** - * Establish a {@link SecurityContext} that has a - * {@link BearerTokenAuthentication} for the - * {@link Authentication} and a {@link OAuth2AuthenticatedPrincipal} for the - * {@link Authentication#getPrincipal()}. All details are - * declarative and do not require the token to be valid + * Establish a {@link SecurityContext} that has a {@link BearerTokenAuthentication} + * for the {@link Authentication} and a {@link OAuth2AuthenticatedPrincipal} for the + * {@link Authentication#getPrincipal()}. All details are declarative and do not + * require the token to be valid * *

        - * The support works by associating the authentication to the HttpServletRequest. To associate - * the request to the SecurityContextHolder you need to ensure that the + * The support works by associating the authentication to the HttpServletRequest. To + * associate the request to the SecurityContextHolder you need to ensure that the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few * ways to do this are: *

        @@ -290,7 +281,6 @@ public final class SecurityMockMvcRequestPostProcessors { *
      24. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      25. * - * * @return the {@link OpaqueTokenRequestPostProcessor} for additional customization * @since 5.3 */ @@ -316,7 +306,6 @@ public final class SecurityMockMvcRequestPostProcessors { *
      26. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      27. * - * * @param authentication the Authentication to populate * @return the {@link RequestPostProcessor} to use */ @@ -348,7 +337,6 @@ public final class SecurityMockMvcRequestPostProcessors { * // ... lots of tests ran with a default user ... * } * - * * @return the {@link RequestPostProcessor} to use */ public static RequestPostProcessor anonymous() { @@ -373,7 +361,6 @@ public final class SecurityMockMvcRequestPostProcessors { * Convenience mechanism for setting the Authorization header to use HTTP Basic with * the given username and password. This method will automatically perform the * necessary Base64 encoding. - * * @param username the username to include in the Authorization header. * @param password the password to include in the Authorization header. * @return the {@link RequestPostProcessor} to use @@ -383,15 +370,14 @@ public final class SecurityMockMvcRequestPostProcessors { } /** - * Establish a {@link SecurityContext} that has a - * {@link OAuth2AuthenticationToken} for the - * {@link Authentication}, a {@link OAuth2User} as the principal, - * and a {@link OAuth2AuthorizedClient} in the session. All details are - * declarative and do not require associated tokens to be valid. + * Establish a {@link SecurityContext} that has a {@link OAuth2AuthenticationToken} + * for the {@link Authentication}, a {@link OAuth2User} as the principal, and a + * {@link OAuth2AuthorizedClient} in the session. All details are declarative and do + * not require associated tokens to be valid. * *

        - * The support works by associating the authentication to the HttpServletRequest. To associate - * the request to the SecurityContextHolder you need to ensure that the + * The support works by associating the authentication to the HttpServletRequest. To + * associate the request to the SecurityContextHolder you need to ensure that the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few * ways to do this are: *

        @@ -402,26 +388,24 @@ public final class SecurityMockMvcRequestPostProcessors { *
      28. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      29. * - * * @return the {@link OidcLoginRequestPostProcessor} for additional customization * @since 5.3 */ public static OAuth2LoginRequestPostProcessor oauth2Login() { - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", - null, null, Collections.singleton("read")); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, + null, Collections.singleton("read")); return new OAuth2LoginRequestPostProcessor(accessToken); } /** - * Establish a {@link SecurityContext} that has a - * {@link OAuth2AuthenticationToken} for the - * {@link Authentication}, a {@link OidcUser} as the principal, - * and a {@link OAuth2AuthorizedClient} in the session. All details are - * declarative and do not require associated tokens to be valid. + * Establish a {@link SecurityContext} that has a {@link OAuth2AuthenticationToken} + * for the {@link Authentication}, a {@link OidcUser} as the principal, and a + * {@link OAuth2AuthorizedClient} in the session. All details are declarative and do + * not require associated tokens to be valid. * *

        - * The support works by associating the authentication to the HttpServletRequest. To associate - * the request to the SecurityContextHolder you need to ensure that the + * The support works by associating the authentication to the HttpServletRequest. To + * associate the request to the SecurityContextHolder you need to ensure that the * SecurityContextPersistenceFilter is associated with the MockMvc instance. A few * ways to do this are: *

        @@ -432,13 +416,12 @@ public final class SecurityMockMvcRequestPostProcessors { *
      30. Manually adding {@link SecurityContextPersistenceFilter} to the MockMvc * instance may make sense when using MockMvcBuilders standaloneSetup
      31. * - * * @return the {@link OidcLoginRequestPostProcessor} for additional customization * @since 5.3 */ public static OidcLoginRequestPostProcessor oidcLogin() { - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", - null, null, Collections.singleton("read")); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, + null, Collections.singleton("read")); return new OidcLoginRequestPostProcessor(accessToken); } @@ -450,7 +433,6 @@ public final class SecurityMockMvcRequestPostProcessors { * The support works by associating the authorized client to the HttpServletRequest * via the {@link HttpSessionOAuth2AuthorizedClientRepository} *

        - * * @return the {@link OAuth2ClientRequestPostProcessor} for additional customization * @since 5.3 */ @@ -466,7 +448,6 @@ public final class SecurityMockMvcRequestPostProcessors { * The support works by associating the authorized client to the HttpServletRequest * via the {@link HttpSessionOAuth2AuthorizedClientRepository} *

        - * * @param registrationId The registration id for the {@link OAuth2AuthorizedClient} * @return the {@link OAuth2ClientRequestPostProcessor} for additional customization * @since 5.3 @@ -478,7 +459,8 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Populates the X509Certificate instances onto the request */ - private static class X509RequestPostProcessor implements RequestPostProcessor { + private static final class X509RequestPostProcessor implements RequestPostProcessor { + private final X509Certificate[] certificates; private X509RequestPostProcessor(X509Certificate... certificates) { @@ -488,10 +470,10 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - request.setAttribute("javax.servlet.request.X509Certificate", - this.certificates); + request.setAttribute("javax.servlet.request.X509Certificate", this.certificates); return request; } + } /** @@ -500,31 +482,26 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Rob Winch * @since 4.0 */ - public static class CsrfRequestPostProcessor implements RequestPostProcessor { + public static final class CsrfRequestPostProcessor implements RequestPostProcessor { private boolean asHeader; private boolean useInvalidToken; - /* - * (non-Javadoc) - * - * @see org.springframework.test.web.servlet.request.RequestPostProcessor - * #postProcessRequest (org.springframework.mock.web.MockHttpServletRequest) - */ + private CsrfRequestPostProcessor() { + } + @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request); if (!(repository instanceof TestCsrfTokenRepository)) { - repository = new TestCsrfTokenRepository( - new HttpSessionCsrfTokenRepository()); + repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository()); WebTestUtils.setCsrfTokenRepository(request, repository); } TestCsrfTokenRepository.enable(request); CsrfToken token = repository.generateToken(request); repository.saveToken(token, request, new MockHttpServletResponse()); - String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() - : token.getToken(); + String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken(); if (this.asHeader) { request.addHeader(token.getHeaderName(), tokenValue); } @@ -537,7 +514,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Instead of using the {@link CsrfToken} as a request parameter (default) will * populate the {@link CsrfToken} as a header. - * * @return the {@link CsrfRequestPostProcessor} for additional customizations */ public CsrfRequestPostProcessor asHeader() { @@ -547,7 +523,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Populates an invalid token value on the request. - * * @return the {@link CsrfRequestPostProcessor} for additional customizations */ public CsrfRequestPostProcessor useInvalidToken() { @@ -555,19 +530,15 @@ public final class SecurityMockMvcRequestPostProcessors { return this; } - private CsrfRequestPostProcessor() { - } - /** * Used to wrap the CsrfTokenRepository to provide support for testing when the * request is wrapped (i.e. Spring Session is in use). */ static class TestCsrfTokenRepository implements CsrfTokenRepository { - final static String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName() - .concat(".TOKEN"); - final static String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class - .getName().concat(".ENABLED"); + static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN"); + + static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED"); private final CsrfTokenRepository delegate; @@ -581,8 +552,7 @@ public final class SecurityMockMvcRequestPostProcessors { } @Override - public void saveToken(CsrfToken token, HttpServletRequest request, - HttpServletResponse response) { + public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) { if (isEnabled(request)) { request.setAttribute(TOKEN_ATTR_NAME, token); } @@ -601,17 +571,20 @@ public final class SecurityMockMvcRequestPostProcessors { } } - public static void enable(HttpServletRequest request) { - request.setAttribute(ENABLED_ATTR_NAME, TRUE); + static void enable(HttpServletRequest request) { + request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); } - public boolean isEnabled(HttpServletRequest request) { - return TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); + boolean isEnabled(HttpServletRequest request) { + return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); } + } + } public static class DigestRequestPostProcessor implements RequestPostProcessor { + private String username = "user"; private String password = "password"; @@ -664,24 +637,20 @@ public final class SecurityMockMvcRequestPostProcessors { String toDigest = expiryTime + ":" + "key"; String signatureValue = md5Hex(toDigest); String nonceValue = expiryTime + ":" + signatureValue; - return new String(Base64.getEncoder().encode(nonceValue.getBytes())); } private String createAuthorizationHeader(MockHttpServletRequest request) { String uri = request.getRequestURI(); - String responseDigest = generateDigest(this.username, this.realm, - this.password, request.getMethod(), uri, this.qop, this.nonce, - this.nc, this.cnonce); - return "Digest username=\"" + this.username + "\", realm=\"" + this.realm - + "\", nonce=\"" + this.nonce + "\", uri=\"" + uri + "\", response=\"" - + responseDigest + "\", qop=" + this.qop + ", nc=" + this.nc - + ", cnonce=\"" + this.cnonce + "\""; + String responseDigest = generateDigest(this.username, this.realm, this.password, request.getMethod(), uri, + this.qop, this.nonce, this.nc, this.cnonce); + return "Digest username=\"" + this.username + "\", realm=\"" + this.realm + "\", nonce=\"" + this.nonce + + "\", uri=\"" + uri + "\", response=\"" + responseDigest + "\", qop=" + this.qop + ", nc=" + + this.nc + ", cnonce=\"" + this.cnonce + "\""; } @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - request.addHeader("Authorization", createAuthorizationHeader(request)); return request; } @@ -691,7 +660,6 @@ public final class SecurityMockMvcRequestPostProcessors { * Both the server and user agent should compute the response * independently. Provided as a static method to simplify the coding of user * agents. - * * @param username the user's login name. * @param realm the name of the realm. * @param password the user's password in plaintext or ready-encoded. @@ -704,54 +672,41 @@ public final class SecurityMockMvcRequestPostProcessors { * @return the MD5 of the digest authentication response, encoded in hex * @throws IllegalArgumentException if the supplied qop value is unsupported. */ - private static String generateDigest(String username, String realm, - String password, String httpMethod, String uri, String qop, String nonce, - String nc, String cnonce) throws IllegalArgumentException { + private static String generateDigest(String username, String realm, String password, String httpMethod, + String uri, String qop, String nonce, String nc, String cnonce) throws IllegalArgumentException { String a1Md5 = encodePasswordInA1Format(username, realm, password); String a2 = httpMethod + ":" + uri; String a2Md5 = md5Hex(a2); - - String digest; - if (qop == null) { // as per RFC 2069 compliant clients (also reaffirmed by RFC 2617) - digest = a1Md5 + ":" + nonce + ":" + a2Md5; + return md5Hex(a1Md5 + ":" + nonce + ":" + a2Md5); } - else if ("auth".equals(qop)) { + if ("auth".equals(qop)) { // As per RFC 2617 compliant clients - digest = a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" - + a2Md5; + return md5Hex(a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5); } - else { - throw new IllegalArgumentException( - "This method does not support a qop: '" + qop + "'"); - } - - return md5Hex(digest); + throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'"); } - static String encodePasswordInA1Format(String username, String realm, - String password) { - String a1 = username + ":" + realm + ":" + password; - - return md5Hex(a1); + static String encodePasswordInA1Format(String username, String realm, String password) { + return md5Hex(username + ":" + realm + ":" + password); } private static String md5Hex(String a2) { return DigestUtils.md5DigestAsHex(a2.getBytes(StandardCharsets.UTF_8)); } + } /** * Support class for {@link RequestPostProcessor}'s that establish a Spring Security * context */ - private static abstract class SecurityContextRequestPostProcessorSupport { + private abstract static class SecurityContextRequestPostProcessorSupport { /** * Saves the specified {@link Authentication} into an empty * {@link SecurityContext} using the {@link SecurityContextRepository}. - * * @param authentication the {@link Authentication} to save * @param request the {@link HttpServletRequest} to use */ @@ -763,30 +718,21 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Saves the {@link SecurityContext} using the {@link SecurityContextRepository} - * * @param securityContext the {@link SecurityContext} to save * @param request the {@link HttpServletRequest} to use */ final void save(SecurityContext securityContext, HttpServletRequest request) { - SecurityContextRepository securityContextRepository = WebTestUtils - .getSecurityContextRepository(request); + SecurityContextRepository securityContextRepository = WebTestUtils.getSecurityContextRepository(request); boolean isTestRepository = securityContextRepository instanceof TestSecurityContextRepository; if (!isTestRepository) { - securityContextRepository = new TestSecurityContextRepository( - securityContextRepository); - WebTestUtils.setSecurityContextRepository(request, - securityContextRepository); + securityContextRepository = new TestSecurityContextRepository(securityContextRepository); + WebTestUtils.setSecurityContextRepository(request, securityContextRepository); } - HttpServletResponse response = new MockHttpServletResponse(); - - HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder( - request, response); + HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(request, response); securityContextRepository.loadContext(requestResponseHolder); - request = requestResponseHolder.getRequest(); response = requestResponseHolder.getResponse(); - securityContextRepository.saveContext(securityContext, request, response); } @@ -794,9 +740,9 @@ public final class SecurityMockMvcRequestPostProcessors { * Used to wrap the SecurityContextRepository to provide support for testing in * stateless mode */ - static class TestSecurityContextRepository implements SecurityContextRepository { - private final static String ATTR_NAME = TestSecurityContextRepository.class - .getName().concat(".REPO"); + static final class TestSecurityContextRepository implements SecurityContextRepository { + + private static final String ATTR_NAME = TestSecurityContextRepository.class.getName().concat(".REPO"); private final SecurityContextRepository delegate; @@ -805,35 +751,33 @@ public final class SecurityMockMvcRequestPostProcessors { } @Override - public SecurityContext loadContext( - HttpRequestResponseHolder requestResponseHolder) { + public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) { SecurityContext result = getContext(requestResponseHolder.getRequest()); // always load from the delegate to ensure the request/response in the // holder are updated // remember the SecurityContextRepository is used in many different // locations - SecurityContext delegateResult = this.delegate - .loadContext(requestResponseHolder); - return result == null ? delegateResult : result; + SecurityContext delegateResult = this.delegate.loadContext(requestResponseHolder); + return (result != null) ? result : delegateResult; } @Override - public void saveContext(SecurityContext context, HttpServletRequest request, - HttpServletResponse response) { + public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) { request.setAttribute(ATTR_NAME, context); this.delegate.saveContext(context, request, response); } @Override public boolean containsContext(HttpServletRequest request) { - return getContext(request) != null - || this.delegate.containsContext(request); + return getContext(request) != null || this.delegate.containsContext(request); } private static SecurityContext getContext(HttpServletRequest request) { return (SecurityContext) request.getAttribute(ATTR_NAME); } + } + } /** @@ -844,26 +788,25 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Rob Winch * @since 4.0 */ - private final static class TestSecurityContextHolderPostProcessor extends - SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { + private static final class TestSecurityContextHolderPostProcessor extends SecurityContextRequestPostProcessorSupport + implements RequestPostProcessor { + private SecurityContext EMPTY = SecurityContextHolder.createEmptyContext(); @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { // TestSecurityContextHolder is only a default value - SecurityContext existingContext = TestSecurityContextRepository - .getContext(request); + SecurityContext existingContext = TestSecurityContextRepository.getContext(request); if (existingContext != null) { return request; } - SecurityContext context = TestSecurityContextHolder.getContext(); if (!this.EMPTY.equals(context)) { save(context, request); } - return request; } + } /** @@ -873,8 +816,8 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Rob Winch * @since 4.0 */ - private final static class SecurityContextRequestPostProcessor extends - SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { + private static final class SecurityContextRequestPostProcessor extends SecurityContextRequestPostProcessorSupport + implements RequestPostProcessor { private final SecurityContext securityContext; @@ -887,6 +830,7 @@ public final class SecurityMockMvcRequestPostProcessors { save(this.securityContext, request); return request; } + } /** @@ -897,8 +841,9 @@ public final class SecurityMockMvcRequestPostProcessors { * @since 4.0 * */ - private final static class AuthenticationRequestPostProcessor extends - SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { + private static final class AuthenticationRequestPostProcessor extends SecurityContextRequestPostProcessorSupport + implements RequestPostProcessor { + private final Authentication authentication; private AuthenticationRequestPostProcessor(Authentication authentication) { @@ -912,6 +857,7 @@ public final class SecurityMockMvcRequestPostProcessors { save(this.authentication, request); return request; } + } /** @@ -922,14 +868,13 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Rob Winch * @since 4.0 */ - private final static class UserDetailsRequestPostProcessor - implements RequestPostProcessor { + private static final class UserDetailsRequestPostProcessor implements RequestPostProcessor { + private final RequestPostProcessor delegate; UserDetailsRequestPostProcessor(UserDetails user) { - Authentication token = new UsernamePasswordAuthenticationToken(user, - user.getPassword(), user.getAuthorities()); - + Authentication token = new UsernamePasswordAuthenticationToken(user, user.getPassword(), + user.getAuthorities()); this.delegate = new AuthenticationRequestPostProcessor(token); } @@ -937,6 +882,7 @@ public final class SecurityMockMvcRequestPostProcessors { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { return this.delegate.postProcessRequest(request); } + } /** @@ -946,8 +892,8 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Rob Winch * @since 4.0 */ - public final static class UserRequestPostProcessor extends - SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { + public static final class UserRequestPostProcessor extends SecurityContextRequestPostProcessorSupport + implements RequestPostProcessor { private String username; @@ -955,8 +901,7 @@ public final class SecurityMockMvcRequestPostProcessors { private static final String ROLE_PREFIX = "ROLE_"; - private Collection authorities = AuthorityUtils - .createAuthorityList("ROLE_USER"); + private Collection authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); private boolean enabled = true; @@ -978,27 +923,19 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Specify the roles of the user to authenticate as. This method is similar to * {@link #authorities(GrantedAuthority...)}, but just not as flexible. - * * @param roles The roles to populate. Note that if the role does not start with * {@link #ROLE_PREFIX} it will automatically be prepended. This means by default * {@code roles("ROLE_USER")} and {@code roles("USER")} are equivalent. + * @return the UserRequestPostProcessor for further customizations * @see #authorities(GrantedAuthority...) * @see #ROLE_PREFIX - * @return the UserRequestPostProcessor for further customizations */ public UserRequestPostProcessor roles(String... roles) { - List authorities = new ArrayList<>( - roles.length); + List authorities = new ArrayList<>(roles.length); for (String role : roles) { - if (role.startsWith(ROLE_PREFIX)) { - throw new IllegalArgumentException( - "Role should not start with " + ROLE_PREFIX - + " since this method automatically prefixes with this value. Got " - + role); - } - else { - authorities.add(new SimpleGrantedAuthority(ROLE_PREFIX + role)); - } + Assert.isTrue(!role.startsWith(ROLE_PREFIX), () -> "Role should not start with " + ROLE_PREFIX + + " since this method automatically prefixes with this value. Got " + role); + authorities.add(new SimpleGrantedAuthority(ROLE_PREFIX + role)); } this.authorities = authorities; return this; @@ -1006,10 +943,9 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Populates the user's {@link GrantedAuthority}'s. The default is ROLE_USER. - * * @param authorities - * @see #roles(String...) * @return the UserRequestPostProcessor for further customizations + * @see #roles(String...) */ public UserRequestPostProcessor authorities(GrantedAuthority... authorities) { return authorities(Arrays.asList(authorities)); @@ -1017,20 +953,17 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Populates the user's {@link GrantedAuthority}'s. The default is ROLE_USER. - * * @param authorities - * @see #roles(String...) * @return the UserRequestPostProcessor for further customizations + * @see #roles(String...) */ - public UserRequestPostProcessor authorities( - Collection authorities) { + public UserRequestPostProcessor authorities(Collection authorities) { this.authorities = authorities; return this; } /** * Populates the user's password. The default is "password" - * * @param password the user's password * @return the UserRequestPostProcessor for further customizations */ @@ -1041,8 +974,7 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - UserDetailsRequestPostProcessor delegate = new UserDetailsRequestPostProcessor( - createUser()); + UserDetailsRequestPostProcessor delegate = new UserDetailsRequestPostProcessor(createUser()); return delegate.postProcessRequest(request); } @@ -1051,36 +983,32 @@ public final class SecurityMockMvcRequestPostProcessors { * @return the {@link User} for the principal */ private User createUser() { - return new User(this.username, this.password, this.enabled, - this.accountNonExpired, this.credentialsNonExpired, - this.accountNonLocked, this.authorities); + return new User(this.username, this.password, this.enabled, this.accountNonExpired, + this.credentialsNonExpired, this.accountNonLocked, this.authorities); } + } - private static class AnonymousRequestPostProcessor extends - SecurityContextRequestPostProcessorSupport implements RequestPostProcessor { + private static class AnonymousRequestPostProcessor extends SecurityContextRequestPostProcessorSupport + implements RequestPostProcessor { + private AuthenticationRequestPostProcessor delegate = new AuthenticationRequestPostProcessor( new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); - /* - * (non-Javadoc) - * - * @see org.springframework.test.web.servlet.request.RequestPostProcessor# - * postProcessRequest(org.springframework.mock.web.MockHttpServletRequest) - */ @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { return this.delegate.postProcessRequest(request); } + } - private static class HttpBasicRequestPostProcessor implements RequestPostProcessor { + private static final class HttpBasicRequestPostProcessor implements RequestPostProcessor { + private String headerValue; private HttpBasicRequestPostProcessor(String username, String password) { - byte[] toEncode; - toEncode = (username + ":" + password).getBytes(StandardCharsets.UTF_8); + byte[] toEncode = (username + ":" + password).getBytes(StandardCharsets.UTF_8); this.headerValue = "Basic " + new String(Base64.getEncoder().encode(toEncode)); } @@ -1089,6 +1017,7 @@ public final class SecurityMockMvcRequestPostProcessors { request.addHeader("Authorization", this.headerValue); return request; } + } /** @@ -1096,32 +1025,32 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Josh Cummings * @since 5.2 */ - public final static class JwtRequestPostProcessor implements RequestPostProcessor { + public static final class JwtRequestPostProcessor implements RequestPostProcessor { + private Jwt jwt; - private Converter> authoritiesConverter = - new JwtGrantedAuthoritiesConverter(); + + private Converter> authoritiesConverter = new JwtGrantedAuthoritiesConverter(); private JwtRequestPostProcessor() { - this.jwt((jwt) -> {}); + this.jwt((jwt) -> { + }); } /** - * Use the given {@link Jwt.Builder} {@link Consumer} to configure the underlying {@link Jwt} + * Use the given {@link Jwt.Builder} {@link Consumer} to configure the underlying + * {@link Jwt} * - * This method first creates a default {@link Jwt.Builder} instance with default values for - * the {@code alg}, {@code sub}, and {@code scope} claims. The {@link Consumer} can then modify - * these or provide additional configuration. - * - * Calling {@link SecurityMockMvcRequestPostProcessors#jwt()} is the equivalent of calling - * {@code SecurityMockMvcRequestPostProcessors.jwt().jwt(() -> {})}. + * This method first creates a default {@link Jwt.Builder} instance with default + * values for the {@code alg}, {@code sub}, and {@code scope} claims. The + * {@link Consumer} can then modify these or provide additional configuration. * + * Calling {@link SecurityMockMvcRequestPostProcessors#jwt()} is the equivalent of + * calling {@code SecurityMockMvcRequestPostProcessors.jwt().jwt(() -> {})}. * @param jwtBuilderConsumer For configuring the underlying {@link Jwt} * @return the {@link JwtRequestPostProcessor} for additional customization */ public JwtRequestPostProcessor jwt(Consumer jwtBuilderConsumer) { - Jwt.Builder jwtBuilder = Jwt.withTokenValue("token") - .header("alg", "none") - .claim(SUB, "user") + Jwt.Builder jwtBuilder = Jwt.withTokenValue("token").header("alg", "none").claim(JwtClaimNames.SUB, "user") .claim("scope", "read"); jwtBuilderConsumer.accept(jwtBuilder); this.jwt = jwtBuilder.build(); @@ -1130,7 +1059,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the given {@link Jwt} - * * @param jwt The {@link Jwt} to use * @return the {@link JwtRequestPostProcessor} for additional customization */ @@ -1146,7 +1074,7 @@ public final class SecurityMockMvcRequestPostProcessors { */ public JwtRequestPostProcessor authorities(Collection authorities) { Assert.notNull(authorities, "authorities cannot be null"); - this.authoritiesConverter = jwt -> authorities; + this.authoritiesConverter = (jwt) -> authorities; return this; } @@ -1157,16 +1085,15 @@ public final class SecurityMockMvcRequestPostProcessors { */ public JwtRequestPostProcessor authorities(GrantedAuthority... authorities) { Assert.notNull(authorities, "authorities cannot be null"); - this.authoritiesConverter = jwt -> Arrays.asList(authorities); + this.authoritiesConverter = (jwt) -> Arrays.asList(authorities); return this; } /** * Provides the configured {@link Jwt} so that custom authorities can be derived * from it - * - * @param authoritiesConverter the conversion strategy from {@link Jwt} to a {@link Collection} - * of {@link GrantedAuthority}s + * @param authoritiesConverter the conversion strategy from {@link Jwt} to a + * {@link Collection} of {@link GrantedAuthority}s * @return the {@link JwtRequestPostProcessor} for further configuration */ public JwtRequestPostProcessor authorities(Converter> authoritiesConverter) { @@ -1189,18 +1116,21 @@ public final class SecurityMockMvcRequestPostProcessors { * @author Josh Cummings * @since 5.3 */ - public final static class OpaqueTokenRequestPostProcessor implements RequestPostProcessor { + public static final class OpaqueTokenRequestPostProcessor implements RequestPostProcessor { + private Supplier> attributes = this::defaultAttributes; + private Supplier> authorities = this::defaultAuthorities; private Supplier principal = this::defaultPrincipal; - private OpaqueTokenRequestPostProcessor() { } + private OpaqueTokenRequestPostProcessor() { + } /** * Mutate the attributes using the given {@link Consumer} - * - * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of attributes + * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of + * attributes * @return the {@link OpaqueTokenRequestPostProcessor} for further configuration */ public OpaqueTokenRequestPostProcessor attributes(Consumer> attributesConsumer) { @@ -1254,8 +1184,8 @@ public final class SecurityMockMvcRequestPostProcessors { CsrfFilter.skipRequest(request); OAuth2AuthenticatedPrincipal principal = this.principal.get(); OAuth2AccessToken accessToken = getOAuth2AccessToken(principal); - BearerTokenAuthentication token = new BearerTokenAuthentication - (principal, accessToken, principal.getAuthorities()); + BearerTokenAuthentication token = new BearerTokenAuthentication(principal, accessToken, + principal.getAuthorities()); return new AuthenticationRequestPostProcessor(token).postProcessRequest(request); } @@ -1283,21 +1213,18 @@ public final class SecurityMockMvcRequestPostProcessors { } private OAuth2AuthenticatedPrincipal defaultPrincipal() { - return new OAuth2IntrospectionAuthenticatedPrincipal - (this.attributes.get(), this.authorities.get()); + return new OAuth2IntrospectionAuthenticatedPrincipal(this.attributes.get(), this.authorities.get()); } private Collection getAuthorities(Collection scopes) { - return scopes.stream() - .map(scope -> new SimpleGrantedAuthority("SCOPE_" + scope)) + return scopes.stream().map((scope) -> new SimpleGrantedAuthority("SCOPE_" + scope)) .collect(Collectors.toList()); } private OAuth2AccessToken getOAuth2AccessToken(OAuth2AuthenticatedPrincipal principal) { Instant expiresAt = getInstant(principal.getAttributes(), "exp"); Instant issuedAt = getInstant(principal.getAttributes(), "iat"); - return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "token", issuedAt, expiresAt); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", issuedAt, expiresAt); } private Instant getInstant(Map attributes, String name) { @@ -1310,20 +1237,25 @@ public final class SecurityMockMvcRequestPostProcessors { } throw new IllegalArgumentException(name + " attribute must be of type Instant"); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OAuth2LoginRequestPostProcessor implements RequestPostProcessor { + public static final class OAuth2LoginRequestPostProcessor implements RequestPostProcessor { + private final String nameAttributeKey = "sub"; private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; private Supplier> authorities = this::defaultAuthorities; + private Supplier> attributes = this::defaultAttributes; + private Supplier oauth2User = this::defaultPrincipal; private OAuth2LoginRequestPostProcessor(OAuth2AccessToken accessToken) { @@ -1333,7 +1265,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OAuth2LoginRequestPostProcessor} for further configuration */ @@ -1346,7 +1277,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OAuth2LoginRequestPostProcessor} for further configuration */ @@ -1359,8 +1289,8 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Mutate the attributes using the given {@link Consumer} - * - * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of attributes + * @param attributesConsumer The {@link Consumer} for mutating the {@Map} of + * attributes * @return the {@link OAuth2LoginRequestPostProcessor} for further configuration */ public OAuth2LoginRequestPostProcessor attributes(Consumer> attributesConsumer) { @@ -1376,7 +1306,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided {@link OAuth2User} as the authenticated user. - * * @param oauth2User the {@link OAuth2User} to use * @return the {@link OAuth2LoginRequestPostProcessor} for further configuration */ @@ -1391,9 +1320,9 @@ public final class SecurityMockMvcRequestPostProcessors { * The supplied {@link ClientRegistration} will be registered into an * {@link HttpSessionOAuth2AuthorizedClientRepository}. Tests relying on * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an {@link HttpSessionOAuth2AuthorizedClientRepository} bean - * to the application context. - * + * annotations should register an + * {@link HttpSessionOAuth2AuthorizedClientRepository} bean to the application + * context. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OAuth2LoginRequestPostProcessor} for further configuration */ @@ -1405,22 +1334,16 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { OAuth2User oauth2User = this.oauth2User.get(); - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken - (oauth2User, oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId()); - + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(), + this.clientRegistration.getRegistrationId()); request = new AuthenticationRequestPostProcessor(token).postProcessRequest(request); - return new OAuth2ClientRequestPostProcessor() - .clientRegistration(this.clientRegistration) - .principalName(oauth2User.getName()) - .accessToken(this.accessToken) - .postProcessRequest(request); + return new OAuth2ClientRequestPostProcessor().clientRegistration(this.clientRegistration) + .principalName(oauth2User.getName()).accessToken(this.accessToken).postProcessRequest(request); } private ClientRegistration.Builder clientRegistrationBuilder() { - return ClientRegistration.withRegistrationId("test") - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .clientId("test-client") - .tokenUri("https://token-uri.example.org"); + return ClientRegistration.withRegistrationId("test").authorizationGrantType(AuthorizationGrantType.PASSWORD) + .clientId("test-client").tokenUri("https://token-uri.example.org"); } private Collection defaultAuthorities() { @@ -1441,18 +1364,25 @@ public final class SecurityMockMvcRequestPostProcessors { private OAuth2User defaultPrincipal() { return new DefaultOAuth2User(this.authorities.get(), this.attributes.get(), this.nameAttributeKey); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OidcLoginRequestPostProcessor implements RequestPostProcessor { + public static final class OidcLoginRequestPostProcessor implements RequestPostProcessor { + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken; + private OidcIdToken idToken; + private OidcUserInfo userInfo; + private Supplier oidcUser = this::defaultPrincipal; + private Collection authorities; private OidcLoginRequestPostProcessor(OAuth2AccessToken accessToken) { @@ -1462,7 +1392,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ @@ -1475,7 +1404,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided authorities in the {@link Authentication} - * * @param authorities the authorities to use * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ @@ -1488,8 +1416,8 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided {@link OidcIdToken} when constructing the authenticated user - * - * @param idTokenBuilderConsumer a {@link Consumer} of a {@link OidcIdToken.Builder} + * @param idTokenBuilderConsumer a {@link Consumer} of a + * {@link OidcIdToken.Builder} * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ public OidcLoginRequestPostProcessor idToken(Consumer idTokenBuilderConsumer) { @@ -1503,8 +1431,8 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided {@link OidcUserInfo} when constructing the authenticated user - * - * @param userInfoBuilderConsumer a {@link Consumer} of a {@link OidcUserInfo.Builder} + * @param userInfoBuilderConsumer a {@link Consumer} of a + * {@link OidcUserInfo.Builder} * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ public OidcLoginRequestPostProcessor userInfoToken(Consumer userInfoBuilderConsumer) { @@ -1517,8 +1445,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use the provided {@link OidcUser} as the authenticated user. - * - * * @param oidcUser the {@link OidcUser} to use * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ @@ -1533,9 +1459,9 @@ public final class SecurityMockMvcRequestPostProcessors { * The supplied {@link ClientRegistration} will be registered into an * {@link HttpSessionOAuth2AuthorizedClientRepository}. Tests relying on * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an {@link HttpSessionOAuth2AuthorizedClientRepository} bean - * to the application context. - * + * annotations should register an + * {@link HttpSessionOAuth2AuthorizedClientRepository} bean to the application + * context. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ @@ -1547,39 +1473,32 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { OidcUser oidcUser = this.oidcUser.get(); - return new OAuth2LoginRequestPostProcessor(this.accessToken) - .oauth2User(oidcUser) - .clientRegistration(this.clientRegistration) - .postProcessRequest(request); + return new OAuth2LoginRequestPostProcessor(this.accessToken).oauth2User(oidcUser) + .clientRegistration(this.clientRegistration).postProcessRequest(request); } private ClientRegistration.Builder clientRegistrationBuilder() { - return ClientRegistration.withRegistrationId("test") - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .clientId("test-client") - .tokenUri("https://token-uri.example.org"); + return ClientRegistration.withRegistrationId("test").authorizationGrantType(AuthorizationGrantType.PASSWORD) + .clientId("test-client").tokenUri("https://token-uri.example.org"); } private Collection getAuthorities() { - if (this.authorities == null) { - Set authorities = new LinkedHashSet<>(); - authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo())); - for (String authority : this.accessToken.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); - } - return authorities; - } else { + if (this.authorities != null) { return this.authorities; } + Set authorities = new LinkedHashSet<>(); + authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo())); + for (String authority : this.accessToken.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } + return authorities; } private OidcIdToken getOidcIdToken() { - if (this.idToken == null) { - return new OidcIdToken("id-token", null, null, - Collections.singletonMap(IdTokenClaimNames.SUB, "user")); - } else { + if (this.idToken != null) { return this.idToken; } + return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user")); } private OidcUserInfo getOidcUserInfo() { @@ -1589,16 +1508,21 @@ public final class SecurityMockMvcRequestPostProcessors { private OidcUser defaultPrincipal() { return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo); } + } /** * @author Josh Cummings * @since 5.3 */ - public final static class OAuth2ClientRequestPostProcessor implements RequestPostProcessor { + public static final class OAuth2ClientRequestPostProcessor implements RequestPostProcessor { + private String registrationId = "test"; + private ClientRegistration clientRegistration; + private String principalName = "user"; + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, null, Collections.singleton("read")); @@ -1607,12 +1531,12 @@ public final class SecurityMockMvcRequestPostProcessors { private OAuth2ClientRequestPostProcessor(String registrationId) { this.registrationId = registrationId; - clientRegistration(c -> {}); + clientRegistration((c) -> { + }); } /** * Use this {@link ClientRegistration} - * * @param clientRegistration * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration */ @@ -1623,13 +1547,11 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use this {@link Consumer} to configure a {@link ClientRegistration} - * * @param clientRegistrationConfigurer the {@link ClientRegistration} configurer * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration */ - public OAuth2ClientRequestPostProcessor clientRegistration - (Consumer clientRegistrationConfigurer) { - + public OAuth2ClientRequestPostProcessor clientRegistration( + Consumer clientRegistrationConfigurer) { ClientRegistration.Builder builder = clientRegistrationBuilder(); clientRegistrationConfigurer.accept(builder); this.clientRegistration = builder.build(); @@ -1638,7 +1560,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use this as the resource owner's principal name - * * @param principalName the resource owner's principal name * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration */ @@ -1650,7 +1571,6 @@ public final class SecurityMockMvcRequestPostProcessors { /** * Use this {@link OAuth2AccessToken} - * * @param accessToken the {@link OAuth2AccessToken} to use * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration */ @@ -1662,17 +1582,15 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { if (this.clientRegistration == null) { - throw new IllegalArgumentException("Please specify a ClientRegistration via one " + - "of the clientRegistration methods"); + throw new IllegalArgumentException( + "Please specify a ClientRegistration via one " + "of the clientRegistration methods"); } - OAuth2AuthorizedClient client = new OAuth2AuthorizedClient - (this.clientRegistration, this.principalName, this.accessToken); - + OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, + this.accessToken); OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils .getOAuth2AuthorizedClientManager(request); if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) { - authorizationClientManager = - new TestOAuth2AuthorizedClientManager(authorizationClientManager); + authorizationClientManager = new TestOAuth2AuthorizedClientManager(authorizationClientManager); OAuth2ClientServletTestUtils.setOAuth2AuthorizedClientManager(request, authorizationClientManager); } TestOAuth2AuthorizedClientManager.enable(request); @@ -1682,24 +1600,20 @@ public final class SecurityMockMvcRequestPostProcessors { private ClientRegistration.Builder clientRegistrationBuilder() { return ClientRegistration.withRegistrationId(this.registrationId) - .authorizationGrantType(AuthorizationGrantType.PASSWORD) - .clientId("test-client") - .clientSecret("test-secret") - .tokenUri("https://idp.example.org/oauth/token"); + .authorizationGrantType(AuthorizationGrantType.PASSWORD).clientId("test-client") + .clientSecret("test-secret").tokenUri("https://idp.example.org/oauth/token"); } /** - * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the - * request is wrapped + * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for + * testing when the request is wrapped */ - private static class TestOAuth2AuthorizedClientManager - implements OAuth2AuthorizedClientManager { + private static final class TestOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { - final static String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName() - .concat(".TOKEN"); + static final String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName().concat(".TOKEN"); - final static String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientManager.class - .getName().concat(".ENABLED"); + static final String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName() + .concat(".ENABLED"); private final OAuth2AuthorizedClientManager delegate; @@ -1709,59 +1623,61 @@ public final class SecurityMockMvcRequestPostProcessors { @Override public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { - HttpServletRequest request = - authorizeRequest.getAttribute(HttpServletRequest.class.getName()); + HttpServletRequest request = authorizeRequest.getAttribute(HttpServletRequest.class.getName()); if (isEnabled(request)) { return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME); - } else { - return this.delegate.authorize(authorizeRequest); } + return this.delegate.authorize(authorizeRequest); } - public static void enable(HttpServletRequest request) { - request.setAttribute(ENABLED_ATTR_NAME, TRUE); + static void enable(HttpServletRequest request) { + request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); } - public boolean isEnabled(HttpServletRequest request) { - return TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); + boolean isEnabled(HttpServletRequest request) { + return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); } + } - private static class OAuth2ClientServletTestUtils { - private static final OAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO = - new HttpSessionOAuth2AuthorizedClientRepository(); + private static final class OAuth2ClientServletTestUtils { + + private static final OAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO = new HttpSessionOAuth2AuthorizedClientRepository(); + + private OAuth2ClientServletTestUtils() { + } /** - * Gets the {@link OAuth2AuthorizedClientManager} for the specified {@link HttpServletRequest}. - * If one is not found, one based off of {@link HttpSessionOAuth2AuthorizedClientRepository} is used. - * + * Gets the {@link OAuth2AuthorizedClientManager} for the specified + * {@link HttpServletRequest}. If one is not found, one based off of + * {@link HttpSessionOAuth2AuthorizedClientRepository} is used. * @param request the {@link HttpServletRequest} to obtain the * {@link OAuth2AuthorizedClientManager} * @return the {@link OAuth2AuthorizedClientManager} for the specified * {@link HttpServletRequest} */ - public static OAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(HttpServletRequest request) { - OAuth2AuthorizedClientArgumentResolver resolver = - findResolver(request, OAuth2AuthorizedClientArgumentResolver.class); + static OAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(HttpServletRequest request) { + OAuth2AuthorizedClientArgumentResolver resolver = findResolver(request, + OAuth2AuthorizedClientArgumentResolver.class); if (resolver == null) { - return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient - (authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), request); + return (authorizeRequest) -> DEFAULT_CLIENT_REPO.loadAuthorizedClient( + authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), request); } - return (OAuth2AuthorizedClientManager) - ReflectionTestUtils.getField(resolver, "authorizedClientManager"); + return (OAuth2AuthorizedClientManager) ReflectionTestUtils.getField(resolver, + "authorizedClientManager"); } /** - * Sets the {@link OAuth2AuthorizedClientManager} for the specified {@link HttpServletRequest}. - * + * Sets the {@link OAuth2AuthorizedClientManager} for the specified + * {@link HttpServletRequest}. * @param request the {@link HttpServletRequest} to obtain the * {@link OAuth2AuthorizedClientManager} * @param manager the {@link OAuth2AuthorizedClientManager} to set */ - public static void setOAuth2AuthorizedClientManager(HttpServletRequest request, + static void setOAuth2AuthorizedClientManager(HttpServletRequest request, OAuth2AuthorizedClientManager manager) { - OAuth2AuthorizedClientArgumentResolver resolver = - findResolver(request, OAuth2AuthorizedClientArgumentResolver.class); + OAuth2AuthorizedClientArgumentResolver resolver = findResolver(request, + OAuth2AuthorizedClientArgumentResolver.class); if (resolver == null) { return; } @@ -1771,14 +1687,15 @@ public final class SecurityMockMvcRequestPostProcessors { @SuppressWarnings("unchecked") static T findResolver(HttpServletRequest request, Class resolverClass) { - if (!ClassUtils.isPresent - ("org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter", null)) { + if (!ClassUtils.isPresent( + "org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter", null)) { return null; } return WebMvcClasspathGuard.findResolver(request, resolverClass); } private static class WebMvcClasspathGuard { + static T findResolver(HttpServletRequest request, Class resolverClass) { ServletContext servletContext = request.getServletContext(); @@ -1798,9 +1715,9 @@ public final class SecurityMockMvcRequestPostProcessors { return null; } - private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServletContext servletContext) { - WebApplicationContext context = WebApplicationContextUtils - .getWebApplicationContext(servletContext); + private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter( + ServletContext servletContext) { + WebApplicationContext context = WebApplicationContextUtils.getWebApplicationContext(servletContext); if (context != null) { String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class); if (names.length > 0) { @@ -1809,13 +1726,11 @@ public final class SecurityMockMvcRequestPostProcessors { } return null; } + } - private OAuth2ClientServletTestUtils() { - } } + } - private SecurityMockMvcRequestPostProcessors() { - } } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchers.java b/test/src/main/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchers.java index a4f93652ac..d27f1c5afb 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchers.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchers.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.response; import java.util.ArrayList; @@ -28,13 +29,11 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.test.util.AssertionErrors; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; -import static org.springframework.test.util.AssertionErrors.assertEquals; -import static org.springframework.test.util.AssertionErrors.assertTrue; - /** * Security related {@link MockMvc} {@link ResultMatcher}s. * @@ -44,9 +43,11 @@ import static org.springframework.test.util.AssertionErrors.assertTrue; */ public final class SecurityMockMvcResultMatchers { + private SecurityMockMvcResultMatchers() { + } + /** * {@link ResultMatcher} that verifies that a specified user is authenticated. - * * @return the {@link AuthenticatedMatcher} to use */ public static AuthenticatedMatcher authenticated() { @@ -55,23 +56,20 @@ public final class SecurityMockMvcResultMatchers { /** * {@link ResultMatcher} that verifies that no user is authenticated. - * * @return the {@link AuthenticatedMatcher} to use */ public static ResultMatcher unauthenticated() { return new UnAuthenticatedMatcher(); } - private static abstract class AuthenticationMatcher> - implements ResultMatcher { + private abstract static class AuthenticationMatcher> implements ResultMatcher { protected SecurityContext load(MvcResult result) { - HttpRequestResponseHolder holder = new HttpRequestResponseHolder( - result.getRequest(), result.getResponse()); - SecurityContextRepository repository = WebTestUtils - .getSecurityContextRepository(result.getRequest()); + HttpRequestResponseHolder holder = new HttpRequestResponseHolder(result.getRequest(), result.getResponse()); + SecurityContextRepository repository = WebTestUtils.getSecurityContextRepository(result.getRequest()); return repository.loadContext(holder); } + } /** @@ -81,67 +79,61 @@ public final class SecurityMockMvcResultMatchers { * @author Rob Winch * @since 4.0 */ - public static final class AuthenticatedMatcher - extends AuthenticationMatcher { + public static final class AuthenticatedMatcher extends AuthenticationMatcher { private SecurityContext expectedContext; + private Authentication expectedAuthentication; + private Object expectedAuthenticationPrincipal; + private String expectedAuthenticationName; + private Collection expectedGrantedAuthorities; + private Consumer assertAuthentication; + AuthenticatedMatcher() { + } + @Override public void match(MvcResult result) { SecurityContext context = load(result); - Authentication auth = context.getAuthentication(); - - assertTrue("Authentication should not be null", auth != null); - + AssertionErrors.assertTrue("Authentication should not be null", auth != null); if (this.assertAuthentication != null) { this.assertAuthentication.accept(auth); } - if (this.expectedContext != null) { - assertEquals(this.expectedContext + " does not equal " + context, - this.expectedContext, context); + AssertionErrors.assertEquals(this.expectedContext + " does not equal " + context, this.expectedContext, + context); } - if (this.expectedAuthentication != null) { - assertEquals( - this.expectedAuthentication + " does not equal " - + context.getAuthentication(), + AssertionErrors.assertEquals( + this.expectedAuthentication + " does not equal " + context.getAuthentication(), this.expectedAuthentication, context.getAuthentication()); } - if (this.expectedAuthenticationPrincipal != null) { - assertTrue("Authentication cannot be null", - context.getAuthentication() != null); - assertEquals( + AssertionErrors.assertTrue("Authentication cannot be null", context.getAuthentication() != null); + AssertionErrors.assertEquals( this.expectedAuthenticationPrincipal + " does not equal " + context.getAuthentication().getPrincipal(), - this.expectedAuthenticationPrincipal, - context.getAuthentication().getPrincipal()); + this.expectedAuthenticationPrincipal, context.getAuthentication().getPrincipal()); } - if (this.expectedAuthenticationName != null) { - assertTrue("Authentication cannot be null", auth != null); + AssertionErrors.assertTrue("Authentication cannot be null", auth != null); String name = auth.getName(); - assertEquals(this.expectedAuthenticationName + " does not equal " + name, + AssertionErrors.assertEquals(this.expectedAuthenticationName + " does not equal " + name, this.expectedAuthenticationName, name); } - if (this.expectedGrantedAuthorities != null) { - assertTrue("Authentication cannot be null", auth != null); - Collection authorities = auth - .getAuthorities(); - assertTrue( - authorities + " does not contain the same authorities as " - + this.expectedGrantedAuthorities, + AssertionErrors.assertTrue("Authentication cannot be null", auth != null); + Collection authorities = auth.getAuthorities(); + AssertionErrors.assertTrue( + authorities + " does not contain the same authorities as " + this.expectedGrantedAuthorities, authorities.containsAll(this.expectedGrantedAuthorities)); - assertTrue(this.expectedGrantedAuthorities - + " does not contain the same authorities as " + authorities, + AssertionErrors.assertTrue( + this.expectedGrantedAuthorities + " does not contain the same authorities as " + authorities, this.expectedGrantedAuthorities.containsAll(authorities)); } } @@ -158,7 +150,6 @@ public final class SecurityMockMvcResultMatchers { /** * Specifies the expected username - * * @param expected the expected username * @return the {@link AuthenticatedMatcher} for further customization */ @@ -168,7 +159,6 @@ public final class SecurityMockMvcResultMatchers { /** * Specifies the expected {@link SecurityContext} - * * @param expected the expected {@link SecurityContext} * @return the {@link AuthenticatedMatcher} for further customization */ @@ -179,7 +169,6 @@ public final class SecurityMockMvcResultMatchers { /** * Specifies the expected {@link Authentication} - * * @param expected the expected {@link Authentication} * @return the {@link AuthenticatedMatcher} for further customization */ @@ -190,7 +179,6 @@ public final class SecurityMockMvcResultMatchers { /** * Specifies the expected principal - * * @param expected the expected principal * @return the {@link AuthenticatedMatcher} for further customization */ @@ -201,7 +189,6 @@ public final class SecurityMockMvcResultMatchers { /** * Specifies the expected {@link Authentication#getName()} - * * @param expected the expected {@link Authentication#getName()} * @return the {@link AuthenticatedMatcher} for further customization */ @@ -212,19 +199,16 @@ public final class SecurityMockMvcResultMatchers { /** * Specifies the {@link Authentication#getAuthorities()} - * * @param expected the {@link Authentication#getAuthorities()} * @return the {@link AuthenticatedMatcher} for further customization */ - public AuthenticatedMatcher withAuthorities( - Collection expected) { + public AuthenticatedMatcher withAuthorities(Collection expected) { this.expectedGrantedAuthorities = expected; return this; } /** * Specifies the {@link Authentication#getAuthorities()} - * * @param roles the roles. Each value is automatically prefixed with "ROLE_" * @return the {@link AuthenticatedMatcher} for further customization */ @@ -236,8 +220,6 @@ public final class SecurityMockMvcResultMatchers { return withAuthorities(authorities); } - AuthenticatedMatcher() { - } } /** @@ -247,24 +229,22 @@ public final class SecurityMockMvcResultMatchers { * @author Rob Winch * @since 4.0 */ - private static final class UnAuthenticatedMatcher - extends AuthenticationMatcher { + private static final class UnAuthenticatedMatcher extends AuthenticationMatcher { + private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); + private UnAuthenticatedMatcher() { + } + @Override public void match(MvcResult result) { SecurityContext context = load(result); Authentication authentication = context.getAuthentication(); - assertTrue("Expected anonymous Authentication got " + context, - authentication == null - || this.trustResolver.isAnonymous(authentication)); + AssertionErrors.assertTrue("Expected anonymous Authentication got " + context, + authentication == null || this.trustResolver.isAnonymous(authentication)); } - private UnAuthenticatedMatcher() { - } } - private SecurityMockMvcResultMatchers() { - } } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java b/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java index 41e0b9fcee..98173963c1 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.setup; +import java.io.IOException; + import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -26,10 +29,9 @@ import org.springframework.security.config.BeanIds; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder; import org.springframework.test.web.servlet.setup.MockMvcConfigurerAdapter; +import org.springframework.util.Assert; import org.springframework.web.context.WebApplicationContext; -import java.io.IOException; - import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext; /** @@ -41,6 +43,7 @@ import static org.springframework.security.test.web.servlet.request.SecurityMock * @since 4.0 */ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter { + private final DelegateFilter delegateFilter; /** @@ -64,26 +67,17 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter { } @Override - public RequestPostProcessor beforeMockMvcCreated( - ConfigurableMockMvcBuilder builder, WebApplicationContext context) { + public RequestPostProcessor beforeMockMvcCreated(ConfigurableMockMvcBuilder builder, + WebApplicationContext context) { String securityBeanId = BeanIds.SPRING_SECURITY_FILTER_CHAIN; - if (getSpringSecurityFilterChain() == null - && context.containsBean(securityBeanId)) { - setSpringSecurityFitlerChain(context.getBean(securityBeanId, - Filter.class)); + if (getSpringSecurityFilterChain() == null && context.containsBean(securityBeanId)) { + setSpringSecurityFitlerChain(context.getBean(securityBeanId, Filter.class)); } - - if (getSpringSecurityFilterChain() == null) { - throw new IllegalStateException( - "springSecurityFilterChain cannot be null. Ensure a Bean with the name " - + securityBeanId - + " implementing Filter is present or inject the Filter to be used."); - } - + Assert.state(getSpringSecurityFilterChain() != null, + () -> "springSecurityFilterChain cannot be null. Ensure a Bean with the name " + securityBeanId + + " implementing Filter is present or inject the Filter to be used."); // This is used by other test support to obtain the FilterChainProxy - context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, - getSpringSecurityFilterChain()); - + context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, getSpringSecurityFilterChain()); return testSecurityContext(); } @@ -96,11 +90,13 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter { } /** - * Allows adding in {@link #afterConfigurerAdded(ConfigurableMockMvcBuilder)} to preserve Filter order and then - * lazily set the delegate in {@link #beforeMockMvcCreated(ConfigurableMockMvcBuilder, WebApplicationContext)}. + * Allows adding in {@link #afterConfigurerAdded(ConfigurableMockMvcBuilder)} to + * preserve Filter order and then lazily set the delegate in + * {@link #beforeMockMvcCreated(ConfigurableMockMvcBuilder, WebApplicationContext)}. * - * {@link org.springframework.web.filter.DelegatingFilterProxy} is not used because it is not easy to lazily set - * the delegate or get the delegate which is necessary for the test infrastructure. + * {@link org.springframework.web.filter.DelegatingFilterProxy} is not used because it + * is not easy to lazily set the delegate or get the delegate which is necessary for + * the test infrastructure. */ static class DelegateFilter implements Filter { @@ -119,11 +115,9 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter { Filter getDelegate() { Filter result = this.delegate; - if (result == null) { - throw new IllegalStateException("delegate cannot be null. Ensure a Bean with the name " - + BeanIds.SPRING_SECURITY_FILTER_CHAIN - + " implementing Filter is present or inject the Filter to be used."); - } + Assert.state(result != null, + () -> "delegate cannot be null. Ensure a Bean with the name " + BeanIds.SPRING_SECURITY_FILTER_CHAIN + + " implementing Filter is present or inject the Filter to be used."); return result; } @@ -144,18 +138,20 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter { } @Override - public int hashCode() { - return getDelegate().hashCode(); + public boolean equals(Object obj) { + return getDelegate().equals(obj); } @Override - public boolean equals(Object obj) { - return getDelegate().equals(obj); + public int hashCode() { + return getDelegate().hashCode(); } @Override public String toString() { return getDelegate().toString(); } + } + } diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurers.java b/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurers.java index 9145c57e90..07332b3244 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurers.java @@ -13,21 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.setup; +import javax.servlet.Filter; + import org.springframework.test.web.servlet.setup.MockMvcConfigurer; import org.springframework.util.Assert; -import javax.servlet.Filter; - /** * Provides Security related * {@link org.springframework.test.web.servlet.setup.MockMvcConfigurer} implementations. * - * @since 4.0 * @author Rob Winch + * @since 4.0 */ public final class SecurityMockMvcConfigurers { + + private SecurityMockMvcConfigurers() { + } + /** * Configures the MockMvcBuilder for use with Spring Security. Specifically the * configurer adds the Spring Bean named "springSecurityFilterChain" as a Filter. It @@ -35,7 +40,6 @@ public final class SecurityMockMvcConfigurers { * by applying * {@link org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors#testSecurityContext()} * . - * * @return the {@link org.springframework.test.web.servlet.setup.MockMvcConfigurer} to * use */ @@ -49,15 +53,13 @@ public final class SecurityMockMvcConfigurers { * TestSecurityContextHolder is leveraged for each request by applying * {@link org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors#testSecurityContext()} * . - * * @param springSecurityFilterChain the Filter to be added - * * @return the {@link org.springframework.test.web.servlet.setup.MockMvcConfigurer} to * use */ public static MockMvcConfigurer springSecurity(Filter springSecurityFilterChain) { - Assert.notNull(springSecurityFilterChain, - "springSecurityFilterChain cannot be null"); + Assert.notNull(springSecurityFilterChain, "springSecurityFilterChain cannot be null"); return new SecurityMockMvcConfigurer(springSecurityFilterChain); } + } diff --git a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java index 98294cd08a..8f9a79f51d 100644 --- a/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java +++ b/test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.support; import java.util.List; @@ -41,23 +42,25 @@ import org.springframework.web.context.support.WebApplicationContextUtils; * @since 4.0 */ public abstract class WebTestUtils { + private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository(); + private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); + private WebTestUtils() { + } + /** * Gets the {@link SecurityContextRepository} for the specified * {@link HttpServletRequest}. If one is not found, a default * {@link HttpSessionSecurityContextRepository} is used. - * * @param request the {@link HttpServletRequest} to obtain the * {@link SecurityContextRepository} * @return the {@link SecurityContextRepository} for the specified * {@link HttpServletRequest} */ - public static SecurityContextRepository getSecurityContextRepository( - HttpServletRequest request) { - SecurityContextPersistenceFilter filter = findFilter(request, - SecurityContextPersistenceFilter.class); + public static SecurityContextRepository getSecurityContextRepository(HttpServletRequest request) { + SecurityContextPersistenceFilter filter = findFilter(request, SecurityContextPersistenceFilter.class); if (filter == null) { return DEFAULT_CONTEXT_REPO; } @@ -67,15 +70,13 @@ public abstract class WebTestUtils { /** * Sets the {@link SecurityContextRepository} for the specified * {@link HttpServletRequest}. - * * @param request the {@link HttpServletRequest} to obtain the * {@link SecurityContextRepository} * @param securityContextRepository the {@link SecurityContextRepository} to set */ public static void setSecurityContextRepository(HttpServletRequest request, SecurityContextRepository securityContextRepository) { - SecurityContextPersistenceFilter filter = findFilter(request, - SecurityContextPersistenceFilter.class); + SecurityContextPersistenceFilter filter = findFilter(request, SecurityContextPersistenceFilter.class); if (filter != null) { ReflectionTestUtils.setField(filter, "repo", securityContextRepository); } @@ -84,7 +85,6 @@ public abstract class WebTestUtils { /** * Gets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. * If one is not found, the default {@link HttpSessionCsrfTokenRepository} is used. - * * @param request the {@link HttpServletRequest} to obtain the * {@link CsrfTokenRepository} * @return the {@link CsrfTokenRepository} for the specified @@ -95,19 +95,16 @@ public abstract class WebTestUtils { if (filter == null) { return DEFAULT_TOKEN_REPO; } - return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, - "tokenRepository"); + return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository"); } /** * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}. - * * @param request the {@link HttpServletRequest} to obtain the * {@link CsrfTokenRepository} * @param repository the {@link CsrfTokenRepository} to set */ - public static void setCsrfTokenRepository(HttpServletRequest request, - CsrfTokenRepository repository) { + public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) { CsrfFilter filter = findFilter(request, CsrfFilter.class); if (filter != null) { ReflectionTestUtils.setField(filter, "tokenRepository", repository); @@ -115,15 +112,13 @@ public abstract class WebTestUtils { } @SuppressWarnings("unchecked") - static T findFilter(HttpServletRequest request, - Class filterClass) { + static T findFilter(HttpServletRequest request, Class filterClass) { ServletContext servletContext = request.getServletContext(); Filter springSecurityFilterChain = getSpringSecurityFilterChain(servletContext); if (springSecurityFilterChain == null) { return null; } - List filters = ReflectionTestUtils - .invokeMethod(springSecurityFilterChain, "getFilters", request); + List filters = ReflectionTestUtils.invokeMethod(springSecurityFilterChain, "getFilters", request); if (filters == null) { return null; } @@ -136,25 +131,22 @@ public abstract class WebTestUtils { } private static Filter getSpringSecurityFilterChain(ServletContext servletContext) { - Filter result = (Filter) servletContext - .getAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN); + Filter result = (Filter) servletContext.getAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN); if (result != null) { return result; } WebApplicationContext webApplicationContext = WebApplicationContextUtils .getWebApplicationContext(servletContext); - if (webApplicationContext != null) { - try { - return webApplicationContext.getBean( - AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME, - Filter.class); - } - catch (NoSuchBeanDefinitionException notFound) { - } + if (webApplicationContext == null) { + return null; + } + try { + String beanName = AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME; + return webApplicationContext.getBean(beanName, Filter.class); + } + catch (NoSuchBeanDefinitionException ex) { + return null; } - return null; } - private WebTestUtils() { - } } diff --git a/test/src/test/java/org/springframework/security/test/context/TestSecurityContextHolderTests.java b/test/src/test/java/org/springframework/security/test/context/TestSecurityContextHolderTests.java index a15f9ed5c2..28362616aa 100644 --- a/test/src/test/java/org/springframework/security/test/context/TestSecurityContextHolderTests.java +++ b/test/src/test/java/org/springframework/security/test/context/TestSecurityContextHolderTests.java @@ -13,25 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.context; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +package org.springframework.security.test.context; import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + public class TestSecurityContextHolderTests { private SecurityContext context; @Before public void setup() { - context = SecurityContextHolder.createEmptyContext(); + this.context = SecurityContextHolder.createEmptyContext(); } @After @@ -41,13 +43,11 @@ public class TestSecurityContextHolderTests { @Test public void clearContextClearsBoth() { - SecurityContextHolder.setContext(context); - TestSecurityContextHolder.setContext(context); - + SecurityContextHolder.setContext(this.context); + TestSecurityContextHolder.setContext(this.context); TestSecurityContextHolder.clearContext(); - - assertThat(SecurityContextHolder.getContext()).isNotSameAs(context); - assertThat(TestSecurityContextHolder.getContext()).isNotSameAs(context); + assertThat(SecurityContextHolder.getContext()).isNotSameAs(this.context); + assertThat(TestSecurityContextHolder.getContext()).isNotSameAs(this.context); } @Test @@ -58,18 +58,16 @@ public class TestSecurityContextHolderTests { @Test public void setContextSetsBoth() { - TestSecurityContextHolder.setContext(context); - - assertThat(TestSecurityContextHolder.getContext()).isSameAs(context); - assertThat(SecurityContextHolder.getContext()).isSameAs(context); + TestSecurityContextHolder.setContext(this.context); + assertThat(TestSecurityContextHolder.getContext()).isSameAs(this.context); + assertThat(SecurityContextHolder.getContext()).isSameAs(this.context); } @Test public void setContextWithAuthentication() { Authentication authentication = mock(Authentication.class); - TestSecurityContextHolder.setAuthentication(authentication); - assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication); } + } diff --git a/test/src/test/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListenerTests.java index 0e378b6346..5bd5b512b8 100644 --- a/test/src/test/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListenerTests.java @@ -13,21 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.annotation; -import static org.assertj.core.api.Assertions.assertThat; +import java.security.Principal; import org.junit.Test; import org.junit.runner.RunWith; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import java.security.Principal; +import static org.assertj.core.api.Assertions.assertThat; @RunWith(SpringJUnit4ClassRunner.class) @SecurityTestExecutionListeners @@ -39,16 +41,12 @@ public class SecurityTestExecutionListenerTests { assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("user"); } - @WithMockUser @Test public void reactorContextTestSecurityContextHolderExecutionListenerTestIsRegistered() { - Mono name = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .map(Principal::getName); - - StepVerifier.create(name) - .expectNext("user") - .verifyComplete(); + Mono name = ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) + .map(Principal::getName); + StepVerifier.create(name).expectNext("user").verifyComplete(); } + } diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/CustomUserDetails.java b/test/src/test/java/org/springframework/security/test/context/showcase/CustomUserDetails.java index 6300d5a883..ccbc7a00f0 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/CustomUserDetails.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/CustomUserDetails.java @@ -13,20 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase; +import java.util.Collection; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.userdetails.UserDetails; -import java.util.Collection; - /** * @author Rob Winch */ public class CustomUserDetails implements UserDetails { + private final String name; + private final String username; + private final Collection authorities; public CustomUserDetails(String name, String username) { @@ -35,36 +39,44 @@ public class CustomUserDetails implements UserDetails { this.authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); } + @Override public Collection getAuthorities() { - return authorities; + return this.authorities; } + @Override public String getPassword() { return null; } + @Override public String getUsername() { - return username; + return this.username; } + @Override public boolean isAccountNonExpired() { return true; } + @Override public boolean isAccountNonLocked() { return true; } + @Override public boolean isCredentialsNonExpired() { return true; } + @Override public boolean isEnabled() { return true; } @Override public String toString() { - return "CustomUserDetails{" + "username='" + username + '\'' + '}'; + return "CustomUserDetails{" + "username='" + this.username + '\'' + '}'; } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUser.java b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUser.java index 73058bc2c4..a3c61627f2 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUser.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUser.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase; import org.springframework.security.test.context.support.WithSecurityContext; @@ -22,6 +23,7 @@ import org.springframework.security.test.context.support.WithSecurityContext; */ @WithSecurityContext(factory = WithMockCustomUserSecurityContextFactory.class) public @interface WithMockCustomUser { + /** * The username to be used. The default is rob * @return @@ -33,7 +35,6 @@ public @interface WithMockCustomUser { * {@link org.springframework.security.core.GrantedAuthority} will be created for each * value within roles. Each value in roles will automatically be prefixed with * "ROLE_". For example, the default will result in "ROLE_USER" being used. - * * @return */ String[] roles() default { "USER" }; @@ -43,4 +44,5 @@ public @interface WithMockCustomUser { * @return */ String name() default "Rob Winch"; + } diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUserSecurityContextFactory.java b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUserSecurityContextFactory.java index 67d0b28453..d174584cb1 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUserSecurityContextFactory.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUserSecurityContextFactory.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; @@ -24,16 +25,16 @@ import org.springframework.security.test.context.support.WithSecurityContextFact /** * @author Rob Winch */ -public class WithMockCustomUserSecurityContextFactory implements - WithSecurityContextFactory { +public class WithMockCustomUserSecurityContextFactory implements WithSecurityContextFactory { + + @Override public SecurityContext createSecurityContext(WithMockCustomUser customUser) { SecurityContext context = SecurityContextHolder.createEmptyContext(); - - CustomUserDetails principal = new CustomUserDetails(customUser.name(), - customUser.username()); - Authentication auth = new UsernamePasswordAuthenticationToken(principal, - "password", principal.getAuthorities()); + CustomUserDetails principal = new CustomUserDetails(customUser.name(), customUser.username()); + Authentication auth = new UsernamePasswordAuthenticationToken(principal, "password", + principal.getAuthorities()); context.setAuthentication(auth); return context; } + } diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParent.java b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParent.java index 8085335ea1..1f2c5c6016 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParent.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParent.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase; import org.springframework.security.test.context.support.WithMockUser; @@ -23,4 +24,4 @@ import org.springframework.security.test.context.support.WithMockUser; @WithMockUser public class WithMockUserParent { -} \ No newline at end of file +} diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParentTests.java b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParentTests.java index 0f25882073..d7294d510f 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParentTests.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserParentTests.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase; -import static org.assertj.core.api.Assertions.assertThat; +import org.junit.Test; +import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.ComponentScan; @@ -26,8 +28,7 @@ import org.springframework.security.test.context.showcase.service.MessageService import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import org.junit.Test; -import org.junit.runner.RunWith; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Eddú Meléndez @@ -41,20 +42,23 @@ public class WithMockUserParentTests extends WithMockUserParent { @Test public void getMessageWithMockUser() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("user"); } @EnableGlobalMethodSecurity(prePostEnabled = true) @ComponentScan(basePackageClasses = HelloMessageService.class) static class Config { - // @formatter:off + @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserTests.java b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserTests.java index eafe930b31..b14a95a33f 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserTests.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/WithMockUserTests.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.context.showcase; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.test.context.showcase; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.ComponentScan; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; @@ -30,60 +30,64 @@ import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Rob Winch */ - @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = WithMockUserTests.Config.class) public class WithMockUserTests { + @Autowired private MessageService messageService; @Test(expected = AuthenticationCredentialsNotFoundException.class) public void getMessageUnauthenticated() { - messageService.getMessage(); + this.messageService.getMessage(); } @Test @WithMockUser public void getMessageWithMockUser() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("user"); } @Test @WithMockUser("customUsername") public void getMessageWithMockUserCustomUsername() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("customUsername"); } @Test @WithMockUser(username = "admin", roles = { "USER", "ADMIN" }) public void getMessageWithMockUserCustomUser() { - String message = messageService.getMessage(); - assertThat(message).contains("admin").contains("ROLE_USER") - .contains("ROLE_ADMIN"); + String message = this.messageService.getMessage(); + assertThat(message).contains("admin").contains("ROLE_USER").contains("ROLE_ADMIN"); } @Test @WithMockUser(username = "admin", authorities = { "ADMIN", "USER" }) public void getMessageWithMockUserCustomAuthorities() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("admin").contains("ADMIN").contains("USER").doesNotContain("ROLE_"); } @EnableGlobalMethodSecurity(prePostEnabled = true) @ComponentScan(basePackageClasses = HelloMessageService.class) static class Config { - // @formatter:off + @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/WithUserDetailsTests.java b/test/src/test/java/org/springframework/security/test/context/showcase/WithUserDetailsTests.java index 55a8af3d38..cfd2a040c7 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/WithUserDetailsTests.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/WithUserDetailsTests.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.context.showcase; -import static org.assertj.core.api.Assertions.assertThat; +package org.springframework.security.test.context.showcase; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; @@ -35,25 +35,27 @@ import org.springframework.security.test.context.support.WithUserDetails; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Rob Winch */ - @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = WithUserDetailsTests.Config.class) public class WithUserDetailsTests { + @Autowired private MessageService messageService; @Test(expected = AuthenticationCredentialsNotFoundException.class) public void getMessageUnauthenticated() { - messageService.getMessage(); + this.messageService.getMessage(); } @Test @WithUserDetails public void getMessageWithUserDetails() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("user"); assertThat(getPrincipal()).isInstanceOf(CustomUserDetails.class); } @@ -61,45 +63,46 @@ public class WithUserDetailsTests { @Test @WithUserDetails("customUsername") public void getMessageWithUserDetailsCustomUsername() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("customUsername"); assertThat(getPrincipal()).isInstanceOf(CustomUserDetails.class); } @Test - @WithUserDetails(value="customUsername", userDetailsServiceBeanName="myUserDetailsService") + @WithUserDetails(value = "customUsername", userDetailsServiceBeanName = "myUserDetailsService") public void getMessageWithUserDetailsServiceBeanName() { - String message = messageService.getMessage(); + String message = this.messageService.getMessage(); assertThat(message).contains("customUsername"); assertThat(getPrincipal()).isInstanceOf(CustomUserDetails.class); } - @EnableGlobalMethodSecurity(prePostEnabled = true) - @ComponentScan(basePackageClasses = HelloMessageService.class) - static class Config { - // @formatter:off - @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { - auth - .userDetailsService(myUserDetailsService()); - } - // @formatter:on - - @Bean - public UserDetailsService myUserDetailsService() { - return new CustomUserDetailsService(); - } - } - private Object getPrincipal() { return SecurityContextHolder.getContext().getAuthentication().getPrincipal(); } + @EnableGlobalMethodSecurity(prePostEnabled = true) + @ComponentScan(basePackageClasses = HelloMessageService.class) + static class Config { + + @Autowired + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + auth.userDetailsService(myUserDetailsService()); + } + + @Bean + UserDetailsService myUserDetailsService() { + return new CustomUserDetailsService(); + } + + } + static class CustomUserDetailsService implements UserDetailsService { - public UserDetails loadUserByUsername(final String username) - throws UsernameNotFoundException { + @Override + public UserDetails loadUserByUsername(final String username) throws UsernameNotFoundException { return new CustomUserDetails("name", username); } + } + } diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/service/HelloMessageService.java b/test/src/test/java/org/springframework/security/test/context/showcase/service/HelloMessageService.java index 79c23e6a0a..95ad519c5e 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/service/HelloMessageService.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/service/HelloMessageService.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase.service; import org.springframework.security.access.prepost.PreAuthorize; @@ -26,10 +27,11 @@ import org.springframework.stereotype.Component; @Component public class HelloMessageService implements MessageService { + @Override @PreAuthorize("authenticated") public String getMessage() { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); return "Hello " + authentication; } + } diff --git a/test/src/test/java/org/springframework/security/test/context/showcase/service/MessageService.java b/test/src/test/java/org/springframework/security/test/context/showcase/service/MessageService.java index 4a78e2c546..f42757f9ee 100644 --- a/test/src/test/java/org/springframework/security/test/context/showcase/service/MessageService.java +++ b/test/src/test/java/org/springframework/security/test/context/showcase/service/MessageService.java @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.showcase.service; /** * @author Rob Winch */ public interface MessageService { + String getMessage(); + } diff --git a/test/src/test/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListenerTests.java index 9a47499de7..b9e88381ad 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListenerTests.java @@ -20,7 +20,6 @@ package org.springframework.security.test.context.support; * @author Rob Winch * @since 5.0 */ - import java.util.concurrent.ForkJoinPool; import org.junit.After; @@ -28,8 +27,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.security.core.context.SecurityContext; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -37,6 +34,8 @@ import reactor.test.StepVerifier; import org.springframework.core.OrderComparator; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.test.context.TestContext; @@ -49,8 +48,7 @@ public class ReactorContextTestExecutionListenerTests { @Mock private TestContext testContext; - private ReactorContextTestExecutionListener listener = - new ReactorContextTestExecutionListener(); + private ReactorContextTestExecutionListener listener = new ReactorContextTestExecutionListener(); @After public void cleanup() { @@ -61,49 +59,119 @@ public class ReactorContextTestExecutionListenerTests { @Test public void beforeTestMethodWhenSecurityContextEmptyThenReactorContextNull() throws Exception { this.listener.beforeTestMethod(this.testContext); - - Mono result = ReactiveSecurityContextHolder - .getContext(); - - StepVerifier.create(result) - .verifyComplete(); + Mono result = ReactiveSecurityContextHolder.getContext(); + StepVerifier.create(result).verifyComplete(); } @Test public void beforeTestMethodWhenNullAuthenticationThenReactorContextNull() throws Exception { TestSecurityContextHolder.setContext(new SecurityContextImpl()); - this.listener.beforeTestMethod(this.testContext); - - Mono result = ReactiveSecurityContextHolder - .getContext(); - - StepVerifier.create(result) - .verifyComplete(); + Mono result = ReactiveSecurityContextHolder.getContext(); + StepVerifier.create(result).verifyComplete(); } @Test public void beforeTestMethodWhenAuthenticationThenReactorContextHasAuthentication() throws Exception { - TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", + "ROLE_USER"); TestSecurityContextHolder.setAuthentication(expectedAuthentication); - this.listener.beforeTestMethod(this.testContext); - assertAuthentication(expectedAuthentication); } @Test public void beforeTestMethodWhenCustomContext() throws Exception { - TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", + "ROLE_USER"); SecurityContext context = new CustomContext(expectedAuthentication); TestSecurityContextHolder.setContext(context); - this.listener.beforeTestMethod(this.testContext); - assertSecurityContext(context); } + @Test + public void beforeTestMethodWhenExistingAuthenticationThenReactorContextHasOriginalAuthentication() + throws Exception { + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", + "ROLE_USER"); + TestingAuthenticationToken contextHolder = new TestingAuthenticationToken("contextHolder", "password", + "ROLE_USER"); + TestSecurityContextHolder.setAuthentication(contextHolder); + this.listener.beforeTestMethod(this.testContext); + Mono authentication = Mono.just("any") + .flatMap((s) -> ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication)) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expectedAuthentication)); + StepVerifier.create(authentication).expectNext(expectedAuthentication).verifyComplete(); + } + + @Test + public void beforeTestMethodWhenClearThenReactorContextDoesNotOverride() throws Exception { + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", + "ROLE_USER"); + TestingAuthenticationToken contextHolder = new TestingAuthenticationToken("contextHolder", "password", + "ROLE_USER"); + TestSecurityContextHolder.setAuthentication(contextHolder); + this.listener.beforeTestMethod(this.testContext); + Mono authentication = Mono.just("any") + .flatMap((s) -> ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication)) + .subscriberContext(ReactiveSecurityContextHolder.clearContext()); + StepVerifier.create(authentication).verifyComplete(); + } + + @Test + public void afterTestMethodWhenSecurityContextEmptyThenNoError() throws Exception { + this.listener.beforeTestMethod(this.testContext); + this.listener.afterTestMethod(this.testContext); + } + + @Test + public void afterTestMethodWhenSetupThenReactorContextNull() throws Exception { + beforeTestMethodWhenAuthenticationThenReactorContextHasAuthentication(); + this.listener.afterTestMethod(this.testContext); + assertThat(Mono.subscriberContext().block().isEmpty()).isTrue(); + } + + @Test + public void afterTestMethodWhenDifferentHookIsRegistered() throws Exception { + Object obj = new Object(); + Hooks.onLastOperator("CUSTOM_HOOK", (p) -> Mono.just(obj)); + this.listener.afterTestMethod(this.testContext); + Object result = Mono.subscriberContext().block(); + assertThat(result).isEqualTo(obj); + } + + @Test + public void orderWhenComparedToWithSecurityContextTestExecutionListenerIsAfter() { + OrderComparator comparator = new OrderComparator(); + WithSecurityContextTestExecutionListener withSecurity = new WithSecurityContextTestExecutionListener(); + ReactorContextTestExecutionListener reactorContext = new ReactorContextTestExecutionListener(); + assertThat(comparator.compare(withSecurity, reactorContext)).isLessThan(0); + } + + @Test + public void checkSecurityContextResolutionWhenSubscribedContextCalledOnTheDifferentThreadThanWithSecurityContextTestExecutionListener() + throws Exception { + TestingAuthenticationToken contextHolder = new TestingAuthenticationToken("contextHolder", "password", + "ROLE_USER"); + TestSecurityContextHolder.setAuthentication(contextHolder); + this.listener.beforeTestMethod(this.testContext); + ForkJoinPool.commonPool().submit(() -> assertAuthentication(contextHolder)).join(); + } + + public void assertAuthentication(Authentication expected) { + Mono authentication = ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication); + StepVerifier.create(authentication).expectNext(expected).verifyComplete(); + } + + private void assertSecurityContext(SecurityContext expected) { + Mono securityContext = ReactiveSecurityContextHolder.getContext(); + StepVerifier.create(securityContext).expectNext(expected).verifyComplete(); + } + static class CustomContext implements SecurityContext { + private Authentication authentication; CustomContext(Authentication authentication) { @@ -119,107 +187,7 @@ public class ReactorContextTestExecutionListenerTests { public void setAuthentication(Authentication authentication) { this.authentication = authentication; } + } - @Test - public void beforeTestMethodWhenExistingAuthenticationThenReactorContextHasOriginalAuthentication() throws Exception { - TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); - TestingAuthenticationToken contextHolder = new TestingAuthenticationToken("contextHolder", "password", "ROLE_USER"); - TestSecurityContextHolder.setAuthentication(contextHolder); - - this.listener.beforeTestMethod(this.testContext); - - Mono authentication = Mono.just("any") - .flatMap(s -> ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - ) - .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(expectedAuthentication)); - - StepVerifier.create(authentication) - .expectNext(expectedAuthentication) - .verifyComplete(); - } - - @Test - public void beforeTestMethodWhenClearThenReactorContextDoesNotOverride() throws Exception { - TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); - TestingAuthenticationToken contextHolder = new TestingAuthenticationToken("contextHolder", "password", "ROLE_USER"); - TestSecurityContextHolder.setAuthentication(contextHolder); - - this.listener.beforeTestMethod(this.testContext); - - Mono authentication = Mono.just("any") - .flatMap(s -> ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - ) - .subscriberContext(ReactiveSecurityContextHolder.clearContext()); - - StepVerifier.create(authentication) - .verifyComplete(); - } - - @Test - public void afterTestMethodWhenSecurityContextEmptyThenNoError() throws Exception { - this.listener.beforeTestMethod(this.testContext); - - this.listener.afterTestMethod(this.testContext); - } - - @Test - public void afterTestMethodWhenSetupThenReactorContextNull() throws Exception { - beforeTestMethodWhenAuthenticationThenReactorContextHasAuthentication(); - - this.listener.afterTestMethod(this.testContext); - - assertThat(Mono.subscriberContext().block().isEmpty()).isTrue(); - } - - @Test - public void afterTestMethodWhenDifferentHookIsRegistered() throws Exception { - Object obj = new Object(); - - Hooks.onLastOperator("CUSTOM_HOOK", p -> Mono.just(obj)); - this.listener.afterTestMethod(this.testContext); - - Object result = Mono.subscriberContext().block(); - assertThat(result).isEqualTo(obj); - } - - @Test - public void orderWhenComparedToWithSecurityContextTestExecutionListenerIsAfter() { - OrderComparator comparator = new OrderComparator(); - WithSecurityContextTestExecutionListener withSecurity = new WithSecurityContextTestExecutionListener(); - ReactorContextTestExecutionListener reactorContext = new ReactorContextTestExecutionListener(); - assertThat(comparator.compare(withSecurity, reactorContext)).isLessThan(0); - } - - @Test - public void checkSecurityContextResolutionWhenSubscribedContextCalledOnTheDifferentThreadThanWithSecurityContextTestExecutionListener() throws Exception { - TestingAuthenticationToken contextHolder = new TestingAuthenticationToken("contextHolder", "password", "ROLE_USER"); - TestSecurityContextHolder.setAuthentication(contextHolder); - - this.listener.beforeTestMethod(this.testContext); - - ForkJoinPool.commonPool() - .submit(() -> assertAuthentication(contextHolder)) - .join(); - } - - public void assertAuthentication(Authentication expected) { - Mono authentication = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication); - - StepVerifier.create(authentication) - .expectNext(expected) - .verifyComplete(); - } - - - private void assertSecurityContext(SecurityContext expected) { - Mono securityContext = ReactiveSecurityContextHolder.getContext(); - - StepVerifier.create(securityContext) - .expectNext(expected) - .verifyComplete(); - } } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithAnonymousUserTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithAnonymousUserTests.java index 838becec27..94780c93db 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithAnonymousUserTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithAnonymousUserTests.java @@ -17,6 +17,7 @@ package org.springframework.security.test.context.support; import org.junit.Test; + import org.springframework.core.annotation.AnnotatedElementUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -26,40 +27,41 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 5.0 */ public class WithAnonymousUserTests { + @Test public void defaults() { WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(Annotated.class, - WithSecurityContext.class); - + WithSecurityContext.class); assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_METHOD); } - @WithAnonymousUser - private class Annotated { - } - @Test public void findMergedAnnotationWhenSetupExplicitThenOverridden() { - WithSecurityContext context = AnnotatedElementUtils - .findMergedAnnotation(SetupExplicit.class, + WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(SetupExplicit.class, WithSecurityContext.class); - assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_METHOD); } - @WithAnonymousUser(setupBefore = TestExecutionEvent.TEST_METHOD) - private class SetupExplicit { - } - @Test public void findMergedAnnotationWhenSetupOverriddenThenOverridden() { WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(SetupOverridden.class, - WithSecurityContext.class); - + WithSecurityContext.class); assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_EXECUTION); } + @WithAnonymousUser + private class Annotated { + + } + + @WithAnonymousUser(setupBefore = TestExecutionEvent.TEST_METHOD) + private class SetupExplicit { + + } + @WithAnonymousUser(setupBefore = TestExecutionEvent.TEST_EXECUTION) private class SetupOverridden { + } + } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactoryTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactoryTests.java index a129baeece..56b90ca9fe 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactoryTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactoryTests.java @@ -13,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.context.support; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; +package org.springframework.security.test.context.support; import org.junit.Before; import org.junit.Test; @@ -24,6 +22,9 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + @RunWith(MockitoJUnitRunner.class) public class WithMockUserSecurityContextFactoryTests { @@ -34,77 +35,68 @@ public class WithMockUserSecurityContextFactoryTests { @Before public void setup() { - factory = new WithMockUserSecurityContextFactory(); + this.factory = new WithMockUserSecurityContextFactory(); } @Test(expected = IllegalArgumentException.class) public void usernameNull() { - factory.createSecurityContext(withUser); + this.factory.createSecurityContext(this.withUser); } @Test public void valueDefaultsUsername() { - when(withUser.value()).thenReturn("valueUser"); - when(withUser.password()).thenReturn("password"); - when(withUser.roles()).thenReturn(new String[] { "USER" }); - when(withUser.authorities()).thenReturn(new String[] {}); - - assertThat(factory.createSecurityContext(withUser).getAuthentication().getName()) - .isEqualTo(withUser.value()); + given(this.withUser.value()).willReturn("valueUser"); + given(this.withUser.password()).willReturn("password"); + given(this.withUser.roles()).willReturn(new String[] { "USER" }); + given(this.withUser.authorities()).willReturn(new String[] {}); + assertThat(this.factory.createSecurityContext(this.withUser).getAuthentication().getName()) + .isEqualTo(this.withUser.value()); } @Test public void usernamePrioritizedOverValue() { - when(withUser.username()).thenReturn("customUser"); - when(withUser.password()).thenReturn("password"); - when(withUser.roles()).thenReturn(new String[] { "USER" }); - when(withUser.authorities()).thenReturn(new String[] {}); - - assertThat(factory.createSecurityContext(withUser).getAuthentication().getName()) - .isEqualTo(withUser.username()); + given(this.withUser.username()).willReturn("customUser"); + given(this.withUser.password()).willReturn("password"); + given(this.withUser.roles()).willReturn(new String[] { "USER" }); + given(this.withUser.authorities()).willReturn(new String[] {}); + assertThat(this.factory.createSecurityContext(this.withUser).getAuthentication().getName()) + .isEqualTo(this.withUser.username()); } @Test public void rolesWorks() { - when(withUser.value()).thenReturn("valueUser"); - when(withUser.password()).thenReturn("password"); - when(withUser.roles()).thenReturn(new String[] { "USER", "CUSTOM" }); - when(withUser.authorities()).thenReturn(new String[] {}); - - assertThat( - factory.createSecurityContext(withUser).getAuthentication() - .getAuthorities()).extracting("authority").containsOnly( - "ROLE_USER", "ROLE_CUSTOM"); + given(this.withUser.value()).willReturn("valueUser"); + given(this.withUser.password()).willReturn("password"); + given(this.withUser.roles()).willReturn(new String[] { "USER", "CUSTOM" }); + given(this.withUser.authorities()).willReturn(new String[] {}); + assertThat(this.factory.createSecurityContext(this.withUser).getAuthentication().getAuthorities()) + .extracting("authority").containsOnly("ROLE_USER", "ROLE_CUSTOM"); } @Test public void authoritiesWorks() { - when(withUser.value()).thenReturn("valueUser"); - when(withUser.password()).thenReturn("password"); - when(withUser.roles()).thenReturn(new String[] { "USER" }); - when(withUser.authorities()).thenReturn(new String[] { "USER", "CUSTOM" }); - - assertThat( - factory.createSecurityContext(withUser).getAuthentication() - .getAuthorities()).extracting("authority").containsOnly( - "USER", "CUSTOM"); + given(this.withUser.value()).willReturn("valueUser"); + given(this.withUser.password()).willReturn("password"); + given(this.withUser.roles()).willReturn(new String[] { "USER" }); + given(this.withUser.authorities()).willReturn(new String[] { "USER", "CUSTOM" }); + assertThat(this.factory.createSecurityContext(this.withUser).getAuthentication().getAuthorities()) + .extracting("authority").containsOnly("USER", "CUSTOM"); } @Test(expected = IllegalStateException.class) public void authoritiesAndRolesInvalid() { - when(withUser.value()).thenReturn("valueUser"); - when(withUser.roles()).thenReturn(new String[] { "CUSTOM" }); - when(withUser.authorities()).thenReturn(new String[] { "USER", "CUSTOM" }); - - factory.createSecurityContext(withUser); + given(this.withUser.value()).willReturn("valueUser"); + given(this.withUser.roles()).willReturn(new String[] { "CUSTOM" }); + given(this.withUser.authorities()).willReturn(new String[] { "USER", "CUSTOM" }); + this.factory.createSecurityContext(this.withUser); } @Test(expected = IllegalArgumentException.class) public void rolesWithRolePrefixFails() { - when(withUser.value()).thenReturn("valueUser"); - when(withUser.roles()).thenReturn(new String[] { "ROLE_FAIL" }); - when(withUser.authorities()).thenReturn(new String[] {}); - - factory.createSecurityContext(withUser); + given(this.withUser.value()).willReturn("valueUser"); + given(this.withUser.roles()).willReturn(new String[] { "ROLE_FAIL" }); + given(this.withUser.authorities()).willReturn(new String[] {}); + this.factory.createSecurityContext(this.withUser); } + } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithMockUserTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithMockUserTests.java index ed040a3caf..763cfdf9b2 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithMockUserTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithMockUserTests.java @@ -16,55 +16,54 @@ package org.springframework.security.test.context.support; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.core.annotation.AnnotatedElementUtils; +import static org.assertj.core.api.Assertions.assertThat; + public class WithMockUserTests { @Test public void defaults() { - WithMockUser mockUser = AnnotatedElementUtils.findMergedAnnotation(Annotated.class, - WithMockUser.class); + WithMockUser mockUser = AnnotatedElementUtils.findMergedAnnotation(Annotated.class, WithMockUser.class); assertThat(mockUser.value()).isEqualTo("user"); assertThat(mockUser.username()).isEmpty(); assertThat(mockUser.password()).isEqualTo("password"); assertThat(mockUser.roles()).containsOnly("USER"); assertThat(mockUser.setupBefore()).isEqualByComparingTo(TestExecutionEvent.TEST_METHOD); - WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(Annotated.class, - WithSecurityContext.class); - + WithSecurityContext.class); assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_METHOD); } - @WithMockUser - private class Annotated { - } - @Test public void findMergedAnnotationWhenSetupExplicitThenOverridden() { - WithSecurityContext context = AnnotatedElementUtils - .findMergedAnnotation(SetupExplicit.class, + WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(SetupExplicit.class, WithSecurityContext.class); - assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_METHOD); } - @WithMockUser(setupBefore = TestExecutionEvent.TEST_METHOD) - private class SetupExplicit { - } - @Test public void findMergedAnnotationWhenSetupOverriddenThenOverridden() { WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(SetupOverridden.class, - WithSecurityContext.class); - + WithSecurityContext.class); assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_EXECUTION); } + @WithMockUser + private class Annotated { + + } + + @WithMockUser(setupBefore = TestExecutionEvent.TEST_METHOD) + private class SetupExplicit { + + } + @WithMockUser(setupBefore = TestExecutionEvent.TEST_EXECUTION) private class SetupOverridden { + } + } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java index c44e3e2a36..e4849015f3 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; import java.lang.annotation.Annotation; @@ -45,11 +46,12 @@ import org.springframework.test.context.support.AbstractTestExecutionListener; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class WithSecurityContextTestExcecutionListenerTests { + private ConfigurableApplicationContext context; @Mock @@ -59,15 +61,15 @@ public class WithSecurityContextTestExcecutionListenerTests { @Before public void setup() { - listener = new WithSecurityContextTestExecutionListener(); - context = new AnnotationConfigApplicationContext(Config.class); + this.listener = new WithSecurityContextTestExecutionListener(); + this.context = new AnnotationConfigApplicationContext(Config.class); } @After public void cleanup() { TestSecurityContextHolder.clearContext(); - if (context != null) { - context.close(); + if (this.context != null) { + this.context.close(); } } @@ -75,36 +77,28 @@ public class WithSecurityContextTestExcecutionListenerTests { @SuppressWarnings({ "rawtypes", "unchecked" }) public void beforeTestMethodNullSecurityContextNoError() throws Exception { Class testClass = FakeTest.class; - when(testContext.getTestClass()).thenReturn(testClass); - when(testContext.getTestMethod()).thenReturn( - ReflectionUtils.findMethod(testClass, "testNoAnnotation")); - - listener.beforeTestMethod(testContext); + given(this.testContext.getTestClass()).willReturn(testClass); + given(this.testContext.getTestMethod()).willReturn(ReflectionUtils.findMethod(testClass, "testNoAnnotation")); + this.listener.beforeTestMethod(this.testContext); } @Test @SuppressWarnings({ "rawtypes", "unchecked" }) public void beforeTestMethodNoApplicationContext() throws Exception { Class testClass = FakeTest.class; - when(testContext.getApplicationContext()).thenThrow(new IllegalStateException()); - when(testContext.getTestMethod()).thenReturn( - ReflectionUtils.findMethod(testClass, "testWithMockUser")); - - listener.beforeTestMethod(testContext); - - assertThat(TestSecurityContextHolder.getContext().getAuthentication().getName()) - .isEqualTo("user"); + given(this.testContext.getApplicationContext()).willThrow(new IllegalStateException()); + given(this.testContext.getTestMethod()).willReturn(ReflectionUtils.findMethod(testClass, "testWithMockUser")); + this.listener.beforeTestMethod(this.testContext); + assertThat(TestSecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("user"); } + // gh-3962 @Test public void withSecurityContextAfterSqlScripts() { SqlScriptsTestExecutionListener sql = new SqlScriptsTestExecutionListener(); WithSecurityContextTestExecutionListener security = new WithSecurityContextTestExecutionListener(); - List listeners = Arrays.asList(security, sql); - AnnotationAwareOrderComparator.sort(listeners); - assertThat(listeners).containsExactly(sql, security); } @@ -113,29 +107,22 @@ public class WithSecurityContextTestExcecutionListenerTests { public void orderOverridden() { AbstractTestExecutionListener otherListener = new AbstractTestExecutionListener() { }; - List listeners = new ArrayList<>(); listeners.add(otherListener); listeners.add(this.listener); - AnnotationAwareOrderComparator.sort(listeners); - assertThat(listeners).containsSequence(this.listener, otherListener); } @Test // gh-3837 public void handlesGenericAnnotation() throws Exception { - Method method = ReflectionUtils.findMethod( - WithSecurityContextTestExcecutionListenerTests.class, + Method method = ReflectionUtils.findMethod(WithSecurityContextTestExcecutionListenerTests.class, "handlesGenericAnnotationTestMethod"); TestContext testContext = mock(TestContext.class); - when(testContext.getTestMethod()).thenReturn(method); - when(testContext.getApplicationContext()) - .thenThrow(new IllegalStateException("")); - + given(testContext.getTestMethod()).willReturn(method); + given(testContext.getApplicationContext()).willThrow(new IllegalStateException("")); this.listener.beforeTestMethod(testContext); - assertThat(SecurityContextHolder.getContext().getAuthentication().getPrincipal()) .isInstanceOf(WithSuperClassWithSecurityContext.class); } @@ -147,11 +134,12 @@ public class WithSecurityContextTestExcecutionListenerTests { @Retention(RetentionPolicy.RUNTIME) @WithSecurityContext(factory = SuperClassWithSecurityContextFactory.class) @interface WithSuperClassWithSecurityContext { + String username() default "WithSuperClassWithSecurityContext"; + } - static class SuperClassWithSecurityContextFactory - implements WithSecurityContextFactory { + static class SuperClassWithSecurityContextFactory implements WithSecurityContextFactory { @Override public SecurityContext createSecurityContext(Annotation annotation) { @@ -159,19 +147,23 @@ public class WithSecurityContextTestExcecutionListenerTests { context.setAuthentication(new TestingAuthenticationToken(annotation, "NA")); return context; } + } static class FakeTest { - public void testNoAnnotation() { + + void testNoAnnotation() { } @WithMockUser - public void testWithMockUser() { - + void testWithMockUser() { } + } @Configuration static class Config { + } + } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java index bb3c8cfd0c..64186a5b96 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java @@ -16,6 +16,9 @@ package org.springframework.security.test.context.support; +import java.lang.reflect.Method; +import java.util.function.Supplier; + import org.junit.After; import org.junit.ClassRule; import org.junit.Rule; @@ -25,6 +28,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Configuration; @@ -37,15 +41,12 @@ import org.springframework.test.context.TestContext; import org.springframework.test.context.junit4.rules.SpringClassRule; import org.springframework.test.context.junit4.rules.SpringMethodRule; -import java.lang.reflect.Method; -import java.util.function.Supplier; - -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -54,8 +55,10 @@ import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @ContextConfiguration(classes = WithSecurityContextTestExecutionListenerTests.NoOpConfiguration.class) public class WithSecurityContextTestExecutionListenerTests { + @ClassRule public static final SpringClassRule spring = new SpringClassRule(); + @Rule public final SpringMethodRule springMethod = new SpringMethodRule(); @@ -75,49 +78,43 @@ public class WithSecurityContextTestExecutionListenerTests { @Test public void beforeTestMethodWhenWithMockUserTestExecutionDefaultThenSecurityContextSet() throws Exception { Method testMethod = TheTest.class.getMethod("withMockUserDefault"); - when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext); - when(this.testContext.getTestMethod()).thenReturn(testMethod); - + given(this.testContext.getApplicationContext()).willReturn(this.applicationContext); + given(this.testContext.getTestMethod()).willReturn(testMethod); this.listener.beforeTestMethod(this.testContext); - assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNotNull(); - verify(this.testContext, never()).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class)); + verify(this.testContext, never()).setAttribute( + eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class)); } @Test public void beforeTestMethodWhenWithMockUserTestMethodThenSecurityContextSet() throws Exception { Method testMethod = TheTest.class.getMethod("withMockUserTestMethod"); - when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext); - when(this.testContext.getTestMethod()).thenReturn(testMethod); - + given(this.testContext.getApplicationContext()).willReturn(this.applicationContext); + given(this.testContext.getTestMethod()).willReturn(testMethod); this.listener.beforeTestMethod(this.testContext); - assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNotNull(); - verify(this.testContext, never()).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class)); + verify(this.testContext, never()).setAttribute( + eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class)); } @Test public void beforeTestMethodWhenWithMockUserTestExecutionThenTestContextSet() throws Exception { Method testMethod = TheTest.class.getMethod("withMockUserTestExecution"); - when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext); - when(this.testContext.getTestMethod()).thenReturn(testMethod); - + given(this.testContext.getApplicationContext()).willReturn(this.applicationContext); + given(this.testContext.getTestMethod()).willReturn(testMethod); this.listener.beforeTestMethod(this.testContext); - assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNull(); - verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME) - , ArgumentMatchers.>any()); + verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), + ArgumentMatchers.>any()); } @Test @SuppressWarnings("unchecked") public void beforeTestMethodWhenWithMockUserTestExecutionThenTestContextSupplierOk() throws Exception { Method testMethod = TheTest.class.getMethod("withMockUserTestExecution"); - when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext); - when(this.testContext.getTestMethod()).thenReturn(testMethod); - + given(this.testContext.getApplicationContext()).willReturn(this.applicationContext); + given(this.testContext.getTestMethod()).willReturn(testMethod); this.listener.beforeTestMethod(this.testContext); - ArgumentCaptor> supplierCaptor = ArgumentCaptor.forClass(Supplier.class); verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), supplierCaptor.capture()); @@ -128,10 +125,9 @@ public class WithSecurityContextTestExecutionListenerTests { // gh-6591 public void beforeTestMethodWhenTestExecutionThenDelayFactoryCreate() throws Exception { Method testMethod = TheTest.class.getMethod("withUserDetails"); - when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext); + given(this.testContext.getApplicationContext()).willReturn(this.applicationContext); // do not set a UserDetailsService Bean so it would fail if looked up - when(this.testContext.getTestMethod()).thenReturn(testMethod); - + given(this.testContext.getTestMethod()).willReturn(testMethod); this.listener.beforeTestMethod(this.testContext); // bean lookup of UserDetailsService would fail if it has already been looked up } @@ -139,7 +135,6 @@ public class WithSecurityContextTestExecutionListenerTests { @Test public void beforeTestExecutionWhenTestContextNullThenSecurityContextNotSet() { this.listener.beforeTestExecution(this.testContext); - assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNull(); } @@ -148,17 +143,20 @@ public class WithSecurityContextTestExecutionListenerTests { SecurityContextImpl securityContext = new SecurityContextImpl(); securityContext.setAuthentication(new TestingAuthenticationToken("user", "passsword", "ROLE_USER")); Supplier supplier = () -> securityContext; - when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(supplier); - + given(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)) + .willReturn(supplier); this.listener.beforeTestExecution(this.testContext); - - assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isEqualTo(securityContext.getAuthentication()); + assertThat(TestSecurityContextHolder.getContext().getAuthentication()) + .isEqualTo(securityContext.getAuthentication()); } @Configuration - static class NoOpConfiguration {} + static class NoOpConfiguration { + + } static class TheTest { + @WithMockUser(setupBefore = TestExecutionEvent.TEST_EXECUTION) public void withMockUserTestExecution() { } @@ -174,6 +172,7 @@ public class WithSecurityContextTestExecutionListenerTests { @WithUserDetails(setupBefore = TestExecutionEvent.TEST_EXECUTION) public void withUserDetails() { } + } } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactoryTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactoryTests.java index 50b6e8aff0..5a2d710699 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactoryTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactoryTests.java @@ -13,17 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.context.support; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +package org.springframework.security.test.context.support; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanNotOfRequiredTypeException; import org.springframework.beans.factory.NoSuchBeanDefinitionException; @@ -33,17 +32,22 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; -import reactor.core.publisher.Mono; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.verify; @RunWith(MockitoJUnitRunner.class) public class WithUserDetailsSecurityContextFactoryTests { @Mock private ReactiveUserDetailsService reactiveUserDetailsService; + @Mock private UserDetailsService userDetailsService; + @Mock private UserDetails userDetails; + @Mock private BeanFactory beans; @@ -54,34 +58,31 @@ public class WithUserDetailsSecurityContextFactoryTests { @Before public void setup() { - factory = new WithUserDetailsSecurityContextFactory(beans); + this.factory = new WithUserDetailsSecurityContextFactory(this.beans); } @Test(expected = IllegalArgumentException.class) public void createSecurityContextNullValue() { - factory.createSecurityContext(withUserDetails); + this.factory.createSecurityContext(this.withUserDetails); } @Test(expected = IllegalArgumentException.class) public void createSecurityContextEmptyValue() { - - when(withUserDetails.value()).thenReturn(""); - factory.createSecurityContext(withUserDetails); + given(this.withUserDetails.value()).willReturn(""); + this.factory.createSecurityContext(this.withUserDetails); } @Test public void createSecurityContextWithExistingUser() { String username = "user"; - when(this.beans.getBean(ReactiveUserDetailsService.class)).thenThrow(new NoSuchBeanDefinitionException("")); - when(beans.getBean(UserDetailsService.class)).thenReturn(userDetailsService); - when(withUserDetails.value()).thenReturn(username); - when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); - - SecurityContext context = factory.createSecurityContext(withUserDetails); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); - assertThat(context.getAuthentication().getPrincipal()).isEqualTo(userDetails); - verify(beans).getBean(UserDetailsService.class); + given(this.beans.getBean(ReactiveUserDetailsService.class)).willThrow(new NoSuchBeanDefinitionException("")); + given(this.beans.getBean(UserDetailsService.class)).willReturn(this.userDetailsService); + given(this.withUserDetails.value()).willReturn(username); + given(this.userDetailsService.loadUserByUsername(username)).willReturn(this.userDetails); + SecurityContext context = this.factory.createSecurityContext(this.withUserDetails); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); + assertThat(context.getAuthentication().getPrincipal()).isEqualTo(this.userDetails); + verify(this.beans).getBean(UserDetailsService.class); } // gh-3346 @@ -89,30 +90,27 @@ public class WithUserDetailsSecurityContextFactoryTests { public void createSecurityContextWithUserDetailsServiceName() { String beanName = "secondUserDetailsServiceBean"; String username = "user"; - when(this.beans.getBean(beanName, ReactiveUserDetailsService.class)).thenThrow(new BeanNotOfRequiredTypeException("", ReactiveUserDetailsService.class, UserDetailsService.class)); - when(withUserDetails.value()).thenReturn(username); - when(withUserDetails.userDetailsServiceBeanName()).thenReturn(beanName); - when(userDetailsService.loadUserByUsername(username)).thenReturn(userDetails); - when(beans.getBean(beanName, UserDetailsService.class)).thenReturn(userDetailsService); - - SecurityContext context = factory.createSecurityContext(withUserDetails); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); - assertThat(context.getAuthentication().getPrincipal()).isEqualTo(userDetails); - verify(beans).getBean(beanName, UserDetailsService.class); + given(this.beans.getBean(beanName, ReactiveUserDetailsService.class)).willThrow( + new BeanNotOfRequiredTypeException("", ReactiveUserDetailsService.class, UserDetailsService.class)); + given(this.withUserDetails.value()).willReturn(username); + given(this.withUserDetails.userDetailsServiceBeanName()).willReturn(beanName); + given(this.userDetailsService.loadUserByUsername(username)).willReturn(this.userDetails); + given(this.beans.getBean(beanName, UserDetailsService.class)).willReturn(this.userDetailsService); + SecurityContext context = this.factory.createSecurityContext(this.withUserDetails); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); + assertThat(context.getAuthentication().getPrincipal()).isEqualTo(this.userDetails); + verify(this.beans).getBean(beanName, UserDetailsService.class); } @Test public void createSecurityContextWithReactiveUserDetailsService() { String username = "user"; - when(withUserDetails.value()).thenReturn(username); - when(this.beans.getBean(ReactiveUserDetailsService.class)).thenReturn(this.reactiveUserDetailsService); - when(this.reactiveUserDetailsService.findByUsername(username)).thenReturn(Mono.just(userDetails)); - - SecurityContext context = factory.createSecurityContext(withUserDetails); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); - assertThat(context.getAuthentication().getPrincipal()).isEqualTo(userDetails); + given(this.withUserDetails.value()).willReturn(username); + given(this.beans.getBean(ReactiveUserDetailsService.class)).willReturn(this.reactiveUserDetailsService); + given(this.reactiveUserDetailsService.findByUsername(username)).willReturn(Mono.just(this.userDetails)); + SecurityContext context = this.factory.createSecurityContext(this.withUserDetails); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); + assertThat(context.getAuthentication().getPrincipal()).isEqualTo(this.userDetails); verify(this.beans).getBean(ReactiveUserDetailsService.class); } @@ -120,15 +118,15 @@ public class WithUserDetailsSecurityContextFactoryTests { public void createSecurityContextWithReactiveUserDetailsServiceAndBeanName() { String beanName = "secondUserDetailsServiceBean"; String username = "user"; - when(withUserDetails.value()).thenReturn(username); - when(withUserDetails.userDetailsServiceBeanName()).thenReturn(beanName); - when(this.beans.getBean(beanName, ReactiveUserDetailsService.class)).thenReturn(this.reactiveUserDetailsService); - when(this.reactiveUserDetailsService.findByUsername(username)).thenReturn(Mono.just(userDetails)); - - SecurityContext context = factory.createSecurityContext(withUserDetails); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); - assertThat(context.getAuthentication().getPrincipal()).isEqualTo(userDetails); + given(this.withUserDetails.value()).willReturn(username); + given(this.withUserDetails.userDetailsServiceBeanName()).willReturn(beanName); + given(this.beans.getBean(beanName, ReactiveUserDetailsService.class)) + .willReturn(this.reactiveUserDetailsService); + given(this.reactiveUserDetailsService.findByUsername(username)).willReturn(Mono.just(this.userDetails)); + SecurityContext context = this.factory.createSecurityContext(this.withUserDetails); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); + assertThat(context.getAuthentication().getPrincipal()).isEqualTo(this.userDetails); verify(this.beans).getBean(beanName, ReactiveUserDetailsService.class); } + } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsTests.java index 3c4f12749d..b2d041d980 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithUserDetailsTests.java @@ -13,56 +13,54 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.context.support; -import static org.assertj.core.api.Assertions.assertThat; - import org.junit.Test; + import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotationUtils; +import static org.assertj.core.api.Assertions.assertThat; + public class WithUserDetailsTests { @Test public void defaults() { - WithUserDetails userDetails = AnnotationUtils.findAnnotation(Annotated.class, - WithUserDetails.class); + WithUserDetails userDetails = AnnotationUtils.findAnnotation(Annotated.class, WithUserDetails.class); assertThat(userDetails.value()).isEqualTo("user"); - - WithSecurityContext context = AnnotatedElementUtils - .findMergedAnnotation(Annotated.class, + WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(Annotated.class, WithSecurityContext.class); - assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_METHOD); } - @WithUserDetails - private static class Annotated { - } - @Test public void findMergedAnnotationWhenSetupExplicitThenOverridden() { - WithSecurityContext context = AnnotatedElementUtils - .findMergedAnnotation(SetupExplicit.class, + WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(SetupExplicit.class, WithSecurityContext.class); - assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_METHOD); } - @WithUserDetails(setupBefore = TestExecutionEvent.TEST_METHOD) - private class SetupExplicit { - } - @Test public void findMergedAnnotationWhenSetupOverriddenThenOverridden() { - WithSecurityContext context = AnnotatedElementUtils - .findMergedAnnotation(SetupOverridden.class, + WithSecurityContext context = AnnotatedElementUtils.findMergedAnnotation(SetupOverridden.class, WithSecurityContext.class); - assertThat(context.setupBefore()).isEqualTo(TestExecutionEvent.TEST_EXECUTION); } + @WithUserDetails + private static class Annotated { + + } + + @WithUserDetails(setupBefore = TestExecutionEvent.TEST_METHOD) + private class SetupExplicit { + + } + @WithUserDetails(setupBefore = TestExecutionEvent.TEST_EXECUTION) private class SetupOverridden { + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/AbstractMockServerConfigurersTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/AbstractMockServerConfigurersTests.java index 6af2661f64..5389b138f2 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/AbstractMockServerConfigurersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/AbstractMockServerConfigurersTests.java @@ -34,21 +34,18 @@ import static org.assertj.core.api.Assertions.assertThat; * @since 5.0 */ abstract class AbstractMockServerConfigurersTests { + protected PrincipalController controller = new PrincipalController(); + protected SecurityContextController securityContextController = new SecurityContextController(); - protected User.UserBuilder userBuilder = User - .withUsername("user") - .password("password") - .roles("USER"); + protected User.UserBuilder userBuilder = User.withUsername("user").password("password").roles("USER"); protected void assertPrincipalCreatedFromUserDetails(Principal principal, UserDetails originalUserDetails) { assertThat(principal).isInstanceOf(UsernamePasswordAuthenticationToken.class); - UsernamePasswordAuthenticationToken authentication = (UsernamePasswordAuthenticationToken) principal; assertThat(authentication.getCredentials()).isEqualTo(originalUserDetails.getPassword()); assertThat(authentication.getAuthorities()).containsOnlyElementsOf(originalUserDetails.getAuthorities()); - UserDetails userDetails = (UserDetails) authentication.getPrincipal(); assertThat(userDetails.getPassword()).isEqualTo(authentication.getCredentials()); assertThat(authentication.getAuthorities()).containsOnlyElementsOf(userDetails.getAuthorities()); @@ -56,6 +53,7 @@ abstract class AbstractMockServerConfigurersTests { @RestController protected static class PrincipalController { + volatile Principal principal; @RequestMapping("/**") @@ -74,10 +72,12 @@ abstract class AbstractMockServerConfigurersTests { assertThat(this.principal).isEqualTo(expected); this.principal = null; } + } @RestController protected static class SecurityContextController { + volatile SecurityContext securityContext; @RequestMapping("/**") @@ -91,5 +91,7 @@ abstract class AbstractMockServerConfigurersTests { this.securityContext = null; return result; } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurerOpaqueTokenTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurerOpaqueTokenTests.java index c44da6255a..4ddd34771d 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurerOpaqueTokenTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurerOpaqueTokenTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.reactive.server; import java.util.List; @@ -28,16 +29,14 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals; import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication; +import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames; import org.springframework.security.web.reactive.result.method.annotation.CurrentSecurityContextArgumentResolver; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.test.web.reactive.server.WebTestClient; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals.active; -import static org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionClaimNames.SUBJECT; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOpaqueToken; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; /** * @author Josh Cummings @@ -45,78 +44,59 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock */ @RunWith(MockitoJUnitRunner.class) public class SecurityMockServerConfigurerOpaqueTokenTests extends AbstractMockServerConfigurersTests { + private GrantedAuthority authority1 = new SimpleGrantedAuthority("one"); private GrantedAuthority authority2 = new SimpleGrantedAuthority("two"); - private WebTestClient client = WebTestClient - .bindToController(securityContextController) + private WebTestClient client = WebTestClient.bindToController(this.securityContextController) .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .argumentResolvers(resolvers -> resolvers.addCustomResolver( - new CurrentSecurityContextArgumentResolver(new ReactiveAdapterRegistry()))) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + .argumentResolvers((resolvers) -> resolvers + .addCustomResolver(new CurrentSecurityContextArgumentResolver(new ReactiveAdapterRegistry()))) + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); @Test public void mockOpaqueTokenWhenUsingDefaultsThenBearerTokenAuthentication() { - this.client - .mutateWith(mockOpaqueToken()) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat(context.getAuthentication()).isInstanceOf( - BearerTokenAuthentication.class); + this.client.mutateWith(SecurityMockServerConfigurers.mockOpaqueToken()).get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat(context.getAuthentication()).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication token = (BearerTokenAuthentication) context.getAuthentication(); assertThat(token.getAuthorities()).isNotEmpty(); assertThat(token.getToken()).isNotNull(); - assertThat(token.getTokenAttributes().get(SUBJECT)).isEqualTo("user"); + assertThat(token.getTokenAttributes().get(OAuth2IntrospectionClaimNames.SUBJECT)).isEqualTo("user"); } @Test public void mockOpaqueTokenWhenAuthoritiesThenBearerTokenAuthentication() { this.client - .mutateWith(mockOpaqueToken() - .authorities(this.authority1, this.authority2)) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(this.authority1, this.authority2); + .mutateWith( + SecurityMockServerConfigurers.mockOpaqueToken().authorities(this.authority1, this.authority2)) + .get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1, + this.authority2); } @Test public void mockOpaqueTokenWhenAttributesThenBearerTokenAuthentication() { String sub = new String("my-subject"); this.client - .mutateWith(mockOpaqueToken() - .attributes(attributes -> attributes.put(SUBJECT, sub))) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); + .mutateWith(SecurityMockServerConfigurers.mockOpaqueToken() + .attributes((attributes) -> attributes.put(OAuth2IntrospectionClaimNames.SUBJECT, sub))) + .get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); assertThat(context.getAuthentication()).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication token = (BearerTokenAuthentication) context.getAuthentication(); - assertThat(token.getTokenAttributes().get(SUBJECT)).isSameAs(sub); + assertThat(token.getTokenAttributes().get(OAuth2IntrospectionClaimNames.SUBJECT)).isSameAs(sub); } @Test public void mockOpaqueTokenWhenPrincipalThenBearerTokenAuthentication() { - OAuth2AuthenticatedPrincipal principal = active(); - this.client - .mutateWith(mockOpaqueToken() - .principal(principal)) - .get() - .exchange() + OAuth2AuthenticatedPrincipal principal = TestOAuth2AuthenticatedPrincipals.active(); + this.client.mutateWith(SecurityMockServerConfigurers.mockOpaqueToken().principal(principal)).get().exchange() .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); + SecurityContext context = this.securityContextController.removeSecurityContext(); assertThat(context.getAuthentication()).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication token = (BearerTokenAuthentication) context.getAuthentication(); assertThat(token.getPrincipal()).isSameAs(principal); @@ -124,34 +104,27 @@ public class SecurityMockServerConfigurerOpaqueTokenTests extends AbstractMockSe @Test public void mockOpaqueTokenWhenPrincipalSpecifiedThenLastCalledTakesPrecedence() { - OAuth2AuthenticatedPrincipal principal = active(a -> a.put("scope", "user")); - + OAuth2AuthenticatedPrincipal principal = TestOAuth2AuthenticatedPrincipals + .active((a) -> a.put("scope", "user")); this.client - .mutateWith(mockOpaqueToken() - .attributes(a -> a.put(SUBJECT, "foo")) - .principal(principal)) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); + .mutateWith(SecurityMockServerConfigurers.mockOpaqueToken() + .attributes((a) -> a.put(OAuth2IntrospectionClaimNames.SUBJECT, "foo")).principal(principal)) + .get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); assertThat(context.getAuthentication()).isInstanceOf(BearerTokenAuthentication.class); BearerTokenAuthentication token = (BearerTokenAuthentication) context.getAuthentication(); - assertThat((String) ((OAuth2AuthenticatedPrincipal) token.getPrincipal()).getAttribute(SUBJECT)) - .isEqualTo(principal.getAttribute(SUBJECT)); - + assertThat((String) ((OAuth2AuthenticatedPrincipal) token.getPrincipal()) + .getAttribute(OAuth2IntrospectionClaimNames.SUBJECT)) + .isEqualTo(principal.getAttribute(OAuth2IntrospectionClaimNames.SUBJECT)); this.client - .mutateWith(mockOpaqueToken() - .principal(principal) - .attributes(a -> a.put(SUBJECT, "bar"))) - .get() - .exchange() - .expectStatus().isOk(); - - context = securityContextController.removeSecurityContext(); + .mutateWith(SecurityMockServerConfigurers.mockOpaqueToken().principal(principal) + .attributes((a) -> a.put(OAuth2IntrospectionClaimNames.SUBJECT, "bar"))) + .get().exchange().expectStatus().isOk(); + context = this.securityContextController.removeSecurityContext(); assertThat(context.getAuthentication()).isInstanceOf(BearerTokenAuthentication.class); token = (BearerTokenAuthentication) context.getAuthentication(); - assertThat((String) ((OAuth2AuthenticatedPrincipal) token.getPrincipal()).getAttribute(SUBJECT)) - .isEqualTo("bar"); + assertThat((String) ((OAuth2AuthenticatedPrincipal) token.getPrincipal()) + .getAttribute(OAuth2IntrospectionClaimNames.SUBJECT)).isEqualTo("bar"); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java index cce61fba7b..97735c1ab9 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java @@ -20,6 +20,7 @@ import java.util.concurrent.ForkJoinPool; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -31,9 +32,6 @@ import org.springframework.security.web.server.context.SecurityContextServerWebE import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockAuthentication; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; - /** * @author Rob Winch * @since 5.0 @@ -42,128 +40,75 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock @SecurityTestExecutionListeners public class SecurityMockServerConfigurersAnnotatedTests extends AbstractMockServerConfigurersTests { - WebTestClient client = WebTestClient - .bindToController(controller) - .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + WebTestClient client = WebTestClient.bindToController(this.controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); @Test @WithMockUser public void withMockUserWhenOnMethodThenSuccess() { - client - .get() - .exchange() - .expectStatus().isOk(); - + this.client.get().exchange().expectStatus().isOk(); Authentication authentication = TestSecurityContextHolder.getContext().getAuthentication(); - controller.assertPrincipalIsEqualTo(authentication); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser public void withMockUserWhenGlobalMockPrincipalThenOverridesAnnotation() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); - client = WebTestClient - .bindToController(controller) - .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .apply(mockAuthentication(authentication)) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); - - client - .get() - .exchange() - .expectStatus().isOk(); - - controller.assertPrincipalIsEqualTo(authentication); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", + "ROLE_USER"); + this.client = WebTestClient.bindToController(this.controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()) + .apply(SecurityMockServerConfigurers.mockAuthentication(authentication)).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); + this.client.get().exchange().expectStatus().isOk(); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser public void withMockUserWhenMutateWithMockPrincipalThenOverridesAnnotation() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); - client - .mutateWith(mockAuthentication(authentication)) - .get() - .exchange() - .expectStatus().isOk(); - - controller.assertPrincipalIsEqualTo(authentication); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", + "ROLE_USER"); + this.client.mutateWith(SecurityMockServerConfigurers.mockAuthentication(authentication)).get().exchange() + .expectStatus().isOk(); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser public void withMockUserWhenMutateWithMockPrincipalAndNoMutateThenOverridesAnnotationAndUsesAnnotation() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); - client - .mutateWith(mockAuthentication(authentication)) - .get() - .exchange() - .expectStatus().isOk(); - - controller.assertPrincipalIsEqualTo(authentication); - - - client - .get() - .exchange() - .expectStatus().isOk(); - - assertPrincipalCreatedFromUserDetails(controller.removePrincipal(), userBuilder.build()); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", + "ROLE_USER"); + this.client.mutateWith(SecurityMockServerConfigurers.mockAuthentication(authentication)).get().exchange() + .expectStatus().isOk(); + this.controller.assertPrincipalIsEqualTo(authentication); + this.client.get().exchange().expectStatus().isOk(); + assertPrincipalCreatedFromUserDetails(this.controller.removePrincipal(), this.userBuilder.build()); } @Test @WithMockUser public void withMockUserWhenOnMethodAndRequestIsExecutedOnDifferentThreadThenSuccess() { Authentication authentication = TestSecurityContextHolder.getContext().getAuthentication(); - ForkJoinPool - .commonPool() - .submit(() -> - client - .get() - .exchange() - .expectStatus() - .isOk() - ) - .join(); - - controller.assertPrincipalIsEqualTo(authentication); + ForkJoinPool.commonPool().submit(() -> this.client.get().exchange().expectStatus().isOk()).join(); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser public void withMockUserAndWithCallOnSeparateThreadWhenMutateWithMockPrincipalAndNoMutateThenOverridesAnnotationAndUsesAnnotation() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); - - ForkJoinPool - .commonPool() - .submit(() -> - client - .mutateWith(mockAuthentication(authentication)) - .get() - .exchange() - .expectStatus().isOk() - ) - .join(); - - controller.assertPrincipalIsEqualTo(authentication); - - - ForkJoinPool - .commonPool() - .submit(() -> - client - .get() - .exchange() - .expectStatus().isOk() - ) - .join(); - - assertPrincipalCreatedFromUserDetails(controller.removePrincipal(), userBuilder.build()); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", + "ROLE_USER"); + ForkJoinPool.commonPool() + .submit(() -> this.client.mutateWith(SecurityMockServerConfigurers.mockAuthentication(authentication)) + .get().exchange().expectStatus().isOk()) + .join(); + this.controller.assertPrincipalIsEqualTo(authentication); + ForkJoinPool.commonPool().submit(() -> this.client.get().exchange().expectStatus().isOk()).join(); + assertPrincipalCreatedFromUserDetails(this.controller.removePrincipal(), this.userBuilder.build()); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java index 25c2afae1d..a96f00493a 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java @@ -16,8 +16,11 @@ package org.springframework.security.test.web.reactive.server; +import java.security.Principal; + import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.core.Authentication; @@ -28,11 +31,7 @@ import org.springframework.security.web.server.context.SecurityContextServerWebE import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; -import java.security.Principal; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockUser; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; /** * @author Rob Winch @@ -42,49 +41,37 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock @RunWith(SpringRunner.class) @SecurityTestExecutionListeners public class SecurityMockServerConfigurersClassAnnotatedTests extends AbstractMockServerConfigurersTests { - WebTestClient client = WebTestClient - .bindToController(controller) - .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + + WebTestClient client = WebTestClient.bindToController(this.controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); @Test public void wheMockUserWhenClassAnnotatedThenSuccess() { - client - .get() - .exchange() - .expectStatus().isOk() - .expectBody(String.class).consumeWith( response -> assertThat(response.getResponseBody()).contains("\"username\":\"user\"")); - + this.client.get().exchange().expectStatus().isOk().expectBody(String.class) + .consumeWith((response) -> assertThat(response.getResponseBody()).contains("\"username\":\"user\"")); Authentication authentication = TestSecurityContextHolder.getContext().getAuthentication(); - controller.assertPrincipalIsEqualTo(authentication); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser("method-user") public void withMockUserWhenClassAndMethodAnnotationThenMethodOverrides() { - client - .get() - .exchange() - .expectStatus().isOk() - .expectBody(String.class).consumeWith( response -> assertThat(response.getResponseBody()).contains("\"username\":\"method-user\"")); - + this.client.get().exchange().expectStatus().isOk().expectBody(String.class).consumeWith( + (response) -> assertThat(response.getResponseBody()).contains("\"username\":\"method-user\"")); Authentication authentication = TestSecurityContextHolder.getContext().getAuthentication(); - controller.assertPrincipalIsEqualTo(authentication); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test public void withMockUserWhenMutateWithThenMustateWithOverrides() { - client - .mutateWith(mockUser("mutateWith-mockUser")) - .get() - .exchange() - .expectStatus().isOk() - .expectBody(String.class).consumeWith( response -> assertThat(response.getResponseBody()).contains("\"username\":\"mutateWith-mockUser\"")); - - Principal principal = controller.removePrincipal(); - assertPrincipalCreatedFromUserDetails(principal, userBuilder.username("mutateWith-mockUser").build()); + this.client.mutateWith(SecurityMockServerConfigurers.mockUser("mutateWith-mockUser")).get().exchange() + .expectStatus().isOk().expectBody(String.class) + .consumeWith((response) -> assertThat(response.getResponseBody()) + .contains("\"username\":\"mutateWith-mockUser\"")); + Principal principal = this.controller.removePrincipal(); + assertPrincipalCreatedFromUserDetails(principal, this.userBuilder.username("mutateWith-mockUser").build()); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersJwtTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersJwtTests.java index e92100fa39..7a0aa28c8b 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersJwtTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersJwtTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.reactive.server; import java.util.Arrays; @@ -37,8 +38,6 @@ import org.springframework.security.web.server.context.SecurityContextServerWebE import org.springframework.test.web.reactive.server.WebTestClient; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockJwt; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; /** * @author Jérôme Wacongne <ch4mp@c4-soft.com> @@ -47,33 +46,25 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock */ @RunWith(MockitoJUnitRunner.class) public class SecurityMockServerConfigurersJwtTests extends AbstractMockServerConfigurersTests { + @Mock GrantedAuthority authority1; @Mock GrantedAuthority authority2; - WebTestClient client = WebTestClient - .bindToController(securityContextController) + WebTestClient client = WebTestClient.bindToController(this.securityContextController) .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .argumentResolvers(resolvers -> resolvers.addCustomResolver( - new CurrentSecurityContextArgumentResolver(new ReactiveAdapterRegistry()))) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + .argumentResolvers((resolvers) -> resolvers + .addCustomResolver(new CurrentSecurityContextArgumentResolver(new ReactiveAdapterRegistry()))) + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); @Test public void mockJwtWhenUsingDefaultsTheCreatesJwtAuthentication() { - client - .mutateWith(mockJwt()) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat(context.getAuthentication()).isInstanceOf( - JwtAuthenticationToken.class); + this.client.mutateWith(SecurityMockServerConfigurers.mockJwt()).get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat(context.getAuthentication()).isInstanceOf(JwtAuthenticationToken.class); JwtAuthenticationToken token = (JwtAuthenticationToken) context.getAuthentication(); assertThat(token.getAuthorities()).isNotEmpty(); assertThat(token.getToken()).isNotNull(); @@ -84,79 +75,57 @@ public class SecurityMockServerConfigurersJwtTests extends AbstractMockServerCon @Test public void mockJwtWhenProvidingBuilderConsumerThenProducesJwtAuthentication() { String name = new String("user"); - client - .mutateWith(mockJwt().jwt(jwt -> jwt.subject(name))) - .get() - .exchange() + this.client.mutateWith(SecurityMockServerConfigurers.mockJwt().jwt((jwt) -> jwt.subject(name))).get().exchange() .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat(context.getAuthentication()).isInstanceOf( - JwtAuthenticationToken.class); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat(context.getAuthentication()).isInstanceOf(JwtAuthenticationToken.class); JwtAuthenticationToken token = (JwtAuthenticationToken) context.getAuthentication(); assertThat(token.getToken().getSubject()).isSameAs(name); } @Test public void mockJwtWhenProvidingCustomAuthoritiesThenProducesJwtAuthentication() { - client - .mutateWith(mockJwt().jwt(jwt -> jwt.claim("scope", "ignored authorities")) - .authorities(this.authority1, this.authority2)) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(this.authority1, this.authority2); + this.client.mutateWith(SecurityMockServerConfigurers.mockJwt() + .jwt((jwt) -> jwt.claim("scope", "ignored authorities")).authorities(this.authority1, this.authority2)) + .get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1, + this.authority2); } @Test public void mockJwtWhenProvidingScopedAuthoritiesThenProducesJwtAuthentication() { - client - .mutateWith(mockJwt().jwt(jwt -> jwt.claim("scope", "scoped authorities"))) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(new SimpleGrantedAuthority("SCOPE_scoped"), - new SimpleGrantedAuthority("SCOPE_authorities")); + this.client + .mutateWith( + SecurityMockServerConfigurers.mockJwt().jwt((jwt) -> jwt.claim("scope", "scoped authorities"))) + .get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly( + new SimpleGrantedAuthority("SCOPE_scoped"), new SimpleGrantedAuthority("SCOPE_authorities")); } @Test public void mockJwtWhenProvidingGrantedAuthoritiesThenProducesJwtAuthentication() { - client - .mutateWith(mockJwt().jwt(jwt -> jwt.claim("scope", "ignored authorities")) - .authorities(jwt -> Arrays.asList(this.authority1))) - .get() - .exchange() - .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(this.authority1); + this.client + .mutateWith( + SecurityMockServerConfigurers.mockJwt().jwt((jwt) -> jwt.claim("scope", "ignored authorities")) + .authorities((jwt) -> Arrays.asList(this.authority1))) + .get().exchange().expectStatus().isOk(); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1); } @Test public void mockJwtWhenProvidingPreparedJwtThenProducesJwtAuthentication() { - Jwt originalToken = TestJwts.jwt() - .header("header1", "value1") - .subject("some_user") - .build(); - this.client - .mutateWith(mockJwt().jwt(originalToken)) - .get() - .exchange() + Jwt originalToken = TestJwts.jwt().header("header1", "value1").subject("some_user").build(); + this.client.mutateWith(SecurityMockServerConfigurers.mockJwt().jwt(originalToken)).get().exchange() .expectStatus().isOk(); - - SecurityContext context = securityContextController.removeSecurityContext(); - assertThat(context.getAuthentication()).isInstanceOf( - JwtAuthenticationToken.class); + SecurityContext context = this.securityContextController.removeSecurityContext(); + assertThat(context.getAuthentication()).isInstanceOf(JwtAuthenticationToken.class); JwtAuthenticationToken retrievedToken = (JwtAuthenticationToken) context.getAuthentication(); assertThat(retrievedToken.getToken().getSubject()).isEqualTo("some_user"); assertThat(retrievedToken.getToken().getTokenValue()).isEqualTo("token"); assertThat(retrievedToken.getToken().getHeaders().get("header1")).isEqualTo("value1"); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java index c860da6b46..8990a7101e 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java @@ -30,9 +30,11 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; @@ -42,18 +44,15 @@ import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Client; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; @RunWith(MockitoJUnitRunner.class) public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMockServerConfigurersTests { + private OAuth2LoginController controller = new OAuth2LoginController(); @Mock @@ -66,37 +65,26 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock @Before public void setup() { - this.client = WebTestClient - .bindToController(this.controller) - .argumentResolvers(c -> c.addCustomResolver( - new OAuth2AuthorizedClientArgumentResolver - (this.clientRegistrationRepository, this.authorizedClientRepository))) + this.client = WebTestClient.bindToController(this.controller) + .argumentResolvers((c) -> c.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver( + this.clientRegistrationRepository, this.authorizedClientRepository))) .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); } @Test - public void oauth2ClientWhenUsingDefaultsThenException() - throws Exception { - + public void oauth2ClientWhenUsingDefaultsThenException() throws Exception { WebHttpHandlerBuilder builder = WebHttpHandlerBuilder.webHandler(new DispatcherHandler()); - assertThatCode(() -> mockOAuth2Client().beforeServerCreated(builder)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ClientRegistration"); + assertThatIllegalArgumentException() + .isThrownBy(() -> SecurityMockServerConfigurers.mockOAuth2Client().beforeServerCreated(builder)) + .withMessageContaining("ClientRegistration"); } @Test - public void oauth2ClientWhenUsingRegistrationIdThenProducesAuthorizedClient() - throws Exception { - - this.client.mutateWith(mockOAuth2Client("registration-id")) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + public void oauth2ClientWhenUsingRegistrationIdThenProducesAuthorizedClient() throws Exception { + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Client("registration-id")).get().uri("/client") + .exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); @@ -105,16 +93,11 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock } @Test - public void oauth2ClientWhenClientRegistrationThenUses() - throws Exception { - - ClientRegistration clientRegistration = clientRegistration() + public void oauth2ClientWhenClientRegistrationThenUses() throws Exception { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() .registrationId("registration-id").clientId("client-id").build(); - this.client.mutateWith(mockOAuth2Client().clientRegistration(clientRegistration)) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Client().clientRegistration(clientRegistration)) + .get().uri("/client").exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); @@ -123,15 +106,11 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock } @Test - public void oauth2ClientWhenClientRegistrationConsumerThenUses() - throws Exception { - - this.client.mutateWith(mockOAuth2Client("registration-id") - .clientRegistration(c -> c.clientId("client-id"))) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + public void oauth2ClientWhenClientRegistrationConsumerThenUses() throws Exception { + this.client + .mutateWith(SecurityMockServerConfigurers.mockOAuth2Client("registration-id") + .clientRegistration((c) -> c.clientId("client-id"))) + .get().uri("/client").exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); @@ -142,25 +121,19 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock @Test public void oauth2ClientWhenPrincipalNameThenUses() throws Exception { - this.client.mutateWith(mockOAuth2Client("registration-id") - .principalName("test-subject")) - .get().uri("/client") - .exchange() - .expectStatus().isOk() - .expectBody(String.class).isEqualTo("test-subject"); + this.client + .mutateWith( + SecurityMockServerConfigurers.mockOAuth2Client("registration-id").principalName("test-subject")) + .get().uri("/client").exchange().expectStatus().isOk().expectBody(String.class) + .isEqualTo("test-subject"); } @Test - public void oauth2ClientWhenAccessTokenThenUses() - throws Exception { - - OAuth2AccessToken accessToken = noScopes(); - this.client.mutateWith(mockOAuth2Client("registration-id") - .accessToken(accessToken)) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + public void oauth2ClientWhenAccessTokenThenUses() throws Exception { + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); + this.client + .mutateWith(SecurityMockServerConfigurers.mockOAuth2Client("registration-id").accessToken(accessToken)) + .get().uri("/client").exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); @@ -170,39 +143,35 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock @Test public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exception { - this.client.mutateWith(mockOAuth2Client("registration-id")) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Client("registration-id")).get().uri("/client") + .exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getClientId()).isEqualTo("test-client"); - - client = new OAuth2AuthorizedClient(clientRegistration().build(), "sub", noScopes()); - when(this.authorizedClientRepository - .loadAuthorizedClient(eq("registration-id"), any(Authentication.class), any(ServerWebExchange.class))) - .thenReturn(Mono.just(client)); - this.client - .get().uri("/client") - .exchange() - .expectStatus().isOk(); + client = new OAuth2AuthorizedClient(TestClientRegistrations.clientRegistration().build(), "sub", + TestOAuth2AccessTokens.noScopes()); + given(this.authorizedClientRepository.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), + any(ServerWebExchange.class))).willReturn(Mono.just(client)); + this.client.get().uri("/client").exchange().expectStatus().isOk(); client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getClientId()).isEqualTo("client-id"); - verify(this.authorizedClientRepository).loadAuthorizedClient( - eq("registration-id"), any(Authentication.class), any(ServerWebExchange.class)); + verify(this.authorizedClientRepository).loadAuthorizedClient(eq("registration-id"), any(Authentication.class), + any(ServerWebExchange.class)); } @RestController static class OAuth2LoginController { + volatile OAuth2AuthorizedClient authorizedClient; @GetMapping("/client") - String authorizedClient - (@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { + String authorizedClient( + @RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { this.authorizedClient = authorizedClient; return authorizedClient.getPrincipalName(); } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2LoginTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2LoginTests.java index 2e13d5a0de..0820d65d50 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2LoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2LoginTests.java @@ -44,11 +44,10 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Login; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; @RunWith(MockitoJUnitRunner.class) public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockServerConfigurersTests { + private OAuth2LoginController controller = new OAuth2LoginController(); @Mock @@ -61,42 +60,31 @@ public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockS @Before public void setup() { - this.client = WebTestClient - .bindToController(this.controller) - .argumentResolvers(c -> c.addCustomResolver( - new OAuth2AuthorizedClientArgumentResolver - (this.clientRegistrationRepository, this.authorizedClientRepository))) + this.client = WebTestClient.bindToController(this.controller) + .argumentResolvers((c) -> c.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver( + this.clientRegistrationRepository, this.authorizedClientRepository))) .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); } @Test public void oauth2LoginWhenUsingDefaultsThenProducesDefaultAuthentication() { - this.client.mutateWith(mockOAuth2Login()) - .get().uri("/token") - .exchange() + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login()).get().uri("/token").exchange() .expectStatus().isOk(); - OAuth2AuthenticationToken token = this.controller.token; assertThat(token).isNotNull(); assertThat(token.getAuthorizedClientRegistrationId()).isEqualTo("test"); assertThat(token.getPrincipal()).isInstanceOf(OAuth2User.class); - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("sub", "user"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("sub", "user"); assertThat((Collection) token.getPrincipal().getAuthorities()) .contains(new SimpleGrantedAuthority("SCOPE_read")); } @Test public void oauth2LoginWhenUsingDefaultsThenProducesDefaultAuthorizedClient() { - this.client.mutateWith(mockOAuth2Login()) - .get().uri("/client") - .exchange() + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login()).get().uri("/client").exchange() .expectStatus().isOk(); - OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("test"); @@ -106,12 +94,10 @@ public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockS @Test public void oauth2LoginWhenAuthoritiesSpecifiedThenGrantsAccess() { - this.client.mutateWith(mockOAuth2Login() - .authorities(new SimpleGrantedAuthority("SCOPE_admin"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + this.client + .mutateWith(SecurityMockServerConfigurers.mockOAuth2Login() + .authorities(new SimpleGrantedAuthority("SCOPE_admin"))) + .get().uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; assertThat((Collection) token.getPrincipal().getAuthorities()) .contains(new SimpleGrantedAuthority("SCOPE_admin")); @@ -119,78 +105,48 @@ public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockS @Test public void oauth2LoginWhenAttributeSpecifiedThenUserHasAttribute() { - this.client.mutateWith(mockOAuth2Login() - .attributes(a -> a.put("iss", "https://idp.example.org"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + this.client + .mutateWith(SecurityMockServerConfigurers.mockOAuth2Login() + .attributes((a) -> a.put("iss", "https://idp.example.org"))) + .get().uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("iss", "https://idp.example.org"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("iss", "https://idp.example.org"); } @Test public void oauth2LoginWhenNameSpecifiedThenUserHasName() throws Exception { - OAuth2User oauth2User = new DefaultOAuth2User( - AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), - Collections.singletonMap("custom-attribute", "test-subject"), - "custom-attribute"); - - this.client.mutateWith(mockOAuth2Login() - .oauth2User(oauth2User)) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), + Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"); + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login().oauth2User(oauth2User)).get() + .uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getName()) - .isEqualTo("test-subject"); - - this.client.mutateWith(mockOAuth2Login() - .oauth2User(oauth2User)) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + assertThat(token.getPrincipal().getName()).isEqualTo("test-subject"); + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login().oauth2User(oauth2User)).get() + .uri("/client").exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; - assertThat(client.getPrincipalName()) - .isEqualTo("test-subject"); + assertThat(client.getPrincipalName()).isEqualTo("test-subject"); } @Test public void oauth2LoginWhenOAuth2UserSpecifiedThenLastCalledTakesPrecedence() throws Exception { - OAuth2User oauth2User = new DefaultOAuth2User( - AuthorityUtils.createAuthorityList("SCOPE_read"), - Collections.singletonMap("sub", "subject"), - "sub"); - - this.client.mutateWith(mockOAuth2Login() - .attributes(a -> a.put("subject", "foo")) - .oauth2User(oauth2User)) - .get().uri("/token") - .exchange() + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("SCOPE_read"), + Collections.singletonMap("sub", "subject"), "sub"); + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login() + .attributes((a) -> a.put("subject", "foo")).oauth2User(oauth2User)).get().uri("/token").exchange() .expectStatus().isOk(); - OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("sub", "subject"); - - this.client.mutateWith(mockOAuth2Login() - .oauth2User(oauth2User) - .attributes(a -> a.put("sub", "bar"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + assertThat(token.getPrincipal().getAttributes()).containsEntry("sub", "subject"); + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login().oauth2User(oauth2User) + .attributes((a) -> a.put("sub", "bar"))).get().uri("/token").exchange().expectStatus().isOk(); token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("sub", "bar"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("sub", "bar"); } @RestController static class OAuth2LoginController { + volatile OAuth2AuthenticationToken token; + volatile OAuth2AuthorizedClient authorizedClient; @GetMapping("/token") @@ -200,10 +156,11 @@ public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockS } @GetMapping("/client") - String authorizedClient - (@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { + String authorizedClient(@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { this.authorizedClient = authorizedClient; return authorizedClient.getPrincipalName(); } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOidcLoginTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOidcLoginTests.java index 3388f5518d..b9361bfbdd 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOidcLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOidcLoginTests.java @@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; @@ -44,13 +45,10 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Login; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOidcLogin; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; @RunWith(MockitoJUnitRunner.class) public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockServerConfigurersTests { + private OAuth2LoginController controller = new OAuth2LoginController(); @Mock @@ -63,44 +61,32 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer @Before public void setup() { - this.client = WebTestClient - .bindToController(this.controller) - .argumentResolvers(c -> c.addCustomResolver( - new OAuth2AuthorizedClientArgumentResolver - (this.clientRegistrationRepository, this.authorizedClientRepository))) + this.client = WebTestClient.bindToController(this.controller) + .argumentResolvers((c) -> c.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver( + this.clientRegistrationRepository, this.authorizedClientRepository))) .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); } @Test public void oidcLoginWhenUsingDefaultsThenProducesDefaultAuthentication() { - this.client.mutateWith(mockOidcLogin()) - .get().uri("/token") - .exchange() + this.client.mutateWith(SecurityMockServerConfigurers.mockOidcLogin()).get().uri("/token").exchange() .expectStatus().isOk(); - OAuth2AuthenticationToken token = this.controller.token; assertThat(token).isNotNull(); assertThat(token.getAuthorizedClientRegistrationId()).isEqualTo("test"); assertThat(token.getPrincipal()).isInstanceOf(OidcUser.class); - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("sub", "user"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("sub", "user"); assertThat((Collection) token.getPrincipal().getAuthorities()) .contains(new SimpleGrantedAuthority("SCOPE_read")); - assertThat(((OidcUser) token.getPrincipal()).getIdToken().getTokenValue()) - .isEqualTo("id-token"); + assertThat(((OidcUser) token.getPrincipal()).getIdToken().getTokenValue()).isEqualTo("id-token"); } @Test public void oidcLoginWhenUsingDefaultsThenProducesDefaultAuthorizedClient() { - this.client.mutateWith(mockOidcLogin()) - .get().uri("/client") - .exchange() + this.client.mutateWith(SecurityMockServerConfigurers.mockOidcLogin()).get().uri("/client").exchange() .expectStatus().isOk(); - OAuth2AuthorizedClient client = this.controller.authorizedClient; assertThat(client).isNotNull(); assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("test"); @@ -110,12 +96,10 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer @Test public void oidcLoginWhenAuthoritiesSpecifiedThenGrantsAccess() { - this.client.mutateWith(mockOidcLogin() - .authorities(new SimpleGrantedAuthority("SCOPE_admin"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + this.client + .mutateWith(SecurityMockServerConfigurers.mockOidcLogin() + .authorities(new SimpleGrantedAuthority("SCOPE_admin"))) + .get().uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; assertThat((Collection) token.getPrincipal().getAuthorities()) .contains(new SimpleGrantedAuthority("SCOPE_admin")); @@ -123,90 +107,60 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer @Test public void oidcLoginWhenIdTokenSpecifiedThenUserHasClaims() { - this.client.mutateWith(mockOidcLogin() - .idToken(i -> i.issuer("https://idp.example.org"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + this.client + .mutateWith(SecurityMockServerConfigurers.mockOidcLogin() + .idToken((i) -> i.issuer("https://idp.example.org"))) + .get().uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("iss", "https://idp.example.org"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("iss", "https://idp.example.org"); } @Test public void oidcLoginWhenUserInfoSpecifiedThenUserHasClaims() throws Exception { - this.client.mutateWith(mockOidcLogin() - .userInfoToken(u -> u.email("email@email"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + this.client + .mutateWith(SecurityMockServerConfigurers.mockOidcLogin().userInfoToken((u) -> u.email("email@email"))) + .get().uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("email", "email@email"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("email", "email@email"); } @Test public void oidcUserWhenNameSpecifiedThenUserHasName() throws Exception { - OidcUser oidcUser = new DefaultOidcUser( - AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), + OidcUser oidcUser = new DefaultOidcUser(AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), OidcIdToken.withTokenValue("id-token").claim("custom-attribute", "test-subject").build(), "custom-attribute"); - - this.client.mutateWith(mockOAuth2Login() - .oauth2User(oidcUser)) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login().oauth2User(oidcUser)).get().uri("/token") + .exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getName()) - .isEqualTo("test-subject"); - - this.client.mutateWith(mockOAuth2Login() - .oauth2User(oidcUser)) - .get().uri("/client") - .exchange() - .expectStatus().isOk(); - + assertThat(token.getPrincipal().getName()).isEqualTo("test-subject"); + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Login().oauth2User(oidcUser)).get() + .uri("/client").exchange().expectStatus().isOk(); OAuth2AuthorizedClient client = this.controller.authorizedClient; - assertThat(client.getPrincipalName()) - .isEqualTo("test-subject"); + assertThat(client.getPrincipalName()).isEqualTo("test-subject"); } // gh-7794 @Test public void oidcLoginWhenOidcUserSpecifiedThenLastCalledTakesPrecedence() throws Exception { - OidcUser oidcUser = new DefaultOidcUser( - AuthorityUtils.createAuthorityList("SCOPE_read"), idToken().build()); - - this.client.mutateWith(mockOidcLogin() - .idToken(i -> i.subject("foo")) - .oidcUser(oidcUser)) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + OidcUser oidcUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("SCOPE_read"), + TestOidcIdTokens.idToken().build()); + this.client.mutateWith( + SecurityMockServerConfigurers.mockOidcLogin().idToken((i) -> i.subject("foo")).oidcUser(oidcUser)).get() + .uri("/token").exchange().expectStatus().isOk(); OAuth2AuthenticationToken token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("sub", "subject"); - - this.client.mutateWith(mockOidcLogin() - .oidcUser(oidcUser) - .idToken(i -> i.subject("bar"))) - .get().uri("/token") - .exchange() - .expectStatus().isOk(); - + assertThat(token.getPrincipal().getAttributes()).containsEntry("sub", "subject"); + this.client.mutateWith( + SecurityMockServerConfigurers.mockOidcLogin().oidcUser(oidcUser).idToken((i) -> i.subject("bar"))).get() + .uri("/token").exchange().expectStatus().isOk(); token = this.controller.token; - assertThat(token.getPrincipal().getAttributes()) - .containsEntry("sub", "bar"); + assertThat(token.getPrincipal().getAttributes()).containsEntry("sub", "bar"); } @RestController static class OAuth2LoginController { + volatile OAuth2AuthenticationToken token; + volatile OAuth2AuthorizedClient authorizedClient; @GetMapping("/token") @@ -216,10 +170,11 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer } @GetMapping("/client") - String authorizedClient - (@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { + String authorizedClient(@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) { this.authorizedClient = authorizedClient; return authorizedClient.getPrincipalName(); } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java index 11203e8d09..3b11ff0c73 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java @@ -16,7 +16,10 @@ package org.springframework.security.test.web.reactive.server; +import java.security.Principal; + import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -27,157 +30,99 @@ import org.springframework.security.web.server.context.SecurityContextServerWebE import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.test.web.reactive.server.WebTestClient; -import java.security.Principal; - import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*; /** * @author Rob Winch * @since 5.0 */ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests { - WebTestClient client = WebTestClient - .bindToController(controller) - .webFilter( new CsrfWebFilter(), new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); + + WebTestClient client = WebTestClient.bindToController(this.controller) + .webFilter(new CsrfWebFilter(), new SecurityContextServerWebExchangeWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); @Test public void mockAuthenticationWhenLocalThenSuccess() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); - client - .mutateWith(mockAuthentication(authentication)) - .get() - .exchange() - .expectStatus().isOk(); - controller.assertPrincipalIsEqualTo(authentication); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", + "ROLE_USER"); + this.client.mutateWith(SecurityMockServerConfigurers.mockAuthentication(authentication)).get().exchange() + .expectStatus().isOk(); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test public void mockAuthenticationWhenGlobalThenSuccess() { - TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); - client = WebTestClient - .bindToController(controller) - .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .apply(mockAuthentication(authentication)) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); - client - .get() - .exchange() - .expectStatus().isOk(); - controller.assertPrincipalIsEqualTo(authentication); + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", + "ROLE_USER"); + this.client = WebTestClient.bindToController(this.controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()) + .apply(SecurityMockServerConfigurers.mockAuthentication(authentication)).configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); + this.client.get().exchange().expectStatus().isOk(); + this.controller.assertPrincipalIsEqualTo(authentication); } @Test public void mockUserWhenDefaultsThenSuccess() { - client - .mutateWith(mockUser()) - .get() - .exchange() - .expectStatus().isOk(); - - Principal actual = controller.removePrincipal(); - - assertPrincipalCreatedFromUserDetails(actual, userBuilder.build()); + this.client.mutateWith(SecurityMockServerConfigurers.mockUser()).get().exchange().expectStatus().isOk(); + Principal actual = this.controller.removePrincipal(); + assertPrincipalCreatedFromUserDetails(actual, this.userBuilder.build()); } @Test public void mockUserWhenGlobalThenSuccess() { - client = WebTestClient - .bindToController(controller) - .webFilter(new SecurityContextServerWebExchangeWebFilter()) - .apply(springSecurity()) - .apply(mockUser()) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); - client - .get() - .exchange() - .expectStatus().isOk(); - - Principal actual = controller.removePrincipal(); - - assertPrincipalCreatedFromUserDetails(actual, userBuilder.build()); + this.client = WebTestClient.bindToController(this.controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()).apply(SecurityMockServerConfigurers.mockUser()) + .configureClient().defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); + this.client.get().exchange().expectStatus().isOk(); + Principal actual = this.controller.removePrincipal(); + assertPrincipalCreatedFromUserDetails(actual, this.userBuilder.build()); } @Test public void mockUserStringWhenLocalThenSuccess() { - client - .mutateWith(mockUser(userBuilder.build().getUsername())) - .get() - .exchange() - .expectStatus().isOk(); - - Principal actual = controller.removePrincipal(); - - assertPrincipalCreatedFromUserDetails(actual, userBuilder.build()); + this.client.mutateWith(SecurityMockServerConfigurers.mockUser(this.userBuilder.build().getUsername())).get() + .exchange().expectStatus().isOk(); + Principal actual = this.controller.removePrincipal(); + assertPrincipalCreatedFromUserDetails(actual, this.userBuilder.build()); } @Test public void mockUserStringWhenCustomThenSuccess() { this.userBuilder = User.withUsername("admin").password("secret").roles("USER", "ADMIN"); - client - .mutateWith(mockUser("admin").password("secret").roles("USER", "ADMIN")) - .get() - .exchange() - .expectStatus().isOk(); - - Principal actual = controller.removePrincipal(); - - assertPrincipalCreatedFromUserDetails(actual, userBuilder.build()); + this.client + .mutateWith(SecurityMockServerConfigurers.mockUser("admin").password("secret").roles("USER", "ADMIN")) + .get().exchange().expectStatus().isOk(); + Principal actual = this.controller.removePrincipal(); + assertPrincipalCreatedFromUserDetails(actual, this.userBuilder.build()); } @Test public void mockUserUserDetailsLocalThenSuccess() { UserDetails userDetails = this.userBuilder.build(); - client - .mutateWith(mockUser(userDetails)) - .get() - .exchange() - .expectStatus().isOk(); - - Principal actual = controller.removePrincipal(); - - assertPrincipalCreatedFromUserDetails(actual, userBuilder.build()); + this.client.mutateWith(SecurityMockServerConfigurers.mockUser(userDetails)).get().exchange().expectStatus() + .isOk(); + Principal actual = this.controller.removePrincipal(); + assertPrincipalCreatedFromUserDetails(actual, this.userBuilder.build()); } @Test public void csrfWhenMutateWithThenDisablesCsrf() { - this.client - .post() - .exchange() - .expectStatus().isEqualTo(HttpStatus.FORBIDDEN) - .expectBody().consumeWith( b -> assertThat(new String(b.getResponseBody())).contains("CSRF")); - - this.client - .mutateWith(csrf()) - .post() - .exchange() - .expectStatus().isOk(); - + this.client.post().exchange().expectStatus().isEqualTo(HttpStatus.FORBIDDEN).expectBody() + .consumeWith((b) -> assertThat(new String(b.getResponseBody())).contains("CSRF")); + this.client.mutateWith(SecurityMockServerConfigurers.csrf()).post().exchange().expectStatus().isOk(); } @Test public void csrfWhenGlobalThenDisablesCsrf() { - this.client = WebTestClient - .bindToController(this.controller) - .webFilter(new CsrfWebFilter()) - .apply(springSecurity()) - .apply(csrf()) - .configureClient() - .build(); - - this.client - .get() - .exchange() - .expectStatus().isOk(); - + this.client = WebTestClient.bindToController(this.controller).webFilter(new CsrfWebFilter()) + .apply(SecurityMockServerConfigurers.springSecurity()).apply(SecurityMockServerConfigurers.csrf()) + .configureClient().build(); + this.client.get().exchange().expectStatus().isOk(); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/Sec2935Tests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/Sec2935Tests.java index 03499732a8..ebca73b566 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/Sec2935Tests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/Sec2935Tests.java @@ -13,19 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.request; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.test.web.servlet.request; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; @@ -40,11 +35,16 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + /** * @author Rob Winch */ - - @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration @WebAppConfiguration @@ -57,89 +57,55 @@ public class Sec2935Tests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .apply(springSecurity()) - .build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } // SEC-2935 @Test public void postProcessorUserNoUser() throws Exception { - mvc - .perform(get("/admin/abc").with(user("user").roles("ADMIN", "USER"))) - .andExpect(status().isNotFound()) - .andExpect(authenticated().withUsername("user")); - - mvc - .perform(get("/admin/abc")) - .andExpect(status().isUnauthorized()) - .andExpect(unauthenticated()); + this.mvc.perform(get("/admin/abc").with(user("user").roles("ADMIN", "USER"))).andExpect(status().isNotFound()) + .andExpect(authenticated().withUsername("user")); + this.mvc.perform(get("/admin/abc")).andExpect(status().isUnauthorized()).andExpect(unauthenticated()); } @Test public void postProcessorUserOtherUser() throws Exception { - mvc - .perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))) - .andExpect(status().isNotFound()) - .andExpect(authenticated().withUsername("user1")); - - mvc - .perform(get("/admin/abc").with(user("user2").roles("USER"))) - .andExpect(status().isForbidden()) - .andExpect(authenticated().withUsername("user2")); + this.mvc.perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))).andExpect(status().isNotFound()) + .andExpect(authenticated().withUsername("user1")); + this.mvc.perform(get("/admin/abc").with(user("user2").roles("USER"))).andExpect(status().isForbidden()) + .andExpect(authenticated().withUsername("user2")); } @WithMockUser @Test public void postProcessorUserWithMockUser() throws Exception { - mvc - .perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))) - .andExpect(status().isNotFound()) - .andExpect(authenticated().withUsername("user1")); - - mvc - .perform(get("/admin/abc")) - .andExpect(status().isForbidden()) - .andExpect(authenticated().withUsername("user")); + this.mvc.perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))).andExpect(status().isNotFound()) + .andExpect(authenticated().withUsername("user1")); + this.mvc.perform(get("/admin/abc")).andExpect(status().isForbidden()) + .andExpect(authenticated().withUsername("user")); } // SEC-2941 @Test public void defaultRequest() throws Exception { - mvc = MockMvcBuilders.webAppContextSetup(context) - .apply(springSecurity()) - .defaultRequest(get("/").with(user("default"))) - .build(); - - mvc - .perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))) - .andExpect(status().isNotFound()) - .andExpect(authenticated().withUsername("user1")); - - mvc - .perform(get("/admin/abc")) - .andExpect(status().isForbidden()) - .andExpect(authenticated().withUsername("default")); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()) + .defaultRequest(get("/").with(user("default"))).build(); + this.mvc.perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))).andExpect(status().isNotFound()) + .andExpect(authenticated().withUsername("user1")); + this.mvc.perform(get("/admin/abc")).andExpect(status().isForbidden()) + .andExpect(authenticated().withUsername("default")); } @Ignore @WithMockUser @Test public void defaultRequestOverridesWithMockUser() throws Exception { - mvc = MockMvcBuilders.webAppContextSetup(context) - .apply(springSecurity()) - .defaultRequest(get("/").with(user("default"))) - .build(); - - mvc - .perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))) - .andExpect(status().isNotFound()) - .andExpect(authenticated().withUsername("user1")); - - mvc - .perform(get("/admin/abc")) - .andExpect(status().isForbidden()) - .andExpect(authenticated().withUsername("default")); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()) + .defaultRequest(get("/").with(user("default"))).build(); + this.mvc.perform(get("/admin/abc").with(user("user1").roles("ADMIN", "USER"))).andExpect(status().isNotFound()) + .andExpect(authenticated().withUsername("user1")); + this.mvc.perform(get("/admin/abc")).andExpect(status().isForbidden()) + .andExpect(authenticated().withUsername("default")); } @EnableWebSecurity @@ -148,17 +114,21 @@ public class Sec2935Tests { @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .httpBasic(); + // @formatter:on } @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { auth.inMemoryAuthentication(); } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index 3a00f6a080..e236a9295f 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; @@ -29,16 +33,15 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.test.web.servlet.setup.MockMvcBuilders; -import java.util.Arrays; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; public class SecurityMockMvcRequestBuildersFormLoginTests { + private MockServletContext servletContext; @Before @@ -51,29 +54,24 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { MockHttpServletRequest request = formLogin().buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); - assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())) - .isEqualTo(token.getToken()); + assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/login"); assertThat(request.getParameter("_csrf")).isNotNull(); } @Test public void custom() { - MockHttpServletRequest request = formLogin("/login").user("username", "admin") - .password("password", "secret").buildRequest(this.servletContext); - + MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") + .buildRequest(this.servletContext); CsrfToken token = (CsrfToken) request .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); - assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())) - .isEqualTo(token.getToken()); + assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/login"); } @@ -81,32 +79,26 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { public void customWithUriVars() { MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2") .user("username", "admin").password("password", "secret").buildRequest(this.servletContext); - CsrfToken token = (CsrfToken) request .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); - assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())) - .isEqualTo(token.getToken()); + assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/uri-login/val1/val2"); } /** - * spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together - * with our request builders. (gh-7572) + * spring-restdocs uses postprocessors to do its trick. It will work only if these are + * merged together with our request builders. (gh-7572) * @throws Exception */ @Test public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception { RequestPostProcessor postProcessor = mock(RequestPostProcessor.class); - when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0)); + given(postProcessor.postProcessRequest(any())).willAnswer((i) -> i.getArgument(0)); MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object()) - .defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor)) - .build(); - - + .defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor)).build(); MvcResult mvcResult = mockMvc.perform(formLogin()).andReturn(); assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name()); assertThat(mvcResult.getRequest().getHeader("Accept")) @@ -121,10 +113,9 @@ public class SecurityMockMvcRequestBuildersFormLoginTests { // gh-3920 @Test public void usesAcceptMediaForContentNegotiation() { - MockHttpServletRequest request = formLogin("/login").user("username", "admin") - .password("password", "secret").buildRequest(this.servletContext); - - assertThat(request.getHeader("Accept")) - .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); + MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret") + .buildRequest(this.servletContext); + assertThat(request.getHeader("Accept")).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index b6271e9409..dfdcd71507 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; @@ -29,74 +33,64 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.test.web.servlet.setup.MockMvcBuilders; -import java.util.Arrays; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.when; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout; public class SecurityMockMvcRequestBuildersFormLogoutTests { + private MockServletContext servletContext; @Before public void setup() { - servletContext = new MockServletContext(); + this.servletContext = new MockServletContext(); } @Test public void defaults() { - MockHttpServletRequest request = logout().buildRequest(servletContext); - - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); - + MockHttpServletRequest request = logout().buildRequest(this.servletContext); + CsrfToken token = (CsrfToken) request + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())).isEqualTo( - token.getToken()); + assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/logout"); } @Test public void custom() { - MockHttpServletRequest request = logout("/admin/logout").buildRequest( - servletContext); - - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); - + MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext); + CsrfToken token = (CsrfToken) request + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())).isEqualTo( - token.getToken()); + assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); } @Test public void customWithUriVars() { - MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2").buildRequest( - servletContext); - - CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); - + MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2") + .buildRequest(this.servletContext); + CsrfToken token = (CsrfToken) request + .getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())).isEqualTo( - token.getToken()); + assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken()); assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2"); } /** - * spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together - * with our request builders. (gh-7572) + * spring-restdocs uses postprocessors to do its trick. It will work only if these are + * merged together with our request builders. (gh-7572) * @throws Exception */ @Test public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception { RequestPostProcessor postProcessor = mock(RequestPostProcessor.class); - when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0)); + given(postProcessor.postProcessRequest(any())).willAnswer((i) -> i.getArgument(0)); MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object()) - .defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor)) - .build(); - + .defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor)).build(); MvcResult mvcResult = mockMvc.perform(logout()).andReturn(); assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name()); assertThat(mvcResult.getRequest().getHeader("Accept")) @@ -105,4 +99,5 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests { assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty(); verify(postProcessor).postProcessRequest(any()); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationStatelessTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationStatelessTests.java index d0a00f5cdd..63e7b6c177 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationStatelessTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationStatelessTests.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.ContextConfiguration; @@ -35,8 +37,8 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.*; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -52,51 +54,54 @@ public class SecurityMockMvcRequestPostProcessorsAuthenticationStatelessTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } // SEC-2593 @Test public void userRequestPostProcessorWorksWithStateless() throws Exception { - mvc.perform(get("/").with(user("user"))).andExpect(status().is2xxSuccessful()); + this.mvc.perform(get("/").with(user("user"))).andExpect(status().is2xxSuccessful()); } // SEC-2593 @WithMockUser @Test public void withMockUserWorksWithStateless() throws Exception { - mvc.perform(get("/")).andExpect(status().is2xxSuccessful()); + this.mvc.perform(get("/")).andExpect(status().is2xxSuccessful()); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { super.configure(http); - + // @formatter:off http .sessionManagement() .sessionCreationPolicy(SessionCreationPolicy.STATELESS); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication(); + // @formatter:on } - // @formatter:on @RestController static class Controller { + @RequestMapping - public String hello() { + String hello() { return "Hello"; } + } + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationTests.java index 357e386719..b73620e033 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsAuthenticationTests.java @@ -13,15 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.request; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; +package org.springframework.security.test.web.servlet.request; import javax.servlet.http.HttpServletResponse; @@ -32,9 +25,11 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; @@ -42,12 +37,21 @@ import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.web.context.SecurityContextRepository; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; + @RunWith(PowerMockRunner.class) @PrepareOnlyThisForTest(WebTestUtils.class) -@PowerMockIgnore({"javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*"}) +@PowerMockIgnore({ "javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", + "javax.xml.parsers.*" }) public class SecurityMockMvcRequestPostProcessorsAuthenticationTests { + @Captor private ArgumentCaptor contextCaptor; + @Mock private SecurityContextRepository repository; @@ -58,7 +62,7 @@ public class SecurityMockMvcRequestPostProcessorsAuthenticationTests { @Before public void setup() { - request = new MockHttpServletRequest(); + this.request = new MockHttpServletRequest(); mockWebTestUtils(); } @@ -69,16 +73,16 @@ public class SecurityMockMvcRequestPostProcessorsAuthenticationTests { @Test public void userDetails() { - authentication(authentication).postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + authentication(this.authentication).postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); - assertThat(context.getAuthentication()).isSameAs(authentication); + SecurityContext context = this.contextCaptor.getValue(); + assertThat(context.getAuthentication()).isSameAs(this.authentication); } private void mockWebTestUtils() { - spy(WebTestUtils.class); - when(WebTestUtils.getSecurityContextRepository(request)).thenReturn(repository); + PowerMockito.spy(WebTestUtils.class); + PowerMockito.when(WebTestUtils.getSecurityContextRepository(this.request)).thenReturn(this.repository); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCertificateTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCertificateTests.java index 2d471ebb38..22b3ccb895 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCertificateTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCertificateTests.java @@ -13,22 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; +import java.security.cert.X509Certificate; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.mock.web.MockHttpServletRequest; -import java.security.cert.X509Certificate; +import org.springframework.mock.web.MockHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509; @RunWith(MockitoJUnitRunner.class) public class SecurityMockMvcRequestPostProcessorsCertificateTests { + @Mock private X509Certificate certificate; @@ -36,30 +39,25 @@ public class SecurityMockMvcRequestPostProcessorsCertificateTests { @Before public void setup() { - request = new MockHttpServletRequest(); + this.request = new MockHttpServletRequest(); } @Test public void x509SingleCertificate() { - MockHttpServletRequest postProcessedRequest = x509(certificate) - .postProcessRequest(request); - + MockHttpServletRequest postProcessedRequest = x509(this.certificate).postProcessRequest(this.request); X509Certificate[] certificates = (X509Certificate[]) postProcessedRequest .getAttribute("javax.servlet.request.X509Certificate"); - - assertThat(certificates).containsOnly(certificate); + assertThat(certificates).containsOnly(this.certificate); } @Test public void x509ResourceName() throws Exception { - MockHttpServletRequest postProcessedRequest = x509("rod.cer").postProcessRequest( - request); - + MockHttpServletRequest postProcessedRequest = x509("rod.cer").postProcessRequest(this.request); X509Certificate[] certificates = (X509Certificate[]) postProcessedRequest .getAttribute("javax.servlet.request.X509Certificate"); - assertThat(certificates).hasSize(1); - assertThat(certificates[0].getSubjectDN().getName()).isEqualTo( - "CN=rod, OU=Spring Security, O=Spring Framework"); + assertThat(certificates[0].getSubjectDN().getName()) + .isEqualTo("CN=rod, OU=Spring Security, O=Spring Framework"); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java index b65a219236..b74fd99528 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -45,7 +47,7 @@ public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests { // SEC-3836 @Test public void findCookieCsrfTokenRepository() { - MockHttpServletRequest request = post("/").buildRequest(wac.getServletContext()); + MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext()); CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request); assertThat(csrfTokenRepository).isNotNull(); assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository); @@ -53,6 +55,7 @@ public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests { @EnableWebSecurity static class Config extends WebSecurityConfigurerAdapter { + static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository(); @Override @@ -65,5 +68,7 @@ public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests { // Enable the DebugFilter web.debug(true); } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java index 26caba94e7..7118bf8889 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import java.io.IOException; @@ -62,10 +63,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @ContextConfiguration @WebAppConfiguration public class SecurityMockMvcRequestPostProcessorsCsrfTests { + @Autowired WebApplicationContext wac; + @Autowired TheController controller; + @Autowired FilterChainProxy springSecurityFilterChain; @@ -151,12 +155,10 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { public void csrfWhenUsedThenDoesNotImpactOriginalRepository() throws Exception { // @formatter:off this.mockMvc.perform(post("/").with(csrf())); - MockHttpServletRequest request = new MockHttpServletRequest(); HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository(); CsrfToken token = repo.generateToken(request); repo.saveToken(token, request, new MockHttpServletResponse()); - MockHttpServletRequestBuilder requestWithCsrf = post("/") .param(token.getParameterName(), token.getToken()) .session((MockHttpSession) request.getSession()); @@ -169,6 +171,10 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { return new CsrfParamResultMatcher(); } + public static ResultMatcher csrfAsHeader() { + return new CsrfHeaderResultMatcher(); + } + static class CsrfParamResultMatcher implements ResultMatcher { @Override @@ -177,10 +183,7 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { assertThat(request.getParameter("_csrf")).isNotNull(); assertThat(request.getHeader("X-CSRF-TOKEN")).isNull(); } - } - public static ResultMatcher csrfAsHeader() { - return new CsrfHeaderResultMatcher(); } static class CsrfHeaderResultMatcher implements ResultMatcher { @@ -191,18 +194,19 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { assertThat(request.getParameter("_csrf")).isNull(); assertThat(request.getHeader("X-CSRF-TOKEN")).isNotNull(); } + } static class SessionRepositoryFilter extends OncePerRequestFilter { @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { filterChain.doFilter(new SessionRequestWrapper(request), response); } static class SessionRequestWrapper extends HttpServletRequestWrapper { + HttpSession session = new MockHttpSession(); SessionRequestWrapper(HttpServletRequest request) { @@ -218,21 +222,28 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { public HttpSession getSession() { return this.session; } + } + } @EnableWebSecurity static class Config extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) { } @RestController static class TheController { + @RequestMapping("/") String index() { return "Hi"; } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsDigestTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsDigestTests.java index 39bae21942..5c06373a90 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsDigestTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsDigestTests.java @@ -13,11 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -28,17 +36,13 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.web.authentication.www.DigestAuthenticationEntryPoint; import org.springframework.security.web.authentication.www.DigestAuthenticationFilter; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import java.io.IOException; - import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.digest; public class SecurityMockMvcRequestPostProcessorsDigestTests { private DigestAuthenticationFilter filter; + private MockHttpServletRequest request; private String username; @@ -50,16 +54,15 @@ public class SecurityMockMvcRequestPostProcessorsDigestTests { @Before public void setup() { this.password = "password"; - request = new MockHttpServletRequest(); - - entryPoint = new DigestAuthenticationEntryPoint(); - entryPoint.setKey("key"); - entryPoint.setRealmName("Spring Security"); - filter = new DigestAuthenticationFilter(); - filter.setUserDetailsService(username -> new User(username, password, AuthorityUtils - .createAuthorityList("ROLE_USER"))); - filter.setAuthenticationEntryPoint(entryPoint); - filter.afterPropertiesSet(); + this.request = new MockHttpServletRequest(); + this.entryPoint = new DigestAuthenticationEntryPoint(); + this.entryPoint.setKey("key"); + this.entryPoint.setRealmName("Spring Security"); + this.filter = new DigestAuthenticationFilter(); + this.filter.setUserDetailsService( + (username) -> new User(username, this.password, AuthorityUtils.createAuthorityList("ROLE_USER"))); + this.filter.setAuthenticationEntryPoint(this.entryPoint); + this.filter.afterPropertiesSet(); } @After @@ -69,38 +72,32 @@ public class SecurityMockMvcRequestPostProcessorsDigestTests { @Test public void digestWithFilter() throws Exception { - MockHttpServletRequest postProcessedRequest = digest() - .postProcessRequest(request); - + MockHttpServletRequest postProcessedRequest = digest().postProcessRequest(this.request); assertThat(extractUser()).isEqualTo("user"); } @Test public void digestWithFilterCustomUsername() throws Exception { String username = "admin"; - MockHttpServletRequest postProcessedRequest = digest(username) - .postProcessRequest(request); - + MockHttpServletRequest postProcessedRequest = digest(username).postProcessRequest(this.request); assertThat(extractUser()).isEqualTo(username); } @Test public void digestWithFilterCustomPassword() throws Exception { String username = "custom"; - password = "secret"; - MockHttpServletRequest postProcessedRequest = digest(username).password(password) - .postProcessRequest(request); - + this.password = "secret"; + MockHttpServletRequest postProcessedRequest = digest(username).password(this.password) + .postProcessRequest(this.request); assertThat(extractUser()).isEqualTo(username); } @Test public void digestWithFilterCustomRealm() throws Exception { String username = "admin"; - entryPoint.setRealmName("Custom"); - MockHttpServletRequest postProcessedRequest = digest(username).realm( - entryPoint.getRealmName()).postProcessRequest(request); - + this.entryPoint.setRealmName("Custom"); + MockHttpServletRequest postProcessedRequest = digest(username).realm(this.entryPoint.getRealmName()) + .postProcessRequest(this.request); assertThat(extractUser()).isEqualTo(username); } @@ -108,20 +105,20 @@ public class SecurityMockMvcRequestPostProcessorsDigestTests { public void digestWithFilterFails() throws Exception { String username = "admin"; MockHttpServletRequest postProcessedRequest = digest(username).realm("Invalid") - .postProcessRequest(request); - + .postProcessRequest(this.request); assertThat(extractUser()).isNull(); } private String extractUser() throws IOException, ServletException { - filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain() { + this.filter.doFilter(this.request, new MockHttpServletResponse(), new MockFilterChain() { @Override public void doFilter(ServletRequest request, ServletResponse response) { - Authentication authentication = SecurityContextHolder.getContext() - .getAuthentication(); - username = authentication == null ? null : authentication.getName(); + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + SecurityMockMvcRequestPostProcessorsDigestTests.this.username = (authentication != null) + ? authentication.getName() : null; } }); - return username; + return this.username; } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsJwtTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsJwtTests.java index 24441de7dc..5b7dd1a042 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsJwtTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsJwtTests.java @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import java.util.Arrays; import java.util.List; + import javax.servlet.http.HttpServletResponse; import org.junit.After; @@ -60,6 +62,7 @@ import static org.springframework.security.test.web.servlet.request.SecurityMock */ @RunWith(MockitoJUnitRunner.class) public class SecurityMockMvcRequestPostProcessorsJwtTests { + @Captor private ArgumentCaptor contextCaptor; @@ -70,6 +73,7 @@ public class SecurityMockMvcRequestPostProcessorsJwtTests { @Mock private GrantedAuthority authority1; + @Mock private GrantedAuthority authority2; @@ -91,12 +95,10 @@ public class SecurityMockMvcRequestPostProcessorsJwtTests { @Test public void jwtWhenUsingDefaultsThenProducesDefaultJwtAuthentication() { jwt().postProcessRequest(this.request); - verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); SecurityContext context = this.contextCaptor.getValue(); - assertThat(context.getAuthentication()).isInstanceOf( - JwtAuthenticationToken.class); + assertThat(context.getAuthentication()).isInstanceOf(JwtAuthenticationToken.class); JwtAuthenticationToken token = (JwtAuthenticationToken) context.getAuthentication(); assertThat(token.getAuthorities()).isNotEmpty(); assertThat(token.getToken()).isNotNull(); @@ -107,64 +109,50 @@ public class SecurityMockMvcRequestPostProcessorsJwtTests { @Test public void jwtWhenProvidingBuilderConsumerThenProducesJwtAuthentication() { String name = new String("user"); - jwt().jwt(jwt -> jwt.subject(name)).postProcessRequest(this.request); - + jwt().jwt((jwt) -> jwt.subject(name)).postProcessRequest(this.request); verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); SecurityContext context = this.contextCaptor.getValue(); - assertThat(context.getAuthentication()).isInstanceOf( - JwtAuthenticationToken.class); + assertThat(context.getAuthentication()).isInstanceOf(JwtAuthenticationToken.class); JwtAuthenticationToken token = (JwtAuthenticationToken) context.getAuthentication(); assertThat(token.getToken().getSubject()).isSameAs(name); } @Test public void jwtWhenProvidingCustomAuthoritiesThenProducesJwtAuthentication() { - jwt().jwt(jwt -> jwt.claim("scope", "ignored authorities")) - .authorities(this.authority1, this.authority2) + jwt().jwt((jwt) -> jwt.claim("scope", "ignored authorities")).authorities(this.authority1, this.authority2) .postProcessRequest(this.request); - verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); SecurityContext context = this.contextCaptor.getValue(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(this.authority1, this.authority2); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1, + this.authority2); } @Test public void jwtWhenProvidingScopedAuthoritiesThenProducesJwtAuthentication() { - jwt().jwt(jwt -> jwt.claim("scope", "scoped authorities")) - .postProcessRequest(this.request); - + jwt().jwt((jwt) -> jwt.claim("scope", "scoped authorities")).postProcessRequest(this.request); verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); SecurityContext context = this.contextCaptor.getValue(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(new SimpleGrantedAuthority("SCOPE_scoped"), - new SimpleGrantedAuthority("SCOPE_authorities")); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly( + new SimpleGrantedAuthority("SCOPE_scoped"), new SimpleGrantedAuthority("SCOPE_authorities")); } @Test public void jwtWhenProvidingGrantedAuthoritiesThenProducesJwtAuthentication() { - jwt().jwt(jwt -> jwt.claim("scope", "ignored authorities")) - .authorities(jwt -> Arrays.asList(this.authority1)) - .postProcessRequest(this.request); - + jwt().jwt((jwt) -> jwt.claim("scope", "ignored authorities")) + .authorities((jwt) -> Arrays.asList(this.authority1)).postProcessRequest(this.request); verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); SecurityContext context = this.contextCaptor.getValue(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(this.authority1); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1); } @Test public void jwtWhenProvidingPreparedJwtThenUsesItForAuthentication() { - Jwt originalToken = TestJwts.jwt() - .header("header1", "value1") - .subject("some_user") - .build(); + Jwt originalToken = TestJwts.jwt().header("header1", "value1").subject("some_user").build(); jwt().jwt(originalToken).postProcessRequest(this.request); - verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); SecurityContext context = this.contextCaptor.getValue(); @@ -173,4 +161,5 @@ public class SecurityMockMvcRequestPostProcessorsJwtTests { assertThat(retrievedToken.getToken().getTokenValue()).isEqualTo("token"); assertThat(retrievedToken.getToken().getHeaders().get("header1")).isEqualTo("value1"); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java index 13450d318f..ce6557f374 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import javax.servlet.http.HttpServletRequest; @@ -33,8 +34,10 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -46,14 +49,12 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; -import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oauth2Client; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -69,6 +70,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @ContextConfiguration @WebAppConfiguration public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests { + @Autowired WebApplicationContext context; @@ -89,20 +91,15 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests { TestSecurityContextHolder.clearContext(); } - @Test - public void oauth2ClientWhenUsingDefaultsThenException() - throws Exception { - - assertThatCode(() -> oauth2Client().postProcessRequest(new MockHttpServletRequest())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ClientRegistration"); + public void oauth2ClientWhenUsingDefaultsThenException() throws Exception { + assertThatIllegalArgumentException() + .isThrownBy(() -> oauth2Client().postProcessRequest(new MockHttpServletRequest())) + .withMessageContaining("ClientRegistration"); } @Test - public void oauth2ClientWhenUsingDefaultsThenProducesDefaultAuthorizedClient() - throws Exception { - + public void oauth2ClientWhenUsingDefaultsThenProducesDefaultAuthorizedClient() throws Exception { this.mvc.perform(get("/access-token").with(oauth2Client("registration-id"))) .andExpect(content().string("access-token")); this.mvc.perform(get("/client-id").with(oauth2Client("registration-id"))) @@ -110,66 +107,60 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests { } @Test - public void oauth2ClientWhenClientRegistrationThenUses() - throws Exception { - - ClientRegistration clientRegistration = clientRegistration() + public void oauth2ClientWhenClientRegistrationThenUses() throws Exception { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() .registrationId("registration-id").clientId("client-id").build(); - this.mvc.perform(get("/client-id") - .with(oauth2Client().clientRegistration(clientRegistration))) + this.mvc.perform(get("/client-id").with(oauth2Client().clientRegistration(clientRegistration))) .andExpect(content().string("client-id")); } @Test - public void oauth2ClientWhenClientRegistrationConsumerThenUses() - throws Exception { - + public void oauth2ClientWhenClientRegistrationConsumerThenUses() throws Exception { this.mvc.perform(get("/client-id") - .with(oauth2Client("registration-id").clientRegistration(c -> c.clientId("client-id")))) + .with(oauth2Client("registration-id").clientRegistration((c) -> c.clientId("client-id")))) .andExpect(content().string("client-id")); } @Test public void oauth2ClientWhenPrincipalNameThenUses() throws Exception { - this.mvc.perform(get("/principal-name") - .with(oauth2Client("registration-id").principalName("test-subject"))) + this.mvc.perform(get("/principal-name").with(oauth2Client("registration-id").principalName("test-subject"))) .andExpect(content().string("test-subject")); } @Test public void oauth2ClientWhenAccessTokenThenUses() throws Exception { - OAuth2AccessToken accessToken = noScopes(); - this.mvc.perform(get("/access-token") - .with(oauth2Client("registration-id").accessToken(accessToken))) + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); + this.mvc.perform(get("/access-token").with(oauth2Client("registration-id").accessToken(accessToken))) .andExpect(content().string("no-scopes")); } @Test public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exception { - this.mvc.perform(get("/client-id") - .with(oauth2Client("registration-id"))) + this.mvc.perform(get("/client-id").with(oauth2Client("registration-id"))) .andExpect(content().string("test-client")); - - OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(clientRegistration().build(), "sub", noScopes()); + OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(TestClientRegistrations.clientRegistration().build(), + "sub", TestOAuth2AccessTokens.noScopes()); OAuth2AuthorizedClientRepository repository = this.context.getBean(OAuth2AuthorizedClientRepository.class); - when(repository.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), any(HttpServletRequest.class))) - .thenReturn(client); - this.mvc.perform(get("/client-id")) - .andExpect(content().string("client-id")); - verify(repository).loadAuthorizedClient( - eq("registration-id"), any(Authentication.class), any(HttpServletRequest.class)); + given(repository.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), + any(HttpServletRequest.class))).willReturn(client); + this.mvc.perform(get("/client-id")).andExpect(content().string("client-id")); + verify(repository).loadAuthorizedClient(eq("registration-id"), any(Authentication.class), + any(HttpServletRequest.class)); } @EnableWebSecurity @EnableWebMvc static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .authorizeRequests(authz -> authz + .authorizeRequests((authz) -> authz .anyRequest().permitAll() ) .oauth2Client(); + // @formatter:on } @Bean @@ -177,7 +168,6 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests { return mock(ClientRegistrationRepository.class); } - @Bean OAuth2AuthorizedClientRepository authorizedClientRepository() { return mock(OAuth2AuthorizedClientRepository.class); @@ -185,20 +175,27 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests { @RestController static class PrincipalController { + @GetMapping("/access-token") - String accessToken(@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { + String accessToken( + @RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { return authorizedClient.getAccessToken().getTokenValue(); } @GetMapping("/principal-name") - String principalName(@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { + String principalName( + @RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { return authorizedClient.getPrincipalName(); } @GetMapping("/client-id") - String clientId(@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { + String clientId( + @RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { return authorizedClient.getClientRegistration().getClientId(); } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2LoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2LoginTests.java index 52e8941eaf..f5307845af 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2LoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2LoginTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import java.util.Collection; @@ -37,6 +38,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; @@ -52,7 +54,6 @@ import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import static org.mockito.Mockito.mock; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oauth2Login; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -69,6 +70,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @ContextConfiguration @WebAppConfiguration public class SecurityMockMvcRequestPostProcessorsOAuth2LoginTests { + @Autowired WebApplicationContext context; @@ -85,93 +87,74 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2LoginTests { } @Test - public void oauth2LoginWhenUsingDefaultsThenProducesDefaultAuthentication() - throws Exception { - - this.mvc.perform(get("/name").with(oauth2Login())) - .andExpect(content().string("user")); - this.mvc.perform(get("/admin/id-token/name").with(oauth2Login())) - .andExpect(status().isForbidden()); + public void oauth2LoginWhenUsingDefaultsThenProducesDefaultAuthentication() throws Exception { + this.mvc.perform(get("/name").with(oauth2Login())).andExpect(content().string("user")); + this.mvc.perform(get("/admin/id-token/name").with(oauth2Login())).andExpect(status().isForbidden()); } @Test - public void oauth2LoginWhenUsingDefaultsThenProducesDefaultAuthorizedClient() - throws Exception { - - this.mvc.perform(get("/client-id").with(oauth2Login())) - .andExpect(content().string("test-client")); + public void oauth2LoginWhenUsingDefaultsThenProducesDefaultAuthorizedClient() throws Exception { + this.mvc.perform(get("/client-id").with(oauth2Login())).andExpect(content().string("test-client")); } @Test public void oauth2LoginWhenAuthoritiesSpecifiedThenGrantsAccess() throws Exception { - this.mvc.perform(get("/admin/scopes") - .with(oauth2Login().authorities(new SimpleGrantedAuthority("SCOPE_admin")))) + this.mvc.perform( + get("/admin/scopes").with(oauth2Login().authorities(new SimpleGrantedAuthority("SCOPE_admin")))) .andExpect(content().string("[\"SCOPE_admin\"]")); } @Test public void oauth2LoginWhenAttributeSpecifiedThenUserHasAttribute() throws Exception { - this.mvc.perform(get("/attributes/iss") - .with(oauth2Login().attributes(a -> a.put("iss", "https://idp.example.org")))) + this.mvc.perform( + get("/attributes/iss").with(oauth2Login().attributes((a) -> a.put("iss", "https://idp.example.org")))) .andExpect(content().string("https://idp.example.org")); } @Test public void oauth2LoginWhenNameSpecifiedThenUserHasName() throws Exception { - OAuth2User oauth2User = new DefaultOAuth2User( - AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), - Collections.singletonMap("custom-attribute", "test-subject"), - "custom-attribute"); - this.mvc.perform(get("/attributes/custom-attribute") - .with(oauth2Login().oauth2User(oauth2User))) + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), + Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"); + this.mvc.perform(get("/attributes/custom-attribute").with(oauth2Login().oauth2User(oauth2User))) .andExpect(content().string("test-subject")); - - this.mvc.perform(get("/name") - .with(oauth2Login().oauth2User(oauth2User))) + this.mvc.perform(get("/name").with(oauth2Login().oauth2User(oauth2User))) .andExpect(content().string("test-subject")); - - this.mvc.perform(get("/client-name") - .with(oauth2Login().oauth2User(oauth2User))) + this.mvc.perform(get("/client-name").with(oauth2Login().oauth2User(oauth2User))) .andExpect(content().string("test-subject")); } @Test public void oauth2LoginWhenClientRegistrationSpecifiedThenUses() throws Exception { this.mvc.perform(get("/client-id") - .with(oauth2Login().clientRegistration(clientRegistration().build()))) + .with(oauth2Login().clientRegistration(TestClientRegistrations.clientRegistration().build()))) .andExpect(content().string("client-id")); } @Test public void oauth2LoginWhenOAuth2UserSpecifiedThenLastCalledTakesPrecedence() throws Exception { - OAuth2User oauth2User = new DefaultOAuth2User( - AuthorityUtils.createAuthorityList("SCOPE_read"), - Collections.singletonMap("username", "user"), - "username"); - + OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("SCOPE_read"), + Collections.singletonMap("username", "user"), "username"); this.mvc.perform(get("/attributes/sub") - .with(oauth2Login() - .attributes(a -> a.put("sub", "bar")) - .oauth2User(oauth2User))) - .andExpect(status().isOk()) - .andExpect(content().string("no-attribute")); + .with(oauth2Login().attributes((a) -> a.put("sub", "bar")).oauth2User(oauth2User))) + .andExpect(status().isOk()).andExpect(content().string("no-attribute")); this.mvc.perform(get("/attributes/sub") - .with(oauth2Login() - .oauth2User(oauth2User) - .attributes(a -> a.put("sub", "bar")))) + .with(oauth2Login().oauth2User(oauth2User).attributes((a) -> a.put("sub", "bar")))) .andExpect(content().string("bar")); } @EnableWebSecurity @EnableWebMvc static class OAuth2LoginConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http - .authorizeRequests(authorize -> authorize + .authorizeRequests((authorize) -> authorize .mvcMatchers("/admin/**").hasAuthority("SCOPE_admin") .anyRequest().hasAuthority("SCOPE_read") ).oauth2Login(); + // @formatter:on } @Bean @@ -186,6 +169,7 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2LoginTests { @RestController static class PrincipalController { + @GetMapping("/name") String name(@AuthenticationPrincipal OAuth2User oauth2User) { return oauth2User.getName(); @@ -202,19 +186,19 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2LoginTests { } @GetMapping("/attributes/{attribute}") - String attributes( - @AuthenticationPrincipal OAuth2User oauth2User, @PathVariable("attribute") String attribute) { - + String attributes(@AuthenticationPrincipal OAuth2User oauth2User, + @PathVariable("attribute") String attribute) { return Optional.ofNullable((String) oauth2User.getAttribute(attribute)).orElse("no-attribute"); } @GetMapping("/admin/scopes") List scopes( @AuthenticationPrincipal(expression = "authorities") Collection authorities) { - - return authorities.stream().map(GrantedAuthority::getAuthority) - .collect(Collectors.toList()); + return authorities.stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList()); } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOidcLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOidcLoginTests.java index 4728d78a4d..91fa711355 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOidcLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOidcLoginTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; import java.util.Collection; @@ -38,6 +39,7 @@ import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2Aut import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.test.context.TestSecurityContextHolder; @@ -53,7 +55,6 @@ import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import static org.mockito.Mockito.mock; -import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -70,6 +71,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @ContextConfiguration @WebAppConfiguration public class SecurityMockMvcRequestPostProcessorsOidcLoginTests { + @Autowired WebApplicationContext context; @@ -91,94 +93,71 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests { } @Test - public void oidcLoginWhenUsingDefaultsThenProducesDefaultAuthentication() - throws Exception { - - this.mvc.perform(get("/name").with(oidcLogin())) - .andExpect(content().string("user")); - this.mvc.perform(get("/admin/id-token/name").with(oidcLogin())) - .andExpect(status().isForbidden()); + public void oidcLoginWhenUsingDefaultsThenProducesDefaultAuthentication() throws Exception { + this.mvc.perform(get("/name").with(oidcLogin())).andExpect(content().string("user")); + this.mvc.perform(get("/admin/id-token/name").with(oidcLogin())).andExpect(status().isForbidden()); } @Test - public void oidcLoginWhenUsingDefaultsThenProducesDefaultAuthorizedClient() - throws Exception { - - this.mvc.perform(get("/access-token").with(oidcLogin())) - .andExpect(content().string("access-token")); + public void oidcLoginWhenUsingDefaultsThenProducesDefaultAuthorizedClient() throws Exception { + this.mvc.perform(get("/access-token").with(oidcLogin())).andExpect(content().string("access-token")); } @Test public void oidcLoginWhenAuthoritiesSpecifiedThenGrantsAccess() throws Exception { - this.mvc.perform(get("/admin/scopes") - .with(oidcLogin().authorities(new SimpleGrantedAuthority("SCOPE_admin")))) + this.mvc.perform(get("/admin/scopes").with(oidcLogin().authorities(new SimpleGrantedAuthority("SCOPE_admin")))) .andExpect(content().string("[\"SCOPE_admin\"]")); } @Test public void oidcLoginWhenIdTokenSpecifiedThenUserHasClaims() throws Exception { - this.mvc.perform(get("/id-token/iss") - .with(oidcLogin().idToken(i -> i.issuer("https://idp.example.org")))) + this.mvc.perform(get("/id-token/iss").with(oidcLogin().idToken((i) -> i.issuer("https://idp.example.org")))) .andExpect(content().string("https://idp.example.org")); } @Test public void oidcLoginWhenUserInfoSpecifiedThenUserHasClaims() throws Exception { - this.mvc.perform(get("/user-info/email") - .with(oidcLogin().userInfoToken(u -> u.email("email@email")))) + this.mvc.perform(get("/user-info/email").with(oidcLogin().userInfoToken((u) -> u.email("email@email")))) .andExpect(content().string("email@email")); } @Test public void oidcLoginWhenNameSpecifiedThenUserHasName() throws Exception { - OidcUser oidcUser = new DefaultOidcUser( - AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), + OidcUser oidcUser = new DefaultOidcUser(AuthorityUtils.commaSeparatedStringToAuthorityList("SCOPE_read"), OidcIdToken.withTokenValue("id-token").claim("custom-attribute", "test-subject").build(), "custom-attribute"); - - this.mvc.perform(get("/id-token/custom-attribute") - .with(oidcLogin().oidcUser(oidcUser))) + this.mvc.perform(get("/id-token/custom-attribute").with(oidcLogin().oidcUser(oidcUser))) .andExpect(content().string("test-subject")); - - this.mvc.perform(get("/name") - .with(oidcLogin().oidcUser(oidcUser))) - .andExpect(content().string("test-subject")); - - this.mvc.perform(get("/client-name") - .with(oidcLogin().oidcUser(oidcUser))) + this.mvc.perform(get("/name").with(oidcLogin().oidcUser(oidcUser))).andExpect(content().string("test-subject")); + this.mvc.perform(get("/client-name").with(oidcLogin().oidcUser(oidcUser))) .andExpect(content().string("test-subject")); } // gh-7794 @Test public void oidcLoginWhenOidcUserSpecifiedThenLastCalledTakesPrecedence() throws Exception { - OidcUser oidcUser = new DefaultOidcUser( - AuthorityUtils.createAuthorityList("SCOPE_read"), idToken().build()); - - this.mvc.perform(get("/id-token/sub") - .with(oidcLogin() - .idToken(i -> i.subject("foo")) - .oidcUser(oidcUser))) - .andExpect(status().isOk()) - .andExpect(content().string("subject")); - this.mvc.perform(get("/id-token/sub") - .with(oidcLogin() - .oidcUser(oidcUser) - .idToken(i -> i.subject("bar")))) + OidcUser oidcUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("SCOPE_read"), + TestOidcIdTokens.idToken().build()); + this.mvc.perform(get("/id-token/sub").with(oidcLogin().idToken((i) -> i.subject("foo")).oidcUser(oidcUser))) + .andExpect(status().isOk()).andExpect(content().string("subject")); + this.mvc.perform(get("/id-token/sub").with(oidcLogin().oidcUser(oidcUser).idToken((i) -> i.subject("bar")))) .andExpect(content().string("bar")); } @EnableWebSecurity @EnableWebMvc static class OAuth2LoginConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .mvcMatchers("/admin/**").hasAuthority("SCOPE_admin") .anyRequest().hasAuthority("SCOPE_read") .and() .oauth2Login(); + // @formatter:on } @Bean @@ -186,7 +165,6 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests { return mock(ClientRegistrationRepository.class); } - @Bean OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository() { return mock(OAuth2AuthorizedClientRepository.class); @@ -194,6 +172,7 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests { @RestController static class PrincipalController { + @GetMapping("/name") String name(@AuthenticationPrincipal OidcUser oidcUser) { return oidcUser.getName(); @@ -220,11 +199,13 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests { } @GetMapping("/admin/scopes") - List scopes(@AuthenticationPrincipal(expression = "authorities") - Collection authorities) { - return authorities.stream().map(GrantedAuthority::getAuthority) - .collect(Collectors.toList()); + List scopes( + @AuthenticationPrincipal(expression = "authorities") Collection authorities) { + return authorities.stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList()); } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOpaqueTokenTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOpaqueTokenTests.java index 041a4d2e0f..764f51ec1c 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOpaqueTokenTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOpaqueTokenTests.java @@ -33,6 +33,7 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals; import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -45,9 +46,8 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.oauth2.core.TestOAuth2AuthenticatedPrincipals.active; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.opaqueToken; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -64,6 +64,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @ContextConfiguration @WebAppConfiguration public class SecurityMockMvcRequestPostProcessorsOpaqueTokenTests { + @Autowired WebApplicationContext context; @@ -80,19 +81,15 @@ public class SecurityMockMvcRequestPostProcessorsOpaqueTokenTests { } @Test - public void opaqueTokenWhenUsingDefaultsThenProducesDefaultAuthentication() - throws Exception { - - this.mvc.perform(get("/name").with(opaqueToken())) - .andExpect(content().string("user")); - this.mvc.perform(get("/admin/scopes").with(opaqueToken())) - .andExpect(status().isForbidden()); + public void opaqueTokenWhenUsingDefaultsThenProducesDefaultAuthentication() throws Exception { + this.mvc.perform(get("/name").with(opaqueToken())).andExpect(content().string("user")); + this.mvc.perform(get("/admin/scopes").with(opaqueToken())).andExpect(status().isForbidden()); } @Test public void opaqueTokenWhenAttributeSpecifiedThenUserHasAttribute() throws Exception { - this.mvc.perform(get("/opaque-token/iss") - .with(opaqueToken().attributes(a -> a.put("iss", "https://idp.example.org")))) + this.mvc.perform( + get("/opaque-token/iss").with(opaqueToken().attributes((a) -> a.put("iss", "https://idp.example.org")))) .andExpect(content().string("https://idp.example.org")); } @@ -100,36 +97,31 @@ public class SecurityMockMvcRequestPostProcessorsOpaqueTokenTests { public void opaqueTokenWhenPrincipalSpecifiedThenAuthenticationHasPrincipal() throws Exception { Collection authorities = Collections.singleton(new SimpleGrantedAuthority("SCOPE_read")); OAuth2AuthenticatedPrincipal principal = mock(OAuth2AuthenticatedPrincipal.class); - when(principal.getName()).thenReturn("ben"); - when(principal.getAuthorities()).thenReturn(authorities); - - this.mvc.perform(get("/name").with(opaqueToken().principal(principal))) - .andExpect(content().string("ben")); + given(principal.getName()).willReturn("ben"); + given(principal.getAuthorities()).willReturn(authorities); + this.mvc.perform(get("/name").with(opaqueToken().principal(principal))).andExpect(content().string("ben")); } // gh-7800 @Test public void opaqueTokenWhenPrincipalSpecifiedThenLastCalledTakesPrecedence() throws Exception { - OAuth2AuthenticatedPrincipal principal = active(a -> a.put("scope", "user")); - + OAuth2AuthenticatedPrincipal principal = TestOAuth2AuthenticatedPrincipals + .active((a) -> a.put("scope", "user")); this.mvc.perform(get("/opaque-token/sub") - .with(opaqueToken() - .attributes(a -> a.put("sub", "foo")) - .principal(principal))) - .andExpect(status().isOk()) - .andExpect(content().string((String) principal.getAttribute("sub"))); + .with(opaqueToken().attributes((a) -> a.put("sub", "foo")).principal(principal))) + .andExpect(status().isOk()).andExpect(content().string((String) principal.getAttribute("sub"))); this.mvc.perform(get("/opaque-token/sub") - .with(opaqueToken() - .principal(principal) - .attributes(a -> a.put("sub", "bar")))) + .with(opaqueToken().principal(principal).attributes((a) -> a.put("sub", "bar")))) .andExpect(content().string("bar")); } @EnableWebSecurity @EnableWebMvc static class OAuth2LoginConfig extends WebSecurityConfigurerAdapter { + @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .mvcMatchers("/admin/**").hasAuthority("SCOPE_admin") @@ -138,10 +130,12 @@ public class SecurityMockMvcRequestPostProcessorsOpaqueTokenTests { .oauth2ResourceServer() .opaqueToken() .introspector(mock(OpaqueTokenIntrospector.class)); + // @formatter:on } @RestController static class PrincipalController { + @GetMapping("/name") String name(@AuthenticationPrincipal OAuth2AuthenticatedPrincipal principal) { return principal.getName(); @@ -150,17 +144,17 @@ public class SecurityMockMvcRequestPostProcessorsOpaqueTokenTests { @GetMapping("/opaque-token/{attribute}") String tokenAttribute(@AuthenticationPrincipal OAuth2AuthenticatedPrincipal principal, @PathVariable("attribute") String attribute) { - return principal.getAttribute(attribute); } @GetMapping("/admin/scopes") - List scopes(@AuthenticationPrincipal(expression = "authorities") - Collection authorities) { - - return authorities.stream().map(GrantedAuthority::getAuthority) - .collect(Collectors.toList()); + List scopes( + @AuthenticationPrincipal(expression = "authorities") Collection authorities) { + return authorities.stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList()); } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsSecurityContextTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsSecurityContextTests.java index 2ba412c1a7..6f2521f37c 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsSecurityContextTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsSecurityContextTests.java @@ -13,15 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.request; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.securityContext; +package org.springframework.security.test.web.servlet.request; import javax.servlet.http.HttpServletResponse; @@ -32,21 +25,32 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.web.context.SecurityContextRepository; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.securityContext; + @RunWith(PowerMockRunner.class) @PrepareOnlyThisForTest(WebTestUtils.class) -@PowerMockIgnore({"javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*"}) +@PowerMockIgnore({ "javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", + "javax.xml.parsers.*" }) public class SecurityMockMvcRequestPostProcessorsSecurityContextTests { + @Captor private ArgumentCaptor contextCaptor; + @Mock private SecurityContextRepository repository; @@ -57,7 +61,7 @@ public class SecurityMockMvcRequestPostProcessorsSecurityContextTests { @Before public void setup() { - request = new MockHttpServletRequest(); + this.request = new MockHttpServletRequest(); mockWebTestUtils(); } @@ -68,16 +72,16 @@ public class SecurityMockMvcRequestPostProcessorsSecurityContextTests { @Test public void userDetails() { - securityContext(expectedContext).postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + securityContext(this.expectedContext).postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); + SecurityContext context = this.contextCaptor.getValue(); assertThat(context).isSameAs(this.expectedContext); } private void mockWebTestUtils() { - spy(WebTestUtils.class); - when(WebTestUtils.getSecurityContextRepository(request)).thenReturn(repository); + PowerMockito.spy(WebTestUtils.class); + PowerMockito.when(WebTestUtils.getSecurityContextRepository(this.request)).thenReturn(this.repository); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextStatelessTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextStatelessTests.java index a82709e4ad..2ed0a2ffc7 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextStatelessTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextStatelessTests.java @@ -13,16 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.request; +import javax.servlet.Filter; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.test.context.ContextConfiguration; @@ -35,10 +39,7 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import javax.servlet.Filter; - import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext; - import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -57,46 +58,45 @@ public class SecurityMockMvcRequestPostProcessorsTestSecurityContextStatelessTes @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .addFilters(springSecurityFilterChain) + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).addFilters(this.springSecurityFilterChain) .defaultRequest(get("/").with(testSecurityContext())).build(); } @Test @WithMockUser public void testSecurityContextWithMockUserWorksWithStateless() throws Exception { - mvc.perform(get("/")).andExpect(status().is2xxSuccessful()); + this.mvc.perform(get("/")).andExpect(status().is2xxSuccessful()); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { super.configure(http); - + // @formatter:off http .sessionManagement() .sessionCreationPolicy(SessionCreationPolicy.STATELESS); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { - auth - .inMemoryAuthentication(); + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + auth.inMemoryAuthentication(); } - // @formatter:on @RestController static class Controller { + @RequestMapping - public String hello() { + String hello() { return "Hello"; } + } + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextTests.java index 65ed0f969b..0eb7913127 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsTestSecurityContextTests.java @@ -13,14 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.request; -import static org.powermock.api.mockito.PowerMockito.*; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.any; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext; +package org.springframework.security.test.web.servlet.request; import javax.servlet.http.HttpServletResponse; @@ -29,21 +23,32 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.web.context.SecurityContextRepository; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext; + @RunWith(PowerMockRunner.class) @PrepareOnlyThisForTest(WebTestUtils.class) -@PowerMockIgnore({"javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*"}) +@PowerMockIgnore({ "javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", + "javax.xml.parsers.*" }) public class SecurityMockMvcRequestPostProcessorsTestSecurityContextTests { + @Mock private SecurityContext context; + @Mock private SecurityContextRepository repository; @@ -51,7 +56,7 @@ public class SecurityMockMvcRequestPostProcessorsTestSecurityContextTests { @Before public void setup() { - request = new MockHttpServletRequest(); + this.request = new MockHttpServletRequest(); mockWebTestUtils(); } @@ -62,25 +67,22 @@ public class SecurityMockMvcRequestPostProcessorsTestSecurityContextTests { @Test public void testSecurityContextSaves() { - TestSecurityContextHolder.setContext(context); - - testSecurityContext().postProcessRequest(request); - - verify(repository).saveContext(eq(context), eq(request), - any(HttpServletResponse.class)); + TestSecurityContextHolder.setContext(this.context); + testSecurityContext().postProcessRequest(this.request); + verify(this.repository).saveContext(eq(this.context), eq(this.request), any(HttpServletResponse.class)); } // Ensure it does not fail if TestSecurityContextHolder is not initialized @Test public void testSecurityContextNoContext() { - testSecurityContext().postProcessRequest(request); - - verify(repository, never()).saveContext(any(SecurityContext.class), eq(request), + testSecurityContext().postProcessRequest(this.request); + verify(this.repository, never()).saveContext(any(SecurityContext.class), eq(this.request), any(HttpServletResponse.class)); } private void mockWebTestUtils() { - spy(WebTestUtils.class); - when(WebTestUtils.getSecurityContextRepository(request)).thenReturn(repository); + PowerMockito.spy(WebTestUtils.class); + PowerMockito.when(WebTestUtils.getSecurityContextRepository(this.request)).thenReturn(this.repository); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserDetailsTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserDetailsTests.java index c2e8136e52..0f05f727c9 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserDetailsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserDetailsTests.java @@ -13,15 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.request; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +package org.springframework.security.test.web.servlet.request; import javax.servlet.http.HttpServletResponse; @@ -32,9 +25,11 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.context.SecurityContext; @@ -43,12 +38,21 @@ import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.web.context.SecurityContextRepository; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; + @RunWith(PowerMockRunner.class) @PrepareOnlyThisForTest(WebTestUtils.class) -@PowerMockIgnore({"javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*"}) +@PowerMockIgnore({ "javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", + "javax.xml.parsers.*" }) public class SecurityMockMvcRequestPostProcessorsUserDetailsTests { + @Captor private ArgumentCaptor contextCaptor; + @Mock private SecurityContextRepository repository; @@ -59,7 +63,7 @@ public class SecurityMockMvcRequestPostProcessorsUserDetailsTests { @Before public void setup() { - request = new MockHttpServletRequest(); + this.request = new MockHttpServletRequest(); mockWebTestUtils(); } @@ -70,18 +74,17 @@ public class SecurityMockMvcRequestPostProcessorsUserDetailsTests { @Test public void userDetails() { - user(userDetails).postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + user(this.userDetails).postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); - assertThat(context.getAuthentication().getPrincipal()).isSameAs(userDetails); + SecurityContext context = this.contextCaptor.getValue(); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); + assertThat(context.getAuthentication().getPrincipal()).isSameAs(this.userDetails); } private void mockWebTestUtils() { - spy(WebTestUtils.class); - when(WebTestUtils.getSecurityContextRepository(request)).thenReturn(repository); + PowerMockito.spy(WebTestUtils.class); + PowerMockito.when(WebTestUtils.getSecurityContextRepository(this.request)).thenReturn(this.repository); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserTests.java index bc9e2f088a..1c6fb34678 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsUserTests.java @@ -13,15 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.request; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.powermock.api.mockito.PowerMockito.spy; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +package org.springframework.security.test.web.servlet.request; import java.util.Arrays; import java.util.List; @@ -35,9 +28,11 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; +import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.GrantedAuthority; @@ -46,12 +41,21 @@ import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.web.support.WebTestUtils; import org.springframework.security.web.context.SecurityContextRepository; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; + @RunWith(PowerMockRunner.class) @PrepareOnlyThisForTest(WebTestUtils.class) -@PowerMockIgnore({"javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*"}) +@PowerMockIgnore({ "javax.security.auth.*", "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", + "javax.xml.parsers.*" }) public class SecurityMockMvcRequestPostProcessorsUserTests { + @Captor private ArgumentCaptor contextCaptor; + @Mock private SecurityContextRepository repository; @@ -59,12 +63,13 @@ public class SecurityMockMvcRequestPostProcessorsUserTests { @Mock private GrantedAuthority authority1; + @Mock private GrantedAuthority authority2; @Before public void setup() { - request = new MockHttpServletRequest(); + this.request = new MockHttpServletRequest(); mockWebTestUtils(); } @@ -76,72 +81,60 @@ public class SecurityMockMvcRequestPostProcessorsUserTests { @Test public void userWithDefaults() { String username = "userabc"; - - user(username).postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + user(username).postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); + SecurityContext context = this.contextCaptor.getValue(); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); assertThat(context.getAuthentication().getName()).isEqualTo(username); assertThat(context.getAuthentication().getCredentials()).isEqualTo("password"); - assertThat(context.getAuthentication().getAuthorities()).extracting("authority") - .containsOnly("ROLE_USER"); + assertThat(context.getAuthentication().getAuthorities()).extracting("authority").containsOnly("ROLE_USER"); } @Test public void userWithCustom() { String username = "customuser"; - - user(username).roles("CUSTOM", "ADMIN").password("newpass") - .postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + user(username).roles("CUSTOM", "ADMIN").password("newpass").postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); - assertThat(context.getAuthentication()).isInstanceOf( - UsernamePasswordAuthenticationToken.class); + SecurityContext context = this.contextCaptor.getValue(); + assertThat(context.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class); assertThat(context.getAuthentication().getName()).isEqualTo(username); assertThat(context.getAuthentication().getCredentials()).isEqualTo("newpass"); - assertThat(context.getAuthentication().getAuthorities()).extracting("authority") - .containsOnly("ROLE_CUSTOM", "ROLE_ADMIN"); + assertThat(context.getAuthentication().getAuthorities()).extracting("authority").containsOnly("ROLE_CUSTOM", + "ROLE_ADMIN"); } @Test public void userCustomAuthoritiesVarargs() { String username = "customuser"; - - user(username).authorities(authority1, authority2).postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + user(username).authorities(this.authority1, this.authority2).postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(authority1, authority2); + SecurityContext context = this.contextCaptor.getValue(); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1, + this.authority2); } @Test(expected = IllegalArgumentException.class) public void userRolesWithRolePrefixErrors() { - user("user").roles("ROLE_INVALID").postProcessRequest(request); + user("user").roles("ROLE_INVALID").postProcessRequest(this.request); } @Test public void userCustomAuthoritiesList() { String username = "customuser"; - - user(username).authorities(Arrays.asList(authority1, authority2)) - .postProcessRequest(request); - - verify(repository).saveContext(contextCaptor.capture(), eq(request), + user(username).authorities(Arrays.asList(this.authority1, this.authority2)).postProcessRequest(this.request); + verify(this.repository).saveContext(this.contextCaptor.capture(), eq(this.request), any(HttpServletResponse.class)); - SecurityContext context = contextCaptor.getValue(); - assertThat((List) context.getAuthentication().getAuthorities()) - .containsOnly(authority1, authority2); + SecurityContext context = this.contextCaptor.getValue(); + assertThat((List) context.getAuthentication().getAuthorities()).containsOnly(this.authority1, + this.authority2); } private void mockWebTestUtils() { - spy(WebTestUtils.class); - when(WebTestUtils.getSecurityContextRepository(request)).thenReturn(repository); + PowerMockito.spy(WebTestUtils.class); + PowerMockito.when(WebTestUtils.getSecurityContextRepository(this.request)).thenReturn(this.repository); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/response/Gh3409Tests.java b/test/src/test/java/org/springframework/security/test/web/servlet/response/Gh3409Tests.java index d879c54691..a4c3f8869c 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/response/Gh3409Tests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/response/Gh3409Tests.java @@ -69,7 +69,6 @@ public class Gh3409Tests { this.mockMvc .perform(get("/public/") .with(securityContext(new SecurityContextImpl()))); - this.mockMvc .perform(get("/public/")) .andExpect(unauthenticated()); @@ -82,7 +81,6 @@ public class Gh3409Tests { this.mockMvc .perform(get("/") .with(securityContext(new SecurityContextImpl()))); - this.mockMvc .perform(get("/")) .andExpect(unauthenticated()); @@ -104,7 +102,8 @@ public class Gh3409Tests { .formLogin().and() .httpBasic(); // @formatter:on - } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchersTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchersTests.java index 9a63cfe4fa..3441b8cbe8 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchersTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.response; import org.junit.Before; @@ -47,6 +48,7 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv @ContextConfiguration(classes = SecurityMockMvcResultMatchersTests.Config.class) @WebAppConfiguration public class SecurityMockMvcResultMatchersTests { + @Autowired private WebApplicationContext context; @@ -64,16 +66,14 @@ public class SecurityMockMvcResultMatchersTests { @Test public void withAuthenticationWhenMatchesThenSuccess() throws Exception { - this.mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthentication(auth -> - assertThat(auth).isInstanceOf(UsernamePasswordAuthenticationToken.class))); + this.mockMvc.perform(formLogin()).andExpect(authenticated().withAuthentication( + (auth) -> assertThat(auth).isInstanceOf(UsernamePasswordAuthenticationToken.class))); } @Test(expected = AssertionError.class) public void withAuthenticationWhenNotMatchesThenFails() throws Exception { - this.mockMvc - .perform(formLogin()) - .andExpect(authenticated().withAuthentication(auth -> assertThat(auth.getName()).isEqualTo("notmatch"))); + this.mockMvc.perform(formLogin()).andExpect( + authenticated().withAuthentication((auth) -> assertThat(auth.getName()).isEqualTo("notmatch"))); } // SEC-2719 @@ -100,20 +100,25 @@ public class SecurityMockMvcResultMatchersTests { @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off + @Override @Bean public UserDetailsService userDetailsService() { + // @formatter:off UserDetails user = User.withDefaultPasswordEncoder().username("user").password("password").roles("USER", "SELLER").build(); + // @formatter:on return new InMemoryUserDetailsManager(user); } - // @formatter:on @RestController static class Controller { + @RequestMapping("/") - public String ok() { + String ok() { return "ok"; } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockWithAuthoritiesMvcResultMatchersTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockWithAuthoritiesMvcResultMatchersTests.java index 5f4b7af1f9..82335524fb 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockWithAuthoritiesMvcResultMatchersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/response/SecurityMockWithAuthoritiesMvcResultMatchersTests.java @@ -13,11 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.response; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +package org.springframework.security.test.web.servlet.response; import java.util.ArrayList; import java.util.List; @@ -45,10 +42,15 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = SecurityMockWithAuthoritiesMvcResultMatchersTests.Config.class) @WebAppConfiguration public class SecurityMockWithAuthoritiesMvcResultMatchersTests { + @Autowired private WebApplicationContext context; @@ -56,8 +58,7 @@ public class SecurityMockWithAuthoritiesMvcResultMatchersTests { @Before public void setup() { - mockMvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()) - .build(); + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test @@ -65,35 +66,39 @@ public class SecurityMockWithAuthoritiesMvcResultMatchersTests { List grantedAuthorities = new ArrayList<>(); grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_ADMIN")); grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_SELLER")); - mockMvc.perform(formLogin()) - .andExpect(authenticated().withAuthorities(grantedAuthorities)); + this.mockMvc.perform(formLogin()).andExpect(authenticated().withAuthorities(grantedAuthorities)); } @Test(expected = AssertionError.class) public void withAuthoritiesFailsIfNotAllRoles() throws Exception { List grantedAuthorities = new ArrayList<>(); grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_ADMIN")); - mockMvc.perform(formLogin()).andExpect(authenticated().withAuthorities(grantedAuthorities)); + this.mockMvc.perform(formLogin()).andExpect(authenticated().withAuthorities(grantedAuthorities)); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off + @Override @Bean public UserDetailsService userDetailsService() { + // @formatter:off UserDetails user = User.withDefaultPasswordEncoder().username("user").password("password").roles("ADMIN", "SELLER").build(); return new InMemoryUserDetailsManager(user); + // @formatter:on } - // @formatter:on @RestController static class Controller { + @RequestMapping("/") - public String ok() { + String ok() { return "ok"; } + } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java index 9bb74d7966..675a57dba2 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java @@ -13,97 +13,96 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.setup; +import javax.servlet.Filter; +import javax.servlet.ServletContext; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.security.config.BeanIds; import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder; import org.springframework.web.context.WebApplicationContext; -import javax.servlet.Filter; -import javax.servlet.ServletContext; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class SecurityMockMvcConfigurerTests { + @Mock private Filter filter; + @Mock private Filter beanFilter; + @Mock private ConfigurableMockMvcBuilder builder; + @Mock private WebApplicationContext context; + @Mock private ServletContext servletContext; @Before public void setup() { - when(this.context.getServletContext()).thenReturn(this.servletContext); + given(this.context.getServletContext()).willReturn(this.servletContext); } @Test public void beforeMockMvcCreatedOverrideBean() throws Exception { returnFilterBean(); SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(this.filter); - configurer.afterConfigurerAdded(this.builder); configurer.beforeMockMvcCreated(this.builder, this.context); - assertFilterAdded(this.filter); - verify(this.servletContext).setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, - this.filter); + verify(this.servletContext).setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, this.filter); } @Test public void beforeMockMvcCreatedBean() throws Exception { returnFilterBean(); SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(); - configurer.afterConfigurerAdded(this.builder); configurer.beforeMockMvcCreated(this.builder, this.context); - assertFilterAdded(this.beanFilter); } @Test public void beforeMockMvcCreatedNoBean() throws Exception { SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(this.filter); - configurer.afterConfigurerAdded(this.builder); configurer.beforeMockMvcCreated(this.builder, this.context); - assertFilterAdded(this.filter); } @Test(expected = IllegalStateException.class) public void beforeMockMvcCreatedNoFilter() { SecurityMockMvcConfigurer configurer = new SecurityMockMvcConfigurer(); - configurer.afterConfigurerAdded(this.builder); configurer.beforeMockMvcCreated(this.builder, this.context); } private void assertFilterAdded(Filter filter) { - ArgumentCaptor filterArg = ArgumentCaptor.forClass( - SecurityMockMvcConfigurer.DelegateFilter.class); + ArgumentCaptor filterArg = ArgumentCaptor + .forClass(SecurityMockMvcConfigurer.DelegateFilter.class); verify(this.builder).addFilters(filterArg.capture()); assertThat(filterArg.getValue().getDelegate()).isEqualTo(filter); } private void returnFilterBean() { - when(this.context.containsBean(anyString())).thenReturn(true); - when(this.context.getBean(anyString(), eq(Filter.class))) - .thenReturn(this.beanFilter); + given(this.context.containsBean(anyString())).willReturn(true); + given(this.context.getBean(anyString(), eq(Filter.class))).willReturn(this.beanFilter); } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurersTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurersTests.java index 162114d6ee..089e5dc8b1 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurersTests.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.setup; +import javax.servlet.Filter; + import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; @@ -29,8 +33,6 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import javax.servlet.Filter; - import static org.mockito.Mockito.mock; import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -42,46 +44,43 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @RunWith(SpringRunner.class) @WebAppConfiguration public class SecurityMockMvcConfigurersTests { + @Autowired WebApplicationContext wac; Filter noOpFilter = mock(Filter.class); /** - * Since noOpFilter is first does not continue the chain, security will not be invoked and the status should be OK - * + * Since noOpFilter is first does not continue the chain, security will not be invoked + * and the status should be OK * @throws Exception */ @Test public void applySpringSecurityWhenAddFilterFirstThenFilterFirst() throws Exception { - MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.wac) - .addFilters(this.noOpFilter) - .apply(springSecurity()) - .build(); - - mockMvc.perform(get("/")) - .andExpect(status().isOk()); + MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).addFilters(this.noOpFilter) + .apply(springSecurity()).build(); + mockMvc.perform(get("/")).andExpect(status().isOk()); } /** - * Since noOpFilter is second security will be invoked and the status will be not OK. We know this because if noOpFilter - * were first security would not be invoked sincet noOpFilter does not continue the FilterChain + * Since noOpFilter is second security will be invoked and the status will be not OK. + * We know this because if noOpFilter were first security would not be invoked sincet + * noOpFilter does not continue the FilterChain * @throws Exception */ @Test public void applySpringSecurityWhenAddFilterSecondThenSecurityFirst() throws Exception { - MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.wac) - .apply(springSecurity()) - .addFilters(this.noOpFilter) - .build(); - - mockMvc.perform(get("/")) - .andExpect(status().is4xxClientError()); + MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).apply(springSecurity()) + .addFilters(this.noOpFilter).build(); + mockMvc.perform(get("/")).andExpect(status().is4xxClientError()); } @Configuration @EnableWebMvc @EnableWebSecurity @Import(AuthenticationTestConfiguration.class) - static class Config {} + static class Config { + + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CsrfShowcaseTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CsrfShowcaseTests.java index bf5f71edfe..92a9e392ae 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CsrfShowcaseTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CsrfShowcaseTests.java @@ -13,17 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.csrf; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.test.web.servlet.showcase.csrf; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -37,6 +33,12 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = CsrfShowcaseTests.Config.class) @WebAppConfiguration @@ -49,22 +51,22 @@ public class CsrfShowcaseTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test public void postWithCsrfWorks() throws Exception { - mvc.perform(post("/").with(csrf())).andExpect(status().isNotFound()); + this.mvc.perform(post("/").with(csrf())).andExpect(status().isNotFound()); } @Test public void postWithCsrfWorksWithPut() throws Exception { - mvc.perform(put("/").with(csrf())).andExpect(status().isNotFound()); + this.mvc.perform(put("/").with(csrf())).andExpect(status().isNotFound()); } @Test public void postWithNoCsrfForbidden() throws Exception { - mvc.perform(post("/")).andExpect(status().isForbidden()); + this.mvc.perform(post("/")).andExpect(status().isForbidden()); } @EnableWebSecurity @@ -75,13 +77,15 @@ public class CsrfShowcaseTests { protected void configure(HttpSecurity http) { } - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CustomCsrfShowcaseTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CustomCsrfShowcaseTests.java index dfcdf91bfc..c0382f7eaa 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CustomCsrfShowcaseTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/CustomCsrfShowcaseTests.java @@ -13,18 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.csrf; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.test.web.servlet.showcase.csrf; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; @@ -41,6 +36,13 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = CustomCsrfShowcaseTests.Config.class) @WebAppConfiguration @@ -56,47 +58,49 @@ public class CustomCsrfShowcaseTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .defaultRequest(get("/").with(csrf())).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).defaultRequest(get("/").with(csrf())) + .apply(springSecurity()).build(); } @Test public void postWithCsrfWorks() throws Exception { - mvc.perform(post("/").with(csrf())).andExpect(status().isNotFound()); + this.mvc.perform(post("/").with(csrf())).andExpect(status().isNotFound()); } @Test public void postWithCsrfWorksWithPut() throws Exception { - mvc.perform(put("/").with(csrf())).andExpect(status().isNotFound()); + this.mvc.perform(put("/").with(csrf())).andExpect(status().isNotFound()); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .csrf() .csrfTokenRepository(repo()); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on @Bean - public CsrfTokenRepository repo() { + CsrfTokenRepository repo() { HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository(); repo.setParameterName("custom_csrf"); return repo; } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/DefaultCsrfShowcaseTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/DefaultCsrfShowcaseTests.java index 4554ca660d..136852f546 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/DefaultCsrfShowcaseTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/csrf/DefaultCsrfShowcaseTests.java @@ -13,21 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.csrf; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +package org.springframework.security.test.web.servlet.showcase.csrf; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.web.WebAppConfiguration; @@ -36,6 +33,13 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = DefaultCsrfShowcaseTests.Config.class) @WebAppConfiguration @@ -48,18 +52,18 @@ public class DefaultCsrfShowcaseTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .defaultRequest(get("/").with(csrf())).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).defaultRequest(get("/").with(csrf())) + .apply(springSecurity()).build(); } @Test public void postWithCsrfWorks() throws Exception { - mvc.perform(post("/")).andExpect(status().isNotFound()); + this.mvc.perform(post("/")).andExpect(status().isNotFound()); } @Test public void postWithCsrfWorksWithPut() throws Exception { - mvc.perform(put("/")).andExpect(status().isNotFound()); + this.mvc.perform(put("/")).andExpect(status().isNotFound()); } @EnableWebSecurity @@ -70,13 +74,15 @@ public class DefaultCsrfShowcaseTests { protected void configure(HttpSecurity http) { } - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/AuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/AuthenticationTests.java index 660872ea58..4a3d138737 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/AuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/AuthenticationTests.java @@ -13,23 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.login; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.*; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.*; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*; +package org.springframework.security.test.web.servlet.showcase.login; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.http.MediaType; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; @@ -42,6 +37,15 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = AuthenticationTests.Config.class) @WebAppConfiguration @@ -54,48 +58,46 @@ public class AuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .apply(springSecurity()) - .defaultRequest(get("/").accept(MediaType.TEXT_HTML)) - .build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()) + .defaultRequest(get("/").accept(MediaType.TEXT_HTML)).build(); } @Test public void requiresAuthentication() throws Exception { - mvc.perform(get("/")).andExpect(status().isFound()); + this.mvc.perform(get("/")).andExpect(status().isFound()); } @Test public void httpBasicAuthenticationSuccess() throws Exception { - mvc.perform(get("/secured/butnotfound").with(httpBasic("user", "password"))) - .andExpect(status().isNotFound()) - .andExpect(authenticated().withUsername("user")); + this.mvc.perform(get("/secured/butnotfound").with(httpBasic("user", "password"))) + .andExpect(status().isNotFound()).andExpect(authenticated().withUsername("user")); } @Test public void authenticationSuccess() throws Exception { - mvc.perform(formLogin()).andExpect(status().isFound()) - .andExpect(redirectedUrl("/")) + this.mvc.perform(formLogin()).andExpect(status().isFound()).andExpect(redirectedUrl("/")) .andExpect(authenticated().withUsername("user")); } @Test public void authenticationFailed() throws Exception { - mvc.perform(formLogin().user("user").password("invalid")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/login?error")) - .andExpect(unauthenticated()); + this.mvc.perform(formLogin().user("user").password("invalid")).andExpect(status().isFound()) + .andExpect(redirectedUrl("/login?error")).andExpect(unauthenticated()); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off + + @Override @Bean public UserDetailsService userDetailsService() { + // @formatter:off UserDetails user = User.withDefaultPasswordEncoder().username("user").password("password").roles("USER").build(); return new InMemoryUserDetailsManager(user); + // @formatter:on } - // @formatter:on + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomConfigAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomConfigAuthenticationTests.java index c2a45bc74e..d58e832673 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomConfigAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomConfigAuthenticationTests.java @@ -13,23 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.login; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.*; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.*; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +package org.springframework.security.test.web.servlet.showcase.login; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; @@ -44,6 +39,15 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = CustomConfigAuthenticationTests.Config.class) @WebAppConfiguration @@ -59,31 +63,26 @@ public class CustomConfigAuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test public void authenticationSuccess() throws Exception { - mvc.perform( - formLogin("/authenticate").user("user", "user").password("pass", - "password")).andExpect(status().isFound()) - .andExpect(redirectedUrl("/")) + this.mvc.perform(formLogin("/authenticate").user("user", "user").password("pass", "password")) + .andExpect(status().isFound()).andExpect(redirectedUrl("/")) .andExpect(authenticated().withUsername("user")); } @Test public void withUserSuccess() throws Exception { - mvc.perform(get("/").with(user("user"))) - .andExpect(status().isNotFound()) + this.mvc.perform(get("/").with(user("user"))).andExpect(status().isNotFound()) .andExpect(authenticated().withUsername("user")); } @Test public void authenticationFailed() throws Exception { - mvc.perform( - formLogin("/authenticate").user("user", "notfound").password("pass", - "invalid")).andExpect(status().isFound()) - .andExpect(redirectedUrl("/authenticate?error")) + this.mvc.perform(formLogin("/authenticate").user("user", "notfound").password("pass", "invalid")) + .andExpect(status().isFound()).andExpect(redirectedUrl("/authenticate?error")) .andExpect(unauthenticated()); } @@ -91,9 +90,9 @@ public class CustomConfigAuthenticationTests { @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -105,22 +104,24 @@ public class CustomConfigAuthenticationTests { .usernameParameter("user") .passwordParameter("pass") .loginPage("/authenticate"); + // @formatter:on } - // @formatter:on // @formatter:off + @Override @Bean public UserDetailsService userDetailsService() { UserDetails user = User.withDefaultPasswordEncoder().username("user").password("password").roles("USER").build(); return new InMemoryUserDetailsManager(user); } // @formatter:on - @Bean - public SecurityContextRepository securityContextRepository() { + SecurityContextRepository securityContextRepository() { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); repo.setSpringSecurityContextKey("CUSTOM"); return repo; } + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomLoginRequestBuilderAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomLoginRequestBuilderAuthenticationTests.java index 820f66d49a..a52b469534 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomLoginRequestBuilderAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/login/CustomLoginRequestBuilderAuthenticationTests.java @@ -13,25 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.login; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.*; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +package org.springframework.security.test.web.servlet.showcase.login; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; -import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders; import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.FormLoginRequestBuilder; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -41,6 +38,13 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = CustomLoginRequestBuilderAuthenticationTests.Config.class) @WebAppConfiguration @@ -53,37 +57,32 @@ public class CustomLoginRequestBuilderAuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test public void authenticationSuccess() throws Exception { - mvc.perform(login()) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/")) + this.mvc.perform(login()).andExpect(status().isFound()).andExpect(redirectedUrl("/")) .andExpect(authenticated().withUsername("user")); } @Test public void authenticationFailed() throws Exception { - mvc.perform(login().user("notfound").password("invalid")) - .andExpect(status().isFound()) - .andExpect(redirectedUrl("/authenticate?error")) - .andExpect(unauthenticated()); + this.mvc.perform(login().user("notfound").password("invalid")).andExpect(status().isFound()) + .andExpect(redirectedUrl("/authenticate?error")).andExpect(unauthenticated()); } static FormLoginRequestBuilder login() { - return SecurityMockMvcRequestBuilders.formLogin("/authenticate") - .userParameter("user").passwordParam("pass"); + return formLogin("/authenticate").userParameter("user").passwordParam("pass"); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .anyRequest().authenticated() @@ -92,15 +91,18 @@ public class CustomLoginRequestBuilderAuthenticationTests { .usernameParameter("user") .passwordParameter("pass") .loginPage("/authenticate"); + // @formatter:on } - // @formatter:on // @formatter:off + @Override @Bean public UserDetailsService userDetailsService() { UserDetails user = User.withDefaultPasswordEncoder().username("user").password("password").roles("USER").build(); return new InMemoryUserDetailsManager(user); } // @formatter:on + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/DefaultfSecurityRequestsTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/DefaultfSecurityRequestsTests.java index ea580da186..591042179d 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/DefaultfSecurityRequestsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/DefaultfSecurityRequestsTests.java @@ -13,17 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.secured; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.*; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.test.web.servlet.showcase.secured; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -37,6 +33,14 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.anonymous; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = DefaultfSecurityRequestsTests.Config.class) @WebAppConfiguration @@ -49,15 +53,14 @@ public class DefaultfSecurityRequestsTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .defaultRequest(get("/").with(user("user").roles("ADMIN"))) - .apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context) + .defaultRequest(get("/").with(user("user").roles("ADMIN"))).apply(springSecurity()).build(); } @Test public void requestProtectedUrlWithUser() throws Exception { - mvc.perform(get("/")) - // Ensure we got past Security + this.mvc.perform(get("/")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user")); @@ -65,8 +68,8 @@ public class DefaultfSecurityRequestsTests { @Test public void requestProtectedUrlWithAdmin() throws Exception { - mvc.perform(get("/admin")) - // Ensure we got past Security + this.mvc.perform(get("/admin")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user")); @@ -74,7 +77,7 @@ public class DefaultfSecurityRequestsTests { @Test public void requestProtectedUrlWithAnonymous() throws Exception { - mvc.perform(get("/admin").with(anonymous())) + this.mvc.perform(get("/admin").with(anonymous())) // Ensure we got past Security .andExpect(status().isUnauthorized()) // Ensure it appears we are authenticated with user @@ -85,25 +88,27 @@ public class DefaultfSecurityRequestsTests { @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .httpBasic(); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/SecurityRequestsTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/SecurityRequestsTests.java index e1bb963c04..71f1947496 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/SecurityRequestsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/SecurityRequestsTests.java @@ -13,24 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.secured; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.*; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +package org.springframework.security.test.web.servlet.showcase.secured; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.core.Authentication; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; @@ -42,6 +38,13 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = SecurityRequestsTests.Config.class) @WebAppConfiguration @@ -57,13 +60,13 @@ public class SecurityRequestsTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test public void requestProtectedUrlWithUser() throws Exception { - mvc.perform(get("/").with(user("user"))) - // Ensure we got past Security + this.mvc.perform(get("/").with(user("user"))) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user")); @@ -71,8 +74,8 @@ public class SecurityRequestsTests { @Test public void requestProtectedUrlWithAdmin() throws Exception { - mvc.perform(get("/admin").with(user("admin").roles("ADMIN"))) - // Ensure we got past Security + this.mvc.perform(get("/admin").with(user("admin").roles("ADMIN"))) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with admin .andExpect(authenticated().withUsername("admin")); @@ -80,9 +83,9 @@ public class SecurityRequestsTests { @Test public void requestProtectedUrlWithUserDetails() throws Exception { - UserDetails user = userDetailsService.loadUserByUsername("user"); - mvc.perform(get("/").with(user(user))) - // Ensure we got past Security + UserDetails user = this.userDetailsService.loadUserByUsername("user"); + this.mvc.perform(get("/").with(user(user))) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withAuthenticationPrincipal(user)); @@ -90,10 +93,9 @@ public class SecurityRequestsTests { @Test public void requestProtectedUrlWithAuthentication() throws Exception { - Authentication authentication = new TestingAuthenticationToken("test", "notused", - "ROLE_USER"); - mvc.perform(get("/").with(authentication(authentication))) - // Ensure we got past Security + Authentication authentication = new TestingAuthenticationToken("test", "notused", "ROLE_USER"); + this.mvc.perform(get("/").with(authentication(authentication))) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withAuthentication(authentication)); @@ -103,31 +105,33 @@ public class SecurityRequestsTests { @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .formLogin(); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on @Override @Bean public UserDetailsService userDetailsServiceBean() throws Exception { return super.userDetailsServiceBean(); } + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithAdminRob.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithAdminRob.java index bf65b937c5..0b44126df1 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithAdminRob.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithAdminRob.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.showcase.secured; import java.lang.annotation.Documented; @@ -28,6 +29,7 @@ import org.springframework.security.test.context.support.WithMockUser; @Retention(RetentionPolicy.RUNTIME) @Inherited @Documented -@WithMockUser(value="rob", roles="ADMIN") +@WithMockUser(value = "rob", roles = "ADMIN") public @interface WithAdminRob { + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserAuthenticationTests.java index 46b6df9861..e95bdd1816 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserAuthenticationTests.java @@ -13,16 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.showcase.secured; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers; import org.springframework.test.context.ContextConfiguration; @@ -49,15 +51,15 @@ public class WithUserAuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context) - .apply(SecurityMockMvcConfigurers.springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(SecurityMockMvcConfigurers.springSecurity()) + .build(); } @Test @WithMockUser public void requestProtectedUrlWithUser() throws Exception { - mvc.perform(get("/")) - // Ensure we got past Security + this.mvc.perform(get("/")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user")); @@ -66,8 +68,8 @@ public class WithUserAuthenticationTests { @Test @WithAdminRob public void requestProtectedUrlWithAdminRob() throws Exception { - mvc.perform(get("/")) - // Ensure we got past Security + this.mvc.perform(get("/")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("rob").withRoles("ADMIN")); @@ -76,8 +78,8 @@ public class WithUserAuthenticationTests { @Test @WithMockUser(roles = "ADMIN") public void requestProtectedUrlWithAdmin() throws Exception { - mvc.perform(get("/admin")) - // Ensure we got past Security + this.mvc.perform(get("/admin")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user").withRoles("ADMIN")); @@ -87,25 +89,27 @@ public class WithUserAuthenticationTests { @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .formLogin(); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserClassLevelAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserClassLevelAuthenticationTests.java index 7ebb3e4c51..d9906782c0 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserClassLevelAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserClassLevelAuthenticationTests.java @@ -13,17 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.secured; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.test.web.servlet.showcase.secured; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -39,6 +35,12 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = WithUserClassLevelAuthenticationTests.Config.class) @WebAppConfiguration @@ -52,13 +54,13 @@ public class WithUserClassLevelAuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test public void requestProtectedUrlWithUser() throws Exception { - mvc.perform(get("/")) - // Ensure we got past Security + this.mvc.perform(get("/")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user")); @@ -66,8 +68,8 @@ public class WithUserClassLevelAuthenticationTests { @Test public void requestProtectedUrlWithAdmin() throws Exception { - mvc.perform(get("/admin")) - // Ensure we got past Security + this.mvc.perform(get("/admin")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user").withRoles("ADMIN")); @@ -76,7 +78,7 @@ public class WithUserClassLevelAuthenticationTests { @Test @WithAnonymousUser public void requestProtectedUrlWithAnonymous() throws Exception { - mvc.perform(get("/")) + this.mvc.perform(get("/")) // Ensure did not get past security .andExpect(status().isUnauthorized()) // Ensure not authenticated @@ -87,25 +89,27 @@ public class WithUserClassLevelAuthenticationTests { @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .httpBasic(); + // @formatter:on } - // @formatter:on - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER"); + // @formatter:on } - // @formatter:on + } -} \ No newline at end of file + +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsAuthenticationTests.java index 75ea96fcdd..eda6e95166 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsAuthenticationTests.java @@ -13,16 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.test.web.servlet.showcase.secured; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; -import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +package org.springframework.security.test.web.servlet.showcase.secured; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; @@ -39,6 +36,11 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; +import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(classes = WithUserDetailsAuthenticationTests.Config.class) @WebAppConfiguration @@ -51,14 +53,14 @@ public class WithUserDetailsAuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test @WithUserDetails public void requestProtectedUrlWithUser() throws Exception { - mvc.perform(get("/")) - // Ensure we got past Security + this.mvc.perform(get("/")) + // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user .andExpect(authenticated().withUsername("user")); @@ -67,29 +69,28 @@ public class WithUserDetailsAuthenticationTests { @Test @WithUserDetails("admin") public void requestProtectedUrlWithAdmin() throws Exception { - mvc.perform(get("/admin")) + this.mvc.perform(get("/admin")) // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user - .andExpect( - authenticated().withUsername("admin").withRoles("ADMIN", "USER")); + .andExpect(authenticated().withUsername("admin").withRoles("ADMIN", "USER")); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .formLogin(); + // @formatter:on } - // @formatter:on @Bean @Override @@ -97,14 +98,16 @@ public class WithUserDetailsAuthenticationTests { return super.userDetailsServiceBean(); } - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } - // @formatter:on + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsClassLevelAuthenticationTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsClassLevelAuthenticationTests.java index 56b2f324b8..0fb8741231 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsClassLevelAuthenticationTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/showcase/secured/WithUserDetailsClassLevelAuthenticationTests.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.servlet.showcase.secured; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.test.context.support.WithUserDetails; import org.springframework.test.context.ContextConfiguration; @@ -52,44 +54,42 @@ public class WithUserDetailsClassLevelAuthenticationTests { @Before public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).apply(springSecurity()).build(); + this.mvc = MockMvcBuilders.webAppContextSetup(this.context).apply(springSecurity()).build(); } @Test public void requestRootUrlWithAdmin() throws Exception { - mvc.perform(get("/")) + this.mvc.perform(get("/")) // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user - .andExpect( - authenticated().withUsername("admin").withRoles("ADMIN", "USER")); + .andExpect(authenticated().withUsername("admin").withRoles("ADMIN", "USER")); } @Test public void requestProtectedUrlWithAdmin() throws Exception { - mvc.perform(get("/admin")) + this.mvc.perform(get("/admin")) // Ensure we got past Security .andExpect(status().isNotFound()) // Ensure it appears we are authenticated with user - .andExpect( - authenticated().withUsername("admin").withRoles("ADMIN", "USER")); + .andExpect(authenticated().withUsername("admin").withRoles("ADMIN", "USER")); } @EnableWebSecurity @EnableWebMvc static class Config extends WebSecurityConfigurerAdapter { - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .authorizeRequests() .antMatchers("/admin/**").hasRole("ADMIN") .anyRequest().authenticated() .and() .formLogin(); + // @formatter:on } - // @formatter:on @Bean @Override @@ -97,14 +97,16 @@ public class WithUserDetailsClassLevelAuthenticationTests { return super.userDetailsServiceBean(); } - // @formatter:off @Autowired - public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off auth .inMemoryAuthentication() .withUser("user").password("password").roles("USER").and() .withUser("admin").password("password").roles("USER", "ADMIN"); + // @formatter:on } - // @formatter:on + } + } diff --git a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java index c7aead4ed9..874e6c9f1c 100644 --- a/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java +++ b/test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.test.web.support; import org.junit.After; @@ -42,17 +43,18 @@ import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.security.test.web.support.WebTestUtils.getCsrfTokenRepository; -import static org.springframework.security.test.web.support.WebTestUtils.getSecurityContextRepository; @RunWith(MockitoJUnitRunner.class) public class WebTestUtilsTests { + @Mock private SecurityContextRepository contextRepo; + @Mock private CsrfTokenRepository csrfRepo; private MockHttpServletRequest request; + private ConfigurableApplicationContext context; @Before @@ -69,21 +71,21 @@ public class WebTestUtilsTests { @Test public void getCsrfTokenRepositorytNoWac() { - assertThat(getCsrfTokenRepository(this.request)) + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) .isInstanceOf(HttpSessionCsrfTokenRepository.class); } @Test public void getCsrfTokenRepositorytNoSecurity() { loadConfig(Config.class); - assertThat(getCsrfTokenRepository(this.request)) + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) .isInstanceOf(HttpSessionCsrfTokenRepository.class); } @Test public void getCsrfTokenRepositorytSecurityNoCsrf() { loadConfig(SecurityNoCsrfConfig.class); - assertThat(getCsrfTokenRepository(this.request)) + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)) .isInstanceOf(HttpSessionCsrfTokenRepository.class); } @@ -92,28 +94,27 @@ public class WebTestUtilsTests { CustomSecurityConfig.CONTEXT_REPO = this.contextRepo; CustomSecurityConfig.CSRF_REPO = this.csrfRepo; loadConfig(CustomSecurityConfig.class); - assertThat(getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); + assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo); } // getSecurityContextRepository - @Test public void getSecurityContextRepositoryNoWac() { - assertThat(getSecurityContextRepository(this.request)) + assertThat(WebTestUtils.getSecurityContextRepository(this.request)) .isInstanceOf(HttpSessionSecurityContextRepository.class); } @Test public void getSecurityContextRepositoryNoSecurity() { loadConfig(Config.class); - assertThat(getSecurityContextRepository(this.request)) + assertThat(WebTestUtils.getSecurityContextRepository(this.request)) .isInstanceOf(HttpSessionSecurityContextRepository.class); } @Test public void getSecurityContextRepositorySecurityNoCsrf() { loadConfig(SecurityNoCsrfConfig.class); - assertThat(getSecurityContextRepository(this.request)) + assertThat(WebTestUtils.getSecurityContextRepository(this.request)) .isInstanceOf(HttpSessionSecurityContextRepository.class); } @@ -122,44 +123,34 @@ public class WebTestUtilsTests { CustomSecurityConfig.CONTEXT_REPO = this.contextRepo; CustomSecurityConfig.CSRF_REPO = this.csrfRepo; loadConfig(CustomSecurityConfig.class); - assertThat(getSecurityContextRepository(this.request)).isSameAs(this.contextRepo); + assertThat(WebTestUtils.getSecurityContextRepository(this.request)).isSameAs(this.contextRepo); } // gh-3343 @Test public void findFilterNoMatchingFilters() { loadConfig(PartialSecurityConfig.class); - - assertThat(WebTestUtils.findFilter(this.request, - SecurityContextPersistenceFilter.class)).isNull(); + assertThat(WebTestUtils.findFilter(this.request, SecurityContextPersistenceFilter.class)).isNull(); } @Test public void findFilterNoSpringSecurityFilterChainInContext() { loadConfig(NoSecurityConfig.class); - CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository()); FilterChainProxy springSecurityFilterChain = new FilterChainProxy( new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind)); - this.request.getServletContext().setAttribute( - BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain); - - assertThat(WebTestUtils.findFilter(this.request, toFind.getClass())) - .isEqualTo(toFind); + this.request.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain); + assertThat(WebTestUtils.findFilter(this.request, toFind.getClass())).isEqualTo(toFind); } @Test public void findFilterExplicitWithSecurityFilterInContext() { loadConfig(SecurityConfigWithDefaults.class); - CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository()); FilterChainProxy springSecurityFilterChain = new FilterChainProxy( new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind)); - this.request.getServletContext().setAttribute( - BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain); - - assertThat(WebTestUtils.findFilter(this.request, toFind.getClass())) - .isSameAs(toFind); + this.request.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain); + assertThat(WebTestUtils.findFilter(this.request, toFind.getClass())).isSameAs(toFind); } private void loadConfig(Class config) { @@ -167,12 +158,13 @@ public class WebTestUtilsTests { context.register(config); context.refresh(); this.context = context; - this.request.getServletContext().setAttribute( - WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, context); + this.request.getServletContext().setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, + context); } @Configuration static class Config { + } @EnableWebSecurity @@ -182,44 +174,50 @@ public class WebTestUtilsTests { protected void configure(HttpSecurity http) throws Exception { http.csrf().disable(); } + } @EnableWebSecurity static class CustomSecurityConfig extends WebSecurityConfigurerAdapter { + static CsrfTokenRepository CSRF_REPO; static SecurityContextRepository CONTEXT_REPO; - // @formatter:off @Override protected void configure(HttpSecurity http) throws Exception { + // @formatter:off http .csrf() .csrfTokenRepository(CSRF_REPO) .and() .securityContext() .securityContextRepository(CONTEXT_REPO); + // @formatter:on } - // @formatter:on + } @EnableWebSecurity static class PartialSecurityConfig extends WebSecurityConfigurerAdapter { - // @formatter:off @Override public void configure(HttpSecurity http) { + // @formatter:off http .antMatcher("/willnotmatchthis"); + // @formatter:on } - // @formatter:on + } @Configuration static class NoSecurityConfig { + } @EnableWebSecurity static class SecurityConfigWithDefaults extends WebSecurityConfigurerAdapter { } + } diff --git a/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java index 3e992ad67f..ef59b28453 100644 --- a/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/AuthenticationEntryPoint.java @@ -16,23 +16,21 @@ package org.springframework.security.web; -import org.springframework.security.core.AuthenticationException; -import org.springframework.security.web.access.ExceptionTranslationFilter; - import java.io.IOException; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.web.access.ExceptionTranslationFilter; + /** * Used by {@link ExceptionTranslationFilter} to commence an authentication scheme. * * @author Ben Alex */ public interface AuthenticationEntryPoint { - // ~ Methods - // ======================================================================================================== /** * Commences an authentication scheme. @@ -44,12 +42,11 @@ public interface AuthenticationEntryPoint { *

        * Implementations should modify the headers on the ServletResponse as * necessary to commence the authentication process. - * * @param request that resulted in an AuthenticationException * @param response so that the user agent can begin authentication * @param authException that caused the invocation - * */ - void commence(HttpServletRequest request, HttpServletResponse response, - AuthenticationException authException) throws IOException, ServletException; + void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) + throws IOException, ServletException; + } diff --git a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java index 983f5ce29c..b135699766 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web; import java.io.IOException; @@ -22,7 +23,10 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.Assert; /** * Simple implementation of RedirectStrategy which is the default used throughout @@ -45,15 +49,11 @@ public class DefaultRedirectStrategy implements RedirectStrategy { * information (HTTP or HTTPS), so will cause problems if a redirect is being * performed to change to HTTPS, for example. */ - public void sendRedirect(HttpServletRequest request, HttpServletResponse response, - String url) throws IOException { + @Override + public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException { String redirectUrl = calculateRedirectUrl(request.getContextPath(), url); redirectUrl = response.encodeRedirectURL(redirectUrl); - - if (logger.isDebugEnabled()) { - logger.debug("Redirecting to '" + redirectUrl + "'"); - } - + this.logger.debug(LogMessage.format("Redirecting to '%s'", redirectUrl)); response.sendRedirect(redirectUrl); } @@ -62,30 +62,20 @@ public class DefaultRedirectStrategy implements RedirectStrategy { if (isContextRelative()) { return url; } - else { - return contextPath + url; - } + return contextPath + url; } - // Full URL, including http(s):// - if (!isContextRelative()) { return url; } - - if (!url.contains(contextPath)) { - throw new IllegalArgumentException("The fully qualified URL does not include context path."); - } - + Assert.isTrue(url.contains(contextPath), "The fully qualified URL does not include context path."); // Calculate the relative URL from the fully qualified URL, minus the last // occurrence of the scheme and base context. - url = url.substring(url.lastIndexOf("://") + 3); // strip off scheme + url = url.substring(url.lastIndexOf("://") + 3); url = url.substring(url.indexOf(contextPath) + contextPath.length()); - if (url.length() > 1 && url.charAt(0) == '/') { url = url.substring(1); } - return url; } @@ -98,10 +88,11 @@ public class DefaultRedirectStrategy implements RedirectStrategy { } /** - * Returns true, if the redirection URL should be calculated - * minus the protocol and context path (defaults to false). + * Returns true, if the redirection URL should be calculated minus the + * protocol and context path (defaults to false). */ protected boolean isContextRelative() { - return contextRelative; + return this.contextRelative; } + } diff --git a/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java b/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java index d4e5a0e4fa..6e52979b6e 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java +++ b/web/src/main/java/org/springframework/security/web/DefaultSecurityFilterChain.java @@ -13,26 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.security.web.util.matcher.RequestMatcher; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import javax.servlet.Filter; import javax.servlet.http.HttpServletRequest; -import java.util.*; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; +import org.springframework.security.web.util.matcher.RequestMatcher; /** * Standard implementation of {@code SecurityFilterChain}. * * @author Luke Taylor - * * @since 3.1 */ public final class DefaultSecurityFilterChain implements SecurityFilterChain { + private static final Log logger = LogFactory.getLog(DefaultSecurityFilterChain.class); + private final RequestMatcher requestMatcher; + private final List filters; public DefaultSecurityFilterChain(RequestMatcher requestMatcher, Filter... filters) { @@ -40,25 +48,28 @@ public final class DefaultSecurityFilterChain implements SecurityFilterChain { } public DefaultSecurityFilterChain(RequestMatcher requestMatcher, List filters) { - logger.info("Creating filter chain: " + requestMatcher + ", " + filters); + logger.info(LogMessage.format("Creating filter chain: %s, %s", requestMatcher, filters)); this.requestMatcher = requestMatcher; this.filters = new ArrayList<>(filters); } public RequestMatcher getRequestMatcher() { - return requestMatcher; + return this.requestMatcher; } + @Override public List getFilters() { - return filters; + return this.filters; } + @Override public boolean matches(HttpServletRequest request) { - return requestMatcher.matches(request); + return this.requestMatcher.matches(request); } @Override public String toString() { - return "[ " + requestMatcher + ", " + filters + "]"; + return "[ " + this.requestMatcher + ", " + this.filters + "]"; } + } diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index a27c4dd324..6a71c73e26 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -16,20 +16,10 @@ package org.springframework.security.web; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.web.firewall.DefaultRequestRejectedHandler; -import org.springframework.security.web.firewall.FirewalledRequest; -import org.springframework.security.web.firewall.HttpFirewall; -import org.springframework.security.web.firewall.RequestRejectedException; -import org.springframework.security.web.firewall.RequestRejectedHandler; -import org.springframework.security.web.firewall.StrictHttpFirewall; -import org.springframework.security.web.util.matcher.RequestMatcher; -import org.springframework.security.web.util.UrlUtils; -import org.springframework.util.Assert; -import org.springframework.web.filter.DelegatingFilterProxy; -import org.springframework.web.filter.GenericFilterBean; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import javax.servlet.Filter; import javax.servlet.FilterChain; @@ -38,8 +28,23 @@ import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.util.*; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.firewall.DefaultRequestRejectedHandler; +import org.springframework.security.web.firewall.FirewalledRequest; +import org.springframework.security.web.firewall.HttpFirewall; +import org.springframework.security.web.firewall.RequestRejectedException; +import org.springframework.security.web.firewall.RequestRejectedHandler; +import org.springframework.security.web.firewall.StrictHttpFirewall; +import org.springframework.security.web.util.UrlUtils; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.web.filter.DelegatingFilterProxy; +import org.springframework.web.filter.GenericFilterBean; /** * Delegates {@code Filter} requests to a list of Spring-managed filter beans. As of @@ -59,9 +64,9 @@ import java.util.*; * and a list of filters which should be applied to matching requests. Most applications * will only contain a single filter chain, and if you are using the namespace, you don't * have to set the chains explicitly. If you require finer-grained control, you can make - * use of the {@code } namespace element. This defines a URI pattern - * and the list of filters (as comma-separated bean names) which should be applied to - * requests which match the pattern. An example configuration might look like this: + * use of the {@code } namespace element. This defines a URI pattern and the + * list of filters (as comma-separated bean names) which should be applied to requests + * which match the pattern. An example configuration might look like this: * *

          *  <bean id="myfilterChainProxy" class="org.springframework.security.web.FilterChainProxy">
        @@ -136,16 +141,10 @@ import java.util.*;
          * @author Rob Winch
          */
         public class FilterChainProxy extends GenericFilterBean {
        -	// ~ Static fields/initializers
        -	// =====================================================================================
         
         	private static final Log logger = LogFactory.getLog(FilterChainProxy.class);
         
        -	// ~ Instance fields
        -	// ================================================================================================
        -
        -	private final static String FILTER_APPLIED = FilterChainProxy.class.getName().concat(
        -			".APPLIED");
        +	private static final String FILTER_APPLIED = FilterChainProxy.class.getName().concat(".APPLIED");
         
         	private List filterChains;
         
        @@ -155,9 +154,6 @@ public class FilterChainProxy extends GenericFilterBean {
         
         	private RequestRejectedHandler requestRejectedHandler = new DefaultRequestRejectedHandler();
         
        -	// ~ Methods
        -	// ========================================================================================================
        -
         	public FilterChainProxy() {
         	}
         
        @@ -171,83 +167,67 @@ public class FilterChainProxy extends GenericFilterBean {
         
         	@Override
         	public void afterPropertiesSet() {
        -		filterChainValidator.validate(this);
        +		this.filterChainValidator.validate(this);
         	}
         
         	@Override
        -	public void doFilter(ServletRequest request, ServletResponse response,
        -			FilterChain chain) throws IOException, ServletException {
        +	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
        +			throws IOException, ServletException {
         		boolean clearContext = request.getAttribute(FILTER_APPLIED) == null;
        -		if (clearContext) {
        -			try {
        -				request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
        -				doFilterInternal(request, response, chain);
        -			} catch (RequestRejectedException e) {
        -				this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, e);
        -			}
        -			finally {
        -				SecurityContextHolder.clearContext();
        -				request.removeAttribute(FILTER_APPLIED);
        -			}
        -		}
        -		else {
        +		if (!clearContext) {
         			doFilterInternal(request, response, chain);
        +			return;
        +		}
        +		try {
        +			request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
        +			doFilterInternal(request, response, chain);
        +		}
        +		catch (RequestRejectedException ex) {
        +			this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
        +		}
        +		finally {
        +			SecurityContextHolder.clearContext();
        +			request.removeAttribute(FILTER_APPLIED);
         		}
         	}
         
        -	private void doFilterInternal(ServletRequest request, ServletResponse response,
        -			FilterChain chain) throws IOException, ServletException {
        -
        -		FirewalledRequest fwRequest = firewall
        -				.getFirewalledRequest((HttpServletRequest) request);
        -		HttpServletResponse fwResponse = firewall
        -				.getFirewalledResponse((HttpServletResponse) response);
        -
        -		List filters = getFilters(fwRequest);
        -
        +	private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain)
        +			throws IOException, ServletException {
        +		FirewalledRequest firewallRequest = this.firewall.getFirewalledRequest((HttpServletRequest) request);
        +		HttpServletResponse firewallResponse = this.firewall.getFirewalledResponse((HttpServletResponse) response);
        +		List filters = getFilters(firewallRequest);
         		if (filters == null || filters.size() == 0) {
        -			if (logger.isDebugEnabled()) {
        -				logger.debug(UrlUtils.buildRequestUrl(fwRequest)
        -						+ (filters == null ? " has no matching filters"
        -								: " has an empty filter list"));
        -			}
        -
        -			fwRequest.reset();
        -
        -			chain.doFilter(fwRequest, fwResponse);
        -
        +			logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(firewallRequest)
        +					+ ((filters != null) ? " has an empty filter list" : " has no matching filters")));
        +			firewallRequest.reset();
        +			chain.doFilter(firewallRequest, firewallResponse);
         			return;
         		}
        -
        -		VirtualFilterChain vfc = new VirtualFilterChain(fwRequest, chain, filters);
        -		vfc.doFilter(fwRequest, fwResponse);
        +		VirtualFilterChain virtualFilterChain = new VirtualFilterChain(firewallRequest, chain, filters);
        +		virtualFilterChain.doFilter(firewallRequest, firewallResponse);
         	}
         
         	/**
         	 * Returns the first filter chain matching the supplied URL.
        -	 *
         	 * @param request the request to match
         	 * @return an ordered array of Filters defining the filter chain
         	 */
         	private List getFilters(HttpServletRequest request) {
        -		for (SecurityFilterChain chain : filterChains) {
        +		for (SecurityFilterChain chain : this.filterChains) {
         			if (chain.matches(request)) {
         				return chain.getFilters();
         			}
         		}
        -
         		return null;
         	}
         
         	/**
         	 * Convenience method, mainly for testing.
        -	 *
         	 * @param url the URL
         	 * @return matching filter list
         	 */
         	public List getFilters(String url) {
        -		return getFilters(firewall.getFirewalledRequest((new FilterInvocation(url, "GET")
        -				.getRequest())));
        +		return getFilters(this.firewall.getFirewalledRequest((new FilterInvocation(url, "GET").getRequest())));
         	}
         
         	/**
        @@ -255,13 +235,12 @@ public class FilterChainProxy extends GenericFilterBean {
         	 * applied to incoming requests.
         	 */
         	public List getFilterChains() {
        -		return Collections.unmodifiableList(filterChains);
        +		return Collections.unmodifiableList(this.filterChains);
         	}
         
         	/**
         	 * Used (internally) to specify a validation strategy for the filters in each
         	 * configured chain.
        -	 *
         	 * @param filterChainValidator the validator instance which will be invoked on during
         	 * initialization to check the {@code FilterChainProxy} instance.
         	 */
        @@ -273,7 +252,6 @@ public class FilterChainProxy extends GenericFilterBean {
         	 * Sets the "firewall" implementation which will be used to validate and wrap (or
         	 * potentially reject) the incoming requests. The default implementation should be
         	 * satisfactory for most requirements.
        -	 *
         	 * @param firewall
         	 */
         	public void setFirewall(HttpFirewall firewall) {
        @@ -281,10 +259,10 @@ public class FilterChainProxy extends GenericFilterBean {
         	}
         
         	/**
        -	 * Sets the {@link RequestRejectedHandler} to be used for requests rejected by the firewall.
        -	 *
        -	 * @since 5.2
        +	 * Sets the {@link RequestRejectedHandler} to be used for requests rejected by the
        +	 * firewall.
         	 * @param requestRejectedHandler the {@link RequestRejectedHandler}
        +	 * @since 5.2
         	 */
         	public void setRequestRejectedHandler(RequestRejectedHandler requestRejectedHandler) {
         		Assert.notNull(requestRejectedHandler, "requestRejectedHandler may not be null");
        @@ -296,28 +274,29 @@ public class FilterChainProxy extends GenericFilterBean {
         		StringBuilder sb = new StringBuilder();
         		sb.append("FilterChainProxy[");
         		sb.append("Filter Chains: ");
        -		sb.append(filterChains);
        +		sb.append(this.filterChains);
         		sb.append("]");
        -
         		return sb.toString();
         	}
         
        -	// ~ Inner Classes
        -	// ==================================================================================================
        -
         	/**
         	 * Internal {@code FilterChain} implementation that is used to pass a request through
         	 * the additional internal list of filters which match the request.
         	 */
        -	private static class VirtualFilterChain implements FilterChain {
        +	private static final class VirtualFilterChain implements FilterChain {
        +
         		private final FilterChain originalChain;
        +
         		private final List additionalFilters;
        +
         		private final FirewalledRequest firewalledRequest;
        +
         		private final int size;
        +
         		private int currentPosition = 0;
         
        -		private VirtualFilterChain(FirewalledRequest firewalledRequest,
        -				FilterChain chain, List additionalFilters) {
        +		private VirtualFilterChain(FirewalledRequest firewalledRequest, FilterChain chain,
        +				List additionalFilters) {
         			this.originalChain = chain;
         			this.additionalFilters = additionalFilters;
         			this.size = additionalFilters.size();
        @@ -325,44 +304,37 @@ public class FilterChainProxy extends GenericFilterBean {
         		}
         
         		@Override
        -		public void doFilter(ServletRequest request, ServletResponse response)
        -				throws IOException, ServletException {
        -			if (currentPosition == size) {
        -				if (logger.isDebugEnabled()) {
        -					logger.debug(UrlUtils.buildRequestUrl(firewalledRequest)
        -							+ " reached end of additional filter chain; proceeding with original chain");
        -				}
        -
        +		public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
        +			if (this.currentPosition == this.size) {
        +				logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest)
        +						+ " reached end of additional filter chain; proceeding with original chain"));
         				// Deactivate path stripping as we exit the security filter chain
         				this.firewalledRequest.reset();
        -
        -				originalChain.doFilter(request, response);
        -			}
        -			else {
        -				currentPosition++;
        -
        -				Filter nextFilter = additionalFilters.get(currentPosition - 1);
        -
        -				if (logger.isDebugEnabled()) {
        -					logger.debug(UrlUtils.buildRequestUrl(firewalledRequest)
        -							+ " at position " + currentPosition + " of " + size
        -							+ " in additional filter chain; firing Filter: '"
        -							+ nextFilter.getClass().getSimpleName() + "'");
        -				}
        -
        -				nextFilter.doFilter(request, response, this);
        +				this.originalChain.doFilter(request, response);
        +				return;
         			}
        +			this.currentPosition++;
        +			Filter nextFilter = this.additionalFilters.get(this.currentPosition - 1);
        +			logger.debug(LogMessage.of(() -> UrlUtils.buildRequestUrl(this.firewalledRequest) + " at position "
        +					+ this.currentPosition + " of " + this.size + " in additional filter chain; firing Filter: '"
        +					+ nextFilter.getClass().getSimpleName() + "'"));
        +			nextFilter.doFilter(request, response, this);
         		}
        +
         	}
         
         	public interface FilterChainValidator {
        +
         		void validate(FilterChainProxy filterChainProxy);
        +
         	}
         
         	private static class NullFilterChainValidator implements FilterChainValidator {
        +
         		@Override
         		public void validate(FilterChainProxy filterChainProxy) {
         		}
        +
         	}
         
         }
        diff --git a/web/src/main/java/org/springframework/security/web/FilterInvocation.java b/web/src/main/java/org/springframework/security/web/FilterInvocation.java
        index 97061e1872..1062c4eedc 100644
        --- a/web/src/main/java/org/springframework/security/web/FilterInvocation.java
        +++ b/web/src/main/java/org/springframework/security/web/FilterInvocation.java
        @@ -37,6 +37,7 @@ import javax.servlet.http.HttpServletResponse;
         
         import org.springframework.http.HttpHeaders;
         import org.springframework.security.web.util.UrlUtils;
        +import org.springframework.util.Assert;
         
         /**
          * Holds objects associated with a HTTP filter.
        @@ -53,28 +54,19 @@ import org.springframework.security.web.util.UrlUtils;
          * @author Rob Winch
          */
         public class FilterInvocation {
        -	// ~ Static fields
        -	// ==================================================================================================
        +
         	static final FilterChain DUMMY_CHAIN = (req, res) -> {
         		throw new UnsupportedOperationException("Dummy filter chain");
         	};
         
        -	// ~ Instance fields
        -	// ================================================================================================
        -
         	private FilterChain chain;
        +
         	private HttpServletRequest request;
        +
         	private HttpServletResponse response;
         
        -	// ~ Constructors
        -	// ===================================================================================================
        -
        -	public FilterInvocation(ServletRequest request, ServletResponse response,
        -			FilterChain chain) {
        -		if ((request == null) || (response == null) || (chain == null)) {
        -			throw new IllegalArgumentException("Cannot pass null values to constructor");
        -		}
        -
        +	public FilterInvocation(ServletRequest request, ServletResponse response, FilterChain chain) {
        +		Assert.isTrue(request != null && response != null && chain != null, "Cannot pass null values to constructor");
         		this.request = (HttpServletRequest) request;
         		this.response = (HttpServletResponse) response;
         		this.chain = chain;
        @@ -88,25 +80,18 @@ public class FilterInvocation {
         		this(contextPath, servletPath, null, null, method);
         	}
         
        -	public FilterInvocation(String contextPath, String servletPath, String pathInfo,
        -			String query, String method) {
        +	public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) {
         		DummyRequest request = new DummyRequest();
        -		if (contextPath == null) {
        -			contextPath = "/cp";
        -		}
        +		contextPath = (contextPath != null) ? contextPath : "/cp";
         		request.setContextPath(contextPath);
         		request.setServletPath(servletPath);
        -		request.setRequestURI(
        -				contextPath + servletPath + (pathInfo == null ? "" : pathInfo));
        +		request.setRequestURI(contextPath + servletPath + ((pathInfo != null) ? pathInfo : ""));
         		request.setPathInfo(pathInfo);
         		request.setQueryString(query);
         		request.setMethod(method);
         		this.request = request;
         	}
         
        -	// ~ Methods
        -	// ========================================================================================================
        -
         	public FilterChain getChain() {
         		return this.chain;
         	}
        @@ -116,7 +101,6 @@ public class FilterInvocation {
         	 * 

        * The returned URL does not reflect the port number determined from a * {@link org.springframework.security.web.PortResolver}. - * * @return the full URL of this request */ public String getFullRequestUrl() { @@ -133,7 +117,6 @@ public class FilterInvocation { /** * Obtains the web application-specific fragment of the URL. - * * @return the URL, excluding any server name, context path or servlet path */ public String getRequestUrl() { @@ -152,188 +135,191 @@ public class FilterInvocation { public String toString() { return "FilterInvocation: URL: " + getRequestUrl(); } -} -class DummyRequest extends HttpServletRequestWrapper { - private static final HttpServletRequest UNSUPPORTED_REQUEST = (HttpServletRequest) Proxy - .newProxyInstance(DummyRequest.class.getClassLoader(), - new Class[] { HttpServletRequest.class }, - new UnsupportedOperationExceptionInvocationHandler()); + static class DummyRequest extends HttpServletRequestWrapper { - private String requestURI; - private String contextPath = ""; - private String servletPath; - private String pathInfo; - private String queryString; - private String method; - private final HttpHeaders headers = new HttpHeaders(); - private final Map parameters = new LinkedHashMap<>(); + private static final HttpServletRequest UNSUPPORTED_REQUEST = (HttpServletRequest) Proxy.newProxyInstance( + DummyRequest.class.getClassLoader(), new Class[] { HttpServletRequest.class }, + new UnsupportedOperationExceptionInvocationHandler()); - DummyRequest() { - super(UNSUPPORTED_REQUEST); - } + private String requestURI; - public String getCharacterEncoding() { - return "UTF-8"; - } + private String contextPath = ""; - public Object getAttribute(String attributeName) { - return null; - } + private String servletPath; - public void setRequestURI(String requestURI) { - this.requestURI = requestURI; - } + private String pathInfo; - public void setPathInfo(String pathInfo) { - this.pathInfo = pathInfo; - } + private String queryString; - @Override - public String getRequestURI() { - return this.requestURI; - } + private String method; - public void setContextPath(String contextPath) { - this.contextPath = contextPath; - } + private final HttpHeaders headers = new HttpHeaders(); - @Override - public String getContextPath() { - return this.contextPath; - } + private final Map parameters = new LinkedHashMap<>(); - public void setServletPath(String servletPath) { - this.servletPath = servletPath; - } - - @Override - public String getServletPath() { - return this.servletPath; - } - - public void setMethod(String method) { - this.method = method; - } - - @Override - public String getMethod() { - return this.method; - } - - @Override - public String getPathInfo() { - return this.pathInfo; - } - - @Override - public String getQueryString() { - return this.queryString; - } - - public void setQueryString(String queryString) { - this.queryString = queryString; - } - - @Override - public String getServerName() { - return null; - } - - @Override - public String getHeader(String name) { - return this.headers.getFirst(name); - } - - @Override - public Enumeration getHeaders(String name) { - return Collections.enumeration(this.headers.get(name)); - } - - @Override - public Enumeration getHeaderNames() { - return Collections.enumeration(this.headers.keySet()); - } - - @Override - public int getIntHeader(String name) { - String value = this.headers.getFirst(name); - if (value == null ) { - return -1; + DummyRequest() { + super(UNSUPPORTED_REQUEST); } - else { + + @Override + public String getCharacterEncoding() { + return "UTF-8"; + } + + @Override + public Object getAttribute(String attributeName) { + return null; + } + + void setRequestURI(String requestURI) { + this.requestURI = requestURI; + } + + void setPathInfo(String pathInfo) { + this.pathInfo = pathInfo; + } + + @Override + public String getRequestURI() { + return this.requestURI; + } + + void setContextPath(String contextPath) { + this.contextPath = contextPath; + } + + @Override + public String getContextPath() { + return this.contextPath; + } + + void setServletPath(String servletPath) { + this.servletPath = servletPath; + } + + @Override + public String getServletPath() { + return this.servletPath; + } + + void setMethod(String method) { + this.method = method; + } + + @Override + public String getMethod() { + return this.method; + } + + @Override + public String getPathInfo() { + return this.pathInfo; + } + + @Override + public String getQueryString() { + return this.queryString; + } + + void setQueryString(String queryString) { + this.queryString = queryString; + } + + @Override + public String getServerName() { + return null; + } + + @Override + public String getHeader(String name) { + return this.headers.getFirst(name); + } + + @Override + public Enumeration getHeaders(String name) { + return Collections.enumeration(this.headers.get(name)); + } + + @Override + public Enumeration getHeaderNames() { + return Collections.enumeration(this.headers.keySet()); + } + + @Override + public int getIntHeader(String name) { + String value = this.headers.getFirst(name); + if (value == null) { + return -1; + } return Integer.parseInt(value); } - } - public void addHeader(String name, String value) { - this.headers.add(name, value); - } - - @Override - public String getParameter(String name) { - String[] arr = this.parameters.get(name); - return (arr != null && arr.length > 0 ? arr[0] : null); - } - - @Override - public Map getParameterMap() { - return Collections.unmodifiableMap(this.parameters); - } - - @Override - public Enumeration getParameterNames() { - return Collections.enumeration(this.parameters.keySet()); - } - - @Override - public String[] getParameterValues(String name) { - return this.parameters.get(name); - } - - public void setParameter(String name, String... values) { - this.parameters.put(name, values); - } -} - -final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler { - private static final float JAVA_VERSION = Float.parseFloat(System.getProperty("java.class.version", "52")); - - public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - if (method.isDefault()) { - return invokeDefaultMethod(proxy, method, args); + void addHeader(String name, String value) { + this.headers.add(name, value); } - throw new UnsupportedOperationException(method + " is not supported"); - } - private Object invokeDefaultMethod(Object proxy, Method method, Object[] args) throws Throwable { - if (isJdk8OrEarlier()) { - return invokeDefaultMethodForJdk8(proxy, method, args); + @Override + public String getParameter(String name) { + String[] array = this.parameters.get(name); + return (array != null && array.length > 0) ? array[0] : null; } - return MethodHandles.lookup() - .findSpecial( - method.getDeclaringClass(), - method.getName(), - MethodType.methodType(method.getReturnType(), new Class[0]), - method.getDeclaringClass() - ) - .bindTo(proxy) - .invokeWithArguments(args); + + @Override + public Map getParameterMap() { + return Collections.unmodifiableMap(this.parameters); + } + + @Override + public Enumeration getParameterNames() { + return Collections.enumeration(this.parameters.keySet()); + } + + @Override + public String[] getParameterValues(String name) { + return this.parameters.get(name); + } + + void setParameter(String name, String... values) { + this.parameters.put(name, values); + } + } - private Object invokeDefaultMethodForJdk8(Object proxy, Method method, Object[] args) throws Throwable { - Constructor constructor = Lookup.class.getDeclaredConstructor(Class.class); - constructor.setAccessible(true); + static final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler { + + private static final float JAVA_VERSION = Float.parseFloat(System.getProperty("java.class.version", "52")); + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (method.isDefault()) { + return invokeDefaultMethod(proxy, method, args); + } + throw new UnsupportedOperationException(method + " is not supported"); + } + + private Object invokeDefaultMethod(Object proxy, Method method, Object[] args) throws Throwable { + if (isJdk8OrEarlier()) { + return invokeDefaultMethodForJdk8(proxy, method, args); + } + return MethodHandles.lookup() + .findSpecial(method.getDeclaringClass(), method.getName(), + MethodType.methodType(method.getReturnType(), new Class[0]), method.getDeclaringClass()) + .bindTo(proxy).invokeWithArguments(args); + } + + private Object invokeDefaultMethodForJdk8(Object proxy, Method method, Object[] args) throws Throwable { + Constructor constructor = Lookup.class.getDeclaredConstructor(Class.class); + constructor.setAccessible(true); + Class clazz = method.getDeclaringClass(); + return constructor.newInstance(clazz).in(clazz).unreflectSpecial(method, clazz).bindTo(proxy) + .invokeWithArguments(args); + } + + private boolean isJdk8OrEarlier() { + return JAVA_VERSION <= 52; + } - Class clazz = method.getDeclaringClass(); - return constructor.newInstance(clazz) - .in(clazz) - .unreflectSpecial(method, clazz) - .bindTo(proxy) - .invokeWithArguments(args); } - private boolean isJdk8OrEarlier() { - return JAVA_VERSION <= 52; - } } diff --git a/web/src/main/java/org/springframework/security/web/PortMapper.java b/web/src/main/java/org/springframework/security/web/PortMapper.java index ab21cc86a6..eafe8597c6 100644 --- a/web/src/main/java/org/springframework/security/web/PortMapper.java +++ b/web/src/main/java/org/springframework/security/web/PortMapper.java @@ -23,17 +23,13 @@ package org.springframework.security.web; * @author Ben Alex */ public interface PortMapper { - // ~ Methods - // ======================================================================================================== /** * Locates the HTTP port associated with the specified HTTPS port. *

        * Returns null if unknown. *

        - * * @param httpsPort - * * @return the HTTP port or null if unknown */ Integer lookupHttpPort(Integer httpsPort); @@ -43,10 +39,9 @@ public interface PortMapper { *

        * Returns null if unknown. *

        - * * @param httpPort - * * @return the HTTPS port or null if unknown */ Integer lookupHttpsPort(Integer httpPort); + } diff --git a/web/src/main/java/org/springframework/security/web/PortMapperImpl.java b/web/src/main/java/org/springframework/security/web/PortMapperImpl.java index f9144ecd52..947cc93c62 100644 --- a/web/src/main/java/org/springframework/security/web/PortMapperImpl.java +++ b/web/src/main/java/org/springframework/security/web/PortMapperImpl.java @@ -32,23 +32,15 @@ import org.springframework.util.Assert; * @author colin sampaleanu */ public class PortMapperImpl implements PortMapper { - // ~ Instance fields - // ================================================================================================ private final Map httpsPortMappings; - // ~ Constructors - // =================================================================================================== - public PortMapperImpl() { this.httpsPortMappings = new HashMap<>(); this.httpsPortMappings.put(80, 443); this.httpsPortMappings.put(8080, 8443); } - // ~ Methods - // ======================================================================================================== - /** * Returns the translated (Integer -> Integer) version of the original port mapping * specified via setHttpsPortMapping() @@ -57,16 +49,17 @@ public class PortMapperImpl implements PortMapper { return this.httpsPortMappings; } + @Override public Integer lookupHttpPort(Integer httpsPort) { for (Integer httpPort : this.httpsPortMappings.keySet()) { if (this.httpsPortMappings.get(httpPort).equals(httpsPort)) { return httpPort; } } - return null; } + @Override public Integer lookupHttpsPort(Integer httpPort) { return this.httpsPortMappings.get(httpPort); } @@ -84,38 +77,29 @@ public class PortMapperImpl implements PortMapper { * </map> * </property> *
        - * * @param newMappings A Map consisting of String keys and String values, where for * each entry the key is the string representation of an integer HTTP port number, and * the value is the string representation of the corresponding integer HTTPS port * number. - * * @throws IllegalArgumentException if input map does not consist of String keys and * values, each representing an integer port number in the range 1-65535 for that * mapping. */ public void setPortMappings(Map newMappings) { - Assert.notNull(newMappings, - "A valid list of HTTPS port mappings must be provided"); - + Assert.notNull(newMappings, "A valid list of HTTPS port mappings must be provided"); this.httpsPortMappings.clear(); - for (Map.Entry entry : newMappings.entrySet()) { Integer httpPort = Integer.valueOf(entry.getKey()); Integer httpsPort = Integer.valueOf(entry.getValue()); - - if ((httpPort < 1) || (httpPort > 65535) - || (httpsPort < 1) || (httpsPort > 65535)) { - throw new IllegalArgumentException( - "one or both ports out of legal range: " + httpPort + ", " - + httpsPort); - } - + Assert.isTrue(isInPortRange(httpPort) && isInPortRange(httpsPort), + () -> "one or both ports out of legal range: " + httpPort + ", " + httpsPort); this.httpsPortMappings.put(httpPort, httpsPort); } - - if (this.httpsPortMappings.size() < 1) { - throw new IllegalArgumentException("must map at least one port"); - } + Assert.isTrue(!this.httpsPortMappings.isEmpty(), "must map at least one port"); } + + private boolean isInPortRange(int port) { + return port >= 1 && port <= 65535; + } + } diff --git a/web/src/main/java/org/springframework/security/web/PortResolver.java b/web/src/main/java/org/springframework/security/web/PortResolver.java index ca2417f4ad..b83cfc976c 100644 --- a/web/src/main/java/org/springframework/security/web/PortResolver.java +++ b/web/src/main/java/org/springframework/security/web/PortResolver.java @@ -30,15 +30,12 @@ import javax.servlet.ServletRequest; * @author Ben Alex */ public interface PortResolver { - // ~ Methods - // ======================================================================================================== /** * Indicates the port the ServletRequest was received on. - * * @param request that the method should lookup the port for - * * @return the port the request was received on */ int getServerPort(ServletRequest request); + } diff --git a/web/src/main/java/org/springframework/security/web/PortResolverImpl.java b/web/src/main/java/org/springframework/security/web/PortResolverImpl.java index d77bb353a0..faa01d83c3 100644 --- a/web/src/main/java/org/springframework/security/web/PortResolverImpl.java +++ b/web/src/main/java/org/springframework/security/web/PortResolverImpl.java @@ -16,10 +16,10 @@ package org.springframework.security.web; -import org.springframework.util.Assert; - import javax.servlet.ServletRequest; +import org.springframework.util.Assert; + /** * Concrete implementation of {@link PortResolver} that obtains the port from * ServletRequest.getServerPort(). @@ -35,42 +35,34 @@ import javax.servlet.ServletRequest; * @author Ben Alex */ public class PortResolverImpl implements PortResolver { - // ~ Instance fields - // ================================================================================================ private PortMapper portMapper = new PortMapperImpl(); - // ~ Methods - // ======================================================================================================== - public PortMapper getPortMapper() { - return portMapper; + return this.portMapper; } + @Override public int getServerPort(ServletRequest request) { int serverPort = request.getServerPort(); - Integer portLookup = null; - String scheme = request.getScheme().toLowerCase(); + Integer mappedPort = getMappedPort(serverPort, scheme); + return (mappedPort != null) ? mappedPort : serverPort; + } + private Integer getMappedPort(int serverPort, String scheme) { if ("http".equals(scheme)) { - portLookup = portMapper.lookupHttpPort(serverPort); - + return this.portMapper.lookupHttpPort(serverPort); } - else if ("https".equals(scheme)) { - portLookup = portMapper.lookupHttpsPort(serverPort); + if ("https".equals(scheme)) { + return this.portMapper.lookupHttpsPort(serverPort); } - - if (portLookup != null) { - // IE 6 bug - serverPort = portLookup; - } - - return serverPort; + return null; } public void setPortMapper(PortMapper portMapper) { Assert.notNull(portMapper, "portMapper cannot be null"); this.portMapper = portMapper; } + } diff --git a/web/src/main/java/org/springframework/security/web/RedirectStrategy.java b/web/src/main/java/org/springframework/security/web/RedirectStrategy.java index 7bd0124459..8dc718a46d 100644 --- a/web/src/main/java/org/springframework/security/web/RedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/RedirectStrategy.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web; import java.io.IOException; @@ -35,6 +36,6 @@ public interface RedirectStrategy { * @param response the response to redirect * @param url the target URL to redirect to, for example "/login" */ - void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) - throws IOException; + void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException; + } diff --git a/web/src/main/java/org/springframework/security/web/SecurityFilterChain.java b/web/src/main/java/org/springframework/security/web/SecurityFilterChain.java index 446de5547c..bc919ef91b 100644 --- a/web/src/main/java/org/springframework/security/web/SecurityFilterChain.java +++ b/web/src/main/java/org/springframework/security/web/SecurityFilterChain.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web; +import java.util.List; + import javax.servlet.Filter; import javax.servlet.http.HttpServletRequest; -import java.util.*; /** * Defines a filter chain which is capable of being matched against an @@ -25,9 +27,7 @@ import java.util.*; *

        * Used to configure a {@code FilterChainProxy}. * - * * @author Luke Taylor - * * @since 3.1 */ public interface SecurityFilterChain { @@ -35,4 +35,5 @@ public interface SecurityFilterChain { boolean matches(HttpServletRequest request); List getFilters(); + } diff --git a/web/src/main/java/org/springframework/security/web/WebAttributes.java b/web/src/main/java/org/springframework/security/web/WebAttributes.java index e280fb19e4..a95700241c 100644 --- a/web/src/main/java/org/springframework/security/web/WebAttributes.java +++ b/web/src/main/java/org/springframework/security/web/WebAttributes.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web; import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; @@ -26,6 +27,7 @@ import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; * @since 3.0.3 */ public final class WebAttributes { + /** * Used to cache an {@code AccessDeniedException} in the request for rendering. * @@ -44,9 +46,13 @@ public final class WebAttributes { * Set as a request attribute to override the default * {@link WebInvocationPrivilegeEvaluator} * - * @see WebInvocationPrivilegeEvaluator * @since 3.1.3 + * @see WebInvocationPrivilegeEvaluator */ - public static final String WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE = WebAttributes.class - .getName() + ".WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE"; + public static final String WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE = WebAttributes.class.getName() + + ".WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE"; + + private WebAttributes() { + } + } diff --git a/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandler.java b/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandler.java index aa40b8deb5..8c4495ca6f 100644 --- a/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandler.java +++ b/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandler.java @@ -16,14 +16,14 @@ package org.springframework.security.web.access; -import org.springframework.security.access.AccessDeniedException; - import java.io.IOException; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.security.access.AccessDeniedException; + /** * Used by {@link ExceptionTranslationFilter} to handle an * AccessDeniedException. @@ -31,20 +31,16 @@ import javax.servlet.http.HttpServletResponse; * @author Ben Alex */ public interface AccessDeniedHandler { - // ~ Methods - // ======================================================================================================== /** * Handles an access denied failure. - * * @param request that resulted in an AccessDeniedException * @param response so that the user agent can be advised of the failure * @param accessDeniedException that caused the invocation - * * @throws IOException in the event of an IOException * @throws ServletException in the event of a ServletException */ - void handle(HttpServletRequest request, HttpServletResponse response, - AccessDeniedException accessDeniedException) throws IOException, - ServletException; + void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException accessDeniedException) + throws IOException, ServletException; + } diff --git a/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java b/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java index 742dcb64dc..26315b9ec0 100644 --- a/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java +++ b/web/src/main/java/org/springframework/security/web/access/AccessDeniedHandlerImpl.java @@ -18,16 +18,17 @@ package org.springframework.security.web.access; import java.io.IOException; -import javax.servlet.RequestDispatcher; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.http.HttpStatus; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.web.WebAttributes; +import org.springframework.util.Assert; /** * Base implementation of {@link AccessDeniedHandler}. @@ -43,56 +44,39 @@ import org.springframework.security.web.WebAttributes; * @author Ben Alex */ public class AccessDeniedHandlerImpl implements AccessDeniedHandler { - // ~ Static fields/initializers - // ===================================================================================== protected static final Log logger = LogFactory.getLog(AccessDeniedHandlerImpl.class); - // ~ Instance fields - // ================================================================================================ - private String errorPage; - // ~ Methods - // ======================================================================================================== - + @Override public void handle(HttpServletRequest request, HttpServletResponse response, - AccessDeniedException accessDeniedException) throws IOException, - ServletException { - if (!response.isCommitted()) { - if (errorPage != null) { - // Put exception into request scope (perhaps of use to a view) - request.setAttribute(WebAttributes.ACCESS_DENIED_403, - accessDeniedException); - - // Set the 403 status code. - response.setStatus(HttpStatus.FORBIDDEN.value()); - - // forward to error page. - RequestDispatcher dispatcher = request.getRequestDispatcher(errorPage); - dispatcher.forward(request, response); - } - else { - response.sendError(HttpStatus.FORBIDDEN.value(), - HttpStatus.FORBIDDEN.getReasonPhrase()); - } + AccessDeniedException accessDeniedException) throws IOException, ServletException { + if (response.isCommitted()) { + return; } + if (this.errorPage == null) { + response.sendError(HttpStatus.FORBIDDEN.value(), HttpStatus.FORBIDDEN.getReasonPhrase()); + return; + } + // Put exception into request scope (perhaps of use to a view) + request.setAttribute(WebAttributes.ACCESS_DENIED_403, accessDeniedException); + // Set the 403 status code. + response.setStatus(HttpStatus.FORBIDDEN.value()); + // forward to error page. + request.getRequestDispatcher(this.errorPage).forward(request, response); } /** * The error page to use. Must begin with a "/" and is interpreted relative to the * current context root. - * * @param errorPage the dispatcher path to display - * * @throws IllegalArgumentException if the argument doesn't comply with the above * limitations */ public void setErrorPage(String errorPage) { - if ((errorPage != null) && !errorPage.startsWith("/")) { - throw new IllegalArgumentException("errorPage must begin with '/'"); - } - + Assert.isTrue(errorPage == null || errorPage.startsWith("/"), "errorPage must begin with '/'"); this.errorPage = errorPage; } + } diff --git a/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java b/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java index 229eeb749a..7030d29c46 100644 --- a/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java +++ b/web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java @@ -20,6 +20,8 @@ import java.util.Collection; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.ConfigAttribute; import org.springframework.security.access.intercept.AbstractSecurityInterceptor; @@ -34,44 +36,28 @@ import org.springframework.util.Assert; * @author Luke Taylor * @since 3.0 */ -public class DefaultWebInvocationPrivilegeEvaluator implements - WebInvocationPrivilegeEvaluator { - // ~ Static fields/initializers - // ===================================================================================== +public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator { - protected static final Log logger = LogFactory - .getLog(DefaultWebInvocationPrivilegeEvaluator.class); - - // ~ Instance fields - // ================================================================================================ + protected static final Log logger = LogFactory.getLog(DefaultWebInvocationPrivilegeEvaluator.class); private final AbstractSecurityInterceptor securityInterceptor; - // ~ Constructors - // =================================================================================================== - - public DefaultWebInvocationPrivilegeEvaluator( - AbstractSecurityInterceptor securityInterceptor) { + public DefaultWebInvocationPrivilegeEvaluator(AbstractSecurityInterceptor securityInterceptor) { Assert.notNull(securityInterceptor, "SecurityInterceptor cannot be null"); - Assert.isTrue( - FilterInvocation.class.equals(securityInterceptor.getSecureObjectClass()), + Assert.isTrue(FilterInvocation.class.equals(securityInterceptor.getSecureObjectClass()), "AbstractSecurityInterceptor does not support FilterInvocations"); Assert.notNull(securityInterceptor.getAccessDecisionManager(), "AbstractSecurityInterceptor must provide a non-null AccessDecisionManager"); - this.securityInterceptor = securityInterceptor; } - // ~ Methods - // ======================================================================================================== - /** * Determines whether the user represented by the supplied Authentication * object is allowed to invoke the supplied URI. - * * @param uri the URI excluding the context path (a default context path setting will * be used) */ + @Override public boolean isAllowed(String uri, Authentication authentication) { return isAllowed(null, uri, null, authentication); } @@ -85,7 +71,6 @@ public class DefaultWebInvocationPrivilegeEvaluator implements * metadata applies to a given request URI, so generally the contextPath * is unimportant unless you are using a custom * FilterInvocationSecurityMetadataSource. - * * @param uri the URI excluding the context path * @param contextPath the context path (may be null, in which case a default value * will be used). @@ -94,39 +79,26 @@ public class DefaultWebInvocationPrivilegeEvaluator implements * be used in evaluation whether access should be granted. * @return true if access is allowed, false if denied */ - public boolean isAllowed(String contextPath, String uri, String method, - Authentication authentication) { + @Override + public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) { Assert.notNull(uri, "uri parameter is required"); - - FilterInvocation fi = new FilterInvocation(contextPath, uri, method); - Collection attrs = securityInterceptor - .obtainSecurityMetadataSource().getAttributes(fi); - - if (attrs == null) { - if (securityInterceptor.isRejectPublicInvocations()) { - return false; - } - - return true; + FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method); + Collection attributes = this.securityInterceptor.obtainSecurityMetadataSource() + .getAttributes(filterInvocation); + if (attributes == null) { + return (!this.securityInterceptor.isRejectPublicInvocations()); } - if (authentication == null) { return false; } - try { - securityInterceptor.getAccessDecisionManager().decide(authentication, fi, - attrs); + this.securityInterceptor.getAccessDecisionManager().decide(authentication, filterInvocation, attributes); + return true; } - catch (AccessDeniedException unauthorized) { - if (logger.isDebugEnabled()) { - logger.debug(fi.toString() + " denied for " + authentication.toString(), - unauthorized); - } - + catch (AccessDeniedException ex) { + logger.debug(LogMessage.format("%s denied for %s", filterInvocation, authentication), ex); return false; } - - return true; } + } diff --git a/web/src/main/java/org/springframework/security/web/access/DelegatingAccessDeniedHandler.java b/web/src/main/java/org/springframework/security/web/access/DelegatingAccessDeniedHandler.java index bb4a33a052..9e5fc8d3fa 100644 --- a/web/src/main/java/org/springframework/security/web/access/DelegatingAccessDeniedHandler.java +++ b/web/src/main/java/org/springframework/security/web/access/DelegatingAccessDeniedHandler.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web.access; import java.io.IOException; @@ -36,13 +37,13 @@ import org.springframework.util.Assert; * */ public final class DelegatingAccessDeniedHandler implements AccessDeniedHandler { + private final LinkedHashMap, AccessDeniedHandler> handlers; private final AccessDeniedHandler defaultHandler; /** * Creates a new instance - * * @param handlers a map of the {@link AccessDeniedException} class to the * {@link AccessDeniedHandler} that should be used. Each is considered in the order * they are specified and only the first {@link AccessDeniedHandler} is ued. @@ -58,11 +59,10 @@ public final class DelegatingAccessDeniedHandler implements AccessDeniedHandler this.defaultHandler = defaultHandler; } + @Override public void handle(HttpServletRequest request, HttpServletResponse response, - AccessDeniedException accessDeniedException) throws IOException, - ServletException { - for (Entry, AccessDeniedHandler> entry : handlers - .entrySet()) { + AccessDeniedException accessDeniedException) throws IOException, ServletException { + for (Entry, AccessDeniedHandler> entry : this.handlers.entrySet()) { Class handlerClass = entry.getKey(); if (handlerClass.isAssignableFrom(accessDeniedException.getClass())) { AccessDeniedHandler handler = entry.getValue(); @@ -70,7 +70,7 @@ public final class DelegatingAccessDeniedHandler implements AccessDeniedHandler return; } } - defaultHandler.handle(request, response, accessDeniedException); + this.defaultHandler.handle(request, response, accessDeniedException); } } diff --git a/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java b/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java index 3e3345250c..dd0360aa87 100644 --- a/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java +++ b/web/src/main/java/org/springframework/security/web/access/ExceptionTranslationFilter.java @@ -13,8 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.security.web.access; +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.context.support.MessageSourceAccessor; +import org.springframework.core.log.LogMessage; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; @@ -30,16 +42,6 @@ import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; -import org.springframework.context.support.MessageSourceAccessor; - -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; - /** * Handles any AccessDeniedException and AuthenticationException * thrown within the filter chain. @@ -58,13 +60,15 @@ import java.io.IOException; * authenticationEntryPoint will be launched. If they are not an anonymous * user, the filter will delegate to the * {@link org.springframework.security.web.access.AccessDeniedHandler}. By default the - * filter will use {@link org.springframework.security.web.access.AccessDeniedHandlerImpl}. + * filter will use + * {@link org.springframework.security.web.access.AccessDeniedHandlerImpl}. *

        * To use this filter, it is necessary to specify the following properties: *